import json from typing import Optional, Sequence import numpy as np import torch import torch.distributed as ptdist from monai.data import ( CacheDataset, PersistentDataset, partition_dataset, ) from monai.data.utils import pad_list_data_collate from monai.transforms import ( Compose, CropForegroundd, EnsureChannelFirstd, LoadImaged, Orientationd, RandSpatialCropSamplesd, ScaleIntensityRanged, Spacingd, SpatialPadd, ToTensord, Transform, ) class PermuteImage(Transform): """Permute the dimensions of the image""" def __call__(self, data): data["image"] = data["image"].permute( 3, 0, 1, 2 ) # Adjust permutation order as needed return data class CTDataset: def __init__( self, json_path: str, img_size: int, depth: int, mask_patch_size: int, patch_size: int, downsample_ratio: Sequence[float], cache_dir: str, batch_size: int = 1, val_batch_size: int = 1, num_workers: int = 4, cache_num: int = 0, cache_rate: float = 0.0, dist: bool = False, ): super().__init__() self.json_path = json_path self.img_size = img_size self.depth = depth self.mask_patch_size = mask_patch_size self.patch_size = patch_size self.cache_dir = cache_dir self.downsample_ratio = downsample_ratio self.batch_size = batch_size self.val_batch_size = val_batch_size self.num_workers = num_workers self.cache_num = cache_num self.cache_rate = cache_rate self.dist = dist data_list = json.load(open(json_path, "r")) if "train" in data_list.keys(): self.train_list = data_list["train"] if "validation" in data_list.keys(): self.val_list = data_list["validation"] def val_transforms( self, ): return self.train_transforms() def train_transforms( self, ): transforms = Compose( [ LoadImaged(keys=["image"]), EnsureChannelFirstd(keys=["image"]), Orientationd(keys=["image"], axcodes="RAS"), Spacingd( keys=["image"], pixdim=self.downsample_ratio, mode=("bilinear"), ), ScaleIntensityRanged( keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image"], source_key="image"), RandSpatialCropSamplesd( keys=["image"], roi_size=(self.img_size, self.img_size, self.depth), random_size=False, num_samples=1, ), SpatialPadd( keys=["image"], spatial_size=(self.img_size, self.img_size, self.depth), ), # RandScaleIntensityd(keys="image", factors=0.1, prob=0.5), # RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), ToTensord(keys=["image"]), PermuteImage(), ] ) return transforms def setup(self, stage: Optional[str] = None): # Assign Train split(s) for use in Dataloaders if stage in [None, "train"]: if self.dist: train_partition = partition_dataset( data=self.train_list, num_partitions=ptdist.get_world_size(), shuffle=True, even_divisible=True, drop_last=False, )[ptdist.get_rank()] valid_partition = partition_dataset( data=self.val_list, num_partitions=ptdist.get_world_size(), shuffle=False, even_divisible=True, drop_last=False, )[ptdist.get_rank()] # self.cache_num //= ptdist.get_world_size() else: train_partition = self.train_list valid_partition = self.val_list if any([self.cache_num, self.cache_rate]) > 0: train_ds = CacheDataset( train_partition, cache_num=self.cache_num, cache_rate=self.cache_rate, num_workers=self.num_workers, transform=self.train_transforms(), ) valid_ds = CacheDataset( valid_partition, cache_num=self.cache_num // 4, cache_rate=self.cache_rate, num_workers=self.num_workers, transform=self.val_transforms(), ) else: train_ds = PersistentDataset( train_partition, transform=self.train_transforms(), cache_dir=self.cache_dir, ) valid_ds = PersistentDataset( valid_partition, transform=self.val_transforms(), cache_dir=self.cache_dir, ) return {"train": train_ds, "validation": valid_ds} if stage in [None, "test"]: if any([self.cache_num, self.cache_rate]) > 0: test_ds = CacheDataset( self.val_list, cache_num=self.cache_num // 4, cache_rate=self.cache_rate, num_workers=self.num_workers, transform=self.val_transforms(), ) else: test_ds = PersistentDataset( self.val_list, transform=self.val_transforms(), cache_dir=self.cache_dir, ) return {"test": test_ds} return {"train": None, "validation": None} def train_dataloader(self, train_ds): # def collate_fn(examples): # pixel_values = torch.stack([example["image"] for example in examples]) # mask = torch.stack([example["mask"] for example in examples]) # return {"pixel_values": pixel_values, "bool_masked_pos": mask} return torch.utils.data.DataLoader( train_ds, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, shuffle=True, collate_fn=pad_list_data_collate, # collate_fn=collate_fn # drop_last=False, # prefetch_factor=4, ) def val_dataloader(self, valid_ds): return torch.utils.data.DataLoader( valid_ds, batch_size=self.val_batch_size, num_workers=self.num_workers, pin_memory=True, shuffle=False, # drop_last=False, collate_fn=pad_list_data_collate, # prefetch_factor=4, ) def test_dataloader(self, test_ds): return torch.utils.data.DataLoader( test_ds, batch_size=self.val_batch_size, num_workers=self.num_workers, pin_memory=True, shuffle=False, # drop_last=False, collate_fn=pad_list_data_collate, # prefetch_factor=4, )