File size: 1,486 Bytes
7cdf421 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from header import *
from .samplers import DistributedBatchSampler, DistributedMultiDatasetBatchSampler
from .catalog import DatasetCatalog
from .utils import instantiate_from_config
import torch
from torch.utils.data import ConcatDataset
from .concat_dataset import MyConcatDataset
def load_dataset(args, dataset_name_list):
"""
Args:
args:
dataset_name_list: List[str]
repeats: List[int], the training epochs for each dataset
"""
# concat_data = get_concat_dataset(dataset_name_list)
concat_data = MyConcatDataset(dataset_name_list)
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
batch_size = args['world_size'] * args['dschf'].config['train_micro_batch_size_per_gpu']
sampler = torch.utils.data.RandomSampler(concat_data)
batch_sampler = DistributedMultiDatasetBatchSampler(dataset=concat_data,
sampler=sampler,
batch_size=batch_size,
drop_last=True,
rank=rank,
world_size=world_size)
iter_ = DataLoader(
concat_data,
batch_sampler=batch_sampler,
num_workers=1,
collate_fn=concat_data.collate,
pin_memory=True
)
return concat_data, iter_, sampler
|