VfiTest / vfi_inference.py
SuyeonJ's picture
Upload folder using huggingface_hub
8d015d4 verified
raw
history blame
8.11 kB
# [START_COMMAND]
# python3 -m vfi_inference --cuda_index 0 \
# --use_video --root ../VFI_Inference/thistest/test_video.mp4 --save_root ../VFI_Inference/thistest/results --source_frame_ext png \
# --pretrain_path ./pretrained/upr_freq002.pth \
# --pyr_level 7 --nr_lvl_skipped 0 --splat_mode average --down_scale 1 \
# --make_video --fps 0 --new_video_name test_video_vfi.mp4
# python3 -m vfi_inference --cuda_index 0 \
# --root ../VFI_Inference/thistest/frames --save_root ../VFI_Inference/thistest/results --source_frame_ext png \
# --pretrain_path ./pretrained/upr_freq002.pth \
# --pyr_level 7 --nr_lvl_skipped 0 --splat_mode average --down_scale 1 \
# --make_video --fps 0 --new_video_name test_video_vfi.mp4
# [FILE SYSTEM] ํ”„๋ ˆ์ž„ ์‹œํ€€์Šค
# args.root ํด๋”์— ํ”„๋ ˆ์ž„ ์‹œํ€€์Šค ๋ชจ์—ฌ ์žˆ์–ด์•ผ ํ•จ
# args.save_root ํด๋”๋Š” args.root ํด๋”์™€ ์ƒ์œ„ํด๋”๊ฐ€ ๋™์ผํ•ด์•ผ ํ•˜๊ณ , args.save_root ํด๋”์— ๊ฒฐ๊ณผ๋ฌผ ํ”„๋ ˆ์ž„ ์‹œํ€€์Šค ์ €์žฅ๋จ
# [FILE SYSTEM] ๋น„๋””์˜ค
# args.root๋Š” ๋น„๋””์˜ค ๊ฒฝ๋กœ
# ํ”„๋ ˆ์ž„ ์‹œํ€€์Šค๋Š” ์ž๋™์œผ๋กœ ์ €์žฅ
# args.save_root ํด๋”๋Š” args.root ํŒŒ์ผ๊ณผ ์ƒ์œ„ํด๋”๊ฐ€ ๋™์ผํ•ด์•ผ ํ•˜๊ณ , args.save_root ํด๋”์— ๊ฒฐ๊ณผ๋ฌผ ํ”„๋ ˆ์ž„ ์‹œํ€€์Šค ์ €์žฅ๋จ
import argparse
import os
import cv2
import glob
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from torchvision.transforms import functional as TF
from modules.components.upr_net_freq import upr_freq as upr_freq002
from modules.components.upr_basic import upr as upr_basic
def parent_folder(path):
return os.path.split(path)[0]
print('์ธํผ๋Ÿฐ์Šค ์‹œ\n1. utils.pad.py replicate->constant๋กœ ๋ณ€๊ฒฝํ•˜๊ณ \n2. components upr Model ์ตœ์ดˆ์ธํ’‹์—์„œ normalization๊ณผ padding ์œ„์น˜ ๋ฐ”๊ฟจ๋Š”์ง€ ํ™•์ธํ•  ๊ฒƒ (padding์ด ์œ„์— ์žˆ์–ด์•ผ๋จ)')
def main():
parser = argparse.ArgumentParser('Video Frame Interpolation Inference',add_help=True)
parser.add_argument('--cuda_index', default=0, type=int, help='CUDA GPU index')
parser.add_argument('--use_video', action='store_true', help='whether using video file')
parser.add_argument('--root', default='', type=str, help='root containing frames [./videoname/frames] (or video [./videoname/videoname.mp4])')
parser.add_argument('--save_root', default='', type=str, help='root to save result frames [./videoname/results_expname]')
parser.add_argument('--source_frame_ext', default='png', type=str, help='source frames extension name')
parser.add_argument('--pretrain_path', default='', type=str, help='path containing pretrained model')
parser.add_argument('--pyr_level', default=5, type=int, help='UPR-Net pyramid level')
parser.add_argument('--nr_lvl_skipped', default=0, type=int, help='UPR-Net pyramid skip number')
parser.add_argument('--splat_mode', default='average', type=str, help='UPR-Net warping splat mode')
parser.add_argument('--down_scale', default=1, type=int, help='frame down-scaling factor (due to GPU memory issue)')
parser.add_argument('--make_video', action='store_true', help='whether merging frames and making video file')
parser.add_argument('--fps', default=0, type=int, help='FPS before VFI')
parser.add_argument('--new_video_name', default='newvideo', type=str, help='new video name [new_video_name.mp4]')
args = parser.parse_args()
assert parent_folder(args.root)==parent_folder(args.save_root),\
f"the parents of 'root' ({parent_folder(args.root)}) and save_root ({parent_folder(args.save_root)}) should be same!!"
if args.make_video:
assert os.path.splitext(args.new_video_name)[1]!='', f"'new_video_name' ({args.new_video_name}) should have extension name!!"
assert parent_folder(args.new_video_name)=='', f"'new_video_name' should not contain directory path"
if args.use_video:
temp1 = cv2.VideoCapture(args.root)
temp2 = int(temp1.get(cv2.CAP_PROP_FRAME_COUNT))
assert temp2>0, f"number of frames in video ({args.root}) must be larger than 0!! !!check file name!!"
temp1.release()
del temp1, temp2
DEVICE = args.cuda_index
torch.cuda.set_device(DEVICE)
VIDEO_ROOT = args.root if args.use_video else None
FRAME_ROOT = args.root if VIDEO_ROOT is None else parent_folder(VIDEO_ROOT)+'/frames'
SAVE_ROOT = args.save_root
EXT = args.source_frame_ext
SCALE = args.down_scale
if VIDEO_ROOT is not None:
print('@@@@@@@@@@@@@@@@@@@@Extracting frames from video@@@@@@@@@@@@@@@@@@@@')
os.makedirs(FRAME_ROOT, exist_ok=True)
video = cv2.VideoCapture(VIDEO_ROOT)
this_fps = video.get(cv2.CAP_PROP_FPS)
for index in tqdm(range(int(video.get(cv2.CAP_PROP_FRAME_COUNT)))):
_, frame = video.read()
newfile = os.path.join(FRAME_ROOT, str(index).zfill(4)+f'.{EXT}')
cv2.imwrite(newfile, frame)
video.release()
model = upr_freq002.Model(pyr_level=args.pyr_level,
nr_lvl_skipped=args.nr_lvl_skipped,
splat_mode=args.splat_mode)
sd = torch.load(args.pretrain_path, map_location='cpu')
sd = sd['model'] if 'model' in sd.keys() else sd
print(model.load_state_dict(sd))
model = model.to(DEVICE)
file_list = sorted(glob.glob(os.path.join(FRAME_ROOT, f'*.{EXT}')))
for i, file in enumerate(file_list):
newfile = os.path.join(FRAME_ROOT, str(i).zfill(4)+f'.{EXT}')
os.rename(file, newfile)
if args.make_video:
num_frame_before = len(file_list)
fps_before = args.fps if not args.use_video else this_fps
num_frame_after = 2*num_frame_before-1
fps_after = fps_before*num_frame_after/num_frame_before
print(f'num_frame_before: {num_frame_before}, fps_before: {fps_before:.6f}, time_before: {num_frame_before/fps_before:.6f}')
print(f'num_frame_after: {num_frame_after}, fps_after: {fps_after:.6f}, time_after: {num_frame_after/fps_after:.6f}')
print()
print('@@@@@@@@@@@@@@@@@@@@Staring VFI@@@@@@@@@@@@@@@@@@@@')
os.makedirs(SAVE_ROOT, exist_ok=True)
for frame_num, file in enumerate(tqdm(file_list)):
img0 = img1 if frame_num!=0 else None
aaa = os.path.join(SAVE_ROOT, str(frame_num*2).zfill(4)+f'.{EXT}')
if EXT not in ['tga', 'TGA']:
img1 = cv2.imread(file)
cv2.imwrite(aaa, img1)
else:
img1 = Image.open(file)
img1.save(aaa)
img1 = np.array(img1)[:,:,[2,1,0]]
H,W,_ = img1.shape
if SCALE==1:
img1 = (torch.from_numpy(img1[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE)
else:
img1 = (torch.from_numpy(cv2.resize(img1, (W//SCALE,H//SCALE), interpolation=cv2.INTER_CUBIC)[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE)
if img0 is None: continue
with torch.no_grad():
result_dict, extra_dict = model(img0, img1, pyr_level=args.pyr_level, nr_lvl_skipped=args.nr_lvl_skipped, time_step=0.5)
out = result_dict['imgt_pred']
bbb = os.path.join(SAVE_ROOT, str(2*frame_num-1).zfill(4)+f'.{EXT}')
if EXT not in ['tga', 'TGA']:
if SCALE==1:
out = (out[0].cpu().permute(1,2,0).clamp(0,1).numpy()*255).astype(np.uint8)[:,:,[2,1,0]]
else:
out = cv2.resize((out[0].cpu().permute(1,2,0).clamp(0,1).numpy()*255).astype(np.uint8)[:,:,[2,1,0]], (W,H), interpolation=cv2.INTER_CUBIC)
cv2.imwrite(bbb, out)
else:
if SCALE==1:
out = TF.to_pil_image(out[0].clamp(0,1).cpu())
else:
out = TF.to_pil_image(TF.resize(out[0].clamp(0,1).cpu(), (H,W), interpolation=TF.InterpolationMode.BICUBIC))
out.save(bbb)
if args.make_video:
cmd = f'ffmpeg -framerate {fps_after} -i {SAVE_ROOT}/%04d.{EXT} -c:v libx264 -preset veryslow -crf 10 {parent_folder(SAVE_ROOT)}/{args.new_video_name}'
os.system(cmd)
if __name__ == '__main__':
main()