import torch.optim as optim import pytorch_lightning as pl import timm from torchmetrics import Accuracy, Precision, Recall, F1Score import torch class timm_backbones(pl.LightningModule): """ PyTorch Lightning model for image classification using a ResNet-18 architecture. This model uses a pre-trained ResNet-18 model and fine-tunes it for a specific number of classes. Args: num_classes (int, optional): The number of classes in the dataset. Defaults to 2. optimizer_cfg (DictConfig, optional): A Hydra configuration object for the optimizer. Methods: forward(x): Computes the forward pass of the model. configure_optimizers(): Configures the optimizer for the model. training_step(batch, batch_idx): Performs a training step on the model. validation_step(batch, batch_idx): Performs a validation step on the model. on_validation_epoch_end(): Called at the end of each validation epoch. test_step(batch, batch_idx): Performs a test step on the model. Example: model = ResNet18(num_classes=2, optimizer_cfg=cfg.model.optimizer) trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) trainer.test(model, dataloaders=test_dataloader) """ def __init__(self, encoder='resnet18', num_classes=2, optimizer_cfg=None, l1_lambda=0.0): super().__init__() self.encoder = encoder self.model = timm.create_model(encoder, pretrained=True) if self.model.default_cfg["input_size"][1] == 3: # If model expects 3 channels self.model.conv1 = torch.nn.Conv2d( in_channels=1, # Change to single channel out_channels=self.model.conv1.out_channels, kernel_size=self.model.conv1.kernel_size, stride=self.model.conv1.stride, padding=self.model.conv1.padding, bias=False ) self.accuracy = Accuracy(task="multiclass", num_classes=num_classes) self.precision = Precision(task="multiclass", num_classes=num_classes) self.recall = Recall(task="multiclass", num_classes=num_classes) self.f1 = F1Score(task="multiclass", num_classes=num_classes) self.l1_lambda = l1_lambda if hasattr(self.model, 'fc'): # For models with 'fc' as the classification layer in_features = self.model.fc.in_features self.model.fc = torch.nn.Linear(in_features, num_classes) elif hasattr(self.model, 'classifier'): # For models with 'classifier' in_features = self.model.classifier.in_features self.model.classifier = torch.nn.Linear(in_features, num_classes) elif hasattr(self.model, 'head'): # For models with 'head' in_features = self.model.head.in_features self.model.head = torch.nn.Linear(in_features, num_classes) else: raise ValueError(f"Unsupported model architecture for encoder: {encoder}") if optimizer_cfg is not None: optimizer_name = optimizer_cfg.name optimizer_lr = optimizer_cfg.lr optimizer_weight_decay = optimizer_cfg.weight_decay if optimizer_name == 'Adam': self.optimizer = optim.Adam(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) elif optimizer_name == 'SGD': self.optimizer = optim.SGD(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) else: raise ValueError(f"Unsupported optimizer: {optimizer_name}") else: self.optimizer = None def forward(self, x): return self.model(x) def configure_optimizers(self): optimizer = self.optimizer scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5) return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"} def training_step(self, batch, batch_idx): x, y = batch y = y.long() # Compute predictions and loss logits = self(x) loss = torch.nn.functional.cross_entropy(logits, y) # Add L1 regularization l1_norm = sum(param.abs().sum() for param in self.parameters()) loss += self.l1_lambda * l1_norm self.log('train_loss', loss, prog_bar=True, on_epoch=True, on_step=False, logger=True) return loss def validation_step(self, batch, batch_idx): x, y = batch y = y.long() logits = self(x) loss = torch.nn.functional.cross_entropy(logits, y) preds = torch.argmax(logits, dim=1) accuracy = self.accuracy(y, preds) precision = self.precision(y, preds) recall = self.recall(y, preds) f1 = self.f1(y, preds) self.log('val_loss', loss, prog_bar=True, on_epoch=True, on_step=True) self.log('val_acc', accuracy, prog_bar=True, on_epoch=True, on_step=True) self.log('val_precision', precision, prog_bar=True, on_epoch=True, on_step=True) self.log('val_recall', recall, prog_bar=True, on_epoch=True, on_step=True) self.log('val_f1', f1, prog_bar=True, on_epoch=True, on_step=True) return loss def on_validation_epoch_end(self): avg_loss = self.trainer.logged_metrics['val_loss_epoch'] accuracy = self.trainer.logged_metrics['val_acc_epoch'] self.log('val_loss', avg_loss, prog_bar=True, on_epoch=True) self.log('val_acc', accuracy, prog_bar=True, on_epoch=True) return {'Average Loss:': avg_loss, 'Accuracy:': accuracy} def test_step(self, batch, batch_idx): x, y = batch y = y.long() logits = self(x) loss = torch.nn.functional.cross_entropy(logits, y) preds = torch.argmax(logits, dim=1) accuracy = self.accuracy(y, preds) precision = self.precision(y, preds) recall = self.recall(y, preds) f1 = self.f1(y, preds) # Log test metrics self.log('test_loss', loss, prog_bar=True, logger=True) self.log('test_acc', accuracy, prog_bar=True, logger=True) self.log('test_precision', precision, prog_bar=True, logger=True) self.log('test_recall', recall, prog_bar=True, logger=True) self.log('test_f1', f1, prog_bar=True, logger=True) return {'test_loss': loss, 'test_accuracy': accuracy, 'test_precision': precision, 'test_recall': recall, 'test_f1': f1} class CTCEncoderPL(pl.LightningModule): def __init__(self, ctc_encoder, num_classes, optimizer_cfg): super(CTCEncoderPL, self).__init__() self.ctc_encoder = ctc_encoder self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=True) self.optimizer_cfg = optimizer_cfg self.accuracy = Accuracy(task="multiclass", num_classes=num_classes) self.precision = Precision(task="multiclass", num_classes=num_classes) self.recall = Recall(task="multiclass", num_classes=num_classes) self.f1 = F1Score(task="multiclass", num_classes=num_classes) if optimizer_cfg is not None: optimizer_name = optimizer_cfg.name optimizer_lr = optimizer_cfg.lr optimizer_weight_decay = optimizer_cfg.weight_decay if optimizer_name == 'Adam': self.optimizer = optim.Adam(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) elif optimizer_name == 'SGD': self.optimizer = optim.SGD(self.parameters(), lr=optimizer_lr, weight_decay=optimizer_weight_decay) else: raise ValueError(f"Unsupported optimizer: {optimizer_name}") else: self.optimizer = None def forward(self, x): return self.ctc_encoder(x) def training_step(self, batch, batch_idx): x, y, input_lengths, target_lengths = batch logits, input_lengths = self.ctc_encoder(x, input_lengths) log_probs = torch.nn.functional.log_softmax(logits, dim=-1) loss = self.ctc_loss(log_probs, y, input_lengths, target_lengths) assert input_lengths.size(0) == x.size(0), f"input_lengths size ({input_lengths.size(0)}) must match batch size ({x.size(0)})" preds = torch.argmax(log_probs, dim=-1) self.log("train_loss", loss, on_epoch=True) return loss def validation_step(self, batch, batch_idx): x, y, input_lengths, target_lengths = batch # Compute logits and adjust input lengths logits, input_lengths = self.ctc_encoder(x, input_lengths) log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # Validate input_lengths size assert input_lengths.size(0) == logits.size(0), "Mismatch between input_lengths and batch size" # Compute CTC loss loss = self.ctc_loss(log_probs, y, input_lengths, target_lengths) # Compute metrics preds = torch.argmax(log_probs, dim=-1) accuracy = self.accuracy(y, preds) precision = self.precision(y, preds) recall = self.recall(y, preds) f1 = self.f1(y, preds) # Log metrics self.log('val_loss', loss, prog_bar=True, on_epoch=True, on_step=True) self.log('val_acc', accuracy, prog_bar=True, on_epoch=True, on_step=True) self.log('val_precision', precision, prog_bar=True, on_epoch=True, on_step=True) self.log('val_recall', recall, prog_bar=True, on_epoch=True, on_step=True) self.log('val_f1', f1, prog_bar=True, on_epoch=True, on_step=True) return loss def on_validation_epoch_end(self): avg_loss = self.trainer.logged_metrics['val_loss_epoch'] accuracy = self.trainer.logged_metrics['val_acc_epoch'] self.log('val_loss', avg_loss, prog_bar=True, on_epoch=True) self.log('val_acc', accuracy, prog_bar=True, on_epoch=True) return {'Average Loss:': avg_loss, 'Accuracy:': accuracy} def test_step(self, batch, batch_idx): x, y, input_lengths, target_lengths = batch logits = self(x) log_probs = torch.nn.functional.log_softmax(logits, dim=-1) loss = self.ctc_loss(log_probs, y, input_lengths, target_lengths) preds = torch.argmax(log_probs, dim=-1) accuracy = self.accuracy(y, preds) precision = self.precision(y, preds) recall = self.recall(y, preds) f1 = self.f1(y, preds) self.log('test_loss', loss, prog_bar=True, logger=True) self.log('test_acc', accuracy, prog_bar=True, logger=True) self.log('test_precision', precision, prog_bar=True, logger=True) self.log('test_recall', recall, prog_bar=True, logger=True) self.log('test_f1', f1, prog_bar=True, logger=True) return {'test_loss': loss, 'test_accuracy': accuracy, 'test_precision': precision, 'test_recall': recall, 'test_f1': f1} def configure_optimizers(self): optimizer = self.optimizer scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5) return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"} def greedy_decode(self, log_probs): """ Perform greedy decoding to get predictions from log probabilities. """ preds = torch.argmax(log_probs, dim=-1) return preds