# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Generic utilities """ from collections import OrderedDict from dataclasses import fields, is_dataclass from typing import Any, Tuple import numpy as np from .import_utils import is_torch_available, is_torch_version def is_tensor(x) -> bool: """ Tests if `x` is a `torch.Tensor` or `np.ndarray`. """ if is_torch_available(): import torch if isinstance(x, torch.Tensor): return True return isinstance(x, np.ndarray) class BaseOutput(OrderedDict): """ Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular Python dictionary. You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple first. """ def __init_subclass__(cls) -> None: """Register subclasses as pytree nodes. This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with `static_graph=True` with modules that output `ModelOutput` subclasses. """ if is_torch_available(): import torch.utils._pytree if is_torch_version("<", "2.2"): torch.utils._pytree._register_pytree_node( cls, torch.utils._pytree._dict_flatten, lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), ) else: torch.utils._pytree.register_pytree_node( cls, torch.utils._pytree._dict_flatten, lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), ) def __post_init__(self) -> None: class_fields = fields(self) # Safety and consistency checks if not len(class_fields): raise ValueError(f"{self.__class__.__name__} has no fields.") first_field = getattr(self, class_fields[0].name) other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) if other_fields_are_none and isinstance(first_field, dict): for key, value in first_field.items(): self[key] = value else: for field in class_fields: v = getattr(self, field.name) if v is not None: self[field.name] = v def __delitem__(self, *args, **kwargs): raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") def setdefault(self, *args, **kwargs): raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") def pop(self, *args, **kwargs): raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") def update(self, *args, **kwargs): raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") def __getitem__(self, k: Any) -> Any: if isinstance(k, str): inner_dict = dict(self.items()) return inner_dict[k] else: return self.to_tuple()[k] def __setattr__(self, name: Any, value: Any) -> None: if name in self.keys() and value is not None: # Don't call self.__setitem__ to avoid recursion errors super().__setitem__(name, value) super().__setattr__(name, value) def __setitem__(self, key, value): # Will raise a KeyException if needed super().__setitem__(key, value) # Don't call self.__setattr__ to avoid recursion errors super().__setattr__(key, value) def __reduce__(self): if not is_dataclass(self): return super().__reduce__() callable, _args, *remaining = super().__reduce__() args = tuple(getattr(self, field.name) for field in fields(self)) return callable, args, *remaining def to_tuple(self) -> Tuple[Any, ...]: """ Convert self to a tuple containing all the attributes/keys that are not `None`. """ return tuple(self[k] for k in self.keys())