Apollo / look2hear /utils /parser_utils.py
Serhiy Stetskovych
Initial code
78e32cc
raw
history blame
5.68 kB
###
# Author: Kai Li
# Date: 2021-06-20 00:36:46
# LastEditors: Please set LastEditors
# LastEditTime: 2024-01-22 03:02:57
###
import sys
import argparse
import importlib
from omegaconf import DictConfig
def prepare_parser_from_dict(dic, parser=None):
"""Prepare an argparser from a dictionary.
Args:
dic (dict): Two-level config dictionary with unique bottom-level keys.
parser (argparse.ArgumentParser, optional): If a parser already
exists, add the keys from the dictionary on the top of it.
Returns:
argparse.ArgumentParser:
Parser instance with groups corresponding to the first level keys
and arguments corresponding to the second level keys with default
values given by the values.
"""
def standardized_entry_type(value):
"""If the default value is None, replace NoneType by str_int_float.
If the default value is boolean, look for boolean strings."""
if value is None:
return str_int_float
if isinstance(str2bool(value), bool):
return str2bool_arg
return type(value)
if parser is None:
parser = argparse.ArgumentParser()
for k in dic.keys():
group = parser.add_argument_group(k)
if isinstance(dic[k], list):
entry_type = standardized_entry_type(dic[k])
group.add_argument("--" + k, default=dic[k], type=entry_type)
elif isinstance(dic[k], dict):
for kk in dic[k].keys():
entry_type = standardized_entry_type(dic[k][kk])
group.add_argument("--" + kk, default=dic[k][kk], type=entry_type)
elif isinstance(dic[k], str):
entry_type = standardized_entry_type(dic[k])
group.add_argument("--" + k, default=dic[k], type=entry_type)
return parser
def str_int_float(value):
"""Type to convert strings to int, float (in this order) if possible.
Args:
value (str): Value to convert.
Returns:
int, float, str: Converted value.
"""
if isint(value):
return int(value)
if isfloat(value):
return float(value)
elif isinstance(value, str):
return value
def str2bool(value):
"""Type to convert strings to Boolean (returns input if not boolean)"""
if not isinstance(value, str):
return value
if value.lower() in ("yes", "true", "y", "1"):
return True
elif value.lower() in ("no", "false", "n", "0"):
return False
else:
return value
def str2bool_arg(value):
"""Argparse type to convert strings to Boolean"""
value = str2bool(value)
if isinstance(value, bool):
return value
raise argparse.ArgumentTypeError("Boolean value expected.")
def isfloat(value):
"""Computes whether `value` can be cast to a float.
Args:
value (str): Value to check.
Returns:
bool: Whether `value` can be cast to a float.
"""
try:
float(value)
return True
except ValueError:
return False
def isint(value):
"""Computes whether `value` can be cast to an int
Args:
value (str): Value to check.
Returns:
bool: Whether `value` can be cast to an int.
"""
try:
int(value)
return True
except ValueError:
return False
def parse_args_as_dict(parser, return_plain_args=False, args=None):
"""Get a dict of dicts out of process `parser.parse_args()`
Top-level keys corresponding to groups and bottom-level keys corresponding
to arguments. Under `'main_args'`, the arguments which don't belong to a
argparse group (i.e main arguments defined before parsing from a dict) can
be found.
Args:
parser (argparse.ArgumentParser): ArgumentParser instance containing
groups. Output of `prepare_parser_from_dict`.
return_plain_args (bool): Whether to return the output or
`parser.parse_args()`.
args (list): List of arguments as read from the command line.
Used for unit testing.
Returns:
dict:
Dictionary of dictionaries containing the arguments. Optionally the
direct output `parser.parse_args()`.
"""
args = parser.parse_args(args=args)
args_dic = {}
for group in parser._action_groups:
group_dict = {a.dest: getattr(args, a.dest, None) for a in group._group_actions}
args_dic[group.title] = group_dict
if sys.version_info.minor == 10:
args_dic["main_args"] = args_dic["positional arguments"]
del args_dic["positional arguments"]
else:
args_dic["main_args"] = args_dic["optional arguments"]
del args_dic["optional arguments"]
if return_plain_args:
return args_dic, args
return args_dic
def instantiate(config, **kwargs):
if '__target__' in config:
module_path, class_name = config['__target__'].rsplit('.', 1)
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
# 先处理嵌套的配置
params = {}
for key, value in config.items():
if key != '__target__':
if isinstance(value, DictConfig) and '__target__' in value:
params[key] = instantiate(value)
else:
params[key] = value
# 添加额外的关键字参数
params.update(kwargs)
return cls(**params)
else:
# 对于不包含 '__target__' 的字典,递归处理其每个值
return {k: instantiate(v, **kwargs) if isinstance(v, DictConfig) else v for k, v in config.items()}