chenjgtea commited on
Commit
638a294
·
1 Parent(s): 1898711

gpu模型下代码更新,已定型

Browse files
Chat2TTS/core.py CHANGED
@@ -12,20 +12,21 @@ from .utils.io_utils import get_latest_modified_file
12
  from .infer.api import refine_text, infer_code
13
  from dataclasses import dataclass
14
  from typing import Literal, Optional, List, Tuple, Dict
15
- import numpy as np
16
- import pybase16384 as b14
17
- import lzma
18
 
19
- from huggingface_hub import snapshot_download
20
 
21
- logging.basicConfig(level = logging.INFO)
22
 
23
 
24
  class Chat:
25
  def __init__(self, ):
26
  self.pretrain_models = {}
27
- self.logger = logging.getLogger(__name__)
28
- self.gpt=None
 
 
 
29
 
30
  def check_model(self, level = logging.INFO, use_decoder = False):
31
  not_finish = False
@@ -159,6 +160,21 @@ class Chat:
159
  ):
160
 
161
  assert self.check_model(use_decoder=use_decoder)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  if not skip_refine_text:
164
  text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
@@ -179,6 +195,14 @@ class Chat:
179
  wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
180
 
181
  return wav
 
 
 
 
 
 
 
 
182
 
183
 
184
  # def sample_random_speaker(self) -> str:
 
12
  from .infer.api import refine_text, infer_code
13
  from dataclasses import dataclass
14
  from typing import Literal, Optional, List, Tuple, Dict
15
+ from tool.logger import get_logger
 
 
16
 
17
+ from ChatTTS.norm import Normalizer
18
 
19
+ from huggingface_hub import snapshot_download
20
 
21
 
22
  class Chat:
23
  def __init__(self, ):
24
  self.pretrain_models = {}
25
+ self.logger = get_logger(__name__,lv=logging.INFO)
26
+ self.normalizer = Normalizer(
27
+ os.path.join(os.path.dirname(__file__), "res", "homophones_map.json"),
28
+ self.logger,
29
+ )
30
 
31
  def check_model(self, level = logging.INFO, use_decoder = False):
32
  not_finish = False
 
160
  ):
161
 
162
  assert self.check_model(use_decoder=use_decoder)
163
+
164
+ if skip_refine_text:
165
+ self.logger.info("========对文本内容不做优化处理,仅做规则处理======")
166
+ if not isinstance(text, list):
167
+ text = [text]
168
+
169
+ text = [
170
+ self.normalizer(
171
+ text=t,
172
+ do_text_normalization=True,
173
+ do_homophone_replacement=True,
174
+ lang=None,
175
+ )
176
+ for t in text
177
+ ]
178
 
179
  if not skip_refine_text:
180
  text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
 
195
  wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
196
 
197
  return wav
198
+
199
+ def emptpy_audio(self):
200
+ return self.infer(" ",
201
+ skip_refine_text=True,
202
+ refine_text_only=False,
203
+ params_refine_text={},
204
+ params_infer_code={},
205
+ use_decoder=False)
206
 
207
 
208
  # def sample_random_speaker(self) -> str:
Chat2TTS/res/__init__.py ADDED
File without changes
Chat2TTS/res/homophones_map.json ADDED
The diff for this file is too large to render. See raw diff
 
web/app_cpu.py CHANGED
@@ -14,17 +14,13 @@ from tool.ctx import TorchSeedContext
14
  import ChatTTS
15
  import argparse
16
  import torch._dynamo
17
-
18
  torch._dynamo.config.suppress_errors = True
19
 
20
 
21
-
22
  logger = get_logger("app")
23
-
24
  # Initialize and load the model:
25
  chat = ChatTTS.Chat()
26
-
27
-
28
  def init_chat(args):
29
  global chat
30
  source = "custom"
 
14
  import ChatTTS
15
  import argparse
16
  import torch._dynamo
 
17
  torch._dynamo.config.suppress_errors = True
18
 
19
 
20
+ #HF空间中,GPU模式 运行 大模型代码
21
  logger = get_logger("app")
 
22
  # Initialize and load the model:
23
  chat = ChatTTS.Chat()
 
 
24
  def init_chat(args):
25
  global chat
26
  source = "custom"
web/app_gpu.py CHANGED
@@ -18,10 +18,8 @@ import torch._dynamo
18
 
19
  torch._dynamo.config.suppress_errors = True
20
 
21
-
22
-
23
  logger = get_logger("app")
24
-
25
  # Initialize and load the model:
26
  chat = Chat2TTS.Chat()
27
 
@@ -67,6 +65,11 @@ def main(args):
67
  interactive=True,
68
  value=True
69
  )
 
 
 
 
 
70
  temperature_slider = gr.Slider(
71
  minimum=0.00001,
72
  maximum=1.0,
@@ -164,6 +167,7 @@ def main(args):
164
  inputs=[text_input,
165
  text_seed_input,
166
  refine_text_checkBox,
 
167
  temperature_slider,
168
  top_p_slider,
169
  top_k_slider,
@@ -195,6 +199,7 @@ def main(args):
195
  def general_chat_infer_audio(text,
196
  text_seed_input,
197
  refine_text_checkBox,
 
198
  temperature_slider,
199
  top_p_slider,
200
  top_k_slider,
@@ -223,30 +228,35 @@ def general_chat_infer_audio(text,
223
  params_refine_text=params_refine_text,
224
  )
225
 
226
- logger.info("========开始生成音频文件=====")
227
- #torch.manual_seed(audio_seed_input)
228
 
229
  with TorchSeedContext(audio_seed_input):
230
- #rand_spk = torch.randn(768)
231
- rand_spk = chat.sample_random_speaker_tensor()
232
- logger.info("========生成音频spk_emb参数完成=====")
233
- params_infer_code = {
234
- 'spk_emb': rand_spk,
235
- 'temperature': temperature_slider,
236
- 'top_P': top_p_slider,
237
- 'top_K': top_k_slider,
238
- }
239
- wav = chat.infer(
240
- text=chat_txt,
241
- skip_refine_text=True, #跳过文本优化
242
- params_refine_text=params_refine_text,
243
- params_infer_code=params_infer_code,
244
- )
 
 
 
 
 
 
 
 
245
  #yield 24000, float_to_int16(wav[0]).T
246
  audio_data = np.array(wav[0]).flatten()
247
- sample_rate = 24000
248
  text_data = chat_txt[0] if isinstance(chat_txt, list) else chat_txt
249
-
250
  return [text_data,(sample_rate, audio_data)]
251
 
252
 
@@ -283,36 +293,8 @@ def general_chat_infer_audio(text,
283
  # rand_spk = torch.randn(audio_seed_input)
284
  # return encode_spk_emb(rand_spk)
285
 
286
- # def encode_spk_emb(spk_emb: torch.Tensor) -> str:
287
- # import pybase16384 as b14
288
- # import lzma
289
- # with torch.no_grad():
290
- # arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
291
- # s = b14.encode_to_string(
292
- # lzma.compress(
293
- # arr.tobytes(),
294
- # format=lzma.FORMAT_RAW,
295
- # filters=[
296
- # {"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}
297
- # ],
298
- # ),
299
- # )
300
- # del arr
301
- # return s
302
-
303
-
304
- # def _sample_random_speaker(self) -> torch.Tensor:
305
- # with torch.no_grad():
306
- # dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
307
- # out: torch.Tensor = self.pretrain_models["spk_stat"]
308
- # std, mean = out.chunk(2)
309
- # spk = (
310
- # torch.randn(dim, device=std.device, dtype=torch.float16)
311
- # .mul_(std)
312
- # .add_(mean)
313
- # )
314
- # del out, std, mean
315
- # return spk
316
 
317
  if __name__ == "__main__":
318
  parser = argparse.ArgumentParser(description="ChatTTS demo Launch")
 
18
 
19
  torch._dynamo.config.suppress_errors = True
20
 
21
+ #HF空间中,GPU模式 运行 大模型代码
 
22
  logger = get_logger("app")
 
23
  # Initialize and load the model:
24
  chat = Chat2TTS.Chat()
25
 
 
65
  interactive=True,
66
  value=True
67
  )
68
+ refine_audio_checkBox = gr.Checkbox(
69
+ label="是否生成音频文件,如是才会生成音频文件",
70
+ interactive=True,
71
+ value=True
72
+ )
73
  temperature_slider = gr.Slider(
74
  minimum=0.00001,
75
  maximum=1.0,
 
167
  inputs=[text_input,
168
  text_seed_input,
169
  refine_text_checkBox,
170
+ refine_audio_checkBox,
171
  temperature_slider,
172
  top_p_slider,
173
  top_k_slider,
 
199
  def general_chat_infer_audio(text,
200
  text_seed_input,
201
  refine_text_checkBox,
202
+ refine_audio_checkBox,
203
  temperature_slider,
204
  top_p_slider,
205
  top_k_slider,
 
228
  params_refine_text=params_refine_text,
229
  )
230
 
 
 
231
 
232
  with TorchSeedContext(audio_seed_input):
233
+ if not refine_audio_checkBox:
234
+ logger.info("========无需生成音频文件=====")
235
+ #创建一个空的音频文件
236
+ wav = chat.emptpy_audio()
237
+ else:
238
+ logger.info("========开始生成音频文件=====")
239
+ #torch.manual_seed(audio_seed_input)
240
+ #rand_spk = torch.randn(768)
241
+ rand_spk = chat.sample_random_speaker_tensor()
242
+ logger.info("========生成音频spk_emb参数完成=====")
243
+ params_infer_code = {
244
+ 'spk_emb': rand_spk,
245
+ 'temperature': temperature_slider,
246
+ 'top_P': top_p_slider,
247
+ 'top_K': top_k_slider,
248
+ }
249
+ wav = chat.infer(
250
+ text=chat_txt,
251
+ skip_refine_text=True, #跳过文本优化
252
+ params_refine_text=params_refine_text,
253
+ params_infer_code=params_infer_code,
254
+ )
255
+
256
  #yield 24000, float_to_int16(wav[0]).T
257
  audio_data = np.array(wav[0]).flatten()
 
258
  text_data = chat_txt[0] if isinstance(chat_txt, list) else chat_txt
259
+ sample_rate = 24000
260
  return [text_data,(sample_rate, audio_data)]
261
 
262
 
 
293
  # rand_spk = torch.randn(audio_seed_input)
294
  # return encode_spk_emb(rand_spk)
295
 
296
+
297
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
  if __name__ == "__main__":
300
  parser = argparse.ArgumentParser(description="ChatTTS demo Launch")