chenjgtea
commited on
Commit
·
8dce793
1
Parent(s):
7766575
新增gpu模式下chattts代码
Browse files- Chat2TTS/core.py +2 -0
- web/{app_sc.py → app_cpu.py} +0 -0
- web/app_gpu.py +5 -5
Chat2TTS/core.py
CHANGED
@@ -25,6 +25,7 @@ class Chat:
|
|
25 |
def __init__(self, ):
|
26 |
self.pretrain_models = {}
|
27 |
self.logger = logging.getLogger(__name__)
|
|
|
28 |
|
29 |
def check_model(self, level = logging.INFO, use_decoder = False):
|
30 |
not_finish = False
|
@@ -201,6 +202,7 @@ class Chat:
|
|
201 |
return s
|
202 |
|
203 |
def _sample_random_speaker(self) -> torch.Tensor:
|
|
|
204 |
with torch.no_grad():
|
205 |
dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
|
206 |
out: torch.Tensor = self.pretrain_models["spk_stat"]
|
|
|
25 |
def __init__(self, ):
|
26 |
self.pretrain_models = {}
|
27 |
self.logger = logging.getLogger(__name__)
|
28 |
+
self.gpt=None
|
29 |
|
30 |
def check_model(self, level = logging.INFO, use_decoder = False):
|
31 |
not_finish = False
|
|
|
202 |
return s
|
203 |
|
204 |
def _sample_random_speaker(self) -> torch.Tensor:
|
205 |
+
|
206 |
with torch.no_grad():
|
207 |
dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
|
208 |
out: torch.Tensor = self.pretrain_models["spk_stat"]
|
web/{app_sc.py → app_cpu.py}
RENAMED
File without changes
|
web/app_gpu.py
CHANGED
@@ -26,21 +26,21 @@ chat = Chat2TTS.Chat()
|
|
26 |
|
27 |
def init_chat(args):
|
28 |
global chat
|
29 |
-
source = "
|
30 |
# 获取启动模式
|
31 |
MODEL = os.getenv('MODEL')
|
32 |
-
logger.info("loading Chat2TTS model..., start MODEL:" + str(MODEL))
|
33 |
# huggingface 部署模式下,模型则直接使用hf的模型数据
|
34 |
if MODEL == "HF":
|
35 |
source = "huggingface"
|
36 |
|
37 |
-
|
38 |
|
|
|
39 |
logger.info("loading ChatTTS device :" + str(device))
|
40 |
|
41 |
if chat.load_models(source=source, local_path="D:\\chenjgspace\\ai-model\\chattts"):
|
42 |
print("Models loaded successfully.")
|
43 |
-
|
44 |
# else:
|
45 |
# logger.error("=========Models load failed.")
|
46 |
# sys.exit(1)
|
@@ -212,7 +212,7 @@ def get_chat_infer_audio(chat_txt,
|
|
212 |
# torch.manual_seed(audio_seed_input)
|
213 |
# rand_spk = torch.randn(768)
|
214 |
params_infer_code = {
|
215 |
-
'spk_emb':
|
216 |
'temperature': temperature_slider,
|
217 |
'top_P': top_p_slider,
|
218 |
'top_K': top_k_slider,
|
|
|
26 |
|
27 |
def init_chat(args):
|
28 |
global chat
|
29 |
+
source = "local"
|
30 |
# 获取启动模式
|
31 |
MODEL = os.getenv('MODEL')
|
|
|
32 |
# huggingface 部署模式下,模型则直接使用hf的模型数据
|
33 |
if MODEL == "HF":
|
34 |
source = "huggingface"
|
35 |
|
36 |
+
logger.info("loading Chat2TTS model..., start source:" + source)
|
37 |
|
38 |
+
device=select_device()
|
39 |
logger.info("loading ChatTTS device :" + str(device))
|
40 |
|
41 |
if chat.load_models(source=source, local_path="D:\\chenjgspace\\ai-model\\chattts"):
|
42 |
print("Models loaded successfully.")
|
43 |
+
logger.info("Models loaded end.")
|
44 |
# else:
|
45 |
# logger.error("=========Models load failed.")
|
46 |
# sys.exit(1)
|
|
|
212 |
# torch.manual_seed(audio_seed_input)
|
213 |
# rand_spk = torch.randn(768)
|
214 |
params_infer_code = {
|
215 |
+
'spk_emb': None,
|
216 |
'temperature': temperature_slider,
|
217 |
'top_P': top_p_slider,
|
218 |
'top_K': top_k_slider,
|