Spaces:
Running
Running
import os | |
import subprocess | |
# 编译 monotonic_align | |
def compile_monotonic_align(): | |
# 检查是否已编译 | |
if not os.path.exists("monotonic_align/monotonic_align/core.cpython-*.so"): | |
print("正在编译 monotonic_align...") | |
# 假设 monotonic_align 文件夹已存在 | |
if not os.path.exists("monotonic_align"): | |
raise FileNotFoundError("monotonic_align 文件夹未找到!请确保它存在于根目录中。") | |
os.chdir("monotonic_align") | |
os.makedirs("monotonic_align", exist_ok=True) # 创建 monotonic_align 子目录 | |
subprocess.run(["python", "setup.py", "build_ext", "--inplace"], check=True) | |
os.chdir("..") | |
print("monotonic_align 编译成功!") | |
else: | |
print("monotonic_align 已编译,跳过...") | |
# 在程序启动时编译 | |
compile_monotonic_align() | |
import gradio as gr | |
import torch | |
import numpy as np | |
from scipy.io.wavfile import write | |
import commons | |
import utils | |
from models import SynthesizerTrn | |
from text.symbols import symbols | |
from text import get_bert, cleaned_text_to_sequence | |
from text.cleaner import clean_text | |
from huggingface_hub import hf_hub_download, snapshot_download | |
# 模型配置 | |
MODEL_CONFIG = { | |
"roberta": { | |
"repo_id": "hfl/chinese-roberta-wwm-ext-large" | |
}, | |
"vits": { | |
"repo_id": "guetLzy/BERT-ISTFT-VITS-Model", | |
"files": ["G_1000.pth"] # 根据实际文件名调整 | |
} | |
} | |
# 设备设置 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# 可用的模型选项 | |
MODEL_OPTIONS = { | |
"VITS_Model": "models/G_1000.pth", | |
} | |
def download_models(): | |
"""下载所有需要的模型文件""" | |
os.makedirs("./bert/chinese-roberta-wwm-ext-large", exist_ok=True) # 创建 RoBERTa 模型存储目录 | |
os.makedirs("./models", exist_ok=True) # 创建 VITS 模型存储目录 | |
# 下载 RoBERTa 模型(所有文件) | |
roberta_path = snapshot_download( | |
repo_id=MODEL_CONFIG["roberta"]["repo_id"], | |
local_dir="./bert/chinese-roberta-wwm-ext-large", | |
resume_download=True # 支持断点续传 | |
) | |
roberta_paths = {"repo_dir": roberta_path} # 返回整个文件夹路径 | |
# 下载 VITS 模型(指定文件) | |
vits_paths = {} | |
for model_name, model_path in MODEL_OPTIONS.items(): | |
path = hf_hub_download( | |
repo_id=MODEL_CONFIG["vits"]["repo_id"], | |
filename=os.path.basename(model_path), | |
local_dir="./models", | |
resume_download=True # 支持断点续传 | |
) | |
vits_paths[model_name] = path | |
return { | |
"roberta": roberta_paths, | |
"vits": vits_paths | |
} | |
# 在程序启动时下载模型 | |
model_paths = download_models() | |
# 加载配置和模型 | |
hps = utils.get_hparams_from_file("configs/1.json") # 从配置文件加载超参数 | |
net_g = SynthesizerTrn( | |
len(symbols), | |
hps.data.filter_length // 2 + 1, | |
hps.train.segment_size // hps.data.hop_length, | |
n_speakers=hps.data.n_speakers, | |
**hps.model, | |
).to(device) # 初始化 SynthesizerTrn 模型并移到指定设备 | |
_ = net_g.eval() # 设置模型为评估模式 | |
# 加载下载的 VITS 模型权重 | |
_ = utils.load_checkpoint(model_paths["vits"]["VITS_Model"], net_g, None) | |
def get_text(text, hps, language_str="ZH"): | |
"""处理输入文本,生成语音所需的序列""" | |
norm_text, phone, tone, word2ph = clean_text(text, language_str) # 清理文本 | |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) # 转换为序列 | |
if hps.data.add_blank: | |
phone = commons.intersperse(phone, 0) # 在序列中插入空白 | |
tone = commons.intersperse(tone, 0) | |
language = commons.intersperse(language, 0) | |
for i in range(len(word2ph)): | |
word2ph[i] = word2ph[i] * 2 | |
word2ph[0] += 1 | |
if hps.data.use_bert: | |
bert = get_bert(norm_text, word2ph, language_str, device) # 获取 BERT 特征 | |
del word2ph | |
assert bert.shape[-1] == len(phone) # 确保 BERT 特征长度与 phone 一致 | |
if language_str == "ZH": | |
bert = bert | |
else: | |
bert = torch.zeros(1024, len(phone)) # 非中文时使用零填充 | |
else: | |
bert = torch.zeros(1024, len(phone)) # 不使用 BERT 时填充零 | |
phone = torch.LongTensor(phone) # 转换为张量 | |
tone = torch.LongTensor(tone) | |
language = torch.LongTensor(language) | |
return bert, phone, tone, language | |
def generate_audio(text, noise_scale=1.0, noise_scale_w=0.8, length_scale=1.0): | |
"""生成音频文件""" | |
bert, phones, tones, language_id = get_text(text, hps) # 获取处理后的文本数据 | |
with torch.no_grad(): # 不计算梯度 | |
x_tst = phones.to(device).unsqueeze(0) # 输入序列 | |
tones = tones.to(device).unsqueeze(0) # 音调 | |
language_id = language_id.to(device).unsqueeze(0) # 语言标识 | |
bert = bert.to(device).unsqueeze(0) # BERT 特征 | |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) # 输入长度 | |
sid = torch.LongTensor([hps.data.spk2id["SSB0005"]]).to(device) # 说话者 ID | |
audio = ( | |
net_g.infer( | |
x_tst, | |
x_tst_lengths, | |
sid, | |
tones, | |
language_id, | |
bert, | |
noise_scale=noise_scale, | |
noise_scale_w=noise_scale_w, | |
length_scale=length_scale, | |
)[0][0, 0] | |
.data.cpu() # 将结果移到 CPU | |
.float() | |
.numpy() # 转换为 numpy 数组 | |
) | |
output_path = "output.wav" # 输出音频文件路径 | |
write(output_path, 22050, (audio * 32767.0).astype(np.int16)) # 保存为 WAV 文件 | |
return output_path | |
with gr.Blocks( | |
title="BERT-ISTFT-VITS中文语音合成系统", | |
theme="NoCrypt/miku" | |
) as interface: | |
# 标题和描述 | |
gr.Markdown("# BERT-ISTFT-VITS中文语音合成系统") | |
gr.Markdown("输入中文文本并调整参数以生成语音。支持调整噪声和语速参数。") | |
# 主布局:两列设计 | |
with gr.Row(): | |
# 左侧:输入区域 | |
with gr.Column(scale=1): | |
# 文本输入框 | |
text_input = gr.Textbox( | |
label="输入文本", | |
value="桂林电子科技大学", | |
placeholder="请输入中文文本...", | |
lines=5, # 增加行数,便于输入长文本 | |
) | |
# 参数调整分组 | |
with gr.Group(): | |
gr.Markdown("### 参数调整") # 分组标题 | |
noise_scale = gr.Slider( | |
minimum=0.1, | |
maximum=1, | |
step=0.1, | |
value=0.667, | |
label="噪声比例", | |
info="控制生成音频的噪声水平" | |
) | |
noise_scale_w = gr.Slider( | |
minimum=0.1, | |
maximum=1, | |
step=0.1, | |
value=1.0, | |
label="噪声比例 W", | |
info="控制音调的噪声影响" | |
) | |
length_scale = gr.Slider( | |
minimum=0.1, | |
maximum=1, | |
step=0.1, | |
value=1.0, | |
label="语速比例", | |
info="调整语音的播放速度" | |
) | |
# 右侧:输出区域 | |
with gr.Column(scale=1): | |
audio_output = gr.Audio( | |
label="生成的音频", | |
type="filepath", # 返回文件路径 | |
interactive=False # 禁止用户编辑音频 | |
) | |
# 生成按钮 | |
generate_btn = gr.Button("生成语音", variant="primary") | |
# 绑定生成函数 | |
generate_btn.click( | |
fn=generate_audio, | |
inputs=[text_input, noise_scale, noise_scale_w, length_scale], | |
outputs=audio_output | |
) | |
# 启动界面 | |
interface.launch() |