|
import torch |
|
from torch.utils.data import DataLoader |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import InterpolationMode |
|
from data.facecaption_dataset import facecaption_train, facecaption_test |
|
from data.randaugment import RandomAugment |
|
|
|
def create_dataset(args, dataset, min_scale=0.5): |
|
|
|
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
|
|
transform_train = transforms.Compose([ |
|
transforms.Resize((224, 224),interpolation=InterpolationMode.BICUBIC), |
|
transforms.RandomHorizontalFlip(), |
|
RandomAugment(2,5,isPIL=True,augs=['Identity','Brightness','Sharpness','Equalize', |
|
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
transform_test = transforms.Compose([ |
|
transforms.Resize((224, 224),interpolation=InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
|
|
if dataset=='facecaption': |
|
train_dataset = facecaption_train(transform_train, args.img_root, args.ann_root) |
|
eval_dataset = facecaption_test(transform_test, args.img_root, args.ann_root) |
|
return train_dataset, eval_dataset |
|
|
|
|
|
|
|
def create_sampler(datasets, shuffles, num_tasks, global_rank): |
|
samplers = [] |
|
for dataset,shuffle in zip(datasets,shuffles): |
|
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) |
|
samplers.append(sampler) |
|
return samplers |
|
|
|
|
|
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): |
|
loaders = [] |
|
for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): |
|
if is_train: |
|
shuffle = (sampler is None) |
|
drop_last = True |
|
else: |
|
shuffle = False |
|
drop_last = False |
|
loader = DataLoader( |
|
dataset, |
|
batch_size=bs, |
|
num_workers=n_worker, |
|
pin_memory=True, |
|
sampler=sampler, |
|
shuffle=shuffle, |
|
collate_fn=collate_fn, |
|
drop_last=drop_last, |
|
) |
|
loaders.append(loader) |
|
return loaders |
|
|
|
|