Spaces:
Runtime error
Runtime error
| import pytorch_lightning as pl | |
| import torch | |
| from torchmetrics import Accuracy, Precision, Recall, F1Score | |
| from transformers import Wav2Vec2Model, Wav2Vec2ForSequenceClassification | |
| import torch.nn.functional as F | |
| from models.lora import LinearWithLoRA, LoRALayer | |
| class Wav2Vec2Classifier(pl.LightningModule): | |
| def __init__(self, num_classes, optimizer_cfg = "Adam", l1_lambda=0.0): | |
| super(Wav2Vec2Classifier, self).__init__() | |
| self.save_hyperparameters() | |
| # Wav2Vec2 backbone | |
| # self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") | |
| self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53") | |
| # trying without the need to fine tune it | |
| for param in self.wav2vec2.parameters(): | |
| param.requires_grad = False | |
| # Classification head | |
| self.classifier = torch.nn.Linear(self.wav2vec2.config.hidden_size, num_classes) | |
| # Metrics | |
| self.accuracy = Accuracy(task="multiclass", num_classes=num_classes) | |
| self.precision = Precision(task="multiclass", num_classes=num_classes) | |
| self.recall = Recall(task="multiclass", num_classes=num_classes) | |
| self.f1 = F1Score(task="multiclass", num_classes=num_classes) | |
| self.l1_lambda = l1_lambda | |
| if optimizer_cfg is not None: | |
| optimizer_name = optimizer_cfg.name | |
| optimizer_lr = optimizer_cfg.lr | |
| optimizer_weight_decay = optimizer_cfg.weight_decay | |
| if optimizer_name == 'Adam': | |
| self.optimizer = torch.optim.Adam(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) | |
| elif optimizer_name == 'SGD': | |
| self.optimizer = torch.optim.SGD(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) | |
| else: | |
| raise ValueError(f"Unsupported optimizer: {optimizer_name}") | |
| else: | |
| self.optimizer = None | |
| def forward(self, x, attention_mask=None): | |
| # Debug input shape | |
| # Ensure input shape is [batch_size, sequence_length] | |
| if x.dim() > 2: | |
| x = x.squeeze(-1) # Remove unnecessary dimensions if present | |
| # Pass through Wav2Vec2 backbone | |
| output = self.wav2vec2(x, attention_mask=attention_mask) | |
| x = output.last_hidden_state | |
| # Classification head | |
| x = torch.mean(x, dim=1) # Pooling | |
| logits = self.classifier(x) | |
| return logits | |
| def training_step(self, batch, batch_idx): | |
| x, attention_mask, y = batch | |
| # Forward pass | |
| logits = self(x, attention_mask=attention_mask) | |
| # Compute loss | |
| loss = F.cross_entropy(logits, y) | |
| # Add L1 regularization if specified | |
| l1_norm = sum(param.abs().sum() for param in self.parameters()) | |
| loss += self.l1_lambda * l1_norm | |
| # Log metrics | |
| self.log("train_loss", loss, prog_bar=True, logger=True) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| x, attention_mask, y = batch # Unpack batch | |
| # Forward pass | |
| logits = self(x, attention_mask=attention_mask) | |
| # Compute loss and metrics | |
| loss = F.cross_entropy(logits, y) | |
| preds = torch.argmax(logits, dim=1) | |
| accuracy = self.accuracy(preds, y) | |
| precision = self.precision(preds, y) | |
| recall = self.recall(preds, y) | |
| f1 = self.f1(preds, y) | |
| # Log metrics | |
| self.log("val_loss", loss, prog_bar=True, logger=True) | |
| self.log("val_acc", accuracy, prog_bar=True, logger=True) | |
| self.log("val_precision", precision, prog_bar=True, logger=True) | |
| self.log("val_recall", recall, prog_bar=True, logger=True) | |
| self.log("val_f1", f1, prog_bar=True, logger=True) | |
| return loss | |
| def test_step(self, batch, batch_idx): | |
| x, attention_mask, y = batch # Unpack batch | |
| # Forward pass | |
| logits = self(x, attention_mask=attention_mask) | |
| # Compute loss and metrics | |
| loss = F.cross_entropy(logits, y) | |
| preds = torch.argmax(logits, dim=1) | |
| accuracy = self.accuracy(preds, y) | |
| precision = self.precision(preds, y) | |
| recall = self.recall(preds, y) | |
| f1 = self.f1(preds, y) | |
| # Log metrics | |
| self.log("test_loss", loss, prog_bar=True, logger=True) | |
| self.log("test_acc", accuracy, prog_bar=True, logger=True) | |
| self.log("test_precision", precision, prog_bar=True, logger=True) | |
| self.log("test_recall", recall, prog_bar=True, logger=True) | |
| self.log("test_f1", f1, prog_bar=True, logger=True) | |
| return {"test_loss": loss, "test_accuracy": accuracy} | |
| def configure_optimizers(self): | |
| optimizer = self.optimizer | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5) | |
| return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"} | |
| class Wav2Vec2EmotionClassifier(pl.LightningModule): | |
| def __init__(self, num_classes, learning_rate=1e-4, freeze_base=False, optimizer_cfg="AdamW"): | |
| super(Wav2Vec2EmotionClassifier, self).__init__() | |
| self.save_hyperparameters() | |
| # Load a pre-trained Wav2Vec2 model optimized for emotion recognition | |
| self.model = Wav2Vec2ForSequenceClassification.from_pretrained( | |
| "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim", | |
| num_labels=num_classes, | |
| ) | |
| # Optionally freeze the Wav2Vec2 base layers | |
| if freeze_base: | |
| for param in self.model.wav2vec2.parameters(): | |
| param.requires_grad = False | |
| # Metrics | |
| self.accuracy = Accuracy(task="multiclass", num_classes=num_classes) | |
| self.precision = Precision(task="multiclass", num_classes=num_classes) | |
| self.recall = Recall(task="multiclass", num_classes=num_classes) | |
| self.f1 = F1Score(task="multiclass", num_classes=num_classes) | |
| self.learning_rate = learning_rate | |
| if optimizer_cfg is not None: | |
| optimizer_name = optimizer_cfg['name'] | |
| optimizer_lr = optimizer_cfg['lr'] | |
| optimizer_weight_decay = optimizer_cfg['weight_decay'] | |
| if optimizer_name == 'Adam': | |
| self.optimizer = torch.optim.Adam(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) | |
| elif optimizer_name == 'SGD': | |
| self.optimizer = torch.optim.SGD(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) | |
| elif optimizer_name == 'AdamW': | |
| self.optimizer = torch.optim.AdamW(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) | |
| else: | |
| raise ValueError(f"Unsupported optimizer: {optimizer_name}") | |
| else: | |
| self.optimizer = None | |
| # Apply LoRA | |
| low_rank = 8 | |
| lora_alpha = 16 | |
| self.apply_lora(low_rank, lora_alpha) | |
| def apply_lora(self, rank, alpha): | |
| # Replace specific linear layers with LinearWithLoRA | |
| for layer in self.model.wav2vec2.encoder.layers: | |
| layer.attention.q_proj = LinearWithLoRA(layer.attention.q_proj, rank, alpha) | |
| layer.attention.k_proj = LinearWithLoRA(layer.attention.k_proj, rank, alpha) | |
| layer.attention.v_proj = LinearWithLoRA(layer.attention.v_proj, rank, alpha) | |
| layer.attention.out_proj = LinearWithLoRA(layer.attention.out_proj, rank, alpha) | |
| layer.feed_forward.intermediate_dense = LinearWithLoRA(layer.feed_forward.intermediate_dense, rank, alpha) | |
| layer.feed_forward.output_dense = LinearWithLoRA(layer.feed_forward.output_dense, rank, alpha) | |
| def state_dict(self, *args, **kwargs): | |
| # Save only LoRA and classifier/projector parameters | |
| state = super().state_dict(*args, **kwargs) | |
| return {k: v for k, v in state.items() if "lora" in k or "classifier" in k or "projector" in k} | |
| def load_state_dict(self, state_dict, strict=True): | |
| missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False) | |
| if missing_keys or unexpected_keys: | |
| print(f"Missing keys: {missing_keys}") | |
| print(f"Unexpected keys: {unexpected_keys}") | |
| def forward(self, x, attention_mask=None): | |
| return self.model(x, attention_mask=attention_mask).logits | |
| def training_step(self, batch, batch_idx): | |
| x, attention_mask, y = batch | |
| # Forward pass | |
| logits = self(x, attention_mask=attention_mask) | |
| # Compute loss | |
| loss = F.cross_entropy(logits, y) | |
| # Log training loss | |
| self.log("train_loss", loss, prog_bar=True, logger=True) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| x, attention_mask, y = batch | |
| # Forward pass | |
| logits = self(x, attention_mask=attention_mask) | |
| # Compute loss and metrics | |
| loss = F.cross_entropy(logits, y) | |
| preds = torch.argmax(logits, dim=1) | |
| accuracy = self.accuracy(preds, y) | |
| precision = self.precision(preds, y) | |
| recall = self.recall(preds, y) | |
| f1 = self.f1(preds, y) | |
| # Log metrics | |
| self.log("val_loss", loss, prog_bar=True, logger=True) | |
| self.log("val_acc", accuracy, prog_bar=True, logger=True) | |
| self.log("val_precision", precision, prog_bar=True, logger=True) | |
| self.log("val_recall", recall, prog_bar=True, logger=True) | |
| self.log("val_f1", f1, prog_bar=True, logger=True) | |
| return loss | |
| def test_step(self, batch, batch_idx): | |
| x, attention_mask, y = batch | |
| # Forward pass | |
| logits = self(x, attention_mask=attention_mask) | |
| # Compute loss and metrics | |
| loss = F.cross_entropy(logits, y) | |
| preds = torch.argmax(logits, dim=1) | |
| accuracy = self.accuracy(preds, y) | |
| precision = self.precision(preds, y) | |
| recall = self.recall(preds, y) | |
| f1 = self.f1(preds, y) | |
| # Log metrics | |
| self.log("test_loss", loss, prog_bar=True, logger=True) | |
| self.log("test_acc", accuracy, prog_bar=True, logger=True) | |
| self.log("test_precision", precision, prog_bar=True, logger=True) | |
| self.log("test_recall", recall, prog_bar=True, logger=True) | |
| self.log("test_f1", f1, prog_bar=True, logger=True) | |
| return {"test_loss": loss, "test_accuracy": accuracy} | |
| def configure_optimizers(self): | |
| optimizer = self.optimizer | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5) | |
| return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"} |