Spaces:
Sleeping
Sleeping
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import logging | |
import inspect | |
import typing | |
from logging import WARNING | |
from toolbox.to_markdown.common.registrable import Registrable | |
logger = logging.getLogger("toolbox") | |
class Params(Registrable): | |
""" | |
仿照 AllenNLP 框架, Params 与 Registrable 配合使用, 通过子类的 Annotation 类型标注使得子类可以通过 | |
.from_json({ | |
**parameters | |
}) | |
方式实例化. | |
注意事项: | |
1. 子类除了 self, 每个参数都必须有类型标注. | |
2. 当子类没有 __init__ 方法时, 会调用到基类的该方法 (基类的默认实现没有类型标注), | |
因此, 都应实现 __init__ 方法. | |
""" | |
# def __init__(self): | |
# # Subclasses should override this method, even if it is def __init__(self): pass | |
# pass | |
def from_json(cls, params: dict = None, global_params: dict = None): | |
""" | |
:param params: | |
:param global_params: 当缺少某参数时, 尝试从 global_params 中查找. | |
:return: | |
""" | |
if params is None: | |
params = dict() | |
if global_params is None: | |
global_params = dict() | |
if "type" in params: | |
cls = cls.by_name(params["type"]) | |
signature = inspect.signature(cls.__init__) | |
kwargs = dict() | |
for k, v in signature.parameters.items(): | |
if k in ("self",): | |
continue | |
if k in ("args", "kwargs"): | |
msg = ( | |
f"parameter: args or kwargs is not expected. " | |
f"you may need to override the __init__ method of cls: {cls.__name__}." | |
) | |
logger.warning(msg) | |
print(msg) | |
continue | |
if v.annotation is inspect._empty: | |
raise NotImplementedError( | |
"all parameter should have a annotation. " | |
"parameter `{}` of {} have not annotation".format(k, cls) | |
) | |
if v.name in params: | |
sub_params = params[v.name] | |
elif v.name in global_params: | |
sub_params = global_params[v.name] | |
else: | |
continue | |
if isinstance(v.annotation, str): | |
raise NotImplementedError("string annotation not supported.") | |
if hasattr(v.annotation, "_subs_tree"): | |
# typing 标注类型. | |
subs_tree = v.annotation._subs_tree() | |
kwargs[v.name] = cls.from_annotation(sub_params, global_params, subs_tree) | |
elif isinstance(v.annotation, typing._GenericAlias): | |
# typing 标注类型. | |
subs_tree = (v.annotation.__origin__, *v.annotation.__args__) | |
kwargs[v.name] = cls.from_annotation(sub_params, global_params, subs_tree) | |
elif issubclass(v.annotation, Params): | |
# Params 子类. | |
kwargs[v.name] = v.annotation.from_json( | |
sub_params, global_params | |
) | |
elif isinstance(sub_params, v.annotation): | |
# 传入的是已实例化好的值. | |
kwargs[v.name] = sub_params | |
else: | |
# str, int, list, dict 等基本类型. | |
value = sub_params | |
if isinstance(value, dict): | |
value = v.annotation(**value) | |
else: | |
value = v.annotation(value) | |
kwargs[v.name] = value | |
obj = cls.__new__(cls, **kwargs) | |
try: | |
obj.__init__(**kwargs) | |
except TypeError as e: | |
print(e) | |
print("cls: {}, obj: {}, kwargs: {}".format(cls, obj, kwargs)) | |
logger.error(e) | |
logger.error("cls: {}, obj: {}, kwargs: {}".format(cls, obj, kwargs)) | |
raise e | |
return obj | |
def from_annotation(cls, params, global_params: dict, subs_tree=None): | |
if params is None: | |
return params | |
if subs_tree is None: | |
return params | |
if isinstance(subs_tree, tuple) and len(subs_tree) > 1: | |
# such as: (Dict, str, int) in List[Dict[str, int]] | |
args_type = subs_tree[0] | |
annotation = subs_tree[1:] | |
elif isinstance(subs_tree, tuple) and len(subs_tree) == 1: | |
args_type = subs_tree[0] | |
annotation = None | |
else: | |
args_type = subs_tree | |
annotation = None | |
if args_type is typing.List or args_type is list: | |
result = list() | |
for param in params: | |
result.append(cls.from_annotation(param, global_params, annotation)) | |
return result | |
elif args_type is typing.Dict or args_type is list: | |
result = dict() | |
for k, v in params.items(): | |
key = cls.from_annotation(k, global_params, annotation[0]) | |
value = cls.from_annotation(v, global_params, annotation[1]) | |
result[key] = value | |
return result | |
elif args_type is typing.Tuple or args_type is tuple: | |
if len(annotation) != len(params): | |
raise AssertionError( | |
"number of params not match the annotation. " | |
"{}, annotation: {}, params: {}".format(cls, annotation, params) | |
) | |
result = list() | |
for param, sub_annotation in zip(params, annotation): | |
result.append(cls.from_annotation(param, global_params, sub_annotation)) | |
return tuple(result) | |
elif args_type is typing.Union: | |
for option in annotation: | |
try: | |
result = cls.from_annotation(params, global_params, option) | |
break | |
except Exception: | |
continue | |
else: | |
raise ValueError("no type of Union match the params {}".format(params)) | |
return result | |
elif args_type is typing.Any: | |
result = params | |
return result | |
if hasattr(typing, "GenericMeta"): | |
built_in_type = typing.GenericMeta | |
elif hasattr(typing, "GenericAlias"): | |
built_in_type = typing.GenericAlias | |
else: | |
raise NotImplementedError | |
if not isinstance(args_type, built_in_type): | |
if hasattr(args_type, "from_json"): | |
result = args_type.from_json(params, global_params) | |
elif isinstance(args_type, tuple) and len(args_type) > 0 and isinstance(args_type[0], built_in_type): | |
# List[Dict[str, List[str]]] | |
result = cls.from_annotation(params, global_params, args_type) | |
else: | |
if isinstance(params, dict): | |
result = args_type(**params) | |
else: | |
result = args_type(params) | |
return result | |
raise NotImplementedError( | |
"{}, params: {}, subs_tree: {}".format(cls, params, subs_tree) | |
) | |
if __name__ == "__main__": | |
pass | |