import os import cv2 from glob import glob from datasets import register import torch import torchvision.transforms.functional as TF from torch.utils.data import Dataset @register('ucf101') class UCF101(Dataset): def __init__(self, root_path, **kwargs): self.data_root = root_path self.load_data() def __len__(self): return len(self.meta_data) def load_data(self): triplet_dirs = glob(os.path.join(self.data_root, "*")) self.meta_data = triplet_dirs def get_img(self, index): img_path = self.meta_data[index] img_paths = [os.path.join(img_path, 'im1.png'), os.path.join(img_path, 'im2.png'), os.path.join(img_path, 'im3.png')] # Load images img0 = cv2.imread(img_paths[0])[:,:,::-1] imgt = cv2.imread(img_paths[1])[:,:,::-1] img1 = cv2.imread(img_paths[2])[:,:,::-1] return img0, imgt, img1 def __getitem__(self, index): img0, imgt, img1 = self.get_img(index) img0 = TF.to_tensor(img0.copy()) img1 = TF.to_tensor(img1.copy()) imgt = TF.to_tensor(imgt.copy()) time_step = torch.Tensor([0.5]).reshape(1, 1, 1) return { 'img0': img0, 'imgt': imgt, 'img1': img1, 'time_step': time_step, 'scene_name': self.meta_data[index] }