Caleb Spradlin
initial commit
ab687e7
raw
history blame
3.38 kB
from ..datasets.modis_dataset import MODISDataset
from ..datasets.modis_lc_five_dataset import MODISLCFiveDataset
from ..datasets.modis_lc_nine_dataset import MODISLCNineDataset
from ..transforms import TensorResizeTransform
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
DATASETS = {
'modis': MODISDataset,
'modislc9': MODISLCNineDataset,
'modislc5': MODISLCFiveDataset,
# 'modis tree': MODISTree,
}
def get_dataset_from_dict(dataset_name: str):
"""Gets the proper dataset given a dataset name.
Args:
dataset_name (str): name of the dataset
Raises:
KeyError: thrown if dataset key is not present in dict
Returns:
dataset: pytorch dataset
"""
dataset_name = dataset_name.lower()
try:
dataset_to_use = DATASETS[dataset_name]
except KeyError:
error_msg = f"{dataset_name} is not an existing dataset"
error_msg = f"{error_msg}. Available datasets: {DATASETS.keys()}"
raise KeyError(error_msg)
return dataset_to_use
def build_finetune_dataloaders(config, logger):
"""Builds the dataloaders and datasets for a fine-tuning task.
Args:
config: config object
logger: logging logger
Returns:
dataloader_train: training dataloader
dataloader_val: validation dataloader
"""
transform = TensorResizeTransform(config)
logger.info(f'Finetuning data transform:\n{transform}')
dataset_name = config.DATA.DATASET
logger.info(f'Dataset: {dataset_name}')
logger.info(f'Data Paths: {config.DATA.DATA_PATHS}')
dataset_to_use = get_dataset_from_dict(dataset_name)
logger.info(f'Dataset obj: {dataset_to_use}')
dataset_train = dataset_to_use(data_paths=config.DATA.DATA_PATHS,
split="train",
img_size=config.DATA.IMG_SIZE,
transform=transform)
dataset_val = dataset_to_use(data_paths=config.DATA.DATA_PATHS,
split="val",
img_size=config.DATA.IMG_SIZE,
transform=transform)
logger.info(f'Build dataset: train images = {len(dataset_train)}')
logger.info(f'Build dataset: val images = {len(dataset_val)}')
sampler_train = DistributedSampler(
dataset_train,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=True)
sampler_val = DistributedSampler(
dataset_val,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=False)
dataloader_train = DataLoader(dataset_train,
config.DATA.BATCH_SIZE,
sampler=sampler_train,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=True,
drop_last=True)
dataloader_val = DataLoader(dataset_val,
config.DATA.BATCH_SIZE,
sampler=sampler_val,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=True,
drop_last=False)
return dataloader_train, dataloader_val