File size: 1,323 Bytes
2efd69c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from torchvision import datasets
import albumentations as A
from albumentations.pytorch import ToTensorV2


NORM_DATA_MEAN = (0.49139968, 0.48215841, 0.44653091)
NORM_DATA_STD = (0.24703223, 0.24348513, 0.26158784)

CIFAR_CLASS_LABELS = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
    ]

TRAIN_TRANSFORM = A.Compose([
    A.Normalize(
        mean=NORM_DATA_MEAN,
        std=NORM_DATA_STD,
    ),
    A.HorizontalFlip(),
    A.Compose([
        A.PadIfNeeded(min_height=40, min_width=40, p=1.0),
        A.CoarseDropout(max_holes=1, max_height=16, max_width=16,
            min_holes=1, min_height=16, min_width=16, 
            fill_value=NORM_DATA_MEAN, mask_fill_value=None, p=1.0),
        A.RandomCrop(p=1.0, height=32, width=32)
    ]),
    ToTensorV2(),
])

TEST_TRANSFORM = A.Compose([
    A.Normalize(
        mean=NORM_DATA_MEAN,
        std=NORM_DATA_STD,
    ),
    ToTensorV2(),
])

class CifarAlbumentationsDataset(datasets.CIFAR10):
    def __init__(self, *args,  **kwargs):
        super().__init__(*args, **kwargs)
    def __getitem__(self, idx):
        img, target = self.data[idx], self.targets[idx]
        if self.transform:
            augmented = self.transform(image=img)
            image = augmented['image']
        return image, target