Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from scipy.io.wavfile import write
|
5 |
+
import os
|
6 |
+
import commons
|
7 |
+
import utils
|
8 |
+
from models import SynthesizerTrn
|
9 |
+
from text.symbols import symbols
|
10 |
+
from text import get_bert, cleaned_text_to_sequence
|
11 |
+
from text.cleaner import clean_text
|
12 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
13 |
+
|
14 |
+
# 模型配置
|
15 |
+
MODEL_CONFIG = {
|
16 |
+
"roberta": {
|
17 |
+
"repo_id": "hfl/chinese-roberta-wwm-ext-large"
|
18 |
+
},
|
19 |
+
"vits": {
|
20 |
+
"repo_id": "guetLzy/BERT-ISTFT-VITS-Model",
|
21 |
+
"files": ["G_1000.pth", "config.json"] # 根据实际文件名调整
|
22 |
+
}
|
23 |
+
}
|
24 |
+
|
25 |
+
# 设备设置
|
26 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
27 |
+
|
28 |
+
# 可用的模型选项
|
29 |
+
MODEL_OPTIONS = {
|
30 |
+
"VITS_Model": "models/G_1000.pth",
|
31 |
+
}
|
32 |
+
|
33 |
+
def download_models():
|
34 |
+
"""下载所有需要的模型文件"""
|
35 |
+
os.makedirs("./bert/chinese-roberta-wwm-ext-large", exist_ok=True)
|
36 |
+
os.makedirs("./models", exist_ok=True)
|
37 |
+
|
38 |
+
# 下载 RoBERTa 模型
|
39 |
+
roberta_paths = {}
|
40 |
+
for file in MODEL_CONFIG["roberta"]["files"]:
|
41 |
+
path = hf_hub_download(
|
42 |
+
repo_id=MODEL_CONFIG["roberta"]["repo_id"],
|
43 |
+
filename=file,
|
44 |
+
local_dir="./bert/chinese-roberta-wwm-ext-large",
|
45 |
+
resume_download=True
|
46 |
+
)
|
47 |
+
roberta_paths[file] = path
|
48 |
+
|
49 |
+
# 下载 VITS 模型
|
50 |
+
vits_paths = {}
|
51 |
+
for model_name, model_path in MODEL_OPTIONS.items():
|
52 |
+
path = hf_hub_download(
|
53 |
+
repo_id=MODEL_CONFIG["vits"]["repo_id"],
|
54 |
+
filename=os.path.basename(model_path),
|
55 |
+
local_dir="./models",
|
56 |
+
resume_download=True
|
57 |
+
)
|
58 |
+
vits_paths[model_name] = path
|
59 |
+
|
60 |
+
return {
|
61 |
+
"roberta": roberta_paths,
|
62 |
+
"vits": vits_paths
|
63 |
+
}
|
64 |
+
|
65 |
+
# 在程序启动时下载模型
|
66 |
+
model_paths = download_models()
|
67 |
+
|
68 |
+
# 加载配置和模型
|
69 |
+
hps = utils.get_hparams_from_file("configs/1.json")
|
70 |
+
|
71 |
+
net_g = SynthesizerTrn(
|
72 |
+
len(symbols),
|
73 |
+
hps.data.filter_length // 2 + 1,
|
74 |
+
hps.train.segment_size // hps.data.hop_length,
|
75 |
+
n_speakers=hps.data.n_speakers,
|
76 |
+
**hps.model,
|
77 |
+
).to(device)
|
78 |
+
_ = net_g.eval()
|
79 |
+
|
80 |
+
# 加载下载的 VITS 模型权重
|
81 |
+
_ = utils.load_checkpoint(model_paths["vits"]["VITS_Model"], net_g, None)
|
82 |
+
|
83 |
+
def get_text(text, hps, language_str="ZH"):
|
84 |
+
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
85 |
+
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
86 |
+
|
87 |
+
if hps.data.add_blank:
|
88 |
+
phone = commons.intersperse(phone, 0)
|
89 |
+
tone = commons.intersperse(tone, 0)
|
90 |
+
language = commons.intersperse(language, 0)
|
91 |
+
for i in range(len(word2ph)):
|
92 |
+
word2ph[i] = word2ph[i] * 2
|
93 |
+
word2ph[0] += 1
|
94 |
+
|
95 |
+
if hps.data.use_bert:
|
96 |
+
bert = get_bert(norm_text, word2ph, language_str, device)
|
97 |
+
del word2ph
|
98 |
+
assert bert.shape[-1] == len(phone)
|
99 |
+
if language_str == "ZH":
|
100 |
+
bert = bert
|
101 |
+
else:
|
102 |
+
bert = torch.zeros(1024, len(phone))
|
103 |
+
else:
|
104 |
+
bert = torch.zeros(1024, len(phone))
|
105 |
+
|
106 |
+
phone = torch.LongTensor(phone)
|
107 |
+
tone = torch.LongTensor(tone)
|
108 |
+
language = torch.LongTensor(language)
|
109 |
+
return bert, phone, tone, language
|
110 |
+
|
111 |
+
def generate_audio(text, noise_scale=1.0, noise_scale_w=0.8, length_scale=1.0):
|
112 |
+
bert, phones, tones, language_id = get_text(text, hps)
|
113 |
+
|
114 |
+
with torch.no_grad():
|
115 |
+
x_tst = phones.to(device).unsqueeze(0)
|
116 |
+
tones = tones.to(device).unsqueeze(0)
|
117 |
+
language_id = language_id.to(device).unsqueeze(0)
|
118 |
+
bert = bert.to(device).unsqueeze(0)
|
119 |
+
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
120 |
+
sid = torch.LongTensor([hps.data.spk2id["single"]]).to(device)
|
121 |
+
|
122 |
+
audio = (
|
123 |
+
net_g.infer(
|
124 |
+
x_tst,
|
125 |
+
x_tst_lengths,
|
126 |
+
sid,
|
127 |
+
tones,
|
128 |
+
language_id,
|
129 |
+
bert,
|
130 |
+
noise_scale=noise_scale,
|
131 |
+
noise_scale_w=noise_scale_w,
|
132 |
+
length_scale=length_scale,
|
133 |
+
)[0][0, 0]
|
134 |
+
.data.cpu()
|
135 |
+
.float()
|
136 |
+
.numpy()
|
137 |
+
)
|
138 |
+
|
139 |
+
output_path = "output.wav"
|
140 |
+
write(output_path, 22050, (audio * 32767.0).astype(np.int16))
|
141 |
+
return output_path
|
142 |
+
|
143 |
+
# 创建 Gradio 界面
|
144 |
+
interface = gr.Interface(
|
145 |
+
fn=generate_audio,
|
146 |
+
inputs=[
|
147 |
+
gr.Textbox(label="输入文本", placeholder="请输入中文文本..."),
|
148 |
+
gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Noise Scale"),
|
149 |
+
gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.8, label="Noise Scale W"),
|
150 |
+
gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Scale")
|
151 |
+
],
|
152 |
+
outputs=gr.Audio(label="生成的音频"),
|
153 |
+
title="中文文本转语音",
|
154 |
+
description="输入中文文本并调整参数以生成语音。支持调整噪声和语速参数。",
|
155 |
+
theme="default"
|
156 |
+
)
|
157 |
+
|
158 |
+
# 启动界面
|
159 |
+
interface.launch()
|