File size: 4,777 Bytes
3bbb319 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import unittest
from unittest.mock import MagicMock, patch
import pytest
from mmdet.datasets import DATASETS
@patch('mmdet.datasets.CocoDataset.load_annotations', MagicMock())
@patch('mmdet.datasets.CustomDataset.load_annotations', MagicMock())
@patch('mmdet.datasets.XMLDataset.load_annotations', MagicMock())
@patch('mmdet.datasets.CityscapesDataset.load_annotations', MagicMock())
@patch('mmdet.datasets.CocoDataset._filter_imgs', MagicMock)
@patch('mmdet.datasets.CustomDataset._filter_imgs', MagicMock)
@patch('mmdet.datasets.XMLDataset._filter_imgs', MagicMock)
@patch('mmdet.datasets.CityscapesDataset._filter_imgs', MagicMock)
@pytest.mark.parametrize('dataset',
['CocoDataset', 'VOCDataset', 'CityscapesDataset'])
def test_custom_classes_override_default(dataset):
dataset_class = DATASETS.get(dataset)
if dataset in ['CocoDataset', 'CityscapesDataset']:
dataset_class.coco = MagicMock()
dataset_class.cat_ids = MagicMock()
original_classes = dataset_class.CLASSES
# Test setting classes as a tuple
custom_dataset = dataset_class(
ann_file=MagicMock(),
pipeline=[],
classes=('bus', 'car'),
test_mode=True,
img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ('bus', 'car')
print(custom_dataset)
# Test setting classes as a list
custom_dataset = dataset_class(
ann_file=MagicMock(),
pipeline=[],
classes=['bus', 'car'],
test_mode=True,
img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ['bus', 'car']
print(custom_dataset)
# Test overriding not a subset
custom_dataset = dataset_class(
ann_file=MagicMock(),
pipeline=[],
classes=['foo'],
test_mode=True,
img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ['foo']
print(custom_dataset)
# Test default behavior
custom_dataset = dataset_class(
ann_file=MagicMock(),
pipeline=[],
classes=None,
test_mode=True,
img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
assert custom_dataset.CLASSES == original_classes
print(custom_dataset)
# Test sending file path
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
path = tmpdir + 'classes.txt'
with open(path, 'w') as f:
f.write('bus\ncar\n')
custom_dataset = dataset_class(
ann_file=MagicMock(),
pipeline=[],
classes=path,
test_mode=True,
img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ['bus', 'car']
print(custom_dataset)
class CustomDatasetTests(unittest.TestCase):
def setUp(self):
super().setUp()
self.data_dir = osp.join(
osp.dirname(osp.dirname(osp.dirname(__file__))), 'data')
self.dataset_class = DATASETS.get('XMLDataset')
def test_data_infos__default_db_directories(self):
"""Test correct data read having a Pacal-VOC directory structure."""
test_dataset_root = osp.join(self.data_dir, 'VOCdevkit', 'VOC2007')
custom_ds = self.dataset_class(
data_root=test_dataset_root,
ann_file=osp.join(test_dataset_root, 'ImageSets', 'Main',
'trainval.txt'),
pipeline=[],
classes=('person', 'dog'),
test_mode=True)
self.assertListEqual([{
'id': '000001',
'filename': osp.join('JPEGImages', '000001.jpg'),
'width': 353,
'height': 500
}], custom_ds.data_infos)
def test_data_infos__overridden_db_subdirectories(self):
"""Test correct data read having a customized directory structure."""
test_dataset_root = osp.join(self.data_dir, 'custom_dataset')
custom_ds = self.dataset_class(
data_root=test_dataset_root,
ann_file=osp.join(test_dataset_root, 'trainval.txt'),
pipeline=[],
classes=('person', 'dog'),
test_mode=True,
img_prefix='',
img_subdir='images',
ann_subdir='images')
self.assertListEqual([{
'id': '000001',
'filename': osp.join('images', '000001.jpg'),
'width': 353,
'height': 500
}], custom_ds.data_infos)
|