BERT-ISTFT-VITS / app.py
guetLzy's picture
Create app.py
edb5ebf verified
raw
history blame
4.91 kB
import gradio as gr
import torch
import numpy as np
from scipy.io.wavfile import write
import os
import commons
import utils
from models import SynthesizerTrn
from text.symbols import symbols
from text import get_bert, cleaned_text_to_sequence
from text.cleaner import clean_text
from huggingface_hub import hf_hub_download, snapshot_download
# 模型配置
MODEL_CONFIG = {
"roberta": {
"repo_id": "hfl/chinese-roberta-wwm-ext-large"
},
"vits": {
"repo_id": "guetLzy/BERT-ISTFT-VITS-Model",
"files": ["G_1000.pth", "config.json"] # 根据实际文件名调整
}
}
# 设备设置
device = "cuda" if torch.cuda.is_available() else "cpu"
# 可用的模型选项
MODEL_OPTIONS = {
"VITS_Model": "models/G_1000.pth",
}
def download_models():
"""下载所有需要的模型文件"""
os.makedirs("./bert/chinese-roberta-wwm-ext-large", exist_ok=True)
os.makedirs("./models", exist_ok=True)
# 下载 RoBERTa 模型
roberta_paths = {}
for file in MODEL_CONFIG["roberta"]["files"]:
path = hf_hub_download(
repo_id=MODEL_CONFIG["roberta"]["repo_id"],
filename=file,
local_dir="./bert/chinese-roberta-wwm-ext-large",
resume_download=True
)
roberta_paths[file] = path
# 下载 VITS 模型
vits_paths = {}
for model_name, model_path in MODEL_OPTIONS.items():
path = hf_hub_download(
repo_id=MODEL_CONFIG["vits"]["repo_id"],
filename=os.path.basename(model_path),
local_dir="./models",
resume_download=True
)
vits_paths[model_name] = path
return {
"roberta": roberta_paths,
"vits": vits_paths
}
# 在程序启动时下载模型
model_paths = download_models()
# 加载配置和模型
hps = utils.get_hparams_from_file("configs/1.json")
net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device)
_ = net_g.eval()
# 加载下载的 VITS 模型权重
_ = utils.load_checkpoint(model_paths["vits"]["VITS_Model"], net_g, None)
def get_text(text, hps, language_str="ZH"):
norm_text, phone, tone, word2ph = clean_text(text, language_str)
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
if hps.data.add_blank:
phone = commons.intersperse(phone, 0)
tone = commons.intersperse(tone, 0)
language = commons.intersperse(language, 0)
for i in range(len(word2ph)):
word2ph[i] = word2ph[i] * 2
word2ph[0] += 1
if hps.data.use_bert:
bert = get_bert(norm_text, word2ph, language_str, device)
del word2ph
assert bert.shape[-1] == len(phone)
if language_str == "ZH":
bert = bert
else:
bert = torch.zeros(1024, len(phone))
else:
bert = torch.zeros(1024, len(phone))
phone = torch.LongTensor(phone)
tone = torch.LongTensor(tone)
language = torch.LongTensor(language)
return bert, phone, tone, language
def generate_audio(text, noise_scale=1.0, noise_scale_w=0.8, length_scale=1.0):
bert, phones, tones, language_id = get_text(text, hps)
with torch.no_grad():
x_tst = phones.to(device).unsqueeze(0)
tones = tones.to(device).unsqueeze(0)
language_id = language_id.to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
sid = torch.LongTensor([hps.data.spk2id["single"]]).to(device)
audio = (
net_g.infer(
x_tst,
x_tst_lengths,
sid,
tones,
language_id,
bert,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
)[0][0, 0]
.data.cpu()
.float()
.numpy()
)
output_path = "output.wav"
write(output_path, 22050, (audio * 32767.0).astype(np.int16))
return output_path
# 创建 Gradio 界面
interface = gr.Interface(
fn=generate_audio,
inputs=[
gr.Textbox(label="输入文本", placeholder="请输入中文文本..."),
gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Noise Scale"),
gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.8, label="Noise Scale W"),
gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Scale")
],
outputs=gr.Audio(label="生成的音频"),
title="中文文本转语音",
description="输入中文文本并调整参数以生成语音。支持调整噪声和语速参数。",
theme="default"
)
# 启动界面
interface.launch()