File size: 3,427 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from ..ab_dataset import ABDataset
from ..dataset_split import train_val_split, train_val_test_split
from typing import Dict, List, Optional
from torchvision.transforms import Compose
from .yolox_data_util.api import get_default_yolox_coco_dataset, get_yolox_coco_dataset_with_caption, remap_dataset, ensure_index_start_from_1_and_successive, coco_train_val_test_split
import os
from ..registery import dataset_register
@dataset_register(
name='CityscapesDet',
classes=[
'car', 'bus'
],
task_type='Object Detection',
object_type='Driving',
class_aliases=[],
shift_type=None
)
class CityscapesDet(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform: Optional[Compose],
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
assert transform is None, \
'The implementation of object detection datasets is based on YOLOX (https://github.com/Megvii-BaseDetection/YOLOX) ' \
'where normal `torchvision.transforms` is not supported. You can re-implement the dataset to override default data aug.'
ann_json_file_path = os.path.join(root_dir, 'coco_ann.json')
assert os.path.exists(ann_json_file_path), \
f'Please put the COCO annotation JSON file in root_dir: `{root_dir}/coco_ann.json`.'
ann_json_file_path = ensure_index_start_from_1_and_successive(ann_json_file_path)
ann_json_file_path = remap_dataset(ann_json_file_path, ignore_classes, idx_map)
ann_json_file_path = coco_train_val_test_split(ann_json_file_path, split)
self.ann_json_file_path_for_split = ann_json_file_path
dataset = get_default_yolox_coco_dataset(root_dir, ann_json_file_path, train=(split == 'train'))
# dataset = train_val_test_split(dataset, split)
return dataset
@dataset_register(
name='MM-CityscapesDet',
classes=[
'car', 'bus'
],
task_type='MM Object Detection',
object_type='Driving',
class_aliases=[],
shift_type=None
)
class MM_CityscapesDet(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform: Optional[Compose],
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
# assert transform is None, \
# 'The implementation of object detection datasets is based on YOLOX (https://github.com/Megvii-BaseDetection/YOLOX) ' \
# 'where normal `torchvision.transforms` is not supported. You can re-implement the dataset to override default data aug.'
ann_json_file_path = os.path.join(root_dir, 'coco_ann.json')
assert os.path.exists(ann_json_file_path), \
f'Please put the COCO annotation JSON file in root_dir: `{root_dir}/coco_ann.json`.'
ann_json_file_path = ensure_index_start_from_1_and_successive(ann_json_file_path)
ann_json_file_path = remap_dataset(ann_json_file_path, ignore_classes, idx_map)
ann_json_file_path = coco_train_val_test_split(ann_json_file_path, split)
self.ann_json_file_path_for_split = ann_json_file_path
dataset = get_yolox_coco_dataset_with_caption(root_dir, ann_json_file_path, transform=transform, train=(split == 'train'), classes=classes)
# dataset = train_val_test_split(dataset, split)
return dataset |