Hannes Kuchelmeister
commited on
Commit
·
91867af
1
Parent(s):
ceb9ded
readd cod modifications
Browse files- configs/datamodule/focus150.yaml +4 -2
- configs/experiment/focusConvMSE_besthyp_150.yaml +40 -0
- configs/experiment/focusResNet101_pretrained_150.yaml +33 -0
- configs/experiment/focusResNet_150.yaml +33 -0
- configs/hparams_search/focusResNetMSE_150.yaml +70 -0
- configs/model/focusResNet_150.yaml +6 -0
- data/focus150/data_ascaris.json +3 -0
- data/focus150/data_hookworm.json +3 -0
- data/focus150/data_schistosoma.json +3 -0
- data/focus150/data_trichuris.json +3 -0
- data/focus150/test_metadata.csv +3 -0
- data/focus150/train_metadata.csv +3 -0
- data/focus150/validation_metadata.csv +3 -0
- notebooks/3.0-hfk-model-performance-visualisation.ipynb +0 -0
- src/datamodules/focus_datamodule.py +73 -23
- src/models/focus_conv_module.py +1 -1
- src/models/focus_module.py +2 -2
- src/models/focus_resnet_module.py +162 -0
configs/datamodule/focus150.yaml
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
_target_: src.datamodules.focus_datamodule.FocusDataModule
|
2 |
|
3 |
data_dir: ${data_dir}/focus150 # data_dir is specified in config.yaml
|
4 |
-
|
|
|
|
|
|
|
5 |
batch_size: 64
|
6 |
-
train_val_test_split_percentage: [0.7, 0.15, 0.15]
|
7 |
num_workers: 0
|
8 |
pin_memory: False
|
|
|
1 |
_target_: src.datamodules.focus_datamodule.FocusDataModule
|
2 |
|
3 |
data_dir: ${data_dir}/focus150 # data_dir is specified in config.yaml
|
4 |
+
csv_train_file: ${data_dir}/focus150/train_metadata.csv
|
5 |
+
csv_val_file: ${data_dir}/focus150/validation_metadata.csv
|
6 |
+
csv_test_file: ${data_dir}/focus150/test_metadata.csv
|
7 |
+
|
8 |
batch_size: 64
|
|
|
9 |
num_workers: 0
|
10 |
pin_memory: False
|
configs/experiment/focusConvMSE_besthyp_150.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# to execute this experiment run:
|
4 |
+
# python train.py experiment=example
|
5 |
+
|
6 |
+
defaults:
|
7 |
+
- override /datamodule: focus150.yaml
|
8 |
+
- override /model: focusConv_150.yaml
|
9 |
+
- override /callbacks: default.yaml
|
10 |
+
- override /logger: many_loggers
|
11 |
+
- override /trainer: default.yaml
|
12 |
+
|
13 |
+
# all parameters below will be merged with parameters from default configurations set above
|
14 |
+
# this allows you to overwrite only specified parameters
|
15 |
+
|
16 |
+
# name of the run determines folder name in logs
|
17 |
+
name: "focusConvMSE_150"
|
18 |
+
seed: 12345
|
19 |
+
|
20 |
+
trainer:
|
21 |
+
min_epochs: 1
|
22 |
+
max_epochs: 100
|
23 |
+
|
24 |
+
model:
|
25 |
+
image_size: 150
|
26 |
+
pool_size: 2
|
27 |
+
conv1_size: 3
|
28 |
+
conv1_channels: 9
|
29 |
+
conv2_size: 7
|
30 |
+
conv2_channels: 6
|
31 |
+
lin1_size: 32
|
32 |
+
lin2_size: 72
|
33 |
+
output_size: 1
|
34 |
+
lr: 0.001
|
35 |
+
weight_decay: 0.0005
|
36 |
+
|
37 |
+
datamodule:
|
38 |
+
batch_size: 64
|
39 |
+
augmentation: True
|
40 |
+
|
configs/experiment/focusResNet101_pretrained_150.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# to execute this experiment run:
|
4 |
+
# python train.py experiment=example
|
5 |
+
|
6 |
+
defaults:
|
7 |
+
- override /datamodule: focus150.yaml
|
8 |
+
- override /model: focusResNet_150.yaml
|
9 |
+
- override /callbacks: default.yaml
|
10 |
+
- override /logger: many_loggers
|
11 |
+
- override /trainer: default.yaml
|
12 |
+
|
13 |
+
# all parameters below will be merged with parameters from default configurations set above
|
14 |
+
# this allows you to overwrite only specified parameters
|
15 |
+
|
16 |
+
# name of the run determines folder name in logs
|
17 |
+
name: "focusResNet101pretrained_150"
|
18 |
+
seed: 12345
|
19 |
+
|
20 |
+
trainer:
|
21 |
+
min_epochs: 1
|
22 |
+
max_epochs: 100
|
23 |
+
|
24 |
+
model:
|
25 |
+
resnet_type: "resnet101"
|
26 |
+
pretrained: True
|
27 |
+
lr: 0.0011538
|
28 |
+
weight_decay: 0.0005
|
29 |
+
|
30 |
+
datamodule:
|
31 |
+
batch_size: 64
|
32 |
+
augmentation: True
|
33 |
+
|
configs/experiment/focusResNet_150.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# to execute this experiment run:
|
4 |
+
# python train.py experiment=example
|
5 |
+
|
6 |
+
defaults:
|
7 |
+
- override /datamodule: focus150.yaml
|
8 |
+
- override /model: focusResNet_150.yaml
|
9 |
+
- override /callbacks: default.yaml
|
10 |
+
- override /logger: many_loggers
|
11 |
+
- override /trainer: default.yaml
|
12 |
+
|
13 |
+
# all parameters below will be merged with parameters from default configurations set above
|
14 |
+
# this allows you to overwrite only specified parameters
|
15 |
+
|
16 |
+
# name of the run determines folder name in logs
|
17 |
+
name: "focusResNet_150"
|
18 |
+
seed: 12345
|
19 |
+
|
20 |
+
trainer:
|
21 |
+
min_epochs: 1
|
22 |
+
max_epochs: 100
|
23 |
+
|
24 |
+
model:
|
25 |
+
resnet_type: "resnet50"
|
26 |
+
pretrained: false
|
27 |
+
lr: 0.001
|
28 |
+
weight_decay: 0.0005
|
29 |
+
|
30 |
+
datamodule:
|
31 |
+
batch_size: 128
|
32 |
+
augmentation: True
|
33 |
+
|
configs/hparams_search/focusResNetMSE_150.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# example hyperparameter optimization of some experiment with Optuna:
|
4 |
+
# python train.py -m hparams_search=mnist_optuna experiment=example
|
5 |
+
|
6 |
+
defaults:
|
7 |
+
- override /datamodule: focus150.yaml
|
8 |
+
- override /model: focusResNet_150.yaml
|
9 |
+
- override /hydra/sweeper: optuna
|
10 |
+
|
11 |
+
# choose metric which will be optimized by Optuna
|
12 |
+
# make sure this is the correct name of some metric logged in lightning module!
|
13 |
+
optimized_metric: "val/mae_best"
|
14 |
+
|
15 |
+
datamodule:
|
16 |
+
batch_size: 64
|
17 |
+
augmentation: True
|
18 |
+
|
19 |
+
name: "focusResNet_150_hyperparameter_search"
|
20 |
+
|
21 |
+
# here we define Optuna hyperparameter search
|
22 |
+
# it optimizes for value returned from function with @hydra.main decorator
|
23 |
+
# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
|
24 |
+
hydra:
|
25 |
+
sweeper:
|
26 |
+
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
|
27 |
+
|
28 |
+
# storage URL to persist optimization results
|
29 |
+
# for example, you can use SQLite if you set 'sqlite:///example.db'
|
30 |
+
storage: null
|
31 |
+
|
32 |
+
# name of the study to persist optimization results
|
33 |
+
study_name: focusResNet_150_hyperparameter
|
34 |
+
|
35 |
+
# number of parallel workers
|
36 |
+
n_jobs: 1
|
37 |
+
|
38 |
+
# 'minimize' or 'maximize' the objective
|
39 |
+
direction: minimize
|
40 |
+
|
41 |
+
# total number of runs that will be executed
|
42 |
+
n_trials: 20
|
43 |
+
|
44 |
+
# choose Optuna hyperparameter sampler
|
45 |
+
# docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
|
46 |
+
sampler:
|
47 |
+
_target_: optuna.samplers.TPESampler
|
48 |
+
seed: 12345
|
49 |
+
n_startup_trials: 10 # number of random sampling runs before optimization starts
|
50 |
+
|
51 |
+
# define range of hyperparameters
|
52 |
+
search_space:
|
53 |
+
model.pretrained:
|
54 |
+
type: categorical
|
55 |
+
choices: [true, false]
|
56 |
+
model.lr:
|
57 |
+
type: float
|
58 |
+
low: 0.0001
|
59 |
+
high: 0.01
|
60 |
+
model.resnet_type:
|
61 |
+
type: categorical
|
62 |
+
choices: [
|
63 |
+
"ResNet",
|
64 |
+
"resnet18",
|
65 |
+
"resnet34",
|
66 |
+
"resnet50",
|
67 |
+
"resnet101",
|
68 |
+
"resnext50_32x4d",
|
69 |
+
"wide_resnet50_2",
|
70 |
+
]
|
configs/model/focusResNet_150.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.models.focus_resnet_module.ResNetLitModule
|
2 |
+
|
3 |
+
resnet_type: "resnet50"
|
4 |
+
pretrained: false
|
5 |
+
lr: 0.001
|
6 |
+
weight_decay: 0.0005
|
data/focus150/data_ascaris.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9a130ae8a06978312f912f52ac1483cb1d2b278e3719bc10328cc39f8371107d
|
3 |
+
size 53392355
|
data/focus150/data_hookworm.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dd2e638e9f8074b0060391f33dc1cefad94fdd6072b58e686292295a54db6633
|
3 |
+
size 748261
|
data/focus150/data_schistosoma.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0c3e8c404cfdd63b1d001c7aecafb164b3ae38517fe81562dc80c0dbfcccb7a2
|
3 |
+
size 7372453
|
data/focus150/data_trichuris.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7e471d0dfe587f88f12caf37e18b70589af5e244fdee5aa20d9664afb51d434
|
3 |
+
size 1413835
|
data/focus150/test_metadata.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:da25484eafc82b1e6a1307a2f017fee7e153820f09572504e8eabd32f7f72672
|
3 |
+
size 119718
|
data/focus150/train_metadata.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1c2c5abcda4cb0c11f033715004ec6c03c580ddeb3364bddecf6c3278b40e584
|
3 |
+
size 560735
|
data/focus150/validation_metadata.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9bc5da3f5fa7ed919d5728594df752f1d1cd2dd3a506c4fe507639bfdf9a2a1d
|
3 |
+
size 119551
|
notebooks/3.0-hfk-model-performance-visualisation.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/datamodules/focus_datamodule.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import os
|
2 |
-
from typing import Optional, Tuple
|
3 |
import pandas as pd
|
4 |
from skimage import io
|
5 |
|
@@ -14,7 +14,9 @@ from torchvision.transforms import transforms
|
|
14 |
class FocusDataSet(Dataset):
|
15 |
"""Dataset for z-stacked images of neglected tropical diseaeses."""
|
16 |
|
17 |
-
def __init__(
|
|
|
|
|
18 |
"""Initialize focus satck dataset.
|
19 |
|
20 |
Args:
|
@@ -25,8 +27,17 @@ class FocusDataSet(Dataset):
|
|
25 |
"""
|
26 |
self.metadata = pd.read_csv(csv_file)
|
27 |
self.in_memory = in_memory
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
self.col_index_path = self.metadata.columns.get_loc("image_path")
|
29 |
-
self.col_index_focus = self.metadata.columns.get_loc("
|
30 |
self.root_dir = root_dir
|
31 |
self.transform = transform
|
32 |
|
@@ -56,7 +67,7 @@ class FocusDataSet(Dataset):
|
|
56 |
idx (int) The index of the sample that is to be retrieved
|
57 |
|
58 |
Returns:
|
59 |
-
Item/Items which is a dictionary containing "image" and "
|
60 |
"""
|
61 |
if torch.is_tensor(idx):
|
62 |
idx = idx.tolist()
|
@@ -69,11 +80,14 @@ class FocusDataSet(Dataset):
|
|
69 |
if self.transform:
|
70 |
image = self.transform(image)
|
71 |
|
72 |
-
|
73 |
np.asarray(self.metadata.iloc[idx, self.col_index_focus])
|
74 |
).float()
|
75 |
|
76 |
-
sample = {"image": image, "
|
|
|
|
|
|
|
77 |
|
78 |
return sample
|
79 |
|
@@ -86,27 +100,59 @@ class FocusDataModule(LightningDataModule):
|
|
86 |
def __init__(
|
87 |
self,
|
88 |
data_dir: str = "data/",
|
89 |
-
|
90 |
-
|
|
|
91 |
batch_size: int = 64,
|
92 |
num_workers: int = 0,
|
93 |
pin_memory: bool = False,
|
94 |
in_memory: bool = True,
|
|
|
|
|
95 |
):
|
96 |
super().__init__()
|
97 |
|
98 |
# this line allows to access init params with 'self.hparams' attribute
|
99 |
self.save_hyperparameters(logger=False)
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
# data transformations
|
102 |
-
self.transforms = transforms.Compose(
|
103 |
-
[transforms.ToTensor(), transforms.ConvertImageDtype(torch.float)]
|
104 |
-
)
|
105 |
|
106 |
self.data_train: Optional[Dataset] = None
|
107 |
self.data_val: Optional[Dataset] = None
|
108 |
self.data_test: Optional[Dataset] = None
|
109 |
self.in_memory = in_memory
|
|
|
110 |
|
111 |
def prepare_data(self):
|
112 |
"""This method is not implemented as of yet.
|
@@ -123,24 +169,28 @@ class FocusDataModule(LightningDataModule):
|
|
123 |
|
124 |
# load datasets only if they're not loaded already
|
125 |
if not self.data_train and not self.data_val and not self.data_test:
|
126 |
-
|
127 |
-
self.hparams.
|
128 |
self.hparams.data_dir,
|
129 |
transform=self.transforms,
|
130 |
in_memory=self.in_memory,
|
|
|
131 |
)
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
137 |
)
|
138 |
-
test_length = len(dataset) - val_length - train_length
|
139 |
|
140 |
-
self.
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
144 |
)
|
145 |
|
146 |
def train_dataloader(self):
|
|
|
1 |
import os
|
2 |
+
from typing import List, Optional, Tuple
|
3 |
import pandas as pd
|
4 |
from skimage import io
|
5 |
|
|
|
14 |
class FocusDataSet(Dataset):
|
15 |
"""Dataset for z-stacked images of neglected tropical diseaeses."""
|
16 |
|
17 |
+
def __init__(
|
18 |
+
self, csv_file, root_dir, transform=None, in_memory=True, additional_col_list=[]
|
19 |
+
):
|
20 |
"""Initialize focus satck dataset.
|
21 |
|
22 |
Args:
|
|
|
27 |
"""
|
28 |
self.metadata = pd.read_csv(csv_file)
|
29 |
self.in_memory = in_memory
|
30 |
+
|
31 |
+
self.additional_col_index = {}
|
32 |
+
|
33 |
+
_col_list = list(additional_col_list) # clone list to avoid modifying default
|
34 |
+
for attribute in _col_list:
|
35 |
+
self.additional_col_index[attribute] = self.metadata.columns.get_loc(
|
36 |
+
attribute
|
37 |
+
)
|
38 |
+
|
39 |
self.col_index_path = self.metadata.columns.get_loc("image_path")
|
40 |
+
self.col_index_focus = self.metadata.columns.get_loc("focus_height")
|
41 |
self.root_dir = root_dir
|
42 |
self.transform = transform
|
43 |
|
|
|
67 |
idx (int) The index of the sample that is to be retrieved
|
68 |
|
69 |
Returns:
|
70 |
+
Item/Items which is a dictionary containing "image" and "focus_height"
|
71 |
"""
|
72 |
if torch.is_tensor(idx):
|
73 |
idx = idx.tolist()
|
|
|
80 |
if self.transform:
|
81 |
image = self.transform(image)
|
82 |
|
83 |
+
focus_height = torch.from_numpy(
|
84 |
np.asarray(self.metadata.iloc[idx, self.col_index_focus])
|
85 |
).float()
|
86 |
|
87 |
+
sample = {"image": image, "focus_height": focus_height}
|
88 |
+
|
89 |
+
for attr, col_idx in self.additional_col_index.items():
|
90 |
+
sample[attr] = self.metadata.iloc[idx, col_idx]
|
91 |
|
92 |
return sample
|
93 |
|
|
|
100 |
def __init__(
|
101 |
self,
|
102 |
data_dir: str = "data/",
|
103 |
+
csv_train_file: str = "data/train_metadata.csv",
|
104 |
+
csv_val_file: str = "data/validation_metadata.csv",
|
105 |
+
csv_test_file: str = "data/test_metadata.csv",
|
106 |
batch_size: int = 64,
|
107 |
num_workers: int = 0,
|
108 |
pin_memory: bool = False,
|
109 |
in_memory: bool = True,
|
110 |
+
augmentation: bool = False,
|
111 |
+
additional_col_list: List[str] = [],
|
112 |
):
|
113 |
super().__init__()
|
114 |
|
115 |
# this line allows to access init params with 'self.hparams' attribute
|
116 |
self.save_hyperparameters(logger=False)
|
117 |
|
118 |
+
transform_list = [
|
119 |
+
transforms.ToTensor(),
|
120 |
+
transforms.ConvertImageDtype(torch.float),
|
121 |
+
]
|
122 |
+
|
123 |
+
self.base_transforms = []
|
124 |
+
self.base_transforms.extend(transform_list)
|
125 |
+
self.base_transforms = transforms.Compose(self.base_transforms)
|
126 |
+
|
127 |
+
if augmentation:
|
128 |
+
transform_list.extend(
|
129 |
+
[
|
130 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
131 |
+
transforms.RandomVerticalFlip(p=0.5),
|
132 |
+
transforms.RandomChoice(
|
133 |
+
[
|
134 |
+
transforms.RandomApply(
|
135 |
+
[transforms.RandomRotation((90, 90))], p=0.5
|
136 |
+
),
|
137 |
+
transforms.RandomApply(
|
138 |
+
[transforms.RandomRotation((180, 180))], p=0.5
|
139 |
+
),
|
140 |
+
transforms.RandomApply(
|
141 |
+
[transforms.RandomRotation((270, 270))], p=0.5
|
142 |
+
),
|
143 |
+
]
|
144 |
+
),
|
145 |
+
]
|
146 |
+
)
|
147 |
+
|
148 |
# data transformations
|
149 |
+
self.transforms = transforms.Compose(transform_list)
|
|
|
|
|
150 |
|
151 |
self.data_train: Optional[Dataset] = None
|
152 |
self.data_val: Optional[Dataset] = None
|
153 |
self.data_test: Optional[Dataset] = None
|
154 |
self.in_memory = in_memory
|
155 |
+
self.additional_col_list = additional_col_list
|
156 |
|
157 |
def prepare_data(self):
|
158 |
"""This method is not implemented as of yet.
|
|
|
169 |
|
170 |
# load datasets only if they're not loaded already
|
171 |
if not self.data_train and not self.data_val and not self.data_test:
|
172 |
+
self.data_train = FocusDataSet(
|
173 |
+
self.hparams.csv_train_file,
|
174 |
self.hparams.data_dir,
|
175 |
transform=self.transforms,
|
176 |
in_memory=self.in_memory,
|
177 |
+
additional_col_list=self.additional_col_list,
|
178 |
)
|
179 |
+
|
180 |
+
self.data_val = FocusDataSet(
|
181 |
+
self.hparams.csv_val_file,
|
182 |
+
self.hparams.data_dir,
|
183 |
+
transform=self.base_transforms,
|
184 |
+
in_memory=self.in_memory,
|
185 |
+
additional_col_list=self.additional_col_list,
|
186 |
)
|
|
|
187 |
|
188 |
+
self.data_test = FocusDataSet(
|
189 |
+
self.hparams.csv_test_file,
|
190 |
+
self.hparams.data_dir,
|
191 |
+
transform=self.base_transforms,
|
192 |
+
in_memory=self.in_memory,
|
193 |
+
additional_col_list=self.additional_col_list,
|
194 |
)
|
195 |
|
196 |
def train_dataloader(self):
|
src/models/focus_conv_module.py
CHANGED
@@ -98,7 +98,7 @@ class FocusConvLitModule(LightningModule):
|
|
98 |
|
99 |
def step(self, batch: Any):
|
100 |
x = batch["image"]
|
101 |
-
y = batch["
|
102 |
logits = self.forward(x)
|
103 |
loss = self.criterion(logits, y.unsqueeze(1))
|
104 |
preds = torch.squeeze(logits)
|
|
|
98 |
|
99 |
def step(self, batch: Any):
|
100 |
x = batch["image"]
|
101 |
+
y = batch["focus_height"]
|
102 |
logits = self.forward(x)
|
103 |
loss = self.criterion(logits, y.unsqueeze(1))
|
104 |
preds = torch.squeeze(logits)
|
src/models/focus_module.py
CHANGED
@@ -83,7 +83,7 @@ class FocusLitModule(LightningModule):
|
|
83 |
|
84 |
def step(self, batch: Any):
|
85 |
x = batch["image"]
|
86 |
-
y = batch["
|
87 |
logits = self.forward(x)
|
88 |
loss = self.criterion(logits, y.unsqueeze(1))
|
89 |
preds = torch.squeeze(logits)
|
@@ -208,7 +208,7 @@ class FocusMSELitModule(LightningModule):
|
|
208 |
|
209 |
def step(self, batch: Any):
|
210 |
x = batch["image"]
|
211 |
-
y = batch["
|
212 |
logits = self.forward(x)
|
213 |
loss = self.criterion(logits, y.unsqueeze(1))
|
214 |
preds = torch.squeeze(logits)
|
|
|
83 |
|
84 |
def step(self, batch: Any):
|
85 |
x = batch["image"]
|
86 |
+
y = batch["focus_height"]
|
87 |
logits = self.forward(x)
|
88 |
loss = self.criterion(logits, y.unsqueeze(1))
|
89 |
preds = torch.squeeze(logits)
|
|
|
208 |
|
209 |
def step(self, batch: Any):
|
210 |
x = batch["image"]
|
211 |
+
y = batch["focus_height"]
|
212 |
logits = self.forward(x)
|
213 |
loss = self.criterion(logits, y.unsqueeze(1))
|
214 |
preds = torch.squeeze(logits)
|
src/models/focus_resnet_module.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
from pytorch_lightning import LightningModule
|
7 |
+
from torchmetrics import MaxMetric, MeanAbsoluteError, MinMetric
|
8 |
+
from torchmetrics.classification.accuracy import Accuracy
|
9 |
+
import torchvision.models as models
|
10 |
+
|
11 |
+
|
12 |
+
class ResNetLitModule(LightningModule):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
resnet_type: str = "ResNet",
|
16 |
+
pretrained=False,
|
17 |
+
lr: float = 0.001,
|
18 |
+
weight_decay: float = 0.0005,
|
19 |
+
):
|
20 |
+
"""Initialize function for a resnet module.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
resnet_type (str, optional): Type of the used resnet network. Defaults to
|
24 |
+
"ResNet".
|
25 |
+
Can be one of the following values: "ResNet", "resnet18",
|
26 |
+
"resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d",
|
27 |
+
"resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2"
|
28 |
+
pretrained (bool, optional): if True loads pytorch pretrained models.
|
29 |
+
Defaults to False.
|
30 |
+
"""
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
# this line allows to access init params with 'self.hparams' attribute
|
34 |
+
# it also ensures init params will be stored in ckpt
|
35 |
+
self.save_hyperparameters(logger=False)
|
36 |
+
|
37 |
+
# loss function
|
38 |
+
self.criterion = torch.nn.MSELoss()
|
39 |
+
|
40 |
+
# use separate metric instance for train, val and test step
|
41 |
+
# to ensure a proper reduction over the epoch
|
42 |
+
self.train_mae = MeanAbsoluteError()
|
43 |
+
self.val_mae = MeanAbsoluteError()
|
44 |
+
self.test_mae = MeanAbsoluteError()
|
45 |
+
|
46 |
+
# for logging best so far validation accuracy
|
47 |
+
self.val_mae_best = MinMetric()
|
48 |
+
|
49 |
+
self.pretrained = pretrained
|
50 |
+
|
51 |
+
if resnet_type == "ResNet":
|
52 |
+
resnet_constructor = models.ResNet()
|
53 |
+
elif resnet_type == "resnet18":
|
54 |
+
resnet_constructor = models.resnet18
|
55 |
+
elif resnet_type == "resnet34":
|
56 |
+
resnet_constructor = models.resnet34
|
57 |
+
elif resnet_type == "resnet50":
|
58 |
+
resnet_constructor = models.resnet50
|
59 |
+
elif resnet_type == "resnet101":
|
60 |
+
resnet_constructor = models.resnet101
|
61 |
+
elif resnet_type == "resnet152":
|
62 |
+
resnet_constructor = models.resnet152
|
63 |
+
elif resnet_type == "resnext50_32x4d":
|
64 |
+
resnet_constructor = models.resnext50_32x4d
|
65 |
+
elif resnet_type == "resnext101_32x8d":
|
66 |
+
resnet_constructor = models.resnext101_32x8d
|
67 |
+
elif resnet_type == "wide_resnet50_2":
|
68 |
+
resnet_constructor = models.wide_resnet50_2
|
69 |
+
elif resnet_type == "wide_resnet101_2":
|
70 |
+
resnet_constructor = models.wide_resnet101_2
|
71 |
+
else:
|
72 |
+
raise Exception(f"did not find model type: {resnet_type}")
|
73 |
+
|
74 |
+
backbone = resnet_constructor(pretrained=pretrained)
|
75 |
+
# init a pretrained resnet
|
76 |
+
|
77 |
+
num_filters = backbone.fc.in_features
|
78 |
+
layers = list(backbone.children())[:-1]
|
79 |
+
self.feature_extractor = nn.Sequential(*layers)
|
80 |
+
|
81 |
+
self.fc = nn.Linear(num_filters, 1)
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
representations = self.feature_extractor(x).flatten(1)
|
85 |
+
x = self.fc(representations)
|
86 |
+
return x
|
87 |
+
|
88 |
+
def step(self, batch: Any):
|
89 |
+
x = batch["image"]
|
90 |
+
y = batch["focus_height"]
|
91 |
+
logits = self.forward(x)
|
92 |
+
loss = self.criterion(logits, y.unsqueeze(1))
|
93 |
+
preds = torch.squeeze(logits)
|
94 |
+
return loss, preds, y
|
95 |
+
|
96 |
+
def training_step(self, batch: Any, batch_idx: int):
|
97 |
+
loss, preds, targets = self.step(batch)
|
98 |
+
|
99 |
+
# log train metrics
|
100 |
+
mae = self.train_mae(preds, targets)
|
101 |
+
self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
|
102 |
+
self.log("train/mae", mae, on_step=False, on_epoch=True, prog_bar=True)
|
103 |
+
|
104 |
+
# we can return here dict with any tensors
|
105 |
+
# and then read it in some callback or in `training_epoch_end()`` below
|
106 |
+
# remember to always return loss from `training_step()` or else
|
107 |
+
# backpropagation will fail!
|
108 |
+
return {"loss": loss, "preds": preds, "targets": targets}
|
109 |
+
|
110 |
+
def training_epoch_end(self, outputs: List[Any]):
|
111 |
+
# `outputs` is a list of dicts returned from `training_step()`
|
112 |
+
pass
|
113 |
+
|
114 |
+
def validation_step(self, batch: Any, batch_idx: int):
|
115 |
+
loss, preds, targets = self.step(batch)
|
116 |
+
|
117 |
+
# log val metrics
|
118 |
+
mae = self.val_mae(preds, targets)
|
119 |
+
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
|
120 |
+
self.log("val/mae", mae, on_step=False, on_epoch=True, prog_bar=True)
|
121 |
+
|
122 |
+
return {"loss": loss, "preds": preds, "targets": targets}
|
123 |
+
|
124 |
+
def validation_epoch_end(self, outputs: List[Any]):
|
125 |
+
mae = self.val_mae.compute() # get val accuracy from current epoch
|
126 |
+
self.val_mae_best.update(mae)
|
127 |
+
self.log(
|
128 |
+
"val/mae_best", self.val_mae_best.compute(), on_epoch=True, prog_bar=True
|
129 |
+
)
|
130 |
+
|
131 |
+
def test_step(self, batch: Any, batch_idx: int):
|
132 |
+
loss, preds, targets = self.step(batch)
|
133 |
+
|
134 |
+
# log test metrics
|
135 |
+
mae = self.test_mae(preds, targets)
|
136 |
+
self.log("test/loss", loss, on_step=False, on_epoch=True)
|
137 |
+
self.log("test/mae", mae, on_step=False, on_epoch=True)
|
138 |
+
|
139 |
+
def test_epoch_end(self, outputs: List[Any]):
|
140 |
+
print(outputs)
|
141 |
+
pass
|
142 |
+
|
143 |
+
def on_epoch_end(self):
|
144 |
+
# reset metrics at the end of every epoch
|
145 |
+
self.train_mae.reset()
|
146 |
+
self.test_mae.reset()
|
147 |
+
self.val_mae.reset()
|
148 |
+
|
149 |
+
def configure_optimizers(self):
|
150 |
+
"""Choose what optimizers and learning-rate schedulers.
|
151 |
+
|
152 |
+
Normally you'd need one. But in the case of GANs or similar you might
|
153 |
+
have multiple.
|
154 |
+
|
155 |
+
See examples here:
|
156 |
+
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
|
157 |
+
"""
|
158 |
+
return torch.optim.Adam(
|
159 |
+
params=self.parameters(),
|
160 |
+
lr=self.hparams.lr,
|
161 |
+
weight_decay=self.hparams.weight_decay,
|
162 |
+
)
|