chenjgtea
commited on
Commit
·
3fd53ff
1
Parent(s):
8825975
新增gpu模式下chattts代码
Browse files- README.md +1 -1
- web/{app.py → app_gpu.py} +51 -12
README.md
CHANGED
@@ -6,7 +6,7 @@ colorTo: purple
|
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.41.0
|
8 |
#app_port: 8080
|
9 |
-
app_file: web/
|
10 |
pinned: false
|
11 |
---
|
12 |
|
|
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.41.0
|
8 |
#app_port: 8080
|
9 |
+
app_file: web/app_gpu.py
|
10 |
pinned: false
|
11 |
---
|
12 |
|
web/{app.py → app_gpu.py}
RENAMED
@@ -203,13 +203,20 @@ def get_chat_infer_audio(chat_txt,
|
|
203 |
spk_emb_text):
|
204 |
logger.info("========开始生成音频文件=====")
|
205 |
#音频参数设置
|
206 |
-
params_infer_code = Chat2TTS.Chat.InferCodeParams(
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
)
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
torch.manual_seed(audio_seed_input)
|
214 |
wav = chat.infer(
|
215 |
text=chat_txt,
|
@@ -227,10 +234,11 @@ def get_chat_infer_text(text,seed,refine_text_checkBox):
|
|
227 |
logger.info("========文本内容无需优化=====")
|
228 |
return text
|
229 |
|
230 |
-
params_refine_text = Chat2TTS.Chat.RefineTextParams(
|
231 |
-
|
232 |
-
)
|
233 |
|
|
|
234 |
torch.manual_seed(seed)
|
235 |
chat_text = chat.infer(
|
236 |
text=text,
|
@@ -245,9 +253,40 @@ def get_chat_infer_text(text,seed,refine_text_checkBox):
|
|
245 |
def on_audio_seed_change(audio_seed_input):
|
246 |
global chat
|
247 |
torch.manual_seed(audio_seed_input)
|
248 |
-
rand_spk =
|
249 |
return rand_spk
|
250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
|
252 |
if __name__ == "__main__":
|
253 |
parser = argparse.ArgumentParser(description="ChatTTS demo Launch")
|
|
|
203 |
spk_emb_text):
|
204 |
logger.info("========开始生成音频文件=====")
|
205 |
#音频参数设置
|
206 |
+
# params_infer_code = Chat2TTS.Chat.InferCodeParams(
|
207 |
+
# spk_emb=spk_emb_text, # add sampled speaker
|
208 |
+
# temperature=temperature_slider, # using custom temperature
|
209 |
+
# top_P=top_p_slider, # top P decode
|
210 |
+
# top_K=top_k_slider, # top K decode
|
211 |
+
# )
|
212 |
+
torch.manual_seed(audio_seed_input)
|
213 |
+
rand_spk = torch.randn(768)
|
214 |
+
params_infer_code = {
|
215 |
+
'spk_emb': rand_spk,
|
216 |
+
'temperature': temperature_slider,
|
217 |
+
'top_P': top_p_slider,
|
218 |
+
'top_K': top_k_slider,
|
219 |
+
}
|
220 |
torch.manual_seed(audio_seed_input)
|
221 |
wav = chat.infer(
|
222 |
text=chat_txt,
|
|
|
234 |
logger.info("========文本内容无需优化=====")
|
235 |
return text
|
236 |
|
237 |
+
# params_refine_text = Chat2TTS.Chat.RefineTextParams(
|
238 |
+
# prompt='[oral_2][laugh_0][break_6]',
|
239 |
+
# )
|
240 |
|
241 |
+
params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'}
|
242 |
torch.manual_seed(seed)
|
243 |
chat_text = chat.infer(
|
244 |
text=text,
|
|
|
253 |
def on_audio_seed_change(audio_seed_input):
|
254 |
global chat
|
255 |
torch.manual_seed(audio_seed_input)
|
256 |
+
rand_spk = torch.randn(audio_seed_input)
|
257 |
return rand_spk
|
258 |
+
#return encode_spk_emb(rand_spk)
|
259 |
+
|
260 |
+
def encode_spk_emb(spk_emb: torch.Tensor) -> str:
|
261 |
+
import pybase16384 as b14
|
262 |
+
import lzma
|
263 |
+
with torch.no_grad():
|
264 |
+
arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
|
265 |
+
s = b14.encode_to_string(
|
266 |
+
lzma.compress(
|
267 |
+
arr.tobytes(),
|
268 |
+
format=lzma.FORMAT_RAW,
|
269 |
+
filters=[
|
270 |
+
{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}
|
271 |
+
],
|
272 |
+
),
|
273 |
+
)
|
274 |
+
del arr
|
275 |
+
return s
|
276 |
+
|
277 |
+
|
278 |
+
# def _sample_random_speaker(self) -> torch.Tensor:
|
279 |
+
# with torch.no_grad():
|
280 |
+
# dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
|
281 |
+
# out: torch.Tensor = self.pretrain_models["spk_stat"]
|
282 |
+
# std, mean = out.chunk(2)
|
283 |
+
# spk = (
|
284 |
+
# torch.randn(dim, device=std.device, dtype=torch.float16)
|
285 |
+
# .mul_(std)
|
286 |
+
# .add_(mean)
|
287 |
+
# )
|
288 |
+
# del out, std, mean
|
289 |
+
# return spk
|
290 |
|
291 |
if __name__ == "__main__":
|
292 |
parser = argparse.ArgumentParser(description="ChatTTS demo Launch")
|