osbm's picture
Upload 9 files
3953219
raw
history blame
8.49 kB
# create dataloaders form csv file
## ---------- imports ----------
import os
import torch
import shutil
import numpy as np
import pandas as pd
from typing import Union
from monai.utils import first
from functools import partial
from collections import namedtuple
from monai.data import DataLoader as MonaiDataLoader
from . import transforms
from .utils import num_workers
def import_dataset(config: dict):
if config.data.dataset_type == 'persistent':
from monai.data import PersistentDataset
if os.path.exists(config.data.cache_dir):
shutil.rmtree(config.data.cache_dir) # rm previous cache DS
os.makedirs(config.data.cache_dir, exist_ok = True)
Dataset = partial(PersistentDataset, cache_dir = config.data.cache_dir)
elif config.data.dataset_type == 'cache':
from monai.data import CacheDataset
raise NotImplementedError('CacheDataset not yet implemented')
else:
from monai.data import Dataset
return Dataset
class DataLoader(MonaiDataLoader):
"overwrite monai DataLoader for enhanced viewing capabilities"
def show_batch(self,
image_key: str='image',
label_key: str='label',
image_transform=lambda x: x.squeeze().transpose(0,2).flip(-2),
label_transform=lambda x: x.squeeze().transpose(0,2).flip(-2)):
"""Args:
image_key: dict key name for image to view
label_key: dict kex name for corresponding label. Can be a tensor or str
image_transform: transform input before it is passed to the viewer to ensure
ndim of the image is equal to 3 and image is oriented correctly
label_transform: transform labels before passed to the viewer, to ensure
segmentations masks have same shape and orientations as images. Should be
identity function of labels are str.
"""
from .viewer import ListViewer
batch = first(self)
image = torch.unbind(batch[image_key], 0)
label = torch.unbind(batch[label_key], 0)
ListViewer([image_transform(im) for im in image],
[label_transform(im) for im in label]).show()
# TODO
## Work with 3 dataloaders
def segmentation_dataloaders(config: dict,
train: bool = None,
valid: bool = None,
test: bool = None,
):
"""Create segmentation dataloaders
Args:
config: config file
train: whether to return a train DataLoader
valid: whether to return a valid DataLoader
test: whether to return a test DateLoader
Args from config:
data_dir: base directory for the data
csv_name: path to csv file containing filenames and paths
image_cols: columns in csv containing path to images
label_cols: columns in csv containing path to label files
dataset_type: PersistentDataset, CacheDataset and Dataset are supported
cache_dir: cache directory to be used by PersistentDataset
batch_size: batch size for training. Valid and test are always 1
debug: run with reduced number of images
Returns:
list of:
train_loader: DataLoader (optional, if train==True)
valid_loader: DataLoader (optional, if valid==True)
test_loader: DataLoader (optional, if test==True)
"""
## parse needed rguments from config
if train is None: train = config.data.train
if valid is None: valid = config.data.valid
if test is None: test = config.data.test
data_dir = config.data.data_dir
train_csv = config.data.train_csv
valid_csv = config.data.valid_csv
test_csv = config.data.test_csv
image_cols = config.data.image_cols
label_cols = config.data.label_cols
dataset_type = config.data.dataset_type
cache_dir = config.data.cache_dir
batch_size = config.data.batch_size
debug = config.debug
## ---------- data dicts ----------
# first a global data dict, containing only the filepath from image_cols and label_cols is created. For this,
# the dataframe is reduced to only the relevant columns. Then the rows are iterated, converting each row into an
# individual dict, as expected by monai
if not isinstance(image_cols, (tuple, list)): image_cols = [image_cols]
if not isinstance(label_cols, (tuple, list)): label_cols = [label_cols]
train_df = pd.read_csv(train_csv)
valid_df = pd.read_csv(valid_csv)
test_df = pd.read_csv(test_csv)
if debug:
train_df = train_df.sample(25)
valid_df = valid_df.sample(5)
train_df['split']='train'
valid_df['split']='valid'
test_df['split']='test'
whole_df = []
if train: whole_df += [train_df]
if valid: whole_df += [valid_df]
if test: whole_df += [test_df]
df = pd.concat(whole_df)
cols = image_cols + label_cols
for col in cols:
# create absolute file name from relative fn in df and data_dir
df[col] = [os.path.join(data_dir, fn) for fn in df[col]]
if not os.path.exists(list(df[col])[0]):
raise FileNotFoundError(list(df[col])[0])
data_dict = [dict(row[1]) for row in df[cols].iterrows()]
# data_dict is not the correct name, list_of_data_dicts would be more accurate, but also longer.
# The data_dict looks like this:
# [
# {'image_col_1': 'data_dir/path/to/image1',
# 'image_col_2': 'data_dir/path/to/image2'
# 'label_col_1': 'data_dir/path/to/label1},
# {'image_col_1': 'data_dir/path/to/image1',
# 'image_col_2': 'data_dir/path/to/image2'
# 'label_col_1': 'data_dir/path/to/label1},
# ...]
# Filename should now be absolute or relative to working directory
# now we create separate data dicts for train, valid and test data respectively
assert train or test or valid, 'No dataset type is specified (train/valid or test)'
if test:
test_files = list(map(data_dict.__getitem__, *np.where(df.split == 'test')))
if valid:
val_files = list(map(data_dict.__getitem__, *np.where(df.split == 'valid')))
if train:
train_files = list(map(data_dict.__getitem__, *np.where(df.split == 'train')))
# transforms are specified in transforms.py and are just loaded here
if train: train_transforms = transforms.get_train_transforms(config)
if valid: val_transforms = transforms.get_val_transforms(config)
if test: test_transforms = transforms.get_test_transforms(config)
## ---------- construct dataloaders ----------
Dataset=import_dataset(config)
data_loaders = []
if train:
train_ds = Dataset(
data=train_files,
transform=train_transforms
)
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
num_workers=num_workers(),
shuffle=True
)
data_loaders.append(train_loader)
if valid:
val_ds = Dataset(
data=val_files,
transform=val_transforms
)
val_loader = DataLoader(
val_ds,
batch_size=1,
num_workers=num_workers(),
shuffle=False
)
data_loaders.append(val_loader)
if test:
test_ds = Dataset(
data=test_files,
transform=test_transforms
)
test_loader = DataLoader(
test_ds,
batch_size=1,
num_workers=num_workers(),
shuffle=False
)
data_loaders.append(test_loader)
# if only one dataloader is constructed, return only this dataloader else return a named tuple with dataloaders,
# so it is clear which DataLoader is train/valid or test
if len(data_loaders) == 1:
return data_loaders[0]
else:
DataLoaders = namedtuple(
'DataLoaders',
# create str with specification of loader type if train and test are true but
# valid is false string will be 'train test'
' '.join(
[
'train' if train else '',
'valid' if valid else '',
'test' if test else ''
]
).strip()
)
return DataLoaders(*data_loaders)