Spaces:
Runtime error
Runtime error
import os | |
import subprocess | |
from pathlib import Path | |
import gradio as gr | |
from config import hparams as hp | |
from config import hparams_gradio as hp_gradio | |
from nota_wav2lip import Wav2LipModelComparisonGradio | |
# device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
device = hp_gradio.device | |
print(f'Using {device} for inference.') | |
video_label_dict = hp_gradio.sample.video | |
audio_label_dict = hp_gradio.sample.audio | |
LRS_ORIGINAL_URL = os.getenv('LRS_ORIGINAL_URL', None) | |
LRS_COMPRESSED_URL = os.getenv('LRS_COMPRESSED_URL', None) | |
LRS_INFERENCE_SAMPLE = os.getenv('LRS_INFERENCE_SAMPLE', None) | |
if not Path(hp.inference.model.wav2lip.checkpoint).exists() and LRS_ORIGINAL_URL is not None: | |
subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.wav2lip.checkpoint} {LRS_ORIGINAL_URL}", shell=True) | |
if not Path(hp.inference.model.nota_wav2lip.checkpoint).exists() and LRS_COMPRESSED_URL is not None: | |
subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.nota_wav2lip.checkpoint} {LRS_COMPRESSED_URL}", shell=True) | |
path_inference_sample = "sample.tar.gz" | |
if not Path(path_inference_sample).exists() and LRS_INFERENCE_SAMPLE is not None: | |
subprocess.call(f"wget --no-check-certificate -O {path_inference_sample} {LRS_INFERENCE_SAMPLE}", shell=True) | |
subprocess.call(f"tar -zxvf {path_inference_sample}", shell=True) | |
if __name__ == "__main__": | |
servicer = Wav2LipModelComparisonGradio( | |
device=device, | |
video_label_dict=video_label_dict, | |
audio_label_list=audio_label_dict, | |
default_video='v1', | |
default_audio='a1' | |
) | |
for video_name in sorted(video_label_dict): | |
video_stem = Path(video_label_dict[video_name]) | |
servicer.update_video(video_stem, video_stem.with_suffix('.json'), | |
name=video_name) | |
for audio_name in sorted(audio_label_dict): | |
audio_path = Path(audio_label_dict[audio_name]) | |
servicer.update_audio(audio_path, name=audio_name) | |
with gr.Blocks(theme='nota-ai/theme', css=Path('docs/main.css').read_text()) as demo: | |
gr.Markdown(Path('docs/header.md').read_text()) | |
gr.Markdown(Path('docs/description.md').read_text()) | |
with gr.Row(): | |
with gr.Column(variant='panel'): | |
gr.Markdown('## Select input video and audio', sanitize_html=False) | |
# Define samples | |
sample_video = gr.Video(interactive=False, label="Input Video") | |
sample_audio = gr.Audio(interactive=False, label="Input Audio") | |
# Define radio inputs | |
video_selection = gr.components.Radio(video_label_dict, | |
type='value', label="Select an input video:") | |
audio_selection = gr.components.Radio(audio_label_dict, | |
type='value', label="Select an input audio:") | |
# Define button inputs | |
with gr.Row(equal_height=True): | |
generate_original_button = gr.Button(value="Generate with Original Model", variant="primary") | |
generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary") | |
with gr.Column(variant='panel'): | |
# Define original model output components | |
gr.Markdown('## Original Wav2Lip') | |
original_model_output = gr.Video(label="Original Model", interactive=False) | |
with gr.Column(): | |
with gr.Row(equal_height=True): | |
original_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)") | |
original_model_fps = gr.Textbox(value="", label="FPS") | |
original_model_params = gr.Textbox(value=servicer.params['wav2lip'], label="# Parameters") | |
with gr.Column(variant='panel'): | |
# Define compressed model output components | |
gr.Markdown('## Compressed Wav2Lip (Ours)') | |
compressed_model_output = gr.Video(label="Compressed Model", interactive=False) | |
with gr.Column(): | |
with gr.Row(equal_height=True): | |
compressed_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)") | |
compressed_model_fps = gr.Textbox(value="", label="FPS") | |
compressed_model_params = gr.Textbox(value=servicer.params['nota_wav2lip'], label="# Parameters") | |
# Switch video and audio samples when selecting the raido button | |
video_selection.change(fn=servicer.switch_video_samples, inputs=video_selection, outputs=sample_video) | |
audio_selection.change(fn=servicer.switch_audio_samples, inputs=audio_selection, outputs=sample_audio) | |
# Click the generate button for original model | |
generate_original_button.click(servicer.generate_original_model, | |
inputs=[video_selection, audio_selection], | |
outputs=[original_model_output, original_model_inference_time, original_model_fps]) | |
# Click the generate button for compressed model | |
generate_compressed_button.click(servicer.generate_compressed_model, | |
inputs=[video_selection, audio_selection], | |
outputs=[compressed_model_output, compressed_model_inference_time, compressed_model_fps]) | |
gr.Markdown(Path('docs/footer.md').read_text()) | |
demo.queue().launch() | |