VfiTest / modules /models /inference_video.py
SuyeonJ's picture
Upload folder using huggingface_hub
8d015d4 verified
raw
history blame
4.85 kB
import glob
import numpy
import os
import cv2
import math
import PIL.Image
import torch
import torch.nn.functional as F
import tqdm
import argparse
from moviepy.editor import VideoFileClip
import sys
from torchvision.utils import save_image
from utils.flowvis import flow2img
from utils.padder import InputPadder
##########################################################
##########################################################
def inference_demo(model, ratio, video_path, out_path):
videogen = []
is_video = video_path.endswith(".mkv") or video_path.endswith(".webm") or video_path.endswith(
".mp4") or video_path.endswith(".avi")
if is_video:
clip = VideoFileClip(video_path)
videogen = clip.iter_frames()
ratio = 2
fps = clip.fps
# if fps == 23 or fps == 25:
# fps = 24
# if fps == 29 or fps == 31:
# fps = 30
# if fps == 59:
# fps = 60
# ratio = 120 // fps
# if fps == 60:
# ratio = 120 // 24
else:
for f in os.listdir(video_path):
if 'png' or 'jpg' in f:
videogen.append(f)
videogen.sort(key=lambda x: int(x[:-4]))
if not os.path.exists(out_path):
os.mkdir(out_path)
if not os.path.exists(out_path + "_flow"):
os.mkdir(out_path + '_flow')
img0 = None
idx = 0
name_idx = 0
time_range = torch.arange(1, ratio).view(ratio - 1, 1, 1, 1).cuda() / ratio
for curfile_name in videogen:
if not is_video:
curframe = os.path.join(video_path, curfile_name)
img4_np = cv2.imread(curframe)[:, :, ::-1]
else:
img4_np = curfile_name
img4 = (torch.tensor(img4_np.transpose(2, 0, 1).copy()).float() / 255.0).unsqueeze(0).cuda()
if img0 is None:
img0 = img4
cv2.imwrite(out_path + '/{:0>7d}.png'.format(name_idx),
(img0[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:, :, ::-1])
_, _, h, w = img0.shape
if h >= 2160:
scale_factor = 0.25
pyr_level = 8
nr_lvl_skipped = 4
elif h >= 1080:
scale_factor = 0.5
pyr_level = 7
nr_lvl_skipped = 0
else:
scale_factor = 1
pyr_level = 5
nr_lvl_skipped = 0
idx += 1
name_idx += 1
continue
# if is_video:
# if fps == 60:
# if idx % 5 != 0 and idx % 5 != 3:
# idx += 1
# continue
# img0_ = F.interpolate(img0, scale_factor=pre_down, mode='bilinear')
# img4_ = F.interpolate(img4, scale_factor=pre_down, mode='bilinear')
results_dict = model(img0=img0, img1=img4, time_step=time_range, scale_factor=scale_factor,
ratio=(1 / scale_factor), pyr_level=pyr_level, nr_lvl_skipped=nr_lvl_skipped)
imgt_pred = results_dict['imgt_pred']
imgt_pred = torch.clip(imgt_pred, 0, 1)
save_image(flow2img(results_dict['flowfwd']),
os.path.join(out_path + '_flow', "{:0>7d}ff.png".format(name_idx - 1)))
save_image(flow2img(results_dict['flowbwd']),
os.path.join(out_path + '_flow', "{:0>7d}bb.png".format(name_idx - 1)))
if "flowfwd_pre" in results_dict.keys():
save_image(flow2img(results_dict['flowfwd_pre']),
os.path.join(out_path + '_flow', "pre_{:0>7d}ff.png".format(name_idx - 1)))
save_image(results_dict['refine_res'], os.path.join(out_path, "refine_res.png"))
save_image(results_dict['refine_mask'], os.path.join(out_path, "refine_mask.png"))
save_image(results_dict['warped_img0'], os.path.join(out_path, "warped_img0.png"))
save_image(results_dict['warped_img1'], os.path.join(out_path, "warped_img1.png"))
save_image(results_dict['merged_img'], os.path.join(out_path, "merged_img.png"))
img_pred = imgt_pred
# img_pred = F.interpolate(img_pred, scale_factor=1 // pre_down, mode='bilinear')
cv2.imwrite(out_path + '/{:0>7d}.png'.format(name_idx),
(img_pred[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:, :, ::-1])
name_idx += 1
# img4 = F.interpolate(img4, scale_factor=1 // pre_down, mode='bilinear')
cv2.imwrite(out_path + '/{:0>7d}.png'.format(name_idx),
(img4[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:, :, ::-1])
name_idx += 1
idx += 1
img0 = img4
if is_video:
os.system(
f'ffmpeg -framerate {fps * 2} -pattern_type glob -i "{out_path}/*.png" -c:v libx265 -qp 8 -pix_fmt yuv420p {out_path}_{fps * 2}fps.mp4')