import pytorch_lightning as pl import torch from torchmetrics import Accuracy, Precision, Recall, F1Score from transformers import Wav2Vec2Model, Wav2Vec2ForSequenceClassification import torch.nn.functional as F from models.lora import LinearWithLoRA, LoRALayer class Wav2Vec2Classifier(pl.LightningModule): def __init__(self, num_classes, optimizer_cfg = "Adam", l1_lambda=0.0): super(Wav2Vec2Classifier, self).__init__() self.save_hyperparameters() # Wav2Vec2 backbone # self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53") # trying without the need to fine tune it for param in self.wav2vec2.parameters(): param.requires_grad = False # Classification head self.classifier = torch.nn.Linear(self.wav2vec2.config.hidden_size, num_classes) # Metrics self.accuracy = Accuracy(task="multiclass", num_classes=num_classes) self.precision = Precision(task="multiclass", num_classes=num_classes) self.recall = Recall(task="multiclass", num_classes=num_classes) self.f1 = F1Score(task="multiclass", num_classes=num_classes) self.l1_lambda = l1_lambda if optimizer_cfg is not None: optimizer_name = optimizer_cfg.name optimizer_lr = optimizer_cfg.lr optimizer_weight_decay = optimizer_cfg.weight_decay if optimizer_name == 'Adam': self.optimizer = torch.optim.Adam(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) elif optimizer_name == 'SGD': self.optimizer = torch.optim.SGD(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) else: raise ValueError(f"Unsupported optimizer: {optimizer_name}") else: self.optimizer = None def forward(self, x, attention_mask=None): # Debug input shape # Ensure input shape is [batch_size, sequence_length] if x.dim() > 2: x = x.squeeze(-1) # Remove unnecessary dimensions if present # Pass through Wav2Vec2 backbone output = self.wav2vec2(x, attention_mask=attention_mask) x = output.last_hidden_state # Classification head x = torch.mean(x, dim=1) # Pooling logits = self.classifier(x) return logits def training_step(self, batch, batch_idx): x, attention_mask, y = batch # Forward pass logits = self(x, attention_mask=attention_mask) # Compute loss loss = F.cross_entropy(logits, y) # Add L1 regularization if specified l1_norm = sum(param.abs().sum() for param in self.parameters()) loss += self.l1_lambda * l1_norm # Log metrics self.log("train_loss", loss, prog_bar=True, logger=True) return loss def validation_step(self, batch, batch_idx): x, attention_mask, y = batch # Unpack batch # Forward pass logits = self(x, attention_mask=attention_mask) # Compute loss and metrics loss = F.cross_entropy(logits, y) preds = torch.argmax(logits, dim=1) accuracy = self.accuracy(preds, y) precision = self.precision(preds, y) recall = self.recall(preds, y) f1 = self.f1(preds, y) # Log metrics self.log("val_loss", loss, prog_bar=True, logger=True) self.log("val_acc", accuracy, prog_bar=True, logger=True) self.log("val_precision", precision, prog_bar=True, logger=True) self.log("val_recall", recall, prog_bar=True, logger=True) self.log("val_f1", f1, prog_bar=True, logger=True) return loss def test_step(self, batch, batch_idx): x, attention_mask, y = batch # Unpack batch # Forward pass logits = self(x, attention_mask=attention_mask) # Compute loss and metrics loss = F.cross_entropy(logits, y) preds = torch.argmax(logits, dim=1) accuracy = self.accuracy(preds, y) precision = self.precision(preds, y) recall = self.recall(preds, y) f1 = self.f1(preds, y) # Log metrics self.log("test_loss", loss, prog_bar=True, logger=True) self.log("test_acc", accuracy, prog_bar=True, logger=True) self.log("test_precision", precision, prog_bar=True, logger=True) self.log("test_recall", recall, prog_bar=True, logger=True) self.log("test_f1", f1, prog_bar=True, logger=True) return {"test_loss": loss, "test_accuracy": accuracy} def configure_optimizers(self): optimizer = self.optimizer scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5) return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"} class Wav2Vec2EmotionClassifier(pl.LightningModule): def __init__(self, num_classes, learning_rate=1e-4, freeze_base=False, optimizer_cfg="AdamW"): super(Wav2Vec2EmotionClassifier, self).__init__() self.save_hyperparameters() # Load a pre-trained Wav2Vec2 model optimized for emotion recognition self.model = Wav2Vec2ForSequenceClassification.from_pretrained( "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim", num_labels=num_classes, ) # Optionally freeze the Wav2Vec2 base layers if freeze_base: for param in self.model.wav2vec2.parameters(): param.requires_grad = False # Metrics self.accuracy = Accuracy(task="multiclass", num_classes=num_classes) self.precision = Precision(task="multiclass", num_classes=num_classes) self.recall = Recall(task="multiclass", num_classes=num_classes) self.f1 = F1Score(task="multiclass", num_classes=num_classes) self.learning_rate = learning_rate if optimizer_cfg is not None: optimizer_name = optimizer_cfg['name'] optimizer_lr = optimizer_cfg['lr'] optimizer_weight_decay = optimizer_cfg['weight_decay'] if optimizer_name == 'Adam': self.optimizer = torch.optim.Adam(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) elif optimizer_name == 'SGD': self.optimizer = torch.optim.SGD(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) elif optimizer_name == 'AdamW': self.optimizer = torch.optim.AdamW(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) else: raise ValueError(f"Unsupported optimizer: {optimizer_name}") else: self.optimizer = None # Apply LoRA low_rank = 8 lora_alpha = 16 self.apply_lora(low_rank, lora_alpha) def apply_lora(self, rank, alpha): # Replace specific linear layers with LinearWithLoRA for layer in self.model.wav2vec2.encoder.layers: layer.attention.q_proj = LinearWithLoRA(layer.attention.q_proj, rank, alpha) layer.attention.k_proj = LinearWithLoRA(layer.attention.k_proj, rank, alpha) layer.attention.v_proj = LinearWithLoRA(layer.attention.v_proj, rank, alpha) layer.attention.out_proj = LinearWithLoRA(layer.attention.out_proj, rank, alpha) layer.feed_forward.intermediate_dense = LinearWithLoRA(layer.feed_forward.intermediate_dense, rank, alpha) layer.feed_forward.output_dense = LinearWithLoRA(layer.feed_forward.output_dense, rank, alpha) def state_dict(self, *args, **kwargs): # Save only LoRA and classifier/projector parameters state = super().state_dict(*args, **kwargs) return {k: v for k, v in state.items() if "lora" in k or "classifier" in k or "projector" in k} def load_state_dict(self, state_dict, strict=True): missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False) if missing_keys or unexpected_keys: print(f"Missing keys: {missing_keys}") print(f"Unexpected keys: {unexpected_keys}") def forward(self, x, attention_mask=None): return self.model(x, attention_mask=attention_mask).logits def training_step(self, batch, batch_idx): x, attention_mask, y = batch # Forward pass logits = self(x, attention_mask=attention_mask) # Compute loss loss = F.cross_entropy(logits, y) # Log training loss self.log("train_loss", loss, prog_bar=True, logger=True) return loss def validation_step(self, batch, batch_idx): x, attention_mask, y = batch # Forward pass logits = self(x, attention_mask=attention_mask) # Compute loss and metrics loss = F.cross_entropy(logits, y) preds = torch.argmax(logits, dim=1) accuracy = self.accuracy(preds, y) precision = self.precision(preds, y) recall = self.recall(preds, y) f1 = self.f1(preds, y) # Log metrics self.log("val_loss", loss, prog_bar=True, logger=True) self.log("val_acc", accuracy, prog_bar=True, logger=True) self.log("val_precision", precision, prog_bar=True, logger=True) self.log("val_recall", recall, prog_bar=True, logger=True) self.log("val_f1", f1, prog_bar=True, logger=True) return loss def test_step(self, batch, batch_idx): x, attention_mask, y = batch # Forward pass logits = self(x, attention_mask=attention_mask) # Compute loss and metrics loss = F.cross_entropy(logits, y) preds = torch.argmax(logits, dim=1) accuracy = self.accuracy(preds, y) precision = self.precision(preds, y) recall = self.recall(preds, y) f1 = self.f1(preds, y) # Log metrics self.log("test_loss", loss, prog_bar=True, logger=True) self.log("test_acc", accuracy, prog_bar=True, logger=True) self.log("test_precision", precision, prog_bar=True, logger=True) self.log("test_recall", recall, prog_bar=True, logger=True) self.log("test_f1", f1, prog_bar=True, logger=True) return {"test_loss": loss, "test_accuracy": accuracy} def configure_optimizers(self): optimizer = self.optimizer scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5) return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}