Spaces:
Runtime error
Runtime error
# create dataloaders form csv file | |
## ---------- imports ---------- | |
import os | |
import torch | |
import shutil | |
import numpy as np | |
import pandas as pd | |
from typing import Union | |
from monai.utils import first | |
from functools import partial | |
from collections import namedtuple | |
from monai.data import DataLoader as MonaiDataLoader | |
from . import transforms | |
from .utils import num_workers | |
def import_dataset(config: dict): | |
if config.data.dataset_type == 'persistent': | |
from monai.data import PersistentDataset | |
if os.path.exists(config.data.cache_dir): | |
shutil.rmtree(config.data.cache_dir) # rm previous cache DS | |
os.makedirs(config.data.cache_dir, exist_ok = True) | |
Dataset = partial(PersistentDataset, cache_dir = config.data.cache_dir) | |
elif config.data.dataset_type == 'cache': | |
from monai.data import CacheDataset | |
raise NotImplementedError('CacheDataset not yet implemented') | |
else: | |
from monai.data import Dataset | |
return Dataset | |
class DataLoader(MonaiDataLoader): | |
"overwrite monai DataLoader for enhanced viewing capabilities" | |
def show_batch(self, | |
image_key: str='image', | |
label_key: str='label', | |
image_transform=lambda x: x.squeeze().transpose(0,2).flip(-2), | |
label_transform=lambda x: x.squeeze().transpose(0,2).flip(-2)): | |
"""Args: | |
image_key: dict key name for image to view | |
label_key: dict kex name for corresponding label. Can be a tensor or str | |
image_transform: transform input before it is passed to the viewer to ensure | |
ndim of the image is equal to 3 and image is oriented correctly | |
label_transform: transform labels before passed to the viewer, to ensure | |
segmentations masks have same shape and orientations as images. Should be | |
identity function of labels are str. | |
""" | |
from .viewer import ListViewer | |
batch = first(self) | |
image = torch.unbind(batch[image_key], 0) | |
label = torch.unbind(batch[label_key], 0) | |
ListViewer([image_transform(im) for im in image], | |
[label_transform(im) for im in label]).show() | |
# TODO | |
## Work with 3 dataloaders | |
def segmentation_dataloaders(config: dict, | |
train: bool = None, | |
valid: bool = None, | |
test: bool = None, | |
): | |
"""Create segmentation dataloaders | |
Args: | |
config: config file | |
train: whether to return a train DataLoader | |
valid: whether to return a valid DataLoader | |
test: whether to return a test DateLoader | |
Args from config: | |
data_dir: base directory for the data | |
csv_name: path to csv file containing filenames and paths | |
image_cols: columns in csv containing path to images | |
label_cols: columns in csv containing path to label files | |
dataset_type: PersistentDataset, CacheDataset and Dataset are supported | |
cache_dir: cache directory to be used by PersistentDataset | |
batch_size: batch size for training. Valid and test are always 1 | |
debug: run with reduced number of images | |
Returns: | |
list of: | |
train_loader: DataLoader (optional, if train==True) | |
valid_loader: DataLoader (optional, if valid==True) | |
test_loader: DataLoader (optional, if test==True) | |
""" | |
## parse needed rguments from config | |
if train is None: train = config.data.train | |
if valid is None: valid = config.data.valid | |
if test is None: test = config.data.test | |
data_dir = config.data.data_dir | |
train_csv = config.data.train_csv | |
valid_csv = config.data.valid_csv | |
test_csv = config.data.test_csv | |
image_cols = config.data.image_cols | |
label_cols = config.data.label_cols | |
dataset_type = config.data.dataset_type | |
cache_dir = config.data.cache_dir | |
batch_size = config.data.batch_size | |
debug = config.debug | |
## ---------- data dicts ---------- | |
# first a global data dict, containing only the filepath from image_cols and label_cols is created. For this, | |
# the dataframe is reduced to only the relevant columns. Then the rows are iterated, converting each row into an | |
# individual dict, as expected by monai | |
if not isinstance(image_cols, (tuple, list)): image_cols = [image_cols] | |
if not isinstance(label_cols, (tuple, list)): label_cols = [label_cols] | |
train_df = pd.read_csv(train_csv) | |
valid_df = pd.read_csv(valid_csv) | |
test_df = pd.read_csv(test_csv) | |
if debug: | |
train_df = train_df.sample(25) | |
valid_df = valid_df.sample(5) | |
train_df['split']='train' | |
valid_df['split']='valid' | |
test_df['split']='test' | |
whole_df = [] | |
if train: whole_df += [train_df] | |
if valid: whole_df += [valid_df] | |
if test: whole_df += [test_df] | |
df = pd.concat(whole_df) | |
cols = image_cols + label_cols | |
for col in cols: | |
# create absolute file name from relative fn in df and data_dir | |
df[col] = [os.path.join(data_dir, fn) for fn in df[col]] | |
if not os.path.exists(list(df[col])[0]): | |
raise FileNotFoundError(list(df[col])[0]) | |
data_dict = [dict(row[1]) for row in df[cols].iterrows()] | |
# data_dict is not the correct name, list_of_data_dicts would be more accurate, but also longer. | |
# The data_dict looks like this: | |
# [ | |
# {'image_col_1': 'data_dir/path/to/image1', | |
# 'image_col_2': 'data_dir/path/to/image2' | |
# 'label_col_1': 'data_dir/path/to/label1}, | |
# {'image_col_1': 'data_dir/path/to/image1', | |
# 'image_col_2': 'data_dir/path/to/image2' | |
# 'label_col_1': 'data_dir/path/to/label1}, | |
# ...] | |
# Filename should now be absolute or relative to working directory | |
# now we create separate data dicts for train, valid and test data respectively | |
assert train or test or valid, 'No dataset type is specified (train/valid or test)' | |
if test: | |
test_files = list(map(data_dict.__getitem__, *np.where(df.split == 'test'))) | |
if valid: | |
val_files = list(map(data_dict.__getitem__, *np.where(df.split == 'valid'))) | |
if train: | |
train_files = list(map(data_dict.__getitem__, *np.where(df.split == 'train'))) | |
# transforms are specified in transforms.py and are just loaded here | |
if train: train_transforms = transforms.get_train_transforms(config) | |
if valid: val_transforms = transforms.get_val_transforms(config) | |
if test: test_transforms = transforms.get_test_transforms(config) | |
## ---------- construct dataloaders ---------- | |
Dataset=import_dataset(config) | |
data_loaders = [] | |
if train: | |
train_ds = Dataset( | |
data=train_files, | |
transform=train_transforms | |
) | |
train_loader = DataLoader( | |
train_ds, | |
batch_size=batch_size, | |
num_workers=num_workers(), | |
shuffle=True | |
) | |
data_loaders.append(train_loader) | |
if valid: | |
val_ds = Dataset( | |
data=val_files, | |
transform=val_transforms | |
) | |
val_loader = DataLoader( | |
val_ds, | |
batch_size=1, | |
num_workers=num_workers(), | |
shuffle=False | |
) | |
data_loaders.append(val_loader) | |
if test: | |
test_ds = Dataset( | |
data=test_files, | |
transform=test_transforms | |
) | |
test_loader = DataLoader( | |
test_ds, | |
batch_size=1, | |
num_workers=num_workers(), | |
shuffle=False | |
) | |
data_loaders.append(test_loader) | |
# if only one dataloader is constructed, return only this dataloader else return a named tuple with dataloaders, | |
# so it is clear which DataLoader is train/valid or test | |
if len(data_loaders) == 1: | |
return data_loaders[0] | |
else: | |
DataLoaders = namedtuple( | |
'DataLoaders', | |
# create str with specification of loader type if train and test are true but | |
# valid is false string will be 'train test' | |
' '.join( | |
[ | |
'train' if train else '', | |
'valid' if valid else '', | |
'test' if test else '' | |
] | |
).strip() | |
) | |
return DataLoaders(*data_loaders) | |