chat-tts / web /app.py
chenjgtea
提交代码
214ea91
raw
history blame
8.53 kB
import os, sys
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
now_dir = os.getcwd()
sys.path.append(now_dir)
from tool.logger import get_logger
import ChatTTS
import argparse
import gradio as gr
from tool.func import *
from tool.ctx import TorchSeedContext
from tool.np import *
logger = get_logger("app")
# Initialize and load the model:
chat = ChatTTS.Chat()
def init_chat(args):
global chat
# 获取启动模式
MODEL = os.getenv('MODEL')
logger.info("loading ChatTTS model..., start MODEL:" + str(MODEL))
source = "custom"
# huggingface 部署模式下,模型则直接使用hf的模型数据
if MODEL == "HF":
source = "huggingface"
if chat.load(source=source, custom_path="D:\\chenjgspace\\ai-model\\chattts", coef=None):
print("Models loaded successfully.")
else:
print("Models load failed.")
sys.exit(1)
def main(args):
with gr.Blocks() as demo:
gr.Markdown("# ChatTTS demo")
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(
label="转换内容",
lines=4,
max_lines=4,
placeholder="Please Input Text...",
value="柔柔的,浓浓的,痴痴的风,牵引起心底灵动的思潮;情愫悠悠,思情绵绵,风里默坐,红尘中的浅醉,诗词中的优柔,任那自在飞花轻似梦的情怀,裁一束霓衣,织就清浅淡薄的安寂。",
interactive=True,
)
with gr.Row():
refine_text_checkBox = gr.Checkbox(
label="是否优化文本,如是则先对文本内容做优化分词",
interactive=True,
value=True
)
temperature_slider = gr.Slider(
minimum=0.00001,
maximum=1.0,
step=0.00001,
value=0.3,
interactive=True,
label="模型 Temperature 参数设置"
)
top_p_slider = gr.Slider(
minimum=0.1,
maximum=0.9,
step=0.05,
value=0.7,
label="模型 top_P 参数设置",
interactive=True,
)
top_k_slider = gr.Slider(
minimum=1,
maximum=20,
step=1,
value=20,
label="模型 top_K 参数设置",
interactive=True,
)
with gr.Row():
voice_selection = gr.Dropdown(
label="Timbre",
choices=voices.keys(),
value="旁白",
interactive=True,
show_label=True
)
audio_seed_input = gr.Number(
value=2,
label="音色种子",
interactive=True,
minimum=seed_min,
maximum=seed_max,
)
generate_audio_seed = gr.Button("随机生成音色种子", interactive=True)
text_seed_input = gr.Number(
value=42,
label="文本种子",
interactive=True,
minimum=seed_min,
maximum=seed_max,
)
generate_text_seed = gr.Button("随机生成文本种子", interactive=True)
with gr.Row():
spk_emb_text = gr.Textbox(
label="Speaker Embedding",
max_lines=3,
show_copy_button=True,
interactive=False,
scale=2,
)
reload_chat_button = gr.Button("Reload", scale=1, interactive=True)
with gr.Row():
generate_button = gr.Button("生成音频文件", scale=1, interactive=True)
with gr.Row():
text_output = gr.Textbox(
label="输出文本",
interactive=False,
show_copy_button=True,
)
audio_output = gr.Audio(
label="输出音频",
value=None,
format="wav",
autoplay=False,
streaming=False,
interactive=False,
show_label=True,
waveform_options=gr.WaveformOptions(
sample_rate=24000,
),
)
# 针对页面元素新增 监听事件
voice_selection.change(fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input)
audio_seed_input.change(fn=on_audio_seed_change, inputs=audio_seed_input, outputs=spk_emb_text)
generate_audio_seed.click(fn=generate_seed, outputs=audio_seed_input)
generate_text_seed.click(fn=generate_seed,outputs=text_seed_input)
# reload_chat_button.click()
generate_button.click(fn=get_chat_infer_text,
inputs=[text_input,
text_seed_input,
refine_text_checkBox
],
outputs=[text_output]
).then(fn=get_chat_infer_audio,
inputs=[text_output,
temperature_slider,
top_p_slider,
top_k_slider,
audio_seed_input,
spk_emb_text
],
outputs=[audio_output])
# 初始化 spk_emb_text 数值
spk_emb_text.value = on_audio_seed_change(audio_seed_input.value)
logger.info("元素初始化完成,启动gradio服务=======")
# 运行gradio服务
demo.launch(
server_name=args.server_name,
server_port=args.server_port,
inbrowser=True,
show_api=False)
def get_chat_infer_audio(chat_txt,
temperature_slider,
top_p_slider,
top_k_slider,
audio_seed_input,
spk_emb_text):
logger.info("========开始生成音频文件=====")
#音频参数设置
params_infer_code = ChatTTS.Chat.InferCodeParams(
spk_emb=spk_emb_text, # add sampled speaker
temperature=temperature_slider, # using custom temperature
top_P=top_p_slider, # top P decode
top_K=top_k_slider, # top K decode
)
with TorchSeedContext(audio_seed_input):
wav = chat.infer(
text=chat_txt,
skip_refine_text=True, #跳过文本优化
params_infer_code=params_infer_code,
)
yield 24000, float_to_int16(wav[0]).T
def get_chat_infer_text(text,seed,refine_text_checkBox):
logger.info("========开始优化文本内容=====")
global chat
if not refine_text_checkBox:
logger.info("========文本内容无需优化=====")
return text
params_refine_text = ChatTTS.Chat.RefineTextParams(
prompt='[oral_2][laugh_0][break_6]',
)
with TorchSeedContext(seed):
chat_text = chat.infer(
text=text,
skip_refine_text=False,
refine_text_only=True, #仅返回优化后文本内容
params_refine_text=params_refine_text,
)
return chat_text[0] if isinstance(chat_text, list) else chat_text
def on_audio_seed_change(audio_seed_input):
global chat
with TorchSeedContext(audio_seed_input):
rand_spk = chat.sample_random_speaker()
return rand_spk
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="ChatTTS demo Launch")
parser.add_argument(
"--server_name", type=str, default="0.0.0.0", help="server name"
)
parser.add_argument("--server_port", type=int, default=8080, help="server port")
parser.add_argument(
"--custom_path", type=str, default="D:\\chenjgspace\\ai-model\\chattts", help="custom model path"
)
parser.add_argument(
"--coef", type=str, default=None, help="custom dvae coefficient"
)
args = parser.parse_args()
init_chat(args)
main(args)