''' This code is partially borrowed from IFRNet (https://github.com/ltkong218/IFRNet). ''' import os import cv2 import torch import random import numpy as np from torch.utils.data import Dataset from utils.utils import read def random_resize(img0, imgt, img1, flow, p=0.1): if random.uniform(0, 1) < p: img0 = cv2.resize(img0, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) imgt = cv2.resize(imgt, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) img1 = cv2.resize(img1, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) flow = cv2.resize(flow, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) * 2.0 return img0, imgt, img1, flow def random_crop(img0, imgt, img1, flow, crop_size=(224, 224)): h, w = crop_size[0], crop_size[1] ih, iw, _ = img0.shape x = np.random.randint(0, ih-h+1) y = np.random.randint(0, iw-w+1) img0 = img0[x:x+h, y:y+w, :] imgt = imgt[x:x+h, y:y+w, :] img1 = img1[x:x+h, y:y+w, :] flow = flow[x:x+h, y:y+w, :] return img0, imgt, img1, flow def random_reverse_channel(img0, imgt, img1, flow, p=0.5): if random.uniform(0, 1) < p: img0 = img0[:, :, ::-1] imgt = imgt[:, :, ::-1] img1 = img1[:, :, ::-1] return img0, imgt, img1, flow def random_vertical_flip(img0, imgt, img1, flow, p=0.3): if random.uniform(0, 1) < p: img0 = img0[::-1] imgt = imgt[::-1] img1 = img1[::-1] flow = flow[::-1] flow = np.concatenate((flow[:, :, 0:1], -flow[:, :, 1:2], flow[:, :, 2:3], -flow[:, :, 3:4]), 2) return img0, imgt, img1, flow def random_horizontal_flip(img0, imgt, img1, flow, p=0.5): if random.uniform(0, 1) < p: img0 = img0[:, ::-1] imgt = imgt[:, ::-1] img1 = img1[:, ::-1] flow = flow[:, ::-1] flow = np.concatenate((-flow[:, :, 0:1], flow[:, :, 1:2], -flow[:, :, 2:3], flow[:, :, 3:4]), 2) return img0, imgt, img1, flow def random_rotate(img0, imgt, img1, flow, p=0.05): if random.uniform(0, 1) < p: img0 = img0.transpose((1, 0, 2)) imgt = imgt.transpose((1, 0, 2)) img1 = img1.transpose((1, 0, 2)) flow = flow.transpose((1, 0, 2)) flow = np.concatenate((flow[:, :, 1:2], flow[:, :, 0:1], flow[:, :, 3:4], flow[:, :, 2:3]), 2) return img0, imgt, img1, flow def random_reverse_time(img0, imgt, img1, flow, p=0.5): if random.uniform(0, 1) < p: tmp = img1 img1 = img0 img0 = tmp flow = np.concatenate((flow[:, :, 2:4], flow[:, :, 0:2]), 2) return img0, imgt, img1, flow class Vimeo90K_Train_Dataset(Dataset): def __init__(self, dataset_dir='data/vimeo_triplet', flow_dir=None, augment=True, crop_size=(224, 224)): self.dataset_dir = dataset_dir self.augment = augment self.crop_size = crop_size self.img0_list = [] self.imgt_list = [] self.img1_list = [] self.flow_t0_list = [] self.flow_t1_list = [] if flow_dir is None: flow_dir = 'flow' with open(os.path.join(dataset_dir, 'tri_trainlist.txt'), 'r') as f: for i in f: name = str(i).strip() if(len(name) <= 1): continue self.img0_list.append(os.path.join(dataset_dir, 'sequences', name, 'im1.png')) self.imgt_list.append(os.path.join(dataset_dir, 'sequences', name, 'im2.png')) self.img1_list.append(os.path.join(dataset_dir, 'sequences', name, 'im3.png')) self.flow_t0_list.append(os.path.join(dataset_dir, flow_dir, name, 'flow_t0.flo')) self.flow_t1_list.append(os.path.join(dataset_dir, flow_dir, name, 'flow_t1.flo')) def __len__(self): return len(self.imgt_list) def __getitem__(self, idx): img0 = read(self.img0_list[idx]) imgt = read(self.imgt_list[idx]) img1 = read(self.img1_list[idx]) flow_t0 = read(self.flow_t0_list[idx]) flow_t1 = read(self.flow_t1_list[idx]) flow = np.concatenate((flow_t0, flow_t1), 2).astype(np.float64) if self.augment == True: img0, imgt, img1, flow = random_resize(img0, imgt, img1, flow, p=0.1) img0, imgt, img1, flow = random_crop(img0, imgt, img1, flow, crop_size=self.crop_size) img0, imgt, img1, flow = random_reverse_channel(img0, imgt, img1, flow, p=0.5) img0, imgt, img1, flow = random_vertical_flip(img0, imgt, img1, flow, p=0.3) img0, imgt, img1, flow = random_horizontal_flip(img0, imgt, img1, flow, p=0.5) img0, imgt, img1, flow = random_rotate(img0, imgt, img1, flow, p=0.05) img0, imgt, img1, flow = random_reverse_time(img0, imgt, img1, flow, p=0.5) img0 = torch.from_numpy(img0.transpose((2, 0, 1)).astype(np.float32) / 255.0) imgt = torch.from_numpy(imgt.transpose((2, 0, 1)).astype(np.float32) / 255.0) img1 = torch.from_numpy(img1.transpose((2, 0, 1)).astype(np.float32) / 255.0) flow = torch.from_numpy(flow.transpose((2, 0, 1)).astype(np.float32)) embt = torch.from_numpy(np.array(1/2).reshape(1, 1, 1).astype(np.float32)) return {'img0': img0.float(), 'imgt': imgt.float(), 'img1': img1.float(), 'flow': flow.float(), 'embt': embt} class Vimeo90K_Test_Dataset(Dataset): def __init__(self, dataset_dir='data/vimeo_triplet'): self.dataset_dir = dataset_dir self.img0_list = [] self.imgt_list = [] self.img1_list = [] self.flow_t0_list = [] self.flow_t1_list = [] with open(os.path.join(dataset_dir, 'tri_testlist.txt'), 'r') as f: for i in f: name = str(i).strip() if(len(name) <= 1): continue self.img0_list.append(os.path.join(dataset_dir, 'sequences', name, 'im1.png')) self.imgt_list.append(os.path.join(dataset_dir, 'sequences', name, 'im2.png')) self.img1_list.append(os.path.join(dataset_dir, 'sequences', name, 'im3.png')) self.flow_t0_list.append(os.path.join(dataset_dir, 'flow', name, 'flow_t0.flo')) self.flow_t1_list.append(os.path.join(dataset_dir, 'flow', name, 'flow_t1.flo')) def __len__(self): return len(self.imgt_list) def __getitem__(self, idx): img0 = read(self.img0_list[idx]) imgt = read(self.imgt_list[idx]) img1 = read(self.img1_list[idx]) flow_t0 = read(self.flow_t0_list[idx]) flow_t1 = read(self.flow_t1_list[idx]) flow = np.concatenate((flow_t0, flow_t1), 2) img0 = torch.from_numpy(img0.transpose((2, 0, 1)).astype(np.float32) / 255.0) imgt = torch.from_numpy(imgt.transpose((2, 0, 1)).astype(np.float32) / 255.0) img1 = torch.from_numpy(img1.transpose((2, 0, 1)).astype(np.float32) / 255.0) flow = torch.from_numpy(flow.transpose((2, 0, 1)).astype(np.float32)) embt = torch.from_numpy(np.array(1/2).reshape(1, 1, 1).astype(np.float32)) return {'img0': img0.float(), 'imgt': imgt.float(), 'img1': img1.float(), 'flow': flow.float(), 'embt': embt}