DmitriiKhizbullin commited on
Commit
f036ad4
1 Parent(s): 9d09bdd

Partial docstrings

Browse files
Files changed (6) hide show
  1. README.md +26 -4
  2. app.py +34 -2
  3. environment.yml +0 -2
  4. labelmap.py +2 -0
  5. requirements.txt +2 -0
  6. train.py +69 -15
README.md CHANGED
@@ -1,8 +1,30 @@
1
- # diabetic-retinopathy-detection
2
 
3
- # Installataion
 
 
 
 
 
 
4
 
5
  Create conda environment from YAML:
6
  ```bash
7
- mamba env create -n retinopathy_restore -f environment.yml
8
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diabetic Retinopathy Detection with AI
2
 
3
+ ## Setup
4
+
5
+ ### Gradio app environment
6
+
7
+ TODO
8
+
9
+ ### Training environment
10
 
11
  Create conda environment from YAML:
12
  ```bash
13
+ mamba env create -n retinopathy_train -f environment.yml
14
+ ```
15
+
16
+ Download the data from [Kaggle](https://www.kaggle.com/competitions/diabetic-retinopathy-detection/data) or use kaggle API:
17
+
18
+ ```bash
19
+ pip install kaggle
20
+ kaggle competitions download -c diabetic-retinopathy-detection
21
+ mkdir retinopathy_data/
22
+ unzip diabetic-retinopathy-detection.zip -d retinopathy_data/
23
+ ```
24
+
25
+ Launch training:
26
+ ```bash
27
+ conda activate retinopathy_train
28
+ python train.py
29
+ ```
30
+ The trained model will be put into `lightning_logs/`.
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import gradio as gr
3
  import numpy as np
4
  import torch
5
- from typing import Tuple, Optional, Dict, List
6
  import glob
7
  from collections import defaultdict
8
 
@@ -13,7 +13,10 @@ from labelmap import DR_LABELMAP
13
 
14
 
15
  class App:
 
 
16
  def __init__(self) -> None:
 
17
 
18
  ckpt_name = "2023-12-24_20-02-18_30345221_V100_x4_resnet34/"
19
 
@@ -66,9 +69,19 @@ class App:
66
  self.ui = ui
67
 
68
  def launch(self) -> None:
 
69
  self.ui.queue().launch(share=True)
70
 
71
- def predict(self, image: Optional[np.ndarray]):
 
 
 
 
 
 
 
 
 
72
  if image is None:
73
  return dict()
74
  cls_name, prob, probs = self._infer(image)
@@ -79,6 +92,19 @@ class App:
79
  return probs_dict
80
 
81
  def _infer(self, image_chw: np.ndarray) -> Tuple[str, float, np.ndarray]:
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  assert isinstance(self.model, ResNetForImageClassification)
83
 
84
  inputs = self.image_processor(image_chw, return_tensors="pt")
@@ -98,6 +124,11 @@ class App:
98
 
99
  @staticmethod
100
  def _load_example_lists() -> Dict[int, List[str]]:
 
 
 
 
 
101
 
102
  example_flat_list = glob.glob("demo_data/train/**/*.jpeg")
103
 
@@ -115,6 +146,7 @@ class App:
115
 
116
 
117
  def main():
 
118
  app = App()
119
  app.launch()
120
 
 
2
  import gradio as gr
3
  import numpy as np
4
  import torch
5
+ from typing import Tuple, Optional, Dict, List, Dict
6
  import glob
7
  from collections import defaultdict
8
 
 
13
 
14
 
15
  class App:
16
+ """ Demonstration of the Diabetic Retinopathy model as a Gradio app. """
17
+
18
  def __init__(self) -> None:
19
+ """ Constructor. """
20
 
21
  ckpt_name = "2023-12-24_20-02-18_30345221_V100_x4_resnet34/"
22
 
 
69
  self.ui = ui
70
 
71
  def launch(self) -> None:
72
+ """ Launch the application, blocking. """
73
  self.ui.queue().launch(share=True)
74
 
75
+ def predict(self, image: Optional[np.ndarray]) -> Dict[str, float]:
76
+ """ Gradio callback for pricessing of an image.
77
+
78
+ Args:
79
+ image (Optional[np.ndarray]): Provided image.
80
+
81
+ Returns:
82
+ Dict[str, float]: Label-compatible dict.
83
+ """
84
+
85
  if image is None:
86
  return dict()
87
  cls_name, prob, probs = self._infer(image)
 
92
  return probs_dict
93
 
94
  def _infer(self, image_chw: np.ndarray) -> Tuple[str, float, np.ndarray]:
95
+ """ Low-level method to perform neural network inference.
96
+
97
+ Args:
98
+ image_chw (np.ndarray): Provided image.
99
+
100
+ Returns:
101
+ Tuple[str, float, np.ndarray]:
102
+ - Most probable class name
103
+ - Probability of the most probable class name.
104
+ - Probablilities of all classes in the order of
105
+ being listed in the label map.
106
+ """
107
+
108
  assert isinstance(self.model, ResNetForImageClassification)
109
 
110
  inputs = self.image_processor(image_chw, return_tensors="pt")
 
124
 
125
  @staticmethod
126
  def _load_example_lists() -> Dict[int, List[str]]:
127
+ """ Load example retina images from disk.
128
+
129
+ Returns:
130
+ Dict[int, List[str]]: Dictionary of cls_id -> list of images paths.
131
+ """
132
 
133
  example_flat_list = glob.glob("demo_data/train/**/*.jpeg")
134
 
 
146
 
147
 
148
  def main():
149
+ """ App entry point. """
150
  app = App()
151
  app.launch()
152
 
environment.yml CHANGED
@@ -69,7 +69,6 @@ dependencies:
69
  - parso=0.8.3=pyhd3eb1b0_0
70
  - pexpect=4.8.0=pyhd3eb1b0_3
71
  - pickleshare=0.7.5=pyhd3eb1b0_1003
72
- - pip=23.3.1=py310h06a4308_0
73
  - platformdirs=3.10.0=py310h06a4308_0
74
  - prometheus_client=0.14.1=py310h06a4308_0
75
  - prompt-toolkit=3.0.36=py310h06a4308_0
@@ -104,7 +103,6 @@ dependencies:
104
  - tornado=6.3.3=py310h5eee18b_0
105
  - webencodings=0.5.1=py310h06a4308_1
106
  - wheel=0.41.2=py310h06a4308_0
107
- - xz=5.4.5=h5eee18b_0
108
  - y-py=0.5.9=py310h52d8a92_0
109
  - yaml=0.2.5=h7b6447c_0
110
  - ypy-websocket=0.8.2=py310h06a4308_0
 
69
  - parso=0.8.3=pyhd3eb1b0_0
70
  - pexpect=4.8.0=pyhd3eb1b0_3
71
  - pickleshare=0.7.5=pyhd3eb1b0_1003
 
72
  - platformdirs=3.10.0=py310h06a4308_0
73
  - prometheus_client=0.14.1=py310h06a4308_0
74
  - prompt-toolkit=3.0.36=py310h06a4308_0
 
103
  - tornado=6.3.3=py310h5eee18b_0
104
  - webencodings=0.5.1=py310h06a4308_1
105
  - wheel=0.41.2=py310h06a4308_0
 
106
  - y-py=0.5.9=py310h52d8a92_0
107
  - yaml=0.2.5=h7b6447c_0
108
  - ypy-websocket=0.8.2=py310h06a4308_0
labelmap.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  DR_LABELMAP = {
2
  0: 'No DR',
3
  1: 'Mild',
 
1
+ """ Mapping of class IDs to lables. """
2
+
3
  DR_LABELMAP = {
4
  0: 'No DR',
5
  1: 'Mild',
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+ torch==2.1.2+cpu
train.py CHANGED
@@ -49,7 +49,15 @@ DataRecord = Tuple[Image.Image, int]
49
 
50
 
51
  class RetinopathyDataset(data.Dataset[DataRecord]):
 
 
52
  def __init__(self, data_path: str) -> None:
 
 
 
 
 
 
53
  super().__init__()
54
 
55
  self.data_path = data_path
@@ -88,21 +96,25 @@ class RetinopathyDataset(data.Dataset[DataRecord]):
88
  return img_path
89
 
90
 
 
91
  class Purpose(Enum):
92
  Train = 0
93
  Val = 1
94
 
95
-
96
  FeatureAndTargetTransforms = Tuple[Callable[..., torch.Tensor],
97
  Callable[..., torch.Tensor]]
98
 
 
99
  TensorRecord = Tuple[torch.Tensor, torch.Tensor]
100
 
101
- def normalize(arr: np.ndarray) -> np.ndarray:
102
- return arr / np.sum(arr)
103
-
104
 
105
  class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
 
 
 
 
 
106
  def __init__(self, dataset: RetinopathyDataset,
107
  indices: np.ndarray,
108
  purpose: Purpose,
@@ -111,7 +123,24 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
111
  stratify_classes: bool = False,
112
  use_log_frequencies: bool = False,
113
  ):
114
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  self.dataset = dataset
116
  self.indices = indices
117
  self.purpose = purpose
@@ -124,22 +153,26 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
124
  self.per_class_indices: Optional[Dict[int, np.ndarray]] = None
125
  self.frequencies: Optional[Dict[int, float]] = None
126
  if self.stratify_classes:
127
- self.bucketize_indices()
128
  if self.use_log_frequencies:
129
- self.calc_frequencies()
130
 
131
- def calc_frequencies(self):
132
  assert self.per_class_indices is not None
133
  counts_dict = {lbl: len(arr) for lbl, arr in self.per_class_indices.items()}
134
  counts = np.array(list(counts_dict.values()))
135
- counts_nrm = normalize(counts)
136
  temperature = 50.0 # > 1 to even-out frequencies
137
- freqs = normalize(np.log1p(counts_nrm * temperature))
138
  self.frequencies = {k: freq.item() for k, freq
139
  in zip(self.per_class_indices.keys(), freqs)}
140
  print(self.frequencies)
141
 
142
- def bucketize_indices(self):
 
 
 
 
143
  buckets = defaultdict(list)
144
  for index in self.indices:
145
  label = self.dataset.get_label_at(index)
@@ -191,6 +224,14 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
191
  seed: int = 54,
192
  ) -> Tuple['Split', 'Split']:
193
 
 
 
 
 
 
 
 
 
194
  prng = RandomState(seed)
195
 
196
  num_train = int(len(all_data) * train_fraction)
@@ -204,7 +245,8 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
204
  return train_data, val_data
205
 
206
 
207
- def print_data_stats(dataset: Union[Iterable[DataRecord], DataLoader], split_name: str) -> None:
 
208
  labels = []
209
  for _, label in dataset:
210
  if isinstance(label, torch.Tensor):
@@ -261,7 +303,16 @@ class Metrics:
261
  return self
262
 
263
 
264
- def worker_init_fn(worker_id):
 
 
 
 
 
 
 
 
 
265
  state = np.random.get_state()
266
  assert isinstance(state, tuple)
267
  assert isinstance(state[1], np.ndarray)
@@ -274,6 +325,7 @@ def worker_init_fn(worker_id):
274
 
275
 
276
  class ViTLightningModule(L.LightningModule):
 
277
  def __init__(self, debug: bool) -> None:
278
  super().__init__()
279
 
@@ -443,6 +495,7 @@ class ViTLightningModule(L.LightningModule):
443
  return loss
444
 
445
  def _dump_train_images(self) -> None:
 
446
  img_batch, label_batch = next(iter(self._train_dataloader))
447
  for i_img, (img, label) in enumerate(zip(img_batch, label_batch)):
448
  img_np = img.cpu().numpy()
@@ -494,18 +547,19 @@ class ViTLightningModule(L.LightningModule):
494
 
495
 
496
  def main():
 
497
 
498
  parser = ArgumentParser(description='KAUST-SDAIA Diabetic Retinopathy')
499
  parser.add_argument('--tag', action='store', type=str,
500
  help='Extra suffix to put on the artefact dir name')
501
- parser.add_argument('--debug', action='store_true')
 
502
  parser.add_argument('--convert-checkpoint', action='store', type=str,
503
  help='Convert a checkpoint from training to pickle-independent '
504
  'predictor-compatible directory')
505
 
506
  args = parser.parse_args()
507
 
508
-
509
  torch.set_float32_matmul_precision('high') # for V100/A100
510
 
511
  if args.convert_checkpoint is not None:
 
49
 
50
 
51
  class RetinopathyDataset(data.Dataset[DataRecord]):
52
+ """ A class to access the pre-downloaded Diabetic Retinopathy dataset. """
53
+
54
  def __init__(self, data_path: str) -> None:
55
+ """ Constructor.
56
+
57
+ Args:
58
+ data_path (str): path to the dataset, ex: "retinopathy_data"
59
+ containing "trainLabels.csv" and "train/".
60
+ """
61
  super().__init__()
62
 
63
  self.data_path = data_path
 
96
  return img_path
97
 
98
 
99
+ """ Purpose of a split: training or validation. """
100
  class Purpose(Enum):
101
  Train = 0
102
  Val = 1
103
 
104
+ """ Augmentation transformations for an image and a label. """
105
  FeatureAndTargetTransforms = Tuple[Callable[..., torch.Tensor],
106
  Callable[..., torch.Tensor]]
107
 
108
+ """ Feature (image) and target (label) tensors. """
109
  TensorRecord = Tuple[torch.Tensor, torch.Tensor]
110
 
 
 
 
111
 
112
  class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
113
+ """ Split is a class that keep a view on a part of a dataset.
114
+ Split is used to hold the imormation about which samples go to training
115
+ and which to validation without a need to put these groups of files into
116
+ separate folders.
117
+ """
118
  def __init__(self, dataset: RetinopathyDataset,
119
  indices: np.ndarray,
120
  purpose: Purpose,
 
123
  stratify_classes: bool = False,
124
  use_log_frequencies: bool = False,
125
  ):
126
+ """ Constructor.
127
+
128
+ Args:
129
+ dataset (RetinopathyDataset): The dataset on which the Split "views".
130
+ indices (np.ndarray): Externally provided indices of samples that
131
+ are "viewed" on.
132
+ purpose (Purpose): Either train or val, to be able to replicate
133
+ the data for train split for effecient workers utilization.
134
+ transforms (FeatureAndTargetTransforms): Functors of feature and
135
+ target transforms.
136
+ oversample_factor (int, optional): Expand the training dataset by
137
+ replication to avoid dataloader stalls on epoch ends. Defaults to 1.
138
+ stratify_classes (bool, optional): Whether to apply stratified sampling.
139
+ Defaults to False.
140
+ use_log_frequencies (bool, optional): If stratify_classes=True,
141
+ whether to use logarithmic sampling strategy. If False, apply
142
+ regular even sampling. Defaults to False.
143
+ """
144
  self.dataset = dataset
145
  self.indices = indices
146
  self.purpose = purpose
 
153
  self.per_class_indices: Optional[Dict[int, np.ndarray]] = None
154
  self.frequencies: Optional[Dict[int, float]] = None
155
  if self.stratify_classes:
156
+ self._bucketize_indices()
157
  if self.use_log_frequencies:
158
+ self._calc_frequencies()
159
 
160
+ def _calc_frequencies(self):
161
  assert self.per_class_indices is not None
162
  counts_dict = {lbl: len(arr) for lbl, arr in self.per_class_indices.items()}
163
  counts = np.array(list(counts_dict.values()))
164
+ counts_nrm = self._normalize(counts)
165
  temperature = 50.0 # > 1 to even-out frequencies
166
+ freqs = self._normalize(np.log1p(counts_nrm * temperature))
167
  self.frequencies = {k: freq.item() for k, freq
168
  in zip(self.per_class_indices.keys(), freqs)}
169
  print(self.frequencies)
170
 
171
+ @staticmethod
172
+ def _normalize(arr: np.ndarray) -> np.ndarray:
173
+ return arr / np.sum(arr)
174
+
175
+ def _bucketize_indices(self):
176
  buckets = defaultdict(list)
177
  for index in self.indices:
178
  label = self.dataset.get_label_at(index)
 
224
  seed: int = 54,
225
  ) -> Tuple['Split', 'Split']:
226
 
227
+ """ Prepare train and val splits deterministically.
228
+
229
+ Returns:
230
+ Tuple[Split, Split]:
231
+ - Train split
232
+ - Val split
233
+ """
234
+
235
  prng = RandomState(seed)
236
 
237
  num_train = int(len(all_data) * train_fraction)
 
245
  return train_data, val_data
246
 
247
 
248
+ def print_data_stats(dataset: Union[Iterable[DataRecord], DataLoader],
249
+ split_name: str) -> None:
250
  labels = []
251
  for _, label in dataset:
252
  if isinstance(label, torch.Tensor):
 
303
  return self
304
 
305
 
306
+ def worker_init_fn(worker_id: int) -> None:
307
+ """ Initialize workers in a way that they draw different
308
+ random samples and do not repeat identical pseudorandom
309
+ sequences of each other, which may be the case with Fork
310
+ multiprocessing.
311
+
312
+ Args:
313
+ worker_id (int): id of a preprocessing worker process launched
314
+ by one DDP training process.
315
+ """
316
  state = np.random.get_state()
317
  assert isinstance(state, tuple)
318
  assert isinstance(state[1], np.ndarray)
 
325
 
326
 
327
  class ViTLightningModule(L.LightningModule):
328
+ """ Lightning Module that implements neural network training hooks. """
329
  def __init__(self, debug: bool) -> None:
330
  super().__init__()
331
 
 
495
  return loss
496
 
497
  def _dump_train_images(self) -> None:
498
+ """ Save augmented images to disk for inspection. """
499
  img_batch, label_batch = next(iter(self._train_dataloader))
500
  for i_img, (img, label) in enumerate(zip(img_batch, label_batch)):
501
  img_np = img.cpu().numpy()
 
547
 
548
 
549
  def main():
550
+ """ Neural network trainer entry point. """
551
 
552
  parser = ArgumentParser(description='KAUST-SDAIA Diabetic Retinopathy')
553
  parser.add_argument('--tag', action='store', type=str,
554
  help='Extra suffix to put on the artefact dir name')
555
+ parser.add_argument('--debug', action='store_true',
556
+ help="Dummy training cycle for testing purposes")
557
  parser.add_argument('--convert-checkpoint', action='store', type=str,
558
  help='Convert a checkpoint from training to pickle-independent '
559
  'predictor-compatible directory')
560
 
561
  args = parser.parse_args()
562
 
 
563
  torch.set_float32_matmul_precision('high') # for V100/A100
564
 
565
  if args.convert_checkpoint is not None: