File size: 1,471 Bytes
3352589
 
 
 
 
 
 
 
 
 
373be07
3352589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373be07
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from torchvision import datasets
from torch.utils.data import DataLoader
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2

def get_train_transform():
    return A.Compose([
        A.RandomResizedCrop(height=224, width=224, scale=(0.08, 1.0), ratio=(3/4, 4/3), p=1.0),
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05, p=0.8),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])

def get_test_transform():
    return A.Compose([
        A.Resize(height=256, width=256),
        A.CenterCrop(height=224, width=224),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])

def get_data_loaders(train_transform, test_transform, batch_size_train=128, batch_size_test=500):
    trainset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/train', transform=lambda img: train_transform(image=np.array(img))['image'])
    trainloader = DataLoader(trainset, batch_size=batch_size_train, shuffle=True, num_workers=8, pin_memory=True)

    testset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/val', transform=lambda img: test_transform(image=np.array(img))['image'])
    testloader = DataLoader(testset, batch_size=batch_size_test, shuffle=False, num_workers=8, pin_memory=True)

    return trainloader, testloader