|
from pathlib import Path |
|
import os |
|
from PIL import Image |
|
import random |
|
import numpy as np |
|
import cv2 |
|
|
|
from datasets import register |
|
|
|
import torch |
|
import torchvision |
|
import torchvision.transforms as T |
|
import torchvision.transforms.functional as TF |
|
from torch.utils.data import Dataset |
|
|
|
|
|
@register('xiph') |
|
class Xiph(Dataset): |
|
def __init__(self, root_path, split="resized-2k"): |
|
self.data_root = root_path |
|
self.split = split |
|
assert split in ["resized-2k", "cropped-4k"] |
|
self.load_data() |
|
|
|
def __len__(self): |
|
return len(self.imgt_path_list) |
|
|
|
def load_data(self): |
|
self.img0_path_list = [] |
|
self.imgt_path_list = [] |
|
self.img1_path_list = [] |
|
for flie_name in os.listdir(self.data_root): |
|
for intFrame in range(2, 99, 2): |
|
self.img0_path_list.append(f'{flie_name}/{intFrame - 1:03d}.png') |
|
self.imgt_path_list.append(f'{flie_name}/{intFrame:03d}.png') |
|
self.img1_path_list.append(f'{flie_name}/{intFrame + 1:03d}.png') |
|
|
|
def get_img(self, index): |
|
img0_path = os.path.join(self.data_root, self.img0_path_list[index]) |
|
imgt_path = os.path.join(self.data_root, self.imgt_path_list[index]) |
|
img1_path = os.path.join(self.data_root, self.img1_path_list[index]) |
|
|
|
|
|
img0 = cv2.imread(img0_path)[:, :, ::-1] |
|
imgt = cv2.imread(imgt_path)[:, :, ::-1] |
|
img1 = cv2.imread(img1_path)[:, :, ::-1] |
|
|
|
return img0, imgt, img1 |
|
|
|
def __getitem__(self, index): |
|
img0, imgt, img1 = self.get_img(index) |
|
if self.split == 'resized-2k': |
|
img0 = cv2.resize(src=img0, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA) |
|
img1 = cv2.resize(src=img1, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA) |
|
imgt = cv2.resize(src=imgt, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA) |
|
|
|
elif self.split == 'cropped-4k': |
|
img0 = img0[540:-540, 1024:-1024, :] |
|
img1 = img1[540:-540, 1024:-1024, :] |
|
imgt = imgt[540:-540, 1024:-1024, :] |
|
img0 = TF.to_tensor(img0.copy()) |
|
img1 = TF.to_tensor(img1.copy()) |
|
imgt = TF.to_tensor(imgt.copy()) |
|
time_step = torch.Tensor([0.5]).reshape(1, 1, 1) |
|
return { |
|
'img0': img0, 'imgt': imgt, 'img1': img1, 'time_step': time_step, 'scene_name': self.imgt_path_list[index] |
|
} |
|
|