AutoLink / app.py
xingzhehe's picture
cache examples
f226284
raw
history blame
2.73 kB
from models.model import Model as AutoLink
import gradio as gr
import PIL
import torch
import os
import imageio
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
autolink = AutoLink.load_from_checkpoint(os.path.join("checkpoints", "celeba_wild_k32_m0.8_b16_t0.00075_sklr512", "model.ckpt"))
autolink.to(device)
def predict_image(image_in: PIL.Image.Image) -> PIL.Image.Image:
if image_in == None:
raise gr.Error("Please upload a video or image.")
edge_map = autolink(image_in)
return edge_map
def predict_video(video_in: str) -> str:
if video_in == None:
raise gr.Error("Please upload a video or image.")
video_out = video_in[:-4] + '_out.mp4'
video_in = imageio.get_reader(video_in)
writer = imageio.get_writer(video_out, mode='I', fps=video_in.get_meta_data()['fps'])
for image_in in video_in:
image_in = PIL.Image.fromarray(image_in)
edge_map = autolink(image_in)
writer.append_data(np.array(edge_map))
writer.close()
return video_out
with gr.Blocks() as blocks:
gr.Markdown("""
# AutoLink
## Self-supervised Learning of Human Skeletons and Object Outlines by Linking Keypoints
## This demo is specifically for self-supervised facial landmark detection
#### Note that there is no face detection in this demo, so please make sure that the face is center-around in the image.
* [Paper](https://arxiv.org/abs/2205.10636)
* [Project Page](https://xingzhehe.github.io/autolink/)
* [GitHub](https://github.com/xingzhehe/AutoLink-Self-supervised-Learning-of-Human-Skeletons-and-Object-Outlines-by-Linking-Keypoints)
""")
with gr.Tab("Image"):
with gr.Row():
with gr.Column():
image_in = gr.Image(source="upload", type="pil", visible=True)
with gr.Column():
image_out = gr.Image()
run_btn = gr.Button("Run")
run_btn.click(fn=predict_image, inputs=[image_in], outputs=[image_out])
gr.Examples(fn=predict_image, examples=[["assets/jackie_chan.jpg", None]],
inputs=[image_in], outputs=[image_out],
cache_examples=True)
with gr.Tab("Video") as tab:
with gr.Row():
with gr.Column():
video_in = gr.Video(source="upload", type="mp4")
with gr.Column():
video_out = gr.Video()
run_btn = gr.Button("Run")
run_btn.click(fn=predict_video, inputs=[video_in], outputs=[video_out])
gr.Examples(fn=predict_video, examples=[["assets/00344.mp4"],],
inputs=[video_in], outputs=[video_out],
cache_examples=True)
blocks.launch()