guetLzy commited on
Commit
edb5ebf
·
verified ·
1 Parent(s): 89d7e0b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from scipy.io.wavfile import write
5
+ import os
6
+ import commons
7
+ import utils
8
+ from models import SynthesizerTrn
9
+ from text.symbols import symbols
10
+ from text import get_bert, cleaned_text_to_sequence
11
+ from text.cleaner import clean_text
12
+ from huggingface_hub import hf_hub_download, snapshot_download
13
+
14
+ # 模型配置
15
+ MODEL_CONFIG = {
16
+ "roberta": {
17
+ "repo_id": "hfl/chinese-roberta-wwm-ext-large"
18
+ },
19
+ "vits": {
20
+ "repo_id": "guetLzy/BERT-ISTFT-VITS-Model",
21
+ "files": ["G_1000.pth", "config.json"] # 根据实际文件名调整
22
+ }
23
+ }
24
+
25
+ # 设备设置
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ # 可用的模型选项
29
+ MODEL_OPTIONS = {
30
+ "VITS_Model": "models/G_1000.pth",
31
+ }
32
+
33
+ def download_models():
34
+ """下载所有需要的模型文件"""
35
+ os.makedirs("./bert/chinese-roberta-wwm-ext-large", exist_ok=True)
36
+ os.makedirs("./models", exist_ok=True)
37
+
38
+ # 下载 RoBERTa 模型
39
+ roberta_paths = {}
40
+ for file in MODEL_CONFIG["roberta"]["files"]:
41
+ path = hf_hub_download(
42
+ repo_id=MODEL_CONFIG["roberta"]["repo_id"],
43
+ filename=file,
44
+ local_dir="./bert/chinese-roberta-wwm-ext-large",
45
+ resume_download=True
46
+ )
47
+ roberta_paths[file] = path
48
+
49
+ # 下载 VITS 模型
50
+ vits_paths = {}
51
+ for model_name, model_path in MODEL_OPTIONS.items():
52
+ path = hf_hub_download(
53
+ repo_id=MODEL_CONFIG["vits"]["repo_id"],
54
+ filename=os.path.basename(model_path),
55
+ local_dir="./models",
56
+ resume_download=True
57
+ )
58
+ vits_paths[model_name] = path
59
+
60
+ return {
61
+ "roberta": roberta_paths,
62
+ "vits": vits_paths
63
+ }
64
+
65
+ # 在程序启动时下载模型
66
+ model_paths = download_models()
67
+
68
+ # 加载配置和模型
69
+ hps = utils.get_hparams_from_file("configs/1.json")
70
+
71
+ net_g = SynthesizerTrn(
72
+ len(symbols),
73
+ hps.data.filter_length // 2 + 1,
74
+ hps.train.segment_size // hps.data.hop_length,
75
+ n_speakers=hps.data.n_speakers,
76
+ **hps.model,
77
+ ).to(device)
78
+ _ = net_g.eval()
79
+
80
+ # 加载下载的 VITS 模型权重
81
+ _ = utils.load_checkpoint(model_paths["vits"]["VITS_Model"], net_g, None)
82
+
83
+ def get_text(text, hps, language_str="ZH"):
84
+ norm_text, phone, tone, word2ph = clean_text(text, language_str)
85
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
86
+
87
+ if hps.data.add_blank:
88
+ phone = commons.intersperse(phone, 0)
89
+ tone = commons.intersperse(tone, 0)
90
+ language = commons.intersperse(language, 0)
91
+ for i in range(len(word2ph)):
92
+ word2ph[i] = word2ph[i] * 2
93
+ word2ph[0] += 1
94
+
95
+ if hps.data.use_bert:
96
+ bert = get_bert(norm_text, word2ph, language_str, device)
97
+ del word2ph
98
+ assert bert.shape[-1] == len(phone)
99
+ if language_str == "ZH":
100
+ bert = bert
101
+ else:
102
+ bert = torch.zeros(1024, len(phone))
103
+ else:
104
+ bert = torch.zeros(1024, len(phone))
105
+
106
+ phone = torch.LongTensor(phone)
107
+ tone = torch.LongTensor(tone)
108
+ language = torch.LongTensor(language)
109
+ return bert, phone, tone, language
110
+
111
+ def generate_audio(text, noise_scale=1.0, noise_scale_w=0.8, length_scale=1.0):
112
+ bert, phones, tones, language_id = get_text(text, hps)
113
+
114
+ with torch.no_grad():
115
+ x_tst = phones.to(device).unsqueeze(0)
116
+ tones = tones.to(device).unsqueeze(0)
117
+ language_id = language_id.to(device).unsqueeze(0)
118
+ bert = bert.to(device).unsqueeze(0)
119
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
120
+ sid = torch.LongTensor([hps.data.spk2id["single"]]).to(device)
121
+
122
+ audio = (
123
+ net_g.infer(
124
+ x_tst,
125
+ x_tst_lengths,
126
+ sid,
127
+ tones,
128
+ language_id,
129
+ bert,
130
+ noise_scale=noise_scale,
131
+ noise_scale_w=noise_scale_w,
132
+ length_scale=length_scale,
133
+ )[0][0, 0]
134
+ .data.cpu()
135
+ .float()
136
+ .numpy()
137
+ )
138
+
139
+ output_path = "output.wav"
140
+ write(output_path, 22050, (audio * 32767.0).astype(np.int16))
141
+ return output_path
142
+
143
+ # 创建 Gradio 界面
144
+ interface = gr.Interface(
145
+ fn=generate_audio,
146
+ inputs=[
147
+ gr.Textbox(label="输入文本", placeholder="请输入中文文本..."),
148
+ gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Noise Scale"),
149
+ gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.8, label="Noise Scale W"),
150
+ gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Scale")
151
+ ],
152
+ outputs=gr.Audio(label="生成的音频"),
153
+ title="中文文本转语音",
154
+ description="输入中文文本并调整参数以生成语音。支持调整噪声和语速参数。",
155
+ theme="default"
156
+ )
157
+
158
+ # 启动界面
159
+ interface.launch()