# [START_COMMAND] # python3 -m vfi_inference --cuda_index 0 \ # --use_video --root ../VFI_Inference/thistest/test_video.mp4 --save_root ../VFI_Inference/thistest/results --source_frame_ext png \ # --pretrain_path ./pretrained/upr_freq002.pth \ # --pyr_level 7 --nr_lvl_skipped 0 --splat_mode average --down_scale 1 \ # --make_video --fps 0 --new_video_name test_video_vfi.mp4 # python3 -m vfi_inference --cuda_index 0 \ # --root ../VFI_Inference/thistest/frames --save_root ../VFI_Inference/thistest/results --source_frame_ext png \ # --pretrain_path ./pretrained/upr_freq002.pth \ # --pyr_level 7 --nr_lvl_skipped 0 --splat_mode average --down_scale 1 \ # --make_video --fps 0 --new_video_name test_video_vfi.mp4 # [FILE SYSTEM] 프레임 시퀀스 # args.root 폴더에 프레임 시퀀스 모여 있어야 함 # args.save_root 폴더는 args.root 폴더와 상위폴더가 동일해야 하고, args.save_root 폴더에 결과물 프레임 시퀀스 저장됨 # [FILE SYSTEM] 비디오 # args.root는 비디오 경로 # 프레임 시퀀스는 자동으로 저장 # args.save_root 폴더는 args.root 파일과 상위폴더가 동일해야 하고, args.save_root 폴더에 결과물 프레임 시퀀스 저장됨 import argparse import os import cv2 import glob import torch import numpy as np from PIL import Image from tqdm import tqdm from torchvision.transforms import functional as TF from modules.components.upr_net_freq import upr_freq as upr_freq002 from modules.components.upr_basic import upr as upr_basic def parent_folder(path): return os.path.split(path)[0] print('인퍼런스 시\n1. utils.pad.py replicate->constant로 변경하고\n2. components upr Model 최초인풋에서 normalization과 padding 위치 바꿨는지 확인할 것 (padding이 위에 있어야됨)') def main(): parser = argparse.ArgumentParser('Video Frame Interpolation Inference',add_help=True) parser.add_argument('--cuda_index', default=0, type=int, help='CUDA GPU index') parser.add_argument('--use_video', action='store_true', help='whether using video file') parser.add_argument('--root', default='', type=str, help='root containing frames [./videoname/frames] (or video [./videoname/videoname.mp4])') parser.add_argument('--save_root', default='', type=str, help='root to save result frames [./videoname/results_expname]') parser.add_argument('--source_frame_ext', default='png', type=str, help='source frames extension name') parser.add_argument('--pretrain_path', default='', type=str, help='path containing pretrained model') parser.add_argument('--pyr_level', default=5, type=int, help='UPR-Net pyramid level') parser.add_argument('--nr_lvl_skipped', default=0, type=int, help='UPR-Net pyramid skip number') parser.add_argument('--splat_mode', default='average', type=str, help='UPR-Net warping splat mode') parser.add_argument('--down_scale', default=1, type=int, help='frame down-scaling factor (due to GPU memory issue)') parser.add_argument('--make_video', action='store_true', help='whether merging frames and making video file') parser.add_argument('--fps', default=0, type=int, help='FPS before VFI') parser.add_argument('--new_video_name', default='newvideo', type=str, help='new video name [new_video_name.mp4]') args = parser.parse_args() assert parent_folder(args.root)==parent_folder(args.save_root),\ f"the parents of 'root' ({parent_folder(args.root)}) and save_root ({parent_folder(args.save_root)}) should be same!!" if args.make_video: assert os.path.splitext(args.new_video_name)[1]!='', f"'new_video_name' ({args.new_video_name}) should have extension name!!" assert parent_folder(args.new_video_name)=='', f"'new_video_name' should not contain directory path" if args.use_video: temp1 = cv2.VideoCapture(args.root) temp2 = int(temp1.get(cv2.CAP_PROP_FRAME_COUNT)) assert temp2>0, f"number of frames in video ({args.root}) must be larger than 0!! !!check file name!!" temp1.release() del temp1, temp2 DEVICE = args.cuda_index torch.cuda.set_device(DEVICE) VIDEO_ROOT = args.root if args.use_video else None FRAME_ROOT = args.root if VIDEO_ROOT is None else parent_folder(VIDEO_ROOT)+'/frames' SAVE_ROOT = args.save_root EXT = args.source_frame_ext SCALE = args.down_scale if VIDEO_ROOT is not None: print('@@@@@@@@@@@@@@@@@@@@Extracting frames from video@@@@@@@@@@@@@@@@@@@@') os.makedirs(FRAME_ROOT, exist_ok=True) video = cv2.VideoCapture(VIDEO_ROOT) this_fps = video.get(cv2.CAP_PROP_FPS) for index in tqdm(range(int(video.get(cv2.CAP_PROP_FRAME_COUNT)))): _, frame = video.read() newfile = os.path.join(FRAME_ROOT, str(index).zfill(4)+f'.{EXT}') cv2.imwrite(newfile, frame) video.release() model = upr_freq002.Model(pyr_level=args.pyr_level, nr_lvl_skipped=args.nr_lvl_skipped, splat_mode=args.splat_mode) sd = torch.load(args.pretrain_path, map_location='cpu') sd = sd['model'] if 'model' in sd.keys() else sd print(model.load_state_dict(sd)) model = model.to(DEVICE) file_list = sorted(glob.glob(os.path.join(FRAME_ROOT, f'*.{EXT}'))) for i, file in enumerate(file_list): newfile = os.path.join(FRAME_ROOT, str(i).zfill(4)+f'.{EXT}') os.rename(file, newfile) if args.make_video: num_frame_before = len(file_list) fps_before = args.fps if not args.use_video else this_fps num_frame_after = 2*num_frame_before-1 fps_after = fps_before*num_frame_after/num_frame_before print(f'num_frame_before: {num_frame_before}, fps_before: {fps_before:.6f}, time_before: {num_frame_before/fps_before:.6f}') print(f'num_frame_after: {num_frame_after}, fps_after: {fps_after:.6f}, time_after: {num_frame_after/fps_after:.6f}') print() print('@@@@@@@@@@@@@@@@@@@@Staring VFI@@@@@@@@@@@@@@@@@@@@') os.makedirs(SAVE_ROOT, exist_ok=True) for frame_num, file in enumerate(tqdm(file_list)): img0 = img1 if frame_num!=0 else None aaa = os.path.join(SAVE_ROOT, str(frame_num*2).zfill(4)+f'.{EXT}') if EXT not in ['tga', 'TGA']: img1 = cv2.imread(file) cv2.imwrite(aaa, img1) else: img1 = Image.open(file) img1.save(aaa) img1 = np.array(img1)[:,:,[2,1,0]] H,W,_ = img1.shape if SCALE==1: img1 = (torch.from_numpy(img1[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE) else: img1 = (torch.from_numpy(cv2.resize(img1, (W//SCALE,H//SCALE), interpolation=cv2.INTER_CUBIC)[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE) if img0 is None: continue with torch.no_grad(): result_dict, extra_dict = model(img0, img1, pyr_level=args.pyr_level, nr_lvl_skipped=args.nr_lvl_skipped, time_step=0.5) out = result_dict['imgt_pred'] bbb = os.path.join(SAVE_ROOT, str(2*frame_num-1).zfill(4)+f'.{EXT}') if EXT not in ['tga', 'TGA']: if SCALE==1: out = (out[0].cpu().permute(1,2,0).clamp(0,1).numpy()*255).astype(np.uint8)[:,:,[2,1,0]] else: out = cv2.resize((out[0].cpu().permute(1,2,0).clamp(0,1).numpy()*255).astype(np.uint8)[:,:,[2,1,0]], (W,H), interpolation=cv2.INTER_CUBIC) cv2.imwrite(bbb, out) else: if SCALE==1: out = TF.to_pil_image(out[0].clamp(0,1).cpu()) else: out = TF.to_pil_image(TF.resize(out[0].clamp(0,1).cpu(), (H,W), interpolation=TF.InterpolationMode.BICUBIC)) out.save(bbb) if args.make_video: cmd = f'ffmpeg -framerate {fps_after} -i {SAVE_ROOT}/%04d.{EXT} -c:v libx264 -preset veryslow -crf 10 {parent_folder(SAVE_ROOT)}/{args.new_video_name}' os.system(cmd) if __name__ == '__main__': main()