File size: 2,280 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from ..data_aug import cifar_like_image_test_aug, cifar_like_image_train_aug
from ..ab_dataset import ABDataset
from ..dataset_split import train_val_split
from torchvision.datasets import SVHN as RawSVHN
import numpy as np
from typing import Dict, List, Optional
from torchvision import transforms
from torchvision.transforms import Compose
from utils.common.others import HiddenPrints

from ..registery import dataset_register


@dataset_register(
    name='SVHN', 
    classes=[str(i) for i in range(10)], 
    task_type='Image Classification',
    object_type='Digit and Letter',
    class_aliases=[],
    shift_type=None
)
class SVHN(ABDataset):    
    def create_dataset(self, root_dir: str, split: str, transform: Optional[Compose], 
                       classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
        if transform is None:
            mean, std = [0.5] * 3, [0.5] * 3
            transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]) if split == 'train' else \
                transforms.Compose([
                transforms.Resize(32),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
                
            self.transform = transform
        
        with HiddenPrints():
            dataset = RawSVHN(root_dir, 'train' if split != 'test' else 'test', transform=transform, download=True)
        
        if len(ignore_classes) > 0: 
            for ignore_class in ignore_classes:
                dataset.data = dataset.data[dataset.labels != classes.index(ignore_class)]
                dataset.labels = dataset.labels[dataset.labels != classes.index(ignore_class)]
        
        if idx_map is not None:
            # note: the code below seems correct but has bug!
            # for old_idx, new_idx in idx_map.items():
            #     dataset.targets[dataset.targets == old_idx] = new_idx
                
            for ti, t in enumerate(dataset.labels):
                dataset.labels[ti] = idx_map[t]
        
        if split != 'test':
            dataset = train_val_split(dataset, split)
        return dataset