from typing import Any, List import torch from pytorch_lightning import LightningModule from torchmetrics import MaxMetric from torchmetrics.classification.accuracy import Accuracy from src.models.components.simple_dense_net import SimpleDenseNet class MNISTLitModule(LightningModule): """ Example of LightningModule for MNIST classification. A LightningModule organizes your PyTorch code into 5 sections: - Computations (init). - Train loop (training_step) - Validation loop (validation_step) - Test loop (test_step) - Optimizers (configure_optimizers) Read the docs: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html """ def __init__( self, input_size: int = 784, lin1_size: int = 256, lin2_size: int = 256, lin3_size: int = 256, output_size: int = 10, lr: float = 0.001, weight_decay: float = 0.0005, ): super().__init__() # this line allows to access init params with 'self.hparams' attribute # it also ensures init params will be stored in ckpt self.save_hyperparameters(logger=False) self.model = SimpleDenseNet(hparams=self.hparams) # loss function self.criterion = torch.nn.CrossEntropyLoss() # use separate metric instance for train, val and test step # to ensure a proper reduction over the epoch self.train_acc = Accuracy() self.val_acc = Accuracy() self.test_acc = Accuracy() # for logging best so far validation accuracy self.val_acc_best = MaxMetric() def forward(self, x: torch.Tensor): return self.model(x) def step(self, batch: Any): x, y = batch logits = self.forward(x) loss = self.criterion(logits, y) preds = torch.argmax(logits, dim=1) return loss, preds, y def training_step(self, batch: Any, batch_idx: int): loss, preds, targets = self.step(batch) # log train metrics acc = self.train_acc(preds, targets) self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False) self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True) # we can return here dict with any tensors # and then read it in some callback or in `training_epoch_end()`` below # remember to always return loss from `training_step()` or else backpropagation will fail! return {"loss": loss, "preds": preds, "targets": targets} def training_epoch_end(self, outputs: List[Any]): # `outputs` is a list of dicts returned from `training_step()` pass def validation_step(self, batch: Any, batch_idx: int): loss, preds, targets = self.step(batch) # log val metrics acc = self.val_acc(preds, targets) self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False) self.log("val/acc", acc, on_step=False, on_epoch=True, prog_bar=True) return {"loss": loss, "preds": preds, "targets": targets} def validation_epoch_end(self, outputs: List[Any]): acc = self.val_acc.compute() # get val accuracy from current epoch self.val_acc_best.update(acc) self.log("val/acc_best", self.val_acc_best.compute(), on_epoch=True, prog_bar=True) def test_step(self, batch: Any, batch_idx: int): loss, preds, targets = self.step(batch) # log test metrics acc = self.test_acc(preds, targets) self.log("test/loss", loss, on_step=False, on_epoch=True) self.log("test/acc", acc, on_step=False, on_epoch=True) return {"loss": loss, "preds": preds, "targets": targets} def test_epoch_end(self, outputs: List[Any]): pass def on_epoch_end(self): # reset metrics at the end of every epoch self.train_acc.reset() self.test_acc.reset() self.val_acc.reset() def configure_optimizers(self): """Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. See examples here: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers """ return torch.optim.Adam( params=self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay )