### # Author: Kai Li # Date: 2021-06-20 00:36:46 # LastEditors: Please set LastEditors # LastEditTime: 2024-01-22 03:02:57 ### import sys import argparse import importlib from omegaconf import DictConfig def prepare_parser_from_dict(dic, parser=None): """Prepare an argparser from a dictionary. Args: dic (dict): Two-level config dictionary with unique bottom-level keys. parser (argparse.ArgumentParser, optional): If a parser already exists, add the keys from the dictionary on the top of it. Returns: argparse.ArgumentParser: Parser instance with groups corresponding to the first level keys and arguments corresponding to the second level keys with default values given by the values. """ def standardized_entry_type(value): """If the default value is None, replace NoneType by str_int_float. If the default value is boolean, look for boolean strings.""" if value is None: return str_int_float if isinstance(str2bool(value), bool): return str2bool_arg return type(value) if parser is None: parser = argparse.ArgumentParser() for k in dic.keys(): group = parser.add_argument_group(k) if isinstance(dic[k], list): entry_type = standardized_entry_type(dic[k]) group.add_argument("--" + k, default=dic[k], type=entry_type) elif isinstance(dic[k], dict): for kk in dic[k].keys(): entry_type = standardized_entry_type(dic[k][kk]) group.add_argument("--" + kk, default=dic[k][kk], type=entry_type) elif isinstance(dic[k], str): entry_type = standardized_entry_type(dic[k]) group.add_argument("--" + k, default=dic[k], type=entry_type) return parser def str_int_float(value): """Type to convert strings to int, float (in this order) if possible. Args: value (str): Value to convert. Returns: int, float, str: Converted value. """ if isint(value): return int(value) if isfloat(value): return float(value) elif isinstance(value, str): return value def str2bool(value): """Type to convert strings to Boolean (returns input if not boolean)""" if not isinstance(value, str): return value if value.lower() in ("yes", "true", "y", "1"): return True elif value.lower() in ("no", "false", "n", "0"): return False else: return value def str2bool_arg(value): """Argparse type to convert strings to Boolean""" value = str2bool(value) if isinstance(value, bool): return value raise argparse.ArgumentTypeError("Boolean value expected.") def isfloat(value): """Computes whether `value` can be cast to a float. Args: value (str): Value to check. Returns: bool: Whether `value` can be cast to a float. """ try: float(value) return True except ValueError: return False def isint(value): """Computes whether `value` can be cast to an int Args: value (str): Value to check. Returns: bool: Whether `value` can be cast to an int. """ try: int(value) return True except ValueError: return False def parse_args_as_dict(parser, return_plain_args=False, args=None): """Get a dict of dicts out of process `parser.parse_args()` Top-level keys corresponding to groups and bottom-level keys corresponding to arguments. Under `'main_args'`, the arguments which don't belong to a argparse group (i.e main arguments defined before parsing from a dict) can be found. Args: parser (argparse.ArgumentParser): ArgumentParser instance containing groups. Output of `prepare_parser_from_dict`. return_plain_args (bool): Whether to return the output or `parser.parse_args()`. args (list): List of arguments as read from the command line. Used for unit testing. Returns: dict: Dictionary of dictionaries containing the arguments. Optionally the direct output `parser.parse_args()`. """ args = parser.parse_args(args=args) args_dic = {} for group in parser._action_groups: group_dict = {a.dest: getattr(args, a.dest, None) for a in group._group_actions} args_dic[group.title] = group_dict if sys.version_info.minor == 10: args_dic["main_args"] = args_dic["positional arguments"] del args_dic["positional arguments"] else: args_dic["main_args"] = args_dic["optional arguments"] del args_dic["optional arguments"] if return_plain_args: return args_dic, args return args_dic def instantiate(config, **kwargs): if '__target__' in config: module_path, class_name = config['__target__'].rsplit('.', 1) module = importlib.import_module(module_path) cls = getattr(module, class_name) # 先处理嵌套的配置 params = {} for key, value in config.items(): if key != '__target__': if isinstance(value, DictConfig) and '__target__' in value: params[key] = instantiate(value) else: params[key] = value # 添加额外的关键字参数 params.update(kwargs) return cls(**params) else: # 对于不包含 '__target__' 的字典,递归处理其每个值 return {k: instantiate(v, **kwargs) if isinstance(v, DictConfig) else v for k, v in config.items()}