|
from typing import Any, List |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from pytorch_lightning import LightningModule |
|
from torchmetrics import MaxMetric, MeanAbsoluteError, MinMetric |
|
from torchmetrics.classification.accuracy import Accuracy |
|
import torchvision.models as models |
|
|
|
|
|
class ResNetLitModule(LightningModule): |
|
def __init__( |
|
self, |
|
resnet_type: str = "ResNet", |
|
pretrained=False, |
|
lr: float = 0.001, |
|
weight_decay: float = 0.0005, |
|
): |
|
"""Initialize function for a resnet module. |
|
|
|
Args: |
|
resnet_type (str, optional): Type of the used resnet network. Defaults to |
|
"ResNet". |
|
Can be one of the following values: "ResNet", "resnet18", |
|
"resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", |
|
"resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2" |
|
pretrained (bool, optional): if True loads pytorch pretrained models. |
|
Defaults to False. |
|
""" |
|
super().__init__() |
|
|
|
|
|
|
|
self.save_hyperparameters(logger=False) |
|
|
|
|
|
self.criterion = torch.nn.MSELoss() |
|
|
|
|
|
|
|
self.train_mae = MeanAbsoluteError() |
|
self.val_mae = MeanAbsoluteError() |
|
self.test_mae = MeanAbsoluteError() |
|
|
|
|
|
self.val_mae_best = MinMetric() |
|
|
|
self.pretrained = pretrained |
|
|
|
if resnet_type == "ResNet": |
|
resnet_constructor = models.ResNet() |
|
elif resnet_type == "resnet18": |
|
resnet_constructor = models.resnet18 |
|
elif resnet_type == "resnet34": |
|
resnet_constructor = models.resnet34 |
|
elif resnet_type == "resnet50": |
|
resnet_constructor = models.resnet50 |
|
elif resnet_type == "resnet101": |
|
resnet_constructor = models.resnet101 |
|
elif resnet_type == "resnet152": |
|
resnet_constructor = models.resnet152 |
|
elif resnet_type == "resnext50_32x4d": |
|
resnet_constructor = models.resnext50_32x4d |
|
elif resnet_type == "resnext101_32x8d": |
|
resnet_constructor = models.resnext101_32x8d |
|
elif resnet_type == "wide_resnet50_2": |
|
resnet_constructor = models.wide_resnet50_2 |
|
elif resnet_type == "wide_resnet101_2": |
|
resnet_constructor = models.wide_resnet101_2 |
|
else: |
|
raise Exception(f"did not find model type: {resnet_type}") |
|
|
|
backbone = resnet_constructor(pretrained=pretrained) |
|
|
|
|
|
num_filters = backbone.fc.in_features |
|
layers = list(backbone.children())[:-1] |
|
self.feature_extractor = nn.Sequential(*layers) |
|
|
|
self.fc = nn.Linear(num_filters, 1) |
|
|
|
def forward(self, x): |
|
representations = self.feature_extractor(x).flatten(1) |
|
x = self.fc(representations) |
|
return x |
|
|
|
def step(self, batch: Any): |
|
x = batch["image"] |
|
y = batch["focus_height"] |
|
logits = self.forward(x) |
|
loss = self.criterion(logits, y.unsqueeze(1)) |
|
preds = torch.squeeze(logits) |
|
return loss, preds, y |
|
|
|
def training_step(self, batch: Any, batch_idx: int): |
|
loss, preds, targets = self.step(batch) |
|
|
|
|
|
mae = self.train_mae(preds, targets) |
|
self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False) |
|
self.log("train/mae", mae, on_step=False, on_epoch=True, prog_bar=True) |
|
|
|
|
|
|
|
|
|
|
|
return {"loss": loss, "preds": preds, "targets": targets} |
|
|
|
def training_epoch_end(self, outputs: List[Any]): |
|
|
|
pass |
|
|
|
def validation_step(self, batch: Any, batch_idx: int): |
|
loss, preds, targets = self.step(batch) |
|
|
|
|
|
mae = self.val_mae(preds, targets) |
|
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False) |
|
self.log("val/mae", mae, on_step=False, on_epoch=True, prog_bar=True) |
|
|
|
return {"loss": loss, "preds": preds, "targets": targets} |
|
|
|
def validation_epoch_end(self, outputs: List[Any]): |
|
mae = self.val_mae.compute() |
|
self.val_mae_best.update(mae) |
|
self.log( |
|
"val/mae_best", self.val_mae_best.compute(), on_epoch=True, prog_bar=True |
|
) |
|
|
|
def test_step(self, batch: Any, batch_idx: int): |
|
loss, preds, targets = self.step(batch) |
|
|
|
|
|
mae = self.test_mae(preds, targets) |
|
self.log("test/loss", loss, on_step=False, on_epoch=True) |
|
self.log("test/mae", mae, on_step=False, on_epoch=True) |
|
|
|
def test_epoch_end(self, outputs: List[Any]): |
|
print(outputs) |
|
pass |
|
|
|
def on_epoch_end(self): |
|
|
|
self.train_mae.reset() |
|
self.test_mae.reset() |
|
self.val_mae.reset() |
|
|
|
def configure_optimizers(self): |
|
"""Choose what optimizers and learning-rate schedulers. |
|
|
|
Normally you'd need one. But in the case of GANs or similar you might |
|
have multiple. |
|
|
|
See examples here: |
|
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers |
|
""" |
|
return torch.optim.Adam( |
|
params=self.parameters(), |
|
lr=self.hparams.lr, |
|
weight_decay=self.hparams.weight_decay, |
|
) |
|
|