|
from pathlib import Path |
|
import os |
|
from PIL import Image |
|
import random |
|
import numpy as np |
|
import cv2 |
|
|
|
from datasets import register |
|
from .data_utils import * |
|
|
|
import torch |
|
import torchvision |
|
import torchvision.transforms as T |
|
import torchvision.transforms.functional as TF |
|
from torch.utils.data import Dataset |
|
|
|
perm = [(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0)] |
|
rotate = [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_180, cv2.ROTATE_90_COUNTERCLOCKWISE] |
|
|
|
|
|
@register('vimeo') |
|
class Vimeo(Dataset): |
|
def __init__(self, root_path, patch_size=(224, 224), split='train', flow="none", flow_root=None, |
|
use_distance=False, distance_root=None, tri_trainlist='tri_trainlist.txt'): |
|
super(Vimeo, self).__init__() |
|
self.data_root = root_path |
|
self.mode = split |
|
self.patch_size = patch_size |
|
train_fn = os.path.join(self.data_root, tri_trainlist) |
|
test_fn = os.path.join(self.data_root, 'tri_testlist.txt') |
|
|
|
self.flow = flow |
|
self.flow_root = flow_root if flow!='none' else None |
|
self.use_distance = use_distance |
|
self.distance_root = distance_root |
|
|
|
with open(train_fn, "r") as f: |
|
self.trainlist = [line.strip() for line in f.readlines() if len(line.strip())>0] |
|
|
|
with open(test_fn, "r") as f: |
|
self.testlist = [line.strip() for line in f.readlines() if len(line.strip())>0] |
|
|
|
if self.mode == "train": |
|
|
|
self.img_list = self.trainlist |
|
elif self.mode == "test": |
|
self.img_list = self.testlist |
|
else: |
|
self.img_list = self.testlist |
|
|
|
|
|
def get_img(self, index): |
|
img_path = os.path.join(self.data_root, "sequences", self.img_list[index]) |
|
if os.path.exists(os.path.join(img_path, "im1.png")): |
|
img0 = cv2.imread(os.path.join(img_path, "im1.png"))[:, :896, ::-1] |
|
imgt = cv2.imread(os.path.join(img_path, "im2.png"))[:, :896, ::-1] |
|
img1 = cv2.imread(os.path.join(img_path, "im3.png"))[:, :896, ::-1] |
|
|
|
elif os.path.exists(os.path.join(img_path, "im1.jpg")): |
|
img0 = cv2.imread(os.path.join(img_path, "im1.jpg"))[:, :, ::-1] |
|
imgt = cv2.imread(os.path.join(img_path, "im2.jpg"))[:, :, ::-1] |
|
img1 = cv2.imread(os.path.join(img_path, "im3.jpg"))[:, :, ::-1] |
|
else: |
|
print(img_path,"νμΌμ΄ μ μμ§?") |
|
|
|
|
|
if self.flow == 't0': |
|
|
|
if not os.path.exists(os.path.join(self.flow_root, 'sequences', self.img_list[index], 'flowt0.npy')): |
|
print(self.img_list[index]) |
|
|
|
|
|
flowt0 = np.load(os.path.join(self.flow_root, 'sequences', self.img_list[index], 'flowt0.npy')).astype(np.float32) |
|
flowt1 = np.load(os.path.join(self.flow_root, 'sequences', self.img_list[index], 'flowt1.npy')).astype(np.float32) |
|
elif self.flow == '01': |
|
flowt0 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_01.flo")) |
|
flowt1 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_10.flo")) |
|
elif self.flow == '0t': |
|
flowt0 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_0t.flo")) |
|
flowt1 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_1t.flo")) |
|
else: |
|
flowt0 = None |
|
flowt1 = None |
|
return img0, imgt, img1, flowt0, flowt1 |
|
|
|
def __getitem__(self, item): |
|
img0, imgt, img1, flowt0, flowt1 = self.get_img(item) |
|
distance = None |
|
H,W,_ = img0.shape |
|
time_step = torch.Tensor([0.5]).reshape(1, 1, 1) |
|
if self.mode == "train": |
|
if random.random() > 0.5: |
|
img0, imgt, img1, time_step, flowt0, flowt1 = random_temporal_flip(img0, imgt, img1, time_step, flowt0, flowt1) |
|
if self.use_distance and self.distance_root is not None: |
|
distance_path = os.path.join(self.distance_root, "sequences", self.img_list[item]) |
|
distance = np.load(os.path.join(distance_path, 'distance_rev.npy')).astype(np.float32).reshape(H,W,1) |
|
asdf = 'distance_rev.npy' |
|
else: |
|
if self.use_distance and self.distance_root is not None: |
|
distance_path = os.path.join(self.distance_root, "sequences", self.img_list[item]) |
|
distance = np.load(os.path.join(distance_path, 'distance_for.npy')).astype(np.float32).reshape(H,W,1) |
|
asdf = 'distance_for.npy' |
|
if random.random() > 0.9: |
|
img0, imgt, img1, flowt0, flowt1, distance = random_resize(img0, imgt, img1, flowt0, flowt1, distance) |
|
img0, imgt, img1, flowt0, flowt1, distance = random_crop(img0, imgt, img1, self.patch_size, flowt0, flowt1, distance) |
|
if random.random() > 0.5: |
|
img0, imgt, img1, flowt0, flowt1, distance = random_hor_flip(img0, imgt, img1, flowt0, flowt1, distance) |
|
if random.random() > 0.5: |
|
img0, imgt, img1, flowt0, flowt1, distance = random_ver_flip(img0, imgt, img1, flowt0, flowt1, distance) |
|
if random.random() > 0.5: |
|
img0, imgt, img1 = random_color_permutation(img0, imgt, img1) |
|
degree = random.randint(0, 3) |
|
img0, imgt, img1, flowt0, flowt1, distance = random_rotation(img0, imgt, img1, degree, flowt0, flowt1, distance) |
|
else: |
|
if self.distance_root is not None: |
|
distance_path = os.path.join(self.distance_root, "sequences", self.img_list[item], 'distance_for.npy') |
|
distance = np.load(distance_path).astype(np.float32).reshape(H,W,1) |
|
img0, imgt, img1 = TF.to_tensor(img0.copy()), TF.to_tensor(imgt.copy()), TF.to_tensor(img1.copy()) |
|
input_dict = { |
|
'img0': img0, 'imgt': imgt, 'img1': img1, 'time_step': time_step, 'scene_name': self.img_list[item] |
|
} |
|
if flowt0 is not None and flowt1 is not None: |
|
flowt0 = torch.from_numpy(flowt0).type(torch.float32).permute(2, 0, 1) |
|
flowt1 = torch.from_numpy(flowt1).type(torch.float32).permute(2, 0, 1) |
|
input_dict['flowt0'] = flowt0 |
|
input_dict['flowt1'] = flowt1 |
|
if self.use_distance: |
|
if self.distance_root is not None: |
|
distance = TF.to_tensor(distance.copy()) |
|
if torch.any(torch.isnan(distance)): |
|
print(f'@@@@@@@@@@@@@@@@@@@@@@{self.img_list[item]}, {asdf}@@@@@@@@@@@@@@@@@@@@@@') |
|
else: |
|
distance = np.array(0.5).reshape(1,1,1).repeat(H, axis=0).repeat(W, axis=1) |
|
distance = torch.from_numpy(distance).type(torch.float32).permute(2,0,1) |
|
input_dict['distance'] = distance |
|
return input_dict |
|
|
|
def __len__(self): |
|
return len(self.img_list) |
|
|