# Copyright (c) OpenMMLab. All rights reserved. from collections import abc from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union import numpy as np import torch from mmengine.structures import BaseDataElement, PixelData from mmengine.utils import is_list_of IndexType = Union[str, slice, int, list, torch.LongTensor, torch.cuda.LongTensor, torch.BoolTensor, torch.cuda.BoolTensor, np.ndarray] class MultilevelPixelData(BaseDataElement): """Data structure for multi-level pixel-wise annotations or predictions. All data items in ``data_fields`` of ``MultilevelPixelData`` are lists of np.ndarray or torch.Tensor, and should meet the following requirements: - Have the same length, which is the number of levels - At each level, the data should have 3 dimensions in order of channel, height and weight - At each level, the data should have the same height and weight Examples: >>> metainfo = dict(num_keypoints=17) >>> sizes = [(64, 48), (128, 96), (256, 192)] >>> heatmaps = [np.random.rand(17, h, w) for h, w in sizes] >>> masks = [torch.rand(1, h, w) for h, w in sizes] >>> data = MultilevelPixelData(metainfo=metainfo, ... heatmaps=heatmaps, ... masks=masks) >>> # get data item >>> heatmaps = data.heatmaps # A list of 3 numpy.ndarrays >>> masks = data.masks # A list of 3 torch.Tensors >>> # get level >>> data_l0 = data[0] # PixelData with fields 'heatmaps' and 'masks' >>> data.nlevel 3 >>> # get shape >>> data.shape ((64, 48), (128, 96), (256, 192)) >>> # set >>> offset_maps = [torch.rand(2, h, w) for h, w in sizes] >>> data.offset_maps = offset_maps """ def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None: object.__setattr__(self, '_nlevel', None) super().__init__(metainfo=metainfo, **kwargs) @property def nlevel(self): """Return the level number. Returns: Optional[int]: The level number, or ``None`` if the data has not been assigned. """ return self._nlevel def __getitem__(self, item: Union[int, str, list, slice]) -> Union[PixelData, Sequence]: if isinstance(item, int): if self.nlevel is None or item >= self.nlevel: raise IndexError( f'Lcale index {item} out of range ({self.nlevel})') return self.get(f'_level_{item}') if isinstance(item, str): if item not in self: raise KeyError(item) return getattr(self, item) # TODO: support indexing by list and slice over levels raise NotImplementedError( f'{self.__class__.__name__} does not support index type ' f'{type(item)}') def levels(self) -> List[PixelData]: if self.nlevel: return list(self[i] for i in range(self.nlevel)) return [] @property def shape(self) -> Optional[Tuple[Tuple]]: """Get the shape of multi-level pixel data. Returns: Optional[tuple]: A tuple of data shape at each level, or ``None`` if the data has not been assigned. """ if self.nlevel is None: return None return tuple(level.shape for level in self.levels()) def set_data(self, data: dict) -> None: """Set or change key-value pairs in ``data_field`` by parameter ``data``. Args: data (dict): A dict contains annotations of image or model predictions. """ assert isinstance(data, dict), f'meta should be a `dict` but got {data}' for k, v in data.items(): self.set_field(v, k, field_type='data') def set_field(self, value: Any, name: str, dtype: Optional[Union[Type, Tuple[Type, ...]]] = None, field_type: str = 'data') -> None: """Special method for set union field, used as property.setter functions.""" assert field_type in ['metainfo', 'data'] if dtype is not None: assert isinstance( value, dtype), f'{value} should be a {dtype} but got {type(value)}' if name.startswith('_level_'): raise AttributeError( f'Cannot set {name} to be a field because the pattern ' '<_level_{n}> is reserved for inner data field') if field_type == 'metainfo': if name in self._data_fields: raise AttributeError( f'Cannot set {name} to be a field of metainfo ' f'because {name} is already a data field') self._metainfo_fields.add(name) else: if name in self._metainfo_fields: raise AttributeError( f'Cannot set {name} to be a field of data ' f'because {name} is already a metainfo field') if not isinstance(value, abc.Sequence): raise TypeError( 'The value should be a sequence (of numpy.ndarray or' f'torch.Tesnor), but got a {type(value)}') if len(value) == 0: raise ValueError('Setting empty value is not allowed') if not isinstance(value[0], (torch.Tensor, np.ndarray)): raise TypeError( 'The value should be a sequence of numpy.ndarray or' f'torch.Tesnor, but got a sequence of {type(value[0])}') if self.nlevel is not None: assert len(value) == self.nlevel, ( f'The length of the value ({len(value)}) should match the' f'number of the levels ({self.nlevel})') else: object.__setattr__(self, '_nlevel', len(value)) for i in range(self.nlevel): object.__setattr__(self, f'_level_{i}', PixelData()) for i, v in enumerate(value): self[i].set_field(v, name, field_type='data') self._data_fields.add(name) object.__setattr__(self, name, value) def __delattr__(self, item: str): """delete the item in dataelement. Args: item (str): The key to delete. """ if item in ('_metainfo_fields', '_data_fields'): raise AttributeError(f'{item} has been used as a ' 'private attribute, which is immutable. ') if item in self._metainfo_fields: super().__delattr__(item) else: for level in self.levels(): level.__delattr__(item) self._data_fields.remove(item) def __getattr__(self, name): if name in {'_data_fields', '_metainfo_fields' } or name not in self._data_fields: raise AttributeError( f'\'{self.__class__.__name__}\' object has no attribute ' f'\'{name}\'') return [getattr(level, name) for level in self.levels()] def pop(self, *args) -> Any: """pop property in data and metainfo as the same as python.""" assert len(args) < 3, '``pop`` get more than 2 arguments' name = args[0] if name in self._metainfo_fields: self._metainfo_fields.remove(name) return self.__dict__.pop(*args) elif name in self._data_fields: self._data_fields.remove(name) return [level.pop(*args) for level in self.levels()] # with default value elif len(args) == 2: return args[1] else: # don't just use 'self.__dict__.pop(*args)' for only popping key in # metainfo or data raise KeyError(f'{args[0]} is not contained in metainfo or data') def _convert(self, apply_to: Type, func: Callable[[Any], Any]) -> 'MultilevelPixelData': """Convert data items with the given function. Args: apply_to (Type): The type of data items to apply the conversion func (Callable): The conversion function that takes a data item as the input and return the converted result Returns: MultilevelPixelData: the converted data element. """ new_data = self.new() for k, v in self.items(): if is_list_of(v, apply_to): v = [func(_v) for _v in v] data = {k: v} new_data.set_data(data) return new_data def cpu(self) -> 'MultilevelPixelData': """Convert all tensors to CPU in data.""" return self._convert(apply_to=torch.Tensor, func=lambda x: x.cpu()) def cuda(self) -> 'MultilevelPixelData': """Convert all tensors to GPU in data.""" return self._convert(apply_to=torch.Tensor, func=lambda x: x.cuda()) def detach(self) -> 'MultilevelPixelData': """Detach all tensors in data.""" return self._convert(apply_to=torch.Tensor, func=lambda x: x.detach()) def numpy(self) -> 'MultilevelPixelData': """Convert all tensor to np.narray in data.""" return self._convert( apply_to=torch.Tensor, func=lambda x: x.detach().cpu().numpy()) def to_tensor(self) -> 'MultilevelPixelData': """Convert all tensor to np.narray in data.""" return self._convert( apply_to=np.ndarray, func=lambda x: torch.from_numpy(x)) # Tensor-like methods def to(self, *args, **kwargs) -> 'MultilevelPixelData': """Apply same name function to all tensors in data_fields.""" new_data = self.new() for k, v in self.items(): if hasattr(v[0], 'to'): v = [v_.to(*args, **kwargs) for v_ in v] data = {k: v} new_data.set_data(data) return new_data