|
import glob |
|
import numpy |
|
import os |
|
import cv2 |
|
import math |
|
import PIL.Image |
|
import torch |
|
import torch.nn.functional as F |
|
import tqdm |
|
import argparse |
|
from moviepy.editor import VideoFileClip |
|
import sys |
|
|
|
from torchvision.utils import save_image |
|
from utils.flowvis import flow2img |
|
from utils.padder import InputPadder |
|
|
|
|
|
|
|
|
|
|
|
def inference_demo(model, ratio, video_path, out_path): |
|
videogen = [] |
|
is_video = video_path.endswith(".mkv") or video_path.endswith(".webm") or video_path.endswith( |
|
".mp4") or video_path.endswith(".avi") |
|
if is_video: |
|
clip = VideoFileClip(video_path) |
|
videogen = clip.iter_frames() |
|
ratio = 2 |
|
fps = clip.fps |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
for f in os.listdir(video_path): |
|
if 'png' or 'jpg' in f: |
|
videogen.append(f) |
|
videogen.sort(key=lambda x: int(x[:-4])) |
|
|
|
if not os.path.exists(out_path): |
|
os.mkdir(out_path) |
|
if not os.path.exists(out_path + "_flow"): |
|
os.mkdir(out_path + '_flow') |
|
|
|
img0 = None |
|
idx = 0 |
|
name_idx = 0 |
|
time_range = torch.arange(1, ratio).view(ratio - 1, 1, 1, 1).cuda() / ratio |
|
for curfile_name in videogen: |
|
if not is_video: |
|
curframe = os.path.join(video_path, curfile_name) |
|
img4_np = cv2.imread(curframe)[:, :, ::-1] |
|
else: |
|
img4_np = curfile_name |
|
img4 = (torch.tensor(img4_np.transpose(2, 0, 1).copy()).float() / 255.0).unsqueeze(0).cuda() |
|
if img0 is None: |
|
img0 = img4 |
|
cv2.imwrite(out_path + '/{:0>7d}.png'.format(name_idx), |
|
(img0[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:, :, ::-1]) |
|
_, _, h, w = img0.shape |
|
if h >= 2160: |
|
scale_factor = 0.25 |
|
pyr_level = 8 |
|
nr_lvl_skipped = 4 |
|
elif h >= 1080: |
|
scale_factor = 0.5 |
|
pyr_level = 7 |
|
nr_lvl_skipped = 0 |
|
else: |
|
scale_factor = 1 |
|
pyr_level = 5 |
|
nr_lvl_skipped = 0 |
|
idx += 1 |
|
name_idx += 1 |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
results_dict = model(img0=img0, img1=img4, time_step=time_range, scale_factor=scale_factor, |
|
ratio=(1 / scale_factor), pyr_level=pyr_level, nr_lvl_skipped=nr_lvl_skipped) |
|
imgt_pred = results_dict['imgt_pred'] |
|
imgt_pred = torch.clip(imgt_pred, 0, 1) |
|
save_image(flow2img(results_dict['flowfwd']), |
|
os.path.join(out_path + '_flow', "{:0>7d}ff.png".format(name_idx - 1))) |
|
save_image(flow2img(results_dict['flowbwd']), |
|
os.path.join(out_path + '_flow', "{:0>7d}bb.png".format(name_idx - 1))) |
|
if "flowfwd_pre" in results_dict.keys(): |
|
save_image(flow2img(results_dict['flowfwd_pre']), |
|
os.path.join(out_path + '_flow', "pre_{:0>7d}ff.png".format(name_idx - 1))) |
|
save_image(results_dict['refine_res'], os.path.join(out_path, "refine_res.png")) |
|
save_image(results_dict['refine_mask'], os.path.join(out_path, "refine_mask.png")) |
|
save_image(results_dict['warped_img0'], os.path.join(out_path, "warped_img0.png")) |
|
save_image(results_dict['warped_img1'], os.path.join(out_path, "warped_img1.png")) |
|
save_image(results_dict['merged_img'], os.path.join(out_path, "merged_img.png")) |
|
|
|
img_pred = imgt_pred |
|
|
|
cv2.imwrite(out_path + '/{:0>7d}.png'.format(name_idx), |
|
(img_pred[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:, :, ::-1]) |
|
name_idx += 1 |
|
|
|
cv2.imwrite(out_path + '/{:0>7d}.png'.format(name_idx), |
|
(img4[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:, :, ::-1]) |
|
name_idx += 1 |
|
idx += 1 |
|
img0 = img4 |
|
if is_video: |
|
os.system( |
|
f'ffmpeg -framerate {fps * 2} -pattern_type glob -i "{out_path}/*.png" -c:v libx265 -qp 8 -pix_fmt yuv420p {out_path}_{fps * 2}fps.mp4') |
|
|