metric / dataclass.py
Elron's picture
Upload folder using huggingface_hub
88c61d3 verified
import copy
import dataclasses
import functools
import inspect
from abc import ABCMeta
from inspect import Parameter, Signature
from typing import Any, Dict, List, Optional, final
_FIELDS = "__fields__"
class Undefined:
pass
@dataclasses.dataclass
class Field:
"""An alternative to dataclasses.dataclass decorator for a more flexible field definition.
Args:
default (Any, optional):
Default value for the field. Defaults to None.
name (str, optional):
Name of the field. Defaults to None.
type (type, optional):
Type of the field. Defaults to None.
default_factory (Any, optional):
A function that returns the default value. Defaults to None.
final (bool, optional):
A boolean indicating if the field is final (cannot be overridden). Defaults to False.
abstract (bool, optional):
A boolean indicating if the field is abstract (must be implemented by subclasses). Defaults to False.
required (bool, optional):
A boolean indicating if the field is required. Defaults to False.
origin_cls (type, optional):
The original class that defined the field. Defaults to None.
"""
default: Any = Undefined
name: str = None
type: type = None
init: bool = True
also_positional: bool = True
default_factory: Any = None
final: bool = False
abstract: bool = False
required: bool = False
internal: bool = False
origin_cls: type = None
metadata: Dict[str, str] = dataclasses.field(default_factory=dict)
def get_default(self):
if self.default_factory is not None:
return self.default_factory()
return self.default
@dataclasses.dataclass
class FinalField(Field):
def __post_init__(self):
self.final = True
@dataclasses.dataclass
class RequiredField(Field):
def __post_init__(self):
self.required = True
class MissingDefaultError(TypeError):
pass
@dataclasses.dataclass
class OptionalField(Field):
def __post_init__(self):
self.required = False
if self.default is Undefined and self.default_factory is None:
raise MissingDefaultError(
"OptionalField must have default or default_factory"
)
@dataclasses.dataclass
class AbstractField(Field):
def __post_init__(self):
self.abstract = True
@dataclasses.dataclass
class NonPositionalField(Field):
def __post_init__(self):
self.also_positional = False
@dataclasses.dataclass
class InternalField(Field):
def __post_init__(self):
self.internal = True
self.init = False
self.also_positional = False
class FinalFieldError(TypeError):
pass
class RequiredFieldError(TypeError):
pass
class AbstractFieldError(TypeError):
pass
class TypeMismatchError(TypeError):
pass
class UnexpectedArgumentError(TypeError):
pass
standard_variables = dir(object)
def is_class_method(func):
if inspect.ismethod(func):
return True
if inspect.isfunction(func):
sig = inspect.signature(func)
params = list(sig.parameters.values())
if len(params) > 0 and params[0].name in ["self", "cls"]:
return True
return False
def is_possible_field(field_name, field_value):
"""Check if a name-value pair can potentially represent a field.
Args:
field_name (str): The name of the field.
field_value: The value of the field.
Returns:
bool: True if the name-value pair can represent a field, False otherwise.
"""
if field_name in standard_variables:
return False
if is_class_method(field_value):
return False
return True
def get_fields(cls, attrs):
"""Get the fields for a class based on its attributes.
Args:
cls (type): The class to get the fields for.
attrs (dict): The attributes of the class.
Returns:
dict: A dictionary mapping field names to Field instances.
"""
fields = {}
for base in cls.__bases__:
fields = {**getattr(base, _FIELDS, {}), **fields}
annotations = {**attrs.get("__annotations__", {})}
for attr_name, attr_value in attrs.items():
if attr_name not in annotations and is_possible_field(attr_name, attr_value):
if attr_name in fields:
try:
if not isinstance(attr_value, fields[attr_name].type):
raise TypeMismatchError(
f"Type mismatch for field '{attr_name}' of class '{fields[attr_name].origin_cls}'. Expected {fields[attr_name].type}, got {type(attr_value)}"
)
except TypeError:
pass
annotations[attr_name] = fields[attr_name].type
for field_name, field_type in annotations.items():
if field_name in fields and fields[field_name].final:
raise FinalFieldError(
f"Final field {field_name} defined in {fields[field_name].origin_cls} overridden in {cls}"
)
args = {
"name": field_name,
"type": field_type,
"origin_cls": attrs["__qualname__"],
}
if field_name in attrs:
field_value = attrs[field_name]
if isinstance(field_value, Field):
args = {**dataclasses.asdict(field_value), **args}
elif isinstance(field_value, dataclasses.Field):
args = {
"default": field_value.default,
"name": field_value.name,
"type": field_value.type,
"init": field_value.init,
"default_factory": field_value.default_factory,
**args,
}
else:
args["default"] = field_value
args["default_factory"] = None
else:
args["default"] = dataclasses.MISSING
args["default_factory"] = None
args["required"] = True
field_instance = Field(**args)
fields[field_name] = field_instance
if cls.__allow_unexpected_arguments__:
fields["_argv"] = InternalField(name="_argv", type=tuple, default=())
fields["_kwargs"] = InternalField(name="_kwargs", type=dict, default={})
return fields
def is_dataclass(obj):
"""Returns True if obj is a dataclass or an instance of a dataclass."""
cls = obj if isinstance(obj, type) else type(obj)
return hasattr(cls, _FIELDS)
def class_fields(obj):
all_fields = fields(obj)
return [
field for field in all_fields if field.origin_cls == obj.__class__.__qualname__
]
def fields(cls):
return list(getattr(cls, _FIELDS).values())
def fields_names(cls):
return list(getattr(cls, _FIELDS).keys())
def external_fields_names(cls):
return [field.name for field in fields(cls) if not field.internal]
def final_fields(cls):
return [field for field in fields(cls) if field.final]
def required_fields(cls):
return [field for field in fields(cls) if field.required]
def abstract_fields(cls):
return [field for field in fields(cls) if field.abstract]
def is_abstract_field(field):
return field.abstract
def is_final_field(field):
return field.final
def get_field_default(field):
if field.default_factory is not None:
return field.default_factory()
return field.default
def asdict(obj):
assert is_dataclass(
obj
), f"{obj} must be a dataclass, got {type(obj)} with bases {obj.__class__.__bases__}"
return _asdict_inner(obj)
def _asdict_inner(obj):
if is_dataclass(obj):
return obj.to_dict()
if isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple
return type(obj)(*[_asdict_inner(v) for v in obj])
if isinstance(obj, (list, tuple)):
return type(obj)([_asdict_inner(v) for v in obj])
if isinstance(obj, dict):
return type(obj)({_asdict_inner(k): _asdict_inner(v) for k, v in obj.items()})
return copy.deepcopy(obj)
class DataclassMeta(ABCMeta):
"""Metaclass for Dataclass.
Checks for final fields when a subclass is created.
"""
@final
def __init__(cls, name, bases, attrs):
super().__init__(name, bases, attrs)
fields = get_fields(cls, attrs)
setattr(cls, _FIELDS, fields)
cls.update_init_signature()
def update_init_signature(cls):
parameters = []
for name, field in getattr(cls, _FIELDS).items():
if field.init and not field.internal:
if field.default is not Undefined:
default_value = field.default
elif field.default_factory is not None:
default_value = field.default_factory()
else:
default_value = Parameter.empty
if isinstance(default_value, dataclasses._MISSING_TYPE):
default_value = Parameter.empty
param = Parameter(
name,
Parameter.POSITIONAL_OR_KEYWORD,
default=default_value,
annotation=field.type,
)
parameters.append(param)
if getattr(cls, "__allow_unexpected_arguments__", False):
parameters.append(Parameter("_argv", Parameter.VAR_POSITIONAL))
parameters.append(Parameter("_kwargs", Parameter.VAR_KEYWORD))
signature = Signature(parameters, __validate_parameters__=False)
original_init = cls.__init__
@functools.wraps(original_init)
def custom_cls_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
custom_cls_init.__signature__ = signature
cls.__init__ = custom_cls_init
class Dataclass(metaclass=DataclassMeta):
"""Base class for data-like classes that provides additional functionality and control.
Base class for data-like classes that provides additional functionality and control
over Python's built-in @dataclasses.dataclass decorator. Other classes can inherit from
this class to get the benefits of this implementation. As a base class, it ensures that
all subclasses will automatically be data classes.
The usage and field definitions are similar to Python's built-in @dataclasses.dataclass decorator.
However, this implementation provides additional classes for defining "final", "required",
and "abstract" fields.
Key enhancements of this custom implementation:
1. Automatic Data Class Creation: All subclasses automatically become data classes,
without needing to use the @dataclasses.dataclass decorator.
2. Field Immutability: Supports creation of "final" fields (using FinalField class) that
cannot be overridden by subclasses. This functionality is not natively supported in
Python or in the built-in dataclasses module.
3. Required Fields: Supports creation of "required" fields (using RequiredField class) that
must be provided when creating an instance of the class, adding a level of validation
not present in the built-in dataclasses module.
4. Abstract Fields: Supports creation of "abstract" fields (using AbstractField class) that
must be overridden by any non-abstract subclass. This is similar to abstract methods in
an abc.ABC class, but applied to fields.
5. Type Checking: Performs type checking to ensure that if a field is redefined in a subclass,
the type of the field remains consistent, adding static type checking not natively supported
in Python.
6. Error Definitions: Defines specific error types (FinalFieldError, RequiredFieldError,
AbstractFieldError, TypeMismatchError) for providing detailed error information during debugging.
7. MetaClass Usage: Uses a metaclass (DataclassMeta) for customization of class creation,
allowing checks and alterations to be made at the time of class creation, providing more control.
:Example:
.. code-block:: python
class Parent(Dataclass):
final_field: int = FinalField(1) # this field cannot be overridden
required_field: str = RequiredField()
also_required_field: float
abstract_field: int = AbstractField()
class Child(Parent):
abstract_field = 3 # now once overridden, this is no longer abstract
required_field = Field(name="required_field", default="provided", type=str)
class Mixin(Dataclass):
mixin_field = Field(name="mixin_field", default="mixin", type=str)
class GrandChild(Child, Mixin):
pass
grand_child = GrandChild()
logger.info(grand_child.to_dict())
...
"""
__allow_unexpected_arguments__ = False
@final
def __init__(self, *argv, **kwargs):
"""Initialize fields based on kwargs.
Checks for abstract fields when an instance is created.
"""
super().__init__()
_init_fields = [field for field in fields(self) if field.init]
_init_fields_names = [field.name for field in _init_fields]
_init_positional_fields_names = [
field.name for field in _init_fields if field.also_positional
]
for name in _init_positional_fields_names[: len(argv)]:
if name in kwargs:
raise TypeError(
f"{self.__class__.__name__} got multiple values for argument '{name}'"
)
expected_unexpected_argv = kwargs.pop("_argv", None)
if len(argv) <= len(_init_positional_fields_names):
unexpected_argv = []
else:
unexpected_argv = argv[len(_init_positional_fields_names) :]
if expected_unexpected_argv is not None:
assert (
len(unexpected_argv) == 0
), f"Cannot specify both _argv and unexpected positional arguments. Got {unexpected_argv}"
unexpected_argv = tuple(expected_unexpected_argv)
expected_unexpected_kwargs = kwargs.pop("_kwargs", None)
unexpected_kwargs = {
k: v
for k, v in kwargs.items()
if k not in _init_fields_names and k not in ["_argv", "_kwargs"]
}
if expected_unexpected_kwargs is not None:
intersection = set(unexpected_kwargs.keys()) & set(
expected_unexpected_kwargs.keys()
)
assert (
len(intersection) == 0
), f"Cannot specify the same arguments in both _kwargs and in unexpected keyword arguments. Got {intersection} in both."
unexpected_kwargs = {**unexpected_kwargs, **expected_unexpected_kwargs}
if self.__allow_unexpected_arguments__:
if len(unexpected_argv) > 0:
kwargs["_argv"] = unexpected_argv
if len(unexpected_kwargs) > 0:
kwargs["_kwargs"] = unexpected_kwargs
else:
if len(unexpected_argv) > 0:
raise UnexpectedArgumentError(
f"Too many positional arguments {unexpected_argv} for class {self.__class__.__name__}.\nShould be only {len(_init_positional_fields_names)} positional arguments: {', '.join(_init_positional_fields_names)}"
)
if len(unexpected_kwargs) > 0:
raise UnexpectedArgumentError(
f"Unexpected keyword argument(s) {unexpected_kwargs} for class {self.__class__.__name__}.\nShould be one of: {external_fields_names(self)}"
)
for name, arg in zip(_init_positional_fields_names, argv):
kwargs[name] = arg
for field in abstract_fields(self):
raise AbstractFieldError(
f"Abstract field '{field.name}' of class {field.origin_cls} not implemented in {self.__class__.__name__}"
)
for field in required_fields(self):
if field.name not in kwargs:
raise RequiredFieldError(
f"Required field '{field.name}' of class {field.origin_cls} not set in {self.__class__.__name__}"
)
self.__pre_init__(**kwargs)
for field in fields(self):
if field.name in kwargs:
setattr(self, field.name, kwargs[field.name])
else:
setattr(self, field.name, get_field_default(field))
self.__post_init__()
@property
def __is_dataclass__(self) -> bool:
return True
def __pre_init__(self, **kwargs):
"""Pre initialization hook."""
pass
def __post_init__(self):
"""Post initialization hook."""
pass
def _to_raw_dict(self):
"""Convert to raw dict."""
return {field.name: getattr(self, field.name) for field in fields(self)}
def to_dict(self, classes: Optional[List] = None, keep_empty: bool = True):
"""Convert to dict.
Args:
classes (List, optional): List of parent classes which attributes should
be returned. If set to None, then all class' attributes are returned.
keep_empty (bool): If True, then parameters are returned regardless if
their values are None or not.
"""
if not classes:
attributes_dict = _asdict_inner(self._to_raw_dict())
else:
attributes = []
for cls in classes:
attributes += list(cls.__annotations__.keys())
attributes_dict = {
attribute: getattr(self, attribute) for attribute in attributes
}
return {
attribute: value
for attribute, value in attributes_dict.items()
if keep_empty or value is not None
}
def get_repr_dict(self):
result = {}
for field in fields(self):
if not field.internal:
result[field.name] = getattr(self, field.name)
return result
def __repr__(self) -> str:
"""String representation."""
return f"{self.__class__.__name__}({', '.join([f'{key}={val!r}' for key, val in self.get_repr_dict().items()])})"