File size: 5,605 Bytes
16c8067
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import subprocess
from pathlib import Path

import gradio as gr

from config import hparams as hp
from config import hparams_gradio as hp_gradio
from nota_wav2lip import Wav2LipModelComparisonGradio

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = hp_gradio.device
print(f'Using {device} for inference.')
video_label_dict = hp_gradio.sample.video
audio_label_dict = hp_gradio.sample.audio

LRS_ORIGINAL_URL = os.getenv('LRS_ORIGINAL_URL', None)
LRS_COMPRESSED_URL = os.getenv('LRS_COMPRESSED_URL', None)
LRS_INFERENCE_SAMPLE = os.getenv('LRS_INFERENCE_SAMPLE', None)

if not Path(hp.inference.model.wav2lip.checkpoint).exists() and LRS_ORIGINAL_URL is not None:
    subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.wav2lip.checkpoint} {LRS_ORIGINAL_URL}", shell=True)
if not Path(hp.inference.model.nota_wav2lip.checkpoint).exists() and LRS_COMPRESSED_URL is not None:
    subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.nota_wav2lip.checkpoint} {LRS_COMPRESSED_URL}", shell=True)

path_inference_sample = "sample.tar.gz"
if not Path(path_inference_sample).exists() and LRS_INFERENCE_SAMPLE is not None:
    subprocess.call(f"wget --no-check-certificate -O {path_inference_sample} {LRS_INFERENCE_SAMPLE}", shell=True)
subprocess.call(f"tar -zxvf {path_inference_sample}", shell=True)


if __name__ == "__main__":

    servicer = Wav2LipModelComparisonGradio(
        device=device,
        video_label_dict=video_label_dict,
        audio_label_list=audio_label_dict,
        default_video='v1',
        default_audio='a1'
    )

    for video_name in sorted(video_label_dict):
        video_stem = Path(video_label_dict[video_name])
        servicer.update_video(video_stem, video_stem.with_suffix('.json'),
                              name=video_name)

    for audio_name in sorted(audio_label_dict):
        audio_path = Path(audio_label_dict[audio_name])
        servicer.update_audio(audio_path, name=audio_name)

    with gr.Blocks(theme='nota-ai/theme', css=Path('docs/main.css').read_text()) as demo:
        gr.Markdown(Path('docs/header.md').read_text())
        gr.Markdown(Path('docs/description.md').read_text())
        with gr.Row():
            with gr.Column(variant='panel'):

                gr.Markdown('## Select input video and audio', sanitize_html=False)
                # Define samples
                sample_video = gr.Video(interactive=False, label="Input Video")
                sample_audio = gr.Audio(interactive=False, label="Input Audio")

                # Define radio inputs
                video_selection = gr.components.Radio(video_label_dict,
                                                      type='value', label="Select an input video:")
                audio_selection = gr.components.Radio(audio_label_dict,
                                                      type='value', label="Select an input audio:")
                # Define button inputs
                with gr.Row(equal_height=True):
                    generate_original_button = gr.Button(value="Generate with Original Model", variant="primary")
                    generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary")
            with gr.Column(variant='panel'):
                # Define original model output components
                gr.Markdown('## Original Wav2Lip')
                original_model_output = gr.Video(label="Original Model", interactive=False)
                with gr.Column():
                    with gr.Row(equal_height=True):
                        original_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)")
                        original_model_fps = gr.Textbox(value="", label="FPS")
                    original_model_params = gr.Textbox(value=servicer.params['wav2lip'], label="# Parameters")
            with gr.Column(variant='panel'):
                # Define compressed model output components
                gr.Markdown('## Compressed Wav2Lip (Ours)')
                compressed_model_output = gr.Video(label="Compressed Model", interactive=False)
                with gr.Column():
                    with gr.Row(equal_height=True):
                        compressed_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)")
                        compressed_model_fps = gr.Textbox(value="", label="FPS")
                    compressed_model_params = gr.Textbox(value=servicer.params['nota_wav2lip'], label="# Parameters")

        # Switch video and audio samples when selecting the raido button
        video_selection.change(fn=servicer.switch_video_samples, inputs=video_selection, outputs=sample_video)
        audio_selection.change(fn=servicer.switch_audio_samples, inputs=audio_selection, outputs=sample_audio)

        # Click the generate button for original model
        generate_original_button.click(servicer.generate_original_model,
                                       inputs=[video_selection, audio_selection],
                                       outputs=[original_model_output, original_model_inference_time, original_model_fps])
        # Click the generate button for compressed model
        generate_compressed_button.click(servicer.generate_compressed_model,
                                         inputs=[video_selection, audio_selection],
                                         outputs=[compressed_model_output, compressed_model_inference_time, compressed_model_fps])

        gr.Markdown(Path('docs/footer.md').read_text())

    demo.queue().launch()