from typing import Any, List import torch import torch.nn.functional as F from torch import nn from pytorch_lightning import LightningModule from torchmetrics import MaxMetric, MeanAbsoluteError, MinMetric from torchmetrics.classification.accuracy import Accuracy import torchvision.models as models class ResNetLitModule(LightningModule): def __init__( self, resnet_type: str = "ResNet", pretrained=False, lr: float = 0.001, weight_decay: float = 0.0005, ): """Initialize function for a resnet module. Args: resnet_type (str, optional): Type of the used resnet network. Defaults to "ResNet". Can be one of the following values: "ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2" pretrained (bool, optional): if True loads pytorch pretrained models. Defaults to False. """ super().__init__() # this line allows to access init params with 'self.hparams' attribute # it also ensures init params will be stored in ckpt self.save_hyperparameters(logger=False) # loss function self.criterion = torch.nn.MSELoss() # use separate metric instance for train, val and test step # to ensure a proper reduction over the epoch self.train_mae = MeanAbsoluteError() self.val_mae = MeanAbsoluteError() self.test_mae = MeanAbsoluteError() # for logging best so far validation accuracy self.val_mae_best = MinMetric() self.pretrained = pretrained if resnet_type == "ResNet": resnet_constructor = models.ResNet() elif resnet_type == "resnet18": resnet_constructor = models.resnet18 elif resnet_type == "resnet34": resnet_constructor = models.resnet34 elif resnet_type == "resnet50": resnet_constructor = models.resnet50 elif resnet_type == "resnet101": resnet_constructor = models.resnet101 elif resnet_type == "resnet152": resnet_constructor = models.resnet152 elif resnet_type == "resnext50_32x4d": resnet_constructor = models.resnext50_32x4d elif resnet_type == "resnext101_32x8d": resnet_constructor = models.resnext101_32x8d elif resnet_type == "wide_resnet50_2": resnet_constructor = models.wide_resnet50_2 elif resnet_type == "wide_resnet101_2": resnet_constructor = models.wide_resnet101_2 else: raise Exception(f"did not find model type: {resnet_type}") backbone = resnet_constructor(pretrained=pretrained) # init a pretrained resnet num_filters = backbone.fc.in_features layers = list(backbone.children())[:-1] self.feature_extractor = nn.Sequential(*layers) self.fc = nn.Linear(num_filters, 1) def forward(self, x): representations = self.feature_extractor(x).flatten(1) x = self.fc(representations) return x def step(self, batch: Any): x = batch["image"] y = batch["focus_height"] logits = self.forward(x) loss = self.criterion(logits, y.unsqueeze(1)) preds = torch.squeeze(logits) return loss, preds, y def training_step(self, batch: Any, batch_idx: int): loss, preds, targets = self.step(batch) # log train metrics mae = self.train_mae(preds, targets) self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False) self.log("train/mae", mae, on_step=False, on_epoch=True, prog_bar=True) # we can return here dict with any tensors # and then read it in some callback or in `training_epoch_end()`` below # remember to always return loss from `training_step()` or else # backpropagation will fail! return {"loss": loss, "preds": preds, "targets": targets} def training_epoch_end(self, outputs: List[Any]): # `outputs` is a list of dicts returned from `training_step()` pass def validation_step(self, batch: Any, batch_idx: int): loss, preds, targets = self.step(batch) # log val metrics mae = self.val_mae(preds, targets) self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False) self.log("val/mae", mae, on_step=False, on_epoch=True, prog_bar=True) return {"loss": loss, "preds": preds, "targets": targets} def validation_epoch_end(self, outputs: List[Any]): mae = self.val_mae.compute() # get val accuracy from current epoch self.val_mae_best.update(mae) self.log( "val/mae_best", self.val_mae_best.compute(), on_epoch=True, prog_bar=True ) def test_step(self, batch: Any, batch_idx: int): loss, preds, targets = self.step(batch) # log test metrics mae = self.test_mae(preds, targets) self.log("test/loss", loss, on_step=False, on_epoch=True) self.log("test/mae", mae, on_step=False, on_epoch=True) def test_epoch_end(self, outputs: List[Any]): print(outputs) pass def on_epoch_end(self): # reset metrics at the end of every epoch self.train_mae.reset() self.test_mae.reset() self.val_mae.reset() def configure_optimizers(self): """Choose what optimizers and learning-rate schedulers. Normally you'd need one. But in the case of GANs or similar you might have multiple. See examples here: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers """ return torch.optim.Adam( params=self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay, )