VfiTest / vfi_inference_triplet.py
SuyeonJ's picture
Upload folder using huggingface_hub
8d015d4 verified
raw
history blame
6.54 kB
# [START_COMMAND]
# python3 -m vfi_inference --cuda_index 0 \
# --use_video --root ../VFI_Inference/thistest/test_video.mp4 --save_root ../VFI_Inference/thistest/results --source_frame_ext png \
# --pretrain_path ./pretrained/upr_freq002.pth \
# --pyr_level 7 --nr_lvl_skipped 0 --splat_mode average --down_scale 1 \
# --make_video --fps 0 --new_video_name test_video_vfi.mp4
# python3 -m vfi_inference_triplet --cuda_index 0 \
# --root ../VFI_Inference/thistriplet_notarget --pretrain_path ./pretrained/upr_freq002.pth \
# --pyr_level 7 --nr_lvl_skipped 0 --splat_mode average --down_scale 1
# [FILE SYSTEM]
# args.root ํด๋” ์•„๋ž˜์—
# ํ•˜์œ„ ํด๋” (๊นŠ์ด๋Š” ์ตœ๋Œ€ 10๊ฐœ) ์•„๋ž˜์—
# triplet 3๊ฐœ ์ด๋ฏธ์ง€ (with GT) ๋˜๋Š” triplet 2๊ฐœ ์ด๋ฏธ์ง€ (without GT)
import argparse
import os
import cv2
import glob
import torch
import datetime
import numpy as np
from PIL import Image
from tqdm import tqdm
from torch.nn import functional as F
from torchvision.transforms import functional as TF
from utils.metrics import calculate_batch_psnr, calculate_batch_ssim
from modules.components.upr_net_freq import upr_freq as upr_freq002
from modules.components.upr_basic import upr as upr_basic
def multiple_pad(image, multiple):
_,_,H,W = image.size()
pad1 = multiple-(H%multiple) if H%multiple!=0 else 0
pad2 = multiple-(W%multiple) if W%multiple!=0 else 0
return TF.pad(image, (0,0,pad2,pad1))
print('์ธํผ๋Ÿฐ์Šค ์‹œ\n1. utils.pad.py replicate->constant๋กœ ๋ณ€๊ฒฝํ•˜๊ณ \n2. components upr Model ์ตœ์ดˆ์ธํ’‹์—์„œ normalization๊ณผ padding ์œ„์น˜ ๋ฐ”๊ฟจ๋Š”์ง€ ํ™•์ธํ•  ๊ฒƒ (padding์ด ์œ„์— ์žˆ์–ด์•ผ๋จ)')
def main():
NOW = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
parser = argparse.ArgumentParser('Video Frame Interpolation Inference',add_help=True)
parser.add_argument('--cuda_index', default=0, type=int, help='CUDA GPU index')
parser.add_argument('--exist_gt', action='store_true', help='whether ground-truth existing')
parser.add_argument('--root', default='', type=str, help='root containing frames [./triplet_name]')
parser.add_argument('--pretrain_path', default='', type=str, help='path containing pretrained model')
parser.add_argument('--pyr_level', default=5, type=int, help='UPR-Net pyramid level')
parser.add_argument('--nr_lvl_skipped', default=0, type=int, help='UPR-Net pyramid skip number')
parser.add_argument('--splat_mode', default='average', type=str, help='UPR-Net warping splat mode')
parser.add_argument('--down_scale', default=1, type=int, help='frame down-scaling factor (due to GPU memory issue)')
args = parser.parse_args()
assert not args.root.endswith('/'), f"'root' ({args.root}) must not end with '/'"
DEVICE = args.cuda_index
torch.cuda.set_device(DEVICE)
ROOT = args.root
SAVE_ROOT = f'{ROOT}_{NOW}'
os.makedirs(SAVE_ROOT, exist_ok=True)
SCALE = args.down_scale
model = upr_freq002.Model(pyr_level=args.pyr_level,
nr_lvl_skipped=args.nr_lvl_skipped,
splat_mode=args.splat_mode)
sd = torch.load(args.pretrain_path, map_location='cpu')
sd = sd['model'] if 'model' in sd.keys() else sd
print(model.load_state_dict(sd))
model = model.to(DEVICE)
star = '/*'
temp = [x for i in range(10) for x in glob.glob(f'{ROOT}{star*i}') if os.path.isfile(x)]
folder_list = sorted(set([os.path.split(x)[0] for x in temp]))
if args.exist_gt:
with open(os.path.join(SAVE_ROOT, f'record.txt'), 'w', encoding='utf8') as f:
f.writelines('')
psnr_list = []
ssim_list = []
print('@@@@@@@@@@@@@@@@@@@@Staring VFI@@@@@@@@@@@@@@@@@@@@')
for folder in tqdm(folder_list):
file_list = []
for ext in ['tif', 'TIF', 'jpg', 'png', 'tga', 'TGA']:
file_list += sorted(glob.glob(os.path.join(folder, f'*.{ext}')))
cur_ext = os.path.splitext(file_list[0])[1][1:]
if cur_ext in ['tga', 'TGA']:
img_list = [TF.to_tensor(Image.open(file))[:3].unsqueeze(0).to(DEVICE) for file in file_list]
else:
img_list = [(torch.from_numpy(cv2.imread(file)[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE) for file in file_list]
_,_,Hori,Wori = img_list[0].size()
# if Hori*Wori<=2100000:
# SCALE = 1
# elif Hori*Wori<=2100000*4:
# SCALE = 2
# else:
# SCALE = 4
if args.exist_gt:
img_list = [multiple_pad(img, SCALE) if k!=1 else img for k, img in enumerate(img_list)]
img_list = [F.interpolate(img, scale_factor=1/SCALE, mode='bicubic') if k!=1 else img for k, img in enumerate(img_list)]
img0,imgt,img1 = img_list
else:
img_list = [multiple_pad(img, SCALE) for k, img in enumerate(img_list)]
img_list = [F.interpolate(img, scale_factor=1/SCALE, mode='bicubic') for k, img in enumerate(img_list)]
img0,img1 = img_list
with torch.no_grad():
result_dict, extra_dict = model(img0, img1, pyr_level=args.pyr_level, nr_lvl_skipped=args.nr_lvl_skipped, time_step=0.5)
out = F.interpolate(result_dict['imgt_pred'], scale_factor=SCALE, mode='bicubic')[:,:,:Hori,:Wori].clamp(0,1)
if args.exist_gt:
psnr, _ = calculate_batch_psnr(imgt, out)
ssim, _ = calculate_batch_ssim(imgt, out)
psnr_list.append(psnr)
ssim_list.append(ssim)
filepath, ext = os.path.splitext(file_list[1])
newfilename = filepath.replace(ROOT, SAVE_ROOT)
newfile = newfilename+'_pred'+ext if args.exist_gt else os.path.join(os.path.split(newfilename)[0], 'im_pred'+ext)
newfolder = os.path.split(newfile)[0]
os.makedirs(newfolder, exist_ok=True)
if cur_ext in ['tga', 'TGA']:
TF.to_pil_image(out[0].cpu()).save(newfile)
else:
cv2.imwrite(newfile, (out[0].cpu().permute(1,2,0)*255).numpy().astype(np.uint8)[:,:,[2,1,0]])
if args.exist_gt:
with open(os.path.join(SAVE_ROOT, f'record.txt'), 'a', encoding='utf8') as f:
foldername = '/'.join(folder.split('/')[2:])
f.writelines(f'{foldername:45}PSNR: {psnr:.4f} SSIM: {ssim:.4f}\n')
if args.exist_gt:
print(f'PSNR: {np.mean(psnr_list):.4f}, SSIM: {np.mean(ssim_list):.6f}')
if __name__ == '__main__':
main()