Spaces:
Running
Running
File size: 1,918 Bytes
241adf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from mmcv.utils import build_from_cfg
from mmpose.datasets.builder import DATASETS
from mmpose.datasets.dataset_wrappers import RepeatDataset
from torch.utils.data.dataset import ConcatDataset
def _concat_cfg(cfg):
replace = ['ann_file', 'img_prefix']
channels = ['num_joints', 'dataset_channel']
concat_cfg = []
for i in range(len(cfg['type'])):
cfg_tmp = cfg.deepcopy()
cfg_tmp['type'] = cfg['type'][i]
for item in replace:
assert item in cfg_tmp
assert len(cfg['type']) == len(cfg[item]), (cfg[item])
cfg_tmp[item] = cfg[item][i]
for item in channels:
assert item in cfg_tmp['data_cfg']
assert len(cfg['type']) == len(cfg['data_cfg'][item])
cfg_tmp['data_cfg'][item] = cfg['data_cfg'][item][i]
concat_cfg.append(cfg_tmp)
return concat_cfg
def _check_vaild(cfg):
replace = ['num_joints', 'dataset_channel']
if isinstance(cfg['data_cfg'][replace[0]], (list, tuple)):
for item in replace:
cfg['data_cfg'][item] = cfg['data_cfg'][item][0]
return cfg
def build_dataset(cfg, default_args=None):
"""Build a dataset from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
default_args (dict, optional): Default initialization arguments.
Default: None.
Returns:
Dataset: The constructed dataset.
"""
if isinstance(cfg['type'], (list, tuple)): # In training, type=TransformerPoseDataset
dataset = ConcatDataset(
[build_dataset(c, default_args) for c in _concat_cfg(cfg)])
elif cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
else:
cfg = _check_vaild(cfg)
dataset = build_from_cfg(cfg, DATASETS, default_args)
return dataset
|