Spaces:
Runtime error
Runtime error
Upload 303 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- mmpretrain/__init__.py +28 -0
- mmpretrain/annotations/WHU_building_test.json +3 -0
- mmpretrain/annotations/WHU_building_train.json +3 -0
- mmpretrain/annotations/WHU_building_val.json +0 -0
- mmpretrain/apis/__init__.py +22 -0
- mmpretrain/apis/base.py +390 -0
- mmpretrain/apis/feature_extractor.py +128 -0
- mmpretrain/apis/image_caption.py +164 -0
- mmpretrain/apis/image_classification.py +221 -0
- mmpretrain/apis/image_retrieval.py +285 -0
- mmpretrain/apis/model.py +408 -0
- mmpretrain/apis/multimodal_retrieval.py +603 -0
- mmpretrain/apis/nlvr.py +150 -0
- mmpretrain/apis/utils.py +270 -0
- mmpretrain/apis/visual_grounding.py +180 -0
- mmpretrain/apis/visual_question_answering.py +181 -0
- mmpretrain/datasets/__init__.py +54 -0
- mmpretrain/datasets/base_dataset.py +219 -0
- mmpretrain/datasets/builder.py +25 -0
- mmpretrain/datasets/caltech101.py +113 -0
- mmpretrain/datasets/categories.py +1440 -0
- mmpretrain/datasets/cifar.py +210 -0
- mmpretrain/datasets/coco_caption.py +42 -0
- mmpretrain/datasets/coco_retrieval.py +77 -0
- mmpretrain/datasets/coco_vqa.py +114 -0
- mmpretrain/datasets/cub.py +142 -0
- mmpretrain/datasets/custom.py +287 -0
- mmpretrain/datasets/dataset_wrappers.py +176 -0
- mmpretrain/datasets/dtd.py +116 -0
- mmpretrain/datasets/fgvcaircraft.py +98 -0
- mmpretrain/datasets/flamingo.py +295 -0
- mmpretrain/datasets/flowers102.py +104 -0
- mmpretrain/datasets/food101.py +102 -0
- mmpretrain/datasets/imagenet.py +102 -0
- mmpretrain/datasets/inshop.py +157 -0
- mmpretrain/datasets/mnist.py +220 -0
- mmpretrain/datasets/multi_label.py +85 -0
- mmpretrain/datasets/multi_task.py +337 -0
- mmpretrain/datasets/nlvr2.py +36 -0
- mmpretrain/datasets/oxfordiiitpet.py +97 -0
- mmpretrain/datasets/places205.py +40 -0
- mmpretrain/datasets/refcoco.py +81 -0
- mmpretrain/datasets/samplers/__init__.py +5 -0
- mmpretrain/datasets/samplers/repeat_aug.py +101 -0
- mmpretrain/datasets/samplers/sequential.py +56 -0
- mmpretrain/datasets/scienceqa.py +104 -0
- mmpretrain/datasets/stanfordcars.py +148 -0
- mmpretrain/datasets/sun397.py +225 -0
- mmpretrain/datasets/transforms/__init__.py +36 -0
.gitattributes
CHANGED
@@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
data/WHU/annotations/WHU_building_test.json filter=lfs diff=lfs merge=lfs -text
|
37 |
data/WHU/annotations/WHU_building_train.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
data/WHU/annotations/WHU_building_test.json filter=lfs diff=lfs merge=lfs -text
|
37 |
data/WHU/annotations/WHU_building_train.json filter=lfs diff=lfs merge=lfs -text
|
38 |
+
mmpretrain/annotations/WHU_building_test.json filter=lfs diff=lfs merge=lfs -text
|
39 |
+
mmpretrain/annotations/WHU_building_train.json filter=lfs diff=lfs merge=lfs -text
|
mmpretrain/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import mmcv
|
3 |
+
import mmengine
|
4 |
+
from mmengine.utils import digit_version
|
5 |
+
|
6 |
+
from .apis import * # noqa: F401, F403
|
7 |
+
from .version import __version__
|
8 |
+
|
9 |
+
mmcv_minimum_version = '2.0.0rc4'
|
10 |
+
mmcv_maximum_version = '2.1.0'
|
11 |
+
mmcv_version = digit_version(mmcv.__version__)
|
12 |
+
|
13 |
+
mmengine_minimum_version = '0.7.1'
|
14 |
+
mmengine_maximum_version = '1.0.0'
|
15 |
+
mmengine_version = digit_version(mmengine.__version__)
|
16 |
+
|
17 |
+
assert (mmcv_version >= digit_version(mmcv_minimum_version)
|
18 |
+
and mmcv_version < digit_version(mmcv_maximum_version)), \
|
19 |
+
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
20 |
+
f'Please install mmcv>={mmcv_minimum_version}, <{mmcv_maximum_version}.'
|
21 |
+
|
22 |
+
assert (mmengine_version >= digit_version(mmengine_minimum_version)
|
23 |
+
and mmengine_version < digit_version(mmengine_maximum_version)), \
|
24 |
+
f'MMEngine=={mmengine.__version__} is used but incompatible. ' \
|
25 |
+
f'Please install mmengine>={mmengine_minimum_version}, ' \
|
26 |
+
f'<{mmengine_maximum_version}.'
|
27 |
+
|
28 |
+
__all__ = ['__version__']
|
mmpretrain/annotations/WHU_building_test.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c5845dd19a3ec84aa3bc978ad5dc8066b43569c4ac9ff12c954d96208ec13432
|
3 |
+
size 13511169
|
mmpretrain/annotations/WHU_building_train.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:28c490b7c80e6900a5b4da522faee91c6251589a0c9ebb258e79221c2586d2fa
|
3 |
+
size 42910976
|
mmpretrain/annotations/WHU_building_val.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
mmpretrain/apis/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .base import BaseInferencer
|
3 |
+
from .feature_extractor import FeatureExtractor
|
4 |
+
from .image_caption import ImageCaptionInferencer
|
5 |
+
from .image_classification import ImageClassificationInferencer
|
6 |
+
from .image_retrieval import ImageRetrievalInferencer
|
7 |
+
from .model import (ModelHub, get_model, inference_model, init_model,
|
8 |
+
list_models)
|
9 |
+
from .multimodal_retrieval import (ImageToTextRetrievalInferencer,
|
10 |
+
TextToImageRetrievalInferencer)
|
11 |
+
from .nlvr import NLVRInferencer
|
12 |
+
from .visual_grounding import VisualGroundingInferencer
|
13 |
+
from .visual_question_answering import VisualQuestionAnsweringInferencer
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
'init_model', 'inference_model', 'list_models', 'get_model', 'ModelHub',
|
17 |
+
'ImageClassificationInferencer', 'ImageRetrievalInferencer',
|
18 |
+
'FeatureExtractor', 'ImageCaptionInferencer',
|
19 |
+
'TextToImageRetrievalInferencer', 'VisualGroundingInferencer',
|
20 |
+
'VisualQuestionAnsweringInferencer', 'ImageToTextRetrievalInferencer',
|
21 |
+
'BaseInferencer', 'NLVRInferencer'
|
22 |
+
]
|
mmpretrain/apis/base.py
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from abc import abstractmethod
|
3 |
+
from math import ceil
|
4 |
+
from typing import Callable, Iterable, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from mmengine.config import Config
|
9 |
+
from mmengine.dataset import default_collate
|
10 |
+
from mmengine.fileio import get_file_backend
|
11 |
+
from mmengine.model import BaseModel
|
12 |
+
from mmengine.runner import load_checkpoint
|
13 |
+
|
14 |
+
from mmpretrain.structures import DataSample
|
15 |
+
from mmpretrain.utils import track
|
16 |
+
from .model import get_model, list_models
|
17 |
+
|
18 |
+
ModelType = Union[BaseModel, str, Config]
|
19 |
+
InputType = Union[str, np.ndarray, list]
|
20 |
+
|
21 |
+
|
22 |
+
class BaseInferencer:
|
23 |
+
"""Base inferencer for various tasks.
|
24 |
+
|
25 |
+
The BaseInferencer provides the standard workflow for inference as follows:
|
26 |
+
|
27 |
+
1. Preprocess the input data by :meth:`preprocess`.
|
28 |
+
2. Forward the data to the model by :meth:`forward`. ``BaseInferencer``
|
29 |
+
assumes the model inherits from :class:`mmengine.models.BaseModel` and
|
30 |
+
will call `model.test_step` in :meth:`forward` by default.
|
31 |
+
3. Visualize the results by :meth:`visualize`.
|
32 |
+
4. Postprocess and return the results by :meth:`postprocess`.
|
33 |
+
|
34 |
+
When we call the subclasses inherited from BaseInferencer (not overriding
|
35 |
+
``__call__``), the workflow will be executed in order.
|
36 |
+
|
37 |
+
All subclasses of BaseInferencer could define the following class
|
38 |
+
attributes for customization:
|
39 |
+
|
40 |
+
- ``preprocess_kwargs``: The keys of the kwargs that will be passed to
|
41 |
+
:meth:`preprocess`.
|
42 |
+
- ``forward_kwargs``: The keys of the kwargs that will be passed to
|
43 |
+
:meth:`forward`
|
44 |
+
- ``visualize_kwargs``: The keys of the kwargs that will be passed to
|
45 |
+
:meth:`visualize`
|
46 |
+
- ``postprocess_kwargs``: The keys of the kwargs that will be passed to
|
47 |
+
:meth:`postprocess`
|
48 |
+
|
49 |
+
All attributes mentioned above should be a ``set`` of keys (strings),
|
50 |
+
and each key should not be duplicated. Actually, :meth:`__call__` will
|
51 |
+
dispatch all the arguments to the corresponding methods according to the
|
52 |
+
``xxx_kwargs`` mentioned above.
|
53 |
+
|
54 |
+
Subclasses inherited from ``BaseInferencer`` should implement
|
55 |
+
:meth:`_init_pipeline`, :meth:`visualize` and :meth:`postprocess`:
|
56 |
+
|
57 |
+
- _init_pipeline: Return a callable object to preprocess the input data.
|
58 |
+
- visualize: Visualize the results returned by :meth:`forward`.
|
59 |
+
- postprocess: Postprocess the results returned by :meth:`forward` and
|
60 |
+
:meth:`visualize`.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
model (BaseModel | str | Config): A model name or a path to the config
|
64 |
+
file, or a :obj:`BaseModel` object. The model name can be found
|
65 |
+
by ``cls.list_models()`` and you can also query it in
|
66 |
+
:doc:`/modelzoo_statistics`.
|
67 |
+
pretrained (str, optional): Path to the checkpoint. If None, it will
|
68 |
+
try to find a pre-defined weight from the model you specified
|
69 |
+
(only work if the ``model`` is a model name). Defaults to None.
|
70 |
+
device (str | torch.device | None): Transfer the model to the target
|
71 |
+
device. Defaults to None.
|
72 |
+
device_map (str | dict | None): A map that specifies where each
|
73 |
+
submodule should go. It doesn't need to be refined to each
|
74 |
+
parameter/buffer name, once a given module name is inside, every
|
75 |
+
submodule of it will be sent to the same device. You can use
|
76 |
+
`device_map="auto"` to automatically generate the device map.
|
77 |
+
Defaults to None.
|
78 |
+
offload_folder (str | None): If the `device_map` contains any value
|
79 |
+
`"disk"`, the folder where we will offload weights.
|
80 |
+
**kwargs: Other keyword arguments to initialize the model (only work if
|
81 |
+
the ``model`` is a model name).
|
82 |
+
"""
|
83 |
+
|
84 |
+
preprocess_kwargs: set = set()
|
85 |
+
forward_kwargs: set = set()
|
86 |
+
visualize_kwargs: set = set()
|
87 |
+
postprocess_kwargs: set = set()
|
88 |
+
|
89 |
+
def __init__(self,
|
90 |
+
model: ModelType,
|
91 |
+
pretrained: Union[bool, str] = True,
|
92 |
+
device: Union[str, torch.device, None] = None,
|
93 |
+
device_map=None,
|
94 |
+
offload_folder=None,
|
95 |
+
**kwargs) -> None:
|
96 |
+
|
97 |
+
if isinstance(model, BaseModel):
|
98 |
+
if isinstance(pretrained, str):
|
99 |
+
load_checkpoint(model, pretrained, map_location='cpu')
|
100 |
+
if device_map is not None:
|
101 |
+
from .utils import dispatch_model
|
102 |
+
model = dispatch_model(
|
103 |
+
model,
|
104 |
+
device_map=device_map,
|
105 |
+
offload_folder=offload_folder)
|
106 |
+
elif device is not None:
|
107 |
+
model.to(device)
|
108 |
+
else:
|
109 |
+
model = get_model(
|
110 |
+
model,
|
111 |
+
pretrained,
|
112 |
+
device=device,
|
113 |
+
device_map=device_map,
|
114 |
+
offload_folder=offload_folder,
|
115 |
+
**kwargs)
|
116 |
+
|
117 |
+
model.eval()
|
118 |
+
|
119 |
+
self.config = model._config
|
120 |
+
self.model = model
|
121 |
+
self.pipeline = self._init_pipeline(self.config)
|
122 |
+
self.visualizer = None
|
123 |
+
|
124 |
+
def __call__(
|
125 |
+
self,
|
126 |
+
inputs,
|
127 |
+
return_datasamples: bool = False,
|
128 |
+
batch_size: int = 1,
|
129 |
+
**kwargs,
|
130 |
+
) -> dict:
|
131 |
+
"""Call the inferencer.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
inputs (InputsType): Inputs for the inferencer.
|
135 |
+
return_datasamples (bool): Whether to return results as
|
136 |
+
:obj:`BaseDataElement`. Defaults to False.
|
137 |
+
batch_size (int): Batch size. Defaults to 1.
|
138 |
+
**kwargs: Key words arguments passed to :meth:`preprocess`,
|
139 |
+
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
|
140 |
+
Each key in kwargs should be in the corresponding set of
|
141 |
+
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
|
142 |
+
and ``postprocess_kwargs``.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
dict: Inference and visualization results.
|
146 |
+
"""
|
147 |
+
(
|
148 |
+
preprocess_kwargs,
|
149 |
+
forward_kwargs,
|
150 |
+
visualize_kwargs,
|
151 |
+
postprocess_kwargs,
|
152 |
+
) = self._dispatch_kwargs(**kwargs)
|
153 |
+
|
154 |
+
ori_inputs = self._inputs_to_list(inputs)
|
155 |
+
inputs = self.preprocess(
|
156 |
+
ori_inputs, batch_size=batch_size, **preprocess_kwargs)
|
157 |
+
preds = []
|
158 |
+
for data in track(
|
159 |
+
inputs, 'Inference', total=ceil(len(ori_inputs) / batch_size)):
|
160 |
+
preds.extend(self.forward(data, **forward_kwargs))
|
161 |
+
visualization = self.visualize(ori_inputs, preds, **visualize_kwargs)
|
162 |
+
results = self.postprocess(preds, visualization, return_datasamples,
|
163 |
+
**postprocess_kwargs)
|
164 |
+
return results
|
165 |
+
|
166 |
+
def _inputs_to_list(self, inputs: InputType) -> list:
|
167 |
+
"""Preprocess the inputs to a list.
|
168 |
+
|
169 |
+
Cast the input data to a list of data.
|
170 |
+
|
171 |
+
- list or tuple: return inputs
|
172 |
+
- str:
|
173 |
+
- Directory path: return all files in the directory
|
174 |
+
- other cases: return a list containing the string. The string
|
175 |
+
could be a path to file, a url or other types of string according
|
176 |
+
to the task.
|
177 |
+
- other: return a list with one item.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
inputs (str | array | list): Inputs for the inferencer.
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
list: List of input for the :meth:`preprocess`.
|
184 |
+
"""
|
185 |
+
if isinstance(inputs, str):
|
186 |
+
backend = get_file_backend(inputs)
|
187 |
+
if hasattr(backend, 'isdir') and backend.isdir(inputs):
|
188 |
+
# Backends like HttpsBackend do not implement `isdir`, so only
|
189 |
+
# those backends that implement `isdir` could accept the inputs
|
190 |
+
# as a directory
|
191 |
+
file_list = backend.list_dir_or_file(inputs, list_dir=False)
|
192 |
+
inputs = [
|
193 |
+
backend.join_path(inputs, file) for file in file_list
|
194 |
+
]
|
195 |
+
|
196 |
+
if not isinstance(inputs, (list, tuple)):
|
197 |
+
inputs = [inputs]
|
198 |
+
|
199 |
+
return list(inputs)
|
200 |
+
|
201 |
+
def preprocess(self, inputs: InputType, batch_size: int = 1, **kwargs):
|
202 |
+
"""Process the inputs into a model-feedable format.
|
203 |
+
|
204 |
+
Customize your preprocess by overriding this method. Preprocess should
|
205 |
+
return an iterable object, of which each item will be used as the
|
206 |
+
input of ``model.test_step``.
|
207 |
+
|
208 |
+
``BaseInferencer.preprocess`` will return an iterable chunked data,
|
209 |
+
which will be used in __call__ like this:
|
210 |
+
|
211 |
+
.. code-block:: python
|
212 |
+
|
213 |
+
def __call__(self, inputs, batch_size=1, **kwargs):
|
214 |
+
chunked_data = self.preprocess(inputs, batch_size, **kwargs)
|
215 |
+
for batch in chunked_data:
|
216 |
+
preds = self.forward(batch, **kwargs)
|
217 |
+
|
218 |
+
Args:
|
219 |
+
inputs (InputsType): Inputs given by user.
|
220 |
+
batch_size (int): batch size. Defaults to 1.
|
221 |
+
|
222 |
+
Yields:
|
223 |
+
Any: Data processed by the ``pipeline`` and ``default_collate``.
|
224 |
+
"""
|
225 |
+
chunked_data = self._get_chunk_data(
|
226 |
+
map(self.pipeline, inputs), batch_size)
|
227 |
+
yield from map(default_collate, chunked_data)
|
228 |
+
|
229 |
+
@torch.no_grad()
|
230 |
+
def forward(self, inputs: Union[dict, tuple], **kwargs):
|
231 |
+
"""Feed the inputs to the model."""
|
232 |
+
return self.model.test_step(inputs)
|
233 |
+
|
234 |
+
def visualize(self,
|
235 |
+
inputs: list,
|
236 |
+
preds: List[DataSample],
|
237 |
+
show: bool = False,
|
238 |
+
**kwargs) -> List[np.ndarray]:
|
239 |
+
"""Visualize predictions.
|
240 |
+
|
241 |
+
Customize your visualization by overriding this method. visualize
|
242 |
+
should return visualization results, which could be np.ndarray or any
|
243 |
+
other objects.
|
244 |
+
|
245 |
+
Args:
|
246 |
+
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
|
247 |
+
preds (Any): Predictions of the model.
|
248 |
+
show (bool): Whether to display the image in a popup window.
|
249 |
+
Defaults to False.
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
List[np.ndarray]: Visualization results.
|
253 |
+
"""
|
254 |
+
if show:
|
255 |
+
raise NotImplementedError(
|
256 |
+
f'The `visualize` method of {self.__class__.__name__} '
|
257 |
+
'is not implemented.')
|
258 |
+
|
259 |
+
@abstractmethod
|
260 |
+
def postprocess(
|
261 |
+
self,
|
262 |
+
preds: List[DataSample],
|
263 |
+
visualization: List[np.ndarray],
|
264 |
+
return_datasample=False,
|
265 |
+
**kwargs,
|
266 |
+
) -> dict:
|
267 |
+
"""Process the predictions and visualization results from ``forward``
|
268 |
+
and ``visualize``.
|
269 |
+
|
270 |
+
This method should be responsible for the following tasks:
|
271 |
+
|
272 |
+
1. Convert datasamples into a json-serializable dict if needed.
|
273 |
+
2. Pack the predictions and visualization results and return them.
|
274 |
+
3. Dump or log the predictions.
|
275 |
+
|
276 |
+
Customize your postprocess by overriding this method. Make sure
|
277 |
+
``postprocess`` will return a dict with visualization results and
|
278 |
+
inference results.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
preds (List[Dict]): Predictions of the model.
|
282 |
+
visualization (np.ndarray): Visualized predictions.
|
283 |
+
return_datasample (bool): Whether to return results as datasamples.
|
284 |
+
Defaults to False.
|
285 |
+
|
286 |
+
Returns:
|
287 |
+
dict: Inference and visualization results with key ``predictions``
|
288 |
+
and ``visualization``
|
289 |
+
|
290 |
+
- ``visualization (Any)``: Returned by :meth:`visualize`
|
291 |
+
- ``predictions`` (dict or DataSample): Returned by
|
292 |
+
:meth:`forward` and processed in :meth:`postprocess`.
|
293 |
+
If ``return_datasample=False``, it usually should be a
|
294 |
+
json-serializable dict containing only basic data elements such
|
295 |
+
as strings and numbers.
|
296 |
+
"""
|
297 |
+
|
298 |
+
@abstractmethod
|
299 |
+
def _init_pipeline(self, cfg: Config) -> Callable:
|
300 |
+
"""Initialize the test pipeline.
|
301 |
+
|
302 |
+
Return a pipeline to handle various input data, such as ``str``,
|
303 |
+
``np.ndarray``. It is an abstract method in BaseInferencer, and should
|
304 |
+
be implemented in subclasses.
|
305 |
+
|
306 |
+
The returned pipeline will be used to process a single data.
|
307 |
+
It will be used in :meth:`preprocess` like this:
|
308 |
+
|
309 |
+
.. code-block:: python
|
310 |
+
def preprocess(self, inputs, batch_size, **kwargs):
|
311 |
+
...
|
312 |
+
dataset = map(self.pipeline, dataset)
|
313 |
+
...
|
314 |
+
"""
|
315 |
+
|
316 |
+
def _get_chunk_data(self, inputs: Iterable, chunk_size: int):
|
317 |
+
"""Get batch data from dataset.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
inputs (Iterable): An iterable dataset.
|
321 |
+
chunk_size (int): Equivalent to batch size.
|
322 |
+
|
323 |
+
Yields:
|
324 |
+
list: batch data.
|
325 |
+
"""
|
326 |
+
inputs_iter = iter(inputs)
|
327 |
+
while True:
|
328 |
+
try:
|
329 |
+
chunk_data = []
|
330 |
+
for _ in range(chunk_size):
|
331 |
+
processed_data = next(inputs_iter)
|
332 |
+
chunk_data.append(processed_data)
|
333 |
+
yield chunk_data
|
334 |
+
except StopIteration:
|
335 |
+
if chunk_data:
|
336 |
+
yield chunk_data
|
337 |
+
break
|
338 |
+
|
339 |
+
def _dispatch_kwargs(self, **kwargs) -> Tuple[dict, dict, dict, dict]:
|
340 |
+
"""Dispatch kwargs to preprocess(), forward(), visualize() and
|
341 |
+
postprocess() according to the actual demands.
|
342 |
+
|
343 |
+
Returns:
|
344 |
+
Tuple[Dict, Dict, Dict, Dict]: kwargs passed to preprocess,
|
345 |
+
forward, visualize and postprocess respectively.
|
346 |
+
"""
|
347 |
+
# Ensure each argument only matches one function
|
348 |
+
method_kwargs = self.preprocess_kwargs | self.forward_kwargs | \
|
349 |
+
self.visualize_kwargs | self.postprocess_kwargs
|
350 |
+
|
351 |
+
union_kwargs = method_kwargs | set(kwargs.keys())
|
352 |
+
if union_kwargs != method_kwargs:
|
353 |
+
unknown_kwargs = union_kwargs - method_kwargs
|
354 |
+
raise ValueError(
|
355 |
+
f'unknown argument {unknown_kwargs} for `preprocess`, '
|
356 |
+
'`forward`, `visualize` and `postprocess`')
|
357 |
+
|
358 |
+
preprocess_kwargs = {}
|
359 |
+
forward_kwargs = {}
|
360 |
+
visualize_kwargs = {}
|
361 |
+
postprocess_kwargs = {}
|
362 |
+
|
363 |
+
for key, value in kwargs.items():
|
364 |
+
if key in self.preprocess_kwargs:
|
365 |
+
preprocess_kwargs[key] = value
|
366 |
+
if key in self.forward_kwargs:
|
367 |
+
forward_kwargs[key] = value
|
368 |
+
if key in self.visualize_kwargs:
|
369 |
+
visualize_kwargs[key] = value
|
370 |
+
if key in self.postprocess_kwargs:
|
371 |
+
postprocess_kwargs[key] = value
|
372 |
+
|
373 |
+
return (
|
374 |
+
preprocess_kwargs,
|
375 |
+
forward_kwargs,
|
376 |
+
visualize_kwargs,
|
377 |
+
postprocess_kwargs,
|
378 |
+
)
|
379 |
+
|
380 |
+
@staticmethod
|
381 |
+
def list_models(pattern: Optional[str] = None):
|
382 |
+
"""List models defined in metafile of corresponding packages.
|
383 |
+
|
384 |
+
Args:
|
385 |
+
pattern (str | None): A wildcard pattern to match model names.
|
386 |
+
|
387 |
+
Returns:
|
388 |
+
List[str]: a list of model names.
|
389 |
+
"""
|
390 |
+
return list_models(pattern=pattern)
|
mmpretrain/apis/feature_extractor.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import Callable, List, Optional, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from mmcv.image import imread
|
6 |
+
from mmengine.config import Config
|
7 |
+
from mmengine.dataset import Compose, default_collate
|
8 |
+
|
9 |
+
from mmpretrain.registry import TRANSFORMS
|
10 |
+
from .base import BaseInferencer, InputType
|
11 |
+
from .model import list_models
|
12 |
+
|
13 |
+
|
14 |
+
class FeatureExtractor(BaseInferencer):
|
15 |
+
"""The inferencer for extract features.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
model (BaseModel | str | Config): A model name or a path to the config
|
19 |
+
file, or a :obj:`BaseModel` object. The model name can be found
|
20 |
+
by ``FeatureExtractor.list_models()`` and you can also query it in
|
21 |
+
:doc:`/modelzoo_statistics`.
|
22 |
+
pretrained (str, optional): Path to the checkpoint. If None, it will
|
23 |
+
try to find a pre-defined weight from the model you specified
|
24 |
+
(only work if the ``model`` is a model name). Defaults to None.
|
25 |
+
device (str, optional): Device to run inference. If None, the available
|
26 |
+
device will be automatically used. Defaults to None.
|
27 |
+
**kwargs: Other keyword arguments to initialize the model (only work if
|
28 |
+
the ``model`` is a model name).
|
29 |
+
|
30 |
+
Example:
|
31 |
+
>>> from mmpretrain import FeatureExtractor
|
32 |
+
>>> inferencer = FeatureExtractor('resnet50_8xb32_in1k', backbone=dict(out_indices=(0, 1, 2, 3)))
|
33 |
+
>>> feats = inferencer('demo/demo.JPEG', stage='backbone')[0]
|
34 |
+
>>> for feat in feats:
|
35 |
+
>>> print(feat.shape)
|
36 |
+
torch.Size([256, 56, 56])
|
37 |
+
torch.Size([512, 28, 28])
|
38 |
+
torch.Size([1024, 14, 14])
|
39 |
+
torch.Size([2048, 7, 7])
|
40 |
+
""" # noqa: E501
|
41 |
+
|
42 |
+
def __call__(self,
|
43 |
+
inputs: InputType,
|
44 |
+
batch_size: int = 1,
|
45 |
+
**kwargs) -> dict:
|
46 |
+
"""Call the inferencer.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
inputs (str | array | list): The image path or array, or a list of
|
50 |
+
images.
|
51 |
+
batch_size (int): Batch size. Defaults to 1.
|
52 |
+
**kwargs: Other keyword arguments accepted by the `extract_feat`
|
53 |
+
method of the model.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
tensor | Tuple[tensor]: The extracted features.
|
57 |
+
"""
|
58 |
+
ori_inputs = self._inputs_to_list(inputs)
|
59 |
+
inputs = self.preprocess(ori_inputs, batch_size=batch_size)
|
60 |
+
preds = []
|
61 |
+
for data in inputs:
|
62 |
+
preds.extend(self.forward(data, **kwargs))
|
63 |
+
|
64 |
+
return preds
|
65 |
+
|
66 |
+
@torch.no_grad()
|
67 |
+
def forward(self, inputs: Union[dict, tuple], **kwargs):
|
68 |
+
inputs = self.model.data_preprocessor(inputs, False)['inputs']
|
69 |
+
outputs = self.model.extract_feat(inputs, **kwargs)
|
70 |
+
|
71 |
+
def scatter(feats, index):
|
72 |
+
if isinstance(feats, torch.Tensor):
|
73 |
+
return feats[index]
|
74 |
+
else:
|
75 |
+
# Sequence of tensor
|
76 |
+
return type(feats)([scatter(item, index) for item in feats])
|
77 |
+
|
78 |
+
results = []
|
79 |
+
for i in range(inputs.shape[0]):
|
80 |
+
results.append(scatter(outputs, i))
|
81 |
+
|
82 |
+
return results
|
83 |
+
|
84 |
+
def _init_pipeline(self, cfg: Config) -> Callable:
|
85 |
+
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
86 |
+
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
87 |
+
# Image loading is finished in `self.preprocess`.
|
88 |
+
test_pipeline_cfg = test_pipeline_cfg[1:]
|
89 |
+
test_pipeline = Compose(
|
90 |
+
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
91 |
+
return test_pipeline
|
92 |
+
|
93 |
+
def preprocess(self, inputs: List[InputType], batch_size: int = 1):
|
94 |
+
|
95 |
+
def load_image(input_):
|
96 |
+
img = imread(input_)
|
97 |
+
if img is None:
|
98 |
+
raise ValueError(f'Failed to read image {input_}.')
|
99 |
+
return dict(
|
100 |
+
img=img,
|
101 |
+
img_shape=img.shape[:2],
|
102 |
+
ori_shape=img.shape[:2],
|
103 |
+
)
|
104 |
+
|
105 |
+
pipeline = Compose([load_image, self.pipeline])
|
106 |
+
|
107 |
+
chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
|
108 |
+
yield from map(default_collate, chunked_data)
|
109 |
+
|
110 |
+
def visualize(self):
|
111 |
+
raise NotImplementedError(
|
112 |
+
"The FeatureExtractor doesn't support visualization.")
|
113 |
+
|
114 |
+
def postprocess(self):
|
115 |
+
raise NotImplementedError(
|
116 |
+
"The FeatureExtractor doesn't need postprocessing.")
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def list_models(pattern: Optional[str] = None):
|
120 |
+
"""List all available model names.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
pattern (str | None): A wildcard pattern to match model names.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
List[str]: a list of model names.
|
127 |
+
"""
|
128 |
+
return list_models(pattern=pattern)
|
mmpretrain/apis/image_caption.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Callable, List, Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from mmcv.image import imread
|
7 |
+
from mmengine.config import Config
|
8 |
+
from mmengine.dataset import Compose, default_collate
|
9 |
+
|
10 |
+
from mmpretrain.registry import TRANSFORMS
|
11 |
+
from mmpretrain.structures import DataSample
|
12 |
+
from .base import BaseInferencer, InputType
|
13 |
+
from .model import list_models
|
14 |
+
|
15 |
+
|
16 |
+
class ImageCaptionInferencer(BaseInferencer):
|
17 |
+
"""The inferencer for image caption.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
model (BaseModel | str | Config): A model name or a path to the config
|
21 |
+
file, or a :obj:`BaseModel` object. The model name can be found
|
22 |
+
by ``ImageCaptionInferencer.list_models()`` and you can also
|
23 |
+
query it in :doc:`/modelzoo_statistics`.
|
24 |
+
pretrained (str, optional): Path to the checkpoint. If None, it will
|
25 |
+
try to find a pre-defined weight from the model you specified
|
26 |
+
(only work if the ``model`` is a model name). Defaults to None.
|
27 |
+
device (str, optional): Device to run inference. If None, the available
|
28 |
+
device will be automatically used. Defaults to None.
|
29 |
+
**kwargs: Other keyword arguments to initialize the model (only work if
|
30 |
+
the ``model`` is a model name).
|
31 |
+
|
32 |
+
Example:
|
33 |
+
>>> from mmpretrain import ImageCaptionInferencer
|
34 |
+
>>> inferencer = ImageCaptionInferencer('blip-base_3rdparty_caption')
|
35 |
+
>>> inferencer('demo/cat-dog.png')[0]
|
36 |
+
{'pred_caption': 'a puppy and a cat sitting on a blanket'}
|
37 |
+
""" # noqa: E501
|
38 |
+
|
39 |
+
visualize_kwargs: set = {'resize', 'show', 'show_dir', 'wait_time'}
|
40 |
+
|
41 |
+
def __call__(self,
|
42 |
+
images: InputType,
|
43 |
+
return_datasamples: bool = False,
|
44 |
+
batch_size: int = 1,
|
45 |
+
**kwargs) -> dict:
|
46 |
+
"""Call the inferencer.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
images (str | array | list): The image path or array, or a list of
|
50 |
+
images.
|
51 |
+
return_datasamples (bool): Whether to return results as
|
52 |
+
:obj:`DataSample`. Defaults to False.
|
53 |
+
batch_size (int): Batch size. Defaults to 1.
|
54 |
+
resize (int, optional): Resize the short edge of the image to the
|
55 |
+
specified length before visualization. Defaults to None.
|
56 |
+
draw_score (bool): Whether to draw the prediction scores
|
57 |
+
of prediction categories. Defaults to True.
|
58 |
+
show (bool): Whether to display the visualization result in a
|
59 |
+
window. Defaults to False.
|
60 |
+
wait_time (float): The display time (s). Defaults to 0, which means
|
61 |
+
"forever".
|
62 |
+
show_dir (str, optional): If not None, save the visualization
|
63 |
+
results in the specified directory. Defaults to None.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
list: The inference results.
|
67 |
+
"""
|
68 |
+
return super().__call__(images, return_datasamples, batch_size,
|
69 |
+
**kwargs)
|
70 |
+
|
71 |
+
def _init_pipeline(self, cfg: Config) -> Callable:
|
72 |
+
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
73 |
+
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
74 |
+
# Image loading is finished in `self.preprocess`.
|
75 |
+
test_pipeline_cfg = test_pipeline_cfg[1:]
|
76 |
+
test_pipeline = Compose(
|
77 |
+
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
78 |
+
return test_pipeline
|
79 |
+
|
80 |
+
def preprocess(self, inputs: List[InputType], batch_size: int = 1):
|
81 |
+
|
82 |
+
def load_image(input_):
|
83 |
+
img = imread(input_)
|
84 |
+
if img is None:
|
85 |
+
raise ValueError(f'Failed to read image {input_}.')
|
86 |
+
return dict(
|
87 |
+
img=img,
|
88 |
+
img_shape=img.shape[:2],
|
89 |
+
ori_shape=img.shape[:2],
|
90 |
+
)
|
91 |
+
|
92 |
+
pipeline = Compose([load_image, self.pipeline])
|
93 |
+
|
94 |
+
chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
|
95 |
+
yield from map(default_collate, chunked_data)
|
96 |
+
|
97 |
+
def visualize(self,
|
98 |
+
ori_inputs: List[InputType],
|
99 |
+
preds: List[DataSample],
|
100 |
+
show: bool = False,
|
101 |
+
wait_time: int = 0,
|
102 |
+
resize: Optional[int] = None,
|
103 |
+
show_dir=None):
|
104 |
+
if not show and show_dir is None:
|
105 |
+
return None
|
106 |
+
|
107 |
+
if self.visualizer is None:
|
108 |
+
from mmpretrain.visualization import UniversalVisualizer
|
109 |
+
self.visualizer = UniversalVisualizer()
|
110 |
+
|
111 |
+
visualization = []
|
112 |
+
for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
|
113 |
+
image = imread(input_)
|
114 |
+
if isinstance(input_, str):
|
115 |
+
# The image loaded from path is BGR format.
|
116 |
+
image = image[..., ::-1]
|
117 |
+
name = Path(input_).stem
|
118 |
+
else:
|
119 |
+
name = str(i)
|
120 |
+
|
121 |
+
if show_dir is not None:
|
122 |
+
show_dir = Path(show_dir)
|
123 |
+
show_dir.mkdir(exist_ok=True)
|
124 |
+
out_file = str((show_dir / name).with_suffix('.png'))
|
125 |
+
else:
|
126 |
+
out_file = None
|
127 |
+
|
128 |
+
self.visualizer.visualize_image_caption(
|
129 |
+
image,
|
130 |
+
data_sample,
|
131 |
+
resize=resize,
|
132 |
+
show=show,
|
133 |
+
wait_time=wait_time,
|
134 |
+
name=name,
|
135 |
+
out_file=out_file)
|
136 |
+
visualization.append(self.visualizer.get_image())
|
137 |
+
if show:
|
138 |
+
self.visualizer.close()
|
139 |
+
return visualization
|
140 |
+
|
141 |
+
def postprocess(self,
|
142 |
+
preds: List[DataSample],
|
143 |
+
visualization: List[np.ndarray],
|
144 |
+
return_datasamples=False) -> dict:
|
145 |
+
if return_datasamples:
|
146 |
+
return preds
|
147 |
+
|
148 |
+
results = []
|
149 |
+
for data_sample in preds:
|
150 |
+
results.append({'pred_caption': data_sample.get('pred_caption')})
|
151 |
+
|
152 |
+
return results
|
153 |
+
|
154 |
+
@staticmethod
|
155 |
+
def list_models(pattern: Optional[str] = None):
|
156 |
+
"""List all available model names.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
pattern (str | None): A wildcard pattern to match model names.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
List[str]: a list of model names.
|
163 |
+
"""
|
164 |
+
return list_models(pattern=pattern, task='Image Caption')
|
mmpretrain/apis/image_classification.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Callable, List, Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from mmcv.image import imread
|
8 |
+
from mmengine.config import Config
|
9 |
+
from mmengine.dataset import Compose, default_collate
|
10 |
+
|
11 |
+
from mmpretrain.registry import TRANSFORMS
|
12 |
+
from mmpretrain.structures import DataSample
|
13 |
+
from .base import BaseInferencer, InputType, ModelType
|
14 |
+
from .model import list_models
|
15 |
+
|
16 |
+
|
17 |
+
class ImageClassificationInferencer(BaseInferencer):
|
18 |
+
"""The inferencer for image classification.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
model (BaseModel | str | Config): A model name or a path to the config
|
22 |
+
file, or a :obj:`BaseModel` object. The model name can be found
|
23 |
+
by ``ImageClassificationInferencer.list_models()`` and you can also
|
24 |
+
query it in :doc:`/modelzoo_statistics`.
|
25 |
+
pretrained (str, optional): Path to the checkpoint. If None, it will
|
26 |
+
try to find a pre-defined weight from the model you specified
|
27 |
+
(only work if the ``model`` is a model name). Defaults to None.
|
28 |
+
device (str, optional): Device to run inference. If None, the available
|
29 |
+
device will be automatically used. Defaults to None.
|
30 |
+
**kwargs: Other keyword arguments to initialize the model (only work if
|
31 |
+
the ``model`` is a model name).
|
32 |
+
|
33 |
+
Example:
|
34 |
+
1. Use a pre-trained model in MMPreTrain to inference an image.
|
35 |
+
|
36 |
+
>>> from mmpretrain import ImageClassificationInferencer
|
37 |
+
>>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k')
|
38 |
+
>>> inferencer('demo/demo.JPEG')
|
39 |
+
[{'pred_score': array([...]),
|
40 |
+
'pred_label': 65,
|
41 |
+
'pred_score': 0.6649367809295654,
|
42 |
+
'pred_class': 'sea snake'}]
|
43 |
+
|
44 |
+
2. Use a config file and checkpoint to inference multiple images on GPU,
|
45 |
+
and save the visualization results in a folder.
|
46 |
+
|
47 |
+
>>> from mmpretrain import ImageClassificationInferencer
|
48 |
+
>>> inferencer = ImageClassificationInferencer(
|
49 |
+
model='configs/resnet/resnet50_8xb32_in1k.py',
|
50 |
+
pretrained='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
|
51 |
+
device='cuda')
|
52 |
+
>>> inferencer(['demo/dog.jpg', 'demo/bird.JPEG'], show_dir="./visualize/")
|
53 |
+
""" # noqa: E501
|
54 |
+
|
55 |
+
visualize_kwargs: set = {
|
56 |
+
'resize', 'rescale_factor', 'draw_score', 'show', 'show_dir',
|
57 |
+
'wait_time'
|
58 |
+
}
|
59 |
+
|
60 |
+
def __init__(self,
|
61 |
+
model: ModelType,
|
62 |
+
pretrained: Union[bool, str] = True,
|
63 |
+
device: Union[str, torch.device, None] = None,
|
64 |
+
classes=None,
|
65 |
+
**kwargs) -> None:
|
66 |
+
super().__init__(
|
67 |
+
model=model, pretrained=pretrained, device=device, **kwargs)
|
68 |
+
|
69 |
+
if classes is not None:
|
70 |
+
self.classes = classes
|
71 |
+
else:
|
72 |
+
self.classes = getattr(self.model, '_dataset_meta',
|
73 |
+
{}).get('classes')
|
74 |
+
|
75 |
+
def __call__(self,
|
76 |
+
inputs: InputType,
|
77 |
+
return_datasamples: bool = False,
|
78 |
+
batch_size: int = 1,
|
79 |
+
**kwargs) -> dict:
|
80 |
+
"""Call the inferencer.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
inputs (str | array | list): The image path or array, or a list of
|
84 |
+
images.
|
85 |
+
return_datasamples (bool): Whether to return results as
|
86 |
+
:obj:`DataSample`. Defaults to False.
|
87 |
+
batch_size (int): Batch size. Defaults to 1.
|
88 |
+
resize (int, optional): Resize the short edge of the image to the
|
89 |
+
specified length before visualization. Defaults to None.
|
90 |
+
rescale_factor (float, optional): Rescale the image by the rescale
|
91 |
+
factor for visualization. This is helpful when the image is too
|
92 |
+
large or too small for visualization. Defaults to None.
|
93 |
+
draw_score (bool): Whether to draw the prediction scores
|
94 |
+
of prediction categories. Defaults to True.
|
95 |
+
show (bool): Whether to display the visualization result in a
|
96 |
+
window. Defaults to False.
|
97 |
+
wait_time (float): The display time (s). Defaults to 0, which means
|
98 |
+
"forever".
|
99 |
+
show_dir (str, optional): If not None, save the visualization
|
100 |
+
results in the specified directory. Defaults to None.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
list: The inference results.
|
104 |
+
"""
|
105 |
+
return super().__call__(
|
106 |
+
inputs,
|
107 |
+
return_datasamples=return_datasamples,
|
108 |
+
batch_size=batch_size,
|
109 |
+
**kwargs)
|
110 |
+
|
111 |
+
def _init_pipeline(self, cfg: Config) -> Callable:
|
112 |
+
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
113 |
+
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
114 |
+
# Image loading is finished in `self.preprocess`.
|
115 |
+
test_pipeline_cfg = test_pipeline_cfg[1:]
|
116 |
+
test_pipeline = Compose(
|
117 |
+
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
118 |
+
return test_pipeline
|
119 |
+
|
120 |
+
def preprocess(self, inputs: List[InputType], batch_size: int = 1):
|
121 |
+
|
122 |
+
def load_image(input_):
|
123 |
+
img = imread(input_)
|
124 |
+
if img is None:
|
125 |
+
raise ValueError(f'Failed to read image {input_}.')
|
126 |
+
return dict(
|
127 |
+
img=img,
|
128 |
+
img_shape=img.shape[:2],
|
129 |
+
ori_shape=img.shape[:2],
|
130 |
+
)
|
131 |
+
|
132 |
+
pipeline = Compose([load_image, self.pipeline])
|
133 |
+
|
134 |
+
chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
|
135 |
+
yield from map(default_collate, chunked_data)
|
136 |
+
|
137 |
+
def visualize(self,
|
138 |
+
ori_inputs: List[InputType],
|
139 |
+
preds: List[DataSample],
|
140 |
+
show: bool = False,
|
141 |
+
wait_time: int = 0,
|
142 |
+
resize: Optional[int] = None,
|
143 |
+
rescale_factor: Optional[float] = None,
|
144 |
+
draw_score=True,
|
145 |
+
show_dir=None):
|
146 |
+
if not show and show_dir is None:
|
147 |
+
return None
|
148 |
+
|
149 |
+
if self.visualizer is None:
|
150 |
+
from mmpretrain.visualization import UniversalVisualizer
|
151 |
+
self.visualizer = UniversalVisualizer()
|
152 |
+
|
153 |
+
visualization = []
|
154 |
+
for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
|
155 |
+
image = imread(input_)
|
156 |
+
if isinstance(input_, str):
|
157 |
+
# The image loaded from path is BGR format.
|
158 |
+
image = image[..., ::-1]
|
159 |
+
name = Path(input_).stem
|
160 |
+
else:
|
161 |
+
name = str(i)
|
162 |
+
|
163 |
+
if show_dir is not None:
|
164 |
+
show_dir = Path(show_dir)
|
165 |
+
show_dir.mkdir(exist_ok=True)
|
166 |
+
out_file = str((show_dir / name).with_suffix('.png'))
|
167 |
+
else:
|
168 |
+
out_file = None
|
169 |
+
|
170 |
+
self.visualizer.visualize_cls(
|
171 |
+
image,
|
172 |
+
data_sample,
|
173 |
+
classes=self.classes,
|
174 |
+
resize=resize,
|
175 |
+
show=show,
|
176 |
+
wait_time=wait_time,
|
177 |
+
rescale_factor=rescale_factor,
|
178 |
+
draw_gt=False,
|
179 |
+
draw_pred=True,
|
180 |
+
draw_score=draw_score,
|
181 |
+
name=name,
|
182 |
+
out_file=out_file)
|
183 |
+
visualization.append(self.visualizer.get_image())
|
184 |
+
if show:
|
185 |
+
self.visualizer.close()
|
186 |
+
return visualization
|
187 |
+
|
188 |
+
def postprocess(self,
|
189 |
+
preds: List[DataSample],
|
190 |
+
visualization: List[np.ndarray],
|
191 |
+
return_datasamples=False) -> dict:
|
192 |
+
if return_datasamples:
|
193 |
+
return preds
|
194 |
+
|
195 |
+
results = []
|
196 |
+
for data_sample in preds:
|
197 |
+
pred_scores = data_sample.pred_score
|
198 |
+
pred_score = float(torch.max(pred_scores).item())
|
199 |
+
pred_label = torch.argmax(pred_scores).item()
|
200 |
+
result = {
|
201 |
+
'pred_scores': pred_scores.detach().cpu().numpy(),
|
202 |
+
'pred_label': pred_label,
|
203 |
+
'pred_score': pred_score,
|
204 |
+
}
|
205 |
+
if self.classes is not None:
|
206 |
+
result['pred_class'] = self.classes[pred_label]
|
207 |
+
results.append(result)
|
208 |
+
|
209 |
+
return results
|
210 |
+
|
211 |
+
@staticmethod
|
212 |
+
def list_models(pattern: Optional[str] = None):
|
213 |
+
"""List all available model names.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
pattern (str | None): A wildcard pattern to match model names.
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
List[str]: a list of model names.
|
220 |
+
"""
|
221 |
+
return list_models(pattern=pattern, task='Image Classification')
|
mmpretrain/apis/image_retrieval.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Callable, List, Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from mmcv.image import imread
|
8 |
+
from mmengine.config import Config
|
9 |
+
from mmengine.dataset import BaseDataset, Compose, default_collate
|
10 |
+
|
11 |
+
from mmpretrain.registry import TRANSFORMS
|
12 |
+
from mmpretrain.structures import DataSample
|
13 |
+
from .base import BaseInferencer, InputType, ModelType
|
14 |
+
from .model import list_models
|
15 |
+
|
16 |
+
|
17 |
+
class ImageRetrievalInferencer(BaseInferencer):
|
18 |
+
"""The inferencer for image to image retrieval.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
model (BaseModel | str | Config): A model name or a path to the config
|
22 |
+
file, or a :obj:`BaseModel` object. The model name can be found
|
23 |
+
by ``ImageRetrievalInferencer.list_models()`` and you can also
|
24 |
+
query it in :doc:`/modelzoo_statistics`.
|
25 |
+
prototype (str | list | dict | DataLoader, BaseDataset): The images to
|
26 |
+
be retrieved. It can be the following types:
|
27 |
+
|
28 |
+
- str: The directory of the the images.
|
29 |
+
- list: A list of path of the images.
|
30 |
+
- dict: A config dict of the a prototype dataset.
|
31 |
+
- BaseDataset: A prototype dataset.
|
32 |
+
- DataLoader: A data loader to load the prototype data.
|
33 |
+
|
34 |
+
prototype_cache (str, optional): The path of the generated prototype
|
35 |
+
features. If exists, directly load the cache instead of re-generate
|
36 |
+
the prototype features. If not exists, save the generated features
|
37 |
+
to the path. Defaults to None.
|
38 |
+
pretrained (str, optional): Path to the checkpoint. If None, it will
|
39 |
+
try to find a pre-defined weight from the model you specified
|
40 |
+
(only work if the ``model`` is a model name). Defaults to None.
|
41 |
+
device (str, optional): Device to run inference. If None, the available
|
42 |
+
device will be automatically used. Defaults to None.
|
43 |
+
**kwargs: Other keyword arguments to initialize the model (only work if
|
44 |
+
the ``model`` is a model name).
|
45 |
+
|
46 |
+
Example:
|
47 |
+
>>> from mmpretrain import ImageRetrievalInferencer
|
48 |
+
>>> inferencer = ImageRetrievalInferencer(
|
49 |
+
... 'resnet50-arcface_8xb32_inshop',
|
50 |
+
... prototype='./demo/',
|
51 |
+
... prototype_cache='img_retri.pth')
|
52 |
+
>>> inferencer('demo/cat-dog.png', topk=2)[0][1]
|
53 |
+
{'match_score': tensor(0.4088, device='cuda:0'),
|
54 |
+
'sample_idx': 3,
|
55 |
+
'sample': {'img_path': './demo/dog.jpg'}}
|
56 |
+
""" # noqa: E501
|
57 |
+
|
58 |
+
visualize_kwargs: set = {
|
59 |
+
'draw_score', 'resize', 'show_dir', 'show', 'wait_time', 'topk'
|
60 |
+
}
|
61 |
+
postprocess_kwargs: set = {'topk'}
|
62 |
+
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
model: ModelType,
|
66 |
+
prototype,
|
67 |
+
prototype_cache=None,
|
68 |
+
prepare_batch_size=8,
|
69 |
+
pretrained: Union[bool, str] = True,
|
70 |
+
device: Union[str, torch.device, None] = None,
|
71 |
+
**kwargs,
|
72 |
+
) -> None:
|
73 |
+
super().__init__(
|
74 |
+
model=model, pretrained=pretrained, device=device, **kwargs)
|
75 |
+
|
76 |
+
self.prototype_dataset = self._prepare_prototype(
|
77 |
+
prototype, prototype_cache, prepare_batch_size)
|
78 |
+
|
79 |
+
def _prepare_prototype(self, prototype, cache=None, batch_size=8):
|
80 |
+
from mmengine.dataset import DefaultSampler
|
81 |
+
from torch.utils.data import DataLoader
|
82 |
+
|
83 |
+
def build_dataloader(dataset):
|
84 |
+
return DataLoader(
|
85 |
+
dataset,
|
86 |
+
batch_size=batch_size,
|
87 |
+
collate_fn=default_collate,
|
88 |
+
sampler=DefaultSampler(dataset, shuffle=False),
|
89 |
+
persistent_workers=False,
|
90 |
+
)
|
91 |
+
|
92 |
+
if isinstance(prototype, str):
|
93 |
+
# A directory path of images
|
94 |
+
prototype = dict(
|
95 |
+
type='CustomDataset', with_label=False, data_root=prototype)
|
96 |
+
|
97 |
+
if isinstance(prototype, list):
|
98 |
+
test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline]
|
99 |
+
dataset = BaseDataset(
|
100 |
+
lazy_init=True, serialize_data=False, pipeline=test_pipeline)
|
101 |
+
dataset.data_list = [{
|
102 |
+
'sample_idx': i,
|
103 |
+
'img_path': file
|
104 |
+
} for i, file in enumerate(prototype)]
|
105 |
+
dataset._fully_initialized = True
|
106 |
+
dataloader = build_dataloader(dataset)
|
107 |
+
elif isinstance(prototype, dict):
|
108 |
+
# A config of dataset
|
109 |
+
from mmpretrain.registry import DATASETS
|
110 |
+
test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline]
|
111 |
+
dataset = DATASETS.build(prototype)
|
112 |
+
dataloader = build_dataloader(dataset)
|
113 |
+
elif isinstance(prototype, DataLoader):
|
114 |
+
dataset = prototype.dataset
|
115 |
+
dataloader = prototype
|
116 |
+
elif isinstance(prototype, BaseDataset):
|
117 |
+
dataset = prototype
|
118 |
+
dataloader = build_dataloader(dataset)
|
119 |
+
else:
|
120 |
+
raise TypeError(f'Unsupported prototype type {type(prototype)}.')
|
121 |
+
|
122 |
+
if cache is not None and Path(cache).exists():
|
123 |
+
self.model.prototype = cache
|
124 |
+
else:
|
125 |
+
self.model.prototype = dataloader
|
126 |
+
self.model.prepare_prototype()
|
127 |
+
|
128 |
+
from mmengine.logging import MMLogger
|
129 |
+
logger = MMLogger.get_current_instance()
|
130 |
+
if cache is None:
|
131 |
+
logger.info('The prototype has been prepared, you can use '
|
132 |
+
'`save_prototype` to dump it into a pickle '
|
133 |
+
'file for the future usage.')
|
134 |
+
elif not Path(cache).exists():
|
135 |
+
self.save_prototype(cache)
|
136 |
+
logger.info(f'The prototype has been saved at {cache}.')
|
137 |
+
|
138 |
+
return dataset
|
139 |
+
|
140 |
+
def save_prototype(self, path):
|
141 |
+
self.model.dump_prototype(path)
|
142 |
+
|
143 |
+
def __call__(self,
|
144 |
+
inputs: InputType,
|
145 |
+
return_datasamples: bool = False,
|
146 |
+
batch_size: int = 1,
|
147 |
+
**kwargs) -> dict:
|
148 |
+
"""Call the inferencer.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
inputs (str | array | list): The image path or array, or a list of
|
152 |
+
images.
|
153 |
+
return_datasamples (bool): Whether to return results as
|
154 |
+
:obj:`DataSample`. Defaults to False.
|
155 |
+
batch_size (int): Batch size. Defaults to 1.
|
156 |
+
resize (int, optional): Resize the long edge of the image to the
|
157 |
+
specified length before visualization. Defaults to None.
|
158 |
+
draw_score (bool): Whether to draw the match scores.
|
159 |
+
Defaults to True.
|
160 |
+
show (bool): Whether to display the visualization result in a
|
161 |
+
window. Defaults to False.
|
162 |
+
wait_time (float): The display time (s). Defaults to 0, which means
|
163 |
+
"forever".
|
164 |
+
show_dir (str, optional): If not None, save the visualization
|
165 |
+
results in the specified directory. Defaults to None.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
list: The inference results.
|
169 |
+
"""
|
170 |
+
return super().__call__(inputs, return_datasamples, batch_size,
|
171 |
+
**kwargs)
|
172 |
+
|
173 |
+
def _init_pipeline(self, cfg: Config) -> Callable:
|
174 |
+
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
175 |
+
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
176 |
+
# Image loading is finished in `self.preprocess`.
|
177 |
+
test_pipeline_cfg = test_pipeline_cfg[1:]
|
178 |
+
test_pipeline = Compose(
|
179 |
+
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
180 |
+
return test_pipeline
|
181 |
+
|
182 |
+
def preprocess(self, inputs: List[InputType], batch_size: int = 1):
|
183 |
+
|
184 |
+
def load_image(input_):
|
185 |
+
img = imread(input_)
|
186 |
+
if img is None:
|
187 |
+
raise ValueError(f'Failed to read image {input_}.')
|
188 |
+
return dict(
|
189 |
+
img=img,
|
190 |
+
img_shape=img.shape[:2],
|
191 |
+
ori_shape=img.shape[:2],
|
192 |
+
)
|
193 |
+
|
194 |
+
pipeline = Compose([load_image, self.pipeline])
|
195 |
+
|
196 |
+
chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
|
197 |
+
yield from map(default_collate, chunked_data)
|
198 |
+
|
199 |
+
def visualize(self,
|
200 |
+
ori_inputs: List[InputType],
|
201 |
+
preds: List[DataSample],
|
202 |
+
topk: int = 3,
|
203 |
+
resize: Optional[int] = 224,
|
204 |
+
show: bool = False,
|
205 |
+
wait_time: int = 0,
|
206 |
+
draw_score=True,
|
207 |
+
show_dir=None):
|
208 |
+
if not show and show_dir is None:
|
209 |
+
return None
|
210 |
+
|
211 |
+
if self.visualizer is None:
|
212 |
+
from mmpretrain.visualization import UniversalVisualizer
|
213 |
+
self.visualizer = UniversalVisualizer()
|
214 |
+
|
215 |
+
visualization = []
|
216 |
+
for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
|
217 |
+
image = imread(input_)
|
218 |
+
if isinstance(input_, str):
|
219 |
+
# The image loaded from path is BGR format.
|
220 |
+
image = image[..., ::-1]
|
221 |
+
name = Path(input_).stem
|
222 |
+
else:
|
223 |
+
name = str(i)
|
224 |
+
|
225 |
+
if show_dir is not None:
|
226 |
+
show_dir = Path(show_dir)
|
227 |
+
show_dir.mkdir(exist_ok=True)
|
228 |
+
out_file = str((show_dir / name).with_suffix('.png'))
|
229 |
+
else:
|
230 |
+
out_file = None
|
231 |
+
|
232 |
+
self.visualizer.visualize_image_retrieval(
|
233 |
+
image,
|
234 |
+
data_sample,
|
235 |
+
self.prototype_dataset,
|
236 |
+
topk=topk,
|
237 |
+
resize=resize,
|
238 |
+
draw_score=draw_score,
|
239 |
+
show=show,
|
240 |
+
wait_time=wait_time,
|
241 |
+
name=name,
|
242 |
+
out_file=out_file)
|
243 |
+
visualization.append(self.visualizer.get_image())
|
244 |
+
if show:
|
245 |
+
self.visualizer.close()
|
246 |
+
return visualization
|
247 |
+
|
248 |
+
def postprocess(
|
249 |
+
self,
|
250 |
+
preds: List[DataSample],
|
251 |
+
visualization: List[np.ndarray],
|
252 |
+
return_datasamples=False,
|
253 |
+
topk=1,
|
254 |
+
) -> dict:
|
255 |
+
if return_datasamples:
|
256 |
+
return preds
|
257 |
+
|
258 |
+
results = []
|
259 |
+
for data_sample in preds:
|
260 |
+
match_scores, indices = torch.topk(data_sample.pred_score, k=topk)
|
261 |
+
matches = []
|
262 |
+
for match_score, sample_idx in zip(match_scores, indices):
|
263 |
+
sample = self.prototype_dataset.get_data_info(
|
264 |
+
sample_idx.item())
|
265 |
+
sample_idx = sample.pop('sample_idx')
|
266 |
+
matches.append({
|
267 |
+
'match_score': match_score,
|
268 |
+
'sample_idx': sample_idx,
|
269 |
+
'sample': sample
|
270 |
+
})
|
271 |
+
results.append(matches)
|
272 |
+
|
273 |
+
return results
|
274 |
+
|
275 |
+
@staticmethod
|
276 |
+
def list_models(pattern: Optional[str] = None):
|
277 |
+
"""List all available model names.
|
278 |
+
|
279 |
+
Args:
|
280 |
+
pattern (str | None): A wildcard pattern to match model names.
|
281 |
+
|
282 |
+
Returns:
|
283 |
+
List[str]: a list of model names.
|
284 |
+
"""
|
285 |
+
return list_models(pattern=pattern, task='Image Retrieval')
|
mmpretrain/apis/model.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
import fnmatch
|
4 |
+
import os.path as osp
|
5 |
+
import re
|
6 |
+
import warnings
|
7 |
+
from os import PathLike
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import List, Tuple, Union
|
10 |
+
|
11 |
+
from mmengine.config import Config
|
12 |
+
from modelindex.load_model_index import load
|
13 |
+
from modelindex.models.Model import Model
|
14 |
+
|
15 |
+
|
16 |
+
class ModelHub:
|
17 |
+
"""A hub to host the meta information of all pre-defined models."""
|
18 |
+
_models_dict = {}
|
19 |
+
__mmpretrain_registered = False
|
20 |
+
|
21 |
+
@classmethod
|
22 |
+
def register_model_index(cls,
|
23 |
+
model_index_path: Union[str, PathLike],
|
24 |
+
config_prefix: Union[str, PathLike, None] = None):
|
25 |
+
"""Parse the model-index file and register all models.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
model_index_path (str | PathLike): The path of the model-index
|
29 |
+
file.
|
30 |
+
config_prefix (str | PathLike | None): The prefix of all config
|
31 |
+
file paths in the model-index file.
|
32 |
+
"""
|
33 |
+
model_index = load(str(model_index_path))
|
34 |
+
model_index.build_models_with_collections()
|
35 |
+
|
36 |
+
for metainfo in model_index.models:
|
37 |
+
model_name = metainfo.name.lower()
|
38 |
+
if metainfo.name in cls._models_dict:
|
39 |
+
raise ValueError(
|
40 |
+
'The model name {} is conflict in {} and {}.'.format(
|
41 |
+
model_name, osp.abspath(metainfo.filepath),
|
42 |
+
osp.abspath(cls._models_dict[model_name].filepath)))
|
43 |
+
metainfo.config = cls._expand_config_path(metainfo, config_prefix)
|
44 |
+
cls._models_dict[model_name] = metainfo
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def get(cls, model_name):
|
48 |
+
"""Get the model's metainfo by the model name.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
model_name (str): The name of model.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
modelindex.models.Model: The metainfo of the specified model.
|
55 |
+
"""
|
56 |
+
cls._register_mmpretrain_models()
|
57 |
+
# lazy load config
|
58 |
+
metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower()))
|
59 |
+
if metainfo is None:
|
60 |
+
raise ValueError(
|
61 |
+
f'Failed to find model "{model_name}". please use '
|
62 |
+
'`mmpretrain.list_models` to get all available names.')
|
63 |
+
if isinstance(metainfo.config, str):
|
64 |
+
metainfo.config = Config.fromfile(metainfo.config)
|
65 |
+
return metainfo
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def _expand_config_path(metainfo: Model,
|
69 |
+
config_prefix: Union[str, PathLike] = None):
|
70 |
+
if config_prefix is None:
|
71 |
+
config_prefix = osp.dirname(metainfo.filepath)
|
72 |
+
|
73 |
+
if metainfo.config is None or osp.isabs(metainfo.config):
|
74 |
+
config_path: str = metainfo.config
|
75 |
+
else:
|
76 |
+
config_path = osp.abspath(osp.join(config_prefix, metainfo.config))
|
77 |
+
|
78 |
+
return config_path
|
79 |
+
|
80 |
+
@classmethod
|
81 |
+
def _register_mmpretrain_models(cls):
|
82 |
+
# register models in mmpretrain
|
83 |
+
if not cls.__mmpretrain_registered:
|
84 |
+
from importlib_metadata import distribution
|
85 |
+
root = distribution('mmpretrain').locate_file('mmpretrain')
|
86 |
+
model_index_path = root / '.mim' / 'model-index.yml'
|
87 |
+
ModelHub.register_model_index(
|
88 |
+
model_index_path, config_prefix=root / '.mim')
|
89 |
+
cls.__mmpretrain_registered = True
|
90 |
+
|
91 |
+
@classmethod
|
92 |
+
def has(cls, model_name):
|
93 |
+
"""Whether a model name is in the ModelHub."""
|
94 |
+
return model_name in cls._models_dict
|
95 |
+
|
96 |
+
|
97 |
+
def get_model(model: Union[str, Config],
|
98 |
+
pretrained: Union[str, bool] = False,
|
99 |
+
device=None,
|
100 |
+
device_map=None,
|
101 |
+
offload_folder=None,
|
102 |
+
url_mapping: Tuple[str, str] = None,
|
103 |
+
**kwargs):
|
104 |
+
"""Get a pre-defined model or create a model from config.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
model (str | Config): The name of model, the config file path or a
|
108 |
+
config instance.
|
109 |
+
pretrained (bool | str): When use name to specify model, you can
|
110 |
+
use ``True`` to load the pre-defined pretrained weights. And you
|
111 |
+
can also use a string to specify the path or link of weights to
|
112 |
+
load. Defaults to False.
|
113 |
+
device (str | torch.device | None): Transfer the model to the target
|
114 |
+
device. Defaults to None.
|
115 |
+
device_map (str | dict | None): A map that specifies where each
|
116 |
+
submodule should go. It doesn't need to be refined to each
|
117 |
+
parameter/buffer name, once a given module name is inside, every
|
118 |
+
submodule of it will be sent to the same device. You can use
|
119 |
+
`device_map="auto"` to automatically generate the device map.
|
120 |
+
Defaults to None.
|
121 |
+
offload_folder (str | None): If the `device_map` contains any value
|
122 |
+
`"disk"`, the folder where we will offload weights.
|
123 |
+
url_mapping (Tuple[str, str], optional): The mapping of pretrained
|
124 |
+
checkpoint link. For example, load checkpoint from a local dir
|
125 |
+
instead of download by ``('https://.*/', './checkpoint')``.
|
126 |
+
Defaults to None.
|
127 |
+
**kwargs: Other keyword arguments of the model config.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
mmengine.model.BaseModel: The result model.
|
131 |
+
|
132 |
+
Examples:
|
133 |
+
Get a ResNet-50 model and extract images feature:
|
134 |
+
|
135 |
+
>>> import torch
|
136 |
+
>>> from mmpretrain import get_model
|
137 |
+
>>> inputs = torch.rand(16, 3, 224, 224)
|
138 |
+
>>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)))
|
139 |
+
>>> feats = model.extract_feat(inputs)
|
140 |
+
>>> for feat in feats:
|
141 |
+
... print(feat.shape)
|
142 |
+
torch.Size([16, 256])
|
143 |
+
torch.Size([16, 512])
|
144 |
+
torch.Size([16, 1024])
|
145 |
+
torch.Size([16, 2048])
|
146 |
+
|
147 |
+
Get Swin-Transformer model with pre-trained weights and inference:
|
148 |
+
|
149 |
+
>>> from mmpretrain import get_model, inference_model
|
150 |
+
>>> model = get_model('swin-base_16xb64_in1k', pretrained=True)
|
151 |
+
>>> result = inference_model(model, 'demo/demo.JPEG')
|
152 |
+
>>> print(result['pred_class'])
|
153 |
+
'sea snake'
|
154 |
+
""" # noqa: E501
|
155 |
+
if device_map is not None:
|
156 |
+
from .utils import dispatch_model
|
157 |
+
dispatch_model._verify_require()
|
158 |
+
|
159 |
+
metainfo = None
|
160 |
+
if isinstance(model, Config):
|
161 |
+
config = copy.deepcopy(model)
|
162 |
+
if pretrained is True and 'load_from' in config:
|
163 |
+
pretrained = config.load_from
|
164 |
+
elif isinstance(model, (str, PathLike)) and Path(model).suffix == '.py':
|
165 |
+
config = Config.fromfile(model)
|
166 |
+
if pretrained is True and 'load_from' in config:
|
167 |
+
pretrained = config.load_from
|
168 |
+
elif isinstance(model, str):
|
169 |
+
metainfo = ModelHub.get(model)
|
170 |
+
config = metainfo.config
|
171 |
+
if pretrained is True and metainfo.weights is not None:
|
172 |
+
pretrained = metainfo.weights
|
173 |
+
else:
|
174 |
+
raise TypeError('model must be a name, a path or a Config object, '
|
175 |
+
f'but got {type(config)}')
|
176 |
+
|
177 |
+
if pretrained is True:
|
178 |
+
warnings.warn('Unable to find pre-defined checkpoint of the model.')
|
179 |
+
pretrained = None
|
180 |
+
elif pretrained is False:
|
181 |
+
pretrained = None
|
182 |
+
|
183 |
+
if kwargs:
|
184 |
+
config.merge_from_dict({'model': kwargs})
|
185 |
+
config.model.setdefault('data_preprocessor',
|
186 |
+
config.get('data_preprocessor', None))
|
187 |
+
|
188 |
+
from mmengine.registry import DefaultScope
|
189 |
+
|
190 |
+
from mmpretrain.registry import MODELS
|
191 |
+
with DefaultScope.overwrite_default_scope('mmpretrain'):
|
192 |
+
model = MODELS.build(config.model)
|
193 |
+
|
194 |
+
dataset_meta = {}
|
195 |
+
if pretrained:
|
196 |
+
# Mapping the weights to GPU may cause unexpected video memory leak
|
197 |
+
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
|
198 |
+
from mmengine.runner import load_checkpoint
|
199 |
+
if url_mapping is not None:
|
200 |
+
pretrained = re.sub(url_mapping[0], url_mapping[1], pretrained)
|
201 |
+
checkpoint = load_checkpoint(model, pretrained, map_location='cpu')
|
202 |
+
if 'dataset_meta' in checkpoint.get('meta', {}):
|
203 |
+
# mmpretrain 1.x
|
204 |
+
dataset_meta = checkpoint['meta']['dataset_meta']
|
205 |
+
elif 'CLASSES' in checkpoint.get('meta', {}):
|
206 |
+
# mmcls 0.x
|
207 |
+
dataset_meta = {'classes': checkpoint['meta']['CLASSES']}
|
208 |
+
|
209 |
+
if len(dataset_meta) == 0 and 'test_dataloader' in config:
|
210 |
+
from mmpretrain.registry import DATASETS
|
211 |
+
dataset_class = DATASETS.get(config.test_dataloader.dataset.type)
|
212 |
+
dataset_meta = getattr(dataset_class, 'METAINFO', {})
|
213 |
+
|
214 |
+
if device_map is not None:
|
215 |
+
model = dispatch_model(
|
216 |
+
model, device_map=device_map, offload_folder=offload_folder)
|
217 |
+
elif device is not None:
|
218 |
+
model.to(device)
|
219 |
+
|
220 |
+
model._dataset_meta = dataset_meta # save the dataset meta
|
221 |
+
model._config = config # save the config in the model
|
222 |
+
model._metainfo = metainfo # save the metainfo in the model
|
223 |
+
model.eval()
|
224 |
+
return model
|
225 |
+
|
226 |
+
|
227 |
+
def init_model(config, checkpoint=None, device=None, **kwargs):
|
228 |
+
"""Initialize a classifier from config file (deprecated).
|
229 |
+
|
230 |
+
It's only for compatibility, please use :func:`get_model` instead.
|
231 |
+
|
232 |
+
Args:
|
233 |
+
config (str | :obj:`mmengine.Config`): Config file path or the config
|
234 |
+
object.
|
235 |
+
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
236 |
+
will not load any weights.
|
237 |
+
device (str | torch.device | None): Transfer the model to the target
|
238 |
+
device. Defaults to None.
|
239 |
+
**kwargs: Other keyword arguments of the model config.
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
nn.Module: The constructed model.
|
243 |
+
"""
|
244 |
+
return get_model(config, checkpoint, device, **kwargs)
|
245 |
+
|
246 |
+
|
247 |
+
def list_models(pattern=None, exclude_patterns=None, task=None) -> List[str]:
|
248 |
+
"""List all models available in MMPretrain.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
pattern (str | None): A wildcard pattern to match model names.
|
252 |
+
Defaults to None.
|
253 |
+
exclude_patterns (list | None): A list of wildcard patterns to
|
254 |
+
exclude names from the matched names. Defaults to None.
|
255 |
+
task (str | none): The evaluation task of the model.
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
List[str]: a list of model names.
|
259 |
+
|
260 |
+
Examples:
|
261 |
+
List all models:
|
262 |
+
|
263 |
+
>>> from mmpretrain import list_models
|
264 |
+
>>> list_models()
|
265 |
+
|
266 |
+
List ResNet-50 models on ImageNet-1k dataset:
|
267 |
+
|
268 |
+
>>> from mmpretrain import list_models
|
269 |
+
>>> list_models('resnet*in1k')
|
270 |
+
['resnet50_8xb32_in1k',
|
271 |
+
'resnet50_8xb32-fp16_in1k',
|
272 |
+
'resnet50_8xb256-rsb-a1-600e_in1k',
|
273 |
+
'resnet50_8xb256-rsb-a2-300e_in1k',
|
274 |
+
'resnet50_8xb256-rsb-a3-100e_in1k']
|
275 |
+
|
276 |
+
List Swin-Transformer models trained from stratch and exclude
|
277 |
+
Swin-Transformer-V2 models:
|
278 |
+
|
279 |
+
>>> from mmpretrain import list_models
|
280 |
+
>>> list_models('swin', exclude_patterns=['swinv2', '*-pre'])
|
281 |
+
['swin-base_16xb64_in1k',
|
282 |
+
'swin-base_3rdparty_in1k',
|
283 |
+
'swin-base_3rdparty_in1k-384',
|
284 |
+
'swin-large_8xb8_cub-384px',
|
285 |
+
'swin-small_16xb64_in1k',
|
286 |
+
'swin-small_3rdparty_in1k',
|
287 |
+
'swin-tiny_16xb64_in1k',
|
288 |
+
'swin-tiny_3rdparty_in1k']
|
289 |
+
|
290 |
+
List all EVA models for image classification task.
|
291 |
+
|
292 |
+
>>> from mmpretrain import list_models
|
293 |
+
>>> list_models('eva', task='Image Classification')
|
294 |
+
['eva-g-p14_30m-in21k-pre_3rdparty_in1k-336px',
|
295 |
+
'eva-g-p14_30m-in21k-pre_3rdparty_in1k-560px',
|
296 |
+
'eva-l-p14_mim-in21k-pre_3rdparty_in1k-196px',
|
297 |
+
'eva-l-p14_mim-in21k-pre_3rdparty_in1k-336px',
|
298 |
+
'eva-l-p14_mim-pre_3rdparty_in1k-196px',
|
299 |
+
'eva-l-p14_mim-pre_3rdparty_in1k-336px']
|
300 |
+
"""
|
301 |
+
ModelHub._register_mmpretrain_models()
|
302 |
+
matches = set(ModelHub._models_dict.keys())
|
303 |
+
|
304 |
+
if pattern is not None:
|
305 |
+
# Always match keys with any postfix.
|
306 |
+
matches = set(fnmatch.filter(matches, pattern + '*'))
|
307 |
+
|
308 |
+
exclude_patterns = exclude_patterns or []
|
309 |
+
for exclude_pattern in exclude_patterns:
|
310 |
+
exclude = set(fnmatch.filter(matches, exclude_pattern + '*'))
|
311 |
+
matches = matches - exclude
|
312 |
+
|
313 |
+
if task is not None:
|
314 |
+
task_matches = []
|
315 |
+
for key in matches:
|
316 |
+
metainfo = ModelHub._models_dict[key]
|
317 |
+
if metainfo.results is None and task == 'null':
|
318 |
+
task_matches.append(key)
|
319 |
+
elif metainfo.results is None:
|
320 |
+
continue
|
321 |
+
elif task in [result.task for result in metainfo.results]:
|
322 |
+
task_matches.append(key)
|
323 |
+
matches = task_matches
|
324 |
+
|
325 |
+
return sorted(list(matches))
|
326 |
+
|
327 |
+
|
328 |
+
def inference_model(model, *args, **kwargs):
|
329 |
+
"""Inference an image with the inferencer.
|
330 |
+
|
331 |
+
Automatically select inferencer to inference according to the type of
|
332 |
+
model. It's a shortcut for a quick start, and for advanced usage, please
|
333 |
+
use the correspondding inferencer class.
|
334 |
+
|
335 |
+
Here is the mapping from task to inferencer:
|
336 |
+
|
337 |
+
- Image Classification: :class:`ImageClassificationInferencer`
|
338 |
+
- Image Retrieval: :class:`ImageRetrievalInferencer`
|
339 |
+
- Image Caption: :class:`ImageCaptionInferencer`
|
340 |
+
- Visual Question Answering: :class:`VisualQuestionAnsweringInferencer`
|
341 |
+
- Visual Grounding: :class:`VisualGroundingInferencer`
|
342 |
+
- Text-To-Image Retrieval: :class:`TextToImageRetrievalInferencer`
|
343 |
+
- Image-To-Text Retrieval: :class:`ImageToTextRetrievalInferencer`
|
344 |
+
- NLVR: :class:`NLVRInferencer`
|
345 |
+
|
346 |
+
Args:
|
347 |
+
model (BaseModel | str | Config): The loaded model, the model
|
348 |
+
name or the config of the model.
|
349 |
+
*args: Positional arguments to call the inferencer.
|
350 |
+
**kwargs: Other keyword arguments to initialize and call the
|
351 |
+
correspondding inferencer.
|
352 |
+
|
353 |
+
Returns:
|
354 |
+
result (dict): The inference results.
|
355 |
+
""" # noqa: E501
|
356 |
+
from mmengine.model import BaseModel
|
357 |
+
|
358 |
+
if isinstance(model, BaseModel):
|
359 |
+
metainfo = getattr(model, '_metainfo', None)
|
360 |
+
else:
|
361 |
+
metainfo = ModelHub.get(model)
|
362 |
+
|
363 |
+
from inspect import signature
|
364 |
+
|
365 |
+
from .image_caption import ImageCaptionInferencer
|
366 |
+
from .image_classification import ImageClassificationInferencer
|
367 |
+
from .image_retrieval import ImageRetrievalInferencer
|
368 |
+
from .multimodal_retrieval import (ImageToTextRetrievalInferencer,
|
369 |
+
TextToImageRetrievalInferencer)
|
370 |
+
from .nlvr import NLVRInferencer
|
371 |
+
from .visual_grounding import VisualGroundingInferencer
|
372 |
+
from .visual_question_answering import VisualQuestionAnsweringInferencer
|
373 |
+
task_mapping = {
|
374 |
+
'Image Classification': ImageClassificationInferencer,
|
375 |
+
'Image Retrieval': ImageRetrievalInferencer,
|
376 |
+
'Image Caption': ImageCaptionInferencer,
|
377 |
+
'Visual Question Answering': VisualQuestionAnsweringInferencer,
|
378 |
+
'Visual Grounding': VisualGroundingInferencer,
|
379 |
+
'Text-To-Image Retrieval': TextToImageRetrievalInferencer,
|
380 |
+
'Image-To-Text Retrieval': ImageToTextRetrievalInferencer,
|
381 |
+
'NLVR': NLVRInferencer,
|
382 |
+
}
|
383 |
+
|
384 |
+
inferencer_type = None
|
385 |
+
|
386 |
+
if metainfo is not None and metainfo.results is not None:
|
387 |
+
tasks = set(result.task for result in metainfo.results)
|
388 |
+
inferencer_type = [
|
389 |
+
task_mapping.get(task) for task in tasks if task in task_mapping
|
390 |
+
]
|
391 |
+
if len(inferencer_type) > 1:
|
392 |
+
inferencer_names = [cls.__name__ for cls in inferencer_type]
|
393 |
+
warnings.warn('The model supports multiple tasks, auto select '
|
394 |
+
f'{inferencer_names[0]}, you can also use other '
|
395 |
+
f'inferencer {inferencer_names} directly.')
|
396 |
+
inferencer_type = inferencer_type[0]
|
397 |
+
|
398 |
+
if inferencer_type is None:
|
399 |
+
raise NotImplementedError('No available inferencer for the model')
|
400 |
+
|
401 |
+
init_kwargs = {
|
402 |
+
k: kwargs.pop(k)
|
403 |
+
for k in list(kwargs)
|
404 |
+
if k in signature(inferencer_type).parameters.keys()
|
405 |
+
}
|
406 |
+
|
407 |
+
inferencer = inferencer_type(model, **init_kwargs)
|
408 |
+
return inferencer(*args, **kwargs)[0]
|
mmpretrain/apis/multimodal_retrieval.py
ADDED
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from copy import deepcopy
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Callable, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import mmengine
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from mmcv.image import imread
|
10 |
+
from mmengine.config import Config
|
11 |
+
from mmengine.dataset import BaseDataset, Compose, default_collate
|
12 |
+
|
13 |
+
from mmpretrain.registry import TRANSFORMS
|
14 |
+
from mmpretrain.structures import DataSample
|
15 |
+
from mmpretrain.utils import track
|
16 |
+
from .base import BaseInferencer
|
17 |
+
from .base import InputType as ImageType
|
18 |
+
from .base import ModelType
|
19 |
+
from .model import list_models
|
20 |
+
|
21 |
+
|
22 |
+
def filter_transforms(transforms: list, data_info: dict):
|
23 |
+
"""Filter pipeline to avoid KeyError with partial data info."""
|
24 |
+
data_info = deepcopy(data_info)
|
25 |
+
filtered_transforms = []
|
26 |
+
for t in transforms:
|
27 |
+
try:
|
28 |
+
data_info = t(data_info)
|
29 |
+
filtered_transforms.append(t)
|
30 |
+
except KeyError:
|
31 |
+
pass
|
32 |
+
return filtered_transforms
|
33 |
+
|
34 |
+
|
35 |
+
class TextToImageRetrievalInferencer(BaseInferencer):
|
36 |
+
"""The inferencer for text to image retrieval.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
model (BaseModel | str | Config): A model name or a path to the config
|
40 |
+
file, or a :obj:`BaseModel` object. The model name can be found
|
41 |
+
by ``TextToImageRetrievalInferencer.list_models()`` and you can also
|
42 |
+
query it in :doc:`/modelzoo_statistics`.
|
43 |
+
prototype (str | list | dict | DataLoader | BaseDataset): The images to
|
44 |
+
be retrieved. It can be the following types:
|
45 |
+
|
46 |
+
- str: The directory of the the images.
|
47 |
+
- list: A list of path of the images.
|
48 |
+
- dict: A config dict of the a prototype dataset.
|
49 |
+
- BaseDataset: A prototype dataset.
|
50 |
+
- DataLoader: A data loader to load the prototype data.
|
51 |
+
|
52 |
+
prototype_cache (str, optional): The path of the generated prototype
|
53 |
+
features. If exists, directly load the cache instead of re-generate
|
54 |
+
the prototype features. If not exists, save the generated features
|
55 |
+
to the path. Defaults to None.
|
56 |
+
fast_match (bool): Some algorithms will record extra image features for
|
57 |
+
further matching, which may consume large memory, set True to avoid
|
58 |
+
this behavior. Defaults to True.
|
59 |
+
pretrained (str, optional): Path to the checkpoint. If None, it will
|
60 |
+
try to find a pre-defined weight from the model you specified
|
61 |
+
(only work if the ``model`` is a model name). Defaults to None.
|
62 |
+
device (str, optional): Device to run inference. If None, the available
|
63 |
+
device will be automatically used. Defaults to None.
|
64 |
+
**kwargs: Other keyword arguments to initialize the model (only work if
|
65 |
+
the ``model`` is a model name).
|
66 |
+
|
67 |
+
Example:
|
68 |
+
>>> from mmpretrain import TextToImageRetrievalInferencer
|
69 |
+
>>> inferencer = TextToImageRetrievalInferencer(
|
70 |
+
... 'blip-base_3rdparty_retrieval',
|
71 |
+
... prototype='./demo/',
|
72 |
+
... prototype_cache='t2i_retri.pth')
|
73 |
+
>>> inferencer('A cat and a dog.')[0]
|
74 |
+
{'match_score': tensor(0.3855, device='cuda:0'),
|
75 |
+
'sample_idx': 1,
|
76 |
+
'sample': {'img_path': './demo/cat-dog.png'}}
|
77 |
+
""" # noqa: E501
|
78 |
+
|
79 |
+
visualize_kwargs: set = {
|
80 |
+
'draw_score', 'show_dir', 'show', 'wait_time', 'figsize', 'topk'
|
81 |
+
}
|
82 |
+
postprocess_kwargs: set = {'topk'}
|
83 |
+
|
84 |
+
def __init__(self,
|
85 |
+
model: ModelType,
|
86 |
+
prototype,
|
87 |
+
prototype_cache=None,
|
88 |
+
fast_match=True,
|
89 |
+
prepare_batch_size=8,
|
90 |
+
pretrained: Union[bool, str] = True,
|
91 |
+
device: Union[str, torch.device, None] = None,
|
92 |
+
**kwargs) -> None:
|
93 |
+
super().__init__(
|
94 |
+
model=model, pretrained=pretrained, device=device, **kwargs)
|
95 |
+
|
96 |
+
self.img_pipeline, self.text_pipeline = self.pipeline
|
97 |
+
|
98 |
+
if hasattr(self.model, 'fast_match'):
|
99 |
+
self.model.fast_match = fast_match
|
100 |
+
|
101 |
+
self.prototype_dataset = self._prepare_prototype(
|
102 |
+
prototype, prototype_cache, batch_size=prepare_batch_size)
|
103 |
+
|
104 |
+
def _prepare_prototype(self, prototype, cache=None, batch_size=8):
|
105 |
+
from mmengine.dataset import DefaultSampler
|
106 |
+
from torch.utils.data import DataLoader
|
107 |
+
|
108 |
+
def build_dataloader(dataset):
|
109 |
+
return DataLoader(
|
110 |
+
dataset,
|
111 |
+
batch_size=batch_size,
|
112 |
+
collate_fn=default_collate,
|
113 |
+
sampler=DefaultSampler(dataset, shuffle=False),
|
114 |
+
persistent_workers=False,
|
115 |
+
)
|
116 |
+
|
117 |
+
if isinstance(prototype, str):
|
118 |
+
# A directory path of images
|
119 |
+
prototype = dict(
|
120 |
+
type='CustomDataset', with_label=False, data_root=prototype)
|
121 |
+
|
122 |
+
if isinstance(prototype, list):
|
123 |
+
test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline]
|
124 |
+
dataset = BaseDataset(
|
125 |
+
lazy_init=True, serialize_data=False, pipeline=test_pipeline)
|
126 |
+
dataset.data_list = [{
|
127 |
+
'sample_idx': i,
|
128 |
+
'img_path': file
|
129 |
+
} for i, file in enumerate(prototype)]
|
130 |
+
dataset._fully_initialized = True
|
131 |
+
dataloader = build_dataloader(dataset)
|
132 |
+
elif isinstance(prototype, dict):
|
133 |
+
# A config of dataset
|
134 |
+
from mmpretrain.registry import DATASETS
|
135 |
+
test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline]
|
136 |
+
prototype.setdefault('pipeline', test_pipeline)
|
137 |
+
dataset = DATASETS.build(prototype)
|
138 |
+
dataloader = build_dataloader(dataset)
|
139 |
+
elif isinstance(prototype, list):
|
140 |
+
test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline]
|
141 |
+
dataset = BaseDataset(
|
142 |
+
lazy_init=True, serialize_data=False, pipeline=test_pipeline)
|
143 |
+
dataset.data_list = [{
|
144 |
+
'sample_idx': i,
|
145 |
+
'img_path': file
|
146 |
+
} for i, file in enumerate(prototype)]
|
147 |
+
dataset._fully_initialized = True
|
148 |
+
dataloader = build_dataloader(dataset)
|
149 |
+
elif isinstance(prototype, DataLoader):
|
150 |
+
dataset = prototype.dataset
|
151 |
+
dataloader = prototype
|
152 |
+
elif isinstance(prototype, BaseDataset):
|
153 |
+
dataset = prototype
|
154 |
+
dataloader = build_dataloader(dataset)
|
155 |
+
else:
|
156 |
+
raise TypeError(f'Unsupported prototype type {type(prototype)}.')
|
157 |
+
|
158 |
+
if cache is not None and Path(cache).exists():
|
159 |
+
self.prototype = torch.load(cache)
|
160 |
+
else:
|
161 |
+
prototype = []
|
162 |
+
for data_batch in track(dataloader, 'Prepare prototype...'):
|
163 |
+
with torch.no_grad():
|
164 |
+
data_batch = self.model.data_preprocessor(
|
165 |
+
data_batch, False)
|
166 |
+
feats = self.model._run_forward(data_batch, mode='tensor')
|
167 |
+
prototype.append(feats)
|
168 |
+
prototype = {
|
169 |
+
k: torch.cat([d[k] for d in prototype])
|
170 |
+
for k in prototype[0]
|
171 |
+
}
|
172 |
+
self.prototype = prototype
|
173 |
+
|
174 |
+
from mmengine.logging import MMLogger
|
175 |
+
logger = MMLogger.get_current_instance()
|
176 |
+
if cache is None:
|
177 |
+
logger.info('The prototype has been prepared, you can use '
|
178 |
+
'`save_prototype` to dump it into a pickle '
|
179 |
+
'file for the future usage.')
|
180 |
+
elif not Path(cache).exists():
|
181 |
+
self.save_prototype(cache)
|
182 |
+
logger.info(f'The prototype has been saved at {cache}.')
|
183 |
+
|
184 |
+
return dataset
|
185 |
+
|
186 |
+
def save_prototype(self, path):
|
187 |
+
torch.save(self.prototype, path)
|
188 |
+
|
189 |
+
def __call__(self,
|
190 |
+
inputs: ImageType,
|
191 |
+
return_datasamples: bool = False,
|
192 |
+
batch_size: int = 1,
|
193 |
+
**kwargs) -> dict:
|
194 |
+
"""Call the inferencer.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
inputs (str | array | list): The image path or array, or a list of
|
198 |
+
images.
|
199 |
+
return_datasamples (bool): Whether to return results as
|
200 |
+
:obj:`DataSample`. Defaults to False.
|
201 |
+
batch_size (int): Batch size. Defaults to 1.
|
202 |
+
resize (int, optional): Resize the long edge of the image to the
|
203 |
+
specified length before visualization. Defaults to None.
|
204 |
+
draw_score (bool): Whether to draw the match scores.
|
205 |
+
Defaults to True.
|
206 |
+
show (bool): Whether to display the visualization result in a
|
207 |
+
window. Defaults to False.
|
208 |
+
wait_time (float): The display time (s). Defaults to 0, which means
|
209 |
+
"forever".
|
210 |
+
show_dir (str, optional): If not None, save the visualization
|
211 |
+
results in the specified directory. Defaults to None.
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
list: The inference results.
|
215 |
+
"""
|
216 |
+
return super().__call__(inputs, return_datasamples, batch_size,
|
217 |
+
**kwargs)
|
218 |
+
|
219 |
+
@torch.no_grad()
|
220 |
+
def forward(self, data: dict, **kwargs):
|
221 |
+
"""Feed the inputs to the model."""
|
222 |
+
data = self.model.data_preprocessor(data, False)
|
223 |
+
data_samples = data['data_samples']
|
224 |
+
feats = self.prototype.copy()
|
225 |
+
feats.update(self.model.extract_feat(data_samples=data_samples))
|
226 |
+
return self.model.predict_all(feats, data_samples, cal_i2t=False)[0]
|
227 |
+
|
228 |
+
def _init_pipeline(self, cfg: Config) -> Callable:
|
229 |
+
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
230 |
+
test_transfroms = [TRANSFORMS.build(t) for t in test_pipeline_cfg]
|
231 |
+
img_info = {'img': np.zeros((224, 224, 3), dtype=np.uint8)}
|
232 |
+
text_info = {'text': 'example'}
|
233 |
+
img_pipeline = Compose(filter_transforms(test_transfroms, img_info))
|
234 |
+
text_pipeline = Compose(filter_transforms(test_transfroms, text_info))
|
235 |
+
return img_pipeline, text_pipeline
|
236 |
+
|
237 |
+
def preprocess(self, inputs: List[str], batch_size: int = 1):
|
238 |
+
|
239 |
+
def process_text(input_: str):
|
240 |
+
return self.text_pipeline({'text': input_})
|
241 |
+
|
242 |
+
chunked_data = self._get_chunk_data(
|
243 |
+
map(process_text, inputs), batch_size)
|
244 |
+
yield from map(default_collate, chunked_data)
|
245 |
+
|
246 |
+
def visualize(self,
|
247 |
+
ori_inputs: List[str],
|
248 |
+
preds: List[DataSample],
|
249 |
+
topk: int = 3,
|
250 |
+
figsize: Tuple[int, int] = (16, 9),
|
251 |
+
show: bool = False,
|
252 |
+
wait_time: int = 0,
|
253 |
+
draw_score=True,
|
254 |
+
show_dir=None):
|
255 |
+
if not show and show_dir is None:
|
256 |
+
return None
|
257 |
+
|
258 |
+
if self.visualizer is None:
|
259 |
+
from mmpretrain.visualization import UniversalVisualizer
|
260 |
+
self.visualizer = UniversalVisualizer()
|
261 |
+
|
262 |
+
visualization = []
|
263 |
+
for i, (text, data_sample) in enumerate(zip(ori_inputs, preds)):
|
264 |
+
name = str(i)
|
265 |
+
|
266 |
+
if show_dir is not None:
|
267 |
+
show_dir = Path(show_dir)
|
268 |
+
show_dir.mkdir(exist_ok=True)
|
269 |
+
out_file = str((show_dir / name).with_suffix('.png'))
|
270 |
+
else:
|
271 |
+
out_file = None
|
272 |
+
|
273 |
+
self.visualizer.visualize_t2i_retrieval(
|
274 |
+
text,
|
275 |
+
data_sample,
|
276 |
+
self.prototype_dataset,
|
277 |
+
topk=topk,
|
278 |
+
fig_cfg=dict(figsize=figsize),
|
279 |
+
draw_score=draw_score,
|
280 |
+
show=show,
|
281 |
+
wait_time=wait_time,
|
282 |
+
name=name,
|
283 |
+
out_file=out_file)
|
284 |
+
visualization.append(self.visualizer.get_image())
|
285 |
+
if show:
|
286 |
+
self.visualizer.close()
|
287 |
+
return visualization
|
288 |
+
|
289 |
+
def postprocess(
|
290 |
+
self,
|
291 |
+
preds: List[DataSample],
|
292 |
+
visualization: List[np.ndarray],
|
293 |
+
return_datasamples=False,
|
294 |
+
topk=1,
|
295 |
+
) -> dict:
|
296 |
+
if return_datasamples:
|
297 |
+
return preds
|
298 |
+
|
299 |
+
results = []
|
300 |
+
for data_sample in preds:
|
301 |
+
match_scores, indices = torch.topk(data_sample.pred_score, k=topk)
|
302 |
+
matches = []
|
303 |
+
for match_score, sample_idx in zip(match_scores, indices):
|
304 |
+
sample = self.prototype_dataset.get_data_info(
|
305 |
+
sample_idx.item())
|
306 |
+
sample_idx = sample.pop('sample_idx')
|
307 |
+
matches.append({
|
308 |
+
'match_score': match_score,
|
309 |
+
'sample_idx': sample_idx,
|
310 |
+
'sample': sample
|
311 |
+
})
|
312 |
+
results.append(matches)
|
313 |
+
|
314 |
+
return results
|
315 |
+
|
316 |
+
@staticmethod
|
317 |
+
def list_models(pattern: Optional[str] = None):
|
318 |
+
"""List all available model names.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
pattern (str | None): A wildcard pattern to match model names.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
List[str]: a list of model names.
|
325 |
+
"""
|
326 |
+
return list_models(pattern=pattern, task='Text-To-Image Retrieval')
|
327 |
+
|
328 |
+
|
329 |
+
class ImageToTextRetrievalInferencer(BaseInferencer):
|
330 |
+
"""The inferencer for image to text retrieval.
|
331 |
+
|
332 |
+
Args:
|
333 |
+
model (BaseModel | str | Config): A model name or a path to the config
|
334 |
+
file, or a :obj:`BaseModel` object. The model name can be found
|
335 |
+
by ``ImageToTextRetrievalInferencer.list_models()`` and you can
|
336 |
+
also query it in :doc:`/modelzoo_statistics`.
|
337 |
+
prototype (str | list | dict | DataLoader, BaseDataset): The images to
|
338 |
+
be retrieved. It can be the following types:
|
339 |
+
|
340 |
+
- str: The file path to load the string list.
|
341 |
+
- list: A list of string.
|
342 |
+
|
343 |
+
prototype_cache (str, optional): The path of the generated prototype
|
344 |
+
features. If exists, directly load the cache instead of re-generate
|
345 |
+
the prototype features. If not exists, save the generated features
|
346 |
+
to the path. Defaults to None.
|
347 |
+
fast_match (bool): Some algorithms will record extra image features for
|
348 |
+
further matching, which may consume large memory, set True to avoid
|
349 |
+
this behavior. Defaults to True.
|
350 |
+
pretrained (str, optional): Path to the checkpoint. If None, it will
|
351 |
+
try to find a pre-defined weight from the model you specified
|
352 |
+
(only work if the ``model`` is a model name). Defaults to None.
|
353 |
+
device (str, optional): Device to run inference. If None, the available
|
354 |
+
device will be automatically used. Defaults to None.
|
355 |
+
**kwargs: Other keyword arguments to initialize the model (only work if
|
356 |
+
the ``model`` is a model name).
|
357 |
+
|
358 |
+
Example:
|
359 |
+
>>> from mmpretrain import ImageToTextRetrievalInferencer
|
360 |
+
>>> inferencer = ImageToTextRetrievalInferencer(
|
361 |
+
... 'blip-base_3rdparty_retrieval',
|
362 |
+
... prototype=['cat', 'dog', 'snake', 'bird'],
|
363 |
+
... prototype_cache='i2t_retri.pth')
|
364 |
+
>>> inferencer('demo/bird.JPEG')[0]
|
365 |
+
{'match_score': tensor(0.3855, device='cuda:0'),
|
366 |
+
'sample_idx': 1,
|
367 |
+
'sample': {'img_path': './demo/cat-dog.png'}}
|
368 |
+
""" # noqa: E501
|
369 |
+
|
370 |
+
visualize_kwargs: set = {
|
371 |
+
'draw_score', 'resize', 'show_dir', 'show', 'wait_time', 'topk'
|
372 |
+
}
|
373 |
+
postprocess_kwargs: set = {'topk'}
|
374 |
+
|
375 |
+
def __init__(self,
|
376 |
+
model: ModelType,
|
377 |
+
prototype,
|
378 |
+
prototype_cache=None,
|
379 |
+
fast_match=True,
|
380 |
+
prepare_batch_size=8,
|
381 |
+
pretrained: Union[bool, str] = True,
|
382 |
+
device: Union[str, torch.device, None] = None,
|
383 |
+
**kwargs) -> None:
|
384 |
+
super().__init__(
|
385 |
+
model=model, pretrained=pretrained, device=device, **kwargs)
|
386 |
+
|
387 |
+
self.img_pipeline, self.text_pipeline = self.pipeline
|
388 |
+
|
389 |
+
if hasattr(self.model, 'fast_match'):
|
390 |
+
self.model.fast_match = fast_match
|
391 |
+
|
392 |
+
self.prototype_dataset = self._prepare_prototype(
|
393 |
+
prototype, cache=prototype_cache, batch_size=prepare_batch_size)
|
394 |
+
|
395 |
+
def _prepare_prototype(self, prototype, cache=None, batch_size=8):
|
396 |
+
from mmengine.dataset import DefaultSampler
|
397 |
+
from torch.utils.data import DataLoader
|
398 |
+
|
399 |
+
def build_dataloader(dataset):
|
400 |
+
return DataLoader(
|
401 |
+
[
|
402 |
+
self.text_pipeline({
|
403 |
+
'sample_idx': i,
|
404 |
+
'text': text
|
405 |
+
}) for i, text in enumerate(dataset)
|
406 |
+
],
|
407 |
+
batch_size=batch_size,
|
408 |
+
collate_fn=default_collate,
|
409 |
+
sampler=DefaultSampler(dataset, shuffle=False),
|
410 |
+
persistent_workers=False,
|
411 |
+
)
|
412 |
+
|
413 |
+
if isinstance(prototype, str):
|
414 |
+
# A file path of a list of string
|
415 |
+
dataset = mmengine.list_from_file(prototype)
|
416 |
+
elif mmengine.utils.is_seq_of(prototype, str):
|
417 |
+
dataset = prototype
|
418 |
+
else:
|
419 |
+
raise TypeError(f'Unsupported prototype type {type(prototype)}.')
|
420 |
+
|
421 |
+
dataloader = build_dataloader(dataset)
|
422 |
+
|
423 |
+
if cache is not None and Path(cache).exists():
|
424 |
+
self.prototype = torch.load(cache)
|
425 |
+
else:
|
426 |
+
prototype = []
|
427 |
+
for data_batch in track(dataloader, 'Prepare prototype...'):
|
428 |
+
with torch.no_grad():
|
429 |
+
data_batch = self.model.data_preprocessor(
|
430 |
+
data_batch, False)
|
431 |
+
feats = self.model._run_forward(data_batch, mode='tensor')
|
432 |
+
prototype.append(feats)
|
433 |
+
prototype = {
|
434 |
+
k: torch.cat([d[k] for d in prototype])
|
435 |
+
for k in prototype[0]
|
436 |
+
}
|
437 |
+
self.prototype = prototype
|
438 |
+
|
439 |
+
from mmengine.logging import MMLogger
|
440 |
+
logger = MMLogger.get_current_instance()
|
441 |
+
if cache is None:
|
442 |
+
logger.info('The prototype has been prepared, you can use '
|
443 |
+
'`save_prototype` to dump it into a pickle '
|
444 |
+
'file for the future usage.')
|
445 |
+
elif not Path(cache).exists():
|
446 |
+
self.save_prototype(cache)
|
447 |
+
logger.info(f'The prototype has been saved at {cache}.')
|
448 |
+
|
449 |
+
return dataset
|
450 |
+
|
451 |
+
def save_prototype(self, path):
|
452 |
+
torch.save(self.prototype, path)
|
453 |
+
|
454 |
+
def __call__(self,
|
455 |
+
inputs: ImageType,
|
456 |
+
return_datasamples: bool = False,
|
457 |
+
batch_size: int = 1,
|
458 |
+
**kwargs) -> dict:
|
459 |
+
"""Call the inferencer.
|
460 |
+
|
461 |
+
Args:
|
462 |
+
inputs (str | array | list): The image path or array, or a list of
|
463 |
+
images.
|
464 |
+
return_datasamples (bool): Whether to return results as
|
465 |
+
:obj:`DataSample`. Defaults to False.
|
466 |
+
batch_size (int): Batch size. Defaults to 1.
|
467 |
+
resize (int, optional): Resize the long edge of the image to the
|
468 |
+
specified length before visualization. Defaults to None.
|
469 |
+
draw_score (bool): Whether to draw the match scores.
|
470 |
+
Defaults to True.
|
471 |
+
show (bool): Whether to display the visualization result in a
|
472 |
+
window. Defaults to False.
|
473 |
+
wait_time (float): The display time (s). Defaults to 0, which means
|
474 |
+
"forever".
|
475 |
+
show_dir (str, optional): If not None, save the visualization
|
476 |
+
results in the specified directory. Defaults to None.
|
477 |
+
|
478 |
+
Returns:
|
479 |
+
list: The inference results.
|
480 |
+
"""
|
481 |
+
return super().__call__(inputs, return_datasamples, batch_size,
|
482 |
+
**kwargs)
|
483 |
+
|
484 |
+
@torch.no_grad()
|
485 |
+
def forward(self, data: dict, **kwargs):
|
486 |
+
"""Feed the inputs to the model."""
|
487 |
+
data = self.model.data_preprocessor(data, False)
|
488 |
+
feats = self.prototype.copy()
|
489 |
+
feats.update(self.model.extract_feat(images=data['images']))
|
490 |
+
return self.model.predict_all(
|
491 |
+
feats, data['data_samples'], cal_t2i=False)[0]
|
492 |
+
|
493 |
+
def _init_pipeline(self, cfg: Config) -> Callable:
|
494 |
+
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
495 |
+
test_transfroms = [TRANSFORMS.build(t) for t in test_pipeline_cfg]
|
496 |
+
img_info = {'img': np.zeros((224, 224, 3), dtype=np.uint8)}
|
497 |
+
text_info = {'text': 'example'}
|
498 |
+
img_pipeline = Compose(filter_transforms(test_transfroms, img_info))
|
499 |
+
text_pipeline = Compose(filter_transforms(test_transfroms, text_info))
|
500 |
+
return img_pipeline, text_pipeline
|
501 |
+
|
502 |
+
def preprocess(self, inputs: List[ImageType], batch_size: int = 1):
|
503 |
+
|
504 |
+
def load_image(input_):
|
505 |
+
img = imread(input_)
|
506 |
+
if img is None:
|
507 |
+
raise ValueError(f'Failed to read image {input_}.')
|
508 |
+
return dict(
|
509 |
+
img=img,
|
510 |
+
img_shape=img.shape[:2],
|
511 |
+
ori_shape=img.shape[:2],
|
512 |
+
)
|
513 |
+
|
514 |
+
pipeline = Compose([load_image, self.img_pipeline])
|
515 |
+
|
516 |
+
chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
|
517 |
+
yield from map(default_collate, chunked_data)
|
518 |
+
|
519 |
+
def visualize(self,
|
520 |
+
ori_inputs: List[ImageType],
|
521 |
+
preds: List[DataSample],
|
522 |
+
topk: int = 3,
|
523 |
+
resize: Optional[int] = 224,
|
524 |
+
show: bool = False,
|
525 |
+
wait_time: int = 0,
|
526 |
+
draw_score=True,
|
527 |
+
show_dir=None):
|
528 |
+
if not show and show_dir is None:
|
529 |
+
return None
|
530 |
+
|
531 |
+
if self.visualizer is None:
|
532 |
+
from mmpretrain.visualization import UniversalVisualizer
|
533 |
+
self.visualizer = UniversalVisualizer()
|
534 |
+
|
535 |
+
visualization = []
|
536 |
+
for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
|
537 |
+
image = imread(input_)
|
538 |
+
if isinstance(input_, str):
|
539 |
+
# The image loaded from path is BGR format.
|
540 |
+
image = image[..., ::-1]
|
541 |
+
name = Path(input_).stem
|
542 |
+
else:
|
543 |
+
name = str(i)
|
544 |
+
|
545 |
+
if show_dir is not None:
|
546 |
+
show_dir = Path(show_dir)
|
547 |
+
show_dir.mkdir(exist_ok=True)
|
548 |
+
out_file = str((show_dir / name).with_suffix('.png'))
|
549 |
+
else:
|
550 |
+
out_file = None
|
551 |
+
|
552 |
+
self.visualizer.visualize_i2t_retrieval(
|
553 |
+
image,
|
554 |
+
data_sample,
|
555 |
+
self.prototype_dataset,
|
556 |
+
topk=topk,
|
557 |
+
resize=resize,
|
558 |
+
draw_score=draw_score,
|
559 |
+
show=show,
|
560 |
+
wait_time=wait_time,
|
561 |
+
name=name,
|
562 |
+
out_file=out_file)
|
563 |
+
visualization.append(self.visualizer.get_image())
|
564 |
+
if show:
|
565 |
+
self.visualizer.close()
|
566 |
+
return visualization
|
567 |
+
|
568 |
+
def postprocess(
|
569 |
+
self,
|
570 |
+
preds: List[DataSample],
|
571 |
+
visualization: List[np.ndarray],
|
572 |
+
return_datasamples=False,
|
573 |
+
topk=1,
|
574 |
+
) -> dict:
|
575 |
+
if return_datasamples:
|
576 |
+
return preds
|
577 |
+
|
578 |
+
results = []
|
579 |
+
for data_sample in preds:
|
580 |
+
match_scores, indices = torch.topk(data_sample.pred_score, k=topk)
|
581 |
+
matches = []
|
582 |
+
for match_score, sample_idx in zip(match_scores, indices):
|
583 |
+
text = self.prototype_dataset[sample_idx.item()]
|
584 |
+
matches.append({
|
585 |
+
'match_score': match_score,
|
586 |
+
'sample_idx': sample_idx,
|
587 |
+
'text': text
|
588 |
+
})
|
589 |
+
results.append(matches)
|
590 |
+
|
591 |
+
return results
|
592 |
+
|
593 |
+
@staticmethod
|
594 |
+
def list_models(pattern: Optional[str] = None):
|
595 |
+
"""List all available model names.
|
596 |
+
|
597 |
+
Args:
|
598 |
+
pattern (str | None): A wildcard pattern to match model names.
|
599 |
+
|
600 |
+
Returns:
|
601 |
+
List[str]: a list of model names.
|
602 |
+
"""
|
603 |
+
return list_models(pattern=pattern, task='Image-To-Text Retrieval')
|
mmpretrain/apis/nlvr.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from copy import deepcopy
|
3 |
+
from typing import Callable, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from mmcv.image import imread
|
8 |
+
from mmengine.config import Config
|
9 |
+
from mmengine.dataset import Compose, default_collate
|
10 |
+
|
11 |
+
from mmpretrain.registry import TRANSFORMS
|
12 |
+
from mmpretrain.structures import DataSample
|
13 |
+
from .base import BaseInferencer
|
14 |
+
from .model import list_models
|
15 |
+
|
16 |
+
InputType = Tuple[Union[str, np.ndarray], Union[str, np.ndarray], str]
|
17 |
+
InputsType = Union[List[InputType], InputType]
|
18 |
+
|
19 |
+
|
20 |
+
class NLVRInferencer(BaseInferencer):
|
21 |
+
"""The inferencer for Natural Language for Visual Reasoning.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
model (BaseModel | str | Config): A model name or a path to the config
|
25 |
+
file, or a :obj:`BaseModel` object. The model name can be found
|
26 |
+
by ``NLVRInferencer.list_models()`` and you can also
|
27 |
+
query it in :doc:`/modelzoo_statistics`.
|
28 |
+
pretrained (str, optional): Path to the checkpoint. If None, it will
|
29 |
+
try to find a pre-defined weight from the model you specified
|
30 |
+
(only work if the ``model`` is a model name). Defaults to None.
|
31 |
+
device (str, optional): Device to run inference. If None, the available
|
32 |
+
device will be automatically used. Defaults to None.
|
33 |
+
**kwargs: Other keyword arguments to initialize the model (only work if
|
34 |
+
the ``model`` is a model name).
|
35 |
+
"""
|
36 |
+
|
37 |
+
visualize_kwargs: set = {
|
38 |
+
'resize', 'draw_score', 'show', 'show_dir', 'wait_time'
|
39 |
+
}
|
40 |
+
|
41 |
+
def __call__(self,
|
42 |
+
inputs: InputsType,
|
43 |
+
return_datasamples: bool = False,
|
44 |
+
batch_size: int = 1,
|
45 |
+
**kwargs) -> dict:
|
46 |
+
"""Call the inferencer.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
inputs (tuple, List[tuple]): The input data tuples, every tuple
|
50 |
+
should include three items (left image, right image, text).
|
51 |
+
The image can be a path or numpy array.
|
52 |
+
return_datasamples (bool): Whether to return results as
|
53 |
+
:obj:`DataSample`. Defaults to False.
|
54 |
+
batch_size (int): Batch size. Defaults to 1.
|
55 |
+
resize (int, optional): Resize the short edge of the image to the
|
56 |
+
specified length before visualization. Defaults to None.
|
57 |
+
draw_score (bool): Whether to draw the prediction scores
|
58 |
+
of prediction categories. Defaults to True.
|
59 |
+
show (bool): Whether to display the visualization result in a
|
60 |
+
window. Defaults to False.
|
61 |
+
wait_time (float): The display time (s). Defaults to 0, which means
|
62 |
+
"forever".
|
63 |
+
show_dir (str, optional): If not None, save the visualization
|
64 |
+
results in the specified directory. Defaults to None.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
list: The inference results.
|
68 |
+
"""
|
69 |
+
assert isinstance(inputs, (tuple, list))
|
70 |
+
if isinstance(inputs, tuple):
|
71 |
+
inputs = [inputs]
|
72 |
+
for input_ in inputs:
|
73 |
+
assert isinstance(input_, tuple)
|
74 |
+
assert len(input_) == 3
|
75 |
+
|
76 |
+
return super().__call__(
|
77 |
+
inputs,
|
78 |
+
return_datasamples=return_datasamples,
|
79 |
+
batch_size=batch_size,
|
80 |
+
**kwargs)
|
81 |
+
|
82 |
+
def _init_pipeline(self, cfg: Config) -> Callable:
|
83 |
+
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
84 |
+
assert test_pipeline_cfg[0]['type'] == 'ApplyToList'
|
85 |
+
|
86 |
+
list_pipeline = deepcopy(test_pipeline_cfg[0])
|
87 |
+
if list_pipeline.scatter_key == 'img_path':
|
88 |
+
# Remove `LoadImageFromFile`
|
89 |
+
list_pipeline.transforms.pop(0)
|
90 |
+
list_pipeline.scatter_key = 'img'
|
91 |
+
|
92 |
+
test_pipeline = Compose(
|
93 |
+
[TRANSFORMS.build(list_pipeline)] +
|
94 |
+
[TRANSFORMS.build(t) for t in test_pipeline_cfg[1:]])
|
95 |
+
return test_pipeline
|
96 |
+
|
97 |
+
def preprocess(self, inputs: InputsType, batch_size: int = 1):
|
98 |
+
|
99 |
+
def load_image(input_):
|
100 |
+
img1 = imread(input_[0])
|
101 |
+
img2 = imread(input_[1])
|
102 |
+
text = input_[2]
|
103 |
+
if img1 is None:
|
104 |
+
raise ValueError(f'Failed to read image {input_[0]}.')
|
105 |
+
if img2 is None:
|
106 |
+
raise ValueError(f'Failed to read image {input_[1]}.')
|
107 |
+
return dict(
|
108 |
+
img=[img1, img2],
|
109 |
+
img_shape=[img1.shape[:2], img2.shape[:2]],
|
110 |
+
ori_shape=[img1.shape[:2], img2.shape[:2]],
|
111 |
+
text=text,
|
112 |
+
)
|
113 |
+
|
114 |
+
pipeline = Compose([load_image, self.pipeline])
|
115 |
+
|
116 |
+
chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
|
117 |
+
yield from map(default_collate, chunked_data)
|
118 |
+
|
119 |
+
def postprocess(self,
|
120 |
+
preds: List[DataSample],
|
121 |
+
visualization: List[np.ndarray],
|
122 |
+
return_datasamples=False) -> dict:
|
123 |
+
if return_datasamples:
|
124 |
+
return preds
|
125 |
+
|
126 |
+
results = []
|
127 |
+
for data_sample in preds:
|
128 |
+
pred_scores = data_sample.pred_score
|
129 |
+
pred_score = float(torch.max(pred_scores).item())
|
130 |
+
pred_label = torch.argmax(pred_scores).item()
|
131 |
+
result = {
|
132 |
+
'pred_scores': pred_scores.detach().cpu().numpy(),
|
133 |
+
'pred_label': pred_label,
|
134 |
+
'pred_score': pred_score,
|
135 |
+
}
|
136 |
+
results.append(result)
|
137 |
+
|
138 |
+
return results
|
139 |
+
|
140 |
+
@staticmethod
|
141 |
+
def list_models(pattern: Optional[str] = None):
|
142 |
+
"""List all available model names.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
pattern (str | None): A wildcard pattern to match model names.
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
List[str]: a list of model names.
|
149 |
+
"""
|
150 |
+
return list_models(pattern=pattern, task='NLVR')
|
mmpretrain/apis/utils.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os
|
3 |
+
from collections import defaultdict
|
4 |
+
from contextlib import contextmanager
|
5 |
+
from itertools import chain
|
6 |
+
from typing import Dict, List, Optional, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from mmpretrain.utils import require
|
12 |
+
|
13 |
+
|
14 |
+
@require('torch>=1.9.0', 'https://pytorch.org/get-started/locally/')
|
15 |
+
@require('accelerate')
|
16 |
+
def dispatch_model(
|
17 |
+
model,
|
18 |
+
device_map: Union[str, dict],
|
19 |
+
max_memory: Optional[dict] = None,
|
20 |
+
no_split_module_classes: Optional[List[str]] = None,
|
21 |
+
offload_folder: str = None,
|
22 |
+
offload_buffers: bool = False,
|
23 |
+
preload_module_classes: Optional[List[str]] = None,
|
24 |
+
):
|
25 |
+
"""Split and dispatch a model across devices.
|
26 |
+
|
27 |
+
The function depends on the `accelerate` package. Refers to
|
28 |
+
https://huggingface.co/docs/accelerate/main/en/usage_guides/big_modeling
|
29 |
+
|
30 |
+
Args:
|
31 |
+
model (torch.nn.Module): The model to dispatch.
|
32 |
+
device_map (str | dict | None): A map that specifies where each
|
33 |
+
submodule should go. It doesn't need to be refined to each
|
34 |
+
parameter/buffer name, once a given module name is inside, every
|
35 |
+
submodule of it will be sent to the same device. You can use
|
36 |
+
`device_map="auto"` to automatically generate the device map.
|
37 |
+
Defaults to None.
|
38 |
+
max_memory (dict | None): A dictionary device identifier to maximum
|
39 |
+
memory. Will default to the maximum memory available for each GPU
|
40 |
+
and the available CPU RAM if unset. Defaults to None.
|
41 |
+
no_split_module_classes (List[str] | None): A list of layer class names
|
42 |
+
that should never be split across device (for instance any layer
|
43 |
+
that has a residual connection). If None, try to get the settings
|
44 |
+
from the model class. Defaults to None.
|
45 |
+
offload_folder (str | None): If the `device_map` contains any value
|
46 |
+
`"disk"`, the folder where we will offload weights.
|
47 |
+
offload_buffers (bool): In the layers that are offloaded on the CPU
|
48 |
+
or the hard drive, whether or not to offload the buffers as
|
49 |
+
well as the parameters. Defaults to False.
|
50 |
+
preload_module_classes (List[str] | None): A list of classes whose
|
51 |
+
instances should load all their weights (even in the submodules) at
|
52 |
+
the beginning of the forward. This should only be used for classes
|
53 |
+
that have submodules which are registered but not called directly
|
54 |
+
during the forward, for instance if a `dense` linear layer is
|
55 |
+
registered, but at forward, `dense.weight` and `dense.bias` are
|
56 |
+
used in some operations instead of calling `dense` directly.
|
57 |
+
Defaults to None.
|
58 |
+
"""
|
59 |
+
from accelerate import dispatch_model, infer_auto_device_map
|
60 |
+
|
61 |
+
# Check valid device_map string.
|
62 |
+
valid_map_option = ['auto', 'balanced', 'balanced_low_0', 'sequential']
|
63 |
+
if isinstance(device_map, str) and device_map not in valid_map_option:
|
64 |
+
raise ValueError('If passing a string for `device_map`, please choose '
|
65 |
+
f'from {valid_map_option}.')
|
66 |
+
|
67 |
+
# Generate device map automatically
|
68 |
+
if isinstance(device_map, str):
|
69 |
+
if no_split_module_classes is None:
|
70 |
+
no_split_module_classes = getattr(model, '_no_split_modules', None)
|
71 |
+
if no_split_module_classes is None:
|
72 |
+
raise ValueError(f'{model.__class__.__name__} does not support '
|
73 |
+
f"`device_map='{device_map}'` yet.")
|
74 |
+
|
75 |
+
if device_map != 'sequential':
|
76 |
+
from accelerate.utils import get_balanced_memory
|
77 |
+
max_memory = get_balanced_memory(
|
78 |
+
model,
|
79 |
+
max_memory=max_memory,
|
80 |
+
no_split_module_classes=no_split_module_classes,
|
81 |
+
dtype=None,
|
82 |
+
low_zero=(device_map == 'balanced_low_0'),
|
83 |
+
)
|
84 |
+
max_memory[0] *= 0.9
|
85 |
+
device_map = infer_auto_device_map(
|
86 |
+
model,
|
87 |
+
max_memory=max_memory,
|
88 |
+
no_split_module_classes=no_split_module_classes,
|
89 |
+
dtype=None,
|
90 |
+
)
|
91 |
+
|
92 |
+
if 'disk' in device_map.values():
|
93 |
+
if offload_folder is None:
|
94 |
+
raise ValueError(
|
95 |
+
'The current `device_map` had weights offloaded to the disk. '
|
96 |
+
'Please provide an `offload_folder` for them.')
|
97 |
+
os.makedirs(offload_folder, exist_ok=True)
|
98 |
+
|
99 |
+
main_device = next(
|
100 |
+
(d for d in device_map.values() if d not in ['cpu', 'disk']), 'cpu')
|
101 |
+
|
102 |
+
model = dispatch_model(
|
103 |
+
model,
|
104 |
+
device_map=device_map,
|
105 |
+
main_device=main_device,
|
106 |
+
offload_dir=offload_folder,
|
107 |
+
offload_buffers=offload_buffers,
|
108 |
+
preload_module_classes=preload_module_classes,
|
109 |
+
)
|
110 |
+
if hasattr(model, 'data_preprocessor'):
|
111 |
+
model.data_preprocessor._device = torch.device(main_device)
|
112 |
+
return model
|
113 |
+
|
114 |
+
|
115 |
+
@contextmanager
|
116 |
+
def init_empty_weights(include_buffers: bool = False):
|
117 |
+
"""A context manager under which models are initialized with all parameters
|
118 |
+
on the meta device.
|
119 |
+
|
120 |
+
With this context manager, we can create an empty model. Useful when just
|
121 |
+
initializing the model would blow the available RAM.
|
122 |
+
|
123 |
+
Besides move the parameters to meta device, this method will also avoid
|
124 |
+
load checkpoint from `mmengine.runner.load_checkpoint` and
|
125 |
+
`transformers.PreTrainedModel.from_pretrained`.
|
126 |
+
|
127 |
+
Modified from https://github.com/huggingface/accelerate
|
128 |
+
|
129 |
+
Args:
|
130 |
+
include_buffers (bool): Whether put all buffers on the meta device
|
131 |
+
during initialization.
|
132 |
+
"""
|
133 |
+
device = torch.device('meta')
|
134 |
+
|
135 |
+
# move parameter and buffer to meta device
|
136 |
+
old_register_parameter = nn.Module.register_parameter
|
137 |
+
if include_buffers:
|
138 |
+
old_register_buffer = nn.Module.register_buffer
|
139 |
+
# See https://github.com/huggingface/accelerate/pull/699
|
140 |
+
tensor_constructors_to_patch = {
|
141 |
+
torch_function_name: getattr(torch, torch_function_name)
|
142 |
+
for torch_function_name in ['empty', 'zeros', 'ones', 'full']
|
143 |
+
}
|
144 |
+
|
145 |
+
def register_parameter(module, name, param):
|
146 |
+
old_register_parameter(module, name, param)
|
147 |
+
if param is not None:
|
148 |
+
param_cls = type(module._parameters[name])
|
149 |
+
kwargs = module._parameters[name].__dict__
|
150 |
+
module._parameters[name] = param_cls(
|
151 |
+
module._parameters[name].to(device), **kwargs)
|
152 |
+
|
153 |
+
def register_buffer(module, name, buffer, *args, **kwargs):
|
154 |
+
old_register_buffer(module, name, buffer, *args, **kwargs)
|
155 |
+
if buffer is not None:
|
156 |
+
module._buffers[name] = module._buffers[name].to(device)
|
157 |
+
|
158 |
+
def patch_tensor_constructor(fn):
|
159 |
+
|
160 |
+
def wrapper(*args, **kwargs):
|
161 |
+
kwargs['device'] = device
|
162 |
+
return fn(*args, **kwargs)
|
163 |
+
|
164 |
+
return wrapper
|
165 |
+
|
166 |
+
# Patch load_checkpoint
|
167 |
+
import mmengine.runner.checkpoint as mmengine_load
|
168 |
+
old_load_checkpoint = mmengine_load.load_checkpoint
|
169 |
+
|
170 |
+
def patch_load_checkpoint(*args, **kwargs):
|
171 |
+
return {}
|
172 |
+
|
173 |
+
# Patch transformers from pretrained
|
174 |
+
try:
|
175 |
+
from transformers import PreTrainedModel
|
176 |
+
from transformers.models.auto.auto_factory import (AutoConfig,
|
177 |
+
_BaseAutoModelClass)
|
178 |
+
with_transformers = True
|
179 |
+
except ImportError:
|
180 |
+
with_transformers = False
|
181 |
+
|
182 |
+
@classmethod
|
183 |
+
def patch_auto_model(cls, pretrained_model_name_or_path, *model_args,
|
184 |
+
**kwargs):
|
185 |
+
cfg = AutoConfig.from_pretrained(pretrained_model_name_or_path,
|
186 |
+
*model_args, **kwargs)
|
187 |
+
return cls.from_config(cfg)
|
188 |
+
|
189 |
+
@classmethod
|
190 |
+
def patch_pretrained_model(cls, pretrained_model_name_or_path, *model_args,
|
191 |
+
**kwargs):
|
192 |
+
cfg = cls.config_class.from_pretrained(pretrained_model_name_or_path,
|
193 |
+
*model_args, **kwargs)
|
194 |
+
return cls(cfg)
|
195 |
+
|
196 |
+
if with_transformers:
|
197 |
+
old_pretrained_model = PreTrainedModel.from_pretrained
|
198 |
+
old_auto_model = _BaseAutoModelClass.from_pretrained
|
199 |
+
|
200 |
+
try:
|
201 |
+
nn.Module.register_parameter = register_parameter
|
202 |
+
mmengine_load.load_checkpoint = patch_load_checkpoint
|
203 |
+
if with_transformers:
|
204 |
+
PreTrainedModel.from_pretrained = patch_pretrained_model
|
205 |
+
_BaseAutoModelClass.from_pretrained = patch_auto_model
|
206 |
+
if include_buffers:
|
207 |
+
nn.Module.register_buffer = register_buffer
|
208 |
+
for func in tensor_constructors_to_patch.keys():
|
209 |
+
tensor_constructor = patch_tensor_constructor(
|
210 |
+
getattr(torch, func))
|
211 |
+
setattr(torch, func, tensor_constructor)
|
212 |
+
yield
|
213 |
+
finally:
|
214 |
+
nn.Module.register_parameter = old_register_parameter
|
215 |
+
mmengine_load.load_checkpoint = old_load_checkpoint
|
216 |
+
if with_transformers:
|
217 |
+
PreTrainedModel.from_pretrained = old_pretrained_model
|
218 |
+
_BaseAutoModelClass.from_pretrained = old_auto_model
|
219 |
+
if include_buffers:
|
220 |
+
nn.Module.register_buffer = old_register_buffer
|
221 |
+
for func, ori in tensor_constructors_to_patch.items():
|
222 |
+
setattr(torch, func, ori)
|
223 |
+
|
224 |
+
|
225 |
+
def compute_module_sizes(
|
226 |
+
model: nn.Module,
|
227 |
+
dtype: Union[str, torch.dtype, None] = None,
|
228 |
+
special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None):
|
229 |
+
"""Compute the size of each submodule of a given model."""
|
230 |
+
|
231 |
+
def get_dtype(dtype):
|
232 |
+
if isinstance(dtype, str):
|
233 |
+
dtype = getattr(torch, dtype)
|
234 |
+
if dtype is not None:
|
235 |
+
assert issubclass(dtype, torch.dtype)
|
236 |
+
return dtype
|
237 |
+
|
238 |
+
def dtype_bytes(dtype: torch.dtype):
|
239 |
+
if dtype is torch.bool:
|
240 |
+
return 1
|
241 |
+
if dtype.is_floating_point:
|
242 |
+
return torch.finfo(dtype).bits / 8
|
243 |
+
else:
|
244 |
+
return torch.iinfo(dtype).bits / 8
|
245 |
+
|
246 |
+
if dtype is not None:
|
247 |
+
dtype = get_dtype(dtype)
|
248 |
+
dtype_size = dtype_bytes(dtype)
|
249 |
+
|
250 |
+
if special_dtypes is not None:
|
251 |
+
special_dtypes = {
|
252 |
+
key: dtype_bytes(dtype)
|
253 |
+
for key, dtype in special_dtypes.items()
|
254 |
+
}
|
255 |
+
|
256 |
+
module_sizes = defaultdict(int)
|
257 |
+
for name, tensor in chain(
|
258 |
+
model.named_parameters(recurse=True),
|
259 |
+
model.named_buffers(recurse=True)):
|
260 |
+
if special_dtypes is not None and name in special_dtypes:
|
261 |
+
size = tensor.numel() * special_dtypes[name]
|
262 |
+
elif dtype is None:
|
263 |
+
size = tensor.numel() * tensor.element_size()
|
264 |
+
else:
|
265 |
+
size = tensor.numel() * min(dtype_size, tensor.element_size())
|
266 |
+
name_parts = name.split('.')
|
267 |
+
for idx in range(len(name_parts) + 1):
|
268 |
+
module_sizes['.'.join(name_parts[:idx])] += size
|
269 |
+
|
270 |
+
return module_sizes
|
mmpretrain/apis/visual_grounding.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Callable, List, Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from mmcv.image import imread
|
7 |
+
from mmengine.config import Config
|
8 |
+
from mmengine.dataset import Compose, default_collate
|
9 |
+
|
10 |
+
from mmpretrain.registry import TRANSFORMS
|
11 |
+
from mmpretrain.structures import DataSample
|
12 |
+
from .base import BaseInferencer
|
13 |
+
from .model import list_models
|
14 |
+
|
15 |
+
|
16 |
+
class VisualGroundingInferencer(BaseInferencer):
|
17 |
+
"""The inferencer for visual grounding.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
model (BaseModel | str | Config): A model name or a path to the config
|
21 |
+
file, or a :obj:`BaseModel` object. The model name can be found
|
22 |
+
by ``VisualGroundingInferencer.list_models()`` and you can also
|
23 |
+
query it in :doc:`/modelzoo_statistics`.
|
24 |
+
pretrained (str, optional): Path to the checkpoint. If None, it will
|
25 |
+
try to find a pre-defined weight from the model you specified
|
26 |
+
(only work if the ``model`` is a model name). Defaults to None.
|
27 |
+
device (str, optional): Device to run inference. If None, the available
|
28 |
+
device will be automatically used. Defaults to None.
|
29 |
+
**kwargs: Other keyword arguments to initialize the model (only work if
|
30 |
+
the ``model`` is a model name).
|
31 |
+
|
32 |
+
Example:
|
33 |
+
>>> from mmpretrain import VisualGroundingInferencer
|
34 |
+
>>> inferencer = VisualGroundingInferencer('ofa-base_3rdparty_refcoco')
|
35 |
+
>>> inferencer('demo/cat-dog.png', 'dog')[0]
|
36 |
+
{'pred_bboxes': tensor([[ 36.6000, 29.6000, 355.8000, 395.2000]])}
|
37 |
+
""" # noqa: E501
|
38 |
+
|
39 |
+
visualize_kwargs: set = {
|
40 |
+
'resize', 'show', 'show_dir', 'wait_time', 'line_width', 'bbox_color'
|
41 |
+
}
|
42 |
+
|
43 |
+
def __call__(self,
|
44 |
+
images: Union[str, np.ndarray, list],
|
45 |
+
texts: Union[str, list],
|
46 |
+
return_datasamples: bool = False,
|
47 |
+
batch_size: int = 1,
|
48 |
+
**kwargs) -> dict:
|
49 |
+
"""Call the inferencer.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
images (str | array | list): The image path or array, or a list of
|
53 |
+
images.
|
54 |
+
texts (str | list): The text to do visual grounding.
|
55 |
+
return_datasamples (bool): Whether to return results as
|
56 |
+
:obj:`DataSample`. Defaults to False.
|
57 |
+
batch_size (int): Batch size. Defaults to 1.
|
58 |
+
resize (int, optional): Resize the short edge of the image to the
|
59 |
+
specified length before visualization. Defaults to None.
|
60 |
+
draw_score (bool): Whether to draw the prediction scores
|
61 |
+
of prediction categories. Defaults to True.
|
62 |
+
show (bool): Whether to display the visualization result in a
|
63 |
+
window. Defaults to False.
|
64 |
+
wait_time (float): The display time (s). Defaults to 0, which means
|
65 |
+
"forever".
|
66 |
+
show_dir (str, optional): If not None, save the visualization
|
67 |
+
results in the specified directory. Defaults to None.
|
68 |
+
line_width (int): The line width of the bbox. Defaults to 3.
|
69 |
+
bbox_color (str | tuple): The color of the bbox.
|
70 |
+
Defaults to 'green'.
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
list: The inference results.
|
74 |
+
"""
|
75 |
+
if not isinstance(images, (list, tuple)):
|
76 |
+
assert isinstance(texts, str)
|
77 |
+
inputs = [{'img': images, 'text': texts}]
|
78 |
+
else:
|
79 |
+
inputs = []
|
80 |
+
for i in range(len(images)):
|
81 |
+
input_ = {'img': images[i], 'text': texts[i]}
|
82 |
+
inputs.append(input_)
|
83 |
+
|
84 |
+
return super().__call__(inputs, return_datasamples, batch_size,
|
85 |
+
**kwargs)
|
86 |
+
|
87 |
+
def _init_pipeline(self, cfg: Config) -> Callable:
|
88 |
+
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
89 |
+
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
90 |
+
# Image loading is finished in `self.preprocess`.
|
91 |
+
test_pipeline_cfg = test_pipeline_cfg[1:]
|
92 |
+
test_pipeline = Compose(
|
93 |
+
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
94 |
+
return test_pipeline
|
95 |
+
|
96 |
+
def preprocess(self, inputs: List[dict], batch_size: int = 1):
|
97 |
+
|
98 |
+
def load_image(input_: dict):
|
99 |
+
img = imread(input_['img'])
|
100 |
+
if img is None:
|
101 |
+
raise ValueError(f'Failed to read image {input_}.')
|
102 |
+
return {**input_, 'img': img}
|
103 |
+
|
104 |
+
pipeline = Compose([load_image, self.pipeline])
|
105 |
+
|
106 |
+
chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
|
107 |
+
yield from map(default_collate, chunked_data)
|
108 |
+
|
109 |
+
def visualize(self,
|
110 |
+
ori_inputs: List[dict],
|
111 |
+
preds: List[DataSample],
|
112 |
+
show: bool = False,
|
113 |
+
wait_time: int = 0,
|
114 |
+
resize: Optional[int] = None,
|
115 |
+
line_width: int = 3,
|
116 |
+
bbox_color: Union[str, tuple] = 'green',
|
117 |
+
show_dir=None):
|
118 |
+
if not show and show_dir is None:
|
119 |
+
return None
|
120 |
+
|
121 |
+
if self.visualizer is None:
|
122 |
+
from mmpretrain.visualization import UniversalVisualizer
|
123 |
+
self.visualizer = UniversalVisualizer()
|
124 |
+
|
125 |
+
visualization = []
|
126 |
+
for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
|
127 |
+
image = imread(input_['img'])
|
128 |
+
if isinstance(input_['img'], str):
|
129 |
+
# The image loaded from path is BGR format.
|
130 |
+
image = image[..., ::-1]
|
131 |
+
name = Path(input_['img']).stem
|
132 |
+
else:
|
133 |
+
name = str(i)
|
134 |
+
|
135 |
+
if show_dir is not None:
|
136 |
+
show_dir = Path(show_dir)
|
137 |
+
show_dir.mkdir(exist_ok=True)
|
138 |
+
out_file = str((show_dir / name).with_suffix('.png'))
|
139 |
+
else:
|
140 |
+
out_file = None
|
141 |
+
|
142 |
+
self.visualizer.visualize_visual_grounding(
|
143 |
+
image,
|
144 |
+
data_sample,
|
145 |
+
resize=resize,
|
146 |
+
show=show,
|
147 |
+
wait_time=wait_time,
|
148 |
+
line_width=line_width,
|
149 |
+
bbox_color=bbox_color,
|
150 |
+
name=name,
|
151 |
+
out_file=out_file)
|
152 |
+
visualization.append(self.visualizer.get_image())
|
153 |
+
if show:
|
154 |
+
self.visualizer.close()
|
155 |
+
return visualization
|
156 |
+
|
157 |
+
def postprocess(self,
|
158 |
+
preds: List[DataSample],
|
159 |
+
visualization: List[np.ndarray],
|
160 |
+
return_datasamples=False) -> dict:
|
161 |
+
if return_datasamples:
|
162 |
+
return preds
|
163 |
+
|
164 |
+
results = []
|
165 |
+
for data_sample in preds:
|
166 |
+
results.append({'pred_bboxes': data_sample.get('pred_bboxes')})
|
167 |
+
|
168 |
+
return results
|
169 |
+
|
170 |
+
@staticmethod
|
171 |
+
def list_models(pattern: Optional[str] = None):
|
172 |
+
"""List all available model names.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
pattern (str | None): A wildcard pattern to match model names.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
List[str]: a list of model names.
|
179 |
+
"""
|
180 |
+
return list_models(pattern=pattern, task='Visual Grounding')
|
mmpretrain/apis/visual_question_answering.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Callable, List, Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from mmcv.image import imread
|
7 |
+
from mmengine.config import Config
|
8 |
+
from mmengine.dataset import Compose, default_collate
|
9 |
+
|
10 |
+
from mmpretrain.registry import TRANSFORMS
|
11 |
+
from mmpretrain.structures import DataSample
|
12 |
+
from .base import BaseInferencer
|
13 |
+
from .model import list_models
|
14 |
+
|
15 |
+
|
16 |
+
class VisualQuestionAnsweringInferencer(BaseInferencer):
|
17 |
+
"""The inferencer for visual question answering.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
model (BaseModel | str | Config): A model name or a path to the config
|
21 |
+
file, or a :obj:`BaseModel` object. The model name can be found
|
22 |
+
by ``VisualQuestionAnsweringInferencer.list_models()`` and you can
|
23 |
+
also query it in :doc:`/modelzoo_statistics`.
|
24 |
+
pretrained (str, optional): Path to the checkpoint. If None, it will
|
25 |
+
try to find a pre-defined weight from the model you specified
|
26 |
+
(only work if the ``model`` is a model name). Defaults to None.
|
27 |
+
device (str, optional): Device to run inference. If None, the available
|
28 |
+
device will be automatically used. Defaults to None.
|
29 |
+
**kwargs: Other keyword arguments to initialize the model (only work if
|
30 |
+
the ``model`` is a model name).
|
31 |
+
|
32 |
+
Example:
|
33 |
+
>>> from mmpretrain import VisualQuestionAnsweringInferencer
|
34 |
+
>>> inferencer = VisualQuestionAnsweringInferencer('ofa-base_3rdparty-zeroshot_vqa')
|
35 |
+
>>> inferencer('demo/cat-dog.png', "What's the animal next to the dog?")[0]
|
36 |
+
{'question': "What's the animal next to the dog?", 'pred_answer': 'cat'}
|
37 |
+
""" # noqa: E501
|
38 |
+
|
39 |
+
visualize_kwargs: set = {'resize', 'show', 'show_dir', 'wait_time'}
|
40 |
+
|
41 |
+
def __call__(self,
|
42 |
+
images: Union[str, np.ndarray, list],
|
43 |
+
questions: Union[str, list],
|
44 |
+
return_datasamples: bool = False,
|
45 |
+
batch_size: int = 1,
|
46 |
+
objects: Optional[List[str]] = None,
|
47 |
+
**kwargs) -> dict:
|
48 |
+
"""Call the inferencer.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
images (str | array | list): The image path or array, or a list of
|
52 |
+
images.
|
53 |
+
questions (str | list): The question to the correspondding image.
|
54 |
+
return_datasamples (bool): Whether to return results as
|
55 |
+
:obj:`DataSample`. Defaults to False.
|
56 |
+
batch_size (int): Batch size. Defaults to 1.
|
57 |
+
objects (List[List[str]], optional): Some algorithms like OFA
|
58 |
+
fine-tuned VQA models requires extra object description list
|
59 |
+
for every image. Defaults to None.
|
60 |
+
resize (int, optional): Resize the short edge of the image to the
|
61 |
+
specified length before visualization. Defaults to None.
|
62 |
+
show (bool): Whether to display the visualization result in a
|
63 |
+
window. Defaults to False.
|
64 |
+
wait_time (float): The display time (s). Defaults to 0, which means
|
65 |
+
"forever".
|
66 |
+
show_dir (str, optional): If not None, save the visualization
|
67 |
+
results in the specified directory. Defaults to None.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
list: The inference results.
|
71 |
+
"""
|
72 |
+
if not isinstance(images, (list, tuple)):
|
73 |
+
assert isinstance(questions, str)
|
74 |
+
inputs = [{'img': images, 'question': questions}]
|
75 |
+
if objects is not None:
|
76 |
+
assert isinstance(objects[0], str)
|
77 |
+
inputs[0]['objects'] = objects
|
78 |
+
else:
|
79 |
+
inputs = []
|
80 |
+
for i in range(len(images)):
|
81 |
+
input_ = {'img': images[i], 'question': questions[i]}
|
82 |
+
if objects is not None:
|
83 |
+
input_['objects'] = objects[i]
|
84 |
+
inputs.append(input_)
|
85 |
+
|
86 |
+
return super().__call__(inputs, return_datasamples, batch_size,
|
87 |
+
**kwargs)
|
88 |
+
|
89 |
+
def _init_pipeline(self, cfg: Config) -> Callable:
|
90 |
+
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
91 |
+
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
92 |
+
# Image loading is finished in `self.preprocess`.
|
93 |
+
test_pipeline_cfg = test_pipeline_cfg[1:]
|
94 |
+
test_pipeline = Compose(
|
95 |
+
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
96 |
+
return test_pipeline
|
97 |
+
|
98 |
+
def preprocess(self, inputs: List[dict], batch_size: int = 1):
|
99 |
+
|
100 |
+
def load_image(input_: dict):
|
101 |
+
img = imread(input_['img'])
|
102 |
+
if img is None:
|
103 |
+
raise ValueError(f'Failed to read image {input_}.')
|
104 |
+
return {**input_, 'img': img}
|
105 |
+
|
106 |
+
pipeline = Compose([load_image, self.pipeline])
|
107 |
+
|
108 |
+
chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
|
109 |
+
yield from map(default_collate, chunked_data)
|
110 |
+
|
111 |
+
def visualize(self,
|
112 |
+
ori_inputs: List[dict],
|
113 |
+
preds: List[DataSample],
|
114 |
+
show: bool = False,
|
115 |
+
wait_time: int = 0,
|
116 |
+
resize: Optional[int] = None,
|
117 |
+
show_dir=None):
|
118 |
+
if not show and show_dir is None:
|
119 |
+
return None
|
120 |
+
|
121 |
+
if self.visualizer is None:
|
122 |
+
from mmpretrain.visualization import UniversalVisualizer
|
123 |
+
self.visualizer = UniversalVisualizer()
|
124 |
+
|
125 |
+
visualization = []
|
126 |
+
for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
|
127 |
+
image = imread(input_['img'])
|
128 |
+
if isinstance(input_['img'], str):
|
129 |
+
# The image loaded from path is BGR format.
|
130 |
+
image = image[..., ::-1]
|
131 |
+
name = Path(input_['img']).stem
|
132 |
+
else:
|
133 |
+
name = str(i)
|
134 |
+
|
135 |
+
if show_dir is not None:
|
136 |
+
show_dir = Path(show_dir)
|
137 |
+
show_dir.mkdir(exist_ok=True)
|
138 |
+
out_file = str((show_dir / name).with_suffix('.png'))
|
139 |
+
else:
|
140 |
+
out_file = None
|
141 |
+
|
142 |
+
self.visualizer.visualize_vqa(
|
143 |
+
image,
|
144 |
+
data_sample,
|
145 |
+
resize=resize,
|
146 |
+
show=show,
|
147 |
+
wait_time=wait_time,
|
148 |
+
name=name,
|
149 |
+
out_file=out_file)
|
150 |
+
visualization.append(self.visualizer.get_image())
|
151 |
+
if show:
|
152 |
+
self.visualizer.close()
|
153 |
+
return visualization
|
154 |
+
|
155 |
+
def postprocess(self,
|
156 |
+
preds: List[DataSample],
|
157 |
+
visualization: List[np.ndarray],
|
158 |
+
return_datasamples=False) -> dict:
|
159 |
+
if return_datasamples:
|
160 |
+
return preds
|
161 |
+
|
162 |
+
results = []
|
163 |
+
for data_sample in preds:
|
164 |
+
results.append({
|
165 |
+
'question': data_sample.get('question'),
|
166 |
+
'pred_answer': data_sample.get('pred_answer'),
|
167 |
+
})
|
168 |
+
|
169 |
+
return results
|
170 |
+
|
171 |
+
@staticmethod
|
172 |
+
def list_models(pattern: Optional[str] = None):
|
173 |
+
"""List all available model names.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
pattern (str | None): A wildcard pattern to match model names.
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
List[str]: a list of model names.
|
180 |
+
"""
|
181 |
+
return list_models(pattern=pattern, task='Visual Question Answering')
|
mmpretrain/datasets/__init__.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from mmpretrain.utils.dependency import WITH_MULTIMODAL
|
3 |
+
from .base_dataset import BaseDataset
|
4 |
+
from .builder import build_dataset
|
5 |
+
from .caltech101 import Caltech101
|
6 |
+
from .cifar import CIFAR10, CIFAR100
|
7 |
+
from .cub import CUB
|
8 |
+
from .custom import CustomDataset
|
9 |
+
from .dataset_wrappers import KFoldDataset
|
10 |
+
from .dtd import DTD
|
11 |
+
from .fgvcaircraft import FGVCAircraft
|
12 |
+
from .flowers102 import Flowers102
|
13 |
+
from .food101 import Food101
|
14 |
+
from .imagenet import ImageNet, ImageNet21k
|
15 |
+
from .inshop import InShop
|
16 |
+
from .mnist import MNIST, FashionMNIST
|
17 |
+
from .multi_label import MultiLabelDataset
|
18 |
+
from .multi_task import MultiTaskDataset
|
19 |
+
from .nlvr2 import NLVR2
|
20 |
+
from .oxfordiiitpet import OxfordIIITPet
|
21 |
+
from .places205 import Places205
|
22 |
+
from .samplers import * # noqa: F401,F403
|
23 |
+
from .stanfordcars import StanfordCars
|
24 |
+
from .sun397 import SUN397
|
25 |
+
from .transforms import * # noqa: F401,F403
|
26 |
+
from .voc import VOC
|
27 |
+
|
28 |
+
__all__ = [
|
29 |
+
'BaseDataset', 'CIFAR10', 'CIFAR100', 'CUB', 'Caltech101', 'CustomDataset',
|
30 |
+
'DTD', 'FGVCAircraft', 'FashionMNIST', 'Flowers102', 'Food101', 'ImageNet',
|
31 |
+
'ImageNet21k', 'InShop', 'KFoldDataset', 'MNIST', 'MultiLabelDataset',
|
32 |
+
'MultiTaskDataset', 'NLVR2', 'OxfordIIITPet', 'Places205', 'SUN397',
|
33 |
+
'StanfordCars', 'VOC', 'build_dataset'
|
34 |
+
]
|
35 |
+
|
36 |
+
if WITH_MULTIMODAL:
|
37 |
+
from .coco_caption import COCOCaption
|
38 |
+
from .coco_retrieval import COCORetrieval
|
39 |
+
from .coco_vqa import COCOVQA
|
40 |
+
from .flamingo import FlamingoEvalCOCOCaption, FlamingoEvalCOCOVQA
|
41 |
+
from .refcoco import RefCOCO
|
42 |
+
from .scienceqa import ScienceQA
|
43 |
+
from .visual_genome import VisualGenomeQA
|
44 |
+
|
45 |
+
__all__.extend([
|
46 |
+
'COCOCaption',
|
47 |
+
'COCORetrieval',
|
48 |
+
'COCOVQA',
|
49 |
+
'FlamingoEvalCOCOCaption',
|
50 |
+
'FlamingoEvalCOCOVQA',
|
51 |
+
'RefCOCO',
|
52 |
+
'VisualGenomeQA',
|
53 |
+
'ScienceQA',
|
54 |
+
])
|
mmpretrain/datasets/base_dataset.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
from os import PathLike
|
4 |
+
from typing import List, Optional, Sequence, Union
|
5 |
+
|
6 |
+
import mmengine
|
7 |
+
import numpy as np
|
8 |
+
from mmengine.dataset import BaseDataset as _BaseDataset
|
9 |
+
|
10 |
+
from mmpretrain.registry import DATASETS, TRANSFORMS
|
11 |
+
|
12 |
+
|
13 |
+
def expanduser(path):
|
14 |
+
"""Expand ~ and ~user constructions.
|
15 |
+
|
16 |
+
If user or $HOME is unknown, do nothing.
|
17 |
+
"""
|
18 |
+
if isinstance(path, (str, PathLike)):
|
19 |
+
return osp.expanduser(path)
|
20 |
+
else:
|
21 |
+
return path
|
22 |
+
|
23 |
+
|
24 |
+
@DATASETS.register_module()
|
25 |
+
class BaseDataset(_BaseDataset):
|
26 |
+
"""Base dataset for image classification task.
|
27 |
+
|
28 |
+
This dataset support annotation file in `OpenMMLab 2.0 style annotation
|
29 |
+
format`.
|
30 |
+
|
31 |
+
.. _OpenMMLab 2.0 style annotation format:
|
32 |
+
https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md
|
33 |
+
|
34 |
+
Comparing with the :class:`mmengine.BaseDataset`, this class implemented
|
35 |
+
several useful methods.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
ann_file (str): Annotation file path.
|
39 |
+
metainfo (dict, optional): Meta information for dataset, such as class
|
40 |
+
information. Defaults to None.
|
41 |
+
data_root (str): The root directory for ``data_prefix`` and
|
42 |
+
``ann_file``. Defaults to ''.
|
43 |
+
data_prefix (str | dict): Prefix for training data. Defaults to ''.
|
44 |
+
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
45 |
+
indices (int or Sequence[int], optional): Support using first few
|
46 |
+
data in annotation file to facilitate training/testing on a smaller
|
47 |
+
dataset. Defaults to None, which means using all ``data_infos``.
|
48 |
+
serialize_data (bool): Whether to hold memory using serialized objects,
|
49 |
+
when enabled, data loader workers can use shared RAM from master
|
50 |
+
process instead of making a copy. Defaults to True.
|
51 |
+
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
|
52 |
+
test_mode (bool, optional): ``test_mode=True`` means in test phase,
|
53 |
+
an error will be raised when getting an item fails, ``test_mode=False``
|
54 |
+
means in training phase, another item will be returned randomly.
|
55 |
+
Defaults to False.
|
56 |
+
lazy_init (bool): Whether to load annotation during instantiation.
|
57 |
+
In some cases, such as visualization, only the meta information of
|
58 |
+
the dataset is needed, which is not necessary to load annotation
|
59 |
+
file. ``Basedataset`` can skip load annotations to save time by set
|
60 |
+
``lazy_init=False``. Defaults to False.
|
61 |
+
max_refetch (int): If ``Basedataset.prepare_data`` get a None img.
|
62 |
+
The maximum extra number of cycles to get a valid image.
|
63 |
+
Defaults to 1000.
|
64 |
+
classes (str | Sequence[str], optional): Specify names of classes.
|
65 |
+
|
66 |
+
- If is string, it should be a file path, and the every line of
|
67 |
+
the file is a name of a class.
|
68 |
+
- If is a sequence of string, every item is a name of class.
|
69 |
+
- If is None, use categories information in ``metainfo`` argument,
|
70 |
+
annotation file or the class attribute ``METAINFO``.
|
71 |
+
|
72 |
+
Defaults to None.
|
73 |
+
""" # noqa: E501
|
74 |
+
|
75 |
+
def __init__(self,
|
76 |
+
ann_file: str,
|
77 |
+
metainfo: Optional[dict] = None,
|
78 |
+
data_root: str = '',
|
79 |
+
data_prefix: Union[str, dict] = '',
|
80 |
+
filter_cfg: Optional[dict] = None,
|
81 |
+
indices: Optional[Union[int, Sequence[int]]] = None,
|
82 |
+
serialize_data: bool = True,
|
83 |
+
pipeline: Sequence = (),
|
84 |
+
test_mode: bool = False,
|
85 |
+
lazy_init: bool = False,
|
86 |
+
max_refetch: int = 1000,
|
87 |
+
classes: Union[str, Sequence[str], None] = None):
|
88 |
+
if isinstance(data_prefix, str):
|
89 |
+
data_prefix = dict(img_path=expanduser(data_prefix))
|
90 |
+
|
91 |
+
ann_file = expanduser(ann_file)
|
92 |
+
metainfo = self._compat_classes(metainfo, classes)
|
93 |
+
|
94 |
+
transforms = []
|
95 |
+
for transform in pipeline:
|
96 |
+
if isinstance(transform, dict):
|
97 |
+
transforms.append(TRANSFORMS.build(transform))
|
98 |
+
else:
|
99 |
+
transforms.append(transform)
|
100 |
+
|
101 |
+
super().__init__(
|
102 |
+
ann_file=ann_file,
|
103 |
+
metainfo=metainfo,
|
104 |
+
data_root=data_root,
|
105 |
+
data_prefix=data_prefix,
|
106 |
+
filter_cfg=filter_cfg,
|
107 |
+
indices=indices,
|
108 |
+
serialize_data=serialize_data,
|
109 |
+
pipeline=transforms,
|
110 |
+
test_mode=test_mode,
|
111 |
+
lazy_init=lazy_init,
|
112 |
+
max_refetch=max_refetch)
|
113 |
+
|
114 |
+
@property
|
115 |
+
def img_prefix(self):
|
116 |
+
"""The prefix of images."""
|
117 |
+
return self.data_prefix['img_path']
|
118 |
+
|
119 |
+
@property
|
120 |
+
def CLASSES(self):
|
121 |
+
"""Return all categories names."""
|
122 |
+
return self._metainfo.get('classes', None)
|
123 |
+
|
124 |
+
@property
|
125 |
+
def class_to_idx(self):
|
126 |
+
"""Map mapping class name to class index.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
dict: mapping from class name to class index.
|
130 |
+
"""
|
131 |
+
|
132 |
+
return {cat: i for i, cat in enumerate(self.CLASSES)}
|
133 |
+
|
134 |
+
def get_gt_labels(self):
|
135 |
+
"""Get all ground-truth labels (categories).
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
np.ndarray: categories for all images.
|
139 |
+
"""
|
140 |
+
|
141 |
+
gt_labels = np.array(
|
142 |
+
[self.get_data_info(i)['gt_label'] for i in range(len(self))])
|
143 |
+
return gt_labels
|
144 |
+
|
145 |
+
def get_cat_ids(self, idx: int) -> List[int]:
|
146 |
+
"""Get category id by index.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
idx (int): Index of data.
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
cat_ids (List[int]): Image category of specified index.
|
153 |
+
"""
|
154 |
+
|
155 |
+
return [int(self.get_data_info(idx)['gt_label'])]
|
156 |
+
|
157 |
+
def _compat_classes(self, metainfo, classes):
|
158 |
+
"""Merge the old style ``classes`` arguments to ``metainfo``."""
|
159 |
+
if isinstance(classes, str):
|
160 |
+
# take it as a file path
|
161 |
+
class_names = mmengine.list_from_file(expanduser(classes))
|
162 |
+
elif isinstance(classes, (tuple, list)):
|
163 |
+
class_names = classes
|
164 |
+
elif classes is not None:
|
165 |
+
raise ValueError(f'Unsupported type {type(classes)} of classes.')
|
166 |
+
|
167 |
+
if metainfo is None:
|
168 |
+
metainfo = {}
|
169 |
+
|
170 |
+
if classes is not None:
|
171 |
+
metainfo = {'classes': tuple(class_names), **metainfo}
|
172 |
+
|
173 |
+
return metainfo
|
174 |
+
|
175 |
+
def full_init(self):
|
176 |
+
"""Load annotation file and set ``BaseDataset._fully_initialized`` to
|
177 |
+
True."""
|
178 |
+
super().full_init()
|
179 |
+
|
180 |
+
# To support the standard OpenMMLab 2.0 annotation format. Generate
|
181 |
+
# metainfo in internal format from standard metainfo format.
|
182 |
+
if 'categories' in self._metainfo and 'classes' not in self._metainfo:
|
183 |
+
categories = sorted(
|
184 |
+
self._metainfo['categories'], key=lambda x: x['id'])
|
185 |
+
self._metainfo['classes'] = tuple(
|
186 |
+
[cat['category_name'] for cat in categories])
|
187 |
+
|
188 |
+
def __repr__(self):
|
189 |
+
"""Print the basic information of the dataset.
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
str: Formatted string.
|
193 |
+
"""
|
194 |
+
head = 'Dataset ' + self.__class__.__name__
|
195 |
+
body = []
|
196 |
+
if self._fully_initialized:
|
197 |
+
body.append(f'Number of samples: \t{self.__len__()}')
|
198 |
+
else:
|
199 |
+
body.append("Haven't been initialized")
|
200 |
+
|
201 |
+
if self.CLASSES is not None:
|
202 |
+
body.append(f'Number of categories: \t{len(self.CLASSES)}')
|
203 |
+
|
204 |
+
body.extend(self.extra_repr())
|
205 |
+
|
206 |
+
if len(self.pipeline.transforms) > 0:
|
207 |
+
body.append('With transforms:')
|
208 |
+
for t in self.pipeline.transforms:
|
209 |
+
body.append(f' {t}')
|
210 |
+
|
211 |
+
lines = [head] + [' ' * 4 + line for line in body]
|
212 |
+
return '\n'.join(lines)
|
213 |
+
|
214 |
+
def extra_repr(self) -> List[str]:
|
215 |
+
"""The extra repr information of the dataset."""
|
216 |
+
body = []
|
217 |
+
body.append(f'Annotation file: \t{self.ann_file}')
|
218 |
+
body.append(f'Prefix of images: \t{self.img_prefix}')
|
219 |
+
return body
|
mmpretrain/datasets/builder.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from mmpretrain.registry import DATASETS
|
3 |
+
|
4 |
+
|
5 |
+
def build_dataset(cfg):
|
6 |
+
"""Build dataset.
|
7 |
+
|
8 |
+
Examples:
|
9 |
+
>>> from mmpretrain.datasets import build_dataset
|
10 |
+
>>> mnist_train = build_dataset(
|
11 |
+
... dict(type='MNIST', data_prefix='data/mnist/', test_mode=False))
|
12 |
+
>>> print(mnist_train)
|
13 |
+
Dataset MNIST
|
14 |
+
Number of samples: 60000
|
15 |
+
Number of categories: 10
|
16 |
+
Prefix of data: data/mnist/
|
17 |
+
>>> mnist_test = build_dataset(
|
18 |
+
... dict(type='MNIST', data_prefix='data/mnist/', test_mode=True))
|
19 |
+
>>> print(mnist_test)
|
20 |
+
Dataset MNIST
|
21 |
+
Number of samples: 10000
|
22 |
+
Number of categories: 10
|
23 |
+
Prefix of data: data/mnist/
|
24 |
+
"""
|
25 |
+
return DATASETS.build(cfg)
|
mmpretrain/datasets/caltech101.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from mmengine import get_file_backend, list_from_file
|
5 |
+
|
6 |
+
from mmpretrain.registry import DATASETS
|
7 |
+
from .base_dataset import BaseDataset
|
8 |
+
from .categories import CALTECH101_CATEGORIES
|
9 |
+
|
10 |
+
|
11 |
+
@DATASETS.register_module()
|
12 |
+
class Caltech101(BaseDataset):
|
13 |
+
"""The Caltech101 Dataset.
|
14 |
+
|
15 |
+
Support the `Caltech101 <https://data.caltech.edu/records/mzrjq-6wc02>`_ Dataset.
|
16 |
+
After downloading and decompression, the dataset directory structure is as follows.
|
17 |
+
|
18 |
+
Caltech101 dataset directory: ::
|
19 |
+
|
20 |
+
caltech-101
|
21 |
+
├── 101_ObjectCategories
|
22 |
+
│ ├── class_x
|
23 |
+
│ │ ├── xx1.jpg
|
24 |
+
│ │ ├── xx2.jpg
|
25 |
+
│ │ └── ...
|
26 |
+
│ ├── class_y
|
27 |
+
│ │ ├── yy1.jpg
|
28 |
+
│ │ ├── yy2.jpg
|
29 |
+
│ │ └── ...
|
30 |
+
│ └── ...
|
31 |
+
├── Annotations
|
32 |
+
│ ├── class_x
|
33 |
+
│ │ ├── xx1.mat
|
34 |
+
│ │ └── ...
|
35 |
+
│ └── ...
|
36 |
+
├── meta
|
37 |
+
│ ├── train.txt
|
38 |
+
│ └── test.txt
|
39 |
+
└── ....
|
40 |
+
|
41 |
+
Please note that since there is no official splitting for training and
|
42 |
+
test set, you can use the train.txt and text.txt provided by us or
|
43 |
+
create your own annotation files. Here is the download
|
44 |
+
`link <https://download.openmmlab.com/mmpretrain/datasets/caltech_meta.zip>`_
|
45 |
+
for the annotations.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
data_root (str): The root directory for the Caltech101 dataset.
|
49 |
+
split (str, optional): The dataset split, supports "train" and "test".
|
50 |
+
Default to "train".
|
51 |
+
|
52 |
+
Examples:
|
53 |
+
>>> from mmpretrain.datasets import Caltech101
|
54 |
+
>>> train_dataset = Caltech101(data_root='data/caltech-101', split='train')
|
55 |
+
>>> train_dataset
|
56 |
+
Dataset Caltech101
|
57 |
+
Number of samples: 3060
|
58 |
+
Number of categories: 102
|
59 |
+
Root of dataset: data/caltech-101
|
60 |
+
>>> test_dataset = Caltech101(data_root='data/caltech-101', split='test')
|
61 |
+
>>> test_dataset
|
62 |
+
Dataset Caltech101
|
63 |
+
Number of samples: 6728
|
64 |
+
Number of categories: 102
|
65 |
+
Root of dataset: data/caltech-101
|
66 |
+
""" # noqa: E501
|
67 |
+
|
68 |
+
METAINFO = {'classes': CALTECH101_CATEGORIES}
|
69 |
+
|
70 |
+
def __init__(self, data_root: str, split: str = 'train', **kwargs):
|
71 |
+
|
72 |
+
splits = ['train', 'test']
|
73 |
+
assert split in splits, \
|
74 |
+
f"The split must be one of {splits}, but get '{split}'"
|
75 |
+
self.split = split
|
76 |
+
|
77 |
+
self.backend = get_file_backend(data_root, enable_singleton=True)
|
78 |
+
|
79 |
+
if split == 'train':
|
80 |
+
ann_file = self.backend.join_path('meta', 'train.txt')
|
81 |
+
else:
|
82 |
+
ann_file = self.backend.join_path('meta', 'test.txt')
|
83 |
+
|
84 |
+
data_prefix = '101_ObjectCategories'
|
85 |
+
test_mode = split == 'test'
|
86 |
+
|
87 |
+
super(Caltech101, self).__init__(
|
88 |
+
ann_file=ann_file,
|
89 |
+
data_root=data_root,
|
90 |
+
data_prefix=data_prefix,
|
91 |
+
test_mode=test_mode,
|
92 |
+
**kwargs)
|
93 |
+
|
94 |
+
def load_data_list(self):
|
95 |
+
"""Load images and ground truth labels."""
|
96 |
+
|
97 |
+
pairs = list_from_file(self.ann_file)
|
98 |
+
data_list = []
|
99 |
+
|
100 |
+
for pair in pairs:
|
101 |
+
path, gt_label = pair.split()
|
102 |
+
img_path = self.backend.join_path(self.img_prefix, path)
|
103 |
+
info = dict(img_path=img_path, gt_label=int(gt_label))
|
104 |
+
data_list.append(info)
|
105 |
+
|
106 |
+
return data_list
|
107 |
+
|
108 |
+
def extra_repr(self) -> List[str]:
|
109 |
+
"""The extra repr information of the dataset."""
|
110 |
+
body = [
|
111 |
+
f'Root of dataset: \t{self.data_root}',
|
112 |
+
]
|
113 |
+
return body
|
mmpretrain/datasets/categories.py
ADDED
@@ -0,0 +1,1440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
# Pre-defined categories names of various datasets.
|
3 |
+
|
4 |
+
VOC2007_CATEGORIES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
|
5 |
+
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
|
6 |
+
'horse', 'motorbike', 'person', 'pottedplant', 'sheep',
|
7 |
+
'sofa', 'train', 'tvmonitor')
|
8 |
+
|
9 |
+
CUB_CATEGORIES = (
|
10 |
+
'Black_footed_Albatross', 'Laysan_Albatross', 'Sooty_Albatross',
|
11 |
+
'Groove_billed_Ani', 'Crested_Auklet', 'Least_Auklet', 'Parakeet_Auklet',
|
12 |
+
'Rhinoceros_Auklet', 'Brewer_Blackbird', 'Red_winged_Blackbird',
|
13 |
+
'Rusty_Blackbird', 'Yellow_headed_Blackbird', 'Bobolink', 'Indigo_Bunting',
|
14 |
+
'Lazuli_Bunting', 'Painted_Bunting', 'Cardinal', 'Spotted_Catbird',
|
15 |
+
'Gray_Catbird', 'Yellow_breasted_Chat', 'Eastern_Towhee',
|
16 |
+
'Chuck_will_Widow', 'Brandt_Cormorant', 'Red_faced_Cormorant',
|
17 |
+
'Pelagic_Cormorant', 'Bronzed_Cowbird', 'Shiny_Cowbird', 'Brown_Creeper',
|
18 |
+
'American_Crow', 'Fish_Crow', 'Black_billed_Cuckoo', 'Mangrove_Cuckoo',
|
19 |
+
'Yellow_billed_Cuckoo', 'Gray_crowned_Rosy_Finch', 'Purple_Finch',
|
20 |
+
'Northern_Flicker', 'Acadian_Flycatcher', 'Great_Crested_Flycatcher',
|
21 |
+
'Least_Flycatcher', 'Olive_sided_Flycatcher', 'Scissor_tailed_Flycatcher',
|
22 |
+
'Vermilion_Flycatcher', 'Yellow_bellied_Flycatcher', 'Frigatebird',
|
23 |
+
'Northern_Fulmar', 'Gadwall', 'American_Goldfinch', 'European_Goldfinch',
|
24 |
+
'Boat_tailed_Grackle', 'Eared_Grebe', 'Horned_Grebe', 'Pied_billed_Grebe',
|
25 |
+
'Western_Grebe', 'Blue_Grosbeak', 'Evening_Grosbeak', 'Pine_Grosbeak',
|
26 |
+
'Rose_breasted_Grosbeak', 'Pigeon_Guillemot', 'California_Gull',
|
27 |
+
'Glaucous_winged_Gull', 'Heermann_Gull', 'Herring_Gull', 'Ivory_Gull',
|
28 |
+
'Ring_billed_Gull', 'Slaty_backed_Gull', 'Western_Gull',
|
29 |
+
'Anna_Hummingbird', 'Ruby_throated_Hummingbird', 'Rufous_Hummingbird',
|
30 |
+
'Green_Violetear', 'Long_tailed_Jaeger', 'Pomarine_Jaeger', 'Blue_Jay',
|
31 |
+
'Florida_Jay', 'Green_Jay', 'Dark_eyed_Junco', 'Tropical_Kingbird',
|
32 |
+
'Gray_Kingbird', 'Belted_Kingfisher', 'Green_Kingfisher',
|
33 |
+
'Pied_Kingfisher', 'Ringed_Kingfisher', 'White_breasted_Kingfisher',
|
34 |
+
'Red_legged_Kittiwake', 'Horned_Lark', 'Pacific_Loon', 'Mallard',
|
35 |
+
'Western_Meadowlark', 'Hooded_Merganser', 'Red_breasted_Merganser',
|
36 |
+
'Mockingbird', 'Nighthawk', 'Clark_Nutcracker', 'White_breasted_Nuthatch',
|
37 |
+
'Baltimore_Oriole', 'Hooded_Oriole', 'Orchard_Oriole', 'Scott_Oriole',
|
38 |
+
'Ovenbird', 'Brown_Pelican', 'White_Pelican', 'Western_Wood_Pewee',
|
39 |
+
'Sayornis', 'American_Pipit', 'Whip_poor_Will', 'Horned_Puffin',
|
40 |
+
'Common_Raven', 'White_necked_Raven', 'American_Redstart', 'Geococcyx',
|
41 |
+
'Loggerhead_Shrike', 'Great_Grey_Shrike', 'Baird_Sparrow',
|
42 |
+
'Black_throated_Sparrow', 'Brewer_Sparrow', 'Chipping_Sparrow',
|
43 |
+
'Clay_colored_Sparrow', 'House_Sparrow', 'Field_Sparrow', 'Fox_Sparrow',
|
44 |
+
'Grasshopper_Sparrow', 'Harris_Sparrow', 'Henslow_Sparrow',
|
45 |
+
'Le_Conte_Sparrow', 'Lincoln_Sparrow', 'Nelson_Sharp_tailed_Sparrow',
|
46 |
+
'Savannah_Sparrow', 'Seaside_Sparrow', 'Song_Sparrow', 'Tree_Sparrow',
|
47 |
+
'Vesper_Sparrow', 'White_crowned_Sparrow', 'White_throated_Sparrow',
|
48 |
+
'Cape_Glossy_Starling', 'Bank_Swallow', 'Barn_Swallow', 'Cliff_Swallow',
|
49 |
+
'Tree_Swallow', 'Scarlet_Tanager', 'Summer_Tanager', 'Artic_Tern',
|
50 |
+
'Black_Tern', 'Caspian_Tern', 'Common_Tern', 'Elegant_Tern',
|
51 |
+
'Forsters_Tern', 'Least_Tern', 'Green_tailed_Towhee', 'Brown_Thrasher',
|
52 |
+
'Sage_Thrasher', 'Black_capped_Vireo', 'Blue_headed_Vireo',
|
53 |
+
'Philadelphia_Vireo', 'Red_eyed_Vireo', 'Warbling_Vireo',
|
54 |
+
'White_eyed_Vireo', 'Yellow_throated_Vireo', 'Bay_breasted_Warbler',
|
55 |
+
'Black_and_white_Warbler', 'Black_throated_Blue_Warbler',
|
56 |
+
'Blue_winged_Warbler', 'Canada_Warbler', 'Cape_May_Warbler',
|
57 |
+
'Cerulean_Warbler', 'Chestnut_sided_Warbler', 'Golden_winged_Warbler',
|
58 |
+
'Hooded_Warbler', 'Kentucky_Warbler', 'Magnolia_Warbler',
|
59 |
+
'Mourning_Warbler', 'Myrtle_Warbler', 'Nashville_Warbler',
|
60 |
+
'Orange_crowned_Warbler', 'Palm_Warbler', 'Pine_Warbler',
|
61 |
+
'Prairie_Warbler', 'Prothonotary_Warbler', 'Swainson_Warbler',
|
62 |
+
'Tennessee_Warbler', 'Wilson_Warbler', 'Worm_eating_Warbler',
|
63 |
+
'Yellow_Warbler', 'Northern_Waterthrush', 'Louisiana_Waterthrush',
|
64 |
+
'Bohemian_Waxwing', 'Cedar_Waxwing', 'American_Three_toed_Woodpecker',
|
65 |
+
'Pileated_Woodpecker', 'Red_bellied_Woodpecker', 'Red_cockaded_Woodpecker',
|
66 |
+
'Red_headed_Woodpecker', 'Downy_Woodpecker', 'Bewick_Wren', 'Cactus_Wren',
|
67 |
+
'Carolina_Wren', 'House_Wren', 'Marsh_Wren', 'Rock_Wren', 'Winter_Wren',
|
68 |
+
'Common_Yellowthroat')
|
69 |
+
|
70 |
+
IMAGENET_CATEGORIES = (
|
71 |
+
'tench, Tinca tinca',
|
72 |
+
'goldfish, Carassius auratus',
|
73 |
+
'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', # noqa: E501
|
74 |
+
'tiger shark, Galeocerdo cuvieri',
|
75 |
+
'hammerhead, hammerhead shark',
|
76 |
+
'electric ray, crampfish, numbfish, torpedo',
|
77 |
+
'stingray',
|
78 |
+
'cock',
|
79 |
+
'hen',
|
80 |
+
'ostrich, Struthio camelus',
|
81 |
+
'brambling, Fringilla montifringilla',
|
82 |
+
'goldfinch, Carduelis carduelis',
|
83 |
+
'house finch, linnet, Carpodacus mexicanus',
|
84 |
+
'junco, snowbird',
|
85 |
+
'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
|
86 |
+
'robin, American robin, Turdus migratorius',
|
87 |
+
'bulbul',
|
88 |
+
'jay',
|
89 |
+
'magpie',
|
90 |
+
'chickadee',
|
91 |
+
'water ouzel, dipper',
|
92 |
+
'kite',
|
93 |
+
'bald eagle, American eagle, Haliaeetus leucocephalus',
|
94 |
+
'vulture',
|
95 |
+
'great grey owl, great gray owl, Strix nebulosa',
|
96 |
+
'European fire salamander, Salamandra salamandra',
|
97 |
+
'common newt, Triturus vulgaris',
|
98 |
+
'eft',
|
99 |
+
'spotted salamander, Ambystoma maculatum',
|
100 |
+
'axolotl, mud puppy, Ambystoma mexicanum',
|
101 |
+
'bullfrog, Rana catesbeiana',
|
102 |
+
'tree frog, tree-frog',
|
103 |
+
'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
|
104 |
+
'loggerhead, loggerhead turtle, Caretta caretta',
|
105 |
+
'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea', # noqa: E501
|
106 |
+
'mud turtle',
|
107 |
+
'terrapin',
|
108 |
+
'box turtle, box tortoise',
|
109 |
+
'banded gecko',
|
110 |
+
'common iguana, iguana, Iguana iguana',
|
111 |
+
'American chameleon, anole, Anolis carolinensis',
|
112 |
+
'whiptail, whiptail lizard',
|
113 |
+
'agama',
|
114 |
+
'frilled lizard, Chlamydosaurus kingi',
|
115 |
+
'alligator lizard',
|
116 |
+
'Gila monster, Heloderma suspectum',
|
117 |
+
'green lizard, Lacerta viridis',
|
118 |
+
'African chameleon, Chamaeleo chamaeleon',
|
119 |
+
'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis', # noqa: E501
|
120 |
+
'African crocodile, Nile crocodile, Crocodylus niloticus',
|
121 |
+
'American alligator, Alligator mississipiensis',
|
122 |
+
'triceratops',
|
123 |
+
'thunder snake, worm snake, Carphophis amoenus',
|
124 |
+
'ringneck snake, ring-necked snake, ring snake',
|
125 |
+
'hognose snake, puff adder, sand viper',
|
126 |
+
'green snake, grass snake',
|
127 |
+
'king snake, kingsnake',
|
128 |
+
'garter snake, grass snake',
|
129 |
+
'water snake',
|
130 |
+
'vine snake',
|
131 |
+
'night snake, Hypsiglena torquata',
|
132 |
+
'boa constrictor, Constrictor constrictor',
|
133 |
+
'rock python, rock snake, Python sebae',
|
134 |
+
'Indian cobra, Naja naja',
|
135 |
+
'green mamba',
|
136 |
+
'sea snake',
|
137 |
+
'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
|
138 |
+
'diamondback, diamondback rattlesnake, Crotalus adamanteus',
|
139 |
+
'sidewinder, horned rattlesnake, Crotalus cerastes',
|
140 |
+
'trilobite',
|
141 |
+
'harvestman, daddy longlegs, Phalangium opilio',
|
142 |
+
'scorpion',
|
143 |
+
'black and gold garden spider, Argiope aurantia',
|
144 |
+
'barn spider, Araneus cavaticus',
|
145 |
+
'garden spider, Aranea diademata',
|
146 |
+
'black widow, Latrodectus mactans',
|
147 |
+
'tarantula',
|
148 |
+
'wolf spider, hunting spider',
|
149 |
+
'tick',
|
150 |
+
'centipede',
|
151 |
+
'black grouse',
|
152 |
+
'ptarmigan',
|
153 |
+
'ruffed grouse, partridge, Bonasa umbellus',
|
154 |
+
'prairie chicken, prairie grouse, prairie fowl',
|
155 |
+
'peacock',
|
156 |
+
'quail',
|
157 |
+
'partridge',
|
158 |
+
'African grey, African gray, Psittacus erithacus',
|
159 |
+
'macaw',
|
160 |
+
'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
|
161 |
+
'lorikeet',
|
162 |
+
'coucal',
|
163 |
+
'bee eater',
|
164 |
+
'hornbill',
|
165 |
+
'hummingbird',
|
166 |
+
'jacamar',
|
167 |
+
'toucan',
|
168 |
+
'drake',
|
169 |
+
'red-breasted merganser, Mergus serrator',
|
170 |
+
'goose',
|
171 |
+
'black swan, Cygnus atratus',
|
172 |
+
'tusker',
|
173 |
+
'echidna, spiny anteater, anteater',
|
174 |
+
'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus', # noqa: E501
|
175 |
+
'wallaby, brush kangaroo',
|
176 |
+
'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus', # noqa: E501
|
177 |
+
'wombat',
|
178 |
+
'jellyfish',
|
179 |
+
'sea anemone, anemone',
|
180 |
+
'brain coral',
|
181 |
+
'flatworm, platyhelminth',
|
182 |
+
'nematode, nematode worm, roundworm',
|
183 |
+
'conch',
|
184 |
+
'snail',
|
185 |
+
'slug',
|
186 |
+
'sea slug, nudibranch',
|
187 |
+
'chiton, coat-of-mail shell, sea cradle, polyplacophore',
|
188 |
+
'chambered nautilus, pearly nautilus, nautilus',
|
189 |
+
'Dungeness crab, Cancer magister',
|
190 |
+
'rock crab, Cancer irroratus',
|
191 |
+
'fiddler crab',
|
192 |
+
'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica', # noqa: E501
|
193 |
+
'American lobster, Northern lobster, Maine lobster, Homarus americanus', # noqa: E501
|
194 |
+
'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish', # noqa: E501
|
195 |
+
'crayfish, crawfish, crawdad, crawdaddy',
|
196 |
+
'hermit crab',
|
197 |
+
'isopod',
|
198 |
+
'white stork, Ciconia ciconia',
|
199 |
+
'black stork, Ciconia nigra',
|
200 |
+
'spoonbill',
|
201 |
+
'flamingo',
|
202 |
+
'little blue heron, Egretta caerulea',
|
203 |
+
'American egret, great white heron, Egretta albus',
|
204 |
+
'bittern',
|
205 |
+
'crane',
|
206 |
+
'limpkin, Aramus pictus',
|
207 |
+
'European gallinule, Porphyrio porphyrio',
|
208 |
+
'American coot, marsh hen, mud hen, water hen, Fulica americana',
|
209 |
+
'bustard',
|
210 |
+
'ruddy turnstone, Arenaria interpres',
|
211 |
+
'red-backed sandpiper, dunlin, Erolia alpina',
|
212 |
+
'redshank, Tringa totanus',
|
213 |
+
'dowitcher',
|
214 |
+
'oystercatcher, oyster catcher',
|
215 |
+
'pelican',
|
216 |
+
'king penguin, Aptenodytes patagonica',
|
217 |
+
'albatross, mollymawk',
|
218 |
+
'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus', # noqa: E501
|
219 |
+
'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
|
220 |
+
'dugong, Dugong dugon',
|
221 |
+
'sea lion',
|
222 |
+
'Chihuahua',
|
223 |
+
'Japanese spaniel',
|
224 |
+
'Maltese dog, Maltese terrier, Maltese',
|
225 |
+
'Pekinese, Pekingese, Peke',
|
226 |
+
'Shih-Tzu',
|
227 |
+
'Blenheim spaniel',
|
228 |
+
'papillon',
|
229 |
+
'toy terrier',
|
230 |
+
'Rhodesian ridgeback',
|
231 |
+
'Afghan hound, Afghan',
|
232 |
+
'basset, basset hound',
|
233 |
+
'beagle',
|
234 |
+
'bloodhound, sleuthhound',
|
235 |
+
'bluetick',
|
236 |
+
'black-and-tan coonhound',
|
237 |
+
'Walker hound, Walker foxhound',
|
238 |
+
'English foxhound',
|
239 |
+
'redbone',
|
240 |
+
'borzoi, Russian wolfhound',
|
241 |
+
'Irish wolfhound',
|
242 |
+
'Italian greyhound',
|
243 |
+
'whippet',
|
244 |
+
'Ibizan hound, Ibizan Podenco',
|
245 |
+
'Norwegian elkhound, elkhound',
|
246 |
+
'otterhound, otter hound',
|
247 |
+
'Saluki, gazelle hound',
|
248 |
+
'Scottish deerhound, deerhound',
|
249 |
+
'Weimaraner',
|
250 |
+
'Staffordshire bullterrier, Staffordshire bull terrier',
|
251 |
+
'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier', # noqa: E501
|
252 |
+
'Bedlington terrier',
|
253 |
+
'Border terrier',
|
254 |
+
'Kerry blue terrier',
|
255 |
+
'Irish terrier',
|
256 |
+
'Norfolk terrier',
|
257 |
+
'Norwich terrier',
|
258 |
+
'Yorkshire terrier',
|
259 |
+
'wire-haired fox terrier',
|
260 |
+
'Lakeland terrier',
|
261 |
+
'Sealyham terrier, Sealyham',
|
262 |
+
'Airedale, Airedale terrier',
|
263 |
+
'cairn, cairn terrier',
|
264 |
+
'Australian terrier',
|
265 |
+
'Dandie Dinmont, Dandie Dinmont terrier',
|
266 |
+
'Boston bull, Boston terrier',
|
267 |
+
'miniature schnauzer',
|
268 |
+
'giant schnauzer',
|
269 |
+
'standard schnauzer',
|
270 |
+
'Scotch terrier, Scottish terrier, Scottie',
|
271 |
+
'Tibetan terrier, chrysanthemum dog',
|
272 |
+
'silky terrier, Sydney silky',
|
273 |
+
'soft-coated wheaten terrier',
|
274 |
+
'West Highland white terrier',
|
275 |
+
'Lhasa, Lhasa apso',
|
276 |
+
'flat-coated retriever',
|
277 |
+
'curly-coated retriever',
|
278 |
+
'golden retriever',
|
279 |
+
'Labrador retriever',
|
280 |
+
'Chesapeake Bay retriever',
|
281 |
+
'German short-haired pointer',
|
282 |
+
'vizsla, Hungarian pointer',
|
283 |
+
'English setter',
|
284 |
+
'Irish setter, red setter',
|
285 |
+
'Gordon setter',
|
286 |
+
'Brittany spaniel',
|
287 |
+
'clumber, clumber spaniel',
|
288 |
+
'English springer, English springer spaniel',
|
289 |
+
'Welsh springer spaniel',
|
290 |
+
'cocker spaniel, English cocker spaniel, cocker',
|
291 |
+
'Sussex spaniel',
|
292 |
+
'Irish water spaniel',
|
293 |
+
'kuvasz',
|
294 |
+
'schipperke',
|
295 |
+
'groenendael',
|
296 |
+
'malinois',
|
297 |
+
'briard',
|
298 |
+
'kelpie',
|
299 |
+
'komondor',
|
300 |
+
'Old English sheepdog, bobtail',
|
301 |
+
'Shetland sheepdog, Shetland sheep dog, Shetland',
|
302 |
+
'collie',
|
303 |
+
'Border collie',
|
304 |
+
'Bouvier des Flandres, Bouviers des Flandres',
|
305 |
+
'Rottweiler',
|
306 |
+
'German shepherd, German shepherd dog, German police dog, alsatian',
|
307 |
+
'Doberman, Doberman pinscher',
|
308 |
+
'miniature pinscher',
|
309 |
+
'Greater Swiss Mountain dog',
|
310 |
+
'Bernese mountain dog',
|
311 |
+
'Appenzeller',
|
312 |
+
'EntleBucher',
|
313 |
+
'boxer',
|
314 |
+
'bull mastiff',
|
315 |
+
'Tibetan mastiff',
|
316 |
+
'French bulldog',
|
317 |
+
'Great Dane',
|
318 |
+
'Saint Bernard, St Bernard',
|
319 |
+
'Eskimo dog, husky',
|
320 |
+
'malamute, malemute, Alaskan malamute',
|
321 |
+
'Siberian husky',
|
322 |
+
'dalmatian, coach dog, carriage dog',
|
323 |
+
'affenpinscher, monkey pinscher, monkey dog',
|
324 |
+
'basenji',
|
325 |
+
'pug, pug-dog',
|
326 |
+
'Leonberg',
|
327 |
+
'Newfoundland, Newfoundland dog',
|
328 |
+
'Great Pyrenees',
|
329 |
+
'Samoyed, Samoyede',
|
330 |
+
'Pomeranian',
|
331 |
+
'chow, chow chow',
|
332 |
+
'keeshond',
|
333 |
+
'Brabancon griffon',
|
334 |
+
'Pembroke, Pembroke Welsh corgi',
|
335 |
+
'Cardigan, Cardigan Welsh corgi',
|
336 |
+
'toy poodle',
|
337 |
+
'miniature poodle',
|
338 |
+
'standard poodle',
|
339 |
+
'Mexican hairless',
|
340 |
+
'timber wolf, grey wolf, gray wolf, Canis lupus',
|
341 |
+
'white wolf, Arctic wolf, Canis lupus tundrarum',
|
342 |
+
'red wolf, maned wolf, Canis rufus, Canis niger',
|
343 |
+
'coyote, prairie wolf, brush wolf, Canis latrans',
|
344 |
+
'dingo, warrigal, warragal, Canis dingo',
|
345 |
+
'dhole, Cuon alpinus',
|
346 |
+
'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
|
347 |
+
'hyena, hyaena',
|
348 |
+
'red fox, Vulpes vulpes',
|
349 |
+
'kit fox, Vulpes macrotis',
|
350 |
+
'Arctic fox, white fox, Alopex lagopus',
|
351 |
+
'grey fox, gray fox, Urocyon cinereoargenteus',
|
352 |
+
'tabby, tabby cat',
|
353 |
+
'tiger cat',
|
354 |
+
'Persian cat',
|
355 |
+
'Siamese cat, Siamese',
|
356 |
+
'Egyptian cat',
|
357 |
+
'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor', # noqa: E501
|
358 |
+
'lynx, catamount',
|
359 |
+
'leopard, Panthera pardus',
|
360 |
+
'snow leopard, ounce, Panthera uncia',
|
361 |
+
'jaguar, panther, Panthera onca, Felis onca',
|
362 |
+
'lion, king of beasts, Panthera leo',
|
363 |
+
'tiger, Panthera tigris',
|
364 |
+
'cheetah, chetah, Acinonyx jubatus',
|
365 |
+
'brown bear, bruin, Ursus arctos',
|
366 |
+
'American black bear, black bear, Ursus americanus, Euarctos americanus', # noqa: E501
|
367 |
+
'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
|
368 |
+
'sloth bear, Melursus ursinus, Ursus ursinus',
|
369 |
+
'mongoose',
|
370 |
+
'meerkat, mierkat',
|
371 |
+
'tiger beetle',
|
372 |
+
'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
|
373 |
+
'ground beetle, carabid beetle',
|
374 |
+
'long-horned beetle, longicorn, longicorn beetle',
|
375 |
+
'leaf beetle, chrysomelid',
|
376 |
+
'dung beetle',
|
377 |
+
'rhinoceros beetle',
|
378 |
+
'weevil',
|
379 |
+
'fly',
|
380 |
+
'bee',
|
381 |
+
'ant, emmet, pismire',
|
382 |
+
'grasshopper, hopper',
|
383 |
+
'cricket',
|
384 |
+
'walking stick, walkingstick, stick insect',
|
385 |
+
'cockroach, roach',
|
386 |
+
'mantis, mantid',
|
387 |
+
'cicada, cicala',
|
388 |
+
'leafhopper',
|
389 |
+
'lacewing, lacewing fly',
|
390 |
+
"dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", # noqa: E501
|
391 |
+
'damselfly',
|
392 |
+
'admiral',
|
393 |
+
'ringlet, ringlet butterfly',
|
394 |
+
'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
|
395 |
+
'cabbage butterfly',
|
396 |
+
'sulphur butterfly, sulfur butterfly',
|
397 |
+
'lycaenid, lycaenid butterfly',
|
398 |
+
'starfish, sea star',
|
399 |
+
'sea urchin',
|
400 |
+
'sea cucumber, holothurian',
|
401 |
+
'wood rabbit, cottontail, cottontail rabbit',
|
402 |
+
'hare',
|
403 |
+
'Angora, Angora rabbit',
|
404 |
+
'hamster',
|
405 |
+
'porcupine, hedgehog',
|
406 |
+
'fox squirrel, eastern fox squirrel, Sciurus niger',
|
407 |
+
'marmot',
|
408 |
+
'beaver',
|
409 |
+
'guinea pig, Cavia cobaya',
|
410 |
+
'sorrel',
|
411 |
+
'zebra',
|
412 |
+
'hog, pig, grunter, squealer, Sus scrofa',
|
413 |
+
'wild boar, boar, Sus scrofa',
|
414 |
+
'warthog',
|
415 |
+
'hippopotamus, hippo, river horse, Hippopotamus amphibius',
|
416 |
+
'ox',
|
417 |
+
'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
|
418 |
+
'bison',
|
419 |
+
'ram, tup',
|
420 |
+
'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis', # noqa: E501
|
421 |
+
'ibex, Capra ibex',
|
422 |
+
'hartebeest',
|
423 |
+
'impala, Aepyceros melampus',
|
424 |
+
'gazelle',
|
425 |
+
'Arabian camel, dromedary, Camelus dromedarius',
|
426 |
+
'llama',
|
427 |
+
'weasel',
|
428 |
+
'mink',
|
429 |
+
'polecat, fitch, foulmart, foumart, Mustela putorius',
|
430 |
+
'black-footed ferret, ferret, Mustela nigripes',
|
431 |
+
'otter',
|
432 |
+
'skunk, polecat, wood pussy',
|
433 |
+
'badger',
|
434 |
+
'armadillo',
|
435 |
+
'three-toed sloth, ai, Bradypus tridactylus',
|
436 |
+
'orangutan, orang, orangutang, Pongo pygmaeus',
|
437 |
+
'gorilla, Gorilla gorilla',
|
438 |
+
'chimpanzee, chimp, Pan troglodytes',
|
439 |
+
'gibbon, Hylobates lar',
|
440 |
+
'siamang, Hylobates syndactylus, Symphalangus syndactylus',
|
441 |
+
'guenon, guenon monkey',
|
442 |
+
'patas, hussar monkey, Erythrocebus patas',
|
443 |
+
'baboon',
|
444 |
+
'macaque',
|
445 |
+
'langur',
|
446 |
+
'colobus, colobus monkey',
|
447 |
+
'proboscis monkey, Nasalis larvatus',
|
448 |
+
'marmoset',
|
449 |
+
'capuchin, ringtail, Cebus capucinus',
|
450 |
+
'howler monkey, howler',
|
451 |
+
'titi, titi monkey',
|
452 |
+
'spider monkey, Ateles geoffroyi',
|
453 |
+
'squirrel monkey, Saimiri sciureus',
|
454 |
+
'Madagascar cat, ring-tailed lemur, Lemur catta',
|
455 |
+
'indri, indris, Indri indri, Indri brevicaudatus',
|
456 |
+
'Indian elephant, Elephas maximus',
|
457 |
+
'African elephant, Loxodonta africana',
|
458 |
+
'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
|
459 |
+
'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
|
460 |
+
'barracouta, snoek',
|
461 |
+
'eel',
|
462 |
+
'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch', # noqa: E501
|
463 |
+
'rock beauty, Holocanthus tricolor',
|
464 |
+
'anemone fish',
|
465 |
+
'sturgeon',
|
466 |
+
'gar, garfish, garpike, billfish, Lepisosteus osseus',
|
467 |
+
'lionfish',
|
468 |
+
'puffer, pufferfish, blowfish, globefish',
|
469 |
+
'abacus',
|
470 |
+
'abaya',
|
471 |
+
"academic gown, academic robe, judge's robe",
|
472 |
+
'accordion, piano accordion, squeeze box',
|
473 |
+
'acoustic guitar',
|
474 |
+
'aircraft carrier, carrier, flattop, attack aircraft carrier',
|
475 |
+
'airliner',
|
476 |
+
'airship, dirigible',
|
477 |
+
'altar',
|
478 |
+
'ambulance',
|
479 |
+
'amphibian, amphibious vehicle',
|
480 |
+
'analog clock',
|
481 |
+
'apiary, bee house',
|
482 |
+
'apron',
|
483 |
+
'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin', # noqa: E501
|
484 |
+
'assault rifle, assault gun',
|
485 |
+
'backpack, back pack, knapsack, packsack, rucksack, haversack',
|
486 |
+
'bakery, bakeshop, bakehouse',
|
487 |
+
'balance beam, beam',
|
488 |
+
'balloon',
|
489 |
+
'ballpoint, ballpoint pen, ballpen, Biro',
|
490 |
+
'Band Aid',
|
491 |
+
'banjo',
|
492 |
+
'bannister, banister, balustrade, balusters, handrail',
|
493 |
+
'barbell',
|
494 |
+
'barber chair',
|
495 |
+
'barbershop',
|
496 |
+
'barn',
|
497 |
+
'barometer',
|
498 |
+
'barrel, cask',
|
499 |
+
'barrow, garden cart, lawn cart, wheelbarrow',
|
500 |
+
'baseball',
|
501 |
+
'basketball',
|
502 |
+
'bassinet',
|
503 |
+
'bassoon',
|
504 |
+
'bathing cap, swimming cap',
|
505 |
+
'bath towel',
|
506 |
+
'bathtub, bathing tub, bath, tub',
|
507 |
+
'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon', # noqa: E501
|
508 |
+
'beacon, lighthouse, beacon light, pharos',
|
509 |
+
'beaker',
|
510 |
+
'bearskin, busby, shako',
|
511 |
+
'beer bottle',
|
512 |
+
'beer glass',
|
513 |
+
'bell cote, bell cot',
|
514 |
+
'bib',
|
515 |
+
'bicycle-built-for-two, tandem bicycle, tandem',
|
516 |
+
'bikini, two-piece',
|
517 |
+
'binder, ring-binder',
|
518 |
+
'binoculars, field glasses, opera glasses',
|
519 |
+
'birdhouse',
|
520 |
+
'boathouse',
|
521 |
+
'bobsled, bobsleigh, bob',
|
522 |
+
'bolo tie, bolo, bola tie, bola',
|
523 |
+
'bonnet, poke bonnet',
|
524 |
+
'bookcase',
|
525 |
+
'bookshop, bookstore, bookstall',
|
526 |
+
'bottlecap',
|
527 |
+
'bow',
|
528 |
+
'bow tie, bow-tie, bowtie',
|
529 |
+
'brass, memorial tablet, plaque',
|
530 |
+
'brassiere, bra, bandeau',
|
531 |
+
'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
|
532 |
+
'breastplate, aegis, egis',
|
533 |
+
'broom',
|
534 |
+
'bucket, pail',
|
535 |
+
'buckle',
|
536 |
+
'bulletproof vest',
|
537 |
+
'bullet train, bullet',
|
538 |
+
'butcher shop, meat market',
|
539 |
+
'cab, hack, taxi, taxicab',
|
540 |
+
'caldron, cauldron',
|
541 |
+
'candle, taper, wax light',
|
542 |
+
'cannon',
|
543 |
+
'canoe',
|
544 |
+
'can opener, tin opener',
|
545 |
+
'cardigan',
|
546 |
+
'car mirror',
|
547 |
+
'carousel, carrousel, merry-go-round, roundabout, whirligig',
|
548 |
+
"carpenter's kit, tool kit",
|
549 |
+
'carton',
|
550 |
+
'car wheel',
|
551 |
+
'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM', # noqa: E501
|
552 |
+
'cassette',
|
553 |
+
'cassette player',
|
554 |
+
'castle',
|
555 |
+
'catamaran',
|
556 |
+
'CD player',
|
557 |
+
'cello, violoncello',
|
558 |
+
'cellular telephone, cellular phone, cellphone, cell, mobile phone',
|
559 |
+
'chain',
|
560 |
+
'chainlink fence',
|
561 |
+
'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour', # noqa: E501
|
562 |
+
'chain saw, chainsaw',
|
563 |
+
'chest',
|
564 |
+
'chiffonier, commode',
|
565 |
+
'chime, bell, gong',
|
566 |
+
'china cabinet, china closet',
|
567 |
+
'Christmas stocking',
|
568 |
+
'church, church building',
|
569 |
+
'cinema, movie theater, movie theatre, movie house, picture palace',
|
570 |
+
'cleaver, meat cleaver, chopper',
|
571 |
+
'cliff dwelling',
|
572 |
+
'cloak',
|
573 |
+
'clog, geta, patten, sabot',
|
574 |
+
'cocktail shaker',
|
575 |
+
'coffee mug',
|
576 |
+
'coffeepot',
|
577 |
+
'coil, spiral, volute, whorl, helix',
|
578 |
+
'combination lock',
|
579 |
+
'computer keyboard, keypad',
|
580 |
+
'confectionery, confectionary, candy store',
|
581 |
+
'container ship, containership, container vessel',
|
582 |
+
'convertible',
|
583 |
+
'corkscrew, bottle screw',
|
584 |
+
'cornet, horn, trumpet, trump',
|
585 |
+
'cowboy boot',
|
586 |
+
'cowboy hat, ten-gallon hat',
|
587 |
+
'cradle',
|
588 |
+
'crane',
|
589 |
+
'crash helmet',
|
590 |
+
'crate',
|
591 |
+
'crib, cot',
|
592 |
+
'Crock Pot',
|
593 |
+
'croquet ball',
|
594 |
+
'crutch',
|
595 |
+
'cuirass',
|
596 |
+
'dam, dike, dyke',
|
597 |
+
'desk',
|
598 |
+
'desktop computer',
|
599 |
+
'dial telephone, dial phone',
|
600 |
+
'diaper, nappy, napkin',
|
601 |
+
'digital clock',
|
602 |
+
'digital watch',
|
603 |
+
'dining table, board',
|
604 |
+
'dishrag, dishcloth',
|
605 |
+
'dishwasher, dish washer, dishwashing machine',
|
606 |
+
'disk brake, disc brake',
|
607 |
+
'dock, dockage, docking facility',
|
608 |
+
'dogsled, dog sled, dog sleigh',
|
609 |
+
'dome',
|
610 |
+
'doormat, welcome mat',
|
611 |
+
'drilling platform, offshore rig',
|
612 |
+
'drum, membranophone, tympan',
|
613 |
+
'drumstick',
|
614 |
+
'dumbbell',
|
615 |
+
'Dutch oven',
|
616 |
+
'electric fan, blower',
|
617 |
+
'electric guitar',
|
618 |
+
'electric locomotive',
|
619 |
+
'entertainment center',
|
620 |
+
'envelope',
|
621 |
+
'espresso maker',
|
622 |
+
'face powder',
|
623 |
+
'feather boa, boa',
|
624 |
+
'file, file cabinet, filing cabinet',
|
625 |
+
'fireboat',
|
626 |
+
'fire engine, fire truck',
|
627 |
+
'fire screen, fireguard',
|
628 |
+
'flagpole, flagstaff',
|
629 |
+
'flute, transverse flute',
|
630 |
+
'folding chair',
|
631 |
+
'football helmet',
|
632 |
+
'forklift',
|
633 |
+
'fountain',
|
634 |
+
'fountain pen',
|
635 |
+
'four-poster',
|
636 |
+
'freight car',
|
637 |
+
'French horn, horn',
|
638 |
+
'frying pan, frypan, skillet',
|
639 |
+
'fur coat',
|
640 |
+
'garbage truck, dustcart',
|
641 |
+
'gasmask, respirator, gas helmet',
|
642 |
+
'gas pump, gasoline pump, petrol pump, island dispenser',
|
643 |
+
'goblet',
|
644 |
+
'go-kart',
|
645 |
+
'golf ball',
|
646 |
+
'golfcart, golf cart',
|
647 |
+
'gondola',
|
648 |
+
'gong, tam-tam',
|
649 |
+
'gown',
|
650 |
+
'grand piano, grand',
|
651 |
+
'greenhouse, nursery, glasshouse',
|
652 |
+
'grille, radiator grille',
|
653 |
+
'grocery store, grocery, food market, market',
|
654 |
+
'guillotine',
|
655 |
+
'hair slide',
|
656 |
+
'hair spray',
|
657 |
+
'half track',
|
658 |
+
'hammer',
|
659 |
+
'hamper',
|
660 |
+
'hand blower, blow dryer, blow drier, hair dryer, hair drier',
|
661 |
+
'hand-held computer, hand-held microcomputer',
|
662 |
+
'handkerchief, hankie, hanky, hankey',
|
663 |
+
'hard disc, hard disk, fixed disk',
|
664 |
+
'harmonica, mouth organ, harp, mouth harp',
|
665 |
+
'harp',
|
666 |
+
'harvester, reaper',
|
667 |
+
'hatchet',
|
668 |
+
'holster',
|
669 |
+
'home theater, home theatre',
|
670 |
+
'honeycomb',
|
671 |
+
'hook, claw',
|
672 |
+
'hoopskirt, crinoline',
|
673 |
+
'horizontal bar, high bar',
|
674 |
+
'horse cart, horse-cart',
|
675 |
+
'hourglass',
|
676 |
+
'iPod',
|
677 |
+
'iron, smoothing iron',
|
678 |
+
"jack-o'-lantern",
|
679 |
+
'jean, blue jean, denim',
|
680 |
+
'jeep, landrover',
|
681 |
+
'jersey, T-shirt, tee shirt',
|
682 |
+
'jigsaw puzzle',
|
683 |
+
'jinrikisha, ricksha, rickshaw',
|
684 |
+
'joystick',
|
685 |
+
'kimono',
|
686 |
+
'knee pad',
|
687 |
+
'knot',
|
688 |
+
'lab coat, laboratory coat',
|
689 |
+
'ladle',
|
690 |
+
'lampshade, lamp shade',
|
691 |
+
'laptop, laptop computer',
|
692 |
+
'lawn mower, mower',
|
693 |
+
'lens cap, lens cover',
|
694 |
+
'letter opener, paper knife, paperknife',
|
695 |
+
'library',
|
696 |
+
'lifeboat',
|
697 |
+
'lighter, light, igniter, ignitor',
|
698 |
+
'limousine, limo',
|
699 |
+
'liner, ocean liner',
|
700 |
+
'lipstick, lip rouge',
|
701 |
+
'Loafer',
|
702 |
+
'lotion',
|
703 |
+
'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system', # noqa: E501
|
704 |
+
"loupe, jeweler's loupe",
|
705 |
+
'lumbermill, sawmill',
|
706 |
+
'magnetic compass',
|
707 |
+
'mailbag, postbag',
|
708 |
+
'mailbox, letter box',
|
709 |
+
'maillot',
|
710 |
+
'maillot, tank suit',
|
711 |
+
'manhole cover',
|
712 |
+
'maraca',
|
713 |
+
'marimba, xylophone',
|
714 |
+
'mask',
|
715 |
+
'matchstick',
|
716 |
+
'maypole',
|
717 |
+
'maze, labyrinth',
|
718 |
+
'measuring cup',
|
719 |
+
'medicine chest, medicine cabinet',
|
720 |
+
'megalith, megalithic structure',
|
721 |
+
'microphone, mike',
|
722 |
+
'microwave, microwave oven',
|
723 |
+
'military uniform',
|
724 |
+
'milk can',
|
725 |
+
'minibus',
|
726 |
+
'miniskirt, mini',
|
727 |
+
'minivan',
|
728 |
+
'missile',
|
729 |
+
'mitten',
|
730 |
+
'mixing bowl',
|
731 |
+
'mobile home, manufactured home',
|
732 |
+
'Model T',
|
733 |
+
'modem',
|
734 |
+
'monastery',
|
735 |
+
'monitor',
|
736 |
+
'moped',
|
737 |
+
'mortar',
|
738 |
+
'mortarboard',
|
739 |
+
'mosque',
|
740 |
+
'mosquito net',
|
741 |
+
'motor scooter, scooter',
|
742 |
+
'mountain bike, all-terrain bike, off-roader',
|
743 |
+
'mountain tent',
|
744 |
+
'mouse, computer mouse',
|
745 |
+
'mousetrap',
|
746 |
+
'moving van',
|
747 |
+
'muzzle',
|
748 |
+
'nail',
|
749 |
+
'neck brace',
|
750 |
+
'necklace',
|
751 |
+
'nipple',
|
752 |
+
'notebook, notebook computer',
|
753 |
+
'obelisk',
|
754 |
+
'oboe, hautboy, hautbois',
|
755 |
+
'ocarina, sweet potato',
|
756 |
+
'odometer, hodometer, mileometer, milometer',
|
757 |
+
'oil filter',
|
758 |
+
'organ, pipe organ',
|
759 |
+
'oscilloscope, scope, cathode-ray oscilloscope, CRO',
|
760 |
+
'overskirt',
|
761 |
+
'oxcart',
|
762 |
+
'oxygen mask',
|
763 |
+
'packet',
|
764 |
+
'paddle, boat paddle',
|
765 |
+
'paddlewheel, paddle wheel',
|
766 |
+
'padlock',
|
767 |
+
'paintbrush',
|
768 |
+
"pajama, pyjama, pj's, jammies",
|
769 |
+
'palace',
|
770 |
+
'panpipe, pandean pipe, syrinx',
|
771 |
+
'paper towel',
|
772 |
+
'parachute, chute',
|
773 |
+
'parallel bars, bars',
|
774 |
+
'park bench',
|
775 |
+
'parking meter',
|
776 |
+
'passenger car, coach, carriage',
|
777 |
+
'patio, terrace',
|
778 |
+
'pay-phone, pay-station',
|
779 |
+
'pedestal, plinth, footstall',
|
780 |
+
'pencil box, pencil case',
|
781 |
+
'pencil sharpener',
|
782 |
+
'perfume, essence',
|
783 |
+
'Petri dish',
|
784 |
+
'photocopier',
|
785 |
+
'pick, plectrum, plectron',
|
786 |
+
'pickelhaube',
|
787 |
+
'picket fence, paling',
|
788 |
+
'pickup, pickup truck',
|
789 |
+
'pier',
|
790 |
+
'piggy bank, penny bank',
|
791 |
+
'pill bottle',
|
792 |
+
'pillow',
|
793 |
+
'ping-pong ball',
|
794 |
+
'pinwheel',
|
795 |
+
'pirate, pirate ship',
|
796 |
+
'pitcher, ewer',
|
797 |
+
"plane, carpenter's plane, woodworking plane",
|
798 |
+
'planetarium',
|
799 |
+
'plastic bag',
|
800 |
+
'plate rack',
|
801 |
+
'plow, plough',
|
802 |
+
"plunger, plumber's helper",
|
803 |
+
'Polaroid camera, Polaroid Land camera',
|
804 |
+
'pole',
|
805 |
+
'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria', # noqa: E501
|
806 |
+
'poncho',
|
807 |
+
'pool table, billiard table, snooker table',
|
808 |
+
'pop bottle, soda bottle',
|
809 |
+
'pot, flowerpot',
|
810 |
+
"potter's wheel",
|
811 |
+
'power drill',
|
812 |
+
'prayer rug, prayer mat',
|
813 |
+
'printer',
|
814 |
+
'prison, prison house',
|
815 |
+
'projectile, missile',
|
816 |
+
'projector',
|
817 |
+
'puck, hockey puck',
|
818 |
+
'punching bag, punch bag, punching ball, punchball',
|
819 |
+
'purse',
|
820 |
+
'quill, quill pen',
|
821 |
+
'quilt, comforter, comfort, puff',
|
822 |
+
'racer, race car, racing car',
|
823 |
+
'racket, racquet',
|
824 |
+
'radiator',
|
825 |
+
'radio, wireless',
|
826 |
+
'radio telescope, radio reflector',
|
827 |
+
'rain barrel',
|
828 |
+
'recreational vehicle, RV, R.V.',
|
829 |
+
'reel',
|
830 |
+
'reflex camera',
|
831 |
+
'refrigerator, icebox',
|
832 |
+
'remote control, remote',
|
833 |
+
'restaurant, eating house, eating place, eatery',
|
834 |
+
'revolver, six-gun, six-shooter',
|
835 |
+
'rifle',
|
836 |
+
'rocking chair, rocker',
|
837 |
+
'rotisserie',
|
838 |
+
'rubber eraser, rubber, pencil eraser',
|
839 |
+
'rugby ball',
|
840 |
+
'rule, ruler',
|
841 |
+
'running shoe',
|
842 |
+
'safe',
|
843 |
+
'safety pin',
|
844 |
+
'saltshaker, salt shaker',
|
845 |
+
'sandal',
|
846 |
+
'sarong',
|
847 |
+
'sax, saxophone',
|
848 |
+
'scabbard',
|
849 |
+
'scale, weighing machine',
|
850 |
+
'school bus',
|
851 |
+
'schooner',
|
852 |
+
'scoreboard',
|
853 |
+
'screen, CRT screen',
|
854 |
+
'screw',
|
855 |
+
'screwdriver',
|
856 |
+
'seat belt, seatbelt',
|
857 |
+
'sewing machine',
|
858 |
+
'shield, buckler',
|
859 |
+
'shoe shop, shoe-shop, shoe store',
|
860 |
+
'shoji',
|
861 |
+
'shopping basket',
|
862 |
+
'shopping cart',
|
863 |
+
'shovel',
|
864 |
+
'shower cap',
|
865 |
+
'shower curtain',
|
866 |
+
'ski',
|
867 |
+
'ski mask',
|
868 |
+
'sleeping bag',
|
869 |
+
'slide rule, slipstick',
|
870 |
+
'sliding door',
|
871 |
+
'slot, one-armed bandit',
|
872 |
+
'snorkel',
|
873 |
+
'snowmobile',
|
874 |
+
'snowplow, snowplough',
|
875 |
+
'soap dispenser',
|
876 |
+
'soccer ball',
|
877 |
+
'sock',
|
878 |
+
'solar dish, solar collector, solar furnace',
|
879 |
+
'sombrero',
|
880 |
+
'soup bowl',
|
881 |
+
'space bar',
|
882 |
+
'space heater',
|
883 |
+
'space shuttle',
|
884 |
+
'spatula',
|
885 |
+
'speedboat',
|
886 |
+
"spider web, spider's web",
|
887 |
+
'spindle',
|
888 |
+
'sports car, sport car',
|
889 |
+
'spotlight, spot',
|
890 |
+
'stage',
|
891 |
+
'steam locomotive',
|
892 |
+
'steel arch bridge',
|
893 |
+
'steel drum',
|
894 |
+
'stethoscope',
|
895 |
+
'stole',
|
896 |
+
'stone wall',
|
897 |
+
'stopwatch, stop watch',
|
898 |
+
'stove',
|
899 |
+
'strainer',
|
900 |
+
'streetcar, tram, tramcar, trolley, trolley car',
|
901 |
+
'stretcher',
|
902 |
+
'studio couch, day bed',
|
903 |
+
'stupa, tope',
|
904 |
+
'submarine, pigboat, sub, U-boat',
|
905 |
+
'suit, suit of clothes',
|
906 |
+
'sundial',
|
907 |
+
'sunglass',
|
908 |
+
'sunglasses, dark glasses, shades',
|
909 |
+
'sunscreen, sunblock, sun blocker',
|
910 |
+
'suspension bridge',
|
911 |
+
'swab, swob, mop',
|
912 |
+
'sweatshirt',
|
913 |
+
'swimming trunks, bathing trunks',
|
914 |
+
'swing',
|
915 |
+
'switch, electric switch, electrical switch',
|
916 |
+
'syringe',
|
917 |
+
'table lamp',
|
918 |
+
'tank, army tank, armored combat vehicle, armoured combat vehicle',
|
919 |
+
'tape player',
|
920 |
+
'teapot',
|
921 |
+
'teddy, teddy bear',
|
922 |
+
'television, television system',
|
923 |
+
'tennis ball',
|
924 |
+
'thatch, thatched roof',
|
925 |
+
'theater curtain, theatre curtain',
|
926 |
+
'thimble',
|
927 |
+
'thresher, thrasher, threshing machine',
|
928 |
+
'throne',
|
929 |
+
'tile roof',
|
930 |
+
'toaster',
|
931 |
+
'tobacco shop, tobacconist shop, tobacconist',
|
932 |
+
'toilet seat',
|
933 |
+
'torch',
|
934 |
+
'totem pole',
|
935 |
+
'tow truck, tow car, wrecker',
|
936 |
+
'toyshop',
|
937 |
+
'tractor',
|
938 |
+
'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi', # noqa: E501
|
939 |
+
'tray',
|
940 |
+
'trench coat',
|
941 |
+
'tricycle, trike, velocipede',
|
942 |
+
'trimaran',
|
943 |
+
'tripod',
|
944 |
+
'triumphal arch',
|
945 |
+
'trolleybus, trolley coach, trackless trolley',
|
946 |
+
'trombone',
|
947 |
+
'tub, vat',
|
948 |
+
'turnstile',
|
949 |
+
'typewriter keyboard',
|
950 |
+
'umbrella',
|
951 |
+
'unicycle, monocycle',
|
952 |
+
'upright, upright piano',
|
953 |
+
'vacuum, vacuum cleaner',
|
954 |
+
'vase',
|
955 |
+
'vault',
|
956 |
+
'velvet',
|
957 |
+
'vending machine',
|
958 |
+
'vestment',
|
959 |
+
'viaduct',
|
960 |
+
'violin, fiddle',
|
961 |
+
'volleyball',
|
962 |
+
'waffle iron',
|
963 |
+
'wall clock',
|
964 |
+
'wallet, billfold, notecase, pocketbook',
|
965 |
+
'wardrobe, closet, press',
|
966 |
+
'warplane, military plane',
|
967 |
+
'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
|
968 |
+
'washer, automatic washer, washing machine',
|
969 |
+
'water bottle',
|
970 |
+
'water jug',
|
971 |
+
'water tower',
|
972 |
+
'whiskey jug',
|
973 |
+
'whistle',
|
974 |
+
'wig',
|
975 |
+
'window screen',
|
976 |
+
'window shade',
|
977 |
+
'Windsor tie',
|
978 |
+
'wine bottle',
|
979 |
+
'wing',
|
980 |
+
'wok',
|
981 |
+
'wooden spoon',
|
982 |
+
'wool, woolen, woollen',
|
983 |
+
'worm fence, snake fence, snake-rail fence, Virginia fence',
|
984 |
+
'wreck',
|
985 |
+
'yawl',
|
986 |
+
'yurt',
|
987 |
+
'web site, website, internet site, site',
|
988 |
+
'comic book',
|
989 |
+
'crossword puzzle, crossword',
|
990 |
+
'street sign',
|
991 |
+
'traffic light, traffic signal, stoplight',
|
992 |
+
'book jacket, dust cover, dust jacket, dust wrapper',
|
993 |
+
'menu',
|
994 |
+
'plate',
|
995 |
+
'guacamole',
|
996 |
+
'consomme',
|
997 |
+
'hot pot, hotpot',
|
998 |
+
'trifle',
|
999 |
+
'ice cream, icecream',
|
1000 |
+
'ice lolly, lolly, lollipop, popsicle',
|
1001 |
+
'French loaf',
|
1002 |
+
'bagel, beigel',
|
1003 |
+
'pretzel',
|
1004 |
+
'cheeseburger',
|
1005 |
+
'hotdog, hot dog, red hot',
|
1006 |
+
'mashed potato',
|
1007 |
+
'head cabbage',
|
1008 |
+
'broccoli',
|
1009 |
+
'cauliflower',
|
1010 |
+
'zucchini, courgette',
|
1011 |
+
'spaghetti squash',
|
1012 |
+
'acorn squash',
|
1013 |
+
'butternut squash',
|
1014 |
+
'cucumber, cuke',
|
1015 |
+
'artichoke, globe artichoke',
|
1016 |
+
'bell pepper',
|
1017 |
+
'cardoon',
|
1018 |
+
'mushroom',
|
1019 |
+
'Granny Smith',
|
1020 |
+
'strawberry',
|
1021 |
+
'orange',
|
1022 |
+
'lemon',
|
1023 |
+
'fig',
|
1024 |
+
'pineapple, ananas',
|
1025 |
+
'banana',
|
1026 |
+
'jackfruit, jak, jack',
|
1027 |
+
'custard apple',
|
1028 |
+
'pomegranate',
|
1029 |
+
'hay',
|
1030 |
+
'carbonara',
|
1031 |
+
'chocolate sauce, chocolate syrup',
|
1032 |
+
'dough',
|
1033 |
+
'meat loaf, meatloaf',
|
1034 |
+
'pizza, pizza pie',
|
1035 |
+
'potpie',
|
1036 |
+
'burrito',
|
1037 |
+
'red wine',
|
1038 |
+
'espresso',
|
1039 |
+
'cup',
|
1040 |
+
'eggnog',
|
1041 |
+
'alp',
|
1042 |
+
'bubble',
|
1043 |
+
'cliff, drop, drop-off',
|
1044 |
+
'coral reef',
|
1045 |
+
'geyser',
|
1046 |
+
'lakeside, lakeshore',
|
1047 |
+
'promontory, headland, head, foreland',
|
1048 |
+
'sandbar, sand bar',
|
1049 |
+
'seashore, coast, seacoast, sea-coast',
|
1050 |
+
'valley, vale',
|
1051 |
+
'volcano',
|
1052 |
+
'ballplayer, baseball player',
|
1053 |
+
'groom, bridegroom',
|
1054 |
+
'scuba diver',
|
1055 |
+
'rapeseed',
|
1056 |
+
'daisy',
|
1057 |
+
"yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", # noqa: E501
|
1058 |
+
'corn',
|
1059 |
+
'acorn',
|
1060 |
+
'hip, rose hip, rosehip',
|
1061 |
+
'buckeye, horse chestnut, conker',
|
1062 |
+
'coral fungus',
|
1063 |
+
'agaric',
|
1064 |
+
'gyromitra',
|
1065 |
+
'stinkhorn, carrion fungus',
|
1066 |
+
'earthstar',
|
1067 |
+
'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa', # noqa: E501
|
1068 |
+
'bolete',
|
1069 |
+
'ear, spike, capitulum',
|
1070 |
+
'toilet tissue, toilet paper, bathroom tissue')
|
1071 |
+
|
1072 |
+
CIFAR10_CATEGORIES = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog',
|
1073 |
+
'frog', 'horse', 'ship', 'truck')
|
1074 |
+
|
1075 |
+
CIFAR100_CATEGORIES = (
|
1076 |
+
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
|
1077 |
+
'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
|
1078 |
+
'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
|
1079 |
+
'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
|
1080 |
+
'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
|
1081 |
+
'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
|
1082 |
+
'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain',
|
1083 |
+
'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree',
|
1084 |
+
'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy',
|
1085 |
+
'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket',
|
1086 |
+
'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail',
|
1087 |
+
'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper',
|
1088 |
+
'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train',
|
1089 |
+
'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf',
|
1090 |
+
'woman', 'worm')
|
1091 |
+
|
1092 |
+
MNIST_CATEGORITES = ('0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
|
1093 |
+
'5 - five', '6 - six', '7 - seven', '8 - eight',
|
1094 |
+
'9 - nine')
|
1095 |
+
|
1096 |
+
FASHIONMNIST_CATEGORITES = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress',
|
1097 |
+
'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag',
|
1098 |
+
'Ankle boot')
|
1099 |
+
|
1100 |
+
PLACES205_CATEGORIES = (
|
1101 |
+
'abbey', 'airport_terminal', 'alley', 'amphitheater', 'amusement_park',
|
1102 |
+
'aquarium', 'aqueduct', 'arch', 'art_gallery', 'art_studio',
|
1103 |
+
'assembly_line', 'attic', 'auditorium', 'apartment_building/outdoor',
|
1104 |
+
'badlands', 'ballroom', 'bamboo_forest', 'banquet_hall', 'bar',
|
1105 |
+
'baseball_field', 'basement', 'basilica', 'bayou', 'beauty_salon',
|
1106 |
+
'bedroom', 'boardwalk', 'boat_deck', 'bookstore', 'botanical_garden',
|
1107 |
+
'bowling_alley', 'boxing_ring', 'bridge', 'building_facade',
|
1108 |
+
'bus_interior', 'butchers_shop', 'butte', 'bakery/shop', 'cafeteria',
|
1109 |
+
'campsite', 'candy_store', 'canyon', 'castle', 'cemetery', 'chalet',
|
1110 |
+
'classroom', 'closet', 'clothing_store', 'coast', 'cockpit', 'coffee_shop',
|
1111 |
+
'conference_center', 'conference_room', 'construction_site', 'corn_field',
|
1112 |
+
'corridor', 'cottage_garden', 'courthouse', 'courtyard', 'creek',
|
1113 |
+
'crevasse', 'crosswalk', 'cathedral/outdoor', 'church/outdoor', 'dam',
|
1114 |
+
'dining_room', 'dock', 'dorm_room', 'driveway', 'desert/sand',
|
1115 |
+
'desert/vegetation', 'dinette/home', 'doorway/outdoor', 'engine_room',
|
1116 |
+
'excavation', 'fairway', 'fire_escape', 'fire_station', 'food_court',
|
1117 |
+
'forest_path', 'forest_road', 'formal_garden', 'fountain',
|
1118 |
+
'field/cultivated', 'field/wild', 'galley', 'game_room', 'garbage_dump',
|
1119 |
+
'gas_station', 'gift_shop', 'golf_course', 'harbor', 'herb_garden',
|
1120 |
+
'highway', 'home_office', 'hospital', 'hospital_room', 'hot_spring',
|
1121 |
+
'hotel_room', 'hotel/outdoor', 'ice_cream_parlor', 'iceberg', 'igloo',
|
1122 |
+
'islet', 'ice_skating_rink/outdoor', 'inn/outdoor', 'jail_cell', 'kasbah',
|
1123 |
+
'kindergarden_classroom', 'kitchen', 'kitchenette', 'laundromat',
|
1124 |
+
'lighthouse', 'living_room', 'lobby', 'locker_room', 'mansion', 'marsh',
|
1125 |
+
'martial_arts_gym', 'mausoleum', 'medina', 'motel', 'mountain',
|
1126 |
+
'mountain_snowy', 'music_studio', 'market/outdoor', 'monastery/outdoor',
|
1127 |
+
'museum/indoor', 'nursery', 'ocean', 'office', 'office_building',
|
1128 |
+
'orchard', 'pagoda', 'palace', 'pantry', 'parking_lot', 'parlor',
|
1129 |
+
'pasture', 'patio', 'pavilion', 'phone_booth', 'picnic_area', 'playground',
|
1130 |
+
'plaza', 'pond', 'pulpit', 'racecourse', 'raft', 'railroad_track',
|
1131 |
+
'rainforest', 'reception', 'residential_neighborhood', 'restaurant',
|
1132 |
+
'restaurant_kitchen', 'restaurant_patio', 'rice_paddy', 'river',
|
1133 |
+
'rock_arch', 'rope_bridge', 'ruin', 'runway', 'sandbar', 'schoolhouse',
|
1134 |
+
'sea_cliff', 'shed', 'shoe_shop', 'shopfront', 'shower', 'ski_resort',
|
1135 |
+
'ski_slope', 'sky', 'skyscraper', 'slum', 'snowfield', 'staircase',
|
1136 |
+
'supermarket', 'swamp', 'stadium/baseball', 'stadium/football',
|
1137 |
+
'stage/indoor', 'subway_station/platform', 'swimming_pool/outdoor',
|
1138 |
+
'television_studio', 'topiary_garden', 'tower', 'train_railway',
|
1139 |
+
'tree_farm', 'trench', 'temple/east_asia', 'temple/south_asia',
|
1140 |
+
'track/outdoor', 'train_station/platform', 'underwater/coral_reef',
|
1141 |
+
'valley', 'vegetable_garden', 'veranda', 'viaduct', 'volcano',
|
1142 |
+
'waiting_room', 'water_tower', 'watering_hole', 'wheat_field', 'wind_farm',
|
1143 |
+
'windmill', 'yard')
|
1144 |
+
|
1145 |
+
OxfordIIITPet_CATEGORIES = (
|
1146 |
+
'Abyssinian', 'american_bulldog', 'american_pit_bull_terrier',
|
1147 |
+
'basset_hound', 'beagle', 'Bengal', 'Birman', 'Bombay', 'boxer',
|
1148 |
+
'British_Shorthair', 'chihuahua', 'Egyptian_Mau', 'english_cocker_spaniel',
|
1149 |
+
'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese',
|
1150 |
+
'japanese_chin', 'keeshond', 'leonberger', 'Maine_Coon',
|
1151 |
+
'miniature_pinscher', 'newfoundland', 'Persian', 'pomeranian', 'pug',
|
1152 |
+
'Ragdoll', 'Russian_Blue', 'saint_bernard', 'samoyed', 'scottish_terrier',
|
1153 |
+
'shiba_inu', 'Siamese', 'Sphynx', 'staffordshire_bull_terrier',
|
1154 |
+
'wheaten_terrier', 'yorkshire_terrier')
|
1155 |
+
|
1156 |
+
DTD_CATEGORIES = ('banded', 'blotchy', 'braided', 'bubbly', 'bumpy',
|
1157 |
+
'chequered', 'cobwebbed', 'cracked', 'crosshatched',
|
1158 |
+
'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled',
|
1159 |
+
'frilly', 'gauzy', 'grid', 'grooved', 'honeycombed',
|
1160 |
+
'interlaced', 'knitted', 'lacelike', 'lined', 'marbled',
|
1161 |
+
'matted', 'meshed', 'paisley', 'perforated', 'pitted',
|
1162 |
+
'pleated', 'polka-dotted', 'porous', 'potholed', 'scaly',
|
1163 |
+
'smeared', 'spiralled', 'sprinkled', 'stained', 'stratified',
|
1164 |
+
'striped', 'studded', 'swirly', 'veined', 'waffled', 'woven',
|
1165 |
+
'wrinkled', 'zigzagged')
|
1166 |
+
|
1167 |
+
FGVCAIRCRAFT_CATEGORIES = (
|
1168 |
+
'707-320', '727-200', '737-200', '737-300', '737-400', '737-500',
|
1169 |
+
'737-600', '737-700', '737-800', '737-900', '747-100', '747-200',
|
1170 |
+
'747-300', '747-400', '757-200', '757-300', '767-200', '767-300',
|
1171 |
+
'767-400', '777-200', '777-300', 'A300B4', 'A310', 'A318', 'A319', 'A320',
|
1172 |
+
'A321', 'A330-200', 'A330-300', 'A340-200', 'A340-300', 'A340-500',
|
1173 |
+
'A340-600', 'A380', 'ATR-42', 'ATR-72', 'An-12', 'BAE 146-200',
|
1174 |
+
'BAE 146-300', 'BAE-125', 'Beechcraft 1900', 'Boeing 717', 'C-130', 'C-47',
|
1175 |
+
'CRJ-200', 'CRJ-700', 'CRJ-900', 'Cessna 172', 'Cessna 208', 'Cessna 525',
|
1176 |
+
'Cessna 560', 'Challenger 600', 'DC-10', 'DC-3', 'DC-6', 'DC-8', 'DC-9-30',
|
1177 |
+
'DH-82', 'DHC-1', 'DHC-6', 'DHC-8-100', 'DHC-8-300', 'DR-400',
|
1178 |
+
'Dornier 328', 'E-170', 'E-190', 'E-195', 'EMB-120', 'ERJ 135', 'ERJ 145',
|
1179 |
+
'Embraer Legacy 600', 'Eurofighter Typhoon', 'F-16A/B', 'F/A-18',
|
1180 |
+
'Falcon 2000', 'Falcon 900', 'Fokker 100', 'Fokker 50', 'Fokker 70',
|
1181 |
+
'Global Express', 'Gulfstream IV', 'Gulfstream V', 'Hawk T1', 'Il-76',
|
1182 |
+
'L-1011', 'MD-11', 'MD-80', 'MD-87', 'MD-90', 'Metroliner', 'Model B200',
|
1183 |
+
'PA-28', 'SR-20', 'Saab 2000', 'Saab 340', 'Spitfire', 'Tornado', 'Tu-134',
|
1184 |
+
'Tu-154', 'Yak-42')
|
1185 |
+
|
1186 |
+
STANFORDCARS_CATEGORIES = (
|
1187 |
+
'AM General Hummer SUV 2000', 'Acura RL Sedan 2012', 'Acura TL Sedan 2012',
|
1188 |
+
'Acura TL Type-S 2008', 'Acura TSX Sedan 2012',
|
1189 |
+
'Acura Integra Type R 2001', 'Acura ZDX Hatchback 2012',
|
1190 |
+
'Aston Martin V8 Vantage Convertible 2012',
|
1191 |
+
'Aston Martin V8 Vantage Coupe 2012',
|
1192 |
+
'Aston Martin Virage Convertible 2012', 'Aston Martin Virage Coupe 2012',
|
1193 |
+
'Audi RS 4 Convertible 2008', 'Audi A5 Coupe 2012', 'Audi TTS Coupe 2012',
|
1194 |
+
'Audi R8 Coupe 2012', 'Audi V8 Sedan 1994', 'Audi 100 Sedan 1994',
|
1195 |
+
'Audi 100 Wagon 1994', 'Audi TT Hatchback 2011', 'Audi S6 Sedan 2011',
|
1196 |
+
'Audi S5 Convertible 2012', 'Audi S5 Coupe 2012', 'Audi S4 Sedan 2012',
|
1197 |
+
'Audi S4 Sedan 2007', 'Audi TT RS Coupe 2012',
|
1198 |
+
'BMW ActiveHybrid 5 Sedan 2012', 'BMW 1 Series Convertible 2012',
|
1199 |
+
'BMW 1 Series Coupe 2012', 'BMW 3 Series Sedan 2012',
|
1200 |
+
'BMW 3 Series Wagon 2012', 'BMW 6 Series Convertible 2007',
|
1201 |
+
'BMW X5 SUV 2007', 'BMW X6 SUV 2012', 'BMW M3 Coupe 2012',
|
1202 |
+
'BMW M5 Sedan 2010', 'BMW M6 Convertible 2010', 'BMW X3 SUV 2012',
|
1203 |
+
'BMW Z4 Convertible 2012',
|
1204 |
+
'Bentley Continental Supersports Conv. Convertible 2012',
|
1205 |
+
'Bentley Arnage Sedan 2009', 'Bentley Mulsanne Sedan 2011',
|
1206 |
+
'Bentley Continental GT Coupe 2012', 'Bentley Continental GT Coupe 2007',
|
1207 |
+
'Bentley Continental Flying Spur Sedan 2007',
|
1208 |
+
'Bugatti Veyron 16.4 Convertible 2009', 'Bugatti Veyron 16.4 Coupe 2009',
|
1209 |
+
'Buick Regal GS 2012', 'Buick Rainier SUV 2007', 'Buick Verano Sedan 2012',
|
1210 |
+
'Buick Enclave SUV 2012', 'Cadillac CTS-V Sedan 2012',
|
1211 |
+
'Cadillac SRX SUV 2012', 'Cadillac Escalade EXT Crew Cab 2007',
|
1212 |
+
'Chevrolet Silverado 1500 Hybrid Crew Cab 2012',
|
1213 |
+
'Chevrolet Corvette Convertible 2012', 'Chevrolet Corvette ZR1 2012',
|
1214 |
+
'Chevrolet Corvette Ron Fellows Edition Z06 2007',
|
1215 |
+
'Chevrolet Traverse SUV 2012', 'Chevrolet Camaro Convertible 2012',
|
1216 |
+
'Chevrolet HHR SS 2010', 'Chevrolet Impala Sedan 2007',
|
1217 |
+
'Chevrolet Tahoe Hybrid SUV 2012', 'Chevrolet Sonic Sedan 2012',
|
1218 |
+
'Chevrolet Express Cargo Van 2007', 'Chevrolet Avalanche Crew Cab 2012',
|
1219 |
+
'Chevrolet Cobalt SS 2010', 'Chevrolet Malibu Hybrid Sedan 2010',
|
1220 |
+
'Chevrolet TrailBlazer SS 2009',
|
1221 |
+
'Chevrolet Silverado 2500HD Regular Cab 2012',
|
1222 |
+
'Chevrolet Silverado 1500 Classic Extended Cab 2007',
|
1223 |
+
'Chevrolet Express Van 2007', 'Chevrolet Monte Carlo Coupe 2007',
|
1224 |
+
'Chevrolet Malibu Sedan 2007',
|
1225 |
+
'Chevrolet Silverado 1500 Extended Cab 2012',
|
1226 |
+
'Chevrolet Silverado 1500 Regular Cab 2012', 'Chrysler Aspen SUV 2009',
|
1227 |
+
'Chrysler Sebring Convertible 2010',
|
1228 |
+
'Chrysler Town and Country Minivan 2012', 'Chrysler 300 SRT-8 2010',
|
1229 |
+
'Chrysler Crossfire Convertible 2008',
|
1230 |
+
'Chrysler PT Cruiser Convertible 2008', 'Daewoo Nubira Wagon 2002',
|
1231 |
+
'Dodge Caliber Wagon 2012', 'Dodge Caliber Wagon 2007',
|
1232 |
+
'Dodge Caravan Minivan 1997', 'Dodge Ram Pickup 3500 Crew Cab 2010',
|
1233 |
+
'Dodge Ram Pickup 3500 Quad Cab 2009', 'Dodge Sprinter Cargo Van 2009',
|
1234 |
+
'Dodge Journey SUV 2012', 'Dodge Dakota Crew Cab 2010',
|
1235 |
+
'Dodge Dakota Club Cab 2007', 'Dodge Magnum Wagon 2008',
|
1236 |
+
'Dodge Challenger SRT8 2011', 'Dodge Durango SUV 2012',
|
1237 |
+
'Dodge Durango SUV 2007', 'Dodge Charger Sedan 2012',
|
1238 |
+
'Dodge Charger SRT-8 2009', 'Eagle Talon Hatchback 1998',
|
1239 |
+
'FIAT 500 Abarth 2012', 'FIAT 500 Convertible 2012',
|
1240 |
+
'Ferrari FF Coupe 2012', 'Ferrari California Convertible 2012',
|
1241 |
+
'Ferrari 458 Italia Convertible 2012', 'Ferrari 458 Italia Coupe 2012',
|
1242 |
+
'Fisker Karma Sedan 2012', 'Ford F-450 Super Duty Crew Cab 2012',
|
1243 |
+
'Ford Mustang Convertible 2007', 'Ford Freestar Minivan 2007',
|
1244 |
+
'Ford Expedition EL SUV 2009', 'Ford Edge SUV 2012',
|
1245 |
+
'Ford Ranger SuperCab 2011', 'Ford GT Coupe 2006',
|
1246 |
+
'Ford F-150 Regular Cab 2012', 'Ford F-150 Regular Cab 2007',
|
1247 |
+
'Ford Focus Sedan 2007', 'Ford E-Series Wagon Van 2012',
|
1248 |
+
'Ford Fiesta Sedan 2012', 'GMC Terrain SUV 2012', 'GMC Savana Van 2012',
|
1249 |
+
'GMC Yukon Hybrid SUV 2012', 'GMC Acadia SUV 2012',
|
1250 |
+
'GMC Canyon Extended Cab 2012', 'Geo Metro Convertible 1993',
|
1251 |
+
'HUMMER H3T Crew Cab 2010', 'HUMMER H2 SUT Crew Cab 2009',
|
1252 |
+
'Honda Odyssey Minivan 2012', 'Honda Odyssey Minivan 2007',
|
1253 |
+
'Honda Accord Coupe 2012', 'Honda Accord Sedan 2012',
|
1254 |
+
'Hyundai Veloster Hatchback 2012', 'Hyundai Santa Fe SUV 2012',
|
1255 |
+
'Hyundai Tucson SUV 2012', 'Hyundai Veracruz SUV 2012',
|
1256 |
+
'Hyundai Sonata Hybrid Sedan 2012', 'Hyundai Elantra Sedan 2007',
|
1257 |
+
'Hyundai Accent Sedan 2012', 'Hyundai Genesis Sedan 2012',
|
1258 |
+
'Hyundai Sonata Sedan 2012', 'Hyundai Elantra Touring Hatchback 2012',
|
1259 |
+
'Hyundai Azera Sedan 2012', 'Infiniti G Coupe IPL 2012',
|
1260 |
+
'Infiniti QX56 SUV 2011', 'Isuzu Ascender SUV 2008', 'Jaguar XK XKR 2012',
|
1261 |
+
'Jeep Patriot SUV 2012', 'Jeep Wrangler SUV 2012', 'Jeep Liberty SUV 2012',
|
1262 |
+
'Jeep Grand Cherokee SUV 2012', 'Jeep Compass SUV 2012',
|
1263 |
+
'Lamborghini Reventon Coupe 2008', 'Lamborghini Aventador Coupe 2012',
|
1264 |
+
'Lamborghini Gallardo LP 570-4 Superleggera 2012',
|
1265 |
+
'Lamborghini Diablo Coupe 2001', 'Land Rover Range Rover SUV 2012',
|
1266 |
+
'Land Rover LR2 SUV 2012', 'Lincoln Town Car Sedan 2011',
|
1267 |
+
'MINI Cooper Roadster Convertible 2012',
|
1268 |
+
'Maybach Landaulet Convertible 2012', 'Mazda Tribute SUV 2011',
|
1269 |
+
'McLaren MP4-12C Coupe 2012', 'Mercedes-Benz 300-Class Convertible 1993',
|
1270 |
+
'Mercedes-Benz C-Class Sedan 2012', 'Mercedes-Benz SL-Class Coupe 2009',
|
1271 |
+
'Mercedes-Benz E-Class Sedan 2012', 'Mercedes-Benz S-Class Sedan 2012',
|
1272 |
+
'Mercedes-Benz Sprinter Van 2012', 'Mitsubishi Lancer Sedan 2012',
|
1273 |
+
'Nissan Leaf Hatchback 2012', 'Nissan NV Passenger Van 2012',
|
1274 |
+
'Nissan Juke Hatchback 2012', 'Nissan 240SX Coupe 1998',
|
1275 |
+
'Plymouth Neon Coupe 1999', 'Porsche Panamera Sedan 2012',
|
1276 |
+
'Ram C/V Cargo Van Minivan 2012',
|
1277 |
+
'Rolls-Royce Phantom Drophead Coupe Convertible 2012',
|
1278 |
+
'Rolls-Royce Ghost Sedan 2012', 'Rolls-Royce Phantom Sedan 2012',
|
1279 |
+
'Scion xD Hatchback 2012', 'Spyker C8 Convertible 2009',
|
1280 |
+
'Spyker C8 Coupe 2009', 'Suzuki Aerio Sedan 2007',
|
1281 |
+
'Suzuki Kizashi Sedan 2012', 'Suzuki SX4 Hatchback 2012',
|
1282 |
+
'Suzuki SX4 Sedan 2012', 'Tesla Model S Sedan 2012',
|
1283 |
+
'Toyota Sequoia SUV 2012', 'Toyota Camry Sedan 2012',
|
1284 |
+
'Toyota Corolla Sedan 2012', 'Toyota 4Runner SUV 2012',
|
1285 |
+
'Volkswagen Golf Hatchback 2012', 'Volkswagen Golf Hatchback 1991',
|
1286 |
+
'Volkswagen Beetle Hatchback 2012', 'Volvo C30 Hatchback 2012',
|
1287 |
+
'Volvo 240 Sedan 1993', 'Volvo XC90 SUV 2007',
|
1288 |
+
'smart fortwo Convertible 2012')
|
1289 |
+
|
1290 |
+
SUN397_CATEGORIES = (
|
1291 |
+
'abbey', 'airplane_cabin', 'airport_terminal', 'alley', 'amphitheater',
|
1292 |
+
'amusement_arcade', 'amusement_park', 'anechoic_chamber',
|
1293 |
+
'apartment_building_outdoor', 'apse_indoor', 'aquarium', 'aqueduct',
|
1294 |
+
'arch', 'archive', 'arrival_gate_outdoor', 'art_gallery', 'art_school',
|
1295 |
+
'art_studio', 'assembly_line', 'athletic_field_outdoor', 'atrium_public',
|
1296 |
+
'attic', 'auditorium', 'auto_factory', 'badlands',
|
1297 |
+
'badminton_court_indoor', 'baggage_claim', 'bakery_shop',
|
1298 |
+
'balcony_exterior', 'balcony_interior', 'ball_pit', 'ballroom',
|
1299 |
+
'bamboo_forest', 'banquet_hall', 'bar', 'barn', 'barndoor',
|
1300 |
+
'baseball_field', 'basement', 'basilica', 'basketball_court_outdoor',
|
1301 |
+
'bathroom', 'batters_box', 'bayou', 'bazaar_indoor', 'bazaar_outdoor',
|
1302 |
+
'beach', 'beauty_salon', 'bedroom', 'berth', 'biology_laboratory',
|
1303 |
+
'bistro_indoor', 'boardwalk', 'boat_deck', 'boathouse', 'bookstore',
|
1304 |
+
'booth_indoor', 'botanical_garden', 'bow_window_indoor',
|
1305 |
+
'bow_window_outdoor', 'bowling_alley', 'boxing_ring', 'brewery_indoor',
|
1306 |
+
'bridge', 'building_facade', 'bullring', 'burial_chamber', 'bus_interior',
|
1307 |
+
'butchers_shop', 'butte', 'cabin_outdoor', 'cafeteria', 'campsite',
|
1308 |
+
'campus', 'canal_natural', 'canal_urban', 'candy_store', 'canyon',
|
1309 |
+
'car_interior_backseat', 'car_interior_frontseat', 'carrousel',
|
1310 |
+
'casino_indoor', 'castle', 'catacomb', 'cathedral_indoor',
|
1311 |
+
'cathedral_outdoor', 'cavern_indoor', 'cemetery', 'chalet',
|
1312 |
+
'cheese_factory', 'chemistry_lab', 'chicken_coop_indoor',
|
1313 |
+
'chicken_coop_outdoor', 'childs_room', 'church_indoor', 'church_outdoor',
|
1314 |
+
'classroom', 'clean_room', 'cliff', 'cloister_indoor', 'closet',
|
1315 |
+
'clothing_store', 'coast', 'cockpit', 'coffee_shop', 'computer_room',
|
1316 |
+
'conference_center', 'conference_room', 'construction_site',
|
1317 |
+
'control_room', 'control_tower_outdoor', 'corn_field', 'corral',
|
1318 |
+
'corridor', 'cottage_garden', 'courthouse', 'courtroom', 'courtyard',
|
1319 |
+
'covered_bridge_exterior', 'creek', 'crevasse', 'crosswalk',
|
1320 |
+
'cubicle_office', 'dam', 'delicatessen', 'dentists_office', 'desert_sand',
|
1321 |
+
'desert_vegetation', 'diner_indoor', 'diner_outdoor', 'dinette_home',
|
1322 |
+
'dinette_vehicle', 'dining_car', 'dining_room', 'discotheque', 'dock',
|
1323 |
+
'doorway_outdoor', 'dorm_room', 'driveway', 'driving_range_outdoor',
|
1324 |
+
'drugstore', 'electrical_substation', 'elevator_door', 'elevator_interior',
|
1325 |
+
'elevator_shaft', 'engine_room', 'escalator_indoor', 'excavation',
|
1326 |
+
'factory_indoor', 'fairway', 'fastfood_restaurant', 'field_cultivated',
|
1327 |
+
'field_wild', 'fire_escape', 'fire_station', 'firing_range_indoor',
|
1328 |
+
'fishpond', 'florist_shop_indoor', 'food_court', 'forest_broadleaf',
|
1329 |
+
'forest_needleleaf', 'forest_path', 'forest_road', 'formal_garden',
|
1330 |
+
'fountain', 'galley', 'game_room', 'garage_indoor', 'garbage_dump',
|
1331 |
+
'gas_station', 'gazebo_exterior', 'general_store_indoor',
|
1332 |
+
'general_store_outdoor', 'gift_shop', 'golf_course', 'greenhouse_indoor',
|
1333 |
+
'greenhouse_outdoor', 'gymnasium_indoor', 'hangar_indoor',
|
1334 |
+
'hangar_outdoor', 'harbor', 'hayfield', 'heliport', 'herb_garden',
|
1335 |
+
'highway', 'hill', 'home_office', 'hospital', 'hospital_room',
|
1336 |
+
'hot_spring', 'hot_tub_outdoor', 'hotel_outdoor', 'hotel_room', 'house',
|
1337 |
+
'hunting_lodge_outdoor', 'ice_cream_parlor', 'ice_floe', 'ice_shelf',
|
1338 |
+
'ice_skating_rink_indoor', 'ice_skating_rink_outdoor', 'iceberg', 'igloo',
|
1339 |
+
'industrial_area', 'inn_outdoor', 'islet', 'jacuzzi_indoor', 'jail_indoor',
|
1340 |
+
'jail_cell', 'jewelry_shop', 'kasbah', 'kennel_indoor', 'kennel_outdoor',
|
1341 |
+
'kindergarden_classroom', 'kitchen', 'kitchenette', 'labyrinth_outdoor',
|
1342 |
+
'lake_natural', 'landfill', 'landing_deck', 'laundromat', 'lecture_room',
|
1343 |
+
'library_indoor', 'library_outdoor', 'lido_deck_outdoor', 'lift_bridge',
|
1344 |
+
'lighthouse', 'limousine_interior', 'living_room', 'lobby', 'lock_chamber',
|
1345 |
+
'locker_room', 'mansion', 'manufactured_home', 'market_indoor',
|
1346 |
+
'market_outdoor', 'marsh', 'martial_arts_gym', 'mausoleum', 'medina',
|
1347 |
+
'moat_water', 'monastery_outdoor', 'mosque_indoor', 'mosque_outdoor',
|
1348 |
+
'motel', 'mountain', 'mountain_snowy', 'movie_theater_indoor',
|
1349 |
+
'museum_indoor', 'music_store', 'music_studio',
|
1350 |
+
'nuclear_power_plant_outdoor', 'nursery', 'oast_house',
|
1351 |
+
'observatory_outdoor', 'ocean', 'office', 'office_building',
|
1352 |
+
'oil_refinery_outdoor', 'oilrig', 'operating_room', 'orchard',
|
1353 |
+
'outhouse_outdoor', 'pagoda', 'palace', 'pantry', 'park',
|
1354 |
+
'parking_garage_indoor', 'parking_garage_outdoor', 'parking_lot', 'parlor',
|
1355 |
+
'pasture', 'patio', 'pavilion', 'pharmacy', 'phone_booth',
|
1356 |
+
'physics_laboratory', 'picnic_area', 'pilothouse_indoor',
|
1357 |
+
'planetarium_outdoor', 'playground', 'playroom', 'plaza', 'podium_indoor',
|
1358 |
+
'podium_outdoor', 'pond', 'poolroom_establishment', 'poolroom_home',
|
1359 |
+
'power_plant_outdoor', 'promenade_deck', 'pub_indoor', 'pulpit',
|
1360 |
+
'putting_green', 'racecourse', 'raceway', 'raft', 'railroad_track',
|
1361 |
+
'rainforest', 'reception', 'recreation_room', 'residential_neighborhood',
|
1362 |
+
'restaurant', 'restaurant_kitchen', 'restaurant_patio', 'rice_paddy',
|
1363 |
+
'riding_arena', 'river', 'rock_arch', 'rope_bridge', 'ruin', 'runway',
|
1364 |
+
'sandbar', 'sandbox', 'sauna', 'schoolhouse', 'sea_cliff', 'server_room',
|
1365 |
+
'shed', 'shoe_shop', 'shopfront', 'shopping_mall_indoor', 'shower',
|
1366 |
+
'skatepark', 'ski_lodge', 'ski_resort', 'ski_slope', 'sky', 'skyscraper',
|
1367 |
+
'slum', 'snowfield', 'squash_court', 'stable', 'stadium_baseball',
|
1368 |
+
'stadium_football', 'stage_indoor', 'staircase', 'street',
|
1369 |
+
'subway_interior', 'subway_station_platform', 'supermarket', 'sushi_bar',
|
1370 |
+
'swamp', 'swimming_pool_indoor', 'swimming_pool_outdoor',
|
1371 |
+
'synagogue_indoor', 'synagogue_outdoor', 'television_studio',
|
1372 |
+
'temple_east_asia', 'temple_south_asia', 'tennis_court_indoor',
|
1373 |
+
'tennis_court_outdoor', 'tent_outdoor', 'theater_indoor_procenium',
|
1374 |
+
'theater_indoor_seats', 'thriftshop', 'throne_room', 'ticket_booth',
|
1375 |
+
'toll_plaza', 'topiary_garden', 'tower', 'toyshop', 'track_outdoor',
|
1376 |
+
'train_railway', 'train_station_platform', 'tree_farm', 'tree_house',
|
1377 |
+
'trench', 'underwater_coral_reef', 'utility_room', 'valley',
|
1378 |
+
'van_interior', 'vegetable_garden', 'veranda', 'veterinarians_office',
|
1379 |
+
'viaduct', 'videostore', 'village', 'vineyard', 'volcano',
|
1380 |
+
'volleyball_court_indoor', 'volleyball_court_outdoor', 'waiting_room',
|
1381 |
+
'warehouse_indoor', 'water_tower', 'waterfall_block', 'waterfall_fan',
|
1382 |
+
'waterfall_plunge', 'watering_hole', 'wave', 'wet_bar', 'wheat_field',
|
1383 |
+
'wind_farm', 'windmill', 'wine_cellar_barrel_storage',
|
1384 |
+
'wine_cellar_bottle_storage', 'wrestling_ring_indoor', 'yard',
|
1385 |
+
'youth_hostel')
|
1386 |
+
|
1387 |
+
CALTECH101_CATEGORIES = (
|
1388 |
+
'BACKGROUND_Google', 'Faces', 'Faces_easy', 'Leopards', 'Motorbikes',
|
1389 |
+
'accordion', 'airplanes', 'anchor', 'ant', 'barrel', 'bass', 'beaver',
|
1390 |
+
'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly',
|
1391 |
+
'camera', 'cannon', 'car_side', 'ceiling_fan', 'cellphone', 'chair',
|
1392 |
+
'chandelier', 'cougar_body', 'cougar_face', 'crab', 'crayfish',
|
1393 |
+
'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill',
|
1394 |
+
'dolphin', 'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium',
|
1395 |
+
'ewer', 'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk',
|
1396 |
+
'gramophone', 'grand_piano', 'hawksbill', 'headphone', 'hedgehog',
|
1397 |
+
'helicopter', 'ibis', 'inline_skate', 'joshua_tree', 'kangaroo', 'ketch',
|
1398 |
+
'lamp', 'laptop', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly',
|
1399 |
+
'menorah', 'metronome', 'minaret', 'nautilus', 'octopus', 'okapi',
|
1400 |
+
'pagoda', 'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver',
|
1401 |
+
'rhino', 'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion',
|
1402 |
+
'sea_horse', 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus',
|
1403 |
+
'stop_sign', 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella',
|
1404 |
+
'watch', 'water_lilly', 'wheelchair', 'wild_cat', 'windsor_chair',
|
1405 |
+
'wrench', 'yin_yang')
|
1406 |
+
|
1407 |
+
FOOD101_CATEGORIES = (
|
1408 |
+
'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare',
|
1409 |
+
'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito',
|
1410 |
+
'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake',
|
1411 |
+
'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry',
|
1412 |
+
'chicken_quesadilla', 'chicken_wings', 'chocolate_cake',
|
1413 |
+
'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich',
|
1414 |
+
'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs',
|
1415 |
+
'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel',
|
1416 |
+
'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries',
|
1417 |
+
'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice',
|
1418 |
+
'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad',
|
1419 |
+
'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza',
|
1420 |
+
'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus',
|
1421 |
+
'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich',
|
1422 |
+
'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos',
|
1423 |
+
'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes',
|
1424 |
+
'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine',
|
1425 |
+
'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake',
|
1426 |
+
'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad',
|
1427 |
+
'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara',
|
1428 |
+
'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos',
|
1429 |
+
'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles')
|
1430 |
+
|
1431 |
+
CIFAR100_CATEGORIES_CN = (
|
1432 |
+
'苹果', '水族馆鱼', '婴儿', '熊', '河狸', '床', '蜜蜂', '甲虫', '自行车', '瓶子', '碗', '小男孩',
|
1433 |
+
'桥', '公共汽车', '蝴蝶', '骆驼', '易拉罐', '城堡', '毛毛虫', '牛', '椅子', '猩猩', '钟', '白云',
|
1434 |
+
'蟑螂', '沙发', '螃蟹', '鳄鱼', '杯子', '恐龙', '海豚', '大象', '比目鱼', '森林', '狐狸', '小女孩',
|
1435 |
+
'仓鼠', '屋子', '袋鼠', '键盘', '台灯', '割草机', '猎豹', '狮子', '蜥蜴', '龙虾', '男人', '枫树',
|
1436 |
+
'摩托车', '山', '老鼠', '蘑菇', '橡树', '橙子橘子', '兰花', '水獭', '棕榈树', '梨', '皮卡车', '松树',
|
1437 |
+
'田野', '盘子', '罂粟', '豪猪', '负鼠', '兔子', '浣熊', '鳐鱼', '公路', '火箭', '玫瑰', '大海',
|
1438 |
+
'海豹', '鲨鱼', '尖嘴小鼠', '臭鼬', '摩天大楼', '蜗牛', '蛇', '蜘蛛', '松鼠', '电车', '向日葵', '甜椒',
|
1439 |
+
'桌子', '坦克', '电话', '电视', '老虎', '拖拉机', '火车', '鳟鱼', '郁金香', '乌龟', '衣柜', '鲸鱼',
|
1440 |
+
'柳树', '狼', '女人', '蠕虫')
|
mmpretrain/datasets/cifar.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import pickle
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import mmengine.dist as dist
|
6 |
+
import numpy as np
|
7 |
+
from mmengine.fileio import (LocalBackend, exists, get, get_file_backend,
|
8 |
+
join_path)
|
9 |
+
from mmengine.logging import MMLogger
|
10 |
+
|
11 |
+
from mmpretrain.registry import DATASETS
|
12 |
+
from .base_dataset import BaseDataset
|
13 |
+
from .categories import CIFAR10_CATEGORIES, CIFAR100_CATEGORIES
|
14 |
+
from .utils import check_md5, download_and_extract_archive
|
15 |
+
|
16 |
+
|
17 |
+
@DATASETS.register_module()
|
18 |
+
class CIFAR10(BaseDataset):
|
19 |
+
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
|
20 |
+
|
21 |
+
This implementation is modified from
|
22 |
+
https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py
|
23 |
+
|
24 |
+
Args:
|
25 |
+
data_root (str): The root directory of the CIFAR Dataset.
|
26 |
+
split (str, optional): The dataset split, supports "train" and "test".
|
27 |
+
Default to "train".
|
28 |
+
metainfo (dict, optional): Meta information for dataset, such as
|
29 |
+
categories information. Defaults to None.
|
30 |
+
download (bool): Whether to download the dataset if not exists.
|
31 |
+
Defaults to True.
|
32 |
+
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
33 |
+
""" # noqa: E501
|
34 |
+
|
35 |
+
base_folder = 'cifar-10-batches-py'
|
36 |
+
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
|
37 |
+
filename = 'cifar-10-python.tar.gz'
|
38 |
+
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
|
39 |
+
train_list = [
|
40 |
+
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
|
41 |
+
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
|
42 |
+
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
|
43 |
+
['data_batch_4', '634d18415352ddfa80567beed471001a'],
|
44 |
+
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
|
45 |
+
]
|
46 |
+
|
47 |
+
test_list = [
|
48 |
+
['test_batch', '40351d587109b95175f43aff81a1287e'],
|
49 |
+
]
|
50 |
+
meta = {
|
51 |
+
'filename': 'batches.meta',
|
52 |
+
'key': 'label_names',
|
53 |
+
'md5': '5ff9c542aee3614f3951f8cda6e48888',
|
54 |
+
}
|
55 |
+
METAINFO = {'classes': CIFAR10_CATEGORIES}
|
56 |
+
|
57 |
+
def __init__(self,
|
58 |
+
data_root: str = '',
|
59 |
+
split: str = 'train',
|
60 |
+
metainfo: Optional[dict] = None,
|
61 |
+
download: bool = True,
|
62 |
+
data_prefix: str = '',
|
63 |
+
test_mode: bool = False,
|
64 |
+
**kwargs):
|
65 |
+
|
66 |
+
splits = ['train', 'test']
|
67 |
+
assert split in splits, \
|
68 |
+
f"The split must be one of {splits}, but get '{split}'"
|
69 |
+
self.split = split
|
70 |
+
|
71 |
+
# To handle the BC-breaking
|
72 |
+
if split == 'train' and test_mode:
|
73 |
+
logger = MMLogger.get_current_instance()
|
74 |
+
logger.warning('split="train" but test_mode=True. '
|
75 |
+
'The training set will be used.')
|
76 |
+
|
77 |
+
if not data_root and not data_prefix:
|
78 |
+
raise RuntimeError('Please set ``data_root`` to'
|
79 |
+
'specify the dataset path')
|
80 |
+
|
81 |
+
self.download = download
|
82 |
+
super().__init__(
|
83 |
+
# The CIFAR dataset doesn't need specify annotation file
|
84 |
+
ann_file='',
|
85 |
+
metainfo=metainfo,
|
86 |
+
data_root=data_root,
|
87 |
+
data_prefix=dict(root=data_prefix),
|
88 |
+
test_mode=test_mode,
|
89 |
+
**kwargs)
|
90 |
+
|
91 |
+
def load_data_list(self):
|
92 |
+
"""Load images and ground truth labels."""
|
93 |
+
root = self.data_prefix['root']
|
94 |
+
backend = get_file_backend(root, enable_singleton=True)
|
95 |
+
|
96 |
+
if dist.is_main_process() and not self._check_integrity():
|
97 |
+
if not isinstance(backend, LocalBackend):
|
98 |
+
raise RuntimeError(f'The dataset on {root} is not integrated, '
|
99 |
+
f'please manually handle it.')
|
100 |
+
|
101 |
+
if self.download:
|
102 |
+
download_and_extract_archive(
|
103 |
+
self.url, root, filename=self.filename, md5=self.tgz_md5)
|
104 |
+
else:
|
105 |
+
raise RuntimeError(
|
106 |
+
f'Cannot find {self.__class__.__name__} dataset in '
|
107 |
+
f"{self.data_prefix['root']}, you can specify "
|
108 |
+
'`download=True` to download automatically.')
|
109 |
+
|
110 |
+
dist.barrier()
|
111 |
+
assert self._check_integrity(), \
|
112 |
+
'Download failed or shared storage is unavailable. Please ' \
|
113 |
+
f'download the dataset manually through {self.url}.'
|
114 |
+
|
115 |
+
if self.split == 'train':
|
116 |
+
downloaded_list = self.train_list
|
117 |
+
else:
|
118 |
+
downloaded_list = self.test_list
|
119 |
+
|
120 |
+
imgs = []
|
121 |
+
gt_labels = []
|
122 |
+
|
123 |
+
# load the picked numpy arrays
|
124 |
+
for file_name, _ in downloaded_list:
|
125 |
+
file_path = join_path(root, self.base_folder, file_name)
|
126 |
+
entry = pickle.loads(get(file_path), encoding='latin1')
|
127 |
+
imgs.append(entry['data'])
|
128 |
+
if 'labels' in entry:
|
129 |
+
gt_labels.extend(entry['labels'])
|
130 |
+
else:
|
131 |
+
gt_labels.extend(entry['fine_labels'])
|
132 |
+
|
133 |
+
imgs = np.vstack(imgs).reshape(-1, 3, 32, 32)
|
134 |
+
imgs = imgs.transpose((0, 2, 3, 1)) # convert to HWC
|
135 |
+
|
136 |
+
if self.CLASSES is None:
|
137 |
+
# The metainfo in the file has the lowest priority, therefore
|
138 |
+
# we only need to load it if classes is not specified.
|
139 |
+
self._load_meta()
|
140 |
+
|
141 |
+
data_list = []
|
142 |
+
for img, gt_label in zip(imgs, gt_labels):
|
143 |
+
info = {'img': img, 'gt_label': int(gt_label)}
|
144 |
+
data_list.append(info)
|
145 |
+
return data_list
|
146 |
+
|
147 |
+
def _load_meta(self):
|
148 |
+
"""Load categories information from metafile."""
|
149 |
+
root = self.data_prefix['root']
|
150 |
+
|
151 |
+
path = join_path(root, self.base_folder, self.meta['filename'])
|
152 |
+
md5 = self.meta.get('md5', None)
|
153 |
+
if not exists(path) or (md5 is not None and not check_md5(path, md5)):
|
154 |
+
raise RuntimeError(
|
155 |
+
'Dataset metadata file not found or corrupted.' +
|
156 |
+
' You can use `download=True` to download it')
|
157 |
+
data = pickle.loads(get(path), encoding='latin1')
|
158 |
+
self._metainfo.setdefault('classes', data[self.meta['key']])
|
159 |
+
|
160 |
+
def _check_integrity(self):
|
161 |
+
"""Check the integrity of data files."""
|
162 |
+
root = self.data_prefix['root']
|
163 |
+
|
164 |
+
for fentry in (self.train_list + self.test_list):
|
165 |
+
filename, md5 = fentry[0], fentry[1]
|
166 |
+
fpath = join_path(root, self.base_folder, filename)
|
167 |
+
if not exists(fpath):
|
168 |
+
return False
|
169 |
+
if md5 is not None and not check_md5(fpath, md5):
|
170 |
+
return False
|
171 |
+
return True
|
172 |
+
|
173 |
+
def extra_repr(self) -> List[str]:
|
174 |
+
"""The extra repr information of the dataset."""
|
175 |
+
body = [f"Prefix of data: \t{self.data_prefix['root']}"]
|
176 |
+
return body
|
177 |
+
|
178 |
+
|
179 |
+
@DATASETS.register_module()
|
180 |
+
class CIFAR100(CIFAR10):
|
181 |
+
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
data_root (str): The root directory of the CIFAR Dataset.
|
185 |
+
split (str, optional): The dataset split, supports "train" and "test".
|
186 |
+
Default to "train".
|
187 |
+
metainfo (dict, optional): Meta information for dataset, such as
|
188 |
+
categories information. Defaults to None.
|
189 |
+
download (bool): Whether to download the dataset if not exists.
|
190 |
+
Defaults to True.
|
191 |
+
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
192 |
+
"""
|
193 |
+
|
194 |
+
base_folder = 'cifar-100-python'
|
195 |
+
url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
|
196 |
+
filename = 'cifar-100-python.tar.gz'
|
197 |
+
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
|
198 |
+
train_list = [
|
199 |
+
['train', '16019d7e3df5f24257cddd939b257f8d'],
|
200 |
+
]
|
201 |
+
|
202 |
+
test_list = [
|
203 |
+
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
|
204 |
+
]
|
205 |
+
meta = {
|
206 |
+
'filename': 'meta',
|
207 |
+
'key': 'fine_label_names',
|
208 |
+
'md5': '7973b15100ade9c7d40fb424638fde48',
|
209 |
+
}
|
210 |
+
METAINFO = {'classes': CIFAR100_CATEGORIES}
|
mmpretrain/datasets/coco_caption.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import mmengine
|
6 |
+
from mmengine.dataset import BaseDataset
|
7 |
+
from mmengine.fileio import get_file_backend
|
8 |
+
|
9 |
+
from mmpretrain.registry import DATASETS
|
10 |
+
|
11 |
+
|
12 |
+
@DATASETS.register_module()
|
13 |
+
class COCOCaption(BaseDataset):
|
14 |
+
"""COCO Caption dataset.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
data_root (str): The root directory for ``data_prefix`` and
|
18 |
+
``ann_file``..
|
19 |
+
ann_file (str): Annotation file path.
|
20 |
+
data_prefix (dict): Prefix for data field. Defaults to
|
21 |
+
``dict(img_path='')``.
|
22 |
+
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
|
23 |
+
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def load_data_list(self) -> List[dict]:
|
27 |
+
"""Load data list."""
|
28 |
+
img_prefix = self.data_prefix['img_path']
|
29 |
+
annotations = mmengine.load(self.ann_file)
|
30 |
+
file_backend = get_file_backend(img_prefix)
|
31 |
+
|
32 |
+
data_list = []
|
33 |
+
for ann in annotations:
|
34 |
+
data_info = {
|
35 |
+
'image_id': Path(ann['image']).stem.split('_')[-1],
|
36 |
+
'img_path': file_backend.join_path(img_prefix, ann['image']),
|
37 |
+
'gt_caption': ann['caption'],
|
38 |
+
}
|
39 |
+
|
40 |
+
data_list.append(data_info)
|
41 |
+
|
42 |
+
return data_list
|
mmpretrain/datasets/coco_retrieval.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import json
|
3 |
+
from collections import OrderedDict
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from mmengine import get_file_backend
|
7 |
+
|
8 |
+
from mmpretrain.registry import DATASETS
|
9 |
+
from .base_dataset import BaseDataset
|
10 |
+
|
11 |
+
|
12 |
+
@DATASETS.register_module()
|
13 |
+
class COCORetrieval(BaseDataset):
|
14 |
+
"""COCO Retrieval dataset.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
ann_file (str): Annotation file path.
|
18 |
+
test_mode (bool): Whether dataset is used for evaluation. This will
|
19 |
+
decide the annotation format in data list annotations.
|
20 |
+
Defaults to False.
|
21 |
+
data_root (str): The root directory for ``data_prefix`` and
|
22 |
+
``ann_file``. Defaults to ''.
|
23 |
+
data_prefix (str | dict): Prefix for training data. Defaults to ''.
|
24 |
+
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
|
25 |
+
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def load_data_list(self) -> List[dict]:
|
29 |
+
"""Load data list."""
|
30 |
+
# get file backend
|
31 |
+
img_prefix = self.data_prefix['img_path']
|
32 |
+
file_backend = get_file_backend(img_prefix)
|
33 |
+
|
34 |
+
anno_info = json.load(open(self.ann_file, 'r'))
|
35 |
+
# mapping img_id to img filename
|
36 |
+
img_dict = OrderedDict()
|
37 |
+
for idx, img in enumerate(anno_info['images']):
|
38 |
+
if img['id'] not in img_dict:
|
39 |
+
img_rel_path = img['coco_url'].rsplit('/', 2)[-2:]
|
40 |
+
img_path = file_backend.join_path(img_prefix, *img_rel_path)
|
41 |
+
|
42 |
+
# create new idx for image
|
43 |
+
img_dict[img['id']] = dict(
|
44 |
+
ori_id=img['id'],
|
45 |
+
image_id=idx, # will be used for evaluation
|
46 |
+
img_path=img_path,
|
47 |
+
text=[],
|
48 |
+
gt_text_id=[],
|
49 |
+
gt_image_id=[],
|
50 |
+
)
|
51 |
+
|
52 |
+
train_list = []
|
53 |
+
for idx, anno in enumerate(anno_info['annotations']):
|
54 |
+
anno['text'] = anno.pop('caption')
|
55 |
+
anno['ori_id'] = anno.pop('id')
|
56 |
+
anno['text_id'] = idx # will be used for evaluation
|
57 |
+
# 1. prepare train data list item
|
58 |
+
train_data = anno.copy()
|
59 |
+
train_image = img_dict[train_data['image_id']]
|
60 |
+
train_data['img_path'] = train_image['img_path']
|
61 |
+
train_data['image_ori_id'] = train_image['ori_id']
|
62 |
+
train_data['image_id'] = train_image['image_id']
|
63 |
+
train_data['is_matched'] = True
|
64 |
+
train_list.append(train_data)
|
65 |
+
# 2. prepare eval data list item based on img dict
|
66 |
+
img_dict[anno['image_id']]['gt_text_id'].append(anno['text_id'])
|
67 |
+
img_dict[anno['image_id']]['text'].append(anno['text'])
|
68 |
+
img_dict[anno['image_id']]['gt_image_id'].append(
|
69 |
+
train_image['image_id'])
|
70 |
+
|
71 |
+
self.img_size = len(img_dict)
|
72 |
+
self.text_size = len(anno_info['annotations'])
|
73 |
+
|
74 |
+
# return needed format data list
|
75 |
+
if self.test_mode:
|
76 |
+
return list(img_dict.values())
|
77 |
+
return train_list
|
mmpretrain/datasets/coco_vqa.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
import re
|
4 |
+
from collections import Counter
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import mmengine
|
8 |
+
from mmengine.dataset import BaseDataset
|
9 |
+
|
10 |
+
from mmpretrain.registry import DATASETS
|
11 |
+
|
12 |
+
|
13 |
+
@DATASETS.register_module()
|
14 |
+
class COCOVQA(BaseDataset):
|
15 |
+
"""VQAv2 dataset.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
data_root (str): The root directory for ``data_prefix``, ``ann_file``
|
19 |
+
and ``question_file``.
|
20 |
+
data_prefix (str): The directory of images.
|
21 |
+
question_file (str): Question file path.
|
22 |
+
ann_file (str, optional): Annotation file path for training and
|
23 |
+
validation. Defaults to an empty string.
|
24 |
+
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self,
|
28 |
+
data_root: str,
|
29 |
+
data_prefix: str,
|
30 |
+
question_file: str,
|
31 |
+
ann_file: str = '',
|
32 |
+
**kwarg):
|
33 |
+
self.question_file = question_file
|
34 |
+
super().__init__(
|
35 |
+
data_root=data_root,
|
36 |
+
data_prefix=dict(img_path=data_prefix),
|
37 |
+
ann_file=ann_file,
|
38 |
+
**kwarg,
|
39 |
+
)
|
40 |
+
|
41 |
+
def _join_prefix(self):
|
42 |
+
if not mmengine.is_abs(self.question_file) and self.question_file:
|
43 |
+
self.question_file = osp.join(self.data_root, self.question_file)
|
44 |
+
|
45 |
+
return super()._join_prefix()
|
46 |
+
|
47 |
+
def _create_image_index(self):
|
48 |
+
img_prefix = self.data_prefix['img_path']
|
49 |
+
|
50 |
+
files = mmengine.list_dir_or_file(img_prefix, list_dir=False)
|
51 |
+
image_index = {}
|
52 |
+
for file in files:
|
53 |
+
image_id = re.findall(r'\d{12}', file)
|
54 |
+
if len(image_id) > 0:
|
55 |
+
image_id = int(image_id[-1])
|
56 |
+
image_index[image_id] = mmengine.join_path(img_prefix, file)
|
57 |
+
|
58 |
+
return image_index
|
59 |
+
|
60 |
+
def load_data_list(self) -> List[dict]:
|
61 |
+
"""Load data list."""
|
62 |
+
questions = mmengine.load(self.question_file)['questions']
|
63 |
+
if self.ann_file:
|
64 |
+
annotations = mmengine.load(self.ann_file)['annotations']
|
65 |
+
assert len(questions) == len(annotations)
|
66 |
+
else:
|
67 |
+
annotations = [None] * len(questions)
|
68 |
+
|
69 |
+
# The original VQAv2 annotation file and question file includes
|
70 |
+
# only image id but no image file paths.
|
71 |
+
self.image_index = self._create_image_index()
|
72 |
+
|
73 |
+
data_list = []
|
74 |
+
for question, ann in zip(questions, annotations):
|
75 |
+
# question example
|
76 |
+
# {
|
77 |
+
# 'image_id': 262144,
|
78 |
+
# 'question': "Is the ball flying towards the batter?",
|
79 |
+
# 'question_id': 262144000
|
80 |
+
# }
|
81 |
+
#
|
82 |
+
# ann example
|
83 |
+
# {
|
84 |
+
# 'question_type': "what are the",
|
85 |
+
# 'answer_type': "other",
|
86 |
+
# 'answers': [
|
87 |
+
# {'answer': 'watching',
|
88 |
+
# 'answer_id': 1,
|
89 |
+
# 'answer_confidence': 'yes'},
|
90 |
+
# ...
|
91 |
+
# ],
|
92 |
+
# 'image_id': 262148,
|
93 |
+
# 'question_id': 262148000,
|
94 |
+
# 'multiple_choice_answer': 'watching',
|
95 |
+
# 'answer_type': 'other',
|
96 |
+
# }
|
97 |
+
|
98 |
+
data_info = question
|
99 |
+
data_info['img_path'] = self.image_index[question['image_id']]
|
100 |
+
|
101 |
+
if ann is not None:
|
102 |
+
assert ann['question_id'] == question['question_id']
|
103 |
+
|
104 |
+
# add answer_weight & answer_count, delete duplicate answer
|
105 |
+
answers = [item['answer'] for item in ann.pop('answers')]
|
106 |
+
count = Counter(answers)
|
107 |
+
answer_weight = [i / len(answers) for i in count.values()]
|
108 |
+
data_info['gt_answer'] = list(count.keys())
|
109 |
+
data_info['gt_answer_weight'] = answer_weight
|
110 |
+
data_info.update(ann)
|
111 |
+
|
112 |
+
data_list.append(data_info)
|
113 |
+
|
114 |
+
return data_list
|
mmpretrain/datasets/cub.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from mmengine import get_file_backend, list_from_file
|
5 |
+
from mmengine.logging import MMLogger
|
6 |
+
|
7 |
+
from mmpretrain.registry import DATASETS
|
8 |
+
from .base_dataset import BaseDataset
|
9 |
+
from .categories import CUB_CATEGORIES
|
10 |
+
|
11 |
+
|
12 |
+
@DATASETS.register_module()
|
13 |
+
class CUB(BaseDataset):
|
14 |
+
"""The CUB-200-2011 Dataset.
|
15 |
+
|
16 |
+
Support the `CUB-200-2011 <http://www.vision.caltech.edu/visipedia/CUB-200-2011.html>`_ Dataset.
|
17 |
+
Comparing with the `CUB-200 <http://www.vision.caltech.edu/visipedia/CUB-200.html>`_ Dataset,
|
18 |
+
there are much more pictures in `CUB-200-2011`. After downloading and decompression, the dataset
|
19 |
+
directory structure is as follows.
|
20 |
+
|
21 |
+
CUB dataset directory: ::
|
22 |
+
|
23 |
+
CUB_200_2011
|
24 |
+
├── images
|
25 |
+
│ ├── class_x
|
26 |
+
│ │ ├── xx1.jpg
|
27 |
+
│ │ ├── xx2.jpg
|
28 |
+
│ │ └── ...
|
29 |
+
│ ├── class_y
|
30 |
+
│ │ ├── yy1.jpg
|
31 |
+
│ │ ├── yy2.jpg
|
32 |
+
│ │ └── ...
|
33 |
+
│ └── ...
|
34 |
+
├── images.txt
|
35 |
+
├── image_class_labels.txt
|
36 |
+
├── train_test_split.txt
|
37 |
+
└── ....
|
38 |
+
|
39 |
+
Args:
|
40 |
+
data_root (str): The root directory for CUB-200-2011 dataset.
|
41 |
+
split (str, optional): The dataset split, supports "train" and "test".
|
42 |
+
Default to "train".
|
43 |
+
|
44 |
+
Examples:
|
45 |
+
>>> from mmpretrain.datasets import CUB
|
46 |
+
>>> train_dataset = CUB(data_root='data/CUB_200_2011', split='train')
|
47 |
+
>>> train_dataset
|
48 |
+
Dataset CUB
|
49 |
+
Number of samples: 5994
|
50 |
+
Number of categories: 200
|
51 |
+
Root of dataset: data/CUB_200_2011
|
52 |
+
>>> test_dataset = CUB(data_root='data/CUB_200_2011', split='test')
|
53 |
+
>>> test_dataset
|
54 |
+
Dataset CUB
|
55 |
+
Number of samples: 5794
|
56 |
+
Number of categories: 200
|
57 |
+
Root of dataset: data/CUB_200_2011
|
58 |
+
""" # noqa: E501
|
59 |
+
|
60 |
+
METAINFO = {'classes': CUB_CATEGORIES}
|
61 |
+
|
62 |
+
def __init__(self,
|
63 |
+
data_root: str,
|
64 |
+
split: str = 'train',
|
65 |
+
test_mode: bool = False,
|
66 |
+
**kwargs):
|
67 |
+
|
68 |
+
splits = ['train', 'test']
|
69 |
+
assert split in splits, \
|
70 |
+
f"The split must be one of {splits}, but get '{split}'"
|
71 |
+
self.split = split
|
72 |
+
|
73 |
+
# To handle the BC-breaking
|
74 |
+
if split == 'train' and test_mode:
|
75 |
+
logger = MMLogger.get_current_instance()
|
76 |
+
logger.warning('split="train" but test_mode=True. '
|
77 |
+
'The training set will be used.')
|
78 |
+
|
79 |
+
ann_file = 'images.txt'
|
80 |
+
data_prefix = 'images'
|
81 |
+
image_class_labels_file = 'image_class_labels.txt'
|
82 |
+
train_test_split_file = 'train_test_split.txt'
|
83 |
+
|
84 |
+
self.backend = get_file_backend(data_root, enable_singleton=True)
|
85 |
+
self.image_class_labels_file = self.backend.join_path(
|
86 |
+
data_root, image_class_labels_file)
|
87 |
+
self.train_test_split_file = self.backend.join_path(
|
88 |
+
data_root, train_test_split_file)
|
89 |
+
super(CUB, self).__init__(
|
90 |
+
ann_file=ann_file,
|
91 |
+
data_root=data_root,
|
92 |
+
data_prefix=data_prefix,
|
93 |
+
test_mode=test_mode,
|
94 |
+
**kwargs)
|
95 |
+
|
96 |
+
def _load_data_from_txt(self, filepath):
|
97 |
+
"""load data from CUB txt file, the every line of the file is idx and a
|
98 |
+
data item."""
|
99 |
+
pairs = list_from_file(filepath)
|
100 |
+
data_dict = dict()
|
101 |
+
for pair in pairs:
|
102 |
+
idx, data_item = pair.split()
|
103 |
+
# all the index starts from 1 in CUB files,
|
104 |
+
# here we need to '- 1' to let them start from 0.
|
105 |
+
data_dict[int(idx) - 1] = data_item
|
106 |
+
return data_dict
|
107 |
+
|
108 |
+
def load_data_list(self):
|
109 |
+
"""Load images and ground truth labels."""
|
110 |
+
sample_dict = self._load_data_from_txt(self.ann_file)
|
111 |
+
|
112 |
+
label_dict = self._load_data_from_txt(self.image_class_labels_file)
|
113 |
+
|
114 |
+
split_dict = self._load_data_from_txt(self.train_test_split_file)
|
115 |
+
|
116 |
+
assert sample_dict.keys() == label_dict.keys() == split_dict.keys(),\
|
117 |
+
f'sample_ids should be same in files {self.ann_file}, ' \
|
118 |
+
f'{self.image_class_labels_file} and {self.train_test_split_file}'
|
119 |
+
|
120 |
+
data_list = []
|
121 |
+
for sample_id in sample_dict.keys():
|
122 |
+
if split_dict[sample_id] == '1' and self.split == 'test':
|
123 |
+
# skip train samples when split='test'
|
124 |
+
continue
|
125 |
+
elif split_dict[sample_id] == '0' and self.split == 'train':
|
126 |
+
# skip test samples when split='train'
|
127 |
+
continue
|
128 |
+
|
129 |
+
img_path = self.backend.join_path(self.img_prefix,
|
130 |
+
sample_dict[sample_id])
|
131 |
+
gt_label = int(label_dict[sample_id]) - 1
|
132 |
+
info = dict(img_path=img_path, gt_label=gt_label)
|
133 |
+
data_list.append(info)
|
134 |
+
|
135 |
+
return data_list
|
136 |
+
|
137 |
+
def extra_repr(self) -> List[str]:
|
138 |
+
"""The extra repr information of the dataset."""
|
139 |
+
body = [
|
140 |
+
f'Root of dataset: \t{self.data_root}',
|
141 |
+
]
|
142 |
+
return body
|
mmpretrain/datasets/custom.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
|
3 |
+
|
4 |
+
from mmengine.fileio import (BaseStorageBackend, get_file_backend,
|
5 |
+
list_from_file)
|
6 |
+
from mmengine.logging import MMLogger
|
7 |
+
|
8 |
+
from mmpretrain.registry import DATASETS
|
9 |
+
from .base_dataset import BaseDataset
|
10 |
+
|
11 |
+
|
12 |
+
def find_folders(
|
13 |
+
root: str,
|
14 |
+
backend: Optional[BaseStorageBackend] = None
|
15 |
+
) -> Tuple[List[str], Dict[str, int]]:
|
16 |
+
"""Find classes by folders under a root.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
root (string): root directory of folders
|
20 |
+
backend (BaseStorageBackend | None): The file backend of the root.
|
21 |
+
If None, auto infer backend from the root path. Defaults to None.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
Tuple[List[str], Dict[str, int]]:
|
25 |
+
|
26 |
+
- folders: The name of sub folders under the root.
|
27 |
+
- folder_to_idx: The map from folder name to class idx.
|
28 |
+
"""
|
29 |
+
# Pre-build file backend to prevent verbose file backend inference.
|
30 |
+
backend = backend or get_file_backend(root, enable_singleton=True)
|
31 |
+
folders = list(
|
32 |
+
backend.list_dir_or_file(
|
33 |
+
root,
|
34 |
+
list_dir=True,
|
35 |
+
list_file=False,
|
36 |
+
recursive=False,
|
37 |
+
))
|
38 |
+
folders.sort()
|
39 |
+
folder_to_idx = {folders[i]: i for i in range(len(folders))}
|
40 |
+
return folders, folder_to_idx
|
41 |
+
|
42 |
+
|
43 |
+
def get_samples(
|
44 |
+
root: str,
|
45 |
+
folder_to_idx: Dict[str, int],
|
46 |
+
is_valid_file: Callable,
|
47 |
+
backend: Optional[BaseStorageBackend] = None,
|
48 |
+
):
|
49 |
+
"""Make dataset by walking all images under a root.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
root (string): root directory of folders
|
53 |
+
folder_to_idx (dict): the map from class name to class idx
|
54 |
+
is_valid_file (Callable): A function that takes path of a file
|
55 |
+
and check if the file is a valid sample file.
|
56 |
+
backend (BaseStorageBackend | None): The file backend of the root.
|
57 |
+
If None, auto infer backend from the root path. Defaults to None.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
Tuple[list, set]:
|
61 |
+
|
62 |
+
- samples: a list of tuple where each element is (image, class_idx)
|
63 |
+
- empty_folders: The folders don't have any valid files.
|
64 |
+
"""
|
65 |
+
samples = []
|
66 |
+
available_classes = set()
|
67 |
+
# Pre-build file backend to prevent verbose file backend inference.
|
68 |
+
backend = backend or get_file_backend(root, enable_singleton=True)
|
69 |
+
|
70 |
+
if folder_to_idx is not None:
|
71 |
+
for folder_name in sorted(list(folder_to_idx.keys())):
|
72 |
+
_dir = backend.join_path(root, folder_name)
|
73 |
+
files = backend.list_dir_or_file(
|
74 |
+
_dir,
|
75 |
+
list_dir=False,
|
76 |
+
list_file=True,
|
77 |
+
recursive=True,
|
78 |
+
)
|
79 |
+
for file in sorted(list(files)):
|
80 |
+
if is_valid_file(file):
|
81 |
+
path = backend.join_path(folder_name, file)
|
82 |
+
item = (path, folder_to_idx[folder_name])
|
83 |
+
samples.append(item)
|
84 |
+
available_classes.add(folder_name)
|
85 |
+
empty_folders = set(folder_to_idx.keys()) - available_classes
|
86 |
+
else:
|
87 |
+
files = backend.list_dir_or_file(
|
88 |
+
root,
|
89 |
+
list_dir=False,
|
90 |
+
list_file=True,
|
91 |
+
recursive=True,
|
92 |
+
)
|
93 |
+
samples = [file for file in sorted(list(files)) if is_valid_file(file)]
|
94 |
+
empty_folders = None
|
95 |
+
|
96 |
+
return samples, empty_folders
|
97 |
+
|
98 |
+
|
99 |
+
@DATASETS.register_module()
|
100 |
+
class CustomDataset(BaseDataset):
|
101 |
+
"""A generic dataset for multiple tasks.
|
102 |
+
|
103 |
+
The dataset supports two kinds of style.
|
104 |
+
|
105 |
+
1. Use an annotation file to specify all samples, and each line indicates a
|
106 |
+
sample:
|
107 |
+
|
108 |
+
The annotation file (for ``with_label=True``, supervised tasks.): ::
|
109 |
+
|
110 |
+
folder_1/xxx.png 0
|
111 |
+
folder_1/xxy.png 1
|
112 |
+
123.png 4
|
113 |
+
nsdf3.png 3
|
114 |
+
...
|
115 |
+
|
116 |
+
The annotation file (for ``with_label=False``, unsupervised tasks.): ::
|
117 |
+
|
118 |
+
folder_1/xxx.png
|
119 |
+
folder_1/xxy.png
|
120 |
+
123.png
|
121 |
+
nsdf3.png
|
122 |
+
...
|
123 |
+
|
124 |
+
Sample files: ::
|
125 |
+
|
126 |
+
data_prefix/
|
127 |
+
├── folder_1
|
128 |
+
│ ├── xxx.png
|
129 |
+
│ ├── xxy.png
|
130 |
+
│ └── ...
|
131 |
+
├── 123.png
|
132 |
+
├── nsdf3.png
|
133 |
+
└── ...
|
134 |
+
|
135 |
+
Please use the argument ``metainfo`` to specify extra information for
|
136 |
+
the task, like ``{'classes': ('bird', 'cat', 'deer', 'dog', 'frog')}``.
|
137 |
+
|
138 |
+
2. Place all samples in one folder as below:
|
139 |
+
|
140 |
+
Sample files (for ``with_label=True``, supervised tasks, we use the name
|
141 |
+
of sub-folders as the categories names): ::
|
142 |
+
|
143 |
+
data_prefix/
|
144 |
+
├── class_x
|
145 |
+
│ ├── xxx.png
|
146 |
+
│ ├── xxy.png
|
147 |
+
│ └── ...
|
148 |
+
│ └── xxz.png
|
149 |
+
└── class_y
|
150 |
+
├── 123.png
|
151 |
+
├── nsdf3.png
|
152 |
+
├── ...
|
153 |
+
└��─ asd932_.png
|
154 |
+
|
155 |
+
Sample files (for ``with_label=False``, unsupervised tasks, we use all
|
156 |
+
sample files under the specified folder): ::
|
157 |
+
|
158 |
+
data_prefix/
|
159 |
+
├── folder_1
|
160 |
+
│ ├── xxx.png
|
161 |
+
│ ├── xxy.png
|
162 |
+
│ └── ...
|
163 |
+
├── 123.png
|
164 |
+
├── nsdf3.png
|
165 |
+
└── ...
|
166 |
+
|
167 |
+
If the ``ann_file`` is specified, the dataset will be generated by the
|
168 |
+
first way, otherwise, try the second way.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
data_root (str): The root directory for ``data_prefix`` and
|
172 |
+
``ann_file``. Defaults to ''.
|
173 |
+
data_prefix (str | dict): Prefix for the data. Defaults to ''.
|
174 |
+
ann_file (str): Annotation file path. Defaults to ''.
|
175 |
+
with_label (bool): Whether the annotation file includes ground truth
|
176 |
+
labels, or use sub-folders to specify categories.
|
177 |
+
Defaults to True.
|
178 |
+
extensions (Sequence[str]): A sequence of allowed extensions. Defaults
|
179 |
+
to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif').
|
180 |
+
metainfo (dict, optional): Meta information for dataset, such as class
|
181 |
+
information. Defaults to None.
|
182 |
+
lazy_init (bool): Whether to load annotation during instantiation.
|
183 |
+
In some cases, such as visualization, only the meta information of
|
184 |
+
the dataset is needed, which is not necessary to load annotation
|
185 |
+
file. ``Basedataset`` can skip load annotations to save time by set
|
186 |
+
``lazy_init=False``. Defaults to False.
|
187 |
+
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
188 |
+
"""
|
189 |
+
|
190 |
+
def __init__(self,
|
191 |
+
data_root: str = '',
|
192 |
+
data_prefix: Union[str, dict] = '',
|
193 |
+
ann_file: str = '',
|
194 |
+
with_label=True,
|
195 |
+
extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm',
|
196 |
+
'.bmp', '.pgm', '.tif'),
|
197 |
+
metainfo: Optional[dict] = None,
|
198 |
+
lazy_init: bool = False,
|
199 |
+
**kwargs):
|
200 |
+
assert (ann_file or data_prefix or data_root), \
|
201 |
+
'One of `ann_file`, `data_root` and `data_prefix` must '\
|
202 |
+
'be specified.'
|
203 |
+
|
204 |
+
self.extensions = tuple(set([i.lower() for i in extensions]))
|
205 |
+
self.with_label = with_label
|
206 |
+
|
207 |
+
super().__init__(
|
208 |
+
# The base class requires string ann_file but this class doesn't
|
209 |
+
ann_file=ann_file,
|
210 |
+
metainfo=metainfo,
|
211 |
+
data_root=data_root,
|
212 |
+
data_prefix=data_prefix,
|
213 |
+
# Force to lazy_init for some modification before loading data.
|
214 |
+
lazy_init=True,
|
215 |
+
**kwargs)
|
216 |
+
|
217 |
+
# Full initialize the dataset.
|
218 |
+
if not lazy_init:
|
219 |
+
self.full_init()
|
220 |
+
|
221 |
+
def _find_samples(self):
|
222 |
+
"""find samples from ``data_prefix``."""
|
223 |
+
if self.with_label:
|
224 |
+
classes, folder_to_idx = find_folders(self.img_prefix)
|
225 |
+
samples, empty_classes = get_samples(
|
226 |
+
self.img_prefix,
|
227 |
+
folder_to_idx,
|
228 |
+
is_valid_file=self.is_valid_file,
|
229 |
+
)
|
230 |
+
|
231 |
+
self.folder_to_idx = folder_to_idx
|
232 |
+
|
233 |
+
if self.CLASSES is not None:
|
234 |
+
assert len(self.CLASSES) == len(classes), \
|
235 |
+
f"The number of subfolders ({len(classes)}) doesn't " \
|
236 |
+
f'match the number of specified classes ' \
|
237 |
+
f'({len(self.CLASSES)}). Please check the data folder.'
|
238 |
+
else:
|
239 |
+
self._metainfo['classes'] = tuple(classes)
|
240 |
+
else:
|
241 |
+
samples, empty_classes = get_samples(
|
242 |
+
self.img_prefix,
|
243 |
+
None,
|
244 |
+
is_valid_file=self.is_valid_file,
|
245 |
+
)
|
246 |
+
|
247 |
+
if len(samples) == 0:
|
248 |
+
raise RuntimeError(
|
249 |
+
f'Found 0 files in subfolders of: {self.data_prefix}. '
|
250 |
+
f'Supported extensions are: {",".join(self.extensions)}')
|
251 |
+
|
252 |
+
if empty_classes:
|
253 |
+
logger = MMLogger.get_current_instance()
|
254 |
+
logger.warning(
|
255 |
+
'Found no valid file in the folder '
|
256 |
+
f'{", ".join(empty_classes)}. '
|
257 |
+
f"Supported extensions are: {', '.join(self.extensions)}")
|
258 |
+
|
259 |
+
return samples
|
260 |
+
|
261 |
+
def load_data_list(self):
|
262 |
+
"""Load image paths and gt_labels."""
|
263 |
+
if not self.ann_file:
|
264 |
+
samples = self._find_samples()
|
265 |
+
elif self.with_label:
|
266 |
+
lines = list_from_file(self.ann_file)
|
267 |
+
samples = [x.strip().rsplit(' ', 1) for x in lines]
|
268 |
+
else:
|
269 |
+
samples = list_from_file(self.ann_file)
|
270 |
+
|
271 |
+
# Pre-build file backend to prevent verbose file backend inference.
|
272 |
+
backend = get_file_backend(self.img_prefix, enable_singleton=True)
|
273 |
+
data_list = []
|
274 |
+
for sample in samples:
|
275 |
+
if self.with_label:
|
276 |
+
filename, gt_label = sample
|
277 |
+
img_path = backend.join_path(self.img_prefix, filename)
|
278 |
+
info = {'img_path': img_path, 'gt_label': int(gt_label)}
|
279 |
+
else:
|
280 |
+
img_path = backend.join_path(self.img_prefix, sample)
|
281 |
+
info = {'img_path': img_path}
|
282 |
+
data_list.append(info)
|
283 |
+
return data_list
|
284 |
+
|
285 |
+
def is_valid_file(self, filename: str) -> bool:
|
286 |
+
"""Check if a file is a valid sample."""
|
287 |
+
return filename.lower().endswith(self.extensions)
|
mmpretrain/datasets/dataset_wrappers.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from mmengine.dataset import BaseDataset, force_full_init
|
6 |
+
|
7 |
+
from mmpretrain.registry import DATASETS
|
8 |
+
|
9 |
+
|
10 |
+
@DATASETS.register_module()
|
11 |
+
class KFoldDataset:
|
12 |
+
"""A wrapper of dataset for K-Fold cross-validation.
|
13 |
+
|
14 |
+
K-Fold cross-validation divides all the samples in groups of samples,
|
15 |
+
called folds, of almost equal sizes. And we use k-1 of folds to do training
|
16 |
+
and use the fold left to do validation.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
dataset (:obj:`mmengine.dataset.BaseDataset` | dict): The dataset to be
|
20 |
+
divided
|
21 |
+
fold (int): The fold used to do validation. Defaults to 0.
|
22 |
+
num_splits (int): The number of all folds. Defaults to 5.
|
23 |
+
test_mode (bool): Use the training dataset or validation dataset.
|
24 |
+
Defaults to False.
|
25 |
+
seed (int, optional): The seed to shuffle the dataset before splitting.
|
26 |
+
If None, not shuffle the dataset. Defaults to None.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
dataset,
|
31 |
+
fold=0,
|
32 |
+
num_splits=5,
|
33 |
+
test_mode=False,
|
34 |
+
seed=None):
|
35 |
+
if isinstance(dataset, dict):
|
36 |
+
self.dataset = DATASETS.build(dataset)
|
37 |
+
# Init the dataset wrapper lazily according to the dataset setting.
|
38 |
+
lazy_init = dataset.get('lazy_init', False)
|
39 |
+
elif isinstance(dataset, BaseDataset):
|
40 |
+
self.dataset = dataset
|
41 |
+
else:
|
42 |
+
raise TypeError(f'Unsupported dataset type {type(dataset)}.')
|
43 |
+
|
44 |
+
self._metainfo = getattr(self.dataset, 'metainfo', {})
|
45 |
+
self.fold = fold
|
46 |
+
self.num_splits = num_splits
|
47 |
+
self.test_mode = test_mode
|
48 |
+
self.seed = seed
|
49 |
+
|
50 |
+
self._fully_initialized = False
|
51 |
+
if not lazy_init:
|
52 |
+
self.full_init()
|
53 |
+
|
54 |
+
@property
|
55 |
+
def metainfo(self) -> dict:
|
56 |
+
"""Get the meta information of ``self.dataset``.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
dict: Meta information of the dataset.
|
60 |
+
"""
|
61 |
+
# Prevent `self._metainfo` from being modified by outside.
|
62 |
+
return copy.deepcopy(self._metainfo)
|
63 |
+
|
64 |
+
def full_init(self):
|
65 |
+
"""fully initialize the dataset."""
|
66 |
+
if self._fully_initialized:
|
67 |
+
return
|
68 |
+
|
69 |
+
self.dataset.full_init()
|
70 |
+
ori_len = len(self.dataset)
|
71 |
+
indices = list(range(ori_len))
|
72 |
+
if self.seed is not None:
|
73 |
+
rng = np.random.default_rng(self.seed)
|
74 |
+
rng.shuffle(indices)
|
75 |
+
|
76 |
+
test_start = ori_len * self.fold // self.num_splits
|
77 |
+
test_end = ori_len * (self.fold + 1) // self.num_splits
|
78 |
+
if self.test_mode:
|
79 |
+
indices = indices[test_start:test_end]
|
80 |
+
else:
|
81 |
+
indices = indices[:test_start] + indices[test_end:]
|
82 |
+
|
83 |
+
self._ori_indices = indices
|
84 |
+
self.dataset = self.dataset.get_subset(indices)
|
85 |
+
|
86 |
+
self._fully_initialized = True
|
87 |
+
|
88 |
+
@force_full_init
|
89 |
+
def _get_ori_dataset_idx(self, idx: int) -> int:
|
90 |
+
"""Convert global idx to local index.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
idx (int): Global index of ``KFoldDataset``.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
int: The original index in the whole dataset.
|
97 |
+
"""
|
98 |
+
return self._ori_indices[idx]
|
99 |
+
|
100 |
+
@force_full_init
|
101 |
+
def get_data_info(self, idx: int) -> dict:
|
102 |
+
"""Get annotation by index.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
idx (int): Global index of ``KFoldDataset``.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
dict: The idx-th annotation of the datasets.
|
109 |
+
"""
|
110 |
+
return self.dataset.get_data_info(idx)
|
111 |
+
|
112 |
+
@force_full_init
|
113 |
+
def __len__(self):
|
114 |
+
return len(self.dataset)
|
115 |
+
|
116 |
+
@force_full_init
|
117 |
+
def __getitem__(self, idx):
|
118 |
+
return self.dataset[idx]
|
119 |
+
|
120 |
+
@force_full_init
|
121 |
+
def get_cat_ids(self, idx):
|
122 |
+
return self.dataset.get_cat_ids(idx)
|
123 |
+
|
124 |
+
@force_full_init
|
125 |
+
def get_gt_labels(self):
|
126 |
+
return self.dataset.get_gt_labels()
|
127 |
+
|
128 |
+
@property
|
129 |
+
def CLASSES(self):
|
130 |
+
"""Return all categories names."""
|
131 |
+
return self._metainfo.get('classes', None)
|
132 |
+
|
133 |
+
@property
|
134 |
+
def class_to_idx(self):
|
135 |
+
"""Map mapping class name to class index.
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
dict: mapping from class name to class index.
|
139 |
+
"""
|
140 |
+
|
141 |
+
return {cat: i for i, cat in enumerate(self.CLASSES)}
|
142 |
+
|
143 |
+
def __repr__(self):
|
144 |
+
"""Print the basic information of the dataset.
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
str: Formatted string.
|
148 |
+
"""
|
149 |
+
head = 'Dataset ' + self.__class__.__name__
|
150 |
+
body = []
|
151 |
+
type_ = 'test' if self.test_mode else 'training'
|
152 |
+
body.append(f'Type: \t{type_}')
|
153 |
+
body.append(f'Seed: \t{self.seed}')
|
154 |
+
|
155 |
+
def ordinal(n):
|
156 |
+
# Copy from https://codegolf.stackexchange.com/a/74047
|
157 |
+
suffix = 'tsnrhtdd'[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4]
|
158 |
+
return f'{n}{suffix}'
|
159 |
+
|
160 |
+
body.append(
|
161 |
+
f'Fold: \t{ordinal(self.fold+1)} of {self.num_splits}-fold')
|
162 |
+
if self._fully_initialized:
|
163 |
+
body.append(f'Number of samples: \t{self.__len__()}')
|
164 |
+
else:
|
165 |
+
body.append("Haven't been initialized")
|
166 |
+
|
167 |
+
if self.CLASSES is not None:
|
168 |
+
body.append(f'Number of categories: \t{len(self.CLASSES)}')
|
169 |
+
else:
|
170 |
+
body.append('The `CLASSES` meta info is not set.')
|
171 |
+
|
172 |
+
body.append(
|
173 |
+
f'Original dataset type:\t{self.dataset.__class__.__name__}')
|
174 |
+
|
175 |
+
lines = [head] + [' ' * 4 + line for line in body]
|
176 |
+
return '\n'.join(lines)
|
mmpretrain/datasets/dtd.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import mat4py
|
5 |
+
from mmengine import get_file_backend
|
6 |
+
|
7 |
+
from mmpretrain.registry import DATASETS
|
8 |
+
from .base_dataset import BaseDataset
|
9 |
+
from .categories import DTD_CATEGORIES
|
10 |
+
|
11 |
+
|
12 |
+
@DATASETS.register_module()
|
13 |
+
class DTD(BaseDataset):
|
14 |
+
"""The Describable Texture Dataset (DTD).
|
15 |
+
|
16 |
+
Support the `Describable Texture Dataset <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_ Dataset.
|
17 |
+
After downloading and decompression, the dataset directory structure is as follows.
|
18 |
+
|
19 |
+
DTD dataset directory: ::
|
20 |
+
|
21 |
+
dtd
|
22 |
+
├── images
|
23 |
+
│ ├── banded
|
24 |
+
| | ├──banded_0002.jpg
|
25 |
+
| | ├──banded_0004.jpg
|
26 |
+
| | └── ...
|
27 |
+
│ └── ...
|
28 |
+
├── imdb
|
29 |
+
│ └── imdb.mat
|
30 |
+
├── labels
|
31 |
+
| | ├──labels_joint_anno.txt
|
32 |
+
| | ├──test1.txt
|
33 |
+
| | ├──test2.txt
|
34 |
+
| | └── ...
|
35 |
+
│ └── ...
|
36 |
+
└── ....
|
37 |
+
|
38 |
+
Args:
|
39 |
+
data_root (str): The root directory for Describable Texture dataset.
|
40 |
+
split (str, optional): The dataset split, supports "train",
|
41 |
+
"val", "trainval", and "test". Default to "trainval".
|
42 |
+
|
43 |
+
Examples:
|
44 |
+
>>> from mmpretrain.datasets import DTD
|
45 |
+
>>> train_dataset = DTD(data_root='data/dtd', split='trainval')
|
46 |
+
>>> train_dataset
|
47 |
+
Dataset DTD
|
48 |
+
Number of samples: 3760
|
49 |
+
Number of categories: 47
|
50 |
+
Root of dataset: data/dtd
|
51 |
+
>>> test_dataset = DTD(data_root='data/dtd', split='test')
|
52 |
+
>>> test_dataset
|
53 |
+
Dataset DTD
|
54 |
+
Number of samples: 1880
|
55 |
+
Number of categories: 47
|
56 |
+
Root of dataset: data/dtd
|
57 |
+
""" # noqa: E501
|
58 |
+
|
59 |
+
METAINFO = {'classes': DTD_CATEGORIES}
|
60 |
+
|
61 |
+
def __init__(self, data_root: str, split: str = 'trainval', **kwargs):
|
62 |
+
|
63 |
+
splits = ['train', 'val', 'trainval', 'test']
|
64 |
+
assert split in splits, \
|
65 |
+
f"The split must be one of {splits}, but get '{split}'"
|
66 |
+
self.split = split
|
67 |
+
|
68 |
+
data_prefix = 'images'
|
69 |
+
test_mode = split == 'test'
|
70 |
+
|
71 |
+
self.backend = get_file_backend(data_root, enable_singleton=True)
|
72 |
+
ann_file = self.backend.join_path('imdb', 'imdb.mat')
|
73 |
+
|
74 |
+
super(DTD, self).__init__(
|
75 |
+
ann_file=ann_file,
|
76 |
+
data_root=data_root,
|
77 |
+
data_prefix=data_prefix,
|
78 |
+
test_mode=test_mode,
|
79 |
+
**kwargs)
|
80 |
+
|
81 |
+
def load_data_list(self):
|
82 |
+
"""Load images and ground truth labels."""
|
83 |
+
|
84 |
+
data = mat4py.loadmat(self.ann_file)['images']
|
85 |
+
names = data['name']
|
86 |
+
labels = data['class']
|
87 |
+
parts = data['set']
|
88 |
+
num = len(names)
|
89 |
+
assert num == len(labels) == len(parts), 'get error ann file'
|
90 |
+
|
91 |
+
if self.split == 'train':
|
92 |
+
target_set = {1}
|
93 |
+
elif self.split == 'val':
|
94 |
+
target_set = {2}
|
95 |
+
elif self.split == 'test':
|
96 |
+
target_set = {3}
|
97 |
+
else:
|
98 |
+
target_set = {1, 2}
|
99 |
+
|
100 |
+
data_list = []
|
101 |
+
for i in range(num):
|
102 |
+
if parts[i] in target_set:
|
103 |
+
img_name = names[i]
|
104 |
+
img_path = self.backend.join_path(self.img_prefix, img_name)
|
105 |
+
gt_label = labels[i] - 1
|
106 |
+
info = dict(img_path=img_path, gt_label=gt_label)
|
107 |
+
data_list.append(info)
|
108 |
+
|
109 |
+
return data_list
|
110 |
+
|
111 |
+
def extra_repr(self) -> List[str]:
|
112 |
+
"""The extra repr information of the dataset."""
|
113 |
+
body = [
|
114 |
+
f'Root of dataset: \t{self.data_root}',
|
115 |
+
]
|
116 |
+
return body
|
mmpretrain/datasets/fgvcaircraft.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from mmengine import get_file_backend, list_from_file
|
5 |
+
|
6 |
+
from mmpretrain.registry import DATASETS
|
7 |
+
from .base_dataset import BaseDataset
|
8 |
+
from .categories import FGVCAIRCRAFT_CATEGORIES
|
9 |
+
|
10 |
+
|
11 |
+
@DATASETS.register_module()
|
12 |
+
class FGVCAircraft(BaseDataset):
|
13 |
+
"""The FGVC_Aircraft Dataset.
|
14 |
+
|
15 |
+
Support the `FGVC_Aircraft Dataset <https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/>`_ Dataset.
|
16 |
+
After downloading and decompression, the dataset directory structure is as follows.
|
17 |
+
|
18 |
+
FGVC_Aircraft dataset directory: ::
|
19 |
+
|
20 |
+
fgvc-aircraft-2013b
|
21 |
+
└── data
|
22 |
+
├── images
|
23 |
+
│ ├── 1.jpg
|
24 |
+
│ ├── 2.jpg
|
25 |
+
│ └── ...
|
26 |
+
├── images_variant_train.txt
|
27 |
+
├── images_variant_test.txt
|
28 |
+
├── images_variant_trainval.txt
|
29 |
+
├── images_variant_val.txt
|
30 |
+
├── variants.txt
|
31 |
+
└── ....
|
32 |
+
|
33 |
+
Args:
|
34 |
+
data_root (str): The root directory for FGVC_Aircraft dataset.
|
35 |
+
split (str, optional): The dataset split, supports "train",
|
36 |
+
"val", "trainval", and "test". Default to "trainval".
|
37 |
+
|
38 |
+
Examples:
|
39 |
+
>>> from mmpretrain.datasets import FGVCAircraft
|
40 |
+
>>> train_dataset = FGVCAircraft(data_root='data/fgvc-aircraft-2013b', split='trainval')
|
41 |
+
>>> train_dataset
|
42 |
+
Dataset FGVCAircraft
|
43 |
+
Number of samples: 6667
|
44 |
+
Number of categories: 100
|
45 |
+
Root of dataset: data/fgvc-aircraft-2013b
|
46 |
+
>>> test_dataset = FGVCAircraft(data_root='data/fgvc-aircraft-2013b', split='test')
|
47 |
+
>>> test_dataset
|
48 |
+
Dataset FGVCAircraft
|
49 |
+
Number of samples: 3333
|
50 |
+
Number of categories: 100
|
51 |
+
Root of dataset: data/fgvc-aircraft-2013b
|
52 |
+
""" # noqa: E501
|
53 |
+
|
54 |
+
METAINFO = {'classes': FGVCAIRCRAFT_CATEGORIES}
|
55 |
+
|
56 |
+
def __init__(self, data_root: str, split: str = 'trainval', **kwargs):
|
57 |
+
|
58 |
+
splits = ['train', 'val', 'trainval', 'test']
|
59 |
+
assert split in splits, \
|
60 |
+
f"The split must be one of {splits}, but get '{split}'"
|
61 |
+
self.split = split
|
62 |
+
|
63 |
+
self.backend = get_file_backend(data_root, enable_singleton=True)
|
64 |
+
ann_file = self.backend.join_path('data',
|
65 |
+
f'images_variant_{split}.txt')
|
66 |
+
data_prefix = self.backend.join_path('data', 'images')
|
67 |
+
test_mode = split == 'test'
|
68 |
+
|
69 |
+
super(FGVCAircraft, self).__init__(
|
70 |
+
ann_file=ann_file,
|
71 |
+
data_root=data_root,
|
72 |
+
test_mode=test_mode,
|
73 |
+
data_prefix=data_prefix,
|
74 |
+
**kwargs)
|
75 |
+
|
76 |
+
def load_data_list(self):
|
77 |
+
"""Load images and ground truth labels."""
|
78 |
+
|
79 |
+
pairs = list_from_file(self.ann_file)
|
80 |
+
data_list = []
|
81 |
+
for pair in pairs:
|
82 |
+
pair = pair.split()
|
83 |
+
img_name = pair[0]
|
84 |
+
class_name = ' '.join(pair[1:])
|
85 |
+
img_name = f'{img_name}.jpg'
|
86 |
+
img_path = self.backend.join_path(self.img_prefix, img_name)
|
87 |
+
gt_label = self.METAINFO['classes'].index(class_name)
|
88 |
+
info = dict(img_path=img_path, gt_label=gt_label)
|
89 |
+
data_list.append(info)
|
90 |
+
|
91 |
+
return data_list
|
92 |
+
|
93 |
+
def extra_repr(self) -> List[str]:
|
94 |
+
"""The extra repr information of the dataset."""
|
95 |
+
body = [
|
96 |
+
f'Root of dataset: \t{self.data_root}',
|
97 |
+
]
|
98 |
+
return body
|
mmpretrain/datasets/flamingo.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import random
|
3 |
+
from abc import abstractmethod
|
4 |
+
from collections import Counter
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import mmengine
|
8 |
+
import numpy as np
|
9 |
+
from mmengine.dataset import BaseDataset
|
10 |
+
from pycocotools.coco import COCO
|
11 |
+
|
12 |
+
from mmpretrain.registry import DATASETS
|
13 |
+
from .coco_vqa import COCOVQA
|
14 |
+
|
15 |
+
|
16 |
+
class FlamingoFewShotMixin:
|
17 |
+
"""Flamingo fewshot eval dataset minin.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
num_shots (int): Number of shots to perform evaluation.
|
21 |
+
Defaults to 0.
|
22 |
+
Note: 0 does not mean a strict zero-shot in Flamingo setting.
|
23 |
+
It will use 2 only-text prompt without in context images.
|
24 |
+
num_support_examples (int): Number of support examples to get the
|
25 |
+
few shots from. Defaults to 2048.
|
26 |
+
num_query_examples (int): Number of query examples to perform the
|
27 |
+
final evaluation. Defaults to 5000.
|
28 |
+
incontext_prompt_temp (str): In context prompt template for few shot
|
29 |
+
examples. Defaults to ''.
|
30 |
+
final_prompt_temp (str): Final query prompt template. Defaults to ''.
|
31 |
+
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self,
|
35 |
+
num_shots: int = 0,
|
36 |
+
num_support_examples: int = 2048,
|
37 |
+
num_query_examples: int = 5000,
|
38 |
+
incontext_prompt_temp: str = '',
|
39 |
+
final_prompt_temp: str = '',
|
40 |
+
**kwarg):
|
41 |
+
self.num_shots = num_shots
|
42 |
+
self.num_support_examples = num_support_examples
|
43 |
+
self.num_query_examples = num_query_examples
|
44 |
+
self.incontext_prompt_temp = incontext_prompt_temp
|
45 |
+
self.final_prompt_temp = final_prompt_temp
|
46 |
+
super().__init__(**kwarg)
|
47 |
+
|
48 |
+
def get_subset_idx(self, total_num):
|
49 |
+
random_idx = np.random.choice(
|
50 |
+
total_num,
|
51 |
+
self.num_support_examples + self.num_query_examples,
|
52 |
+
replace=False)
|
53 |
+
|
54 |
+
support_idx = random_idx[:self.num_support_examples]
|
55 |
+
query_idx = random_idx[self.num_support_examples:]
|
56 |
+
return support_idx, query_idx
|
57 |
+
|
58 |
+
@abstractmethod
|
59 |
+
def parse_basic_anno(self, anno: dict) -> dict:
|
60 |
+
"""Parse basic annotation for support and query set."""
|
61 |
+
pass
|
62 |
+
|
63 |
+
@abstractmethod
|
64 |
+
def parse_fewshot_anno(self, anno: dict, support_list: List) -> dict:
|
65 |
+
"""Parse fewshot related annotation for query set with support list."""
|
66 |
+
pass
|
67 |
+
|
68 |
+
|
69 |
+
@DATASETS.register_module()
|
70 |
+
class FlamingoEvalCOCOVQA(FlamingoFewShotMixin, COCOVQA):
|
71 |
+
"""Flamingo few shot VQAv2 dataset.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
data_root (str): The root directory for ``data_prefix`` and
|
75 |
+
``ann_file``.
|
76 |
+
ann_file (str): Annotation file path.
|
77 |
+
question_file (str): Question file path.
|
78 |
+
num_shots (int): Number of shots to perform evaluation.
|
79 |
+
Defaults to 0.
|
80 |
+
Note: 0 does not mean a strict zero-shot in Flamingo setting.
|
81 |
+
It will use 2 only-text prompt without in context images.
|
82 |
+
num_support_examples (int): Number of support examples to get the
|
83 |
+
few shots from. Defaults to 2048.
|
84 |
+
num_query_examples (int): Number of query examples to perform the
|
85 |
+
final evaluation. Defaults to 5000.
|
86 |
+
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(self,
|
90 |
+
data_root: str,
|
91 |
+
question_file: str,
|
92 |
+
ann_file: str = '',
|
93 |
+
num_shots: int = 0,
|
94 |
+
num_support_examples: int = 2048,
|
95 |
+
num_query_examples: int = 5000,
|
96 |
+
**kwarg):
|
97 |
+
super().__init__(
|
98 |
+
data_root=data_root,
|
99 |
+
question_file=question_file,
|
100 |
+
ann_file=ann_file,
|
101 |
+
num_shots=num_shots,
|
102 |
+
num_support_examples=num_support_examples,
|
103 |
+
num_query_examples=num_query_examples,
|
104 |
+
**kwarg)
|
105 |
+
|
106 |
+
def parse_basic_anno(self, ann: dict) -> dict:
|
107 |
+
"""Parse basic annotation for support and query set.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
anno (dict): Annotation for single example.
|
111 |
+
|
112 |
+
Return:
|
113 |
+
dict: Parsed annotation for single example.
|
114 |
+
"""
|
115 |
+
if ann is None:
|
116 |
+
return {}
|
117 |
+
|
118 |
+
answers = [a['answer'] for a in ann['answers']]
|
119 |
+
count = Counter(answers)
|
120 |
+
answer_weight = [i / len(answers) for i in count.values()]
|
121 |
+
answer_info = {
|
122 |
+
'gt_answer': list(count.keys()),
|
123 |
+
'gt_answer_weight': answer_weight
|
124 |
+
}
|
125 |
+
return answer_info
|
126 |
+
|
127 |
+
def parse_fewshot_anno(self, query: dict, support_list: List) -> dict:
|
128 |
+
"""Parse fewshot related annotation for query set with support list.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
anno (dict): Annotation for single example.
|
132 |
+
support_list (List): List of support subset to subsample few shots.
|
133 |
+
|
134 |
+
Return:
|
135 |
+
dict: Parsed annotation for single example.
|
136 |
+
"""
|
137 |
+
# prepare n shots examples
|
138 |
+
shots = random.sample(support_list, self.num_shots)
|
139 |
+
|
140 |
+
# append image path for n shots
|
141 |
+
img_path = [shot['img_path'] for shot in shots]
|
142 |
+
img_path.append(query['img_path'])
|
143 |
+
query['img_path'] = img_path
|
144 |
+
|
145 |
+
query['shots'] = [
|
146 |
+
dict(
|
147 |
+
question=item['question'],
|
148 |
+
answer=item['gt_answer'][0],
|
149 |
+
) for item in shots
|
150 |
+
]
|
151 |
+
return query
|
152 |
+
|
153 |
+
def load_data_list(self) -> List[dict]:
|
154 |
+
"""Load data list."""
|
155 |
+
questions = mmengine.load(self.question_file)['questions']
|
156 |
+
if self.ann_file:
|
157 |
+
annotations = mmengine.load(self.ann_file)['annotations']
|
158 |
+
assert len(questions) == len(annotations)
|
159 |
+
else:
|
160 |
+
annotations = [None] * len(questions)
|
161 |
+
if self.num_shots > 0:
|
162 |
+
raise ValueError('Unable to construct few-shot examples '
|
163 |
+
'since no annotation file.')
|
164 |
+
|
165 |
+
# The original VQAv2 annotation file and question file includes
|
166 |
+
# only image id but no image file paths.
|
167 |
+
self.image_index = self._create_image_index()
|
168 |
+
|
169 |
+
num_data = len(questions)
|
170 |
+
support_idx, query_idx = self.get_subset_idx(num_data)
|
171 |
+
|
172 |
+
# prepare support subset
|
173 |
+
if self.num_shots > 0:
|
174 |
+
support_list = []
|
175 |
+
for idx in support_idx:
|
176 |
+
question = questions[idx]
|
177 |
+
ann = annotations[idx]
|
178 |
+
support = {**question, **self.parse_basic_anno(ann)}
|
179 |
+
support['img_path'] = self.image_index[question['image_id']]
|
180 |
+
support_list.append(support)
|
181 |
+
|
182 |
+
# prepare query subset
|
183 |
+
data_list = []
|
184 |
+
for idx in query_idx:
|
185 |
+
question = questions[idx]
|
186 |
+
ann = annotations[idx]
|
187 |
+
data_info = {**question, **self.parse_basic_anno(ann)}
|
188 |
+
data_info['img_path'] = self.image_index[question['image_id']]
|
189 |
+
if self.num_shots > 0:
|
190 |
+
data_info = self.parse_fewshot_anno(data_info, support_list)
|
191 |
+
data_list.append(data_info)
|
192 |
+
|
193 |
+
return data_list
|
194 |
+
|
195 |
+
|
196 |
+
@DATASETS.register_module()
|
197 |
+
class FlamingoEvalCOCOCaption(FlamingoFewShotMixin, BaseDataset):
|
198 |
+
"""Flamingo few shot COCO Caption dataset.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
data_root (str): The root directory for ``data_prefix`` and
|
202 |
+
``ann_file``.
|
203 |
+
ann_file (str): Annotation file path.
|
204 |
+
data_prefix (dict): Prefix for data field. Defaults to
|
205 |
+
``dict(img_path='')``.
|
206 |
+
num_shots (int): Number of shots to perform evaluation.
|
207 |
+
Defaults to 0.
|
208 |
+
num_support_examples (int): Number of support examples to get the
|
209 |
+
few shots from. Defaults to 2048.
|
210 |
+
num_query_examples (int): Number of query examples to perform the
|
211 |
+
final evaluation. Defaults to 5000.
|
212 |
+
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self,
|
216 |
+
data_root: str,
|
217 |
+
ann_file: str,
|
218 |
+
num_shots: int = 0,
|
219 |
+
num_support_examples: int = 2048,
|
220 |
+
num_query_examples: int = 5000,
|
221 |
+
**kwarg):
|
222 |
+
super().__init__(
|
223 |
+
data_root=data_root,
|
224 |
+
ann_file=ann_file,
|
225 |
+
num_shots=num_shots,
|
226 |
+
num_support_examples=num_support_examples,
|
227 |
+
num_query_examples=num_query_examples,
|
228 |
+
**kwarg)
|
229 |
+
|
230 |
+
def parse_basic_anno(self, ann: dict, coco: COCO) -> dict:
|
231 |
+
"""Parse basic annotation for support and query set.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
anno (dict): Annotation for single example.
|
235 |
+
coco (COCO): The coco dataset.
|
236 |
+
|
237 |
+
Return:
|
238 |
+
dict: Parsed annotation for single example.
|
239 |
+
"""
|
240 |
+
img_prefix = self.data_prefix['img_path']
|
241 |
+
img = coco.imgs[ann['image_id']]
|
242 |
+
data_info = dict(
|
243 |
+
img_path=mmengine.join_path(img_prefix, img['file_name']),
|
244 |
+
gt_caption=ann['caption'],
|
245 |
+
image_id=ann['image_id'],
|
246 |
+
)
|
247 |
+
return data_info
|
248 |
+
|
249 |
+
def parse_fewshot_anno(self, query: dict, support_list: List) -> dict:
|
250 |
+
"""Parse fewshot related annotation for query set with support list.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
query (dict): Annotation for single example.
|
254 |
+
support_list (List): List of support subset to subsample few shots.
|
255 |
+
coco (COCO): The coco dataset.
|
256 |
+
|
257 |
+
Return:
|
258 |
+
dict: Parsed annotation for single example.
|
259 |
+
"""
|
260 |
+
# prepare n shots examples
|
261 |
+
shots = random.sample(support_list, self.num_shots)
|
262 |
+
|
263 |
+
# append image path for n shots
|
264 |
+
img_path = [shot['img_path'] for shot in shots]
|
265 |
+
img_path.append(query['img_path'])
|
266 |
+
query['img_path'] = img_path
|
267 |
+
|
268 |
+
query['shots'] = [dict(caption=item['gt_caption']) for item in shots]
|
269 |
+
return query
|
270 |
+
|
271 |
+
def load_data_list(self) -> List[dict]:
|
272 |
+
"""Load data list."""
|
273 |
+
with mmengine.get_local_path(self.ann_file) as ann_file:
|
274 |
+
coco = COCO(ann_file)
|
275 |
+
|
276 |
+
num_data = len(coco.anns)
|
277 |
+
support_idx, query_idx = self.get_subset_idx(num_data)
|
278 |
+
ann_ids = list(coco.anns)
|
279 |
+
|
280 |
+
# prepare support subset
|
281 |
+
if self.num_shots > 0:
|
282 |
+
support_list = []
|
283 |
+
for idx in support_idx:
|
284 |
+
support = self.parse_basic_anno(coco.anns[ann_ids[idx]], coco)
|
285 |
+
support_list.append(support)
|
286 |
+
|
287 |
+
# prepare query subset
|
288 |
+
query_list = []
|
289 |
+
for idx in query_idx:
|
290 |
+
data_info = self.parse_basic_anno(coco.anns[ann_ids[idx]], coco)
|
291 |
+
if self.num_shots > 0:
|
292 |
+
data_info = self.parse_fewshot_anno(data_info, support_list)
|
293 |
+
query_list.append(data_info)
|
294 |
+
|
295 |
+
return query_list
|
mmpretrain/datasets/flowers102.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import mat4py
|
5 |
+
from mmengine import get_file_backend
|
6 |
+
|
7 |
+
from mmpretrain.registry import DATASETS
|
8 |
+
from .base_dataset import BaseDataset
|
9 |
+
|
10 |
+
|
11 |
+
@DATASETS.register_module()
|
12 |
+
class Flowers102(BaseDataset):
|
13 |
+
"""The Oxford 102 Flower Dataset.
|
14 |
+
|
15 |
+
Support the `Oxford 102 Flowers Dataset <https://www.robots.ox.ac.uk/~vgg/data/flowers/102/>`_ Dataset.
|
16 |
+
After downloading and decompression, the dataset directory structure is as follows.
|
17 |
+
|
18 |
+
Flowers102 dataset directory: ::
|
19 |
+
|
20 |
+
Flowers102
|
21 |
+
├── jpg
|
22 |
+
│ ├── image_00001.jpg
|
23 |
+
│ ├── image_00002.jpg
|
24 |
+
│ └── ...
|
25 |
+
├── imagelabels.mat
|
26 |
+
├── setid.mat
|
27 |
+
└── ...
|
28 |
+
|
29 |
+
Args:
|
30 |
+
data_root (str): The root directory for Oxford 102 Flowers dataset.
|
31 |
+
split (str, optional): The dataset split, supports "train",
|
32 |
+
"val", "trainval", and "test". Default to "trainval".
|
33 |
+
|
34 |
+
Examples:
|
35 |
+
>>> from mmpretrain.datasets import Flowers102
|
36 |
+
>>> train_dataset = Flowers102(data_root='data/Flowers102', split='trainval')
|
37 |
+
>>> train_dataset
|
38 |
+
Dataset Flowers102
|
39 |
+
Number of samples: 2040
|
40 |
+
Root of dataset: data/Flowers102
|
41 |
+
>>> test_dataset = Flowers102(data_root='data/Flowers102', split='test')
|
42 |
+
>>> test_dataset
|
43 |
+
Dataset Flowers102
|
44 |
+
Number of samples: 6149
|
45 |
+
Root of dataset: data/Flowers102
|
46 |
+
""" # noqa: E501
|
47 |
+
|
48 |
+
def __init__(self, data_root: str, split: str = 'trainval', **kwargs):
|
49 |
+
splits = ['train', 'val', 'trainval', 'test']
|
50 |
+
assert split in splits, \
|
51 |
+
f"The split must be one of {splits}, but get '{split}'"
|
52 |
+
self.split = split
|
53 |
+
|
54 |
+
ann_file = 'imagelabels.mat'
|
55 |
+
data_prefix = 'jpg'
|
56 |
+
train_test_split_file = 'setid.mat'
|
57 |
+
test_mode = split == 'test'
|
58 |
+
|
59 |
+
self.backend = get_file_backend(data_root, enable_singleton=True)
|
60 |
+
|
61 |
+
self.train_test_split_file = self.backend.join_path(
|
62 |
+
data_root, train_test_split_file)
|
63 |
+
|
64 |
+
super(Flowers102, self).__init__(
|
65 |
+
ann_file=ann_file,
|
66 |
+
data_root=data_root,
|
67 |
+
data_prefix=data_prefix,
|
68 |
+
test_mode=test_mode,
|
69 |
+
**kwargs)
|
70 |
+
|
71 |
+
def load_data_list(self):
|
72 |
+
"""Load images and ground truth labels."""
|
73 |
+
|
74 |
+
label_dict = mat4py.loadmat(self.ann_file)['labels']
|
75 |
+
split_list = mat4py.loadmat(self.train_test_split_file)
|
76 |
+
|
77 |
+
if self.split == 'train':
|
78 |
+
split_list = split_list['trnid']
|
79 |
+
elif self.split == 'val':
|
80 |
+
split_list = split_list['valid']
|
81 |
+
elif self.split == 'test':
|
82 |
+
split_list = split_list['tstid']
|
83 |
+
else:
|
84 |
+
train_ids = split_list['trnid']
|
85 |
+
val_ids = split_list['valid']
|
86 |
+
train_ids.extend(val_ids)
|
87 |
+
split_list = train_ids
|
88 |
+
|
89 |
+
data_list = []
|
90 |
+
for sample_id in split_list:
|
91 |
+
img_name = 'image_%05d.jpg' % (sample_id)
|
92 |
+
img_path = self.backend.join_path(self.img_prefix, img_name)
|
93 |
+
gt_label = int(label_dict[sample_id - 1]) - 1
|
94 |
+
info = dict(img_path=img_path, gt_label=gt_label)
|
95 |
+
data_list.append(info)
|
96 |
+
|
97 |
+
return data_list
|
98 |
+
|
99 |
+
def extra_repr(self) -> List[str]:
|
100 |
+
"""The extra repr information of the dataset."""
|
101 |
+
body = [
|
102 |
+
f'Root of dataset: \t{self.data_root}',
|
103 |
+
]
|
104 |
+
return body
|
mmpretrain/datasets/food101.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from mmengine import get_file_backend, list_from_file
|
5 |
+
|
6 |
+
from mmpretrain.registry import DATASETS
|
7 |
+
from .base_dataset import BaseDataset
|
8 |
+
from .categories import FOOD101_CATEGORIES
|
9 |
+
|
10 |
+
|
11 |
+
@DATASETS.register_module()
|
12 |
+
class Food101(BaseDataset):
|
13 |
+
"""The Food101 Dataset.
|
14 |
+
|
15 |
+
Support the `Food101 Dataset <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_ Dataset.
|
16 |
+
After downloading and decompression, the dataset directory structure is as follows.
|
17 |
+
|
18 |
+
Food101 dataset directory: ::
|
19 |
+
|
20 |
+
food-101
|
21 |
+
├── images
|
22 |
+
│ ├── class_x
|
23 |
+
│ │ ├── xx1.jpg
|
24 |
+
│ │ ├── xx2.jpg
|
25 |
+
│ │ └── ...
|
26 |
+
│ ├── class_y
|
27 |
+
│ │ ├── yy1.jpg
|
28 |
+
│ │ ├── yy2.jpg
|
29 |
+
│ │ └── ...
|
30 |
+
│ └── ...
|
31 |
+
├── meta
|
32 |
+
│ ├── train.txt
|
33 |
+
│ └── test.txt
|
34 |
+
└── ....
|
35 |
+
|
36 |
+
Args:
|
37 |
+
data_root (str): The root directory for Food101 dataset.
|
38 |
+
split (str, optional): The dataset split, supports "train" and "test".
|
39 |
+
Default to "train".
|
40 |
+
|
41 |
+
Examples:
|
42 |
+
>>> from mmpretrain.datasets import Food101
|
43 |
+
>>> train_dataset = Food101(data_root='data/food-101', split='train')
|
44 |
+
>>> train_dataset
|
45 |
+
Dataset Food101
|
46 |
+
Number of samples: 75750
|
47 |
+
Number of categories: 101
|
48 |
+
Root of dataset: data/food-101
|
49 |
+
>>> test_dataset = Food101(data_root='data/food-101', split='test')
|
50 |
+
>>> test_dataset
|
51 |
+
Dataset Food101
|
52 |
+
Number of samples: 25250
|
53 |
+
Number of categories: 101
|
54 |
+
Root of dataset: data/food-101
|
55 |
+
""" # noqa: E501
|
56 |
+
|
57 |
+
METAINFO = {'classes': FOOD101_CATEGORIES}
|
58 |
+
|
59 |
+
def __init__(self, data_root: str, split: str = 'train', **kwargs):
|
60 |
+
|
61 |
+
splits = ['train', 'test']
|
62 |
+
assert split in splits, \
|
63 |
+
f"The split must be one of {splits}, but get '{split}'"
|
64 |
+
self.split = split
|
65 |
+
|
66 |
+
self.backend = get_file_backend(data_root, enable_singleton=True)
|
67 |
+
if split == 'train':
|
68 |
+
ann_file = self.backend.join_path('meta', 'train.txt')
|
69 |
+
else:
|
70 |
+
ann_file = self.backend.join_path('meta', 'test.txt')
|
71 |
+
|
72 |
+
test_mode = split == 'test'
|
73 |
+
data_prefix = 'images'
|
74 |
+
|
75 |
+
super(Food101, self).__init__(
|
76 |
+
ann_file=ann_file,
|
77 |
+
data_root=data_root,
|
78 |
+
test_mode=test_mode,
|
79 |
+
data_prefix=data_prefix,
|
80 |
+
**kwargs)
|
81 |
+
|
82 |
+
def load_data_list(self):
|
83 |
+
"""Load images and ground truth labels."""
|
84 |
+
|
85 |
+
pairs = list_from_file(self.ann_file)
|
86 |
+
data_list = []
|
87 |
+
for pair in pairs:
|
88 |
+
class_name, img_name = pair.split('/')
|
89 |
+
img_name = f'{img_name}.jpg'
|
90 |
+
img_path = self.backend.join_path(self.img_prefix, class_name,
|
91 |
+
img_name)
|
92 |
+
gt_label = self.METAINFO['classes'].index(class_name)
|
93 |
+
info = dict(img_path=img_path, gt_label=gt_label)
|
94 |
+
data_list.append(info)
|
95 |
+
return data_list
|
96 |
+
|
97 |
+
def extra_repr(self) -> List[str]:
|
98 |
+
"""The extra repr information of the dataset."""
|
99 |
+
body = [
|
100 |
+
f'Root of dataset: \t{self.data_root}',
|
101 |
+
]
|
102 |
+
return body
|
mmpretrain/datasets/imagenet.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import Optional, Union
|
3 |
+
|
4 |
+
from mmengine.logging import MMLogger
|
5 |
+
|
6 |
+
from mmpretrain.registry import DATASETS
|
7 |
+
from .categories import IMAGENET_CATEGORIES
|
8 |
+
from .custom import CustomDataset
|
9 |
+
|
10 |
+
|
11 |
+
@DATASETS.register_module()
|
12 |
+
class ImageNet(CustomDataset):
|
13 |
+
"""`ImageNet <http://www.image-net.org>`_ Dataset.
|
14 |
+
|
15 |
+
The dataset supports two kinds of annotation format. More details can be
|
16 |
+
found in :class:`CustomDataset`.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
data_root (str): The root directory for ``data_prefix`` and
|
20 |
+
``ann_file``. Defaults to ''.
|
21 |
+
data_prefix (str | dict): Prefix for training data. Defaults to ''.
|
22 |
+
ann_file (str): Annotation file path. Defaults to ''.
|
23 |
+
metainfo (dict, optional): Meta information for dataset, such as class
|
24 |
+
information. Defaults to None.
|
25 |
+
**kwargs: Other keyword arguments in :class:`CustomDataset` and
|
26 |
+
:class:`BaseDataset`.
|
27 |
+
""" # noqa: E501
|
28 |
+
|
29 |
+
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
|
30 |
+
METAINFO = {'classes': IMAGENET_CATEGORIES}
|
31 |
+
|
32 |
+
def __init__(self,
|
33 |
+
data_root: str = '',
|
34 |
+
data_prefix: Union[str, dict] = '',
|
35 |
+
ann_file: str = '',
|
36 |
+
metainfo: Optional[dict] = None,
|
37 |
+
**kwargs):
|
38 |
+
kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs}
|
39 |
+
super().__init__(
|
40 |
+
data_root=data_root,
|
41 |
+
data_prefix=data_prefix,
|
42 |
+
ann_file=ann_file,
|
43 |
+
metainfo=metainfo,
|
44 |
+
**kwargs)
|
45 |
+
|
46 |
+
|
47 |
+
@DATASETS.register_module()
|
48 |
+
class ImageNet21k(CustomDataset):
|
49 |
+
"""ImageNet21k Dataset.
|
50 |
+
|
51 |
+
Since the dataset ImageNet21k is extremely big, cantains 21k+ classes
|
52 |
+
and 1.4B files. We won't provide the default categories list. Please
|
53 |
+
specify it from the ``classes`` argument.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
data_root (str): The root directory for ``data_prefix`` and
|
57 |
+
``ann_file``. Defaults to ''.
|
58 |
+
data_prefix (str | dict): Prefix for training data. Defaults to ''.
|
59 |
+
ann_file (str): Annotation file path. Defaults to ''.
|
60 |
+
metainfo (dict, optional): Meta information for dataset, such as class
|
61 |
+
information. Defaults to None.
|
62 |
+
multi_label (bool): Not implement by now. Use multi label or not.
|
63 |
+
Defaults to False.
|
64 |
+
**kwargs: Other keyword arguments in :class:`CustomDataset` and
|
65 |
+
:class:`BaseDataset`.
|
66 |
+
"""
|
67 |
+
|
68 |
+
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
|
69 |
+
|
70 |
+
def __init__(self,
|
71 |
+
data_root: str = '',
|
72 |
+
data_prefix: Union[str, dict] = '',
|
73 |
+
ann_file: str = '',
|
74 |
+
metainfo: Optional[dict] = None,
|
75 |
+
multi_label: bool = False,
|
76 |
+
**kwargs):
|
77 |
+
if multi_label:
|
78 |
+
raise NotImplementedError(
|
79 |
+
'The `multi_label` option is not supported by now.')
|
80 |
+
self.multi_label = multi_label
|
81 |
+
|
82 |
+
logger = MMLogger.get_current_instance()
|
83 |
+
|
84 |
+
if not ann_file:
|
85 |
+
logger.warning(
|
86 |
+
'The ImageNet21k dataset is large, and scanning directory may '
|
87 |
+
'consume long time. Considering to specify the `ann_file` to '
|
88 |
+
'accelerate the initialization.')
|
89 |
+
|
90 |
+
kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs}
|
91 |
+
super().__init__(
|
92 |
+
data_root=data_root,
|
93 |
+
data_prefix=data_prefix,
|
94 |
+
ann_file=ann_file,
|
95 |
+
metainfo=metainfo,
|
96 |
+
**kwargs)
|
97 |
+
|
98 |
+
if self.CLASSES is None:
|
99 |
+
logger.warning(
|
100 |
+
'The CLASSES is not stored in the `ImageNet21k` class. '
|
101 |
+
'Considering to specify the `classes` argument if you need '
|
102 |
+
'do inference on the ImageNet-21k dataset')
|
mmpretrain/datasets/inshop.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from mmengine import get_file_backend, list_from_file
|
3 |
+
|
4 |
+
from mmpretrain.registry import DATASETS
|
5 |
+
from .base_dataset import BaseDataset
|
6 |
+
|
7 |
+
|
8 |
+
@DATASETS.register_module()
|
9 |
+
class InShop(BaseDataset):
|
10 |
+
"""InShop Dataset for Image Retrieval.
|
11 |
+
|
12 |
+
Please download the images from the homepage
|
13 |
+
'https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html'
|
14 |
+
(In-shop Clothes Retrieval Benchmark -> Img -> img.zip,
|
15 |
+
Eval/list_eval_partition.txt), and organize them as follows way: ::
|
16 |
+
|
17 |
+
In-shop Clothes Retrieval Benchmark (data_root)/
|
18 |
+
├── Eval /
|
19 |
+
│ └── list_eval_partition.txt (ann_file)
|
20 |
+
├── Img (img_prefix)
|
21 |
+
│ └── img/
|
22 |
+
├── README.txt
|
23 |
+
└── .....
|
24 |
+
|
25 |
+
Args:
|
26 |
+
data_root (str): The root directory for dataset.
|
27 |
+
split (str): Choose from 'train', 'query' and 'gallery'.
|
28 |
+
Defaults to 'train'.
|
29 |
+
data_prefix (str | dict): Prefix for training data.
|
30 |
+
Defaults to 'Img'.
|
31 |
+
ann_file (str): Annotation file path, path relative to
|
32 |
+
``data_root``. Defaults to 'Eval/list_eval_partition.txt'.
|
33 |
+
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
34 |
+
|
35 |
+
Examples:
|
36 |
+
>>> from mmpretrain.datasets import InShop
|
37 |
+
>>>
|
38 |
+
>>> # build train InShop dataset
|
39 |
+
>>> inshop_train_cfg = dict(data_root='data/inshop', split='train')
|
40 |
+
>>> inshop_train = InShop(**inshop_train_cfg)
|
41 |
+
>>> inshop_train
|
42 |
+
Dataset InShop
|
43 |
+
Number of samples: 25882
|
44 |
+
The `CLASSES` meta info is not set.
|
45 |
+
Root of dataset: data/inshop
|
46 |
+
>>>
|
47 |
+
>>> # build query InShop dataset
|
48 |
+
>>> inshop_query_cfg = dict(data_root='data/inshop', split='query')
|
49 |
+
>>> inshop_query = InShop(**inshop_query_cfg)
|
50 |
+
>>> inshop_query
|
51 |
+
Dataset InShop
|
52 |
+
Number of samples: 14218
|
53 |
+
The `CLASSES` meta info is not set.
|
54 |
+
Root of dataset: data/inshop
|
55 |
+
>>>
|
56 |
+
>>> # build gallery InShop dataset
|
57 |
+
>>> inshop_gallery_cfg = dict(data_root='data/inshop', split='gallery')
|
58 |
+
>>> inshop_gallery = InShop(**inshop_gallery_cfg)
|
59 |
+
>>> inshop_gallery
|
60 |
+
Dataset InShop
|
61 |
+
Number of samples: 12612
|
62 |
+
The `CLASSES` meta info is not set.
|
63 |
+
Root of dataset: data/inshop
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self,
|
67 |
+
data_root: str,
|
68 |
+
split: str = 'train',
|
69 |
+
data_prefix: str = 'Img',
|
70 |
+
ann_file: str = 'Eval/list_eval_partition.txt',
|
71 |
+
**kwargs):
|
72 |
+
|
73 |
+
assert split in ('train', 'query', 'gallery'), "'split' of `InShop`" \
|
74 |
+
f" must be one of ['train', 'query', 'gallery'], bu get '{split}'"
|
75 |
+
self.backend = get_file_backend(data_root, enable_singleton=True)
|
76 |
+
self.split = split
|
77 |
+
super().__init__(
|
78 |
+
data_root=data_root,
|
79 |
+
data_prefix=data_prefix,
|
80 |
+
ann_file=ann_file,
|
81 |
+
**kwargs)
|
82 |
+
|
83 |
+
def _process_annotations(self):
|
84 |
+
lines = list_from_file(self.ann_file)
|
85 |
+
|
86 |
+
anno_train = dict(metainfo=dict(), data_list=list())
|
87 |
+
anno_gallery = dict(metainfo=dict(), data_list=list())
|
88 |
+
|
89 |
+
# item_id to label, each item corresponds to one class label
|
90 |
+
class_num = 0
|
91 |
+
gt_label_train = {}
|
92 |
+
|
93 |
+
# item_id to label, each label corresponds to several items
|
94 |
+
gallery_num = 0
|
95 |
+
gt_label_gallery = {}
|
96 |
+
|
97 |
+
# (lines[0], lines[1]) is the image number and the field name;
|
98 |
+
# Each line format as 'image_name, item_id, evaluation_status'
|
99 |
+
for line in lines[2:]:
|
100 |
+
img_name, item_id, status = line.split()
|
101 |
+
img_path = self.backend.join_path(self.img_prefix, img_name)
|
102 |
+
if status == 'train':
|
103 |
+
if item_id not in gt_label_train:
|
104 |
+
gt_label_train[item_id] = class_num
|
105 |
+
class_num += 1
|
106 |
+
# item_id to class_id (for the training set)
|
107 |
+
anno_train['data_list'].append(
|
108 |
+
dict(img_path=img_path, gt_label=gt_label_train[item_id]))
|
109 |
+
elif status == 'gallery':
|
110 |
+
if item_id not in gt_label_gallery:
|
111 |
+
gt_label_gallery[item_id] = []
|
112 |
+
# Since there are multiple images for each item,
|
113 |
+
# record the corresponding item for each image.
|
114 |
+
gt_label_gallery[item_id].append(gallery_num)
|
115 |
+
anno_gallery['data_list'].append(
|
116 |
+
dict(img_path=img_path, sample_idx=gallery_num))
|
117 |
+
gallery_num += 1
|
118 |
+
|
119 |
+
if self.split == 'train':
|
120 |
+
anno_train['metainfo']['class_number'] = class_num
|
121 |
+
anno_train['metainfo']['sample_number'] = \
|
122 |
+
len(anno_train['data_list'])
|
123 |
+
return anno_train
|
124 |
+
elif self.split == 'gallery':
|
125 |
+
anno_gallery['metainfo']['sample_number'] = gallery_num
|
126 |
+
return anno_gallery
|
127 |
+
|
128 |
+
# Generate the label for the query(val) set
|
129 |
+
anno_query = dict(metainfo=dict(), data_list=list())
|
130 |
+
query_num = 0
|
131 |
+
for line in lines[2:]:
|
132 |
+
img_name, item_id, status = line.split()
|
133 |
+
img_path = self.backend.join_path(self.img_prefix, img_name)
|
134 |
+
if status == 'query':
|
135 |
+
anno_query['data_list'].append(
|
136 |
+
dict(
|
137 |
+
img_path=img_path, gt_label=gt_label_gallery[item_id]))
|
138 |
+
query_num += 1
|
139 |
+
|
140 |
+
anno_query['metainfo']['sample_number'] = query_num
|
141 |
+
return anno_query
|
142 |
+
|
143 |
+
def load_data_list(self):
|
144 |
+
"""load data list.
|
145 |
+
|
146 |
+
For the train set, return image and ground truth label. For the query
|
147 |
+
set, return image and ids of images in gallery. For the gallery set,
|
148 |
+
return image and its id.
|
149 |
+
"""
|
150 |
+
data_info = self._process_annotations()
|
151 |
+
data_list = data_info['data_list']
|
152 |
+
return data_list
|
153 |
+
|
154 |
+
def extra_repr(self):
|
155 |
+
"""The extra repr information of the dataset."""
|
156 |
+
body = [f'Root of dataset: \t{self.data_root}']
|
157 |
+
return body
|
mmpretrain/datasets/mnist.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import codecs
|
3 |
+
from typing import List, Optional
|
4 |
+
from urllib.parse import urljoin
|
5 |
+
|
6 |
+
import mmengine.dist as dist
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from mmengine.fileio import LocalBackend, exists, get_file_backend, join_path
|
10 |
+
|
11 |
+
from mmpretrain.registry import DATASETS
|
12 |
+
from .base_dataset import BaseDataset
|
13 |
+
from .categories import FASHIONMNIST_CATEGORITES, MNIST_CATEGORITES
|
14 |
+
from .utils import (download_and_extract_archive, open_maybe_compressed_file,
|
15 |
+
rm_suffix)
|
16 |
+
|
17 |
+
|
18 |
+
@DATASETS.register_module()
|
19 |
+
class MNIST(BaseDataset):
|
20 |
+
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
|
21 |
+
|
22 |
+
This implementation is modified from
|
23 |
+
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
|
24 |
+
|
25 |
+
Args:
|
26 |
+
data_prefix (str): Prefix for data.
|
27 |
+
test_mode (bool): ``test_mode=True`` means in test phase.
|
28 |
+
It determines to use the training set or test set.
|
29 |
+
metainfo (dict, optional): Meta information for dataset, such as
|
30 |
+
categories information. Defaults to None.
|
31 |
+
data_root (str): The root directory for ``data_prefix``.
|
32 |
+
Defaults to ''.
|
33 |
+
download (bool): Whether to download the dataset if not exists.
|
34 |
+
Defaults to True.
|
35 |
+
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
36 |
+
""" # noqa: E501
|
37 |
+
|
38 |
+
url_prefix = 'http://yann.lecun.com/exdb/mnist/'
|
39 |
+
# train images and labels
|
40 |
+
train_list = [
|
41 |
+
['train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'],
|
42 |
+
['train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'],
|
43 |
+
]
|
44 |
+
# test images and labels
|
45 |
+
test_list = [
|
46 |
+
['t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'],
|
47 |
+
['t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c'],
|
48 |
+
]
|
49 |
+
METAINFO = {'classes': MNIST_CATEGORITES}
|
50 |
+
|
51 |
+
def __init__(self,
|
52 |
+
data_prefix: str,
|
53 |
+
test_mode: bool,
|
54 |
+
metainfo: Optional[dict] = None,
|
55 |
+
data_root: str = '',
|
56 |
+
download: bool = True,
|
57 |
+
**kwargs):
|
58 |
+
self.download = download
|
59 |
+
super().__init__(
|
60 |
+
# The MNIST dataset doesn't need specify annotation file
|
61 |
+
ann_file='',
|
62 |
+
metainfo=metainfo,
|
63 |
+
data_root=data_root,
|
64 |
+
data_prefix=dict(root=data_prefix),
|
65 |
+
test_mode=test_mode,
|
66 |
+
**kwargs)
|
67 |
+
|
68 |
+
def load_data_list(self):
|
69 |
+
"""Load images and ground truth labels."""
|
70 |
+
root = self.data_prefix['root']
|
71 |
+
backend = get_file_backend(root, enable_singleton=True)
|
72 |
+
|
73 |
+
if dist.is_main_process() and not self._check_exists():
|
74 |
+
if not isinstance(backend, LocalBackend):
|
75 |
+
raise RuntimeError(f'The dataset on {root} is not integrated, '
|
76 |
+
f'please manually handle it.')
|
77 |
+
|
78 |
+
if self.download:
|
79 |
+
self._download()
|
80 |
+
else:
|
81 |
+
raise RuntimeError(
|
82 |
+
f'Cannot find {self.__class__.__name__} dataset in '
|
83 |
+
f"{self.data_prefix['root']}, you can specify "
|
84 |
+
'`download=True` to download automatically.')
|
85 |
+
|
86 |
+
dist.barrier()
|
87 |
+
assert self._check_exists(), \
|
88 |
+
'Download failed or shared storage is unavailable. Please ' \
|
89 |
+
f'download the dataset manually through {self.url_prefix}.'
|
90 |
+
|
91 |
+
if not self.test_mode:
|
92 |
+
file_list = self.train_list
|
93 |
+
else:
|
94 |
+
file_list = self.test_list
|
95 |
+
|
96 |
+
# load data from SN3 files
|
97 |
+
imgs = read_image_file(join_path(root, rm_suffix(file_list[0][0])))
|
98 |
+
gt_labels = read_label_file(
|
99 |
+
join_path(root, rm_suffix(file_list[1][0])))
|
100 |
+
|
101 |
+
data_infos = []
|
102 |
+
for img, gt_label in zip(imgs, gt_labels):
|
103 |
+
gt_label = np.array(gt_label, dtype=np.int64)
|
104 |
+
info = {'img': img.numpy(), 'gt_label': gt_label}
|
105 |
+
data_infos.append(info)
|
106 |
+
return data_infos
|
107 |
+
|
108 |
+
def _check_exists(self):
|
109 |
+
"""Check the exists of data files."""
|
110 |
+
root = self.data_prefix['root']
|
111 |
+
|
112 |
+
for filename, _ in (self.train_list + self.test_list):
|
113 |
+
# get extracted filename of data
|
114 |
+
extract_filename = rm_suffix(filename)
|
115 |
+
fpath = join_path(root, extract_filename)
|
116 |
+
if not exists(fpath):
|
117 |
+
return False
|
118 |
+
return True
|
119 |
+
|
120 |
+
def _download(self):
|
121 |
+
"""Download and extract data files."""
|
122 |
+
root = self.data_prefix['root']
|
123 |
+
|
124 |
+
for filename, md5 in (self.train_list + self.test_list):
|
125 |
+
url = urljoin(self.url_prefix, filename)
|
126 |
+
download_and_extract_archive(
|
127 |
+
url, download_root=root, filename=filename, md5=md5)
|
128 |
+
|
129 |
+
def extra_repr(self) -> List[str]:
|
130 |
+
"""The extra repr information of the dataset."""
|
131 |
+
body = [f"Prefix of data: \t{self.data_prefix['root']}"]
|
132 |
+
return body
|
133 |
+
|
134 |
+
|
135 |
+
@DATASETS.register_module()
|
136 |
+
class FashionMNIST(MNIST):
|
137 |
+
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_
|
138 |
+
Dataset.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
data_prefix (str): Prefix for data.
|
142 |
+
test_mode (bool): ``test_mode=True`` means in test phase.
|
143 |
+
It determines to use the training set or test set.
|
144 |
+
metainfo (dict, optional): Meta information for dataset, such as
|
145 |
+
categories information. Defaults to None.
|
146 |
+
data_root (str): The root directory for ``data_prefix``.
|
147 |
+
Defaults to ''.
|
148 |
+
download (bool): Whether to download the dataset if not exists.
|
149 |
+
Defaults to True.
|
150 |
+
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
151 |
+
"""
|
152 |
+
|
153 |
+
url_prefix = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
|
154 |
+
# train images and labels
|
155 |
+
train_list = [
|
156 |
+
['train-images-idx3-ubyte.gz', '8d4fb7e6c68d591d4c3dfef9ec88bf0d'],
|
157 |
+
['train-labels-idx1-ubyte.gz', '25c81989df183df01b3e8a0aad5dffbe'],
|
158 |
+
]
|
159 |
+
# test images and labels
|
160 |
+
test_list = [
|
161 |
+
['t10k-images-idx3-ubyte.gz', 'bef4ecab320f06d8554ea6380940ec79'],
|
162 |
+
['t10k-labels-idx1-ubyte.gz', 'bb300cfdad3c16e7a12a480ee83cd310'],
|
163 |
+
]
|
164 |
+
METAINFO = {'classes': FASHIONMNIST_CATEGORITES}
|
165 |
+
|
166 |
+
|
167 |
+
def get_int(b: bytes) -> int:
|
168 |
+
"""Convert bytes to int."""
|
169 |
+
return int(codecs.encode(b, 'hex'), 16)
|
170 |
+
|
171 |
+
|
172 |
+
def read_sn3_pascalvincent_tensor(path: str,
|
173 |
+
strict: bool = True) -> torch.Tensor:
|
174 |
+
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-
|
175 |
+
io.lsh').
|
176 |
+
|
177 |
+
Argument may be a filename, compressed filename, or file object.
|
178 |
+
"""
|
179 |
+
# typemap
|
180 |
+
if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'):
|
181 |
+
read_sn3_pascalvincent_tensor.typemap = {
|
182 |
+
8: (torch.uint8, np.uint8, np.uint8),
|
183 |
+
9: (torch.int8, np.int8, np.int8),
|
184 |
+
11: (torch.int16, np.dtype('>i2'), 'i2'),
|
185 |
+
12: (torch.int32, np.dtype('>i4'), 'i4'),
|
186 |
+
13: (torch.float32, np.dtype('>f4'), 'f4'),
|
187 |
+
14: (torch.float64, np.dtype('>f8'), 'f8')
|
188 |
+
}
|
189 |
+
# read
|
190 |
+
with open_maybe_compressed_file(path) as f:
|
191 |
+
data = f.read()
|
192 |
+
# parse
|
193 |
+
magic = get_int(data[0:4])
|
194 |
+
nd = magic % 256
|
195 |
+
ty = magic // 256
|
196 |
+
assert nd >= 1 and nd <= 3
|
197 |
+
assert ty >= 8 and ty <= 14
|
198 |
+
m = read_sn3_pascalvincent_tensor.typemap[ty]
|
199 |
+
s = [get_int(data[4 * (i + 1):4 * (i + 2)]) for i in range(nd)]
|
200 |
+
parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
|
201 |
+
assert parsed.shape[0] == np.prod(s) or not strict
|
202 |
+
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
|
203 |
+
|
204 |
+
|
205 |
+
def read_label_file(path: str) -> torch.Tensor:
|
206 |
+
"""Read labels from SN3 label file."""
|
207 |
+
with open(path, 'rb') as f:
|
208 |
+
x = read_sn3_pascalvincent_tensor(f, strict=False)
|
209 |
+
assert (x.dtype == torch.uint8)
|
210 |
+
assert (x.ndimension() == 1)
|
211 |
+
return x.long()
|
212 |
+
|
213 |
+
|
214 |
+
def read_image_file(path: str) -> torch.Tensor:
|
215 |
+
"""Read images from SN3 image file."""
|
216 |
+
with open(path, 'rb') as f:
|
217 |
+
x = read_sn3_pascalvincent_tensor(f, strict=False)
|
218 |
+
assert (x.dtype == torch.uint8)
|
219 |
+
assert (x.ndimension() == 3)
|
220 |
+
return x
|
mmpretrain/datasets/multi_label.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from mmpretrain.registry import DATASETS
|
5 |
+
from .base_dataset import BaseDataset
|
6 |
+
|
7 |
+
|
8 |
+
@DATASETS.register_module()
|
9 |
+
class MultiLabelDataset(BaseDataset):
|
10 |
+
"""Multi-label Dataset.
|
11 |
+
|
12 |
+
This dataset support annotation file in `OpenMMLab 2.0 style annotation
|
13 |
+
format`.
|
14 |
+
|
15 |
+
The annotation format is shown as follows.
|
16 |
+
|
17 |
+
.. code-block:: none
|
18 |
+
|
19 |
+
{
|
20 |
+
"metainfo":
|
21 |
+
{
|
22 |
+
"classes":['A', 'B', 'C'....]
|
23 |
+
},
|
24 |
+
"data_list":
|
25 |
+
[
|
26 |
+
{
|
27 |
+
"img_path": "test_img1.jpg",
|
28 |
+
'gt_label': [0, 1],
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"img_path": "test_img2.jpg",
|
32 |
+
'gt_label': [2],
|
33 |
+
},
|
34 |
+
]
|
35 |
+
....
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
Args:
|
40 |
+
ann_file (str): Annotation file path.
|
41 |
+
metainfo (dict, optional): Meta information for dataset, such as class
|
42 |
+
information. Defaults to None.
|
43 |
+
data_root (str): The root directory for ``data_prefix`` and
|
44 |
+
``ann_file``. Defaults to ''.
|
45 |
+
data_prefix (str | dict): Prefix for training data. Defaults to ''.
|
46 |
+
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
47 |
+
indices (int or Sequence[int], optional): Support using first few
|
48 |
+
data in annotation file to facilitate training/testing on a smaller
|
49 |
+
dataset. Defaults to None which means using all ``data_infos``.
|
50 |
+
serialize_data (bool, optional): Whether to hold memory using
|
51 |
+
serialized objects, when enabled, data loader workers can use
|
52 |
+
shared RAM from master process instead of making a copy. Defaults
|
53 |
+
to True.
|
54 |
+
pipeline (list, optional): Processing pipeline. Defaults to [].
|
55 |
+
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
56 |
+
Defaults to False.
|
57 |
+
lazy_init (bool, optional): Whether to load annotation during
|
58 |
+
instantiation. In some cases, such as visualization, only the meta
|
59 |
+
information of the dataset is needed, which is not necessary to
|
60 |
+
load annotation file. ``Basedataset`` can skip load annotations to
|
61 |
+
save time by set ``lazy_init=False``. Defaults to False.
|
62 |
+
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
|
63 |
+
None img. The maximum extra number of cycles to get a valid
|
64 |
+
image. Defaults to 1000.
|
65 |
+
classes (str | Sequence[str], optional): Specify names of classes.
|
66 |
+
|
67 |
+
- If is string, it should be a file path, and the every line of
|
68 |
+
the file is a name of a class.
|
69 |
+
- If is a sequence of string, every item is a name of class.
|
70 |
+
- If is None, use categories information in ``metainfo`` argument,
|
71 |
+
annotation file or the class attribute ``METAINFO``.
|
72 |
+
|
73 |
+
Defaults to None.
|
74 |
+
"""
|
75 |
+
|
76 |
+
def get_cat_ids(self, idx: int) -> List[int]:
|
77 |
+
"""Get category ids by index.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
idx (int): Index of data.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
cat_ids (List[int]): Image categories of specified index.
|
84 |
+
"""
|
85 |
+
return self.get_data_info(idx)['gt_label']
|
mmpretrain/datasets/multi_task.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
import os.path as osp
|
4 |
+
from os import PathLike
|
5 |
+
from typing import Optional, Sequence
|
6 |
+
|
7 |
+
import mmengine
|
8 |
+
from mmcv.transforms import Compose
|
9 |
+
from mmengine.fileio import get_file_backend
|
10 |
+
|
11 |
+
from .builder import DATASETS
|
12 |
+
|
13 |
+
|
14 |
+
def expanduser(path):
|
15 |
+
if isinstance(path, (str, PathLike)):
|
16 |
+
return osp.expanduser(path)
|
17 |
+
else:
|
18 |
+
return path
|
19 |
+
|
20 |
+
|
21 |
+
def isabs(uri):
|
22 |
+
return osp.isabs(uri) or ('://' in uri)
|
23 |
+
|
24 |
+
|
25 |
+
@DATASETS.register_module()
|
26 |
+
class MultiTaskDataset:
|
27 |
+
"""Custom dataset for multi-task dataset.
|
28 |
+
|
29 |
+
To use the dataset, please generate and provide an annotation file in the
|
30 |
+
below format:
|
31 |
+
|
32 |
+
.. code-block:: json
|
33 |
+
|
34 |
+
{
|
35 |
+
"metainfo": {
|
36 |
+
"tasks":
|
37 |
+
[
|
38 |
+
'gender'
|
39 |
+
'wear'
|
40 |
+
]
|
41 |
+
},
|
42 |
+
"data_list": [
|
43 |
+
{
|
44 |
+
"img_path": "a.jpg",
|
45 |
+
gt_label:{
|
46 |
+
"gender": 0,
|
47 |
+
"wear": [1, 0, 1, 0]
|
48 |
+
}
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"img_path": "b.jpg",
|
52 |
+
gt_label:{
|
53 |
+
"gender": 1,
|
54 |
+
"wear": [1, 0, 1, 0]
|
55 |
+
}
|
56 |
+
}
|
57 |
+
]
|
58 |
+
}
|
59 |
+
|
60 |
+
Assume we put our dataset in the ``data/mydataset`` folder in the
|
61 |
+
repository and organize it as the below format: ::
|
62 |
+
|
63 |
+
mmpretrain/
|
64 |
+
└── data
|
65 |
+
└── mydataset
|
66 |
+
├── annotation
|
67 |
+
│ ├── train.json
|
68 |
+
│ ├── test.json
|
69 |
+
│ └── val.json
|
70 |
+
├── train
|
71 |
+
│ ├── a.jpg
|
72 |
+
│ └── ...
|
73 |
+
├── test
|
74 |
+
│ ├── b.jpg
|
75 |
+
│ └── ...
|
76 |
+
└── val
|
77 |
+
├── c.jpg
|
78 |
+
└── ...
|
79 |
+
|
80 |
+
We can use the below config to build datasets:
|
81 |
+
|
82 |
+
.. code:: python
|
83 |
+
|
84 |
+
>>> from mmpretrain.datasets import build_dataset
|
85 |
+
>>> train_cfg = dict(
|
86 |
+
... type="MultiTaskDataset",
|
87 |
+
... ann_file="annotation/train.json",
|
88 |
+
... data_root="data/mydataset",
|
89 |
+
... # The `img_path` field in the train annotation file is relative
|
90 |
+
... # to the `train` folder.
|
91 |
+
... data_prefix='train',
|
92 |
+
... )
|
93 |
+
>>> train_dataset = build_dataset(train_cfg)
|
94 |
+
|
95 |
+
Or we can put all files in the same folder: ::
|
96 |
+
|
97 |
+
mmpretrain/
|
98 |
+
└── data
|
99 |
+
└── mydataset
|
100 |
+
├── train.json
|
101 |
+
├── test.json
|
102 |
+
├── val.json
|
103 |
+
├── a.jpg
|
104 |
+
├── b.jpg
|
105 |
+
├── c.jpg
|
106 |
+
└── ...
|
107 |
+
|
108 |
+
And we can use the below config to build datasets:
|
109 |
+
|
110 |
+
.. code:: python
|
111 |
+
|
112 |
+
>>> from mmpretrain.datasets import build_dataset
|
113 |
+
>>> train_cfg = dict(
|
114 |
+
... type="MultiTaskDataset",
|
115 |
+
... ann_file="train.json",
|
116 |
+
... data_root="data/mydataset",
|
117 |
+
... # the `data_prefix` is not required since all paths are
|
118 |
+
... # relative to the `data_root`.
|
119 |
+
... )
|
120 |
+
>>> train_dataset = build_dataset(train_cfg)
|
121 |
+
|
122 |
+
|
123 |
+
Args:
|
124 |
+
ann_file (str): The annotation file path. It can be either absolute
|
125 |
+
path or relative path to the ``data_root``.
|
126 |
+
metainfo (dict, optional): The extra meta information. It should be
|
127 |
+
a dict with the same format as the ``"metainfo"`` field in the
|
128 |
+
annotation file. Defaults to None.
|
129 |
+
data_root (str, optional): The root path of the data directory. It's
|
130 |
+
the prefix of the ``data_prefix`` and the ``ann_file``. And it can
|
131 |
+
be a remote path like "s3://openmmlab/xxx/". Defaults to None.
|
132 |
+
data_prefix (str, optional): The base folder relative to the
|
133 |
+
``data_root`` for the ``"img_path"`` field in the annotation file.
|
134 |
+
Defaults to None.
|
135 |
+
pipeline (Sequence[dict]): A list of dict, where each element
|
136 |
+
represents a operation defined in
|
137 |
+
:mod:`mmpretrain.datasets.pipelines`. Defaults to an empty tuple.
|
138 |
+
test_mode (bool): in train mode or test mode. Defaults to False.
|
139 |
+
"""
|
140 |
+
METAINFO = dict()
|
141 |
+
|
142 |
+
def __init__(self,
|
143 |
+
ann_file: str,
|
144 |
+
metainfo: Optional[dict] = None,
|
145 |
+
data_root: Optional[str] = None,
|
146 |
+
data_prefix: Optional[str] = None,
|
147 |
+
pipeline: Sequence = (),
|
148 |
+
test_mode: bool = False):
|
149 |
+
|
150 |
+
self.data_root = expanduser(data_root)
|
151 |
+
|
152 |
+
# Inference the file client
|
153 |
+
if self.data_root is not None:
|
154 |
+
self.file_backend = get_file_backend(uri=self.data_root)
|
155 |
+
else:
|
156 |
+
self.file_backend = None
|
157 |
+
|
158 |
+
self.ann_file = self._join_root(expanduser(ann_file))
|
159 |
+
self.data_prefix = self._join_root(data_prefix)
|
160 |
+
|
161 |
+
self.test_mode = test_mode
|
162 |
+
self.pipeline = Compose(pipeline)
|
163 |
+
self.data_list = self.load_data_list(self.ann_file, metainfo)
|
164 |
+
|
165 |
+
def _join_root(self, path):
|
166 |
+
"""Join ``self.data_root`` with the specified path.
|
167 |
+
|
168 |
+
If the path is an absolute path, just return the path. And if the
|
169 |
+
path is None, return ``self.data_root``.
|
170 |
+
|
171 |
+
Examples:
|
172 |
+
>>> self.data_root = 'a/b/c'
|
173 |
+
>>> self._join_root('d/e/')
|
174 |
+
'a/b/c/d/e'
|
175 |
+
>>> self._join_root('https://openmmlab.com')
|
176 |
+
'https://openmmlab.com'
|
177 |
+
>>> self._join_root(None)
|
178 |
+
'a/b/c'
|
179 |
+
"""
|
180 |
+
if path is None:
|
181 |
+
return self.data_root
|
182 |
+
if isabs(path):
|
183 |
+
return path
|
184 |
+
|
185 |
+
joined_path = self.file_backend.join_path(self.data_root, path)
|
186 |
+
return joined_path
|
187 |
+
|
188 |
+
@classmethod
|
189 |
+
def _get_meta_info(cls, in_metainfo: dict = None) -> dict:
|
190 |
+
"""Collect meta information from the dictionary of meta.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
in_metainfo (dict): Meta information dict.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
dict: Parsed meta information.
|
197 |
+
"""
|
198 |
+
# `cls.METAINFO` will be overwritten by in_meta
|
199 |
+
metainfo = copy.deepcopy(cls.METAINFO)
|
200 |
+
if in_metainfo is None:
|
201 |
+
return metainfo
|
202 |
+
|
203 |
+
metainfo.update(in_metainfo)
|
204 |
+
|
205 |
+
return metainfo
|
206 |
+
|
207 |
+
def load_data_list(self, ann_file, metainfo_override=None):
|
208 |
+
"""Load annotations from an annotation file.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
ann_file (str): Absolute annotation file path if ``self.root=None``
|
212 |
+
or relative path if ``self.root=/path/to/data/``.
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
list[dict]: A list of annotation.
|
216 |
+
"""
|
217 |
+
annotations = mmengine.load(ann_file)
|
218 |
+
if not isinstance(annotations, dict):
|
219 |
+
raise TypeError(f'The annotations loaded from annotation file '
|
220 |
+
f'should be a dict, but got {type(annotations)}!')
|
221 |
+
if 'data_list' not in annotations:
|
222 |
+
raise ValueError('The annotation file must have the `data_list` '
|
223 |
+
'field.')
|
224 |
+
metainfo = annotations.get('metainfo', {})
|
225 |
+
raw_data_list = annotations['data_list']
|
226 |
+
|
227 |
+
# Set meta information.
|
228 |
+
assert isinstance(metainfo, dict), 'The `metainfo` field in the '\
|
229 |
+
f'annotation file should be a dict, but got {type(metainfo)}'
|
230 |
+
if metainfo_override is not None:
|
231 |
+
assert isinstance(metainfo_override, dict), 'The `metainfo` ' \
|
232 |
+
f'argument should be a dict, but got {type(metainfo_override)}'
|
233 |
+
metainfo.update(metainfo_override)
|
234 |
+
self._metainfo = self._get_meta_info(metainfo)
|
235 |
+
|
236 |
+
data_list = []
|
237 |
+
for i, raw_data in enumerate(raw_data_list):
|
238 |
+
try:
|
239 |
+
data_list.append(self.parse_data_info(raw_data))
|
240 |
+
except AssertionError as e:
|
241 |
+
raise RuntimeError(
|
242 |
+
f'The format check fails during parse the item {i} of '
|
243 |
+
f'the annotation file with error: {e}')
|
244 |
+
return data_list
|
245 |
+
|
246 |
+
def parse_data_info(self, raw_data):
|
247 |
+
"""Parse raw annotation to target format.
|
248 |
+
|
249 |
+
This method will return a dict which contains the data information of a
|
250 |
+
sample.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
raw_data (dict): Raw data information load from ``ann_file``
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
dict: Parsed annotation.
|
257 |
+
"""
|
258 |
+
assert isinstance(raw_data, dict), \
|
259 |
+
f'The item should be a dict, but got {type(raw_data)}'
|
260 |
+
assert 'img_path' in raw_data, \
|
261 |
+
"The item doesn't have `img_path` field."
|
262 |
+
data = dict(
|
263 |
+
img_path=self._join_root(raw_data['img_path']),
|
264 |
+
gt_label=raw_data['gt_label'],
|
265 |
+
)
|
266 |
+
return data
|
267 |
+
|
268 |
+
@property
|
269 |
+
def metainfo(self) -> dict:
|
270 |
+
"""Get meta information of dataset.
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
dict: meta information collected from ``cls.METAINFO``,
|
274 |
+
annotation file and metainfo argument during instantiation.
|
275 |
+
"""
|
276 |
+
return copy.deepcopy(self._metainfo)
|
277 |
+
|
278 |
+
def prepare_data(self, idx):
|
279 |
+
"""Get data processed by ``self.pipeline``.
|
280 |
+
|
281 |
+
Args:
|
282 |
+
idx (int): The index of ``data_info``.
|
283 |
+
|
284 |
+
Returns:
|
285 |
+
Any: Depends on ``self.pipeline``.
|
286 |
+
"""
|
287 |
+
results = copy.deepcopy(self.data_list[idx])
|
288 |
+
return self.pipeline(results)
|
289 |
+
|
290 |
+
def __len__(self):
|
291 |
+
"""Get the length of the whole dataset.
|
292 |
+
|
293 |
+
Returns:
|
294 |
+
int: The length of filtered dataset.
|
295 |
+
"""
|
296 |
+
return len(self.data_list)
|
297 |
+
|
298 |
+
def __getitem__(self, idx):
|
299 |
+
"""Get the idx-th image and data information of dataset after
|
300 |
+
``self.pipeline``.
|
301 |
+
|
302 |
+
Args:
|
303 |
+
idx (int): The index of of the data.
|
304 |
+
|
305 |
+
Returns:
|
306 |
+
dict: The idx-th image and data information after
|
307 |
+
``self.pipeline``.
|
308 |
+
"""
|
309 |
+
return self.prepare_data(idx)
|
310 |
+
|
311 |
+
def __repr__(self):
|
312 |
+
"""Print the basic information of the dataset.
|
313 |
+
|
314 |
+
Returns:
|
315 |
+
str: Formatted string.
|
316 |
+
"""
|
317 |
+
head = 'Dataset ' + self.__class__.__name__
|
318 |
+
body = [f'Number of samples: \t{self.__len__()}']
|
319 |
+
if self.data_root is not None:
|
320 |
+
body.append(f'Root location: \t{self.data_root}')
|
321 |
+
body.append(f'Annotation file: \t{self.ann_file}')
|
322 |
+
if self.data_prefix is not None:
|
323 |
+
body.append(f'Prefix of images: \t{self.data_prefix}')
|
324 |
+
# -------------------- extra repr --------------------
|
325 |
+
tasks = self.metainfo['tasks']
|
326 |
+
body.append(f'For {len(tasks)} tasks')
|
327 |
+
for task in tasks:
|
328 |
+
body.append(f' {task} ')
|
329 |
+
# ----------------------------------------------------
|
330 |
+
|
331 |
+
if len(self.pipeline.transforms) > 0:
|
332 |
+
body.append('With transforms:')
|
333 |
+
for t in self.pipeline.transforms:
|
334 |
+
body.append(f' {t}')
|
335 |
+
|
336 |
+
lines = [head] + [' ' * 4 + line for line in body]
|
337 |
+
return '\n'.join(lines)
|
mmpretrain/datasets/nlvr2.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import json
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
from mmengine.fileio import get_file_backend, list_from_file
|
6 |
+
|
7 |
+
from mmpretrain.registry import DATASETS
|
8 |
+
from .base_dataset import BaseDataset
|
9 |
+
|
10 |
+
|
11 |
+
@DATASETS.register_module()
|
12 |
+
class NLVR2(BaseDataset):
|
13 |
+
"""COCO Caption dataset."""
|
14 |
+
|
15 |
+
def load_data_list(self) -> List[dict]:
|
16 |
+
"""Load data list."""
|
17 |
+
|
18 |
+
data_list = []
|
19 |
+
img_prefix = self.data_prefix['img_path']
|
20 |
+
file_backend = get_file_backend(img_prefix)
|
21 |
+
examples = list_from_file(self.ann_file)
|
22 |
+
|
23 |
+
for example in examples:
|
24 |
+
example = json.loads(example)
|
25 |
+
prefix = example['identifier'].rsplit('-', 1)[0]
|
26 |
+
train_data = {}
|
27 |
+
train_data['text'] = example['sentence']
|
28 |
+
train_data['gt_label'] = {'True': 1, 'False': 0}[example['label']]
|
29 |
+
train_data['img_path'] = [
|
30 |
+
file_backend.join_path(img_prefix, prefix + f'-img{i}.png')
|
31 |
+
for i in range(2)
|
32 |
+
]
|
33 |
+
|
34 |
+
data_list.append(train_data)
|
35 |
+
|
36 |
+
return data_list
|
mmpretrain/datasets/oxfordiiitpet.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from mmengine import get_file_backend, list_from_file
|
5 |
+
|
6 |
+
from mmpretrain.registry import DATASETS
|
7 |
+
from .base_dataset import BaseDataset
|
8 |
+
from .categories import OxfordIIITPet_CATEGORIES
|
9 |
+
|
10 |
+
|
11 |
+
@DATASETS.register_module()
|
12 |
+
class OxfordIIITPet(BaseDataset):
|
13 |
+
"""The Oxford-IIIT Pets Dataset.
|
14 |
+
|
15 |
+
Support the `Oxford-IIIT Pets Dataset <https://www.robots.ox.ac.uk/~vgg/data/pets/>`_ Dataset.
|
16 |
+
After downloading and decompression, the dataset directory structure is as follows.
|
17 |
+
|
18 |
+
Oxford-IIIT_Pets dataset directory: ::
|
19 |
+
|
20 |
+
Oxford-IIIT_Pets
|
21 |
+
├── images
|
22 |
+
│ ├── Abyssinian_1.jpg
|
23 |
+
│ ├── Abyssinian_2.jpg
|
24 |
+
│ └── ...
|
25 |
+
├── annotations
|
26 |
+
│ ├── trainval.txt
|
27 |
+
│ ├── test.txt
|
28 |
+
│ ├── list.txt
|
29 |
+
│ └── ...
|
30 |
+
└── ....
|
31 |
+
|
32 |
+
Args:
|
33 |
+
data_root (str): The root directory for Oxford-IIIT Pets dataset.
|
34 |
+
split (str, optional): The dataset split, supports "trainval" and "test".
|
35 |
+
Default to "trainval".
|
36 |
+
|
37 |
+
Examples:
|
38 |
+
>>> from mmpretrain.datasets import OxfordIIITPet
|
39 |
+
>>> train_dataset = OxfordIIITPet(data_root='data/Oxford-IIIT_Pets', split='trainval')
|
40 |
+
>>> train_dataset
|
41 |
+
Dataset OxfordIIITPet
|
42 |
+
Number of samples: 3680
|
43 |
+
Number of categories: 37
|
44 |
+
Root of dataset: data/Oxford-IIIT_Pets
|
45 |
+
>>> test_dataset = OxfordIIITPet(data_root='data/Oxford-IIIT_Pets', split='test')
|
46 |
+
>>> test_dataset
|
47 |
+
Dataset OxfordIIITPet
|
48 |
+
Number of samples: 3669
|
49 |
+
Number of categories: 37
|
50 |
+
Root of dataset: data/Oxford-IIIT_Pets
|
51 |
+
""" # noqa: E501
|
52 |
+
|
53 |
+
METAINFO = {'classes': OxfordIIITPet_CATEGORIES}
|
54 |
+
|
55 |
+
def __init__(self, data_root: str, split: str = 'trainval', **kwargs):
|
56 |
+
|
57 |
+
splits = ['trainval', 'test']
|
58 |
+
assert split in splits, \
|
59 |
+
f"The split must be one of {splits}, but get '{split}'"
|
60 |
+
self.split = split
|
61 |
+
|
62 |
+
self.backend = get_file_backend(data_root, enable_singleton=True)
|
63 |
+
if split == 'trainval':
|
64 |
+
ann_file = self.backend.join_path('annotations', 'trainval.txt')
|
65 |
+
else:
|
66 |
+
ann_file = self.backend.join_path('annotations', 'test.txt')
|
67 |
+
|
68 |
+
data_prefix = 'images'
|
69 |
+
test_mode = split == 'test'
|
70 |
+
|
71 |
+
super(OxfordIIITPet, self).__init__(
|
72 |
+
ann_file=ann_file,
|
73 |
+
data_root=data_root,
|
74 |
+
data_prefix=data_prefix,
|
75 |
+
test_mode=test_mode,
|
76 |
+
**kwargs)
|
77 |
+
|
78 |
+
def load_data_list(self):
|
79 |
+
"""Load images and ground truth labels."""
|
80 |
+
|
81 |
+
pairs = list_from_file(self.ann_file)
|
82 |
+
data_list = []
|
83 |
+
for pair in pairs:
|
84 |
+
img_name, class_id, _, _ = pair.split()
|
85 |
+
img_name = f'{img_name}.jpg'
|
86 |
+
img_path = self.backend.join_path(self.img_prefix, img_name)
|
87 |
+
gt_label = int(class_id) - 1
|
88 |
+
info = dict(img_path=img_path, gt_label=gt_label)
|
89 |
+
data_list.append(info)
|
90 |
+
return data_list
|
91 |
+
|
92 |
+
def extra_repr(self) -> List[str]:
|
93 |
+
"""The extra repr information of the dataset."""
|
94 |
+
body = [
|
95 |
+
f'Root of dataset: \t{self.data_root}',
|
96 |
+
]
|
97 |
+
return body
|
mmpretrain/datasets/places205.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import Optional, Union
|
3 |
+
|
4 |
+
from mmpretrain.registry import DATASETS
|
5 |
+
from .categories import PLACES205_CATEGORIES
|
6 |
+
from .custom import CustomDataset
|
7 |
+
|
8 |
+
|
9 |
+
@DATASETS.register_module()
|
10 |
+
class Places205(CustomDataset):
|
11 |
+
"""`Places205 <http://places.csail.mit.edu/downloadData.html>`_ Dataset.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
data_root (str): The root directory for ``data_prefix`` and
|
15 |
+
``ann_file``. Defaults to ''.
|
16 |
+
data_prefix (str | dict): Prefix for training data. Defaults
|
17 |
+
to ''.
|
18 |
+
ann_file (str): Annotation file path. Defaults to ''.
|
19 |
+
metainfo (dict, optional): Meta information for dataset, such as class
|
20 |
+
information. Defaults to None.
|
21 |
+
**kwargs: Other keyword arguments in :class:`CustomDataset` and
|
22 |
+
:class:`BaseDataset`.
|
23 |
+
"""
|
24 |
+
|
25 |
+
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
|
26 |
+
METAINFO = {'classes': PLACES205_CATEGORIES}
|
27 |
+
|
28 |
+
def __init__(self,
|
29 |
+
data_root: str = '',
|
30 |
+
data_prefix: Union[str, dict] = '',
|
31 |
+
ann_file: str = '',
|
32 |
+
metainfo: Optional[dict] = None,
|
33 |
+
**kwargs):
|
34 |
+
kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs}
|
35 |
+
super().__init__(
|
36 |
+
data_root=data_root,
|
37 |
+
data_prefix=data_prefix,
|
38 |
+
ann_file=ann_file,
|
39 |
+
metainfo=metainfo,
|
40 |
+
**kwargs)
|
mmpretrain/datasets/refcoco.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import mmengine
|
6 |
+
import numpy as np
|
7 |
+
from mmengine.dataset import BaseDataset
|
8 |
+
from pycocotools.coco import COCO
|
9 |
+
|
10 |
+
from mmpretrain.registry import DATASETS
|
11 |
+
|
12 |
+
|
13 |
+
@DATASETS.register_module()
|
14 |
+
class RefCOCO(BaseDataset):
|
15 |
+
"""RefCOCO dataset.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
ann_file (str): Annotation file path.
|
19 |
+
data_root (str): The root directory for ``data_prefix`` and
|
20 |
+
``ann_file``. Defaults to ''.
|
21 |
+
data_prefix (str): Prefix for training data.
|
22 |
+
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
|
23 |
+
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self,
|
27 |
+
data_root,
|
28 |
+
ann_file,
|
29 |
+
data_prefix,
|
30 |
+
split_file,
|
31 |
+
split='train',
|
32 |
+
**kwargs):
|
33 |
+
self.split_file = split_file
|
34 |
+
self.split = split
|
35 |
+
|
36 |
+
super().__init__(
|
37 |
+
data_root=data_root,
|
38 |
+
data_prefix=dict(img_path=data_prefix),
|
39 |
+
ann_file=ann_file,
|
40 |
+
**kwargs,
|
41 |
+
)
|
42 |
+
|
43 |
+
def _join_prefix(self):
|
44 |
+
if not mmengine.is_abs(self.split_file) and self.split_file:
|
45 |
+
self.split_file = osp.join(self.data_root, self.split_file)
|
46 |
+
|
47 |
+
return super()._join_prefix()
|
48 |
+
|
49 |
+
def load_data_list(self) -> List[dict]:
|
50 |
+
"""Load data list."""
|
51 |
+
with mmengine.get_local_path(self.ann_file) as ann_file:
|
52 |
+
coco = COCO(ann_file)
|
53 |
+
splits = mmengine.load(self.split_file, file_format='pkl')
|
54 |
+
img_prefix = self.data_prefix['img_path']
|
55 |
+
|
56 |
+
data_list = []
|
57 |
+
join_path = mmengine.fileio.get_file_backend(img_prefix).join_path
|
58 |
+
for refer in splits:
|
59 |
+
if refer['split'] != self.split:
|
60 |
+
continue
|
61 |
+
|
62 |
+
ann = coco.anns[refer['ann_id']]
|
63 |
+
img = coco.imgs[ann['image_id']]
|
64 |
+
sentences = refer['sentences']
|
65 |
+
bbox = np.array(ann['bbox'], dtype=np.float32)
|
66 |
+
bbox[2:4] = bbox[0:2] + bbox[2:4] # XYWH -> XYXY
|
67 |
+
|
68 |
+
for sent in sentences:
|
69 |
+
data_info = {
|
70 |
+
'img_path': join_path(img_prefix, img['file_name']),
|
71 |
+
'image_id': ann['image_id'],
|
72 |
+
'ann_id': ann['id'],
|
73 |
+
'text': sent['sent'],
|
74 |
+
'gt_bboxes': bbox[None, :],
|
75 |
+
}
|
76 |
+
data_list.append(data_info)
|
77 |
+
|
78 |
+
if len(data_list) == 0:
|
79 |
+
raise ValueError(f'No sample in split "{self.split}".')
|
80 |
+
|
81 |
+
return data_list
|
mmpretrain/datasets/samplers/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .repeat_aug import RepeatAugSampler
|
3 |
+
from .sequential import SequentialSampler
|
4 |
+
|
5 |
+
__all__ = ['RepeatAugSampler', 'SequentialSampler']
|
mmpretrain/datasets/samplers/repeat_aug.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Iterator, Optional, Sized
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from mmengine.dist import get_dist_info, is_main_process, sync_random_seed
|
6 |
+
from torch.utils.data import Sampler
|
7 |
+
|
8 |
+
from mmpretrain.registry import DATA_SAMPLERS
|
9 |
+
|
10 |
+
|
11 |
+
@DATA_SAMPLERS.register_module()
|
12 |
+
class RepeatAugSampler(Sampler):
|
13 |
+
"""Sampler that restricts data loading to a subset of the dataset for
|
14 |
+
distributed, with repeated augmentation. It ensures that different each
|
15 |
+
augmented version of a sample will be visible to a different process (GPU).
|
16 |
+
Heavily based on torch.utils.data.DistributedSampler.
|
17 |
+
|
18 |
+
This sampler was taken from
|
19 |
+
https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
|
20 |
+
Used in
|
21 |
+
Copyright (c) 2015-present, Facebook, Inc.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
dataset (Sized): The dataset.
|
25 |
+
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
|
26 |
+
num_repeats (int): The repeat times of every sample. Defaults to 3.
|
27 |
+
seed (int, optional): Random seed used to shuffle the sampler if
|
28 |
+
:attr:`shuffle=True`. This number should be identical across all
|
29 |
+
processes in the distributed group. Defaults to None.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self,
|
33 |
+
dataset: Sized,
|
34 |
+
shuffle: bool = True,
|
35 |
+
num_repeats: int = 3,
|
36 |
+
seed: Optional[int] = None):
|
37 |
+
rank, world_size = get_dist_info()
|
38 |
+
self.rank = rank
|
39 |
+
self.world_size = world_size
|
40 |
+
|
41 |
+
self.dataset = dataset
|
42 |
+
self.shuffle = shuffle
|
43 |
+
if not self.shuffle and is_main_process():
|
44 |
+
from mmengine.logging import MMLogger
|
45 |
+
logger = MMLogger.get_current_instance()
|
46 |
+
logger.warning('The RepeatAugSampler always picks a '
|
47 |
+
'fixed part of data if `shuffle=False`.')
|
48 |
+
|
49 |
+
if seed is None:
|
50 |
+
seed = sync_random_seed()
|
51 |
+
self.seed = seed
|
52 |
+
self.epoch = 0
|
53 |
+
self.num_repeats = num_repeats
|
54 |
+
|
55 |
+
# The number of repeated samples in the rank
|
56 |
+
self.num_samples = math.ceil(
|
57 |
+
len(self.dataset) * num_repeats / world_size)
|
58 |
+
# The total number of repeated samples in all ranks.
|
59 |
+
self.total_size = self.num_samples * world_size
|
60 |
+
# The number of selected samples in the rank
|
61 |
+
self.num_selected_samples = math.ceil(len(self.dataset) / world_size)
|
62 |
+
|
63 |
+
def __iter__(self) -> Iterator[int]:
|
64 |
+
"""Iterate the indices."""
|
65 |
+
# deterministically shuffle based on epoch and seed
|
66 |
+
if self.shuffle:
|
67 |
+
g = torch.Generator()
|
68 |
+
g.manual_seed(self.seed + self.epoch)
|
69 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
70 |
+
else:
|
71 |
+
indices = list(range(len(self.dataset)))
|
72 |
+
|
73 |
+
# produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
|
74 |
+
indices = [x for x in indices for _ in range(self.num_repeats)]
|
75 |
+
# add extra samples to make it evenly divisible
|
76 |
+
padding_size = self.total_size - len(indices)
|
77 |
+
indices += indices[:padding_size]
|
78 |
+
assert len(indices) == self.total_size
|
79 |
+
|
80 |
+
# subsample per rank
|
81 |
+
indices = indices[self.rank:self.total_size:self.world_size]
|
82 |
+
assert len(indices) == self.num_samples
|
83 |
+
|
84 |
+
# return up to num selected samples
|
85 |
+
return iter(indices[:self.num_selected_samples])
|
86 |
+
|
87 |
+
def __len__(self) -> int:
|
88 |
+
"""The number of samples in this rank."""
|
89 |
+
return self.num_selected_samples
|
90 |
+
|
91 |
+
def set_epoch(self, epoch: int) -> None:
|
92 |
+
"""Sets the epoch for this sampler.
|
93 |
+
|
94 |
+
When :attr:`shuffle=True`, this ensures all replicas use a different
|
95 |
+
random ordering for each epoch. Otherwise, the next iteration of this
|
96 |
+
sampler will yield the same ordering.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
epoch (int): Epoch number.
|
100 |
+
"""
|
101 |
+
self.epoch = epoch
|
mmpretrain/datasets/samplers/sequential.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import Iterator
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from mmengine.dataset import DefaultSampler
|
6 |
+
|
7 |
+
from mmpretrain.registry import DATA_SAMPLERS
|
8 |
+
|
9 |
+
|
10 |
+
@DATA_SAMPLERS.register_module()
|
11 |
+
class SequentialSampler(DefaultSampler):
|
12 |
+
"""Sequential sampler which supports different subsample policy.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
dataset (Sized): The dataset.
|
16 |
+
round_up (bool): Whether to add extra samples to make the number of
|
17 |
+
samples evenly divisible by the world size. Defaults to True.
|
18 |
+
subsample_type (str): The method to subsample data on different rank.
|
19 |
+
Supported type:
|
20 |
+
|
21 |
+
- ``'default'``: Original torch behavior. Sample the examples one
|
22 |
+
by one for each GPU in terms. For instance, 8 examples on 2 GPUs,
|
23 |
+
GPU0: [0,2,4,8], GPU1: [1,3,5,7]
|
24 |
+
- ``'sequential'``: Subsample all examples to n chunk sequntially.
|
25 |
+
For instance, 8 examples on 2 GPUs,
|
26 |
+
GPU0: [0,1,2,3], GPU1: [4,5,6,7]
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, subsample_type: str = 'default', **kwargs) -> None:
|
30 |
+
super().__init__(shuffle=False, **kwargs)
|
31 |
+
|
32 |
+
if subsample_type not in ['default', 'sequential']:
|
33 |
+
raise ValueError(f'Unsupported subsample typer "{subsample_type}",'
|
34 |
+
' please choose from ["default", "sequential"]')
|
35 |
+
self.subsample_type = subsample_type
|
36 |
+
|
37 |
+
def __iter__(self) -> Iterator[int]:
|
38 |
+
"""Iterate the indices."""
|
39 |
+
indices = torch.arange(len(self.dataset)).tolist()
|
40 |
+
|
41 |
+
# add extra samples to make it evenly divisible
|
42 |
+
if self.round_up:
|
43 |
+
indices = (
|
44 |
+
indices *
|
45 |
+
int(self.total_size / len(indices) + 1))[:self.total_size]
|
46 |
+
|
47 |
+
# subsample
|
48 |
+
if self.subsample_type == 'default':
|
49 |
+
indices = indices[self.rank:self.total_size:self.world_size]
|
50 |
+
elif self.subsample_type == 'sequential':
|
51 |
+
num_samples_per_rank = self.total_size // self.world_size
|
52 |
+
indices = indices[self.rank *
|
53 |
+
num_samples_per_rank:(self.rank + 1) *
|
54 |
+
num_samples_per_rank]
|
55 |
+
|
56 |
+
return iter(indices)
|
mmpretrain/datasets/scienceqa.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os
|
3 |
+
from typing import Callable, List, Sequence
|
4 |
+
|
5 |
+
import mmengine
|
6 |
+
from mmengine.dataset import BaseDataset
|
7 |
+
from mmengine.fileio import get_file_backend
|
8 |
+
|
9 |
+
from mmpretrain.registry import DATASETS
|
10 |
+
|
11 |
+
|
12 |
+
@DATASETS.register_module()
|
13 |
+
class ScienceQA(BaseDataset):
|
14 |
+
"""ScienceQA dataset.
|
15 |
+
|
16 |
+
This dataset is used to load the multimodal data of ScienceQA dataset.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
data_root (str): The root directory for ``data_prefix`` and
|
20 |
+
``ann_file``.
|
21 |
+
split (str): The split of dataset. Options: ``train``, ``val``,
|
22 |
+
``test``, ``trainval``, ``minival``, and ``minitest``.
|
23 |
+
split_file (str): The split file of dataset, which contains the
|
24 |
+
ids of data samples in the split.
|
25 |
+
ann_file (str): Annotation file path.
|
26 |
+
data_prefix (dict): Prefix for data field. Defaults to
|
27 |
+
``dict(img_path='')``.
|
28 |
+
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
|
29 |
+
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self,
|
33 |
+
data_root: str,
|
34 |
+
split: str,
|
35 |
+
split_file: str,
|
36 |
+
ann_file: str,
|
37 |
+
data_prefix: dict = dict(img_path=''),
|
38 |
+
pipeline: Sequence[Callable] = (),
|
39 |
+
**kwargs):
|
40 |
+
|
41 |
+
assert split in [
|
42 |
+
'train', 'val', 'test', 'trainval', 'minival', 'minitest'
|
43 |
+
], f'Invalid split {split}'
|
44 |
+
self.split = split
|
45 |
+
self.split_file = os.path.join(data_root, split_file)
|
46 |
+
|
47 |
+
super().__init__(
|
48 |
+
data_root=data_root,
|
49 |
+
ann_file=ann_file,
|
50 |
+
data_prefix=data_prefix,
|
51 |
+
pipeline=pipeline,
|
52 |
+
**kwargs)
|
53 |
+
|
54 |
+
def load_data_list(self) -> List[dict]:
|
55 |
+
"""Load data list."""
|
56 |
+
img_prefix = self.data_prefix['img_path']
|
57 |
+
annotations = mmengine.load(self.ann_file)
|
58 |
+
current_data_split = mmengine.load(self.split_file)[self.split] # noqa
|
59 |
+
|
60 |
+
file_backend = get_file_backend(img_prefix)
|
61 |
+
|
62 |
+
data_list = []
|
63 |
+
for data_id in current_data_split:
|
64 |
+
ann = annotations[data_id]
|
65 |
+
data_info = {
|
66 |
+
'image_id':
|
67 |
+
data_id,
|
68 |
+
'question':
|
69 |
+
ann['question'],
|
70 |
+
'choices':
|
71 |
+
ann['choices'],
|
72 |
+
'gt_answer':
|
73 |
+
ann['answer'],
|
74 |
+
'hint':
|
75 |
+
ann['hint'],
|
76 |
+
'image_name':
|
77 |
+
ann['image'],
|
78 |
+
'task':
|
79 |
+
ann['task'],
|
80 |
+
'grade':
|
81 |
+
ann['grade'],
|
82 |
+
'subject':
|
83 |
+
ann['subject'],
|
84 |
+
'topic':
|
85 |
+
ann['topic'],
|
86 |
+
'category':
|
87 |
+
ann['category'],
|
88 |
+
'skill':
|
89 |
+
ann['skill'],
|
90 |
+
'lecture':
|
91 |
+
ann['lecture'],
|
92 |
+
'solution':
|
93 |
+
ann['solution'],
|
94 |
+
'split':
|
95 |
+
ann['split'],
|
96 |
+
'img_path':
|
97 |
+
file_backend.join_path(img_prefix, data_id, ann['image'])
|
98 |
+
if ann['image'] is not None else None,
|
99 |
+
'has_image':
|
100 |
+
True if ann['image'] is not None else False,
|
101 |
+
}
|
102 |
+
data_list.append(data_info)
|
103 |
+
|
104 |
+
return data_list
|
mmpretrain/datasets/stanfordcars.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import mat4py
|
5 |
+
from mmengine import get_file_backend
|
6 |
+
|
7 |
+
from mmpretrain.registry import DATASETS
|
8 |
+
from .base_dataset import BaseDataset
|
9 |
+
from .categories import STANFORDCARS_CATEGORIES
|
10 |
+
|
11 |
+
|
12 |
+
@DATASETS.register_module()
|
13 |
+
class StanfordCars(BaseDataset):
|
14 |
+
"""The Stanford Cars Dataset.
|
15 |
+
|
16 |
+
Support the `Stanford Cars Dataset <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset.
|
17 |
+
The official website provides two ways to organize the dataset.
|
18 |
+
Therefore, after downloading and decompression, the dataset directory structure is as follows.
|
19 |
+
|
20 |
+
Stanford Cars dataset directory: ::
|
21 |
+
|
22 |
+
Stanford_Cars
|
23 |
+
├── car_ims
|
24 |
+
│ ├── 00001.jpg
|
25 |
+
│ ├── 00002.jpg
|
26 |
+
│ └── ...
|
27 |
+
└── cars_annos.mat
|
28 |
+
|
29 |
+
or ::
|
30 |
+
|
31 |
+
Stanford_Cars
|
32 |
+
├── cars_train
|
33 |
+
│ ├── 00001.jpg
|
34 |
+
│ ├── 00002.jpg
|
35 |
+
│ └── ...
|
36 |
+
├── cars_test
|
37 |
+
│ ├── 00001.jpg
|
38 |
+
│ ├── 00002.jpg
|
39 |
+
│ └── ...
|
40 |
+
└── devkit
|
41 |
+
├── cars_meta.mat
|
42 |
+
├── cars_train_annos.mat
|
43 |
+
├── cars_test_annos.mat
|
44 |
+
├── cars_test_annoswithlabels.mat
|
45 |
+
├── eval_train.m
|
46 |
+
└── train_perfect_preds.txt
|
47 |
+
|
48 |
+
Args:
|
49 |
+
data_root (str): The root directory for Stanford Cars dataset.
|
50 |
+
split (str, optional): The dataset split, supports "train"
|
51 |
+
and "test". Default to "train".
|
52 |
+
|
53 |
+
Examples:
|
54 |
+
>>> from mmpretrain.datasets import StanfordCars
|
55 |
+
>>> train_dataset = StanfordCars(data_root='data/Stanford_Cars', split='train')
|
56 |
+
>>> train_dataset
|
57 |
+
Dataset StanfordCars
|
58 |
+
Number of samples: 8144
|
59 |
+
Number of categories: 196
|
60 |
+
Root of dataset: data/Stanford_Cars
|
61 |
+
>>> test_dataset = StanfordCars(data_root='data/Stanford_Cars', split='test')
|
62 |
+
>>> test_dataset
|
63 |
+
Dataset StanfordCars
|
64 |
+
Number of samples: 8041
|
65 |
+
Number of categories: 196
|
66 |
+
Root of dataset: data/Stanford_Cars
|
67 |
+
""" # noqa: E501
|
68 |
+
|
69 |
+
METAINFO = {'classes': STANFORDCARS_CATEGORIES}
|
70 |
+
|
71 |
+
def __init__(self, data_root: str, split: str = 'train', **kwargs):
|
72 |
+
|
73 |
+
splits = ['train', 'test']
|
74 |
+
assert split in splits, \
|
75 |
+
f"The split must be one of {splits}, but get '{split}'"
|
76 |
+
self.split = split
|
77 |
+
|
78 |
+
test_mode = split == 'test'
|
79 |
+
self.backend = get_file_backend(data_root, enable_singleton=True)
|
80 |
+
|
81 |
+
anno_file_path = self.backend.join_path(data_root, 'cars_annos.mat')
|
82 |
+
if self.backend.exists(anno_file_path):
|
83 |
+
ann_file = 'cars_annos.mat'
|
84 |
+
data_prefix = ''
|
85 |
+
else:
|
86 |
+
if test_mode:
|
87 |
+
ann_file = self.backend.join_path(
|
88 |
+
'devkit', 'cars_test_annos_withlabels.mat')
|
89 |
+
data_prefix = 'cars_test'
|
90 |
+
else:
|
91 |
+
ann_file = self.backend.join_path('devkit',
|
92 |
+
'cars_train_annos.mat')
|
93 |
+
data_prefix = 'cars_train'
|
94 |
+
|
95 |
+
if not self.backend.exists(
|
96 |
+
self.backend.join_path(data_root, ann_file)):
|
97 |
+
doc_url = 'https://mmpretrain.readthedocs.io/en/latest/api/datasets.html#stanfordcars' # noqa: E501
|
98 |
+
raise RuntimeError(
|
99 |
+
f'The dataset is incorrectly organized, please \
|
100 |
+
refer to {doc_url} and reorganize your folders.')
|
101 |
+
|
102 |
+
super(StanfordCars, self).__init__(
|
103 |
+
ann_file=ann_file,
|
104 |
+
data_root=data_root,
|
105 |
+
data_prefix=data_prefix,
|
106 |
+
test_mode=test_mode,
|
107 |
+
**kwargs)
|
108 |
+
|
109 |
+
def load_data_list(self):
|
110 |
+
data = mat4py.loadmat(self.ann_file)['annotations']
|
111 |
+
|
112 |
+
data_list = []
|
113 |
+
if 'test' in data.keys():
|
114 |
+
# first way
|
115 |
+
img_paths, labels, test = data['relative_im_path'], data[
|
116 |
+
'class'], data['test']
|
117 |
+
num = len(img_paths)
|
118 |
+
assert num == len(labels) == len(test), 'get error ann file'
|
119 |
+
for i in range(num):
|
120 |
+
if not self.test_mode and test[i] == 1:
|
121 |
+
continue
|
122 |
+
if self.test_mode and test[i] == 0:
|
123 |
+
continue
|
124 |
+
img_path = self.backend.join_path(self.img_prefix,
|
125 |
+
img_paths[i])
|
126 |
+
gt_label = labels[i] - 1
|
127 |
+
info = dict(img_path=img_path, gt_label=gt_label)
|
128 |
+
data_list.append(info)
|
129 |
+
else:
|
130 |
+
# second way
|
131 |
+
img_names, labels = data['fname'], data['class']
|
132 |
+
num = len(img_names)
|
133 |
+
assert num == len(labels), 'get error ann file'
|
134 |
+
for i in range(num):
|
135 |
+
img_path = self.backend.join_path(self.img_prefix,
|
136 |
+
img_names[i])
|
137 |
+
gt_label = labels[i] - 1
|
138 |
+
info = dict(img_path=img_path, gt_label=gt_label)
|
139 |
+
data_list.append(info)
|
140 |
+
|
141 |
+
return data_list
|
142 |
+
|
143 |
+
def extra_repr(self) -> List[str]:
|
144 |
+
"""The extra repr information of the dataset."""
|
145 |
+
body = [
|
146 |
+
f'Root of dataset: \t{self.data_root}',
|
147 |
+
]
|
148 |
+
return body
|
mmpretrain/datasets/sun397.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from mmengine import get_file_backend, list_from_file
|
5 |
+
|
6 |
+
from mmpretrain.registry import DATASETS
|
7 |
+
from .base_dataset import BaseDataset
|
8 |
+
from .categories import SUN397_CATEGORIES
|
9 |
+
|
10 |
+
# Note that some images are not a jpg file although the name ends
|
11 |
+
# with jpg and therefore cannot be read properly. So we provide
|
12 |
+
# a list to skip these files.
|
13 |
+
INVALID = [
|
14 |
+
'/a/assembly_line/sun_ajckcfldgdrdjogj.jpg',
|
15 |
+
'/a/auto_factory/sun_apfsprenzdnzbhmt.jpg',
|
16 |
+
'/b/baggage_claim/sun_avittiqqaiibgcau.jpg',
|
17 |
+
'/b/batters_box/sun_alqlfpgtbgggezyr.jpg',
|
18 |
+
'/b/bow_window/indoor/sun_ahsholsagvlrsboa.jpg',
|
19 |
+
'/b/bow_window/indoor/sun_aioomcoujmmcxkkx.jpg',
|
20 |
+
'/b/bow_window/outdoor/sun_atgtjdpqikjmllth.jpg',
|
21 |
+
'/c/carrousel/sun_atsgphqympojgxnc.jpg',
|
22 |
+
'/c/carrousel/sun_auzitjuirwolazns.jpg',
|
23 |
+
'/c/church/outdoor/sun_boagasgfltequmal.jpg',
|
24 |
+
'/c/church/outdoor/sun_brhmnwzzbkphcvfo.jpg',
|
25 |
+
'/c/church/outdoor/sun_byjkqzybxpjnuofa.jpg',
|
26 |
+
'/c/corridor/sun_aznefxvocwpgimko.jpg',
|
27 |
+
'/d/dentists_office/sun_aaefsoauqlcsihou.jpg',
|
28 |
+
'/d/diner/indoor/sun_apswilaujhntrybg.jpg',
|
29 |
+
'/e/elevator/door/sun_aaudobqlphijkjdv.jpg',
|
30 |
+
'/f/fastfood_restaurant/sun_axeniwtesffxqedr.jpg',
|
31 |
+
'/f/fire_station/sun_bjyapttwilyyuxqm.jpg',
|
32 |
+
'/f/fountain/sun_axgmpbdyvqhtkhee.jpg',
|
33 |
+
'/h/hospital_room/sun_ahokhhxjiclpxqqa.jpg',
|
34 |
+
'/o/oast_house/sun_bqsrrygxyrutgjve.jpg',
|
35 |
+
'/r/restaurant_patio/sun_aurwypviprwycame.jpg',
|
36 |
+
'/s/ski_resort/sun_bplmntyzoiobcqhp.jpg',
|
37 |
+
'/w/wine_cellar/bottle_storage/sun_afmzwxkzmxkbamqi.jpg',
|
38 |
+
'/w/wine_cellar/bottle_storage/sun_ahyymswdjejrbhyb.jpg',
|
39 |
+
'/w/wine_cellar/bottle_storage/sun_avnttpxamufejbfe.jpg',
|
40 |
+
'/a/archive/sun_awgsrbljlsvhqjij.jpg',
|
41 |
+
'/a/art_school/sun_aabogqsjulyvmcse.jpg',
|
42 |
+
'/a/art_school/sun_apnzojafyvkariue.jpg',
|
43 |
+
'/b/ball_pit/sun_atjhwqngtoeuwhso.jpg',
|
44 |
+
'/b/bow_window/indoor/sun_asxvsqbexmmtqmht.jpg',
|
45 |
+
'/b/bow_window/indoor/sun_abeugxecxrwzmffp.jpg',
|
46 |
+
'/b/bow_window/outdoor/sun_auwcqhrtzkgihvlv.jpg',
|
47 |
+
'/b/bow_window/outdoor/sun_apnvdyecnjjmcuhi.jpg',
|
48 |
+
'/c/childs_room/sun_alggivksjwwiklmt.jpg',
|
49 |
+
'/c/control_tower/outdoor/sun_avbcxakrvpomqdgr.jpg',
|
50 |
+
'/d/diner/indoor/sun_ajmzozstvsxisvgx.jpg',
|
51 |
+
'/e/elevator/door/sun_aaqsyluqbluugqgy.jpg',
|
52 |
+
'/f/fastfood_restaurant/sun_aevchxlxoruhxgrb.jpg',
|
53 |
+
'/f/firing_range/indoor/sun_affrzvahwjorpalo.jpg',
|
54 |
+
'/f/formal_garden/sun_bjvrlaeatjufekft.jpg',
|
55 |
+
'/g/garage/indoor/sun_akbocuwclkxqlofx.jpg',
|
56 |
+
'/g/greenhouse/indoor/sun_addirvgtxfbndlwf.jpg',
|
57 |
+
'/k/kindergarden_classroom/sun_ajtpaahilrqzarri.jpg',
|
58 |
+
'/l/laundromat/sun_afrrjykuhhlwiwun.jpg',
|
59 |
+
'/m/music_studio/sun_bsntklkmwqgnjrjj.jpg',
|
60 |
+
'/t/track/outdoor/sun_aophkoiosslinihb.jpg',
|
61 |
+
'/a/archive/sun_aegmzltkiwyevpwa.jpg',
|
62 |
+
'/a/auto_factory/sun_aybymzvbxgvcrwgn.jpg',
|
63 |
+
'/b/baggage_claim/sun_atpmiqmnxjpgqsxi.jpg',
|
64 |
+
'/b/baggage_claim/sun_ajffcdpsvgqfzoxx.jpg',
|
65 |
+
'/b/bamboo_forest/sun_ausmxphosyahoyjo.jpg',
|
66 |
+
'/b/batters_box/sun_aaeheulsicxtxnbu.jpg',
|
67 |
+
'/c/carrousel/sun_arjrjcxemhttubqz.jpg',
|
68 |
+
'/c/chicken_coop/outdoor/sun_abcegmmdbizqkpgh.jpg',
|
69 |
+
'/c/control_tower/outdoor/sun_axhjfpkxdvqdfkyr.jpg',
|
70 |
+
'/d/diner/indoor/sun_apaotiublwqeowck.jpg',
|
71 |
+
'/f/fastfood_restaurant/sun_anexashcgmxdbmxq.jpg',
|
72 |
+
'/l/landing_deck/sun_aizahnjfkuurjibw.jpg',
|
73 |
+
'/n/nuclear_power_plant/outdoor/sun_aoblfvgyleweqanr.jpg',
|
74 |
+
'/w/waiting_room/sun_aicytusmthfvqcwc.jpg',
|
75 |
+
'/b/bow_window/indoor/sun_asmvdfnjlulewkpr.jpg',
|
76 |
+
'/b/bus_interior/sun_adhktvidwzmodeou.jpg',
|
77 |
+
'/c/catacomb/sun_algnawesgjzzmcqd.jpg',
|
78 |
+
'/c/church/outdoor/sun_baihxlseimcsdhdx.jpg',
|
79 |
+
'/d/diner/indoor/sun_agoyalzcawgxodbm.jpg',
|
80 |
+
'/e/elevator_shaft/sun_awaitimkinrjaybl.jpg',
|
81 |
+
'/f/fastfood_restaurant/sun_aplvzfbmtqtbsvbx.jpg',
|
82 |
+
'/g/greenhouse/indoor/sun_bkccvyfpwetwjuhk.jpg',
|
83 |
+
'/c/car_interior/backseat/sun_adexwfoqdyhowxpu.jpg',
|
84 |
+
'/c/church/outdoor/sun_blmmweiumednscuf.jpg',
|
85 |
+
'/f/fire_station/sun_bibntbsuunbsdrum.jpg',
|
86 |
+
'/g/game_room/sun_aopfaqlllpvzhrak.jpg',
|
87 |
+
'/u/underwater/coral_reef/sun_biiueajvszaxqopo.jpg',
|
88 |
+
'/a/airplane_cabin/sun_arqyikigkyfpegug.jpg',
|
89 |
+
'/b/badminton_court/indoor/sun_amppvxecgtjpfold.jpg',
|
90 |
+
'/c/carrousel/sun_anxtrtieimkpmhvk.jpg',
|
91 |
+
'/c/computer_room/sun_aebgvpgtwoqbfyvl.jpg',
|
92 |
+
'/f/fire_escape/sun_atbraxuwwlvdoolv.jpg',
|
93 |
+
'/k/kasbah/sun_abxkkoielpavsouu.jpg',
|
94 |
+
'/t/tower/sun_bccqnzcvqkiwicjt.jpg',
|
95 |
+
'/a/archive/sun_afngadshxudodkct.jpg',
|
96 |
+
'/b/bow_window/indoor/sun_awnrlipyxpgxxgxz.jpg',
|
97 |
+
'/c/control_tower/outdoor/sun_arohngcbtsvbthho.jpg',
|
98 |
+
'/f/fire_station/sun_brbskkfgghbfvgkk.jpg',
|
99 |
+
'/r/restaurant_patio/sun_amjfbqzfgxarrpec.jpg',
|
100 |
+
'/v/vineyard/sun_bdxhnbgbnolddswz.jpg',
|
101 |
+
'/b/baggage_claim/sun_axrtsmillrglugia.jpg',
|
102 |
+
'/d/diner/indoor/sun_alaqevbwpjaqqdqz.jpg',
|
103 |
+
'/l/landing_deck/sun_acodgoamhgnnbmvr.jpg',
|
104 |
+
'/c/carrousel/sun_adsafgyrinnekycc.jpg',
|
105 |
+
'/c/church/outdoor/sun_bzqhuwshtdgakkay.jpg',
|
106 |
+
'/c/closet/sun_absahzamlrylkxyn.jpg',
|
107 |
+
'/f/fire_escape/sun_acdthenaosuqcoqn.jpg',
|
108 |
+
'/b/butchers_shop/sun_asrdgbefoszenfex.jpg',
|
109 |
+
'/c/church/outdoor/sun_bzfyucfrdigaqneg.jpg',
|
110 |
+
'/c/church/outdoor/sun_byzxhknqrejdajxi.jpg',
|
111 |
+
'/c/cockpit/sun_ajkulpqauavrmxae.jpg',
|
112 |
+
'/l/living_room/sun_aefoqbeatyufobtx.jpg',
|
113 |
+
'/s/supermarket/sun_attvxbzocurnddbz.jpg',
|
114 |
+
'/c/closet/sun_aqnutmwfkypmrnfy.jpg',
|
115 |
+
'/f/fire_station/sun_bttrtzktpbymxkmf.jpg',
|
116 |
+
'/s/shopping_mall/indoor/sun_avwzjsijaxnwuzjx.jpg',
|
117 |
+
'/w/windmill/sun_blvczkyqbmabzeej.jpg',
|
118 |
+
'/c/chicken_coop/outdoor/sun_amaonsnnkskxwmrj.jpg',
|
119 |
+
'/s/swimming_pool/outdoor/sun_bslaihiqlhfewtzn.jpg',
|
120 |
+
'/u/underwater/coral_reef/sun_bhcrnmvbgnkvcvkr.jpg',
|
121 |
+
'/d/dining_room/sun_azlxdhiajwrhaivq.jpg',
|
122 |
+
'/c/church/outdoor/sun_bnunxbznqnvgeykx.jpg',
|
123 |
+
'/c/corridor/sun_aspwpqqlcwzfanvl.jpg',
|
124 |
+
'/r/restaurant_patio/sun_awcbpizjbudjvrhs.jpg',
|
125 |
+
'/b/ball_pit/sun_avdnmemjrgrbkwjm.jpg',
|
126 |
+
]
|
127 |
+
|
128 |
+
|
129 |
+
@DATASETS.register_module()
|
130 |
+
class SUN397(BaseDataset):
|
131 |
+
"""The SUN397 Dataset.
|
132 |
+
|
133 |
+
Support the `SUN397 Dataset <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_ Dataset.
|
134 |
+
After downloading and decompression, the dataset directory structure is as follows.
|
135 |
+
|
136 |
+
SUN397 dataset directory: ::
|
137 |
+
|
138 |
+
SUN397
|
139 |
+
├── SUN397
|
140 |
+
│ ├── a
|
141 |
+
│ │ ├── abbey
|
142 |
+
│ | | ├── sun_aaalbzqrimafwbiv.jpg
|
143 |
+
│ | | └── ...
|
144 |
+
│ │ ├── airplane_cabin
|
145 |
+
│ | | ├── sun_aadqdkqaslqqoblu.jpg
|
146 |
+
│ | | └── ...
|
147 |
+
│ | └── ...
|
148 |
+
│ ├── b
|
149 |
+
│ │ └── ...
|
150 |
+
│ ├── c
|
151 |
+
│ │ └── ...
|
152 |
+
│ └── ...
|
153 |
+
└── Partitions
|
154 |
+
├── ClassName.txt
|
155 |
+
├── Training_01.txt
|
156 |
+
├── Testing_01.txt
|
157 |
+
└── ...
|
158 |
+
|
159 |
+
Args:
|
160 |
+
data_root (str): The root directory for Stanford Cars dataset.
|
161 |
+
split (str, optional): The dataset split, supports "train" and "test".
|
162 |
+
Default to "train".
|
163 |
+
|
164 |
+
Examples:
|
165 |
+
>>> from mmpretrain.datasets import SUN397
|
166 |
+
>>> train_dataset = SUN397(data_root='data/SUN397', split='train')
|
167 |
+
>>> train_dataset
|
168 |
+
Dataset SUN397
|
169 |
+
Number of samples: 19824
|
170 |
+
Number of categories: 397
|
171 |
+
Root of dataset: data/SUN397
|
172 |
+
>>> test_dataset = SUN397(data_root='data/SUN397', split='test')
|
173 |
+
>>> test_dataset
|
174 |
+
Dataset SUN397
|
175 |
+
Number of samples: 19829
|
176 |
+
Number of categories: 397
|
177 |
+
Root of dataset: data/SUN397
|
178 |
+
""" # noqa: E501
|
179 |
+
|
180 |
+
METAINFO = {'classes': SUN397_CATEGORIES}
|
181 |
+
|
182 |
+
def __init__(self, data_root: str, split: str = 'train', **kwargs):
|
183 |
+
|
184 |
+
splits = ['train', 'test']
|
185 |
+
assert split in splits, \
|
186 |
+
f"The split must be one of {splits}, but get '{split}'"
|
187 |
+
self.split = split
|
188 |
+
|
189 |
+
self.backend = get_file_backend(data_root, enable_singleton=True)
|
190 |
+
if split == 'train':
|
191 |
+
ann_file = self.backend.join_path('Partitions', 'Training_01.txt')
|
192 |
+
else:
|
193 |
+
ann_file = self.backend.join_path('Partitions', 'Testing_01.txt')
|
194 |
+
|
195 |
+
data_prefix = 'SUN397'
|
196 |
+
test_mode = split == 'test'
|
197 |
+
|
198 |
+
super(SUN397, self).__init__(
|
199 |
+
ann_file=ann_file,
|
200 |
+
data_root=data_root,
|
201 |
+
test_mode=test_mode,
|
202 |
+
data_prefix=data_prefix,
|
203 |
+
**kwargs)
|
204 |
+
|
205 |
+
def load_data_list(self):
|
206 |
+
pairs = list_from_file(self.ann_file)
|
207 |
+
data_list = []
|
208 |
+
for pair in pairs:
|
209 |
+
if pair in INVALID:
|
210 |
+
continue
|
211 |
+
img_path = self.backend.join_path(self.img_prefix, pair[1:])
|
212 |
+
items = pair.split('/')
|
213 |
+
class_name = '_'.join(items[2:-1])
|
214 |
+
gt_label = self.METAINFO['classes'].index(class_name)
|
215 |
+
info = dict(img_path=img_path, gt_label=gt_label)
|
216 |
+
data_list.append(info)
|
217 |
+
|
218 |
+
return data_list
|
219 |
+
|
220 |
+
def extra_repr(self) -> List[str]:
|
221 |
+
"""The extra repr information of the dataset."""
|
222 |
+
body = [
|
223 |
+
f'Root of dataset: \t{self.data_root}',
|
224 |
+
]
|
225 |
+
return body
|
mmpretrain/datasets/transforms/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from mmcv.transforms import (CenterCrop, LoadImageFromFile, Normalize,
|
3 |
+
RandomFlip, RandomGrayscale, RandomResize, Resize)
|
4 |
+
|
5 |
+
from mmpretrain.registry import TRANSFORMS
|
6 |
+
from .auto_augment import (AutoAugment, AutoContrast, BaseAugTransform,
|
7 |
+
Brightness, ColorTransform, Contrast, Cutout,
|
8 |
+
Equalize, GaussianBlur, Invert, Posterize,
|
9 |
+
RandAugment, Rotate, Sharpness, Shear, Solarize,
|
10 |
+
SolarizeAdd, Translate)
|
11 |
+
from .formatting import (Collect, NumpyToPIL, PackInputs, PackMultiTaskInputs,
|
12 |
+
PILToNumpy, Transpose)
|
13 |
+
from .processing import (Albumentations, BEiTMaskGenerator, CleanCaption,
|
14 |
+
ColorJitter, EfficientNetCenterCrop,
|
15 |
+
EfficientNetRandomCrop, Lighting, RandomCrop,
|
16 |
+
RandomErasing, RandomResizedCrop, RandomTranslatePad,
|
17 |
+
ResizeEdge, SimMIMMaskGenerator)
|
18 |
+
from .wrappers import ApplyToList, MultiView
|
19 |
+
|
20 |
+
for t in (CenterCrop, LoadImageFromFile, Normalize, RandomFlip,
|
21 |
+
RandomGrayscale, RandomResize, Resize):
|
22 |
+
TRANSFORMS.register_module(module=t)
|
23 |
+
|
24 |
+
__all__ = [
|
25 |
+
'NumpyToPIL', 'PILToNumpy', 'Transpose', 'Collect', 'RandomCrop',
|
26 |
+
'RandomResizedCrop', 'Shear', 'Translate', 'Rotate', 'Invert',
|
27 |
+
'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize',
|
28 |
+
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',
|
29 |
+
'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing',
|
30 |
+
'PackInputs', 'Albumentations', 'EfficientNetRandomCrop',
|
31 |
+
'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform',
|
32 |
+
'PackMultiTaskInputs', 'GaussianBlur', 'BEiTMaskGenerator',
|
33 |
+
'SimMIMMaskGenerator', 'CenterCrop', 'LoadImageFromFile', 'Normalize',
|
34 |
+
'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView',
|
35 |
+
'ApplyToList', 'CleanCaption', 'RandomTranslatePad'
|
36 |
+
]
|