chenz53 commited on
Commit
0425418
1 Parent(s): 58f458c

Create dataload.py

Browse files
Files changed (1) hide show
  1. dataload.py +235 -0
dataload.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Optional, Sequence
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.distributed as ptdist
7
+ from monai.data import (
8
+ CacheDataset,
9
+ PersistentDataset,
10
+ partition_dataset,
11
+ )
12
+ from monai.data.utils import pad_list_data_collate
13
+ from monai.transforms import (
14
+ Compose,
15
+ CropForegroundd,
16
+ EnsureChannelFirstd,
17
+ LoadImaged,
18
+ Orientationd,
19
+ RandSpatialCropSamplesd,
20
+ ScaleIntensityRanged,
21
+ Spacingd,
22
+ SpatialPadd,
23
+ ToTensord,
24
+ Transform,
25
+ )
26
+
27
+
28
+ class PermuteImage(Transform):
29
+ """Permute the dimensions of the image"""
30
+
31
+ def __call__(self, data):
32
+ data["image"] = data["image"].permute(
33
+ 3, 0, 1, 2
34
+ ) # Adjust permutation order as needed
35
+ return data
36
+
37
+
38
+ class CTDataset:
39
+ def __init__(
40
+ self,
41
+ json_path: str,
42
+ img_size: int,
43
+ depth: int,
44
+ mask_patch_size: int,
45
+ patch_size: int,
46
+ downsample_ratio: Sequence[float],
47
+ cache_dir: str,
48
+ batch_size: int = 1,
49
+ val_batch_size: int = 1,
50
+ num_workers: int = 4,
51
+ cache_num: int = 0,
52
+ cache_rate: float = 0.0,
53
+ dist: bool = False,
54
+ ):
55
+ super().__init__()
56
+ self.json_path = json_path
57
+ self.img_size = img_size
58
+ self.depth = depth
59
+ self.mask_patch_size = mask_patch_size
60
+ self.patch_size = patch_size
61
+ self.cache_dir = cache_dir
62
+ self.downsample_ratio = downsample_ratio
63
+ self.batch_size = batch_size
64
+ self.val_batch_size = val_batch_size
65
+ self.num_workers = num_workers
66
+ self.cache_num = cache_num
67
+ self.cache_rate = cache_rate
68
+ self.dist = dist
69
+
70
+ data_list = json.load(open(json_path, "r"))
71
+
72
+ if "train" in data_list.keys():
73
+ self.train_list = data_list["train"]
74
+ if "validation" in data_list.keys():
75
+ self.val_list = data_list["validation"]
76
+
77
+ def val_transforms(
78
+ self,
79
+ ):
80
+ return self.train_transforms()
81
+
82
+ def train_transforms(
83
+ self,
84
+ ):
85
+ transforms = Compose(
86
+ [
87
+ LoadImaged(keys=["image"]),
88
+ EnsureChannelFirstd(keys=["image"]),
89
+ Orientationd(keys=["image"], axcodes="RAS"),
90
+ Spacingd(
91
+ keys=["image"],
92
+ pixdim=self.downsample_ratio,
93
+ mode=("bilinear"),
94
+ ),
95
+ ScaleIntensityRanged(
96
+ keys=["image"],
97
+ a_min=-175,
98
+ a_max=250,
99
+ b_min=0.0,
100
+ b_max=1.0,
101
+ clip=True,
102
+ ),
103
+ CropForegroundd(keys=["image"], source_key="image"),
104
+ RandSpatialCropSamplesd(
105
+ keys=["image"],
106
+ roi_size=(self.img_size, self.img_size, self.depth),
107
+ random_size=False,
108
+ num_samples=1,
109
+ ),
110
+ SpatialPadd(
111
+ keys=["image"],
112
+ spatial_size=(self.img_size, self.img_size, self.depth),
113
+ ),
114
+ # RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
115
+ # RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
116
+ ToTensord(keys=["image"]),
117
+ PermuteImage(),
118
+ ]
119
+ )
120
+
121
+ return transforms
122
+
123
+ def setup(self, stage: Optional[str] = None):
124
+ # Assign Train split(s) for use in Dataloaders
125
+ if stage in [None, "train"]:
126
+ if self.dist:
127
+ train_partition = partition_dataset(
128
+ data=self.train_list,
129
+ num_partitions=ptdist.get_world_size(),
130
+ shuffle=True,
131
+ even_divisible=True,
132
+ drop_last=False,
133
+ )[ptdist.get_rank()]
134
+ valid_partition = partition_dataset(
135
+ data=self.val_list,
136
+ num_partitions=ptdist.get_world_size(),
137
+ shuffle=False,
138
+ even_divisible=True,
139
+ drop_last=False,
140
+ )[ptdist.get_rank()]
141
+ # self.cache_num //= ptdist.get_world_size()
142
+ else:
143
+ train_partition = self.train_list
144
+ valid_partition = self.val_list
145
+
146
+ if any([self.cache_num, self.cache_rate]) > 0:
147
+ train_ds = CacheDataset(
148
+ train_partition,
149
+ cache_num=self.cache_num,
150
+ cache_rate=self.cache_rate,
151
+ num_workers=self.num_workers,
152
+ transform=self.train_transforms(),
153
+ )
154
+ valid_ds = CacheDataset(
155
+ valid_partition,
156
+ cache_num=self.cache_num // 4,
157
+ cache_rate=self.cache_rate,
158
+ num_workers=self.num_workers,
159
+ transform=self.val_transforms(),
160
+ )
161
+ else:
162
+ train_ds = PersistentDataset(
163
+ train_partition,
164
+ transform=self.train_transforms(),
165
+ cache_dir=self.cache_dir,
166
+ )
167
+ valid_ds = PersistentDataset(
168
+ valid_partition,
169
+ transform=self.val_transforms(),
170
+ cache_dir=self.cache_dir,
171
+ )
172
+
173
+ return {"train": train_ds, "validation": valid_ds}
174
+
175
+ if stage in [None, "test"]:
176
+ if any([self.cache_num, self.cache_rate]) > 0:
177
+ test_ds = CacheDataset(
178
+ self.val_list,
179
+ cache_num=self.cache_num // 4,
180
+ cache_rate=self.cache_rate,
181
+ num_workers=self.num_workers,
182
+ transform=self.val_transforms(),
183
+ )
184
+ else:
185
+ test_ds = PersistentDataset(
186
+ self.val_list,
187
+ transform=self.val_transforms(),
188
+ cache_dir=self.cache_dir,
189
+ )
190
+
191
+ return {"test": test_ds}
192
+
193
+ return {"train": None, "validation": None}
194
+
195
+ def train_dataloader(self, train_ds):
196
+ # def collate_fn(examples):
197
+ # pixel_values = torch.stack([example["image"] for example in examples])
198
+ # mask = torch.stack([example["mask"] for example in examples])
199
+ # return {"pixel_values": pixel_values, "bool_masked_pos": mask}
200
+
201
+ return torch.utils.data.DataLoader(
202
+ train_ds,
203
+ batch_size=self.batch_size,
204
+ num_workers=self.num_workers,
205
+ pin_memory=True,
206
+ shuffle=True,
207
+ collate_fn=pad_list_data_collate,
208
+ # collate_fn=collate_fn
209
+ # drop_last=False,
210
+ # prefetch_factor=4,
211
+ )
212
+
213
+ def val_dataloader(self, valid_ds):
214
+ return torch.utils.data.DataLoader(
215
+ valid_ds,
216
+ batch_size=self.val_batch_size,
217
+ num_workers=self.num_workers,
218
+ pin_memory=True,
219
+ shuffle=False,
220
+ # drop_last=False,
221
+ collate_fn=pad_list_data_collate,
222
+ # prefetch_factor=4,
223
+ )
224
+
225
+ def test_dataloader(self, test_ds):
226
+ return torch.utils.data.DataLoader(
227
+ test_ds,
228
+ batch_size=self.val_batch_size,
229
+ num_workers=self.num_workers,
230
+ pin_memory=True,
231
+ shuffle=False,
232
+ # drop_last=False,
233
+ collate_fn=pad_list_data_collate,
234
+ # prefetch_factor=4,
235
+ )