|
import os |
|
import cv2 |
|
from glob import glob |
|
|
|
from datasets import register |
|
|
|
import torch |
|
import torchvision.transforms.functional as TF |
|
from torch.utils.data import Dataset |
|
|
|
|
|
@register('ucf101') |
|
class UCF101(Dataset): |
|
def __init__(self, root_path, **kwargs): |
|
self.data_root = root_path |
|
self.load_data() |
|
|
|
def __len__(self): |
|
return len(self.meta_data) |
|
|
|
def load_data(self): |
|
triplet_dirs = glob(os.path.join(self.data_root, "*")) |
|
self.meta_data = triplet_dirs |
|
|
|
def get_img(self, index): |
|
img_path = self.meta_data[index] |
|
img_paths = [os.path.join(img_path, 'im1.png'), |
|
os.path.join(img_path, 'im2.png'), |
|
os.path.join(img_path, 'im3.png')] |
|
|
|
|
|
img0 = cv2.imread(img_paths[0])[:,:,::-1] |
|
imgt = cv2.imread(img_paths[1])[:,:,::-1] |
|
img1 = cv2.imread(img_paths[2])[:,:,::-1] |
|
return img0, imgt, img1 |
|
|
|
def __getitem__(self, index): |
|
img0, imgt, img1 = self.get_img(index) |
|
img0 = TF.to_tensor(img0.copy()) |
|
img1 = TF.to_tensor(img1.copy()) |
|
imgt = TF.to_tensor(imgt.copy()) |
|
time_step = torch.Tensor([0.5]).reshape(1, 1, 1) |
|
return { |
|
'img0': img0, 'imgt': imgt, 'img1': img1, 'time_step': time_step, 'scene_name': self.meta_data[index] |
|
} |
|
|
|
|