Hannes Kuchelmeister commited on
Commit
d908fd7
1 Parent(s): a10a7fc

implement simple model

Browse files
models/notebooks/1.0-hfk-datamodules-exploration.ipynb CHANGED
@@ -393,7 +393,7 @@
393
  },
394
  {
395
  "cell_type": "code",
396
- "execution_count": 15,
397
  "metadata": {},
398
  "outputs": [
399
  {
@@ -402,7 +402,7 @@
402
  "64"
403
  ]
404
  },
405
- "execution_count": 15,
406
  "metadata": {},
407
  "output_type": "execute_result"
408
  }
@@ -413,6 +413,53 @@
413
  "\n",
414
  "len(data[\"focus_value\"])"
415
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  }
417
  ],
418
  "metadata": {
 
393
  },
394
  {
395
  "cell_type": "code",
396
+ "execution_count": 8,
397
  "metadata": {},
398
  "outputs": [
399
  {
 
402
  "64"
403
  ]
404
  },
405
+ "execution_count": 8,
406
  "metadata": {},
407
  "output_type": "execute_result"
408
  }
 
413
  "\n",
414
  "len(data[\"focus_value\"])"
415
  ]
416
+ },
417
+ {
418
+ "cell_type": "code",
419
+ "execution_count": 9,
420
+ "metadata": {},
421
+ "outputs": [
422
+ {
423
+ "name": "stderr",
424
+ "output_type": "stream",
425
+ "text": [
426
+ "/home/hku/.local/lib/python3.8/site-packages/torch/nn/modules/loss.py:96: UserWarning: Using a target size (torch.Size([64])) that is different to the input size (torch.Size([64, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
427
+ " return F.l1_loss(input, target, reduction=self.reduction)\n"
428
+ ]
429
+ },
430
+ {
431
+ "data": {
432
+ "text/plain": [
433
+ "(tensor(2.5787, grad_fn=<L1LossBackward0>),\n",
434
+ " tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
435
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
436
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n",
437
+ " tensor([-1.2805, -0.0943, -2.3645, 0.8542, -0.8047, -6.0020, 0.0000, -4.3352,\n",
438
+ " -1.8066, -2.7189, -6.4697, -3.2557, -4.2778, -5.0264, -3.4891, 0.0000,\n",
439
+ " -1.7181, -2.7314, 0.3324, -0.0943, -0.8991, 0.0000, -4.4178, 1.9723,\n",
440
+ " -3.0026, -5.5685, 3.8374, 3.8625, -0.4125, -4.1936, -1.5781, -1.6393,\n",
441
+ " -2.9583, -5.4933, -1.7807, -3.3135, -5.3423, -0.7978, -5.3971, -4.9412,\n",
442
+ " 0.0000, -4.4128, -5.7744, -5.2755, -1.0996, -5.7482, 0.0000, -0.1737,\n",
443
+ " -3.5851, -6.1429, -6.3642, -3.9653, -0.2081, -0.9539, -0.4159, -0.5388,\n",
444
+ " -1.3643, -4.4441, -1.5161, 0.6395, -5.4710, -2.6482, 0.0000, -2.6257],\n",
445
+ " dtype=torch.float64))"
446
+ ]
447
+ },
448
+ "execution_count": 9,
449
+ "metadata": {},
450
+ "output_type": "execute_result"
451
+ }
452
+ ],
453
+ "source": [
454
+ "import types\n",
455
+ "import importlib.machinery\n",
456
+ "focus_module = SourceFileLoader('focus_module', '../src/models/focus_module.py').load_module()\n",
457
+ "from focus_module import FocusLitModule\n",
458
+ "\n",
459
+ "model = FocusLitModule()\n",
460
+ "\n",
461
+ "model.step(data)"
462
+ ]
463
  }
464
  ],
465
  "metadata": {
models/requirements.txt CHANGED
@@ -6,6 +6,7 @@ torchmetrics>=0.7.0
6
 
7
  # --------- data and model dependencies --------- #
8
  scikit-image
 
9
 
10
  # --------- hydra --------- #
11
  hydra-core>=1.1.0
 
6
 
7
  # --------- data and model dependencies --------- #
8
  scikit-image
9
+ pandas
10
 
11
  # --------- hydra --------- #
12
  hydra-core>=1.1.0
models/src/datamodules/focus_datamodule.py CHANGED
@@ -51,7 +51,7 @@ class FocusDataSet(Dataset):
51
  sample = {"image": image, "focus_value": focus_value}
52
 
53
  if self.transform:
54
- sample = self.transform(sample)
55
 
56
  return sample
57
 
@@ -76,7 +76,9 @@ class FocusDataModule(LightningDataModule):
76
  self.save_hyperparameters(logger=False)
77
 
78
  # data transformations
79
- self.transforms = transforms.Compose([])
 
 
80
 
81
  self.data_train: Optional[Dataset] = None
82
  self.data_val: Optional[Dataset] = None
 
51
  sample = {"image": image, "focus_value": focus_value}
52
 
53
  if self.transform:
54
+ sample["image"] = self.transform(sample["image"])
55
 
56
  return sample
57
 
 
76
  self.save_hyperparameters(logger=False)
77
 
78
  # data transformations
79
+ self.transforms = transforms.Compose(
80
+ [transforms.ToTensor(), transforms.ConvertImageDtype(torch.float)]
81
+ )
82
 
83
  self.data_train: Optional[Dataset] = None
84
  self.data_val: Optional[Dataset] = None
models/src/models/focus_module.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List
2
+
3
+ import torch
4
+ from torch import nn
5
+ from pytorch_lightning import LightningModule
6
+ from torchmetrics import MaxMetric, MeanAbsoluteError, MinMetric
7
+ from torchmetrics.classification.accuracy import Accuracy
8
+
9
+
10
+ class SimpleDenseNet(nn.Module):
11
+ def __init__(self, hparams: dict):
12
+ super().__init__()
13
+
14
+ self.model = nn.Sequential(
15
+ nn.Linear(hparams["input_size"], hparams["lin1_size"]),
16
+ nn.BatchNorm1d(hparams["lin1_size"]),
17
+ nn.ReLU(),
18
+ nn.Linear(hparams["lin1_size"], hparams["lin2_size"]),
19
+ nn.BatchNorm1d(hparams["lin2_size"]),
20
+ nn.ReLU(),
21
+ nn.Linear(hparams["lin2_size"], hparams["lin3_size"]),
22
+ nn.BatchNorm1d(hparams["lin3_size"]),
23
+ nn.ReLU(),
24
+ nn.Linear(hparams["lin3_size"], hparams["output_size"]),
25
+ )
26
+
27
+ def forward(self, x):
28
+ batch_size, channels, width, height = x.size()
29
+
30
+ # (batch, 1, width, height) -> (batch, 1*width*height)
31
+ x = x.view(batch_size, -1)
32
+
33
+ return self.model(x)
34
+
35
+
36
+ class FocusLitModule(LightningModule):
37
+ """
38
+ Example of LightningModule for MNIST classification.
39
+
40
+ A LightningModule organizes your PyTorch code into 5 sections:
41
+ - Computations (init).
42
+ - Train loop (training_step)
43
+ - Validation loop (validation_step)
44
+ - Test loop (test_step)
45
+ - Optimizers (configure_optimizers)
46
+
47
+ Read the docs:
48
+ https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ input_size: int = 75 * 75 * 3,
54
+ lin1_size: int = 256,
55
+ lin2_size: int = 256,
56
+ lin3_size: int = 256,
57
+ output_size: int = 1,
58
+ lr: float = 0.001,
59
+ weight_decay: float = 0.0005,
60
+ ):
61
+ super().__init__()
62
+
63
+ # this line allows to access init params with 'self.hparams' attribute
64
+ # it also ensures init params will be stored in ckpt
65
+ self.save_hyperparameters(logger=False)
66
+
67
+ self.model = SimpleDenseNet(hparams=self.hparams)
68
+
69
+ # loss function
70
+ self.criterion = torch.nn.L1Loss()
71
+
72
+ # use separate metric instance for train, val and test step
73
+ # to ensure a proper reduction over the epoch
74
+ self.train_acc = MeanAbsoluteError()
75
+ self.val_acc = MeanAbsoluteError()
76
+ self.test_acc = MeanAbsoluteError()
77
+
78
+ # for logging best so far validation accuracy
79
+ self.val_acc_best = MinMetric()
80
+
81
+ def forward(self, x: torch.Tensor):
82
+ return self.model(x)
83
+
84
+ def step(self, batch: Any):
85
+ x = batch["image"]
86
+ y = batch["focus_value"]
87
+ logits = self.forward(x)
88
+ loss = self.criterion(logits, y)
89
+ preds = torch.argmax(logits, dim=1)
90
+ return loss, preds, y
91
+
92
+ def training_step(self, batch: Any, batch_idx: int):
93
+ loss, preds, targets = self.step(batch)
94
+
95
+ # log train metrics
96
+ acc = self.train_acc(preds, targets)
97
+ self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
98
+ self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
99
+
100
+ # we can return here dict with any tensors
101
+ # and then read it in some callback or in `training_epoch_end()`` below
102
+ # remember to always return loss from `training_step()` or else backpropagation will fail!
103
+ return {"loss": loss, "preds": preds, "targets": targets}
104
+
105
+ def training_epoch_end(self, outputs: List[Any]):
106
+ # `outputs` is a list of dicts returned from `training_step()`
107
+ pass
108
+
109
+ def validation_step(self, batch: Any, batch_idx: int):
110
+ loss, preds, targets = self.step(batch)
111
+
112
+ # log val metrics
113
+ acc = self.val_acc(preds, targets)
114
+ self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
115
+ self.log("val/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
116
+
117
+ return {"loss": loss, "preds": preds, "targets": targets}
118
+
119
+ def validation_epoch_end(self, outputs: List[Any]):
120
+ acc = self.val_acc.compute() # get val accuracy from current epoch
121
+ self.val_acc_best.update(acc)
122
+ self.log(
123
+ "val/acc_best", self.val_acc_best.compute(), on_epoch=True, prog_bar=True
124
+ )
125
+
126
+ def test_step(self, batch: Any, batch_idx: int):
127
+ loss, preds, targets = self.step(batch)
128
+
129
+ # log test metrics
130
+ acc = self.test_acc(preds, targets)
131
+ self.log("test/loss", loss, on_step=False, on_epoch=True)
132
+ self.log("test/acc", acc, on_step=False, on_epoch=True)
133
+
134
+ return {"loss": loss, "preds": preds, "targets": targets}
135
+
136
+ def test_epoch_end(self, outputs: List[Any]):
137
+ pass
138
+
139
+ def on_epoch_end(self):
140
+ # reset metrics at the end of every epoch
141
+ self.train_acc.reset()
142
+ self.test_acc.reset()
143
+ self.val_acc.reset()
144
+
145
+ def configure_optimizers(self):
146
+ """Choose what optimizers and learning-rate schedulers.
147
+
148
+ Normally you'd need one. But in the case of GANs or similar you might have multiple.
149
+
150
+ See examples here:
151
+ https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
152
+ """
153
+ return torch.optim.Adam(
154
+ params=self.parameters(),
155
+ lr=self.hparams.lr,
156
+ weight_decay=self.hparams.weight_decay,
157
+ )