File size: 1,918 Bytes
241adf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from mmcv.utils import build_from_cfg
from mmpose.datasets.builder import DATASETS
from mmpose.datasets.dataset_wrappers import RepeatDataset
from torch.utils.data.dataset import ConcatDataset


def _concat_cfg(cfg):
    replace = ['ann_file', 'img_prefix']
    channels = ['num_joints', 'dataset_channel']
    concat_cfg = []
    for i in range(len(cfg['type'])):
        cfg_tmp = cfg.deepcopy()
        cfg_tmp['type'] = cfg['type'][i]
        for item in replace:
            assert item in cfg_tmp
            assert len(cfg['type']) == len(cfg[item]), (cfg[item])
            cfg_tmp[item] = cfg[item][i]
        for item in channels:
            assert item in cfg_tmp['data_cfg']
            assert len(cfg['type']) == len(cfg['data_cfg'][item])
            cfg_tmp['data_cfg'][item] = cfg['data_cfg'][item][i]
        concat_cfg.append(cfg_tmp)
    return concat_cfg


def _check_vaild(cfg):
    replace = ['num_joints', 'dataset_channel']
    if isinstance(cfg['data_cfg'][replace[0]], (list, tuple)):
        for item in replace:
            cfg['data_cfg'][item] = cfg['data_cfg'][item][0]
    return cfg


def build_dataset(cfg, default_args=None):
    """Build a dataset from config dict.

    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        default_args (dict, optional): Default initialization arguments.
            Default: None.

    Returns:
        Dataset: The constructed dataset.
    """
    if isinstance(cfg['type'], (list, tuple)):  # In training, type=TransformerPoseDataset
        dataset = ConcatDataset(
            [build_dataset(c, default_args) for c in _concat_cfg(cfg)])
    elif cfg['type'] == 'RepeatDataset':
        dataset = RepeatDataset(
            build_dataset(cfg['dataset'], default_args), cfg['times'])
    else:
        cfg = _check_vaild(cfg)
        dataset = build_from_cfg(cfg, DATASETS, default_args)
    return dataset