|
import os |
|
import sys |
|
import cv2 |
|
import tqdm |
|
import glob |
|
import torch |
|
import argparse |
|
import numpy as np |
|
import os.path as osp |
|
from omegaconf import OmegaConf |
|
|
|
sys.path.append('.') |
|
from utils.utils import InputPadder, read, img2tensor |
|
from utils.build_utils import build_from_cfg |
|
from metrics.psnr_ssim import calculate_psnr, calculate_ssim |
|
|
|
parser = argparse.ArgumentParser( |
|
prog = 'AMT', |
|
description = 'Xiph evaluation', |
|
) |
|
parser.add_argument('-c', '--config', default='cfgs/AMT-S.yaml') |
|
parser.add_argument('-p', '--ckpt', default='pretrained/amt-s.pth') |
|
parser.add_argument('-r', '--root', default='data/xiph') |
|
args = parser.parse_args() |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
cfg_path = args.config |
|
ckpt_path = args.ckpt |
|
root = args.root |
|
|
|
network_cfg = OmegaConf.load(cfg_path).network |
|
network_name = network_cfg.name |
|
model = build_from_cfg(network_cfg) |
|
ckpt = torch.load(ckpt_path) |
|
model.load_state_dict(ckpt['state_dict'], False) |
|
model = model.to(device) |
|
model.eval() |
|
|
|
|
|
download_links = [ |
|
'https://media.xiph.org/video/derf/ElFuente/Netflix_BoxingPractice_4096x2160_60fps_10bit_420.y4m', |
|
'https://media.xiph.org/video/derf/ElFuente/Netflix_Crosswalk_4096x2160_60fps_10bit_420.y4m', |
|
'https://media.xiph.org/video/derf/Chimera/Netflix_DrivingPOV_4096x2160_60fps_10bit_420.y4m', |
|
'https://media.xiph.org/video/derf/ElFuente/Netflix_FoodMarket_4096x2160_60fps_10bit_420.y4m', |
|
'https://media.xiph.org/video/derf/ElFuente/Netflix_FoodMarket2_4096x2160_60fps_10bit_420.y4m', |
|
'https://media.xiph.org/video/derf/ElFuente/Netflix_RitualDance_4096x2160_60fps_10bit_420.y4m', |
|
'https://media.xiph.org/video/derf/ElFuente/Netflix_SquareAndTimelapse_4096x2160_60fps_10bit_420.y4m', |
|
'https://media.xiph.org/video/derf/ElFuente/Netflix_Tango_4096x2160_60fps_10bit_420.y4m', |
|
] |
|
file_list = ['BoxingPractice', 'Crosswalk', 'DrivingPOV', 'FoodMarket', 'FoodMarket2', 'RitualDance', |
|
'SquareAndTimelapse', 'Tango'] |
|
|
|
for file_name, link in zip(file_list, download_links): |
|
data_dir = osp.join(root, file_name) |
|
if osp.exists(data_dir) is False: |
|
os.makedirs(data_dir) |
|
if len(glob.glob(f'{data_dir}/*.png')) < 100: |
|
os.system(f'ffmpeg -i {link} -pix_fmt rgb24 -vframes 100 {data_dir}/%03d.png') |
|
|
|
|
|
|
|
divisor = 32; scale_factor = 0.5 |
|
for category in ['resized-2k', 'cropped-4k']: |
|
psnr_list = [] |
|
ssim_list = [] |
|
pbar = tqdm.tqdm(file_list, total=len(file_list)) |
|
for flie_name in pbar: |
|
dir_name = osp.join(root, flie_name) |
|
for intFrame in range(2, 99, 2): |
|
img0 = read(f'{dir_name}/{intFrame - 1:03d}.png') |
|
img1 = read(f'{dir_name}/{intFrame + 1:03d}.png') |
|
imgt = read(f'{dir_name}/{intFrame:03d}.png') |
|
|
|
if category == 'resized-2k': |
|
img0 = cv2.resize(src=img0, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA) |
|
img1 = cv2.resize(src=img1, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA) |
|
imgt = cv2.resize(src=imgt, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA) |
|
|
|
elif category == 'cropped-4k': |
|
img0 = img0[540:-540, 1024:-1024, :] |
|
img1 = img1[540:-540, 1024:-1024, :] |
|
imgt = imgt[540:-540, 1024:-1024, :] |
|
img0 = img2tensor(img0).to(device) |
|
imgt = img2tensor(imgt).to(device) |
|
img1 = img2tensor(img1).to(device) |
|
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) |
|
|
|
padder = InputPadder(img0.shape, divisor) |
|
img0, img1 = padder.pad(img0, img1) |
|
|
|
with torch.no_grad(): |
|
imgt_pred = model(img0, img1, embt, scale_factor=scale_factor, eval=True)['imgt_pred'] |
|
imgt_pred = padder.unpad(imgt_pred) |
|
|
|
psnr = calculate_psnr(imgt_pred, imgt) |
|
ssim = calculate_ssim(imgt_pred, imgt) |
|
|
|
avg_psnr = np.mean(psnr_list) |
|
avg_ssim = np.mean(ssim_list) |
|
psnr_list.append(psnr) |
|
ssim_list.append(ssim) |
|
desc_str = f'[{network_name}/Xiph] [{category}/{flie_name}] psnr: {avg_psnr:.02f}, ssim: {avg_ssim:.04f}' |
|
|
|
pbar.set_description_str(desc_str) |