Hannes Kuchelmeister commited on
Commit
a10a7fc
1 Parent(s): b67f297

implement data loader

Browse files
models/notebooks/1.0-hfk-datamodules-exploration.ipynb CHANGED
@@ -292,65 +292,72 @@
292
  },
293
  {
294
  "cell_type": "code",
295
- "execution_count": 7,
296
  "metadata": {},
297
  "outputs": [
 
 
 
 
 
 
 
298
  {
299
  "data": {
300
  "text/plain": [
301
- "{'image': array([[[181, 190, 171],\n",
302
- " [180, 189, 170],\n",
303
- " [180, 186, 172],\n",
304
  " ...,\n",
305
- " [172, 176, 177],\n",
306
- " [171, 176, 179],\n",
307
- " [170, 178, 180]],\n",
308
  " \n",
309
- " [[181, 190, 173],\n",
310
- " [181, 190, 173],\n",
311
- " [180, 188, 175],\n",
312
  " ...,\n",
313
- " [169, 173, 174],\n",
314
- " [169, 175, 175],\n",
315
- " [170, 176, 176]],\n",
316
  " \n",
317
- " [[179, 190, 176],\n",
318
- " [179, 190, 176],\n",
319
- " [179, 189, 180],\n",
320
  " ...,\n",
321
- " [169, 169, 167],\n",
322
- " [169, 171, 170],\n",
323
- " [169, 171, 170]],\n",
324
  " \n",
325
  " ...,\n",
326
  " \n",
327
- " [[195, 201, 197],\n",
328
- " [195, 201, 197],\n",
329
- " [195, 201, 197],\n",
330
  " ...,\n",
331
- " [198, 195, 188],\n",
332
- " [199, 198, 196],\n",
333
- " [202, 200, 205]],\n",
334
  " \n",
335
- " [[195, 201, 197],\n",
336
- " [195, 201, 197],\n",
337
- " [195, 201, 197],\n",
338
  " ...,\n",
339
- " [198, 195, 188],\n",
340
- " [199, 198, 196],\n",
341
- " [202, 200, 205]],\n",
342
  " \n",
343
- " [[195, 201, 197],\n",
344
- " [195, 201, 197],\n",
345
- " [195, 201, 197],\n",
346
  " ...,\n",
347
- " [198, 195, 188],\n",
348
- " [199, 198, 196],\n",
349
- " [202, 200, 203]]], dtype=uint8),\n",
350
- " 'focus_value': -2.70408}"
351
  ]
352
  },
353
- "execution_count": 7,
354
  "metadata": {},
355
  "output_type": "execute_result"
356
  }
@@ -362,7 +369,49 @@
362
  "from focus_datamodule import FocusDataSet\n",
363
  "\n",
364
  "ds = FocusDataSet(\"../data/focus/metadata.csv\", \"../data/focus/\")\n",
365
- "ds[1]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  ]
367
  }
368
  ],
 
292
  },
293
  {
294
  "cell_type": "code",
295
+ "execution_count": 6,
296
  "metadata": {},
297
  "outputs": [
298
+ {
299
+ "name": "stdout",
300
+ "output_type": "stream",
301
+ "text": [
302
+ "570\n"
303
+ ]
304
+ },
305
  {
306
  "data": {
307
  "text/plain": [
308
+ "{'image': array([[[172, 173, 159],\n",
309
+ " [166, 167, 153],\n",
310
+ " [171, 173, 160],\n",
311
  " ...,\n",
312
+ " [199, 202, 173],\n",
313
+ " [199, 202, 173],\n",
314
+ " [200, 201, 170]],\n",
315
  " \n",
316
+ " [[167, 169, 155],\n",
317
+ " [164, 166, 152],\n",
318
+ " [171, 175, 160],\n",
319
  " ...,\n",
320
+ " [194, 197, 168],\n",
321
+ " [195, 198, 169],\n",
322
+ " [199, 200, 169]],\n",
323
  " \n",
324
+ " [[146, 153, 135],\n",
325
+ " [149, 156, 138],\n",
326
+ " [163, 172, 153],\n",
327
  " ...,\n",
328
+ " [189, 192, 163],\n",
329
+ " [191, 194, 165],\n",
330
+ " [197, 198, 167]],\n",
331
  " \n",
332
  " ...,\n",
333
  " \n",
334
+ " [[ 57, 62, 68],\n",
335
+ " [ 41, 46, 52],\n",
336
+ " [ 24, 31, 39],\n",
337
  " ...,\n",
338
+ " [198, 189, 180],\n",
339
+ " [188, 179, 170],\n",
340
+ " [180, 171, 164]],\n",
341
  " \n",
342
+ " [[ 46, 51, 57],\n",
343
+ " [ 34, 39, 45],\n",
344
+ " [ 21, 28, 36],\n",
345
  " ...,\n",
346
+ " [208, 200, 189],\n",
347
+ " [197, 190, 180],\n",
348
+ " [188, 181, 173]],\n",
349
  " \n",
350
+ " [[ 31, 39, 42],\n",
351
+ " [ 23, 31, 34],\n",
352
+ " [ 18, 25, 31],\n",
353
  " ...,\n",
354
+ " [215, 209, 197],\n",
355
+ " [205, 199, 187],\n",
356
+ " [197, 190, 180]]], dtype=uint8),\n",
357
+ " 'focus_value': 0.0}"
358
  ]
359
  },
360
+ "execution_count": 6,
361
  "metadata": {},
362
  "output_type": "execute_result"
363
  }
 
369
  "from focus_datamodule import FocusDataSet\n",
370
  "\n",
371
  "ds = FocusDataSet(\"../data/focus/metadata.csv\", \"../data/focus/\")\n",
372
+ "\n",
373
+ "counter = 0\n",
374
+ "for d in ds:\n",
375
+ " counter += 1\n",
376
+ "\n",
377
+ "print(counter)\n",
378
+ "\n",
379
+ "d"
380
+ ]
381
+ },
382
+ {
383
+ "cell_type": "code",
384
+ "execution_count": 7,
385
+ "metadata": {},
386
+ "outputs": [],
387
+ "source": [
388
+ "from focus_datamodule import FocusDataModule\n",
389
+ "\n",
390
+ "datamodule = FocusDataModule(data_dir=\"../data/focus\", csv_file=\"../data/focus/metadata.csv\")\n",
391
+ "datamodule.setup()"
392
+ ]
393
+ },
394
+ {
395
+ "cell_type": "code",
396
+ "execution_count": 15,
397
+ "metadata": {},
398
+ "outputs": [
399
+ {
400
+ "data": {
401
+ "text/plain": [
402
+ "64"
403
+ ]
404
+ },
405
+ "execution_count": 15,
406
+ "metadata": {},
407
+ "output_type": "execute_result"
408
+ }
409
+ ],
410
+ "source": [
411
+ "for data in datamodule.test_dataloader():\n",
412
+ " break\n",
413
+ "\n",
414
+ "len(data[\"focus_value\"])"
415
  ]
416
  }
417
  ],
models/src/datamodules/focus_datamodule.py CHANGED
@@ -1,14 +1,11 @@
1
  import os
2
- from typing import Any, Optional, Tuple, Union
3
- from typing_extensions import Self
4
- import numpy as np
5
  import pandas as pd
6
  from skimage import io
7
 
8
  import torch
9
  from pytorch_lightning import LightningDataModule
10
- from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
11
- from torchvision.datasets import MNIST
12
  from torchvision.transforms import transforms
13
 
14
 
@@ -58,3 +55,88 @@ class FocusDataSet(Dataset):
58
 
59
  return sample
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Optional, Tuple
 
 
3
  import pandas as pd
4
  from skimage import io
5
 
6
  import torch
7
  from pytorch_lightning import LightningDataModule
8
+ from torch.utils.data import DataLoader, Dataset, random_split
 
9
  from torchvision.transforms import transforms
10
 
11
 
 
55
 
56
  return sample
57
 
58
+
59
+ class FocusDataModule(LightningDataModule):
60
+ """
61
+ LightningDataModule for FocusStack dataset.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ data_dir: str = "data/",
67
+ csv_file: str = "data/metadata.csv",
68
+ train_val_test_split_percentage: Tuple[int, int, int] = (0.75, 0.15, 0.15),
69
+ batch_size: int = 64,
70
+ num_workers: int = 0,
71
+ pin_memory: bool = False,
72
+ ):
73
+ super().__init__()
74
+
75
+ # this line allows to access init params with 'self.hparams' attribute
76
+ self.save_hyperparameters(logger=False)
77
+
78
+ # data transformations
79
+ self.transforms = transforms.Compose([])
80
+
81
+ self.data_train: Optional[Dataset] = None
82
+ self.data_val: Optional[Dataset] = None
83
+ self.data_test: Optional[Dataset] = None
84
+
85
+ def prepare_data(self):
86
+ """This method is not implemented as of yet.
87
+
88
+ Download data if needed. This method is called only from a single GPU.
89
+ Do not use it to assign state (self.x = y).
90
+ """
91
+ pass
92
+
93
+ def setup(self, stage: Optional[str] = None):
94
+ """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
95
+ This method is called by lightning twice for `trainer.fit()` and `trainer.test()`, so be careful if you do a random split!
96
+ The `stage` can be used to differentiate whether it's called before trainer.fit()` or `trainer.test()`."""
97
+
98
+ # load datasets only if they're not loaded already
99
+ if not self.data_train and not self.data_val and not self.data_test:
100
+ dataset = FocusDataSet(
101
+ self.hparams.csv_file, self.hparams.data_dir, transform=self.transforms
102
+ )
103
+ train_length = int(
104
+ len(dataset) * self.hparams.train_val_test_split_percentage[0]
105
+ )
106
+ val_length = int(
107
+ len(dataset) * self.hparams.train_val_test_split_percentage[1]
108
+ )
109
+ test_length = len(dataset) - val_length - train_length
110
+
111
+ self.data_train, self.data_val, self.data_test = random_split(
112
+ dataset=dataset,
113
+ lengths=(train_length, test_length, val_length),
114
+ generator=torch.Generator().manual_seed(42),
115
+ )
116
+
117
+ def train_dataloader(self):
118
+ return DataLoader(
119
+ dataset=self.data_train,
120
+ batch_size=self.hparams.batch_size,
121
+ num_workers=self.hparams.num_workers,
122
+ pin_memory=self.hparams.pin_memory,
123
+ shuffle=True,
124
+ )
125
+
126
+ def val_dataloader(self):
127
+ return DataLoader(
128
+ dataset=self.data_val,
129
+ batch_size=self.hparams.batch_size,
130
+ num_workers=self.hparams.num_workers,
131
+ pin_memory=self.hparams.pin_memory,
132
+ shuffle=False,
133
+ )
134
+
135
+ def test_dataloader(self):
136
+ return DataLoader(
137
+ dataset=self.data_test,
138
+ batch_size=self.hparams.batch_size,
139
+ num_workers=self.hparams.num_workers,
140
+ pin_memory=self.hparams.pin_memory,
141
+ shuffle=False,
142
+ )