from typing import Any, List import torch import torch.nn.functional as F from torch import nn from pytorch_lightning import LightningModule from torchmetrics import MaxMetric, MeanAbsoluteError, MinMetric from torchmetrics.classification.accuracy import Accuracy class SimpleConvNet(nn.Module): def __init__(self, hparams): super().__init__() pool_size = hparams["pool_size"] # 2 conv1_size = hparams["conv1_size"] # 5 conv1_out = hparams["conv1_channels"] # 6 conv2_size = hparams["conv1_channels"] # 5 conv2_out = hparams["conv2_channels"] # 16 size_img = hparams["image_size"] # 150 lin1_size = hparams["lin1_size"] # 100 lin2_size = hparams["lin2_size"] # 80 output_size = hparams["output_size"] # 1 size_img -= conv1_size - 1 size_img = int((size_img) / pool_size) size_img -= conv2_size - 1 size_img = int(size_img / pool_size) self.model = nn.Sequential( nn.Conv2d(3, conv1_out, conv1_size), nn.MaxPool2d(pool_size, pool_size), nn.Conv2d(conv1_out, conv2_out, conv2_size), nn.MaxPool2d(pool_size, pool_size), nn.Flatten(), nn.Linear(conv2_out * size_img * size_img, lin1_size), nn.Linear(lin1_size, lin2_size), nn.Linear(lin2_size, output_size), ) def forward(self, x): x = self.model(x) return x class FocusConvLitModule(LightningModule): """ Example of LightningModule for MNIST classification. A LightningModule organizes your PyTorch code into 5 sections: - Computations (init). - Train loop (training_step) - Validation loop (validation_step) - Test loop (test_step) - Optimizers (configure_optimizers) Read the docs: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html """ def __init__( self, image_size: int = 150, pool_size: int = 2, conv1_size: int = 5, conv1_channels: int = 6, conv2_size: int = 5, conv2_channels: int = 16, lin1_size: int = 100, lin2_size: int = 80, output_size: int = 1, lr: float = 0.001, weight_decay: float = 0.0005, ): super().__init__() # this line allows to access init params with 'self.hparams' attribute # it also ensures init params will be stored in ckpt self.save_hyperparameters(logger=False) self.model = SimpleConvNet(hparams=self.hparams) # loss function self.criterion = torch.nn.MSELoss() # use separate metric instance for train, val and test step # to ensure a proper reduction over the epoch self.train_mae = MeanAbsoluteError() self.val_mae = MeanAbsoluteError() self.test_mae = MeanAbsoluteError() # for logging best so far validation accuracy self.val_mae_best = MinMetric() def forward(self, x: torch.Tensor): return self.model(x) def step(self, batch: Any): x = batch["image"] y = batch["focus_height"] logits = self.forward(x) loss = self.criterion(logits, y.unsqueeze(1)) preds = torch.squeeze(logits) return loss, preds, y def training_step(self, batch: Any, batch_idx: int): loss, preds, targets = self.step(batch) # log train metrics mae = self.train_mae(preds, targets) self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False) self.log("train/mae", mae, on_step=False, on_epoch=True, prog_bar=True) # we can return here dict with any tensors # and then read it in some callback or in `training_epoch_end()`` below # remember to always return loss from `training_step()` or else # backpropagation will fail! return {"loss": loss, "preds": preds, "targets": targets} def training_epoch_end(self, outputs: List[Any]): # `outputs` is a list of dicts returned from `training_step()` pass def validation_step(self, batch: Any, batch_idx: int): loss, preds, targets = self.step(batch) # log val metrics mae = self.val_mae(preds, targets) self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False) self.log("val/mae", mae, on_step=False, on_epoch=True, prog_bar=True) return {"loss": loss, "preds": preds, "targets": targets} def validation_epoch_end(self, outputs: List[Any]): mae = self.val_mae.compute() # get val accuracy from current epoch self.val_mae_best.update(mae) self.log( "val/mae_best", self.val_mae_best.compute(), on_epoch=True, prog_bar=True ) def test_step(self, batch: Any, batch_idx: int): loss, preds, targets = self.step(batch) # log test metrics mae = self.test_mae(preds, targets) self.log("test/loss", loss, on_step=False, on_epoch=True) self.log("test/mae", mae, on_step=False, on_epoch=True) return {"loss": loss, "preds": preds, "targets": targets} def test_epoch_end(self, outputs: List[Any]): print(outputs) pass def on_epoch_end(self): # reset metrics at the end of every epoch self.train_mae.reset() self.test_mae.reset() self.val_mae.reset() def configure_optimizers(self): """Choose what optimizers and learning-rate schedulers. Normally you'd need one. But in the case of GANs or similar you might have multiple. See examples here: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers """ return torch.optim.Adam( params=self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay, )