Mahiruoshi
commited on
Commit
•
3ddbd74
1
Parent(s):
f8a0cc5
Delete app.py
Browse files
app.py
DELETED
@@ -1,251 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
logging.getLogger('numba').setLevel(logging.WARNING)
|
3 |
-
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
4 |
-
logging.getLogger('urllib3').setLevel(logging.WARNING)
|
5 |
-
from text import text_to_sequence
|
6 |
-
import numpy as np
|
7 |
-
from scipy.io import wavfile
|
8 |
-
import torch
|
9 |
-
import json
|
10 |
-
import commons
|
11 |
-
import utils
|
12 |
-
import sys
|
13 |
-
import pathlib
|
14 |
-
import onnxruntime as ort
|
15 |
-
import gradio as gr
|
16 |
-
import argparse
|
17 |
-
import time
|
18 |
-
import os
|
19 |
-
import io
|
20 |
-
from scipy.io.wavfile import write
|
21 |
-
from flask import Flask, request
|
22 |
-
from threading import Thread
|
23 |
-
import openai
|
24 |
-
import requests
|
25 |
-
class VitsGradio:
|
26 |
-
def __init__(self):
|
27 |
-
self.lan = ["中文","日文","自动"]
|
28 |
-
self.chatapi = ["gpt-3.5-turbo","gpt3"]
|
29 |
-
self.modelPaths = []
|
30 |
-
for root,dirs,files in os.walk("checkpoints"):
|
31 |
-
for dir in dirs:
|
32 |
-
self.modelPaths.append(dir)
|
33 |
-
with gr.Blocks() as self.Vits:
|
34 |
-
with gr.Tab("调试用"):
|
35 |
-
with gr.Row():
|
36 |
-
with gr.Column():
|
37 |
-
with gr.Row():
|
38 |
-
with gr.Column():
|
39 |
-
self.text = gr.TextArea(label="Text", value="你好")
|
40 |
-
with gr.Accordion(label="测试api", open=False):
|
41 |
-
self.local_chat1 = gr.Checkbox(value=False, label="使用网址+文本进行模拟")
|
42 |
-
self.url_input = gr.TextArea(label="键入测试", value="http://127.0.0.1:8080/chat?Text=")
|
43 |
-
butto = gr.Button("测试从网页端获取文本")
|
44 |
-
btnVC = gr.Button("测试tts+对话程序")
|
45 |
-
with gr.Column():
|
46 |
-
output2 = gr.TextArea(label="回复")
|
47 |
-
output1 = gr.Audio(label="采样率22050")
|
48 |
-
output3 = gr.outputs.File(label="44100hz: output.wav")
|
49 |
-
butto.click(self.Simul, inputs=[self.text, self.url_input], outputs=[output2,output3])
|
50 |
-
btnVC.click(self.tts_fn, inputs=[self.text], outputs=[output1,output2])
|
51 |
-
with gr.Tab("控制面板"):
|
52 |
-
with gr.Row():
|
53 |
-
with gr.Column():
|
54 |
-
with gr.Row():
|
55 |
-
with gr.Column():
|
56 |
-
self.api_input1 = gr.TextArea(label="输入api-key或ChATGLM模型的路径", value="https://platform.openai.com/account/api-keys")
|
57 |
-
with gr.Accordion(label="chatbot选择", open=False):
|
58 |
-
self.api_input2 = gr.Checkbox(value=True, label="采用gpt3.5")
|
59 |
-
self.local_chat1 = gr.Checkbox(value=False, label="启动本地chatbot")
|
60 |
-
self.local_chat2 = gr.Checkbox(value=True, label="是否量化")
|
61 |
-
res = gr.TextArea()
|
62 |
-
Botselection = gr.Button("聊天机器人选择")
|
63 |
-
Botselection.click(self.check_bot, inputs=[self.api_input1,self.api_input2,self.local_chat1,self.local_chat2], outputs = [res])
|
64 |
-
self.input1 = gr.Dropdown(label = "vits模型加载", choices = self.modelPaths, value = self.modelPaths[0], type = "value")
|
65 |
-
self.input2 = gr.Dropdown(label="Language", choices=self.lan, value="自动", interactive=True)
|
66 |
-
with gr.Column():
|
67 |
-
btnVC = gr.Button("Submit")
|
68 |
-
self.input3 = gr.Dropdown(label="Speaker", choices=list(range(101)), value=0, interactive=True)
|
69 |
-
self.input4 = gr.Slider(minimum=0, maximum=1.0, label="更改噪声比例(noise scale),以控制情感", value=0.267)
|
70 |
-
self.input5 = gr.Slider(minimum=0, maximum=1.0, label="更改噪声偏差(noise scale w),以控制音素长短", value=0.7)
|
71 |
-
self.input6 = gr.Slider(minimum=0.1, maximum=10, label="duration", value=1)
|
72 |
-
statusa = gr.TextArea()
|
73 |
-
btnVC.click(self.create_tts_fn, inputs=[self.input1, self.input2, self.input3, self.input4, self.input5, self.input6], outputs = [statusa])
|
74 |
-
|
75 |
-
def Simul(self,text,url_input):
|
76 |
-
web = url_input + text
|
77 |
-
res = requests.get(web)
|
78 |
-
music = res.content
|
79 |
-
with open('output.wav', 'wb') as code:
|
80 |
-
code.write(music)
|
81 |
-
file_path = "output.wav"
|
82 |
-
return web,file_path
|
83 |
-
|
84 |
-
|
85 |
-
def chatgpt(self,text):
|
86 |
-
self.messages.append({"role": "user", "content": text},)
|
87 |
-
chat = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages= self.messages)
|
88 |
-
reply = chat.choices[0].message.content
|
89 |
-
return reply
|
90 |
-
|
91 |
-
def ChATGLM(self,text):
|
92 |
-
if text == 'clear':
|
93 |
-
self.history = []
|
94 |
-
response, new_history = self.model.chat(self.tokenizer, text, self.history)
|
95 |
-
response = response.replace(" ",'').replace("\n",'.')
|
96 |
-
self.history = new_history
|
97 |
-
return response
|
98 |
-
|
99 |
-
def gpt3_chat(self,text):
|
100 |
-
call_name = "Waifu"
|
101 |
-
openai.api_key = args.key
|
102 |
-
identity = ""
|
103 |
-
start_sequence = '\n'+str(call_name)+':'
|
104 |
-
restart_sequence = "\nYou: "
|
105 |
-
if 1 == 1:
|
106 |
-
prompt0 = text #当期prompt
|
107 |
-
if text == 'quit':
|
108 |
-
return prompt0
|
109 |
-
prompt = identity + prompt0 + start_sequence
|
110 |
-
response = openai.Completion.create(
|
111 |
-
model="text-davinci-003",
|
112 |
-
prompt=prompt,
|
113 |
-
temperature=0.5,
|
114 |
-
max_tokens=1000,
|
115 |
-
top_p=1.0,
|
116 |
-
frequency_penalty=0.5,
|
117 |
-
presence_penalty=0.0,
|
118 |
-
stop=["\nYou:"]
|
119 |
-
)
|
120 |
-
return response['choices'][0]['text'].strip()
|
121 |
-
|
122 |
-
def check_bot(self,api_input1,api_input2,local_chat1,local_chat2):
|
123 |
-
if local_chat1:
|
124 |
-
from transformers import AutoTokenizer, AutoModel
|
125 |
-
self.tokenizer = AutoTokenizer.from_pretrained(api_input1, trust_remote_code=True)
|
126 |
-
if local_chat2:
|
127 |
-
self.model = AutoModel.from_pretrained(api_input1, trust_remote_code=True).half().quantize(4).cuda()
|
128 |
-
else:
|
129 |
-
self.model = AutoModel.from_pretrained(api_input1, trust_remote_code=True)
|
130 |
-
self.history = []
|
131 |
-
else:
|
132 |
-
self.messages = []
|
133 |
-
openai.api_key = api_input1
|
134 |
-
return "Finished"
|
135 |
-
|
136 |
-
def is_japanese(self,string):
|
137 |
-
for ch in string:
|
138 |
-
if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
|
139 |
-
return True
|
140 |
-
return False
|
141 |
-
|
142 |
-
def is_english(self,string):
|
143 |
-
import re
|
144 |
-
pattern = re.compile('^[A-Za-z0-9.,:;!?()_*"\' ]+$')
|
145 |
-
if pattern.fullmatch(string):
|
146 |
-
return True
|
147 |
-
else:
|
148 |
-
return False
|
149 |
-
|
150 |
-
def get_symbols_from_json(self,path):
|
151 |
-
assert os.path.isfile(path)
|
152 |
-
with open(path, 'r') as f:
|
153 |
-
data = json.load(f)
|
154 |
-
return data['symbols']
|
155 |
-
|
156 |
-
def sle(self,language,text):
|
157 |
-
text = text.replace('\n','。').replace(' ',',')
|
158 |
-
if language == "中文":
|
159 |
-
tts_input1 = "[ZH]" + text + "[ZH]"
|
160 |
-
return tts_input1
|
161 |
-
elif language == "自动":
|
162 |
-
tts_input1 = f"[JA]{text}[JA]" if self.is_japanese(text) else f"[ZH]{text}[ZH]"
|
163 |
-
return tts_input1
|
164 |
-
elif language == "日文":
|
165 |
-
tts_input1 = "[JA]" + text + "[JA]"
|
166 |
-
return tts_input1
|
167 |
-
|
168 |
-
def get_text(self,text,hps_ms):
|
169 |
-
text_norm = text_to_sequence(text,hps_ms.data.text_cleaners)
|
170 |
-
if hps_ms.data.add_blank:
|
171 |
-
text_norm = commons.intersperse(text_norm, 0)
|
172 |
-
text_norm = torch.LongTensor(text_norm)
|
173 |
-
return text_norm
|
174 |
-
|
175 |
-
def create_tts_fn(self,path, input2, input3, n_scale= 0.667,n_scale_w = 0.8, l_scale = 1 ):
|
176 |
-
self.symbols = self.get_symbols_from_json(f"checkpoints/{path}/config.json")
|
177 |
-
self.hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
|
178 |
-
phone_dict = {
|
179 |
-
symbol: i for i, symbol in enumerate(self.symbols)
|
180 |
-
}
|
181 |
-
self.ort_sess = ort.InferenceSession(f"checkpoints/{path}/model.onnx")
|
182 |
-
self.language = input2
|
183 |
-
self.speaker_id = input3
|
184 |
-
self.n_scale = n_scale
|
185 |
-
self.n_scale_w = n_scale_w
|
186 |
-
self.l_scale = l_scale
|
187 |
-
print(self.language,self.speaker_id,self.n_scale)
|
188 |
-
return 'success'
|
189 |
-
|
190 |
-
def tts_fn(self,text):
|
191 |
-
if self.local_chat1:
|
192 |
-
text = self.chatgpt(text)
|
193 |
-
elif self.api_input2:
|
194 |
-
text = self.ChATGLM(text)
|
195 |
-
else:
|
196 |
-
text = self.gpt3_chat(text)
|
197 |
-
print(text)
|
198 |
-
text =self.sle(self.language,text)
|
199 |
-
seq = text_to_sequence(text, cleaner_names=self.hps.data.text_cleaners)
|
200 |
-
if self.hps.data.add_blank:
|
201 |
-
seq = commons.intersperse(seq, 0)
|
202 |
-
with torch.no_grad():
|
203 |
-
x = np.array([seq], dtype=np.int64)
|
204 |
-
x_len = np.array([x.shape[1]], dtype=np.int64)
|
205 |
-
sid = np.array([self.speaker_id], dtype=np.int64)
|
206 |
-
scales = np.array([self.n_scale, self.n_scale_w, self.l_scale], dtype=np.float32)
|
207 |
-
scales.resize(1, 3)
|
208 |
-
ort_inputs = {
|
209 |
-
'input': x,
|
210 |
-
'input_lengths': x_len,
|
211 |
-
'scales': scales,
|
212 |
-
'sid': sid
|
213 |
-
}
|
214 |
-
t1 = time.time()
|
215 |
-
audio = np.squeeze(self.ort_sess.run(None, ort_inputs))
|
216 |
-
audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6
|
217 |
-
audio = np.clip(audio, -32767.0, 32767.0)
|
218 |
-
t2 = time.time()
|
219 |
-
spending_time = "推理时间:"+str(t2-t1)+"s"
|
220 |
-
print(spending_time)
|
221 |
-
bytes_wav = bytes()
|
222 |
-
byte_io = io.BytesIO(bytes_wav)
|
223 |
-
wavfile.write('moe/temp1.wav',self.hps.data.sampling_rate, audio.astype(np.int16))
|
224 |
-
cmd = 'ffmpeg -y -i ' + 'moe/temp1.wav' + ' -ar 44100 ' + 'moe/temp2.wav'
|
225 |
-
os.system(cmd)
|
226 |
-
return (self.hps.data.sampling_rate, audio),text.replace('[JA]','').replace('[ZH]','')
|
227 |
-
|
228 |
-
app = Flask(__name__)
|
229 |
-
print("开始部署")
|
230 |
-
grVits = VitsGradio()
|
231 |
-
|
232 |
-
@app.route('/chat')
|
233 |
-
def text_api():
|
234 |
-
message = request.args.get('Text','')
|
235 |
-
audio,text = grVits.tts_fn(message)
|
236 |
-
text = text.replace('[JA]','').replace('[ZH]','')
|
237 |
-
with open('moe/temp2.wav','rb') as bit:
|
238 |
-
wav_bytes = bit.read()
|
239 |
-
headers = {
|
240 |
-
'Content-Type': 'audio/wav',
|
241 |
-
'Text': text.encode('utf-8')}
|
242 |
-
return wav_bytes, 200, headers
|
243 |
-
|
244 |
-
def gradio_interface():
|
245 |
-
return grVits.Vits.launch()
|
246 |
-
|
247 |
-
if __name__ == '__main__':
|
248 |
-
api_thread = Thread(target=app.run, args=("0.0.0.0", 8080))
|
249 |
-
gradio_thread = Thread(target=gradio_interface)
|
250 |
-
api_thread.start()
|
251 |
-
gradio_thread.start()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|