Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
from scipy.io.wavfile import write | |
import os | |
import commons | |
import utils | |
from models import SynthesizerTrn | |
from text.symbols import symbols | |
from text import get_bert, cleaned_text_to_sequence | |
from text.cleaner import clean_text | |
from huggingface_hub import hf_hub_download, snapshot_download | |
# 模型配置 | |
MODEL_CONFIG = { | |
"roberta": { | |
"repo_id": "hfl/chinese-roberta-wwm-ext-large" | |
}, | |
"vits": { | |
"repo_id": "guetLzy/BERT-ISTFT-VITS-Model", | |
"files": ["G_1000.pth", "config.json"] # 根据实际文件名调整 | |
} | |
} | |
# 设备设置 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# 可用的模型选项 | |
MODEL_OPTIONS = { | |
"VITS_Model": "models/G_1000.pth", | |
} | |
def download_models(): | |
"""下载所有需要的模型文件""" | |
os.makedirs("./bert/chinese-roberta-wwm-ext-large", exist_ok=True) | |
os.makedirs("./models", exist_ok=True) | |
# 下载 RoBERTa 模型 | |
roberta_paths = {} | |
for file in MODEL_CONFIG["roberta"]["files"]: | |
path = hf_hub_download( | |
repo_id=MODEL_CONFIG["roberta"]["repo_id"], | |
filename=file, | |
local_dir="./bert/chinese-roberta-wwm-ext-large", | |
resume_download=True | |
) | |
roberta_paths[file] = path | |
# 下载 VITS 模型 | |
vits_paths = {} | |
for model_name, model_path in MODEL_OPTIONS.items(): | |
path = hf_hub_download( | |
repo_id=MODEL_CONFIG["vits"]["repo_id"], | |
filename=os.path.basename(model_path), | |
local_dir="./models", | |
resume_download=True | |
) | |
vits_paths[model_name] = path | |
return { | |
"roberta": roberta_paths, | |
"vits": vits_paths | |
} | |
# 在程序启动时下载模型 | |
model_paths = download_models() | |
# 加载配置和模型 | |
hps = utils.get_hparams_from_file("configs/1.json") | |
net_g = SynthesizerTrn( | |
len(symbols), | |
hps.data.filter_length // 2 + 1, | |
hps.train.segment_size // hps.data.hop_length, | |
n_speakers=hps.data.n_speakers, | |
**hps.model, | |
).to(device) | |
_ = net_g.eval() | |
# 加载下载的 VITS 模型权重 | |
_ = utils.load_checkpoint(model_paths["vits"]["VITS_Model"], net_g, None) | |
def get_text(text, hps, language_str="ZH"): | |
norm_text, phone, tone, word2ph = clean_text(text, language_str) | |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) | |
if hps.data.add_blank: | |
phone = commons.intersperse(phone, 0) | |
tone = commons.intersperse(tone, 0) | |
language = commons.intersperse(language, 0) | |
for i in range(len(word2ph)): | |
word2ph[i] = word2ph[i] * 2 | |
word2ph[0] += 1 | |
if hps.data.use_bert: | |
bert = get_bert(norm_text, word2ph, language_str, device) | |
del word2ph | |
assert bert.shape[-1] == len(phone) | |
if language_str == "ZH": | |
bert = bert | |
else: | |
bert = torch.zeros(1024, len(phone)) | |
else: | |
bert = torch.zeros(1024, len(phone)) | |
phone = torch.LongTensor(phone) | |
tone = torch.LongTensor(tone) | |
language = torch.LongTensor(language) | |
return bert, phone, tone, language | |
def generate_audio(text, noise_scale=1.0, noise_scale_w=0.8, length_scale=1.0): | |
bert, phones, tones, language_id = get_text(text, hps) | |
with torch.no_grad(): | |
x_tst = phones.to(device).unsqueeze(0) | |
tones = tones.to(device).unsqueeze(0) | |
language_id = language_id.to(device).unsqueeze(0) | |
bert = bert.to(device).unsqueeze(0) | |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) | |
sid = torch.LongTensor([hps.data.spk2id["single"]]).to(device) | |
audio = ( | |
net_g.infer( | |
x_tst, | |
x_tst_lengths, | |
sid, | |
tones, | |
language_id, | |
bert, | |
noise_scale=noise_scale, | |
noise_scale_w=noise_scale_w, | |
length_scale=length_scale, | |
)[0][0, 0] | |
.data.cpu() | |
.float() | |
.numpy() | |
) | |
output_path = "output.wav" | |
write(output_path, 22050, (audio * 32767.0).astype(np.int16)) | |
return output_path | |
# 创建 Gradio 界面 | |
interface = gr.Interface( | |
fn=generate_audio, | |
inputs=[ | |
gr.Textbox(label="输入文本", placeholder="请输入中文文本..."), | |
gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Noise Scale"), | |
gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.8, label="Noise Scale W"), | |
gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Scale") | |
], | |
outputs=gr.Audio(label="生成的音频"), | |
title="中文文本转语音", | |
description="输入中文文本并调整参数以生成语音。支持调整噪声和语速参数。", | |
theme="default" | |
) | |
# 启动界面 | |
interface.launch() |