import torch from torch.utils.data import DataLoader from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from data.facecaption_dataset import facecaption_train, facecaption_test from data.randaugment import RandomAugment def create_dataset(args, dataset, min_scale=0.5): normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) transform_train = transforms.Compose([ transforms.Resize((224, 224),interpolation=InterpolationMode.BICUBIC), transforms.RandomHorizontalFlip(), RandomAugment(2,5,isPIL=True,augs=['Identity','Brightness','Sharpness','Equalize', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), transforms.ToTensor(), normalize, ]) transform_test = transforms.Compose([ transforms.Resize((224, 224),interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), normalize, ]) if dataset=='facecaption': train_dataset = facecaption_train(transform_train, args.img_root, args.ann_root) eval_dataset = facecaption_test(transform_test, args.img_root, args.ann_root) return train_dataset, eval_dataset def create_sampler(datasets, shuffles, num_tasks, global_rank): samplers = [] for dataset,shuffle in zip(datasets,shuffles): sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) samplers.append(sampler) return samplers def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): loaders = [] for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): if is_train: shuffle = (sampler is None) drop_last = True else: shuffle = False drop_last = False loader = DataLoader( dataset, batch_size=bs, num_workers=n_worker, pin_memory=True, sampler=sampler, shuffle=shuffle, collate_fn=collate_fn, drop_last=drop_last, ) loaders.append(loader) return loaders