|
from typing import Any, List |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from pytorch_lightning import LightningModule |
|
from torchmetrics import MaxMetric, MeanAbsoluteError, MinMetric |
|
from torchmetrics.classification.accuracy import Accuracy |
|
|
|
|
|
class SimpleConvNet(nn.Module): |
|
def __init__(self, hparams): |
|
super().__init__() |
|
|
|
pool_size = hparams["pool_size"] |
|
conv1_size = hparams["conv1_size"] |
|
conv1_out = hparams["conv1_channels"] |
|
conv2_size = hparams["conv1_channels"] |
|
conv2_out = hparams["conv2_channels"] |
|
size_img = hparams["image_size"] |
|
|
|
lin1_size = hparams["lin1_size"] |
|
lin2_size = hparams["lin2_size"] |
|
output_size = hparams["output_size"] |
|
|
|
size_img -= conv1_size - 1 |
|
size_img = int((size_img) / pool_size) |
|
size_img -= conv2_size - 1 |
|
size_img = int(size_img / pool_size) |
|
|
|
self.model = nn.Sequential( |
|
nn.Conv2d(3, conv1_out, conv1_size), |
|
nn.MaxPool2d(pool_size, pool_size), |
|
nn.Conv2d(conv1_out, conv2_out, conv2_size), |
|
nn.MaxPool2d(pool_size, pool_size), |
|
nn.Flatten(), |
|
nn.Linear(conv2_out * size_img * size_img, lin1_size), |
|
nn.Linear(lin1_size, lin2_size), |
|
nn.Linear(lin2_size, output_size), |
|
) |
|
|
|
def forward(self, x): |
|
x = self.model(x) |
|
return x |
|
|
|
|
|
class FocusConvLitModule(LightningModule): |
|
""" |
|
Example of LightningModule for MNIST classification. |
|
|
|
A LightningModule organizes your PyTorch code into 5 sections: |
|
- Computations (init). |
|
- Train loop (training_step) |
|
- Validation loop (validation_step) |
|
- Test loop (test_step) |
|
- Optimizers (configure_optimizers) |
|
|
|
Read the docs: |
|
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html |
|
""" |
|
|
|
def __init__( |
|
self, |
|
image_size: int = 150, |
|
pool_size: int = 2, |
|
conv1_size: int = 5, |
|
conv1_channels: int = 6, |
|
conv2_size: int = 5, |
|
conv2_channels: int = 16, |
|
lin1_size: int = 100, |
|
lin2_size: int = 80, |
|
output_size: int = 1, |
|
lr: float = 0.001, |
|
weight_decay: float = 0.0005, |
|
): |
|
super().__init__() |
|
|
|
|
|
|
|
self.save_hyperparameters(logger=False) |
|
|
|
self.model = SimpleConvNet(hparams=self.hparams) |
|
|
|
|
|
self.criterion = torch.nn.MSELoss() |
|
|
|
|
|
|
|
self.train_mae = MeanAbsoluteError() |
|
self.val_mae = MeanAbsoluteError() |
|
self.test_mae = MeanAbsoluteError() |
|
|
|
|
|
self.val_mae_best = MinMetric() |
|
|
|
def forward(self, x: torch.Tensor): |
|
return self.model(x) |
|
|
|
def step(self, batch: Any): |
|
x = batch["image"] |
|
y = batch["focus_height"] |
|
logits = self.forward(x) |
|
loss = self.criterion(logits, y.unsqueeze(1)) |
|
preds = torch.squeeze(logits) |
|
return loss, preds, y |
|
|
|
def training_step(self, batch: Any, batch_idx: int): |
|
loss, preds, targets = self.step(batch) |
|
|
|
|
|
mae = self.train_mae(preds, targets) |
|
self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False) |
|
self.log("train/mae", mae, on_step=False, on_epoch=True, prog_bar=True) |
|
|
|
|
|
|
|
|
|
|
|
return {"loss": loss, "preds": preds, "targets": targets} |
|
|
|
def training_epoch_end(self, outputs: List[Any]): |
|
|
|
pass |
|
|
|
def validation_step(self, batch: Any, batch_idx: int): |
|
loss, preds, targets = self.step(batch) |
|
|
|
|
|
mae = self.val_mae(preds, targets) |
|
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False) |
|
self.log("val/mae", mae, on_step=False, on_epoch=True, prog_bar=True) |
|
|
|
return {"loss": loss, "preds": preds, "targets": targets} |
|
|
|
def validation_epoch_end(self, outputs: List[Any]): |
|
mae = self.val_mae.compute() |
|
self.val_mae_best.update(mae) |
|
self.log( |
|
"val/mae_best", self.val_mae_best.compute(), on_epoch=True, prog_bar=True |
|
) |
|
|
|
def test_step(self, batch: Any, batch_idx: int): |
|
loss, preds, targets = self.step(batch) |
|
|
|
|
|
mae = self.test_mae(preds, targets) |
|
self.log("test/loss", loss, on_step=False, on_epoch=True) |
|
self.log("test/mae", mae, on_step=False, on_epoch=True) |
|
|
|
return {"loss": loss, "preds": preds, "targets": targets} |
|
|
|
def test_epoch_end(self, outputs: List[Any]): |
|
print(outputs) |
|
pass |
|
|
|
def on_epoch_end(self): |
|
|
|
self.train_mae.reset() |
|
self.test_mae.reset() |
|
self.val_mae.reset() |
|
|
|
def configure_optimizers(self): |
|
"""Choose what optimizers and learning-rate schedulers. |
|
|
|
Normally you'd need one. But in the case of GANs or similar you might |
|
have multiple. |
|
|
|
See examples here: |
|
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers |
|
""" |
|
return torch.optim.Adam( |
|
params=self.parameters(), |
|
lr=self.hparams.lr, |
|
weight_decay=self.hparams.weight_decay, |
|
) |
|
|