from mmcv.utils import build_from_cfg from mmpose.datasets.builder import DATASETS from mmpose.datasets.dataset_wrappers import RepeatDataset from torch.utils.data.dataset import ConcatDataset def _concat_cfg(cfg): replace = ['ann_file', 'img_prefix'] channels = ['num_joints', 'dataset_channel'] concat_cfg = [] for i in range(len(cfg['type'])): cfg_tmp = cfg.deepcopy() cfg_tmp['type'] = cfg['type'][i] for item in replace: assert item in cfg_tmp assert len(cfg['type']) == len(cfg[item]), (cfg[item]) cfg_tmp[item] = cfg[item][i] for item in channels: assert item in cfg_tmp['data_cfg'] assert len(cfg['type']) == len(cfg['data_cfg'][item]) cfg_tmp['data_cfg'][item] = cfg['data_cfg'][item][i] concat_cfg.append(cfg_tmp) return concat_cfg def _check_vaild(cfg): replace = ['num_joints', 'dataset_channel'] if isinstance(cfg['data_cfg'][replace[0]], (list, tuple)): for item in replace: cfg['data_cfg'][item] = cfg['data_cfg'][item][0] return cfg def build_dataset(cfg, default_args=None): """Build a dataset from config dict. Args: cfg (dict): Config dict. It should at least contain the key "type". default_args (dict, optional): Default initialization arguments. Default: None. Returns: Dataset: The constructed dataset. """ if isinstance(cfg['type'], (list, tuple)): # In training, type=TransformerPoseDataset dataset = ConcatDataset( [build_dataset(c, default_args) for c in _concat_cfg(cfg)]) elif cfg['type'] == 'RepeatDataset': dataset = RepeatDataset( build_dataset(cfg['dataset'], default_args), cfg['times']) else: cfg = _check_vaild(cfg) dataset = build_from_cfg(cfg, DATASETS, default_args) return dataset