Mahiruoshi commited on
Commit
3ddbd74
1 Parent(s): f8a0cc5

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -251
app.py DELETED
@@ -1,251 +0,0 @@
1
- import logging
2
- logging.getLogger('numba').setLevel(logging.WARNING)
3
- logging.getLogger('matplotlib').setLevel(logging.WARNING)
4
- logging.getLogger('urllib3').setLevel(logging.WARNING)
5
- from text import text_to_sequence
6
- import numpy as np
7
- from scipy.io import wavfile
8
- import torch
9
- import json
10
- import commons
11
- import utils
12
- import sys
13
- import pathlib
14
- import onnxruntime as ort
15
- import gradio as gr
16
- import argparse
17
- import time
18
- import os
19
- import io
20
- from scipy.io.wavfile import write
21
- from flask import Flask, request
22
- from threading import Thread
23
- import openai
24
- import requests
25
- class VitsGradio:
26
- def __init__(self):
27
- self.lan = ["中文","日文","自动"]
28
- self.chatapi = ["gpt-3.5-turbo","gpt3"]
29
- self.modelPaths = []
30
- for root,dirs,files in os.walk("checkpoints"):
31
- for dir in dirs:
32
- self.modelPaths.append(dir)
33
- with gr.Blocks() as self.Vits:
34
- with gr.Tab("调试用"):
35
- with gr.Row():
36
- with gr.Column():
37
- with gr.Row():
38
- with gr.Column():
39
- self.text = gr.TextArea(label="Text", value="你好")
40
- with gr.Accordion(label="测试api", open=False):
41
- self.local_chat1 = gr.Checkbox(value=False, label="使用网址+文本进行模拟")
42
- self.url_input = gr.TextArea(label="键入测试", value="http://127.0.0.1:8080/chat?Text=")
43
- butto = gr.Button("测试从网页端获取文本")
44
- btnVC = gr.Button("测试tts+对话程序")
45
- with gr.Column():
46
- output2 = gr.TextArea(label="回复")
47
- output1 = gr.Audio(label="采样率22050")
48
- output3 = gr.outputs.File(label="44100hz: output.wav")
49
- butto.click(self.Simul, inputs=[self.text, self.url_input], outputs=[output2,output3])
50
- btnVC.click(self.tts_fn, inputs=[self.text], outputs=[output1,output2])
51
- with gr.Tab("控制面板"):
52
- with gr.Row():
53
- with gr.Column():
54
- with gr.Row():
55
- with gr.Column():
56
- self.api_input1 = gr.TextArea(label="输入api-key或ChATGLM模型的路径", value="https://platform.openai.com/account/api-keys")
57
- with gr.Accordion(label="chatbot选择", open=False):
58
- self.api_input2 = gr.Checkbox(value=True, label="采用gpt3.5")
59
- self.local_chat1 = gr.Checkbox(value=False, label="启动本地chatbot")
60
- self.local_chat2 = gr.Checkbox(value=True, label="是否量化")
61
- res = gr.TextArea()
62
- Botselection = gr.Button("聊天机器人选择")
63
- Botselection.click(self.check_bot, inputs=[self.api_input1,self.api_input2,self.local_chat1,self.local_chat2], outputs = [res])
64
- self.input1 = gr.Dropdown(label = "vits模型加载", choices = self.modelPaths, value = self.modelPaths[0], type = "value")
65
- self.input2 = gr.Dropdown(label="Language", choices=self.lan, value="自动", interactive=True)
66
- with gr.Column():
67
- btnVC = gr.Button("Submit")
68
- self.input3 = gr.Dropdown(label="Speaker", choices=list(range(101)), value=0, interactive=True)
69
- self.input4 = gr.Slider(minimum=0, maximum=1.0, label="更改噪声比例(noise scale),以控制情感", value=0.267)
70
- self.input5 = gr.Slider(minimum=0, maximum=1.0, label="更改噪声偏差(noise scale w),以控制音素长短", value=0.7)
71
- self.input6 = gr.Slider(minimum=0.1, maximum=10, label="duration", value=1)
72
- statusa = gr.TextArea()
73
- btnVC.click(self.create_tts_fn, inputs=[self.input1, self.input2, self.input3, self.input4, self.input5, self.input6], outputs = [statusa])
74
-
75
- def Simul(self,text,url_input):
76
- web = url_input + text
77
- res = requests.get(web)
78
- music = res.content
79
- with open('output.wav', 'wb') as code:
80
- code.write(music)
81
- file_path = "output.wav"
82
- return web,file_path
83
-
84
-
85
- def chatgpt(self,text):
86
- self.messages.append({"role": "user", "content": text},)
87
- chat = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages= self.messages)
88
- reply = chat.choices[0].message.content
89
- return reply
90
-
91
- def ChATGLM(self,text):
92
- if text == 'clear':
93
- self.history = []
94
- response, new_history = self.model.chat(self.tokenizer, text, self.history)
95
- response = response.replace(" ",'').replace("\n",'.')
96
- self.history = new_history
97
- return response
98
-
99
- def gpt3_chat(self,text):
100
- call_name = "Waifu"
101
- openai.api_key = args.key
102
- identity = ""
103
- start_sequence = '\n'+str(call_name)+':'
104
- restart_sequence = "\nYou: "
105
- if 1 == 1:
106
- prompt0 = text #当期prompt
107
- if text == 'quit':
108
- return prompt0
109
- prompt = identity + prompt0 + start_sequence
110
- response = openai.Completion.create(
111
- model="text-davinci-003",
112
- prompt=prompt,
113
- temperature=0.5,
114
- max_tokens=1000,
115
- top_p=1.0,
116
- frequency_penalty=0.5,
117
- presence_penalty=0.0,
118
- stop=["\nYou:"]
119
- )
120
- return response['choices'][0]['text'].strip()
121
-
122
- def check_bot(self,api_input1,api_input2,local_chat1,local_chat2):
123
- if local_chat1:
124
- from transformers import AutoTokenizer, AutoModel
125
- self.tokenizer = AutoTokenizer.from_pretrained(api_input1, trust_remote_code=True)
126
- if local_chat2:
127
- self.model = AutoModel.from_pretrained(api_input1, trust_remote_code=True).half().quantize(4).cuda()
128
- else:
129
- self.model = AutoModel.from_pretrained(api_input1, trust_remote_code=True)
130
- self.history = []
131
- else:
132
- self.messages = []
133
- openai.api_key = api_input1
134
- return "Finished"
135
-
136
- def is_japanese(self,string):
137
- for ch in string:
138
- if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
139
- return True
140
- return False
141
-
142
- def is_english(self,string):
143
- import re
144
- pattern = re.compile('^[A-Za-z0-9.,:;!?()_*"\' ]+$')
145
- if pattern.fullmatch(string):
146
- return True
147
- else:
148
- return False
149
-
150
- def get_symbols_from_json(self,path):
151
- assert os.path.isfile(path)
152
- with open(path, 'r') as f:
153
- data = json.load(f)
154
- return data['symbols']
155
-
156
- def sle(self,language,text):
157
- text = text.replace('\n','。').replace(' ',',')
158
- if language == "中文":
159
- tts_input1 = "[ZH]" + text + "[ZH]"
160
- return tts_input1
161
- elif language == "自动":
162
- tts_input1 = f"[JA]{text}[JA]" if self.is_japanese(text) else f"[ZH]{text}[ZH]"
163
- return tts_input1
164
- elif language == "日文":
165
- tts_input1 = "[JA]" + text + "[JA]"
166
- return tts_input1
167
-
168
- def get_text(self,text,hps_ms):
169
- text_norm = text_to_sequence(text,hps_ms.data.text_cleaners)
170
- if hps_ms.data.add_blank:
171
- text_norm = commons.intersperse(text_norm, 0)
172
- text_norm = torch.LongTensor(text_norm)
173
- return text_norm
174
-
175
- def create_tts_fn(self,path, input2, input3, n_scale= 0.667,n_scale_w = 0.8, l_scale = 1 ):
176
- self.symbols = self.get_symbols_from_json(f"checkpoints/{path}/config.json")
177
- self.hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
178
- phone_dict = {
179
- symbol: i for i, symbol in enumerate(self.symbols)
180
- }
181
- self.ort_sess = ort.InferenceSession(f"checkpoints/{path}/model.onnx")
182
- self.language = input2
183
- self.speaker_id = input3
184
- self.n_scale = n_scale
185
- self.n_scale_w = n_scale_w
186
- self.l_scale = l_scale
187
- print(self.language,self.speaker_id,self.n_scale)
188
- return 'success'
189
-
190
- def tts_fn(self,text):
191
- if self.local_chat1:
192
- text = self.chatgpt(text)
193
- elif self.api_input2:
194
- text = self.ChATGLM(text)
195
- else:
196
- text = self.gpt3_chat(text)
197
- print(text)
198
- text =self.sle(self.language,text)
199
- seq = text_to_sequence(text, cleaner_names=self.hps.data.text_cleaners)
200
- if self.hps.data.add_blank:
201
- seq = commons.intersperse(seq, 0)
202
- with torch.no_grad():
203
- x = np.array([seq], dtype=np.int64)
204
- x_len = np.array([x.shape[1]], dtype=np.int64)
205
- sid = np.array([self.speaker_id], dtype=np.int64)
206
- scales = np.array([self.n_scale, self.n_scale_w, self.l_scale], dtype=np.float32)
207
- scales.resize(1, 3)
208
- ort_inputs = {
209
- 'input': x,
210
- 'input_lengths': x_len,
211
- 'scales': scales,
212
- 'sid': sid
213
- }
214
- t1 = time.time()
215
- audio = np.squeeze(self.ort_sess.run(None, ort_inputs))
216
- audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6
217
- audio = np.clip(audio, -32767.0, 32767.0)
218
- t2 = time.time()
219
- spending_time = "推理时间:"+str(t2-t1)+"s"
220
- print(spending_time)
221
- bytes_wav = bytes()
222
- byte_io = io.BytesIO(bytes_wav)
223
- wavfile.write('moe/temp1.wav',self.hps.data.sampling_rate, audio.astype(np.int16))
224
- cmd = 'ffmpeg -y -i ' + 'moe/temp1.wav' + ' -ar 44100 ' + 'moe/temp2.wav'
225
- os.system(cmd)
226
- return (self.hps.data.sampling_rate, audio),text.replace('[JA]','').replace('[ZH]','')
227
-
228
- app = Flask(__name__)
229
- print("开始部署")
230
- grVits = VitsGradio()
231
-
232
- @app.route('/chat')
233
- def text_api():
234
- message = request.args.get('Text','')
235
- audio,text = grVits.tts_fn(message)
236
- text = text.replace('[JA]','').replace('[ZH]','')
237
- with open('moe/temp2.wav','rb') as bit:
238
- wav_bytes = bit.read()
239
- headers = {
240
- 'Content-Type': 'audio/wav',
241
- 'Text': text.encode('utf-8')}
242
- return wav_bytes, 200, headers
243
-
244
- def gradio_interface():
245
- return grVits.Vits.launch()
246
-
247
- if __name__ == '__main__':
248
- api_thread = Thread(target=app.run, args=("0.0.0.0", 8080))
249
- gradio_thread = Thread(target=gradio_interface)
250
- api_thread.start()
251
- gradio_thread.start()