import gradio as gr import numpy as np import cv2 import os import glob import torch import shutil from PIL import Image from tqdm import tqdm from torch.nn import functional as F from torchvision.transforms import functional as TF from matplotlib import pyplot as plt from modules.components.upr_net_freq import upr_freq as upr_freq002 from modules.components.upr_basic import upr as upr_basic import datetime import zipfile os.system('python -m pip install --upgrade pip') #from scipy.interpolate import make_interp_spline # python3 -m vfi_inference_triplet --cuda_index 0 \ # --root ../VFI_Inference/thistriplet_notarget --pretrain_path ./pretrained/upr_freq002.pth \ # --pyr_level 7 --nr_lvl_skipped 0 --splat_mode average --down_scale 1 # 아이디, 비밀번호 튜플, 리스트에 추가하면 여러 사용자가 사용 가능합니다. # 다른 파일로 만들어 사용하기를 권장합니다. KEY = [("test", "test"), ] # 로그인 시 호출되는 함수입니다. #혹시 로그인에 대한 정보나, 다ip 등을 얻고 싶으면 이 부분 수정바랍니다. def check_valid_login(user_name, password): #client_ip = request.client.host #print(client_ip) flag = (user_name, password) in KEY return flag # 비디오에서 처음 몇 프레임을 자를지 변수입니다. MAX_FRAME = 24 #VFI inference 코드를 그대로 가져왔습니다. DEVICE = 0#"cuda" torch.cuda.set_device(DEVICE) #ROOT = args.root #SAVE_ROOT = f'output' SCALE = 1 pyr_level = 7 nr_lvl_skipped = 0 splat_mode = "average" pretrain_path = "./pretrained/upr_freq002.pth" model = upr_freq002.Model(pyr_level=pyr_level, nr_lvl_skipped=nr_lvl_skipped, splat_mode=splat_mode) sd = torch.load(pretrain_path, map_location='cpu') sd = sd['model'] if 'model' in sd.keys() else sd print(model.load_state_dict(sd)) model = model.to(DEVICE) def get_sorted_img(file_path): return sorted(glob.glob(os.path.join(file_path, f"*.png")), key = lambda x : float(x.split("_")[-1][:-4])) def multiple_pad(image, multiple): _,_,H,W = image.size() pad1 = multiple-(H%multiple) if H%multiple!=0 else 0 pad2 = multiple-(W%multiple) if W%multiple!=0 else 0 return TF.pad(image, (0,0,pad2,pad1)) #이미지 1(path1), 2를 VFI하여 가운데 이미지를 생성하는 함수입니다. def multiple_VFIx2(path1, path2, output_name): file_list = [path1, path2] img_list = [(torch.from_numpy(cv2.imread(file)[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE) for file in file_list] img_list = [multiple_pad(img, SCALE) for k, img in enumerate(img_list)] img_list = [F.interpolate(img, scale_factor=1/SCALE, mode='bicubic') for k, img in enumerate(img_list)] img0,img1 = img_list _,_,Hori,Wori = img0.size() with torch.no_grad(): result_dict, extra_dict = model(img0, img1, pyr_level=pyr_level, nr_lvl_skipped=nr_lvl_skipped, time_step=0.5) out = F.interpolate(result_dict['imgt_pred'], scale_factor=SCALE, mode='bicubic')[:,:,:Hori,:Wori].clamp(0,1) cv2.imwrite(output_name, (out[0].cpu().permute(1,2,0)*255).numpy().astype(np.uint8)[:,:,[2,1,0]]) torch.cuda.empty_cache() #1, 2를 3번 VFI하여 3장을 만드는 함수입니다. """ def multiple_VFIx4(path1, path2, name1, name2, name3): multiple_VFIx2(path1, path2, name2) multiple_VFIx2(path1, name2, name1) multiple_VFIx2(name2, path2, name3) """ def multiple_VFIx4(path1, path2): frac = [".25", ".5", ".75"] name1 , name2, name3 = [f"{path1[:-4]}{f}.png" for f in frac] multiple_VFIx2(path1, path2, name2) multiple_VFIx2(path1, name2, name1) multiple_VFIx2(name2, path2, name3) #0, 0.125 , 0.25, 0.5, 0.75, 0.875, 1로 5장 생성 def multiple_VFIx6(path1, path2): frac = [".125", ".25", ".75", ".875"] name_inf1 , name1, name2, name_inf2 = [f"{path1[:-4]}{f}.png" for f in frac] multiple_VFIx4(path1, path2) multiple_VFIx2(path1, name1, name_inf1) multiple_VFIx2(name2, path2, name_inf2) #비디오에서 fix를 하여 이미지를 대체하여 출력하는 함수입니다. def fix_img(idx, fixed_list, input_dir = "input", output_dir = "output"): idx = int(idx) #올바르지 않거나, 이미 fix 했다면 변화 x if idx < 1 or idx > MAX_FRAME - 2 or fixed_list[idx] == 1: return { fix_result_gallery : gr.Gallery(), fix_result_group : gr.Group(), fixed_frame : gr.Text() } now_time = os.path.basename(input_dir) output_dir = os.path.join(output_dir, f"fix_{now_time}") os.makedirs(output_dir, exist_ok = True) output_name = os.path.join(output_dir, f"img_{idx:03d}.png") multiple_VFIx2(os.path.join(input_dir, f"img_{idx - 1:03d}.png"), os.path.join(input_dir, f"img_{idx + 1:03d}.png"), output_name) fixed_list[idx] = 1 fixed_frame_string = "" result_list = [] name_list = [] #순차적으로 결과 갤러리 갱신 for i in range(MAX_FRAME): if fixed_list[i] == 1: name_list.append(f"(fixed) frame {i}") result_list.append(os.path.join(output_dir, f"img_{i:03d}.png")) fixed_frame_string += f"{i}, " else: name_list.append(f"frame {i}") result_list.append(os.path.join(input_dir, f"img_{i:03d}.png")) return { fix_result_gallery : gr.Gallery(value = [(img, name) for img, name in zip(result_list, name_list)], selected_index = idx), fix_result_group : gr.Group(visible=True), fixed_frame : gr.Text(visible=True, value = fixed_frame_string[:-2]), } #주어진 ease_val 리스트의 값 바탕으로 ease를 실행시키는 함수입니다. def ease_frames(ease_val, input_dir = "input", output_dir = "output", progress=gr.Progress(track_tqdm=False)): #now = os.path.basename(input_dir) now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') output_dir = os.path.join(output_dir, f"ease_{now}") os.makedirs(output_dir, exist_ok = True) out_frame_list = [os.path.join(output_dir, f"img_{i:03d}.png") for i in range(MAX_FRAME)] for i, f in enumerate([os.path.join(input_dir, f"img_{i:03d}.png") for i in range(MAX_FRAME)]): shutil.copyfile(f, out_frame_list[i]) img_name = [] for i in progress.tqdm(range(MAX_FRAME - 1), desc = "VFI frames..."): img_name.append(f"frame {i}") if ease_val[i] == 1: pass #x1는 아무것도, x2는 한 장, x4는 3장 # 아래 글자 추가 부분은 새로운 이미지의 제목 바꾸는 부분입니다. elif ease_val[i] == 2: multiple_VFIx2(out_frame_list[i], out_frame_list[i + 1] , os.path.join(output_dir, f"img_{i:03d}.5.png")) img_name.append(f"(new) frame {i + 0.5}") elif ease_val[i] == 3: multiple_VFIx4(out_frame_list[i], out_frame_list[i + 1]) img_name.append(f"(new) frame {i + 0.25}") img_name.append(f"(new) frame {i + 0.5}") img_name.append(f"(new) frame {i + 0.75}") img_name.append(f"frame {MAX_FRAME - 1}") files = get_sorted_img(output_dir) #다운로드용 zip 파일 zip_name = os.path.join(output_dir,"frame_list.zip") with zipfile.ZipFile(zip_name, 'w', compression=zipfile.ZIP_DEFLATED) as new_zip: for x in progress.tqdm(files, desc ="compress file..."): new_zip.write(x, os.path.basename(x)) return { ease_result_gallery : [(file, name) for file, name in zip(files, img_name)], ease_make_video : gr.Accordion(visible = True), last_ease_dir : output_dir, ease_zip : gr.File(value = zip_name) } # 이미지 두 장을 받아 VFI를 수행하는 함수입니다. def VFI_two(l, r, flag ,output_dir = "output"): now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') output_dir = os.path.join(output_dir, f"fix_img_{now}") os.makedirs(output_dir, exist_ok = True) l = Image.fromarray(l) r = Image.fromarray(r) #메모리 초과를 막기 위해 적당한 크기 픽셀 이하가 되도록 관리 W, H = l.size #1920 * 1080 * 1.2 * 1.2 가 대충 3e6라 그걸 기준으로 잡았습니다. mul = ((3e6) / (W * H)) ** (1/2) H, W = int(H * mul), int(W * mul) #이미지가 커서 줄여야 한다면 감소, 아님 그냥 입력 if mul < 1: l = l.resize((W, H)) r = r.resize((W, H)) l_name, r_name = f"{output_dir}/img_000.png", f"{output_dir}/img_001.png" l.save(l_name) r.save(r_name) if flag == "x4": multiple_VFIx4(l_name, r_name) elif flag == "x2": output_name = f"{output_dir}/img_000.5.png" multiple_VFIx2(l_name, r_name, output_name) else: multiple_VFIx6(l_name, r_name) return { frame_gen_result_gallery : gr.Gallery(visible=True, value=get_sorted_img(output_dir)) } #다른 이미지 입력을 위해 입력된 이미지를 날리는 합수입니다. def clear_fix(): return{ img_0 : gr.Image(label="start image", sources =["upload"], value = None), img_1 : gr.Image(label="end image", sources =["upload"], value = None), frame_gen_result_gallery : gr.Gallery(visible=True, value=None) } with gr.Blocks(theme=gr.themes.Default(), title = "Inshorts Animator V. 0.5") as demo: def info(request: gr.Request): #ip를 얻는 부분입니다. #추후 특정 ip 허용, 차단 등이 필요하면 이쪽 참고해 주세요 headers = request.headers print(headers["x-forwarded-for"].split(",")) demo.load(info, None) gr.Markdown(f"""# Inshorts Animator V. 0.5 WebUI (Permitted User Only)""") with gr.Tab("Mid Frame Generator"): with gr.Column(): with gr.Row(): img_0 = gr.Image(label="start image", sources =["upload"]) img_1 = gr.Image(label="end image", sources =["upload"]) with gr.Row(): VFI_flag = gr.Radio(["x2", "x4", "x6(side ease)"], label="VFI ratio", value = "x2", interactive = True) image_button = gr.Button("Run model") frame_gen_result_gallery = gr.Gallery(visible=True, label="result", columns=[5], rows=[1], object_fit="contain", height="auto", preview = True, interactive = False) image_button.click(VFI_two, inputs=[img_0, img_1, VFI_flag], outputs=[frame_gen_result_gallery]) clear_button = gr.Button("Clear images") clear_button.click(clear_fix, inputs=[], outputs=[img_0, img_1, frame_gen_result_gallery]) with gr.Tab("Video"): with gr.Group(visible=True) as video_input_group: gr.Markdown(f"""#### only can handle {MAX_FRAME} frames""") with gr.Column(): input_dir = gr.State("") fps = gr.Number(visible=False) video_input = gr.Video(label="Input Video", interactive=True, sources=['upload']) gr.Markdown(f"""If video frame size is big, it will be resized""") upload_button = gr.Button("upload video") with gr.Group(visible=False) as image_edit_group: with gr.Row(): with gr.Column(): with gr.Tab("Original Frame (for Monitoring)"): fixed_list = gr.State([0] * (MAX_FRAME)) selected = gr.Number(visible=False, label = "selected frame", interactive = False) image_gallery = gr.Gallery( label="inputs", columns=[MAX_FRAME], rows=[1], object_fit="contain", height="auto", preview = True, show_download_button=False) clear_video = gr.Button("clear video") with gr.Tab("Frame Fixer"): with gr.Column(): #with gr.Row(): #with gr.Row(): with gr.Group(visible=True) as fix_result_group: fix_result_gallery = gr.Gallery( label="result", columns=[MAX_FRAME], rows=[1], object_fit="contain", height="auto", preview = True, interactive = False) fix_button = gr.Button(visible = True) with gr.Row(): fixed_frame = gr.Text(visible=False, label = "fixed frame", interactive = False) fix_button.click(fix_img, inputs=[selected, fixed_list, input_dir], outputs=[fix_result_gallery, fix_result_group, fixed_frame]) def update_fix_button_visible(evt: gr.SelectData): flag = 0 < evt.index < MAX_FRAME - 1 msg = f"fix frame {evt.index}" if flag else f"can only fix 1 ~ {MAX_FRAME - 2}" return { fix_button:gr.Button(msg, visible=True), selected : evt.index, fix_result_gallery : gr.Gallery(selected_index = evt.index) } image_gallery.select(update_fix_button_visible, None, [fix_button, selected, fix_result_gallery]) with gr.Tab("Motion easer"): with gr.Column(): with gr.Column(): #with gr.Row(): with gr.Group(visible=True) as ease_result_group: last_ease_dir = gr.State("") ease_result_gallery = gr.Gallery( label="result", columns=[MAX_FRAME], rows=[4], object_fit="contain", height="auto", preview = True, interactive = False) ease_button = gr.Button("ease", visible = True) plt_data = gr.State([1] * (MAX_FRAME - 1)) VFI_x = gr.Radio([("x1", 1), ("x2", 2), ("x4", 3)], value = 1, label="Slow ratio", info="adjust Slow ratio", interactive = True) with gr.Row(): edit_one_button = gr.Button("edit one scale", visible = True) edit_all_button = gr.Button("edit all scale", visible = True) now_frame = gr.Slider(0, MAX_FRAME - 1 - 1, step=1, label="Start frame", info="Choose Start frame to make slow. Interpolation will apply to (frame ~ frame + 1)") def plt_edit(data): fig = plt.figure() x = np.arange(0, MAX_FRAME - 1) + 0.5 y = np.array(data) plt.plot(x , y, color = 'black', marker = "o", linewidth = "2.5") plt.xticks(np.arange(0, MAX_FRAME)) plt.yticks([1, 2, 3], ["x1", "x2\nslow", "x4\nslow"]) plt.gca().invert_yaxis() plt.grid(True) plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True return fig ease_plot = gr.Plot(value = plt_edit(plt_data.value), show_label=False) with gr.Accordion("get result", visible = False) as ease_make_video: ease_zip = gr.File(label = "Download all image frames in Zip", interactive = False) make_video_button = gr.Button("make video") result_video = gr.Video(interactive = False) def make_video(frame_dir, fps): t = os.path.basename(frame_dir) output_name = f"{frame_dir}/{t}.mp4" if os.path.exists(output_name): os.remove(output_name) frame_list = get_sorted_img(frame_dir) with open(f"{frame_dir}/input.txt", "w") as f: for line in frame_list: f.write(f"file '{os.path.basename(line)}'\n") cmd = f'ffmpeg -r {fps} -f concat -safe 0 -i {frame_dir}/input.txt -c:v libx264 -preset veryslow -crf 10 {output_name}' os.system(cmd) return output_name make_video_button.click(make_video, inputs = [last_ease_dir, fps], outputs = [result_video]) ease_button.click(ease_frames, inputs=[plt_data, input_dir], outputs=[ease_result_gallery, ease_make_video, last_ease_dir, ease_zip]) def edit_one_scale(data, idx, x): if idx < MAX_FRAME - 1: data[idx] = x if x else 1 return plt_edit(data) edit_one_button.click(edit_one_scale, inputs=[plt_data, now_frame, VFI_x] , outputs=[ease_plot]) def edit_all_scale(data, x): for i in range(len(data)): data[i] = x if x else 1 return plt_edit(data) edit_all_button.click(edit_all_scale, inputs=[plt_data, VFI_x], outputs=[ease_plot]) def clear_vd(plt_data, fixed_list): for i in range(len(plt_data)): plt_data[i] = 1 for i in range(len(fixed_list)): fixed_list[i] = 0 return {video_input:gr.Video(label="Input Video", interactive=True, sources=['upload'], value = None), ease_result_gallery : gr.Gallery( label="result", columns=[MAX_FRAME], rows=[4], object_fit="contain", height="auto", preview = True, interactive = False, value = None), fix_result_gallery : gr.Gallery( label="result", columns=[MAX_FRAME], rows=[1], object_fit="contain", height="auto", preview = True, interactive = False, value = None), fixed_frame : gr.Text(visible=False, label = "fixed frame", interactive = False, value = None), ease_make_video : gr.Accordion(visible = True), video_input_group:gr.Group(visible=True), image_edit_group:gr.Group(visible=False), ease_plot : gr.Plot(value = plt_edit(plt_data))} clear_video.click(clear_vd, inputs=[plt_data, fixed_list],outputs=[video_input, ease_result_gallery, fix_result_gallery, fixed_frame, ease_make_video, video_input_group, image_edit_group, ease_plot]) def update_video_visible(video): if not video: return {video_input_group:gr.Group(visible=True), image_edit_group:gr.Group(visible=False), image_gallery:[], input_dir : "", fps : 0 } now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') input_now = os.path.join("input", now) os.makedirs(input_now, exist_ok = True) cap = cv2.VideoCapture(video) frame_count = 0 video_fps = cap.get(cv2.CAP_PROP_FPS) #print('video fps:', video_fps) H = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) W = cap.get(cv2.CAP_PROP_FRAME_WIDTH) mul = ((3e6) / (W * H)) ** (1/2) H, W = int(H * mul), int(W * mul) frame_name_list = [] while cap.isOpened(): ret, frame = cap.read() if not ret: break img_name = os.path.join(input_now, f"img_{frame_count:03d}.png") if mul < 1: frame = cv2.resize(frame, (W, H), interpolation=cv2.INTER_CUBIC) cv2.imwrite(img_name, frame) frame_name_list.append((img_name, f"frame {frame_count}")) frame_count += 1 if frame_count >= MAX_FRAME: break cap.release() return {video_input_group:gr.Group(visible=False), image_edit_group:gr.Group(visible=True), image_gallery:frame_name_list, input_dir : input_now, fps : video_fps } upload_button.click(update_video_visible, [video_input], [video_input_group, image_edit_group, image_gallery, input_dir, fps]) if __name__ == '__main__': demo.launch(allowed_paths=["./input", "./output"], auth = check_valid_login, auth_message = "Inshorts Animator V. 0.5 WebUI (Permitted User Only)", share = True)