XzJosh commited on
Commit
0f2e028
1 Parent(s): 902b99d

Delete inference_webui.py

Browse files
Files changed (1) hide show
  1. inference_webui.py +0 -363
inference_webui.py DELETED
@@ -1,363 +0,0 @@
1
- import os
2
-
3
- gpt_path = os.environ.get(
4
- "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
5
- )
6
- sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
7
- cnhubert_base_path = os.environ.get(
8
- "cnhubert_base_path", "pretrained_models/chinese-hubert-base"
9
- )
10
- bert_path = os.environ.get(
11
- "bert_path", "pretrained_models/chinese-roberta-wwm-ext-large"
12
- )
13
- infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
14
- infer_ttswebui = int(infer_ttswebui)
15
- if "_CUDA_VISIBLE_DEVICES" in os.environ:
16
- os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
17
- is_half = eval(os.environ.get("is_half", "True"))
18
- import gradio as gr
19
- from transformers import AutoModelForMaskedLM, AutoTokenizer
20
- import numpy as np
21
- import librosa,torch
22
- from feature_extractor import cnhubert
23
- cnhubert.cnhubert_base_path=cnhubert_base_path
24
-
25
- from module.models import SynthesizerTrn
26
- from AR.models.t2s_lightning_module import Text2SemanticLightningModule
27
- from text import cleaned_text_to_sequence
28
- from text.cleaner import clean_text
29
- from time import time as ttime
30
- from module.mel_processing import spectrogram_torch
31
- from my_utils import load_audio
32
-
33
- device = "cuda"
34
- tokenizer = AutoTokenizer.from_pretrained(bert_path)
35
- bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
36
- if is_half == True:
37
- bert_model = bert_model.half().to(device)
38
- else:
39
- bert_model = bert_model.to(device)
40
-
41
-
42
- # bert_model=bert_model.to(device)
43
- def get_bert_feature(text, word2ph):
44
- with torch.no_grad():
45
- inputs = tokenizer(text, return_tensors="pt")
46
- for i in inputs:
47
- inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model
48
- res = bert_model(**inputs, output_hidden_states=True)
49
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
50
- assert len(word2ph) == len(text)
51
- phone_level_feature = []
52
- for i in range(len(word2ph)):
53
- repeat_feature = res[i].repeat(word2ph[i], 1)
54
- phone_level_feature.append(repeat_feature)
55
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
56
- # if(is_half==True):phone_level_feature=phone_level_feature.half()
57
- return phone_level_feature.T
58
-
59
-
60
- n_semantic = 1024
61
-
62
- dict_s2=torch.load(sovits_path,map_location="cpu")
63
- hps=dict_s2["config"]
64
-
65
- class DictToAttrRecursive(dict):
66
- def __init__(self, input_dict):
67
- super().__init__(input_dict)
68
- for key, value in input_dict.items():
69
- if isinstance(value, dict):
70
- value = DictToAttrRecursive(value)
71
- self[key] = value
72
- setattr(self, key, value)
73
-
74
- def __getattr__(self, item):
75
- try:
76
- return self[item]
77
- except KeyError:
78
- raise AttributeError(f"Attribute {item} not found")
79
-
80
- def __setattr__(self, key, value):
81
- if isinstance(value, dict):
82
- value = DictToAttrRecursive(value)
83
- super(DictToAttrRecursive, self).__setitem__(key, value)
84
- super().__setattr__(key, value)
85
-
86
- def __delattr__(self, item):
87
- try:
88
- del self[item]
89
- except KeyError:
90
- raise AttributeError(f"Attribute {item} not found")
91
-
92
-
93
- hps = DictToAttrRecursive(hps)
94
-
95
- hps.model.semantic_frame_rate = "25hz"
96
- dict_s1 = torch.load(gpt_path, map_location="cpu")
97
- config = dict_s1["config"]
98
- ssl_model = cnhubert.get_model()
99
- if is_half == True:
100
- ssl_model = ssl_model.half().to(device)
101
- else:
102
- ssl_model = ssl_model.to(device)
103
-
104
- vq_model = SynthesizerTrn(
105
- hps.data.filter_length // 2 + 1,
106
- hps.train.segment_size // hps.data.hop_length,
107
- n_speakers=hps.data.n_speakers,
108
- **hps.model
109
- )
110
- if is_half == True:
111
- vq_model = vq_model.half().to(device)
112
- else:
113
- vq_model = vq_model.to(device)
114
- vq_model.eval()
115
- print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
116
- hz = 50
117
- max_sec = config["data"]["max_sec"]
118
- # t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
119
- t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
120
- t2s_model.load_state_dict(dict_s1["weight"])
121
- if is_half == True:
122
- t2s_model = t2s_model.half()
123
- t2s_model = t2s_model.to(device)
124
- t2s_model.eval()
125
- total = sum([param.nelement() for param in t2s_model.parameters()])
126
- print("Number of parameter: %.2fM" % (total / 1e6))
127
-
128
-
129
- def get_spepc(hps, filename):
130
- audio = load_audio(filename, int(hps.data.sampling_rate))
131
- audio = torch.FloatTensor(audio)
132
- audio_norm = audio
133
- audio_norm = audio_norm.unsqueeze(0)
134
- spec = spectrogram_torch(
135
- audio_norm,
136
- hps.data.filter_length,
137
- hps.data.sampling_rate,
138
- hps.data.hop_length,
139
- hps.data.win_length,
140
- center=False,
141
- )
142
- return spec
143
-
144
-
145
- dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
146
-
147
-
148
- def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
149
- t0 = ttime()
150
- prompt_text = prompt_text.strip("\n")
151
- prompt_language, text = prompt_language, text.strip("\n")
152
- with torch.no_grad():
153
- wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
154
- wav16k = torch.from_numpy(wav16k)
155
- if is_half == True:
156
- wav16k = wav16k.half().to(device)
157
- else:
158
- wav16k = wav16k.to(device)
159
- ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
160
- "last_hidden_state"
161
- ].transpose(
162
- 1, 2
163
- ) # .float()
164
- codes = vq_model.extract_latent(ssl_content)
165
- prompt_semantic = codes[0, 0]
166
- t1 = ttime()
167
- prompt_language = dict_language[prompt_language]
168
- text_language = dict_language[text_language]
169
- phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
170
- phones1 = cleaned_text_to_sequence(phones1)
171
- texts = text.split("\n")
172
- audio_opt = []
173
- zero_wav = np.zeros(
174
- int(hps.data.sampling_rate * 0.3),
175
- dtype=np.float16 if is_half == True else np.float32,
176
- )
177
- for text in texts:
178
- # 解决输入目标文本的空行导致报错的问题
179
- if (len(text.strip()) == 0):
180
- continue
181
- phones2, word2ph2, norm_text2 = clean_text(text, text_language)
182
- phones2 = cleaned_text_to_sequence(phones2)
183
- if prompt_language == "zh":
184
- bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
185
- else:
186
- bert1 = torch.zeros(
187
- (1024, len(phones1)),
188
- dtype=torch.float16 if is_half == True else torch.float32,
189
- ).to(device)
190
- if text_language == "zh":
191
- bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
192
- else:
193
- bert2 = torch.zeros((1024, len(phones2))).to(bert1)
194
- bert = torch.cat([bert1, bert2], 1)
195
-
196
- all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
197
- bert = bert.to(device).unsqueeze(0)
198
- all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
199
- prompt = prompt_semantic.unsqueeze(0).to(device)
200
- t2 = ttime()
201
- with torch.no_grad():
202
- # pred_semantic = t2s_model.model.infer(
203
- pred_semantic, idx = t2s_model.model.infer_panel(
204
- all_phoneme_ids,
205
- all_phoneme_len,
206
- prompt,
207
- bert,
208
- # prompt_phone_len=ph_offset,
209
- top_k=config["inference"]["top_k"],
210
- early_stop_num=hz * max_sec,
211
- )
212
- t3 = ttime()
213
- # print(pred_semantic.shape,idx)
214
- pred_semantic = pred_semantic[:, -idx:].unsqueeze(
215
- 0
216
- ) # .unsqueeze(0)#mq要多unsqueeze一次
217
- refer = get_spepc(hps, ref_wav_path) # .to(device)
218
- if is_half == True:
219
- refer = refer.half().to(device)
220
- else:
221
- refer = refer.to(device)
222
- # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
223
- audio = (
224
- vq_model.decode(
225
- pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
226
- )
227
- .detach()
228
- .cpu()
229
- .numpy()[0, 0]
230
- ) ###试试重建不带上prompt部分
231
- audio_opt.append(audio)
232
- audio_opt.append(zero_wav)
233
- t4 = ttime()
234
- print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
235
- yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
236
- np.int16
237
- )
238
-
239
-
240
- splits = {
241
- ",",
242
- "。",
243
- "?",
244
- "!",
245
- ",",
246
- ".",
247
- "?",
248
- "!",
249
- "~",
250
- ":",
251
- ":",
252
- "—",
253
- "…",
254
- } # 不考虑省略号
255
-
256
-
257
- def split(todo_text):
258
- todo_text = todo_text.replace("……", "。").replace("——", ",")
259
- if todo_text[-1] not in splits:
260
- todo_text += "。"
261
- i_split_head = i_split_tail = 0
262
- len_text = len(todo_text)
263
- todo_texts = []
264
- while 1:
265
- if i_split_head >= len_text:
266
- break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
267
- if todo_text[i_split_head] in splits:
268
- i_split_head += 1
269
- todo_texts.append(todo_text[i_split_tail:i_split_head])
270
- i_split_tail = i_split_head
271
- else:
272
- i_split_head += 1
273
- return todo_texts
274
-
275
-
276
- def cut1(inp):
277
- inp = inp.strip("\n")
278
- inps = split(inp)
279
- split_idx = list(range(0, len(inps), 5))
280
- split_idx[-1] = None
281
- if len(split_idx) > 1:
282
- opts = []
283
- for idx in range(len(split_idx) - 1):
284
- opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
285
- else:
286
- opts = [inp]
287
- return "\n".join(opts)
288
-
289
-
290
- def cut2(inp):
291
- inp = inp.strip("\n")
292
- inps = split(inp)
293
- if len(inps) < 2:
294
- return [inp]
295
- opts = []
296
- summ = 0
297
- tmp_str = ""
298
- for i in range(len(inps)):
299
- summ += len(inps[i])
300
- tmp_str += inps[i]
301
- if summ > 50:
302
- summ = 0
303
- opts.append(tmp_str)
304
- tmp_str = ""
305
- if tmp_str != "":
306
- opts.append(tmp_str)
307
- if len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
308
- opts[-2] = opts[-2] + opts[-1]
309
- opts = opts[:-1]
310
- return "\n".join(opts)
311
-
312
-
313
- def cut3(inp):
314
- inp = inp.strip("\n")
315
- return "\n".join(["%s。" % item for item in inp.strip("。").split("。")])
316
-
317
-
318
- with gr.Blocks(title="GPT-SoVITS WebUI") as app:
319
- gr.Markdown(
320
- value="本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
321
- )
322
- # with gr.Tabs():
323
- # with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
324
- with gr.Group():
325
- gr.Markdown(value="*请上传并填写参考信息")
326
- with gr.Row():
327
- inp_ref = gr.Audio(label="请上传参考音频", type="filepath")
328
- prompt_text = gr.Textbox(label="参考音频的文本", value="")
329
- prompt_language = gr.Dropdown(
330
- label="参考音频的语种", choices=["中文", "英文", "日文"], value="中文"
331
- )
332
- gr.Markdown(value="*请填写需要合成的目标文本")
333
- with gr.Row():
334
- text = gr.Textbox(label="需要合成的文本", value="")
335
- text_language = gr.Dropdown(
336
- label="需要合成的语种", choices=["中文", "英文", "日文"], value="中文"
337
- )
338
- inference_button = gr.Button("合成语音", variant="primary")
339
- output = gr.Audio(label="输出的语音")
340
- inference_button.click(
341
- get_tts_wav,
342
- [inp_ref, prompt_text, prompt_language, text, text_language],
343
- [output],
344
- )
345
-
346
- gr.Markdown(value="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
347
- with gr.Row():
348
- text_inp = gr.Textbox(label="需要合成的切分前文本", value="")
349
- button1 = gr.Button("凑五句一切", variant="primary")
350
- button2 = gr.Button("凑50字一切", variant="primary")
351
- button3 = gr.Button("按中文句号。切", variant="primary")
352
- text_opt = gr.Textbox(label="切分后文本", value="")
353
- button1.click(cut1, [text_inp], [text_opt])
354
- button2.click(cut2, [text_inp], [text_opt])
355
- button3.click(cut3, [text_inp], [text_opt])
356
- gr.Markdown(value="后续将支持混合语种编码文本输入。")
357
-
358
- app.queue(concurrency_count=511, max_size=1022).launch(
359
- server_name="0.0.0.0",
360
- inbrowser=True,
361
- server_port=infer_ttswebui,
362
- quiet=True,
363
- )