File size: 8,112 Bytes
8d015d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# [START_COMMAND]
# python3 -m vfi_inference --cuda_index 0 \
# --use_video --root ../VFI_Inference/thistest/test_video.mp4 --save_root ../VFI_Inference/thistest/results --source_frame_ext png \
# --pretrain_path ./pretrained/upr_freq002.pth \
# --pyr_level 7 --nr_lvl_skipped 0 --splat_mode average --down_scale 1 \
# --make_video --fps 0 --new_video_name test_video_vfi.mp4

# python3 -m vfi_inference --cuda_index 0 \
# --root ../VFI_Inference/thistest/frames --save_root ../VFI_Inference/thistest/results --source_frame_ext png \
# --pretrain_path ./pretrained/upr_freq002.pth \
# --pyr_level 7 --nr_lvl_skipped 0 --splat_mode average --down_scale 1 \
# --make_video --fps 0 --new_video_name test_video_vfi.mp4

# [FILE SYSTEM] ํ”„๋ ˆ์ž„ ์‹œํ€€์Šค
# args.root ํด๋”์— ํ”„๋ ˆ์ž„ ์‹œํ€€์Šค ๋ชจ์—ฌ ์žˆ์–ด์•ผ ํ•จ
# args.save_root ํด๋”๋Š” args.root ํด๋”์™€ ์ƒ์œ„ํด๋”๊ฐ€ ๋™์ผํ•ด์•ผ ํ•˜๊ณ , args.save_root ํด๋”์— ๊ฒฐ๊ณผ๋ฌผ ํ”„๋ ˆ์ž„ ์‹œํ€€์Šค ์ €์žฅ๋จ

# [FILE SYSTEM] ๋น„๋””์˜ค
# args.root๋Š” ๋น„๋””์˜ค ๊ฒฝ๋กœ
# ํ”„๋ ˆ์ž„ ์‹œํ€€์Šค๋Š” ์ž๋™์œผ๋กœ ์ €์žฅ
# args.save_root ํด๋”๋Š” args.root ํŒŒ์ผ๊ณผ ์ƒ์œ„ํด๋”๊ฐ€ ๋™์ผํ•ด์•ผ ํ•˜๊ณ , args.save_root ํด๋”์— ๊ฒฐ๊ณผ๋ฌผ ํ”„๋ ˆ์ž„ ์‹œํ€€์Šค ์ €์žฅ๋จ

import argparse

import os
import cv2
import glob
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from torchvision.transforms import functional as TF

from modules.components.upr_net_freq import upr_freq as upr_freq002
from modules.components.upr_basic import upr as upr_basic

def parent_folder(path):
    return os.path.split(path)[0]

print('์ธํผ๋Ÿฐ์Šค ์‹œ\n1. utils.pad.py replicate->constant๋กœ ๋ณ€๊ฒฝํ•˜๊ณ \n2. components upr Model ์ตœ์ดˆ์ธํ’‹์—์„œ normalization๊ณผ padding ์œ„์น˜ ๋ฐ”๊ฟจ๋Š”์ง€ ํ™•์ธํ•  ๊ฒƒ (padding์ด ์œ„์— ์žˆ์–ด์•ผ๋จ)')
def main():
    parser = argparse.ArgumentParser('Video Frame Interpolation Inference',add_help=True)
    parser.add_argument('--cuda_index', default=0, type=int, help='CUDA GPU index')
    
    parser.add_argument('--use_video', action='store_true',  help='whether using video file')
    parser.add_argument('--root', default='', type=str, help='root containing frames [./videoname/frames] (or video [./videoname/videoname.mp4])')
    parser.add_argument('--save_root', default='', type=str, help='root to save result frames [./videoname/results_expname]')
    parser.add_argument('--source_frame_ext', default='png', type=str, help='source frames extension name')
    
    parser.add_argument('--pretrain_path', default='', type=str, help='path containing pretrained model')
    
    parser.add_argument('--pyr_level', default=5, type=int, help='UPR-Net pyramid level')
    parser.add_argument('--nr_lvl_skipped', default=0, type=int, help='UPR-Net pyramid skip number')
    parser.add_argument('--splat_mode', default='average', type=str, help='UPR-Net warping splat mode')
    parser.add_argument('--down_scale', default=1, type=int, help='frame down-scaling factor (due to GPU memory issue)')
    
    parser.add_argument('--make_video', action='store_true', help='whether merging frames and making video file')
    parser.add_argument('--fps', default=0, type=int, help='FPS before VFI')
    parser.add_argument('--new_video_name', default='newvideo', type=str, help='new video name [new_video_name.mp4]')
    
    args = parser.parse_args()
    assert parent_folder(args.root)==parent_folder(args.save_root),\
    f"the parents of 'root' ({parent_folder(args.root)}) and save_root ({parent_folder(args.save_root)}) should be same!!"
    if args.make_video:
        assert os.path.splitext(args.new_video_name)[1]!='', f"'new_video_name' ({args.new_video_name}) should have extension name!!"
        assert parent_folder(args.new_video_name)=='', f"'new_video_name' should not contain directory path"
    if args.use_video:
        temp1 = cv2.VideoCapture(args.root)
        temp2 = int(temp1.get(cv2.CAP_PROP_FRAME_COUNT))
        assert temp2>0, f"number of frames in video ({args.root}) must be larger than 0!! !!check file name!!"
        temp1.release()
        del temp1, temp2
    
    DEVICE = args.cuda_index
    torch.cuda.set_device(DEVICE)
    VIDEO_ROOT = args.root if args.use_video else None
    FRAME_ROOT = args.root if VIDEO_ROOT is None else parent_folder(VIDEO_ROOT)+'/frames'
    SAVE_ROOT = args.save_root
    EXT = args.source_frame_ext
    SCALE = args.down_scale
    
    if VIDEO_ROOT is not None:
        print('@@@@@@@@@@@@@@@@@@@@Extracting frames from video@@@@@@@@@@@@@@@@@@@@')
        os.makedirs(FRAME_ROOT, exist_ok=True)
        video = cv2.VideoCapture(VIDEO_ROOT)
        this_fps = video.get(cv2.CAP_PROP_FPS)
        for index in tqdm(range(int(video.get(cv2.CAP_PROP_FRAME_COUNT)))):
            _, frame = video.read()
            newfile = os.path.join(FRAME_ROOT, str(index).zfill(4)+f'.{EXT}')
            cv2.imwrite(newfile, frame)
        video.release()
    
    model = upr_freq002.Model(pyr_level=args.pyr_level,
                              nr_lvl_skipped=args.nr_lvl_skipped,
                              splat_mode=args.splat_mode)
    sd = torch.load(args.pretrain_path, map_location='cpu')
    sd = sd['model'] if 'model' in sd.keys() else sd
    print(model.load_state_dict(sd))
    model = model.to(DEVICE)
    
    file_list = sorted(glob.glob(os.path.join(FRAME_ROOT, f'*.{EXT}')))
    for i, file in enumerate(file_list):
        newfile = os.path.join(FRAME_ROOT, str(i).zfill(4)+f'.{EXT}')
        os.rename(file, newfile)
        
    if args.make_video:
        num_frame_before = len(file_list)
        fps_before = args.fps if not args.use_video else this_fps
        num_frame_after = 2*num_frame_before-1
        fps_after = fps_before*num_frame_after/num_frame_before
        print(f'num_frame_before: {num_frame_before}, fps_before: {fps_before:.6f}, time_before: {num_frame_before/fps_before:.6f}')
        print(f'num_frame_after: {num_frame_after}, fps_after: {fps_after:.6f}, time_after: {num_frame_after/fps_after:.6f}')
        print()
    
    print('@@@@@@@@@@@@@@@@@@@@Staring VFI@@@@@@@@@@@@@@@@@@@@')
    os.makedirs(SAVE_ROOT, exist_ok=True)
    for frame_num, file in enumerate(tqdm(file_list)):
        img0 = img1 if frame_num!=0 else None
        aaa = os.path.join(SAVE_ROOT, str(frame_num*2).zfill(4)+f'.{EXT}')
        if EXT not in ['tga', 'TGA']:
            img1 = cv2.imread(file)
            cv2.imwrite(aaa, img1)
        else:
            img1 = Image.open(file)
            img1.save(aaa)
            img1 = np.array(img1)[:,:,[2,1,0]]
        H,W,_ = img1.shape

        if SCALE==1:
            img1 = (torch.from_numpy(img1[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE)
        else:
            img1 = (torch.from_numpy(cv2.resize(img1, (W//SCALE,H//SCALE), interpolation=cv2.INTER_CUBIC)[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE)
        if img0 is None: continue

        with torch.no_grad():
            result_dict, extra_dict = model(img0, img1, pyr_level=args.pyr_level, nr_lvl_skipped=args.nr_lvl_skipped, time_step=0.5)
            out = result_dict['imgt_pred']
            
        bbb = os.path.join(SAVE_ROOT, str(2*frame_num-1).zfill(4)+f'.{EXT}')
        if EXT not in ['tga', 'TGA']:
            if SCALE==1:
                out = (out[0].cpu().permute(1,2,0).clamp(0,1).numpy()*255).astype(np.uint8)[:,:,[2,1,0]]
            else:
                out = cv2.resize((out[0].cpu().permute(1,2,0).clamp(0,1).numpy()*255).astype(np.uint8)[:,:,[2,1,0]], (W,H), interpolation=cv2.INTER_CUBIC)
            cv2.imwrite(bbb, out)
        else:
            if SCALE==1:
                out = TF.to_pil_image(out[0].clamp(0,1).cpu())
            else:
                out = TF.to_pil_image(TF.resize(out[0].clamp(0,1).cpu(), (H,W), interpolation=TF.InterpolationMode.BICUBIC))
            out.save(bbb)
            
    if args.make_video:
        cmd = f'ffmpeg -framerate {fps_after} -i {SAVE_ROOT}/%04d.{EXT} -c:v libx264 -preset veryslow -crf 10 {parent_folder(SAVE_ROOT)}/{args.new_video_name}'
        os.system(cmd)
        
if __name__ == '__main__':
    main()