Spaces:
Sleeping
Sleeping
import pytorch_lightning as pl | |
import torch | |
from torchmetrics import Accuracy, Precision, Recall, F1Score | |
from transformers import Wav2Vec2Model, Wav2Vec2ForSequenceClassification | |
import torch.nn.functional as F | |
from models.lora import LinearWithLoRA, LoRALayer | |
class Wav2Vec2Classifier(pl.LightningModule): | |
def __init__(self, num_classes, optimizer_cfg = "Adam", l1_lambda=0.0): | |
super(Wav2Vec2Classifier, self).__init__() | |
self.save_hyperparameters() | |
# Wav2Vec2 backbone | |
# self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") | |
self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53") | |
# trying without the need to fine tune it | |
for param in self.wav2vec2.parameters(): | |
param.requires_grad = False | |
# Classification head | |
self.classifier = torch.nn.Linear(self.wav2vec2.config.hidden_size, num_classes) | |
# Metrics | |
self.accuracy = Accuracy(task="multiclass", num_classes=num_classes) | |
self.precision = Precision(task="multiclass", num_classes=num_classes) | |
self.recall = Recall(task="multiclass", num_classes=num_classes) | |
self.f1 = F1Score(task="multiclass", num_classes=num_classes) | |
self.l1_lambda = l1_lambda | |
if optimizer_cfg is not None: | |
optimizer_name = optimizer_cfg.name | |
optimizer_lr = optimizer_cfg.lr | |
optimizer_weight_decay = optimizer_cfg.weight_decay | |
if optimizer_name == 'Adam': | |
self.optimizer = torch.optim.Adam(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) | |
elif optimizer_name == 'SGD': | |
self.optimizer = torch.optim.SGD(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) | |
else: | |
raise ValueError(f"Unsupported optimizer: {optimizer_name}") | |
else: | |
self.optimizer = None | |
def forward(self, x, attention_mask=None): | |
# Debug input shape | |
# Ensure input shape is [batch_size, sequence_length] | |
if x.dim() > 2: | |
x = x.squeeze(-1) # Remove unnecessary dimensions if present | |
# Pass through Wav2Vec2 backbone | |
output = self.wav2vec2(x, attention_mask=attention_mask) | |
x = output.last_hidden_state | |
# Classification head | |
x = torch.mean(x, dim=1) # Pooling | |
logits = self.classifier(x) | |
return logits | |
def training_step(self, batch, batch_idx): | |
x, attention_mask, y = batch | |
# Forward pass | |
logits = self(x, attention_mask=attention_mask) | |
# Compute loss | |
loss = F.cross_entropy(logits, y) | |
# Add L1 regularization if specified | |
l1_norm = sum(param.abs().sum() for param in self.parameters()) | |
loss += self.l1_lambda * l1_norm | |
# Log metrics | |
self.log("train_loss", loss, prog_bar=True, logger=True) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
x, attention_mask, y = batch # Unpack batch | |
# Forward pass | |
logits = self(x, attention_mask=attention_mask) | |
# Compute loss and metrics | |
loss = F.cross_entropy(logits, y) | |
preds = torch.argmax(logits, dim=1) | |
accuracy = self.accuracy(preds, y) | |
precision = self.precision(preds, y) | |
recall = self.recall(preds, y) | |
f1 = self.f1(preds, y) | |
# Log metrics | |
self.log("val_loss", loss, prog_bar=True, logger=True) | |
self.log("val_acc", accuracy, prog_bar=True, logger=True) | |
self.log("val_precision", precision, prog_bar=True, logger=True) | |
self.log("val_recall", recall, prog_bar=True, logger=True) | |
self.log("val_f1", f1, prog_bar=True, logger=True) | |
return loss | |
def test_step(self, batch, batch_idx): | |
x, attention_mask, y = batch # Unpack batch | |
# Forward pass | |
logits = self(x, attention_mask=attention_mask) | |
# Compute loss and metrics | |
loss = F.cross_entropy(logits, y) | |
preds = torch.argmax(logits, dim=1) | |
accuracy = self.accuracy(preds, y) | |
precision = self.precision(preds, y) | |
recall = self.recall(preds, y) | |
f1 = self.f1(preds, y) | |
# Log metrics | |
self.log("test_loss", loss, prog_bar=True, logger=True) | |
self.log("test_acc", accuracy, prog_bar=True, logger=True) | |
self.log("test_precision", precision, prog_bar=True, logger=True) | |
self.log("test_recall", recall, prog_bar=True, logger=True) | |
self.log("test_f1", f1, prog_bar=True, logger=True) | |
return {"test_loss": loss, "test_accuracy": accuracy} | |
def configure_optimizers(self): | |
optimizer = self.optimizer | |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5) | |
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"} | |
class Wav2Vec2EmotionClassifier(pl.LightningModule): | |
def __init__(self, num_classes, learning_rate=1e-4, freeze_base=False, optimizer_cfg="AdamW"): | |
super(Wav2Vec2EmotionClassifier, self).__init__() | |
self.save_hyperparameters() | |
# Load a pre-trained Wav2Vec2 model optimized for emotion recognition | |
self.model = Wav2Vec2ForSequenceClassification.from_pretrained( | |
"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim", | |
num_labels=num_classes, | |
) | |
# Optionally freeze the Wav2Vec2 base layers | |
if freeze_base: | |
for param in self.model.wav2vec2.parameters(): | |
param.requires_grad = False | |
# Metrics | |
self.accuracy = Accuracy(task="multiclass", num_classes=num_classes) | |
self.precision = Precision(task="multiclass", num_classes=num_classes) | |
self.recall = Recall(task="multiclass", num_classes=num_classes) | |
self.f1 = F1Score(task="multiclass", num_classes=num_classes) | |
self.learning_rate = learning_rate | |
if optimizer_cfg is not None: | |
optimizer_name = optimizer_cfg['name'] | |
optimizer_lr = optimizer_cfg['lr'] | |
optimizer_weight_decay = optimizer_cfg['weight_decay'] | |
if optimizer_name == 'Adam': | |
self.optimizer = torch.optim.Adam(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) | |
elif optimizer_name == 'SGD': | |
self.optimizer = torch.optim.SGD(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) | |
elif optimizer_name == 'AdamW': | |
self.optimizer = torch.optim.AdamW(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) | |
else: | |
raise ValueError(f"Unsupported optimizer: {optimizer_name}") | |
else: | |
self.optimizer = None | |
# Apply LoRA | |
low_rank = 8 | |
lora_alpha = 16 | |
self.apply_lora(low_rank, lora_alpha) | |
def apply_lora(self, rank, alpha): | |
# Replace specific linear layers with LinearWithLoRA | |
for layer in self.model.wav2vec2.encoder.layers: | |
layer.attention.q_proj = LinearWithLoRA(layer.attention.q_proj, rank, alpha) | |
layer.attention.k_proj = LinearWithLoRA(layer.attention.k_proj, rank, alpha) | |
layer.attention.v_proj = LinearWithLoRA(layer.attention.v_proj, rank, alpha) | |
layer.attention.out_proj = LinearWithLoRA(layer.attention.out_proj, rank, alpha) | |
layer.feed_forward.intermediate_dense = LinearWithLoRA(layer.feed_forward.intermediate_dense, rank, alpha) | |
layer.feed_forward.output_dense = LinearWithLoRA(layer.feed_forward.output_dense, rank, alpha) | |
def state_dict(self, *args, **kwargs): | |
# Save only LoRA and classifier/projector parameters | |
state = super().state_dict(*args, **kwargs) | |
return {k: v for k, v in state.items() if "lora" in k or "classifier" in k or "projector" in k} | |
def load_state_dict(self, state_dict, strict=True): | |
missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False) | |
if missing_keys or unexpected_keys: | |
print(f"Missing keys: {missing_keys}") | |
print(f"Unexpected keys: {unexpected_keys}") | |
def forward(self, x, attention_mask=None): | |
return self.model(x, attention_mask=attention_mask).logits | |
def training_step(self, batch, batch_idx): | |
x, attention_mask, y = batch | |
# Forward pass | |
logits = self(x, attention_mask=attention_mask) | |
# Compute loss | |
loss = F.cross_entropy(logits, y) | |
# Log training loss | |
self.log("train_loss", loss, prog_bar=True, logger=True) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
x, attention_mask, y = batch | |
# Forward pass | |
logits = self(x, attention_mask=attention_mask) | |
# Compute loss and metrics | |
loss = F.cross_entropy(logits, y) | |
preds = torch.argmax(logits, dim=1) | |
accuracy = self.accuracy(preds, y) | |
precision = self.precision(preds, y) | |
recall = self.recall(preds, y) | |
f1 = self.f1(preds, y) | |
# Log metrics | |
self.log("val_loss", loss, prog_bar=True, logger=True) | |
self.log("val_acc", accuracy, prog_bar=True, logger=True) | |
self.log("val_precision", precision, prog_bar=True, logger=True) | |
self.log("val_recall", recall, prog_bar=True, logger=True) | |
self.log("val_f1", f1, prog_bar=True, logger=True) | |
return loss | |
def test_step(self, batch, batch_idx): | |
x, attention_mask, y = batch | |
# Forward pass | |
logits = self(x, attention_mask=attention_mask) | |
# Compute loss and metrics | |
loss = F.cross_entropy(logits, y) | |
preds = torch.argmax(logits, dim=1) | |
accuracy = self.accuracy(preds, y) | |
precision = self.precision(preds, y) | |
recall = self.recall(preds, y) | |
f1 = self.f1(preds, y) | |
# Log metrics | |
self.log("test_loss", loss, prog_bar=True, logger=True) | |
self.log("test_acc", accuracy, prog_bar=True, logger=True) | |
self.log("test_precision", precision, prog_bar=True, logger=True) | |
self.log("test_recall", recall, prog_bar=True, logger=True) | |
self.log("test_f1", f1, prog_bar=True, logger=True) | |
return {"test_loss": loss, "test_accuracy": accuracy} | |
def configure_optimizers(self): | |
optimizer = self.optimizer | |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5) | |
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"} |