LINC-BIT's picture
Upload 1912 files
b84549f verified
raw
history blame
2.28 kB
from ..data_aug import cifar_like_image_test_aug, cifar_like_image_train_aug
from ..ab_dataset import ABDataset
from ..dataset_split import train_val_split
from torchvision.datasets import SVHN as RawSVHN
import numpy as np
from typing import Dict, List, Optional
from torchvision import transforms
from torchvision.transforms import Compose
from utils.common.others import HiddenPrints
from ..registery import dataset_register
@dataset_register(
name='SVHN',
classes=[str(i) for i in range(10)],
task_type='Image Classification',
object_type='Digit and Letter',
class_aliases=[],
shift_type=None
)
class SVHN(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform: Optional[Compose],
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
if transform is None:
mean, std = [0.5] * 3, [0.5] * 3
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(mean, std)
]) if split == 'train' else \
transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
self.transform = transform
with HiddenPrints():
dataset = RawSVHN(root_dir, 'train' if split != 'test' else 'test', transform=transform, download=True)
if len(ignore_classes) > 0:
for ignore_class in ignore_classes:
dataset.data = dataset.data[dataset.labels != classes.index(ignore_class)]
dataset.labels = dataset.labels[dataset.labels != classes.index(ignore_class)]
if idx_map is not None:
# note: the code below seems correct but has bug!
# for old_idx, new_idx in idx_map.items():
# dataset.targets[dataset.targets == old_idx] = new_idx
for ti, t in enumerate(dataset.labels):
dataset.labels[ti] = idx_map[t]
if split != 'test':
dataset = train_val_split(dataset, split)
return dataset