|
import os |
|
import sys |
|
sys.path.append('.') |
|
import cv2 |
|
import math |
|
import torch |
|
import argparse |
|
import numpy as np |
|
from torch.nn import functional as F |
|
from model.pytorch_msssim import ssim_matlab |
|
from model.RIFE import Model |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
model = Model() |
|
model.load_model('train_log') |
|
model.eval() |
|
model.device() |
|
|
|
path = 'UCF101/ucf101_interp_ours/' |
|
dirs = os.listdir(path) |
|
psnr_list = [] |
|
ssim_list = [] |
|
print(len(dirs)) |
|
for d in dirs: |
|
img0 = (path + d + '/frame_00.png') |
|
img1 = (path + d + '/frame_02.png') |
|
gt = (path + d + '/frame_01_gt.png') |
|
img0 = (torch.tensor(cv2.imread(img0).transpose(2, 0, 1) / 255.)).to(device).float().unsqueeze(0) |
|
img1 = (torch.tensor(cv2.imread(img1).transpose(2, 0, 1) / 255.)).to(device).float().unsqueeze(0) |
|
gt = (torch.tensor(cv2.imread(gt).transpose(2, 0, 1) / 255.)).to(device).float().unsqueeze(0) |
|
pred = model.inference(img0, img1)[0] |
|
ssim = ssim_matlab(gt, torch.round(pred * 255).unsqueeze(0) / 255.).detach().cpu().numpy() |
|
out = pred.detach().cpu().numpy().transpose(1, 2, 0) |
|
out = np.round(out * 255) / 255. |
|
gt = gt[0].cpu().numpy().transpose(1, 2, 0) |
|
psnr = -10 * math.log10(((gt - out) * (gt - out)).mean()) |
|
psnr_list.append(psnr) |
|
ssim_list.append(ssim) |
|
print("Avg PSNR: {} SSIM: {}".format(np.mean(psnr_list), np.mean(ssim_list))) |
|
|