Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import collections | |
import copy | |
from typing import List, Optional, Sequence, Union | |
from mmengine.dataset import ConcatDataset, force_full_init | |
from mmseg.registry import DATASETS, TRANSFORMS | |
class MultiImageMixDataset: | |
"""A wrapper of multiple images mixed dataset. | |
Suitable for training on multiple images mixed data augmentation like | |
mosaic and mixup. | |
Args: | |
dataset (ConcatDataset or dict): The dataset to be mixed. | |
pipeline (Sequence[dict]): Sequence of transform object or | |
config dict to be composed. | |
skip_type_keys (list[str], optional): Sequence of type string to | |
be skip pipeline. Default to None. | |
""" | |
def __init__(self, | |
dataset: Union[ConcatDataset, dict], | |
pipeline: Sequence[dict], | |
skip_type_keys: Optional[List[str]] = None, | |
lazy_init: bool = False) -> None: | |
assert isinstance(pipeline, collections.abc.Sequence) | |
if isinstance(dataset, dict): | |
self.dataset = DATASETS.build(dataset) | |
elif isinstance(dataset, ConcatDataset): | |
self.dataset = dataset | |
else: | |
raise TypeError( | |
'elements in datasets sequence should be config or ' | |
f'`ConcatDataset` instance, but got {type(dataset)}') | |
if skip_type_keys is not None: | |
assert all([ | |
isinstance(skip_type_key, str) | |
for skip_type_key in skip_type_keys | |
]) | |
self._skip_type_keys = skip_type_keys | |
self.pipeline = [] | |
self.pipeline_types = [] | |
for transform in pipeline: | |
if isinstance(transform, dict): | |
self.pipeline_types.append(transform['type']) | |
transform = TRANSFORMS.build(transform) | |
self.pipeline.append(transform) | |
else: | |
raise TypeError('pipeline must be a dict') | |
self._metainfo = self.dataset.metainfo | |
self.num_samples = len(self.dataset) | |
self._fully_initialized = False | |
if not lazy_init: | |
self.full_init() | |
def metainfo(self) -> dict: | |
"""Get the meta information of the multi-image-mixed dataset. | |
Returns: | |
dict: The meta information of multi-image-mixed dataset. | |
""" | |
return copy.deepcopy(self._metainfo) | |
def full_init(self): | |
"""Loop to ``full_init`` each dataset.""" | |
if self._fully_initialized: | |
return | |
self.dataset.full_init() | |
self._ori_len = len(self.dataset) | |
self._fully_initialized = True | |
def get_data_info(self, idx: int) -> dict: | |
"""Get annotation by index. | |
Args: | |
idx (int): Global index of ``ConcatDataset``. | |
Returns: | |
dict: The idx-th annotation of the datasets. | |
""" | |
return self.dataset.get_data_info(idx) | |
def __len__(self): | |
return self.num_samples | |
def __getitem__(self, idx): | |
results = copy.deepcopy(self.dataset[idx]) | |
for (transform, transform_type) in zip(self.pipeline, | |
self.pipeline_types): | |
if self._skip_type_keys is not None and \ | |
transform_type in self._skip_type_keys: | |
continue | |
if hasattr(transform, 'get_indices'): | |
indices = transform.get_indices(self.dataset) | |
if not isinstance(indices, collections.abc.Sequence): | |
indices = [indices] | |
mix_results = [ | |
copy.deepcopy(self.dataset[index]) for index in indices | |
] | |
results['mix_results'] = mix_results | |
results = transform(results) | |
if 'mix_results' in results: | |
results.pop('mix_results') | |
return results | |
def update_skip_type_keys(self, skip_type_keys): | |
"""Update skip_type_keys. | |
It is called by an external hook. | |
Args: | |
skip_type_keys (list[str], optional): Sequence of type | |
string to be skip pipeline. | |
""" | |
assert all([ | |
isinstance(skip_type_key, str) for skip_type_key in skip_type_keys | |
]) | |
self._skip_type_keys = skip_type_keys | |