|
from torch.utils.data import ConcatDataset, Dataset |
|
from .catalog import DatasetCatalog |
|
from .utils import instantiate_from_config |
|
|
|
|
|
class MyConcatDataset(Dataset): |
|
def __init__(self, dataset_name_list): |
|
super(MyConcatDataset, self).__init__() |
|
|
|
_datasets = [] |
|
|
|
catalog = DatasetCatalog() |
|
for dataset_idx, dataset_name in enumerate(dataset_name_list): |
|
dataset_dict = getattr(catalog, dataset_name) |
|
|
|
target = dataset_dict['target'] |
|
params = dataset_dict['params'] |
|
print(target) |
|
print(params) |
|
dataset = instantiate_from_config(dict(target=target, params=params)) |
|
|
|
_datasets.append(dataset) |
|
self.datasets = ConcatDataset(_datasets) |
|
|
|
def __len__(self): |
|
return self.datasets.__len__() |
|
|
|
def __getitem__(self, item): |
|
return self.datasets.__getitem__(item) |
|
|
|
def collate(self, instances): |
|
data = {key: [] for key in instances[0].keys()} if instances else {} |
|
|
|
for instance in instances: |
|
for key, value in instance.items(): |
|
data[key].append(value) |
|
|
|
return data |
|
|