Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,683 Bytes
78e32cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
###
# Author: Kai Li
# Date: 2021-06-20 00:36:46
# LastEditors: Please set LastEditors
# LastEditTime: 2024-01-22 03:02:57
###
import sys
import argparse
import importlib
from omegaconf import DictConfig
def prepare_parser_from_dict(dic, parser=None):
"""Prepare an argparser from a dictionary.
Args:
dic (dict): Two-level config dictionary with unique bottom-level keys.
parser (argparse.ArgumentParser, optional): If a parser already
exists, add the keys from the dictionary on the top of it.
Returns:
argparse.ArgumentParser:
Parser instance with groups corresponding to the first level keys
and arguments corresponding to the second level keys with default
values given by the values.
"""
def standardized_entry_type(value):
"""If the default value is None, replace NoneType by str_int_float.
If the default value is boolean, look for boolean strings."""
if value is None:
return str_int_float
if isinstance(str2bool(value), bool):
return str2bool_arg
return type(value)
if parser is None:
parser = argparse.ArgumentParser()
for k in dic.keys():
group = parser.add_argument_group(k)
if isinstance(dic[k], list):
entry_type = standardized_entry_type(dic[k])
group.add_argument("--" + k, default=dic[k], type=entry_type)
elif isinstance(dic[k], dict):
for kk in dic[k].keys():
entry_type = standardized_entry_type(dic[k][kk])
group.add_argument("--" + kk, default=dic[k][kk], type=entry_type)
elif isinstance(dic[k], str):
entry_type = standardized_entry_type(dic[k])
group.add_argument("--" + k, default=dic[k], type=entry_type)
return parser
def str_int_float(value):
"""Type to convert strings to int, float (in this order) if possible.
Args:
value (str): Value to convert.
Returns:
int, float, str: Converted value.
"""
if isint(value):
return int(value)
if isfloat(value):
return float(value)
elif isinstance(value, str):
return value
def str2bool(value):
"""Type to convert strings to Boolean (returns input if not boolean)"""
if not isinstance(value, str):
return value
if value.lower() in ("yes", "true", "y", "1"):
return True
elif value.lower() in ("no", "false", "n", "0"):
return False
else:
return value
def str2bool_arg(value):
"""Argparse type to convert strings to Boolean"""
value = str2bool(value)
if isinstance(value, bool):
return value
raise argparse.ArgumentTypeError("Boolean value expected.")
def isfloat(value):
"""Computes whether `value` can be cast to a float.
Args:
value (str): Value to check.
Returns:
bool: Whether `value` can be cast to a float.
"""
try:
float(value)
return True
except ValueError:
return False
def isint(value):
"""Computes whether `value` can be cast to an int
Args:
value (str): Value to check.
Returns:
bool: Whether `value` can be cast to an int.
"""
try:
int(value)
return True
except ValueError:
return False
def parse_args_as_dict(parser, return_plain_args=False, args=None):
"""Get a dict of dicts out of process `parser.parse_args()`
Top-level keys corresponding to groups and bottom-level keys corresponding
to arguments. Under `'main_args'`, the arguments which don't belong to a
argparse group (i.e main arguments defined before parsing from a dict) can
be found.
Args:
parser (argparse.ArgumentParser): ArgumentParser instance containing
groups. Output of `prepare_parser_from_dict`.
return_plain_args (bool): Whether to return the output or
`parser.parse_args()`.
args (list): List of arguments as read from the command line.
Used for unit testing.
Returns:
dict:
Dictionary of dictionaries containing the arguments. Optionally the
direct output `parser.parse_args()`.
"""
args = parser.parse_args(args=args)
args_dic = {}
for group in parser._action_groups:
group_dict = {a.dest: getattr(args, a.dest, None) for a in group._group_actions}
args_dic[group.title] = group_dict
if sys.version_info.minor == 10:
args_dic["main_args"] = args_dic["positional arguments"]
del args_dic["positional arguments"]
else:
args_dic["main_args"] = args_dic["optional arguments"]
del args_dic["optional arguments"]
if return_plain_args:
return args_dic, args
return args_dic
def instantiate(config, **kwargs):
if '__target__' in config:
module_path, class_name = config['__target__'].rsplit('.', 1)
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
# 先处理嵌套的配置
params = {}
for key, value in config.items():
if key != '__target__':
if isinstance(value, DictConfig) and '__target__' in value:
params[key] = instantiate(value)
else:
params[key] = value
# 添加额外的关键字参数
params.update(kwargs)
return cls(**params)
else:
# 对于不包含 '__target__' 的字典,递归处理其每个值
return {k: instantiate(v, **kwargs) if isinstance(v, DictConfig) else v for k, v in config.items()} |