ResNet50_replicate / data_utils.py
ubuntu
Added one cycle lr and lr_finder and reduced jitter
373be07
raw
history blame
1.47 kB
from torchvision import datasets
from torch.utils.data import DataLoader
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
def get_train_transform():
return A.Compose([
A.RandomResizedCrop(height=224, width=224, scale=(0.08, 1.0), ratio=(3/4, 4/3), p=1.0),
A.HorizontalFlip(p=0.5),
A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05, p=0.8),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])
def get_test_transform():
return A.Compose([
A.Resize(height=256, width=256),
A.CenterCrop(height=224, width=224),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])
def get_data_loaders(train_transform, test_transform, batch_size_train=128, batch_size_test=500):
trainset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/train', transform=lambda img: train_transform(image=np.array(img))['image'])
trainloader = DataLoader(trainset, batch_size=batch_size_train, shuffle=True, num_workers=8, pin_memory=True)
testset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/val', transform=lambda img: test_transform(image=np.array(img))['image'])
testloader = DataLoader(testset, batch_size=batch_size_test, shuffle=False, num_workers=8, pin_memory=True)
return trainloader, testloader