LINC-BIT's picture
Upload 1912 files
b84549f verified
raw
history blame
1.93 kB
from ..data_aug import cifar_like_image_test_aug, cifar_like_image_train_aug
from ..ab_dataset import ABDataset
from ..dataset_split import train_val_test_split
from torchvision.datasets import ImageFolder
import numpy as np
from typing import Dict, List, Optional
from torchvision import transforms
from torchvision.transforms import Compose
from utils.common.others import HiddenPrints
from ..registery import dataset_register
@dataset_register(
name='SVHN-single',
classes=[str(i) for i in range(10)],
task_type='Image Classification',
object_type='Digit and Letter',
class_aliases=[],
shift_type=None
)
class SVHNSingle(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform: Optional[Compose],
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
if transform is None:
mean, std = [0.5] * 3, [0.5] * 3
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(mean, std)
]) if split == 'train' else \
transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
self.transform = transform
dataset = ImageFolder(root_dir, transform=transform)
if len(ignore_classes) > 0:
ignore_classes_idx = [classes.index(c) for c in ignore_classes]
dataset.samples = [s for s in dataset.samples if s[1] not in ignore_classes_idx]
if idx_map is not None:
dataset.samples = [(s[0], idx_map[s[1]]) if s[1] in idx_map.keys() else s for s in dataset.samples]
dataset = train_val_test_split(dataset, split)
return dataset