Spaces:
Running
Running
""" Compose multiple datasets in a single loader. """ | |
import numpy as np | |
from copy import deepcopy | |
from torch.utils.data import Dataset | |
from .wireframe_dataset import WireframeDataset | |
from .holicity_dataset import HolicityDataset | |
class MergeDataset(Dataset): | |
def __init__(self, mode, config=None): | |
super(MergeDataset, self).__init__() | |
# Initialize the datasets | |
self._datasets = [] | |
spec_config = deepcopy(config) | |
for i, d in enumerate(config['datasets']): | |
spec_config['dataset_name'] = d | |
spec_config['gt_source_train'] = config['gt_source_train'][i] | |
spec_config['gt_source_test'] = config['gt_source_test'][i] | |
if d == "wireframe": | |
self._datasets.append(WireframeDataset(mode, spec_config)) | |
elif d == "holicity": | |
spec_config['train_split'] = config['train_splits'][i] | |
self._datasets.append(HolicityDataset(mode, spec_config)) | |
else: | |
raise ValueError("Unknown dataset: " + d) | |
self._weights = config['weights'] | |
def __getitem__(self, item): | |
dataset = self._datasets[np.random.choice( | |
range(len(self._datasets)), p=self._weights)] | |
return dataset[np.random.randint(len(dataset))] | |
def __len__(self): | |
return np.sum([len(d) for d in self._datasets]) | |