Spaces:
Runtime error
Runtime error
File size: 10,214 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 |
# Copyright (c) OpenMMLab. All rights reserved.
from collections import abc
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union
import numpy as np
import torch
from mmengine.structures import BaseDataElement, PixelData
from mmengine.utils import is_list_of
IndexType = Union[str, slice, int, list, torch.LongTensor,
torch.cuda.LongTensor, torch.BoolTensor,
torch.cuda.BoolTensor, np.ndarray]
class MultilevelPixelData(BaseDataElement):
"""Data structure for multi-level pixel-wise annotations or predictions.
All data items in ``data_fields`` of ``MultilevelPixelData`` are lists
of np.ndarray or torch.Tensor, and should meet the following requirements:
- Have the same length, which is the number of levels
- At each level, the data should have 3 dimensions in order of channel,
height and weight
- At each level, the data should have the same height and weight
Examples:
>>> metainfo = dict(num_keypoints=17)
>>> sizes = [(64, 48), (128, 96), (256, 192)]
>>> heatmaps = [np.random.rand(17, h, w) for h, w in sizes]
>>> masks = [torch.rand(1, h, w) for h, w in sizes]
>>> data = MultilevelPixelData(metainfo=metainfo,
... heatmaps=heatmaps,
... masks=masks)
>>> # get data item
>>> heatmaps = data.heatmaps # A list of 3 numpy.ndarrays
>>> masks = data.masks # A list of 3 torch.Tensors
>>> # get level
>>> data_l0 = data[0] # PixelData with fields 'heatmaps' and 'masks'
>>> data.nlevel
3
>>> # get shape
>>> data.shape
((64, 48), (128, 96), (256, 192))
>>> # set
>>> offset_maps = [torch.rand(2, h, w) for h, w in sizes]
>>> data.offset_maps = offset_maps
"""
def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None:
object.__setattr__(self, '_nlevel', None)
super().__init__(metainfo=metainfo, **kwargs)
@property
def nlevel(self):
"""Return the level number.
Returns:
Optional[int]: The level number, or ``None`` if the data has not
been assigned.
"""
return self._nlevel
def __getitem__(self, item: Union[int, str, list,
slice]) -> Union[PixelData, Sequence]:
if isinstance(item, int):
if self.nlevel is None or item >= self.nlevel:
raise IndexError(
f'Lcale index {item} out of range ({self.nlevel})')
return self.get(f'_level_{item}')
if isinstance(item, str):
if item not in self:
raise KeyError(item)
return getattr(self, item)
# TODO: support indexing by list and slice over levels
raise NotImplementedError(
f'{self.__class__.__name__} does not support index type '
f'{type(item)}')
def levels(self) -> List[PixelData]:
if self.nlevel:
return list(self[i] for i in range(self.nlevel))
return []
@property
def shape(self) -> Optional[Tuple[Tuple]]:
"""Get the shape of multi-level pixel data.
Returns:
Optional[tuple]: A tuple of data shape at each level, or ``None``
if the data has not been assigned.
"""
if self.nlevel is None:
return None
return tuple(level.shape for level in self.levels())
def set_data(self, data: dict) -> None:
"""Set or change key-value pairs in ``data_field`` by parameter
``data``.
Args:
data (dict): A dict contains annotations of image or
model predictions.
"""
assert isinstance(data,
dict), f'meta should be a `dict` but got {data}'
for k, v in data.items():
self.set_field(v, k, field_type='data')
def set_field(self,
value: Any,
name: str,
dtype: Optional[Union[Type, Tuple[Type, ...]]] = None,
field_type: str = 'data') -> None:
"""Special method for set union field, used as property.setter
functions."""
assert field_type in ['metainfo', 'data']
if dtype is not None:
assert isinstance(
value,
dtype), f'{value} should be a {dtype} but got {type(value)}'
if name.startswith('_level_'):
raise AttributeError(
f'Cannot set {name} to be a field because the pattern '
'<_level_{n}> is reserved for inner data field')
if field_type == 'metainfo':
if name in self._data_fields:
raise AttributeError(
f'Cannot set {name} to be a field of metainfo '
f'because {name} is already a data field')
self._metainfo_fields.add(name)
else:
if name in self._metainfo_fields:
raise AttributeError(
f'Cannot set {name} to be a field of data '
f'because {name} is already a metainfo field')
if not isinstance(value, abc.Sequence):
raise TypeError(
'The value should be a sequence (of numpy.ndarray or'
f'torch.Tesnor), but got a {type(value)}')
if len(value) == 0:
raise ValueError('Setting empty value is not allowed')
if not isinstance(value[0], (torch.Tensor, np.ndarray)):
raise TypeError(
'The value should be a sequence of numpy.ndarray or'
f'torch.Tesnor, but got a sequence of {type(value[0])}')
if self.nlevel is not None:
assert len(value) == self.nlevel, (
f'The length of the value ({len(value)}) should match the'
f'number of the levels ({self.nlevel})')
else:
object.__setattr__(self, '_nlevel', len(value))
for i in range(self.nlevel):
object.__setattr__(self, f'_level_{i}', PixelData())
for i, v in enumerate(value):
self[i].set_field(v, name, field_type='data')
self._data_fields.add(name)
object.__setattr__(self, name, value)
def __delattr__(self, item: str):
"""delete the item in dataelement.
Args:
item (str): The key to delete.
"""
if item in ('_metainfo_fields', '_data_fields'):
raise AttributeError(f'{item} has been used as a '
'private attribute, which is immutable. ')
if item in self._metainfo_fields:
super().__delattr__(item)
else:
for level in self.levels():
level.__delattr__(item)
self._data_fields.remove(item)
def __getattr__(self, name):
if name in {'_data_fields', '_metainfo_fields'
} or name not in self._data_fields:
raise AttributeError(
f'\'{self.__class__.__name__}\' object has no attribute '
f'\'{name}\'')
return [getattr(level, name) for level in self.levels()]
def pop(self, *args) -> Any:
"""pop property in data and metainfo as the same as python."""
assert len(args) < 3, '``pop`` get more than 2 arguments'
name = args[0]
if name in self._metainfo_fields:
self._metainfo_fields.remove(name)
return self.__dict__.pop(*args)
elif name in self._data_fields:
self._data_fields.remove(name)
return [level.pop(*args) for level in self.levels()]
# with default value
elif len(args) == 2:
return args[1]
else:
# don't just use 'self.__dict__.pop(*args)' for only popping key in
# metainfo or data
raise KeyError(f'{args[0]} is not contained in metainfo or data')
def _convert(self, apply_to: Type,
func: Callable[[Any], Any]) -> 'MultilevelPixelData':
"""Convert data items with the given function.
Args:
apply_to (Type): The type of data items to apply the conversion
func (Callable): The conversion function that takes a data item
as the input and return the converted result
Returns:
MultilevelPixelData: the converted data element.
"""
new_data = self.new()
for k, v in self.items():
if is_list_of(v, apply_to):
v = [func(_v) for _v in v]
data = {k: v}
new_data.set_data(data)
return new_data
def cpu(self) -> 'MultilevelPixelData':
"""Convert all tensors to CPU in data."""
return self._convert(apply_to=torch.Tensor, func=lambda x: x.cpu())
def cuda(self) -> 'MultilevelPixelData':
"""Convert all tensors to GPU in data."""
return self._convert(apply_to=torch.Tensor, func=lambda x: x.cuda())
def detach(self) -> 'MultilevelPixelData':
"""Detach all tensors in data."""
return self._convert(apply_to=torch.Tensor, func=lambda x: x.detach())
def numpy(self) -> 'MultilevelPixelData':
"""Convert all tensor to np.narray in data."""
return self._convert(
apply_to=torch.Tensor, func=lambda x: x.detach().cpu().numpy())
def to_tensor(self) -> 'MultilevelPixelData':
"""Convert all tensor to np.narray in data."""
return self._convert(
apply_to=np.ndarray, func=lambda x: torch.from_numpy(x))
# Tensor-like methods
def to(self, *args, **kwargs) -> 'MultilevelPixelData':
"""Apply same name function to all tensors in data_fields."""
new_data = self.new()
for k, v in self.items():
if hasattr(v[0], 'to'):
v = [v_.to(*args, **kwargs) for v_ in v]
data = {k: v}
new_data.set_data(data)
return new_data
|