File size: 1,157 Bytes
7cdf421 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from torch.utils.data import ConcatDataset, Dataset
from .catalog import DatasetCatalog
from .utils import instantiate_from_config
class MyConcatDataset(Dataset):
def __init__(self, dataset_name_list):
super(MyConcatDataset, self).__init__()
_datasets = []
catalog = DatasetCatalog()
for dataset_idx, dataset_name in enumerate(dataset_name_list):
dataset_dict = getattr(catalog, dataset_name)
target = dataset_dict['target']
params = dataset_dict['params']
print(target)
print(params)
dataset = instantiate_from_config(dict(target=target, params=params))
_datasets.append(dataset)
self.datasets = ConcatDataset(_datasets)
def __len__(self):
return self.datasets.__len__()
def __getitem__(self, item):
return self.datasets.__getitem__(item)
def collate(self, instances):
data = {key: [] for key in instances[0].keys()} if instances else {}
for instance in instances:
for key, value in instance.items():
data[key].append(value)
return data
|