Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import os.path as osp | |
from copy import deepcopy | |
from unittest import TestCase | |
from unittest.mock import MagicMock | |
from mmengine.registry import init_default_scope | |
from mmocr.datasets import ConcatDataset, OCRDataset | |
from mmocr.registry import TRANSFORMS | |
class TestConcatDataset(TestCase): | |
class MockTransform: | |
def __init__(self, return_value): | |
self.return_value = return_value | |
def __call__(self, *args, **kwargs): | |
return self.return_value | |
def setUp(self): | |
init_default_scope('mmocr') | |
dataset = OCRDataset | |
# create dataset_a | |
data_info = dict(filename='img_1.jpg', height=720, width=1280) | |
dataset.parse_data_info = MagicMock(return_value=data_info) | |
self.dataset_a = dataset( | |
data_root=osp.join( | |
osp.dirname(__file__), '../data/det_toy_dataset'), | |
data_prefix=dict(img_path='imgs'), | |
ann_file='textdet_test.json') | |
self.dataset_a_with_pipeline = dataset( | |
data_root=osp.join( | |
osp.dirname(__file__), '../data/det_toy_dataset'), | |
data_prefix=dict(img_path='imgs'), | |
ann_file='textdet_test.json', | |
pipeline=[dict(type='MockTransform', return_value=1)]) | |
# create dataset_b | |
data_info = dict(filename='img_2.jpg', height=720, width=1280) | |
dataset.parse_data_info = MagicMock(return_value=data_info) | |
self.dataset_b = dataset( | |
data_root=osp.join( | |
osp.dirname(__file__), '../data/det_toy_dataset'), | |
data_prefix=dict(img_path='imgs'), | |
ann_file='textdet_test.json') | |
self.dataset_b_with_pipeline = dataset( | |
data_root=osp.join( | |
osp.dirname(__file__), '../data/det_toy_dataset'), | |
data_prefix=dict(img_path='imgs'), | |
ann_file='textdet_test.json', | |
pipeline=[dict(type='MockTransform', return_value=2)]) | |
def test_init(self): | |
with self.assertRaises(TypeError): | |
ConcatDataset(datasets=[0]) | |
with self.assertRaises(ValueError): | |
ConcatDataset( | |
datasets=[ | |
deepcopy(self.dataset_a_with_pipeline), | |
deepcopy(self.dataset_b) | |
], | |
pipeline=[dict(type='MockTransform', return_value=3)]) | |
with self.assertRaises(ValueError): | |
ConcatDataset( | |
datasets=[ | |
deepcopy(self.dataset_a), | |
deepcopy(self.dataset_b_with_pipeline) | |
], | |
pipeline=[dict(type='MockTransform', return_value=3)]) | |
with self.assertRaises(ValueError): | |
dataset_a = deepcopy(self.dataset_a) | |
dataset_b = OCRDataset( | |
metainfo=dict(dummy='dummy'), | |
data_root=osp.join( | |
osp.dirname(__file__), '../data/det_toy_dataset'), | |
data_prefix=dict(img_path='imgs'), | |
ann_file='textdet_test.json') | |
ConcatDataset(datasets=[dataset_a, dataset_b]) | |
# test lazy init | |
ConcatDataset( | |
datasets=[deepcopy(self.dataset_a), | |
deepcopy(self.dataset_b)], | |
pipeline=[dict(type='MockTransform', return_value=3)], | |
lazy_init=True) | |
def test_getitem(self): | |
cat_datasets = ConcatDataset( | |
datasets=[deepcopy(self.dataset_a), | |
deepcopy(self.dataset_b)], | |
pipeline=[dict(type='MockTransform', return_value=3)]) | |
for datum in cat_datasets: | |
self.assertEqual(datum, 3) | |
cat_datasets = ConcatDataset( | |
datasets=[ | |
deepcopy(self.dataset_a_with_pipeline), | |
deepcopy(self.dataset_b) | |
], | |
pipeline=[dict(type='MockTransform', return_value=3)], | |
force_apply=True) | |
for datum in cat_datasets: | |
self.assertEqual(datum, 3) | |
cat_datasets = ConcatDataset(datasets=[ | |
deepcopy(self.dataset_a_with_pipeline), | |
deepcopy(self.dataset_b_with_pipeline) | |
]) | |
self.assertEqual(cat_datasets[0], 1) | |
self.assertEqual(cat_datasets[-1], 2) | |