Hannes Kuchelmeister commited on
Commit
91867af
·
1 Parent(s): ceb9ded

readd cod modifications

Browse files
configs/datamodule/focus150.yaml CHANGED
@@ -1,8 +1,10 @@
1
  _target_: src.datamodules.focus_datamodule.FocusDataModule
2
 
3
  data_dir: ${data_dir}/focus150 # data_dir is specified in config.yaml
4
- csv_file: ${data_dir}/focus150/metadata.csv
 
 
 
5
  batch_size: 64
6
- train_val_test_split_percentage: [0.7, 0.15, 0.15]
7
  num_workers: 0
8
  pin_memory: False
 
1
  _target_: src.datamodules.focus_datamodule.FocusDataModule
2
 
3
  data_dir: ${data_dir}/focus150 # data_dir is specified in config.yaml
4
+ csv_train_file: ${data_dir}/focus150/train_metadata.csv
5
+ csv_val_file: ${data_dir}/focus150/validation_metadata.csv
6
+ csv_test_file: ${data_dir}/focus150/test_metadata.csv
7
+
8
  batch_size: 64
 
9
  num_workers: 0
10
  pin_memory: False
configs/experiment/focusConvMSE_besthyp_150.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=example
5
+
6
+ defaults:
7
+ - override /datamodule: focus150.yaml
8
+ - override /model: focusConv_150.yaml
9
+ - override /callbacks: default.yaml
10
+ - override /logger: many_loggers
11
+ - override /trainer: default.yaml
12
+
13
+ # all parameters below will be merged with parameters from default configurations set above
14
+ # this allows you to overwrite only specified parameters
15
+
16
+ # name of the run determines folder name in logs
17
+ name: "focusConvMSE_150"
18
+ seed: 12345
19
+
20
+ trainer:
21
+ min_epochs: 1
22
+ max_epochs: 100
23
+
24
+ model:
25
+ image_size: 150
26
+ pool_size: 2
27
+ conv1_size: 3
28
+ conv1_channels: 9
29
+ conv2_size: 7
30
+ conv2_channels: 6
31
+ lin1_size: 32
32
+ lin2_size: 72
33
+ output_size: 1
34
+ lr: 0.001
35
+ weight_decay: 0.0005
36
+
37
+ datamodule:
38
+ batch_size: 64
39
+ augmentation: True
40
+
configs/experiment/focusResNet101_pretrained_150.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=example
5
+
6
+ defaults:
7
+ - override /datamodule: focus150.yaml
8
+ - override /model: focusResNet_150.yaml
9
+ - override /callbacks: default.yaml
10
+ - override /logger: many_loggers
11
+ - override /trainer: default.yaml
12
+
13
+ # all parameters below will be merged with parameters from default configurations set above
14
+ # this allows you to overwrite only specified parameters
15
+
16
+ # name of the run determines folder name in logs
17
+ name: "focusResNet101pretrained_150"
18
+ seed: 12345
19
+
20
+ trainer:
21
+ min_epochs: 1
22
+ max_epochs: 100
23
+
24
+ model:
25
+ resnet_type: "resnet101"
26
+ pretrained: True
27
+ lr: 0.0011538
28
+ weight_decay: 0.0005
29
+
30
+ datamodule:
31
+ batch_size: 64
32
+ augmentation: True
33
+
configs/experiment/focusResNet_150.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=example
5
+
6
+ defaults:
7
+ - override /datamodule: focus150.yaml
8
+ - override /model: focusResNet_150.yaml
9
+ - override /callbacks: default.yaml
10
+ - override /logger: many_loggers
11
+ - override /trainer: default.yaml
12
+
13
+ # all parameters below will be merged with parameters from default configurations set above
14
+ # this allows you to overwrite only specified parameters
15
+
16
+ # name of the run determines folder name in logs
17
+ name: "focusResNet_150"
18
+ seed: 12345
19
+
20
+ trainer:
21
+ min_epochs: 1
22
+ max_epochs: 100
23
+
24
+ model:
25
+ resnet_type: "resnet50"
26
+ pretrained: false
27
+ lr: 0.001
28
+ weight_decay: 0.0005
29
+
30
+ datamodule:
31
+ batch_size: 128
32
+ augmentation: True
33
+
configs/hparams_search/focusResNetMSE_150.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # example hyperparameter optimization of some experiment with Optuna:
4
+ # python train.py -m hparams_search=mnist_optuna experiment=example
5
+
6
+ defaults:
7
+ - override /datamodule: focus150.yaml
8
+ - override /model: focusResNet_150.yaml
9
+ - override /hydra/sweeper: optuna
10
+
11
+ # choose metric which will be optimized by Optuna
12
+ # make sure this is the correct name of some metric logged in lightning module!
13
+ optimized_metric: "val/mae_best"
14
+
15
+ datamodule:
16
+ batch_size: 64
17
+ augmentation: True
18
+
19
+ name: "focusResNet_150_hyperparameter_search"
20
+
21
+ # here we define Optuna hyperparameter search
22
+ # it optimizes for value returned from function with @hydra.main decorator
23
+ # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
24
+ hydra:
25
+ sweeper:
26
+ _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
27
+
28
+ # storage URL to persist optimization results
29
+ # for example, you can use SQLite if you set 'sqlite:///example.db'
30
+ storage: null
31
+
32
+ # name of the study to persist optimization results
33
+ study_name: focusResNet_150_hyperparameter
34
+
35
+ # number of parallel workers
36
+ n_jobs: 1
37
+
38
+ # 'minimize' or 'maximize' the objective
39
+ direction: minimize
40
+
41
+ # total number of runs that will be executed
42
+ n_trials: 20
43
+
44
+ # choose Optuna hyperparameter sampler
45
+ # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
46
+ sampler:
47
+ _target_: optuna.samplers.TPESampler
48
+ seed: 12345
49
+ n_startup_trials: 10 # number of random sampling runs before optimization starts
50
+
51
+ # define range of hyperparameters
52
+ search_space:
53
+ model.pretrained:
54
+ type: categorical
55
+ choices: [true, false]
56
+ model.lr:
57
+ type: float
58
+ low: 0.0001
59
+ high: 0.01
60
+ model.resnet_type:
61
+ type: categorical
62
+ choices: [
63
+ "ResNet",
64
+ "resnet18",
65
+ "resnet34",
66
+ "resnet50",
67
+ "resnet101",
68
+ "resnext50_32x4d",
69
+ "wide_resnet50_2",
70
+ ]
configs/model/focusResNet_150.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ _target_: src.models.focus_resnet_module.ResNetLitModule
2
+
3
+ resnet_type: "resnet50"
4
+ pretrained: false
5
+ lr: 0.001
6
+ weight_decay: 0.0005
data/focus150/data_ascaris.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a130ae8a06978312f912f52ac1483cb1d2b278e3719bc10328cc39f8371107d
3
+ size 53392355
data/focus150/data_hookworm.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd2e638e9f8074b0060391f33dc1cefad94fdd6072b58e686292295a54db6633
3
+ size 748261
data/focus150/data_schistosoma.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c3e8c404cfdd63b1d001c7aecafb164b3ae38517fe81562dc80c0dbfcccb7a2
3
+ size 7372453
data/focus150/data_trichuris.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7e471d0dfe587f88f12caf37e18b70589af5e244fdee5aa20d9664afb51d434
3
+ size 1413835
data/focus150/test_metadata.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da25484eafc82b1e6a1307a2f017fee7e153820f09572504e8eabd32f7f72672
3
+ size 119718
data/focus150/train_metadata.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c2c5abcda4cb0c11f033715004ec6c03c580ddeb3364bddecf6c3278b40e584
3
+ size 560735
data/focus150/validation_metadata.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9bc5da3f5fa7ed919d5728594df752f1d1cd2dd3a506c4fe507639bfdf9a2a1d
3
+ size 119551
notebooks/3.0-hfk-model-performance-visualisation.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/datamodules/focus_datamodule.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from typing import Optional, Tuple
3
  import pandas as pd
4
  from skimage import io
5
 
@@ -14,7 +14,9 @@ from torchvision.transforms import transforms
14
  class FocusDataSet(Dataset):
15
  """Dataset for z-stacked images of neglected tropical diseaeses."""
16
 
17
- def __init__(self, csv_file, root_dir, transform=None, in_memory=True):
 
 
18
  """Initialize focus satck dataset.
19
 
20
  Args:
@@ -25,8 +27,17 @@ class FocusDataSet(Dataset):
25
  """
26
  self.metadata = pd.read_csv(csv_file)
27
  self.in_memory = in_memory
 
 
 
 
 
 
 
 
 
28
  self.col_index_path = self.metadata.columns.get_loc("image_path")
29
- self.col_index_focus = self.metadata.columns.get_loc("focus_value")
30
  self.root_dir = root_dir
31
  self.transform = transform
32
 
@@ -56,7 +67,7 @@ class FocusDataSet(Dataset):
56
  idx (int) The index of the sample that is to be retrieved
57
 
58
  Returns:
59
- Item/Items which is a dictionary containing "image" and "focus_value"
60
  """
61
  if torch.is_tensor(idx):
62
  idx = idx.tolist()
@@ -69,11 +80,14 @@ class FocusDataSet(Dataset):
69
  if self.transform:
70
  image = self.transform(image)
71
 
72
- focus_value = torch.from_numpy(
73
  np.asarray(self.metadata.iloc[idx, self.col_index_focus])
74
  ).float()
75
 
76
- sample = {"image": image, "focus_value": focus_value}
 
 
 
77
 
78
  return sample
79
 
@@ -86,27 +100,59 @@ class FocusDataModule(LightningDataModule):
86
  def __init__(
87
  self,
88
  data_dir: str = "data/",
89
- csv_file: str = "data/metadata.csv",
90
- train_val_test_split_percentage: Tuple[int, int, int] = (0.75, 0.15, 0.15),
 
91
  batch_size: int = 64,
92
  num_workers: int = 0,
93
  pin_memory: bool = False,
94
  in_memory: bool = True,
 
 
95
  ):
96
  super().__init__()
97
 
98
  # this line allows to access init params with 'self.hparams' attribute
99
  self.save_hyperparameters(logger=False)
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # data transformations
102
- self.transforms = transforms.Compose(
103
- [transforms.ToTensor(), transforms.ConvertImageDtype(torch.float)]
104
- )
105
 
106
  self.data_train: Optional[Dataset] = None
107
  self.data_val: Optional[Dataset] = None
108
  self.data_test: Optional[Dataset] = None
109
  self.in_memory = in_memory
 
110
 
111
  def prepare_data(self):
112
  """This method is not implemented as of yet.
@@ -123,24 +169,28 @@ class FocusDataModule(LightningDataModule):
123
 
124
  # load datasets only if they're not loaded already
125
  if not self.data_train and not self.data_val and not self.data_test:
126
- dataset = FocusDataSet(
127
- self.hparams.csv_file,
128
  self.hparams.data_dir,
129
  transform=self.transforms,
130
  in_memory=self.in_memory,
 
131
  )
132
- train_length = int(
133
- len(dataset) * self.hparams.train_val_test_split_percentage[0]
134
- )
135
- val_length = int(
136
- len(dataset) * self.hparams.train_val_test_split_percentage[1]
 
 
137
  )
138
- test_length = len(dataset) - val_length - train_length
139
 
140
- self.data_train, self.data_val, self.data_test = random_split(
141
- dataset=dataset,
142
- lengths=(train_length, test_length, val_length),
143
- generator=torch.Generator().manual_seed(42),
 
 
144
  )
145
 
146
  def train_dataloader(self):
 
1
  import os
2
+ from typing import List, Optional, Tuple
3
  import pandas as pd
4
  from skimage import io
5
 
 
14
  class FocusDataSet(Dataset):
15
  """Dataset for z-stacked images of neglected tropical diseaeses."""
16
 
17
+ def __init__(
18
+ self, csv_file, root_dir, transform=None, in_memory=True, additional_col_list=[]
19
+ ):
20
  """Initialize focus satck dataset.
21
 
22
  Args:
 
27
  """
28
  self.metadata = pd.read_csv(csv_file)
29
  self.in_memory = in_memory
30
+
31
+ self.additional_col_index = {}
32
+
33
+ _col_list = list(additional_col_list) # clone list to avoid modifying default
34
+ for attribute in _col_list:
35
+ self.additional_col_index[attribute] = self.metadata.columns.get_loc(
36
+ attribute
37
+ )
38
+
39
  self.col_index_path = self.metadata.columns.get_loc("image_path")
40
+ self.col_index_focus = self.metadata.columns.get_loc("focus_height")
41
  self.root_dir = root_dir
42
  self.transform = transform
43
 
 
67
  idx (int) The index of the sample that is to be retrieved
68
 
69
  Returns:
70
+ Item/Items which is a dictionary containing "image" and "focus_height"
71
  """
72
  if torch.is_tensor(idx):
73
  idx = idx.tolist()
 
80
  if self.transform:
81
  image = self.transform(image)
82
 
83
+ focus_height = torch.from_numpy(
84
  np.asarray(self.metadata.iloc[idx, self.col_index_focus])
85
  ).float()
86
 
87
+ sample = {"image": image, "focus_height": focus_height}
88
+
89
+ for attr, col_idx in self.additional_col_index.items():
90
+ sample[attr] = self.metadata.iloc[idx, col_idx]
91
 
92
  return sample
93
 
 
100
  def __init__(
101
  self,
102
  data_dir: str = "data/",
103
+ csv_train_file: str = "data/train_metadata.csv",
104
+ csv_val_file: str = "data/validation_metadata.csv",
105
+ csv_test_file: str = "data/test_metadata.csv",
106
  batch_size: int = 64,
107
  num_workers: int = 0,
108
  pin_memory: bool = False,
109
  in_memory: bool = True,
110
+ augmentation: bool = False,
111
+ additional_col_list: List[str] = [],
112
  ):
113
  super().__init__()
114
 
115
  # this line allows to access init params with 'self.hparams' attribute
116
  self.save_hyperparameters(logger=False)
117
 
118
+ transform_list = [
119
+ transforms.ToTensor(),
120
+ transforms.ConvertImageDtype(torch.float),
121
+ ]
122
+
123
+ self.base_transforms = []
124
+ self.base_transforms.extend(transform_list)
125
+ self.base_transforms = transforms.Compose(self.base_transforms)
126
+
127
+ if augmentation:
128
+ transform_list.extend(
129
+ [
130
+ transforms.RandomHorizontalFlip(p=0.5),
131
+ transforms.RandomVerticalFlip(p=0.5),
132
+ transforms.RandomChoice(
133
+ [
134
+ transforms.RandomApply(
135
+ [transforms.RandomRotation((90, 90))], p=0.5
136
+ ),
137
+ transforms.RandomApply(
138
+ [transforms.RandomRotation((180, 180))], p=0.5
139
+ ),
140
+ transforms.RandomApply(
141
+ [transforms.RandomRotation((270, 270))], p=0.5
142
+ ),
143
+ ]
144
+ ),
145
+ ]
146
+ )
147
+
148
  # data transformations
149
+ self.transforms = transforms.Compose(transform_list)
 
 
150
 
151
  self.data_train: Optional[Dataset] = None
152
  self.data_val: Optional[Dataset] = None
153
  self.data_test: Optional[Dataset] = None
154
  self.in_memory = in_memory
155
+ self.additional_col_list = additional_col_list
156
 
157
  def prepare_data(self):
158
  """This method is not implemented as of yet.
 
169
 
170
  # load datasets only if they're not loaded already
171
  if not self.data_train and not self.data_val and not self.data_test:
172
+ self.data_train = FocusDataSet(
173
+ self.hparams.csv_train_file,
174
  self.hparams.data_dir,
175
  transform=self.transforms,
176
  in_memory=self.in_memory,
177
+ additional_col_list=self.additional_col_list,
178
  )
179
+
180
+ self.data_val = FocusDataSet(
181
+ self.hparams.csv_val_file,
182
+ self.hparams.data_dir,
183
+ transform=self.base_transforms,
184
+ in_memory=self.in_memory,
185
+ additional_col_list=self.additional_col_list,
186
  )
 
187
 
188
+ self.data_test = FocusDataSet(
189
+ self.hparams.csv_test_file,
190
+ self.hparams.data_dir,
191
+ transform=self.base_transforms,
192
+ in_memory=self.in_memory,
193
+ additional_col_list=self.additional_col_list,
194
  )
195
 
196
  def train_dataloader(self):
src/models/focus_conv_module.py CHANGED
@@ -98,7 +98,7 @@ class FocusConvLitModule(LightningModule):
98
 
99
  def step(self, batch: Any):
100
  x = batch["image"]
101
- y = batch["focus_value"]
102
  logits = self.forward(x)
103
  loss = self.criterion(logits, y.unsqueeze(1))
104
  preds = torch.squeeze(logits)
 
98
 
99
  def step(self, batch: Any):
100
  x = batch["image"]
101
+ y = batch["focus_height"]
102
  logits = self.forward(x)
103
  loss = self.criterion(logits, y.unsqueeze(1))
104
  preds = torch.squeeze(logits)
src/models/focus_module.py CHANGED
@@ -83,7 +83,7 @@ class FocusLitModule(LightningModule):
83
 
84
  def step(self, batch: Any):
85
  x = batch["image"]
86
- y = batch["focus_value"]
87
  logits = self.forward(x)
88
  loss = self.criterion(logits, y.unsqueeze(1))
89
  preds = torch.squeeze(logits)
@@ -208,7 +208,7 @@ class FocusMSELitModule(LightningModule):
208
 
209
  def step(self, batch: Any):
210
  x = batch["image"]
211
- y = batch["focus_value"]
212
  logits = self.forward(x)
213
  loss = self.criterion(logits, y.unsqueeze(1))
214
  preds = torch.squeeze(logits)
 
83
 
84
  def step(self, batch: Any):
85
  x = batch["image"]
86
+ y = batch["focus_height"]
87
  logits = self.forward(x)
88
  loss = self.criterion(logits, y.unsqueeze(1))
89
  preds = torch.squeeze(logits)
 
208
 
209
  def step(self, batch: Any):
210
  x = batch["image"]
211
+ y = batch["focus_height"]
212
  logits = self.forward(x)
213
  loss = self.criterion(logits, y.unsqueeze(1))
214
  preds = torch.squeeze(logits)
src/models/focus_resnet_module.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from pytorch_lightning import LightningModule
7
+ from torchmetrics import MaxMetric, MeanAbsoluteError, MinMetric
8
+ from torchmetrics.classification.accuracy import Accuracy
9
+ import torchvision.models as models
10
+
11
+
12
+ class ResNetLitModule(LightningModule):
13
+ def __init__(
14
+ self,
15
+ resnet_type: str = "ResNet",
16
+ pretrained=False,
17
+ lr: float = 0.001,
18
+ weight_decay: float = 0.0005,
19
+ ):
20
+ """Initialize function for a resnet module.
21
+
22
+ Args:
23
+ resnet_type (str, optional): Type of the used resnet network. Defaults to
24
+ "ResNet".
25
+ Can be one of the following values: "ResNet", "resnet18",
26
+ "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d",
27
+ "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2"
28
+ pretrained (bool, optional): if True loads pytorch pretrained models.
29
+ Defaults to False.
30
+ """
31
+ super().__init__()
32
+
33
+ # this line allows to access init params with 'self.hparams' attribute
34
+ # it also ensures init params will be stored in ckpt
35
+ self.save_hyperparameters(logger=False)
36
+
37
+ # loss function
38
+ self.criterion = torch.nn.MSELoss()
39
+
40
+ # use separate metric instance for train, val and test step
41
+ # to ensure a proper reduction over the epoch
42
+ self.train_mae = MeanAbsoluteError()
43
+ self.val_mae = MeanAbsoluteError()
44
+ self.test_mae = MeanAbsoluteError()
45
+
46
+ # for logging best so far validation accuracy
47
+ self.val_mae_best = MinMetric()
48
+
49
+ self.pretrained = pretrained
50
+
51
+ if resnet_type == "ResNet":
52
+ resnet_constructor = models.ResNet()
53
+ elif resnet_type == "resnet18":
54
+ resnet_constructor = models.resnet18
55
+ elif resnet_type == "resnet34":
56
+ resnet_constructor = models.resnet34
57
+ elif resnet_type == "resnet50":
58
+ resnet_constructor = models.resnet50
59
+ elif resnet_type == "resnet101":
60
+ resnet_constructor = models.resnet101
61
+ elif resnet_type == "resnet152":
62
+ resnet_constructor = models.resnet152
63
+ elif resnet_type == "resnext50_32x4d":
64
+ resnet_constructor = models.resnext50_32x4d
65
+ elif resnet_type == "resnext101_32x8d":
66
+ resnet_constructor = models.resnext101_32x8d
67
+ elif resnet_type == "wide_resnet50_2":
68
+ resnet_constructor = models.wide_resnet50_2
69
+ elif resnet_type == "wide_resnet101_2":
70
+ resnet_constructor = models.wide_resnet101_2
71
+ else:
72
+ raise Exception(f"did not find model type: {resnet_type}")
73
+
74
+ backbone = resnet_constructor(pretrained=pretrained)
75
+ # init a pretrained resnet
76
+
77
+ num_filters = backbone.fc.in_features
78
+ layers = list(backbone.children())[:-1]
79
+ self.feature_extractor = nn.Sequential(*layers)
80
+
81
+ self.fc = nn.Linear(num_filters, 1)
82
+
83
+ def forward(self, x):
84
+ representations = self.feature_extractor(x).flatten(1)
85
+ x = self.fc(representations)
86
+ return x
87
+
88
+ def step(self, batch: Any):
89
+ x = batch["image"]
90
+ y = batch["focus_height"]
91
+ logits = self.forward(x)
92
+ loss = self.criterion(logits, y.unsqueeze(1))
93
+ preds = torch.squeeze(logits)
94
+ return loss, preds, y
95
+
96
+ def training_step(self, batch: Any, batch_idx: int):
97
+ loss, preds, targets = self.step(batch)
98
+
99
+ # log train metrics
100
+ mae = self.train_mae(preds, targets)
101
+ self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
102
+ self.log("train/mae", mae, on_step=False, on_epoch=True, prog_bar=True)
103
+
104
+ # we can return here dict with any tensors
105
+ # and then read it in some callback or in `training_epoch_end()`` below
106
+ # remember to always return loss from `training_step()` or else
107
+ # backpropagation will fail!
108
+ return {"loss": loss, "preds": preds, "targets": targets}
109
+
110
+ def training_epoch_end(self, outputs: List[Any]):
111
+ # `outputs` is a list of dicts returned from `training_step()`
112
+ pass
113
+
114
+ def validation_step(self, batch: Any, batch_idx: int):
115
+ loss, preds, targets = self.step(batch)
116
+
117
+ # log val metrics
118
+ mae = self.val_mae(preds, targets)
119
+ self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
120
+ self.log("val/mae", mae, on_step=False, on_epoch=True, prog_bar=True)
121
+
122
+ return {"loss": loss, "preds": preds, "targets": targets}
123
+
124
+ def validation_epoch_end(self, outputs: List[Any]):
125
+ mae = self.val_mae.compute() # get val accuracy from current epoch
126
+ self.val_mae_best.update(mae)
127
+ self.log(
128
+ "val/mae_best", self.val_mae_best.compute(), on_epoch=True, prog_bar=True
129
+ )
130
+
131
+ def test_step(self, batch: Any, batch_idx: int):
132
+ loss, preds, targets = self.step(batch)
133
+
134
+ # log test metrics
135
+ mae = self.test_mae(preds, targets)
136
+ self.log("test/loss", loss, on_step=False, on_epoch=True)
137
+ self.log("test/mae", mae, on_step=False, on_epoch=True)
138
+
139
+ def test_epoch_end(self, outputs: List[Any]):
140
+ print(outputs)
141
+ pass
142
+
143
+ def on_epoch_end(self):
144
+ # reset metrics at the end of every epoch
145
+ self.train_mae.reset()
146
+ self.test_mae.reset()
147
+ self.val_mae.reset()
148
+
149
+ def configure_optimizers(self):
150
+ """Choose what optimizers and learning-rate schedulers.
151
+
152
+ Normally you'd need one. But in the case of GANs or similar you might
153
+ have multiple.
154
+
155
+ See examples here:
156
+ https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
157
+ """
158
+ return torch.optim.Adam(
159
+ params=self.parameters(),
160
+ lr=self.hparams.lr,
161
+ weight_decay=self.hparams.weight_decay,
162
+ )