File size: 7,713 Bytes
8d015d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from pathlib import Path
import os
from PIL import Image
import random
import numpy as np
import cv2
from datasets import register
from .data_utils import *
import torch
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset
perm = [(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0)]
rotate = [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_180, cv2.ROTATE_90_COUNTERCLOCKWISE]
@register('vimeo')
class Vimeo(Dataset):
def __init__(self, root_path, patch_size=(224, 224), split='train', flow="none", flow_root=None,
use_distance=False, distance_root=None, tri_trainlist='tri_trainlist.txt'):
super(Vimeo, self).__init__()
self.data_root = root_path
self.mode = split
self.patch_size = patch_size
train_fn = os.path.join(self.data_root, tri_trainlist)
test_fn = os.path.join(self.data_root, 'tri_testlist.txt')
# self.flow = 't0'
self.flow = flow
self.flow_root = flow_root if flow!='none' else None
self.use_distance = use_distance
self.distance_root = distance_root
with open(train_fn, "r") as f:
self.trainlist = [line.strip() for line in f.readlines() if len(line.strip())>0]
# self.trainlist = [line.strip() for line in f.readlines() if len(line.strip())>0 and line.strip().endswith(('e', 'n'))]
with open(test_fn, "r") as f:
self.testlist = [line.strip() for line in f.readlines() if len(line.strip())>0]
#cnt = int(len(self.trainlist) * 0.95)
if self.mode == "train":
#self.img_list = self.trainlist[:cnt]
self.img_list = self.trainlist
elif self.mode == "test":
self.img_list = self.testlist
else:
self.img_list = self.testlist
#self.img_list = self.trainlist[cnt:]
def get_img(self, index):
img_path = os.path.join(self.data_root, "sequences", self.img_list[index])
if os.path.exists(os.path.join(img_path, "im1.png")):
img0 = cv2.imread(os.path.join(img_path, "im1.png"))[:, :896, ::-1]
imgt = cv2.imread(os.path.join(img_path, "im2.png"))[:, :896, ::-1]
img1 = cv2.imread(os.path.join(img_path, "im3.png"))[:, :896, ::-1]
elif os.path.exists(os.path.join(img_path, "im1.jpg")):
img0 = cv2.imread(os.path.join(img_path, "im1.jpg"))[:, :, ::-1]
imgt = cv2.imread(os.path.join(img_path, "im2.jpg"))[:, :, ::-1]
img1 = cv2.imread(os.path.join(img_path, "im3.jpg"))[:, :, ::-1]
else:
print(img_path,"νμΌμ΄ μ μμ§?")
# print(f'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!{self.flow}!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
if self.flow == 't0':
# if not os.path.exists(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_t0.flo")):
if not os.path.exists(os.path.join(self.flow_root, 'sequences', self.img_list[index], 'flowt0.npy')):
print(self.img_list[index])
# flowt0 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_t0.flo"))
# flowt1 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_t1.flo"))
flowt0 = np.load(os.path.join(self.flow_root, 'sequences', self.img_list[index], 'flowt0.npy')).astype(np.float32)
flowt1 = np.load(os.path.join(self.flow_root, 'sequences', self.img_list[index], 'flowt1.npy')).astype(np.float32)
elif self.flow == '01':
flowt0 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_01.flo"))
flowt1 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_10.flo"))
elif self.flow == '0t':
flowt0 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_0t.flo"))
flowt1 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_1t.flo"))
else:
flowt0 = None
flowt1 = None
return img0, imgt, img1, flowt0, flowt1
def __getitem__(self, item):
img0, imgt, img1, flowt0, flowt1 = self.get_img(item)
distance = None
H,W,_ = img0.shape
time_step = torch.Tensor([0.5]).reshape(1, 1, 1)
if self.mode == "train":
if random.random() > 0.5:
img0, imgt, img1, time_step, flowt0, flowt1 = random_temporal_flip(img0, imgt, img1, time_step, flowt0, flowt1)
if self.use_distance and self.distance_root is not None:
distance_path = os.path.join(self.distance_root, "sequences", self.img_list[item])
distance = np.load(os.path.join(distance_path, 'distance_rev.npy')).astype(np.float32).reshape(H,W,1)
asdf = 'distance_rev.npy'
else:
if self.use_distance and self.distance_root is not None:
distance_path = os.path.join(self.distance_root, "sequences", self.img_list[item])
distance = np.load(os.path.join(distance_path, 'distance_for.npy')).astype(np.float32).reshape(H,W,1)
asdf = 'distance_for.npy'
if random.random() > 0.9:
img0, imgt, img1, flowt0, flowt1, distance = random_resize(img0, imgt, img1, flowt0, flowt1, distance)
img0, imgt, img1, flowt0, flowt1, distance = random_crop(img0, imgt, img1, self.patch_size, flowt0, flowt1, distance)
if random.random() > 0.5:
img0, imgt, img1, flowt0, flowt1, distance = random_hor_flip(img0, imgt, img1, flowt0, flowt1, distance)
if random.random() > 0.5:
img0, imgt, img1, flowt0, flowt1, distance = random_ver_flip(img0, imgt, img1, flowt0, flowt1, distance)
if random.random() > 0.5:
img0, imgt, img1 = random_color_permutation(img0, imgt, img1)
degree = random.randint(0, 3)
img0, imgt, img1, flowt0, flowt1, distance = random_rotation(img0, imgt, img1, degree, flowt0, flowt1, distance)
else:
if self.distance_root is not None:
distance_path = os.path.join(self.distance_root, "sequences", self.img_list[item], 'distance_for.npy')
distance = np.load(distance_path).astype(np.float32).reshape(H,W,1)
img0, imgt, img1 = TF.to_tensor(img0.copy()), TF.to_tensor(imgt.copy()), TF.to_tensor(img1.copy())
input_dict = {
'img0': img0, 'imgt': imgt, 'img1': img1, 'time_step': time_step, 'scene_name': self.img_list[item]
}
if flowt0 is not None and flowt1 is not None:
flowt0 = torch.from_numpy(flowt0).type(torch.float32).permute(2, 0, 1)
flowt1 = torch.from_numpy(flowt1).type(torch.float32).permute(2, 0, 1)
input_dict['flowt0'] = flowt0
input_dict['flowt1'] = flowt1
if self.use_distance:
if self.distance_root is not None:
distance = TF.to_tensor(distance.copy())
if torch.any(torch.isnan(distance)):
print(f'@@@@@@@@@@@@@@@@@@@@@@{self.img_list[item]}, {asdf}@@@@@@@@@@@@@@@@@@@@@@')
else:
distance = np.array(0.5).reshape(1,1,1).repeat(H, axis=0).repeat(W, axis=1)
distance = torch.from_numpy(distance).type(torch.float32).permute(2,0,1)
input_dict['distance'] = distance
return input_dict
def __len__(self):
return len(self.img_list)
|