|
import sys |
|
import tqdm |
|
import torch |
|
import argparse |
|
import numpy as np |
|
import os.path as osp |
|
from omegaconf import OmegaConf |
|
|
|
sys.path.append('.') |
|
from utils.utils import read, img2tensor |
|
from utils.build_utils import build_from_cfg |
|
from metrics.psnr_ssim import calculate_psnr, calculate_ssim |
|
|
|
parser = argparse.ArgumentParser( |
|
prog = 'AMT', |
|
description = 'Vimeo90K evaluation (with Test-Time Augmentation)', |
|
) |
|
parser.add_argument('-c', '--config', default='cfgs/AMT-S.yaml') |
|
parser.add_argument('p', '--ckpt', default='pretrained/amt-s.pth',) |
|
parser.add_argument('-r', '--root', default='data/vimeo_triplet',) |
|
args = parser.parse_args() |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
cfg_path = args.config |
|
ckpt_path = args.ckpt |
|
root = args.root |
|
|
|
network_cfg = OmegaConf.load(cfg_path).network |
|
network_name = network_cfg.name |
|
model = build_from_cfg(network_cfg) |
|
ckpt = torch.load(ckpt_path) |
|
model.load_state_dict(ckpt['state_dict']) |
|
model = model.to(device) |
|
model.eval() |
|
|
|
with open(osp.join(root, 'tri_testlist.txt'), 'r') as fr: |
|
file_list = fr.readlines() |
|
|
|
psnr_list = [] |
|
ssim_list = [] |
|
|
|
pbar = tqdm.tqdm(file_list, total=len(file_list)) |
|
for name in pbar: |
|
name = str(name).strip() |
|
if(len(name) <= 1): |
|
continue |
|
dir_path = osp.join(root, 'sequences', name) |
|
I0 = img2tensor(read(osp.join(dir_path, 'im1.png'))).to(device) |
|
I1 = img2tensor(read(osp.join(dir_path, 'im2.png'))).to(device) |
|
I2 = img2tensor(read(osp.join(dir_path, 'im3.png'))).to(device) |
|
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) |
|
|
|
I1_pred1 = model(I0, I2, embt, |
|
scale_factor=1.0, eval=True)['imgt_pred'] |
|
I1_pred2 = model(torch.flip(I0, [2]), torch.flip(I2, [2]), embt, |
|
scale_factor=1.0, eval=True)['imgt_pred'] |
|
I1_pred = I1_pred1 / 2 + torch.flip(I1_pred2, [2]) / 2 |
|
psnr = calculate_psnr(I1_pred, I1).detach().cpu().numpy() |
|
ssim = calculate_ssim(I1_pred, I1).detach().cpu().numpy() |
|
|
|
psnr_list.append(psnr) |
|
ssim_list.append(ssim) |
|
avg_psnr = np.mean(psnr_list) |
|
avg_ssim = np.mean(ssim_list) |
|
desc_str = f'[{network_name}/Vimeo90K] psnr: {avg_psnr:.02f}, ssim: {avg_ssim:.04f}' |
|
pbar.set_description_str(desc_str) |
|
|
|
|