Spaces:
Running
on
Zero
Running
on
Zero
### | |
# Author: Kai Li | |
# Date: 2021-06-20 00:36:46 | |
# LastEditors: Please set LastEditors | |
# LastEditTime: 2024-01-22 03:02:57 | |
### | |
import sys | |
import argparse | |
import importlib | |
from omegaconf import DictConfig | |
def prepare_parser_from_dict(dic, parser=None): | |
"""Prepare an argparser from a dictionary. | |
Args: | |
dic (dict): Two-level config dictionary with unique bottom-level keys. | |
parser (argparse.ArgumentParser, optional): If a parser already | |
exists, add the keys from the dictionary on the top of it. | |
Returns: | |
argparse.ArgumentParser: | |
Parser instance with groups corresponding to the first level keys | |
and arguments corresponding to the second level keys with default | |
values given by the values. | |
""" | |
def standardized_entry_type(value): | |
"""If the default value is None, replace NoneType by str_int_float. | |
If the default value is boolean, look for boolean strings.""" | |
if value is None: | |
return str_int_float | |
if isinstance(str2bool(value), bool): | |
return str2bool_arg | |
return type(value) | |
if parser is None: | |
parser = argparse.ArgumentParser() | |
for k in dic.keys(): | |
group = parser.add_argument_group(k) | |
if isinstance(dic[k], list): | |
entry_type = standardized_entry_type(dic[k]) | |
group.add_argument("--" + k, default=dic[k], type=entry_type) | |
elif isinstance(dic[k], dict): | |
for kk in dic[k].keys(): | |
entry_type = standardized_entry_type(dic[k][kk]) | |
group.add_argument("--" + kk, default=dic[k][kk], type=entry_type) | |
elif isinstance(dic[k], str): | |
entry_type = standardized_entry_type(dic[k]) | |
group.add_argument("--" + k, default=dic[k], type=entry_type) | |
return parser | |
def str_int_float(value): | |
"""Type to convert strings to int, float (in this order) if possible. | |
Args: | |
value (str): Value to convert. | |
Returns: | |
int, float, str: Converted value. | |
""" | |
if isint(value): | |
return int(value) | |
if isfloat(value): | |
return float(value) | |
elif isinstance(value, str): | |
return value | |
def str2bool(value): | |
"""Type to convert strings to Boolean (returns input if not boolean)""" | |
if not isinstance(value, str): | |
return value | |
if value.lower() in ("yes", "true", "y", "1"): | |
return True | |
elif value.lower() in ("no", "false", "n", "0"): | |
return False | |
else: | |
return value | |
def str2bool_arg(value): | |
"""Argparse type to convert strings to Boolean""" | |
value = str2bool(value) | |
if isinstance(value, bool): | |
return value | |
raise argparse.ArgumentTypeError("Boolean value expected.") | |
def isfloat(value): | |
"""Computes whether `value` can be cast to a float. | |
Args: | |
value (str): Value to check. | |
Returns: | |
bool: Whether `value` can be cast to a float. | |
""" | |
try: | |
float(value) | |
return True | |
except ValueError: | |
return False | |
def isint(value): | |
"""Computes whether `value` can be cast to an int | |
Args: | |
value (str): Value to check. | |
Returns: | |
bool: Whether `value` can be cast to an int. | |
""" | |
try: | |
int(value) | |
return True | |
except ValueError: | |
return False | |
def parse_args_as_dict(parser, return_plain_args=False, args=None): | |
"""Get a dict of dicts out of process `parser.parse_args()` | |
Top-level keys corresponding to groups and bottom-level keys corresponding | |
to arguments. Under `'main_args'`, the arguments which don't belong to a | |
argparse group (i.e main arguments defined before parsing from a dict) can | |
be found. | |
Args: | |
parser (argparse.ArgumentParser): ArgumentParser instance containing | |
groups. Output of `prepare_parser_from_dict`. | |
return_plain_args (bool): Whether to return the output or | |
`parser.parse_args()`. | |
args (list): List of arguments as read from the command line. | |
Used for unit testing. | |
Returns: | |
dict: | |
Dictionary of dictionaries containing the arguments. Optionally the | |
direct output `parser.parse_args()`. | |
""" | |
args = parser.parse_args(args=args) | |
args_dic = {} | |
for group in parser._action_groups: | |
group_dict = {a.dest: getattr(args, a.dest, None) for a in group._group_actions} | |
args_dic[group.title] = group_dict | |
if sys.version_info.minor == 10: | |
args_dic["main_args"] = args_dic["positional arguments"] | |
del args_dic["positional arguments"] | |
else: | |
args_dic["main_args"] = args_dic["optional arguments"] | |
del args_dic["optional arguments"] | |
if return_plain_args: | |
return args_dic, args | |
return args_dic | |
def instantiate(config, **kwargs): | |
if '__target__' in config: | |
module_path, class_name = config['__target__'].rsplit('.', 1) | |
module = importlib.import_module(module_path) | |
cls = getattr(module, class_name) | |
# 先处理嵌套的配置 | |
params = {} | |
for key, value in config.items(): | |
if key != '__target__': | |
if isinstance(value, DictConfig) and '__target__' in value: | |
params[key] = instantiate(value) | |
else: | |
params[key] = value | |
# 添加额外的关键字参数 | |
params.update(kwargs) | |
return cls(**params) | |
else: | |
# 对于不包含 '__target__' 的字典,递归处理其每个值 | |
return {k: instantiate(v, **kwargs) if isinstance(v, DictConfig) else v for k, v in config.items()} |