chenjgtea
commited on
Commit
·
638a294
1
Parent(s):
1898711
gpu模型下代码更新,已定型
Browse files- Chat2TTS/core.py +31 -7
- Chat2TTS/res/__init__.py +0 -0
- Chat2TTS/res/homophones_map.json +0 -0
- web/app_cpu.py +1 -5
- web/app_gpu.py +34 -52
Chat2TTS/core.py
CHANGED
@@ -12,20 +12,21 @@ from .utils.io_utils import get_latest_modified_file
|
|
12 |
from .infer.api import refine_text, infer_code
|
13 |
from dataclasses import dataclass
|
14 |
from typing import Literal, Optional, List, Tuple, Dict
|
15 |
-
|
16 |
-
import pybase16384 as b14
|
17 |
-
import lzma
|
18 |
|
19 |
-
from
|
20 |
|
21 |
-
|
22 |
|
23 |
|
24 |
class Chat:
|
25 |
def __init__(self, ):
|
26 |
self.pretrain_models = {}
|
27 |
-
self.logger = logging.
|
28 |
-
self.
|
|
|
|
|
|
|
29 |
|
30 |
def check_model(self, level = logging.INFO, use_decoder = False):
|
31 |
not_finish = False
|
@@ -159,6 +160,21 @@ class Chat:
|
|
159 |
):
|
160 |
|
161 |
assert self.check_model(use_decoder=use_decoder)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
if not skip_refine_text:
|
164 |
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
|
@@ -179,6 +195,14 @@ class Chat:
|
|
179 |
wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
|
180 |
|
181 |
return wav
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
|
184 |
# def sample_random_speaker(self) -> str:
|
|
|
12 |
from .infer.api import refine_text, infer_code
|
13 |
from dataclasses import dataclass
|
14 |
from typing import Literal, Optional, List, Tuple, Dict
|
15 |
+
from tool.logger import get_logger
|
|
|
|
|
16 |
|
17 |
+
from ChatTTS.norm import Normalizer
|
18 |
|
19 |
+
from huggingface_hub import snapshot_download
|
20 |
|
21 |
|
22 |
class Chat:
|
23 |
def __init__(self, ):
|
24 |
self.pretrain_models = {}
|
25 |
+
self.logger = get_logger(__name__,lv=logging.INFO)
|
26 |
+
self.normalizer = Normalizer(
|
27 |
+
os.path.join(os.path.dirname(__file__), "res", "homophones_map.json"),
|
28 |
+
self.logger,
|
29 |
+
)
|
30 |
|
31 |
def check_model(self, level = logging.INFO, use_decoder = False):
|
32 |
not_finish = False
|
|
|
160 |
):
|
161 |
|
162 |
assert self.check_model(use_decoder=use_decoder)
|
163 |
+
|
164 |
+
if skip_refine_text:
|
165 |
+
self.logger.info("========对文本内容不做优化处理,仅做规则处理======")
|
166 |
+
if not isinstance(text, list):
|
167 |
+
text = [text]
|
168 |
+
|
169 |
+
text = [
|
170 |
+
self.normalizer(
|
171 |
+
text=t,
|
172 |
+
do_text_normalization=True,
|
173 |
+
do_homophone_replacement=True,
|
174 |
+
lang=None,
|
175 |
+
)
|
176 |
+
for t in text
|
177 |
+
]
|
178 |
|
179 |
if not skip_refine_text:
|
180 |
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
|
|
|
195 |
wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
|
196 |
|
197 |
return wav
|
198 |
+
|
199 |
+
def emptpy_audio(self):
|
200 |
+
return self.infer(" ",
|
201 |
+
skip_refine_text=True,
|
202 |
+
refine_text_only=False,
|
203 |
+
params_refine_text={},
|
204 |
+
params_infer_code={},
|
205 |
+
use_decoder=False)
|
206 |
|
207 |
|
208 |
# def sample_random_speaker(self) -> str:
|
Chat2TTS/res/__init__.py
ADDED
File without changes
|
Chat2TTS/res/homophones_map.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
web/app_cpu.py
CHANGED
@@ -14,17 +14,13 @@ from tool.ctx import TorchSeedContext
|
|
14 |
import ChatTTS
|
15 |
import argparse
|
16 |
import torch._dynamo
|
17 |
-
|
18 |
torch._dynamo.config.suppress_errors = True
|
19 |
|
20 |
|
21 |
-
|
22 |
logger = get_logger("app")
|
23 |
-
|
24 |
# Initialize and load the model:
|
25 |
chat = ChatTTS.Chat()
|
26 |
-
|
27 |
-
|
28 |
def init_chat(args):
|
29 |
global chat
|
30 |
source = "custom"
|
|
|
14 |
import ChatTTS
|
15 |
import argparse
|
16 |
import torch._dynamo
|
|
|
17 |
torch._dynamo.config.suppress_errors = True
|
18 |
|
19 |
|
20 |
+
#HF空间中,GPU模式 运行 大模型代码
|
21 |
logger = get_logger("app")
|
|
|
22 |
# Initialize and load the model:
|
23 |
chat = ChatTTS.Chat()
|
|
|
|
|
24 |
def init_chat(args):
|
25 |
global chat
|
26 |
source = "custom"
|
web/app_gpu.py
CHANGED
@@ -18,10 +18,8 @@ import torch._dynamo
|
|
18 |
|
19 |
torch._dynamo.config.suppress_errors = True
|
20 |
|
21 |
-
|
22 |
-
|
23 |
logger = get_logger("app")
|
24 |
-
|
25 |
# Initialize and load the model:
|
26 |
chat = Chat2TTS.Chat()
|
27 |
|
@@ -67,6 +65,11 @@ def main(args):
|
|
67 |
interactive=True,
|
68 |
value=True
|
69 |
)
|
|
|
|
|
|
|
|
|
|
|
70 |
temperature_slider = gr.Slider(
|
71 |
minimum=0.00001,
|
72 |
maximum=1.0,
|
@@ -164,6 +167,7 @@ def main(args):
|
|
164 |
inputs=[text_input,
|
165 |
text_seed_input,
|
166 |
refine_text_checkBox,
|
|
|
167 |
temperature_slider,
|
168 |
top_p_slider,
|
169 |
top_k_slider,
|
@@ -195,6 +199,7 @@ def main(args):
|
|
195 |
def general_chat_infer_audio(text,
|
196 |
text_seed_input,
|
197 |
refine_text_checkBox,
|
|
|
198 |
temperature_slider,
|
199 |
top_p_slider,
|
200 |
top_k_slider,
|
@@ -223,30 +228,35 @@ def general_chat_infer_audio(text,
|
|
223 |
params_refine_text=params_refine_text,
|
224 |
)
|
225 |
|
226 |
-
logger.info("========开始生成音频文件=====")
|
227 |
-
#torch.manual_seed(audio_seed_input)
|
228 |
|
229 |
with TorchSeedContext(audio_seed_input):
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
#yield 24000, float_to_int16(wav[0]).T
|
246 |
audio_data = np.array(wav[0]).flatten()
|
247 |
-
sample_rate = 24000
|
248 |
text_data = chat_txt[0] if isinstance(chat_txt, list) else chat_txt
|
249 |
-
|
250 |
return [text_data,(sample_rate, audio_data)]
|
251 |
|
252 |
|
@@ -283,36 +293,8 @@ def general_chat_infer_audio(text,
|
|
283 |
# rand_spk = torch.randn(audio_seed_input)
|
284 |
# return encode_spk_emb(rand_spk)
|
285 |
|
286 |
-
|
287 |
-
|
288 |
-
# import lzma
|
289 |
-
# with torch.no_grad():
|
290 |
-
# arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
|
291 |
-
# s = b14.encode_to_string(
|
292 |
-
# lzma.compress(
|
293 |
-
# arr.tobytes(),
|
294 |
-
# format=lzma.FORMAT_RAW,
|
295 |
-
# filters=[
|
296 |
-
# {"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}
|
297 |
-
# ],
|
298 |
-
# ),
|
299 |
-
# )
|
300 |
-
# del arr
|
301 |
-
# return s
|
302 |
-
|
303 |
-
|
304 |
-
# def _sample_random_speaker(self) -> torch.Tensor:
|
305 |
-
# with torch.no_grad():
|
306 |
-
# dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
|
307 |
-
# out: torch.Tensor = self.pretrain_models["spk_stat"]
|
308 |
-
# std, mean = out.chunk(2)
|
309 |
-
# spk = (
|
310 |
-
# torch.randn(dim, device=std.device, dtype=torch.float16)
|
311 |
-
# .mul_(std)
|
312 |
-
# .add_(mean)
|
313 |
-
# )
|
314 |
-
# del out, std, mean
|
315 |
-
# return spk
|
316 |
|
317 |
if __name__ == "__main__":
|
318 |
parser = argparse.ArgumentParser(description="ChatTTS demo Launch")
|
|
|
18 |
|
19 |
torch._dynamo.config.suppress_errors = True
|
20 |
|
21 |
+
#HF空间中,GPU模式 运行 大模型代码
|
|
|
22 |
logger = get_logger("app")
|
|
|
23 |
# Initialize and load the model:
|
24 |
chat = Chat2TTS.Chat()
|
25 |
|
|
|
65 |
interactive=True,
|
66 |
value=True
|
67 |
)
|
68 |
+
refine_audio_checkBox = gr.Checkbox(
|
69 |
+
label="是否生成音频文件,如是才会生成音频文件",
|
70 |
+
interactive=True,
|
71 |
+
value=True
|
72 |
+
)
|
73 |
temperature_slider = gr.Slider(
|
74 |
minimum=0.00001,
|
75 |
maximum=1.0,
|
|
|
167 |
inputs=[text_input,
|
168 |
text_seed_input,
|
169 |
refine_text_checkBox,
|
170 |
+
refine_audio_checkBox,
|
171 |
temperature_slider,
|
172 |
top_p_slider,
|
173 |
top_k_slider,
|
|
|
199 |
def general_chat_infer_audio(text,
|
200 |
text_seed_input,
|
201 |
refine_text_checkBox,
|
202 |
+
refine_audio_checkBox,
|
203 |
temperature_slider,
|
204 |
top_p_slider,
|
205 |
top_k_slider,
|
|
|
228 |
params_refine_text=params_refine_text,
|
229 |
)
|
230 |
|
|
|
|
|
231 |
|
232 |
with TorchSeedContext(audio_seed_input):
|
233 |
+
if not refine_audio_checkBox:
|
234 |
+
logger.info("========无需生成音频文件=====")
|
235 |
+
#创建一个空的音频文件
|
236 |
+
wav = chat.emptpy_audio()
|
237 |
+
else:
|
238 |
+
logger.info("========开始生成音频文件=====")
|
239 |
+
#torch.manual_seed(audio_seed_input)
|
240 |
+
#rand_spk = torch.randn(768)
|
241 |
+
rand_spk = chat.sample_random_speaker_tensor()
|
242 |
+
logger.info("========生成音频spk_emb参数完成=====")
|
243 |
+
params_infer_code = {
|
244 |
+
'spk_emb': rand_spk,
|
245 |
+
'temperature': temperature_slider,
|
246 |
+
'top_P': top_p_slider,
|
247 |
+
'top_K': top_k_slider,
|
248 |
+
}
|
249 |
+
wav = chat.infer(
|
250 |
+
text=chat_txt,
|
251 |
+
skip_refine_text=True, #跳过文本优化
|
252 |
+
params_refine_text=params_refine_text,
|
253 |
+
params_infer_code=params_infer_code,
|
254 |
+
)
|
255 |
+
|
256 |
#yield 24000, float_to_int16(wav[0]).T
|
257 |
audio_data = np.array(wav[0]).flatten()
|
|
|
258 |
text_data = chat_txt[0] if isinstance(chat_txt, list) else chat_txt
|
259 |
+
sample_rate = 24000
|
260 |
return [text_data,(sample_rate, audio_data)]
|
261 |
|
262 |
|
|
|
293 |
# rand_spk = torch.randn(audio_seed_input)
|
294 |
# return encode_spk_emb(rand_spk)
|
295 |
|
296 |
+
|
297 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
|
299 |
if __name__ == "__main__":
|
300 |
parser = argparse.ArgumentParser(description="ChatTTS demo Launch")
|