File size: 29,804 Bytes
84cfd61 01e655b 02e90e4 01e655b 02e90e4 d5b3cd8 02e90e4 d5b3cd8 01e655b 02e90e4 01e655b 84cfd61 01e655b 84cfd61 01e655b 84cfd61 01e655b d6fe286 01e655b 84cfd61 01e655b 84cfd61 01e655b d5b3cd8 01e655b 02e90e4 01e655b d5b3cd8 01e655b 84cfd61 01e655b 84cfd61 01e655b 84cfd61 01e655b 02e90e4 01e655b 02e90e4 01e655b 02e90e4 01e655b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 |
try:
import spaces
except:
class NoneSpaces:
def __init__(self):
pass
def GPU(self, fn):
return fn
spaces = NoneSpaces()
import os
import logging
import numpy as np
from modules.devices import devices
from modules.synthesize_audio import synthesize_audio
from modules.utils.cache import conditional_cache
logging.basicConfig(
level=os.getenv("LOG_LEVEL", "INFO"),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
import gradio as gr
import torch
from modules.ssml import parse_ssml
from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
from modules.speaker import speaker_mgr
from modules.data import styles_mgr
from modules.api.utils import calc_spk_style
import modules.generate_audio as generate
from modules.normalization import text_normalize
from modules import refiner, config
from modules.utils import env, audio
from modules.SentenceSplitter import SentenceSplitter
torch._dynamo.config.cache_size_limit = 64
torch._dynamo.config.suppress_errors = True
torch.set_float32_matmul_precision("high")
webui_config = {
"tts_max": 1000,
"ssml_max": 5000,
"spliter_threshold": 100,
"max_batch_size": 8,
}
def get_speakers():
return speaker_mgr.list_speakers()
def get_styles():
return styles_mgr.list_items()
def segments_length_limit(segments, total_max: int):
ret_segments = []
total_len = 0
for seg in segments:
total_len += len(seg["text"])
if total_len > total_max:
break
ret_segments.append(seg)
return ret_segments
@torch.inference_mode()
@spaces.GPU
def synthesize_ssml(ssml: str, batch_size=4):
try:
batch_size = int(batch_size)
except Exception:
batch_size = 8
ssml = ssml.strip()
if ssml == "":
return None
segments = parse_ssml(ssml)
max_len = webui_config["ssml_max"]
segments = segments_length_limit(segments, max_len)
if len(segments) == 0:
return None
synthesize = SynthesizeSegments(batch_size=batch_size)
audio_segments = synthesize.synthesize_segments(segments)
combined_audio = combine_audio_segments(audio_segments)
return audio.pydub_to_np(combined_audio)
@torch.inference_mode()
@spaces.GPU
def tts_generate(
text,
temperature,
top_p,
top_k,
spk,
infer_seed,
use_decoder,
prompt1,
prompt2,
prefix,
style,
disable_normalize=False,
batch_size=4,
):
try:
batch_size = int(batch_size)
except Exception:
batch_size = 4
max_len = webui_config["tts_max"]
text = text.strip()[0:max_len]
if text == "":
return None
if style == "*auto":
style = None
if isinstance(top_k, float):
top_k = int(top_k)
params = calc_spk_style(spk=spk, style=style)
spk = params.get("spk", spk)
infer_seed = infer_seed or params.get("seed", infer_seed)
temperature = temperature or params.get("temperature", temperature)
prefix = prefix or params.get("prefix", prefix)
prompt1 = prompt1 or params.get("prompt1", "")
prompt2 = prompt2 or params.get("prompt2", "")
infer_seed = np.clip(infer_seed, -1, 2**32 - 1)
infer_seed = int(infer_seed)
if not disable_normalize:
text = text_normalize(text)
sample_rate, audio_data = synthesize_audio(
text=text,
temperature=temperature,
top_P=top_p,
top_K=top_k,
spk=spk,
infer_seed=infer_seed,
use_decoder=use_decoder,
prompt1=prompt1,
prompt2=prompt2,
prefix=prefix,
batch_size=batch_size,
)
audio_data = audio.audio_to_int16(audio_data)
return sample_rate, audio_data
@torch.inference_mode()
@spaces.GPU
def refine_text(text: str, prompt: str):
text = text_normalize(text)
return refiner.refine_text(text, prompt=prompt)
def read_local_readme():
with open("README.md", "r", encoding="utf-8") as file:
content = file.read()
content = content[content.index("# 🗣️ ChatTTS-Forge") :]
return content
# 演示示例文本
sample_texts = [
{
"text": "大🍌,一条大🍌,嘿,你的感觉真的很奇妙 [lbreak]",
},
{
"text": "天气预报显示,今天会有小雨,请大家出门时记得带伞。降温的天气也提醒我们要适时添衣保暖 [lbreak]",
},
{
"text": "公司的年度总结会议将在下周三举行,请各部门提前准备好相关材料,确保会议顺利进行 [lbreak]",
},
{
"text": "今天的午餐菜单包括烤鸡、沙拉和蔬菜汤,大家可以根据自己的口味选择适合的菜品 [lbreak]",
},
{
"text": "请注意,电梯将在下午两点进行例行维护,预计需要一个小时的时间,请大家在此期间使用楼梯 [lbreak]",
},
{
"text": "图书馆新到了一批书籍,涵盖了文学、科学和历史等多个领域,欢迎大家前来借阅 [lbreak]",
},
{
"text": "电影中梁朝伟扮演的陈永仁的编号27149 [lbreak]",
},
{
"text": "这块黄金重达324.75克 [lbreak]",
},
{
"text": "我们班的最高总分为583分 [lbreak]",
},
{
"text": "12~23 [lbreak]",
},
{
"text": "-1.5~2 [lbreak]",
},
{
"text": "她出生于86年8月18日,她弟弟出生于1995年3月1日 [lbreak]",
},
{
"text": "等会请在12:05请通知我 [lbreak]",
},
{
"text": "今天的最低气温达到-10°C [lbreak]",
},
{
"text": "现场有7/12的观众投出了赞成票 [lbreak]",
},
{
"text": "明天有62%的概率降雨 [lbreak]",
},
{
"text": "随便来几个价格12块5,34.5元,20.1万 [lbreak]",
},
{
"text": "这是固话0421-33441122 [lbreak]",
},
{
"text": "这是手机+86 18544139121 [lbreak]",
},
]
ssml_example1 = """
<speak version="0.1">
<voice spk="Bob" seed="42" style="narration-relaxed">
下面是一个 ChatTTS 用于合成多角色多情感的有声书示例[lbreak]
</voice>
<voice spk="Bob" seed="42" style="narration-relaxed">
黛玉冷笑道:[lbreak]
</voice>
<voice spk="female2" seed="42" style="angry">
我说呢 [uv_break] ,亏了绊住,不然,早就飞起来了[lbreak]
</voice>
<voice spk="Bob" seed="42" style="narration-relaxed">
宝玉道:[lbreak]
</voice>
<voice spk="Alice" seed="42" style="unfriendly">
“只许和你玩 [uv_break] ,替你解闷。不过偶然到他那里,就说这些闲话。”[lbreak]
</voice>
<voice spk="female2" seed="42" style="angry">
“好没意思的话![uv_break] 去不去,关我什么事儿? 又没叫你替我解闷儿 [uv_break],还许你不理我呢” [lbreak]
</voice>
<voice spk="Bob" seed="42" style="narration-relaxed">
说着,便赌气回房去了 [lbreak]
</voice>
</speak>
"""
ssml_example2 = """
<speak version="0.1">
<voice spk="Bob" seed="42" style="narration-relaxed">
使用 prosody 控制生成文本的语速语调和音量,示例如下 [lbreak]
<prosody>
无任何限制将会继承父级voice配置进行生成 [lbreak]
</prosody>
<prosody rate="1.5">
设置 rate 大于1表示加速,小于1为减速 [lbreak]
</prosody>
<prosody pitch="6">
设置 pitch 调整音调,设置为6表示提高6个半音 [lbreak]
</prosody>
<prosody volume="2">
设置 volume 调整音量,设置为2表示提高2个分贝 [lbreak]
</prosody>
在 voice 中无prosody包裹的文本即为默认生成状态下的语音 [lbreak]
</voice>
</speak>
"""
ssml_example3 = """
<speak version="0.1">
<voice spk="Bob" seed="42" style="narration-relaxed">
使用 break 标签将会简单的 [lbreak]
<break time="500" />
插入一段空白到生成结果中 [lbreak]
</voice>
</speak>
"""
ssml_example4 = """
<speak version="0.1">
<voice spk="Bob" seed="42" style="excited">
temperature for sampling (may be overridden by style or speaker) [lbreak]
<break time="500" />
温度值用于采样,这个值有可能被 style 或者 speaker 覆盖 [lbreak]
<break time="500" />
temperature for sampling ,这个值有可能被 style 或者 speaker 覆盖 [lbreak]
<break time="500" />
温度值用于采样,(may be overridden by style or speaker) [lbreak]
</voice>
</speak>
"""
default_ssml = """
<speak version="0.1">
<voice spk="Bob" seed="42" style="narration-relaxed">
这里是一个简单的 SSML 示例 [lbreak]
</voice>
</speak>
"""
def create_tts_interface():
speakers = get_speakers()
def get_speaker_show_name(spk):
if spk.gender == "*" or spk.gender == "":
return spk.name
return f"{spk.gender} : {spk.name}"
speaker_names = ["*random"] + [
get_speaker_show_name(speaker) for speaker in speakers
]
styles = ["*auto"] + [s.get("name") for s in get_styles()]
history = []
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("🎛️Sampling")
temperature_input = gr.Slider(
0.01, 2.0, value=0.3, step=0.01, label="Temperature"
)
top_p_input = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Top P")
top_k_input = gr.Slider(1, 50, value=20, step=1, label="Top K")
batch_size_input = gr.Slider(
1,
webui_config["max_batch_size"],
value=4,
step=1,
label="Batch Size",
)
with gr.Row():
with gr.Group():
gr.Markdown("🎭Style")
gr.Markdown("- 后缀为 `_p` 表示带prompt,效果更强但是影响质量")
style_input_dropdown = gr.Dropdown(
choices=styles,
# label="Choose Style",
interactive=True,
show_label=False,
value="*auto",
)
with gr.Row():
with gr.Group():
gr.Markdown("🗣️Speaker (Name or Seed)")
spk_input_text = gr.Textbox(
label="Speaker (Text or Seed)",
value="female2",
show_label=False,
)
spk_input_dropdown = gr.Dropdown(
choices=speaker_names,
# label="Choose Speaker",
interactive=True,
value="female : female2",
show_label=False,
)
spk_rand_button = gr.Button(
value="🎲",
# tooltip="Random Seed",
variant="secondary",
)
spk_input_dropdown.change(
fn=lambda x: x.startswith("*")
and "-1"
or x.split(":")[-1].strip(),
inputs=[spk_input_dropdown],
outputs=[spk_input_text],
)
spk_rand_button.click(
lambda x: str(torch.randint(0, 2**32 - 1, (1,)).item()),
inputs=[spk_input_text],
outputs=[spk_input_text],
)
with gr.Group():
gr.Markdown("💃Inference Seed")
infer_seed_input = gr.Number(
value=42,
label="Inference Seed",
show_label=False,
minimum=-1,
maximum=2**32 - 1,
)
infer_seed_rand_button = gr.Button(
value="🎲",
# tooltip="Random Seed",
variant="secondary",
)
use_decoder_input = gr.Checkbox(
value=True, label="Use Decoder", visible=False
)
with gr.Group():
gr.Markdown("🔧Prompt engineering")
prompt1_input = gr.Textbox(label="Prompt 1")
prompt2_input = gr.Textbox(label="Prompt 2")
prefix_input = gr.Textbox(label="Prefix")
infer_seed_rand_button.click(
lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
inputs=[infer_seed_input],
outputs=[infer_seed_input],
)
with gr.Column(scale=3):
with gr.Row():
with gr.Column(scale=4):
with gr.Group():
input_title = gr.Markdown(
"📝Text Input",
elem_id="input-title",
)
gr.Markdown(
f"- 字数限制{webui_config['tts_max']:,}字,超过部分截断"
)
gr.Markdown("- 如果尾字吞字不读,可以试试结尾加上 `[lbreak]`")
gr.Markdown(
"- If the input text is all in English, it is recommended to check disable_normalize"
)
text_input = gr.Textbox(
show_label=False,
label="Text to Speech",
lines=10,
placeholder="输入文本或选择示例",
elem_id="text-input",
)
# TODO 字数统计,其实实现很好写,但是就是会触发loading...并且还要和后端交互...
# text_input.change(
# fn=lambda x: (
# f"📝Text Input ({len(x)} char)"
# if x
# else (
# "📝Text Input (0 char)"
# if not x
# else "📝Text Input (0 char)"
# )
# ),
# inputs=[text_input],
# outputs=[input_title],
# )
with gr.Row():
contorl_tokens = [
"[laugh]",
"[uv_break]",
"[v_break]",
"[lbreak]",
]
for tk in contorl_tokens:
t_btn = gr.Button(tk)
t_btn.click(
lambda text, tk=tk: text + " " + tk,
inputs=[text_input],
outputs=[text_input],
)
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("🎶Refiner")
refine_prompt_input = gr.Textbox(
label="Refine Prompt",
value="[oral_2][laugh_0][break_6]",
)
refine_button = gr.Button("✍️Refine Text")
# TODO 分割句子,使用当前配置拼接为SSML,然后发送到SSML tab
# send_button = gr.Button("📩Split and send to SSML")
with gr.Group():
gr.Markdown("🔊Generate")
disable_normalize_input = gr.Checkbox(
value=False, label="Disable Normalize"
)
tts_button = gr.Button(
"🔊Generate Audio",
variant="primary",
elem_classes="big-button",
)
with gr.Group():
gr.Markdown("🎄Examples")
sample_dropdown = gr.Dropdown(
choices=[sample["text"] for sample in sample_texts],
show_label=False,
value=None,
interactive=True,
)
sample_dropdown.change(
fn=lambda x: x,
inputs=[sample_dropdown],
outputs=[text_input],
)
with gr.Group():
gr.Markdown("🎨Output")
tts_output = gr.Audio(label="Generated Audio")
refine_button.click(
refine_text,
inputs=[text_input, refine_prompt_input],
outputs=[text_input],
)
tts_button.click(
tts_generate,
inputs=[
text_input,
temperature_input,
top_p_input,
top_k_input,
spk_input_text,
infer_seed_input,
use_decoder_input,
prompt1_input,
prompt2_input,
prefix_input,
style_input_dropdown,
disable_normalize_input,
batch_size_input,
],
outputs=tts_output,
)
def create_ssml_interface():
examples = [
ssml_example1,
ssml_example2,
ssml_example3,
ssml_example4,
]
with gr.Row():
with gr.Column(scale=3):
with gr.Group():
gr.Markdown("📝SSML Input")
gr.Markdown(f"- 最长{webui_config['ssml_max']:,}字符,超过会被截断")
gr.Markdown("- 尽量保证使用相同的 seed")
gr.Markdown(
"- 关于SSML可以看这个 [文档](https://github.com/lenML/ChatTTS-Forge/blob/main/docs/SSML.md)"
)
ssml_input = gr.Textbox(
label="SSML Input",
lines=10,
value=default_ssml,
placeholder="输入 SSML 或选择示例",
elem_id="ssml_input",
show_label=False,
)
ssml_button = gr.Button("🔊Synthesize SSML", variant="primary")
with gr.Column(scale=1):
with gr.Group():
# 参数
gr.Markdown("🎛️Parameters")
# batch size
batch_size_input = gr.Slider(
label="Batch Size",
value=4,
minimum=1,
maximum=webui_config["max_batch_size"],
step=1,
)
with gr.Group():
gr.Markdown("🎄Examples")
gr.Examples(
examples=examples,
inputs=[ssml_input],
)
ssml_output = gr.Audio(label="Generated Audio")
ssml_button.click(
synthesize_ssml,
inputs=[ssml_input, batch_size_input],
outputs=ssml_output,
)
return ssml_input
def split_long_text(long_text_input):
spliter = SentenceSplitter(webui_config["spliter_threshold"])
sentences = spliter.parse(long_text_input)
sentences = [text_normalize(s) for s in sentences]
data = []
for i, text in enumerate(sentences):
data.append([i, text, len(text)])
return data
def merge_dataframe_to_ssml(dataframe, spk, style, seed):
if style == "*auto":
style = None
if spk == "-1" or spk == -1:
spk = None
if seed == -1 or seed == "-1":
seed = None
ssml = ""
indent = " " * 2
for i, row in dataframe.iterrows():
ssml += f"{indent}<voice"
if spk:
ssml += f' spk="{spk}"'
if style:
ssml += f' style="{style}"'
if seed:
ssml += f' seed="{seed}"'
ssml += ">\n"
ssml += f"{indent}{indent}{text_normalize(row[1])}\n"
ssml += f"{indent}</voice>\n"
return f"<speak version='0.1'>\n{ssml}</speak>"
# 长文本处理
# 可以输入长文本,并选择切割方法,切割之后可以将拼接的SSML发送到SSML tab
# 根据 。 句号切割,切割之后显示到 data table
def create_long_content_tab(ssml_input, tabs):
speakers = get_speakers()
def get_speaker_show_name(spk):
if spk.gender == "*" or spk.gender == "":
return spk.name
return f"{spk.gender} : {spk.name}"
speaker_names = ["*random"] + [
get_speaker_show_name(speaker) for speaker in speakers
]
styles = ["*auto"] + [s.get("name") for s in get_styles()]
with gr.Row():
with gr.Column(scale=1):
# 选择说话人 选择风格 选择seed
with gr.Group():
gr.Markdown("🗣️Speaker")
spk_input_text = gr.Textbox(
label="Speaker (Text or Seed)",
value="female2",
show_label=False,
)
spk_input_dropdown = gr.Dropdown(
choices=speaker_names,
interactive=True,
value="female : female2",
show_label=False,
)
spk_rand_button = gr.Button(
value="🎲",
variant="secondary",
)
with gr.Group():
gr.Markdown("🎭Style")
style_input_dropdown = gr.Dropdown(
choices=styles,
interactive=True,
show_label=False,
value="*auto",
)
with gr.Group():
gr.Markdown("🗣️Seed")
infer_seed_input = gr.Number(
value=42,
label="Inference Seed",
show_label=False,
minimum=-1,
maximum=2**32 - 1,
)
infer_seed_rand_button = gr.Button(
value="🎲",
variant="secondary",
)
send_btn = gr.Button("📩Send to SSML", variant="primary")
with gr.Column(scale=3):
with gr.Group():
gr.Markdown("📝Long Text Input")
gr.Markdown("- 此页面用于处理超长文本")
gr.Markdown("- 切割后,可以选择说话人、风格、seed,然后发送到SSML")
long_text_input = gr.Textbox(
label="Long Text Input",
lines=10,
placeholder="输入长文本",
elem_id="long-text-input",
show_label=False,
)
long_text_split_button = gr.Button("🔪Split Text")
with gr.Row():
with gr.Column(scale=3):
with gr.Group():
gr.Markdown("🎨Output")
long_text_output = gr.DataFrame(
headers=["index", "text", "length"],
datatype=["number", "str", "number"],
elem_id="long-text-output",
interactive=False,
wrap=True,
value=[],
)
spk_input_dropdown.change(
fn=lambda x: x.startswith("*") and "-1" or x.split(":")[-1].strip(),
inputs=[spk_input_dropdown],
outputs=[spk_input_text],
)
spk_rand_button.click(
lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
inputs=[spk_input_text],
outputs=[spk_input_text],
)
infer_seed_rand_button.click(
lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
inputs=[infer_seed_input],
outputs=[infer_seed_input],
)
long_text_split_button.click(
split_long_text,
inputs=[long_text_input],
outputs=[long_text_output],
)
infer_seed_rand_button.click(
lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
inputs=[infer_seed_input],
outputs=[infer_seed_input],
)
send_btn.click(
merge_dataframe_to_ssml,
inputs=[
long_text_output,
spk_input_text,
style_input_dropdown,
infer_seed_input,
],
outputs=[ssml_input],
)
def change_tab():
return gr.Tabs(selected="ssml")
send_btn.click(change_tab, inputs=[], outputs=[tabs])
def create_readme_tab():
readme_content = read_local_readme()
gr.Markdown(readme_content)
def create_interface():
js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'dark') {
url.searchParams.set('__theme', 'dark');
window.location.href = url.href;
}
}
"""
head_js = """
<script>
</script>
"""
with gr.Blocks(js=js_func, head=head_js, title="ChatTTS Forge WebUI") as demo:
css = """
<style>
.big-button {
height: 80px;
}
#input_title div.eta-bar {
display: none !important; transform: none !important;
}
</style>
"""
gr.HTML(css)
with gr.Tabs() as tabs:
with gr.TabItem("TTS"):
create_tts_interface()
with gr.TabItem("SSML", id="ssml"):
ssml_input = create_ssml_interface()
with gr.TabItem("Long Text"):
create_long_content_tab(ssml_input, tabs=tabs)
with gr.TabItem("README"):
create_readme_tab()
gr.Markdown(
"此项目基于 [ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge) "
)
return demo
if __name__ == "__main__":
import argparse
import dotenv
dotenv.load_dotenv(
dotenv_path=os.getenv("ENV_FILE", ".env.webui"),
)
parser = argparse.ArgumentParser(description="Gradio App")
parser.add_argument("--server_name", type=str, help="server name")
parser.add_argument("--server_port", type=int, help="server port")
parser.add_argument(
"--share", action="store_true", help="share the gradio interface"
)
parser.add_argument("--debug", action="store_true", help="enable debug mode")
parser.add_argument("--auth", type=str, help="username:password for authentication")
parser.add_argument(
"--half",
action="store_true",
help="Enable half precision for model inference",
)
parser.add_argument(
"--off_tqdm",
action="store_true",
help="Disable tqdm progress bar",
)
parser.add_argument(
"--tts_max_len",
type=int,
help="Max length of text for TTS",
)
parser.add_argument(
"--ssml_max_len",
type=int,
help="Max length of text for SSML",
)
parser.add_argument(
"--max_batch_size",
type=int,
help="Max batch size for TTS",
)
parser.add_argument(
"--lru_size",
type=int,
default=64,
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
)
parser.add_argument(
"--device_id",
type=str,
help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
default=None,
)
parser.add_argument(
"--use_cpu",
nargs="+",
help="use CPU as torch device for specified modules",
default=[],
type=str.lower,
)
parser.add_argument("--compile", action="store_true", help="Enable model compile")
args = parser.parse_args()
def get_and_update_env(*args):
val = env.get_env_or_arg(*args)
key = args[1]
config.runtime_env_vars[key] = val
return val
server_name = get_and_update_env(args, "server_name", "0.0.0.0", str)
server_port = get_and_update_env(args, "server_port", 7860, int)
share = get_and_update_env(args, "share", False, bool)
debug = get_and_update_env(args, "debug", False, bool)
auth = get_and_update_env(args, "auth", None, str)
half = get_and_update_env(args, "half", False, bool)
off_tqdm = get_and_update_env(args, "off_tqdm", False, bool)
lru_size = get_and_update_env(args, "lru_size", 64, int)
device_id = get_and_update_env(args, "device_id", None, str)
use_cpu = get_and_update_env(args, "use_cpu", [], list)
compile = get_and_update_env(args, "compile", False, bool)
webui_config["tts_max"] = get_and_update_env(args, "tts_max_len", 1000, int)
webui_config["ssml_max"] = get_and_update_env(args, "ssml_max_len", 5000, int)
webui_config["max_batch_size"] = get_and_update_env(args, "max_batch_size", 8, int)
demo = create_interface()
if auth:
auth = tuple(auth.split(":"))
generate.setup_lru_cache()
devices.reset_device()
devices.first_time_calculation()
demo.queue().launch(
server_name=server_name,
server_port=server_port,
share=share,
debug=debug,
auth=auth,
show_api=False,
)
|