|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
|
import os |
|
import cv2 |
|
import glob |
|
import torch |
|
import datetime |
|
import numpy as np |
|
from PIL import Image |
|
from tqdm import tqdm |
|
from torch.nn import functional as F |
|
from torchvision.transforms import functional as TF |
|
from utils.metrics import calculate_batch_psnr, calculate_batch_ssim |
|
|
|
from modules.components.upr_net_freq import upr_freq as upr_freq002 |
|
from modules.components.upr_basic import upr as upr_basic |
|
|
|
def multiple_pad(image, multiple): |
|
_,_,H,W = image.size() |
|
pad1 = multiple-(H%multiple) if H%multiple!=0 else 0 |
|
pad2 = multiple-(W%multiple) if W%multiple!=0 else 0 |
|
return TF.pad(image, (0,0,pad2,pad1)) |
|
|
|
print('์ธํผ๋ฐ์ค ์\n1. utils.pad.py replicate->constant๋ก ๋ณ๊ฒฝํ๊ณ \n2. components upr Model ์ต์ด์ธํ์์ normalization๊ณผ padding ์์น ๋ฐ๊ฟจ๋์ง ํ์ธํ ๊ฒ (padding์ด ์์ ์์ด์ผ๋จ)') |
|
def main(): |
|
NOW = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') |
|
|
|
parser = argparse.ArgumentParser('Video Frame Interpolation Inference',add_help=True) |
|
parser.add_argument('--cuda_index', default=0, type=int, help='CUDA GPU index') |
|
|
|
parser.add_argument('--exist_gt', action='store_true', help='whether ground-truth existing') |
|
parser.add_argument('--root', default='', type=str, help='root containing frames [./triplet_name]') |
|
parser.add_argument('--pretrain_path', default='', type=str, help='path containing pretrained model') |
|
|
|
parser.add_argument('--pyr_level', default=5, type=int, help='UPR-Net pyramid level') |
|
parser.add_argument('--nr_lvl_skipped', default=0, type=int, help='UPR-Net pyramid skip number') |
|
parser.add_argument('--splat_mode', default='average', type=str, help='UPR-Net warping splat mode') |
|
parser.add_argument('--down_scale', default=1, type=int, help='frame down-scaling factor (due to GPU memory issue)') |
|
|
|
args = parser.parse_args() |
|
assert not args.root.endswith('/'), f"'root' ({args.root}) must not end with '/'" |
|
|
|
DEVICE = args.cuda_index |
|
torch.cuda.set_device(DEVICE) |
|
ROOT = args.root |
|
SAVE_ROOT = f'{ROOT}_{NOW}' |
|
os.makedirs(SAVE_ROOT, exist_ok=True) |
|
SCALE = args.down_scale |
|
|
|
model = upr_freq002.Model(pyr_level=args.pyr_level, |
|
nr_lvl_skipped=args.nr_lvl_skipped, |
|
splat_mode=args.splat_mode) |
|
sd = torch.load(args.pretrain_path, map_location='cpu') |
|
sd = sd['model'] if 'model' in sd.keys() else sd |
|
print(model.load_state_dict(sd)) |
|
model = model.to(DEVICE) |
|
|
|
star = '/*' |
|
temp = [x for i in range(10) for x in glob.glob(f'{ROOT}{star*i}') if os.path.isfile(x)] |
|
folder_list = sorted(set([os.path.split(x)[0] for x in temp])) |
|
if args.exist_gt: |
|
with open(os.path.join(SAVE_ROOT, f'record.txt'), 'w', encoding='utf8') as f: |
|
f.writelines('') |
|
psnr_list = [] |
|
ssim_list = [] |
|
|
|
print('@@@@@@@@@@@@@@@@@@@@Staring VFI@@@@@@@@@@@@@@@@@@@@') |
|
for folder in tqdm(folder_list): |
|
file_list = [] |
|
for ext in ['tif', 'TIF', 'jpg', 'png', 'tga', 'TGA']: |
|
file_list += sorted(glob.glob(os.path.join(folder, f'*.{ext}'))) |
|
cur_ext = os.path.splitext(file_list[0])[1][1:] |
|
if cur_ext in ['tga', 'TGA']: |
|
img_list = [TF.to_tensor(Image.open(file))[:3].unsqueeze(0).to(DEVICE) for file in file_list] |
|
else: |
|
img_list = [(torch.from_numpy(cv2.imread(file)[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE) for file in file_list] |
|
|
|
_,_,Hori,Wori = img_list[0].size() |
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.exist_gt: |
|
img_list = [multiple_pad(img, SCALE) if k!=1 else img for k, img in enumerate(img_list)] |
|
img_list = [F.interpolate(img, scale_factor=1/SCALE, mode='bicubic') if k!=1 else img for k, img in enumerate(img_list)] |
|
img0,imgt,img1 = img_list |
|
else: |
|
img_list = [multiple_pad(img, SCALE) for k, img in enumerate(img_list)] |
|
img_list = [F.interpolate(img, scale_factor=1/SCALE, mode='bicubic') for k, img in enumerate(img_list)] |
|
img0,img1 = img_list |
|
|
|
with torch.no_grad(): |
|
result_dict, extra_dict = model(img0, img1, pyr_level=args.pyr_level, nr_lvl_skipped=args.nr_lvl_skipped, time_step=0.5) |
|
out = F.interpolate(result_dict['imgt_pred'], scale_factor=SCALE, mode='bicubic')[:,:,:Hori,:Wori].clamp(0,1) |
|
|
|
if args.exist_gt: |
|
psnr, _ = calculate_batch_psnr(imgt, out) |
|
ssim, _ = calculate_batch_ssim(imgt, out) |
|
psnr_list.append(psnr) |
|
ssim_list.append(ssim) |
|
|
|
filepath, ext = os.path.splitext(file_list[1]) |
|
newfilename = filepath.replace(ROOT, SAVE_ROOT) |
|
newfile = newfilename+'_pred'+ext if args.exist_gt else os.path.join(os.path.split(newfilename)[0], 'im_pred'+ext) |
|
newfolder = os.path.split(newfile)[0] |
|
os.makedirs(newfolder, exist_ok=True) |
|
if cur_ext in ['tga', 'TGA']: |
|
TF.to_pil_image(out[0].cpu()).save(newfile) |
|
else: |
|
cv2.imwrite(newfile, (out[0].cpu().permute(1,2,0)*255).numpy().astype(np.uint8)[:,:,[2,1,0]]) |
|
|
|
if args.exist_gt: |
|
with open(os.path.join(SAVE_ROOT, f'record.txt'), 'a', encoding='utf8') as f: |
|
foldername = '/'.join(folder.split('/')[2:]) |
|
f.writelines(f'{foldername:45}PSNR: {psnr:.4f} SSIM: {ssim:.4f}\n') |
|
if args.exist_gt: |
|
print(f'PSNR: {np.mean(psnr_list):.4f}, SSIM: {np.mean(ssim_list):.6f}') |
|
|
|
if __name__ == '__main__': |
|
main() |