BERT-ISTFT-VITS / app.py
guetLzy's picture
Update app.py
1d529a1 verified
import os
import subprocess
# 编译 monotonic_align
def compile_monotonic_align():
# 检查是否已编译
if not os.path.exists("monotonic_align/monotonic_align/core.cpython-*.so"):
print("正在编译 monotonic_align...")
# 假设 monotonic_align 文件夹已存在
if not os.path.exists("monotonic_align"):
raise FileNotFoundError("monotonic_align 文件夹未找到!请确保它存在于根目录中。")
os.chdir("monotonic_align")
os.makedirs("monotonic_align", exist_ok=True) # 创建 monotonic_align 子目录
subprocess.run(["python", "setup.py", "build_ext", "--inplace"], check=True)
os.chdir("..")
print("monotonic_align 编译成功!")
else:
print("monotonic_align 已编译,跳过...")
# 在程序启动时编译
compile_monotonic_align()
import gradio as gr
import torch
import numpy as np
from scipy.io.wavfile import write
import commons
import utils
from models import SynthesizerTrn
from text.symbols import symbols
from text import get_bert, cleaned_text_to_sequence
from text.cleaner import clean_text
from huggingface_hub import hf_hub_download, snapshot_download
# 模型配置
MODEL_CONFIG = {
"roberta": {
"repo_id": "hfl/chinese-roberta-wwm-ext-large"
},
"vits": {
"repo_id": "guetLzy/BERT-ISTFT-VITS-Model",
"files": ["G_1000.pth"] # 根据实际文件名调整
}
}
# 设备设置
device = "cuda" if torch.cuda.is_available() else "cpu"
# 可用的模型选项
MODEL_OPTIONS = {
"VITS_Model": "models/G_1000.pth",
}
def download_models():
"""下载所有需要的模型文件"""
os.makedirs("./bert/chinese-roberta-wwm-ext-large", exist_ok=True) # 创建 RoBERTa 模型存储目录
os.makedirs("./models", exist_ok=True) # 创建 VITS 模型存储目录
# 下载 RoBERTa 模型(所有文件)
roberta_path = snapshot_download(
repo_id=MODEL_CONFIG["roberta"]["repo_id"],
local_dir="./bert/chinese-roberta-wwm-ext-large",
resume_download=True # 支持断点续传
)
roberta_paths = {"repo_dir": roberta_path} # 返回整个文件夹路径
# 下载 VITS 模型(指定文件)
vits_paths = {}
for model_name, model_path in MODEL_OPTIONS.items():
path = hf_hub_download(
repo_id=MODEL_CONFIG["vits"]["repo_id"],
filename=os.path.basename(model_path),
local_dir="./models",
resume_download=True # 支持断点续传
)
vits_paths[model_name] = path
return {
"roberta": roberta_paths,
"vits": vits_paths
}
# 在程序启动时下载模型
model_paths = download_models()
# 加载配置和模型
hps = utils.get_hparams_from_file("configs/1.json") # 从配置文件加载超参数
net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device) # 初始化 SynthesizerTrn 模型并移到指定设备
_ = net_g.eval() # 设置模型为评估模式
# 加载下载的 VITS 模型权重
_ = utils.load_checkpoint(model_paths["vits"]["VITS_Model"], net_g, None)
def get_text(text, hps, language_str="ZH"):
"""处理输入文本,生成语音所需的序列"""
norm_text, phone, tone, word2ph = clean_text(text, language_str) # 清理文本
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) # 转换为序列
if hps.data.add_blank:
phone = commons.intersperse(phone, 0) # 在序列中插入空白
tone = commons.intersperse(tone, 0)
language = commons.intersperse(language, 0)
for i in range(len(word2ph)):
word2ph[i] = word2ph[i] * 2
word2ph[0] += 1
if hps.data.use_bert:
bert = get_bert(norm_text, word2ph, language_str, device) # 获取 BERT 特征
del word2ph
assert bert.shape[-1] == len(phone) # 确保 BERT 特征长度与 phone 一致
if language_str == "ZH":
bert = bert
else:
bert = torch.zeros(1024, len(phone)) # 非中文时使用零填充
else:
bert = torch.zeros(1024, len(phone)) # 不使用 BERT 时填充零
phone = torch.LongTensor(phone) # 转换为张量
tone = torch.LongTensor(tone)
language = torch.LongTensor(language)
return bert, phone, tone, language
def generate_audio(text, noise_scale=1.0, noise_scale_w=0.8, length_scale=1.0):
"""生成音频文件"""
bert, phones, tones, language_id = get_text(text, hps) # 获取处理后的文本数据
with torch.no_grad(): # 不计算梯度
x_tst = phones.to(device).unsqueeze(0) # 输入序列
tones = tones.to(device).unsqueeze(0) # 音调
language_id = language_id.to(device).unsqueeze(0) # 语言标识
bert = bert.to(device).unsqueeze(0) # BERT 特征
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) # 输入长度
sid = torch.LongTensor([hps.data.spk2id["SSB0005"]]).to(device) # 说话者 ID
audio = (
net_g.infer(
x_tst,
x_tst_lengths,
sid,
tones,
language_id,
bert,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
)[0][0, 0]
.data.cpu() # 将结果移到 CPU
.float()
.numpy() # 转换为 numpy 数组
)
output_path = "output.wav" # 输出音频文件路径
write(output_path, 22050, (audio * 32767.0).astype(np.int16)) # 保存为 WAV 文件
return output_path
with gr.Blocks(
title="BERT-ISTFT-VITS中文语音合成系统",
theme="NoCrypt/miku"
) as interface:
# 标题和描述
gr.Markdown("# BERT-ISTFT-VITS中文语音合成系统")
gr.Markdown("输入中文文本并调整参数以生成语音。支持调整噪声和语速参数。")
# 主布局:两列设计
with gr.Row():
# 左侧:输入区域
with gr.Column(scale=1):
# 文本输入框
text_input = gr.Textbox(
label="输入文本",
value="桂林电子科技大学",
placeholder="请输入中文文本...",
lines=5, # 增加行数,便于输入长文本
)
# 参数调整分组
with gr.Group():
gr.Markdown("### 参数调整") # 分组标题
noise_scale = gr.Slider(
minimum=0.1,
maximum=1,
step=0.1,
value=0.667,
label="噪声比例",
info="控制生成音频的噪声水平"
)
noise_scale_w = gr.Slider(
minimum=0.1,
maximum=1,
step=0.1,
value=1.0,
label="噪声比例 W",
info="控制音调的噪声影响"
)
length_scale = gr.Slider(
minimum=0.1,
maximum=1,
step=0.1,
value=1.0,
label="语速比例",
info="调整语音的播放速度"
)
# 右侧:输出区域
with gr.Column(scale=1):
audio_output = gr.Audio(
label="生成的音频",
type="filepath", # 返回文件路径
interactive=False # 禁止用户编辑音频
)
# 生成按钮
generate_btn = gr.Button("生成语音", variant="primary")
# 绑定生成函数
generate_btn.click(
fn=generate_audio,
inputs=[text_input, noise_scale, noise_scale_w, length_scale],
outputs=audio_output
)
# 启动界面
interface.launch()