|
''' |
|
This code is partially borrowed from IFRNet (https://github.com/ltkong218/IFRNet). |
|
''' |
|
import os |
|
import cv2 |
|
import torch |
|
import random |
|
import numpy as np |
|
from torch.utils.data import Dataset |
|
from utils.utils import read |
|
|
|
|
|
def random_resize(img0, imgt, img1, flow, p=0.1): |
|
if random.uniform(0, 1) < p: |
|
img0 = cv2.resize(img0, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) |
|
imgt = cv2.resize(imgt, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) |
|
img1 = cv2.resize(img1, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) |
|
flow = cv2.resize(flow, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) * 2.0 |
|
return img0, imgt, img1, flow |
|
|
|
def random_crop(img0, imgt, img1, flow, crop_size=(224, 224)): |
|
h, w = crop_size[0], crop_size[1] |
|
ih, iw, _ = img0.shape |
|
x = np.random.randint(0, ih-h+1) |
|
y = np.random.randint(0, iw-w+1) |
|
img0 = img0[x:x+h, y:y+w, :] |
|
imgt = imgt[x:x+h, y:y+w, :] |
|
img1 = img1[x:x+h, y:y+w, :] |
|
flow = flow[x:x+h, y:y+w, :] |
|
return img0, imgt, img1, flow |
|
|
|
def random_reverse_channel(img0, imgt, img1, flow, p=0.5): |
|
if random.uniform(0, 1) < p: |
|
img0 = img0[:, :, ::-1] |
|
imgt = imgt[:, :, ::-1] |
|
img1 = img1[:, :, ::-1] |
|
return img0, imgt, img1, flow |
|
|
|
def random_vertical_flip(img0, imgt, img1, flow, p=0.3): |
|
if random.uniform(0, 1) < p: |
|
img0 = img0[::-1] |
|
imgt = imgt[::-1] |
|
img1 = img1[::-1] |
|
flow = flow[::-1] |
|
flow = np.concatenate((flow[:, :, 0:1], -flow[:, :, 1:2], flow[:, :, 2:3], -flow[:, :, 3:4]), 2) |
|
return img0, imgt, img1, flow |
|
|
|
def random_horizontal_flip(img0, imgt, img1, flow, p=0.5): |
|
if random.uniform(0, 1) < p: |
|
img0 = img0[:, ::-1] |
|
imgt = imgt[:, ::-1] |
|
img1 = img1[:, ::-1] |
|
flow = flow[:, ::-1] |
|
flow = np.concatenate((-flow[:, :, 0:1], flow[:, :, 1:2], -flow[:, :, 2:3], flow[:, :, 3:4]), 2) |
|
return img0, imgt, img1, flow |
|
|
|
def random_rotate(img0, imgt, img1, flow, p=0.05): |
|
if random.uniform(0, 1) < p: |
|
img0 = img0.transpose((1, 0, 2)) |
|
imgt = imgt.transpose((1, 0, 2)) |
|
img1 = img1.transpose((1, 0, 2)) |
|
flow = flow.transpose((1, 0, 2)) |
|
flow = np.concatenate((flow[:, :, 1:2], flow[:, :, 0:1], flow[:, :, 3:4], flow[:, :, 2:3]), 2) |
|
return img0, imgt, img1, flow |
|
|
|
def random_reverse_time(img0, imgt, img1, flow, p=0.5): |
|
if random.uniform(0, 1) < p: |
|
tmp = img1 |
|
img1 = img0 |
|
img0 = tmp |
|
flow = np.concatenate((flow[:, :, 2:4], flow[:, :, 0:2]), 2) |
|
return img0, imgt, img1, flow |
|
|
|
|
|
class Vimeo90K_Train_Dataset(Dataset): |
|
def __init__(self, |
|
dataset_dir='data/vimeo_triplet', |
|
flow_dir=None, |
|
augment=True, |
|
crop_size=(224, 224)): |
|
self.dataset_dir = dataset_dir |
|
self.augment = augment |
|
self.crop_size = crop_size |
|
self.img0_list = [] |
|
self.imgt_list = [] |
|
self.img1_list = [] |
|
self.flow_t0_list = [] |
|
self.flow_t1_list = [] |
|
if flow_dir is None: |
|
flow_dir = 'flow' |
|
with open(os.path.join(dataset_dir, 'tri_trainlist.txt'), 'r') as f: |
|
for i in f: |
|
name = str(i).strip() |
|
if(len(name) <= 1): |
|
continue |
|
self.img0_list.append(os.path.join(dataset_dir, 'sequences', name, 'im1.png')) |
|
self.imgt_list.append(os.path.join(dataset_dir, 'sequences', name, 'im2.png')) |
|
self.img1_list.append(os.path.join(dataset_dir, 'sequences', name, 'im3.png')) |
|
self.flow_t0_list.append(os.path.join(dataset_dir, flow_dir, name, 'flow_t0.flo')) |
|
self.flow_t1_list.append(os.path.join(dataset_dir, flow_dir, name, 'flow_t1.flo')) |
|
|
|
def __len__(self): |
|
return len(self.imgt_list) |
|
|
|
def __getitem__(self, idx): |
|
img0 = read(self.img0_list[idx]) |
|
imgt = read(self.imgt_list[idx]) |
|
img1 = read(self.img1_list[idx]) |
|
flow_t0 = read(self.flow_t0_list[idx]) |
|
flow_t1 = read(self.flow_t1_list[idx]) |
|
flow = np.concatenate((flow_t0, flow_t1), 2).astype(np.float64) |
|
|
|
if self.augment == True: |
|
img0, imgt, img1, flow = random_resize(img0, imgt, img1, flow, p=0.1) |
|
img0, imgt, img1, flow = random_crop(img0, imgt, img1, flow, crop_size=self.crop_size) |
|
img0, imgt, img1, flow = random_reverse_channel(img0, imgt, img1, flow, p=0.5) |
|
img0, imgt, img1, flow = random_vertical_flip(img0, imgt, img1, flow, p=0.3) |
|
img0, imgt, img1, flow = random_horizontal_flip(img0, imgt, img1, flow, p=0.5) |
|
img0, imgt, img1, flow = random_rotate(img0, imgt, img1, flow, p=0.05) |
|
img0, imgt, img1, flow = random_reverse_time(img0, imgt, img1, flow, p=0.5) |
|
|
|
|
|
img0 = torch.from_numpy(img0.transpose((2, 0, 1)).astype(np.float32) / 255.0) |
|
imgt = torch.from_numpy(imgt.transpose((2, 0, 1)).astype(np.float32) / 255.0) |
|
img1 = torch.from_numpy(img1.transpose((2, 0, 1)).astype(np.float32) / 255.0) |
|
flow = torch.from_numpy(flow.transpose((2, 0, 1)).astype(np.float32)) |
|
embt = torch.from_numpy(np.array(1/2).reshape(1, 1, 1).astype(np.float32)) |
|
|
|
return {'img0': img0.float(), 'imgt': imgt.float(), 'img1': img1.float(), 'flow': flow.float(), 'embt': embt} |
|
|
|
|
|
class Vimeo90K_Test_Dataset(Dataset): |
|
def __init__(self, dataset_dir='data/vimeo_triplet'): |
|
self.dataset_dir = dataset_dir |
|
self.img0_list = [] |
|
self.imgt_list = [] |
|
self.img1_list = [] |
|
self.flow_t0_list = [] |
|
self.flow_t1_list = [] |
|
with open(os.path.join(dataset_dir, 'tri_testlist.txt'), 'r') as f: |
|
for i in f: |
|
name = str(i).strip() |
|
if(len(name) <= 1): |
|
continue |
|
self.img0_list.append(os.path.join(dataset_dir, 'sequences', name, 'im1.png')) |
|
self.imgt_list.append(os.path.join(dataset_dir, 'sequences', name, 'im2.png')) |
|
self.img1_list.append(os.path.join(dataset_dir, 'sequences', name, 'im3.png')) |
|
self.flow_t0_list.append(os.path.join(dataset_dir, 'flow', name, 'flow_t0.flo')) |
|
self.flow_t1_list.append(os.path.join(dataset_dir, 'flow', name, 'flow_t1.flo')) |
|
|
|
def __len__(self): |
|
return len(self.imgt_list) |
|
|
|
def __getitem__(self, idx): |
|
img0 = read(self.img0_list[idx]) |
|
imgt = read(self.imgt_list[idx]) |
|
img1 = read(self.img1_list[idx]) |
|
flow_t0 = read(self.flow_t0_list[idx]) |
|
flow_t1 = read(self.flow_t1_list[idx]) |
|
flow = np.concatenate((flow_t0, flow_t1), 2) |
|
|
|
img0 = torch.from_numpy(img0.transpose((2, 0, 1)).astype(np.float32) / 255.0) |
|
imgt = torch.from_numpy(imgt.transpose((2, 0, 1)).astype(np.float32) / 255.0) |
|
img1 = torch.from_numpy(img1.transpose((2, 0, 1)).astype(np.float32) / 255.0) |
|
flow = torch.from_numpy(flow.transpose((2, 0, 1)).astype(np.float32)) |
|
embt = torch.from_numpy(np.array(1/2).reshape(1, 1, 1).astype(np.float32)) |
|
|
|
return {'img0': img0.float(), |
|
'imgt': imgt.float(), |
|
'img1': img1.float(), |
|
'flow': flow.float(), |
|
'embt': embt} |
|
|
|
|
|
|
|
|
|
|