DmitriiKhizbullin commited on
Commit
699342a
1 Parent(s): 8751a91

Initial implementation of trainer and gradio app

Browse files
Files changed (11) hide show
  1. .gitattributes +3 -0
  2. .gitignore +6 -156
  3. app.py +123 -0
  4. labelmap.py +7 -0
  5. mypy.ini +3 -0
  6. tag.sh +2 -0
  7. train.py +563 -0
  8. train_a100_x1.sh +27 -0
  9. train_a100_x4.sh +27 -0
  10. train_v100_x1.sh +27 -0
  11. train_v100_x4.sh +27 -0
.gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ media/* filter=lfs diff=lfs merge=lfs -text
2
+ release_ckpts/** filter=lfs diff=lfs merge=lfs -text
3
+ demo_data/** filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1,160 +1,10 @@
1
- # Byte-compiled / optimized / DLL files
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
 
6
- # C extensions
7
- *.so
8
 
9
- # Distribution / packaging
10
- .Python
11
- build/
12
- develop-eggs/
13
- dist/
14
- downloads/
15
- eggs/
16
- .eggs/
17
- lib/
18
- lib64/
19
- parts/
20
- sdist/
21
- var/
22
- wheels/
23
- share/python-wheels/
24
- *.egg-info/
25
- .installed.cfg
26
- *.egg
27
- MANIFEST
28
 
29
- # PyInstaller
30
- # Usually these files are written by a python script from a template
31
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
- *.manifest
33
- *.spec
34
 
35
- # Installer logs
36
- pip-log.txt
37
- pip-delete-this-directory.txt
38
-
39
- # Unit test / coverage reports
40
- htmlcov/
41
- .tox/
42
- .nox/
43
- .coverage
44
- .coverage.*
45
- .cache
46
- nosetests.xml
47
- coverage.xml
48
- *.cover
49
- *.py,cover
50
- .hypothesis/
51
- .pytest_cache/
52
- cover/
53
-
54
- # Translations
55
- *.mo
56
- *.pot
57
-
58
- # Django stuff:
59
- *.log
60
- local_settings.py
61
- db.sqlite3
62
- db.sqlite3-journal
63
-
64
- # Flask stuff:
65
- instance/
66
- .webassets-cache
67
-
68
- # Scrapy stuff:
69
- .scrapy
70
-
71
- # Sphinx documentation
72
- docs/_build/
73
-
74
- # PyBuilder
75
- .pybuilder/
76
- target/
77
-
78
- # Jupyter Notebook
79
- .ipynb_checkpoints
80
-
81
- # IPython
82
- profile_default/
83
- ipython_config.py
84
-
85
- # pyenv
86
- # For a library or package, you might want to ignore these files since the code is
87
- # intended to run in multiple environments; otherwise, check them in:
88
- # .python-version
89
-
90
- # pipenv
91
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
- # install all needed dependencies.
95
- #Pipfile.lock
96
-
97
- # poetry
98
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
- # This is especially recommended for binary packages to ensure reproducibility, and is more
100
- # commonly ignored for libraries.
101
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
- #poetry.lock
103
-
104
- # pdm
105
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
- #pdm.lock
107
- # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
- # in version control.
109
- # https://pdm.fming.dev/#use-with-ide
110
- .pdm.toml
111
-
112
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
- __pypackages__/
114
-
115
- # Celery stuff
116
- celerybeat-schedule
117
- celerybeat.pid
118
-
119
- # SageMath parsed files
120
- *.sage.py
121
-
122
- # Environments
123
- .env
124
- .venv
125
- env/
126
- venv/
127
- ENV/
128
- env.bak/
129
- venv.bak/
130
-
131
- # Spyder project settings
132
- .spyderproject
133
- .spyproject
134
-
135
- # Rope project settings
136
- .ropeproject
137
-
138
- # mkdocs documentation
139
- /site
140
-
141
- # mypy
142
- .mypy_cache/
143
- .dmypy.json
144
- dmypy.json
145
-
146
- # Pyre type checker
147
- .pyre/
148
-
149
- # pytype static type analyzer
150
- .pytype/
151
-
152
- # Cython debug symbols
153
- cython_debug/
154
-
155
- # PyCharm
156
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
- # and can be added to the global gitignore or merged into this file. For a more nuclear
159
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
- #.idea/
 
1
+ .vscode/
 
 
 
2
 
3
+ /retinopathy_data
4
+ /slurm_logs
5
 
6
+ /lightning_logs/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ __pycache__
 
 
 
 
9
 
10
+ /dcgm/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import numpy as np
4
+ import torch
5
+ from typing import Tuple, Optional, Dict, List
6
+ import glob
7
+ from collections import defaultdict
8
+
9
+ from transformers import (AutoImageProcessor,
10
+ ResNetForImageClassification)
11
+
12
+ from labelmap import DR_LABELMAP
13
+
14
+
15
+ class App:
16
+ def __init__(self) -> None:
17
+
18
+ ckpt_name = "2023-12-24_20-02-18_30345221_V100_x4_resnet34/"
19
+
20
+ path = f"release_ckpts/{ckpt_name}/inference/"
21
+
22
+ self.image_processor = AutoImageProcessor.from_pretrained(path)
23
+
24
+ self.model = ResNetForImageClassification.from_pretrained(path)
25
+
26
+ example_lists = self._load_example_lists()
27
+
28
+ device = 'GPU' if torch.cuda.is_available() else 'CPU'
29
+
30
+ css = ".output-image, .input-image, .image-preview {height: 600px !important}"
31
+
32
+ with gr.Blocks(css=css) as ui:
33
+ with gr.Row():
34
+ with gr.Column(scale=1):
35
+ with gr.Row():
36
+ predict_btn = gr.Button("Predict", size="lg")
37
+ with gr.Row():
38
+ gr.Markdown(f"Running on {device}")
39
+ with gr.Column(scale=4):
40
+ # output = gr.Textbox(label="Retinopathy level prediction")
41
+ output = gr.Label(num_top_classes=len(DR_LABELMAP),
42
+ label="Retinopathy level prediction")
43
+ with gr.Column(scale=4):
44
+ gr.Markdown("![](https://media.githubusercontent.com/media/Obs01ete/retinopathy/master/media/logo1.png)")
45
+ with gr.Row():
46
+ with gr.Column(scale=9, min_width=100):
47
+ image = gr.Image(label="Retina scan")
48
+ with gr.Column(scale=1, min_width=150):
49
+ for cls_id in range(len(example_lists)):
50
+ label = DR_LABELMAP[cls_id]
51
+ with gr.Tab(f"{cls_id} : {label}"):
52
+ gr.Examples(
53
+ example_lists[cls_id],
54
+ inputs=[image],
55
+ outputs=[output],
56
+ fn=self.predict,
57
+ examples_per_page=10,
58
+ run_on_click=True)
59
+
60
+ predict_btn.click(
61
+ fn=self.predict,
62
+ inputs=image,
63
+ outputs=output,
64
+ api_name="predict")
65
+
66
+ self.ui = ui
67
+
68
+ def launch(self) -> None:
69
+ self.ui.queue().launch(share=True)
70
+
71
+ def predict(self, image: Optional[np.ndarray]):
72
+ if image is None:
73
+ return dict()
74
+ cls_name, prob, probs = self._infer(image)
75
+ message = f"Predicted class={cls_name}, prob={prob:.3f}"
76
+ print(message)
77
+ probs_dict = {f"{i} - {DR_LABELMAP[i]}": float(v)
78
+ for i, v in enumerate(probs)}
79
+ return probs_dict
80
+
81
+ def _infer(self, image_chw: np.ndarray) -> Tuple[str, float, np.ndarray]:
82
+ assert isinstance(self.model, ResNetForImageClassification)
83
+
84
+ inputs = self.image_processor(image_chw, return_tensors="pt")
85
+
86
+ with torch.no_grad():
87
+ output = self.model(**inputs)
88
+
89
+ logits_batch = output.logits
90
+ assert len(logits_batch.shape) == 2
91
+ assert logits_batch.shape[0] == 1
92
+ logits = logits_batch[0]
93
+ probs = torch.softmax(logits, dim=-1)
94
+ predicted_label = int(probs.argmax(-1).item())
95
+ prob = probs[predicted_label].item()
96
+ cls_name = self.model.config.id2label[predicted_label]
97
+ return cls_name, prob, probs.numpy()
98
+
99
+ @staticmethod
100
+ def _load_example_lists() -> Dict[int, List[str]]:
101
+
102
+ example_flat_list = glob.glob("demo_data/train/**/*.jpeg")
103
+
104
+ example_lists: Dict[int, List[str]] = defaultdict(list)
105
+ for path in example_flat_list:
106
+ dir, _ = os.path.split(path)
107
+ _, subdir = os.path.split(dir)
108
+ try:
109
+ cls_id = int(subdir)
110
+ except ValueError:
111
+ print(f"Cannot parse path {path}")
112
+ continue
113
+ example_lists[cls_id].append(path)
114
+ return example_lists
115
+
116
+
117
+ def main():
118
+ app = App()
119
+ app.launch()
120
+
121
+
122
+ if __name__ == "__main__":
123
+ main()
labelmap.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ DR_LABELMAP = {
2
+ 0: 'No DR',
3
+ 1: 'Mild',
4
+ 2: 'Moderate',
5
+ 3: 'Severe',
6
+ 4: 'Proliferative DR',
7
+ }
mypy.ini ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [mypy]
2
+ ignore_missing_imports = True
3
+ check_untyped_defs = True
tag.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ TAG=stratval_24h
2
+ EXTRA_KEY=
train.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import (Any, List, Dict, Optional, Tuple,
3
+ Union, Callable, Iterable, Iterator)
4
+ import pandas as pd
5
+ from PIL import Image
6
+ import datetime
7
+ from argparse import ArgumentParser
8
+ from enum import Enum
9
+ import numpy as np
10
+ from numpy.random import RandomState
11
+ import collections.abc
12
+ from collections import Counter, defaultdict
13
+ import math
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.data as data
18
+ from torch.utils.data import DataLoader
19
+
20
+ from torchvision.transforms import (
21
+ CenterCrop,
22
+ Compose,
23
+ Normalize,
24
+ RandomHorizontalFlip,
25
+ RandomResizedCrop,
26
+ RandomRotation,
27
+ RandomAffine,
28
+ Resize,
29
+ ToTensor)
30
+
31
+ from transformers import ViTImageProcessor
32
+ from transformers import ViTForImageClassification
33
+ from transformers import AdamW
34
+
35
+ from transformers import AutoImageProcessor, ResNetForImageClassification
36
+
37
+ import lightning as L
38
+ from lightning import Trainer
39
+ from lightning.pytorch.loggers import TensorBoardLogger
40
+ from lightning.pytorch.callbacks import ModelSummary
41
+ from torchmetrics.aggregation import MeanMetric
42
+ from torchmetrics.classification.accuracy import MulticlassAccuracy
43
+ from torchmetrics.classification import MulticlassCohenKappa
44
+
45
+ from labelmap import DR_LABELMAP
46
+
47
+
48
+ DataRecord = Tuple[Image.Image, int]
49
+
50
+
51
+ class RetinopathyDataset(data.Dataset[DataRecord]):
52
+ def __init__(self, data_path: str) -> None:
53
+ super().__init__()
54
+
55
+ self.data_path = data_path
56
+
57
+ self.ext = ".jpeg"
58
+
59
+ anno_path = os.path.join(data_path, "trainLabels.csv")
60
+ self.anno_df = pd.read_csv(anno_path) # ['image', 'level']
61
+ anno_name_set = set(self.anno_df['image'])
62
+
63
+ if True:
64
+ train_path = os.path.join(data_path, "train")
65
+ img_path_list = os.listdir(train_path)
66
+ img_name_set = set([os.path.splitext(p)[0] for p in img_path_list])
67
+ assert anno_name_set == img_name_set
68
+
69
+ self.label_map = DR_LABELMAP
70
+
71
+ def __getitem__(self, index: Union[int, slice]) -> DataRecord:
72
+ assert isinstance(index, int)
73
+ img_path = self.get_path_at(index)
74
+ img = Image.open(img_path)
75
+ label = self.get_label_at(index)
76
+ return img, label
77
+
78
+ def __len__(self) -> int:
79
+ return len(self.anno_df)
80
+
81
+ def get_label_at(self, index: int) -> int:
82
+ label = self.anno_df['level'].iloc[index].item()
83
+ return label
84
+
85
+ def get_path_at(self, index: int) -> str:
86
+ img_name = self.anno_df['image'].iloc[index]
87
+ img_path = os.path.join(self.data_path, "train", img_name+self.ext)
88
+ return img_path
89
+
90
+
91
+ class Purpose(Enum):
92
+ Train = 0
93
+ Val = 1
94
+
95
+
96
+ FeatureAndTargetTransforms = Tuple[Callable[..., torch.Tensor],
97
+ Callable[..., torch.Tensor]]
98
+
99
+ TensorRecord = Tuple[torch.Tensor, torch.Tensor]
100
+
101
+ def normalize(arr: np.ndarray) -> np.ndarray:
102
+ return arr / np.sum(arr)
103
+
104
+
105
+ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
106
+ def __init__(self, dataset: RetinopathyDataset,
107
+ indices: np.ndarray,
108
+ purpose: Purpose,
109
+ transforms: FeatureAndTargetTransforms,
110
+ oversample_factor: int = 1,
111
+ stratify_classes: bool = False,
112
+ use_log_frequencies: bool = False,
113
+ ):
114
+
115
+ self.dataset = dataset
116
+ self.indices = indices
117
+ self.purpose = purpose
118
+ self.feature_transform = transforms[0]
119
+ self.target_transform = transforms[1]
120
+ self.oversample_factor = oversample_factor
121
+ self.stratify_classes = stratify_classes
122
+ self.use_log_frequencies = use_log_frequencies
123
+
124
+ self.per_class_indices: Optional[Dict[int, np.ndarray]] = None
125
+ self.frequencies: Optional[Dict[int, float]] = None
126
+ if self.stratify_classes:
127
+ self.bucketize_indices()
128
+ if self.use_log_frequencies:
129
+ self.calc_frequencies()
130
+
131
+ def calc_frequencies(self):
132
+ assert self.per_class_indices is not None
133
+ counts_dict = {lbl: len(arr) for lbl, arr in self.per_class_indices.items()}
134
+ counts = np.array(list(counts_dict.values()))
135
+ counts_nrm = normalize(counts)
136
+ temperature = 50.0 # > 1 to even-out frequencies
137
+ freqs = normalize(np.log1p(counts_nrm * temperature))
138
+ self.frequencies = {k: freq.item() for k, freq
139
+ in zip(self.per_class_indices.keys(), freqs)}
140
+ print(self.frequencies)
141
+
142
+ def bucketize_indices(self):
143
+ buckets = defaultdict(list)
144
+ for index in self.indices:
145
+ label = self.dataset.get_label_at(index)
146
+ buckets[label].append(index)
147
+ self.per_class_indices = {k: np.array(v)
148
+ for k, v in buckets.items()}
149
+
150
+ def __getitem__(self, index: Union[int, slice]) -> TensorRecord: # type: ignore[override]
151
+ assert isinstance(index, int)
152
+ if self.purpose == Purpose.Train:
153
+ index_rem = index % len(self.indices)
154
+ idx = self.indices[index_rem].item()
155
+ else:
156
+ idx = self.indices[index].item()
157
+ if self.per_class_indices:
158
+ if self.frequencies is not None:
159
+ arange = np.arange(len(self.per_class_indices))
160
+ frequencies = np.zeros(len(self.per_class_indices), dtype=float)
161
+ for k, v in self.frequencies.items():
162
+ frequencies[k] = v
163
+ random_key = np.random.choice(
164
+ arange,
165
+ p=frequencies)
166
+ else:
167
+ random_key = np.random.randint(len(self.per_class_indices))
168
+
169
+ indices = self.per_class_indices[random_key]
170
+ actual_index = np.random.choice(indices).item()
171
+ else:
172
+ actual_index = idx
173
+ feature, target = self.dataset[actual_index]
174
+ feature_tensor = self.feature_transform(feature)
175
+ target_tensor = self.target_transform(target)
176
+ return feature_tensor, target_tensor
177
+
178
+ def __len__(self):
179
+ if self.purpose == Purpose.Train:
180
+ return len(self.indices) * self.oversample_factor
181
+ else:
182
+ return len(self.indices)
183
+
184
+ @staticmethod
185
+ def make_splits(all_data: RetinopathyDataset,
186
+ train_transforms: FeatureAndTargetTransforms,
187
+ val_transforms: FeatureAndTargetTransforms,
188
+ train_fraction: float,
189
+ stratify_train: bool,
190
+ stratify_val: bool,
191
+ seed: int = 54,
192
+ ) -> Tuple['Split', 'Split']:
193
+
194
+ prng = RandomState(seed)
195
+
196
+ num_train = int(len(all_data) * train_fraction)
197
+ all_indices = prng.permutation(len(all_data))
198
+ train_indices = all_indices[:num_train]
199
+ val_indices = all_indices[num_train:]
200
+ train_data = Split(all_data, train_indices, Purpose.Train,
201
+ train_transforms, stratify_classes=stratify_train)
202
+ val_data = Split(all_data, val_indices, Purpose.Val,
203
+ val_transforms, stratify_classes=stratify_val)
204
+ return train_data, val_data
205
+
206
+
207
+ def print_data_stats(dataset: Union[Iterable[DataRecord], DataLoader], split_name: str) -> None:
208
+ labels = []
209
+ for _, label in dataset:
210
+ if isinstance(label, torch.Tensor):
211
+ label = label.cpu().numpy()
212
+ labels.append(label)
213
+ labels = np.concatenate(labels)
214
+ cnt = Counter(labels)
215
+ print(cnt)
216
+
217
+
218
+ class Metrics:
219
+ def __init__(self,
220
+ num_classes: int,
221
+ labelmap: Dict[int, str],
222
+ split: str,
223
+ log_fn: Callable[..., None]) -> None:
224
+ self.labelmap = labelmap
225
+ self.loss = MeanMetric(nan_strategy='ignore')
226
+ self.accuracy = MulticlassAccuracy(num_classes=num_classes)
227
+ self.per_class_accuracies = MulticlassAccuracy(
228
+ num_classes=num_classes, average=None)
229
+ self.kappa = MulticlassCohenKappa(num_classes)
230
+ self.split = split
231
+ self.log_fn = log_fn
232
+
233
+ def update(self,
234
+ loss: torch.Tensor,
235
+ preds: torch.Tensor,
236
+ labels: torch.Tensor) -> None:
237
+ self.loss.update(loss)
238
+ self.accuracy.update(preds, labels)
239
+ self.per_class_accuracies.update(preds, labels)
240
+ self.kappa.update(preds, labels)
241
+
242
+ def log(self) -> None:
243
+ loss = self.loss.compute()
244
+ accuracy = self.accuracy.compute()
245
+ accuracies = self.per_class_accuracies.compute()
246
+ kappa = self.kappa.compute()
247
+ mean_accuracy = torch.nanmean(accuracies)
248
+ self.log_fn(f"{self.split}/loss", loss, sync_dist=True)
249
+ self.log_fn(f"{self.split}/accuracy", accuracy, sync_dist=True)
250
+ self.log_fn(f"{self.split}/mean_accuracy", mean_accuracy, sync_dist=True)
251
+ for i_class, acc in enumerate(accuracies):
252
+ name = self.labelmap[i_class]
253
+ self.log_fn(f"{self.split}/acc/{i_class} {name}", acc, sync_dist=True)
254
+ self.log_fn(f"{self.split}/kappa", kappa, sync_dist=True)
255
+
256
+ def to(self, device) -> 'Metrics':
257
+ self.loss.to(device) # BUG HERE? should I assign it back?
258
+ self.accuracy.to(device)
259
+ self.per_class_accuracies.to(device)
260
+ self.kappa.to(device)
261
+ return self
262
+
263
+
264
+ def worker_init_fn(worker_id):
265
+ state = np.random.get_state()
266
+ assert isinstance(state, tuple)
267
+ assert isinstance(state[1], np.ndarray)
268
+ seed_arr = state[1]
269
+ seed_np = seed_arr[0] + worker_id
270
+ np.random.seed(seed_np)
271
+ seed_pt = seed_np + 1111
272
+ torch.manual_seed(seed_pt)
273
+ print(f"Setting numpy seed to {seed_np} and pytorch seed to {seed_pt} in worker {worker_id}")
274
+
275
+
276
+ class ViTLightningModule(L.LightningModule):
277
+ def __init__(self, debug: bool) -> None:
278
+ super().__init__()
279
+
280
+ self.save_hyperparameters()
281
+
282
+ np.random.seed(53)
283
+
284
+ # pretrained_name = 'google/vit-base-patch16-224-in21k'
285
+ # pretrained_name = 'google/vit-base-patch16-384-in21k'
286
+
287
+ # pretrained_name = "microsoft/resnet-50"
288
+ pretrained_name = "microsoft/resnet-34"
289
+
290
+ # processor = ViTImageProcessor.from_pretrained(pretrained_name)
291
+ processor = AutoImageProcessor.from_pretrained(pretrained_name)
292
+
293
+ image_mean = processor.image_mean # type: ignore
294
+ image_std = processor.image_std # type: ignore
295
+ # size = processor.size["height"] # type: ignore
296
+ # size = processor.size["shortest_edge"] # type: ignore
297
+ size = 896 # 448
298
+
299
+ normalize = Normalize(mean=image_mean, std=image_std)
300
+ train_transforms = Compose(
301
+ [
302
+ # RandomRotation((-180, 180)),
303
+ RandomAffine((-180, 180), shear=10),
304
+ RandomResizedCrop(size, scale=(0.5, 1.0)),
305
+ RandomHorizontalFlip(),
306
+ ToTensor(),
307
+ normalize,
308
+ ]
309
+ )
310
+ val_transforms = Compose(
311
+ [
312
+ Resize(size),
313
+ CenterCrop(size),
314
+ ToTensor(),
315
+ normalize,
316
+ ]
317
+ )
318
+
319
+ self.dataset = RetinopathyDataset("retinopathy_data")
320
+
321
+ # print_data_stats(self.dataset, "all_data")
322
+
323
+ train_data, val_data = Split.make_splits(
324
+ self.dataset,
325
+ train_transforms=(train_transforms, torch.tensor),
326
+ val_transforms=(val_transforms, torch.tensor),
327
+ train_fraction=0.9,
328
+ stratify_train=True,
329
+ stratify_val=True,
330
+ )
331
+
332
+ assert len(set(train_data.indices).intersection(set(val_data.indices))) == 0
333
+
334
+ label2id = {label: id for id, label in self.dataset.label_map.items()}
335
+
336
+ num_classes = len(self.dataset.label_map)
337
+ labelmap = self.dataset.label_map
338
+ assert len(labelmap) == num_classes
339
+ assert set(labelmap.keys()) == set(range(num_classes))
340
+
341
+ train_batch_size = 4 if debug else 20
342
+ val_batch_size = 4 if debug else 20
343
+
344
+ num_gpus = torch.cuda.device_count()
345
+ print(f"{num_gpus=}")
346
+
347
+ num_cores = torch.get_num_threads()
348
+ print(f"{num_cores=}")
349
+
350
+ num_threads_per_gpu = max(1, int(math.ceil(num_cores / num_gpus))) \
351
+ if num_gpus > 0 else 1
352
+
353
+ num_workers = 1 if debug else num_threads_per_gpu
354
+ print(f"{num_workers=}")
355
+
356
+ self._train_dataloader = DataLoader(
357
+ train_data,
358
+ shuffle=True,
359
+ num_workers=num_workers,
360
+ persistent_workers=num_workers > 0,
361
+ pin_memory=True,
362
+ batch_size=train_batch_size,
363
+ worker_init_fn=worker_init_fn,
364
+ )
365
+ self._val_dataloader = DataLoader(
366
+ val_data,
367
+ shuffle=False,
368
+ num_workers=num_workers,
369
+ persistent_workers=num_workers > 0,
370
+ pin_memory=True,
371
+ batch_size=val_batch_size,
372
+ )
373
+
374
+ # print_data_stats(self._val_dataloader, "val")
375
+ # print_data_stats(self._train_dataloader, "train")
376
+
377
+ img_batch, label_batch = next(iter(self._train_dataloader))
378
+ assert isinstance(img_batch, torch.Tensor)
379
+ assert isinstance(label_batch, torch.Tensor)
380
+ print(f"{img_batch.shape=} {label_batch.shape=}")
381
+
382
+ assert img_batch.shape == (train_batch_size, 3, size, size)
383
+ assert label_batch.shape == (train_batch_size,)
384
+
385
+ self.example_input_array = torch.randn_like(img_batch)
386
+
387
+ # self._model = ViTForImageClassification.from_pretrained(
388
+ # pretrained_name,
389
+ # num_labels=len(self.dataset.label_map),
390
+ # id2label=self.dataset.label_map,
391
+ # label2id=label2id)
392
+
393
+ self._model = ResNetForImageClassification.from_pretrained(
394
+ pretrained_name,
395
+ num_labels=len(self.dataset.label_map),
396
+ id2label=self.dataset.label_map,
397
+ label2id=label2id,
398
+ ignore_mismatched_sizes=True)
399
+
400
+ assert isinstance(self._model, nn.Module)
401
+
402
+ self.train_metrics: Optional[Metrics] = None
403
+ self.val_metrics: Optional[Metrics] = None
404
+
405
+ @property
406
+ def num_classes(self):
407
+ return len(self.dataset.label_map)
408
+
409
+ @property
410
+ def labelmap(self):
411
+ return self.dataset.label_map
412
+
413
+ def forward(self, img_batch):
414
+ outputs = self._model(img_batch) # type: ignore
415
+ return outputs.logits
416
+
417
+ def common_step(self, batch, batch_idx):
418
+ img_batch, label_batch = batch
419
+
420
+ logits = self(img_batch)
421
+
422
+ criterion = nn.CrossEntropyLoss()
423
+ loss = criterion(logits, label_batch)
424
+ preds_batch = logits.argmax(-1)
425
+
426
+ return loss, preds_batch, label_batch
427
+
428
+ def on_train_epoch_start(self) -> None:
429
+ self.train_metrics = Metrics(
430
+ self.num_classes,
431
+ self.labelmap,
432
+ "train",
433
+ self.log).to(self.device)
434
+
435
+ def training_step(self, batch, batch_idx):
436
+ loss, preds, labels = self.common_step(batch, batch_idx)
437
+ assert self.train_metrics is not None
438
+ self.train_metrics.update(loss, preds, labels)
439
+
440
+ if False and batch_idx == 0:
441
+ self._dump_train_images()
442
+
443
+ return loss
444
+
445
+ def _dump_train_images(self) -> None:
446
+ img_batch, label_batch = next(iter(self._train_dataloader))
447
+ for i_img, (img, label) in enumerate(zip(img_batch, label_batch)):
448
+ img_np = img.cpu().numpy()
449
+ denorm_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
450
+ img_uint8 = (255 * denorm_np).astype(np.uint8)
451
+ pil_img = Image.fromarray(np.transpose(img_uint8, (1, 2, 0)))
452
+ if self.logger is not None and self.logger.log_dir is not None:
453
+ assert isinstance(self.logger.log_dir, str)
454
+ os.makedirs(self.logger.log_dir, exist_ok=True)
455
+ path = os.path.join(self.logger.log_dir,
456
+ f"img_{i_img:02d}_{label.item()}.png")
457
+ pil_img.save(path)
458
+
459
+ def on_train_epoch_end(self) -> None:
460
+ assert self.train_metrics is not None
461
+ self.train_metrics.log()
462
+ assert self.logger is not None
463
+ if self.logger.log_dir is not None:
464
+ path = os.path.join(self.logger.log_dir, "inference")
465
+ self.save_checkpoint_dk(path)
466
+
467
+ def save_checkpoint_dk(self, dirpath: str) -> None:
468
+ if self.global_rank == 0:
469
+ self._model.save_pretrained(dirpath)
470
+
471
+ def validation_step(self, batch, batch_idx):
472
+ loss, preds, labels = self.common_step(batch, batch_idx)
473
+ assert self.val_metrics is not None
474
+ self.val_metrics.update(loss, preds, labels)
475
+ return loss
476
+
477
+ def on_validation_epoch_start(self) -> None:
478
+ self.val_metrics = Metrics(
479
+ self.num_classes,
480
+ self.labelmap,
481
+ "val",
482
+ self.log).to(self.device)
483
+
484
+ def on_validation_epoch_end(self) -> None:
485
+ assert self.val_metrics is not None
486
+ self.val_metrics.log()
487
+
488
+ def configure_optimizers(self):
489
+ # No WD is the same as 1e-3 and better than 1e-2
490
+ # LR 1e-3 is worse than 1e-4 (without LR scheduler)
491
+ return AdamW(self.parameters(),
492
+ lr=1e-4,
493
+ )
494
+
495
+
496
+ def main():
497
+
498
+ parser = ArgumentParser(description='KAUST-SDAIA Diabetic Retinopathy')
499
+ parser.add_argument('--tag', action='store', type=str,
500
+ help='Extra suffix to put on the artefact dir name')
501
+ parser.add_argument('--debug', action='store_true')
502
+ parser.add_argument('--convert-checkpoint', action='store', type=str,
503
+ help='Convert a checkpoint from training to pickle-independent '
504
+ 'predictor-compatible directory')
505
+
506
+ args = parser.parse_args()
507
+
508
+
509
+ torch.set_float32_matmul_precision('high') # for V100/A100
510
+
511
+ if args.convert_checkpoint is not None:
512
+
513
+ print("Converting checkpoint", args.convert_checkpoint)
514
+
515
+ checkpoint = torch.load(args.convert_checkpoint, map_location="cpu")
516
+ print(list(checkpoint.keys()))
517
+
518
+ model = ViTLightningModule.load_from_checkpoint(
519
+ args.convert_checkpoint,
520
+ map_location="cpu",
521
+ hparams_file="tmp_ckpt_deleteme.yaml")
522
+
523
+ model.save_checkpoint_dk("tmp_checkp_path_deleteme")
524
+
525
+ print("Saved checkpoint. Done.")
526
+
527
+ else:
528
+
529
+ print("Start training")
530
+
531
+ fast_dev_run = True if args.debug == True else False
532
+
533
+ model = ViTLightningModule(fast_dev_run)
534
+
535
+ datetime_str = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
536
+ art_dir_name = (f"{datetime_str}" +
537
+ (f"_{args.tag}" if args.tag is not None else ""))
538
+ logger = TensorBoardLogger(save_dir=".", name="lightning_logs", version=art_dir_name)
539
+
540
+ trainer = Trainer(
541
+ logger=logger,
542
+ benchmark=True,
543
+ devices="auto",
544
+ accelerator="auto",
545
+ max_epochs=-1,
546
+ callbacks=[
547
+ ModelSummary(max_depth=-1),
548
+ ],
549
+ fast_dev_run=fast_dev_run,
550
+ log_every_n_steps=10,
551
+ )
552
+
553
+ trainer.fit(
554
+ model,
555
+ train_dataloaders=model._train_dataloader,
556
+ val_dataloaders=model._val_dataloader,
557
+ )
558
+
559
+ print("Training done")
560
+
561
+
562
+ if __name__ == "__main__":
563
+ main()
train_a100_x1.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ sbatch <<EOT
3
+ #!/bin/bash
4
+
5
+ #SBATCH -N 1
6
+ #SBATCH -J SDAIA_DR
7
+ #SBATCH -o slurm_logs/output.%J.out
8
+ #SBATCH -e slurm_logs/output.%J.err
9
+ #SBATCH --mail-user=${USER}@kaust.edu.sa
10
+ #SBATCH --mail-type=END,FAIL
11
+ #SBATCH --time=24:00:00
12
+ #SBATCH --mem=100G
13
+ #SBATCH --ntasks=1
14
+ #SBATCH --gres=gpu:a100:1
15
+ #SBATCH --cpus-per-task=16
16
+
17
+ module purge
18
+ source activate retinopathy
19
+
20
+ echo Running four user "${USER}"
21
+
22
+ . ./tag.sh
23
+
24
+ PYTHONPATH=. python train.py --tag=\${SLURM_JOB_ID}_A100_x1_\${TAG}
25
+
26
+ exit 0
27
+ EOT
train_a100_x4.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ sbatch <<EOT
3
+ #!/bin/bash
4
+
5
+ #SBATCH -N 1
6
+ #SBATCH -J SDAIA_DR
7
+ #SBATCH -o slurm_logs/output.%J.out
8
+ #SBATCH -e slurm_logs/output.%J.err
9
+ #SBATCH --mail-user=${USER}@kaust.edu.sa
10
+ #SBATCH --mail-type=END,FAIL
11
+ #SBATCH --time=24:00:00
12
+ #SBATCH --mem=200G
13
+ #SBATCH --ntasks=1
14
+ #SBATCH --gres=gpu:a100:4
15
+ #SBATCH --cpus-per-task=64
16
+
17
+ module purge
18
+ source activate retinopathy
19
+
20
+ echo Running four user "${USER}"
21
+
22
+ . ./tag.sh
23
+
24
+ PYTHONPATH=. python train.py --tag=\${SLURM_JOB_ID}_A100_x4_\${TAG}
25
+
26
+ exit 0
27
+ EOT
train_v100_x1.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ sbatch <<EOT
3
+ #!/bin/bash
4
+
5
+ #SBATCH -N 1
6
+ #SBATCH -J SDAIA_DR
7
+ #SBATCH -o slurm_logs/output.%J.out
8
+ #SBATCH -e slurm_logs/output.%J.err
9
+ #SBATCH --mail-user=${USER}@kaust.edu.sa
10
+ #SBATCH --mail-type=END,FAIL
11
+ #SBATCH --time=24:00:00
12
+ #SBATCH --mem=100G
13
+ #SBATCH --ntasks=1
14
+ #SBATCH --gres=gpu:v100:1
15
+ #SBATCH --cpus-per-task=10
16
+
17
+ module purge
18
+ source activate retinopathy
19
+
20
+ echo Running four user "${USER}"
21
+
22
+ . ./tag.sh
23
+
24
+ PYTHONPATH=. python train.py --tag=\${SLURM_JOB_ID}_V100_x1_\${TAG}
25
+
26
+ exit 0
27
+ EOT
train_v100_x4.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ sbatch <<EOT
3
+ #!/bin/bash
4
+
5
+ #SBATCH -N 1
6
+ #SBATCH -J SDAIA_DR
7
+ #SBATCH -o slurm_logs/output.%J.out
8
+ #SBATCH -e slurm_logs/output.%J.err
9
+ #SBATCH --mail-user=${USER}@kaust.edu.sa
10
+ #SBATCH --mail-type=END,FAIL
11
+ #SBATCH --time=24:00:00
12
+ #SBATCH --mem=200G
13
+ #SBATCH --ntasks=1
14
+ #SBATCH --gres=gpu:v100:4
15
+ #SBATCH --cpus-per-task=40
16
+
17
+ module purge
18
+ source activate retinopathy
19
+
20
+ echo Running four user "${USER}"
21
+
22
+ . ./tag.sh
23
+
24
+ PYTHONPATH=. python train.py --tag=\${SLURM_JOB_ID}_V100_x4_\${TAG}
25
+
26
+ exit 0
27
+ EOT