|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
|
import os |
|
import cv2 |
|
import glob |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from tqdm import tqdm |
|
from torchvision.transforms import functional as TF |
|
|
|
from modules.components.upr_net_freq import upr_freq as upr_freq002 |
|
from modules.components.upr_basic import upr as upr_basic |
|
|
|
def parent_folder(path): |
|
return os.path.split(path)[0] |
|
|
|
print('์ธํผ๋ฐ์ค ์\n1. utils.pad.py replicate->constant๋ก ๋ณ๊ฒฝํ๊ณ \n2. components upr Model ์ต์ด์ธํ์์ normalization๊ณผ padding ์์น ๋ฐ๊ฟจ๋์ง ํ์ธํ ๊ฒ (padding์ด ์์ ์์ด์ผ๋จ)') |
|
def main(): |
|
parser = argparse.ArgumentParser('Video Frame Interpolation Inference',add_help=True) |
|
parser.add_argument('--cuda_index', default=0, type=int, help='CUDA GPU index') |
|
|
|
parser.add_argument('--use_video', action='store_true', help='whether using video file') |
|
parser.add_argument('--root', default='', type=str, help='root containing frames [./videoname/frames] (or video [./videoname/videoname.mp4])') |
|
parser.add_argument('--save_root', default='', type=str, help='root to save result frames [./videoname/results_expname]') |
|
parser.add_argument('--source_frame_ext', default='png', type=str, help='source frames extension name') |
|
|
|
parser.add_argument('--pretrain_path', default='', type=str, help='path containing pretrained model') |
|
|
|
parser.add_argument('--pyr_level', default=5, type=int, help='UPR-Net pyramid level') |
|
parser.add_argument('--nr_lvl_skipped', default=0, type=int, help='UPR-Net pyramid skip number') |
|
parser.add_argument('--splat_mode', default='average', type=str, help='UPR-Net warping splat mode') |
|
parser.add_argument('--down_scale', default=1, type=int, help='frame down-scaling factor (due to GPU memory issue)') |
|
|
|
parser.add_argument('--make_video', action='store_true', help='whether merging frames and making video file') |
|
parser.add_argument('--fps', default=0, type=int, help='FPS before VFI') |
|
parser.add_argument('--new_video_name', default='newvideo', type=str, help='new video name [new_video_name.mp4]') |
|
|
|
args = parser.parse_args() |
|
assert parent_folder(args.root)==parent_folder(args.save_root),\ |
|
f"the parents of 'root' ({parent_folder(args.root)}) and save_root ({parent_folder(args.save_root)}) should be same!!" |
|
if args.make_video: |
|
assert os.path.splitext(args.new_video_name)[1]!='', f"'new_video_name' ({args.new_video_name}) should have extension name!!" |
|
assert parent_folder(args.new_video_name)=='', f"'new_video_name' should not contain directory path" |
|
if args.use_video: |
|
temp1 = cv2.VideoCapture(args.root) |
|
temp2 = int(temp1.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
assert temp2>0, f"number of frames in video ({args.root}) must be larger than 0!! !!check file name!!" |
|
temp1.release() |
|
del temp1, temp2 |
|
|
|
DEVICE = args.cuda_index |
|
torch.cuda.set_device(DEVICE) |
|
VIDEO_ROOT = args.root if args.use_video else None |
|
FRAME_ROOT = args.root if VIDEO_ROOT is None else parent_folder(VIDEO_ROOT)+'/frames' |
|
SAVE_ROOT = args.save_root |
|
EXT = args.source_frame_ext |
|
SCALE = args.down_scale |
|
|
|
if VIDEO_ROOT is not None: |
|
print('@@@@@@@@@@@@@@@@@@@@Extracting frames from video@@@@@@@@@@@@@@@@@@@@') |
|
os.makedirs(FRAME_ROOT, exist_ok=True) |
|
video = cv2.VideoCapture(VIDEO_ROOT) |
|
this_fps = video.get(cv2.CAP_PROP_FPS) |
|
for index in tqdm(range(int(video.get(cv2.CAP_PROP_FRAME_COUNT)))): |
|
_, frame = video.read() |
|
newfile = os.path.join(FRAME_ROOT, str(index).zfill(4)+f'.{EXT}') |
|
cv2.imwrite(newfile, frame) |
|
video.release() |
|
|
|
model = upr_freq002.Model(pyr_level=args.pyr_level, |
|
nr_lvl_skipped=args.nr_lvl_skipped, |
|
splat_mode=args.splat_mode) |
|
sd = torch.load(args.pretrain_path, map_location='cpu') |
|
sd = sd['model'] if 'model' in sd.keys() else sd |
|
print(model.load_state_dict(sd)) |
|
model = model.to(DEVICE) |
|
|
|
file_list = sorted(glob.glob(os.path.join(FRAME_ROOT, f'*.{EXT}'))) |
|
for i, file in enumerate(file_list): |
|
newfile = os.path.join(FRAME_ROOT, str(i).zfill(4)+f'.{EXT}') |
|
os.rename(file, newfile) |
|
|
|
if args.make_video: |
|
num_frame_before = len(file_list) |
|
fps_before = args.fps if not args.use_video else this_fps |
|
num_frame_after = 2*num_frame_before-1 |
|
fps_after = fps_before*num_frame_after/num_frame_before |
|
print(f'num_frame_before: {num_frame_before}, fps_before: {fps_before:.6f}, time_before: {num_frame_before/fps_before:.6f}') |
|
print(f'num_frame_after: {num_frame_after}, fps_after: {fps_after:.6f}, time_after: {num_frame_after/fps_after:.6f}') |
|
print() |
|
|
|
print('@@@@@@@@@@@@@@@@@@@@Staring VFI@@@@@@@@@@@@@@@@@@@@') |
|
os.makedirs(SAVE_ROOT, exist_ok=True) |
|
for frame_num, file in enumerate(tqdm(file_list)): |
|
img0 = img1 if frame_num!=0 else None |
|
aaa = os.path.join(SAVE_ROOT, str(frame_num*2).zfill(4)+f'.{EXT}') |
|
if EXT not in ['tga', 'TGA']: |
|
img1 = cv2.imread(file) |
|
cv2.imwrite(aaa, img1) |
|
else: |
|
img1 = Image.open(file) |
|
img1.save(aaa) |
|
img1 = np.array(img1)[:,:,[2,1,0]] |
|
H,W,_ = img1.shape |
|
|
|
if SCALE==1: |
|
img1 = (torch.from_numpy(img1[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE) |
|
else: |
|
img1 = (torch.from_numpy(cv2.resize(img1, (W//SCALE,H//SCALE), interpolation=cv2.INTER_CUBIC)[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE) |
|
if img0 is None: continue |
|
|
|
with torch.no_grad(): |
|
result_dict, extra_dict = model(img0, img1, pyr_level=args.pyr_level, nr_lvl_skipped=args.nr_lvl_skipped, time_step=0.5) |
|
out = result_dict['imgt_pred'] |
|
|
|
bbb = os.path.join(SAVE_ROOT, str(2*frame_num-1).zfill(4)+f'.{EXT}') |
|
if EXT not in ['tga', 'TGA']: |
|
if SCALE==1: |
|
out = (out[0].cpu().permute(1,2,0).clamp(0,1).numpy()*255).astype(np.uint8)[:,:,[2,1,0]] |
|
else: |
|
out = cv2.resize((out[0].cpu().permute(1,2,0).clamp(0,1).numpy()*255).astype(np.uint8)[:,:,[2,1,0]], (W,H), interpolation=cv2.INTER_CUBIC) |
|
cv2.imwrite(bbb, out) |
|
else: |
|
if SCALE==1: |
|
out = TF.to_pil_image(out[0].clamp(0,1).cpu()) |
|
else: |
|
out = TF.to_pil_image(TF.resize(out[0].clamp(0,1).cpu(), (H,W), interpolation=TF.InterpolationMode.BICUBIC)) |
|
out.save(bbb) |
|
|
|
if args.make_video: |
|
cmd = f'ffmpeg -framerate {fps_after} -i {SAVE_ROOT}/%04d.{EXT} -c:v libx264 -preset veryslow -crf 10 {parent_folder(SAVE_ROOT)}/{args.new_video_name}' |
|
os.system(cmd) |
|
|
|
if __name__ == '__main__': |
|
main() |