VfiTest / datasets /xiph.py
SuyeonJ's picture
Upload folder using huggingface_hub
8d015d4 verified
raw
history blame
2.47 kB
from pathlib import Path
import os
from PIL import Image
import random
import numpy as np
import cv2
from datasets import register
import torch
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset
@register('xiph')
class Xiph(Dataset):
def __init__(self, root_path, split="resized-2k"):
self.data_root = root_path
self.split = split
assert split in ["resized-2k", "cropped-4k"]
self.load_data()
def __len__(self):
return len(self.imgt_path_list)
def load_data(self):
self.img0_path_list = []
self.imgt_path_list = []
self.img1_path_list = []
for flie_name in os.listdir(self.data_root):
for intFrame in range(2, 99, 2):
self.img0_path_list.append(f'{flie_name}/{intFrame - 1:03d}.png')
self.imgt_path_list.append(f'{flie_name}/{intFrame:03d}.png')
self.img1_path_list.append(f'{flie_name}/{intFrame + 1:03d}.png')
def get_img(self, index):
img0_path = os.path.join(self.data_root, self.img0_path_list[index])
imgt_path = os.path.join(self.data_root, self.imgt_path_list[index])
img1_path = os.path.join(self.data_root, self.img1_path_list[index])
# Load images
img0 = cv2.imread(img0_path)[:, :, ::-1]
imgt = cv2.imread(imgt_path)[:, :, ::-1]
img1 = cv2.imread(img1_path)[:, :, ::-1]
return img0, imgt, img1
def __getitem__(self, index):
img0, imgt, img1 = self.get_img(index)
if self.split == 'resized-2k':
img0 = cv2.resize(src=img0, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA)
img1 = cv2.resize(src=img1, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA)
imgt = cv2.resize(src=imgt, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA)
elif self.split == 'cropped-4k':
img0 = img0[540:-540, 1024:-1024, :]
img1 = img1[540:-540, 1024:-1024, :]
imgt = imgt[540:-540, 1024:-1024, :]
img0 = TF.to_tensor(img0.copy())
img1 = TF.to_tensor(img1.copy())
imgt = TF.to_tensor(imgt.copy())
time_step = torch.Tensor([0.5]).reshape(1, 1, 1)
return {
'img0': img0, 'imgt': imgt, 'img1': img1, 'time_step': time_step, 'scene_name': self.imgt_path_list[index]
}