File size: 3,897 Bytes
5e8e534
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import json
import os
import tempfile

import gradio as gr
import TTS
from TTS.utils.synthesizer import Synthesizer
import numpy as np
from huggingface_hub import snapshot_download
from omegaconf import OmegaConf

from ipa.ipa import get_ipa, parse_ipa
from replace.tts import ChangedVitsConfig

TTS.tts.configs.vits_config.VitsConfig = ChangedVitsConfig

def load_model(model_id):
    model_dir = snapshot_download(model_id)
    config_file_path = os.path.join(model_dir, "config.json")
    model_ckpt_path = os.path.join(model_dir, "model.pth")
    speaker_file_path = os.path.join(model_dir, "speakers.pth")
    language_file_path = os.path.join(model_dir, "language_ids.json")

    temp_config_path = "temp_config.json"
    with open(config_file_path, "r") as f:
        content = f.read()
        content = content.replace("speakers.pth", speaker_file_path)
        content = content.replace("language_ids.json", language_file_path)
        f.close()
    with open(temp_config_path, "w") as f:
        f.write(content)
        f.close()
    return Synthesizer(tts_checkpoint=model_ckpt_path, tts_config_path=temp_config_path)

OmegaConf.register_new_resolver("load_model", load_model)

models_config = OmegaConf.to_object(OmegaConf.load("configs/models.yaml"))

def text_to_speech(model_id: str, speaker: str, dialect, text: str):
    model = models_config[model_id]["model"]
    if len(text) == 0:
        raise gr.Error("請勿輸入空字串。")
    words, ipa, pinyin, missing_words = get_ipa(text, dialect=dialect)
    if len(missing_words) > 0:
        raise gr.Error(
            f"句子中的[{','.join(missing_words)}]目前無法轉成 ipa。請嘗試其他句子。"
        )

    wav = model.tts(
        parse_ipa(ipa),
        speaker_name=speaker,
        language_name=dialect,
        split_sentences=False,
    )

    return words, pinyin, (16000, np.array(wav))

def when_model_selected(model_id):
    model_config = models_config[model_id]
    speaker_drop_down_choices = [(k,v) for k, v in model_config["speaker_mapping"].items()]
    dialect_drop_down_choices = model_config["avalible_dialect"]
    return gr.update(choices=speaker_drop_down_choices), gr.update(choices=dialect_drop_down_choices)


demo = gr.Blocks(
    title="臺灣客語語音生成系統",
    css="@import url(https://tauhu.tw/tauhu-oo.css);",
    theme=gr.themes.Default(
        font=(
            "tauhu-oo",
            gr.themes.GoogleFont("Source Sans Pro"),
            "ui-sans-serif",
            "system-ui",
            "sans-serif",
        )
    ),
)

with demo:

    default_model_id = list(models_config.keys())[0]
    model_drop_down = gr.Dropdown(
        models_config.keys(),
        value=default_model_id,
    )
    speaker_drop_down = gr.Dropdown(
        choices=[(k,v) for k, v in models_config[default_model_id]["speaker_mapping"].items()],
        value=list(models_config[default_model_id]["speaker_mapping"].values())[0]
    )
    dialect_drop_down = gr.Dropdown(
        choices=models_config[default_model_id]["avalible_dialect"],
        value=models_config[default_model_id]["avalible_dialect"][0]
    )

    model_drop_down.input(
        when_model_selected,
        inputs=[model_drop_down],
        outputs=[speaker_drop_down, dialect_drop_down]
    )



    gr.Markdown(
        """
        # 臺灣客語語音生成系統
        """
    )
    gr.Interface(
        text_to_speech,
        inputs=[
            model_drop_down,
            speaker_drop_down,
            dialect_drop_down,
            gr.Textbox(),
        ],
        outputs=[
            gr.Textbox(interactive=False, label="word segment"),
            gr.Textbox(interactive=False, label="pinyin"),
            gr.Audio(
                interactive=False, label="generated speech", show_download_button=True
            ),
        ],
        allow_flagging="auto",
    )

demo.launch()