|
from ..datasets.modis_dataset import MODISDataset |
|
from ..datasets.modis_lc_five_dataset import MODISLCFiveDataset |
|
from ..datasets.modis_lc_nine_dataset import MODISLCNineDataset |
|
|
|
from ..transforms import TensorResizeTransform |
|
|
|
import torch.distributed as dist |
|
from torch.utils.data import DataLoader, DistributedSampler |
|
|
|
|
|
DATASETS = { |
|
'modis': MODISDataset, |
|
'modislc9': MODISLCNineDataset, |
|
'modislc5': MODISLCFiveDataset, |
|
|
|
} |
|
|
|
|
|
def get_dataset_from_dict(dataset_name: str): |
|
"""Gets the proper dataset given a dataset name. |
|
|
|
Args: |
|
dataset_name (str): name of the dataset |
|
|
|
Raises: |
|
KeyError: thrown if dataset key is not present in dict |
|
|
|
Returns: |
|
dataset: pytorch dataset |
|
""" |
|
|
|
dataset_name = dataset_name.lower() |
|
|
|
try: |
|
|
|
dataset_to_use = DATASETS[dataset_name] |
|
|
|
except KeyError: |
|
|
|
error_msg = f"{dataset_name} is not an existing dataset" |
|
|
|
error_msg = f"{error_msg}. Available datasets: {DATASETS.keys()}" |
|
|
|
raise KeyError(error_msg) |
|
|
|
return dataset_to_use |
|
|
|
|
|
def build_finetune_dataloaders(config, logger): |
|
"""Builds the dataloaders and datasets for a fine-tuning task. |
|
|
|
Args: |
|
config: config object |
|
logger: logging logger |
|
|
|
Returns: |
|
dataloader_train: training dataloader |
|
dataloader_val: validation dataloader |
|
""" |
|
|
|
transform = TensorResizeTransform(config) |
|
|
|
logger.info(f'Finetuning data transform:\n{transform}') |
|
|
|
dataset_name = config.DATA.DATASET |
|
|
|
logger.info(f'Dataset: {dataset_name}') |
|
logger.info(f'Data Paths: {config.DATA.DATA_PATHS}') |
|
|
|
dataset_to_use = get_dataset_from_dict(dataset_name) |
|
|
|
logger.info(f'Dataset obj: {dataset_to_use}') |
|
|
|
dataset_train = dataset_to_use(data_paths=config.DATA.DATA_PATHS, |
|
split="train", |
|
img_size=config.DATA.IMG_SIZE, |
|
transform=transform) |
|
|
|
dataset_val = dataset_to_use(data_paths=config.DATA.DATA_PATHS, |
|
split="val", |
|
img_size=config.DATA.IMG_SIZE, |
|
transform=transform) |
|
|
|
logger.info(f'Build dataset: train images = {len(dataset_train)}') |
|
|
|
logger.info(f'Build dataset: val images = {len(dataset_val)}') |
|
|
|
sampler_train = DistributedSampler( |
|
dataset_train, |
|
num_replicas=dist.get_world_size(), |
|
rank=dist.get_rank(), |
|
shuffle=True) |
|
|
|
sampler_val = DistributedSampler( |
|
dataset_val, |
|
num_replicas=dist.get_world_size(), |
|
rank=dist.get_rank(), |
|
shuffle=False) |
|
|
|
dataloader_train = DataLoader(dataset_train, |
|
config.DATA.BATCH_SIZE, |
|
sampler=sampler_train, |
|
num_workers=config.DATA.NUM_WORKERS, |
|
pin_memory=True, |
|
drop_last=True) |
|
|
|
dataloader_val = DataLoader(dataset_val, |
|
config.DATA.BATCH_SIZE, |
|
sampler=sampler_val, |
|
num_workers=config.DATA.NUM_WORKERS, |
|
pin_memory=True, |
|
drop_last=False) |
|
|
|
return dataloader_train, dataloader_val |
|
|