Andybeyond commited on
Commit
3e0d8bd
·
verified ·
1 Parent(s): d9c7e7e

Update notebooks/melody_development.ipynb

Browse files

Updated the melody_development sample coding as a starting template.

Files changed (1) hide show
  1. notebooks/melody_development.ipynb +1160 -1
notebooks/melody_development.ipynb CHANGED
@@ -135,4 +135,1163 @@ class MelodyDataset(torch.utils.data.Dataset):
135
  }
136
 
137
  def __len__(self):
138
- return len(self.midi_files)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  }
136
 
137
  def __len__(self):
138
+ return len(self.midi_files)
139
+
140
+
141
+ # =====================================
142
+ # 2. Model Architecture Development
143
+ # =====================================
144
+
145
+ class MelodyTransformer(nn.Module):
146
+ """
147
+ Transformer-based model for melody generation.
148
+
149
+ Architecture Overview:
150
+ 1. Embedding layers for notes, durations, and positions
151
+ 2. Transformer encoder for sequence processing
152
+ 3. Separate prediction heads for notes and durations
153
+
154
+ Args:
155
+ num_notes (int): Size of note vocabulary (default: 128 for MIDI range)
156
+ max_duration (int): Number of possible duration values (default: 32)
157
+ d_model (int): Dimension of the model (default: 512)
158
+ nhead (int): Number of attention heads (default: 8)
159
+ num_layers (int): Number of transformer layers (default: 6)
160
+
161
+ Forward Pass:
162
+ - Input: note sequence, duration sequence, position indices
163
+ - Output: predictions for next note and duration
164
+ """
165
+
166
+ def __init__(self,
167
+ num_notes=128, # MIDI note range (0-127)
168
+ max_duration=32, # Quantized duration values
169
+ d_model=512, # Model dimension (as in original Transformer)
170
+ nhead=8, # Multi-head attention
171
+ num_layers=6): # Number of Transformer layers
172
+ super().__init__()
173
+
174
+ # Embedding layers
175
+ self.note_embedding = nn.Embedding(
176
+ num_embeddings=num_notes,
177
+ embedding_dim=d_model,
178
+ padding_idx=0 # Use 0 for padding
179
+ )
180
+
181
+ self.duration_embedding = nn.Embedding(
182
+ num_embeddings=max_duration,
183
+ embedding_dim=d_model,
184
+ padding_idx=0
185
+ )
186
+
187
+ self.position_embedding = nn.Embedding(
188
+ num_embeddings=1024, # Maximum sequence length
189
+ embedding_dim=d_model
190
+ )
191
+
192
+ # Transformer architecture
193
+ encoder_layer = nn.TransformerEncoderLayer(
194
+ d_model=d_model,
195
+ nhead=nhead,
196
+ dim_feedforward=4*d_model, # As per original Transformer paper
197
+ dropout=0.1,
198
+ activation='gelu' # Modern activation function
199
+ )
200
+
201
+ self.transformer = nn.TransformerEncoder(
202
+ encoder_layer=encoder_layer,
203
+ num_layers=num_layers,
204
+ norm=nn.LayerNorm(d_model)
205
+ )
206
+
207
+ # Output heads
208
+ self.note_head = nn.Sequential(
209
+ nn.Linear(d_model, d_model),
210
+ nn.ReLU(),
211
+ nn.Dropout(0.1),
212
+ nn.Linear(d_model, num_notes)
213
+ )
214
+
215
+ self.duration_head = nn.Sequential(
216
+ nn.Linear(d_model, d_model),
217
+ nn.ReLU(),
218
+ nn.Dropout(0.1),
219
+ nn.Linear(d_model, max_duration)
220
+ )
221
+
222
+ def forward(self, notes, durations, positions):
223
+ """
224
+ Forward pass through the model.
225
+
226
+ Args:
227
+ notes (torch.Tensor): Shape [batch_size, seq_length]
228
+ Contains MIDI note numbers
229
+ durations (torch.Tensor): Shape [batch_size, seq_length]
230
+ Contains quantized duration values
231
+ positions (torch.Tensor): Shape [batch_size, seq_length]
232
+ Contains position indices
233
+
234
+ Returns:
235
+ tuple: (note_logits, duration_logits)
236
+ - note_logits: Shape [batch_size, seq_length, num_notes]
237
+ - duration_logits: Shape [batch_size, seq_length, max_duration]
238
+
239
+ Note:
240
+ The model predicts both the next note and its duration
241
+ simultaneously, allowing for coherent melody generation.
242
+ """
243
+ # Get embeddings for each component
244
+ note_emb = self.note_embedding(notes) # [B, S, D]
245
+ duration_emb = self.duration_embedding(durations) # [B, S, D]
246
+ pos_emb = self.position_embedding(positions) # [B, S, D]
247
+
248
+ # Combine embeddings
249
+ # Sum embeddings as in original Transformer paper
250
+ x = note_emb + duration_emb + pos_emb # [B, S, D]
251
+
252
+ # Apply Transformer
253
+ # Note: Need to reshape for Transformer which expects [S, B, D]
254
+ x = x.transpose(0, 1)
255
+ x = self.transformer(x)
256
+ x = x.transpose(0, 1) # Back to [B, S, D]
257
+
258
+ # Generate predictions
259
+ note_logits = self.note_head(x) # [B, S, num_notes]
260
+ duration_logits = self.duration_head(x) # [B, S, max_duration]
261
+
262
+ return note_logits, duration_logits
263
+
264
+ def generate(self, prompt, max_length=512, temperature=1.0):
265
+ """
266
+ Generate a melody from a starting prompt.
267
+
268
+ Args:
269
+ prompt (dict): Initial notes and durations
270
+ max_length (int): Maximum sequence length to generate
271
+ temperature (float): Sampling temperature (higher = more random)
272
+
273
+ Returns:
274
+ tuple: (generated_notes, generated_durations)
275
+
276
+ Example:
277
+ >>> model = MelodyTransformer()
278
+ >>> prompt = {'notes': [60, 64, 67], 'durations': [1.0, 1.0, 1.0]}
279
+ >>> notes, durations = model.generate(prompt)
280
+ """
281
+ self.eval() # Set to evaluation mode
282
+
283
+ with torch.no_grad():
284
+ # Initialize with prompt
285
+ current_notes = torch.tensor(prompt['notes']).unsqueeze(0)
286
+ current_durations = torch.tensor(prompt['durations']).unsqueeze(0)
287
+
288
+ generated_notes = list(prompt['notes'])
289
+ generated_durations = list(prompt['durations'])
290
+
291
+ # Generate one note at a time
292
+ for i in range(len(prompt['notes']), max_length):
293
+ # Create position tensor
294
+ positions = torch.arange(len(generated_notes)).unsqueeze(0)
295
+
296
+ # Get predictions
297
+ note_logits, duration_logits = self(
298
+ current_notes,
299
+ current_durations,
300
+ positions
301
+ )
302
+
303
+ # Sample from logits using temperature
304
+ note_probs = F.softmax(note_logits[:, -1] / temperature, dim=-1)
305
+ duration_probs = F.softmax(duration_logits[:, -1] / temperature, dim=-1)
306
+
307
+ next_note = torch.multinomial(note_probs, 1)
308
+ next_duration = torch.multinomial(duration_probs, 1)
309
+
310
+ # Append to generated sequence
311
+ generated_notes.append(next_note.item())
312
+ generated_durations.append(next_duration.item())
313
+
314
+ # Update current sequence
315
+ current_notes = torch.tensor(generated_notes).unsqueeze(0)
316
+ current_durations = torch.tensor(generated_durations).unsqueeze(0)
317
+
318
+ return generated_notes, generated_durations
319
+
320
+ # =====================================
321
+ # 3. Training Pipeline
322
+ # =====================================
323
+
324
+ class MelodyTrainer:
325
+ """
326
+ Custom training pipeline for the melody generation model.
327
+
328
+ Features:
329
+ - Automated training loop
330
+ - Validation monitoring
331
+ - Checkpoint saving
332
+ - Logging and metrics tracking
333
+
334
+ Args:
335
+ model (MelodyTransformer): The model to train
336
+ config (dict): Training configuration
337
+ device (str): Device to train on ('cuda' or 'cpu')
338
+ """
339
+
340
+ def __init__(self, model, config, device='cuda'):
341
+ self.model = model.to(device)
342
+ self.config = config
343
+ self.device = device
344
+
345
+ # Initialize training components
346
+ self.criterion = nn.CrossEntropyLoss(ignore_index=0) # Ignore padding
347
+ self.optimizer = torch.optim.AdamW(
348
+ self.model.parameters(),
349
+ lr=config['learning_rate'],
350
+ weight_decay=config.get('weight_decay', 0.01)
351
+ )
352
+
353
+ # Learning rate scheduler
354
+ self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
355
+ self.optimizer,
356
+ max_lr=config['learning_rate'],
357
+ epochs=config['epochs'],
358
+ steps_per_epoch=config['steps_per_epoch']
359
+ )
360
+
361
+ # Initialize wandb for experiment tracking
362
+ if config.get('use_wandb', False):
363
+ wandb.init(
364
+ project="opentunes-melody",
365
+ config=config,
366
+ name=f"melody_training_{datetime.now().strftime('%Y%m%d_%H%M')}"
367
+ )
368
+
369
+ def train_epoch(self, train_loader):
370
+ """
371
+ Train for one epoch.
372
+
373
+ Args:
374
+ train_loader (DataLoader): Training data loader
375
+
376
+ Returns:
377
+ dict: Training metrics for this epoch
378
+ """
379
+ self.model.train()
380
+ epoch_loss = 0
381
+ epoch_note_acc = 0
382
+ epoch_dur_acc = 0
383
+ num_batches = 0
384
+
385
+ for batch in tqdm(train_loader, desc="Training"):
386
+ # Move batch to device
387
+ notes = batch['notes'].to(self.device)
388
+ durations = batch['durations'].to(self.device)
389
+ positions = torch.arange(notes.size(1)).unsqueeze(0).expand(
390
+ notes.size(0), -1).to(self.device)
391
+
392
+ # Forward pass
393
+ note_logits, duration_logits = self.model(notes, durations, positions)
394
+
395
+ # Calculate loss
396
+ # Shift sequences for next-token prediction
397
+ note_loss = self.criterion(
398
+ note_logits[:, :-1].reshape(-1, note_logits.size(-1)),
399
+ notes[:, 1:].reshape(-1)
400
+ )
401
+ duration_loss = self.criterion(
402
+ duration_logits[:, :-1].reshape(-1, duration_logits.size(-1)),
403
+ durations[:, 1:].reshape(-1)
404
+ )
405
+ loss = note_loss + duration_loss
406
+
407
+ # Backward pass
408
+ self.optimizer.zero_grad()
409
+ loss.backward()
410
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
411
+ self.optimizer.step()
412
+ self.scheduler.step()
413
+
414
+ # Calculate metrics
415
+ with torch.no_grad():
416
+ note_preds = note_logits.argmax(dim=-1)
417
+ dur_preds = duration_logits.argmax(dim=-1)
418
+ note_acc = (note_preds[:, :-1] == notes[:, 1:]).float().mean()
419
+ dur_acc = (dur_preds[:, :-1] == durations[:, 1:]).float().mean()
420
+
421
+ # Update running metrics
422
+ epoch_loss += loss.item()
423
+ epoch_note_acc += note_acc.item()
424
+ epoch_dur_acc += dur_acc.item()
425
+ num_batches += 1
426
+
427
+ # Log batch metrics
428
+ if self.config.get('use_wandb', False):
429
+ wandb.log({
430
+ 'batch_loss': loss.item(),
431
+ 'note_accuracy': note_acc.item(),
432
+ 'duration_accuracy': dur_acc.item(),
433
+ 'learning_rate': self.scheduler.get_last_lr()[0]
434
+ })
435
+
436
+ # Calculate epoch metrics
437
+ metrics = {
438
+ 'loss': epoch_loss / num_batches,
439
+ 'note_accuracy': epoch_note_acc / num_batches,
440
+ 'duration_accuracy': epoch_dur_acc / num_batches
441
+ }
442
+
443
+ return metrics
444
+
445
+ def validate(self, val_loader):
446
+ """
447
+ Validate the model.
448
+
449
+ Args:
450
+ val_loader (DataLoader): Validation data loader
451
+
452
+ Returns:
453
+ dict: Validation metrics
454
+ """
455
+ self.model.eval()
456
+ val_loss = 0
457
+ val_note_acc = 0
458
+ val_dur_acc = 0
459
+ num_batches = 0
460
+
461
+ with torch.no_grad():
462
+ for batch in tqdm(val_loader, desc="Validation"):
463
+ notes = batch['notes'].to(self.device)
464
+ durations = batch['durations'].to(self.device)
465
+ positions = torch.arange(notes.size(1)).unsqueeze(0).expand(
466
+ notes.size(0), -1).to(self.device)
467
+
468
+ # Forward pass
469
+ note_logits, duration_logits = self.model(notes, durations, positions)
470
+
471
+ # Calculate metrics (similar to training)
472
+ note_loss = self.criterion(
473
+ note_logits[:, :-1].reshape(-1, note_logits.size(-1)),
474
+ notes[:, 1:].reshape(-1)
475
+ )
476
+ duration_loss = self.criterion(
477
+ duration_logits[:, :-1].reshape(-1, duration_logits.size(-1)),
478
+ durations[:, 1:].reshape(-1)
479
+ )
480
+ loss = note_loss + duration_loss
481
+
482
+ note_preds = note_logits.argmax(dim=-1)
483
+ dur_preds = duration_logits.argmax(dim=-1)
484
+ note_acc = (note_preds[:, :-1] == notes[:, 1:]).float().mean()
485
+ dur_acc = (dur_preds[:, :-1] == durations[:, 1:]).float().mean()
486
+
487
+ val_loss += loss.item()
488
+ val_note_acc += note_acc.item()
489
+ val_dur_acc += dur_acc.item()
490
+ num_batches += 1
491
+
492
+ metrics = {
493
+ 'val_loss': val_loss / num_batches,
494
+ 'val_note_accuracy': val_note_acc / num_batches,
495
+ 'val_duration_accuracy': val_dur_acc / num_batches
496
+ }
497
+
498
+ return metrics
499
+
500
+ def train(self, train_loader, val_loader):
501
+ """
502
+ Full training loop.
503
+
504
+ Args:
505
+ train_loader (DataLoader): Training data loader
506
+ val_loader (DataLoader): Validation data loader
507
+ """
508
+ best_val_loss = float('inf')
509
+
510
+ for epoch in range(self.config['epochs']):
511
+ print(f"\nEpoch {epoch+1}/{self.config['epochs']}")
512
+
513
+ # Training phase
514
+ train_metrics = self.train_epoch(train_loader)
515
+ print(f"Training metrics: {train_metrics}")
516
+
517
+ # Validation phase
518
+ val_metrics = self.validate(val_loader)
519
+ print(f"Validation metrics: {val_metrics}")
520
+
521
+ # Save checkpoint if best so far
522
+ if val_metrics['val_loss'] < best_val_loss:
523
+ best_val_loss = val_metrics['val_loss']
524
+ self.save_checkpoint(
525
+ f"models/melody-gen/weights/v0.1.0/best_model.pth",
526
+ epoch,
527
+ train_metrics,
528
+ val_metrics
529
+ )
530
+
531
+ # Log epoch metrics
532
+ if self.config.get('use_wandb', False):
533
+ wandb.log({
534
+ 'epoch': epoch,
535
+ **train_metrics,
536
+ **val_metrics
537
+ })
538
+
539
+ def save_checkpoint(self, path, epoch, train_metrics, val_metrics):
540
+ """
541
+ Save model checkpoint.
542
+
543
+ Args:
544
+ path (str): Path to save checkpoint
545
+ epoch (int): Current epoch
546
+ train_metrics (dict): Training metrics
547
+ val_metrics (dict): Validation metrics
548
+ """
549
+ checkpoint = {
550
+ 'epoch': epoch,
551
+ 'model_state_dict': self.model.state_dict(),
552
+ 'optimizer_state_dict': self.optimizer.state_dict(),
553
+ 'scheduler_state_dict': self.scheduler.state_dict(),
554
+ 'train_metrics': train_metrics,
555
+ 'val_metrics': val_metrics,
556
+ 'config': self.config
557
+ }
558
+
559
+ torch.save(checkpoint, path)
560
+ print(f"Checkpoint saved to {path}")
561
+
562
+ # =====================================
563
+ # 4. Evaluation Functions
564
+ # =====================================
565
+
566
+ class MelodyEvaluator:
567
+ """
568
+ Comprehensive evaluation suite for melody generation models.
569
+
570
+ Features:
571
+ - Note accuracy metrics
572
+ - Musical quality assessment
573
+ - Style consistency checking
574
+ - Sample generation and analysis
575
+
576
+ Args:
577
+ model (MelodyTransformer): Trained model to evaluate
578
+ device (str): Device to run evaluation on
579
+ """
580
+
581
+ def __init__(self, model, device='cuda'):
582
+ self.model = model.to(device)
583
+ self.device = device
584
+ self.model.eval() # Set model to evaluation mode
585
+
586
+ def evaluate_metrics(self, test_loader):
587
+ """
588
+ Compute quantitative metrics on test set.
589
+
590
+ Args:
591
+ test_loader (DataLoader): Test data loader
592
+
593
+ Returns:
594
+ dict: Dictionary of evaluation metrics
595
+ """
596
+ metrics = {
597
+ 'note_accuracy': 0,
598
+ 'rhythm_accuracy': 0,
599
+ 'sequence_coherence': 0,
600
+ 'scale_consistency': 0
601
+ }
602
+
603
+ num_batches = 0
604
+
605
+ with torch.no_grad():
606
+ for batch in tqdm(test_loader, desc="Evaluating"):
607
+ notes = batch['notes'].to(self.device)
608
+ durations = batch['durations'].to(self.device)
609
+ positions = torch.arange(notes.size(1)).unsqueeze(0).expand(
610
+ notes.size(0), -1).to(self.device)
611
+
612
+ # Get model predictions
613
+ note_logits, duration_logits = self.model(notes, durations, positions)
614
+
615
+ # Calculate basic accuracy
616
+ note_preds = note_logits.argmax(dim=-1)
617
+ dur_preds = duration_logits.argmax(dim=-1)
618
+
619
+ metrics['note_accuracy'] += (note_preds[:, :-1] == notes[:, 1:]).float().mean().item()
620
+ metrics['rhythm_accuracy'] += (dur_preds[:, :-1] == durations[:, 1:]).float().mean().item()
621
+
622
+ # Calculate musical coherence metrics
623
+ metrics['sequence_coherence'] += self._calculate_coherence(note_preds)
624
+ metrics['scale_consistency'] += self._check_scale_consistency(note_preds)
625
+
626
+ num_batches += 1
627
+
628
+ # Average metrics
629
+ for key in metrics:
630
+ metrics[key] /= num_batches
631
+
632
+ return metrics
633
+
634
+ def _calculate_coherence(self, note_sequence):
635
+ """
636
+ Calculate musical coherence score.
637
+
638
+ Checks for:
639
+ - Melodic intervals (steps vs leaps)
640
+ - Phrase structure
641
+ - Repetition patterns
642
+
643
+ Args:
644
+ note_sequence (torch.Tensor): Predicted note sequence
645
+
646
+ Returns:
647
+ float: Coherence score between 0 and 1
648
+ """
649
+ # Convert to numpy for music21 processing
650
+ notes = note_sequence.cpu().numpy()
651
+
652
+ # Calculate interval distribution
653
+ intervals = np.diff(notes, axis=1)
654
+ step_ratio = np.mean(np.abs(intervals) <= 2) # Proportion of stepwise motion
655
+
656
+ # Check for phrase repetition
657
+ phrase_score = self._analyze_phrases(notes)
658
+
659
+ # Combine metrics
660
+ coherence_score = 0.6 * step_ratio + 0.4 * phrase_score
661
+ return coherence_score
662
+
663
+ def _check_scale_consistency(self, note_sequence):
664
+ """
665
+ Check if generated notes follow consistent scale patterns.
666
+
667
+ Args:
668
+ note_sequence (torch.Tensor): Predicted note sequence
669
+
670
+ Returns:
671
+ float: Scale consistency score between 0 and 1
672
+ """
673
+ notes = note_sequence.cpu().numpy()
674
+
675
+ # Create pitch class histogram
676
+ pitch_classes = notes % 12
677
+ histogram = np.bincount(pitch_classes.flatten(), minlength=12)
678
+
679
+ # Check against common scales
680
+ major_scale = np.array([1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1])
681
+ minor_scale = np.array([1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0])
682
+
683
+ # Calculate consistency scores
684
+ major_score = np.sum((histogram > 0) == major_scale) / 12
685
+ minor_score = np.sum((histogram > 0) == minor_scale) / 12
686
+
687
+ return max(major_score, minor_score)
688
+
689
+ def generate_and_evaluate_samples(self, num_samples=10, max_length=512):
690
+ """
691
+ Generate and evaluate multiple melody samples.
692
+
693
+ Args:
694
+ num_samples (int): Number of samples to generate
695
+ max_length (int): Maximum length of each sample
696
+
697
+ Returns:
698
+ tuple: (generated_samples, evaluation_results)
699
+ """
700
+ samples = []
701
+ results = []
702
+
703
+ for i in range(num_samples):
704
+ # Generate sample
705
+ prompt = {
706
+ 'notes': [60], # Start with middle C
707
+ 'durations': [1.0] # Quarter note
708
+ }
709
+
710
+ notes, durations = self.model.generate(
711
+ prompt,
712
+ max_length=max_length,
713
+ temperature=0.8
714
+ )
715
+
716
+ # Evaluate sample
717
+ sample_metrics = {
718
+ 'melodic_range': self._calculate_melodic_range(notes),
719
+ 'rhythm_variety': self._calculate_rhythm_variety(durations),
720
+ 'musical_coherence': self._evaluate_musical_qualities(notes, durations)
721
+ }
722
+
723
+ samples.append({'notes': notes, 'durations': durations})
724
+ results.append(sample_metrics)
725
+
726
+ # Save generated sample
727
+ self._save_sample(
728
+ notes,
729
+ durations,
730
+ f"models/melody-gen/examples/generated_samples/sample_{i+1}.mid"
731
+ )
732
+
733
+ return samples, results
734
+
735
+ def _calculate_melodic_range(self, notes):
736
+ """
737
+ Calculate the melodic range and distribution.
738
+
739
+ Args:
740
+ notes (list): List of MIDI note numbers
741
+
742
+ Returns:
743
+ dict: Melodic range statistics
744
+ """
745
+ return {
746
+ 'range': max(notes) - min(notes),
747
+ 'mean': np.mean(notes),
748
+ 'std': np.std(notes)
749
+ }
750
+
751
+ def _calculate_rhythm_variety(self, durations):
752
+ """
753
+ Analyze rhythm patterns and variety.
754
+
755
+ Args:
756
+ durations (list): List of note durations
757
+
758
+ Returns:
759
+ dict: Rhythm statistics
760
+ """
761
+ return {
762
+ 'unique_values': len(set(durations)),
763
+ 'variance': np.var(durations),
764
+ 'pattern_complexity': len(set(zip(durations[:-1], durations[1:])))
765
+ }
766
+
767
+ def _evaluate_musical_qualities(self, notes, durations):
768
+ """
769
+ Evaluate musical qualities of the generated melody.
770
+
771
+ Checks for:
772
+ - Phrase structure
773
+ - Melodic contour
774
+ - Rhythmic patterns
775
+ - Musical tension and resolution
776
+
777
+ Args:
778
+ notes (list): List of MIDI note numbers
779
+ durations (list): List of note durations
780
+
781
+ Returns:
782
+ dict: Musical quality metrics
783
+ """
784
+ # Convert to music21 stream for analysis
785
+ stream = self._create_music21_stream(notes, durations)
786
+
787
+ return {
788
+ 'phrase_structure': self._analyze_phrases(stream),
789
+ 'melodic_contour': self._analyze_contour(notes),
790
+ 'rhythmic_complexity': self._analyze_rhythm(durations),
791
+ 'tension_resolution': self._analyze_tension(notes)
792
+ }
793
+
794
+ def _save_sample(self, notes, durations, filepath):
795
+ """
796
+ Save generated sample as MIDI file.
797
+
798
+ Args:
799
+ notes (list): List of MIDI note numbers
800
+ durations (list): List of note durations
801
+ filepath (str): Path to save MIDI file
802
+ """
803
+ stream = music21.stream.Stream()
804
+
805
+ for note, duration in zip(notes, durations):
806
+ n = music21.note.Note(note)
807
+ n.duration = music21.duration.Duration(duration)
808
+ stream.append(n)
809
+
810
+ stream.write('midi', fp=filepath)
811
+
812
+ def generate_evaluation_report(self, test_loader):
813
+ """
814
+ Generate comprehensive evaluation report.
815
+
816
+ Args:
817
+ test_loader (DataLoader): Test data loader
818
+
819
+ Returns:
820
+ dict: Complete evaluation report
821
+ """
822
+ # Basic metrics
823
+ metrics = self.evaluate_metrics(test_loader)
824
+
825
+ # Generate and evaluate samples
826
+ samples, sample_results = self.generate_and_evaluate_samples()
827
+
828
+ # Compile complete report
829
+ report = {
830
+ 'quantitative_metrics': metrics,
831
+ 'sample_evaluations': sample_results,
832
+ 'generation_timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
833
+ 'model_version': '0.1.0'
834
+ }
835
+
836
+ # Save report
837
+ with open('models/melody-gen/examples/evaluation_report.json', 'w') as f:
838
+ json.dump(report, f, indent=2)
839
+
840
+ return report
841
+
842
+ # =====================================
843
+ # 5. Generation and Inference
844
+ # =====================================
845
+
846
+ class MelodyGenerator:
847
+ """
848
+ High-level interface for generating melodies using trained model.
849
+
850
+ Features:
851
+ - Text-to-melody generation
852
+ - Style conditioning
853
+ - Batch generation
854
+ - Format conversion and export
855
+
856
+ Args:
857
+ model (MelodyTransformer): Trained model
858
+ device (str): Device to run generation on
859
+ config (dict): Generation parameters
860
+ """
861
+
862
+ def __init__(self, model, device='cuda', config=None):
863
+ self.model = model.to(device)
864
+ self.device = device
865
+ self.model.eval()
866
+
867
+ # Default generation config
868
+ self.config = {
869
+ 'temperature': 0.8,
870
+ 'max_length': 512,
871
+ 'top_k': 50,
872
+ 'top_p': 0.95,
873
+ 'repetition_penalty': 1.2
874
+ }
875
+ if config:
876
+ self.config.update(config)
877
+
878
+ def generate_from_prompt(self, prompt, style=None):
879
+ """
880
+ Generate melody from text prompt.
881
+
882
+ Args:
883
+ prompt (str): Text description of desired melody
884
+ style (dict, optional): Style parameters
885
+ {
886
+ 'genre': 'pop/jazz/classical',
887
+ 'tempo': beats per minute,
888
+ 'mood': 'happy/sad/energetic'
889
+ }
890
+
891
+ Returns:
892
+ dict: Generated melody information
893
+ {
894
+ 'notes': List of MIDI notes,
895
+ 'durations': List of note durations,
896
+ 'midi_path': Path to saved MIDI file,
897
+ 'metadata': Generation metadata
898
+ }
899
+ """
900
+ # Process prompt and style
901
+ generation_params = self._prepare_generation_params(prompt, style)
902
+
903
+ with torch.no_grad():
904
+ # Initialize sequence with start token
905
+ current_notes = torch.tensor([[60]]).to(self.device) # Middle C
906
+ current_durations = torch.tensor([[1.0]]).to(self.device) # Quarter note
907
+
908
+ generated_notes = []
909
+ generated_durations = []
910
+
911
+ # Generate sequence
912
+ for i in range(self.config['max_length']):
913
+ # Get position encoding
914
+ position = torch.arange(current_notes.size(1)).unsqueeze(0).to(self.device)
915
+
916
+ # Get predictions
917
+ note_logits, duration_logits = self.model(
918
+ current_notes,
919
+ current_durations,
920
+ position
921
+ )
922
+
923
+ # Apply temperature and sampling strategies
924
+ next_note = self._sample_from_logits(
925
+ note_logits[:, -1],
926
+ temperature=generation_params['temperature'],
927
+ top_k=generation_params['top_k'],
928
+ top_p=generation_params['top_p']
929
+ )
930
+
931
+ next_duration = self._sample_from_logits(
932
+ duration_logits[:, -1],
933
+ temperature=generation_params['temperature']
934
+ )
935
+
936
+ # Apply repetition penalty
937
+ if len(generated_notes) > 0:
938
+ next_note = self._apply_repetition_penalty(
939
+ next_note,
940
+ generated_notes,
941
+ generation_params['repetition_penalty']
942
+ )
943
+
944
+ # Append to sequences
945
+ generated_notes.append(next_note.item())
946
+ generated_durations.append(next_duration.item())
947
+
948
+ # Update input sequences
949
+ current_notes = torch.tensor([generated_notes]).to(self.device)
950
+ current_durations = torch.tensor([generated_durations]).to(self.device)
951
+
952
+ # Check for end condition
953
+ if self._check_end_condition(generated_notes, generated_durations):
954
+ break
955
+
956
+ # Post-process and save
957
+ return self._post_process_and_save(
958
+ generated_notes,
959
+ generated_durations,
960
+ prompt,
961
+ style
962
+ )
963
+
964
+ def batch_generate(self, prompts, styles=None):
965
+ """
966
+ Generate multiple melodies in batch.
967
+
968
+ Args:
969
+ prompts (list): List of text prompts
970
+ styles (list, optional): List of style parameters
971
+
972
+ Returns:
973
+ list: List of generated melodies
974
+ """
975
+ results = []
976
+ for i, prompt in enumerate(prompts):
977
+ style = styles[i] if styles else None
978
+ result = self.generate_from_prompt(prompt, style)
979
+ results.append(result)
980
+ return results
981
+
982
+ def _prepare_generation_params(self, prompt, style):
983
+ """
984
+ Prepare generation parameters based on prompt and style.
985
+
986
+ Args:
987
+ prompt (str): Text prompt
988
+ style (dict): Style parameters
989
+
990
+ Returns:
991
+ dict: Generation parameters
992
+ """
993
+ params = self.config.copy()
994
+
995
+ # Adjust parameters based on style
996
+ if style:
997
+ if style.get('genre') == 'classical':
998
+ params['temperature'] *= 0.9 # More conservative
999
+ params['repetition_penalty'] *= 1.1
1000
+ elif style.get('genre') == 'jazz':
1001
+ params['temperature'] *= 1.1 # More experimental
1002
+ params['top_k'] *= 1.2
1003
+
1004
+ if style.get('mood') == 'energetic':
1005
+ params['temperature'] *= 1.1
1006
+ elif style.get('mood') == 'calm':
1007
+ params['temperature'] *= 0.9
1008
+
1009
+ return params
1010
+
1011
+ def _sample_from_logits(self, logits, temperature=1.0, top_k=None, top_p=None):
1012
+ """
1013
+ Sample from logits with temperature and optional top-k/top-p filtering.
1014
+
1015
+ Args:
1016
+ logits (torch.Tensor): Raw logits
1017
+ temperature (float): Sampling temperature
1018
+ top_k (int, optional): Top-k filtering parameter
1019
+ top_p (float, optional): Nucleus sampling parameter
1020
+
1021
+ Returns:
1022
+ torch.Tensor: Sampled token
1023
+ """
1024
+ logits = logits / temperature
1025
+
1026
+ # Top-k filtering
1027
+ if top_k is not None:
1028
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
1029
+ logits[indices_to_remove] = float('-inf')
1030
+
1031
+ # Top-p filtering (nucleus sampling)
1032
+ if top_p is not None:
1033
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
1034
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
1035
+
1036
+ sorted_indices_to_remove = cumulative_probs > top_p
1037
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
1038
+ sorted_indices_to_remove[..., 0] = 0
1039
+
1040
+ indices_to_remove = sorted_indices_to_remove.scatter(
1041
+ dim=-1,
1042
+ index=sorted_indices,
1043
+ src=sorted_indices_to_remove
1044
+ )
1045
+ logits[indices_to_remove] = float('-inf')
1046
+
1047
+ # Sample
1048
+ probs = F.softmax(logits, dim=-1)
1049
+ return torch.multinomial(probs, 1)
1050
+
1051
+ def _post_process_and_save(self, notes, durations, prompt, style):
1052
+ """
1053
+ Post-process and save generated melody.
1054
+
1055
+ Args:
1056
+ notes (list): Generated notes
1057
+ durations (list): Generated durations
1058
+ prompt (str): Original prompt
1059
+ style (dict): Style parameters
1060
+
1061
+ Returns:
1062
+ dict: Generation results and metadata
1063
+ """
1064
+ # Create timestamp
1065
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
1066
+
1067
+ # Create MIDI file
1068
+ midi_path = f"models/melody-gen/examples/generated_samples/melody_{timestamp}.mid"
1069
+ self._save_to_midi(notes, durations, midi_path)
1070
+
1071
+ # Prepare metadata
1072
+ metadata = {
1073
+ 'timestamp': timestamp,
1074
+ 'prompt': prompt,
1075
+ 'style': style,
1076
+ 'generation_params': self.config,
1077
+ 'stats': {
1078
+ 'length': len(notes),
1079
+ 'pitch_range': max(notes) - min(notes),
1080
+ 'unique_pitches': len(set(notes)),
1081
+ 'unique_durations': len(set(durations))
1082
+ }
1083
+ }
1084
+
1085
+ # Save metadata
1086
+ metadata_path = f"models/melody-gen/examples/generated_samples/melody_{timestamp}.json"
1087
+ with open(metadata_path, 'w') as f:
1088
+ json.dump(metadata, f, indent=2)
1089
+
1090
+ return {
1091
+ 'notes': notes,
1092
+ 'durations': durations,
1093
+ 'midi_path': midi_path,
1094
+ 'metadata': metadata
1095
+ }
1096
+
1097
+ # =====================================
1098
+ # 6. Utility Functions and Helpers
1099
+ # =====================================
1100
+
1101
+ class MelodyUtils:
1102
+ """
1103
+ Utility functions for melody processing and manipulation.
1104
+ """
1105
+
1106
+ @staticmethod
1107
+ def save_to_midi(notes, durations, path):
1108
+ """
1109
+ Save melody to MIDI file with enhanced musical properties.
1110
+
1111
+ Args:
1112
+ notes (list): MIDI note numbers
1113
+ durations (list): Note durations
1114
+ path (str): Output path
1115
+ """
1116
+ stream = music21.stream.Stream()
1117
+
1118
+ # Add time signature and tempo
1119
+ stream.append(music21.meter.TimeSignature('4/4'))
1120
+ stream.append(music21.tempo.MetronomeMark(number=120))
1121
+
1122
+ # Add notes with velocity for dynamics
1123
+ for note, duration in zip(notes, durations):
1124
+ n = music21.note.Note(note)
1125
+ n.duration = music21.duration.Duration(duration)
1126
+ # Add velocity (dynamics) based on position in phrase
1127
+ n.volume.velocity = MelodyUtils._calculate_velocity(note, notes)
1128
+ stream.append(n)
1129
+
1130
+ stream.write('midi', fp=path)
1131
+
1132
+ @staticmethod
1133
+ def _calculate_velocity(note, notes_sequence):
1134
+ """Calculate appropriate velocity for musical expression."""
1135
+ base_velocity = 64
1136
+ # Emphasize phrase beginnings and high points
1137
+ if note == max(notes_sequence):
1138
+ return min(base_velocity + 32, 127)
1139
+ return base_velocity
1140
+
1141
+ # =====================================
1142
+ # 7. Enhanced Generation Features
1143
+ # =====================================
1144
+
1145
+ class EnhancedMelodyGenerator(MelodyGenerator):
1146
+ """
1147
+ Extended melody generator with additional features.
1148
+ """
1149
+
1150
+ def generate_with_structure(self, prompt, form="AABA"):
1151
+ """
1152
+ Generate melody with specific musical form.
1153
+
1154
+ Args:
1155
+ prompt (str): Text prompt
1156
+ form (str): Musical form (e.g., "AABA", "ABAC")
1157
+
1158
+ Returns:
1159
+ dict: Generated melody with structural sections
1160
+ """
1161
+ sections = {}
1162
+ full_melody = []
1163
+
1164
+ for section in form:
1165
+ if section not in sections:
1166
+ # Generate new section
1167
+ result = self.generate_from_prompt(
1168
+ prompt + f" for section {section}",
1169
+ {'section': section}
1170
+ )
1171
+ sections[section] = (result['notes'], result['durations'])
1172
+
1173
+ # Add section to full melody
1174
+ notes, durations = sections[section]
1175
+ full_melody.extend(zip(notes, durations))
1176
+
1177
+ return self._post_process_structured_melody(full_melody, form)
1178
+
1179
+ def generate_with_harmony(self, prompt, chord_progression=None):
1180
+ """
1181
+ Generate melody with harmonic constraints.
1182
+
1183
+ Args:
1184
+ prompt (str): Text prompt
1185
+ chord_progression (list): Optional chord progression
1186
+
1187
+ Returns:
1188
+ dict: Generated melody with harmonic context
1189
+ """
1190
+ if chord_progression is None:
1191
+ chord_progression = self._generate_chord_progression()
1192
+
1193
+ # Generate melody considering harmony
1194
+ generation_params = self._prepare_generation_params(prompt, {
1195
+ 'harmony': chord_progression
1196
+ })
1197
+
1198
+ return self.generate_from_prompt(prompt, generation_params)
1199
+
1200
+ # =====================================
1201
+ # 8. Example Usage Scenarios
1202
+ # =====================================
1203
+
1204
+ def example_usage():
1205
+ """Example usage of the melody generation system."""
1206
+
1207
+ # 1. Basic melody generation
1208
+ generator = MelodyGenerator(model)
1209
+ result = generator.generate_from_prompt(
1210
+ "Create an upbeat pop melody in C major"
1211
+ )
1212
+
1213
+ # 2. Style-conditional generation
1214
+ styled_result = generator.generate_from_prompt(
1215
+ "Create a jazz melody",
1216
+ style={
1217
+ 'genre': 'jazz',
1218
+ 'tempo': 120,
1219
+ 'mood': 'energetic'
1220
+ }
1221
+ )
1222
+
1223
+ # 3. Structured generation
1224
+ enhanced_generator = EnhancedMelodyGenerator(model)
1225
+ structured_result = enhanced_generator.generate_with_structure(
1226
+ "Create a memorable melody",
1227
+ form="AABA"
1228
+ )
1229
+
1230
+ # 4. Batch generation
1231
+ prompts = [
1232
+ "Happy birthday song style",
1233
+ "Sad emotional melody",
1234
+ "Energetic dance tune"
1235
+ ]
1236
+ batch_results = generator.batch_generate(prompts)
1237
+
1238
+ # 5. Generation with harmony
1239
+ harmonic_result = enhanced_generator.generate_with_harmony(
1240
+ "Create a melody",
1241
+ chord_progression=["C", "Am", "F", "G"]
1242
+ )
1243
+
1244
+ return {
1245
+ 'basic': result,
1246
+ 'styled': styled_result,
1247
+ 'structured': structured_result,
1248
+ 'batch': batch_results,
1249
+ 'harmonic': harmonic_result
1250
+ }
1251
+
1252
+ # =====================================
1253
+ # 9. Integration Example
1254
+ # =====================================
1255
+
1256
+ def run_complete_pipeline():
1257
+ """
1258
+ Complete pipeline from training to generation.
1259
+ """
1260
+ # 1. Load configuration
1261
+ with open('models/melody-gen/config/model_config.json') as f:
1262
+ model_config = json.load(f)
1263
+
1264
+ # 2. Initialize model
1265
+ model = MelodyTransformer(**model_config)
1266
+
1267
+ # 3. Load dataset
1268
+ train_dataset = MelodyDataset('datasets/train')
1269
+ val_dataset = MelodyDataset('datasets/val')
1270
+ test_dataset = MelodyDataset('datasets/test')
1271
+
1272
+ # 4. Training
1273
+ trainer = MelodyTrainer(model, model_config)
1274
+ trainer.train(train_dataset, val_dataset)
1275
+
1276
+ # 5. Evaluation
1277
+ evaluator = MelodyEvaluator(model)
1278
+ eval_results = evaluator.generate_evaluation_report(test_dataset)
1279
+
1280
+ # 6. Generation
1281
+ generator = MelodyGenerator(model)
1282
+ samples = generator.generate_from_prompt(
1283
+ "Create an original melody",
1284
+ style={'genre': 'pop', 'mood': 'happy'}
1285
+ )
1286
+
1287
+ return {
1288
+ 'evaluation': eval_results,
1289
+ 'samples': samples
1290
+ }
1291
+
1292
+ if __name__ == "__main__":
1293
+ # Run example usage
1294
+ results = example_usage()
1295
+
1296
+ # Run complete pipeline
1297
+ pipeline_results = run_complete_pipeline()