chenjgtea commited on
Commit
7766575
·
1 Parent(s): 7027c47

新增gpu模式下chattts代码

Browse files
Files changed (1) hide show
  1. Chat2TTS/core.py +10 -1
Chat2TTS/core.py CHANGED
@@ -100,8 +100,17 @@ class Chat:
100
  assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
101
  gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu'))
102
  self.pretrain_models['gpt'] = gpt
 
103
  self.logger.log(logging.INFO, 'gpt loaded.')
104
-
 
 
 
 
 
 
 
 
105
  if decoder_config_path:
106
  cfg = OmegaConf.load(decoder_config_path)
107
  decoder = DVAE(**cfg).to(device).eval()
 
100
  assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
101
  gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu'))
102
  self.pretrain_models['gpt'] = gpt
103
+ self.gpt = gpt
104
  self.logger.log(logging.INFO, 'gpt loaded.')
105
+ spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt")
106
+ assert os.path.exists(
107
+ spk_stat_path
108
+ ), f"Missing spk_stat.pt: {spk_stat_path}"
109
+ self.pretrain_models["spk_stat"] = torch.load(
110
+ spk_stat_path, weights_only=True, mmap=True, map_location='cpu'
111
+ ).to(device)
112
+
113
+
114
  if decoder_config_path:
115
  cfg = OmegaConf.load(decoder_config_path)
116
  decoder = DVAE(**cfg).to(device).eval()