File size: 4,324 Bytes
0b4516f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from copy import deepcopy
from unittest import TestCase
from unittest.mock import MagicMock

from mmengine.registry import init_default_scope

from mmocr.datasets import ConcatDataset, OCRDataset
from mmocr.registry import TRANSFORMS


class TestConcatDataset(TestCase):

    @TRANSFORMS.register_module()
    class MockTransform:

        def __init__(self, return_value):
            self.return_value = return_value

        def __call__(self, *args, **kwargs):
            return self.return_value

    def setUp(self):

        init_default_scope('mmocr')
        dataset = OCRDataset

        # create dataset_a
        data_info = dict(filename='img_1.jpg', height=720, width=1280)
        dataset.parse_data_info = MagicMock(return_value=data_info)

        self.dataset_a = dataset(
            data_root=osp.join(
                osp.dirname(__file__), '../data/det_toy_dataset'),
            data_prefix=dict(img_path='imgs'),
            ann_file='textdet_test.json')

        self.dataset_a_with_pipeline = dataset(
            data_root=osp.join(
                osp.dirname(__file__), '../data/det_toy_dataset'),
            data_prefix=dict(img_path='imgs'),
            ann_file='textdet_test.json',
            pipeline=[dict(type='MockTransform', return_value=1)])

        # create dataset_b
        data_info = dict(filename='img_2.jpg', height=720, width=1280)
        dataset.parse_data_info = MagicMock(return_value=data_info)
        self.dataset_b = dataset(
            data_root=osp.join(
                osp.dirname(__file__), '../data/det_toy_dataset'),
            data_prefix=dict(img_path='imgs'),
            ann_file='textdet_test.json')
        self.dataset_b_with_pipeline = dataset(
            data_root=osp.join(
                osp.dirname(__file__), '../data/det_toy_dataset'),
            data_prefix=dict(img_path='imgs'),
            ann_file='textdet_test.json',
            pipeline=[dict(type='MockTransform', return_value=2)])

    def test_init(self):
        with self.assertRaises(TypeError):
            ConcatDataset(datasets=[0])
        with self.assertRaises(ValueError):
            ConcatDataset(
                datasets=[
                    deepcopy(self.dataset_a_with_pipeline),
                    deepcopy(self.dataset_b)
                ],
                pipeline=[dict(type='MockTransform', return_value=3)])

        with self.assertRaises(ValueError):
            ConcatDataset(
                datasets=[
                    deepcopy(self.dataset_a),
                    deepcopy(self.dataset_b_with_pipeline)
                ],
                pipeline=[dict(type='MockTransform', return_value=3)])
        with self.assertRaises(ValueError):
            dataset_a = deepcopy(self.dataset_a)
            dataset_b = OCRDataset(
                metainfo=dict(dummy='dummy'),
                data_root=osp.join(
                    osp.dirname(__file__), '../data/det_toy_dataset'),
                data_prefix=dict(img_path='imgs'),
                ann_file='textdet_test.json')
            ConcatDataset(datasets=[dataset_a, dataset_b])
        # test lazy init
        ConcatDataset(
            datasets=[deepcopy(self.dataset_a),
                      deepcopy(self.dataset_b)],
            pipeline=[dict(type='MockTransform', return_value=3)],
            lazy_init=True)

    def test_getitem(self):
        cat_datasets = ConcatDataset(
            datasets=[deepcopy(self.dataset_a),
                      deepcopy(self.dataset_b)],
            pipeline=[dict(type='MockTransform', return_value=3)])
        for datum in cat_datasets:
            self.assertEqual(datum, 3)

        cat_datasets = ConcatDataset(
            datasets=[
                deepcopy(self.dataset_a_with_pipeline),
                deepcopy(self.dataset_b)
            ],
            pipeline=[dict(type='MockTransform', return_value=3)],
            force_apply=True)
        for datum in cat_datasets:
            self.assertEqual(datum, 3)

        cat_datasets = ConcatDataset(datasets=[
            deepcopy(self.dataset_a_with_pipeline),
            deepcopy(self.dataset_b_with_pipeline)
        ])
        self.assertEqual(cat_datasets[0], 1)
        self.assertEqual(cat_datasets[-1], 2)