Spaces:
Starting
on
L40S
Starting
on
L40S
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import json | |
import dataclasses | |
import numpy as np | |
from dataclasses import Field, MISSING | |
from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple | |
_X = TypeVar("_X") | |
def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X: | |
""" | |
Loads to a @dataclass or collection hierarchy including dataclasses | |
from a json recursively. | |
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]). | |
raises KeyError if json has keys not mapping to the dataclass fields. | |
Args: | |
f: Either a path to a file, or a file opened for writing. | |
cls: The class of the loaded dataclass. | |
binary: Set to True if `f` is a file handle, else False. | |
""" | |
if binary: | |
asdict = json.loads(f.read().decode("utf8")) | |
else: | |
asdict = json.load(f) | |
# in the list case, run a faster "vectorized" version | |
cls = get_args(cls)[0] | |
res = list(_dataclass_list_from_dict_list(asdict, cls)) | |
return res | |
def _resolve_optional(type_: Any) -> Tuple[bool, Any]: | |
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T.""" | |
if get_origin(type_) is Union: | |
args = get_args(type_) | |
if len(args) == 2 and args[1] == type(None): # noqa E721 | |
return True, args[0] | |
if type_ is Any: | |
return True, Any | |
return False, type_ | |
def _unwrap_type(tp): | |
# strips Optional wrapper, if any | |
if get_origin(tp) is Union: | |
args = get_args(tp) | |
if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721 | |
# this is typing.Optional | |
return args[0] if args[1] is type(None) else args[1] # noqa: E721 | |
return tp | |
def _get_dataclass_field_default(field: Field) -> Any: | |
if field.default_factory is not MISSING: | |
# pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE, | |
# dataclasses._DefaultFactory[typing.Any]]` is not a function. | |
return field.default_factory() | |
elif field.default is not MISSING: | |
return field.default | |
else: | |
return None | |
def _dataclass_list_from_dict_list(dlist, typeannot): | |
""" | |
Vectorised version of `_dataclass_from_dict`. | |
The output should be equivalent to | |
`[_dataclass_from_dict(d, typeannot) for d in dlist]`. | |
Args: | |
dlist: list of objects to convert. | |
typeannot: type of each of those objects. | |
Returns: | |
iterator or list over converted objects of the same length as `dlist`. | |
Raises: | |
ValueError: it assumes the objects have None's in consistent places across | |
objects, otherwise it would ignore some values. This generally holds for | |
auto-generated annotations, but otherwise use `_dataclass_from_dict`. | |
""" | |
cls = get_origin(typeannot) or typeannot | |
if typeannot is Any: | |
return dlist | |
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes | |
return dlist | |
if any(obj is None for obj in dlist): | |
# filter out Nones and recurse on the resulting list | |
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None] | |
idx, notnone = zip(*idx_notnone) | |
converted = _dataclass_list_from_dict_list(notnone, typeannot) | |
res = [None] * len(dlist) | |
for i, obj in zip(idx, converted): | |
res[i] = obj | |
return res | |
is_optional, contained_type = _resolve_optional(typeannot) | |
if is_optional: | |
return _dataclass_list_from_dict_list(dlist, contained_type) | |
# otherwise, we dispatch by the type of the provided annotation to convert to | |
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple | |
# For namedtuple, call the function recursively on the lists of corresponding keys | |
types = cls.__annotations__.values() | |
dlist_T = zip(*dlist) | |
res_T = [ | |
_dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types) | |
] | |
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)] | |
elif issubclass(cls, (list, tuple)): | |
# For list/tuple, call the function recursively on the lists of corresponding positions | |
types = get_args(typeannot) | |
if len(types) == 1: # probably List; replicate for all items | |
types = types * len(dlist[0]) | |
dlist_T = zip(*dlist) | |
res_T = ( | |
_dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types) | |
) | |
if issubclass(cls, tuple): | |
return list(zip(*res_T)) | |
else: | |
return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)] | |
elif issubclass(cls, dict): | |
# For the dictionary, call the function recursively on concatenated keys and vertices | |
key_t, val_t = get_args(typeannot) | |
all_keys_res = _dataclass_list_from_dict_list( | |
[k for obj in dlist for k in obj.keys()], key_t | |
) | |
all_vals_res = _dataclass_list_from_dict_list( | |
[k for obj in dlist for k in obj.values()], val_t | |
) | |
indices = np.cumsum([len(obj) for obj in dlist]) | |
assert indices[-1] == len(all_keys_res) | |
keys = np.split(list(all_keys_res), indices[:-1]) | |
all_vals_res_iter = iter(all_vals_res) | |
return [cls(zip(k, all_vals_res_iter)) for k in keys] | |
elif not dataclasses.is_dataclass(typeannot): | |
return dlist | |
# dataclass node: 2nd recursion base; call the function recursively on the lists | |
# of the corresponding fields | |
assert dataclasses.is_dataclass(cls) | |
fieldtypes = { | |
f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f)) | |
for f in dataclasses.fields(typeannot) | |
} | |
# NOTE the default object is shared here | |
key_lists = ( | |
_dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_) | |
for k, (type_, default) in fieldtypes.items() | |
) | |
transposed = zip(*key_lists) | |
return [cls(*vals_as_tuple) for vals_as_tuple in transposed] | |