KyanChen commited on
Commit
4d0eb62
·
1 Parent(s): 1c3eb47

Upload 303 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. mmpretrain/__init__.py +28 -0
  3. mmpretrain/annotations/WHU_building_test.json +3 -0
  4. mmpretrain/annotations/WHU_building_train.json +3 -0
  5. mmpretrain/annotations/WHU_building_val.json +0 -0
  6. mmpretrain/apis/__init__.py +22 -0
  7. mmpretrain/apis/base.py +390 -0
  8. mmpretrain/apis/feature_extractor.py +128 -0
  9. mmpretrain/apis/image_caption.py +164 -0
  10. mmpretrain/apis/image_classification.py +221 -0
  11. mmpretrain/apis/image_retrieval.py +285 -0
  12. mmpretrain/apis/model.py +408 -0
  13. mmpretrain/apis/multimodal_retrieval.py +603 -0
  14. mmpretrain/apis/nlvr.py +150 -0
  15. mmpretrain/apis/utils.py +270 -0
  16. mmpretrain/apis/visual_grounding.py +180 -0
  17. mmpretrain/apis/visual_question_answering.py +181 -0
  18. mmpretrain/datasets/__init__.py +54 -0
  19. mmpretrain/datasets/base_dataset.py +219 -0
  20. mmpretrain/datasets/builder.py +25 -0
  21. mmpretrain/datasets/caltech101.py +113 -0
  22. mmpretrain/datasets/categories.py +1440 -0
  23. mmpretrain/datasets/cifar.py +210 -0
  24. mmpretrain/datasets/coco_caption.py +42 -0
  25. mmpretrain/datasets/coco_retrieval.py +77 -0
  26. mmpretrain/datasets/coco_vqa.py +114 -0
  27. mmpretrain/datasets/cub.py +142 -0
  28. mmpretrain/datasets/custom.py +287 -0
  29. mmpretrain/datasets/dataset_wrappers.py +176 -0
  30. mmpretrain/datasets/dtd.py +116 -0
  31. mmpretrain/datasets/fgvcaircraft.py +98 -0
  32. mmpretrain/datasets/flamingo.py +295 -0
  33. mmpretrain/datasets/flowers102.py +104 -0
  34. mmpretrain/datasets/food101.py +102 -0
  35. mmpretrain/datasets/imagenet.py +102 -0
  36. mmpretrain/datasets/inshop.py +157 -0
  37. mmpretrain/datasets/mnist.py +220 -0
  38. mmpretrain/datasets/multi_label.py +85 -0
  39. mmpretrain/datasets/multi_task.py +337 -0
  40. mmpretrain/datasets/nlvr2.py +36 -0
  41. mmpretrain/datasets/oxfordiiitpet.py +97 -0
  42. mmpretrain/datasets/places205.py +40 -0
  43. mmpretrain/datasets/refcoco.py +81 -0
  44. mmpretrain/datasets/samplers/__init__.py +5 -0
  45. mmpretrain/datasets/samplers/repeat_aug.py +101 -0
  46. mmpretrain/datasets/samplers/sequential.py +56 -0
  47. mmpretrain/datasets/scienceqa.py +104 -0
  48. mmpretrain/datasets/stanfordcars.py +148 -0
  49. mmpretrain/datasets/sun397.py +225 -0
  50. mmpretrain/datasets/transforms/__init__.py +36 -0
.gitattributes CHANGED
@@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  data/WHU/annotations/WHU_building_test.json filter=lfs diff=lfs merge=lfs -text
37
  data/WHU/annotations/WHU_building_train.json filter=lfs diff=lfs merge=lfs -text
 
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  data/WHU/annotations/WHU_building_test.json filter=lfs diff=lfs merge=lfs -text
37
  data/WHU/annotations/WHU_building_train.json filter=lfs diff=lfs merge=lfs -text
38
+ mmpretrain/annotations/WHU_building_test.json filter=lfs diff=lfs merge=lfs -text
39
+ mmpretrain/annotations/WHU_building_train.json filter=lfs diff=lfs merge=lfs -text
mmpretrain/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import mmcv
3
+ import mmengine
4
+ from mmengine.utils import digit_version
5
+
6
+ from .apis import * # noqa: F401, F403
7
+ from .version import __version__
8
+
9
+ mmcv_minimum_version = '2.0.0rc4'
10
+ mmcv_maximum_version = '2.1.0'
11
+ mmcv_version = digit_version(mmcv.__version__)
12
+
13
+ mmengine_minimum_version = '0.7.1'
14
+ mmengine_maximum_version = '1.0.0'
15
+ mmengine_version = digit_version(mmengine.__version__)
16
+
17
+ assert (mmcv_version >= digit_version(mmcv_minimum_version)
18
+ and mmcv_version < digit_version(mmcv_maximum_version)), \
19
+ f'MMCV=={mmcv.__version__} is used but incompatible. ' \
20
+ f'Please install mmcv>={mmcv_minimum_version}, <{mmcv_maximum_version}.'
21
+
22
+ assert (mmengine_version >= digit_version(mmengine_minimum_version)
23
+ and mmengine_version < digit_version(mmengine_maximum_version)), \
24
+ f'MMEngine=={mmengine.__version__} is used but incompatible. ' \
25
+ f'Please install mmengine>={mmengine_minimum_version}, ' \
26
+ f'<{mmengine_maximum_version}.'
27
+
28
+ __all__ = ['__version__']
mmpretrain/annotations/WHU_building_test.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5845dd19a3ec84aa3bc978ad5dc8066b43569c4ac9ff12c954d96208ec13432
3
+ size 13511169
mmpretrain/annotations/WHU_building_train.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28c490b7c80e6900a5b4da522faee91c6251589a0c9ebb258e79221c2586d2fa
3
+ size 42910976
mmpretrain/annotations/WHU_building_val.json ADDED
The diff for this file is too large to render. See raw diff
 
mmpretrain/apis/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .base import BaseInferencer
3
+ from .feature_extractor import FeatureExtractor
4
+ from .image_caption import ImageCaptionInferencer
5
+ from .image_classification import ImageClassificationInferencer
6
+ from .image_retrieval import ImageRetrievalInferencer
7
+ from .model import (ModelHub, get_model, inference_model, init_model,
8
+ list_models)
9
+ from .multimodal_retrieval import (ImageToTextRetrievalInferencer,
10
+ TextToImageRetrievalInferencer)
11
+ from .nlvr import NLVRInferencer
12
+ from .visual_grounding import VisualGroundingInferencer
13
+ from .visual_question_answering import VisualQuestionAnsweringInferencer
14
+
15
+ __all__ = [
16
+ 'init_model', 'inference_model', 'list_models', 'get_model', 'ModelHub',
17
+ 'ImageClassificationInferencer', 'ImageRetrievalInferencer',
18
+ 'FeatureExtractor', 'ImageCaptionInferencer',
19
+ 'TextToImageRetrievalInferencer', 'VisualGroundingInferencer',
20
+ 'VisualQuestionAnsweringInferencer', 'ImageToTextRetrievalInferencer',
21
+ 'BaseInferencer', 'NLVRInferencer'
22
+ ]
mmpretrain/apis/base.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from abc import abstractmethod
3
+ from math import ceil
4
+ from typing import Callable, Iterable, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from mmengine.config import Config
9
+ from mmengine.dataset import default_collate
10
+ from mmengine.fileio import get_file_backend
11
+ from mmengine.model import BaseModel
12
+ from mmengine.runner import load_checkpoint
13
+
14
+ from mmpretrain.structures import DataSample
15
+ from mmpretrain.utils import track
16
+ from .model import get_model, list_models
17
+
18
+ ModelType = Union[BaseModel, str, Config]
19
+ InputType = Union[str, np.ndarray, list]
20
+
21
+
22
+ class BaseInferencer:
23
+ """Base inferencer for various tasks.
24
+
25
+ The BaseInferencer provides the standard workflow for inference as follows:
26
+
27
+ 1. Preprocess the input data by :meth:`preprocess`.
28
+ 2. Forward the data to the model by :meth:`forward`. ``BaseInferencer``
29
+ assumes the model inherits from :class:`mmengine.models.BaseModel` and
30
+ will call `model.test_step` in :meth:`forward` by default.
31
+ 3. Visualize the results by :meth:`visualize`.
32
+ 4. Postprocess and return the results by :meth:`postprocess`.
33
+
34
+ When we call the subclasses inherited from BaseInferencer (not overriding
35
+ ``__call__``), the workflow will be executed in order.
36
+
37
+ All subclasses of BaseInferencer could define the following class
38
+ attributes for customization:
39
+
40
+ - ``preprocess_kwargs``: The keys of the kwargs that will be passed to
41
+ :meth:`preprocess`.
42
+ - ``forward_kwargs``: The keys of the kwargs that will be passed to
43
+ :meth:`forward`
44
+ - ``visualize_kwargs``: The keys of the kwargs that will be passed to
45
+ :meth:`visualize`
46
+ - ``postprocess_kwargs``: The keys of the kwargs that will be passed to
47
+ :meth:`postprocess`
48
+
49
+ All attributes mentioned above should be a ``set`` of keys (strings),
50
+ and each key should not be duplicated. Actually, :meth:`__call__` will
51
+ dispatch all the arguments to the corresponding methods according to the
52
+ ``xxx_kwargs`` mentioned above.
53
+
54
+ Subclasses inherited from ``BaseInferencer`` should implement
55
+ :meth:`_init_pipeline`, :meth:`visualize` and :meth:`postprocess`:
56
+
57
+ - _init_pipeline: Return a callable object to preprocess the input data.
58
+ - visualize: Visualize the results returned by :meth:`forward`.
59
+ - postprocess: Postprocess the results returned by :meth:`forward` and
60
+ :meth:`visualize`.
61
+
62
+ Args:
63
+ model (BaseModel | str | Config): A model name or a path to the config
64
+ file, or a :obj:`BaseModel` object. The model name can be found
65
+ by ``cls.list_models()`` and you can also query it in
66
+ :doc:`/modelzoo_statistics`.
67
+ pretrained (str, optional): Path to the checkpoint. If None, it will
68
+ try to find a pre-defined weight from the model you specified
69
+ (only work if the ``model`` is a model name). Defaults to None.
70
+ device (str | torch.device | None): Transfer the model to the target
71
+ device. Defaults to None.
72
+ device_map (str | dict | None): A map that specifies where each
73
+ submodule should go. It doesn't need to be refined to each
74
+ parameter/buffer name, once a given module name is inside, every
75
+ submodule of it will be sent to the same device. You can use
76
+ `device_map="auto"` to automatically generate the device map.
77
+ Defaults to None.
78
+ offload_folder (str | None): If the `device_map` contains any value
79
+ `"disk"`, the folder where we will offload weights.
80
+ **kwargs: Other keyword arguments to initialize the model (only work if
81
+ the ``model`` is a model name).
82
+ """
83
+
84
+ preprocess_kwargs: set = set()
85
+ forward_kwargs: set = set()
86
+ visualize_kwargs: set = set()
87
+ postprocess_kwargs: set = set()
88
+
89
+ def __init__(self,
90
+ model: ModelType,
91
+ pretrained: Union[bool, str] = True,
92
+ device: Union[str, torch.device, None] = None,
93
+ device_map=None,
94
+ offload_folder=None,
95
+ **kwargs) -> None:
96
+
97
+ if isinstance(model, BaseModel):
98
+ if isinstance(pretrained, str):
99
+ load_checkpoint(model, pretrained, map_location='cpu')
100
+ if device_map is not None:
101
+ from .utils import dispatch_model
102
+ model = dispatch_model(
103
+ model,
104
+ device_map=device_map,
105
+ offload_folder=offload_folder)
106
+ elif device is not None:
107
+ model.to(device)
108
+ else:
109
+ model = get_model(
110
+ model,
111
+ pretrained,
112
+ device=device,
113
+ device_map=device_map,
114
+ offload_folder=offload_folder,
115
+ **kwargs)
116
+
117
+ model.eval()
118
+
119
+ self.config = model._config
120
+ self.model = model
121
+ self.pipeline = self._init_pipeline(self.config)
122
+ self.visualizer = None
123
+
124
+ def __call__(
125
+ self,
126
+ inputs,
127
+ return_datasamples: bool = False,
128
+ batch_size: int = 1,
129
+ **kwargs,
130
+ ) -> dict:
131
+ """Call the inferencer.
132
+
133
+ Args:
134
+ inputs (InputsType): Inputs for the inferencer.
135
+ return_datasamples (bool): Whether to return results as
136
+ :obj:`BaseDataElement`. Defaults to False.
137
+ batch_size (int): Batch size. Defaults to 1.
138
+ **kwargs: Key words arguments passed to :meth:`preprocess`,
139
+ :meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
140
+ Each key in kwargs should be in the corresponding set of
141
+ ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
142
+ and ``postprocess_kwargs``.
143
+
144
+ Returns:
145
+ dict: Inference and visualization results.
146
+ """
147
+ (
148
+ preprocess_kwargs,
149
+ forward_kwargs,
150
+ visualize_kwargs,
151
+ postprocess_kwargs,
152
+ ) = self._dispatch_kwargs(**kwargs)
153
+
154
+ ori_inputs = self._inputs_to_list(inputs)
155
+ inputs = self.preprocess(
156
+ ori_inputs, batch_size=batch_size, **preprocess_kwargs)
157
+ preds = []
158
+ for data in track(
159
+ inputs, 'Inference', total=ceil(len(ori_inputs) / batch_size)):
160
+ preds.extend(self.forward(data, **forward_kwargs))
161
+ visualization = self.visualize(ori_inputs, preds, **visualize_kwargs)
162
+ results = self.postprocess(preds, visualization, return_datasamples,
163
+ **postprocess_kwargs)
164
+ return results
165
+
166
+ def _inputs_to_list(self, inputs: InputType) -> list:
167
+ """Preprocess the inputs to a list.
168
+
169
+ Cast the input data to a list of data.
170
+
171
+ - list or tuple: return inputs
172
+ - str:
173
+ - Directory path: return all files in the directory
174
+ - other cases: return a list containing the string. The string
175
+ could be a path to file, a url or other types of string according
176
+ to the task.
177
+ - other: return a list with one item.
178
+
179
+ Args:
180
+ inputs (str | array | list): Inputs for the inferencer.
181
+
182
+ Returns:
183
+ list: List of input for the :meth:`preprocess`.
184
+ """
185
+ if isinstance(inputs, str):
186
+ backend = get_file_backend(inputs)
187
+ if hasattr(backend, 'isdir') and backend.isdir(inputs):
188
+ # Backends like HttpsBackend do not implement `isdir`, so only
189
+ # those backends that implement `isdir` could accept the inputs
190
+ # as a directory
191
+ file_list = backend.list_dir_or_file(inputs, list_dir=False)
192
+ inputs = [
193
+ backend.join_path(inputs, file) for file in file_list
194
+ ]
195
+
196
+ if not isinstance(inputs, (list, tuple)):
197
+ inputs = [inputs]
198
+
199
+ return list(inputs)
200
+
201
+ def preprocess(self, inputs: InputType, batch_size: int = 1, **kwargs):
202
+ """Process the inputs into a model-feedable format.
203
+
204
+ Customize your preprocess by overriding this method. Preprocess should
205
+ return an iterable object, of which each item will be used as the
206
+ input of ``model.test_step``.
207
+
208
+ ``BaseInferencer.preprocess`` will return an iterable chunked data,
209
+ which will be used in __call__ like this:
210
+
211
+ .. code-block:: python
212
+
213
+ def __call__(self, inputs, batch_size=1, **kwargs):
214
+ chunked_data = self.preprocess(inputs, batch_size, **kwargs)
215
+ for batch in chunked_data:
216
+ preds = self.forward(batch, **kwargs)
217
+
218
+ Args:
219
+ inputs (InputsType): Inputs given by user.
220
+ batch_size (int): batch size. Defaults to 1.
221
+
222
+ Yields:
223
+ Any: Data processed by the ``pipeline`` and ``default_collate``.
224
+ """
225
+ chunked_data = self._get_chunk_data(
226
+ map(self.pipeline, inputs), batch_size)
227
+ yield from map(default_collate, chunked_data)
228
+
229
+ @torch.no_grad()
230
+ def forward(self, inputs: Union[dict, tuple], **kwargs):
231
+ """Feed the inputs to the model."""
232
+ return self.model.test_step(inputs)
233
+
234
+ def visualize(self,
235
+ inputs: list,
236
+ preds: List[DataSample],
237
+ show: bool = False,
238
+ **kwargs) -> List[np.ndarray]:
239
+ """Visualize predictions.
240
+
241
+ Customize your visualization by overriding this method. visualize
242
+ should return visualization results, which could be np.ndarray or any
243
+ other objects.
244
+
245
+ Args:
246
+ inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
247
+ preds (Any): Predictions of the model.
248
+ show (bool): Whether to display the image in a popup window.
249
+ Defaults to False.
250
+
251
+ Returns:
252
+ List[np.ndarray]: Visualization results.
253
+ """
254
+ if show:
255
+ raise NotImplementedError(
256
+ f'The `visualize` method of {self.__class__.__name__} '
257
+ 'is not implemented.')
258
+
259
+ @abstractmethod
260
+ def postprocess(
261
+ self,
262
+ preds: List[DataSample],
263
+ visualization: List[np.ndarray],
264
+ return_datasample=False,
265
+ **kwargs,
266
+ ) -> dict:
267
+ """Process the predictions and visualization results from ``forward``
268
+ and ``visualize``.
269
+
270
+ This method should be responsible for the following tasks:
271
+
272
+ 1. Convert datasamples into a json-serializable dict if needed.
273
+ 2. Pack the predictions and visualization results and return them.
274
+ 3. Dump or log the predictions.
275
+
276
+ Customize your postprocess by overriding this method. Make sure
277
+ ``postprocess`` will return a dict with visualization results and
278
+ inference results.
279
+
280
+ Args:
281
+ preds (List[Dict]): Predictions of the model.
282
+ visualization (np.ndarray): Visualized predictions.
283
+ return_datasample (bool): Whether to return results as datasamples.
284
+ Defaults to False.
285
+
286
+ Returns:
287
+ dict: Inference and visualization results with key ``predictions``
288
+ and ``visualization``
289
+
290
+ - ``visualization (Any)``: Returned by :meth:`visualize`
291
+ - ``predictions`` (dict or DataSample): Returned by
292
+ :meth:`forward` and processed in :meth:`postprocess`.
293
+ If ``return_datasample=False``, it usually should be a
294
+ json-serializable dict containing only basic data elements such
295
+ as strings and numbers.
296
+ """
297
+
298
+ @abstractmethod
299
+ def _init_pipeline(self, cfg: Config) -> Callable:
300
+ """Initialize the test pipeline.
301
+
302
+ Return a pipeline to handle various input data, such as ``str``,
303
+ ``np.ndarray``. It is an abstract method in BaseInferencer, and should
304
+ be implemented in subclasses.
305
+
306
+ The returned pipeline will be used to process a single data.
307
+ It will be used in :meth:`preprocess` like this:
308
+
309
+ .. code-block:: python
310
+ def preprocess(self, inputs, batch_size, **kwargs):
311
+ ...
312
+ dataset = map(self.pipeline, dataset)
313
+ ...
314
+ """
315
+
316
+ def _get_chunk_data(self, inputs: Iterable, chunk_size: int):
317
+ """Get batch data from dataset.
318
+
319
+ Args:
320
+ inputs (Iterable): An iterable dataset.
321
+ chunk_size (int): Equivalent to batch size.
322
+
323
+ Yields:
324
+ list: batch data.
325
+ """
326
+ inputs_iter = iter(inputs)
327
+ while True:
328
+ try:
329
+ chunk_data = []
330
+ for _ in range(chunk_size):
331
+ processed_data = next(inputs_iter)
332
+ chunk_data.append(processed_data)
333
+ yield chunk_data
334
+ except StopIteration:
335
+ if chunk_data:
336
+ yield chunk_data
337
+ break
338
+
339
+ def _dispatch_kwargs(self, **kwargs) -> Tuple[dict, dict, dict, dict]:
340
+ """Dispatch kwargs to preprocess(), forward(), visualize() and
341
+ postprocess() according to the actual demands.
342
+
343
+ Returns:
344
+ Tuple[Dict, Dict, Dict, Dict]: kwargs passed to preprocess,
345
+ forward, visualize and postprocess respectively.
346
+ """
347
+ # Ensure each argument only matches one function
348
+ method_kwargs = self.preprocess_kwargs | self.forward_kwargs | \
349
+ self.visualize_kwargs | self.postprocess_kwargs
350
+
351
+ union_kwargs = method_kwargs | set(kwargs.keys())
352
+ if union_kwargs != method_kwargs:
353
+ unknown_kwargs = union_kwargs - method_kwargs
354
+ raise ValueError(
355
+ f'unknown argument {unknown_kwargs} for `preprocess`, '
356
+ '`forward`, `visualize` and `postprocess`')
357
+
358
+ preprocess_kwargs = {}
359
+ forward_kwargs = {}
360
+ visualize_kwargs = {}
361
+ postprocess_kwargs = {}
362
+
363
+ for key, value in kwargs.items():
364
+ if key in self.preprocess_kwargs:
365
+ preprocess_kwargs[key] = value
366
+ if key in self.forward_kwargs:
367
+ forward_kwargs[key] = value
368
+ if key in self.visualize_kwargs:
369
+ visualize_kwargs[key] = value
370
+ if key in self.postprocess_kwargs:
371
+ postprocess_kwargs[key] = value
372
+
373
+ return (
374
+ preprocess_kwargs,
375
+ forward_kwargs,
376
+ visualize_kwargs,
377
+ postprocess_kwargs,
378
+ )
379
+
380
+ @staticmethod
381
+ def list_models(pattern: Optional[str] = None):
382
+ """List models defined in metafile of corresponding packages.
383
+
384
+ Args:
385
+ pattern (str | None): A wildcard pattern to match model names.
386
+
387
+ Returns:
388
+ List[str]: a list of model names.
389
+ """
390
+ return list_models(pattern=pattern)
mmpretrain/apis/feature_extractor.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Callable, List, Optional, Union
3
+
4
+ import torch
5
+ from mmcv.image import imread
6
+ from mmengine.config import Config
7
+ from mmengine.dataset import Compose, default_collate
8
+
9
+ from mmpretrain.registry import TRANSFORMS
10
+ from .base import BaseInferencer, InputType
11
+ from .model import list_models
12
+
13
+
14
+ class FeatureExtractor(BaseInferencer):
15
+ """The inferencer for extract features.
16
+
17
+ Args:
18
+ model (BaseModel | str | Config): A model name or a path to the config
19
+ file, or a :obj:`BaseModel` object. The model name can be found
20
+ by ``FeatureExtractor.list_models()`` and you can also query it in
21
+ :doc:`/modelzoo_statistics`.
22
+ pretrained (str, optional): Path to the checkpoint. If None, it will
23
+ try to find a pre-defined weight from the model you specified
24
+ (only work if the ``model`` is a model name). Defaults to None.
25
+ device (str, optional): Device to run inference. If None, the available
26
+ device will be automatically used. Defaults to None.
27
+ **kwargs: Other keyword arguments to initialize the model (only work if
28
+ the ``model`` is a model name).
29
+
30
+ Example:
31
+ >>> from mmpretrain import FeatureExtractor
32
+ >>> inferencer = FeatureExtractor('resnet50_8xb32_in1k', backbone=dict(out_indices=(0, 1, 2, 3)))
33
+ >>> feats = inferencer('demo/demo.JPEG', stage='backbone')[0]
34
+ >>> for feat in feats:
35
+ >>> print(feat.shape)
36
+ torch.Size([256, 56, 56])
37
+ torch.Size([512, 28, 28])
38
+ torch.Size([1024, 14, 14])
39
+ torch.Size([2048, 7, 7])
40
+ """ # noqa: E501
41
+
42
+ def __call__(self,
43
+ inputs: InputType,
44
+ batch_size: int = 1,
45
+ **kwargs) -> dict:
46
+ """Call the inferencer.
47
+
48
+ Args:
49
+ inputs (str | array | list): The image path or array, or a list of
50
+ images.
51
+ batch_size (int): Batch size. Defaults to 1.
52
+ **kwargs: Other keyword arguments accepted by the `extract_feat`
53
+ method of the model.
54
+
55
+ Returns:
56
+ tensor | Tuple[tensor]: The extracted features.
57
+ """
58
+ ori_inputs = self._inputs_to_list(inputs)
59
+ inputs = self.preprocess(ori_inputs, batch_size=batch_size)
60
+ preds = []
61
+ for data in inputs:
62
+ preds.extend(self.forward(data, **kwargs))
63
+
64
+ return preds
65
+
66
+ @torch.no_grad()
67
+ def forward(self, inputs: Union[dict, tuple], **kwargs):
68
+ inputs = self.model.data_preprocessor(inputs, False)['inputs']
69
+ outputs = self.model.extract_feat(inputs, **kwargs)
70
+
71
+ def scatter(feats, index):
72
+ if isinstance(feats, torch.Tensor):
73
+ return feats[index]
74
+ else:
75
+ # Sequence of tensor
76
+ return type(feats)([scatter(item, index) for item in feats])
77
+
78
+ results = []
79
+ for i in range(inputs.shape[0]):
80
+ results.append(scatter(outputs, i))
81
+
82
+ return results
83
+
84
+ def _init_pipeline(self, cfg: Config) -> Callable:
85
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
86
+ if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
87
+ # Image loading is finished in `self.preprocess`.
88
+ test_pipeline_cfg = test_pipeline_cfg[1:]
89
+ test_pipeline = Compose(
90
+ [TRANSFORMS.build(t) for t in test_pipeline_cfg])
91
+ return test_pipeline
92
+
93
+ def preprocess(self, inputs: List[InputType], batch_size: int = 1):
94
+
95
+ def load_image(input_):
96
+ img = imread(input_)
97
+ if img is None:
98
+ raise ValueError(f'Failed to read image {input_}.')
99
+ return dict(
100
+ img=img,
101
+ img_shape=img.shape[:2],
102
+ ori_shape=img.shape[:2],
103
+ )
104
+
105
+ pipeline = Compose([load_image, self.pipeline])
106
+
107
+ chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
108
+ yield from map(default_collate, chunked_data)
109
+
110
+ def visualize(self):
111
+ raise NotImplementedError(
112
+ "The FeatureExtractor doesn't support visualization.")
113
+
114
+ def postprocess(self):
115
+ raise NotImplementedError(
116
+ "The FeatureExtractor doesn't need postprocessing.")
117
+
118
+ @staticmethod
119
+ def list_models(pattern: Optional[str] = None):
120
+ """List all available model names.
121
+
122
+ Args:
123
+ pattern (str | None): A wildcard pattern to match model names.
124
+
125
+ Returns:
126
+ List[str]: a list of model names.
127
+ """
128
+ return list_models(pattern=pattern)
mmpretrain/apis/image_caption.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from pathlib import Path
3
+ from typing import Callable, List, Optional
4
+
5
+ import numpy as np
6
+ from mmcv.image import imread
7
+ from mmengine.config import Config
8
+ from mmengine.dataset import Compose, default_collate
9
+
10
+ from mmpretrain.registry import TRANSFORMS
11
+ from mmpretrain.structures import DataSample
12
+ from .base import BaseInferencer, InputType
13
+ from .model import list_models
14
+
15
+
16
+ class ImageCaptionInferencer(BaseInferencer):
17
+ """The inferencer for image caption.
18
+
19
+ Args:
20
+ model (BaseModel | str | Config): A model name or a path to the config
21
+ file, or a :obj:`BaseModel` object. The model name can be found
22
+ by ``ImageCaptionInferencer.list_models()`` and you can also
23
+ query it in :doc:`/modelzoo_statistics`.
24
+ pretrained (str, optional): Path to the checkpoint. If None, it will
25
+ try to find a pre-defined weight from the model you specified
26
+ (only work if the ``model`` is a model name). Defaults to None.
27
+ device (str, optional): Device to run inference. If None, the available
28
+ device will be automatically used. Defaults to None.
29
+ **kwargs: Other keyword arguments to initialize the model (only work if
30
+ the ``model`` is a model name).
31
+
32
+ Example:
33
+ >>> from mmpretrain import ImageCaptionInferencer
34
+ >>> inferencer = ImageCaptionInferencer('blip-base_3rdparty_caption')
35
+ >>> inferencer('demo/cat-dog.png')[0]
36
+ {'pred_caption': 'a puppy and a cat sitting on a blanket'}
37
+ """ # noqa: E501
38
+
39
+ visualize_kwargs: set = {'resize', 'show', 'show_dir', 'wait_time'}
40
+
41
+ def __call__(self,
42
+ images: InputType,
43
+ return_datasamples: bool = False,
44
+ batch_size: int = 1,
45
+ **kwargs) -> dict:
46
+ """Call the inferencer.
47
+
48
+ Args:
49
+ images (str | array | list): The image path or array, or a list of
50
+ images.
51
+ return_datasamples (bool): Whether to return results as
52
+ :obj:`DataSample`. Defaults to False.
53
+ batch_size (int): Batch size. Defaults to 1.
54
+ resize (int, optional): Resize the short edge of the image to the
55
+ specified length before visualization. Defaults to None.
56
+ draw_score (bool): Whether to draw the prediction scores
57
+ of prediction categories. Defaults to True.
58
+ show (bool): Whether to display the visualization result in a
59
+ window. Defaults to False.
60
+ wait_time (float): The display time (s). Defaults to 0, which means
61
+ "forever".
62
+ show_dir (str, optional): If not None, save the visualization
63
+ results in the specified directory. Defaults to None.
64
+
65
+ Returns:
66
+ list: The inference results.
67
+ """
68
+ return super().__call__(images, return_datasamples, batch_size,
69
+ **kwargs)
70
+
71
+ def _init_pipeline(self, cfg: Config) -> Callable:
72
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
73
+ if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
74
+ # Image loading is finished in `self.preprocess`.
75
+ test_pipeline_cfg = test_pipeline_cfg[1:]
76
+ test_pipeline = Compose(
77
+ [TRANSFORMS.build(t) for t in test_pipeline_cfg])
78
+ return test_pipeline
79
+
80
+ def preprocess(self, inputs: List[InputType], batch_size: int = 1):
81
+
82
+ def load_image(input_):
83
+ img = imread(input_)
84
+ if img is None:
85
+ raise ValueError(f'Failed to read image {input_}.')
86
+ return dict(
87
+ img=img,
88
+ img_shape=img.shape[:2],
89
+ ori_shape=img.shape[:2],
90
+ )
91
+
92
+ pipeline = Compose([load_image, self.pipeline])
93
+
94
+ chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
95
+ yield from map(default_collate, chunked_data)
96
+
97
+ def visualize(self,
98
+ ori_inputs: List[InputType],
99
+ preds: List[DataSample],
100
+ show: bool = False,
101
+ wait_time: int = 0,
102
+ resize: Optional[int] = None,
103
+ show_dir=None):
104
+ if not show and show_dir is None:
105
+ return None
106
+
107
+ if self.visualizer is None:
108
+ from mmpretrain.visualization import UniversalVisualizer
109
+ self.visualizer = UniversalVisualizer()
110
+
111
+ visualization = []
112
+ for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
113
+ image = imread(input_)
114
+ if isinstance(input_, str):
115
+ # The image loaded from path is BGR format.
116
+ image = image[..., ::-1]
117
+ name = Path(input_).stem
118
+ else:
119
+ name = str(i)
120
+
121
+ if show_dir is not None:
122
+ show_dir = Path(show_dir)
123
+ show_dir.mkdir(exist_ok=True)
124
+ out_file = str((show_dir / name).with_suffix('.png'))
125
+ else:
126
+ out_file = None
127
+
128
+ self.visualizer.visualize_image_caption(
129
+ image,
130
+ data_sample,
131
+ resize=resize,
132
+ show=show,
133
+ wait_time=wait_time,
134
+ name=name,
135
+ out_file=out_file)
136
+ visualization.append(self.visualizer.get_image())
137
+ if show:
138
+ self.visualizer.close()
139
+ return visualization
140
+
141
+ def postprocess(self,
142
+ preds: List[DataSample],
143
+ visualization: List[np.ndarray],
144
+ return_datasamples=False) -> dict:
145
+ if return_datasamples:
146
+ return preds
147
+
148
+ results = []
149
+ for data_sample in preds:
150
+ results.append({'pred_caption': data_sample.get('pred_caption')})
151
+
152
+ return results
153
+
154
+ @staticmethod
155
+ def list_models(pattern: Optional[str] = None):
156
+ """List all available model names.
157
+
158
+ Args:
159
+ pattern (str | None): A wildcard pattern to match model names.
160
+
161
+ Returns:
162
+ List[str]: a list of model names.
163
+ """
164
+ return list_models(pattern=pattern, task='Image Caption')
mmpretrain/apis/image_classification.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from pathlib import Path
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from mmcv.image import imread
8
+ from mmengine.config import Config
9
+ from mmengine.dataset import Compose, default_collate
10
+
11
+ from mmpretrain.registry import TRANSFORMS
12
+ from mmpretrain.structures import DataSample
13
+ from .base import BaseInferencer, InputType, ModelType
14
+ from .model import list_models
15
+
16
+
17
+ class ImageClassificationInferencer(BaseInferencer):
18
+ """The inferencer for image classification.
19
+
20
+ Args:
21
+ model (BaseModel | str | Config): A model name or a path to the config
22
+ file, or a :obj:`BaseModel` object. The model name can be found
23
+ by ``ImageClassificationInferencer.list_models()`` and you can also
24
+ query it in :doc:`/modelzoo_statistics`.
25
+ pretrained (str, optional): Path to the checkpoint. If None, it will
26
+ try to find a pre-defined weight from the model you specified
27
+ (only work if the ``model`` is a model name). Defaults to None.
28
+ device (str, optional): Device to run inference. If None, the available
29
+ device will be automatically used. Defaults to None.
30
+ **kwargs: Other keyword arguments to initialize the model (only work if
31
+ the ``model`` is a model name).
32
+
33
+ Example:
34
+ 1. Use a pre-trained model in MMPreTrain to inference an image.
35
+
36
+ >>> from mmpretrain import ImageClassificationInferencer
37
+ >>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k')
38
+ >>> inferencer('demo/demo.JPEG')
39
+ [{'pred_score': array([...]),
40
+ 'pred_label': 65,
41
+ 'pred_score': 0.6649367809295654,
42
+ 'pred_class': 'sea snake'}]
43
+
44
+ 2. Use a config file and checkpoint to inference multiple images on GPU,
45
+ and save the visualization results in a folder.
46
+
47
+ >>> from mmpretrain import ImageClassificationInferencer
48
+ >>> inferencer = ImageClassificationInferencer(
49
+ model='configs/resnet/resnet50_8xb32_in1k.py',
50
+ pretrained='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
51
+ device='cuda')
52
+ >>> inferencer(['demo/dog.jpg', 'demo/bird.JPEG'], show_dir="./visualize/")
53
+ """ # noqa: E501
54
+
55
+ visualize_kwargs: set = {
56
+ 'resize', 'rescale_factor', 'draw_score', 'show', 'show_dir',
57
+ 'wait_time'
58
+ }
59
+
60
+ def __init__(self,
61
+ model: ModelType,
62
+ pretrained: Union[bool, str] = True,
63
+ device: Union[str, torch.device, None] = None,
64
+ classes=None,
65
+ **kwargs) -> None:
66
+ super().__init__(
67
+ model=model, pretrained=pretrained, device=device, **kwargs)
68
+
69
+ if classes is not None:
70
+ self.classes = classes
71
+ else:
72
+ self.classes = getattr(self.model, '_dataset_meta',
73
+ {}).get('classes')
74
+
75
+ def __call__(self,
76
+ inputs: InputType,
77
+ return_datasamples: bool = False,
78
+ batch_size: int = 1,
79
+ **kwargs) -> dict:
80
+ """Call the inferencer.
81
+
82
+ Args:
83
+ inputs (str | array | list): The image path or array, or a list of
84
+ images.
85
+ return_datasamples (bool): Whether to return results as
86
+ :obj:`DataSample`. Defaults to False.
87
+ batch_size (int): Batch size. Defaults to 1.
88
+ resize (int, optional): Resize the short edge of the image to the
89
+ specified length before visualization. Defaults to None.
90
+ rescale_factor (float, optional): Rescale the image by the rescale
91
+ factor for visualization. This is helpful when the image is too
92
+ large or too small for visualization. Defaults to None.
93
+ draw_score (bool): Whether to draw the prediction scores
94
+ of prediction categories. Defaults to True.
95
+ show (bool): Whether to display the visualization result in a
96
+ window. Defaults to False.
97
+ wait_time (float): The display time (s). Defaults to 0, which means
98
+ "forever".
99
+ show_dir (str, optional): If not None, save the visualization
100
+ results in the specified directory. Defaults to None.
101
+
102
+ Returns:
103
+ list: The inference results.
104
+ """
105
+ return super().__call__(
106
+ inputs,
107
+ return_datasamples=return_datasamples,
108
+ batch_size=batch_size,
109
+ **kwargs)
110
+
111
+ def _init_pipeline(self, cfg: Config) -> Callable:
112
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
113
+ if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
114
+ # Image loading is finished in `self.preprocess`.
115
+ test_pipeline_cfg = test_pipeline_cfg[1:]
116
+ test_pipeline = Compose(
117
+ [TRANSFORMS.build(t) for t in test_pipeline_cfg])
118
+ return test_pipeline
119
+
120
+ def preprocess(self, inputs: List[InputType], batch_size: int = 1):
121
+
122
+ def load_image(input_):
123
+ img = imread(input_)
124
+ if img is None:
125
+ raise ValueError(f'Failed to read image {input_}.')
126
+ return dict(
127
+ img=img,
128
+ img_shape=img.shape[:2],
129
+ ori_shape=img.shape[:2],
130
+ )
131
+
132
+ pipeline = Compose([load_image, self.pipeline])
133
+
134
+ chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
135
+ yield from map(default_collate, chunked_data)
136
+
137
+ def visualize(self,
138
+ ori_inputs: List[InputType],
139
+ preds: List[DataSample],
140
+ show: bool = False,
141
+ wait_time: int = 0,
142
+ resize: Optional[int] = None,
143
+ rescale_factor: Optional[float] = None,
144
+ draw_score=True,
145
+ show_dir=None):
146
+ if not show and show_dir is None:
147
+ return None
148
+
149
+ if self.visualizer is None:
150
+ from mmpretrain.visualization import UniversalVisualizer
151
+ self.visualizer = UniversalVisualizer()
152
+
153
+ visualization = []
154
+ for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
155
+ image = imread(input_)
156
+ if isinstance(input_, str):
157
+ # The image loaded from path is BGR format.
158
+ image = image[..., ::-1]
159
+ name = Path(input_).stem
160
+ else:
161
+ name = str(i)
162
+
163
+ if show_dir is not None:
164
+ show_dir = Path(show_dir)
165
+ show_dir.mkdir(exist_ok=True)
166
+ out_file = str((show_dir / name).with_suffix('.png'))
167
+ else:
168
+ out_file = None
169
+
170
+ self.visualizer.visualize_cls(
171
+ image,
172
+ data_sample,
173
+ classes=self.classes,
174
+ resize=resize,
175
+ show=show,
176
+ wait_time=wait_time,
177
+ rescale_factor=rescale_factor,
178
+ draw_gt=False,
179
+ draw_pred=True,
180
+ draw_score=draw_score,
181
+ name=name,
182
+ out_file=out_file)
183
+ visualization.append(self.visualizer.get_image())
184
+ if show:
185
+ self.visualizer.close()
186
+ return visualization
187
+
188
+ def postprocess(self,
189
+ preds: List[DataSample],
190
+ visualization: List[np.ndarray],
191
+ return_datasamples=False) -> dict:
192
+ if return_datasamples:
193
+ return preds
194
+
195
+ results = []
196
+ for data_sample in preds:
197
+ pred_scores = data_sample.pred_score
198
+ pred_score = float(torch.max(pred_scores).item())
199
+ pred_label = torch.argmax(pred_scores).item()
200
+ result = {
201
+ 'pred_scores': pred_scores.detach().cpu().numpy(),
202
+ 'pred_label': pred_label,
203
+ 'pred_score': pred_score,
204
+ }
205
+ if self.classes is not None:
206
+ result['pred_class'] = self.classes[pred_label]
207
+ results.append(result)
208
+
209
+ return results
210
+
211
+ @staticmethod
212
+ def list_models(pattern: Optional[str] = None):
213
+ """List all available model names.
214
+
215
+ Args:
216
+ pattern (str | None): A wildcard pattern to match model names.
217
+
218
+ Returns:
219
+ List[str]: a list of model names.
220
+ """
221
+ return list_models(pattern=pattern, task='Image Classification')
mmpretrain/apis/image_retrieval.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from pathlib import Path
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from mmcv.image import imread
8
+ from mmengine.config import Config
9
+ from mmengine.dataset import BaseDataset, Compose, default_collate
10
+
11
+ from mmpretrain.registry import TRANSFORMS
12
+ from mmpretrain.structures import DataSample
13
+ from .base import BaseInferencer, InputType, ModelType
14
+ from .model import list_models
15
+
16
+
17
+ class ImageRetrievalInferencer(BaseInferencer):
18
+ """The inferencer for image to image retrieval.
19
+
20
+ Args:
21
+ model (BaseModel | str | Config): A model name or a path to the config
22
+ file, or a :obj:`BaseModel` object. The model name can be found
23
+ by ``ImageRetrievalInferencer.list_models()`` and you can also
24
+ query it in :doc:`/modelzoo_statistics`.
25
+ prototype (str | list | dict | DataLoader, BaseDataset): The images to
26
+ be retrieved. It can be the following types:
27
+
28
+ - str: The directory of the the images.
29
+ - list: A list of path of the images.
30
+ - dict: A config dict of the a prototype dataset.
31
+ - BaseDataset: A prototype dataset.
32
+ - DataLoader: A data loader to load the prototype data.
33
+
34
+ prototype_cache (str, optional): The path of the generated prototype
35
+ features. If exists, directly load the cache instead of re-generate
36
+ the prototype features. If not exists, save the generated features
37
+ to the path. Defaults to None.
38
+ pretrained (str, optional): Path to the checkpoint. If None, it will
39
+ try to find a pre-defined weight from the model you specified
40
+ (only work if the ``model`` is a model name). Defaults to None.
41
+ device (str, optional): Device to run inference. If None, the available
42
+ device will be automatically used. Defaults to None.
43
+ **kwargs: Other keyword arguments to initialize the model (only work if
44
+ the ``model`` is a model name).
45
+
46
+ Example:
47
+ >>> from mmpretrain import ImageRetrievalInferencer
48
+ >>> inferencer = ImageRetrievalInferencer(
49
+ ... 'resnet50-arcface_8xb32_inshop',
50
+ ... prototype='./demo/',
51
+ ... prototype_cache='img_retri.pth')
52
+ >>> inferencer('demo/cat-dog.png', topk=2)[0][1]
53
+ {'match_score': tensor(0.4088, device='cuda:0'),
54
+ 'sample_idx': 3,
55
+ 'sample': {'img_path': './demo/dog.jpg'}}
56
+ """ # noqa: E501
57
+
58
+ visualize_kwargs: set = {
59
+ 'draw_score', 'resize', 'show_dir', 'show', 'wait_time', 'topk'
60
+ }
61
+ postprocess_kwargs: set = {'topk'}
62
+
63
+ def __init__(
64
+ self,
65
+ model: ModelType,
66
+ prototype,
67
+ prototype_cache=None,
68
+ prepare_batch_size=8,
69
+ pretrained: Union[bool, str] = True,
70
+ device: Union[str, torch.device, None] = None,
71
+ **kwargs,
72
+ ) -> None:
73
+ super().__init__(
74
+ model=model, pretrained=pretrained, device=device, **kwargs)
75
+
76
+ self.prototype_dataset = self._prepare_prototype(
77
+ prototype, prototype_cache, prepare_batch_size)
78
+
79
+ def _prepare_prototype(self, prototype, cache=None, batch_size=8):
80
+ from mmengine.dataset import DefaultSampler
81
+ from torch.utils.data import DataLoader
82
+
83
+ def build_dataloader(dataset):
84
+ return DataLoader(
85
+ dataset,
86
+ batch_size=batch_size,
87
+ collate_fn=default_collate,
88
+ sampler=DefaultSampler(dataset, shuffle=False),
89
+ persistent_workers=False,
90
+ )
91
+
92
+ if isinstance(prototype, str):
93
+ # A directory path of images
94
+ prototype = dict(
95
+ type='CustomDataset', with_label=False, data_root=prototype)
96
+
97
+ if isinstance(prototype, list):
98
+ test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline]
99
+ dataset = BaseDataset(
100
+ lazy_init=True, serialize_data=False, pipeline=test_pipeline)
101
+ dataset.data_list = [{
102
+ 'sample_idx': i,
103
+ 'img_path': file
104
+ } for i, file in enumerate(prototype)]
105
+ dataset._fully_initialized = True
106
+ dataloader = build_dataloader(dataset)
107
+ elif isinstance(prototype, dict):
108
+ # A config of dataset
109
+ from mmpretrain.registry import DATASETS
110
+ test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline]
111
+ dataset = DATASETS.build(prototype)
112
+ dataloader = build_dataloader(dataset)
113
+ elif isinstance(prototype, DataLoader):
114
+ dataset = prototype.dataset
115
+ dataloader = prototype
116
+ elif isinstance(prototype, BaseDataset):
117
+ dataset = prototype
118
+ dataloader = build_dataloader(dataset)
119
+ else:
120
+ raise TypeError(f'Unsupported prototype type {type(prototype)}.')
121
+
122
+ if cache is not None and Path(cache).exists():
123
+ self.model.prototype = cache
124
+ else:
125
+ self.model.prototype = dataloader
126
+ self.model.prepare_prototype()
127
+
128
+ from mmengine.logging import MMLogger
129
+ logger = MMLogger.get_current_instance()
130
+ if cache is None:
131
+ logger.info('The prototype has been prepared, you can use '
132
+ '`save_prototype` to dump it into a pickle '
133
+ 'file for the future usage.')
134
+ elif not Path(cache).exists():
135
+ self.save_prototype(cache)
136
+ logger.info(f'The prototype has been saved at {cache}.')
137
+
138
+ return dataset
139
+
140
+ def save_prototype(self, path):
141
+ self.model.dump_prototype(path)
142
+
143
+ def __call__(self,
144
+ inputs: InputType,
145
+ return_datasamples: bool = False,
146
+ batch_size: int = 1,
147
+ **kwargs) -> dict:
148
+ """Call the inferencer.
149
+
150
+ Args:
151
+ inputs (str | array | list): The image path or array, or a list of
152
+ images.
153
+ return_datasamples (bool): Whether to return results as
154
+ :obj:`DataSample`. Defaults to False.
155
+ batch_size (int): Batch size. Defaults to 1.
156
+ resize (int, optional): Resize the long edge of the image to the
157
+ specified length before visualization. Defaults to None.
158
+ draw_score (bool): Whether to draw the match scores.
159
+ Defaults to True.
160
+ show (bool): Whether to display the visualization result in a
161
+ window. Defaults to False.
162
+ wait_time (float): The display time (s). Defaults to 0, which means
163
+ "forever".
164
+ show_dir (str, optional): If not None, save the visualization
165
+ results in the specified directory. Defaults to None.
166
+
167
+ Returns:
168
+ list: The inference results.
169
+ """
170
+ return super().__call__(inputs, return_datasamples, batch_size,
171
+ **kwargs)
172
+
173
+ def _init_pipeline(self, cfg: Config) -> Callable:
174
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
175
+ if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
176
+ # Image loading is finished in `self.preprocess`.
177
+ test_pipeline_cfg = test_pipeline_cfg[1:]
178
+ test_pipeline = Compose(
179
+ [TRANSFORMS.build(t) for t in test_pipeline_cfg])
180
+ return test_pipeline
181
+
182
+ def preprocess(self, inputs: List[InputType], batch_size: int = 1):
183
+
184
+ def load_image(input_):
185
+ img = imread(input_)
186
+ if img is None:
187
+ raise ValueError(f'Failed to read image {input_}.')
188
+ return dict(
189
+ img=img,
190
+ img_shape=img.shape[:2],
191
+ ori_shape=img.shape[:2],
192
+ )
193
+
194
+ pipeline = Compose([load_image, self.pipeline])
195
+
196
+ chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
197
+ yield from map(default_collate, chunked_data)
198
+
199
+ def visualize(self,
200
+ ori_inputs: List[InputType],
201
+ preds: List[DataSample],
202
+ topk: int = 3,
203
+ resize: Optional[int] = 224,
204
+ show: bool = False,
205
+ wait_time: int = 0,
206
+ draw_score=True,
207
+ show_dir=None):
208
+ if not show and show_dir is None:
209
+ return None
210
+
211
+ if self.visualizer is None:
212
+ from mmpretrain.visualization import UniversalVisualizer
213
+ self.visualizer = UniversalVisualizer()
214
+
215
+ visualization = []
216
+ for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
217
+ image = imread(input_)
218
+ if isinstance(input_, str):
219
+ # The image loaded from path is BGR format.
220
+ image = image[..., ::-1]
221
+ name = Path(input_).stem
222
+ else:
223
+ name = str(i)
224
+
225
+ if show_dir is not None:
226
+ show_dir = Path(show_dir)
227
+ show_dir.mkdir(exist_ok=True)
228
+ out_file = str((show_dir / name).with_suffix('.png'))
229
+ else:
230
+ out_file = None
231
+
232
+ self.visualizer.visualize_image_retrieval(
233
+ image,
234
+ data_sample,
235
+ self.prototype_dataset,
236
+ topk=topk,
237
+ resize=resize,
238
+ draw_score=draw_score,
239
+ show=show,
240
+ wait_time=wait_time,
241
+ name=name,
242
+ out_file=out_file)
243
+ visualization.append(self.visualizer.get_image())
244
+ if show:
245
+ self.visualizer.close()
246
+ return visualization
247
+
248
+ def postprocess(
249
+ self,
250
+ preds: List[DataSample],
251
+ visualization: List[np.ndarray],
252
+ return_datasamples=False,
253
+ topk=1,
254
+ ) -> dict:
255
+ if return_datasamples:
256
+ return preds
257
+
258
+ results = []
259
+ for data_sample in preds:
260
+ match_scores, indices = torch.topk(data_sample.pred_score, k=topk)
261
+ matches = []
262
+ for match_score, sample_idx in zip(match_scores, indices):
263
+ sample = self.prototype_dataset.get_data_info(
264
+ sample_idx.item())
265
+ sample_idx = sample.pop('sample_idx')
266
+ matches.append({
267
+ 'match_score': match_score,
268
+ 'sample_idx': sample_idx,
269
+ 'sample': sample
270
+ })
271
+ results.append(matches)
272
+
273
+ return results
274
+
275
+ @staticmethod
276
+ def list_models(pattern: Optional[str] = None):
277
+ """List all available model names.
278
+
279
+ Args:
280
+ pattern (str | None): A wildcard pattern to match model names.
281
+
282
+ Returns:
283
+ List[str]: a list of model names.
284
+ """
285
+ return list_models(pattern=pattern, task='Image Retrieval')
mmpretrain/apis/model.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import fnmatch
4
+ import os.path as osp
5
+ import re
6
+ import warnings
7
+ from os import PathLike
8
+ from pathlib import Path
9
+ from typing import List, Tuple, Union
10
+
11
+ from mmengine.config import Config
12
+ from modelindex.load_model_index import load
13
+ from modelindex.models.Model import Model
14
+
15
+
16
+ class ModelHub:
17
+ """A hub to host the meta information of all pre-defined models."""
18
+ _models_dict = {}
19
+ __mmpretrain_registered = False
20
+
21
+ @classmethod
22
+ def register_model_index(cls,
23
+ model_index_path: Union[str, PathLike],
24
+ config_prefix: Union[str, PathLike, None] = None):
25
+ """Parse the model-index file and register all models.
26
+
27
+ Args:
28
+ model_index_path (str | PathLike): The path of the model-index
29
+ file.
30
+ config_prefix (str | PathLike | None): The prefix of all config
31
+ file paths in the model-index file.
32
+ """
33
+ model_index = load(str(model_index_path))
34
+ model_index.build_models_with_collections()
35
+
36
+ for metainfo in model_index.models:
37
+ model_name = metainfo.name.lower()
38
+ if metainfo.name in cls._models_dict:
39
+ raise ValueError(
40
+ 'The model name {} is conflict in {} and {}.'.format(
41
+ model_name, osp.abspath(metainfo.filepath),
42
+ osp.abspath(cls._models_dict[model_name].filepath)))
43
+ metainfo.config = cls._expand_config_path(metainfo, config_prefix)
44
+ cls._models_dict[model_name] = metainfo
45
+
46
+ @classmethod
47
+ def get(cls, model_name):
48
+ """Get the model's metainfo by the model name.
49
+
50
+ Args:
51
+ model_name (str): The name of model.
52
+
53
+ Returns:
54
+ modelindex.models.Model: The metainfo of the specified model.
55
+ """
56
+ cls._register_mmpretrain_models()
57
+ # lazy load config
58
+ metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower()))
59
+ if metainfo is None:
60
+ raise ValueError(
61
+ f'Failed to find model "{model_name}". please use '
62
+ '`mmpretrain.list_models` to get all available names.')
63
+ if isinstance(metainfo.config, str):
64
+ metainfo.config = Config.fromfile(metainfo.config)
65
+ return metainfo
66
+
67
+ @staticmethod
68
+ def _expand_config_path(metainfo: Model,
69
+ config_prefix: Union[str, PathLike] = None):
70
+ if config_prefix is None:
71
+ config_prefix = osp.dirname(metainfo.filepath)
72
+
73
+ if metainfo.config is None or osp.isabs(metainfo.config):
74
+ config_path: str = metainfo.config
75
+ else:
76
+ config_path = osp.abspath(osp.join(config_prefix, metainfo.config))
77
+
78
+ return config_path
79
+
80
+ @classmethod
81
+ def _register_mmpretrain_models(cls):
82
+ # register models in mmpretrain
83
+ if not cls.__mmpretrain_registered:
84
+ from importlib_metadata import distribution
85
+ root = distribution('mmpretrain').locate_file('mmpretrain')
86
+ model_index_path = root / '.mim' / 'model-index.yml'
87
+ ModelHub.register_model_index(
88
+ model_index_path, config_prefix=root / '.mim')
89
+ cls.__mmpretrain_registered = True
90
+
91
+ @classmethod
92
+ def has(cls, model_name):
93
+ """Whether a model name is in the ModelHub."""
94
+ return model_name in cls._models_dict
95
+
96
+
97
+ def get_model(model: Union[str, Config],
98
+ pretrained: Union[str, bool] = False,
99
+ device=None,
100
+ device_map=None,
101
+ offload_folder=None,
102
+ url_mapping: Tuple[str, str] = None,
103
+ **kwargs):
104
+ """Get a pre-defined model or create a model from config.
105
+
106
+ Args:
107
+ model (str | Config): The name of model, the config file path or a
108
+ config instance.
109
+ pretrained (bool | str): When use name to specify model, you can
110
+ use ``True`` to load the pre-defined pretrained weights. And you
111
+ can also use a string to specify the path or link of weights to
112
+ load. Defaults to False.
113
+ device (str | torch.device | None): Transfer the model to the target
114
+ device. Defaults to None.
115
+ device_map (str | dict | None): A map that specifies where each
116
+ submodule should go. It doesn't need to be refined to each
117
+ parameter/buffer name, once a given module name is inside, every
118
+ submodule of it will be sent to the same device. You can use
119
+ `device_map="auto"` to automatically generate the device map.
120
+ Defaults to None.
121
+ offload_folder (str | None): If the `device_map` contains any value
122
+ `"disk"`, the folder where we will offload weights.
123
+ url_mapping (Tuple[str, str], optional): The mapping of pretrained
124
+ checkpoint link. For example, load checkpoint from a local dir
125
+ instead of download by ``('https://.*/', './checkpoint')``.
126
+ Defaults to None.
127
+ **kwargs: Other keyword arguments of the model config.
128
+
129
+ Returns:
130
+ mmengine.model.BaseModel: The result model.
131
+
132
+ Examples:
133
+ Get a ResNet-50 model and extract images feature:
134
+
135
+ >>> import torch
136
+ >>> from mmpretrain import get_model
137
+ >>> inputs = torch.rand(16, 3, 224, 224)
138
+ >>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)))
139
+ >>> feats = model.extract_feat(inputs)
140
+ >>> for feat in feats:
141
+ ... print(feat.shape)
142
+ torch.Size([16, 256])
143
+ torch.Size([16, 512])
144
+ torch.Size([16, 1024])
145
+ torch.Size([16, 2048])
146
+
147
+ Get Swin-Transformer model with pre-trained weights and inference:
148
+
149
+ >>> from mmpretrain import get_model, inference_model
150
+ >>> model = get_model('swin-base_16xb64_in1k', pretrained=True)
151
+ >>> result = inference_model(model, 'demo/demo.JPEG')
152
+ >>> print(result['pred_class'])
153
+ 'sea snake'
154
+ """ # noqa: E501
155
+ if device_map is not None:
156
+ from .utils import dispatch_model
157
+ dispatch_model._verify_require()
158
+
159
+ metainfo = None
160
+ if isinstance(model, Config):
161
+ config = copy.deepcopy(model)
162
+ if pretrained is True and 'load_from' in config:
163
+ pretrained = config.load_from
164
+ elif isinstance(model, (str, PathLike)) and Path(model).suffix == '.py':
165
+ config = Config.fromfile(model)
166
+ if pretrained is True and 'load_from' in config:
167
+ pretrained = config.load_from
168
+ elif isinstance(model, str):
169
+ metainfo = ModelHub.get(model)
170
+ config = metainfo.config
171
+ if pretrained is True and metainfo.weights is not None:
172
+ pretrained = metainfo.weights
173
+ else:
174
+ raise TypeError('model must be a name, a path or a Config object, '
175
+ f'but got {type(config)}')
176
+
177
+ if pretrained is True:
178
+ warnings.warn('Unable to find pre-defined checkpoint of the model.')
179
+ pretrained = None
180
+ elif pretrained is False:
181
+ pretrained = None
182
+
183
+ if kwargs:
184
+ config.merge_from_dict({'model': kwargs})
185
+ config.model.setdefault('data_preprocessor',
186
+ config.get('data_preprocessor', None))
187
+
188
+ from mmengine.registry import DefaultScope
189
+
190
+ from mmpretrain.registry import MODELS
191
+ with DefaultScope.overwrite_default_scope('mmpretrain'):
192
+ model = MODELS.build(config.model)
193
+
194
+ dataset_meta = {}
195
+ if pretrained:
196
+ # Mapping the weights to GPU may cause unexpected video memory leak
197
+ # which refers to https://github.com/open-mmlab/mmdetection/pull/6405
198
+ from mmengine.runner import load_checkpoint
199
+ if url_mapping is not None:
200
+ pretrained = re.sub(url_mapping[0], url_mapping[1], pretrained)
201
+ checkpoint = load_checkpoint(model, pretrained, map_location='cpu')
202
+ if 'dataset_meta' in checkpoint.get('meta', {}):
203
+ # mmpretrain 1.x
204
+ dataset_meta = checkpoint['meta']['dataset_meta']
205
+ elif 'CLASSES' in checkpoint.get('meta', {}):
206
+ # mmcls 0.x
207
+ dataset_meta = {'classes': checkpoint['meta']['CLASSES']}
208
+
209
+ if len(dataset_meta) == 0 and 'test_dataloader' in config:
210
+ from mmpretrain.registry import DATASETS
211
+ dataset_class = DATASETS.get(config.test_dataloader.dataset.type)
212
+ dataset_meta = getattr(dataset_class, 'METAINFO', {})
213
+
214
+ if device_map is not None:
215
+ model = dispatch_model(
216
+ model, device_map=device_map, offload_folder=offload_folder)
217
+ elif device is not None:
218
+ model.to(device)
219
+
220
+ model._dataset_meta = dataset_meta # save the dataset meta
221
+ model._config = config # save the config in the model
222
+ model._metainfo = metainfo # save the metainfo in the model
223
+ model.eval()
224
+ return model
225
+
226
+
227
+ def init_model(config, checkpoint=None, device=None, **kwargs):
228
+ """Initialize a classifier from config file (deprecated).
229
+
230
+ It's only for compatibility, please use :func:`get_model` instead.
231
+
232
+ Args:
233
+ config (str | :obj:`mmengine.Config`): Config file path or the config
234
+ object.
235
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
236
+ will not load any weights.
237
+ device (str | torch.device | None): Transfer the model to the target
238
+ device. Defaults to None.
239
+ **kwargs: Other keyword arguments of the model config.
240
+
241
+ Returns:
242
+ nn.Module: The constructed model.
243
+ """
244
+ return get_model(config, checkpoint, device, **kwargs)
245
+
246
+
247
+ def list_models(pattern=None, exclude_patterns=None, task=None) -> List[str]:
248
+ """List all models available in MMPretrain.
249
+
250
+ Args:
251
+ pattern (str | None): A wildcard pattern to match model names.
252
+ Defaults to None.
253
+ exclude_patterns (list | None): A list of wildcard patterns to
254
+ exclude names from the matched names. Defaults to None.
255
+ task (str | none): The evaluation task of the model.
256
+
257
+ Returns:
258
+ List[str]: a list of model names.
259
+
260
+ Examples:
261
+ List all models:
262
+
263
+ >>> from mmpretrain import list_models
264
+ >>> list_models()
265
+
266
+ List ResNet-50 models on ImageNet-1k dataset:
267
+
268
+ >>> from mmpretrain import list_models
269
+ >>> list_models('resnet*in1k')
270
+ ['resnet50_8xb32_in1k',
271
+ 'resnet50_8xb32-fp16_in1k',
272
+ 'resnet50_8xb256-rsb-a1-600e_in1k',
273
+ 'resnet50_8xb256-rsb-a2-300e_in1k',
274
+ 'resnet50_8xb256-rsb-a3-100e_in1k']
275
+
276
+ List Swin-Transformer models trained from stratch and exclude
277
+ Swin-Transformer-V2 models:
278
+
279
+ >>> from mmpretrain import list_models
280
+ >>> list_models('swin', exclude_patterns=['swinv2', '*-pre'])
281
+ ['swin-base_16xb64_in1k',
282
+ 'swin-base_3rdparty_in1k',
283
+ 'swin-base_3rdparty_in1k-384',
284
+ 'swin-large_8xb8_cub-384px',
285
+ 'swin-small_16xb64_in1k',
286
+ 'swin-small_3rdparty_in1k',
287
+ 'swin-tiny_16xb64_in1k',
288
+ 'swin-tiny_3rdparty_in1k']
289
+
290
+ List all EVA models for image classification task.
291
+
292
+ >>> from mmpretrain import list_models
293
+ >>> list_models('eva', task='Image Classification')
294
+ ['eva-g-p14_30m-in21k-pre_3rdparty_in1k-336px',
295
+ 'eva-g-p14_30m-in21k-pre_3rdparty_in1k-560px',
296
+ 'eva-l-p14_mim-in21k-pre_3rdparty_in1k-196px',
297
+ 'eva-l-p14_mim-in21k-pre_3rdparty_in1k-336px',
298
+ 'eva-l-p14_mim-pre_3rdparty_in1k-196px',
299
+ 'eva-l-p14_mim-pre_3rdparty_in1k-336px']
300
+ """
301
+ ModelHub._register_mmpretrain_models()
302
+ matches = set(ModelHub._models_dict.keys())
303
+
304
+ if pattern is not None:
305
+ # Always match keys with any postfix.
306
+ matches = set(fnmatch.filter(matches, pattern + '*'))
307
+
308
+ exclude_patterns = exclude_patterns or []
309
+ for exclude_pattern in exclude_patterns:
310
+ exclude = set(fnmatch.filter(matches, exclude_pattern + '*'))
311
+ matches = matches - exclude
312
+
313
+ if task is not None:
314
+ task_matches = []
315
+ for key in matches:
316
+ metainfo = ModelHub._models_dict[key]
317
+ if metainfo.results is None and task == 'null':
318
+ task_matches.append(key)
319
+ elif metainfo.results is None:
320
+ continue
321
+ elif task in [result.task for result in metainfo.results]:
322
+ task_matches.append(key)
323
+ matches = task_matches
324
+
325
+ return sorted(list(matches))
326
+
327
+
328
+ def inference_model(model, *args, **kwargs):
329
+ """Inference an image with the inferencer.
330
+
331
+ Automatically select inferencer to inference according to the type of
332
+ model. It's a shortcut for a quick start, and for advanced usage, please
333
+ use the correspondding inferencer class.
334
+
335
+ Here is the mapping from task to inferencer:
336
+
337
+ - Image Classification: :class:`ImageClassificationInferencer`
338
+ - Image Retrieval: :class:`ImageRetrievalInferencer`
339
+ - Image Caption: :class:`ImageCaptionInferencer`
340
+ - Visual Question Answering: :class:`VisualQuestionAnsweringInferencer`
341
+ - Visual Grounding: :class:`VisualGroundingInferencer`
342
+ - Text-To-Image Retrieval: :class:`TextToImageRetrievalInferencer`
343
+ - Image-To-Text Retrieval: :class:`ImageToTextRetrievalInferencer`
344
+ - NLVR: :class:`NLVRInferencer`
345
+
346
+ Args:
347
+ model (BaseModel | str | Config): The loaded model, the model
348
+ name or the config of the model.
349
+ *args: Positional arguments to call the inferencer.
350
+ **kwargs: Other keyword arguments to initialize and call the
351
+ correspondding inferencer.
352
+
353
+ Returns:
354
+ result (dict): The inference results.
355
+ """ # noqa: E501
356
+ from mmengine.model import BaseModel
357
+
358
+ if isinstance(model, BaseModel):
359
+ metainfo = getattr(model, '_metainfo', None)
360
+ else:
361
+ metainfo = ModelHub.get(model)
362
+
363
+ from inspect import signature
364
+
365
+ from .image_caption import ImageCaptionInferencer
366
+ from .image_classification import ImageClassificationInferencer
367
+ from .image_retrieval import ImageRetrievalInferencer
368
+ from .multimodal_retrieval import (ImageToTextRetrievalInferencer,
369
+ TextToImageRetrievalInferencer)
370
+ from .nlvr import NLVRInferencer
371
+ from .visual_grounding import VisualGroundingInferencer
372
+ from .visual_question_answering import VisualQuestionAnsweringInferencer
373
+ task_mapping = {
374
+ 'Image Classification': ImageClassificationInferencer,
375
+ 'Image Retrieval': ImageRetrievalInferencer,
376
+ 'Image Caption': ImageCaptionInferencer,
377
+ 'Visual Question Answering': VisualQuestionAnsweringInferencer,
378
+ 'Visual Grounding': VisualGroundingInferencer,
379
+ 'Text-To-Image Retrieval': TextToImageRetrievalInferencer,
380
+ 'Image-To-Text Retrieval': ImageToTextRetrievalInferencer,
381
+ 'NLVR': NLVRInferencer,
382
+ }
383
+
384
+ inferencer_type = None
385
+
386
+ if metainfo is not None and metainfo.results is not None:
387
+ tasks = set(result.task for result in metainfo.results)
388
+ inferencer_type = [
389
+ task_mapping.get(task) for task in tasks if task in task_mapping
390
+ ]
391
+ if len(inferencer_type) > 1:
392
+ inferencer_names = [cls.__name__ for cls in inferencer_type]
393
+ warnings.warn('The model supports multiple tasks, auto select '
394
+ f'{inferencer_names[0]}, you can also use other '
395
+ f'inferencer {inferencer_names} directly.')
396
+ inferencer_type = inferencer_type[0]
397
+
398
+ if inferencer_type is None:
399
+ raise NotImplementedError('No available inferencer for the model')
400
+
401
+ init_kwargs = {
402
+ k: kwargs.pop(k)
403
+ for k in list(kwargs)
404
+ if k in signature(inferencer_type).parameters.keys()
405
+ }
406
+
407
+ inferencer = inferencer_type(model, **init_kwargs)
408
+ return inferencer(*args, **kwargs)[0]
mmpretrain/apis/multimodal_retrieval.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from copy import deepcopy
3
+ from pathlib import Path
4
+ from typing import Callable, List, Optional, Tuple, Union
5
+
6
+ import mmengine
7
+ import numpy as np
8
+ import torch
9
+ from mmcv.image import imread
10
+ from mmengine.config import Config
11
+ from mmengine.dataset import BaseDataset, Compose, default_collate
12
+
13
+ from mmpretrain.registry import TRANSFORMS
14
+ from mmpretrain.structures import DataSample
15
+ from mmpretrain.utils import track
16
+ from .base import BaseInferencer
17
+ from .base import InputType as ImageType
18
+ from .base import ModelType
19
+ from .model import list_models
20
+
21
+
22
+ def filter_transforms(transforms: list, data_info: dict):
23
+ """Filter pipeline to avoid KeyError with partial data info."""
24
+ data_info = deepcopy(data_info)
25
+ filtered_transforms = []
26
+ for t in transforms:
27
+ try:
28
+ data_info = t(data_info)
29
+ filtered_transforms.append(t)
30
+ except KeyError:
31
+ pass
32
+ return filtered_transforms
33
+
34
+
35
+ class TextToImageRetrievalInferencer(BaseInferencer):
36
+ """The inferencer for text to image retrieval.
37
+
38
+ Args:
39
+ model (BaseModel | str | Config): A model name or a path to the config
40
+ file, or a :obj:`BaseModel` object. The model name can be found
41
+ by ``TextToImageRetrievalInferencer.list_models()`` and you can also
42
+ query it in :doc:`/modelzoo_statistics`.
43
+ prototype (str | list | dict | DataLoader | BaseDataset): The images to
44
+ be retrieved. It can be the following types:
45
+
46
+ - str: The directory of the the images.
47
+ - list: A list of path of the images.
48
+ - dict: A config dict of the a prototype dataset.
49
+ - BaseDataset: A prototype dataset.
50
+ - DataLoader: A data loader to load the prototype data.
51
+
52
+ prototype_cache (str, optional): The path of the generated prototype
53
+ features. If exists, directly load the cache instead of re-generate
54
+ the prototype features. If not exists, save the generated features
55
+ to the path. Defaults to None.
56
+ fast_match (bool): Some algorithms will record extra image features for
57
+ further matching, which may consume large memory, set True to avoid
58
+ this behavior. Defaults to True.
59
+ pretrained (str, optional): Path to the checkpoint. If None, it will
60
+ try to find a pre-defined weight from the model you specified
61
+ (only work if the ``model`` is a model name). Defaults to None.
62
+ device (str, optional): Device to run inference. If None, the available
63
+ device will be automatically used. Defaults to None.
64
+ **kwargs: Other keyword arguments to initialize the model (only work if
65
+ the ``model`` is a model name).
66
+
67
+ Example:
68
+ >>> from mmpretrain import TextToImageRetrievalInferencer
69
+ >>> inferencer = TextToImageRetrievalInferencer(
70
+ ... 'blip-base_3rdparty_retrieval',
71
+ ... prototype='./demo/',
72
+ ... prototype_cache='t2i_retri.pth')
73
+ >>> inferencer('A cat and a dog.')[0]
74
+ {'match_score': tensor(0.3855, device='cuda:0'),
75
+ 'sample_idx': 1,
76
+ 'sample': {'img_path': './demo/cat-dog.png'}}
77
+ """ # noqa: E501
78
+
79
+ visualize_kwargs: set = {
80
+ 'draw_score', 'show_dir', 'show', 'wait_time', 'figsize', 'topk'
81
+ }
82
+ postprocess_kwargs: set = {'topk'}
83
+
84
+ def __init__(self,
85
+ model: ModelType,
86
+ prototype,
87
+ prototype_cache=None,
88
+ fast_match=True,
89
+ prepare_batch_size=8,
90
+ pretrained: Union[bool, str] = True,
91
+ device: Union[str, torch.device, None] = None,
92
+ **kwargs) -> None:
93
+ super().__init__(
94
+ model=model, pretrained=pretrained, device=device, **kwargs)
95
+
96
+ self.img_pipeline, self.text_pipeline = self.pipeline
97
+
98
+ if hasattr(self.model, 'fast_match'):
99
+ self.model.fast_match = fast_match
100
+
101
+ self.prototype_dataset = self._prepare_prototype(
102
+ prototype, prototype_cache, batch_size=prepare_batch_size)
103
+
104
+ def _prepare_prototype(self, prototype, cache=None, batch_size=8):
105
+ from mmengine.dataset import DefaultSampler
106
+ from torch.utils.data import DataLoader
107
+
108
+ def build_dataloader(dataset):
109
+ return DataLoader(
110
+ dataset,
111
+ batch_size=batch_size,
112
+ collate_fn=default_collate,
113
+ sampler=DefaultSampler(dataset, shuffle=False),
114
+ persistent_workers=False,
115
+ )
116
+
117
+ if isinstance(prototype, str):
118
+ # A directory path of images
119
+ prototype = dict(
120
+ type='CustomDataset', with_label=False, data_root=prototype)
121
+
122
+ if isinstance(prototype, list):
123
+ test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline]
124
+ dataset = BaseDataset(
125
+ lazy_init=True, serialize_data=False, pipeline=test_pipeline)
126
+ dataset.data_list = [{
127
+ 'sample_idx': i,
128
+ 'img_path': file
129
+ } for i, file in enumerate(prototype)]
130
+ dataset._fully_initialized = True
131
+ dataloader = build_dataloader(dataset)
132
+ elif isinstance(prototype, dict):
133
+ # A config of dataset
134
+ from mmpretrain.registry import DATASETS
135
+ test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline]
136
+ prototype.setdefault('pipeline', test_pipeline)
137
+ dataset = DATASETS.build(prototype)
138
+ dataloader = build_dataloader(dataset)
139
+ elif isinstance(prototype, list):
140
+ test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline]
141
+ dataset = BaseDataset(
142
+ lazy_init=True, serialize_data=False, pipeline=test_pipeline)
143
+ dataset.data_list = [{
144
+ 'sample_idx': i,
145
+ 'img_path': file
146
+ } for i, file in enumerate(prototype)]
147
+ dataset._fully_initialized = True
148
+ dataloader = build_dataloader(dataset)
149
+ elif isinstance(prototype, DataLoader):
150
+ dataset = prototype.dataset
151
+ dataloader = prototype
152
+ elif isinstance(prototype, BaseDataset):
153
+ dataset = prototype
154
+ dataloader = build_dataloader(dataset)
155
+ else:
156
+ raise TypeError(f'Unsupported prototype type {type(prototype)}.')
157
+
158
+ if cache is not None and Path(cache).exists():
159
+ self.prototype = torch.load(cache)
160
+ else:
161
+ prototype = []
162
+ for data_batch in track(dataloader, 'Prepare prototype...'):
163
+ with torch.no_grad():
164
+ data_batch = self.model.data_preprocessor(
165
+ data_batch, False)
166
+ feats = self.model._run_forward(data_batch, mode='tensor')
167
+ prototype.append(feats)
168
+ prototype = {
169
+ k: torch.cat([d[k] for d in prototype])
170
+ for k in prototype[0]
171
+ }
172
+ self.prototype = prototype
173
+
174
+ from mmengine.logging import MMLogger
175
+ logger = MMLogger.get_current_instance()
176
+ if cache is None:
177
+ logger.info('The prototype has been prepared, you can use '
178
+ '`save_prototype` to dump it into a pickle '
179
+ 'file for the future usage.')
180
+ elif not Path(cache).exists():
181
+ self.save_prototype(cache)
182
+ logger.info(f'The prototype has been saved at {cache}.')
183
+
184
+ return dataset
185
+
186
+ def save_prototype(self, path):
187
+ torch.save(self.prototype, path)
188
+
189
+ def __call__(self,
190
+ inputs: ImageType,
191
+ return_datasamples: bool = False,
192
+ batch_size: int = 1,
193
+ **kwargs) -> dict:
194
+ """Call the inferencer.
195
+
196
+ Args:
197
+ inputs (str | array | list): The image path or array, or a list of
198
+ images.
199
+ return_datasamples (bool): Whether to return results as
200
+ :obj:`DataSample`. Defaults to False.
201
+ batch_size (int): Batch size. Defaults to 1.
202
+ resize (int, optional): Resize the long edge of the image to the
203
+ specified length before visualization. Defaults to None.
204
+ draw_score (bool): Whether to draw the match scores.
205
+ Defaults to True.
206
+ show (bool): Whether to display the visualization result in a
207
+ window. Defaults to False.
208
+ wait_time (float): The display time (s). Defaults to 0, which means
209
+ "forever".
210
+ show_dir (str, optional): If not None, save the visualization
211
+ results in the specified directory. Defaults to None.
212
+
213
+ Returns:
214
+ list: The inference results.
215
+ """
216
+ return super().__call__(inputs, return_datasamples, batch_size,
217
+ **kwargs)
218
+
219
+ @torch.no_grad()
220
+ def forward(self, data: dict, **kwargs):
221
+ """Feed the inputs to the model."""
222
+ data = self.model.data_preprocessor(data, False)
223
+ data_samples = data['data_samples']
224
+ feats = self.prototype.copy()
225
+ feats.update(self.model.extract_feat(data_samples=data_samples))
226
+ return self.model.predict_all(feats, data_samples, cal_i2t=False)[0]
227
+
228
+ def _init_pipeline(self, cfg: Config) -> Callable:
229
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
230
+ test_transfroms = [TRANSFORMS.build(t) for t in test_pipeline_cfg]
231
+ img_info = {'img': np.zeros((224, 224, 3), dtype=np.uint8)}
232
+ text_info = {'text': 'example'}
233
+ img_pipeline = Compose(filter_transforms(test_transfroms, img_info))
234
+ text_pipeline = Compose(filter_transforms(test_transfroms, text_info))
235
+ return img_pipeline, text_pipeline
236
+
237
+ def preprocess(self, inputs: List[str], batch_size: int = 1):
238
+
239
+ def process_text(input_: str):
240
+ return self.text_pipeline({'text': input_})
241
+
242
+ chunked_data = self._get_chunk_data(
243
+ map(process_text, inputs), batch_size)
244
+ yield from map(default_collate, chunked_data)
245
+
246
+ def visualize(self,
247
+ ori_inputs: List[str],
248
+ preds: List[DataSample],
249
+ topk: int = 3,
250
+ figsize: Tuple[int, int] = (16, 9),
251
+ show: bool = False,
252
+ wait_time: int = 0,
253
+ draw_score=True,
254
+ show_dir=None):
255
+ if not show and show_dir is None:
256
+ return None
257
+
258
+ if self.visualizer is None:
259
+ from mmpretrain.visualization import UniversalVisualizer
260
+ self.visualizer = UniversalVisualizer()
261
+
262
+ visualization = []
263
+ for i, (text, data_sample) in enumerate(zip(ori_inputs, preds)):
264
+ name = str(i)
265
+
266
+ if show_dir is not None:
267
+ show_dir = Path(show_dir)
268
+ show_dir.mkdir(exist_ok=True)
269
+ out_file = str((show_dir / name).with_suffix('.png'))
270
+ else:
271
+ out_file = None
272
+
273
+ self.visualizer.visualize_t2i_retrieval(
274
+ text,
275
+ data_sample,
276
+ self.prototype_dataset,
277
+ topk=topk,
278
+ fig_cfg=dict(figsize=figsize),
279
+ draw_score=draw_score,
280
+ show=show,
281
+ wait_time=wait_time,
282
+ name=name,
283
+ out_file=out_file)
284
+ visualization.append(self.visualizer.get_image())
285
+ if show:
286
+ self.visualizer.close()
287
+ return visualization
288
+
289
+ def postprocess(
290
+ self,
291
+ preds: List[DataSample],
292
+ visualization: List[np.ndarray],
293
+ return_datasamples=False,
294
+ topk=1,
295
+ ) -> dict:
296
+ if return_datasamples:
297
+ return preds
298
+
299
+ results = []
300
+ for data_sample in preds:
301
+ match_scores, indices = torch.topk(data_sample.pred_score, k=topk)
302
+ matches = []
303
+ for match_score, sample_idx in zip(match_scores, indices):
304
+ sample = self.prototype_dataset.get_data_info(
305
+ sample_idx.item())
306
+ sample_idx = sample.pop('sample_idx')
307
+ matches.append({
308
+ 'match_score': match_score,
309
+ 'sample_idx': sample_idx,
310
+ 'sample': sample
311
+ })
312
+ results.append(matches)
313
+
314
+ return results
315
+
316
+ @staticmethod
317
+ def list_models(pattern: Optional[str] = None):
318
+ """List all available model names.
319
+
320
+ Args:
321
+ pattern (str | None): A wildcard pattern to match model names.
322
+
323
+ Returns:
324
+ List[str]: a list of model names.
325
+ """
326
+ return list_models(pattern=pattern, task='Text-To-Image Retrieval')
327
+
328
+
329
+ class ImageToTextRetrievalInferencer(BaseInferencer):
330
+ """The inferencer for image to text retrieval.
331
+
332
+ Args:
333
+ model (BaseModel | str | Config): A model name or a path to the config
334
+ file, or a :obj:`BaseModel` object. The model name can be found
335
+ by ``ImageToTextRetrievalInferencer.list_models()`` and you can
336
+ also query it in :doc:`/modelzoo_statistics`.
337
+ prototype (str | list | dict | DataLoader, BaseDataset): The images to
338
+ be retrieved. It can be the following types:
339
+
340
+ - str: The file path to load the string list.
341
+ - list: A list of string.
342
+
343
+ prototype_cache (str, optional): The path of the generated prototype
344
+ features. If exists, directly load the cache instead of re-generate
345
+ the prototype features. If not exists, save the generated features
346
+ to the path. Defaults to None.
347
+ fast_match (bool): Some algorithms will record extra image features for
348
+ further matching, which may consume large memory, set True to avoid
349
+ this behavior. Defaults to True.
350
+ pretrained (str, optional): Path to the checkpoint. If None, it will
351
+ try to find a pre-defined weight from the model you specified
352
+ (only work if the ``model`` is a model name). Defaults to None.
353
+ device (str, optional): Device to run inference. If None, the available
354
+ device will be automatically used. Defaults to None.
355
+ **kwargs: Other keyword arguments to initialize the model (only work if
356
+ the ``model`` is a model name).
357
+
358
+ Example:
359
+ >>> from mmpretrain import ImageToTextRetrievalInferencer
360
+ >>> inferencer = ImageToTextRetrievalInferencer(
361
+ ... 'blip-base_3rdparty_retrieval',
362
+ ... prototype=['cat', 'dog', 'snake', 'bird'],
363
+ ... prototype_cache='i2t_retri.pth')
364
+ >>> inferencer('demo/bird.JPEG')[0]
365
+ {'match_score': tensor(0.3855, device='cuda:0'),
366
+ 'sample_idx': 1,
367
+ 'sample': {'img_path': './demo/cat-dog.png'}}
368
+ """ # noqa: E501
369
+
370
+ visualize_kwargs: set = {
371
+ 'draw_score', 'resize', 'show_dir', 'show', 'wait_time', 'topk'
372
+ }
373
+ postprocess_kwargs: set = {'topk'}
374
+
375
+ def __init__(self,
376
+ model: ModelType,
377
+ prototype,
378
+ prototype_cache=None,
379
+ fast_match=True,
380
+ prepare_batch_size=8,
381
+ pretrained: Union[bool, str] = True,
382
+ device: Union[str, torch.device, None] = None,
383
+ **kwargs) -> None:
384
+ super().__init__(
385
+ model=model, pretrained=pretrained, device=device, **kwargs)
386
+
387
+ self.img_pipeline, self.text_pipeline = self.pipeline
388
+
389
+ if hasattr(self.model, 'fast_match'):
390
+ self.model.fast_match = fast_match
391
+
392
+ self.prototype_dataset = self._prepare_prototype(
393
+ prototype, cache=prototype_cache, batch_size=prepare_batch_size)
394
+
395
+ def _prepare_prototype(self, prototype, cache=None, batch_size=8):
396
+ from mmengine.dataset import DefaultSampler
397
+ from torch.utils.data import DataLoader
398
+
399
+ def build_dataloader(dataset):
400
+ return DataLoader(
401
+ [
402
+ self.text_pipeline({
403
+ 'sample_idx': i,
404
+ 'text': text
405
+ }) for i, text in enumerate(dataset)
406
+ ],
407
+ batch_size=batch_size,
408
+ collate_fn=default_collate,
409
+ sampler=DefaultSampler(dataset, shuffle=False),
410
+ persistent_workers=False,
411
+ )
412
+
413
+ if isinstance(prototype, str):
414
+ # A file path of a list of string
415
+ dataset = mmengine.list_from_file(prototype)
416
+ elif mmengine.utils.is_seq_of(prototype, str):
417
+ dataset = prototype
418
+ else:
419
+ raise TypeError(f'Unsupported prototype type {type(prototype)}.')
420
+
421
+ dataloader = build_dataloader(dataset)
422
+
423
+ if cache is not None and Path(cache).exists():
424
+ self.prototype = torch.load(cache)
425
+ else:
426
+ prototype = []
427
+ for data_batch in track(dataloader, 'Prepare prototype...'):
428
+ with torch.no_grad():
429
+ data_batch = self.model.data_preprocessor(
430
+ data_batch, False)
431
+ feats = self.model._run_forward(data_batch, mode='tensor')
432
+ prototype.append(feats)
433
+ prototype = {
434
+ k: torch.cat([d[k] for d in prototype])
435
+ for k in prototype[0]
436
+ }
437
+ self.prototype = prototype
438
+
439
+ from mmengine.logging import MMLogger
440
+ logger = MMLogger.get_current_instance()
441
+ if cache is None:
442
+ logger.info('The prototype has been prepared, you can use '
443
+ '`save_prototype` to dump it into a pickle '
444
+ 'file for the future usage.')
445
+ elif not Path(cache).exists():
446
+ self.save_prototype(cache)
447
+ logger.info(f'The prototype has been saved at {cache}.')
448
+
449
+ return dataset
450
+
451
+ def save_prototype(self, path):
452
+ torch.save(self.prototype, path)
453
+
454
+ def __call__(self,
455
+ inputs: ImageType,
456
+ return_datasamples: bool = False,
457
+ batch_size: int = 1,
458
+ **kwargs) -> dict:
459
+ """Call the inferencer.
460
+
461
+ Args:
462
+ inputs (str | array | list): The image path or array, or a list of
463
+ images.
464
+ return_datasamples (bool): Whether to return results as
465
+ :obj:`DataSample`. Defaults to False.
466
+ batch_size (int): Batch size. Defaults to 1.
467
+ resize (int, optional): Resize the long edge of the image to the
468
+ specified length before visualization. Defaults to None.
469
+ draw_score (bool): Whether to draw the match scores.
470
+ Defaults to True.
471
+ show (bool): Whether to display the visualization result in a
472
+ window. Defaults to False.
473
+ wait_time (float): The display time (s). Defaults to 0, which means
474
+ "forever".
475
+ show_dir (str, optional): If not None, save the visualization
476
+ results in the specified directory. Defaults to None.
477
+
478
+ Returns:
479
+ list: The inference results.
480
+ """
481
+ return super().__call__(inputs, return_datasamples, batch_size,
482
+ **kwargs)
483
+
484
+ @torch.no_grad()
485
+ def forward(self, data: dict, **kwargs):
486
+ """Feed the inputs to the model."""
487
+ data = self.model.data_preprocessor(data, False)
488
+ feats = self.prototype.copy()
489
+ feats.update(self.model.extract_feat(images=data['images']))
490
+ return self.model.predict_all(
491
+ feats, data['data_samples'], cal_t2i=False)[0]
492
+
493
+ def _init_pipeline(self, cfg: Config) -> Callable:
494
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
495
+ test_transfroms = [TRANSFORMS.build(t) for t in test_pipeline_cfg]
496
+ img_info = {'img': np.zeros((224, 224, 3), dtype=np.uint8)}
497
+ text_info = {'text': 'example'}
498
+ img_pipeline = Compose(filter_transforms(test_transfroms, img_info))
499
+ text_pipeline = Compose(filter_transforms(test_transfroms, text_info))
500
+ return img_pipeline, text_pipeline
501
+
502
+ def preprocess(self, inputs: List[ImageType], batch_size: int = 1):
503
+
504
+ def load_image(input_):
505
+ img = imread(input_)
506
+ if img is None:
507
+ raise ValueError(f'Failed to read image {input_}.')
508
+ return dict(
509
+ img=img,
510
+ img_shape=img.shape[:2],
511
+ ori_shape=img.shape[:2],
512
+ )
513
+
514
+ pipeline = Compose([load_image, self.img_pipeline])
515
+
516
+ chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
517
+ yield from map(default_collate, chunked_data)
518
+
519
+ def visualize(self,
520
+ ori_inputs: List[ImageType],
521
+ preds: List[DataSample],
522
+ topk: int = 3,
523
+ resize: Optional[int] = 224,
524
+ show: bool = False,
525
+ wait_time: int = 0,
526
+ draw_score=True,
527
+ show_dir=None):
528
+ if not show and show_dir is None:
529
+ return None
530
+
531
+ if self.visualizer is None:
532
+ from mmpretrain.visualization import UniversalVisualizer
533
+ self.visualizer = UniversalVisualizer()
534
+
535
+ visualization = []
536
+ for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
537
+ image = imread(input_)
538
+ if isinstance(input_, str):
539
+ # The image loaded from path is BGR format.
540
+ image = image[..., ::-1]
541
+ name = Path(input_).stem
542
+ else:
543
+ name = str(i)
544
+
545
+ if show_dir is not None:
546
+ show_dir = Path(show_dir)
547
+ show_dir.mkdir(exist_ok=True)
548
+ out_file = str((show_dir / name).with_suffix('.png'))
549
+ else:
550
+ out_file = None
551
+
552
+ self.visualizer.visualize_i2t_retrieval(
553
+ image,
554
+ data_sample,
555
+ self.prototype_dataset,
556
+ topk=topk,
557
+ resize=resize,
558
+ draw_score=draw_score,
559
+ show=show,
560
+ wait_time=wait_time,
561
+ name=name,
562
+ out_file=out_file)
563
+ visualization.append(self.visualizer.get_image())
564
+ if show:
565
+ self.visualizer.close()
566
+ return visualization
567
+
568
+ def postprocess(
569
+ self,
570
+ preds: List[DataSample],
571
+ visualization: List[np.ndarray],
572
+ return_datasamples=False,
573
+ topk=1,
574
+ ) -> dict:
575
+ if return_datasamples:
576
+ return preds
577
+
578
+ results = []
579
+ for data_sample in preds:
580
+ match_scores, indices = torch.topk(data_sample.pred_score, k=topk)
581
+ matches = []
582
+ for match_score, sample_idx in zip(match_scores, indices):
583
+ text = self.prototype_dataset[sample_idx.item()]
584
+ matches.append({
585
+ 'match_score': match_score,
586
+ 'sample_idx': sample_idx,
587
+ 'text': text
588
+ })
589
+ results.append(matches)
590
+
591
+ return results
592
+
593
+ @staticmethod
594
+ def list_models(pattern: Optional[str] = None):
595
+ """List all available model names.
596
+
597
+ Args:
598
+ pattern (str | None): A wildcard pattern to match model names.
599
+
600
+ Returns:
601
+ List[str]: a list of model names.
602
+ """
603
+ return list_models(pattern=pattern, task='Image-To-Text Retrieval')
mmpretrain/apis/nlvr.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from copy import deepcopy
3
+ from typing import Callable, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from mmcv.image import imread
8
+ from mmengine.config import Config
9
+ from mmengine.dataset import Compose, default_collate
10
+
11
+ from mmpretrain.registry import TRANSFORMS
12
+ from mmpretrain.structures import DataSample
13
+ from .base import BaseInferencer
14
+ from .model import list_models
15
+
16
+ InputType = Tuple[Union[str, np.ndarray], Union[str, np.ndarray], str]
17
+ InputsType = Union[List[InputType], InputType]
18
+
19
+
20
+ class NLVRInferencer(BaseInferencer):
21
+ """The inferencer for Natural Language for Visual Reasoning.
22
+
23
+ Args:
24
+ model (BaseModel | str | Config): A model name or a path to the config
25
+ file, or a :obj:`BaseModel` object. The model name can be found
26
+ by ``NLVRInferencer.list_models()`` and you can also
27
+ query it in :doc:`/modelzoo_statistics`.
28
+ pretrained (str, optional): Path to the checkpoint. If None, it will
29
+ try to find a pre-defined weight from the model you specified
30
+ (only work if the ``model`` is a model name). Defaults to None.
31
+ device (str, optional): Device to run inference. If None, the available
32
+ device will be automatically used. Defaults to None.
33
+ **kwargs: Other keyword arguments to initialize the model (only work if
34
+ the ``model`` is a model name).
35
+ """
36
+
37
+ visualize_kwargs: set = {
38
+ 'resize', 'draw_score', 'show', 'show_dir', 'wait_time'
39
+ }
40
+
41
+ def __call__(self,
42
+ inputs: InputsType,
43
+ return_datasamples: bool = False,
44
+ batch_size: int = 1,
45
+ **kwargs) -> dict:
46
+ """Call the inferencer.
47
+
48
+ Args:
49
+ inputs (tuple, List[tuple]): The input data tuples, every tuple
50
+ should include three items (left image, right image, text).
51
+ The image can be a path or numpy array.
52
+ return_datasamples (bool): Whether to return results as
53
+ :obj:`DataSample`. Defaults to False.
54
+ batch_size (int): Batch size. Defaults to 1.
55
+ resize (int, optional): Resize the short edge of the image to the
56
+ specified length before visualization. Defaults to None.
57
+ draw_score (bool): Whether to draw the prediction scores
58
+ of prediction categories. Defaults to True.
59
+ show (bool): Whether to display the visualization result in a
60
+ window. Defaults to False.
61
+ wait_time (float): The display time (s). Defaults to 0, which means
62
+ "forever".
63
+ show_dir (str, optional): If not None, save the visualization
64
+ results in the specified directory. Defaults to None.
65
+
66
+ Returns:
67
+ list: The inference results.
68
+ """
69
+ assert isinstance(inputs, (tuple, list))
70
+ if isinstance(inputs, tuple):
71
+ inputs = [inputs]
72
+ for input_ in inputs:
73
+ assert isinstance(input_, tuple)
74
+ assert len(input_) == 3
75
+
76
+ return super().__call__(
77
+ inputs,
78
+ return_datasamples=return_datasamples,
79
+ batch_size=batch_size,
80
+ **kwargs)
81
+
82
+ def _init_pipeline(self, cfg: Config) -> Callable:
83
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
84
+ assert test_pipeline_cfg[0]['type'] == 'ApplyToList'
85
+
86
+ list_pipeline = deepcopy(test_pipeline_cfg[0])
87
+ if list_pipeline.scatter_key == 'img_path':
88
+ # Remove `LoadImageFromFile`
89
+ list_pipeline.transforms.pop(0)
90
+ list_pipeline.scatter_key = 'img'
91
+
92
+ test_pipeline = Compose(
93
+ [TRANSFORMS.build(list_pipeline)] +
94
+ [TRANSFORMS.build(t) for t in test_pipeline_cfg[1:]])
95
+ return test_pipeline
96
+
97
+ def preprocess(self, inputs: InputsType, batch_size: int = 1):
98
+
99
+ def load_image(input_):
100
+ img1 = imread(input_[0])
101
+ img2 = imread(input_[1])
102
+ text = input_[2]
103
+ if img1 is None:
104
+ raise ValueError(f'Failed to read image {input_[0]}.')
105
+ if img2 is None:
106
+ raise ValueError(f'Failed to read image {input_[1]}.')
107
+ return dict(
108
+ img=[img1, img2],
109
+ img_shape=[img1.shape[:2], img2.shape[:2]],
110
+ ori_shape=[img1.shape[:2], img2.shape[:2]],
111
+ text=text,
112
+ )
113
+
114
+ pipeline = Compose([load_image, self.pipeline])
115
+
116
+ chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
117
+ yield from map(default_collate, chunked_data)
118
+
119
+ def postprocess(self,
120
+ preds: List[DataSample],
121
+ visualization: List[np.ndarray],
122
+ return_datasamples=False) -> dict:
123
+ if return_datasamples:
124
+ return preds
125
+
126
+ results = []
127
+ for data_sample in preds:
128
+ pred_scores = data_sample.pred_score
129
+ pred_score = float(torch.max(pred_scores).item())
130
+ pred_label = torch.argmax(pred_scores).item()
131
+ result = {
132
+ 'pred_scores': pred_scores.detach().cpu().numpy(),
133
+ 'pred_label': pred_label,
134
+ 'pred_score': pred_score,
135
+ }
136
+ results.append(result)
137
+
138
+ return results
139
+
140
+ @staticmethod
141
+ def list_models(pattern: Optional[str] = None):
142
+ """List all available model names.
143
+
144
+ Args:
145
+ pattern (str | None): A wildcard pattern to match model names.
146
+
147
+ Returns:
148
+ List[str]: a list of model names.
149
+ """
150
+ return list_models(pattern=pattern, task='NLVR')
mmpretrain/apis/utils.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ from collections import defaultdict
4
+ from contextlib import contextmanager
5
+ from itertools import chain
6
+ from typing import Dict, List, Optional, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from mmpretrain.utils import require
12
+
13
+
14
+ @require('torch>=1.9.0', 'https://pytorch.org/get-started/locally/')
15
+ @require('accelerate')
16
+ def dispatch_model(
17
+ model,
18
+ device_map: Union[str, dict],
19
+ max_memory: Optional[dict] = None,
20
+ no_split_module_classes: Optional[List[str]] = None,
21
+ offload_folder: str = None,
22
+ offload_buffers: bool = False,
23
+ preload_module_classes: Optional[List[str]] = None,
24
+ ):
25
+ """Split and dispatch a model across devices.
26
+
27
+ The function depends on the `accelerate` package. Refers to
28
+ https://huggingface.co/docs/accelerate/main/en/usage_guides/big_modeling
29
+
30
+ Args:
31
+ model (torch.nn.Module): The model to dispatch.
32
+ device_map (str | dict | None): A map that specifies where each
33
+ submodule should go. It doesn't need to be refined to each
34
+ parameter/buffer name, once a given module name is inside, every
35
+ submodule of it will be sent to the same device. You can use
36
+ `device_map="auto"` to automatically generate the device map.
37
+ Defaults to None.
38
+ max_memory (dict | None): A dictionary device identifier to maximum
39
+ memory. Will default to the maximum memory available for each GPU
40
+ and the available CPU RAM if unset. Defaults to None.
41
+ no_split_module_classes (List[str] | None): A list of layer class names
42
+ that should never be split across device (for instance any layer
43
+ that has a residual connection). If None, try to get the settings
44
+ from the model class. Defaults to None.
45
+ offload_folder (str | None): If the `device_map` contains any value
46
+ `"disk"`, the folder where we will offload weights.
47
+ offload_buffers (bool): In the layers that are offloaded on the CPU
48
+ or the hard drive, whether or not to offload the buffers as
49
+ well as the parameters. Defaults to False.
50
+ preload_module_classes (List[str] | None): A list of classes whose
51
+ instances should load all their weights (even in the submodules) at
52
+ the beginning of the forward. This should only be used for classes
53
+ that have submodules which are registered but not called directly
54
+ during the forward, for instance if a `dense` linear layer is
55
+ registered, but at forward, `dense.weight` and `dense.bias` are
56
+ used in some operations instead of calling `dense` directly.
57
+ Defaults to None.
58
+ """
59
+ from accelerate import dispatch_model, infer_auto_device_map
60
+
61
+ # Check valid device_map string.
62
+ valid_map_option = ['auto', 'balanced', 'balanced_low_0', 'sequential']
63
+ if isinstance(device_map, str) and device_map not in valid_map_option:
64
+ raise ValueError('If passing a string for `device_map`, please choose '
65
+ f'from {valid_map_option}.')
66
+
67
+ # Generate device map automatically
68
+ if isinstance(device_map, str):
69
+ if no_split_module_classes is None:
70
+ no_split_module_classes = getattr(model, '_no_split_modules', None)
71
+ if no_split_module_classes is None:
72
+ raise ValueError(f'{model.__class__.__name__} does not support '
73
+ f"`device_map='{device_map}'` yet.")
74
+
75
+ if device_map != 'sequential':
76
+ from accelerate.utils import get_balanced_memory
77
+ max_memory = get_balanced_memory(
78
+ model,
79
+ max_memory=max_memory,
80
+ no_split_module_classes=no_split_module_classes,
81
+ dtype=None,
82
+ low_zero=(device_map == 'balanced_low_0'),
83
+ )
84
+ max_memory[0] *= 0.9
85
+ device_map = infer_auto_device_map(
86
+ model,
87
+ max_memory=max_memory,
88
+ no_split_module_classes=no_split_module_classes,
89
+ dtype=None,
90
+ )
91
+
92
+ if 'disk' in device_map.values():
93
+ if offload_folder is None:
94
+ raise ValueError(
95
+ 'The current `device_map` had weights offloaded to the disk. '
96
+ 'Please provide an `offload_folder` for them.')
97
+ os.makedirs(offload_folder, exist_ok=True)
98
+
99
+ main_device = next(
100
+ (d for d in device_map.values() if d not in ['cpu', 'disk']), 'cpu')
101
+
102
+ model = dispatch_model(
103
+ model,
104
+ device_map=device_map,
105
+ main_device=main_device,
106
+ offload_dir=offload_folder,
107
+ offload_buffers=offload_buffers,
108
+ preload_module_classes=preload_module_classes,
109
+ )
110
+ if hasattr(model, 'data_preprocessor'):
111
+ model.data_preprocessor._device = torch.device(main_device)
112
+ return model
113
+
114
+
115
+ @contextmanager
116
+ def init_empty_weights(include_buffers: bool = False):
117
+ """A context manager under which models are initialized with all parameters
118
+ on the meta device.
119
+
120
+ With this context manager, we can create an empty model. Useful when just
121
+ initializing the model would blow the available RAM.
122
+
123
+ Besides move the parameters to meta device, this method will also avoid
124
+ load checkpoint from `mmengine.runner.load_checkpoint` and
125
+ `transformers.PreTrainedModel.from_pretrained`.
126
+
127
+ Modified from https://github.com/huggingface/accelerate
128
+
129
+ Args:
130
+ include_buffers (bool): Whether put all buffers on the meta device
131
+ during initialization.
132
+ """
133
+ device = torch.device('meta')
134
+
135
+ # move parameter and buffer to meta device
136
+ old_register_parameter = nn.Module.register_parameter
137
+ if include_buffers:
138
+ old_register_buffer = nn.Module.register_buffer
139
+ # See https://github.com/huggingface/accelerate/pull/699
140
+ tensor_constructors_to_patch = {
141
+ torch_function_name: getattr(torch, torch_function_name)
142
+ for torch_function_name in ['empty', 'zeros', 'ones', 'full']
143
+ }
144
+
145
+ def register_parameter(module, name, param):
146
+ old_register_parameter(module, name, param)
147
+ if param is not None:
148
+ param_cls = type(module._parameters[name])
149
+ kwargs = module._parameters[name].__dict__
150
+ module._parameters[name] = param_cls(
151
+ module._parameters[name].to(device), **kwargs)
152
+
153
+ def register_buffer(module, name, buffer, *args, **kwargs):
154
+ old_register_buffer(module, name, buffer, *args, **kwargs)
155
+ if buffer is not None:
156
+ module._buffers[name] = module._buffers[name].to(device)
157
+
158
+ def patch_tensor_constructor(fn):
159
+
160
+ def wrapper(*args, **kwargs):
161
+ kwargs['device'] = device
162
+ return fn(*args, **kwargs)
163
+
164
+ return wrapper
165
+
166
+ # Patch load_checkpoint
167
+ import mmengine.runner.checkpoint as mmengine_load
168
+ old_load_checkpoint = mmengine_load.load_checkpoint
169
+
170
+ def patch_load_checkpoint(*args, **kwargs):
171
+ return {}
172
+
173
+ # Patch transformers from pretrained
174
+ try:
175
+ from transformers import PreTrainedModel
176
+ from transformers.models.auto.auto_factory import (AutoConfig,
177
+ _BaseAutoModelClass)
178
+ with_transformers = True
179
+ except ImportError:
180
+ with_transformers = False
181
+
182
+ @classmethod
183
+ def patch_auto_model(cls, pretrained_model_name_or_path, *model_args,
184
+ **kwargs):
185
+ cfg = AutoConfig.from_pretrained(pretrained_model_name_or_path,
186
+ *model_args, **kwargs)
187
+ return cls.from_config(cfg)
188
+
189
+ @classmethod
190
+ def patch_pretrained_model(cls, pretrained_model_name_or_path, *model_args,
191
+ **kwargs):
192
+ cfg = cls.config_class.from_pretrained(pretrained_model_name_or_path,
193
+ *model_args, **kwargs)
194
+ return cls(cfg)
195
+
196
+ if with_transformers:
197
+ old_pretrained_model = PreTrainedModel.from_pretrained
198
+ old_auto_model = _BaseAutoModelClass.from_pretrained
199
+
200
+ try:
201
+ nn.Module.register_parameter = register_parameter
202
+ mmengine_load.load_checkpoint = patch_load_checkpoint
203
+ if with_transformers:
204
+ PreTrainedModel.from_pretrained = patch_pretrained_model
205
+ _BaseAutoModelClass.from_pretrained = patch_auto_model
206
+ if include_buffers:
207
+ nn.Module.register_buffer = register_buffer
208
+ for func in tensor_constructors_to_patch.keys():
209
+ tensor_constructor = patch_tensor_constructor(
210
+ getattr(torch, func))
211
+ setattr(torch, func, tensor_constructor)
212
+ yield
213
+ finally:
214
+ nn.Module.register_parameter = old_register_parameter
215
+ mmengine_load.load_checkpoint = old_load_checkpoint
216
+ if with_transformers:
217
+ PreTrainedModel.from_pretrained = old_pretrained_model
218
+ _BaseAutoModelClass.from_pretrained = old_auto_model
219
+ if include_buffers:
220
+ nn.Module.register_buffer = old_register_buffer
221
+ for func, ori in tensor_constructors_to_patch.items():
222
+ setattr(torch, func, ori)
223
+
224
+
225
+ def compute_module_sizes(
226
+ model: nn.Module,
227
+ dtype: Union[str, torch.dtype, None] = None,
228
+ special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None):
229
+ """Compute the size of each submodule of a given model."""
230
+
231
+ def get_dtype(dtype):
232
+ if isinstance(dtype, str):
233
+ dtype = getattr(torch, dtype)
234
+ if dtype is not None:
235
+ assert issubclass(dtype, torch.dtype)
236
+ return dtype
237
+
238
+ def dtype_bytes(dtype: torch.dtype):
239
+ if dtype is torch.bool:
240
+ return 1
241
+ if dtype.is_floating_point:
242
+ return torch.finfo(dtype).bits / 8
243
+ else:
244
+ return torch.iinfo(dtype).bits / 8
245
+
246
+ if dtype is not None:
247
+ dtype = get_dtype(dtype)
248
+ dtype_size = dtype_bytes(dtype)
249
+
250
+ if special_dtypes is not None:
251
+ special_dtypes = {
252
+ key: dtype_bytes(dtype)
253
+ for key, dtype in special_dtypes.items()
254
+ }
255
+
256
+ module_sizes = defaultdict(int)
257
+ for name, tensor in chain(
258
+ model.named_parameters(recurse=True),
259
+ model.named_buffers(recurse=True)):
260
+ if special_dtypes is not None and name in special_dtypes:
261
+ size = tensor.numel() * special_dtypes[name]
262
+ elif dtype is None:
263
+ size = tensor.numel() * tensor.element_size()
264
+ else:
265
+ size = tensor.numel() * min(dtype_size, tensor.element_size())
266
+ name_parts = name.split('.')
267
+ for idx in range(len(name_parts) + 1):
268
+ module_sizes['.'.join(name_parts[:idx])] += size
269
+
270
+ return module_sizes
mmpretrain/apis/visual_grounding.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from pathlib import Path
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import numpy as np
6
+ from mmcv.image import imread
7
+ from mmengine.config import Config
8
+ from mmengine.dataset import Compose, default_collate
9
+
10
+ from mmpretrain.registry import TRANSFORMS
11
+ from mmpretrain.structures import DataSample
12
+ from .base import BaseInferencer
13
+ from .model import list_models
14
+
15
+
16
+ class VisualGroundingInferencer(BaseInferencer):
17
+ """The inferencer for visual grounding.
18
+
19
+ Args:
20
+ model (BaseModel | str | Config): A model name or a path to the config
21
+ file, or a :obj:`BaseModel` object. The model name can be found
22
+ by ``VisualGroundingInferencer.list_models()`` and you can also
23
+ query it in :doc:`/modelzoo_statistics`.
24
+ pretrained (str, optional): Path to the checkpoint. If None, it will
25
+ try to find a pre-defined weight from the model you specified
26
+ (only work if the ``model`` is a model name). Defaults to None.
27
+ device (str, optional): Device to run inference. If None, the available
28
+ device will be automatically used. Defaults to None.
29
+ **kwargs: Other keyword arguments to initialize the model (only work if
30
+ the ``model`` is a model name).
31
+
32
+ Example:
33
+ >>> from mmpretrain import VisualGroundingInferencer
34
+ >>> inferencer = VisualGroundingInferencer('ofa-base_3rdparty_refcoco')
35
+ >>> inferencer('demo/cat-dog.png', 'dog')[0]
36
+ {'pred_bboxes': tensor([[ 36.6000, 29.6000, 355.8000, 395.2000]])}
37
+ """ # noqa: E501
38
+
39
+ visualize_kwargs: set = {
40
+ 'resize', 'show', 'show_dir', 'wait_time', 'line_width', 'bbox_color'
41
+ }
42
+
43
+ def __call__(self,
44
+ images: Union[str, np.ndarray, list],
45
+ texts: Union[str, list],
46
+ return_datasamples: bool = False,
47
+ batch_size: int = 1,
48
+ **kwargs) -> dict:
49
+ """Call the inferencer.
50
+
51
+ Args:
52
+ images (str | array | list): The image path or array, or a list of
53
+ images.
54
+ texts (str | list): The text to do visual grounding.
55
+ return_datasamples (bool): Whether to return results as
56
+ :obj:`DataSample`. Defaults to False.
57
+ batch_size (int): Batch size. Defaults to 1.
58
+ resize (int, optional): Resize the short edge of the image to the
59
+ specified length before visualization. Defaults to None.
60
+ draw_score (bool): Whether to draw the prediction scores
61
+ of prediction categories. Defaults to True.
62
+ show (bool): Whether to display the visualization result in a
63
+ window. Defaults to False.
64
+ wait_time (float): The display time (s). Defaults to 0, which means
65
+ "forever".
66
+ show_dir (str, optional): If not None, save the visualization
67
+ results in the specified directory. Defaults to None.
68
+ line_width (int): The line width of the bbox. Defaults to 3.
69
+ bbox_color (str | tuple): The color of the bbox.
70
+ Defaults to 'green'.
71
+
72
+ Returns:
73
+ list: The inference results.
74
+ """
75
+ if not isinstance(images, (list, tuple)):
76
+ assert isinstance(texts, str)
77
+ inputs = [{'img': images, 'text': texts}]
78
+ else:
79
+ inputs = []
80
+ for i in range(len(images)):
81
+ input_ = {'img': images[i], 'text': texts[i]}
82
+ inputs.append(input_)
83
+
84
+ return super().__call__(inputs, return_datasamples, batch_size,
85
+ **kwargs)
86
+
87
+ def _init_pipeline(self, cfg: Config) -> Callable:
88
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
89
+ if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
90
+ # Image loading is finished in `self.preprocess`.
91
+ test_pipeline_cfg = test_pipeline_cfg[1:]
92
+ test_pipeline = Compose(
93
+ [TRANSFORMS.build(t) for t in test_pipeline_cfg])
94
+ return test_pipeline
95
+
96
+ def preprocess(self, inputs: List[dict], batch_size: int = 1):
97
+
98
+ def load_image(input_: dict):
99
+ img = imread(input_['img'])
100
+ if img is None:
101
+ raise ValueError(f'Failed to read image {input_}.')
102
+ return {**input_, 'img': img}
103
+
104
+ pipeline = Compose([load_image, self.pipeline])
105
+
106
+ chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
107
+ yield from map(default_collate, chunked_data)
108
+
109
+ def visualize(self,
110
+ ori_inputs: List[dict],
111
+ preds: List[DataSample],
112
+ show: bool = False,
113
+ wait_time: int = 0,
114
+ resize: Optional[int] = None,
115
+ line_width: int = 3,
116
+ bbox_color: Union[str, tuple] = 'green',
117
+ show_dir=None):
118
+ if not show and show_dir is None:
119
+ return None
120
+
121
+ if self.visualizer is None:
122
+ from mmpretrain.visualization import UniversalVisualizer
123
+ self.visualizer = UniversalVisualizer()
124
+
125
+ visualization = []
126
+ for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
127
+ image = imread(input_['img'])
128
+ if isinstance(input_['img'], str):
129
+ # The image loaded from path is BGR format.
130
+ image = image[..., ::-1]
131
+ name = Path(input_['img']).stem
132
+ else:
133
+ name = str(i)
134
+
135
+ if show_dir is not None:
136
+ show_dir = Path(show_dir)
137
+ show_dir.mkdir(exist_ok=True)
138
+ out_file = str((show_dir / name).with_suffix('.png'))
139
+ else:
140
+ out_file = None
141
+
142
+ self.visualizer.visualize_visual_grounding(
143
+ image,
144
+ data_sample,
145
+ resize=resize,
146
+ show=show,
147
+ wait_time=wait_time,
148
+ line_width=line_width,
149
+ bbox_color=bbox_color,
150
+ name=name,
151
+ out_file=out_file)
152
+ visualization.append(self.visualizer.get_image())
153
+ if show:
154
+ self.visualizer.close()
155
+ return visualization
156
+
157
+ def postprocess(self,
158
+ preds: List[DataSample],
159
+ visualization: List[np.ndarray],
160
+ return_datasamples=False) -> dict:
161
+ if return_datasamples:
162
+ return preds
163
+
164
+ results = []
165
+ for data_sample in preds:
166
+ results.append({'pred_bboxes': data_sample.get('pred_bboxes')})
167
+
168
+ return results
169
+
170
+ @staticmethod
171
+ def list_models(pattern: Optional[str] = None):
172
+ """List all available model names.
173
+
174
+ Args:
175
+ pattern (str | None): A wildcard pattern to match model names.
176
+
177
+ Returns:
178
+ List[str]: a list of model names.
179
+ """
180
+ return list_models(pattern=pattern, task='Visual Grounding')
mmpretrain/apis/visual_question_answering.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from pathlib import Path
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import numpy as np
6
+ from mmcv.image import imread
7
+ from mmengine.config import Config
8
+ from mmengine.dataset import Compose, default_collate
9
+
10
+ from mmpretrain.registry import TRANSFORMS
11
+ from mmpretrain.structures import DataSample
12
+ from .base import BaseInferencer
13
+ from .model import list_models
14
+
15
+
16
+ class VisualQuestionAnsweringInferencer(BaseInferencer):
17
+ """The inferencer for visual question answering.
18
+
19
+ Args:
20
+ model (BaseModel | str | Config): A model name or a path to the config
21
+ file, or a :obj:`BaseModel` object. The model name can be found
22
+ by ``VisualQuestionAnsweringInferencer.list_models()`` and you can
23
+ also query it in :doc:`/modelzoo_statistics`.
24
+ pretrained (str, optional): Path to the checkpoint. If None, it will
25
+ try to find a pre-defined weight from the model you specified
26
+ (only work if the ``model`` is a model name). Defaults to None.
27
+ device (str, optional): Device to run inference. If None, the available
28
+ device will be automatically used. Defaults to None.
29
+ **kwargs: Other keyword arguments to initialize the model (only work if
30
+ the ``model`` is a model name).
31
+
32
+ Example:
33
+ >>> from mmpretrain import VisualQuestionAnsweringInferencer
34
+ >>> inferencer = VisualQuestionAnsweringInferencer('ofa-base_3rdparty-zeroshot_vqa')
35
+ >>> inferencer('demo/cat-dog.png', "What's the animal next to the dog?")[0]
36
+ {'question': "What's the animal next to the dog?", 'pred_answer': 'cat'}
37
+ """ # noqa: E501
38
+
39
+ visualize_kwargs: set = {'resize', 'show', 'show_dir', 'wait_time'}
40
+
41
+ def __call__(self,
42
+ images: Union[str, np.ndarray, list],
43
+ questions: Union[str, list],
44
+ return_datasamples: bool = False,
45
+ batch_size: int = 1,
46
+ objects: Optional[List[str]] = None,
47
+ **kwargs) -> dict:
48
+ """Call the inferencer.
49
+
50
+ Args:
51
+ images (str | array | list): The image path or array, or a list of
52
+ images.
53
+ questions (str | list): The question to the correspondding image.
54
+ return_datasamples (bool): Whether to return results as
55
+ :obj:`DataSample`. Defaults to False.
56
+ batch_size (int): Batch size. Defaults to 1.
57
+ objects (List[List[str]], optional): Some algorithms like OFA
58
+ fine-tuned VQA models requires extra object description list
59
+ for every image. Defaults to None.
60
+ resize (int, optional): Resize the short edge of the image to the
61
+ specified length before visualization. Defaults to None.
62
+ show (bool): Whether to display the visualization result in a
63
+ window. Defaults to False.
64
+ wait_time (float): The display time (s). Defaults to 0, which means
65
+ "forever".
66
+ show_dir (str, optional): If not None, save the visualization
67
+ results in the specified directory. Defaults to None.
68
+
69
+ Returns:
70
+ list: The inference results.
71
+ """
72
+ if not isinstance(images, (list, tuple)):
73
+ assert isinstance(questions, str)
74
+ inputs = [{'img': images, 'question': questions}]
75
+ if objects is not None:
76
+ assert isinstance(objects[0], str)
77
+ inputs[0]['objects'] = objects
78
+ else:
79
+ inputs = []
80
+ for i in range(len(images)):
81
+ input_ = {'img': images[i], 'question': questions[i]}
82
+ if objects is not None:
83
+ input_['objects'] = objects[i]
84
+ inputs.append(input_)
85
+
86
+ return super().__call__(inputs, return_datasamples, batch_size,
87
+ **kwargs)
88
+
89
+ def _init_pipeline(self, cfg: Config) -> Callable:
90
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
91
+ if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
92
+ # Image loading is finished in `self.preprocess`.
93
+ test_pipeline_cfg = test_pipeline_cfg[1:]
94
+ test_pipeline = Compose(
95
+ [TRANSFORMS.build(t) for t in test_pipeline_cfg])
96
+ return test_pipeline
97
+
98
+ def preprocess(self, inputs: List[dict], batch_size: int = 1):
99
+
100
+ def load_image(input_: dict):
101
+ img = imread(input_['img'])
102
+ if img is None:
103
+ raise ValueError(f'Failed to read image {input_}.')
104
+ return {**input_, 'img': img}
105
+
106
+ pipeline = Compose([load_image, self.pipeline])
107
+
108
+ chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
109
+ yield from map(default_collate, chunked_data)
110
+
111
+ def visualize(self,
112
+ ori_inputs: List[dict],
113
+ preds: List[DataSample],
114
+ show: bool = False,
115
+ wait_time: int = 0,
116
+ resize: Optional[int] = None,
117
+ show_dir=None):
118
+ if not show and show_dir is None:
119
+ return None
120
+
121
+ if self.visualizer is None:
122
+ from mmpretrain.visualization import UniversalVisualizer
123
+ self.visualizer = UniversalVisualizer()
124
+
125
+ visualization = []
126
+ for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
127
+ image = imread(input_['img'])
128
+ if isinstance(input_['img'], str):
129
+ # The image loaded from path is BGR format.
130
+ image = image[..., ::-1]
131
+ name = Path(input_['img']).stem
132
+ else:
133
+ name = str(i)
134
+
135
+ if show_dir is not None:
136
+ show_dir = Path(show_dir)
137
+ show_dir.mkdir(exist_ok=True)
138
+ out_file = str((show_dir / name).with_suffix('.png'))
139
+ else:
140
+ out_file = None
141
+
142
+ self.visualizer.visualize_vqa(
143
+ image,
144
+ data_sample,
145
+ resize=resize,
146
+ show=show,
147
+ wait_time=wait_time,
148
+ name=name,
149
+ out_file=out_file)
150
+ visualization.append(self.visualizer.get_image())
151
+ if show:
152
+ self.visualizer.close()
153
+ return visualization
154
+
155
+ def postprocess(self,
156
+ preds: List[DataSample],
157
+ visualization: List[np.ndarray],
158
+ return_datasamples=False) -> dict:
159
+ if return_datasamples:
160
+ return preds
161
+
162
+ results = []
163
+ for data_sample in preds:
164
+ results.append({
165
+ 'question': data_sample.get('question'),
166
+ 'pred_answer': data_sample.get('pred_answer'),
167
+ })
168
+
169
+ return results
170
+
171
+ @staticmethod
172
+ def list_models(pattern: Optional[str] = None):
173
+ """List all available model names.
174
+
175
+ Args:
176
+ pattern (str | None): A wildcard pattern to match model names.
177
+
178
+ Returns:
179
+ List[str]: a list of model names.
180
+ """
181
+ return list_models(pattern=pattern, task='Visual Question Answering')
mmpretrain/datasets/__init__.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from mmpretrain.utils.dependency import WITH_MULTIMODAL
3
+ from .base_dataset import BaseDataset
4
+ from .builder import build_dataset
5
+ from .caltech101 import Caltech101
6
+ from .cifar import CIFAR10, CIFAR100
7
+ from .cub import CUB
8
+ from .custom import CustomDataset
9
+ from .dataset_wrappers import KFoldDataset
10
+ from .dtd import DTD
11
+ from .fgvcaircraft import FGVCAircraft
12
+ from .flowers102 import Flowers102
13
+ from .food101 import Food101
14
+ from .imagenet import ImageNet, ImageNet21k
15
+ from .inshop import InShop
16
+ from .mnist import MNIST, FashionMNIST
17
+ from .multi_label import MultiLabelDataset
18
+ from .multi_task import MultiTaskDataset
19
+ from .nlvr2 import NLVR2
20
+ from .oxfordiiitpet import OxfordIIITPet
21
+ from .places205 import Places205
22
+ from .samplers import * # noqa: F401,F403
23
+ from .stanfordcars import StanfordCars
24
+ from .sun397 import SUN397
25
+ from .transforms import * # noqa: F401,F403
26
+ from .voc import VOC
27
+
28
+ __all__ = [
29
+ 'BaseDataset', 'CIFAR10', 'CIFAR100', 'CUB', 'Caltech101', 'CustomDataset',
30
+ 'DTD', 'FGVCAircraft', 'FashionMNIST', 'Flowers102', 'Food101', 'ImageNet',
31
+ 'ImageNet21k', 'InShop', 'KFoldDataset', 'MNIST', 'MultiLabelDataset',
32
+ 'MultiTaskDataset', 'NLVR2', 'OxfordIIITPet', 'Places205', 'SUN397',
33
+ 'StanfordCars', 'VOC', 'build_dataset'
34
+ ]
35
+
36
+ if WITH_MULTIMODAL:
37
+ from .coco_caption import COCOCaption
38
+ from .coco_retrieval import COCORetrieval
39
+ from .coco_vqa import COCOVQA
40
+ from .flamingo import FlamingoEvalCOCOCaption, FlamingoEvalCOCOVQA
41
+ from .refcoco import RefCOCO
42
+ from .scienceqa import ScienceQA
43
+ from .visual_genome import VisualGenomeQA
44
+
45
+ __all__.extend([
46
+ 'COCOCaption',
47
+ 'COCORetrieval',
48
+ 'COCOVQA',
49
+ 'FlamingoEvalCOCOCaption',
50
+ 'FlamingoEvalCOCOVQA',
51
+ 'RefCOCO',
52
+ 'VisualGenomeQA',
53
+ 'ScienceQA',
54
+ ])
mmpretrain/datasets/base_dataset.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+ from os import PathLike
4
+ from typing import List, Optional, Sequence, Union
5
+
6
+ import mmengine
7
+ import numpy as np
8
+ from mmengine.dataset import BaseDataset as _BaseDataset
9
+
10
+ from mmpretrain.registry import DATASETS, TRANSFORMS
11
+
12
+
13
+ def expanduser(path):
14
+ """Expand ~ and ~user constructions.
15
+
16
+ If user or $HOME is unknown, do nothing.
17
+ """
18
+ if isinstance(path, (str, PathLike)):
19
+ return osp.expanduser(path)
20
+ else:
21
+ return path
22
+
23
+
24
+ @DATASETS.register_module()
25
+ class BaseDataset(_BaseDataset):
26
+ """Base dataset for image classification task.
27
+
28
+ This dataset support annotation file in `OpenMMLab 2.0 style annotation
29
+ format`.
30
+
31
+ .. _OpenMMLab 2.0 style annotation format:
32
+ https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md
33
+
34
+ Comparing with the :class:`mmengine.BaseDataset`, this class implemented
35
+ several useful methods.
36
+
37
+ Args:
38
+ ann_file (str): Annotation file path.
39
+ metainfo (dict, optional): Meta information for dataset, such as class
40
+ information. Defaults to None.
41
+ data_root (str): The root directory for ``data_prefix`` and
42
+ ``ann_file``. Defaults to ''.
43
+ data_prefix (str | dict): Prefix for training data. Defaults to ''.
44
+ filter_cfg (dict, optional): Config for filter data. Defaults to None.
45
+ indices (int or Sequence[int], optional): Support using first few
46
+ data in annotation file to facilitate training/testing on a smaller
47
+ dataset. Defaults to None, which means using all ``data_infos``.
48
+ serialize_data (bool): Whether to hold memory using serialized objects,
49
+ when enabled, data loader workers can use shared RAM from master
50
+ process instead of making a copy. Defaults to True.
51
+ pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
52
+ test_mode (bool, optional): ``test_mode=True`` means in test phase,
53
+ an error will be raised when getting an item fails, ``test_mode=False``
54
+ means in training phase, another item will be returned randomly.
55
+ Defaults to False.
56
+ lazy_init (bool): Whether to load annotation during instantiation.
57
+ In some cases, such as visualization, only the meta information of
58
+ the dataset is needed, which is not necessary to load annotation
59
+ file. ``Basedataset`` can skip load annotations to save time by set
60
+ ``lazy_init=False``. Defaults to False.
61
+ max_refetch (int): If ``Basedataset.prepare_data`` get a None img.
62
+ The maximum extra number of cycles to get a valid image.
63
+ Defaults to 1000.
64
+ classes (str | Sequence[str], optional): Specify names of classes.
65
+
66
+ - If is string, it should be a file path, and the every line of
67
+ the file is a name of a class.
68
+ - If is a sequence of string, every item is a name of class.
69
+ - If is None, use categories information in ``metainfo`` argument,
70
+ annotation file or the class attribute ``METAINFO``.
71
+
72
+ Defaults to None.
73
+ """ # noqa: E501
74
+
75
+ def __init__(self,
76
+ ann_file: str,
77
+ metainfo: Optional[dict] = None,
78
+ data_root: str = '',
79
+ data_prefix: Union[str, dict] = '',
80
+ filter_cfg: Optional[dict] = None,
81
+ indices: Optional[Union[int, Sequence[int]]] = None,
82
+ serialize_data: bool = True,
83
+ pipeline: Sequence = (),
84
+ test_mode: bool = False,
85
+ lazy_init: bool = False,
86
+ max_refetch: int = 1000,
87
+ classes: Union[str, Sequence[str], None] = None):
88
+ if isinstance(data_prefix, str):
89
+ data_prefix = dict(img_path=expanduser(data_prefix))
90
+
91
+ ann_file = expanduser(ann_file)
92
+ metainfo = self._compat_classes(metainfo, classes)
93
+
94
+ transforms = []
95
+ for transform in pipeline:
96
+ if isinstance(transform, dict):
97
+ transforms.append(TRANSFORMS.build(transform))
98
+ else:
99
+ transforms.append(transform)
100
+
101
+ super().__init__(
102
+ ann_file=ann_file,
103
+ metainfo=metainfo,
104
+ data_root=data_root,
105
+ data_prefix=data_prefix,
106
+ filter_cfg=filter_cfg,
107
+ indices=indices,
108
+ serialize_data=serialize_data,
109
+ pipeline=transforms,
110
+ test_mode=test_mode,
111
+ lazy_init=lazy_init,
112
+ max_refetch=max_refetch)
113
+
114
+ @property
115
+ def img_prefix(self):
116
+ """The prefix of images."""
117
+ return self.data_prefix['img_path']
118
+
119
+ @property
120
+ def CLASSES(self):
121
+ """Return all categories names."""
122
+ return self._metainfo.get('classes', None)
123
+
124
+ @property
125
+ def class_to_idx(self):
126
+ """Map mapping class name to class index.
127
+
128
+ Returns:
129
+ dict: mapping from class name to class index.
130
+ """
131
+
132
+ return {cat: i for i, cat in enumerate(self.CLASSES)}
133
+
134
+ def get_gt_labels(self):
135
+ """Get all ground-truth labels (categories).
136
+
137
+ Returns:
138
+ np.ndarray: categories for all images.
139
+ """
140
+
141
+ gt_labels = np.array(
142
+ [self.get_data_info(i)['gt_label'] for i in range(len(self))])
143
+ return gt_labels
144
+
145
+ def get_cat_ids(self, idx: int) -> List[int]:
146
+ """Get category id by index.
147
+
148
+ Args:
149
+ idx (int): Index of data.
150
+
151
+ Returns:
152
+ cat_ids (List[int]): Image category of specified index.
153
+ """
154
+
155
+ return [int(self.get_data_info(idx)['gt_label'])]
156
+
157
+ def _compat_classes(self, metainfo, classes):
158
+ """Merge the old style ``classes`` arguments to ``metainfo``."""
159
+ if isinstance(classes, str):
160
+ # take it as a file path
161
+ class_names = mmengine.list_from_file(expanduser(classes))
162
+ elif isinstance(classes, (tuple, list)):
163
+ class_names = classes
164
+ elif classes is not None:
165
+ raise ValueError(f'Unsupported type {type(classes)} of classes.')
166
+
167
+ if metainfo is None:
168
+ metainfo = {}
169
+
170
+ if classes is not None:
171
+ metainfo = {'classes': tuple(class_names), **metainfo}
172
+
173
+ return metainfo
174
+
175
+ def full_init(self):
176
+ """Load annotation file and set ``BaseDataset._fully_initialized`` to
177
+ True."""
178
+ super().full_init()
179
+
180
+ # To support the standard OpenMMLab 2.0 annotation format. Generate
181
+ # metainfo in internal format from standard metainfo format.
182
+ if 'categories' in self._metainfo and 'classes' not in self._metainfo:
183
+ categories = sorted(
184
+ self._metainfo['categories'], key=lambda x: x['id'])
185
+ self._metainfo['classes'] = tuple(
186
+ [cat['category_name'] for cat in categories])
187
+
188
+ def __repr__(self):
189
+ """Print the basic information of the dataset.
190
+
191
+ Returns:
192
+ str: Formatted string.
193
+ """
194
+ head = 'Dataset ' + self.__class__.__name__
195
+ body = []
196
+ if self._fully_initialized:
197
+ body.append(f'Number of samples: \t{self.__len__()}')
198
+ else:
199
+ body.append("Haven't been initialized")
200
+
201
+ if self.CLASSES is not None:
202
+ body.append(f'Number of categories: \t{len(self.CLASSES)}')
203
+
204
+ body.extend(self.extra_repr())
205
+
206
+ if len(self.pipeline.transforms) > 0:
207
+ body.append('With transforms:')
208
+ for t in self.pipeline.transforms:
209
+ body.append(f' {t}')
210
+
211
+ lines = [head] + [' ' * 4 + line for line in body]
212
+ return '\n'.join(lines)
213
+
214
+ def extra_repr(self) -> List[str]:
215
+ """The extra repr information of the dataset."""
216
+ body = []
217
+ body.append(f'Annotation file: \t{self.ann_file}')
218
+ body.append(f'Prefix of images: \t{self.img_prefix}')
219
+ return body
mmpretrain/datasets/builder.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from mmpretrain.registry import DATASETS
3
+
4
+
5
+ def build_dataset(cfg):
6
+ """Build dataset.
7
+
8
+ Examples:
9
+ >>> from mmpretrain.datasets import build_dataset
10
+ >>> mnist_train = build_dataset(
11
+ ... dict(type='MNIST', data_prefix='data/mnist/', test_mode=False))
12
+ >>> print(mnist_train)
13
+ Dataset MNIST
14
+ Number of samples: 60000
15
+ Number of categories: 10
16
+ Prefix of data: data/mnist/
17
+ >>> mnist_test = build_dataset(
18
+ ... dict(type='MNIST', data_prefix='data/mnist/', test_mode=True))
19
+ >>> print(mnist_test)
20
+ Dataset MNIST
21
+ Number of samples: 10000
22
+ Number of categories: 10
23
+ Prefix of data: data/mnist/
24
+ """
25
+ return DATASETS.build(cfg)
mmpretrain/datasets/caltech101.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List
3
+
4
+ from mmengine import get_file_backend, list_from_file
5
+
6
+ from mmpretrain.registry import DATASETS
7
+ from .base_dataset import BaseDataset
8
+ from .categories import CALTECH101_CATEGORIES
9
+
10
+
11
+ @DATASETS.register_module()
12
+ class Caltech101(BaseDataset):
13
+ """The Caltech101 Dataset.
14
+
15
+ Support the `Caltech101 <https://data.caltech.edu/records/mzrjq-6wc02>`_ Dataset.
16
+ After downloading and decompression, the dataset directory structure is as follows.
17
+
18
+ Caltech101 dataset directory: ::
19
+
20
+ caltech-101
21
+ ├── 101_ObjectCategories
22
+ │ ├── class_x
23
+ │ │ ├── xx1.jpg
24
+ │ │ ├── xx2.jpg
25
+ │ │ └── ...
26
+ │ ├── class_y
27
+ │ │ ├── yy1.jpg
28
+ │ │ ├── yy2.jpg
29
+ │ │ └── ...
30
+ │ └── ...
31
+ ├── Annotations
32
+ │ ├── class_x
33
+ │ │ ├── xx1.mat
34
+ │ │ └── ...
35
+ │ └── ...
36
+ ├── meta
37
+ │ ├── train.txt
38
+ │ └── test.txt
39
+ └── ....
40
+
41
+ Please note that since there is no official splitting for training and
42
+ test set, you can use the train.txt and text.txt provided by us or
43
+ create your own annotation files. Here is the download
44
+ `link <https://download.openmmlab.com/mmpretrain/datasets/caltech_meta.zip>`_
45
+ for the annotations.
46
+
47
+ Args:
48
+ data_root (str): The root directory for the Caltech101 dataset.
49
+ split (str, optional): The dataset split, supports "train" and "test".
50
+ Default to "train".
51
+
52
+ Examples:
53
+ >>> from mmpretrain.datasets import Caltech101
54
+ >>> train_dataset = Caltech101(data_root='data/caltech-101', split='train')
55
+ >>> train_dataset
56
+ Dataset Caltech101
57
+ Number of samples: 3060
58
+ Number of categories: 102
59
+ Root of dataset: data/caltech-101
60
+ >>> test_dataset = Caltech101(data_root='data/caltech-101', split='test')
61
+ >>> test_dataset
62
+ Dataset Caltech101
63
+ Number of samples: 6728
64
+ Number of categories: 102
65
+ Root of dataset: data/caltech-101
66
+ """ # noqa: E501
67
+
68
+ METAINFO = {'classes': CALTECH101_CATEGORIES}
69
+
70
+ def __init__(self, data_root: str, split: str = 'train', **kwargs):
71
+
72
+ splits = ['train', 'test']
73
+ assert split in splits, \
74
+ f"The split must be one of {splits}, but get '{split}'"
75
+ self.split = split
76
+
77
+ self.backend = get_file_backend(data_root, enable_singleton=True)
78
+
79
+ if split == 'train':
80
+ ann_file = self.backend.join_path('meta', 'train.txt')
81
+ else:
82
+ ann_file = self.backend.join_path('meta', 'test.txt')
83
+
84
+ data_prefix = '101_ObjectCategories'
85
+ test_mode = split == 'test'
86
+
87
+ super(Caltech101, self).__init__(
88
+ ann_file=ann_file,
89
+ data_root=data_root,
90
+ data_prefix=data_prefix,
91
+ test_mode=test_mode,
92
+ **kwargs)
93
+
94
+ def load_data_list(self):
95
+ """Load images and ground truth labels."""
96
+
97
+ pairs = list_from_file(self.ann_file)
98
+ data_list = []
99
+
100
+ for pair in pairs:
101
+ path, gt_label = pair.split()
102
+ img_path = self.backend.join_path(self.img_prefix, path)
103
+ info = dict(img_path=img_path, gt_label=int(gt_label))
104
+ data_list.append(info)
105
+
106
+ return data_list
107
+
108
+ def extra_repr(self) -> List[str]:
109
+ """The extra repr information of the dataset."""
110
+ body = [
111
+ f'Root of dataset: \t{self.data_root}',
112
+ ]
113
+ return body
mmpretrain/datasets/categories.py ADDED
@@ -0,0 +1,1440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ # Pre-defined categories names of various datasets.
3
+
4
+ VOC2007_CATEGORIES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
5
+ 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
6
+ 'horse', 'motorbike', 'person', 'pottedplant', 'sheep',
7
+ 'sofa', 'train', 'tvmonitor')
8
+
9
+ CUB_CATEGORIES = (
10
+ 'Black_footed_Albatross', 'Laysan_Albatross', 'Sooty_Albatross',
11
+ 'Groove_billed_Ani', 'Crested_Auklet', 'Least_Auklet', 'Parakeet_Auklet',
12
+ 'Rhinoceros_Auklet', 'Brewer_Blackbird', 'Red_winged_Blackbird',
13
+ 'Rusty_Blackbird', 'Yellow_headed_Blackbird', 'Bobolink', 'Indigo_Bunting',
14
+ 'Lazuli_Bunting', 'Painted_Bunting', 'Cardinal', 'Spotted_Catbird',
15
+ 'Gray_Catbird', 'Yellow_breasted_Chat', 'Eastern_Towhee',
16
+ 'Chuck_will_Widow', 'Brandt_Cormorant', 'Red_faced_Cormorant',
17
+ 'Pelagic_Cormorant', 'Bronzed_Cowbird', 'Shiny_Cowbird', 'Brown_Creeper',
18
+ 'American_Crow', 'Fish_Crow', 'Black_billed_Cuckoo', 'Mangrove_Cuckoo',
19
+ 'Yellow_billed_Cuckoo', 'Gray_crowned_Rosy_Finch', 'Purple_Finch',
20
+ 'Northern_Flicker', 'Acadian_Flycatcher', 'Great_Crested_Flycatcher',
21
+ 'Least_Flycatcher', 'Olive_sided_Flycatcher', 'Scissor_tailed_Flycatcher',
22
+ 'Vermilion_Flycatcher', 'Yellow_bellied_Flycatcher', 'Frigatebird',
23
+ 'Northern_Fulmar', 'Gadwall', 'American_Goldfinch', 'European_Goldfinch',
24
+ 'Boat_tailed_Grackle', 'Eared_Grebe', 'Horned_Grebe', 'Pied_billed_Grebe',
25
+ 'Western_Grebe', 'Blue_Grosbeak', 'Evening_Grosbeak', 'Pine_Grosbeak',
26
+ 'Rose_breasted_Grosbeak', 'Pigeon_Guillemot', 'California_Gull',
27
+ 'Glaucous_winged_Gull', 'Heermann_Gull', 'Herring_Gull', 'Ivory_Gull',
28
+ 'Ring_billed_Gull', 'Slaty_backed_Gull', 'Western_Gull',
29
+ 'Anna_Hummingbird', 'Ruby_throated_Hummingbird', 'Rufous_Hummingbird',
30
+ 'Green_Violetear', 'Long_tailed_Jaeger', 'Pomarine_Jaeger', 'Blue_Jay',
31
+ 'Florida_Jay', 'Green_Jay', 'Dark_eyed_Junco', 'Tropical_Kingbird',
32
+ 'Gray_Kingbird', 'Belted_Kingfisher', 'Green_Kingfisher',
33
+ 'Pied_Kingfisher', 'Ringed_Kingfisher', 'White_breasted_Kingfisher',
34
+ 'Red_legged_Kittiwake', 'Horned_Lark', 'Pacific_Loon', 'Mallard',
35
+ 'Western_Meadowlark', 'Hooded_Merganser', 'Red_breasted_Merganser',
36
+ 'Mockingbird', 'Nighthawk', 'Clark_Nutcracker', 'White_breasted_Nuthatch',
37
+ 'Baltimore_Oriole', 'Hooded_Oriole', 'Orchard_Oriole', 'Scott_Oriole',
38
+ 'Ovenbird', 'Brown_Pelican', 'White_Pelican', 'Western_Wood_Pewee',
39
+ 'Sayornis', 'American_Pipit', 'Whip_poor_Will', 'Horned_Puffin',
40
+ 'Common_Raven', 'White_necked_Raven', 'American_Redstart', 'Geococcyx',
41
+ 'Loggerhead_Shrike', 'Great_Grey_Shrike', 'Baird_Sparrow',
42
+ 'Black_throated_Sparrow', 'Brewer_Sparrow', 'Chipping_Sparrow',
43
+ 'Clay_colored_Sparrow', 'House_Sparrow', 'Field_Sparrow', 'Fox_Sparrow',
44
+ 'Grasshopper_Sparrow', 'Harris_Sparrow', 'Henslow_Sparrow',
45
+ 'Le_Conte_Sparrow', 'Lincoln_Sparrow', 'Nelson_Sharp_tailed_Sparrow',
46
+ 'Savannah_Sparrow', 'Seaside_Sparrow', 'Song_Sparrow', 'Tree_Sparrow',
47
+ 'Vesper_Sparrow', 'White_crowned_Sparrow', 'White_throated_Sparrow',
48
+ 'Cape_Glossy_Starling', 'Bank_Swallow', 'Barn_Swallow', 'Cliff_Swallow',
49
+ 'Tree_Swallow', 'Scarlet_Tanager', 'Summer_Tanager', 'Artic_Tern',
50
+ 'Black_Tern', 'Caspian_Tern', 'Common_Tern', 'Elegant_Tern',
51
+ 'Forsters_Tern', 'Least_Tern', 'Green_tailed_Towhee', 'Brown_Thrasher',
52
+ 'Sage_Thrasher', 'Black_capped_Vireo', 'Blue_headed_Vireo',
53
+ 'Philadelphia_Vireo', 'Red_eyed_Vireo', 'Warbling_Vireo',
54
+ 'White_eyed_Vireo', 'Yellow_throated_Vireo', 'Bay_breasted_Warbler',
55
+ 'Black_and_white_Warbler', 'Black_throated_Blue_Warbler',
56
+ 'Blue_winged_Warbler', 'Canada_Warbler', 'Cape_May_Warbler',
57
+ 'Cerulean_Warbler', 'Chestnut_sided_Warbler', 'Golden_winged_Warbler',
58
+ 'Hooded_Warbler', 'Kentucky_Warbler', 'Magnolia_Warbler',
59
+ 'Mourning_Warbler', 'Myrtle_Warbler', 'Nashville_Warbler',
60
+ 'Orange_crowned_Warbler', 'Palm_Warbler', 'Pine_Warbler',
61
+ 'Prairie_Warbler', 'Prothonotary_Warbler', 'Swainson_Warbler',
62
+ 'Tennessee_Warbler', 'Wilson_Warbler', 'Worm_eating_Warbler',
63
+ 'Yellow_Warbler', 'Northern_Waterthrush', 'Louisiana_Waterthrush',
64
+ 'Bohemian_Waxwing', 'Cedar_Waxwing', 'American_Three_toed_Woodpecker',
65
+ 'Pileated_Woodpecker', 'Red_bellied_Woodpecker', 'Red_cockaded_Woodpecker',
66
+ 'Red_headed_Woodpecker', 'Downy_Woodpecker', 'Bewick_Wren', 'Cactus_Wren',
67
+ 'Carolina_Wren', 'House_Wren', 'Marsh_Wren', 'Rock_Wren', 'Winter_Wren',
68
+ 'Common_Yellowthroat')
69
+
70
+ IMAGENET_CATEGORIES = (
71
+ 'tench, Tinca tinca',
72
+ 'goldfish, Carassius auratus',
73
+ 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', # noqa: E501
74
+ 'tiger shark, Galeocerdo cuvieri',
75
+ 'hammerhead, hammerhead shark',
76
+ 'electric ray, crampfish, numbfish, torpedo',
77
+ 'stingray',
78
+ 'cock',
79
+ 'hen',
80
+ 'ostrich, Struthio camelus',
81
+ 'brambling, Fringilla montifringilla',
82
+ 'goldfinch, Carduelis carduelis',
83
+ 'house finch, linnet, Carpodacus mexicanus',
84
+ 'junco, snowbird',
85
+ 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
86
+ 'robin, American robin, Turdus migratorius',
87
+ 'bulbul',
88
+ 'jay',
89
+ 'magpie',
90
+ 'chickadee',
91
+ 'water ouzel, dipper',
92
+ 'kite',
93
+ 'bald eagle, American eagle, Haliaeetus leucocephalus',
94
+ 'vulture',
95
+ 'great grey owl, great gray owl, Strix nebulosa',
96
+ 'European fire salamander, Salamandra salamandra',
97
+ 'common newt, Triturus vulgaris',
98
+ 'eft',
99
+ 'spotted salamander, Ambystoma maculatum',
100
+ 'axolotl, mud puppy, Ambystoma mexicanum',
101
+ 'bullfrog, Rana catesbeiana',
102
+ 'tree frog, tree-frog',
103
+ 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
104
+ 'loggerhead, loggerhead turtle, Caretta caretta',
105
+ 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea', # noqa: E501
106
+ 'mud turtle',
107
+ 'terrapin',
108
+ 'box turtle, box tortoise',
109
+ 'banded gecko',
110
+ 'common iguana, iguana, Iguana iguana',
111
+ 'American chameleon, anole, Anolis carolinensis',
112
+ 'whiptail, whiptail lizard',
113
+ 'agama',
114
+ 'frilled lizard, Chlamydosaurus kingi',
115
+ 'alligator lizard',
116
+ 'Gila monster, Heloderma suspectum',
117
+ 'green lizard, Lacerta viridis',
118
+ 'African chameleon, Chamaeleo chamaeleon',
119
+ 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis', # noqa: E501
120
+ 'African crocodile, Nile crocodile, Crocodylus niloticus',
121
+ 'American alligator, Alligator mississipiensis',
122
+ 'triceratops',
123
+ 'thunder snake, worm snake, Carphophis amoenus',
124
+ 'ringneck snake, ring-necked snake, ring snake',
125
+ 'hognose snake, puff adder, sand viper',
126
+ 'green snake, grass snake',
127
+ 'king snake, kingsnake',
128
+ 'garter snake, grass snake',
129
+ 'water snake',
130
+ 'vine snake',
131
+ 'night snake, Hypsiglena torquata',
132
+ 'boa constrictor, Constrictor constrictor',
133
+ 'rock python, rock snake, Python sebae',
134
+ 'Indian cobra, Naja naja',
135
+ 'green mamba',
136
+ 'sea snake',
137
+ 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
138
+ 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
139
+ 'sidewinder, horned rattlesnake, Crotalus cerastes',
140
+ 'trilobite',
141
+ 'harvestman, daddy longlegs, Phalangium opilio',
142
+ 'scorpion',
143
+ 'black and gold garden spider, Argiope aurantia',
144
+ 'barn spider, Araneus cavaticus',
145
+ 'garden spider, Aranea diademata',
146
+ 'black widow, Latrodectus mactans',
147
+ 'tarantula',
148
+ 'wolf spider, hunting spider',
149
+ 'tick',
150
+ 'centipede',
151
+ 'black grouse',
152
+ 'ptarmigan',
153
+ 'ruffed grouse, partridge, Bonasa umbellus',
154
+ 'prairie chicken, prairie grouse, prairie fowl',
155
+ 'peacock',
156
+ 'quail',
157
+ 'partridge',
158
+ 'African grey, African gray, Psittacus erithacus',
159
+ 'macaw',
160
+ 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
161
+ 'lorikeet',
162
+ 'coucal',
163
+ 'bee eater',
164
+ 'hornbill',
165
+ 'hummingbird',
166
+ 'jacamar',
167
+ 'toucan',
168
+ 'drake',
169
+ 'red-breasted merganser, Mergus serrator',
170
+ 'goose',
171
+ 'black swan, Cygnus atratus',
172
+ 'tusker',
173
+ 'echidna, spiny anteater, anteater',
174
+ 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus', # noqa: E501
175
+ 'wallaby, brush kangaroo',
176
+ 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus', # noqa: E501
177
+ 'wombat',
178
+ 'jellyfish',
179
+ 'sea anemone, anemone',
180
+ 'brain coral',
181
+ 'flatworm, platyhelminth',
182
+ 'nematode, nematode worm, roundworm',
183
+ 'conch',
184
+ 'snail',
185
+ 'slug',
186
+ 'sea slug, nudibranch',
187
+ 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
188
+ 'chambered nautilus, pearly nautilus, nautilus',
189
+ 'Dungeness crab, Cancer magister',
190
+ 'rock crab, Cancer irroratus',
191
+ 'fiddler crab',
192
+ 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica', # noqa: E501
193
+ 'American lobster, Northern lobster, Maine lobster, Homarus americanus', # noqa: E501
194
+ 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish', # noqa: E501
195
+ 'crayfish, crawfish, crawdad, crawdaddy',
196
+ 'hermit crab',
197
+ 'isopod',
198
+ 'white stork, Ciconia ciconia',
199
+ 'black stork, Ciconia nigra',
200
+ 'spoonbill',
201
+ 'flamingo',
202
+ 'little blue heron, Egretta caerulea',
203
+ 'American egret, great white heron, Egretta albus',
204
+ 'bittern',
205
+ 'crane',
206
+ 'limpkin, Aramus pictus',
207
+ 'European gallinule, Porphyrio porphyrio',
208
+ 'American coot, marsh hen, mud hen, water hen, Fulica americana',
209
+ 'bustard',
210
+ 'ruddy turnstone, Arenaria interpres',
211
+ 'red-backed sandpiper, dunlin, Erolia alpina',
212
+ 'redshank, Tringa totanus',
213
+ 'dowitcher',
214
+ 'oystercatcher, oyster catcher',
215
+ 'pelican',
216
+ 'king penguin, Aptenodytes patagonica',
217
+ 'albatross, mollymawk',
218
+ 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus', # noqa: E501
219
+ 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
220
+ 'dugong, Dugong dugon',
221
+ 'sea lion',
222
+ 'Chihuahua',
223
+ 'Japanese spaniel',
224
+ 'Maltese dog, Maltese terrier, Maltese',
225
+ 'Pekinese, Pekingese, Peke',
226
+ 'Shih-Tzu',
227
+ 'Blenheim spaniel',
228
+ 'papillon',
229
+ 'toy terrier',
230
+ 'Rhodesian ridgeback',
231
+ 'Afghan hound, Afghan',
232
+ 'basset, basset hound',
233
+ 'beagle',
234
+ 'bloodhound, sleuthhound',
235
+ 'bluetick',
236
+ 'black-and-tan coonhound',
237
+ 'Walker hound, Walker foxhound',
238
+ 'English foxhound',
239
+ 'redbone',
240
+ 'borzoi, Russian wolfhound',
241
+ 'Irish wolfhound',
242
+ 'Italian greyhound',
243
+ 'whippet',
244
+ 'Ibizan hound, Ibizan Podenco',
245
+ 'Norwegian elkhound, elkhound',
246
+ 'otterhound, otter hound',
247
+ 'Saluki, gazelle hound',
248
+ 'Scottish deerhound, deerhound',
249
+ 'Weimaraner',
250
+ 'Staffordshire bullterrier, Staffordshire bull terrier',
251
+ 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier', # noqa: E501
252
+ 'Bedlington terrier',
253
+ 'Border terrier',
254
+ 'Kerry blue terrier',
255
+ 'Irish terrier',
256
+ 'Norfolk terrier',
257
+ 'Norwich terrier',
258
+ 'Yorkshire terrier',
259
+ 'wire-haired fox terrier',
260
+ 'Lakeland terrier',
261
+ 'Sealyham terrier, Sealyham',
262
+ 'Airedale, Airedale terrier',
263
+ 'cairn, cairn terrier',
264
+ 'Australian terrier',
265
+ 'Dandie Dinmont, Dandie Dinmont terrier',
266
+ 'Boston bull, Boston terrier',
267
+ 'miniature schnauzer',
268
+ 'giant schnauzer',
269
+ 'standard schnauzer',
270
+ 'Scotch terrier, Scottish terrier, Scottie',
271
+ 'Tibetan terrier, chrysanthemum dog',
272
+ 'silky terrier, Sydney silky',
273
+ 'soft-coated wheaten terrier',
274
+ 'West Highland white terrier',
275
+ 'Lhasa, Lhasa apso',
276
+ 'flat-coated retriever',
277
+ 'curly-coated retriever',
278
+ 'golden retriever',
279
+ 'Labrador retriever',
280
+ 'Chesapeake Bay retriever',
281
+ 'German short-haired pointer',
282
+ 'vizsla, Hungarian pointer',
283
+ 'English setter',
284
+ 'Irish setter, red setter',
285
+ 'Gordon setter',
286
+ 'Brittany spaniel',
287
+ 'clumber, clumber spaniel',
288
+ 'English springer, English springer spaniel',
289
+ 'Welsh springer spaniel',
290
+ 'cocker spaniel, English cocker spaniel, cocker',
291
+ 'Sussex spaniel',
292
+ 'Irish water spaniel',
293
+ 'kuvasz',
294
+ 'schipperke',
295
+ 'groenendael',
296
+ 'malinois',
297
+ 'briard',
298
+ 'kelpie',
299
+ 'komondor',
300
+ 'Old English sheepdog, bobtail',
301
+ 'Shetland sheepdog, Shetland sheep dog, Shetland',
302
+ 'collie',
303
+ 'Border collie',
304
+ 'Bouvier des Flandres, Bouviers des Flandres',
305
+ 'Rottweiler',
306
+ 'German shepherd, German shepherd dog, German police dog, alsatian',
307
+ 'Doberman, Doberman pinscher',
308
+ 'miniature pinscher',
309
+ 'Greater Swiss Mountain dog',
310
+ 'Bernese mountain dog',
311
+ 'Appenzeller',
312
+ 'EntleBucher',
313
+ 'boxer',
314
+ 'bull mastiff',
315
+ 'Tibetan mastiff',
316
+ 'French bulldog',
317
+ 'Great Dane',
318
+ 'Saint Bernard, St Bernard',
319
+ 'Eskimo dog, husky',
320
+ 'malamute, malemute, Alaskan malamute',
321
+ 'Siberian husky',
322
+ 'dalmatian, coach dog, carriage dog',
323
+ 'affenpinscher, monkey pinscher, monkey dog',
324
+ 'basenji',
325
+ 'pug, pug-dog',
326
+ 'Leonberg',
327
+ 'Newfoundland, Newfoundland dog',
328
+ 'Great Pyrenees',
329
+ 'Samoyed, Samoyede',
330
+ 'Pomeranian',
331
+ 'chow, chow chow',
332
+ 'keeshond',
333
+ 'Brabancon griffon',
334
+ 'Pembroke, Pembroke Welsh corgi',
335
+ 'Cardigan, Cardigan Welsh corgi',
336
+ 'toy poodle',
337
+ 'miniature poodle',
338
+ 'standard poodle',
339
+ 'Mexican hairless',
340
+ 'timber wolf, grey wolf, gray wolf, Canis lupus',
341
+ 'white wolf, Arctic wolf, Canis lupus tundrarum',
342
+ 'red wolf, maned wolf, Canis rufus, Canis niger',
343
+ 'coyote, prairie wolf, brush wolf, Canis latrans',
344
+ 'dingo, warrigal, warragal, Canis dingo',
345
+ 'dhole, Cuon alpinus',
346
+ 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
347
+ 'hyena, hyaena',
348
+ 'red fox, Vulpes vulpes',
349
+ 'kit fox, Vulpes macrotis',
350
+ 'Arctic fox, white fox, Alopex lagopus',
351
+ 'grey fox, gray fox, Urocyon cinereoargenteus',
352
+ 'tabby, tabby cat',
353
+ 'tiger cat',
354
+ 'Persian cat',
355
+ 'Siamese cat, Siamese',
356
+ 'Egyptian cat',
357
+ 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor', # noqa: E501
358
+ 'lynx, catamount',
359
+ 'leopard, Panthera pardus',
360
+ 'snow leopard, ounce, Panthera uncia',
361
+ 'jaguar, panther, Panthera onca, Felis onca',
362
+ 'lion, king of beasts, Panthera leo',
363
+ 'tiger, Panthera tigris',
364
+ 'cheetah, chetah, Acinonyx jubatus',
365
+ 'brown bear, bruin, Ursus arctos',
366
+ 'American black bear, black bear, Ursus americanus, Euarctos americanus', # noqa: E501
367
+ 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
368
+ 'sloth bear, Melursus ursinus, Ursus ursinus',
369
+ 'mongoose',
370
+ 'meerkat, mierkat',
371
+ 'tiger beetle',
372
+ 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
373
+ 'ground beetle, carabid beetle',
374
+ 'long-horned beetle, longicorn, longicorn beetle',
375
+ 'leaf beetle, chrysomelid',
376
+ 'dung beetle',
377
+ 'rhinoceros beetle',
378
+ 'weevil',
379
+ 'fly',
380
+ 'bee',
381
+ 'ant, emmet, pismire',
382
+ 'grasshopper, hopper',
383
+ 'cricket',
384
+ 'walking stick, walkingstick, stick insect',
385
+ 'cockroach, roach',
386
+ 'mantis, mantid',
387
+ 'cicada, cicala',
388
+ 'leafhopper',
389
+ 'lacewing, lacewing fly',
390
+ "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", # noqa: E501
391
+ 'damselfly',
392
+ 'admiral',
393
+ 'ringlet, ringlet butterfly',
394
+ 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
395
+ 'cabbage butterfly',
396
+ 'sulphur butterfly, sulfur butterfly',
397
+ 'lycaenid, lycaenid butterfly',
398
+ 'starfish, sea star',
399
+ 'sea urchin',
400
+ 'sea cucumber, holothurian',
401
+ 'wood rabbit, cottontail, cottontail rabbit',
402
+ 'hare',
403
+ 'Angora, Angora rabbit',
404
+ 'hamster',
405
+ 'porcupine, hedgehog',
406
+ 'fox squirrel, eastern fox squirrel, Sciurus niger',
407
+ 'marmot',
408
+ 'beaver',
409
+ 'guinea pig, Cavia cobaya',
410
+ 'sorrel',
411
+ 'zebra',
412
+ 'hog, pig, grunter, squealer, Sus scrofa',
413
+ 'wild boar, boar, Sus scrofa',
414
+ 'warthog',
415
+ 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
416
+ 'ox',
417
+ 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
418
+ 'bison',
419
+ 'ram, tup',
420
+ 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis', # noqa: E501
421
+ 'ibex, Capra ibex',
422
+ 'hartebeest',
423
+ 'impala, Aepyceros melampus',
424
+ 'gazelle',
425
+ 'Arabian camel, dromedary, Camelus dromedarius',
426
+ 'llama',
427
+ 'weasel',
428
+ 'mink',
429
+ 'polecat, fitch, foulmart, foumart, Mustela putorius',
430
+ 'black-footed ferret, ferret, Mustela nigripes',
431
+ 'otter',
432
+ 'skunk, polecat, wood pussy',
433
+ 'badger',
434
+ 'armadillo',
435
+ 'three-toed sloth, ai, Bradypus tridactylus',
436
+ 'orangutan, orang, orangutang, Pongo pygmaeus',
437
+ 'gorilla, Gorilla gorilla',
438
+ 'chimpanzee, chimp, Pan troglodytes',
439
+ 'gibbon, Hylobates lar',
440
+ 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
441
+ 'guenon, guenon monkey',
442
+ 'patas, hussar monkey, Erythrocebus patas',
443
+ 'baboon',
444
+ 'macaque',
445
+ 'langur',
446
+ 'colobus, colobus monkey',
447
+ 'proboscis monkey, Nasalis larvatus',
448
+ 'marmoset',
449
+ 'capuchin, ringtail, Cebus capucinus',
450
+ 'howler monkey, howler',
451
+ 'titi, titi monkey',
452
+ 'spider monkey, Ateles geoffroyi',
453
+ 'squirrel monkey, Saimiri sciureus',
454
+ 'Madagascar cat, ring-tailed lemur, Lemur catta',
455
+ 'indri, indris, Indri indri, Indri brevicaudatus',
456
+ 'Indian elephant, Elephas maximus',
457
+ 'African elephant, Loxodonta africana',
458
+ 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
459
+ 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
460
+ 'barracouta, snoek',
461
+ 'eel',
462
+ 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch', # noqa: E501
463
+ 'rock beauty, Holocanthus tricolor',
464
+ 'anemone fish',
465
+ 'sturgeon',
466
+ 'gar, garfish, garpike, billfish, Lepisosteus osseus',
467
+ 'lionfish',
468
+ 'puffer, pufferfish, blowfish, globefish',
469
+ 'abacus',
470
+ 'abaya',
471
+ "academic gown, academic robe, judge's robe",
472
+ 'accordion, piano accordion, squeeze box',
473
+ 'acoustic guitar',
474
+ 'aircraft carrier, carrier, flattop, attack aircraft carrier',
475
+ 'airliner',
476
+ 'airship, dirigible',
477
+ 'altar',
478
+ 'ambulance',
479
+ 'amphibian, amphibious vehicle',
480
+ 'analog clock',
481
+ 'apiary, bee house',
482
+ 'apron',
483
+ 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin', # noqa: E501
484
+ 'assault rifle, assault gun',
485
+ 'backpack, back pack, knapsack, packsack, rucksack, haversack',
486
+ 'bakery, bakeshop, bakehouse',
487
+ 'balance beam, beam',
488
+ 'balloon',
489
+ 'ballpoint, ballpoint pen, ballpen, Biro',
490
+ 'Band Aid',
491
+ 'banjo',
492
+ 'bannister, banister, balustrade, balusters, handrail',
493
+ 'barbell',
494
+ 'barber chair',
495
+ 'barbershop',
496
+ 'barn',
497
+ 'barometer',
498
+ 'barrel, cask',
499
+ 'barrow, garden cart, lawn cart, wheelbarrow',
500
+ 'baseball',
501
+ 'basketball',
502
+ 'bassinet',
503
+ 'bassoon',
504
+ 'bathing cap, swimming cap',
505
+ 'bath towel',
506
+ 'bathtub, bathing tub, bath, tub',
507
+ 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon', # noqa: E501
508
+ 'beacon, lighthouse, beacon light, pharos',
509
+ 'beaker',
510
+ 'bearskin, busby, shako',
511
+ 'beer bottle',
512
+ 'beer glass',
513
+ 'bell cote, bell cot',
514
+ 'bib',
515
+ 'bicycle-built-for-two, tandem bicycle, tandem',
516
+ 'bikini, two-piece',
517
+ 'binder, ring-binder',
518
+ 'binoculars, field glasses, opera glasses',
519
+ 'birdhouse',
520
+ 'boathouse',
521
+ 'bobsled, bobsleigh, bob',
522
+ 'bolo tie, bolo, bola tie, bola',
523
+ 'bonnet, poke bonnet',
524
+ 'bookcase',
525
+ 'bookshop, bookstore, bookstall',
526
+ 'bottlecap',
527
+ 'bow',
528
+ 'bow tie, bow-tie, bowtie',
529
+ 'brass, memorial tablet, plaque',
530
+ 'brassiere, bra, bandeau',
531
+ 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
532
+ 'breastplate, aegis, egis',
533
+ 'broom',
534
+ 'bucket, pail',
535
+ 'buckle',
536
+ 'bulletproof vest',
537
+ 'bullet train, bullet',
538
+ 'butcher shop, meat market',
539
+ 'cab, hack, taxi, taxicab',
540
+ 'caldron, cauldron',
541
+ 'candle, taper, wax light',
542
+ 'cannon',
543
+ 'canoe',
544
+ 'can opener, tin opener',
545
+ 'cardigan',
546
+ 'car mirror',
547
+ 'carousel, carrousel, merry-go-round, roundabout, whirligig',
548
+ "carpenter's kit, tool kit",
549
+ 'carton',
550
+ 'car wheel',
551
+ 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM', # noqa: E501
552
+ 'cassette',
553
+ 'cassette player',
554
+ 'castle',
555
+ 'catamaran',
556
+ 'CD player',
557
+ 'cello, violoncello',
558
+ 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
559
+ 'chain',
560
+ 'chainlink fence',
561
+ 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour', # noqa: E501
562
+ 'chain saw, chainsaw',
563
+ 'chest',
564
+ 'chiffonier, commode',
565
+ 'chime, bell, gong',
566
+ 'china cabinet, china closet',
567
+ 'Christmas stocking',
568
+ 'church, church building',
569
+ 'cinema, movie theater, movie theatre, movie house, picture palace',
570
+ 'cleaver, meat cleaver, chopper',
571
+ 'cliff dwelling',
572
+ 'cloak',
573
+ 'clog, geta, patten, sabot',
574
+ 'cocktail shaker',
575
+ 'coffee mug',
576
+ 'coffeepot',
577
+ 'coil, spiral, volute, whorl, helix',
578
+ 'combination lock',
579
+ 'computer keyboard, keypad',
580
+ 'confectionery, confectionary, candy store',
581
+ 'container ship, containership, container vessel',
582
+ 'convertible',
583
+ 'corkscrew, bottle screw',
584
+ 'cornet, horn, trumpet, trump',
585
+ 'cowboy boot',
586
+ 'cowboy hat, ten-gallon hat',
587
+ 'cradle',
588
+ 'crane',
589
+ 'crash helmet',
590
+ 'crate',
591
+ 'crib, cot',
592
+ 'Crock Pot',
593
+ 'croquet ball',
594
+ 'crutch',
595
+ 'cuirass',
596
+ 'dam, dike, dyke',
597
+ 'desk',
598
+ 'desktop computer',
599
+ 'dial telephone, dial phone',
600
+ 'diaper, nappy, napkin',
601
+ 'digital clock',
602
+ 'digital watch',
603
+ 'dining table, board',
604
+ 'dishrag, dishcloth',
605
+ 'dishwasher, dish washer, dishwashing machine',
606
+ 'disk brake, disc brake',
607
+ 'dock, dockage, docking facility',
608
+ 'dogsled, dog sled, dog sleigh',
609
+ 'dome',
610
+ 'doormat, welcome mat',
611
+ 'drilling platform, offshore rig',
612
+ 'drum, membranophone, tympan',
613
+ 'drumstick',
614
+ 'dumbbell',
615
+ 'Dutch oven',
616
+ 'electric fan, blower',
617
+ 'electric guitar',
618
+ 'electric locomotive',
619
+ 'entertainment center',
620
+ 'envelope',
621
+ 'espresso maker',
622
+ 'face powder',
623
+ 'feather boa, boa',
624
+ 'file, file cabinet, filing cabinet',
625
+ 'fireboat',
626
+ 'fire engine, fire truck',
627
+ 'fire screen, fireguard',
628
+ 'flagpole, flagstaff',
629
+ 'flute, transverse flute',
630
+ 'folding chair',
631
+ 'football helmet',
632
+ 'forklift',
633
+ 'fountain',
634
+ 'fountain pen',
635
+ 'four-poster',
636
+ 'freight car',
637
+ 'French horn, horn',
638
+ 'frying pan, frypan, skillet',
639
+ 'fur coat',
640
+ 'garbage truck, dustcart',
641
+ 'gasmask, respirator, gas helmet',
642
+ 'gas pump, gasoline pump, petrol pump, island dispenser',
643
+ 'goblet',
644
+ 'go-kart',
645
+ 'golf ball',
646
+ 'golfcart, golf cart',
647
+ 'gondola',
648
+ 'gong, tam-tam',
649
+ 'gown',
650
+ 'grand piano, grand',
651
+ 'greenhouse, nursery, glasshouse',
652
+ 'grille, radiator grille',
653
+ 'grocery store, grocery, food market, market',
654
+ 'guillotine',
655
+ 'hair slide',
656
+ 'hair spray',
657
+ 'half track',
658
+ 'hammer',
659
+ 'hamper',
660
+ 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
661
+ 'hand-held computer, hand-held microcomputer',
662
+ 'handkerchief, hankie, hanky, hankey',
663
+ 'hard disc, hard disk, fixed disk',
664
+ 'harmonica, mouth organ, harp, mouth harp',
665
+ 'harp',
666
+ 'harvester, reaper',
667
+ 'hatchet',
668
+ 'holster',
669
+ 'home theater, home theatre',
670
+ 'honeycomb',
671
+ 'hook, claw',
672
+ 'hoopskirt, crinoline',
673
+ 'horizontal bar, high bar',
674
+ 'horse cart, horse-cart',
675
+ 'hourglass',
676
+ 'iPod',
677
+ 'iron, smoothing iron',
678
+ "jack-o'-lantern",
679
+ 'jean, blue jean, denim',
680
+ 'jeep, landrover',
681
+ 'jersey, T-shirt, tee shirt',
682
+ 'jigsaw puzzle',
683
+ 'jinrikisha, ricksha, rickshaw',
684
+ 'joystick',
685
+ 'kimono',
686
+ 'knee pad',
687
+ 'knot',
688
+ 'lab coat, laboratory coat',
689
+ 'ladle',
690
+ 'lampshade, lamp shade',
691
+ 'laptop, laptop computer',
692
+ 'lawn mower, mower',
693
+ 'lens cap, lens cover',
694
+ 'letter opener, paper knife, paperknife',
695
+ 'library',
696
+ 'lifeboat',
697
+ 'lighter, light, igniter, ignitor',
698
+ 'limousine, limo',
699
+ 'liner, ocean liner',
700
+ 'lipstick, lip rouge',
701
+ 'Loafer',
702
+ 'lotion',
703
+ 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system', # noqa: E501
704
+ "loupe, jeweler's loupe",
705
+ 'lumbermill, sawmill',
706
+ 'magnetic compass',
707
+ 'mailbag, postbag',
708
+ 'mailbox, letter box',
709
+ 'maillot',
710
+ 'maillot, tank suit',
711
+ 'manhole cover',
712
+ 'maraca',
713
+ 'marimba, xylophone',
714
+ 'mask',
715
+ 'matchstick',
716
+ 'maypole',
717
+ 'maze, labyrinth',
718
+ 'measuring cup',
719
+ 'medicine chest, medicine cabinet',
720
+ 'megalith, megalithic structure',
721
+ 'microphone, mike',
722
+ 'microwave, microwave oven',
723
+ 'military uniform',
724
+ 'milk can',
725
+ 'minibus',
726
+ 'miniskirt, mini',
727
+ 'minivan',
728
+ 'missile',
729
+ 'mitten',
730
+ 'mixing bowl',
731
+ 'mobile home, manufactured home',
732
+ 'Model T',
733
+ 'modem',
734
+ 'monastery',
735
+ 'monitor',
736
+ 'moped',
737
+ 'mortar',
738
+ 'mortarboard',
739
+ 'mosque',
740
+ 'mosquito net',
741
+ 'motor scooter, scooter',
742
+ 'mountain bike, all-terrain bike, off-roader',
743
+ 'mountain tent',
744
+ 'mouse, computer mouse',
745
+ 'mousetrap',
746
+ 'moving van',
747
+ 'muzzle',
748
+ 'nail',
749
+ 'neck brace',
750
+ 'necklace',
751
+ 'nipple',
752
+ 'notebook, notebook computer',
753
+ 'obelisk',
754
+ 'oboe, hautboy, hautbois',
755
+ 'ocarina, sweet potato',
756
+ 'odometer, hodometer, mileometer, milometer',
757
+ 'oil filter',
758
+ 'organ, pipe organ',
759
+ 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
760
+ 'overskirt',
761
+ 'oxcart',
762
+ 'oxygen mask',
763
+ 'packet',
764
+ 'paddle, boat paddle',
765
+ 'paddlewheel, paddle wheel',
766
+ 'padlock',
767
+ 'paintbrush',
768
+ "pajama, pyjama, pj's, jammies",
769
+ 'palace',
770
+ 'panpipe, pandean pipe, syrinx',
771
+ 'paper towel',
772
+ 'parachute, chute',
773
+ 'parallel bars, bars',
774
+ 'park bench',
775
+ 'parking meter',
776
+ 'passenger car, coach, carriage',
777
+ 'patio, terrace',
778
+ 'pay-phone, pay-station',
779
+ 'pedestal, plinth, footstall',
780
+ 'pencil box, pencil case',
781
+ 'pencil sharpener',
782
+ 'perfume, essence',
783
+ 'Petri dish',
784
+ 'photocopier',
785
+ 'pick, plectrum, plectron',
786
+ 'pickelhaube',
787
+ 'picket fence, paling',
788
+ 'pickup, pickup truck',
789
+ 'pier',
790
+ 'piggy bank, penny bank',
791
+ 'pill bottle',
792
+ 'pillow',
793
+ 'ping-pong ball',
794
+ 'pinwheel',
795
+ 'pirate, pirate ship',
796
+ 'pitcher, ewer',
797
+ "plane, carpenter's plane, woodworking plane",
798
+ 'planetarium',
799
+ 'plastic bag',
800
+ 'plate rack',
801
+ 'plow, plough',
802
+ "plunger, plumber's helper",
803
+ 'Polaroid camera, Polaroid Land camera',
804
+ 'pole',
805
+ 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria', # noqa: E501
806
+ 'poncho',
807
+ 'pool table, billiard table, snooker table',
808
+ 'pop bottle, soda bottle',
809
+ 'pot, flowerpot',
810
+ "potter's wheel",
811
+ 'power drill',
812
+ 'prayer rug, prayer mat',
813
+ 'printer',
814
+ 'prison, prison house',
815
+ 'projectile, missile',
816
+ 'projector',
817
+ 'puck, hockey puck',
818
+ 'punching bag, punch bag, punching ball, punchball',
819
+ 'purse',
820
+ 'quill, quill pen',
821
+ 'quilt, comforter, comfort, puff',
822
+ 'racer, race car, racing car',
823
+ 'racket, racquet',
824
+ 'radiator',
825
+ 'radio, wireless',
826
+ 'radio telescope, radio reflector',
827
+ 'rain barrel',
828
+ 'recreational vehicle, RV, R.V.',
829
+ 'reel',
830
+ 'reflex camera',
831
+ 'refrigerator, icebox',
832
+ 'remote control, remote',
833
+ 'restaurant, eating house, eating place, eatery',
834
+ 'revolver, six-gun, six-shooter',
835
+ 'rifle',
836
+ 'rocking chair, rocker',
837
+ 'rotisserie',
838
+ 'rubber eraser, rubber, pencil eraser',
839
+ 'rugby ball',
840
+ 'rule, ruler',
841
+ 'running shoe',
842
+ 'safe',
843
+ 'safety pin',
844
+ 'saltshaker, salt shaker',
845
+ 'sandal',
846
+ 'sarong',
847
+ 'sax, saxophone',
848
+ 'scabbard',
849
+ 'scale, weighing machine',
850
+ 'school bus',
851
+ 'schooner',
852
+ 'scoreboard',
853
+ 'screen, CRT screen',
854
+ 'screw',
855
+ 'screwdriver',
856
+ 'seat belt, seatbelt',
857
+ 'sewing machine',
858
+ 'shield, buckler',
859
+ 'shoe shop, shoe-shop, shoe store',
860
+ 'shoji',
861
+ 'shopping basket',
862
+ 'shopping cart',
863
+ 'shovel',
864
+ 'shower cap',
865
+ 'shower curtain',
866
+ 'ski',
867
+ 'ski mask',
868
+ 'sleeping bag',
869
+ 'slide rule, slipstick',
870
+ 'sliding door',
871
+ 'slot, one-armed bandit',
872
+ 'snorkel',
873
+ 'snowmobile',
874
+ 'snowplow, snowplough',
875
+ 'soap dispenser',
876
+ 'soccer ball',
877
+ 'sock',
878
+ 'solar dish, solar collector, solar furnace',
879
+ 'sombrero',
880
+ 'soup bowl',
881
+ 'space bar',
882
+ 'space heater',
883
+ 'space shuttle',
884
+ 'spatula',
885
+ 'speedboat',
886
+ "spider web, spider's web",
887
+ 'spindle',
888
+ 'sports car, sport car',
889
+ 'spotlight, spot',
890
+ 'stage',
891
+ 'steam locomotive',
892
+ 'steel arch bridge',
893
+ 'steel drum',
894
+ 'stethoscope',
895
+ 'stole',
896
+ 'stone wall',
897
+ 'stopwatch, stop watch',
898
+ 'stove',
899
+ 'strainer',
900
+ 'streetcar, tram, tramcar, trolley, trolley car',
901
+ 'stretcher',
902
+ 'studio couch, day bed',
903
+ 'stupa, tope',
904
+ 'submarine, pigboat, sub, U-boat',
905
+ 'suit, suit of clothes',
906
+ 'sundial',
907
+ 'sunglass',
908
+ 'sunglasses, dark glasses, shades',
909
+ 'sunscreen, sunblock, sun blocker',
910
+ 'suspension bridge',
911
+ 'swab, swob, mop',
912
+ 'sweatshirt',
913
+ 'swimming trunks, bathing trunks',
914
+ 'swing',
915
+ 'switch, electric switch, electrical switch',
916
+ 'syringe',
917
+ 'table lamp',
918
+ 'tank, army tank, armored combat vehicle, armoured combat vehicle',
919
+ 'tape player',
920
+ 'teapot',
921
+ 'teddy, teddy bear',
922
+ 'television, television system',
923
+ 'tennis ball',
924
+ 'thatch, thatched roof',
925
+ 'theater curtain, theatre curtain',
926
+ 'thimble',
927
+ 'thresher, thrasher, threshing machine',
928
+ 'throne',
929
+ 'tile roof',
930
+ 'toaster',
931
+ 'tobacco shop, tobacconist shop, tobacconist',
932
+ 'toilet seat',
933
+ 'torch',
934
+ 'totem pole',
935
+ 'tow truck, tow car, wrecker',
936
+ 'toyshop',
937
+ 'tractor',
938
+ 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi', # noqa: E501
939
+ 'tray',
940
+ 'trench coat',
941
+ 'tricycle, trike, velocipede',
942
+ 'trimaran',
943
+ 'tripod',
944
+ 'triumphal arch',
945
+ 'trolleybus, trolley coach, trackless trolley',
946
+ 'trombone',
947
+ 'tub, vat',
948
+ 'turnstile',
949
+ 'typewriter keyboard',
950
+ 'umbrella',
951
+ 'unicycle, monocycle',
952
+ 'upright, upright piano',
953
+ 'vacuum, vacuum cleaner',
954
+ 'vase',
955
+ 'vault',
956
+ 'velvet',
957
+ 'vending machine',
958
+ 'vestment',
959
+ 'viaduct',
960
+ 'violin, fiddle',
961
+ 'volleyball',
962
+ 'waffle iron',
963
+ 'wall clock',
964
+ 'wallet, billfold, notecase, pocketbook',
965
+ 'wardrobe, closet, press',
966
+ 'warplane, military plane',
967
+ 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
968
+ 'washer, automatic washer, washing machine',
969
+ 'water bottle',
970
+ 'water jug',
971
+ 'water tower',
972
+ 'whiskey jug',
973
+ 'whistle',
974
+ 'wig',
975
+ 'window screen',
976
+ 'window shade',
977
+ 'Windsor tie',
978
+ 'wine bottle',
979
+ 'wing',
980
+ 'wok',
981
+ 'wooden spoon',
982
+ 'wool, woolen, woollen',
983
+ 'worm fence, snake fence, snake-rail fence, Virginia fence',
984
+ 'wreck',
985
+ 'yawl',
986
+ 'yurt',
987
+ 'web site, website, internet site, site',
988
+ 'comic book',
989
+ 'crossword puzzle, crossword',
990
+ 'street sign',
991
+ 'traffic light, traffic signal, stoplight',
992
+ 'book jacket, dust cover, dust jacket, dust wrapper',
993
+ 'menu',
994
+ 'plate',
995
+ 'guacamole',
996
+ 'consomme',
997
+ 'hot pot, hotpot',
998
+ 'trifle',
999
+ 'ice cream, icecream',
1000
+ 'ice lolly, lolly, lollipop, popsicle',
1001
+ 'French loaf',
1002
+ 'bagel, beigel',
1003
+ 'pretzel',
1004
+ 'cheeseburger',
1005
+ 'hotdog, hot dog, red hot',
1006
+ 'mashed potato',
1007
+ 'head cabbage',
1008
+ 'broccoli',
1009
+ 'cauliflower',
1010
+ 'zucchini, courgette',
1011
+ 'spaghetti squash',
1012
+ 'acorn squash',
1013
+ 'butternut squash',
1014
+ 'cucumber, cuke',
1015
+ 'artichoke, globe artichoke',
1016
+ 'bell pepper',
1017
+ 'cardoon',
1018
+ 'mushroom',
1019
+ 'Granny Smith',
1020
+ 'strawberry',
1021
+ 'orange',
1022
+ 'lemon',
1023
+ 'fig',
1024
+ 'pineapple, ananas',
1025
+ 'banana',
1026
+ 'jackfruit, jak, jack',
1027
+ 'custard apple',
1028
+ 'pomegranate',
1029
+ 'hay',
1030
+ 'carbonara',
1031
+ 'chocolate sauce, chocolate syrup',
1032
+ 'dough',
1033
+ 'meat loaf, meatloaf',
1034
+ 'pizza, pizza pie',
1035
+ 'potpie',
1036
+ 'burrito',
1037
+ 'red wine',
1038
+ 'espresso',
1039
+ 'cup',
1040
+ 'eggnog',
1041
+ 'alp',
1042
+ 'bubble',
1043
+ 'cliff, drop, drop-off',
1044
+ 'coral reef',
1045
+ 'geyser',
1046
+ 'lakeside, lakeshore',
1047
+ 'promontory, headland, head, foreland',
1048
+ 'sandbar, sand bar',
1049
+ 'seashore, coast, seacoast, sea-coast',
1050
+ 'valley, vale',
1051
+ 'volcano',
1052
+ 'ballplayer, baseball player',
1053
+ 'groom, bridegroom',
1054
+ 'scuba diver',
1055
+ 'rapeseed',
1056
+ 'daisy',
1057
+ "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", # noqa: E501
1058
+ 'corn',
1059
+ 'acorn',
1060
+ 'hip, rose hip, rosehip',
1061
+ 'buckeye, horse chestnut, conker',
1062
+ 'coral fungus',
1063
+ 'agaric',
1064
+ 'gyromitra',
1065
+ 'stinkhorn, carrion fungus',
1066
+ 'earthstar',
1067
+ 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa', # noqa: E501
1068
+ 'bolete',
1069
+ 'ear, spike, capitulum',
1070
+ 'toilet tissue, toilet paper, bathroom tissue')
1071
+
1072
+ CIFAR10_CATEGORIES = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog',
1073
+ 'frog', 'horse', 'ship', 'truck')
1074
+
1075
+ CIFAR100_CATEGORIES = (
1076
+ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
1077
+ 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
1078
+ 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
1079
+ 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
1080
+ 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
1081
+ 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
1082
+ 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain',
1083
+ 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree',
1084
+ 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy',
1085
+ 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket',
1086
+ 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail',
1087
+ 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper',
1088
+ 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train',
1089
+ 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf',
1090
+ 'woman', 'worm')
1091
+
1092
+ MNIST_CATEGORITES = ('0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
1093
+ '5 - five', '6 - six', '7 - seven', '8 - eight',
1094
+ '9 - nine')
1095
+
1096
+ FASHIONMNIST_CATEGORITES = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress',
1097
+ 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag',
1098
+ 'Ankle boot')
1099
+
1100
+ PLACES205_CATEGORIES = (
1101
+ 'abbey', 'airport_terminal', 'alley', 'amphitheater', 'amusement_park',
1102
+ 'aquarium', 'aqueduct', 'arch', 'art_gallery', 'art_studio',
1103
+ 'assembly_line', 'attic', 'auditorium', 'apartment_building/outdoor',
1104
+ 'badlands', 'ballroom', 'bamboo_forest', 'banquet_hall', 'bar',
1105
+ 'baseball_field', 'basement', 'basilica', 'bayou', 'beauty_salon',
1106
+ 'bedroom', 'boardwalk', 'boat_deck', 'bookstore', 'botanical_garden',
1107
+ 'bowling_alley', 'boxing_ring', 'bridge', 'building_facade',
1108
+ 'bus_interior', 'butchers_shop', 'butte', 'bakery/shop', 'cafeteria',
1109
+ 'campsite', 'candy_store', 'canyon', 'castle', 'cemetery', 'chalet',
1110
+ 'classroom', 'closet', 'clothing_store', 'coast', 'cockpit', 'coffee_shop',
1111
+ 'conference_center', 'conference_room', 'construction_site', 'corn_field',
1112
+ 'corridor', 'cottage_garden', 'courthouse', 'courtyard', 'creek',
1113
+ 'crevasse', 'crosswalk', 'cathedral/outdoor', 'church/outdoor', 'dam',
1114
+ 'dining_room', 'dock', 'dorm_room', 'driveway', 'desert/sand',
1115
+ 'desert/vegetation', 'dinette/home', 'doorway/outdoor', 'engine_room',
1116
+ 'excavation', 'fairway', 'fire_escape', 'fire_station', 'food_court',
1117
+ 'forest_path', 'forest_road', 'formal_garden', 'fountain',
1118
+ 'field/cultivated', 'field/wild', 'galley', 'game_room', 'garbage_dump',
1119
+ 'gas_station', 'gift_shop', 'golf_course', 'harbor', 'herb_garden',
1120
+ 'highway', 'home_office', 'hospital', 'hospital_room', 'hot_spring',
1121
+ 'hotel_room', 'hotel/outdoor', 'ice_cream_parlor', 'iceberg', 'igloo',
1122
+ 'islet', 'ice_skating_rink/outdoor', 'inn/outdoor', 'jail_cell', 'kasbah',
1123
+ 'kindergarden_classroom', 'kitchen', 'kitchenette', 'laundromat',
1124
+ 'lighthouse', 'living_room', 'lobby', 'locker_room', 'mansion', 'marsh',
1125
+ 'martial_arts_gym', 'mausoleum', 'medina', 'motel', 'mountain',
1126
+ 'mountain_snowy', 'music_studio', 'market/outdoor', 'monastery/outdoor',
1127
+ 'museum/indoor', 'nursery', 'ocean', 'office', 'office_building',
1128
+ 'orchard', 'pagoda', 'palace', 'pantry', 'parking_lot', 'parlor',
1129
+ 'pasture', 'patio', 'pavilion', 'phone_booth', 'picnic_area', 'playground',
1130
+ 'plaza', 'pond', 'pulpit', 'racecourse', 'raft', 'railroad_track',
1131
+ 'rainforest', 'reception', 'residential_neighborhood', 'restaurant',
1132
+ 'restaurant_kitchen', 'restaurant_patio', 'rice_paddy', 'river',
1133
+ 'rock_arch', 'rope_bridge', 'ruin', 'runway', 'sandbar', 'schoolhouse',
1134
+ 'sea_cliff', 'shed', 'shoe_shop', 'shopfront', 'shower', 'ski_resort',
1135
+ 'ski_slope', 'sky', 'skyscraper', 'slum', 'snowfield', 'staircase',
1136
+ 'supermarket', 'swamp', 'stadium/baseball', 'stadium/football',
1137
+ 'stage/indoor', 'subway_station/platform', 'swimming_pool/outdoor',
1138
+ 'television_studio', 'topiary_garden', 'tower', 'train_railway',
1139
+ 'tree_farm', 'trench', 'temple/east_asia', 'temple/south_asia',
1140
+ 'track/outdoor', 'train_station/platform', 'underwater/coral_reef',
1141
+ 'valley', 'vegetable_garden', 'veranda', 'viaduct', 'volcano',
1142
+ 'waiting_room', 'water_tower', 'watering_hole', 'wheat_field', 'wind_farm',
1143
+ 'windmill', 'yard')
1144
+
1145
+ OxfordIIITPet_CATEGORIES = (
1146
+ 'Abyssinian', 'american_bulldog', 'american_pit_bull_terrier',
1147
+ 'basset_hound', 'beagle', 'Bengal', 'Birman', 'Bombay', 'boxer',
1148
+ 'British_Shorthair', 'chihuahua', 'Egyptian_Mau', 'english_cocker_spaniel',
1149
+ 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese',
1150
+ 'japanese_chin', 'keeshond', 'leonberger', 'Maine_Coon',
1151
+ 'miniature_pinscher', 'newfoundland', 'Persian', 'pomeranian', 'pug',
1152
+ 'Ragdoll', 'Russian_Blue', 'saint_bernard', 'samoyed', 'scottish_terrier',
1153
+ 'shiba_inu', 'Siamese', 'Sphynx', 'staffordshire_bull_terrier',
1154
+ 'wheaten_terrier', 'yorkshire_terrier')
1155
+
1156
+ DTD_CATEGORIES = ('banded', 'blotchy', 'braided', 'bubbly', 'bumpy',
1157
+ 'chequered', 'cobwebbed', 'cracked', 'crosshatched',
1158
+ 'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled',
1159
+ 'frilly', 'gauzy', 'grid', 'grooved', 'honeycombed',
1160
+ 'interlaced', 'knitted', 'lacelike', 'lined', 'marbled',
1161
+ 'matted', 'meshed', 'paisley', 'perforated', 'pitted',
1162
+ 'pleated', 'polka-dotted', 'porous', 'potholed', 'scaly',
1163
+ 'smeared', 'spiralled', 'sprinkled', 'stained', 'stratified',
1164
+ 'striped', 'studded', 'swirly', 'veined', 'waffled', 'woven',
1165
+ 'wrinkled', 'zigzagged')
1166
+
1167
+ FGVCAIRCRAFT_CATEGORIES = (
1168
+ '707-320', '727-200', '737-200', '737-300', '737-400', '737-500',
1169
+ '737-600', '737-700', '737-800', '737-900', '747-100', '747-200',
1170
+ '747-300', '747-400', '757-200', '757-300', '767-200', '767-300',
1171
+ '767-400', '777-200', '777-300', 'A300B4', 'A310', 'A318', 'A319', 'A320',
1172
+ 'A321', 'A330-200', 'A330-300', 'A340-200', 'A340-300', 'A340-500',
1173
+ 'A340-600', 'A380', 'ATR-42', 'ATR-72', 'An-12', 'BAE 146-200',
1174
+ 'BAE 146-300', 'BAE-125', 'Beechcraft 1900', 'Boeing 717', 'C-130', 'C-47',
1175
+ 'CRJ-200', 'CRJ-700', 'CRJ-900', 'Cessna 172', 'Cessna 208', 'Cessna 525',
1176
+ 'Cessna 560', 'Challenger 600', 'DC-10', 'DC-3', 'DC-6', 'DC-8', 'DC-9-30',
1177
+ 'DH-82', 'DHC-1', 'DHC-6', 'DHC-8-100', 'DHC-8-300', 'DR-400',
1178
+ 'Dornier 328', 'E-170', 'E-190', 'E-195', 'EMB-120', 'ERJ 135', 'ERJ 145',
1179
+ 'Embraer Legacy 600', 'Eurofighter Typhoon', 'F-16A/B', 'F/A-18',
1180
+ 'Falcon 2000', 'Falcon 900', 'Fokker 100', 'Fokker 50', 'Fokker 70',
1181
+ 'Global Express', 'Gulfstream IV', 'Gulfstream V', 'Hawk T1', 'Il-76',
1182
+ 'L-1011', 'MD-11', 'MD-80', 'MD-87', 'MD-90', 'Metroliner', 'Model B200',
1183
+ 'PA-28', 'SR-20', 'Saab 2000', 'Saab 340', 'Spitfire', 'Tornado', 'Tu-134',
1184
+ 'Tu-154', 'Yak-42')
1185
+
1186
+ STANFORDCARS_CATEGORIES = (
1187
+ 'AM General Hummer SUV 2000', 'Acura RL Sedan 2012', 'Acura TL Sedan 2012',
1188
+ 'Acura TL Type-S 2008', 'Acura TSX Sedan 2012',
1189
+ 'Acura Integra Type R 2001', 'Acura ZDX Hatchback 2012',
1190
+ 'Aston Martin V8 Vantage Convertible 2012',
1191
+ 'Aston Martin V8 Vantage Coupe 2012',
1192
+ 'Aston Martin Virage Convertible 2012', 'Aston Martin Virage Coupe 2012',
1193
+ 'Audi RS 4 Convertible 2008', 'Audi A5 Coupe 2012', 'Audi TTS Coupe 2012',
1194
+ 'Audi R8 Coupe 2012', 'Audi V8 Sedan 1994', 'Audi 100 Sedan 1994',
1195
+ 'Audi 100 Wagon 1994', 'Audi TT Hatchback 2011', 'Audi S6 Sedan 2011',
1196
+ 'Audi S5 Convertible 2012', 'Audi S5 Coupe 2012', 'Audi S4 Sedan 2012',
1197
+ 'Audi S4 Sedan 2007', 'Audi TT RS Coupe 2012',
1198
+ 'BMW ActiveHybrid 5 Sedan 2012', 'BMW 1 Series Convertible 2012',
1199
+ 'BMW 1 Series Coupe 2012', 'BMW 3 Series Sedan 2012',
1200
+ 'BMW 3 Series Wagon 2012', 'BMW 6 Series Convertible 2007',
1201
+ 'BMW X5 SUV 2007', 'BMW X6 SUV 2012', 'BMW M3 Coupe 2012',
1202
+ 'BMW M5 Sedan 2010', 'BMW M6 Convertible 2010', 'BMW X3 SUV 2012',
1203
+ 'BMW Z4 Convertible 2012',
1204
+ 'Bentley Continental Supersports Conv. Convertible 2012',
1205
+ 'Bentley Arnage Sedan 2009', 'Bentley Mulsanne Sedan 2011',
1206
+ 'Bentley Continental GT Coupe 2012', 'Bentley Continental GT Coupe 2007',
1207
+ 'Bentley Continental Flying Spur Sedan 2007',
1208
+ 'Bugatti Veyron 16.4 Convertible 2009', 'Bugatti Veyron 16.4 Coupe 2009',
1209
+ 'Buick Regal GS 2012', 'Buick Rainier SUV 2007', 'Buick Verano Sedan 2012',
1210
+ 'Buick Enclave SUV 2012', 'Cadillac CTS-V Sedan 2012',
1211
+ 'Cadillac SRX SUV 2012', 'Cadillac Escalade EXT Crew Cab 2007',
1212
+ 'Chevrolet Silverado 1500 Hybrid Crew Cab 2012',
1213
+ 'Chevrolet Corvette Convertible 2012', 'Chevrolet Corvette ZR1 2012',
1214
+ 'Chevrolet Corvette Ron Fellows Edition Z06 2007',
1215
+ 'Chevrolet Traverse SUV 2012', 'Chevrolet Camaro Convertible 2012',
1216
+ 'Chevrolet HHR SS 2010', 'Chevrolet Impala Sedan 2007',
1217
+ 'Chevrolet Tahoe Hybrid SUV 2012', 'Chevrolet Sonic Sedan 2012',
1218
+ 'Chevrolet Express Cargo Van 2007', 'Chevrolet Avalanche Crew Cab 2012',
1219
+ 'Chevrolet Cobalt SS 2010', 'Chevrolet Malibu Hybrid Sedan 2010',
1220
+ 'Chevrolet TrailBlazer SS 2009',
1221
+ 'Chevrolet Silverado 2500HD Regular Cab 2012',
1222
+ 'Chevrolet Silverado 1500 Classic Extended Cab 2007',
1223
+ 'Chevrolet Express Van 2007', 'Chevrolet Monte Carlo Coupe 2007',
1224
+ 'Chevrolet Malibu Sedan 2007',
1225
+ 'Chevrolet Silverado 1500 Extended Cab 2012',
1226
+ 'Chevrolet Silverado 1500 Regular Cab 2012', 'Chrysler Aspen SUV 2009',
1227
+ 'Chrysler Sebring Convertible 2010',
1228
+ 'Chrysler Town and Country Minivan 2012', 'Chrysler 300 SRT-8 2010',
1229
+ 'Chrysler Crossfire Convertible 2008',
1230
+ 'Chrysler PT Cruiser Convertible 2008', 'Daewoo Nubira Wagon 2002',
1231
+ 'Dodge Caliber Wagon 2012', 'Dodge Caliber Wagon 2007',
1232
+ 'Dodge Caravan Minivan 1997', 'Dodge Ram Pickup 3500 Crew Cab 2010',
1233
+ 'Dodge Ram Pickup 3500 Quad Cab 2009', 'Dodge Sprinter Cargo Van 2009',
1234
+ 'Dodge Journey SUV 2012', 'Dodge Dakota Crew Cab 2010',
1235
+ 'Dodge Dakota Club Cab 2007', 'Dodge Magnum Wagon 2008',
1236
+ 'Dodge Challenger SRT8 2011', 'Dodge Durango SUV 2012',
1237
+ 'Dodge Durango SUV 2007', 'Dodge Charger Sedan 2012',
1238
+ 'Dodge Charger SRT-8 2009', 'Eagle Talon Hatchback 1998',
1239
+ 'FIAT 500 Abarth 2012', 'FIAT 500 Convertible 2012',
1240
+ 'Ferrari FF Coupe 2012', 'Ferrari California Convertible 2012',
1241
+ 'Ferrari 458 Italia Convertible 2012', 'Ferrari 458 Italia Coupe 2012',
1242
+ 'Fisker Karma Sedan 2012', 'Ford F-450 Super Duty Crew Cab 2012',
1243
+ 'Ford Mustang Convertible 2007', 'Ford Freestar Minivan 2007',
1244
+ 'Ford Expedition EL SUV 2009', 'Ford Edge SUV 2012',
1245
+ 'Ford Ranger SuperCab 2011', 'Ford GT Coupe 2006',
1246
+ 'Ford F-150 Regular Cab 2012', 'Ford F-150 Regular Cab 2007',
1247
+ 'Ford Focus Sedan 2007', 'Ford E-Series Wagon Van 2012',
1248
+ 'Ford Fiesta Sedan 2012', 'GMC Terrain SUV 2012', 'GMC Savana Van 2012',
1249
+ 'GMC Yukon Hybrid SUV 2012', 'GMC Acadia SUV 2012',
1250
+ 'GMC Canyon Extended Cab 2012', 'Geo Metro Convertible 1993',
1251
+ 'HUMMER H3T Crew Cab 2010', 'HUMMER H2 SUT Crew Cab 2009',
1252
+ 'Honda Odyssey Minivan 2012', 'Honda Odyssey Minivan 2007',
1253
+ 'Honda Accord Coupe 2012', 'Honda Accord Sedan 2012',
1254
+ 'Hyundai Veloster Hatchback 2012', 'Hyundai Santa Fe SUV 2012',
1255
+ 'Hyundai Tucson SUV 2012', 'Hyundai Veracruz SUV 2012',
1256
+ 'Hyundai Sonata Hybrid Sedan 2012', 'Hyundai Elantra Sedan 2007',
1257
+ 'Hyundai Accent Sedan 2012', 'Hyundai Genesis Sedan 2012',
1258
+ 'Hyundai Sonata Sedan 2012', 'Hyundai Elantra Touring Hatchback 2012',
1259
+ 'Hyundai Azera Sedan 2012', 'Infiniti G Coupe IPL 2012',
1260
+ 'Infiniti QX56 SUV 2011', 'Isuzu Ascender SUV 2008', 'Jaguar XK XKR 2012',
1261
+ 'Jeep Patriot SUV 2012', 'Jeep Wrangler SUV 2012', 'Jeep Liberty SUV 2012',
1262
+ 'Jeep Grand Cherokee SUV 2012', 'Jeep Compass SUV 2012',
1263
+ 'Lamborghini Reventon Coupe 2008', 'Lamborghini Aventador Coupe 2012',
1264
+ 'Lamborghini Gallardo LP 570-4 Superleggera 2012',
1265
+ 'Lamborghini Diablo Coupe 2001', 'Land Rover Range Rover SUV 2012',
1266
+ 'Land Rover LR2 SUV 2012', 'Lincoln Town Car Sedan 2011',
1267
+ 'MINI Cooper Roadster Convertible 2012',
1268
+ 'Maybach Landaulet Convertible 2012', 'Mazda Tribute SUV 2011',
1269
+ 'McLaren MP4-12C Coupe 2012', 'Mercedes-Benz 300-Class Convertible 1993',
1270
+ 'Mercedes-Benz C-Class Sedan 2012', 'Mercedes-Benz SL-Class Coupe 2009',
1271
+ 'Mercedes-Benz E-Class Sedan 2012', 'Mercedes-Benz S-Class Sedan 2012',
1272
+ 'Mercedes-Benz Sprinter Van 2012', 'Mitsubishi Lancer Sedan 2012',
1273
+ 'Nissan Leaf Hatchback 2012', 'Nissan NV Passenger Van 2012',
1274
+ 'Nissan Juke Hatchback 2012', 'Nissan 240SX Coupe 1998',
1275
+ 'Plymouth Neon Coupe 1999', 'Porsche Panamera Sedan 2012',
1276
+ 'Ram C/V Cargo Van Minivan 2012',
1277
+ 'Rolls-Royce Phantom Drophead Coupe Convertible 2012',
1278
+ 'Rolls-Royce Ghost Sedan 2012', 'Rolls-Royce Phantom Sedan 2012',
1279
+ 'Scion xD Hatchback 2012', 'Spyker C8 Convertible 2009',
1280
+ 'Spyker C8 Coupe 2009', 'Suzuki Aerio Sedan 2007',
1281
+ 'Suzuki Kizashi Sedan 2012', 'Suzuki SX4 Hatchback 2012',
1282
+ 'Suzuki SX4 Sedan 2012', 'Tesla Model S Sedan 2012',
1283
+ 'Toyota Sequoia SUV 2012', 'Toyota Camry Sedan 2012',
1284
+ 'Toyota Corolla Sedan 2012', 'Toyota 4Runner SUV 2012',
1285
+ 'Volkswagen Golf Hatchback 2012', 'Volkswagen Golf Hatchback 1991',
1286
+ 'Volkswagen Beetle Hatchback 2012', 'Volvo C30 Hatchback 2012',
1287
+ 'Volvo 240 Sedan 1993', 'Volvo XC90 SUV 2007',
1288
+ 'smart fortwo Convertible 2012')
1289
+
1290
+ SUN397_CATEGORIES = (
1291
+ 'abbey', 'airplane_cabin', 'airport_terminal', 'alley', 'amphitheater',
1292
+ 'amusement_arcade', 'amusement_park', 'anechoic_chamber',
1293
+ 'apartment_building_outdoor', 'apse_indoor', 'aquarium', 'aqueduct',
1294
+ 'arch', 'archive', 'arrival_gate_outdoor', 'art_gallery', 'art_school',
1295
+ 'art_studio', 'assembly_line', 'athletic_field_outdoor', 'atrium_public',
1296
+ 'attic', 'auditorium', 'auto_factory', 'badlands',
1297
+ 'badminton_court_indoor', 'baggage_claim', 'bakery_shop',
1298
+ 'balcony_exterior', 'balcony_interior', 'ball_pit', 'ballroom',
1299
+ 'bamboo_forest', 'banquet_hall', 'bar', 'barn', 'barndoor',
1300
+ 'baseball_field', 'basement', 'basilica', 'basketball_court_outdoor',
1301
+ 'bathroom', 'batters_box', 'bayou', 'bazaar_indoor', 'bazaar_outdoor',
1302
+ 'beach', 'beauty_salon', 'bedroom', 'berth', 'biology_laboratory',
1303
+ 'bistro_indoor', 'boardwalk', 'boat_deck', 'boathouse', 'bookstore',
1304
+ 'booth_indoor', 'botanical_garden', 'bow_window_indoor',
1305
+ 'bow_window_outdoor', 'bowling_alley', 'boxing_ring', 'brewery_indoor',
1306
+ 'bridge', 'building_facade', 'bullring', 'burial_chamber', 'bus_interior',
1307
+ 'butchers_shop', 'butte', 'cabin_outdoor', 'cafeteria', 'campsite',
1308
+ 'campus', 'canal_natural', 'canal_urban', 'candy_store', 'canyon',
1309
+ 'car_interior_backseat', 'car_interior_frontseat', 'carrousel',
1310
+ 'casino_indoor', 'castle', 'catacomb', 'cathedral_indoor',
1311
+ 'cathedral_outdoor', 'cavern_indoor', 'cemetery', 'chalet',
1312
+ 'cheese_factory', 'chemistry_lab', 'chicken_coop_indoor',
1313
+ 'chicken_coop_outdoor', 'childs_room', 'church_indoor', 'church_outdoor',
1314
+ 'classroom', 'clean_room', 'cliff', 'cloister_indoor', 'closet',
1315
+ 'clothing_store', 'coast', 'cockpit', 'coffee_shop', 'computer_room',
1316
+ 'conference_center', 'conference_room', 'construction_site',
1317
+ 'control_room', 'control_tower_outdoor', 'corn_field', 'corral',
1318
+ 'corridor', 'cottage_garden', 'courthouse', 'courtroom', 'courtyard',
1319
+ 'covered_bridge_exterior', 'creek', 'crevasse', 'crosswalk',
1320
+ 'cubicle_office', 'dam', 'delicatessen', 'dentists_office', 'desert_sand',
1321
+ 'desert_vegetation', 'diner_indoor', 'diner_outdoor', 'dinette_home',
1322
+ 'dinette_vehicle', 'dining_car', 'dining_room', 'discotheque', 'dock',
1323
+ 'doorway_outdoor', 'dorm_room', 'driveway', 'driving_range_outdoor',
1324
+ 'drugstore', 'electrical_substation', 'elevator_door', 'elevator_interior',
1325
+ 'elevator_shaft', 'engine_room', 'escalator_indoor', 'excavation',
1326
+ 'factory_indoor', 'fairway', 'fastfood_restaurant', 'field_cultivated',
1327
+ 'field_wild', 'fire_escape', 'fire_station', 'firing_range_indoor',
1328
+ 'fishpond', 'florist_shop_indoor', 'food_court', 'forest_broadleaf',
1329
+ 'forest_needleleaf', 'forest_path', 'forest_road', 'formal_garden',
1330
+ 'fountain', 'galley', 'game_room', 'garage_indoor', 'garbage_dump',
1331
+ 'gas_station', 'gazebo_exterior', 'general_store_indoor',
1332
+ 'general_store_outdoor', 'gift_shop', 'golf_course', 'greenhouse_indoor',
1333
+ 'greenhouse_outdoor', 'gymnasium_indoor', 'hangar_indoor',
1334
+ 'hangar_outdoor', 'harbor', 'hayfield', 'heliport', 'herb_garden',
1335
+ 'highway', 'hill', 'home_office', 'hospital', 'hospital_room',
1336
+ 'hot_spring', 'hot_tub_outdoor', 'hotel_outdoor', 'hotel_room', 'house',
1337
+ 'hunting_lodge_outdoor', 'ice_cream_parlor', 'ice_floe', 'ice_shelf',
1338
+ 'ice_skating_rink_indoor', 'ice_skating_rink_outdoor', 'iceberg', 'igloo',
1339
+ 'industrial_area', 'inn_outdoor', 'islet', 'jacuzzi_indoor', 'jail_indoor',
1340
+ 'jail_cell', 'jewelry_shop', 'kasbah', 'kennel_indoor', 'kennel_outdoor',
1341
+ 'kindergarden_classroom', 'kitchen', 'kitchenette', 'labyrinth_outdoor',
1342
+ 'lake_natural', 'landfill', 'landing_deck', 'laundromat', 'lecture_room',
1343
+ 'library_indoor', 'library_outdoor', 'lido_deck_outdoor', 'lift_bridge',
1344
+ 'lighthouse', 'limousine_interior', 'living_room', 'lobby', 'lock_chamber',
1345
+ 'locker_room', 'mansion', 'manufactured_home', 'market_indoor',
1346
+ 'market_outdoor', 'marsh', 'martial_arts_gym', 'mausoleum', 'medina',
1347
+ 'moat_water', 'monastery_outdoor', 'mosque_indoor', 'mosque_outdoor',
1348
+ 'motel', 'mountain', 'mountain_snowy', 'movie_theater_indoor',
1349
+ 'museum_indoor', 'music_store', 'music_studio',
1350
+ 'nuclear_power_plant_outdoor', 'nursery', 'oast_house',
1351
+ 'observatory_outdoor', 'ocean', 'office', 'office_building',
1352
+ 'oil_refinery_outdoor', 'oilrig', 'operating_room', 'orchard',
1353
+ 'outhouse_outdoor', 'pagoda', 'palace', 'pantry', 'park',
1354
+ 'parking_garage_indoor', 'parking_garage_outdoor', 'parking_lot', 'parlor',
1355
+ 'pasture', 'patio', 'pavilion', 'pharmacy', 'phone_booth',
1356
+ 'physics_laboratory', 'picnic_area', 'pilothouse_indoor',
1357
+ 'planetarium_outdoor', 'playground', 'playroom', 'plaza', 'podium_indoor',
1358
+ 'podium_outdoor', 'pond', 'poolroom_establishment', 'poolroom_home',
1359
+ 'power_plant_outdoor', 'promenade_deck', 'pub_indoor', 'pulpit',
1360
+ 'putting_green', 'racecourse', 'raceway', 'raft', 'railroad_track',
1361
+ 'rainforest', 'reception', 'recreation_room', 'residential_neighborhood',
1362
+ 'restaurant', 'restaurant_kitchen', 'restaurant_patio', 'rice_paddy',
1363
+ 'riding_arena', 'river', 'rock_arch', 'rope_bridge', 'ruin', 'runway',
1364
+ 'sandbar', 'sandbox', 'sauna', 'schoolhouse', 'sea_cliff', 'server_room',
1365
+ 'shed', 'shoe_shop', 'shopfront', 'shopping_mall_indoor', 'shower',
1366
+ 'skatepark', 'ski_lodge', 'ski_resort', 'ski_slope', 'sky', 'skyscraper',
1367
+ 'slum', 'snowfield', 'squash_court', 'stable', 'stadium_baseball',
1368
+ 'stadium_football', 'stage_indoor', 'staircase', 'street',
1369
+ 'subway_interior', 'subway_station_platform', 'supermarket', 'sushi_bar',
1370
+ 'swamp', 'swimming_pool_indoor', 'swimming_pool_outdoor',
1371
+ 'synagogue_indoor', 'synagogue_outdoor', 'television_studio',
1372
+ 'temple_east_asia', 'temple_south_asia', 'tennis_court_indoor',
1373
+ 'tennis_court_outdoor', 'tent_outdoor', 'theater_indoor_procenium',
1374
+ 'theater_indoor_seats', 'thriftshop', 'throne_room', 'ticket_booth',
1375
+ 'toll_plaza', 'topiary_garden', 'tower', 'toyshop', 'track_outdoor',
1376
+ 'train_railway', 'train_station_platform', 'tree_farm', 'tree_house',
1377
+ 'trench', 'underwater_coral_reef', 'utility_room', 'valley',
1378
+ 'van_interior', 'vegetable_garden', 'veranda', 'veterinarians_office',
1379
+ 'viaduct', 'videostore', 'village', 'vineyard', 'volcano',
1380
+ 'volleyball_court_indoor', 'volleyball_court_outdoor', 'waiting_room',
1381
+ 'warehouse_indoor', 'water_tower', 'waterfall_block', 'waterfall_fan',
1382
+ 'waterfall_plunge', 'watering_hole', 'wave', 'wet_bar', 'wheat_field',
1383
+ 'wind_farm', 'windmill', 'wine_cellar_barrel_storage',
1384
+ 'wine_cellar_bottle_storage', 'wrestling_ring_indoor', 'yard',
1385
+ 'youth_hostel')
1386
+
1387
+ CALTECH101_CATEGORIES = (
1388
+ 'BACKGROUND_Google', 'Faces', 'Faces_easy', 'Leopards', 'Motorbikes',
1389
+ 'accordion', 'airplanes', 'anchor', 'ant', 'barrel', 'bass', 'beaver',
1390
+ 'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly',
1391
+ 'camera', 'cannon', 'car_side', 'ceiling_fan', 'cellphone', 'chair',
1392
+ 'chandelier', 'cougar_body', 'cougar_face', 'crab', 'crayfish',
1393
+ 'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill',
1394
+ 'dolphin', 'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium',
1395
+ 'ewer', 'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk',
1396
+ 'gramophone', 'grand_piano', 'hawksbill', 'headphone', 'hedgehog',
1397
+ 'helicopter', 'ibis', 'inline_skate', 'joshua_tree', 'kangaroo', 'ketch',
1398
+ 'lamp', 'laptop', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly',
1399
+ 'menorah', 'metronome', 'minaret', 'nautilus', 'octopus', 'okapi',
1400
+ 'pagoda', 'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver',
1401
+ 'rhino', 'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion',
1402
+ 'sea_horse', 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus',
1403
+ 'stop_sign', 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella',
1404
+ 'watch', 'water_lilly', 'wheelchair', 'wild_cat', 'windsor_chair',
1405
+ 'wrench', 'yin_yang')
1406
+
1407
+ FOOD101_CATEGORIES = (
1408
+ 'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare',
1409
+ 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito',
1410
+ 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake',
1411
+ 'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry',
1412
+ 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake',
1413
+ 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich',
1414
+ 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs',
1415
+ 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel',
1416
+ 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries',
1417
+ 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice',
1418
+ 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad',
1419
+ 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza',
1420
+ 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus',
1421
+ 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich',
1422
+ 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos',
1423
+ 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes',
1424
+ 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine',
1425
+ 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake',
1426
+ 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad',
1427
+ 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara',
1428
+ 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos',
1429
+ 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles')
1430
+
1431
+ CIFAR100_CATEGORIES_CN = (
1432
+ '苹果', '水族馆鱼', '婴儿', '熊', '河狸', '床', '蜜蜂', '甲虫', '自行车', '瓶子', '碗', '小男孩',
1433
+ '桥', '公共汽车', '蝴蝶', '骆驼', '易拉罐', '城堡', '毛毛虫', '牛', '椅子', '猩猩', '钟', '白云',
1434
+ '蟑螂', '沙发', '螃蟹', '鳄鱼', '杯子', '恐龙', '海豚', '大象', '比目鱼', '森林', '狐狸', '小女孩',
1435
+ '仓鼠', '屋子', '袋鼠', '键盘', '台灯', '割草机', '猎豹', '狮子', '蜥蜴', '龙虾', '男人', '枫树',
1436
+ '摩托车', '山', '老鼠', '蘑菇', '橡树', '橙子橘子', '兰花', '水獭', '棕榈树', '梨', '皮卡车', '松树',
1437
+ '田野', '盘子', '罂粟', '豪猪', '负鼠', '兔子', '浣熊', '鳐鱼', '公路', '火箭', '玫瑰', '大海',
1438
+ '海豹', '鲨鱼', '尖嘴小鼠', '臭鼬', '摩天大楼', '蜗牛', '蛇', '蜘蛛', '松鼠', '电车', '向日葵', '甜椒',
1439
+ '桌子', '坦克', '电话', '电视', '老虎', '拖拉机', '火车', '鳟鱼', '郁金香', '乌龟', '衣柜', '鲸鱼',
1440
+ '柳树', '狼', '女人', '蠕虫')
mmpretrain/datasets/cifar.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import pickle
3
+ from typing import List, Optional
4
+
5
+ import mmengine.dist as dist
6
+ import numpy as np
7
+ from mmengine.fileio import (LocalBackend, exists, get, get_file_backend,
8
+ join_path)
9
+ from mmengine.logging import MMLogger
10
+
11
+ from mmpretrain.registry import DATASETS
12
+ from .base_dataset import BaseDataset
13
+ from .categories import CIFAR10_CATEGORIES, CIFAR100_CATEGORIES
14
+ from .utils import check_md5, download_and_extract_archive
15
+
16
+
17
+ @DATASETS.register_module()
18
+ class CIFAR10(BaseDataset):
19
+ """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
20
+
21
+ This implementation is modified from
22
+ https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py
23
+
24
+ Args:
25
+ data_root (str): The root directory of the CIFAR Dataset.
26
+ split (str, optional): The dataset split, supports "train" and "test".
27
+ Default to "train".
28
+ metainfo (dict, optional): Meta information for dataset, such as
29
+ categories information. Defaults to None.
30
+ download (bool): Whether to download the dataset if not exists.
31
+ Defaults to True.
32
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
33
+ """ # noqa: E501
34
+
35
+ base_folder = 'cifar-10-batches-py'
36
+ url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
37
+ filename = 'cifar-10-python.tar.gz'
38
+ tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
39
+ train_list = [
40
+ ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
41
+ ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
42
+ ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
43
+ ['data_batch_4', '634d18415352ddfa80567beed471001a'],
44
+ ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
45
+ ]
46
+
47
+ test_list = [
48
+ ['test_batch', '40351d587109b95175f43aff81a1287e'],
49
+ ]
50
+ meta = {
51
+ 'filename': 'batches.meta',
52
+ 'key': 'label_names',
53
+ 'md5': '5ff9c542aee3614f3951f8cda6e48888',
54
+ }
55
+ METAINFO = {'classes': CIFAR10_CATEGORIES}
56
+
57
+ def __init__(self,
58
+ data_root: str = '',
59
+ split: str = 'train',
60
+ metainfo: Optional[dict] = None,
61
+ download: bool = True,
62
+ data_prefix: str = '',
63
+ test_mode: bool = False,
64
+ **kwargs):
65
+
66
+ splits = ['train', 'test']
67
+ assert split in splits, \
68
+ f"The split must be one of {splits}, but get '{split}'"
69
+ self.split = split
70
+
71
+ # To handle the BC-breaking
72
+ if split == 'train' and test_mode:
73
+ logger = MMLogger.get_current_instance()
74
+ logger.warning('split="train" but test_mode=True. '
75
+ 'The training set will be used.')
76
+
77
+ if not data_root and not data_prefix:
78
+ raise RuntimeError('Please set ``data_root`` to'
79
+ 'specify the dataset path')
80
+
81
+ self.download = download
82
+ super().__init__(
83
+ # The CIFAR dataset doesn't need specify annotation file
84
+ ann_file='',
85
+ metainfo=metainfo,
86
+ data_root=data_root,
87
+ data_prefix=dict(root=data_prefix),
88
+ test_mode=test_mode,
89
+ **kwargs)
90
+
91
+ def load_data_list(self):
92
+ """Load images and ground truth labels."""
93
+ root = self.data_prefix['root']
94
+ backend = get_file_backend(root, enable_singleton=True)
95
+
96
+ if dist.is_main_process() and not self._check_integrity():
97
+ if not isinstance(backend, LocalBackend):
98
+ raise RuntimeError(f'The dataset on {root} is not integrated, '
99
+ f'please manually handle it.')
100
+
101
+ if self.download:
102
+ download_and_extract_archive(
103
+ self.url, root, filename=self.filename, md5=self.tgz_md5)
104
+ else:
105
+ raise RuntimeError(
106
+ f'Cannot find {self.__class__.__name__} dataset in '
107
+ f"{self.data_prefix['root']}, you can specify "
108
+ '`download=True` to download automatically.')
109
+
110
+ dist.barrier()
111
+ assert self._check_integrity(), \
112
+ 'Download failed or shared storage is unavailable. Please ' \
113
+ f'download the dataset manually through {self.url}.'
114
+
115
+ if self.split == 'train':
116
+ downloaded_list = self.train_list
117
+ else:
118
+ downloaded_list = self.test_list
119
+
120
+ imgs = []
121
+ gt_labels = []
122
+
123
+ # load the picked numpy arrays
124
+ for file_name, _ in downloaded_list:
125
+ file_path = join_path(root, self.base_folder, file_name)
126
+ entry = pickle.loads(get(file_path), encoding='latin1')
127
+ imgs.append(entry['data'])
128
+ if 'labels' in entry:
129
+ gt_labels.extend(entry['labels'])
130
+ else:
131
+ gt_labels.extend(entry['fine_labels'])
132
+
133
+ imgs = np.vstack(imgs).reshape(-1, 3, 32, 32)
134
+ imgs = imgs.transpose((0, 2, 3, 1)) # convert to HWC
135
+
136
+ if self.CLASSES is None:
137
+ # The metainfo in the file has the lowest priority, therefore
138
+ # we only need to load it if classes is not specified.
139
+ self._load_meta()
140
+
141
+ data_list = []
142
+ for img, gt_label in zip(imgs, gt_labels):
143
+ info = {'img': img, 'gt_label': int(gt_label)}
144
+ data_list.append(info)
145
+ return data_list
146
+
147
+ def _load_meta(self):
148
+ """Load categories information from metafile."""
149
+ root = self.data_prefix['root']
150
+
151
+ path = join_path(root, self.base_folder, self.meta['filename'])
152
+ md5 = self.meta.get('md5', None)
153
+ if not exists(path) or (md5 is not None and not check_md5(path, md5)):
154
+ raise RuntimeError(
155
+ 'Dataset metadata file not found or corrupted.' +
156
+ ' You can use `download=True` to download it')
157
+ data = pickle.loads(get(path), encoding='latin1')
158
+ self._metainfo.setdefault('classes', data[self.meta['key']])
159
+
160
+ def _check_integrity(self):
161
+ """Check the integrity of data files."""
162
+ root = self.data_prefix['root']
163
+
164
+ for fentry in (self.train_list + self.test_list):
165
+ filename, md5 = fentry[0], fentry[1]
166
+ fpath = join_path(root, self.base_folder, filename)
167
+ if not exists(fpath):
168
+ return False
169
+ if md5 is not None and not check_md5(fpath, md5):
170
+ return False
171
+ return True
172
+
173
+ def extra_repr(self) -> List[str]:
174
+ """The extra repr information of the dataset."""
175
+ body = [f"Prefix of data: \t{self.data_prefix['root']}"]
176
+ return body
177
+
178
+
179
+ @DATASETS.register_module()
180
+ class CIFAR100(CIFAR10):
181
+ """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
182
+
183
+ Args:
184
+ data_root (str): The root directory of the CIFAR Dataset.
185
+ split (str, optional): The dataset split, supports "train" and "test".
186
+ Default to "train".
187
+ metainfo (dict, optional): Meta information for dataset, such as
188
+ categories information. Defaults to None.
189
+ download (bool): Whether to download the dataset if not exists.
190
+ Defaults to True.
191
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
192
+ """
193
+
194
+ base_folder = 'cifar-100-python'
195
+ url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
196
+ filename = 'cifar-100-python.tar.gz'
197
+ tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
198
+ train_list = [
199
+ ['train', '16019d7e3df5f24257cddd939b257f8d'],
200
+ ]
201
+
202
+ test_list = [
203
+ ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
204
+ ]
205
+ meta = {
206
+ 'filename': 'meta',
207
+ 'key': 'fine_label_names',
208
+ 'md5': '7973b15100ade9c7d40fb424638fde48',
209
+ }
210
+ METAINFO = {'classes': CIFAR100_CATEGORIES}
mmpretrain/datasets/coco_caption.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from pathlib import Path
3
+ from typing import List
4
+
5
+ import mmengine
6
+ from mmengine.dataset import BaseDataset
7
+ from mmengine.fileio import get_file_backend
8
+
9
+ from mmpretrain.registry import DATASETS
10
+
11
+
12
+ @DATASETS.register_module()
13
+ class COCOCaption(BaseDataset):
14
+ """COCO Caption dataset.
15
+
16
+ Args:
17
+ data_root (str): The root directory for ``data_prefix`` and
18
+ ``ann_file``..
19
+ ann_file (str): Annotation file path.
20
+ data_prefix (dict): Prefix for data field. Defaults to
21
+ ``dict(img_path='')``.
22
+ pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
23
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
24
+ """
25
+
26
+ def load_data_list(self) -> List[dict]:
27
+ """Load data list."""
28
+ img_prefix = self.data_prefix['img_path']
29
+ annotations = mmengine.load(self.ann_file)
30
+ file_backend = get_file_backend(img_prefix)
31
+
32
+ data_list = []
33
+ for ann in annotations:
34
+ data_info = {
35
+ 'image_id': Path(ann['image']).stem.split('_')[-1],
36
+ 'img_path': file_backend.join_path(img_prefix, ann['image']),
37
+ 'gt_caption': ann['caption'],
38
+ }
39
+
40
+ data_list.append(data_info)
41
+
42
+ return data_list
mmpretrain/datasets/coco_retrieval.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import json
3
+ from collections import OrderedDict
4
+ from typing import List
5
+
6
+ from mmengine import get_file_backend
7
+
8
+ from mmpretrain.registry import DATASETS
9
+ from .base_dataset import BaseDataset
10
+
11
+
12
+ @DATASETS.register_module()
13
+ class COCORetrieval(BaseDataset):
14
+ """COCO Retrieval dataset.
15
+
16
+ Args:
17
+ ann_file (str): Annotation file path.
18
+ test_mode (bool): Whether dataset is used for evaluation. This will
19
+ decide the annotation format in data list annotations.
20
+ Defaults to False.
21
+ data_root (str): The root directory for ``data_prefix`` and
22
+ ``ann_file``. Defaults to ''.
23
+ data_prefix (str | dict): Prefix for training data. Defaults to ''.
24
+ pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
25
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
26
+ """
27
+
28
+ def load_data_list(self) -> List[dict]:
29
+ """Load data list."""
30
+ # get file backend
31
+ img_prefix = self.data_prefix['img_path']
32
+ file_backend = get_file_backend(img_prefix)
33
+
34
+ anno_info = json.load(open(self.ann_file, 'r'))
35
+ # mapping img_id to img filename
36
+ img_dict = OrderedDict()
37
+ for idx, img in enumerate(anno_info['images']):
38
+ if img['id'] not in img_dict:
39
+ img_rel_path = img['coco_url'].rsplit('/', 2)[-2:]
40
+ img_path = file_backend.join_path(img_prefix, *img_rel_path)
41
+
42
+ # create new idx for image
43
+ img_dict[img['id']] = dict(
44
+ ori_id=img['id'],
45
+ image_id=idx, # will be used for evaluation
46
+ img_path=img_path,
47
+ text=[],
48
+ gt_text_id=[],
49
+ gt_image_id=[],
50
+ )
51
+
52
+ train_list = []
53
+ for idx, anno in enumerate(anno_info['annotations']):
54
+ anno['text'] = anno.pop('caption')
55
+ anno['ori_id'] = anno.pop('id')
56
+ anno['text_id'] = idx # will be used for evaluation
57
+ # 1. prepare train data list item
58
+ train_data = anno.copy()
59
+ train_image = img_dict[train_data['image_id']]
60
+ train_data['img_path'] = train_image['img_path']
61
+ train_data['image_ori_id'] = train_image['ori_id']
62
+ train_data['image_id'] = train_image['image_id']
63
+ train_data['is_matched'] = True
64
+ train_list.append(train_data)
65
+ # 2. prepare eval data list item based on img dict
66
+ img_dict[anno['image_id']]['gt_text_id'].append(anno['text_id'])
67
+ img_dict[anno['image_id']]['text'].append(anno['text'])
68
+ img_dict[anno['image_id']]['gt_image_id'].append(
69
+ train_image['image_id'])
70
+
71
+ self.img_size = len(img_dict)
72
+ self.text_size = len(anno_info['annotations'])
73
+
74
+ # return needed format data list
75
+ if self.test_mode:
76
+ return list(img_dict.values())
77
+ return train_list
mmpretrain/datasets/coco_vqa.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+ import re
4
+ from collections import Counter
5
+ from typing import List
6
+
7
+ import mmengine
8
+ from mmengine.dataset import BaseDataset
9
+
10
+ from mmpretrain.registry import DATASETS
11
+
12
+
13
+ @DATASETS.register_module()
14
+ class COCOVQA(BaseDataset):
15
+ """VQAv2 dataset.
16
+
17
+ Args:
18
+ data_root (str): The root directory for ``data_prefix``, ``ann_file``
19
+ and ``question_file``.
20
+ data_prefix (str): The directory of images.
21
+ question_file (str): Question file path.
22
+ ann_file (str, optional): Annotation file path for training and
23
+ validation. Defaults to an empty string.
24
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
25
+ """
26
+
27
+ def __init__(self,
28
+ data_root: str,
29
+ data_prefix: str,
30
+ question_file: str,
31
+ ann_file: str = '',
32
+ **kwarg):
33
+ self.question_file = question_file
34
+ super().__init__(
35
+ data_root=data_root,
36
+ data_prefix=dict(img_path=data_prefix),
37
+ ann_file=ann_file,
38
+ **kwarg,
39
+ )
40
+
41
+ def _join_prefix(self):
42
+ if not mmengine.is_abs(self.question_file) and self.question_file:
43
+ self.question_file = osp.join(self.data_root, self.question_file)
44
+
45
+ return super()._join_prefix()
46
+
47
+ def _create_image_index(self):
48
+ img_prefix = self.data_prefix['img_path']
49
+
50
+ files = mmengine.list_dir_or_file(img_prefix, list_dir=False)
51
+ image_index = {}
52
+ for file in files:
53
+ image_id = re.findall(r'\d{12}', file)
54
+ if len(image_id) > 0:
55
+ image_id = int(image_id[-1])
56
+ image_index[image_id] = mmengine.join_path(img_prefix, file)
57
+
58
+ return image_index
59
+
60
+ def load_data_list(self) -> List[dict]:
61
+ """Load data list."""
62
+ questions = mmengine.load(self.question_file)['questions']
63
+ if self.ann_file:
64
+ annotations = mmengine.load(self.ann_file)['annotations']
65
+ assert len(questions) == len(annotations)
66
+ else:
67
+ annotations = [None] * len(questions)
68
+
69
+ # The original VQAv2 annotation file and question file includes
70
+ # only image id but no image file paths.
71
+ self.image_index = self._create_image_index()
72
+
73
+ data_list = []
74
+ for question, ann in zip(questions, annotations):
75
+ # question example
76
+ # {
77
+ # 'image_id': 262144,
78
+ # 'question': "Is the ball flying towards the batter?",
79
+ # 'question_id': 262144000
80
+ # }
81
+ #
82
+ # ann example
83
+ # {
84
+ # 'question_type': "what are the",
85
+ # 'answer_type': "other",
86
+ # 'answers': [
87
+ # {'answer': 'watching',
88
+ # 'answer_id': 1,
89
+ # 'answer_confidence': 'yes'},
90
+ # ...
91
+ # ],
92
+ # 'image_id': 262148,
93
+ # 'question_id': 262148000,
94
+ # 'multiple_choice_answer': 'watching',
95
+ # 'answer_type': 'other',
96
+ # }
97
+
98
+ data_info = question
99
+ data_info['img_path'] = self.image_index[question['image_id']]
100
+
101
+ if ann is not None:
102
+ assert ann['question_id'] == question['question_id']
103
+
104
+ # add answer_weight & answer_count, delete duplicate answer
105
+ answers = [item['answer'] for item in ann.pop('answers')]
106
+ count = Counter(answers)
107
+ answer_weight = [i / len(answers) for i in count.values()]
108
+ data_info['gt_answer'] = list(count.keys())
109
+ data_info['gt_answer_weight'] = answer_weight
110
+ data_info.update(ann)
111
+
112
+ data_list.append(data_info)
113
+
114
+ return data_list
mmpretrain/datasets/cub.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List
3
+
4
+ from mmengine import get_file_backend, list_from_file
5
+ from mmengine.logging import MMLogger
6
+
7
+ from mmpretrain.registry import DATASETS
8
+ from .base_dataset import BaseDataset
9
+ from .categories import CUB_CATEGORIES
10
+
11
+
12
+ @DATASETS.register_module()
13
+ class CUB(BaseDataset):
14
+ """The CUB-200-2011 Dataset.
15
+
16
+ Support the `CUB-200-2011 <http://www.vision.caltech.edu/visipedia/CUB-200-2011.html>`_ Dataset.
17
+ Comparing with the `CUB-200 <http://www.vision.caltech.edu/visipedia/CUB-200.html>`_ Dataset,
18
+ there are much more pictures in `CUB-200-2011`. After downloading and decompression, the dataset
19
+ directory structure is as follows.
20
+
21
+ CUB dataset directory: ::
22
+
23
+ CUB_200_2011
24
+ ├── images
25
+ │ ├── class_x
26
+ │ │ ├── xx1.jpg
27
+ │ │ ├── xx2.jpg
28
+ │ │ └── ...
29
+ │ ├── class_y
30
+ │ │ ├── yy1.jpg
31
+ │ │ ├── yy2.jpg
32
+ │ │ └── ...
33
+ │ └── ...
34
+ ├── images.txt
35
+ ├── image_class_labels.txt
36
+ ├── train_test_split.txt
37
+ └── ....
38
+
39
+ Args:
40
+ data_root (str): The root directory for CUB-200-2011 dataset.
41
+ split (str, optional): The dataset split, supports "train" and "test".
42
+ Default to "train".
43
+
44
+ Examples:
45
+ >>> from mmpretrain.datasets import CUB
46
+ >>> train_dataset = CUB(data_root='data/CUB_200_2011', split='train')
47
+ >>> train_dataset
48
+ Dataset CUB
49
+ Number of samples: 5994
50
+ Number of categories: 200
51
+ Root of dataset: data/CUB_200_2011
52
+ >>> test_dataset = CUB(data_root='data/CUB_200_2011', split='test')
53
+ >>> test_dataset
54
+ Dataset CUB
55
+ Number of samples: 5794
56
+ Number of categories: 200
57
+ Root of dataset: data/CUB_200_2011
58
+ """ # noqa: E501
59
+
60
+ METAINFO = {'classes': CUB_CATEGORIES}
61
+
62
+ def __init__(self,
63
+ data_root: str,
64
+ split: str = 'train',
65
+ test_mode: bool = False,
66
+ **kwargs):
67
+
68
+ splits = ['train', 'test']
69
+ assert split in splits, \
70
+ f"The split must be one of {splits}, but get '{split}'"
71
+ self.split = split
72
+
73
+ # To handle the BC-breaking
74
+ if split == 'train' and test_mode:
75
+ logger = MMLogger.get_current_instance()
76
+ logger.warning('split="train" but test_mode=True. '
77
+ 'The training set will be used.')
78
+
79
+ ann_file = 'images.txt'
80
+ data_prefix = 'images'
81
+ image_class_labels_file = 'image_class_labels.txt'
82
+ train_test_split_file = 'train_test_split.txt'
83
+
84
+ self.backend = get_file_backend(data_root, enable_singleton=True)
85
+ self.image_class_labels_file = self.backend.join_path(
86
+ data_root, image_class_labels_file)
87
+ self.train_test_split_file = self.backend.join_path(
88
+ data_root, train_test_split_file)
89
+ super(CUB, self).__init__(
90
+ ann_file=ann_file,
91
+ data_root=data_root,
92
+ data_prefix=data_prefix,
93
+ test_mode=test_mode,
94
+ **kwargs)
95
+
96
+ def _load_data_from_txt(self, filepath):
97
+ """load data from CUB txt file, the every line of the file is idx and a
98
+ data item."""
99
+ pairs = list_from_file(filepath)
100
+ data_dict = dict()
101
+ for pair in pairs:
102
+ idx, data_item = pair.split()
103
+ # all the index starts from 1 in CUB files,
104
+ # here we need to '- 1' to let them start from 0.
105
+ data_dict[int(idx) - 1] = data_item
106
+ return data_dict
107
+
108
+ def load_data_list(self):
109
+ """Load images and ground truth labels."""
110
+ sample_dict = self._load_data_from_txt(self.ann_file)
111
+
112
+ label_dict = self._load_data_from_txt(self.image_class_labels_file)
113
+
114
+ split_dict = self._load_data_from_txt(self.train_test_split_file)
115
+
116
+ assert sample_dict.keys() == label_dict.keys() == split_dict.keys(),\
117
+ f'sample_ids should be same in files {self.ann_file}, ' \
118
+ f'{self.image_class_labels_file} and {self.train_test_split_file}'
119
+
120
+ data_list = []
121
+ for sample_id in sample_dict.keys():
122
+ if split_dict[sample_id] == '1' and self.split == 'test':
123
+ # skip train samples when split='test'
124
+ continue
125
+ elif split_dict[sample_id] == '0' and self.split == 'train':
126
+ # skip test samples when split='train'
127
+ continue
128
+
129
+ img_path = self.backend.join_path(self.img_prefix,
130
+ sample_dict[sample_id])
131
+ gt_label = int(label_dict[sample_id]) - 1
132
+ info = dict(img_path=img_path, gt_label=gt_label)
133
+ data_list.append(info)
134
+
135
+ return data_list
136
+
137
+ def extra_repr(self) -> List[str]:
138
+ """The extra repr information of the dataset."""
139
+ body = [
140
+ f'Root of dataset: \t{self.data_root}',
141
+ ]
142
+ return body
mmpretrain/datasets/custom.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
3
+
4
+ from mmengine.fileio import (BaseStorageBackend, get_file_backend,
5
+ list_from_file)
6
+ from mmengine.logging import MMLogger
7
+
8
+ from mmpretrain.registry import DATASETS
9
+ from .base_dataset import BaseDataset
10
+
11
+
12
+ def find_folders(
13
+ root: str,
14
+ backend: Optional[BaseStorageBackend] = None
15
+ ) -> Tuple[List[str], Dict[str, int]]:
16
+ """Find classes by folders under a root.
17
+
18
+ Args:
19
+ root (string): root directory of folders
20
+ backend (BaseStorageBackend | None): The file backend of the root.
21
+ If None, auto infer backend from the root path. Defaults to None.
22
+
23
+ Returns:
24
+ Tuple[List[str], Dict[str, int]]:
25
+
26
+ - folders: The name of sub folders under the root.
27
+ - folder_to_idx: The map from folder name to class idx.
28
+ """
29
+ # Pre-build file backend to prevent verbose file backend inference.
30
+ backend = backend or get_file_backend(root, enable_singleton=True)
31
+ folders = list(
32
+ backend.list_dir_or_file(
33
+ root,
34
+ list_dir=True,
35
+ list_file=False,
36
+ recursive=False,
37
+ ))
38
+ folders.sort()
39
+ folder_to_idx = {folders[i]: i for i in range(len(folders))}
40
+ return folders, folder_to_idx
41
+
42
+
43
+ def get_samples(
44
+ root: str,
45
+ folder_to_idx: Dict[str, int],
46
+ is_valid_file: Callable,
47
+ backend: Optional[BaseStorageBackend] = None,
48
+ ):
49
+ """Make dataset by walking all images under a root.
50
+
51
+ Args:
52
+ root (string): root directory of folders
53
+ folder_to_idx (dict): the map from class name to class idx
54
+ is_valid_file (Callable): A function that takes path of a file
55
+ and check if the file is a valid sample file.
56
+ backend (BaseStorageBackend | None): The file backend of the root.
57
+ If None, auto infer backend from the root path. Defaults to None.
58
+
59
+ Returns:
60
+ Tuple[list, set]:
61
+
62
+ - samples: a list of tuple where each element is (image, class_idx)
63
+ - empty_folders: The folders don't have any valid files.
64
+ """
65
+ samples = []
66
+ available_classes = set()
67
+ # Pre-build file backend to prevent verbose file backend inference.
68
+ backend = backend or get_file_backend(root, enable_singleton=True)
69
+
70
+ if folder_to_idx is not None:
71
+ for folder_name in sorted(list(folder_to_idx.keys())):
72
+ _dir = backend.join_path(root, folder_name)
73
+ files = backend.list_dir_or_file(
74
+ _dir,
75
+ list_dir=False,
76
+ list_file=True,
77
+ recursive=True,
78
+ )
79
+ for file in sorted(list(files)):
80
+ if is_valid_file(file):
81
+ path = backend.join_path(folder_name, file)
82
+ item = (path, folder_to_idx[folder_name])
83
+ samples.append(item)
84
+ available_classes.add(folder_name)
85
+ empty_folders = set(folder_to_idx.keys()) - available_classes
86
+ else:
87
+ files = backend.list_dir_or_file(
88
+ root,
89
+ list_dir=False,
90
+ list_file=True,
91
+ recursive=True,
92
+ )
93
+ samples = [file for file in sorted(list(files)) if is_valid_file(file)]
94
+ empty_folders = None
95
+
96
+ return samples, empty_folders
97
+
98
+
99
+ @DATASETS.register_module()
100
+ class CustomDataset(BaseDataset):
101
+ """A generic dataset for multiple tasks.
102
+
103
+ The dataset supports two kinds of style.
104
+
105
+ 1. Use an annotation file to specify all samples, and each line indicates a
106
+ sample:
107
+
108
+ The annotation file (for ``with_label=True``, supervised tasks.): ::
109
+
110
+ folder_1/xxx.png 0
111
+ folder_1/xxy.png 1
112
+ 123.png 4
113
+ nsdf3.png 3
114
+ ...
115
+
116
+ The annotation file (for ``with_label=False``, unsupervised tasks.): ::
117
+
118
+ folder_1/xxx.png
119
+ folder_1/xxy.png
120
+ 123.png
121
+ nsdf3.png
122
+ ...
123
+
124
+ Sample files: ::
125
+
126
+ data_prefix/
127
+ ├── folder_1
128
+ │ ├── xxx.png
129
+ │ ├── xxy.png
130
+ │ └── ...
131
+ ├── 123.png
132
+ ├── nsdf3.png
133
+ └── ...
134
+
135
+ Please use the argument ``metainfo`` to specify extra information for
136
+ the task, like ``{'classes': ('bird', 'cat', 'deer', 'dog', 'frog')}``.
137
+
138
+ 2. Place all samples in one folder as below:
139
+
140
+ Sample files (for ``with_label=True``, supervised tasks, we use the name
141
+ of sub-folders as the categories names): ::
142
+
143
+ data_prefix/
144
+ ├── class_x
145
+ │ ├── xxx.png
146
+ │ ├── xxy.png
147
+ │ └── ...
148
+ │ └── xxz.png
149
+ └── class_y
150
+ ├── 123.png
151
+ ├── nsdf3.png
152
+ ├── ...
153
+ └��─ asd932_.png
154
+
155
+ Sample files (for ``with_label=False``, unsupervised tasks, we use all
156
+ sample files under the specified folder): ::
157
+
158
+ data_prefix/
159
+ ├── folder_1
160
+ │ ├── xxx.png
161
+ │ ├── xxy.png
162
+ │ └── ...
163
+ ├── 123.png
164
+ ├── nsdf3.png
165
+ └── ...
166
+
167
+ If the ``ann_file`` is specified, the dataset will be generated by the
168
+ first way, otherwise, try the second way.
169
+
170
+ Args:
171
+ data_root (str): The root directory for ``data_prefix`` and
172
+ ``ann_file``. Defaults to ''.
173
+ data_prefix (str | dict): Prefix for the data. Defaults to ''.
174
+ ann_file (str): Annotation file path. Defaults to ''.
175
+ with_label (bool): Whether the annotation file includes ground truth
176
+ labels, or use sub-folders to specify categories.
177
+ Defaults to True.
178
+ extensions (Sequence[str]): A sequence of allowed extensions. Defaults
179
+ to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif').
180
+ metainfo (dict, optional): Meta information for dataset, such as class
181
+ information. Defaults to None.
182
+ lazy_init (bool): Whether to load annotation during instantiation.
183
+ In some cases, such as visualization, only the meta information of
184
+ the dataset is needed, which is not necessary to load annotation
185
+ file. ``Basedataset`` can skip load annotations to save time by set
186
+ ``lazy_init=False``. Defaults to False.
187
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
188
+ """
189
+
190
+ def __init__(self,
191
+ data_root: str = '',
192
+ data_prefix: Union[str, dict] = '',
193
+ ann_file: str = '',
194
+ with_label=True,
195
+ extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm',
196
+ '.bmp', '.pgm', '.tif'),
197
+ metainfo: Optional[dict] = None,
198
+ lazy_init: bool = False,
199
+ **kwargs):
200
+ assert (ann_file or data_prefix or data_root), \
201
+ 'One of `ann_file`, `data_root` and `data_prefix` must '\
202
+ 'be specified.'
203
+
204
+ self.extensions = tuple(set([i.lower() for i in extensions]))
205
+ self.with_label = with_label
206
+
207
+ super().__init__(
208
+ # The base class requires string ann_file but this class doesn't
209
+ ann_file=ann_file,
210
+ metainfo=metainfo,
211
+ data_root=data_root,
212
+ data_prefix=data_prefix,
213
+ # Force to lazy_init for some modification before loading data.
214
+ lazy_init=True,
215
+ **kwargs)
216
+
217
+ # Full initialize the dataset.
218
+ if not lazy_init:
219
+ self.full_init()
220
+
221
+ def _find_samples(self):
222
+ """find samples from ``data_prefix``."""
223
+ if self.with_label:
224
+ classes, folder_to_idx = find_folders(self.img_prefix)
225
+ samples, empty_classes = get_samples(
226
+ self.img_prefix,
227
+ folder_to_idx,
228
+ is_valid_file=self.is_valid_file,
229
+ )
230
+
231
+ self.folder_to_idx = folder_to_idx
232
+
233
+ if self.CLASSES is not None:
234
+ assert len(self.CLASSES) == len(classes), \
235
+ f"The number of subfolders ({len(classes)}) doesn't " \
236
+ f'match the number of specified classes ' \
237
+ f'({len(self.CLASSES)}). Please check the data folder.'
238
+ else:
239
+ self._metainfo['classes'] = tuple(classes)
240
+ else:
241
+ samples, empty_classes = get_samples(
242
+ self.img_prefix,
243
+ None,
244
+ is_valid_file=self.is_valid_file,
245
+ )
246
+
247
+ if len(samples) == 0:
248
+ raise RuntimeError(
249
+ f'Found 0 files in subfolders of: {self.data_prefix}. '
250
+ f'Supported extensions are: {",".join(self.extensions)}')
251
+
252
+ if empty_classes:
253
+ logger = MMLogger.get_current_instance()
254
+ logger.warning(
255
+ 'Found no valid file in the folder '
256
+ f'{", ".join(empty_classes)}. '
257
+ f"Supported extensions are: {', '.join(self.extensions)}")
258
+
259
+ return samples
260
+
261
+ def load_data_list(self):
262
+ """Load image paths and gt_labels."""
263
+ if not self.ann_file:
264
+ samples = self._find_samples()
265
+ elif self.with_label:
266
+ lines = list_from_file(self.ann_file)
267
+ samples = [x.strip().rsplit(' ', 1) for x in lines]
268
+ else:
269
+ samples = list_from_file(self.ann_file)
270
+
271
+ # Pre-build file backend to prevent verbose file backend inference.
272
+ backend = get_file_backend(self.img_prefix, enable_singleton=True)
273
+ data_list = []
274
+ for sample in samples:
275
+ if self.with_label:
276
+ filename, gt_label = sample
277
+ img_path = backend.join_path(self.img_prefix, filename)
278
+ info = {'img_path': img_path, 'gt_label': int(gt_label)}
279
+ else:
280
+ img_path = backend.join_path(self.img_prefix, sample)
281
+ info = {'img_path': img_path}
282
+ data_list.append(info)
283
+ return data_list
284
+
285
+ def is_valid_file(self, filename: str) -> bool:
286
+ """Check if a file is a valid sample."""
287
+ return filename.lower().endswith(self.extensions)
mmpretrain/datasets/dataset_wrappers.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+
4
+ import numpy as np
5
+ from mmengine.dataset import BaseDataset, force_full_init
6
+
7
+ from mmpretrain.registry import DATASETS
8
+
9
+
10
+ @DATASETS.register_module()
11
+ class KFoldDataset:
12
+ """A wrapper of dataset for K-Fold cross-validation.
13
+
14
+ K-Fold cross-validation divides all the samples in groups of samples,
15
+ called folds, of almost equal sizes. And we use k-1 of folds to do training
16
+ and use the fold left to do validation.
17
+
18
+ Args:
19
+ dataset (:obj:`mmengine.dataset.BaseDataset` | dict): The dataset to be
20
+ divided
21
+ fold (int): The fold used to do validation. Defaults to 0.
22
+ num_splits (int): The number of all folds. Defaults to 5.
23
+ test_mode (bool): Use the training dataset or validation dataset.
24
+ Defaults to False.
25
+ seed (int, optional): The seed to shuffle the dataset before splitting.
26
+ If None, not shuffle the dataset. Defaults to None.
27
+ """
28
+
29
+ def __init__(self,
30
+ dataset,
31
+ fold=0,
32
+ num_splits=5,
33
+ test_mode=False,
34
+ seed=None):
35
+ if isinstance(dataset, dict):
36
+ self.dataset = DATASETS.build(dataset)
37
+ # Init the dataset wrapper lazily according to the dataset setting.
38
+ lazy_init = dataset.get('lazy_init', False)
39
+ elif isinstance(dataset, BaseDataset):
40
+ self.dataset = dataset
41
+ else:
42
+ raise TypeError(f'Unsupported dataset type {type(dataset)}.')
43
+
44
+ self._metainfo = getattr(self.dataset, 'metainfo', {})
45
+ self.fold = fold
46
+ self.num_splits = num_splits
47
+ self.test_mode = test_mode
48
+ self.seed = seed
49
+
50
+ self._fully_initialized = False
51
+ if not lazy_init:
52
+ self.full_init()
53
+
54
+ @property
55
+ def metainfo(self) -> dict:
56
+ """Get the meta information of ``self.dataset``.
57
+
58
+ Returns:
59
+ dict: Meta information of the dataset.
60
+ """
61
+ # Prevent `self._metainfo` from being modified by outside.
62
+ return copy.deepcopy(self._metainfo)
63
+
64
+ def full_init(self):
65
+ """fully initialize the dataset."""
66
+ if self._fully_initialized:
67
+ return
68
+
69
+ self.dataset.full_init()
70
+ ori_len = len(self.dataset)
71
+ indices = list(range(ori_len))
72
+ if self.seed is not None:
73
+ rng = np.random.default_rng(self.seed)
74
+ rng.shuffle(indices)
75
+
76
+ test_start = ori_len * self.fold // self.num_splits
77
+ test_end = ori_len * (self.fold + 1) // self.num_splits
78
+ if self.test_mode:
79
+ indices = indices[test_start:test_end]
80
+ else:
81
+ indices = indices[:test_start] + indices[test_end:]
82
+
83
+ self._ori_indices = indices
84
+ self.dataset = self.dataset.get_subset(indices)
85
+
86
+ self._fully_initialized = True
87
+
88
+ @force_full_init
89
+ def _get_ori_dataset_idx(self, idx: int) -> int:
90
+ """Convert global idx to local index.
91
+
92
+ Args:
93
+ idx (int): Global index of ``KFoldDataset``.
94
+
95
+ Returns:
96
+ int: The original index in the whole dataset.
97
+ """
98
+ return self._ori_indices[idx]
99
+
100
+ @force_full_init
101
+ def get_data_info(self, idx: int) -> dict:
102
+ """Get annotation by index.
103
+
104
+ Args:
105
+ idx (int): Global index of ``KFoldDataset``.
106
+
107
+ Returns:
108
+ dict: The idx-th annotation of the datasets.
109
+ """
110
+ return self.dataset.get_data_info(idx)
111
+
112
+ @force_full_init
113
+ def __len__(self):
114
+ return len(self.dataset)
115
+
116
+ @force_full_init
117
+ def __getitem__(self, idx):
118
+ return self.dataset[idx]
119
+
120
+ @force_full_init
121
+ def get_cat_ids(self, idx):
122
+ return self.dataset.get_cat_ids(idx)
123
+
124
+ @force_full_init
125
+ def get_gt_labels(self):
126
+ return self.dataset.get_gt_labels()
127
+
128
+ @property
129
+ def CLASSES(self):
130
+ """Return all categories names."""
131
+ return self._metainfo.get('classes', None)
132
+
133
+ @property
134
+ def class_to_idx(self):
135
+ """Map mapping class name to class index.
136
+
137
+ Returns:
138
+ dict: mapping from class name to class index.
139
+ """
140
+
141
+ return {cat: i for i, cat in enumerate(self.CLASSES)}
142
+
143
+ def __repr__(self):
144
+ """Print the basic information of the dataset.
145
+
146
+ Returns:
147
+ str: Formatted string.
148
+ """
149
+ head = 'Dataset ' + self.__class__.__name__
150
+ body = []
151
+ type_ = 'test' if self.test_mode else 'training'
152
+ body.append(f'Type: \t{type_}')
153
+ body.append(f'Seed: \t{self.seed}')
154
+
155
+ def ordinal(n):
156
+ # Copy from https://codegolf.stackexchange.com/a/74047
157
+ suffix = 'tsnrhtdd'[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4]
158
+ return f'{n}{suffix}'
159
+
160
+ body.append(
161
+ f'Fold: \t{ordinal(self.fold+1)} of {self.num_splits}-fold')
162
+ if self._fully_initialized:
163
+ body.append(f'Number of samples: \t{self.__len__()}')
164
+ else:
165
+ body.append("Haven't been initialized")
166
+
167
+ if self.CLASSES is not None:
168
+ body.append(f'Number of categories: \t{len(self.CLASSES)}')
169
+ else:
170
+ body.append('The `CLASSES` meta info is not set.')
171
+
172
+ body.append(
173
+ f'Original dataset type:\t{self.dataset.__class__.__name__}')
174
+
175
+ lines = [head] + [' ' * 4 + line for line in body]
176
+ return '\n'.join(lines)
mmpretrain/datasets/dtd.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List
3
+
4
+ import mat4py
5
+ from mmengine import get_file_backend
6
+
7
+ from mmpretrain.registry import DATASETS
8
+ from .base_dataset import BaseDataset
9
+ from .categories import DTD_CATEGORIES
10
+
11
+
12
+ @DATASETS.register_module()
13
+ class DTD(BaseDataset):
14
+ """The Describable Texture Dataset (DTD).
15
+
16
+ Support the `Describable Texture Dataset <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_ Dataset.
17
+ After downloading and decompression, the dataset directory structure is as follows.
18
+
19
+ DTD dataset directory: ::
20
+
21
+ dtd
22
+ ├── images
23
+ │ ├── banded
24
+ | | ├──banded_0002.jpg
25
+ | | ├──banded_0004.jpg
26
+ | | └── ...
27
+ │ └── ...
28
+ ├── imdb
29
+ │ └── imdb.mat
30
+ ├── labels
31
+ | | ├──labels_joint_anno.txt
32
+ | | ├──test1.txt
33
+ | | ├──test2.txt
34
+ | | └── ...
35
+ │ └── ...
36
+ └── ....
37
+
38
+ Args:
39
+ data_root (str): The root directory for Describable Texture dataset.
40
+ split (str, optional): The dataset split, supports "train",
41
+ "val", "trainval", and "test". Default to "trainval".
42
+
43
+ Examples:
44
+ >>> from mmpretrain.datasets import DTD
45
+ >>> train_dataset = DTD(data_root='data/dtd', split='trainval')
46
+ >>> train_dataset
47
+ Dataset DTD
48
+ Number of samples: 3760
49
+ Number of categories: 47
50
+ Root of dataset: data/dtd
51
+ >>> test_dataset = DTD(data_root='data/dtd', split='test')
52
+ >>> test_dataset
53
+ Dataset DTD
54
+ Number of samples: 1880
55
+ Number of categories: 47
56
+ Root of dataset: data/dtd
57
+ """ # noqa: E501
58
+
59
+ METAINFO = {'classes': DTD_CATEGORIES}
60
+
61
+ def __init__(self, data_root: str, split: str = 'trainval', **kwargs):
62
+
63
+ splits = ['train', 'val', 'trainval', 'test']
64
+ assert split in splits, \
65
+ f"The split must be one of {splits}, but get '{split}'"
66
+ self.split = split
67
+
68
+ data_prefix = 'images'
69
+ test_mode = split == 'test'
70
+
71
+ self.backend = get_file_backend(data_root, enable_singleton=True)
72
+ ann_file = self.backend.join_path('imdb', 'imdb.mat')
73
+
74
+ super(DTD, self).__init__(
75
+ ann_file=ann_file,
76
+ data_root=data_root,
77
+ data_prefix=data_prefix,
78
+ test_mode=test_mode,
79
+ **kwargs)
80
+
81
+ def load_data_list(self):
82
+ """Load images and ground truth labels."""
83
+
84
+ data = mat4py.loadmat(self.ann_file)['images']
85
+ names = data['name']
86
+ labels = data['class']
87
+ parts = data['set']
88
+ num = len(names)
89
+ assert num == len(labels) == len(parts), 'get error ann file'
90
+
91
+ if self.split == 'train':
92
+ target_set = {1}
93
+ elif self.split == 'val':
94
+ target_set = {2}
95
+ elif self.split == 'test':
96
+ target_set = {3}
97
+ else:
98
+ target_set = {1, 2}
99
+
100
+ data_list = []
101
+ for i in range(num):
102
+ if parts[i] in target_set:
103
+ img_name = names[i]
104
+ img_path = self.backend.join_path(self.img_prefix, img_name)
105
+ gt_label = labels[i] - 1
106
+ info = dict(img_path=img_path, gt_label=gt_label)
107
+ data_list.append(info)
108
+
109
+ return data_list
110
+
111
+ def extra_repr(self) -> List[str]:
112
+ """The extra repr information of the dataset."""
113
+ body = [
114
+ f'Root of dataset: \t{self.data_root}',
115
+ ]
116
+ return body
mmpretrain/datasets/fgvcaircraft.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List
3
+
4
+ from mmengine import get_file_backend, list_from_file
5
+
6
+ from mmpretrain.registry import DATASETS
7
+ from .base_dataset import BaseDataset
8
+ from .categories import FGVCAIRCRAFT_CATEGORIES
9
+
10
+
11
+ @DATASETS.register_module()
12
+ class FGVCAircraft(BaseDataset):
13
+ """The FGVC_Aircraft Dataset.
14
+
15
+ Support the `FGVC_Aircraft Dataset <https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/>`_ Dataset.
16
+ After downloading and decompression, the dataset directory structure is as follows.
17
+
18
+ FGVC_Aircraft dataset directory: ::
19
+
20
+ fgvc-aircraft-2013b
21
+ └── data
22
+ ├── images
23
+ │ ├── 1.jpg
24
+ │ ├── 2.jpg
25
+ │ └── ...
26
+ ├── images_variant_train.txt
27
+ ├── images_variant_test.txt
28
+ ├── images_variant_trainval.txt
29
+ ├── images_variant_val.txt
30
+ ├── variants.txt
31
+ └── ....
32
+
33
+ Args:
34
+ data_root (str): The root directory for FGVC_Aircraft dataset.
35
+ split (str, optional): The dataset split, supports "train",
36
+ "val", "trainval", and "test". Default to "trainval".
37
+
38
+ Examples:
39
+ >>> from mmpretrain.datasets import FGVCAircraft
40
+ >>> train_dataset = FGVCAircraft(data_root='data/fgvc-aircraft-2013b', split='trainval')
41
+ >>> train_dataset
42
+ Dataset FGVCAircraft
43
+ Number of samples: 6667
44
+ Number of categories: 100
45
+ Root of dataset: data/fgvc-aircraft-2013b
46
+ >>> test_dataset = FGVCAircraft(data_root='data/fgvc-aircraft-2013b', split='test')
47
+ >>> test_dataset
48
+ Dataset FGVCAircraft
49
+ Number of samples: 3333
50
+ Number of categories: 100
51
+ Root of dataset: data/fgvc-aircraft-2013b
52
+ """ # noqa: E501
53
+
54
+ METAINFO = {'classes': FGVCAIRCRAFT_CATEGORIES}
55
+
56
+ def __init__(self, data_root: str, split: str = 'trainval', **kwargs):
57
+
58
+ splits = ['train', 'val', 'trainval', 'test']
59
+ assert split in splits, \
60
+ f"The split must be one of {splits}, but get '{split}'"
61
+ self.split = split
62
+
63
+ self.backend = get_file_backend(data_root, enable_singleton=True)
64
+ ann_file = self.backend.join_path('data',
65
+ f'images_variant_{split}.txt')
66
+ data_prefix = self.backend.join_path('data', 'images')
67
+ test_mode = split == 'test'
68
+
69
+ super(FGVCAircraft, self).__init__(
70
+ ann_file=ann_file,
71
+ data_root=data_root,
72
+ test_mode=test_mode,
73
+ data_prefix=data_prefix,
74
+ **kwargs)
75
+
76
+ def load_data_list(self):
77
+ """Load images and ground truth labels."""
78
+
79
+ pairs = list_from_file(self.ann_file)
80
+ data_list = []
81
+ for pair in pairs:
82
+ pair = pair.split()
83
+ img_name = pair[0]
84
+ class_name = ' '.join(pair[1:])
85
+ img_name = f'{img_name}.jpg'
86
+ img_path = self.backend.join_path(self.img_prefix, img_name)
87
+ gt_label = self.METAINFO['classes'].index(class_name)
88
+ info = dict(img_path=img_path, gt_label=gt_label)
89
+ data_list.append(info)
90
+
91
+ return data_list
92
+
93
+ def extra_repr(self) -> List[str]:
94
+ """The extra repr information of the dataset."""
95
+ body = [
96
+ f'Root of dataset: \t{self.data_root}',
97
+ ]
98
+ return body
mmpretrain/datasets/flamingo.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import random
3
+ from abc import abstractmethod
4
+ from collections import Counter
5
+ from typing import List
6
+
7
+ import mmengine
8
+ import numpy as np
9
+ from mmengine.dataset import BaseDataset
10
+ from pycocotools.coco import COCO
11
+
12
+ from mmpretrain.registry import DATASETS
13
+ from .coco_vqa import COCOVQA
14
+
15
+
16
+ class FlamingoFewShotMixin:
17
+ """Flamingo fewshot eval dataset minin.
18
+
19
+ Args:
20
+ num_shots (int): Number of shots to perform evaluation.
21
+ Defaults to 0.
22
+ Note: 0 does not mean a strict zero-shot in Flamingo setting.
23
+ It will use 2 only-text prompt without in context images.
24
+ num_support_examples (int): Number of support examples to get the
25
+ few shots from. Defaults to 2048.
26
+ num_query_examples (int): Number of query examples to perform the
27
+ final evaluation. Defaults to 5000.
28
+ incontext_prompt_temp (str): In context prompt template for few shot
29
+ examples. Defaults to ''.
30
+ final_prompt_temp (str): Final query prompt template. Defaults to ''.
31
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
32
+ """
33
+
34
+ def __init__(self,
35
+ num_shots: int = 0,
36
+ num_support_examples: int = 2048,
37
+ num_query_examples: int = 5000,
38
+ incontext_prompt_temp: str = '',
39
+ final_prompt_temp: str = '',
40
+ **kwarg):
41
+ self.num_shots = num_shots
42
+ self.num_support_examples = num_support_examples
43
+ self.num_query_examples = num_query_examples
44
+ self.incontext_prompt_temp = incontext_prompt_temp
45
+ self.final_prompt_temp = final_prompt_temp
46
+ super().__init__(**kwarg)
47
+
48
+ def get_subset_idx(self, total_num):
49
+ random_idx = np.random.choice(
50
+ total_num,
51
+ self.num_support_examples + self.num_query_examples,
52
+ replace=False)
53
+
54
+ support_idx = random_idx[:self.num_support_examples]
55
+ query_idx = random_idx[self.num_support_examples:]
56
+ return support_idx, query_idx
57
+
58
+ @abstractmethod
59
+ def parse_basic_anno(self, anno: dict) -> dict:
60
+ """Parse basic annotation for support and query set."""
61
+ pass
62
+
63
+ @abstractmethod
64
+ def parse_fewshot_anno(self, anno: dict, support_list: List) -> dict:
65
+ """Parse fewshot related annotation for query set with support list."""
66
+ pass
67
+
68
+
69
+ @DATASETS.register_module()
70
+ class FlamingoEvalCOCOVQA(FlamingoFewShotMixin, COCOVQA):
71
+ """Flamingo few shot VQAv2 dataset.
72
+
73
+ Args:
74
+ data_root (str): The root directory for ``data_prefix`` and
75
+ ``ann_file``.
76
+ ann_file (str): Annotation file path.
77
+ question_file (str): Question file path.
78
+ num_shots (int): Number of shots to perform evaluation.
79
+ Defaults to 0.
80
+ Note: 0 does not mean a strict zero-shot in Flamingo setting.
81
+ It will use 2 only-text prompt without in context images.
82
+ num_support_examples (int): Number of support examples to get the
83
+ few shots from. Defaults to 2048.
84
+ num_query_examples (int): Number of query examples to perform the
85
+ final evaluation. Defaults to 5000.
86
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
87
+ """
88
+
89
+ def __init__(self,
90
+ data_root: str,
91
+ question_file: str,
92
+ ann_file: str = '',
93
+ num_shots: int = 0,
94
+ num_support_examples: int = 2048,
95
+ num_query_examples: int = 5000,
96
+ **kwarg):
97
+ super().__init__(
98
+ data_root=data_root,
99
+ question_file=question_file,
100
+ ann_file=ann_file,
101
+ num_shots=num_shots,
102
+ num_support_examples=num_support_examples,
103
+ num_query_examples=num_query_examples,
104
+ **kwarg)
105
+
106
+ def parse_basic_anno(self, ann: dict) -> dict:
107
+ """Parse basic annotation for support and query set.
108
+
109
+ Args:
110
+ anno (dict): Annotation for single example.
111
+
112
+ Return:
113
+ dict: Parsed annotation for single example.
114
+ """
115
+ if ann is None:
116
+ return {}
117
+
118
+ answers = [a['answer'] for a in ann['answers']]
119
+ count = Counter(answers)
120
+ answer_weight = [i / len(answers) for i in count.values()]
121
+ answer_info = {
122
+ 'gt_answer': list(count.keys()),
123
+ 'gt_answer_weight': answer_weight
124
+ }
125
+ return answer_info
126
+
127
+ def parse_fewshot_anno(self, query: dict, support_list: List) -> dict:
128
+ """Parse fewshot related annotation for query set with support list.
129
+
130
+ Args:
131
+ anno (dict): Annotation for single example.
132
+ support_list (List): List of support subset to subsample few shots.
133
+
134
+ Return:
135
+ dict: Parsed annotation for single example.
136
+ """
137
+ # prepare n shots examples
138
+ shots = random.sample(support_list, self.num_shots)
139
+
140
+ # append image path for n shots
141
+ img_path = [shot['img_path'] for shot in shots]
142
+ img_path.append(query['img_path'])
143
+ query['img_path'] = img_path
144
+
145
+ query['shots'] = [
146
+ dict(
147
+ question=item['question'],
148
+ answer=item['gt_answer'][0],
149
+ ) for item in shots
150
+ ]
151
+ return query
152
+
153
+ def load_data_list(self) -> List[dict]:
154
+ """Load data list."""
155
+ questions = mmengine.load(self.question_file)['questions']
156
+ if self.ann_file:
157
+ annotations = mmengine.load(self.ann_file)['annotations']
158
+ assert len(questions) == len(annotations)
159
+ else:
160
+ annotations = [None] * len(questions)
161
+ if self.num_shots > 0:
162
+ raise ValueError('Unable to construct few-shot examples '
163
+ 'since no annotation file.')
164
+
165
+ # The original VQAv2 annotation file and question file includes
166
+ # only image id but no image file paths.
167
+ self.image_index = self._create_image_index()
168
+
169
+ num_data = len(questions)
170
+ support_idx, query_idx = self.get_subset_idx(num_data)
171
+
172
+ # prepare support subset
173
+ if self.num_shots > 0:
174
+ support_list = []
175
+ for idx in support_idx:
176
+ question = questions[idx]
177
+ ann = annotations[idx]
178
+ support = {**question, **self.parse_basic_anno(ann)}
179
+ support['img_path'] = self.image_index[question['image_id']]
180
+ support_list.append(support)
181
+
182
+ # prepare query subset
183
+ data_list = []
184
+ for idx in query_idx:
185
+ question = questions[idx]
186
+ ann = annotations[idx]
187
+ data_info = {**question, **self.parse_basic_anno(ann)}
188
+ data_info['img_path'] = self.image_index[question['image_id']]
189
+ if self.num_shots > 0:
190
+ data_info = self.parse_fewshot_anno(data_info, support_list)
191
+ data_list.append(data_info)
192
+
193
+ return data_list
194
+
195
+
196
+ @DATASETS.register_module()
197
+ class FlamingoEvalCOCOCaption(FlamingoFewShotMixin, BaseDataset):
198
+ """Flamingo few shot COCO Caption dataset.
199
+
200
+ Args:
201
+ data_root (str): The root directory for ``data_prefix`` and
202
+ ``ann_file``.
203
+ ann_file (str): Annotation file path.
204
+ data_prefix (dict): Prefix for data field. Defaults to
205
+ ``dict(img_path='')``.
206
+ num_shots (int): Number of shots to perform evaluation.
207
+ Defaults to 0.
208
+ num_support_examples (int): Number of support examples to get the
209
+ few shots from. Defaults to 2048.
210
+ num_query_examples (int): Number of query examples to perform the
211
+ final evaluation. Defaults to 5000.
212
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
213
+ """
214
+
215
+ def __init__(self,
216
+ data_root: str,
217
+ ann_file: str,
218
+ num_shots: int = 0,
219
+ num_support_examples: int = 2048,
220
+ num_query_examples: int = 5000,
221
+ **kwarg):
222
+ super().__init__(
223
+ data_root=data_root,
224
+ ann_file=ann_file,
225
+ num_shots=num_shots,
226
+ num_support_examples=num_support_examples,
227
+ num_query_examples=num_query_examples,
228
+ **kwarg)
229
+
230
+ def parse_basic_anno(self, ann: dict, coco: COCO) -> dict:
231
+ """Parse basic annotation for support and query set.
232
+
233
+ Args:
234
+ anno (dict): Annotation for single example.
235
+ coco (COCO): The coco dataset.
236
+
237
+ Return:
238
+ dict: Parsed annotation for single example.
239
+ """
240
+ img_prefix = self.data_prefix['img_path']
241
+ img = coco.imgs[ann['image_id']]
242
+ data_info = dict(
243
+ img_path=mmengine.join_path(img_prefix, img['file_name']),
244
+ gt_caption=ann['caption'],
245
+ image_id=ann['image_id'],
246
+ )
247
+ return data_info
248
+
249
+ def parse_fewshot_anno(self, query: dict, support_list: List) -> dict:
250
+ """Parse fewshot related annotation for query set with support list.
251
+
252
+ Args:
253
+ query (dict): Annotation for single example.
254
+ support_list (List): List of support subset to subsample few shots.
255
+ coco (COCO): The coco dataset.
256
+
257
+ Return:
258
+ dict: Parsed annotation for single example.
259
+ """
260
+ # prepare n shots examples
261
+ shots = random.sample(support_list, self.num_shots)
262
+
263
+ # append image path for n shots
264
+ img_path = [shot['img_path'] for shot in shots]
265
+ img_path.append(query['img_path'])
266
+ query['img_path'] = img_path
267
+
268
+ query['shots'] = [dict(caption=item['gt_caption']) for item in shots]
269
+ return query
270
+
271
+ def load_data_list(self) -> List[dict]:
272
+ """Load data list."""
273
+ with mmengine.get_local_path(self.ann_file) as ann_file:
274
+ coco = COCO(ann_file)
275
+
276
+ num_data = len(coco.anns)
277
+ support_idx, query_idx = self.get_subset_idx(num_data)
278
+ ann_ids = list(coco.anns)
279
+
280
+ # prepare support subset
281
+ if self.num_shots > 0:
282
+ support_list = []
283
+ for idx in support_idx:
284
+ support = self.parse_basic_anno(coco.anns[ann_ids[idx]], coco)
285
+ support_list.append(support)
286
+
287
+ # prepare query subset
288
+ query_list = []
289
+ for idx in query_idx:
290
+ data_info = self.parse_basic_anno(coco.anns[ann_ids[idx]], coco)
291
+ if self.num_shots > 0:
292
+ data_info = self.parse_fewshot_anno(data_info, support_list)
293
+ query_list.append(data_info)
294
+
295
+ return query_list
mmpretrain/datasets/flowers102.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List
3
+
4
+ import mat4py
5
+ from mmengine import get_file_backend
6
+
7
+ from mmpretrain.registry import DATASETS
8
+ from .base_dataset import BaseDataset
9
+
10
+
11
+ @DATASETS.register_module()
12
+ class Flowers102(BaseDataset):
13
+ """The Oxford 102 Flower Dataset.
14
+
15
+ Support the `Oxford 102 Flowers Dataset <https://www.robots.ox.ac.uk/~vgg/data/flowers/102/>`_ Dataset.
16
+ After downloading and decompression, the dataset directory structure is as follows.
17
+
18
+ Flowers102 dataset directory: ::
19
+
20
+ Flowers102
21
+ ├── jpg
22
+ │ ├── image_00001.jpg
23
+ │ ├── image_00002.jpg
24
+ │ └── ...
25
+ ├── imagelabels.mat
26
+ ├── setid.mat
27
+ └── ...
28
+
29
+ Args:
30
+ data_root (str): The root directory for Oxford 102 Flowers dataset.
31
+ split (str, optional): The dataset split, supports "train",
32
+ "val", "trainval", and "test". Default to "trainval".
33
+
34
+ Examples:
35
+ >>> from mmpretrain.datasets import Flowers102
36
+ >>> train_dataset = Flowers102(data_root='data/Flowers102', split='trainval')
37
+ >>> train_dataset
38
+ Dataset Flowers102
39
+ Number of samples: 2040
40
+ Root of dataset: data/Flowers102
41
+ >>> test_dataset = Flowers102(data_root='data/Flowers102', split='test')
42
+ >>> test_dataset
43
+ Dataset Flowers102
44
+ Number of samples: 6149
45
+ Root of dataset: data/Flowers102
46
+ """ # noqa: E501
47
+
48
+ def __init__(self, data_root: str, split: str = 'trainval', **kwargs):
49
+ splits = ['train', 'val', 'trainval', 'test']
50
+ assert split in splits, \
51
+ f"The split must be one of {splits}, but get '{split}'"
52
+ self.split = split
53
+
54
+ ann_file = 'imagelabels.mat'
55
+ data_prefix = 'jpg'
56
+ train_test_split_file = 'setid.mat'
57
+ test_mode = split == 'test'
58
+
59
+ self.backend = get_file_backend(data_root, enable_singleton=True)
60
+
61
+ self.train_test_split_file = self.backend.join_path(
62
+ data_root, train_test_split_file)
63
+
64
+ super(Flowers102, self).__init__(
65
+ ann_file=ann_file,
66
+ data_root=data_root,
67
+ data_prefix=data_prefix,
68
+ test_mode=test_mode,
69
+ **kwargs)
70
+
71
+ def load_data_list(self):
72
+ """Load images and ground truth labels."""
73
+
74
+ label_dict = mat4py.loadmat(self.ann_file)['labels']
75
+ split_list = mat4py.loadmat(self.train_test_split_file)
76
+
77
+ if self.split == 'train':
78
+ split_list = split_list['trnid']
79
+ elif self.split == 'val':
80
+ split_list = split_list['valid']
81
+ elif self.split == 'test':
82
+ split_list = split_list['tstid']
83
+ else:
84
+ train_ids = split_list['trnid']
85
+ val_ids = split_list['valid']
86
+ train_ids.extend(val_ids)
87
+ split_list = train_ids
88
+
89
+ data_list = []
90
+ for sample_id in split_list:
91
+ img_name = 'image_%05d.jpg' % (sample_id)
92
+ img_path = self.backend.join_path(self.img_prefix, img_name)
93
+ gt_label = int(label_dict[sample_id - 1]) - 1
94
+ info = dict(img_path=img_path, gt_label=gt_label)
95
+ data_list.append(info)
96
+
97
+ return data_list
98
+
99
+ def extra_repr(self) -> List[str]:
100
+ """The extra repr information of the dataset."""
101
+ body = [
102
+ f'Root of dataset: \t{self.data_root}',
103
+ ]
104
+ return body
mmpretrain/datasets/food101.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List
3
+
4
+ from mmengine import get_file_backend, list_from_file
5
+
6
+ from mmpretrain.registry import DATASETS
7
+ from .base_dataset import BaseDataset
8
+ from .categories import FOOD101_CATEGORIES
9
+
10
+
11
+ @DATASETS.register_module()
12
+ class Food101(BaseDataset):
13
+ """The Food101 Dataset.
14
+
15
+ Support the `Food101 Dataset <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_ Dataset.
16
+ After downloading and decompression, the dataset directory structure is as follows.
17
+
18
+ Food101 dataset directory: ::
19
+
20
+ food-101
21
+ ├── images
22
+ │ ├── class_x
23
+ │ │ ├── xx1.jpg
24
+ │ │ ├── xx2.jpg
25
+ │ │ └── ...
26
+ │ ├── class_y
27
+ │ │ ├── yy1.jpg
28
+ │ │ ├── yy2.jpg
29
+ │ │ └── ...
30
+ │ └── ...
31
+ ├── meta
32
+ │ ├── train.txt
33
+ │ └── test.txt
34
+ └── ....
35
+
36
+ Args:
37
+ data_root (str): The root directory for Food101 dataset.
38
+ split (str, optional): The dataset split, supports "train" and "test".
39
+ Default to "train".
40
+
41
+ Examples:
42
+ >>> from mmpretrain.datasets import Food101
43
+ >>> train_dataset = Food101(data_root='data/food-101', split='train')
44
+ >>> train_dataset
45
+ Dataset Food101
46
+ Number of samples: 75750
47
+ Number of categories: 101
48
+ Root of dataset: data/food-101
49
+ >>> test_dataset = Food101(data_root='data/food-101', split='test')
50
+ >>> test_dataset
51
+ Dataset Food101
52
+ Number of samples: 25250
53
+ Number of categories: 101
54
+ Root of dataset: data/food-101
55
+ """ # noqa: E501
56
+
57
+ METAINFO = {'classes': FOOD101_CATEGORIES}
58
+
59
+ def __init__(self, data_root: str, split: str = 'train', **kwargs):
60
+
61
+ splits = ['train', 'test']
62
+ assert split in splits, \
63
+ f"The split must be one of {splits}, but get '{split}'"
64
+ self.split = split
65
+
66
+ self.backend = get_file_backend(data_root, enable_singleton=True)
67
+ if split == 'train':
68
+ ann_file = self.backend.join_path('meta', 'train.txt')
69
+ else:
70
+ ann_file = self.backend.join_path('meta', 'test.txt')
71
+
72
+ test_mode = split == 'test'
73
+ data_prefix = 'images'
74
+
75
+ super(Food101, self).__init__(
76
+ ann_file=ann_file,
77
+ data_root=data_root,
78
+ test_mode=test_mode,
79
+ data_prefix=data_prefix,
80
+ **kwargs)
81
+
82
+ def load_data_list(self):
83
+ """Load images and ground truth labels."""
84
+
85
+ pairs = list_from_file(self.ann_file)
86
+ data_list = []
87
+ for pair in pairs:
88
+ class_name, img_name = pair.split('/')
89
+ img_name = f'{img_name}.jpg'
90
+ img_path = self.backend.join_path(self.img_prefix, class_name,
91
+ img_name)
92
+ gt_label = self.METAINFO['classes'].index(class_name)
93
+ info = dict(img_path=img_path, gt_label=gt_label)
94
+ data_list.append(info)
95
+ return data_list
96
+
97
+ def extra_repr(self) -> List[str]:
98
+ """The extra repr information of the dataset."""
99
+ body = [
100
+ f'Root of dataset: \t{self.data_root}',
101
+ ]
102
+ return body
mmpretrain/datasets/imagenet.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Optional, Union
3
+
4
+ from mmengine.logging import MMLogger
5
+
6
+ from mmpretrain.registry import DATASETS
7
+ from .categories import IMAGENET_CATEGORIES
8
+ from .custom import CustomDataset
9
+
10
+
11
+ @DATASETS.register_module()
12
+ class ImageNet(CustomDataset):
13
+ """`ImageNet <http://www.image-net.org>`_ Dataset.
14
+
15
+ The dataset supports two kinds of annotation format. More details can be
16
+ found in :class:`CustomDataset`.
17
+
18
+ Args:
19
+ data_root (str): The root directory for ``data_prefix`` and
20
+ ``ann_file``. Defaults to ''.
21
+ data_prefix (str | dict): Prefix for training data. Defaults to ''.
22
+ ann_file (str): Annotation file path. Defaults to ''.
23
+ metainfo (dict, optional): Meta information for dataset, such as class
24
+ information. Defaults to None.
25
+ **kwargs: Other keyword arguments in :class:`CustomDataset` and
26
+ :class:`BaseDataset`.
27
+ """ # noqa: E501
28
+
29
+ IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
30
+ METAINFO = {'classes': IMAGENET_CATEGORIES}
31
+
32
+ def __init__(self,
33
+ data_root: str = '',
34
+ data_prefix: Union[str, dict] = '',
35
+ ann_file: str = '',
36
+ metainfo: Optional[dict] = None,
37
+ **kwargs):
38
+ kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs}
39
+ super().__init__(
40
+ data_root=data_root,
41
+ data_prefix=data_prefix,
42
+ ann_file=ann_file,
43
+ metainfo=metainfo,
44
+ **kwargs)
45
+
46
+
47
+ @DATASETS.register_module()
48
+ class ImageNet21k(CustomDataset):
49
+ """ImageNet21k Dataset.
50
+
51
+ Since the dataset ImageNet21k is extremely big, cantains 21k+ classes
52
+ and 1.4B files. We won't provide the default categories list. Please
53
+ specify it from the ``classes`` argument.
54
+
55
+ Args:
56
+ data_root (str): The root directory for ``data_prefix`` and
57
+ ``ann_file``. Defaults to ''.
58
+ data_prefix (str | dict): Prefix for training data. Defaults to ''.
59
+ ann_file (str): Annotation file path. Defaults to ''.
60
+ metainfo (dict, optional): Meta information for dataset, such as class
61
+ information. Defaults to None.
62
+ multi_label (bool): Not implement by now. Use multi label or not.
63
+ Defaults to False.
64
+ **kwargs: Other keyword arguments in :class:`CustomDataset` and
65
+ :class:`BaseDataset`.
66
+ """
67
+
68
+ IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
69
+
70
+ def __init__(self,
71
+ data_root: str = '',
72
+ data_prefix: Union[str, dict] = '',
73
+ ann_file: str = '',
74
+ metainfo: Optional[dict] = None,
75
+ multi_label: bool = False,
76
+ **kwargs):
77
+ if multi_label:
78
+ raise NotImplementedError(
79
+ 'The `multi_label` option is not supported by now.')
80
+ self.multi_label = multi_label
81
+
82
+ logger = MMLogger.get_current_instance()
83
+
84
+ if not ann_file:
85
+ logger.warning(
86
+ 'The ImageNet21k dataset is large, and scanning directory may '
87
+ 'consume long time. Considering to specify the `ann_file` to '
88
+ 'accelerate the initialization.')
89
+
90
+ kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs}
91
+ super().__init__(
92
+ data_root=data_root,
93
+ data_prefix=data_prefix,
94
+ ann_file=ann_file,
95
+ metainfo=metainfo,
96
+ **kwargs)
97
+
98
+ if self.CLASSES is None:
99
+ logger.warning(
100
+ 'The CLASSES is not stored in the `ImageNet21k` class. '
101
+ 'Considering to specify the `classes` argument if you need '
102
+ 'do inference on the ImageNet-21k dataset')
mmpretrain/datasets/inshop.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from mmengine import get_file_backend, list_from_file
3
+
4
+ from mmpretrain.registry import DATASETS
5
+ from .base_dataset import BaseDataset
6
+
7
+
8
+ @DATASETS.register_module()
9
+ class InShop(BaseDataset):
10
+ """InShop Dataset for Image Retrieval.
11
+
12
+ Please download the images from the homepage
13
+ 'https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html'
14
+ (In-shop Clothes Retrieval Benchmark -> Img -> img.zip,
15
+ Eval/list_eval_partition.txt), and organize them as follows way: ::
16
+
17
+ In-shop Clothes Retrieval Benchmark (data_root)/
18
+ ├── Eval /
19
+ │ └── list_eval_partition.txt (ann_file)
20
+ ├── Img (img_prefix)
21
+ │ └── img/
22
+ ├── README.txt
23
+ └── .....
24
+
25
+ Args:
26
+ data_root (str): The root directory for dataset.
27
+ split (str): Choose from 'train', 'query' and 'gallery'.
28
+ Defaults to 'train'.
29
+ data_prefix (str | dict): Prefix for training data.
30
+ Defaults to 'Img'.
31
+ ann_file (str): Annotation file path, path relative to
32
+ ``data_root``. Defaults to 'Eval/list_eval_partition.txt'.
33
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
34
+
35
+ Examples:
36
+ >>> from mmpretrain.datasets import InShop
37
+ >>>
38
+ >>> # build train InShop dataset
39
+ >>> inshop_train_cfg = dict(data_root='data/inshop', split='train')
40
+ >>> inshop_train = InShop(**inshop_train_cfg)
41
+ >>> inshop_train
42
+ Dataset InShop
43
+ Number of samples: 25882
44
+ The `CLASSES` meta info is not set.
45
+ Root of dataset: data/inshop
46
+ >>>
47
+ >>> # build query InShop dataset
48
+ >>> inshop_query_cfg = dict(data_root='data/inshop', split='query')
49
+ >>> inshop_query = InShop(**inshop_query_cfg)
50
+ >>> inshop_query
51
+ Dataset InShop
52
+ Number of samples: 14218
53
+ The `CLASSES` meta info is not set.
54
+ Root of dataset: data/inshop
55
+ >>>
56
+ >>> # build gallery InShop dataset
57
+ >>> inshop_gallery_cfg = dict(data_root='data/inshop', split='gallery')
58
+ >>> inshop_gallery = InShop(**inshop_gallery_cfg)
59
+ >>> inshop_gallery
60
+ Dataset InShop
61
+ Number of samples: 12612
62
+ The `CLASSES` meta info is not set.
63
+ Root of dataset: data/inshop
64
+ """
65
+
66
+ def __init__(self,
67
+ data_root: str,
68
+ split: str = 'train',
69
+ data_prefix: str = 'Img',
70
+ ann_file: str = 'Eval/list_eval_partition.txt',
71
+ **kwargs):
72
+
73
+ assert split in ('train', 'query', 'gallery'), "'split' of `InShop`" \
74
+ f" must be one of ['train', 'query', 'gallery'], bu get '{split}'"
75
+ self.backend = get_file_backend(data_root, enable_singleton=True)
76
+ self.split = split
77
+ super().__init__(
78
+ data_root=data_root,
79
+ data_prefix=data_prefix,
80
+ ann_file=ann_file,
81
+ **kwargs)
82
+
83
+ def _process_annotations(self):
84
+ lines = list_from_file(self.ann_file)
85
+
86
+ anno_train = dict(metainfo=dict(), data_list=list())
87
+ anno_gallery = dict(metainfo=dict(), data_list=list())
88
+
89
+ # item_id to label, each item corresponds to one class label
90
+ class_num = 0
91
+ gt_label_train = {}
92
+
93
+ # item_id to label, each label corresponds to several items
94
+ gallery_num = 0
95
+ gt_label_gallery = {}
96
+
97
+ # (lines[0], lines[1]) is the image number and the field name;
98
+ # Each line format as 'image_name, item_id, evaluation_status'
99
+ for line in lines[2:]:
100
+ img_name, item_id, status = line.split()
101
+ img_path = self.backend.join_path(self.img_prefix, img_name)
102
+ if status == 'train':
103
+ if item_id not in gt_label_train:
104
+ gt_label_train[item_id] = class_num
105
+ class_num += 1
106
+ # item_id to class_id (for the training set)
107
+ anno_train['data_list'].append(
108
+ dict(img_path=img_path, gt_label=gt_label_train[item_id]))
109
+ elif status == 'gallery':
110
+ if item_id not in gt_label_gallery:
111
+ gt_label_gallery[item_id] = []
112
+ # Since there are multiple images for each item,
113
+ # record the corresponding item for each image.
114
+ gt_label_gallery[item_id].append(gallery_num)
115
+ anno_gallery['data_list'].append(
116
+ dict(img_path=img_path, sample_idx=gallery_num))
117
+ gallery_num += 1
118
+
119
+ if self.split == 'train':
120
+ anno_train['metainfo']['class_number'] = class_num
121
+ anno_train['metainfo']['sample_number'] = \
122
+ len(anno_train['data_list'])
123
+ return anno_train
124
+ elif self.split == 'gallery':
125
+ anno_gallery['metainfo']['sample_number'] = gallery_num
126
+ return anno_gallery
127
+
128
+ # Generate the label for the query(val) set
129
+ anno_query = dict(metainfo=dict(), data_list=list())
130
+ query_num = 0
131
+ for line in lines[2:]:
132
+ img_name, item_id, status = line.split()
133
+ img_path = self.backend.join_path(self.img_prefix, img_name)
134
+ if status == 'query':
135
+ anno_query['data_list'].append(
136
+ dict(
137
+ img_path=img_path, gt_label=gt_label_gallery[item_id]))
138
+ query_num += 1
139
+
140
+ anno_query['metainfo']['sample_number'] = query_num
141
+ return anno_query
142
+
143
+ def load_data_list(self):
144
+ """load data list.
145
+
146
+ For the train set, return image and ground truth label. For the query
147
+ set, return image and ids of images in gallery. For the gallery set,
148
+ return image and its id.
149
+ """
150
+ data_info = self._process_annotations()
151
+ data_list = data_info['data_list']
152
+ return data_list
153
+
154
+ def extra_repr(self):
155
+ """The extra repr information of the dataset."""
156
+ body = [f'Root of dataset: \t{self.data_root}']
157
+ return body
mmpretrain/datasets/mnist.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import codecs
3
+ from typing import List, Optional
4
+ from urllib.parse import urljoin
5
+
6
+ import mmengine.dist as dist
7
+ import numpy as np
8
+ import torch
9
+ from mmengine.fileio import LocalBackend, exists, get_file_backend, join_path
10
+
11
+ from mmpretrain.registry import DATASETS
12
+ from .base_dataset import BaseDataset
13
+ from .categories import FASHIONMNIST_CATEGORITES, MNIST_CATEGORITES
14
+ from .utils import (download_and_extract_archive, open_maybe_compressed_file,
15
+ rm_suffix)
16
+
17
+
18
+ @DATASETS.register_module()
19
+ class MNIST(BaseDataset):
20
+ """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
21
+
22
+ This implementation is modified from
23
+ https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
24
+
25
+ Args:
26
+ data_prefix (str): Prefix for data.
27
+ test_mode (bool): ``test_mode=True`` means in test phase.
28
+ It determines to use the training set or test set.
29
+ metainfo (dict, optional): Meta information for dataset, such as
30
+ categories information. Defaults to None.
31
+ data_root (str): The root directory for ``data_prefix``.
32
+ Defaults to ''.
33
+ download (bool): Whether to download the dataset if not exists.
34
+ Defaults to True.
35
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
36
+ """ # noqa: E501
37
+
38
+ url_prefix = 'http://yann.lecun.com/exdb/mnist/'
39
+ # train images and labels
40
+ train_list = [
41
+ ['train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'],
42
+ ['train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'],
43
+ ]
44
+ # test images and labels
45
+ test_list = [
46
+ ['t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'],
47
+ ['t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c'],
48
+ ]
49
+ METAINFO = {'classes': MNIST_CATEGORITES}
50
+
51
+ def __init__(self,
52
+ data_prefix: str,
53
+ test_mode: bool,
54
+ metainfo: Optional[dict] = None,
55
+ data_root: str = '',
56
+ download: bool = True,
57
+ **kwargs):
58
+ self.download = download
59
+ super().__init__(
60
+ # The MNIST dataset doesn't need specify annotation file
61
+ ann_file='',
62
+ metainfo=metainfo,
63
+ data_root=data_root,
64
+ data_prefix=dict(root=data_prefix),
65
+ test_mode=test_mode,
66
+ **kwargs)
67
+
68
+ def load_data_list(self):
69
+ """Load images and ground truth labels."""
70
+ root = self.data_prefix['root']
71
+ backend = get_file_backend(root, enable_singleton=True)
72
+
73
+ if dist.is_main_process() and not self._check_exists():
74
+ if not isinstance(backend, LocalBackend):
75
+ raise RuntimeError(f'The dataset on {root} is not integrated, '
76
+ f'please manually handle it.')
77
+
78
+ if self.download:
79
+ self._download()
80
+ else:
81
+ raise RuntimeError(
82
+ f'Cannot find {self.__class__.__name__} dataset in '
83
+ f"{self.data_prefix['root']}, you can specify "
84
+ '`download=True` to download automatically.')
85
+
86
+ dist.barrier()
87
+ assert self._check_exists(), \
88
+ 'Download failed or shared storage is unavailable. Please ' \
89
+ f'download the dataset manually through {self.url_prefix}.'
90
+
91
+ if not self.test_mode:
92
+ file_list = self.train_list
93
+ else:
94
+ file_list = self.test_list
95
+
96
+ # load data from SN3 files
97
+ imgs = read_image_file(join_path(root, rm_suffix(file_list[0][0])))
98
+ gt_labels = read_label_file(
99
+ join_path(root, rm_suffix(file_list[1][0])))
100
+
101
+ data_infos = []
102
+ for img, gt_label in zip(imgs, gt_labels):
103
+ gt_label = np.array(gt_label, dtype=np.int64)
104
+ info = {'img': img.numpy(), 'gt_label': gt_label}
105
+ data_infos.append(info)
106
+ return data_infos
107
+
108
+ def _check_exists(self):
109
+ """Check the exists of data files."""
110
+ root = self.data_prefix['root']
111
+
112
+ for filename, _ in (self.train_list + self.test_list):
113
+ # get extracted filename of data
114
+ extract_filename = rm_suffix(filename)
115
+ fpath = join_path(root, extract_filename)
116
+ if not exists(fpath):
117
+ return False
118
+ return True
119
+
120
+ def _download(self):
121
+ """Download and extract data files."""
122
+ root = self.data_prefix['root']
123
+
124
+ for filename, md5 in (self.train_list + self.test_list):
125
+ url = urljoin(self.url_prefix, filename)
126
+ download_and_extract_archive(
127
+ url, download_root=root, filename=filename, md5=md5)
128
+
129
+ def extra_repr(self) -> List[str]:
130
+ """The extra repr information of the dataset."""
131
+ body = [f"Prefix of data: \t{self.data_prefix['root']}"]
132
+ return body
133
+
134
+
135
+ @DATASETS.register_module()
136
+ class FashionMNIST(MNIST):
137
+ """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_
138
+ Dataset.
139
+
140
+ Args:
141
+ data_prefix (str): Prefix for data.
142
+ test_mode (bool): ``test_mode=True`` means in test phase.
143
+ It determines to use the training set or test set.
144
+ metainfo (dict, optional): Meta information for dataset, such as
145
+ categories information. Defaults to None.
146
+ data_root (str): The root directory for ``data_prefix``.
147
+ Defaults to ''.
148
+ download (bool): Whether to download the dataset if not exists.
149
+ Defaults to True.
150
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
151
+ """
152
+
153
+ url_prefix = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
154
+ # train images and labels
155
+ train_list = [
156
+ ['train-images-idx3-ubyte.gz', '8d4fb7e6c68d591d4c3dfef9ec88bf0d'],
157
+ ['train-labels-idx1-ubyte.gz', '25c81989df183df01b3e8a0aad5dffbe'],
158
+ ]
159
+ # test images and labels
160
+ test_list = [
161
+ ['t10k-images-idx3-ubyte.gz', 'bef4ecab320f06d8554ea6380940ec79'],
162
+ ['t10k-labels-idx1-ubyte.gz', 'bb300cfdad3c16e7a12a480ee83cd310'],
163
+ ]
164
+ METAINFO = {'classes': FASHIONMNIST_CATEGORITES}
165
+
166
+
167
+ def get_int(b: bytes) -> int:
168
+ """Convert bytes to int."""
169
+ return int(codecs.encode(b, 'hex'), 16)
170
+
171
+
172
+ def read_sn3_pascalvincent_tensor(path: str,
173
+ strict: bool = True) -> torch.Tensor:
174
+ """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-
175
+ io.lsh').
176
+
177
+ Argument may be a filename, compressed filename, or file object.
178
+ """
179
+ # typemap
180
+ if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'):
181
+ read_sn3_pascalvincent_tensor.typemap = {
182
+ 8: (torch.uint8, np.uint8, np.uint8),
183
+ 9: (torch.int8, np.int8, np.int8),
184
+ 11: (torch.int16, np.dtype('>i2'), 'i2'),
185
+ 12: (torch.int32, np.dtype('>i4'), 'i4'),
186
+ 13: (torch.float32, np.dtype('>f4'), 'f4'),
187
+ 14: (torch.float64, np.dtype('>f8'), 'f8')
188
+ }
189
+ # read
190
+ with open_maybe_compressed_file(path) as f:
191
+ data = f.read()
192
+ # parse
193
+ magic = get_int(data[0:4])
194
+ nd = magic % 256
195
+ ty = magic // 256
196
+ assert nd >= 1 and nd <= 3
197
+ assert ty >= 8 and ty <= 14
198
+ m = read_sn3_pascalvincent_tensor.typemap[ty]
199
+ s = [get_int(data[4 * (i + 1):4 * (i + 2)]) for i in range(nd)]
200
+ parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
201
+ assert parsed.shape[0] == np.prod(s) or not strict
202
+ return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
203
+
204
+
205
+ def read_label_file(path: str) -> torch.Tensor:
206
+ """Read labels from SN3 label file."""
207
+ with open(path, 'rb') as f:
208
+ x = read_sn3_pascalvincent_tensor(f, strict=False)
209
+ assert (x.dtype == torch.uint8)
210
+ assert (x.ndimension() == 1)
211
+ return x.long()
212
+
213
+
214
+ def read_image_file(path: str) -> torch.Tensor:
215
+ """Read images from SN3 image file."""
216
+ with open(path, 'rb') as f:
217
+ x = read_sn3_pascalvincent_tensor(f, strict=False)
218
+ assert (x.dtype == torch.uint8)
219
+ assert (x.ndimension() == 3)
220
+ return x
mmpretrain/datasets/multi_label.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List
3
+
4
+ from mmpretrain.registry import DATASETS
5
+ from .base_dataset import BaseDataset
6
+
7
+
8
+ @DATASETS.register_module()
9
+ class MultiLabelDataset(BaseDataset):
10
+ """Multi-label Dataset.
11
+
12
+ This dataset support annotation file in `OpenMMLab 2.0 style annotation
13
+ format`.
14
+
15
+ The annotation format is shown as follows.
16
+
17
+ .. code-block:: none
18
+
19
+ {
20
+ "metainfo":
21
+ {
22
+ "classes":['A', 'B', 'C'....]
23
+ },
24
+ "data_list":
25
+ [
26
+ {
27
+ "img_path": "test_img1.jpg",
28
+ 'gt_label': [0, 1],
29
+ },
30
+ {
31
+ "img_path": "test_img2.jpg",
32
+ 'gt_label': [2],
33
+ },
34
+ ]
35
+ ....
36
+ }
37
+
38
+
39
+ Args:
40
+ ann_file (str): Annotation file path.
41
+ metainfo (dict, optional): Meta information for dataset, such as class
42
+ information. Defaults to None.
43
+ data_root (str): The root directory for ``data_prefix`` and
44
+ ``ann_file``. Defaults to ''.
45
+ data_prefix (str | dict): Prefix for training data. Defaults to ''.
46
+ filter_cfg (dict, optional): Config for filter data. Defaults to None.
47
+ indices (int or Sequence[int], optional): Support using first few
48
+ data in annotation file to facilitate training/testing on a smaller
49
+ dataset. Defaults to None which means using all ``data_infos``.
50
+ serialize_data (bool, optional): Whether to hold memory using
51
+ serialized objects, when enabled, data loader workers can use
52
+ shared RAM from master process instead of making a copy. Defaults
53
+ to True.
54
+ pipeline (list, optional): Processing pipeline. Defaults to [].
55
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
56
+ Defaults to False.
57
+ lazy_init (bool, optional): Whether to load annotation during
58
+ instantiation. In some cases, such as visualization, only the meta
59
+ information of the dataset is needed, which is not necessary to
60
+ load annotation file. ``Basedataset`` can skip load annotations to
61
+ save time by set ``lazy_init=False``. Defaults to False.
62
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
63
+ None img. The maximum extra number of cycles to get a valid
64
+ image. Defaults to 1000.
65
+ classes (str | Sequence[str], optional): Specify names of classes.
66
+
67
+ - If is string, it should be a file path, and the every line of
68
+ the file is a name of a class.
69
+ - If is a sequence of string, every item is a name of class.
70
+ - If is None, use categories information in ``metainfo`` argument,
71
+ annotation file or the class attribute ``METAINFO``.
72
+
73
+ Defaults to None.
74
+ """
75
+
76
+ def get_cat_ids(self, idx: int) -> List[int]:
77
+ """Get category ids by index.
78
+
79
+ Args:
80
+ idx (int): Index of data.
81
+
82
+ Returns:
83
+ cat_ids (List[int]): Image categories of specified index.
84
+ """
85
+ return self.get_data_info(idx)['gt_label']
mmpretrain/datasets/multi_task.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import os.path as osp
4
+ from os import PathLike
5
+ from typing import Optional, Sequence
6
+
7
+ import mmengine
8
+ from mmcv.transforms import Compose
9
+ from mmengine.fileio import get_file_backend
10
+
11
+ from .builder import DATASETS
12
+
13
+
14
+ def expanduser(path):
15
+ if isinstance(path, (str, PathLike)):
16
+ return osp.expanduser(path)
17
+ else:
18
+ return path
19
+
20
+
21
+ def isabs(uri):
22
+ return osp.isabs(uri) or ('://' in uri)
23
+
24
+
25
+ @DATASETS.register_module()
26
+ class MultiTaskDataset:
27
+ """Custom dataset for multi-task dataset.
28
+
29
+ To use the dataset, please generate and provide an annotation file in the
30
+ below format:
31
+
32
+ .. code-block:: json
33
+
34
+ {
35
+ "metainfo": {
36
+ "tasks":
37
+ [
38
+ 'gender'
39
+ 'wear'
40
+ ]
41
+ },
42
+ "data_list": [
43
+ {
44
+ "img_path": "a.jpg",
45
+ gt_label:{
46
+ "gender": 0,
47
+ "wear": [1, 0, 1, 0]
48
+ }
49
+ },
50
+ {
51
+ "img_path": "b.jpg",
52
+ gt_label:{
53
+ "gender": 1,
54
+ "wear": [1, 0, 1, 0]
55
+ }
56
+ }
57
+ ]
58
+ }
59
+
60
+ Assume we put our dataset in the ``data/mydataset`` folder in the
61
+ repository and organize it as the below format: ::
62
+
63
+ mmpretrain/
64
+ └── data
65
+ └── mydataset
66
+ ├── annotation
67
+ │   ├── train.json
68
+ │   ├── test.json
69
+ │   └── val.json
70
+ ├── train
71
+ │   ├── a.jpg
72
+ │   └── ...
73
+ ├── test
74
+ │   ├── b.jpg
75
+ │   └── ...
76
+ └── val
77
+ ├── c.jpg
78
+ └── ...
79
+
80
+ We can use the below config to build datasets:
81
+
82
+ .. code:: python
83
+
84
+ >>> from mmpretrain.datasets import build_dataset
85
+ >>> train_cfg = dict(
86
+ ... type="MultiTaskDataset",
87
+ ... ann_file="annotation/train.json",
88
+ ... data_root="data/mydataset",
89
+ ... # The `img_path` field in the train annotation file is relative
90
+ ... # to the `train` folder.
91
+ ... data_prefix='train',
92
+ ... )
93
+ >>> train_dataset = build_dataset(train_cfg)
94
+
95
+ Or we can put all files in the same folder: ::
96
+
97
+ mmpretrain/
98
+ └── data
99
+ └── mydataset
100
+ ├── train.json
101
+ ├── test.json
102
+ ├── val.json
103
+ ├── a.jpg
104
+ ├── b.jpg
105
+ ├── c.jpg
106
+ └── ...
107
+
108
+ And we can use the below config to build datasets:
109
+
110
+ .. code:: python
111
+
112
+ >>> from mmpretrain.datasets import build_dataset
113
+ >>> train_cfg = dict(
114
+ ... type="MultiTaskDataset",
115
+ ... ann_file="train.json",
116
+ ... data_root="data/mydataset",
117
+ ... # the `data_prefix` is not required since all paths are
118
+ ... # relative to the `data_root`.
119
+ ... )
120
+ >>> train_dataset = build_dataset(train_cfg)
121
+
122
+
123
+ Args:
124
+ ann_file (str): The annotation file path. It can be either absolute
125
+ path or relative path to the ``data_root``.
126
+ metainfo (dict, optional): The extra meta information. It should be
127
+ a dict with the same format as the ``"metainfo"`` field in the
128
+ annotation file. Defaults to None.
129
+ data_root (str, optional): The root path of the data directory. It's
130
+ the prefix of the ``data_prefix`` and the ``ann_file``. And it can
131
+ be a remote path like "s3://openmmlab/xxx/". Defaults to None.
132
+ data_prefix (str, optional): The base folder relative to the
133
+ ``data_root`` for the ``"img_path"`` field in the annotation file.
134
+ Defaults to None.
135
+ pipeline (Sequence[dict]): A list of dict, where each element
136
+ represents a operation defined in
137
+ :mod:`mmpretrain.datasets.pipelines`. Defaults to an empty tuple.
138
+ test_mode (bool): in train mode or test mode. Defaults to False.
139
+ """
140
+ METAINFO = dict()
141
+
142
+ def __init__(self,
143
+ ann_file: str,
144
+ metainfo: Optional[dict] = None,
145
+ data_root: Optional[str] = None,
146
+ data_prefix: Optional[str] = None,
147
+ pipeline: Sequence = (),
148
+ test_mode: bool = False):
149
+
150
+ self.data_root = expanduser(data_root)
151
+
152
+ # Inference the file client
153
+ if self.data_root is not None:
154
+ self.file_backend = get_file_backend(uri=self.data_root)
155
+ else:
156
+ self.file_backend = None
157
+
158
+ self.ann_file = self._join_root(expanduser(ann_file))
159
+ self.data_prefix = self._join_root(data_prefix)
160
+
161
+ self.test_mode = test_mode
162
+ self.pipeline = Compose(pipeline)
163
+ self.data_list = self.load_data_list(self.ann_file, metainfo)
164
+
165
+ def _join_root(self, path):
166
+ """Join ``self.data_root`` with the specified path.
167
+
168
+ If the path is an absolute path, just return the path. And if the
169
+ path is None, return ``self.data_root``.
170
+
171
+ Examples:
172
+ >>> self.data_root = 'a/b/c'
173
+ >>> self._join_root('d/e/')
174
+ 'a/b/c/d/e'
175
+ >>> self._join_root('https://openmmlab.com')
176
+ 'https://openmmlab.com'
177
+ >>> self._join_root(None)
178
+ 'a/b/c'
179
+ """
180
+ if path is None:
181
+ return self.data_root
182
+ if isabs(path):
183
+ return path
184
+
185
+ joined_path = self.file_backend.join_path(self.data_root, path)
186
+ return joined_path
187
+
188
+ @classmethod
189
+ def _get_meta_info(cls, in_metainfo: dict = None) -> dict:
190
+ """Collect meta information from the dictionary of meta.
191
+
192
+ Args:
193
+ in_metainfo (dict): Meta information dict.
194
+
195
+ Returns:
196
+ dict: Parsed meta information.
197
+ """
198
+ # `cls.METAINFO` will be overwritten by in_meta
199
+ metainfo = copy.deepcopy(cls.METAINFO)
200
+ if in_metainfo is None:
201
+ return metainfo
202
+
203
+ metainfo.update(in_metainfo)
204
+
205
+ return metainfo
206
+
207
+ def load_data_list(self, ann_file, metainfo_override=None):
208
+ """Load annotations from an annotation file.
209
+
210
+ Args:
211
+ ann_file (str): Absolute annotation file path if ``self.root=None``
212
+ or relative path if ``self.root=/path/to/data/``.
213
+
214
+ Returns:
215
+ list[dict]: A list of annotation.
216
+ """
217
+ annotations = mmengine.load(ann_file)
218
+ if not isinstance(annotations, dict):
219
+ raise TypeError(f'The annotations loaded from annotation file '
220
+ f'should be a dict, but got {type(annotations)}!')
221
+ if 'data_list' not in annotations:
222
+ raise ValueError('The annotation file must have the `data_list` '
223
+ 'field.')
224
+ metainfo = annotations.get('metainfo', {})
225
+ raw_data_list = annotations['data_list']
226
+
227
+ # Set meta information.
228
+ assert isinstance(metainfo, dict), 'The `metainfo` field in the '\
229
+ f'annotation file should be a dict, but got {type(metainfo)}'
230
+ if metainfo_override is not None:
231
+ assert isinstance(metainfo_override, dict), 'The `metainfo` ' \
232
+ f'argument should be a dict, but got {type(metainfo_override)}'
233
+ metainfo.update(metainfo_override)
234
+ self._metainfo = self._get_meta_info(metainfo)
235
+
236
+ data_list = []
237
+ for i, raw_data in enumerate(raw_data_list):
238
+ try:
239
+ data_list.append(self.parse_data_info(raw_data))
240
+ except AssertionError as e:
241
+ raise RuntimeError(
242
+ f'The format check fails during parse the item {i} of '
243
+ f'the annotation file with error: {e}')
244
+ return data_list
245
+
246
+ def parse_data_info(self, raw_data):
247
+ """Parse raw annotation to target format.
248
+
249
+ This method will return a dict which contains the data information of a
250
+ sample.
251
+
252
+ Args:
253
+ raw_data (dict): Raw data information load from ``ann_file``
254
+
255
+ Returns:
256
+ dict: Parsed annotation.
257
+ """
258
+ assert isinstance(raw_data, dict), \
259
+ f'The item should be a dict, but got {type(raw_data)}'
260
+ assert 'img_path' in raw_data, \
261
+ "The item doesn't have `img_path` field."
262
+ data = dict(
263
+ img_path=self._join_root(raw_data['img_path']),
264
+ gt_label=raw_data['gt_label'],
265
+ )
266
+ return data
267
+
268
+ @property
269
+ def metainfo(self) -> dict:
270
+ """Get meta information of dataset.
271
+
272
+ Returns:
273
+ dict: meta information collected from ``cls.METAINFO``,
274
+ annotation file and metainfo argument during instantiation.
275
+ """
276
+ return copy.deepcopy(self._metainfo)
277
+
278
+ def prepare_data(self, idx):
279
+ """Get data processed by ``self.pipeline``.
280
+
281
+ Args:
282
+ idx (int): The index of ``data_info``.
283
+
284
+ Returns:
285
+ Any: Depends on ``self.pipeline``.
286
+ """
287
+ results = copy.deepcopy(self.data_list[idx])
288
+ return self.pipeline(results)
289
+
290
+ def __len__(self):
291
+ """Get the length of the whole dataset.
292
+
293
+ Returns:
294
+ int: The length of filtered dataset.
295
+ """
296
+ return len(self.data_list)
297
+
298
+ def __getitem__(self, idx):
299
+ """Get the idx-th image and data information of dataset after
300
+ ``self.pipeline``.
301
+
302
+ Args:
303
+ idx (int): The index of of the data.
304
+
305
+ Returns:
306
+ dict: The idx-th image and data information after
307
+ ``self.pipeline``.
308
+ """
309
+ return self.prepare_data(idx)
310
+
311
+ def __repr__(self):
312
+ """Print the basic information of the dataset.
313
+
314
+ Returns:
315
+ str: Formatted string.
316
+ """
317
+ head = 'Dataset ' + self.__class__.__name__
318
+ body = [f'Number of samples: \t{self.__len__()}']
319
+ if self.data_root is not None:
320
+ body.append(f'Root location: \t{self.data_root}')
321
+ body.append(f'Annotation file: \t{self.ann_file}')
322
+ if self.data_prefix is not None:
323
+ body.append(f'Prefix of images: \t{self.data_prefix}')
324
+ # -------------------- extra repr --------------------
325
+ tasks = self.metainfo['tasks']
326
+ body.append(f'For {len(tasks)} tasks')
327
+ for task in tasks:
328
+ body.append(f' {task} ')
329
+ # ----------------------------------------------------
330
+
331
+ if len(self.pipeline.transforms) > 0:
332
+ body.append('With transforms:')
333
+ for t in self.pipeline.transforms:
334
+ body.append(f' {t}')
335
+
336
+ lines = [head] + [' ' * 4 + line for line in body]
337
+ return '\n'.join(lines)
mmpretrain/datasets/nlvr2.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import json
3
+ from typing import List
4
+
5
+ from mmengine.fileio import get_file_backend, list_from_file
6
+
7
+ from mmpretrain.registry import DATASETS
8
+ from .base_dataset import BaseDataset
9
+
10
+
11
+ @DATASETS.register_module()
12
+ class NLVR2(BaseDataset):
13
+ """COCO Caption dataset."""
14
+
15
+ def load_data_list(self) -> List[dict]:
16
+ """Load data list."""
17
+
18
+ data_list = []
19
+ img_prefix = self.data_prefix['img_path']
20
+ file_backend = get_file_backend(img_prefix)
21
+ examples = list_from_file(self.ann_file)
22
+
23
+ for example in examples:
24
+ example = json.loads(example)
25
+ prefix = example['identifier'].rsplit('-', 1)[0]
26
+ train_data = {}
27
+ train_data['text'] = example['sentence']
28
+ train_data['gt_label'] = {'True': 1, 'False': 0}[example['label']]
29
+ train_data['img_path'] = [
30
+ file_backend.join_path(img_prefix, prefix + f'-img{i}.png')
31
+ for i in range(2)
32
+ ]
33
+
34
+ data_list.append(train_data)
35
+
36
+ return data_list
mmpretrain/datasets/oxfordiiitpet.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List
3
+
4
+ from mmengine import get_file_backend, list_from_file
5
+
6
+ from mmpretrain.registry import DATASETS
7
+ from .base_dataset import BaseDataset
8
+ from .categories import OxfordIIITPet_CATEGORIES
9
+
10
+
11
+ @DATASETS.register_module()
12
+ class OxfordIIITPet(BaseDataset):
13
+ """The Oxford-IIIT Pets Dataset.
14
+
15
+ Support the `Oxford-IIIT Pets Dataset <https://www.robots.ox.ac.uk/~vgg/data/pets/>`_ Dataset.
16
+ After downloading and decompression, the dataset directory structure is as follows.
17
+
18
+ Oxford-IIIT_Pets dataset directory: ::
19
+
20
+ Oxford-IIIT_Pets
21
+ ├── images
22
+ │ ├── Abyssinian_1.jpg
23
+ │ ├── Abyssinian_2.jpg
24
+ │ └── ...
25
+ ├── annotations
26
+ │ ├── trainval.txt
27
+ │ ├── test.txt
28
+ │ ├── list.txt
29
+ │ └── ...
30
+ └── ....
31
+
32
+ Args:
33
+ data_root (str): The root directory for Oxford-IIIT Pets dataset.
34
+ split (str, optional): The dataset split, supports "trainval" and "test".
35
+ Default to "trainval".
36
+
37
+ Examples:
38
+ >>> from mmpretrain.datasets import OxfordIIITPet
39
+ >>> train_dataset = OxfordIIITPet(data_root='data/Oxford-IIIT_Pets', split='trainval')
40
+ >>> train_dataset
41
+ Dataset OxfordIIITPet
42
+ Number of samples: 3680
43
+ Number of categories: 37
44
+ Root of dataset: data/Oxford-IIIT_Pets
45
+ >>> test_dataset = OxfordIIITPet(data_root='data/Oxford-IIIT_Pets', split='test')
46
+ >>> test_dataset
47
+ Dataset OxfordIIITPet
48
+ Number of samples: 3669
49
+ Number of categories: 37
50
+ Root of dataset: data/Oxford-IIIT_Pets
51
+ """ # noqa: E501
52
+
53
+ METAINFO = {'classes': OxfordIIITPet_CATEGORIES}
54
+
55
+ def __init__(self, data_root: str, split: str = 'trainval', **kwargs):
56
+
57
+ splits = ['trainval', 'test']
58
+ assert split in splits, \
59
+ f"The split must be one of {splits}, but get '{split}'"
60
+ self.split = split
61
+
62
+ self.backend = get_file_backend(data_root, enable_singleton=True)
63
+ if split == 'trainval':
64
+ ann_file = self.backend.join_path('annotations', 'trainval.txt')
65
+ else:
66
+ ann_file = self.backend.join_path('annotations', 'test.txt')
67
+
68
+ data_prefix = 'images'
69
+ test_mode = split == 'test'
70
+
71
+ super(OxfordIIITPet, self).__init__(
72
+ ann_file=ann_file,
73
+ data_root=data_root,
74
+ data_prefix=data_prefix,
75
+ test_mode=test_mode,
76
+ **kwargs)
77
+
78
+ def load_data_list(self):
79
+ """Load images and ground truth labels."""
80
+
81
+ pairs = list_from_file(self.ann_file)
82
+ data_list = []
83
+ for pair in pairs:
84
+ img_name, class_id, _, _ = pair.split()
85
+ img_name = f'{img_name}.jpg'
86
+ img_path = self.backend.join_path(self.img_prefix, img_name)
87
+ gt_label = int(class_id) - 1
88
+ info = dict(img_path=img_path, gt_label=gt_label)
89
+ data_list.append(info)
90
+ return data_list
91
+
92
+ def extra_repr(self) -> List[str]:
93
+ """The extra repr information of the dataset."""
94
+ body = [
95
+ f'Root of dataset: \t{self.data_root}',
96
+ ]
97
+ return body
mmpretrain/datasets/places205.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Optional, Union
3
+
4
+ from mmpretrain.registry import DATASETS
5
+ from .categories import PLACES205_CATEGORIES
6
+ from .custom import CustomDataset
7
+
8
+
9
+ @DATASETS.register_module()
10
+ class Places205(CustomDataset):
11
+ """`Places205 <http://places.csail.mit.edu/downloadData.html>`_ Dataset.
12
+
13
+ Args:
14
+ data_root (str): The root directory for ``data_prefix`` and
15
+ ``ann_file``. Defaults to ''.
16
+ data_prefix (str | dict): Prefix for training data. Defaults
17
+ to ''.
18
+ ann_file (str): Annotation file path. Defaults to ''.
19
+ metainfo (dict, optional): Meta information for dataset, such as class
20
+ information. Defaults to None.
21
+ **kwargs: Other keyword arguments in :class:`CustomDataset` and
22
+ :class:`BaseDataset`.
23
+ """
24
+
25
+ IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
26
+ METAINFO = {'classes': PLACES205_CATEGORIES}
27
+
28
+ def __init__(self,
29
+ data_root: str = '',
30
+ data_prefix: Union[str, dict] = '',
31
+ ann_file: str = '',
32
+ metainfo: Optional[dict] = None,
33
+ **kwargs):
34
+ kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs}
35
+ super().__init__(
36
+ data_root=data_root,
37
+ data_prefix=data_prefix,
38
+ ann_file=ann_file,
39
+ metainfo=metainfo,
40
+ **kwargs)
mmpretrain/datasets/refcoco.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+ from typing import List
4
+
5
+ import mmengine
6
+ import numpy as np
7
+ from mmengine.dataset import BaseDataset
8
+ from pycocotools.coco import COCO
9
+
10
+ from mmpretrain.registry import DATASETS
11
+
12
+
13
+ @DATASETS.register_module()
14
+ class RefCOCO(BaseDataset):
15
+ """RefCOCO dataset.
16
+
17
+ Args:
18
+ ann_file (str): Annotation file path.
19
+ data_root (str): The root directory for ``data_prefix`` and
20
+ ``ann_file``. Defaults to ''.
21
+ data_prefix (str): Prefix for training data.
22
+ pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
23
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
24
+ """
25
+
26
+ def __init__(self,
27
+ data_root,
28
+ ann_file,
29
+ data_prefix,
30
+ split_file,
31
+ split='train',
32
+ **kwargs):
33
+ self.split_file = split_file
34
+ self.split = split
35
+
36
+ super().__init__(
37
+ data_root=data_root,
38
+ data_prefix=dict(img_path=data_prefix),
39
+ ann_file=ann_file,
40
+ **kwargs,
41
+ )
42
+
43
+ def _join_prefix(self):
44
+ if not mmengine.is_abs(self.split_file) and self.split_file:
45
+ self.split_file = osp.join(self.data_root, self.split_file)
46
+
47
+ return super()._join_prefix()
48
+
49
+ def load_data_list(self) -> List[dict]:
50
+ """Load data list."""
51
+ with mmengine.get_local_path(self.ann_file) as ann_file:
52
+ coco = COCO(ann_file)
53
+ splits = mmengine.load(self.split_file, file_format='pkl')
54
+ img_prefix = self.data_prefix['img_path']
55
+
56
+ data_list = []
57
+ join_path = mmengine.fileio.get_file_backend(img_prefix).join_path
58
+ for refer in splits:
59
+ if refer['split'] != self.split:
60
+ continue
61
+
62
+ ann = coco.anns[refer['ann_id']]
63
+ img = coco.imgs[ann['image_id']]
64
+ sentences = refer['sentences']
65
+ bbox = np.array(ann['bbox'], dtype=np.float32)
66
+ bbox[2:4] = bbox[0:2] + bbox[2:4] # XYWH -> XYXY
67
+
68
+ for sent in sentences:
69
+ data_info = {
70
+ 'img_path': join_path(img_prefix, img['file_name']),
71
+ 'image_id': ann['image_id'],
72
+ 'ann_id': ann['id'],
73
+ 'text': sent['sent'],
74
+ 'gt_bboxes': bbox[None, :],
75
+ }
76
+ data_list.append(data_info)
77
+
78
+ if len(data_list) == 0:
79
+ raise ValueError(f'No sample in split "{self.split}".')
80
+
81
+ return data_list
mmpretrain/datasets/samplers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .repeat_aug import RepeatAugSampler
3
+ from .sequential import SequentialSampler
4
+
5
+ __all__ = ['RepeatAugSampler', 'SequentialSampler']
mmpretrain/datasets/samplers/repeat_aug.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Iterator, Optional, Sized
3
+
4
+ import torch
5
+ from mmengine.dist import get_dist_info, is_main_process, sync_random_seed
6
+ from torch.utils.data import Sampler
7
+
8
+ from mmpretrain.registry import DATA_SAMPLERS
9
+
10
+
11
+ @DATA_SAMPLERS.register_module()
12
+ class RepeatAugSampler(Sampler):
13
+ """Sampler that restricts data loading to a subset of the dataset for
14
+ distributed, with repeated augmentation. It ensures that different each
15
+ augmented version of a sample will be visible to a different process (GPU).
16
+ Heavily based on torch.utils.data.DistributedSampler.
17
+
18
+ This sampler was taken from
19
+ https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
20
+ Used in
21
+ Copyright (c) 2015-present, Facebook, Inc.
22
+
23
+ Args:
24
+ dataset (Sized): The dataset.
25
+ shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
26
+ num_repeats (int): The repeat times of every sample. Defaults to 3.
27
+ seed (int, optional): Random seed used to shuffle the sampler if
28
+ :attr:`shuffle=True`. This number should be identical across all
29
+ processes in the distributed group. Defaults to None.
30
+ """
31
+
32
+ def __init__(self,
33
+ dataset: Sized,
34
+ shuffle: bool = True,
35
+ num_repeats: int = 3,
36
+ seed: Optional[int] = None):
37
+ rank, world_size = get_dist_info()
38
+ self.rank = rank
39
+ self.world_size = world_size
40
+
41
+ self.dataset = dataset
42
+ self.shuffle = shuffle
43
+ if not self.shuffle and is_main_process():
44
+ from mmengine.logging import MMLogger
45
+ logger = MMLogger.get_current_instance()
46
+ logger.warning('The RepeatAugSampler always picks a '
47
+ 'fixed part of data if `shuffle=False`.')
48
+
49
+ if seed is None:
50
+ seed = sync_random_seed()
51
+ self.seed = seed
52
+ self.epoch = 0
53
+ self.num_repeats = num_repeats
54
+
55
+ # The number of repeated samples in the rank
56
+ self.num_samples = math.ceil(
57
+ len(self.dataset) * num_repeats / world_size)
58
+ # The total number of repeated samples in all ranks.
59
+ self.total_size = self.num_samples * world_size
60
+ # The number of selected samples in the rank
61
+ self.num_selected_samples = math.ceil(len(self.dataset) / world_size)
62
+
63
+ def __iter__(self) -> Iterator[int]:
64
+ """Iterate the indices."""
65
+ # deterministically shuffle based on epoch and seed
66
+ if self.shuffle:
67
+ g = torch.Generator()
68
+ g.manual_seed(self.seed + self.epoch)
69
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
70
+ else:
71
+ indices = list(range(len(self.dataset)))
72
+
73
+ # produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
74
+ indices = [x for x in indices for _ in range(self.num_repeats)]
75
+ # add extra samples to make it evenly divisible
76
+ padding_size = self.total_size - len(indices)
77
+ indices += indices[:padding_size]
78
+ assert len(indices) == self.total_size
79
+
80
+ # subsample per rank
81
+ indices = indices[self.rank:self.total_size:self.world_size]
82
+ assert len(indices) == self.num_samples
83
+
84
+ # return up to num selected samples
85
+ return iter(indices[:self.num_selected_samples])
86
+
87
+ def __len__(self) -> int:
88
+ """The number of samples in this rank."""
89
+ return self.num_selected_samples
90
+
91
+ def set_epoch(self, epoch: int) -> None:
92
+ """Sets the epoch for this sampler.
93
+
94
+ When :attr:`shuffle=True`, this ensures all replicas use a different
95
+ random ordering for each epoch. Otherwise, the next iteration of this
96
+ sampler will yield the same ordering.
97
+
98
+ Args:
99
+ epoch (int): Epoch number.
100
+ """
101
+ self.epoch = epoch
mmpretrain/datasets/samplers/sequential.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Iterator
3
+
4
+ import torch
5
+ from mmengine.dataset import DefaultSampler
6
+
7
+ from mmpretrain.registry import DATA_SAMPLERS
8
+
9
+
10
+ @DATA_SAMPLERS.register_module()
11
+ class SequentialSampler(DefaultSampler):
12
+ """Sequential sampler which supports different subsample policy.
13
+
14
+ Args:
15
+ dataset (Sized): The dataset.
16
+ round_up (bool): Whether to add extra samples to make the number of
17
+ samples evenly divisible by the world size. Defaults to True.
18
+ subsample_type (str): The method to subsample data on different rank.
19
+ Supported type:
20
+
21
+ - ``'default'``: Original torch behavior. Sample the examples one
22
+ by one for each GPU in terms. For instance, 8 examples on 2 GPUs,
23
+ GPU0: [0,2,4,8], GPU1: [1,3,5,7]
24
+ - ``'sequential'``: Subsample all examples to n chunk sequntially.
25
+ For instance, 8 examples on 2 GPUs,
26
+ GPU0: [0,1,2,3], GPU1: [4,5,6,7]
27
+ """
28
+
29
+ def __init__(self, subsample_type: str = 'default', **kwargs) -> None:
30
+ super().__init__(shuffle=False, **kwargs)
31
+
32
+ if subsample_type not in ['default', 'sequential']:
33
+ raise ValueError(f'Unsupported subsample typer "{subsample_type}",'
34
+ ' please choose from ["default", "sequential"]')
35
+ self.subsample_type = subsample_type
36
+
37
+ def __iter__(self) -> Iterator[int]:
38
+ """Iterate the indices."""
39
+ indices = torch.arange(len(self.dataset)).tolist()
40
+
41
+ # add extra samples to make it evenly divisible
42
+ if self.round_up:
43
+ indices = (
44
+ indices *
45
+ int(self.total_size / len(indices) + 1))[:self.total_size]
46
+
47
+ # subsample
48
+ if self.subsample_type == 'default':
49
+ indices = indices[self.rank:self.total_size:self.world_size]
50
+ elif self.subsample_type == 'sequential':
51
+ num_samples_per_rank = self.total_size // self.world_size
52
+ indices = indices[self.rank *
53
+ num_samples_per_rank:(self.rank + 1) *
54
+ num_samples_per_rank]
55
+
56
+ return iter(indices)
mmpretrain/datasets/scienceqa.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ from typing import Callable, List, Sequence
4
+
5
+ import mmengine
6
+ from mmengine.dataset import BaseDataset
7
+ from mmengine.fileio import get_file_backend
8
+
9
+ from mmpretrain.registry import DATASETS
10
+
11
+
12
+ @DATASETS.register_module()
13
+ class ScienceQA(BaseDataset):
14
+ """ScienceQA dataset.
15
+
16
+ This dataset is used to load the multimodal data of ScienceQA dataset.
17
+
18
+ Args:
19
+ data_root (str): The root directory for ``data_prefix`` and
20
+ ``ann_file``.
21
+ split (str): The split of dataset. Options: ``train``, ``val``,
22
+ ``test``, ``trainval``, ``minival``, and ``minitest``.
23
+ split_file (str): The split file of dataset, which contains the
24
+ ids of data samples in the split.
25
+ ann_file (str): Annotation file path.
26
+ data_prefix (dict): Prefix for data field. Defaults to
27
+ ``dict(img_path='')``.
28
+ pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
29
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
30
+ """
31
+
32
+ def __init__(self,
33
+ data_root: str,
34
+ split: str,
35
+ split_file: str,
36
+ ann_file: str,
37
+ data_prefix: dict = dict(img_path=''),
38
+ pipeline: Sequence[Callable] = (),
39
+ **kwargs):
40
+
41
+ assert split in [
42
+ 'train', 'val', 'test', 'trainval', 'minival', 'minitest'
43
+ ], f'Invalid split {split}'
44
+ self.split = split
45
+ self.split_file = os.path.join(data_root, split_file)
46
+
47
+ super().__init__(
48
+ data_root=data_root,
49
+ ann_file=ann_file,
50
+ data_prefix=data_prefix,
51
+ pipeline=pipeline,
52
+ **kwargs)
53
+
54
+ def load_data_list(self) -> List[dict]:
55
+ """Load data list."""
56
+ img_prefix = self.data_prefix['img_path']
57
+ annotations = mmengine.load(self.ann_file)
58
+ current_data_split = mmengine.load(self.split_file)[self.split] # noqa
59
+
60
+ file_backend = get_file_backend(img_prefix)
61
+
62
+ data_list = []
63
+ for data_id in current_data_split:
64
+ ann = annotations[data_id]
65
+ data_info = {
66
+ 'image_id':
67
+ data_id,
68
+ 'question':
69
+ ann['question'],
70
+ 'choices':
71
+ ann['choices'],
72
+ 'gt_answer':
73
+ ann['answer'],
74
+ 'hint':
75
+ ann['hint'],
76
+ 'image_name':
77
+ ann['image'],
78
+ 'task':
79
+ ann['task'],
80
+ 'grade':
81
+ ann['grade'],
82
+ 'subject':
83
+ ann['subject'],
84
+ 'topic':
85
+ ann['topic'],
86
+ 'category':
87
+ ann['category'],
88
+ 'skill':
89
+ ann['skill'],
90
+ 'lecture':
91
+ ann['lecture'],
92
+ 'solution':
93
+ ann['solution'],
94
+ 'split':
95
+ ann['split'],
96
+ 'img_path':
97
+ file_backend.join_path(img_prefix, data_id, ann['image'])
98
+ if ann['image'] is not None else None,
99
+ 'has_image':
100
+ True if ann['image'] is not None else False,
101
+ }
102
+ data_list.append(data_info)
103
+
104
+ return data_list
mmpretrain/datasets/stanfordcars.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List
3
+
4
+ import mat4py
5
+ from mmengine import get_file_backend
6
+
7
+ from mmpretrain.registry import DATASETS
8
+ from .base_dataset import BaseDataset
9
+ from .categories import STANFORDCARS_CATEGORIES
10
+
11
+
12
+ @DATASETS.register_module()
13
+ class StanfordCars(BaseDataset):
14
+ """The Stanford Cars Dataset.
15
+
16
+ Support the `Stanford Cars Dataset <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset.
17
+ The official website provides two ways to organize the dataset.
18
+ Therefore, after downloading and decompression, the dataset directory structure is as follows.
19
+
20
+ Stanford Cars dataset directory: ::
21
+
22
+ Stanford_Cars
23
+ ├── car_ims
24
+ │ ├── 00001.jpg
25
+ │ ├── 00002.jpg
26
+ │ └── ...
27
+ └── cars_annos.mat
28
+
29
+ or ::
30
+
31
+ Stanford_Cars
32
+ ├── cars_train
33
+ │ ├── 00001.jpg
34
+ │ ├── 00002.jpg
35
+ │ └── ...
36
+ ├── cars_test
37
+ │ ├── 00001.jpg
38
+ │ ├── 00002.jpg
39
+ │ └── ...
40
+ └── devkit
41
+ ├── cars_meta.mat
42
+ ├── cars_train_annos.mat
43
+ ├── cars_test_annos.mat
44
+ ├── cars_test_annoswithlabels.mat
45
+ ├── eval_train.m
46
+ └── train_perfect_preds.txt
47
+
48
+ Args:
49
+ data_root (str): The root directory for Stanford Cars dataset.
50
+ split (str, optional): The dataset split, supports "train"
51
+ and "test". Default to "train".
52
+
53
+ Examples:
54
+ >>> from mmpretrain.datasets import StanfordCars
55
+ >>> train_dataset = StanfordCars(data_root='data/Stanford_Cars', split='train')
56
+ >>> train_dataset
57
+ Dataset StanfordCars
58
+ Number of samples: 8144
59
+ Number of categories: 196
60
+ Root of dataset: data/Stanford_Cars
61
+ >>> test_dataset = StanfordCars(data_root='data/Stanford_Cars', split='test')
62
+ >>> test_dataset
63
+ Dataset StanfordCars
64
+ Number of samples: 8041
65
+ Number of categories: 196
66
+ Root of dataset: data/Stanford_Cars
67
+ """ # noqa: E501
68
+
69
+ METAINFO = {'classes': STANFORDCARS_CATEGORIES}
70
+
71
+ def __init__(self, data_root: str, split: str = 'train', **kwargs):
72
+
73
+ splits = ['train', 'test']
74
+ assert split in splits, \
75
+ f"The split must be one of {splits}, but get '{split}'"
76
+ self.split = split
77
+
78
+ test_mode = split == 'test'
79
+ self.backend = get_file_backend(data_root, enable_singleton=True)
80
+
81
+ anno_file_path = self.backend.join_path(data_root, 'cars_annos.mat')
82
+ if self.backend.exists(anno_file_path):
83
+ ann_file = 'cars_annos.mat'
84
+ data_prefix = ''
85
+ else:
86
+ if test_mode:
87
+ ann_file = self.backend.join_path(
88
+ 'devkit', 'cars_test_annos_withlabels.mat')
89
+ data_prefix = 'cars_test'
90
+ else:
91
+ ann_file = self.backend.join_path('devkit',
92
+ 'cars_train_annos.mat')
93
+ data_prefix = 'cars_train'
94
+
95
+ if not self.backend.exists(
96
+ self.backend.join_path(data_root, ann_file)):
97
+ doc_url = 'https://mmpretrain.readthedocs.io/en/latest/api/datasets.html#stanfordcars' # noqa: E501
98
+ raise RuntimeError(
99
+ f'The dataset is incorrectly organized, please \
100
+ refer to {doc_url} and reorganize your folders.')
101
+
102
+ super(StanfordCars, self).__init__(
103
+ ann_file=ann_file,
104
+ data_root=data_root,
105
+ data_prefix=data_prefix,
106
+ test_mode=test_mode,
107
+ **kwargs)
108
+
109
+ def load_data_list(self):
110
+ data = mat4py.loadmat(self.ann_file)['annotations']
111
+
112
+ data_list = []
113
+ if 'test' in data.keys():
114
+ # first way
115
+ img_paths, labels, test = data['relative_im_path'], data[
116
+ 'class'], data['test']
117
+ num = len(img_paths)
118
+ assert num == len(labels) == len(test), 'get error ann file'
119
+ for i in range(num):
120
+ if not self.test_mode and test[i] == 1:
121
+ continue
122
+ if self.test_mode and test[i] == 0:
123
+ continue
124
+ img_path = self.backend.join_path(self.img_prefix,
125
+ img_paths[i])
126
+ gt_label = labels[i] - 1
127
+ info = dict(img_path=img_path, gt_label=gt_label)
128
+ data_list.append(info)
129
+ else:
130
+ # second way
131
+ img_names, labels = data['fname'], data['class']
132
+ num = len(img_names)
133
+ assert num == len(labels), 'get error ann file'
134
+ for i in range(num):
135
+ img_path = self.backend.join_path(self.img_prefix,
136
+ img_names[i])
137
+ gt_label = labels[i] - 1
138
+ info = dict(img_path=img_path, gt_label=gt_label)
139
+ data_list.append(info)
140
+
141
+ return data_list
142
+
143
+ def extra_repr(self) -> List[str]:
144
+ """The extra repr information of the dataset."""
145
+ body = [
146
+ f'Root of dataset: \t{self.data_root}',
147
+ ]
148
+ return body
mmpretrain/datasets/sun397.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List
3
+
4
+ from mmengine import get_file_backend, list_from_file
5
+
6
+ from mmpretrain.registry import DATASETS
7
+ from .base_dataset import BaseDataset
8
+ from .categories import SUN397_CATEGORIES
9
+
10
+ # Note that some images are not a jpg file although the name ends
11
+ # with jpg and therefore cannot be read properly. So we provide
12
+ # a list to skip these files.
13
+ INVALID = [
14
+ '/a/assembly_line/sun_ajckcfldgdrdjogj.jpg',
15
+ '/a/auto_factory/sun_apfsprenzdnzbhmt.jpg',
16
+ '/b/baggage_claim/sun_avittiqqaiibgcau.jpg',
17
+ '/b/batters_box/sun_alqlfpgtbgggezyr.jpg',
18
+ '/b/bow_window/indoor/sun_ahsholsagvlrsboa.jpg',
19
+ '/b/bow_window/indoor/sun_aioomcoujmmcxkkx.jpg',
20
+ '/b/bow_window/outdoor/sun_atgtjdpqikjmllth.jpg',
21
+ '/c/carrousel/sun_atsgphqympojgxnc.jpg',
22
+ '/c/carrousel/sun_auzitjuirwolazns.jpg',
23
+ '/c/church/outdoor/sun_boagasgfltequmal.jpg',
24
+ '/c/church/outdoor/sun_brhmnwzzbkphcvfo.jpg',
25
+ '/c/church/outdoor/sun_byjkqzybxpjnuofa.jpg',
26
+ '/c/corridor/sun_aznefxvocwpgimko.jpg',
27
+ '/d/dentists_office/sun_aaefsoauqlcsihou.jpg',
28
+ '/d/diner/indoor/sun_apswilaujhntrybg.jpg',
29
+ '/e/elevator/door/sun_aaudobqlphijkjdv.jpg',
30
+ '/f/fastfood_restaurant/sun_axeniwtesffxqedr.jpg',
31
+ '/f/fire_station/sun_bjyapttwilyyuxqm.jpg',
32
+ '/f/fountain/sun_axgmpbdyvqhtkhee.jpg',
33
+ '/h/hospital_room/sun_ahokhhxjiclpxqqa.jpg',
34
+ '/o/oast_house/sun_bqsrrygxyrutgjve.jpg',
35
+ '/r/restaurant_patio/sun_aurwypviprwycame.jpg',
36
+ '/s/ski_resort/sun_bplmntyzoiobcqhp.jpg',
37
+ '/w/wine_cellar/bottle_storage/sun_afmzwxkzmxkbamqi.jpg',
38
+ '/w/wine_cellar/bottle_storage/sun_ahyymswdjejrbhyb.jpg',
39
+ '/w/wine_cellar/bottle_storage/sun_avnttpxamufejbfe.jpg',
40
+ '/a/archive/sun_awgsrbljlsvhqjij.jpg',
41
+ '/a/art_school/sun_aabogqsjulyvmcse.jpg',
42
+ '/a/art_school/sun_apnzojafyvkariue.jpg',
43
+ '/b/ball_pit/sun_atjhwqngtoeuwhso.jpg',
44
+ '/b/bow_window/indoor/sun_asxvsqbexmmtqmht.jpg',
45
+ '/b/bow_window/indoor/sun_abeugxecxrwzmffp.jpg',
46
+ '/b/bow_window/outdoor/sun_auwcqhrtzkgihvlv.jpg',
47
+ '/b/bow_window/outdoor/sun_apnvdyecnjjmcuhi.jpg',
48
+ '/c/childs_room/sun_alggivksjwwiklmt.jpg',
49
+ '/c/control_tower/outdoor/sun_avbcxakrvpomqdgr.jpg',
50
+ '/d/diner/indoor/sun_ajmzozstvsxisvgx.jpg',
51
+ '/e/elevator/door/sun_aaqsyluqbluugqgy.jpg',
52
+ '/f/fastfood_restaurant/sun_aevchxlxoruhxgrb.jpg',
53
+ '/f/firing_range/indoor/sun_affrzvahwjorpalo.jpg',
54
+ '/f/formal_garden/sun_bjvrlaeatjufekft.jpg',
55
+ '/g/garage/indoor/sun_akbocuwclkxqlofx.jpg',
56
+ '/g/greenhouse/indoor/sun_addirvgtxfbndlwf.jpg',
57
+ '/k/kindergarden_classroom/sun_ajtpaahilrqzarri.jpg',
58
+ '/l/laundromat/sun_afrrjykuhhlwiwun.jpg',
59
+ '/m/music_studio/sun_bsntklkmwqgnjrjj.jpg',
60
+ '/t/track/outdoor/sun_aophkoiosslinihb.jpg',
61
+ '/a/archive/sun_aegmzltkiwyevpwa.jpg',
62
+ '/a/auto_factory/sun_aybymzvbxgvcrwgn.jpg',
63
+ '/b/baggage_claim/sun_atpmiqmnxjpgqsxi.jpg',
64
+ '/b/baggage_claim/sun_ajffcdpsvgqfzoxx.jpg',
65
+ '/b/bamboo_forest/sun_ausmxphosyahoyjo.jpg',
66
+ '/b/batters_box/sun_aaeheulsicxtxnbu.jpg',
67
+ '/c/carrousel/sun_arjrjcxemhttubqz.jpg',
68
+ '/c/chicken_coop/outdoor/sun_abcegmmdbizqkpgh.jpg',
69
+ '/c/control_tower/outdoor/sun_axhjfpkxdvqdfkyr.jpg',
70
+ '/d/diner/indoor/sun_apaotiublwqeowck.jpg',
71
+ '/f/fastfood_restaurant/sun_anexashcgmxdbmxq.jpg',
72
+ '/l/landing_deck/sun_aizahnjfkuurjibw.jpg',
73
+ '/n/nuclear_power_plant/outdoor/sun_aoblfvgyleweqanr.jpg',
74
+ '/w/waiting_room/sun_aicytusmthfvqcwc.jpg',
75
+ '/b/bow_window/indoor/sun_asmvdfnjlulewkpr.jpg',
76
+ '/b/bus_interior/sun_adhktvidwzmodeou.jpg',
77
+ '/c/catacomb/sun_algnawesgjzzmcqd.jpg',
78
+ '/c/church/outdoor/sun_baihxlseimcsdhdx.jpg',
79
+ '/d/diner/indoor/sun_agoyalzcawgxodbm.jpg',
80
+ '/e/elevator_shaft/sun_awaitimkinrjaybl.jpg',
81
+ '/f/fastfood_restaurant/sun_aplvzfbmtqtbsvbx.jpg',
82
+ '/g/greenhouse/indoor/sun_bkccvyfpwetwjuhk.jpg',
83
+ '/c/car_interior/backseat/sun_adexwfoqdyhowxpu.jpg',
84
+ '/c/church/outdoor/sun_blmmweiumednscuf.jpg',
85
+ '/f/fire_station/sun_bibntbsuunbsdrum.jpg',
86
+ '/g/game_room/sun_aopfaqlllpvzhrak.jpg',
87
+ '/u/underwater/coral_reef/sun_biiueajvszaxqopo.jpg',
88
+ '/a/airplane_cabin/sun_arqyikigkyfpegug.jpg',
89
+ '/b/badminton_court/indoor/sun_amppvxecgtjpfold.jpg',
90
+ '/c/carrousel/sun_anxtrtieimkpmhvk.jpg',
91
+ '/c/computer_room/sun_aebgvpgtwoqbfyvl.jpg',
92
+ '/f/fire_escape/sun_atbraxuwwlvdoolv.jpg',
93
+ '/k/kasbah/sun_abxkkoielpavsouu.jpg',
94
+ '/t/tower/sun_bccqnzcvqkiwicjt.jpg',
95
+ '/a/archive/sun_afngadshxudodkct.jpg',
96
+ '/b/bow_window/indoor/sun_awnrlipyxpgxxgxz.jpg',
97
+ '/c/control_tower/outdoor/sun_arohngcbtsvbthho.jpg',
98
+ '/f/fire_station/sun_brbskkfgghbfvgkk.jpg',
99
+ '/r/restaurant_patio/sun_amjfbqzfgxarrpec.jpg',
100
+ '/v/vineyard/sun_bdxhnbgbnolddswz.jpg',
101
+ '/b/baggage_claim/sun_axrtsmillrglugia.jpg',
102
+ '/d/diner/indoor/sun_alaqevbwpjaqqdqz.jpg',
103
+ '/l/landing_deck/sun_acodgoamhgnnbmvr.jpg',
104
+ '/c/carrousel/sun_adsafgyrinnekycc.jpg',
105
+ '/c/church/outdoor/sun_bzqhuwshtdgakkay.jpg',
106
+ '/c/closet/sun_absahzamlrylkxyn.jpg',
107
+ '/f/fire_escape/sun_acdthenaosuqcoqn.jpg',
108
+ '/b/butchers_shop/sun_asrdgbefoszenfex.jpg',
109
+ '/c/church/outdoor/sun_bzfyucfrdigaqneg.jpg',
110
+ '/c/church/outdoor/sun_byzxhknqrejdajxi.jpg',
111
+ '/c/cockpit/sun_ajkulpqauavrmxae.jpg',
112
+ '/l/living_room/sun_aefoqbeatyufobtx.jpg',
113
+ '/s/supermarket/sun_attvxbzocurnddbz.jpg',
114
+ '/c/closet/sun_aqnutmwfkypmrnfy.jpg',
115
+ '/f/fire_station/sun_bttrtzktpbymxkmf.jpg',
116
+ '/s/shopping_mall/indoor/sun_avwzjsijaxnwuzjx.jpg',
117
+ '/w/windmill/sun_blvczkyqbmabzeej.jpg',
118
+ '/c/chicken_coop/outdoor/sun_amaonsnnkskxwmrj.jpg',
119
+ '/s/swimming_pool/outdoor/sun_bslaihiqlhfewtzn.jpg',
120
+ '/u/underwater/coral_reef/sun_bhcrnmvbgnkvcvkr.jpg',
121
+ '/d/dining_room/sun_azlxdhiajwrhaivq.jpg',
122
+ '/c/church/outdoor/sun_bnunxbznqnvgeykx.jpg',
123
+ '/c/corridor/sun_aspwpqqlcwzfanvl.jpg',
124
+ '/r/restaurant_patio/sun_awcbpizjbudjvrhs.jpg',
125
+ '/b/ball_pit/sun_avdnmemjrgrbkwjm.jpg',
126
+ ]
127
+
128
+
129
+ @DATASETS.register_module()
130
+ class SUN397(BaseDataset):
131
+ """The SUN397 Dataset.
132
+
133
+ Support the `SUN397 Dataset <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_ Dataset.
134
+ After downloading and decompression, the dataset directory structure is as follows.
135
+
136
+ SUN397 dataset directory: ::
137
+
138
+ SUN397
139
+ ├── SUN397
140
+ │ ├── a
141
+ │ │ ├── abbey
142
+ │ | | ├── sun_aaalbzqrimafwbiv.jpg
143
+ │ | | └── ...
144
+ │ │ ├── airplane_cabin
145
+ │ | | ├── sun_aadqdkqaslqqoblu.jpg
146
+ │ | | └── ...
147
+ │ | └── ...
148
+ │ ├── b
149
+ │ │ └── ...
150
+ │ ├── c
151
+ │ │ └── ...
152
+ │ └── ...
153
+ └── Partitions
154
+ ├── ClassName.txt
155
+ ├── Training_01.txt
156
+ ├── Testing_01.txt
157
+ └── ...
158
+
159
+ Args:
160
+ data_root (str): The root directory for Stanford Cars dataset.
161
+ split (str, optional): The dataset split, supports "train" and "test".
162
+ Default to "train".
163
+
164
+ Examples:
165
+ >>> from mmpretrain.datasets import SUN397
166
+ >>> train_dataset = SUN397(data_root='data/SUN397', split='train')
167
+ >>> train_dataset
168
+ Dataset SUN397
169
+ Number of samples: 19824
170
+ Number of categories: 397
171
+ Root of dataset: data/SUN397
172
+ >>> test_dataset = SUN397(data_root='data/SUN397', split='test')
173
+ >>> test_dataset
174
+ Dataset SUN397
175
+ Number of samples: 19829
176
+ Number of categories: 397
177
+ Root of dataset: data/SUN397
178
+ """ # noqa: E501
179
+
180
+ METAINFO = {'classes': SUN397_CATEGORIES}
181
+
182
+ def __init__(self, data_root: str, split: str = 'train', **kwargs):
183
+
184
+ splits = ['train', 'test']
185
+ assert split in splits, \
186
+ f"The split must be one of {splits}, but get '{split}'"
187
+ self.split = split
188
+
189
+ self.backend = get_file_backend(data_root, enable_singleton=True)
190
+ if split == 'train':
191
+ ann_file = self.backend.join_path('Partitions', 'Training_01.txt')
192
+ else:
193
+ ann_file = self.backend.join_path('Partitions', 'Testing_01.txt')
194
+
195
+ data_prefix = 'SUN397'
196
+ test_mode = split == 'test'
197
+
198
+ super(SUN397, self).__init__(
199
+ ann_file=ann_file,
200
+ data_root=data_root,
201
+ test_mode=test_mode,
202
+ data_prefix=data_prefix,
203
+ **kwargs)
204
+
205
+ def load_data_list(self):
206
+ pairs = list_from_file(self.ann_file)
207
+ data_list = []
208
+ for pair in pairs:
209
+ if pair in INVALID:
210
+ continue
211
+ img_path = self.backend.join_path(self.img_prefix, pair[1:])
212
+ items = pair.split('/')
213
+ class_name = '_'.join(items[2:-1])
214
+ gt_label = self.METAINFO['classes'].index(class_name)
215
+ info = dict(img_path=img_path, gt_label=gt_label)
216
+ data_list.append(info)
217
+
218
+ return data_list
219
+
220
+ def extra_repr(self) -> List[str]:
221
+ """The extra repr information of the dataset."""
222
+ body = [
223
+ f'Root of dataset: \t{self.data_root}',
224
+ ]
225
+ return body
mmpretrain/datasets/transforms/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from mmcv.transforms import (CenterCrop, LoadImageFromFile, Normalize,
3
+ RandomFlip, RandomGrayscale, RandomResize, Resize)
4
+
5
+ from mmpretrain.registry import TRANSFORMS
6
+ from .auto_augment import (AutoAugment, AutoContrast, BaseAugTransform,
7
+ Brightness, ColorTransform, Contrast, Cutout,
8
+ Equalize, GaussianBlur, Invert, Posterize,
9
+ RandAugment, Rotate, Sharpness, Shear, Solarize,
10
+ SolarizeAdd, Translate)
11
+ from .formatting import (Collect, NumpyToPIL, PackInputs, PackMultiTaskInputs,
12
+ PILToNumpy, Transpose)
13
+ from .processing import (Albumentations, BEiTMaskGenerator, CleanCaption,
14
+ ColorJitter, EfficientNetCenterCrop,
15
+ EfficientNetRandomCrop, Lighting, RandomCrop,
16
+ RandomErasing, RandomResizedCrop, RandomTranslatePad,
17
+ ResizeEdge, SimMIMMaskGenerator)
18
+ from .wrappers import ApplyToList, MultiView
19
+
20
+ for t in (CenterCrop, LoadImageFromFile, Normalize, RandomFlip,
21
+ RandomGrayscale, RandomResize, Resize):
22
+ TRANSFORMS.register_module(module=t)
23
+
24
+ __all__ = [
25
+ 'NumpyToPIL', 'PILToNumpy', 'Transpose', 'Collect', 'RandomCrop',
26
+ 'RandomResizedCrop', 'Shear', 'Translate', 'Rotate', 'Invert',
27
+ 'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize',
28
+ 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',
29
+ 'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing',
30
+ 'PackInputs', 'Albumentations', 'EfficientNetRandomCrop',
31
+ 'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform',
32
+ 'PackMultiTaskInputs', 'GaussianBlur', 'BEiTMaskGenerator',
33
+ 'SimMIMMaskGenerator', 'CenterCrop', 'LoadImageFromFile', 'Normalize',
34
+ 'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView',
35
+ 'ApplyToList', 'CleanCaption', 'RandomTranslatePad'
36
+ ]