File size: 10,214 Bytes
cc0dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
# Copyright (c) OpenMMLab. All rights reserved.
from collections import abc
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union

import numpy as np
import torch
from mmengine.structures import BaseDataElement, PixelData
from mmengine.utils import is_list_of

IndexType = Union[str, slice, int, list, torch.LongTensor,
                  torch.cuda.LongTensor, torch.BoolTensor,
                  torch.cuda.BoolTensor, np.ndarray]


class MultilevelPixelData(BaseDataElement):
    """Data structure for multi-level pixel-wise annotations or predictions.

    All data items in ``data_fields`` of ``MultilevelPixelData`` are lists
    of np.ndarray or torch.Tensor, and should meet the following requirements:

    - Have the same length, which is the number of levels
    - At each level, the data should have 3 dimensions in order of channel,
        height and weight
    - At each level, the data should have the same height and weight

    Examples:
        >>> metainfo = dict(num_keypoints=17)
        >>> sizes = [(64, 48), (128, 96), (256, 192)]
        >>> heatmaps = [np.random.rand(17, h, w) for h, w in sizes]
        >>> masks = [torch.rand(1, h, w) for h, w in sizes]
        >>> data = MultilevelPixelData(metainfo=metainfo,
        ...                            heatmaps=heatmaps,
        ...                            masks=masks)

        >>> # get data item
        >>> heatmaps = data.heatmaps  # A list of 3 numpy.ndarrays
        >>> masks = data.masks  # A list of 3 torch.Tensors

        >>> # get level
        >>> data_l0 = data[0]  # PixelData with fields 'heatmaps' and 'masks'
        >>> data.nlevel
        3

        >>> # get shape
        >>> data.shape
        ((64, 48), (128, 96), (256, 192))

        >>> # set
        >>> offset_maps = [torch.rand(2, h, w) for h, w in sizes]
        >>> data.offset_maps = offset_maps
    """

    def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None:
        object.__setattr__(self, '_nlevel', None)
        super().__init__(metainfo=metainfo, **kwargs)

    @property
    def nlevel(self):
        """Return the level number.

        Returns:
            Optional[int]: The level number, or ``None`` if the data has not
            been assigned.
        """
        return self._nlevel

    def __getitem__(self, item: Union[int, str, list,
                                      slice]) -> Union[PixelData, Sequence]:
        if isinstance(item, int):
            if self.nlevel is None or item >= self.nlevel:
                raise IndexError(
                    f'Lcale index {item} out of range ({self.nlevel})')
            return self.get(f'_level_{item}')

        if isinstance(item, str):
            if item not in self:
                raise KeyError(item)
            return getattr(self, item)

        # TODO: support indexing by list and slice over levels
        raise NotImplementedError(
            f'{self.__class__.__name__} does not support index type '
            f'{type(item)}')

    def levels(self) -> List[PixelData]:
        if self.nlevel:
            return list(self[i] for i in range(self.nlevel))
        return []

    @property
    def shape(self) -> Optional[Tuple[Tuple]]:
        """Get the shape of multi-level pixel data.

        Returns:
            Optional[tuple]: A tuple of data shape at each level, or ``None``
            if the data has not been assigned.
        """
        if self.nlevel is None:
            return None

        return tuple(level.shape for level in self.levels())

    def set_data(self, data: dict) -> None:
        """Set or change key-value pairs in ``data_field`` by parameter
        ``data``.

        Args:
            data (dict): A dict contains annotations of image or
                model predictions.
        """
        assert isinstance(data,
                          dict), f'meta should be a `dict` but got {data}'
        for k, v in data.items():
            self.set_field(v, k, field_type='data')

    def set_field(self,
                  value: Any,
                  name: str,
                  dtype: Optional[Union[Type, Tuple[Type, ...]]] = None,
                  field_type: str = 'data') -> None:
        """Special method for set union field, used as property.setter
        functions."""
        assert field_type in ['metainfo', 'data']
        if dtype is not None:
            assert isinstance(
                value,
                dtype), f'{value} should be a {dtype} but got {type(value)}'

        if name.startswith('_level_'):
            raise AttributeError(
                f'Cannot set {name} to be a field because the pattern '
                '<_level_{n}> is reserved for inner data field')

        if field_type == 'metainfo':
            if name in self._data_fields:
                raise AttributeError(
                    f'Cannot set {name} to be a field of metainfo '
                    f'because {name} is already a data field')
            self._metainfo_fields.add(name)

        else:
            if name in self._metainfo_fields:
                raise AttributeError(
                    f'Cannot set {name} to be a field of data '
                    f'because {name} is already a metainfo field')

            if not isinstance(value, abc.Sequence):
                raise TypeError(
                    'The value should be a sequence (of numpy.ndarray or'
                    f'torch.Tesnor), but got a {type(value)}')

            if len(value) == 0:
                raise ValueError('Setting empty value is not allowed')

            if not isinstance(value[0], (torch.Tensor, np.ndarray)):
                raise TypeError(
                    'The value should be a sequence of numpy.ndarray or'
                    f'torch.Tesnor, but got a sequence of {type(value[0])}')

            if self.nlevel is not None:
                assert len(value) == self.nlevel, (
                    f'The length of the value ({len(value)}) should match the'
                    f'number of the levels ({self.nlevel})')
            else:
                object.__setattr__(self, '_nlevel', len(value))
                for i in range(self.nlevel):
                    object.__setattr__(self, f'_level_{i}', PixelData())

            for i, v in enumerate(value):
                self[i].set_field(v, name, field_type='data')

            self._data_fields.add(name)

        object.__setattr__(self, name, value)

    def __delattr__(self, item: str):
        """delete the item in dataelement.

        Args:
            item (str): The key to delete.
        """
        if item in ('_metainfo_fields', '_data_fields'):
            raise AttributeError(f'{item} has been used as a '
                                 'private attribute, which is immutable. ')

        if item in self._metainfo_fields:
            super().__delattr__(item)
        else:
            for level in self.levels():
                level.__delattr__(item)
            self._data_fields.remove(item)

    def __getattr__(self, name):
        if name in {'_data_fields', '_metainfo_fields'
                    } or name not in self._data_fields:
            raise AttributeError(
                f'\'{self.__class__.__name__}\' object has no attribute '
                f'\'{name}\'')

        return [getattr(level, name) for level in self.levels()]

    def pop(self, *args) -> Any:
        """pop property in data and metainfo as the same as python."""
        assert len(args) < 3, '``pop`` get more than 2 arguments'
        name = args[0]
        if name in self._metainfo_fields:
            self._metainfo_fields.remove(name)
            return self.__dict__.pop(*args)

        elif name in self._data_fields:
            self._data_fields.remove(name)
            return [level.pop(*args) for level in self.levels()]

        # with default value
        elif len(args) == 2:
            return args[1]
        else:
            # don't just use 'self.__dict__.pop(*args)' for only popping key in
            # metainfo or data
            raise KeyError(f'{args[0]} is not contained in metainfo or data')

    def _convert(self, apply_to: Type,
                 func: Callable[[Any], Any]) -> 'MultilevelPixelData':
        """Convert data items with the given function.

        Args:
            apply_to (Type): The type of data items to apply the conversion
            func (Callable): The conversion function that takes a data item
                as the input and return the converted result

        Returns:
            MultilevelPixelData: the converted data element.
        """
        new_data = self.new()
        for k, v in self.items():
            if is_list_of(v, apply_to):
                v = [func(_v) for _v in v]
                data = {k: v}
                new_data.set_data(data)
        return new_data

    def cpu(self) -> 'MultilevelPixelData':
        """Convert all tensors to CPU in data."""
        return self._convert(apply_to=torch.Tensor, func=lambda x: x.cpu())

    def cuda(self) -> 'MultilevelPixelData':
        """Convert all tensors to GPU in data."""
        return self._convert(apply_to=torch.Tensor, func=lambda x: x.cuda())

    def detach(self) -> 'MultilevelPixelData':
        """Detach all tensors in data."""
        return self._convert(apply_to=torch.Tensor, func=lambda x: x.detach())

    def numpy(self) -> 'MultilevelPixelData':
        """Convert all tensor to np.narray in data."""
        return self._convert(
            apply_to=torch.Tensor, func=lambda x: x.detach().cpu().numpy())

    def to_tensor(self) -> 'MultilevelPixelData':
        """Convert all tensor to np.narray in data."""
        return self._convert(
            apply_to=np.ndarray, func=lambda x: torch.from_numpy(x))

    # Tensor-like methods
    def to(self, *args, **kwargs) -> 'MultilevelPixelData':
        """Apply same name function to all tensors in data_fields."""
        new_data = self.new()
        for k, v in self.items():
            if hasattr(v[0], 'to'):
                v = [v_.to(*args, **kwargs) for v_ in v]
                data = {k: v}
                new_data.set_data(data)
        return new_data