import os, sys if sys.platform == "darwin": os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" now_dir = os.getcwd() sys.path.append(now_dir) from tool.logger import get_logger import ChatTTS import argparse import gradio as gr from tool.func import * from tool.ctx import TorchSeedContext from tool.np import * logger = get_logger("app") # Initialize and load the model: chat = ChatTTS.Chat() def init_chat(args): global chat # 获取启动模式 MODEL = os.getenv('MODEL') logger.info("loading ChatTTS model..., start MODEL:" + str(MODEL)) source = "custom" # huggingface 部署模式下,模型则直接使用hf的模型数据 if MODEL == "HF": source = "huggingface" if chat.load(source=source, custom_path="D:\\chenjgspace\\ai-model\\chattts", coef=None): print("Models loaded successfully.") else: print("Models load failed.") sys.exit(1) def main(args): with gr.Blocks() as demo: gr.Markdown("# ChatTTS demo") with gr.Row(): with gr.Column(scale=1): text_input = gr.Textbox( label="转换内容", lines=4, max_lines=4, placeholder="Please Input Text...", value="柔柔的,浓浓的,痴痴的风,牵引起心底灵动的思潮;情愫悠悠,思情绵绵,风里默坐,红尘中的浅醉,诗词中的优柔,任那自在飞花轻似梦的情怀,裁一束霓衣,织就清浅淡薄的安寂。", interactive=True, ) with gr.Row(): refine_text_checkBox = gr.Checkbox( label="是否优化文本,如是则先对文本内容做优化分词", interactive=True, value=True ) temperature_slider = gr.Slider( minimum=0.00001, maximum=1.0, step=0.00001, value=0.3, interactive=True, label="模型 Temperature 参数设置" ) top_p_slider = gr.Slider( minimum=0.1, maximum=0.9, step=0.05, value=0.7, label="模型 top_P 参数设置", interactive=True, ) top_k_slider = gr.Slider( minimum=1, maximum=20, step=1, value=20, label="模型 top_K 参数设置", interactive=True, ) with gr.Row(): voice_selection = gr.Dropdown( label="Timbre", choices=voices.keys(), value="旁白", interactive=True, show_label=True ) audio_seed_input = gr.Number( value=2, label="音色种子", interactive=True, minimum=seed_min, maximum=seed_max, ) generate_audio_seed = gr.Button("随机生成音色种子", interactive=True) text_seed_input = gr.Number( value=42, label="文本种子", interactive=True, minimum=seed_min, maximum=seed_max, ) generate_text_seed = gr.Button("随机生成文本种子", interactive=True) with gr.Row(): spk_emb_text = gr.Textbox( label="Speaker Embedding", max_lines=3, show_copy_button=True, interactive=False, scale=2, ) reload_chat_button = gr.Button("Reload", scale=1, interactive=True) with gr.Row(): generate_button = gr.Button("生成音频文件", scale=1, interactive=True) with gr.Row(): text_output = gr.Textbox( label="输出文本", interactive=False, show_copy_button=True, ) audio_output = gr.Audio( label="输出音频", value=None, format="wav", autoplay=False, streaming=False, interactive=False, show_label=True, waveform_options=gr.WaveformOptions( sample_rate=24000, ), ) # 针对页面元素新增 监听事件 voice_selection.change(fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input) audio_seed_input.change(fn=on_audio_seed_change, inputs=audio_seed_input, outputs=spk_emb_text) generate_audio_seed.click(fn=generate_seed, outputs=audio_seed_input) generate_text_seed.click(fn=generate_seed,outputs=text_seed_input) # reload_chat_button.click() generate_button.click(fn=get_chat_infer_text, inputs=[text_input, text_seed_input, refine_text_checkBox ], outputs=[text_output] ).then(fn=get_chat_infer_audio, inputs=[text_output, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, spk_emb_text ], outputs=[audio_output]) # 初始化 spk_emb_text 数值 spk_emb_text.value = on_audio_seed_change(audio_seed_input.value) logger.info("元素初始化完成,启动gradio服务=======") # 运行gradio服务 demo.launch( server_name=args.server_name, server_port=args.server_port, inbrowser=True, show_api=False) def get_chat_infer_audio(chat_txt, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, spk_emb_text): logger.info("========开始生成音频文件=====") #音频参数设置 params_infer_code = ChatTTS.Chat.InferCodeParams( spk_emb=spk_emb_text, # add sampled speaker temperature=temperature_slider, # using custom temperature top_P=top_p_slider, # top P decode top_K=top_k_slider, # top K decode ) with TorchSeedContext(audio_seed_input): wav = chat.infer( text=chat_txt, skip_refine_text=True, #跳过文本优化 params_infer_code=params_infer_code, ) yield 24000, float_to_int16(wav[0]).T def get_chat_infer_text(text,seed,refine_text_checkBox): logger.info("========开始优化文本内容=====") global chat if not refine_text_checkBox: logger.info("========文本内容无需优化=====") return text params_refine_text = ChatTTS.Chat.RefineTextParams( prompt='[oral_2][laugh_0][break_6]', ) with TorchSeedContext(seed): chat_text = chat.infer( text=text, skip_refine_text=False, refine_text_only=True, #仅返回优化后文本内容 params_refine_text=params_refine_text, ) return chat_text[0] if isinstance(chat_text, list) else chat_text def on_audio_seed_change(audio_seed_input): global chat with TorchSeedContext(audio_seed_input): rand_spk = chat.sample_random_speaker() return rand_spk if __name__ == "__main__": parser = argparse.ArgumentParser(description="ChatTTS demo Launch") parser.add_argument( "--server_name", type=str, default="0.0.0.0", help="server name" ) parser.add_argument("--server_port", type=int, default=8080, help="server port") parser.add_argument( "--custom_path", type=str, default="D:\\chenjgspace\\ai-model\\chattts", help="custom model path" ) parser.add_argument( "--coef", type=str, default=None, help="custom dvae coefficient" ) args = parser.parse_args() init_chat(args) main(args)