txya900619's picture
feat: init upload
5e8e534
raw
history blame
3.9 kB
import json
import os
import tempfile
import gradio as gr
import TTS
from TTS.utils.synthesizer import Synthesizer
import numpy as np
from huggingface_hub import snapshot_download
from omegaconf import OmegaConf
from ipa.ipa import get_ipa, parse_ipa
from replace.tts import ChangedVitsConfig
TTS.tts.configs.vits_config.VitsConfig = ChangedVitsConfig
def load_model(model_id):
model_dir = snapshot_download(model_id)
config_file_path = os.path.join(model_dir, "config.json")
model_ckpt_path = os.path.join(model_dir, "model.pth")
speaker_file_path = os.path.join(model_dir, "speakers.pth")
language_file_path = os.path.join(model_dir, "language_ids.json")
temp_config_path = "temp_config.json"
with open(config_file_path, "r") as f:
content = f.read()
content = content.replace("speakers.pth", speaker_file_path)
content = content.replace("language_ids.json", language_file_path)
f.close()
with open(temp_config_path, "w") as f:
f.write(content)
f.close()
return Synthesizer(tts_checkpoint=model_ckpt_path, tts_config_path=temp_config_path)
OmegaConf.register_new_resolver("load_model", load_model)
models_config = OmegaConf.to_object(OmegaConf.load("configs/models.yaml"))
def text_to_speech(model_id: str, speaker: str, dialect, text: str):
model = models_config[model_id]["model"]
if len(text) == 0:
raise gr.Error("請勿輸入空字串。")
words, ipa, pinyin, missing_words = get_ipa(text, dialect=dialect)
if len(missing_words) > 0:
raise gr.Error(
f"句子中的[{','.join(missing_words)}]目前無法轉成 ipa。請嘗試其他句子。"
)
wav = model.tts(
parse_ipa(ipa),
speaker_name=speaker,
language_name=dialect,
split_sentences=False,
)
return words, pinyin, (16000, np.array(wav))
def when_model_selected(model_id):
model_config = models_config[model_id]
speaker_drop_down_choices = [(k,v) for k, v in model_config["speaker_mapping"].items()]
dialect_drop_down_choices = model_config["avalible_dialect"]
return gr.update(choices=speaker_drop_down_choices), gr.update(choices=dialect_drop_down_choices)
demo = gr.Blocks(
title="臺灣客語語音生成系統",
css="@import url(https://tauhu.tw/tauhu-oo.css);",
theme=gr.themes.Default(
font=(
"tauhu-oo",
gr.themes.GoogleFont("Source Sans Pro"),
"ui-sans-serif",
"system-ui",
"sans-serif",
)
),
)
with demo:
default_model_id = list(models_config.keys())[0]
model_drop_down = gr.Dropdown(
models_config.keys(),
value=default_model_id,
)
speaker_drop_down = gr.Dropdown(
choices=[(k,v) for k, v in models_config[default_model_id]["speaker_mapping"].items()],
value=list(models_config[default_model_id]["speaker_mapping"].values())[0]
)
dialect_drop_down = gr.Dropdown(
choices=models_config[default_model_id]["avalible_dialect"],
value=models_config[default_model_id]["avalible_dialect"][0]
)
model_drop_down.input(
when_model_selected,
inputs=[model_drop_down],
outputs=[speaker_drop_down, dialect_drop_down]
)
gr.Markdown(
"""
# 臺灣客語語音生成系統
"""
)
gr.Interface(
text_to_speech,
inputs=[
model_drop_down,
speaker_drop_down,
dialect_drop_down,
gr.Textbox(),
],
outputs=[
gr.Textbox(interactive=False, label="word segment"),
gr.Textbox(interactive=False, label="pinyin"),
gr.Audio(
interactive=False, label="generated speech", show_download_button=True
),
],
allow_flagging="auto",
)
demo.launch()