chenjgtea commited on
Commit
8dce793
·
1 Parent(s): 7766575

新增gpu模式下chattts代码

Browse files
Chat2TTS/core.py CHANGED
@@ -25,6 +25,7 @@ class Chat:
25
  def __init__(self, ):
26
  self.pretrain_models = {}
27
  self.logger = logging.getLogger(__name__)
 
28
 
29
  def check_model(self, level = logging.INFO, use_decoder = False):
30
  not_finish = False
@@ -201,6 +202,7 @@ class Chat:
201
  return s
202
 
203
  def _sample_random_speaker(self) -> torch.Tensor:
 
204
  with torch.no_grad():
205
  dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
206
  out: torch.Tensor = self.pretrain_models["spk_stat"]
 
25
  def __init__(self, ):
26
  self.pretrain_models = {}
27
  self.logger = logging.getLogger(__name__)
28
+ self.gpt=None
29
 
30
  def check_model(self, level = logging.INFO, use_decoder = False):
31
  not_finish = False
 
202
  return s
203
 
204
  def _sample_random_speaker(self) -> torch.Tensor:
205
+
206
  with torch.no_grad():
207
  dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
208
  out: torch.Tensor = self.pretrain_models["spk_stat"]
web/{app_sc.py → app_cpu.py} RENAMED
File without changes
web/app_gpu.py CHANGED
@@ -26,21 +26,21 @@ chat = Chat2TTS.Chat()
26
 
27
  def init_chat(args):
28
  global chat
29
- source = "custom"
30
  # 获取启动模式
31
  MODEL = os.getenv('MODEL')
32
- logger.info("loading Chat2TTS model..., start MODEL:" + str(MODEL))
33
  # huggingface 部署模式下,模型则直接使用hf的模型数据
34
  if MODEL == "HF":
35
  source = "huggingface"
36
 
37
- device=select_device()
38
 
 
39
  logger.info("loading ChatTTS device :" + str(device))
40
 
41
  if chat.load_models(source=source, local_path="D:\\chenjgspace\\ai-model\\chattts"):
42
  print("Models loaded successfully.")
43
- logger.info("Models loaded successfully.")
44
  # else:
45
  # logger.error("=========Models load failed.")
46
  # sys.exit(1)
@@ -212,7 +212,7 @@ def get_chat_infer_audio(chat_txt,
212
  # torch.manual_seed(audio_seed_input)
213
  # rand_spk = torch.randn(768)
214
  params_infer_code = {
215
- 'spk_emb': spk_emb_text,
216
  'temperature': temperature_slider,
217
  'top_P': top_p_slider,
218
  'top_K': top_k_slider,
 
26
 
27
  def init_chat(args):
28
  global chat
29
+ source = "local"
30
  # 获取启动模式
31
  MODEL = os.getenv('MODEL')
 
32
  # huggingface 部署模式下,模型则直接使用hf的模型数据
33
  if MODEL == "HF":
34
  source = "huggingface"
35
 
36
+ logger.info("loading Chat2TTS model..., start source:" + source)
37
 
38
+ device=select_device()
39
  logger.info("loading ChatTTS device :" + str(device))
40
 
41
  if chat.load_models(source=source, local_path="D:\\chenjgspace\\ai-model\\chattts"):
42
  print("Models loaded successfully.")
43
+ logger.info("Models loaded end.")
44
  # else:
45
  # logger.error("=========Models load failed.")
46
  # sys.exit(1)
 
212
  # torch.manual_seed(audio_seed_input)
213
  # rand_spk = torch.randn(768)
214
  params_infer_code = {
215
+ 'spk_emb': None,
216
  'temperature': temperature_slider,
217
  'top_P': top_p_slider,
218
  'top_K': top_k_slider,