ychenhq's picture
Upload folder using huggingface_hub
04fbff5 verified
raw
history blame
4.53 kB
import os
import sys
import cv2
import tqdm
import glob
import torch
import argparse
import numpy as np
import os.path as osp
from omegaconf import OmegaConf
sys.path.append('.')
from utils.utils import InputPadder, read, img2tensor
from utils.build_utils import build_from_cfg
from metrics.psnr_ssim import calculate_psnr, calculate_ssim
parser = argparse.ArgumentParser(
prog = 'AMT',
description = 'Xiph evaluation',
)
parser.add_argument('-c', '--config', default='cfgs/AMT-S.yaml')
parser.add_argument('-p', '--ckpt', default='pretrained/amt-s.pth')
parser.add_argument('-r', '--root', default='data/xiph')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cfg_path = args.config
ckpt_path = args.ckpt
root = args.root
network_cfg = OmegaConf.load(cfg_path).network
network_name = network_cfg.name
model = build_from_cfg(network_cfg)
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt['state_dict'], False)
model = model.to(device)
model.eval()
############################################# Prepare Dataset #############################################
download_links = [
'https://media.xiph.org/video/derf/ElFuente/Netflix_BoxingPractice_4096x2160_60fps_10bit_420.y4m',
'https://media.xiph.org/video/derf/ElFuente/Netflix_Crosswalk_4096x2160_60fps_10bit_420.y4m',
'https://media.xiph.org/video/derf/Chimera/Netflix_DrivingPOV_4096x2160_60fps_10bit_420.y4m',
'https://media.xiph.org/video/derf/ElFuente/Netflix_FoodMarket_4096x2160_60fps_10bit_420.y4m',
'https://media.xiph.org/video/derf/ElFuente/Netflix_FoodMarket2_4096x2160_60fps_10bit_420.y4m',
'https://media.xiph.org/video/derf/ElFuente/Netflix_RitualDance_4096x2160_60fps_10bit_420.y4m',
'https://media.xiph.org/video/derf/ElFuente/Netflix_SquareAndTimelapse_4096x2160_60fps_10bit_420.y4m',
'https://media.xiph.org/video/derf/ElFuente/Netflix_Tango_4096x2160_60fps_10bit_420.y4m',
]
file_list = ['BoxingPractice', 'Crosswalk', 'DrivingPOV', 'FoodMarket', 'FoodMarket2', 'RitualDance',
'SquareAndTimelapse', 'Tango']
for file_name, link in zip(file_list, download_links):
data_dir = osp.join(root, file_name)
if osp.exists(data_dir) is False:
os.makedirs(data_dir)
if len(glob.glob(f'{data_dir}/*.png')) < 100:
os.system(f'ffmpeg -i {link} -pix_fmt rgb24 -vframes 100 {data_dir}/%03d.png')
############################################### Prepare End ###############################################
divisor = 32; scale_factor = 0.5
for category in ['resized-2k', 'cropped-4k']:
psnr_list = []
ssim_list = []
pbar = tqdm.tqdm(file_list, total=len(file_list))
for flie_name in pbar:
dir_name = osp.join(root, flie_name)
for intFrame in range(2, 99, 2):
img0 = read(f'{dir_name}/{intFrame - 1:03d}.png')
img1 = read(f'{dir_name}/{intFrame + 1:03d}.png')
imgt = read(f'{dir_name}/{intFrame:03d}.png')
if category == 'resized-2k':
img0 = cv2.resize(src=img0, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA)
img1 = cv2.resize(src=img1, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA)
imgt = cv2.resize(src=imgt, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA)
elif category == 'cropped-4k':
img0 = img0[540:-540, 1024:-1024, :]
img1 = img1[540:-540, 1024:-1024, :]
imgt = imgt[540:-540, 1024:-1024, :]
img0 = img2tensor(img0).to(device)
imgt = img2tensor(imgt).to(device)
img1 = img2tensor(img1).to(device)
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
padder = InputPadder(img0.shape, divisor)
img0, img1 = padder.pad(img0, img1)
with torch.no_grad():
imgt_pred = model(img0, img1, embt, scale_factor=scale_factor, eval=True)['imgt_pred']
imgt_pred = padder.unpad(imgt_pred)
psnr = calculate_psnr(imgt_pred, imgt)
ssim = calculate_ssim(imgt_pred, imgt)
avg_psnr = np.mean(psnr_list)
avg_ssim = np.mean(ssim_list)
psnr_list.append(psnr)
ssim_list.append(ssim)
desc_str = f'[{network_name}/Xiph] [{category}/{flie_name}] psnr: {avg_psnr:.02f}, ssim: {avg_ssim:.04f}'
pbar.set_description_str(desc_str)