File size: 4,616 Bytes
786f6a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import logging
from .constants import *


_logger = logging.getLogger(__name__)


def resolve_data_config(
        args=None,
        pretrained_cfg=None,
        model=None,
        use_test_size=False,
        verbose=False
):
    assert model or args or pretrained_cfg, "At least one of model, args, or pretrained_cfg required for data config."
    args = args or {}
    pretrained_cfg = pretrained_cfg or {}
    if not pretrained_cfg and model is not None and hasattr(model, 'pretrained_cfg'):
        pretrained_cfg = model.pretrained_cfg
    data_config = {}

    # Resolve input/image size
    in_chans = 3
    if args.get('in_chans', None) is not None:
        in_chans = args['in_chans']
    elif args.get('chans', None) is not None:
        in_chans = args['chans']

    input_size = (in_chans, 224, 224)
    if args.get('input_size', None) is not None:
        assert isinstance(args['input_size'], (tuple, list))
        assert len(args['input_size']) == 3
        input_size = tuple(args['input_size'])
        in_chans = input_size[0]  # input_size overrides in_chans
    elif args.get('img_size', None) is not None:
        assert isinstance(args['img_size'], int)
        input_size = (in_chans, args['img_size'], args['img_size'])
    else:
        if use_test_size and pretrained_cfg.get('test_input_size', None) is not None:
            input_size = pretrained_cfg['test_input_size']
        elif pretrained_cfg.get('input_size', None) is not None:
            input_size = pretrained_cfg['input_size']
    data_config['input_size'] = input_size

    # resolve interpolation method
    data_config['interpolation'] = 'bicubic'
    if args.get('interpolation', None):
        data_config['interpolation'] = args['interpolation']
    elif pretrained_cfg.get('interpolation', None):
        data_config['interpolation'] = pretrained_cfg['interpolation']

    # resolve dataset + model mean for normalization
    data_config['mean'] = IMAGENET_DEFAULT_MEAN
    if args.get('mean', None) is not None:
        mean = tuple(args['mean'])
        if len(mean) == 1:
            mean = tuple(list(mean) * in_chans)
        else:
            assert len(mean) == in_chans
        data_config['mean'] = mean
    elif pretrained_cfg.get('mean', None):
        data_config['mean'] = pretrained_cfg['mean']

    # resolve dataset + model std deviation for normalization
    data_config['std'] = IMAGENET_DEFAULT_STD
    if args.get('std', None) is not None:
        std = tuple(args['std'])
        if len(std) == 1:
            std = tuple(list(std) * in_chans)
        else:
            assert len(std) == in_chans
        data_config['std'] = std
    elif pretrained_cfg.get('std', None):
        data_config['std'] = pretrained_cfg['std']

    # resolve default inference crop
    crop_pct = DEFAULT_CROP_PCT
    if args.get('crop_pct', None):
        crop_pct = args['crop_pct']
    else:
        if use_test_size and pretrained_cfg.get('test_crop_pct', None):
            crop_pct = pretrained_cfg['test_crop_pct']
        elif pretrained_cfg.get('crop_pct', None):
            crop_pct = pretrained_cfg['crop_pct']
    data_config['crop_pct'] = crop_pct

    # resolve default crop percentage
    crop_mode = DEFAULT_CROP_MODE
    if args.get('crop_mode', None):
        crop_mode = args['crop_mode']
    elif pretrained_cfg.get('crop_mode', None):
        crop_mode = pretrained_cfg['crop_mode']
    data_config['crop_mode'] = crop_mode

    if verbose:
        _logger.info('Data processing configuration for current model + dataset:')
        for n, v in data_config.items():
            _logger.info('\t%s: %s' % (n, str(v)))

    return data_config


def resolve_model_data_config(
        model,
        args=None,
        pretrained_cfg=None,
        use_test_size=False,
        verbose=False,
):
    """ Resolve Model Data Config
    This is equivalent to resolve_data_config() but with arguments re-ordered to put model first.

    Args:
        model (nn.Module): the model instance
        args (dict): command line arguments / configuration in dict form (overrides pretrained_cfg)
        pretrained_cfg (dict): pretrained model config (overrides pretrained_cfg attached to model)
        use_test_size (bool): use the test time input resolution (if one exists) instead of default train resolution
        verbose (bool): enable extra logging of resolved values

    Returns:
        dictionary of config
    """
    return resolve_data_config(
        args=args,
        pretrained_cfg=pretrained_cfg,
        model=model,
        use_test_size=use_test_size,
        verbose=verbose,
    )