ychenhq's picture
Upload folder using huggingface_hub
04fbff5 verified
raw
history blame
1.71 kB
import sys
import tqdm
import torch
import argparse
import numpy as np
from omegaconf import OmegaConf
sys.path.append('.')
from utils.build_utils import build_from_cfg
from datasets.gopro_datasets import GoPro_Test_Dataset
from metrics.psnr_ssim import calculate_psnr, calculate_ssim
parser = argparse.ArgumentParser(
prog = 'AMT',
description = 'GOPRO evaluation',
)
parser.add_argument('-c', '--config', default='cfgs/AMT-S_gopro.yaml')
parser.add_argument('-p', '--ckpt', default='pretrained/gopro_amt-s.pth',)
parser.add_argument('-r', '--root', default='data/GOPRO',)
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cfg_path = args.config
ckpt_path = args.ckpt
root = args.root
network_cfg = OmegaConf.load(cfg_path).network
network_name = network_cfg.name
model = build_from_cfg(network_cfg)
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt['state_dict'])
model = model.to(device)
model.eval()
dataset = GoPro_Test_Dataset(dataset_dir=root)
psnr_list = []
ssim_list = []
pbar = tqdm.tqdm(dataset, total=len(dataset))
for data in pbar:
input_dict = {}
for k, v in data.items():
input_dict[k] = v.to(device).unsqueeze(0)
with torch.no_grad():
imgt_pred = model(**input_dict)['imgt_pred']
psnr = calculate_psnr(imgt_pred, input_dict['imgt'])
ssim = calculate_ssim(imgt_pred, input_dict['imgt'])
psnr_list.append(psnr)
ssim_list.append(ssim)
avg_psnr = np.mean(psnr_list)
avg_ssim = np.mean(ssim_list)
desc_str = f'[{network_name}/GOPRO] psnr: {avg_psnr:.02f}, ssim: {avg_ssim:.04f}'
pbar.set_description_str(desc_str)