chenjgtea
commited on
Commit
·
7766575
1
Parent(s):
7027c47
新增gpu模式下chattts代码
Browse files- Chat2TTS/core.py +10 -1
Chat2TTS/core.py
CHANGED
@@ -100,8 +100,17 @@ class Chat:
|
|
100 |
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
|
101 |
gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu'))
|
102 |
self.pretrain_models['gpt'] = gpt
|
|
|
103 |
self.logger.log(logging.INFO, 'gpt loaded.')
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
if decoder_config_path:
|
106 |
cfg = OmegaConf.load(decoder_config_path)
|
107 |
decoder = DVAE(**cfg).to(device).eval()
|
|
|
100 |
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
|
101 |
gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu'))
|
102 |
self.pretrain_models['gpt'] = gpt
|
103 |
+
self.gpt = gpt
|
104 |
self.logger.log(logging.INFO, 'gpt loaded.')
|
105 |
+
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt")
|
106 |
+
assert os.path.exists(
|
107 |
+
spk_stat_path
|
108 |
+
), f"Missing spk_stat.pt: {spk_stat_path}"
|
109 |
+
self.pretrain_models["spk_stat"] = torch.load(
|
110 |
+
spk_stat_path, weights_only=True, mmap=True, map_location='cpu'
|
111 |
+
).to(device)
|
112 |
+
|
113 |
+
|
114 |
if decoder_config_path:
|
115 |
cfg = OmegaConf.load(decoder_config_path)
|
116 |
decoder = DVAE(**cfg).to(device).eval()
|