VfiTest / datasets /vimeo.py
SuyeonJ's picture
Upload folder using huggingface_hub
8d015d4 verified
from pathlib import Path
import os
from PIL import Image
import random
import numpy as np
import cv2
from datasets import register
from .data_utils import *
import torch
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset
perm = [(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0)]
rotate = [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_180, cv2.ROTATE_90_COUNTERCLOCKWISE]
@register('vimeo')
class Vimeo(Dataset):
def __init__(self, root_path, patch_size=(224, 224), split='train', flow="none", flow_root=None,
use_distance=False, distance_root=None, tri_trainlist='tri_trainlist.txt'):
super(Vimeo, self).__init__()
self.data_root = root_path
self.mode = split
self.patch_size = patch_size
train_fn = os.path.join(self.data_root, tri_trainlist)
test_fn = os.path.join(self.data_root, 'tri_testlist.txt')
# self.flow = 't0'
self.flow = flow
self.flow_root = flow_root if flow!='none' else None
self.use_distance = use_distance
self.distance_root = distance_root
with open(train_fn, "r") as f:
self.trainlist = [line.strip() for line in f.readlines() if len(line.strip())>0]
# self.trainlist = [line.strip() for line in f.readlines() if len(line.strip())>0 and line.strip().endswith(('e', 'n'))]
with open(test_fn, "r") as f:
self.testlist = [line.strip() for line in f.readlines() if len(line.strip())>0]
#cnt = int(len(self.trainlist) * 0.95)
if self.mode == "train":
#self.img_list = self.trainlist[:cnt]
self.img_list = self.trainlist
elif self.mode == "test":
self.img_list = self.testlist
else:
self.img_list = self.testlist
#self.img_list = self.trainlist[cnt:]
def get_img(self, index):
img_path = os.path.join(self.data_root, "sequences", self.img_list[index])
if os.path.exists(os.path.join(img_path, "im1.png")):
img0 = cv2.imread(os.path.join(img_path, "im1.png"))[:, :896, ::-1]
imgt = cv2.imread(os.path.join(img_path, "im2.png"))[:, :896, ::-1]
img1 = cv2.imread(os.path.join(img_path, "im3.png"))[:, :896, ::-1]
elif os.path.exists(os.path.join(img_path, "im1.jpg")):
img0 = cv2.imread(os.path.join(img_path, "im1.jpg"))[:, :, ::-1]
imgt = cv2.imread(os.path.join(img_path, "im2.jpg"))[:, :, ::-1]
img1 = cv2.imread(os.path.join(img_path, "im3.jpg"))[:, :, ::-1]
else:
print(img_path,"파일이 μ™œ μ—†μ§€?")
# print(f'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!{self.flow}!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
if self.flow == 't0':
# if not os.path.exists(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_t0.flo")):
if not os.path.exists(os.path.join(self.flow_root, 'sequences', self.img_list[index], 'flowt0.npy')):
print(self.img_list[index])
# flowt0 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_t0.flo"))
# flowt1 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_t1.flo"))
flowt0 = np.load(os.path.join(self.flow_root, 'sequences', self.img_list[index], 'flowt0.npy')).astype(np.float32)
flowt1 = np.load(os.path.join(self.flow_root, 'sequences', self.img_list[index], 'flowt1.npy')).astype(np.float32)
elif self.flow == '01':
flowt0 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_01.flo"))
flowt1 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_10.flo"))
elif self.flow == '0t':
flowt0 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_0t.flo"))
flowt1 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_1t.flo"))
else:
flowt0 = None
flowt1 = None
return img0, imgt, img1, flowt0, flowt1
def __getitem__(self, item):
img0, imgt, img1, flowt0, flowt1 = self.get_img(item)
distance = None
H,W,_ = img0.shape
time_step = torch.Tensor([0.5]).reshape(1, 1, 1)
if self.mode == "train":
if random.random() > 0.5:
img0, imgt, img1, time_step, flowt0, flowt1 = random_temporal_flip(img0, imgt, img1, time_step, flowt0, flowt1)
if self.use_distance and self.distance_root is not None:
distance_path = os.path.join(self.distance_root, "sequences", self.img_list[item])
distance = np.load(os.path.join(distance_path, 'distance_rev.npy')).astype(np.float32).reshape(H,W,1)
asdf = 'distance_rev.npy'
else:
if self.use_distance and self.distance_root is not None:
distance_path = os.path.join(self.distance_root, "sequences", self.img_list[item])
distance = np.load(os.path.join(distance_path, 'distance_for.npy')).astype(np.float32).reshape(H,W,1)
asdf = 'distance_for.npy'
if random.random() > 0.9:
img0, imgt, img1, flowt0, flowt1, distance = random_resize(img0, imgt, img1, flowt0, flowt1, distance)
img0, imgt, img1, flowt0, flowt1, distance = random_crop(img0, imgt, img1, self.patch_size, flowt0, flowt1, distance)
if random.random() > 0.5:
img0, imgt, img1, flowt0, flowt1, distance = random_hor_flip(img0, imgt, img1, flowt0, flowt1, distance)
if random.random() > 0.5:
img0, imgt, img1, flowt0, flowt1, distance = random_ver_flip(img0, imgt, img1, flowt0, flowt1, distance)
if random.random() > 0.5:
img0, imgt, img1 = random_color_permutation(img0, imgt, img1)
degree = random.randint(0, 3)
img0, imgt, img1, flowt0, flowt1, distance = random_rotation(img0, imgt, img1, degree, flowt0, flowt1, distance)
else:
if self.distance_root is not None:
distance_path = os.path.join(self.distance_root, "sequences", self.img_list[item], 'distance_for.npy')
distance = np.load(distance_path).astype(np.float32).reshape(H,W,1)
img0, imgt, img1 = TF.to_tensor(img0.copy()), TF.to_tensor(imgt.copy()), TF.to_tensor(img1.copy())
input_dict = {
'img0': img0, 'imgt': imgt, 'img1': img1, 'time_step': time_step, 'scene_name': self.img_list[item]
}
if flowt0 is not None and flowt1 is not None:
flowt0 = torch.from_numpy(flowt0).type(torch.float32).permute(2, 0, 1)
flowt1 = torch.from_numpy(flowt1).type(torch.float32).permute(2, 0, 1)
input_dict['flowt0'] = flowt0
input_dict['flowt1'] = flowt1
if self.use_distance:
if self.distance_root is not None:
distance = TF.to_tensor(distance.copy())
if torch.any(torch.isnan(distance)):
print(f'@@@@@@@@@@@@@@@@@@@@@@{self.img_list[item]}, {asdf}@@@@@@@@@@@@@@@@@@@@@@')
else:
distance = np.array(0.5).reshape(1,1,1).repeat(H, axis=0).repeat(W, axis=1)
distance = torch.from_numpy(distance).type(torch.float32).permute(2,0,1)
input_dict['distance'] = distance
return input_dict
def __len__(self):
return len(self.img_list)