|
import gradio as gr |
|
import numpy as np |
|
import cv2 |
|
import os |
|
import glob |
|
import torch |
|
import shutil |
|
from PIL import Image |
|
from tqdm import tqdm |
|
from torch.nn import functional as F |
|
from torchvision.transforms import functional as TF |
|
from matplotlib import pyplot as plt |
|
from modules.components.upr_net_freq import upr_freq as upr_freq002 |
|
from modules.components.upr_basic import upr as upr_basic |
|
import datetime |
|
import zipfile |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
KEY = [("test", "test"), |
|
] |
|
|
|
|
|
|
|
def check_valid_login(user_name, password): |
|
|
|
|
|
flag = (user_name, password) in KEY |
|
return flag |
|
|
|
|
|
MAX_FRAME = 24 |
|
|
|
|
|
DEVICE = "cpu" |
|
|
|
|
|
|
|
SCALE = 1 |
|
pyr_level = 7 |
|
nr_lvl_skipped = 0 |
|
splat_mode = "average" |
|
pretrain_path = "./pretrained/upr_freq002.pth" |
|
|
|
model = upr_freq002.Model(pyr_level=pyr_level, |
|
nr_lvl_skipped=nr_lvl_skipped, |
|
splat_mode=splat_mode) |
|
sd = torch.load(pretrain_path, map_location='cpu') |
|
sd = sd['model'] if 'model' in sd.keys() else sd |
|
print(model.load_state_dict(sd)) |
|
model = model.to(DEVICE) |
|
|
|
def get_sorted_img(file_path): |
|
return sorted(glob.glob(os.path.join(file_path, f"*.png")), key = lambda x : float(x.split("_")[-1][:-4])) |
|
|
|
def multiple_pad(image, multiple): |
|
_,_,H,W = image.size() |
|
pad1 = multiple-(H%multiple) if H%multiple!=0 else 0 |
|
pad2 = multiple-(W%multiple) if W%multiple!=0 else 0 |
|
return TF.pad(image, (0,0,pad2,pad1)) |
|
|
|
|
|
def multiple_VFIx2(path1, path2, output_name): |
|
file_list = [path1, path2] |
|
img_list = [(torch.from_numpy(cv2.imread(file)[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE) for file in file_list] |
|
img_list = [multiple_pad(img, SCALE) for k, img in enumerate(img_list)] |
|
img_list = [F.interpolate(img, scale_factor=1/SCALE, mode='bicubic') for k, img in enumerate(img_list)] |
|
img0,img1 = img_list |
|
_,_,Hori,Wori = img0.size() |
|
|
|
with torch.no_grad(): |
|
result_dict, extra_dict = model(img0, img1, pyr_level=pyr_level, nr_lvl_skipped=nr_lvl_skipped, time_step=0.5) |
|
out = F.interpolate(result_dict['imgt_pred'], scale_factor=SCALE, mode='bicubic')[:,:,:Hori,:Wori].clamp(0,1) |
|
cv2.imwrite(output_name, (out[0].cpu().permute(1,2,0)*255).numpy().astype(np.uint8)[:,:,[2,1,0]]) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
""" |
|
def multiple_VFIx4(path1, path2, name1, name2, name3): |
|
multiple_VFIx2(path1, path2, name2) |
|
multiple_VFIx2(path1, name2, name1) |
|
multiple_VFIx2(name2, path2, name3) |
|
""" |
|
def multiple_VFIx4(path1, path2): |
|
frac = [".25", ".5", ".75"] |
|
name1 , name2, name3 = [f"{path1[:-4]}{f}.png" for f in frac] |
|
multiple_VFIx2(path1, path2, name2) |
|
multiple_VFIx2(path1, name2, name1) |
|
multiple_VFIx2(name2, path2, name3) |
|
|
|
|
|
def multiple_VFIx6(path1, path2): |
|
frac = [".125", ".25", ".75", ".875"] |
|
name_inf1 , name1, name2, name_inf2 = [f"{path1[:-4]}{f}.png" for f in frac] |
|
multiple_VFIx4(path1, path2) |
|
multiple_VFIx2(path1, name1, name_inf1) |
|
multiple_VFIx2(name2, path2, name_inf2) |
|
|
|
|
|
def fix_img(idx, fixed_list, input_dir = "input", output_dir = "output"): |
|
idx = int(idx) |
|
|
|
if idx < 1 or idx > MAX_FRAME - 2 or fixed_list[idx] == 1: |
|
return { |
|
fix_result_gallery : gr.Gallery(), |
|
fix_result_group : gr.Group(), |
|
fixed_frame : gr.Text() |
|
} |
|
now_time = os.path.basename(input_dir) |
|
output_dir = os.path.join(output_dir, f"fix_{now_time}") |
|
os.makedirs(output_dir, exist_ok = True) |
|
output_name = os.path.join(output_dir, f"img_{idx:03d}.png") |
|
|
|
multiple_VFIx2(os.path.join(input_dir, f"img_{idx - 1:03d}.png"), |
|
os.path.join(input_dir, f"img_{idx + 1:03d}.png"), |
|
output_name) |
|
fixed_list[idx] = 1 |
|
|
|
fixed_frame_string = "" |
|
result_list = [] |
|
name_list = [] |
|
|
|
for i in range(MAX_FRAME): |
|
if fixed_list[i] == 1: |
|
name_list.append(f"(fixed) frame {i}") |
|
result_list.append(os.path.join(output_dir, f"img_{i:03d}.png")) |
|
fixed_frame_string += f"{i}, " |
|
else: |
|
name_list.append(f"frame {i}") |
|
result_list.append(os.path.join(input_dir, f"img_{i:03d}.png")) |
|
return { |
|
fix_result_gallery : gr.Gallery(value = [(img, name) for img, name in zip(result_list, name_list)], selected_index = idx), |
|
fix_result_group : gr.Group(visible=True), |
|
fixed_frame : gr.Text(visible=True, value = fixed_frame_string[:-2]), |
|
} |
|
|
|
|
|
def ease_frames(ease_val, input_dir = "input", output_dir = "output", progress=gr.Progress(track_tqdm=False)): |
|
|
|
now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') |
|
output_dir = os.path.join(output_dir, f"ease_{now}") |
|
os.makedirs(output_dir, exist_ok = True) |
|
out_frame_list = [os.path.join(output_dir, f"img_{i:03d}.png") for i in range(MAX_FRAME)] |
|
for i, f in enumerate([os.path.join(input_dir, f"img_{i:03d}.png") for i in range(MAX_FRAME)]): |
|
shutil.copyfile(f, out_frame_list[i]) |
|
img_name = [] |
|
for i in progress.tqdm(range(MAX_FRAME - 1), desc = "VFI frames..."): |
|
img_name.append(f"frame {i}") |
|
if ease_val[i] == 1: pass |
|
|
|
|
|
elif ease_val[i] == 2: |
|
multiple_VFIx2(out_frame_list[i], out_frame_list[i + 1] |
|
, os.path.join(output_dir, f"img_{i:03d}.5.png")) |
|
img_name.append(f"(new) frame {i + 0.5}") |
|
elif ease_val[i] == 3: |
|
multiple_VFIx4(out_frame_list[i], out_frame_list[i + 1]) |
|
img_name.append(f"(new) frame {i + 0.25}") |
|
img_name.append(f"(new) frame {i + 0.5}") |
|
img_name.append(f"(new) frame {i + 0.75}") |
|
img_name.append(f"frame {MAX_FRAME - 1}") |
|
files = get_sorted_img(output_dir) |
|
|
|
|
|
zip_name = os.path.join(output_dir,"frame_list.zip") |
|
with zipfile.ZipFile(zip_name, 'w', compression=zipfile.ZIP_DEFLATED) as new_zip: |
|
for x in progress.tqdm(files, desc ="compress file..."): |
|
new_zip.write(x, os.path.basename(x)) |
|
|
|
return { |
|
ease_result_gallery : [(file, name) for file, name in zip(files, img_name)], |
|
ease_make_video : gr.Accordion(visible = True), |
|
last_ease_dir : output_dir, |
|
ease_zip : gr.File(value = zip_name) |
|
} |
|
|
|
def VFI_two(l, r, flag ,output_dir = "output"): |
|
now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') |
|
output_dir = os.path.join(output_dir, f"fix_img_{now}") |
|
os.makedirs(output_dir, exist_ok = True) |
|
l = Image.fromarray(l) |
|
r = Image.fromarray(r) |
|
|
|
W, H = l.size |
|
|
|
mul = ((3e6) / (W * H)) ** (1/2) |
|
H, W = int(H * mul), int(W * mul) |
|
|
|
if mul < 1: |
|
l = l.resize((W, H)) |
|
r = r.resize((W, H)) |
|
l_name, r_name = f"{output_dir}/img_000.png", f"{output_dir}/img_001.png" |
|
l.save(l_name) |
|
r.save(r_name) |
|
|
|
if flag == "x4": |
|
multiple_VFIx4(l_name, r_name) |
|
elif flag == "x2": |
|
output_name = f"{output_dir}/img_000.5.png" |
|
multiple_VFIx2(l_name, r_name, output_name) |
|
else: |
|
multiple_VFIx6(l_name, r_name) |
|
|
|
return { |
|
frame_gen_result_gallery : gr.Gallery(visible=True, value=get_sorted_img(output_dir)) |
|
} |
|
|
|
def clear_fix(): |
|
return{ |
|
img_0 : gr.Image(label="start image", sources =["upload"], value = None), |
|
img_1 : gr.Image(label="end image", sources =["upload"], value = None), |
|
frame_gen_result_gallery : gr.Gallery(visible=True, value=None) |
|
} |
|
|
|
with gr.Blocks(theme=gr.themes.Default(), title = "Inshorts Animator V. 0.5") as demo: |
|
def info(request: gr.Request): |
|
|
|
|
|
headers = request.headers |
|
print(headers["x-forwarded-for"].split(",")) |
|
demo.load(info, None) |
|
gr.Markdown(f"""# Inshorts Animator V. 0.5 WebUI (Permitted User Only)""") |
|
with gr.Tab("Mid Frame Generator"): |
|
with gr.Column(): |
|
|
|
with gr.Row(): |
|
img_0 = gr.Image(label="start image", sources =["upload"]) |
|
img_1 = gr.Image(label="end image", sources =["upload"]) |
|
with gr.Row(): |
|
VFI_flag = gr.Radio(["x2", "x4", "x6(side ease)"], label="VFI ratio", value = "x2", interactive = True) |
|
image_button = gr.Button("Run model") |
|
frame_gen_result_gallery = gr.Gallery(visible=True, |
|
label="result", columns=[5], rows=[1], object_fit="contain", height="auto", preview = True, |
|
interactive = False) |
|
image_button.click(VFI_two, inputs=[img_0, img_1, VFI_flag], |
|
outputs=[frame_gen_result_gallery]) |
|
clear_button = gr.Button("Clear images") |
|
clear_button.click(clear_fix, inputs=[], |
|
outputs=[img_0, img_1, frame_gen_result_gallery]) |
|
with gr.Tab("Video"): |
|
with gr.Group(visible=True) as video_input_group: |
|
gr.Markdown(f"""#### only can handle {MAX_FRAME} frames""") |
|
with gr.Column(): |
|
input_dir = gr.State("") |
|
fps = gr.Number(visible=False) |
|
video_input = gr.Video(label="Input Video", interactive=True, sources=['upload']) |
|
gr.Markdown(f"""If video frame size is big, it will be resized""") |
|
upload_button = gr.Button("upload video") |
|
with gr.Group(visible=False) as image_edit_group: |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Tab("Original Frame (for Monitoring)"): |
|
fixed_list = gr.State([0] * (MAX_FRAME)) |
|
selected = gr.Number(visible=False, label = "selected frame", interactive = False) |
|
image_gallery = gr.Gallery( |
|
label="inputs", columns=[MAX_FRAME], rows=[1], object_fit="contain", height="auto", preview = True, |
|
show_download_button=False) |
|
clear_video = gr.Button("clear video") |
|
with gr.Tab("Frame Fixer"): |
|
with gr.Column(): |
|
|
|
|
|
with gr.Group(visible=True) as fix_result_group: |
|
fix_result_gallery = gr.Gallery( |
|
label="result", columns=[MAX_FRAME], rows=[1], object_fit="contain", height="auto", preview = True, |
|
interactive = False) |
|
fix_button = gr.Button(visible = True) |
|
with gr.Row(): |
|
fixed_frame = gr.Text(visible=False, label = "fixed frame", interactive = False) |
|
fix_button.click(fix_img, inputs=[selected, fixed_list, input_dir], |
|
outputs=[fix_result_gallery, fix_result_group, fixed_frame]) |
|
def update_fix_button_visible(evt: gr.SelectData): |
|
flag = 0 < evt.index < MAX_FRAME - 1 |
|
msg = f"fix frame {evt.index}" if flag else f"can only fix 1 ~ {MAX_FRAME - 2}" |
|
return { |
|
fix_button:gr.Button(msg, visible=True), |
|
selected : evt.index, |
|
fix_result_gallery : gr.Gallery(selected_index = evt.index) |
|
} |
|
image_gallery.select(update_fix_button_visible, None, [fix_button, selected, fix_result_gallery]) |
|
with gr.Tab("Motion easer"): |
|
with gr.Column(): |
|
with gr.Column(): |
|
|
|
with gr.Group(visible=True) as ease_result_group: |
|
last_ease_dir = gr.State("") |
|
ease_result_gallery = gr.Gallery( |
|
label="result", columns=[MAX_FRAME], rows=[4], object_fit="contain", height="auto", preview = True, |
|
interactive = False) |
|
ease_button = gr.Button("ease", visible = True) |
|
plt_data = gr.State([1] * (MAX_FRAME - 1)) |
|
VFI_x = gr.Radio([("x1", 1), ("x2", 2), ("x4", 3)], value = 1, label="Slow ratio", info="adjust Slow ratio", interactive = True) |
|
with gr.Row(): |
|
edit_one_button = gr.Button("edit one scale", visible = True) |
|
edit_all_button = gr.Button("edit all scale", visible = True) |
|
now_frame = gr.Slider(0, MAX_FRAME - 1 - 1, step=1, label="Start frame", info="Choose Start frame to make slow. Interpolation will apply to (frame ~ frame + 1)") |
|
|
|
def plt_edit(data): |
|
fig = plt.figure() |
|
x = np.arange(0, MAX_FRAME - 1) + 0.5 |
|
y = np.array(data) |
|
plt.plot(x , y, color = 'black', marker = "o", linewidth = "2.5") |
|
plt.xticks(np.arange(0, MAX_FRAME)) |
|
plt.yticks([1, 2, 3], ["x1", "x2\nslow", "x4\nslow"]) |
|
plt.gca().invert_yaxis() |
|
plt.grid(True) |
|
plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False |
|
plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True |
|
return fig |
|
ease_plot = gr.Plot(value = plt_edit(plt_data.value), show_label=False) |
|
with gr.Accordion("get result", visible = False) as ease_make_video: |
|
ease_zip = gr.File(label = "Download all image frames in Zip", interactive = False) |
|
make_video_button = gr.Button("make video") |
|
result_video = gr.Video(interactive = False) |
|
def make_video(frame_dir, fps): |
|
t = os.path.basename(frame_dir) |
|
output_name = f"{frame_dir}/{t}.mp4" |
|
if os.path.exists(output_name): |
|
os.remove(output_name) |
|
frame_list = get_sorted_img(frame_dir) |
|
with open(f"{frame_dir}/input.txt", "w") as f: |
|
for line in frame_list: |
|
f.write(f"file '{os.path.basename(line)}'\n") |
|
cmd = f'ffmpeg -r {fps} -f concat -safe 0 -i {frame_dir}/input.txt -c:v libx264 -preset veryslow -crf 10 {output_name}' |
|
os.system(cmd) |
|
return output_name |
|
make_video_button.click(make_video, inputs = [last_ease_dir, fps], outputs = [result_video]) |
|
ease_button.click(ease_frames, inputs=[plt_data, input_dir], outputs=[ease_result_gallery, ease_make_video, last_ease_dir, ease_zip]) |
|
|
|
def edit_one_scale(data, idx, x): |
|
if idx < MAX_FRAME - 1: |
|
data[idx] = x if x else 1 |
|
return plt_edit(data) |
|
edit_one_button.click(edit_one_scale, inputs=[plt_data, now_frame, VFI_x] , outputs=[ease_plot]) |
|
def edit_all_scale(data, x): |
|
for i in range(len(data)): data[i] = x if x else 1 |
|
return plt_edit(data) |
|
edit_all_button.click(edit_all_scale, inputs=[plt_data, VFI_x], outputs=[ease_plot]) |
|
def clear_vd(plt_data, fixed_list): |
|
for i in range(len(plt_data)): plt_data[i] = 1 |
|
for i in range(len(fixed_list)): fixed_list[i] = 0 |
|
return {video_input:gr.Video(label="Input Video", interactive=True, sources=['upload'], value = None), |
|
ease_result_gallery : gr.Gallery( |
|
label="result", columns=[MAX_FRAME], rows=[4], object_fit="contain", height="auto", preview = True, |
|
interactive = False, value = None), |
|
fix_result_gallery : gr.Gallery( |
|
label="result", columns=[MAX_FRAME], rows=[1], object_fit="contain", height="auto", preview = True, |
|
interactive = False, value = None), |
|
fixed_frame : gr.Text(visible=False, label = "fixed frame", interactive = False, value = None), |
|
ease_make_video : gr.Accordion(visible = True), |
|
video_input_group:gr.Group(visible=True), |
|
image_edit_group:gr.Group(visible=False), |
|
ease_plot : gr.Plot(value = plt_edit(plt_data))} |
|
clear_video.click(clear_vd, inputs=[plt_data, fixed_list],outputs=[video_input, ease_result_gallery, fix_result_gallery, fixed_frame, ease_make_video, video_input_group, image_edit_group, ease_plot]) |
|
def update_video_visible(video): |
|
if not video: |
|
return {video_input_group:gr.Group(visible=True), |
|
image_edit_group:gr.Group(visible=False), |
|
image_gallery:[], |
|
input_dir : "", |
|
fps : 0 |
|
} |
|
now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') |
|
input_now = os.path.join("input", now) |
|
os.makedirs(input_now, exist_ok = True) |
|
cap = cv2.VideoCapture(video) |
|
frame_count = 0 |
|
video_fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
|
|
|
H = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) |
|
W = cap.get(cv2.CAP_PROP_FRAME_WIDTH) |
|
mul = ((3e6) / (W * H)) ** (1/2) |
|
H, W = int(H * mul), int(W * mul) |
|
|
|
frame_name_list = [] |
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
img_name = os.path.join(input_now, f"img_{frame_count:03d}.png") |
|
if mul < 1: |
|
frame = cv2.resize(frame, (W, H), interpolation=cv2.INTER_CUBIC) |
|
cv2.imwrite(img_name, frame) |
|
frame_name_list.append((img_name, f"frame {frame_count}")) |
|
frame_count += 1 |
|
if frame_count >= MAX_FRAME: break |
|
cap.release() |
|
return {video_input_group:gr.Group(visible=False), |
|
image_edit_group:gr.Group(visible=True), |
|
image_gallery:frame_name_list, |
|
input_dir : input_now, |
|
fps : video_fps |
|
} |
|
upload_button.click(update_video_visible, |
|
[video_input], |
|
[video_input_group, image_edit_group, image_gallery, input_dir, fps]) |
|
|
|
if __name__ == '__main__': |
|
demo.launch(allowed_paths=["./input", "./output"], auth = check_valid_login, auth_message = "Inshorts Animator V. 0.5 WebUI (Permitted User Only)", share = True) |
|
|