ychenhq's picture
Upload folder using huggingface_hub
04fbff5 verified
raw
history blame
2.3 kB
import sys
import tqdm
import torch
import argparse
import numpy as np
import os.path as osp
from omegaconf import OmegaConf
sys.path.append('.')
from utils.utils import read, img2tensor
from utils.build_utils import build_from_cfg
from metrics.psnr_ssim import calculate_psnr, calculate_ssim
parser = argparse.ArgumentParser(
prog = 'AMT',
description = 'Vimeo90K evaluation (with Test-Time Augmentation)',
)
parser.add_argument('-c', '--config', default='cfgs/AMT-S.yaml')
parser.add_argument('p', '--ckpt', default='pretrained/amt-s.pth',)
parser.add_argument('-r', '--root', default='data/vimeo_triplet',)
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cfg_path = args.config
ckpt_path = args.ckpt
root = args.root
network_cfg = OmegaConf.load(cfg_path).network
network_name = network_cfg.name
model = build_from_cfg(network_cfg)
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt['state_dict'])
model = model.to(device)
model.eval()
with open(osp.join(root, 'tri_testlist.txt'), 'r') as fr:
file_list = fr.readlines()
psnr_list = []
ssim_list = []
pbar = tqdm.tqdm(file_list, total=len(file_list))
for name in pbar:
name = str(name).strip()
if(len(name) <= 1):
continue
dir_path = osp.join(root, 'sequences', name)
I0 = img2tensor(read(osp.join(dir_path, 'im1.png'))).to(device)
I1 = img2tensor(read(osp.join(dir_path, 'im2.png'))).to(device)
I2 = img2tensor(read(osp.join(dir_path, 'im3.png'))).to(device)
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
I1_pred1 = model(I0, I2, embt,
scale_factor=1.0, eval=True)['imgt_pred']
I1_pred2 = model(torch.flip(I0, [2]), torch.flip(I2, [2]), embt,
scale_factor=1.0, eval=True)['imgt_pred']
I1_pred = I1_pred1 / 2 + torch.flip(I1_pred2, [2]) / 2
psnr = calculate_psnr(I1_pred, I1).detach().cpu().numpy()
ssim = calculate_ssim(I1_pred, I1).detach().cpu().numpy()
psnr_list.append(psnr)
ssim_list.append(ssim)
avg_psnr = np.mean(psnr_list)
avg_ssim = np.mean(ssim_list)
desc_str = f'[{network_name}/Vimeo90K] psnr: {avg_psnr:.02f}, ssim: {avg_ssim:.04f}'
pbar.set_description_str(desc_str)