File size: 2,014 Bytes
233b0e2
bf3bbe6
 
c1cd135
efac2d0
 
 
3de3290
efac2d0
bf3bbe6
 
 
c1cd135
bf3bbe6
c1cd135
efac2d0
 
 
 
 
 
 
 
 
 
 
c1cd135
 
 
 
df00b4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf3bbe6
d7e26ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import torch
import gradio as gr

model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3")

if torch.cuda.is_available():
    print("Using GPU")
    model = model.cuda()

convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")


def inference(video):
    convert_video(
        model,  # The loaded model, can be on any device (cpu or cuda).
        input_source=video,  # A video file or an image sequence directory.
        downsample_ratio=0.25,  # [Optional] If None, make downsampled max size be 512px.
        output_type="video",  # Choose "video" or "png_sequence"
        output_composition="com.mp4",  # File path if video; directory path if png sequence.
        output_alpha=None,  # [Optional] Output the raw alpha prediction.
        output_foreground=None,  # [Optional] Output the raw foreground prediction.
        output_video_mbps=4,  # Output video mbps. Not needed for png sequence.
        seq_chunk=12,  # Process n frames at once for better parallelism.
        num_workers=1,  # Only for image sequence input. Reader threads.
        progress=True,  # Print conversion progress.
    )
    return "com.mp4"


with gr.Blocks(title="Robust Video Matting") as block:
    gr.Markdown("# Robust Video Matting")
    gr.Markdown(
        "Gradio demo for Robust Video Matting. To use it, simply upload your video, or click one of the examples to load them. Read more at the links below."
    )
    with gr.Row():
        inp = gr.Video(label="Input Video")
        out = gr.Video(label="Output Video")
    btn = gr.Button("Run")
    btn.click(inference, inputs=inp, outputs=out)

    gr.Examples(
        examples=[["example.mp4"]],
        inputs=[inp],
    )
    gr.HTML(
        "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.11515'>Robust High-Resolution Video Matting with Temporal Guidance</a> | <a href='https://github.com/PeterL1n/RobustVideoMatting'>Github Repo</a></p>"
    )

block.queue(api_open=False, max_size=5).launch()