chenjgtea commited on
Commit
3fd53ff
·
1 Parent(s): 8825975

新增gpu模式下chattts代码

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. web/{app.py → app_gpu.py} +51 -12
README.md CHANGED
@@ -6,7 +6,7 @@ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.41.0
8
  #app_port: 8080
9
- app_file: web/app.py
10
  pinned: false
11
  ---
12
 
 
6
  sdk: gradio
7
  sdk_version: 4.41.0
8
  #app_port: 8080
9
+ app_file: web/app_gpu.py
10
  pinned: false
11
  ---
12
 
web/{app.py → app_gpu.py} RENAMED
@@ -203,13 +203,20 @@ def get_chat_infer_audio(chat_txt,
203
  spk_emb_text):
204
  logger.info("========开始生成音频文件=====")
205
  #音频参数设置
206
- params_infer_code = Chat2TTS.Chat.InferCodeParams(
207
- spk_emb=spk_emb_text, # add sampled speaker
208
- temperature=temperature_slider, # using custom temperature
209
- top_P=top_p_slider, # top P decode
210
- top_K=top_k_slider, # top K decode
211
- )
212
-
 
 
 
 
 
 
 
213
  torch.manual_seed(audio_seed_input)
214
  wav = chat.infer(
215
  text=chat_txt,
@@ -227,10 +234,11 @@ def get_chat_infer_text(text,seed,refine_text_checkBox):
227
  logger.info("========文本内容无需优化=====")
228
  return text
229
 
230
- params_refine_text = Chat2TTS.Chat.RefineTextParams(
231
- prompt='[oral_2][laugh_0][break_6]',
232
- )
233
 
 
234
  torch.manual_seed(seed)
235
  chat_text = chat.infer(
236
  text=text,
@@ -245,9 +253,40 @@ def get_chat_infer_text(text,seed,refine_text_checkBox):
245
  def on_audio_seed_change(audio_seed_input):
246
  global chat
247
  torch.manual_seed(audio_seed_input)
248
- rand_spk = chat.sample_random_speaker()
249
  return rand_spk
250
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
  if __name__ == "__main__":
253
  parser = argparse.ArgumentParser(description="ChatTTS demo Launch")
 
203
  spk_emb_text):
204
  logger.info("========开始生成音频文件=====")
205
  #音频参数设置
206
+ # params_infer_code = Chat2TTS.Chat.InferCodeParams(
207
+ # spk_emb=spk_emb_text, # add sampled speaker
208
+ # temperature=temperature_slider, # using custom temperature
209
+ # top_P=top_p_slider, # top P decode
210
+ # top_K=top_k_slider, # top K decode
211
+ # )
212
+ torch.manual_seed(audio_seed_input)
213
+ rand_spk = torch.randn(768)
214
+ params_infer_code = {
215
+ 'spk_emb': rand_spk,
216
+ 'temperature': temperature_slider,
217
+ 'top_P': top_p_slider,
218
+ 'top_K': top_k_slider,
219
+ }
220
  torch.manual_seed(audio_seed_input)
221
  wav = chat.infer(
222
  text=chat_txt,
 
234
  logger.info("========文本内容无需优化=====")
235
  return text
236
 
237
+ # params_refine_text = Chat2TTS.Chat.RefineTextParams(
238
+ # prompt='[oral_2][laugh_0][break_6]',
239
+ # )
240
 
241
+ params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'}
242
  torch.manual_seed(seed)
243
  chat_text = chat.infer(
244
  text=text,
 
253
  def on_audio_seed_change(audio_seed_input):
254
  global chat
255
  torch.manual_seed(audio_seed_input)
256
+ rand_spk = torch.randn(audio_seed_input)
257
  return rand_spk
258
+ #return encode_spk_emb(rand_spk)
259
+
260
+ def encode_spk_emb(spk_emb: torch.Tensor) -> str:
261
+ import pybase16384 as b14
262
+ import lzma
263
+ with torch.no_grad():
264
+ arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
265
+ s = b14.encode_to_string(
266
+ lzma.compress(
267
+ arr.tobytes(),
268
+ format=lzma.FORMAT_RAW,
269
+ filters=[
270
+ {"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}
271
+ ],
272
+ ),
273
+ )
274
+ del arr
275
+ return s
276
+
277
+
278
+ # def _sample_random_speaker(self) -> torch.Tensor:
279
+ # with torch.no_grad():
280
+ # dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
281
+ # out: torch.Tensor = self.pretrain_models["spk_stat"]
282
+ # std, mean = out.chunk(2)
283
+ # spk = (
284
+ # torch.randn(dim, device=std.device, dtype=torch.float16)
285
+ # .mul_(std)
286
+ # .add_(mean)
287
+ # )
288
+ # del out, std, mean
289
+ # return spk
290
 
291
  if __name__ == "__main__":
292
  parser = argparse.ArgumentParser(description="ChatTTS demo Launch")