File size: 1,355 Bytes
8d015d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import cv2
from glob import glob

from datasets import register

import torch
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset


@register('ucf101')
class UCF101(Dataset):
    def __init__(self, root_path, **kwargs):
        self.data_root = root_path
        self.load_data()

    def __len__(self):
        return len(self.meta_data)

    def load_data(self):
        triplet_dirs = glob(os.path.join(self.data_root, "*"))
        self.meta_data = triplet_dirs

    def get_img(self, index):
        img_path = self.meta_data[index]
        img_paths = [os.path.join(img_path, 'im1.png'),
                    os.path.join(img_path, 'im2.png'),
                    os.path.join(img_path, 'im3.png')]

        # Load images
        img0 = cv2.imread(img_paths[0])[:,:,::-1]
        imgt = cv2.imread(img_paths[1])[:,:,::-1]
        img1 = cv2.imread(img_paths[2])[:,:,::-1]
        return img0, imgt, img1

    def __getitem__(self, index):
        img0, imgt, img1 = self.get_img(index)
        img0 = TF.to_tensor(img0.copy())
        img1 = TF.to_tensor(img1.copy())
        imgt = TF.to_tensor(imgt.copy())
        time_step = torch.Tensor([0.5]).reshape(1, 1, 1)
        return {
            'img0': img0, 'imgt': imgt, 'img1': img1, 'time_step': time_step, 'scene_name': self.meta_data[index]
        }