vad_go / main.py
HoneyTian's picture
update
25144ed
raw
history blame
4.89 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import logging
import json
from pathlib import Path
import platform
import re
from project_settings import project_path, log_directory
import log
log.setup(log_directory=log_directory)
import gradio as gr
from toolbox.os.command import Command
main_logger = logging.getLogger("main")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--example_wav_dir",
default=(project_path / "data/examples").as_posix(),
type=str
)
args = parser.parse_args()
return args
def process_uploaded_file(vad_engine: str, filename: str, silence_time: float = 0.3, longest_activate: float = 3.0) -> str:
if vad_engine == "nx_vad":
return run_nx_vad(filename, silence_time, longest_activate)
elif vad_engine == "silero-vad":
return run_silero_vad(filename, silence_time, longest_activate)
else:
return f"vad engine invalid: {vad_engine}"
def run_nx_vad(filename: str, silence_time: float = 0.3, longest_activate: float = 3.0) -> str:
filename = Path(filename).as_posix()
main_logger.info("do nx vad: {}".format(filename))
cmd = "vad_bins/nx_vad --filename {} --silence_time {} --longest_activate {}".format(
filename, silence_time, longest_activate
)
vad_result = Command.popen(cmd)
pattern = "(\\d+)[\r\n]VadFlagPrepare[\r\n](\\d+)[\r\n]VadFlagSpeaking(?:[\r\n](?:\\d+)[\r\n]VadFlagPause[\r\n](?:\\d+)[\r\n]VadFlagSpeaking)?[\r\n](\\d+)[\r\n]VadFlagNoSpeech"
vad_timestamps = re.findall(pattern, vad_result, flags=re.DOTALL)
vad_timestamps: str = json.dumps(vad_timestamps, ensure_ascii=False, indent=2)
return vad_timestamps
def run_silero_vad(filename: str, silence_time: float = 0.3, longest_activate: float = 3.0) -> str:
filename = Path(filename).as_posix()
main_logger.info("do silero vad: {}".format(filename))
cmd = "vad_bins/silero {}".format(
filename
)
vad_result = Command.popen(cmd)
return vad_result
def shell(cmd: str):
return Command.popen(cmd)
def main():
args = get_args()
title = "## GO语言实现的VAD."
# examples
example_wav_dir = Path(args.example_wav_dir)
examples = list()
for filename in example_wav_dir.glob("*.wav"):
examples.append(
[
"nx_vad",
filename.as_posix(),
0.3, 3.0,
]
)
# blocks
with gr.Blocks() as blocks:
gr.Markdown(value=title)
with gr.Tabs():
with gr.TabItem("Upload from disk"):
uploaded_file = gr.Audio(
sources=["upload"],
type="filepath",
label="Upload from disk",
)
uploaded_vad_engine = gr.Dropdown(choices=["nx_vad", "silero-vad"], value="nx_vad", label="vad_engine")
uploaded_silence_time = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="silence time")
uploaded_longest_activate = gr.Slider(minimum=0.0, maximum=20.0, value=3.0, step=0.1, label="longest activate")
upload_button = gr.Button("Run VAD")
uploaded_output = gr.Textbox(label="outputs")
gr.Examples(
examples=examples,
inputs=[
uploaded_vad_engine,
uploaded_file,
uploaded_silence_time,
uploaded_longest_activate,
],
outputs=[
uploaded_output
],
fn=process_uploaded_file
)
upload_button.click(
process_uploaded_file,
inputs=[
uploaded_vad_engine,
uploaded_file,
uploaded_silence_time,
uploaded_longest_activate,
],
outputs=[
uploaded_output
],
)
with gr.TabItem("shell"):
shell_text = gr.Textbox(label="cmd")
shell_button = gr.Button("run")
shell_output = gr.Textbox(label="output")
shell_button.click(
shell,
inputs=[
shell_text,
],
outputs=[
shell_output
],
)
blocks.queue().launch(
share=False if platform.system() == "Windows" else False,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=7860
)
return
if __name__ == "__main__":
main()