VfiTest / datasets /ucf101.py
SuyeonJ's picture
Upload folder using huggingface_hub
8d015d4 verified
import os
import cv2
from glob import glob
from datasets import register
import torch
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset
@register('ucf101')
class UCF101(Dataset):
def __init__(self, root_path, **kwargs):
self.data_root = root_path
self.load_data()
def __len__(self):
return len(self.meta_data)
def load_data(self):
triplet_dirs = glob(os.path.join(self.data_root, "*"))
self.meta_data = triplet_dirs
def get_img(self, index):
img_path = self.meta_data[index]
img_paths = [os.path.join(img_path, 'im1.png'),
os.path.join(img_path, 'im2.png'),
os.path.join(img_path, 'im3.png')]
# Load images
img0 = cv2.imread(img_paths[0])[:,:,::-1]
imgt = cv2.imread(img_paths[1])[:,:,::-1]
img1 = cv2.imread(img_paths[2])[:,:,::-1]
return img0, imgt, img1
def __getitem__(self, index):
img0, imgt, img1 = self.get_img(index)
img0 = TF.to_tensor(img0.copy())
img1 = TF.to_tensor(img1.copy())
imgt = TF.to_tensor(imgt.copy())
time_step = torch.Tensor([0.5]).reshape(1, 1, 1)
return {
'img0': img0, 'imgt': imgt, 'img1': img1, 'time_step': time_step, 'scene_name': self.meta_data[index]
}