File size: 4,845 Bytes
8d015d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import glob
import numpy
import os
import cv2
import math
import PIL.Image
import torch
import torch.nn.functional as F
import tqdm
import argparse
from moviepy.editor import VideoFileClip
import sys

from torchvision.utils import save_image
from utils.flowvis import flow2img
from utils.padder import InputPadder


##########################################################

##########################################################
def inference_demo(model, ratio, video_path, out_path):
    videogen = []
    is_video = video_path.endswith(".mkv") or video_path.endswith(".webm") or video_path.endswith(
        ".mp4") or video_path.endswith(".avi")
    if is_video:
        clip = VideoFileClip(video_path)
        videogen = clip.iter_frames()
        ratio = 2
        fps = clip.fps
        # if fps == 23 or fps == 25:
        #     fps = 24
        # if fps == 29 or fps == 31:
        #     fps = 30
        # if fps == 59:
        #     fps = 60
        # ratio = 120 // fps
        # if fps == 60:
        #     ratio = 120 // 24
    else:
        for f in os.listdir(video_path):
            if 'png' or 'jpg' in f:
                videogen.append(f)
                videogen.sort(key=lambda x: int(x[:-4]))

    if not os.path.exists(out_path):
        os.mkdir(out_path)
    if not os.path.exists(out_path + "_flow"):
        os.mkdir(out_path + '_flow')

    img0 = None
    idx = 0
    name_idx = 0
    time_range = torch.arange(1, ratio).view(ratio - 1, 1, 1, 1).cuda() / ratio
    for curfile_name in videogen:
        if not is_video:
            curframe = os.path.join(video_path, curfile_name)
            img4_np = cv2.imread(curframe)[:, :, ::-1]
        else:
            img4_np = curfile_name
        img4 = (torch.tensor(img4_np.transpose(2, 0, 1).copy()).float() / 255.0).unsqueeze(0).cuda()
        if img0 is None:
            img0 = img4
            cv2.imwrite(out_path + '/{:0>7d}.png'.format(name_idx),
                        (img0[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:, :, ::-1])
            _, _, h, w = img0.shape
            if h >= 2160:
                scale_factor = 0.25
                pyr_level = 8
                nr_lvl_skipped = 4
            elif h >= 1080:
                scale_factor = 0.5
                pyr_level = 7
                nr_lvl_skipped = 0
            else:
                scale_factor = 1
                pyr_level = 5
                nr_lvl_skipped = 0
            idx += 1
            name_idx += 1
            continue
        # if is_video:
        # if fps == 60:
        #     if idx % 5 != 0 and idx % 5 != 3:
        #         idx += 1
        #         continue
        # img0_ = F.interpolate(img0, scale_factor=pre_down, mode='bilinear')
        # img4_ = F.interpolate(img4, scale_factor=pre_down, mode='bilinear')
        results_dict = model(img0=img0, img1=img4, time_step=time_range, scale_factor=scale_factor,
                             ratio=(1 / scale_factor), pyr_level=pyr_level, nr_lvl_skipped=nr_lvl_skipped)
        imgt_pred = results_dict['imgt_pred']
        imgt_pred = torch.clip(imgt_pred, 0, 1)
        save_image(flow2img(results_dict['flowfwd']),
                   os.path.join(out_path + '_flow', "{:0>7d}ff.png".format(name_idx - 1)))
        save_image(flow2img(results_dict['flowbwd']),
                   os.path.join(out_path + '_flow', "{:0>7d}bb.png".format(name_idx - 1)))
        if "flowfwd_pre" in results_dict.keys():
            save_image(flow2img(results_dict['flowfwd_pre']),
                       os.path.join(out_path + '_flow', "pre_{:0>7d}ff.png".format(name_idx - 1)))
        save_image(results_dict['refine_res'], os.path.join(out_path, "refine_res.png"))
        save_image(results_dict['refine_mask'], os.path.join(out_path, "refine_mask.png"))
        save_image(results_dict['warped_img0'], os.path.join(out_path, "warped_img0.png"))
        save_image(results_dict['warped_img1'], os.path.join(out_path, "warped_img1.png"))
        save_image(results_dict['merged_img'], os.path.join(out_path, "merged_img.png"))

        img_pred = imgt_pred
        # img_pred = F.interpolate(img_pred, scale_factor=1 // pre_down, mode='bilinear')
        cv2.imwrite(out_path + '/{:0>7d}.png'.format(name_idx),
                    (img_pred[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:, :, ::-1])
        name_idx += 1
        # img4 = F.interpolate(img4, scale_factor=1 // pre_down, mode='bilinear')
        cv2.imwrite(out_path + '/{:0>7d}.png'.format(name_idx),
                    (img4[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:, :, ::-1])
        name_idx += 1
        idx += 1
        img0 = img4
    if is_video:
        os.system(
            f'ffmpeg  -framerate {fps * 2} -pattern_type glob -i "{out_path}/*.png" -c:v libx265 -qp 8 -pix_fmt yuv420p {out_path}_{fps * 2}fps.mp4')