File size: 8,112 Bytes
8d015d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# [START_COMMAND]
# python3 -m vfi_inference --cuda_index 0 \
# --use_video --root ../VFI_Inference/thistest/test_video.mp4 --save_root ../VFI_Inference/thistest/results --source_frame_ext png \
# --pretrain_path ./pretrained/upr_freq002.pth \
# --pyr_level 7 --nr_lvl_skipped 0 --splat_mode average --down_scale 1 \
# --make_video --fps 0 --new_video_name test_video_vfi.mp4
# python3 -m vfi_inference --cuda_index 0 \
# --root ../VFI_Inference/thistest/frames --save_root ../VFI_Inference/thistest/results --source_frame_ext png \
# --pretrain_path ./pretrained/upr_freq002.pth \
# --pyr_level 7 --nr_lvl_skipped 0 --splat_mode average --down_scale 1 \
# --make_video --fps 0 --new_video_name test_video_vfi.mp4
# [FILE SYSTEM] ํ๋ ์ ์ํ์ค
# args.root ํด๋์ ํ๋ ์ ์ํ์ค ๋ชจ์ฌ ์์ด์ผ ํจ
# args.save_root ํด๋๋ args.root ํด๋์ ์์ํด๋๊ฐ ๋์ผํด์ผ ํ๊ณ , args.save_root ํด๋์ ๊ฒฐ๊ณผ๋ฌผ ํ๋ ์ ์ํ์ค ์ ์ฅ๋จ
# [FILE SYSTEM] ๋น๋์ค
# args.root๋ ๋น๋์ค ๊ฒฝ๋ก
# ํ๋ ์ ์ํ์ค๋ ์๋์ผ๋ก ์ ์ฅ
# args.save_root ํด๋๋ args.root ํ์ผ๊ณผ ์์ํด๋๊ฐ ๋์ผํด์ผ ํ๊ณ , args.save_root ํด๋์ ๊ฒฐ๊ณผ๋ฌผ ํ๋ ์ ์ํ์ค ์ ์ฅ๋จ
import argparse
import os
import cv2
import glob
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from torchvision.transforms import functional as TF
from modules.components.upr_net_freq import upr_freq as upr_freq002
from modules.components.upr_basic import upr as upr_basic
def parent_folder(path):
return os.path.split(path)[0]
print('์ธํผ๋ฐ์ค ์\n1. utils.pad.py replicate->constant๋ก ๋ณ๊ฒฝํ๊ณ \n2. components upr Model ์ต์ด์ธํ์์ normalization๊ณผ padding ์์น ๋ฐ๊ฟจ๋์ง ํ์ธํ ๊ฒ (padding์ด ์์ ์์ด์ผ๋จ)')
def main():
parser = argparse.ArgumentParser('Video Frame Interpolation Inference',add_help=True)
parser.add_argument('--cuda_index', default=0, type=int, help='CUDA GPU index')
parser.add_argument('--use_video', action='store_true', help='whether using video file')
parser.add_argument('--root', default='', type=str, help='root containing frames [./videoname/frames] (or video [./videoname/videoname.mp4])')
parser.add_argument('--save_root', default='', type=str, help='root to save result frames [./videoname/results_expname]')
parser.add_argument('--source_frame_ext', default='png', type=str, help='source frames extension name')
parser.add_argument('--pretrain_path', default='', type=str, help='path containing pretrained model')
parser.add_argument('--pyr_level', default=5, type=int, help='UPR-Net pyramid level')
parser.add_argument('--nr_lvl_skipped', default=0, type=int, help='UPR-Net pyramid skip number')
parser.add_argument('--splat_mode', default='average', type=str, help='UPR-Net warping splat mode')
parser.add_argument('--down_scale', default=1, type=int, help='frame down-scaling factor (due to GPU memory issue)')
parser.add_argument('--make_video', action='store_true', help='whether merging frames and making video file')
parser.add_argument('--fps', default=0, type=int, help='FPS before VFI')
parser.add_argument('--new_video_name', default='newvideo', type=str, help='new video name [new_video_name.mp4]')
args = parser.parse_args()
assert parent_folder(args.root)==parent_folder(args.save_root),\
f"the parents of 'root' ({parent_folder(args.root)}) and save_root ({parent_folder(args.save_root)}) should be same!!"
if args.make_video:
assert os.path.splitext(args.new_video_name)[1]!='', f"'new_video_name' ({args.new_video_name}) should have extension name!!"
assert parent_folder(args.new_video_name)=='', f"'new_video_name' should not contain directory path"
if args.use_video:
temp1 = cv2.VideoCapture(args.root)
temp2 = int(temp1.get(cv2.CAP_PROP_FRAME_COUNT))
assert temp2>0, f"number of frames in video ({args.root}) must be larger than 0!! !!check file name!!"
temp1.release()
del temp1, temp2
DEVICE = args.cuda_index
torch.cuda.set_device(DEVICE)
VIDEO_ROOT = args.root if args.use_video else None
FRAME_ROOT = args.root if VIDEO_ROOT is None else parent_folder(VIDEO_ROOT)+'/frames'
SAVE_ROOT = args.save_root
EXT = args.source_frame_ext
SCALE = args.down_scale
if VIDEO_ROOT is not None:
print('@@@@@@@@@@@@@@@@@@@@Extracting frames from video@@@@@@@@@@@@@@@@@@@@')
os.makedirs(FRAME_ROOT, exist_ok=True)
video = cv2.VideoCapture(VIDEO_ROOT)
this_fps = video.get(cv2.CAP_PROP_FPS)
for index in tqdm(range(int(video.get(cv2.CAP_PROP_FRAME_COUNT)))):
_, frame = video.read()
newfile = os.path.join(FRAME_ROOT, str(index).zfill(4)+f'.{EXT}')
cv2.imwrite(newfile, frame)
video.release()
model = upr_freq002.Model(pyr_level=args.pyr_level,
nr_lvl_skipped=args.nr_lvl_skipped,
splat_mode=args.splat_mode)
sd = torch.load(args.pretrain_path, map_location='cpu')
sd = sd['model'] if 'model' in sd.keys() else sd
print(model.load_state_dict(sd))
model = model.to(DEVICE)
file_list = sorted(glob.glob(os.path.join(FRAME_ROOT, f'*.{EXT}')))
for i, file in enumerate(file_list):
newfile = os.path.join(FRAME_ROOT, str(i).zfill(4)+f'.{EXT}')
os.rename(file, newfile)
if args.make_video:
num_frame_before = len(file_list)
fps_before = args.fps if not args.use_video else this_fps
num_frame_after = 2*num_frame_before-1
fps_after = fps_before*num_frame_after/num_frame_before
print(f'num_frame_before: {num_frame_before}, fps_before: {fps_before:.6f}, time_before: {num_frame_before/fps_before:.6f}')
print(f'num_frame_after: {num_frame_after}, fps_after: {fps_after:.6f}, time_after: {num_frame_after/fps_after:.6f}')
print()
print('@@@@@@@@@@@@@@@@@@@@Staring VFI@@@@@@@@@@@@@@@@@@@@')
os.makedirs(SAVE_ROOT, exist_ok=True)
for frame_num, file in enumerate(tqdm(file_list)):
img0 = img1 if frame_num!=0 else None
aaa = os.path.join(SAVE_ROOT, str(frame_num*2).zfill(4)+f'.{EXT}')
if EXT not in ['tga', 'TGA']:
img1 = cv2.imread(file)
cv2.imwrite(aaa, img1)
else:
img1 = Image.open(file)
img1.save(aaa)
img1 = np.array(img1)[:,:,[2,1,0]]
H,W,_ = img1.shape
if SCALE==1:
img1 = (torch.from_numpy(img1[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE)
else:
img1 = (torch.from_numpy(cv2.resize(img1, (W//SCALE,H//SCALE), interpolation=cv2.INTER_CUBIC)[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE)
if img0 is None: continue
with torch.no_grad():
result_dict, extra_dict = model(img0, img1, pyr_level=args.pyr_level, nr_lvl_skipped=args.nr_lvl_skipped, time_step=0.5)
out = result_dict['imgt_pred']
bbb = os.path.join(SAVE_ROOT, str(2*frame_num-1).zfill(4)+f'.{EXT}')
if EXT not in ['tga', 'TGA']:
if SCALE==1:
out = (out[0].cpu().permute(1,2,0).clamp(0,1).numpy()*255).astype(np.uint8)[:,:,[2,1,0]]
else:
out = cv2.resize((out[0].cpu().permute(1,2,0).clamp(0,1).numpy()*255).astype(np.uint8)[:,:,[2,1,0]], (W,H), interpolation=cv2.INTER_CUBIC)
cv2.imwrite(bbb, out)
else:
if SCALE==1:
out = TF.to_pil_image(out[0].clamp(0,1).cpu())
else:
out = TF.to_pil_image(TF.resize(out[0].clamp(0,1).cpu(), (H,W), interpolation=TF.InterpolationMode.BICUBIC))
out.save(bbb)
if args.make_video:
cmd = f'ffmpeg -framerate {fps_after} -i {SAVE_ROOT}/%04d.{EXT} -c:v libx264 -preset veryslow -crf 10 {parent_folder(SAVE_ROOT)}/{args.new_video_name}'
os.system(cmd)
if __name__ == '__main__':
main() |