RSPrompter / mmseg /datasets /dataset_wrappers.py
KyanChen's picture
Upload 298 files
2ae34e9
raw
history blame
4.44 kB
# Copyright (c) OpenMMLab. All rights reserved.
import collections
import copy
from typing import List, Optional, Sequence, Union
from mmengine.dataset import ConcatDataset, force_full_init
from mmseg.registry import DATASETS, TRANSFORMS
@DATASETS.register_module()
class MultiImageMixDataset:
"""A wrapper of multiple images mixed dataset.
Suitable for training on multiple images mixed data augmentation like
mosaic and mixup.
Args:
dataset (ConcatDataset or dict): The dataset to be mixed.
pipeline (Sequence[dict]): Sequence of transform object or
config dict to be composed.
skip_type_keys (list[str], optional): Sequence of type string to
be skip pipeline. Default to None.
"""
def __init__(self,
dataset: Union[ConcatDataset, dict],
pipeline: Sequence[dict],
skip_type_keys: Optional[List[str]] = None,
lazy_init: bool = False) -> None:
assert isinstance(pipeline, collections.abc.Sequence)
if isinstance(dataset, dict):
self.dataset = DATASETS.build(dataset)
elif isinstance(dataset, ConcatDataset):
self.dataset = dataset
else:
raise TypeError(
'elements in datasets sequence should be config or '
f'`ConcatDataset` instance, but got {type(dataset)}')
if skip_type_keys is not None:
assert all([
isinstance(skip_type_key, str)
for skip_type_key in skip_type_keys
])
self._skip_type_keys = skip_type_keys
self.pipeline = []
self.pipeline_types = []
for transform in pipeline:
if isinstance(transform, dict):
self.pipeline_types.append(transform['type'])
transform = TRANSFORMS.build(transform)
self.pipeline.append(transform)
else:
raise TypeError('pipeline must be a dict')
self._metainfo = self.dataset.metainfo
self.num_samples = len(self.dataset)
self._fully_initialized = False
if not lazy_init:
self.full_init()
@property
def metainfo(self) -> dict:
"""Get the meta information of the multi-image-mixed dataset.
Returns:
dict: The meta information of multi-image-mixed dataset.
"""
return copy.deepcopy(self._metainfo)
def full_init(self):
"""Loop to ``full_init`` each dataset."""
if self._fully_initialized:
return
self.dataset.full_init()
self._ori_len = len(self.dataset)
self._fully_initialized = True
@force_full_init
def get_data_info(self, idx: int) -> dict:
"""Get annotation by index.
Args:
idx (int): Global index of ``ConcatDataset``.
Returns:
dict: The idx-th annotation of the datasets.
"""
return self.dataset.get_data_info(idx)
@force_full_init
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
results = copy.deepcopy(self.dataset[idx])
for (transform, transform_type) in zip(self.pipeline,
self.pipeline_types):
if self._skip_type_keys is not None and \
transform_type in self._skip_type_keys:
continue
if hasattr(transform, 'get_indices'):
indices = transform.get_indices(self.dataset)
if not isinstance(indices, collections.abc.Sequence):
indices = [indices]
mix_results = [
copy.deepcopy(self.dataset[index]) for index in indices
]
results['mix_results'] = mix_results
results = transform(results)
if 'mix_results' in results:
results.pop('mix_results')
return results
def update_skip_type_keys(self, skip_type_keys):
"""Update skip_type_keys.
It is called by an external hook.
Args:
skip_type_keys (list[str], optional): Sequence of type
string to be skip pipeline.
"""
assert all([
isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
])
self._skip_type_keys = skip_type_keys