File size: 5,683 Bytes
78e32cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
###
# Author: Kai Li
# Date: 2021-06-20 00:36:46
# LastEditors: Please set LastEditors
# LastEditTime: 2024-01-22 03:02:57
###
import sys
import argparse
import importlib
from omegaconf import DictConfig

def prepare_parser_from_dict(dic, parser=None):
    """Prepare an argparser from a dictionary.

    Args:
        dic (dict): Two-level config dictionary with unique bottom-level keys.
        parser (argparse.ArgumentParser, optional): If a parser already
            exists, add the keys from the dictionary on the top of it.

    Returns:
        argparse.ArgumentParser:
            Parser instance with groups corresponding to the first level keys
            and arguments corresponding to the second level keys with default
            values given by the values.
    """

    def standardized_entry_type(value):
        """If the default value is None, replace NoneType by str_int_float.
        If the default value is boolean, look for boolean strings."""
        if value is None:
            return str_int_float
        if isinstance(str2bool(value), bool):
            return str2bool_arg
        return type(value)

    if parser is None:
        parser = argparse.ArgumentParser()
    for k in dic.keys():
        group = parser.add_argument_group(k)
        if isinstance(dic[k], list):
            entry_type = standardized_entry_type(dic[k])
            group.add_argument("--" + k, default=dic[k], type=entry_type)
        elif isinstance(dic[k], dict):
            for kk in dic[k].keys():
                entry_type = standardized_entry_type(dic[k][kk])
                group.add_argument("--" + kk, default=dic[k][kk], type=entry_type)
        elif isinstance(dic[k], str):
            entry_type = standardized_entry_type(dic[k])
            group.add_argument("--" + k, default=dic[k], type=entry_type)
    return parser


def str_int_float(value):
    """Type to convert strings to int, float (in this order) if possible.

    Args:
        value (str): Value to convert.

    Returns:
        int, float, str: Converted value.
    """
    if isint(value):
        return int(value)
    if isfloat(value):
        return float(value)
    elif isinstance(value, str):
        return value


def str2bool(value):
    """Type to convert strings to Boolean (returns input if not boolean)"""
    if not isinstance(value, str):
        return value
    if value.lower() in ("yes", "true", "y", "1"):
        return True
    elif value.lower() in ("no", "false", "n", "0"):
        return False
    else:
        return value


def str2bool_arg(value):
    """Argparse type to convert strings to Boolean"""
    value = str2bool(value)
    if isinstance(value, bool):
        return value
    raise argparse.ArgumentTypeError("Boolean value expected.")


def isfloat(value):
    """Computes whether `value` can be cast to a float.

    Args:
        value (str): Value to check.

    Returns:
        bool: Whether `value` can be cast to a float.

    """
    try:
        float(value)
        return True
    except ValueError:
        return False


def isint(value):
    """Computes whether `value` can be cast to an int

    Args:
        value (str): Value to check.

    Returns:
        bool: Whether `value` can be cast to an int.

    """
    try:
        int(value)
        return True
    except ValueError:
        return False


def parse_args_as_dict(parser, return_plain_args=False, args=None):
    """Get a dict of dicts out of process `parser.parse_args()`

    Top-level keys corresponding to groups and bottom-level keys corresponding
    to arguments. Under `'main_args'`, the arguments which don't belong to a
    argparse group (i.e main arguments defined before parsing from a dict) can
    be found.

    Args:
        parser (argparse.ArgumentParser): ArgumentParser instance containing
            groups. Output of `prepare_parser_from_dict`.
        return_plain_args (bool): Whether to return the output or
            `parser.parse_args()`.
        args (list): List of arguments as read from the command line.
            Used for unit testing.

    Returns:
        dict:
            Dictionary of dictionaries containing the arguments. Optionally the
            direct output `parser.parse_args()`.
    """
    args = parser.parse_args(args=args)
    args_dic = {}
    for group in parser._action_groups:
        group_dict = {a.dest: getattr(args, a.dest, None) for a in group._group_actions}
        args_dic[group.title] = group_dict
    if sys.version_info.minor == 10:
        args_dic["main_args"] = args_dic["positional arguments"]
        del args_dic["positional arguments"]
    else:
        args_dic["main_args"] = args_dic["optional arguments"]
        del args_dic["optional arguments"]
    if return_plain_args:
        return args_dic, args
    return args_dic

def instantiate(config, **kwargs):
    if '__target__' in config:
        module_path, class_name = config['__target__'].rsplit('.', 1)
        module = importlib.import_module(module_path)
        cls = getattr(module, class_name)
        # 先处理嵌套的配置
        params = {}
        for key, value in config.items():
            if key != '__target__':
                if isinstance(value, DictConfig) and '__target__' in value:
                    params[key] = instantiate(value)
                else:
                    params[key] = value
        # 添加额外的关键字参数
        params.update(kwargs)
        return cls(**params)
    else:
        # 对于不包含 '__target__' 的字典,递归处理其每个值
        return {k: instantiate(v, **kwargs) if isinstance(v, DictConfig) else v for k, v in config.items()}