|
try: |
|
import spaces |
|
except: |
|
|
|
class NoneSpaces: |
|
def __init__(self): |
|
pass |
|
|
|
def GPU(self, fn): |
|
return fn |
|
|
|
spaces = NoneSpaces() |
|
|
|
import os |
|
import logging |
|
|
|
import numpy as np |
|
|
|
from modules.devices import devices |
|
from modules.synthesize_audio import synthesize_audio |
|
from modules.utils.cache import conditional_cache |
|
|
|
logging.basicConfig( |
|
level=os.getenv("LOG_LEVEL", "INFO"), |
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
) |
|
|
|
|
|
import gradio as gr |
|
|
|
import torch |
|
|
|
from modules.ssml import parse_ssml |
|
from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments |
|
|
|
from modules.speaker import speaker_mgr |
|
from modules.data import styles_mgr |
|
|
|
from modules.api.utils import calc_spk_style |
|
import modules.generate_audio as generate |
|
|
|
from modules.normalization import text_normalize |
|
from modules import refiner, config |
|
|
|
from modules.utils import env, audio |
|
from modules.SentenceSplitter import SentenceSplitter |
|
|
|
torch._dynamo.config.cache_size_limit = 64 |
|
torch._dynamo.config.suppress_errors = True |
|
torch.set_float32_matmul_precision("high") |
|
|
|
webui_config = { |
|
"tts_max": 1000, |
|
"ssml_max": 5000, |
|
"spliter_threshold": 100, |
|
"max_batch_size": 8, |
|
} |
|
|
|
|
|
def get_speakers(): |
|
return speaker_mgr.list_speakers() |
|
|
|
|
|
def get_styles(): |
|
return styles_mgr.list_items() |
|
|
|
|
|
def segments_length_limit(segments, total_max: int): |
|
ret_segments = [] |
|
total_len = 0 |
|
for seg in segments: |
|
total_len += len(seg["text"]) |
|
if total_len > total_max: |
|
break |
|
ret_segments.append(seg) |
|
return ret_segments |
|
|
|
|
|
@torch.inference_mode() |
|
@spaces.GPU |
|
def synthesize_ssml(ssml: str, batch_size=4): |
|
try: |
|
batch_size = int(batch_size) |
|
except Exception: |
|
batch_size = 8 |
|
|
|
ssml = ssml.strip() |
|
|
|
if ssml == "": |
|
return None |
|
|
|
segments = parse_ssml(ssml) |
|
max_len = webui_config["ssml_max"] |
|
segments = segments_length_limit(segments, max_len) |
|
|
|
if len(segments) == 0: |
|
return None |
|
|
|
synthesize = SynthesizeSegments(batch_size=batch_size) |
|
audio_segments = synthesize.synthesize_segments(segments) |
|
combined_audio = combine_audio_segments(audio_segments) |
|
|
|
return audio.pydub_to_np(combined_audio) |
|
|
|
|
|
@torch.inference_mode() |
|
@spaces.GPU |
|
def tts_generate( |
|
text, |
|
temperature, |
|
top_p, |
|
top_k, |
|
spk, |
|
infer_seed, |
|
use_decoder, |
|
prompt1, |
|
prompt2, |
|
prefix, |
|
style, |
|
disable_normalize=False, |
|
batch_size=4, |
|
): |
|
try: |
|
batch_size = int(batch_size) |
|
except Exception: |
|
batch_size = 4 |
|
|
|
max_len = webui_config["tts_max"] |
|
text = text.strip()[0:max_len] |
|
|
|
if text == "": |
|
return None |
|
|
|
if style == "*auto": |
|
style = None |
|
|
|
if isinstance(top_k, float): |
|
top_k = int(top_k) |
|
|
|
params = calc_spk_style(spk=spk, style=style) |
|
spk = params.get("spk", spk) |
|
|
|
infer_seed = infer_seed or params.get("seed", infer_seed) |
|
temperature = temperature or params.get("temperature", temperature) |
|
prefix = prefix or params.get("prefix", prefix) |
|
prompt1 = prompt1 or params.get("prompt1", "") |
|
prompt2 = prompt2 or params.get("prompt2", "") |
|
|
|
infer_seed = np.clip(infer_seed, -1, 2**32 - 1) |
|
infer_seed = int(infer_seed) |
|
|
|
if not disable_normalize: |
|
text = text_normalize(text) |
|
|
|
sample_rate, audio_data = synthesize_audio( |
|
text=text, |
|
temperature=temperature, |
|
top_P=top_p, |
|
top_K=top_k, |
|
spk=spk, |
|
infer_seed=infer_seed, |
|
use_decoder=use_decoder, |
|
prompt1=prompt1, |
|
prompt2=prompt2, |
|
prefix=prefix, |
|
batch_size=batch_size, |
|
) |
|
|
|
audio_data = audio.audio_to_int16(audio_data) |
|
return sample_rate, audio_data |
|
|
|
|
|
@torch.inference_mode() |
|
@spaces.GPU |
|
def refine_text(text: str, prompt: str): |
|
text = text_normalize(text) |
|
return refiner.refine_text(text, prompt=prompt) |
|
|
|
|
|
def read_local_readme(): |
|
with open("README.md", "r", encoding="utf-8") as file: |
|
content = file.read() |
|
content = content[content.index("# 🗣️ ChatTTS-Forge") :] |
|
return content |
|
|
|
|
|
|
|
sample_texts = [ |
|
{ |
|
"text": "大🍌,一条大🍌,嘿,你的感觉真的很奇妙 [lbreak]", |
|
}, |
|
{ |
|
"text": "天气预报显示,今天会有小雨,请大家出门时记得带伞。降温的天气也提醒我们要适时添衣保暖 [lbreak]", |
|
}, |
|
{ |
|
"text": "公司的年度总结会议将在下周三举行,请各部门提前准备好相关材料,确保会议顺利进行 [lbreak]", |
|
}, |
|
{ |
|
"text": "今天的午餐菜单包括烤鸡、沙拉和蔬菜汤,大家可以根据自己的口味选择适合的菜品 [lbreak]", |
|
}, |
|
{ |
|
"text": "请注意,电梯将在下午两点进行例行维护,预计需要一个小时的时间,请大家在此期间使用楼梯 [lbreak]", |
|
}, |
|
{ |
|
"text": "图书馆新到了一批书籍,涵盖了文学、科学和历史等多个领域,欢迎大家前来借阅 [lbreak]", |
|
}, |
|
{ |
|
"text": "电影中梁朝伟扮演的陈永仁的编号27149 [lbreak]", |
|
}, |
|
{ |
|
"text": "这块黄金重达324.75克 [lbreak]", |
|
}, |
|
{ |
|
"text": "我们班的最高总分为583分 [lbreak]", |
|
}, |
|
{ |
|
"text": "12~23 [lbreak]", |
|
}, |
|
{ |
|
"text": "-1.5~2 [lbreak]", |
|
}, |
|
{ |
|
"text": "她出生于86年8月18日,她弟弟出生于1995年3月1日 [lbreak]", |
|
}, |
|
{ |
|
"text": "等会请在12:05请通知我 [lbreak]", |
|
}, |
|
{ |
|
"text": "今天的最低气温达到-10°C [lbreak]", |
|
}, |
|
{ |
|
"text": "现场有7/12的观众投出了赞成票 [lbreak]", |
|
}, |
|
{ |
|
"text": "明天有62%的概率降雨 [lbreak]", |
|
}, |
|
{ |
|
"text": "随便来几个价格12块5,34.5元,20.1万 [lbreak]", |
|
}, |
|
{ |
|
"text": "这是固话0421-33441122 [lbreak]", |
|
}, |
|
{ |
|
"text": "这是手机+86 18544139121 [lbreak]", |
|
}, |
|
] |
|
|
|
ssml_example1 = """ |
|
<speak version="0.1"> |
|
<voice spk="Bob" seed="42" style="narration-relaxed"> |
|
下面是一个 ChatTTS 用于合成多角色多情感的有声书示例[lbreak] |
|
</voice> |
|
<voice spk="Bob" seed="42" style="narration-relaxed"> |
|
黛玉冷笑道:[lbreak] |
|
</voice> |
|
<voice spk="female2" seed="42" style="angry"> |
|
我说呢 [uv_break] ,亏了绊住,不然,早就飞起来了[lbreak] |
|
</voice> |
|
<voice spk="Bob" seed="42" style="narration-relaxed"> |
|
宝玉道:[lbreak] |
|
</voice> |
|
<voice spk="Alice" seed="42" style="unfriendly"> |
|
“只许和你玩 [uv_break] ,替你解闷。不过偶然到他那里,就说这些闲话。”[lbreak] |
|
</voice> |
|
<voice spk="female2" seed="42" style="angry"> |
|
“好没意思的话![uv_break] 去不去,关我什么事儿? 又没叫你替我解闷儿 [uv_break],还许你不理我呢” [lbreak] |
|
</voice> |
|
<voice spk="Bob" seed="42" style="narration-relaxed"> |
|
说着,便赌气回房去了 [lbreak] |
|
</voice> |
|
</speak> |
|
""" |
|
ssml_example2 = """ |
|
<speak version="0.1"> |
|
<voice spk="Bob" seed="42" style="narration-relaxed"> |
|
使用 prosody 控制生成文本的语速语调和音量,示例如下 [lbreak] |
|
|
|
<prosody> |
|
无任何限制将会继承父级voice配置进行生成 [lbreak] |
|
</prosody> |
|
<prosody rate="1.5"> |
|
设置 rate 大于1表示加速,小于1为减速 [lbreak] |
|
</prosody> |
|
<prosody pitch="6"> |
|
设置 pitch 调整音调,设置为6表示提高6个半音 [lbreak] |
|
</prosody> |
|
<prosody volume="2"> |
|
设置 volume 调整音量,设置为2表示提高2个分贝 [lbreak] |
|
</prosody> |
|
|
|
在 voice 中无prosody包裹的文本即为默认生成状态下的语音 [lbreak] |
|
</voice> |
|
</speak> |
|
""" |
|
ssml_example3 = """ |
|
<speak version="0.1"> |
|
<voice spk="Bob" seed="42" style="narration-relaxed"> |
|
使用 break 标签将会简单的 [lbreak] |
|
|
|
<break time="500" /> |
|
|
|
插入一段空白到生成结果中 [lbreak] |
|
</voice> |
|
</speak> |
|
""" |
|
|
|
ssml_example4 = """ |
|
<speak version="0.1"> |
|
<voice spk="Bob" seed="42" style="excited"> |
|
temperature for sampling (may be overridden by style or speaker) [lbreak] |
|
<break time="500" /> |
|
温度值用于采样,这个值有可能被 style 或者 speaker 覆盖 [lbreak] |
|
<break time="500" /> |
|
temperature for sampling ,这个值有可能被 style 或者 speaker 覆盖 [lbreak] |
|
<break time="500" /> |
|
温度值用于采样,(may be overridden by style or speaker) [lbreak] |
|
</voice> |
|
</speak> |
|
""" |
|
|
|
default_ssml = """ |
|
<speak version="0.1"> |
|
<voice spk="Bob" seed="42" style="narration-relaxed"> |
|
这里是一个简单的 SSML 示例 [lbreak] |
|
</voice> |
|
</speak> |
|
""" |
|
|
|
|
|
def create_tts_interface(): |
|
speakers = get_speakers() |
|
|
|
def get_speaker_show_name(spk): |
|
if spk.gender == "*" or spk.gender == "": |
|
return spk.name |
|
return f"{spk.gender} : {spk.name}" |
|
|
|
speaker_names = ["*random"] + [ |
|
get_speaker_show_name(speaker) for speaker in speakers |
|
] |
|
|
|
styles = ["*auto"] + [s.get("name") for s in get_styles()] |
|
|
|
history = [] |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
with gr.Group(): |
|
gr.Markdown("🎛️Sampling") |
|
temperature_input = gr.Slider( |
|
0.01, 2.0, value=0.3, step=0.01, label="Temperature" |
|
) |
|
top_p_input = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Top P") |
|
top_k_input = gr.Slider(1, 50, value=20, step=1, label="Top K") |
|
batch_size_input = gr.Slider( |
|
1, |
|
webui_config["max_batch_size"], |
|
value=4, |
|
step=1, |
|
label="Batch Size", |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Group(): |
|
gr.Markdown("🎭Style") |
|
gr.Markdown("- 后缀为 `_p` 表示带prompt,效果更强但是影响质量") |
|
style_input_dropdown = gr.Dropdown( |
|
choices=styles, |
|
|
|
interactive=True, |
|
show_label=False, |
|
value="*auto", |
|
) |
|
with gr.Row(): |
|
with gr.Group(): |
|
gr.Markdown("🗣️Speaker (Name or Seed)") |
|
spk_input_text = gr.Textbox( |
|
label="Speaker (Text or Seed)", |
|
value="female2", |
|
show_label=False, |
|
) |
|
spk_input_dropdown = gr.Dropdown( |
|
choices=speaker_names, |
|
|
|
interactive=True, |
|
value="female : female2", |
|
show_label=False, |
|
) |
|
spk_rand_button = gr.Button( |
|
value="🎲", |
|
|
|
variant="secondary", |
|
) |
|
spk_input_dropdown.change( |
|
fn=lambda x: x.startswith("*") |
|
and "-1" |
|
or x.split(":")[-1].strip(), |
|
inputs=[spk_input_dropdown], |
|
outputs=[spk_input_text], |
|
) |
|
spk_rand_button.click( |
|
lambda x: str(torch.randint(0, 2**32 - 1, (1,)).item()), |
|
inputs=[spk_input_text], |
|
outputs=[spk_input_text], |
|
) |
|
with gr.Group(): |
|
gr.Markdown("💃Inference Seed") |
|
infer_seed_input = gr.Number( |
|
value=42, |
|
label="Inference Seed", |
|
show_label=False, |
|
minimum=-1, |
|
maximum=2**32 - 1, |
|
) |
|
infer_seed_rand_button = gr.Button( |
|
value="🎲", |
|
|
|
variant="secondary", |
|
) |
|
use_decoder_input = gr.Checkbox( |
|
value=True, label="Use Decoder", visible=False |
|
) |
|
with gr.Group(): |
|
gr.Markdown("🔧Prompt engineering") |
|
prompt1_input = gr.Textbox(label="Prompt 1") |
|
prompt2_input = gr.Textbox(label="Prompt 2") |
|
prefix_input = gr.Textbox(label="Prefix") |
|
|
|
infer_seed_rand_button.click( |
|
lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()), |
|
inputs=[infer_seed_input], |
|
outputs=[infer_seed_input], |
|
) |
|
with gr.Column(scale=3): |
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
with gr.Group(): |
|
input_title = gr.Markdown( |
|
"📝Text Input", |
|
elem_id="input-title", |
|
) |
|
gr.Markdown( |
|
f"- 字数限制{webui_config['tts_max']:,}字,超过部分截断" |
|
) |
|
gr.Markdown("- 如果尾字吞字不读,可以试试结尾加上 `[lbreak]`") |
|
gr.Markdown( |
|
"- If the input text is all in English, it is recommended to check disable_normalize" |
|
) |
|
text_input = gr.Textbox( |
|
show_label=False, |
|
label="Text to Speech", |
|
lines=10, |
|
placeholder="输入文本或选择示例", |
|
elem_id="text-input", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
contorl_tokens = [ |
|
"[laugh]", |
|
"[uv_break]", |
|
"[v_break]", |
|
"[lbreak]", |
|
] |
|
|
|
for tk in contorl_tokens: |
|
t_btn = gr.Button(tk) |
|
t_btn.click( |
|
lambda text, tk=tk: text + " " + tk, |
|
inputs=[text_input], |
|
outputs=[text_input], |
|
) |
|
with gr.Column(scale=1): |
|
with gr.Group(): |
|
gr.Markdown("🎶Refiner") |
|
refine_prompt_input = gr.Textbox( |
|
label="Refine Prompt", |
|
value="[oral_2][laugh_0][break_6]", |
|
) |
|
refine_button = gr.Button("✍️Refine Text") |
|
|
|
|
|
|
|
with gr.Group(): |
|
gr.Markdown("🔊Generate") |
|
disable_normalize_input = gr.Checkbox( |
|
value=False, label="Disable Normalize" |
|
) |
|
tts_button = gr.Button( |
|
"🔊Generate Audio", |
|
variant="primary", |
|
elem_classes="big-button", |
|
) |
|
|
|
with gr.Group(): |
|
gr.Markdown("🎄Examples") |
|
sample_dropdown = gr.Dropdown( |
|
choices=[sample["text"] for sample in sample_texts], |
|
show_label=False, |
|
value=None, |
|
interactive=True, |
|
) |
|
sample_dropdown.change( |
|
fn=lambda x: x, |
|
inputs=[sample_dropdown], |
|
outputs=[text_input], |
|
) |
|
|
|
with gr.Group(): |
|
gr.Markdown("🎨Output") |
|
tts_output = gr.Audio(label="Generated Audio") |
|
|
|
refine_button.click( |
|
refine_text, |
|
inputs=[text_input, refine_prompt_input], |
|
outputs=[text_input], |
|
) |
|
|
|
tts_button.click( |
|
tts_generate, |
|
inputs=[ |
|
text_input, |
|
temperature_input, |
|
top_p_input, |
|
top_k_input, |
|
spk_input_text, |
|
infer_seed_input, |
|
use_decoder_input, |
|
prompt1_input, |
|
prompt2_input, |
|
prefix_input, |
|
style_input_dropdown, |
|
disable_normalize_input, |
|
batch_size_input, |
|
], |
|
outputs=tts_output, |
|
) |
|
|
|
|
|
def create_ssml_interface(): |
|
examples = [ |
|
ssml_example1, |
|
ssml_example2, |
|
ssml_example3, |
|
ssml_example4, |
|
] |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
with gr.Group(): |
|
gr.Markdown("📝SSML Input") |
|
gr.Markdown(f"- 最长{webui_config['ssml_max']:,}字符,超过会被截断") |
|
gr.Markdown("- 尽量保证使用相同的 seed") |
|
gr.Markdown( |
|
"- 关于SSML可以看这个 [文档](https://github.com/lenML/ChatTTS-Forge/blob/main/docs/SSML.md)" |
|
) |
|
ssml_input = gr.Textbox( |
|
label="SSML Input", |
|
lines=10, |
|
value=default_ssml, |
|
placeholder="输入 SSML 或选择示例", |
|
elem_id="ssml_input", |
|
show_label=False, |
|
) |
|
ssml_button = gr.Button("🔊Synthesize SSML", variant="primary") |
|
with gr.Column(scale=1): |
|
with gr.Group(): |
|
|
|
gr.Markdown("🎛️Parameters") |
|
|
|
batch_size_input = gr.Slider( |
|
label="Batch Size", |
|
value=4, |
|
minimum=1, |
|
maximum=webui_config["max_batch_size"], |
|
step=1, |
|
) |
|
with gr.Group(): |
|
gr.Markdown("🎄Examples") |
|
gr.Examples( |
|
examples=examples, |
|
inputs=[ssml_input], |
|
) |
|
|
|
ssml_output = gr.Audio(label="Generated Audio") |
|
|
|
ssml_button.click( |
|
synthesize_ssml, |
|
inputs=[ssml_input, batch_size_input], |
|
outputs=ssml_output, |
|
) |
|
|
|
return ssml_input |
|
|
|
|
|
def split_long_text(long_text_input): |
|
spliter = SentenceSplitter(webui_config["spliter_threshold"]) |
|
sentences = spliter.parse(long_text_input) |
|
sentences = [text_normalize(s) for s in sentences] |
|
data = [] |
|
for i, text in enumerate(sentences): |
|
data.append([i, text, len(text)]) |
|
return data |
|
|
|
|
|
def merge_dataframe_to_ssml(dataframe, spk, style, seed): |
|
if style == "*auto": |
|
style = None |
|
if spk == "-1" or spk == -1: |
|
spk = None |
|
if seed == -1 or seed == "-1": |
|
seed = None |
|
|
|
ssml = "" |
|
indent = " " * 2 |
|
|
|
for i, row in dataframe.iterrows(): |
|
ssml += f"{indent}<voice" |
|
if spk: |
|
ssml += f' spk="{spk}"' |
|
if style: |
|
ssml += f' style="{style}"' |
|
if seed: |
|
ssml += f' seed="{seed}"' |
|
ssml += ">\n" |
|
ssml += f"{indent}{indent}{text_normalize(row[1])}\n" |
|
ssml += f"{indent}</voice>\n" |
|
return f"<speak version='0.1'>\n{ssml}</speak>" |
|
|
|
|
|
|
|
|
|
|
|
def create_long_content_tab(ssml_input, tabs): |
|
speakers = get_speakers() |
|
|
|
def get_speaker_show_name(spk): |
|
if spk.gender == "*" or spk.gender == "": |
|
return spk.name |
|
return f"{spk.gender} : {spk.name}" |
|
|
|
speaker_names = ["*random"] + [ |
|
get_speaker_show_name(speaker) for speaker in speakers |
|
] |
|
|
|
styles = ["*auto"] + [s.get("name") for s in get_styles()] |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
with gr.Group(): |
|
gr.Markdown("🗣️Speaker") |
|
spk_input_text = gr.Textbox( |
|
label="Speaker (Text or Seed)", |
|
value="female2", |
|
show_label=False, |
|
) |
|
spk_input_dropdown = gr.Dropdown( |
|
choices=speaker_names, |
|
interactive=True, |
|
value="female : female2", |
|
show_label=False, |
|
) |
|
spk_rand_button = gr.Button( |
|
value="🎲", |
|
variant="secondary", |
|
) |
|
with gr.Group(): |
|
gr.Markdown("🎭Style") |
|
style_input_dropdown = gr.Dropdown( |
|
choices=styles, |
|
interactive=True, |
|
show_label=False, |
|
value="*auto", |
|
) |
|
with gr.Group(): |
|
gr.Markdown("🗣️Seed") |
|
infer_seed_input = gr.Number( |
|
value=42, |
|
label="Inference Seed", |
|
show_label=False, |
|
minimum=-1, |
|
maximum=2**32 - 1, |
|
) |
|
infer_seed_rand_button = gr.Button( |
|
value="🎲", |
|
variant="secondary", |
|
) |
|
|
|
send_btn = gr.Button("📩Send to SSML", variant="primary") |
|
|
|
with gr.Column(scale=3): |
|
with gr.Group(): |
|
gr.Markdown("📝Long Text Input") |
|
gr.Markdown("- 此页面用于处理超长文本") |
|
gr.Markdown("- 切割后,可以选择说话人、风格、seed,然后发送到SSML") |
|
long_text_input = gr.Textbox( |
|
label="Long Text Input", |
|
lines=10, |
|
placeholder="输入长文本", |
|
elem_id="long-text-input", |
|
show_label=False, |
|
) |
|
long_text_split_button = gr.Button("🔪Split Text") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
with gr.Group(): |
|
gr.Markdown("🎨Output") |
|
long_text_output = gr.DataFrame( |
|
headers=["index", "text", "length"], |
|
datatype=["number", "str", "number"], |
|
elem_id="long-text-output", |
|
interactive=False, |
|
wrap=True, |
|
value=[], |
|
) |
|
|
|
spk_input_dropdown.change( |
|
fn=lambda x: x.startswith("*") and "-1" or x.split(":")[-1].strip(), |
|
inputs=[spk_input_dropdown], |
|
outputs=[spk_input_text], |
|
) |
|
spk_rand_button.click( |
|
lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()), |
|
inputs=[spk_input_text], |
|
outputs=[spk_input_text], |
|
) |
|
infer_seed_rand_button.click( |
|
lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()), |
|
inputs=[infer_seed_input], |
|
outputs=[infer_seed_input], |
|
) |
|
long_text_split_button.click( |
|
split_long_text, |
|
inputs=[long_text_input], |
|
outputs=[long_text_output], |
|
) |
|
|
|
infer_seed_rand_button.click( |
|
lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()), |
|
inputs=[infer_seed_input], |
|
outputs=[infer_seed_input], |
|
) |
|
|
|
send_btn.click( |
|
merge_dataframe_to_ssml, |
|
inputs=[ |
|
long_text_output, |
|
spk_input_text, |
|
style_input_dropdown, |
|
infer_seed_input, |
|
], |
|
outputs=[ssml_input], |
|
) |
|
|
|
def change_tab(): |
|
return gr.Tabs(selected="ssml") |
|
|
|
send_btn.click(change_tab, inputs=[], outputs=[tabs]) |
|
|
|
|
|
def create_readme_tab(): |
|
readme_content = read_local_readme() |
|
gr.Markdown(readme_content) |
|
|
|
|
|
def create_interface(): |
|
|
|
js_func = """ |
|
function refresh() { |
|
const url = new URL(window.location); |
|
|
|
if (url.searchParams.get('__theme') !== 'dark') { |
|
url.searchParams.set('__theme', 'dark'); |
|
window.location.href = url.href; |
|
} |
|
} |
|
""" |
|
|
|
head_js = """ |
|
<script> |
|
</script> |
|
""" |
|
|
|
with gr.Blocks(js=js_func, head=head_js, title="ChatTTS Forge WebUI") as demo: |
|
css = """ |
|
<style> |
|
.big-button { |
|
height: 80px; |
|
} |
|
#input_title div.eta-bar { |
|
display: none !important; transform: none !important; |
|
} |
|
</style> |
|
""" |
|
|
|
gr.HTML(css) |
|
with gr.Tabs() as tabs: |
|
with gr.TabItem("TTS"): |
|
create_tts_interface() |
|
|
|
with gr.TabItem("SSML", id="ssml"): |
|
ssml_input = create_ssml_interface() |
|
|
|
with gr.TabItem("Long Text"): |
|
create_long_content_tab(ssml_input, tabs=tabs) |
|
|
|
with gr.TabItem("README"): |
|
create_readme_tab() |
|
|
|
gr.Markdown( |
|
"此项目基于 [ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge) " |
|
) |
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
import dotenv |
|
|
|
dotenv.load_dotenv( |
|
dotenv_path=os.getenv("ENV_FILE", ".env.webui"), |
|
) |
|
|
|
parser = argparse.ArgumentParser(description="Gradio App") |
|
parser.add_argument("--server_name", type=str, help="server name") |
|
parser.add_argument("--server_port", type=int, help="server port") |
|
parser.add_argument( |
|
"--share", action="store_true", help="share the gradio interface" |
|
) |
|
parser.add_argument("--debug", action="store_true", help="enable debug mode") |
|
parser.add_argument("--auth", type=str, help="username:password for authentication") |
|
parser.add_argument( |
|
"--half", |
|
action="store_true", |
|
help="Enable half precision for model inference", |
|
) |
|
parser.add_argument( |
|
"--off_tqdm", |
|
action="store_true", |
|
help="Disable tqdm progress bar", |
|
) |
|
parser.add_argument( |
|
"--tts_max_len", |
|
type=int, |
|
help="Max length of text for TTS", |
|
) |
|
parser.add_argument( |
|
"--ssml_max_len", |
|
type=int, |
|
help="Max length of text for SSML", |
|
) |
|
parser.add_argument( |
|
"--max_batch_size", |
|
type=int, |
|
help="Max batch size for TTS", |
|
) |
|
parser.add_argument( |
|
"--lru_size", |
|
type=int, |
|
default=64, |
|
help="Set the size of the request cache pool, set it to 0 will disable lru_cache", |
|
) |
|
parser.add_argument( |
|
"--device_id", |
|
type=str, |
|
help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", |
|
default=None, |
|
) |
|
parser.add_argument( |
|
"--use_cpu", |
|
nargs="+", |
|
help="use CPU as torch device for specified modules", |
|
default=[], |
|
type=str.lower, |
|
) |
|
parser.add_argument("--compile", action="store_true", help="Enable model compile") |
|
|
|
args = parser.parse_args() |
|
|
|
def get_and_update_env(*args): |
|
val = env.get_env_or_arg(*args) |
|
key = args[1] |
|
config.runtime_env_vars[key] = val |
|
return val |
|
|
|
server_name = get_and_update_env(args, "server_name", "0.0.0.0", str) |
|
server_port = get_and_update_env(args, "server_port", 7860, int) |
|
share = get_and_update_env(args, "share", False, bool) |
|
debug = get_and_update_env(args, "debug", False, bool) |
|
auth = get_and_update_env(args, "auth", None, str) |
|
half = get_and_update_env(args, "half", False, bool) |
|
off_tqdm = get_and_update_env(args, "off_tqdm", False, bool) |
|
lru_size = get_and_update_env(args, "lru_size", 64, int) |
|
device_id = get_and_update_env(args, "device_id", None, str) |
|
use_cpu = get_and_update_env(args, "use_cpu", [], list) |
|
compile = get_and_update_env(args, "compile", False, bool) |
|
|
|
webui_config["tts_max"] = get_and_update_env(args, "tts_max_len", 1000, int) |
|
webui_config["ssml_max"] = get_and_update_env(args, "ssml_max_len", 5000, int) |
|
webui_config["max_batch_size"] = get_and_update_env(args, "max_batch_size", 8, int) |
|
|
|
demo = create_interface() |
|
|
|
if auth: |
|
auth = tuple(auth.split(":")) |
|
|
|
generate.setup_lru_cache() |
|
devices.reset_device() |
|
devices.first_time_calculation() |
|
|
|
demo.queue().launch( |
|
server_name=server_name, |
|
server_port=server_port, |
|
share=share, |
|
debug=debug, |
|
auth=auth, |
|
show_api=False, |
|
) |
|
|