File size: 7,155 Bytes
e94100d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import logging
import inspect
import typing
from logging import WARNING

from toolbox.to_markdown.common.registrable import Registrable

logger = logging.getLogger("toolbox")


class Params(Registrable):
    """
    仿照 AllenNLP 框架, Params 与 Registrable 配合使用, 通过子类的 Annotation 类型标注使得子类可以通过

    .from_json({
        **parameters
    })

    方式实例化.

    注意事项:
    1. 子类除了 self, 每个参数都必须有类型标注.
    2. 当子类没有 __init__ 方法时, 会调用到基类的该方法 (基类的默认实现没有类型标注),
    因此, 都应实现 __init__ 方法.

    """
    # def __init__(self):
    #     # Subclasses should override this method, even if it is def __init__(self): pass
    #     pass

    @classmethod
    def from_json(cls, params: dict = None, global_params: dict = None):
        """
        :param params:
        :param global_params: 当缺少某参数时, 尝试从 global_params 中查找.
        :return:
        """
        if params is None:
            params = dict()
        if global_params is None:
            global_params = dict()

        if "type" in params:
            cls = cls.by_name(params["type"])

        signature = inspect.signature(cls.__init__)

        kwargs = dict()
        for k, v in signature.parameters.items():
            if k in ("self",):
                continue
            if k in ("args", "kwargs"):
                msg = (
                    f"parameter: args or kwargs is not expected. "
                    f"you may need to override the __init__ method of cls: {cls.__name__}."
                )
                logger.warning(msg)
                print(msg)
                continue

            if v.annotation is inspect._empty:
                raise NotImplementedError(
                    "all parameter should have a annotation. "
                    "parameter `{}` of {} have not annotation".format(k, cls)
                )

            if v.name in params:
                sub_params = params[v.name]
            elif v.name in global_params:
                sub_params = global_params[v.name]
            else:
                continue

            if isinstance(v.annotation, str):
                raise NotImplementedError("string annotation not supported.")

            if hasattr(v.annotation, "_subs_tree"):
                # typing 标注类型.
                subs_tree = v.annotation._subs_tree()
                kwargs[v.name] = cls.from_annotation(sub_params, global_params, subs_tree)
            elif isinstance(v.annotation, typing._GenericAlias):
                # typing 标注类型.
                subs_tree = (v.annotation.__origin__, *v.annotation.__args__)
                kwargs[v.name] = cls.from_annotation(sub_params, global_params, subs_tree)
            elif issubclass(v.annotation, Params):
                # Params 子类.
                kwargs[v.name] = v.annotation.from_json(
                    sub_params, global_params
                )
            elif isinstance(sub_params, v.annotation):
                # 传入的是已实例化好的值.
                kwargs[v.name] = sub_params
            else:
                # str, int, list, dict 等基本类型.
                value = sub_params
                if isinstance(value, dict):
                    value = v.annotation(**value)
                else:
                    value = v.annotation(value)

                kwargs[v.name] = value

        obj = cls.__new__(cls, **kwargs)
        try:
            obj.__init__(**kwargs)
        except TypeError as e:
            print(e)
            print("cls: {}, obj: {}, kwargs: {}".format(cls, obj, kwargs))
            logger.error(e)
            logger.error("cls: {}, obj: {}, kwargs: {}".format(cls, obj, kwargs))
            raise e

        return obj

    @classmethod
    def from_annotation(cls, params, global_params: dict, subs_tree=None):
        if params is None:
            return params
        if subs_tree is None:
            return params

        if isinstance(subs_tree, tuple) and len(subs_tree) > 1:
            # such as: (Dict, str, int) in List[Dict[str, int]]
            args_type = subs_tree[0]
            annotation = subs_tree[1:]
        elif isinstance(subs_tree, tuple) and len(subs_tree) == 1:
            args_type = subs_tree[0]
            annotation = None
        else:
            args_type = subs_tree
            annotation = None

        if args_type is typing.List or args_type is list:
            result = list()
            for param in params:
                result.append(cls.from_annotation(param, global_params, annotation))
            return result
        elif args_type is typing.Dict or args_type is list:
            result = dict()
            for k, v in params.items():
                key = cls.from_annotation(k, global_params, annotation[0])
                value = cls.from_annotation(v, global_params, annotation[1])
                result[key] = value
            return result
        elif args_type is typing.Tuple or args_type is tuple:
            if len(annotation) != len(params):
                raise AssertionError(
                    "number of params not match the annotation. "
                    "{}, annotation: {}, params: {}".format(cls, annotation, params)
                )
            result = list()
            for param, sub_annotation in zip(params, annotation):
                result.append(cls.from_annotation(param, global_params, sub_annotation))
            return tuple(result)
        elif args_type is typing.Union:
            for option in annotation:
                try:
                    result = cls.from_annotation(params, global_params, option)
                    break
                except Exception:
                    continue
            else:
                raise ValueError("no type of Union match the params {}".format(params))
            return result
        elif args_type is typing.Any:
            result = params
            return result

        if hasattr(typing, "GenericMeta"):
            built_in_type = typing.GenericMeta
        elif hasattr(typing, "GenericAlias"):
            built_in_type = typing.GenericAlias
        else:
            raise NotImplementedError

        if not isinstance(args_type, built_in_type):
            if hasattr(args_type, "from_json"):
                result = args_type.from_json(params, global_params)
            elif isinstance(args_type, tuple) and len(args_type) > 0 and isinstance(args_type[0], built_in_type):
                # List[Dict[str, List[str]]]
                result = cls.from_annotation(params, global_params, args_type)
            else:
                if isinstance(params, dict):
                    result = args_type(**params)
                else:
                    result = args_type(params)

            return result

        raise NotImplementedError(
            "{}, params: {}, subs_tree: {}".format(cls, params, subs_tree)
        )


if __name__ == "__main__":
    pass