xlgeng commited on
Commit
841f290
·
1 Parent(s): 3efab1d

开始部署

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +414 -0
  2. common_utils/__init__.py +0 -0
  3. common_utils/convert_ckpt_dir_to_pt.py +27 -0
  4. common_utils/load_combine_type_yaml.py +59 -0
  5. common_utils/utils4infer.py +163 -0
  6. conf/ct_config.yaml +153 -0
  7. conf/ct_config_sft.yaml +152 -0
  8. conf/data_s2s.yaml +226 -0
  9. conf/data_s2t.yaml +402 -0
  10. conf/data_t2s.yaml +28 -0
  11. conf/data_t2t.yaml +159 -0
  12. conf/data_tmp.yaml +6 -0
  13. conf/ds_stage2.json +34 -0
  14. conf/empty.yaml +0 -0
  15. conf/prompt_config.yaml +0 -0
  16. conf/system_prompt.yaml +27 -0
  17. patches/cumstom_stop_criteria.py +85 -0
  18. patches/custom_speech_ngram_blocking.py +129 -0
  19. patches/custom_speech_repetition_penalty.py +22 -0
  20. patches/modelling_fm_infer_gpu.py +18 -0
  21. patches/modelling_qwen2_infer_gpu.py +416 -0
  22. patches/utils.py +4 -0
  23. requirements.txt +41 -0
  24. tts/__init__.py +0 -0
  25. tts/assert//345/256/236/351/252/214/345/256/244.png +0 -0
  26. tts/cosyvoice/__init__.py +0 -0
  27. tts/cosyvoice/bin/average_model.py +92 -0
  28. tts/cosyvoice/bin/export_jit.py +91 -0
  29. tts/cosyvoice/bin/export_onnx.py +116 -0
  30. tts/cosyvoice/bin/export_trt.sh +10 -0
  31. tts/cosyvoice/bin/inference.py +115 -0
  32. tts/cosyvoice/bin/train.py +170 -0
  33. tts/cosyvoice/cli/__init__.py +0 -0
  34. tts/cosyvoice/cli/cosyvoice.py +197 -0
  35. tts/cosyvoice/cli/frontend.py +240 -0
  36. tts/cosyvoice/cli/model.py +480 -0
  37. tts/cosyvoice/dataset/__init__.py +0 -0
  38. tts/cosyvoice/dataset/dataset.py +164 -0
  39. tts/cosyvoice/dataset/processor.py +435 -0
  40. tts/cosyvoice/flow/decoder.py +301 -0
  41. tts/cosyvoice/flow/flow.py +239 -0
  42. tts/cosyvoice/flow/flow_matching.py +217 -0
  43. tts/cosyvoice/flow/length_regulator.py +69 -0
  44. tts/cosyvoice/hifigan/discriminator.py +140 -0
  45. tts/cosyvoice/hifigan/f0_predictor.py +56 -0
  46. tts/cosyvoice/hifigan/generator.py +412 -0
  47. tts/cosyvoice/hifigan/hifigan.py +67 -0
  48. tts/cosyvoice/llm/llm.py +434 -0
  49. tts/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +0 -0
  50. tts/cosyvoice/tokenizer/tokenizer.py +279 -0
app.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import base64
3
+ import datetime
4
+ import json
5
+ import logging
6
+ import os
7
+ import spaces
8
+
9
+ import gradio as gr
10
+ import sys
11
+ import time
12
+ import traceback
13
+
14
+ import torch
15
+
16
+ from common_utils.utils4infer import get_feat_from_wav_path, load_model_and_tokenizer, token_list2wav
17
+
18
+ sys.path.insert(0, '.')
19
+ sys.path.insert(0, './tts')
20
+ sys.path.insert(0, './tts/third_party/Matcha-TTS')
21
+ from patches import modelling_qwen2_infer_gpu # 打patch
22
+ from tts.cosyvoice.cli.cosyvoice import CosyVoice
23
+ from tts.cosyvoice.utils.file_utils import load_wav
24
+
25
+ is_npu = False
26
+ try:
27
+ import torch_npu
28
+ except ImportError:
29
+ is_npu = False
30
+ print("torch_npu is not available. if you want to use npu, please install it.")
31
+
32
+
33
+
34
+
35
+ from huggingface_hub import hf_hub_download
36
+ # 从Hugging Face下载.pt文件
37
+ CHECKPOINT_PATH_A = hf_hub_download(repo_id="ASLP-lab/OSUM-EChat", filename="language_think_final.pt")
38
+ CHECKPOINT_PATH_B= hf_hub_download(repo_id="ASLP-lab/OSUM-EChat", filename="tag_think_final.pt")
39
+ CONFIG_PATH = "./conf/ct_config.yaml"
40
+ NAME_A="language_think"
41
+ NAME_B="tag_think"
42
+ cosyvoice_model_path = hf_hub_download(repo_id="ASLP-lab/OSUM-EChat", filename="CosyVoice-300M-25Hz.tar")
43
+ # 将tar包解压到当前目录
44
+ os.system(f"tar -xvf {cosyvoice_model_path}")
45
+ print("解压cosyvoice模型pt文件完成")
46
+ cosyvoice_model_path="./CosyVoice-300M-25Hz"
47
+
48
+
49
+
50
+ device = torch.device("cuda")
51
+ print("开始加载模型 A...")
52
+ model_a, tokenizer_a = load_model_and_tokenizer(CHECKPOINT_PATH_A, CONFIG_PATH)
53
+
54
+ print("\n开始加载模型 B...")
55
+ if CHECKPOINT_PATH_B is not None:
56
+ model_b, tokenizer_b = load_model_and_tokenizer(CHECKPOINT_PATH_B, CONFIG_PATH)
57
+ else:
58
+ model_b, tokenizer_b = None, None
59
+ loaded_models = {
60
+ NAME_A: {"model": model_a, "tokenizer": tokenizer_a},
61
+ NAME_B: {"model": model_b, "tokenizer": tokenizer_b},
62
+ } if model_b is not None else {
63
+ NAME_A: {"model": model_a, "tokenizer": tokenizer_a},
64
+ }
65
+ print("\n所有模型已加载完毕。")
66
+
67
+ cosyvoice = CosyVoice(cosyvoice_model_path, gpu_id=0)
68
+
69
+ # 将图片转换为 Base64
70
+ with open("./tts/assert/实验室.png", "rb") as image_file:
71
+ encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
72
+
73
+ # 任务映射
74
+ TASK_PROMPT_MAPPING = {
75
+ "empathetic_s2s_dialogue with think": "THINK",
76
+ "empathetic_s2s_dialogue no think": "s2s_no_think",
77
+ "empathetic_s2t_dialogue with think": "s2t_think",
78
+ "empathetic_s2t_dialogue no think": "s2t_no_think",
79
+ "ASR (Automatic Speech Recognition)": "转录这段音频中的语音内容为文字。",
80
+ "SRWT (Speech Recognition with Timestamps)": "请识别音频内容,并对所有英文词和中文字进行时间对齐,标注格式为<>,时间精度0.1秒。",
81
+ "VED (Vocal Event Detection)(类别:laugh,cough,cry,screaming,sigh,throat clearing,sneeze,other)": "请将音频转化为文字,并在末尾添加相关音频事件标签,标签格式为<>。",
82
+ "SER (Speech Emotion Recognition)(类别:sad,anger,neutral,happy,surprise,fear,disgust,和other)": "请将音频内容转录成文字记录,并在记录末尾标注情感标签,以<>表示。",
83
+ "SSR (Speaking Style Recognition)(类别:新闻科普,恐怖故事,童话故事,客服,诗歌散文,有声书,日常口语,其他)": "请将音频中的讲话内容转化为文字,并在结尾处注明风格标签,用<>表示。",
84
+ "SGC (Speaker Gender Classification)(类别:female,male)": "请将音频转录为文字,并在文本末尾标注性别标签,标签格式为<>。",
85
+ "SAP (Speaker Age Prediction)(类别:child、adult和old)": "请将这段音频转录成文字,并在末尾加上年龄标签,格式为<>。",
86
+ "STTC (Speech to Text Chat)": "首先将语音转录为文字,然后对语音内容进行回复,转录和文字之间使用<开始回答>分割。",
87
+ "Only Age Prediction(类别:child、adult和old)": "请根据音频分析发言者的年龄并输出年龄标签,标签格式为<>。",
88
+ "Only Gender Classification(类别:female,male)": "根据下述音频内容判断说话者性别,返回性别标签,格式为<>.",
89
+ "Only Style Recognition(类别:新闻科普,恐怖故事,童话故事,客服,诗歌散文,有声书,日常口语,其他)": "对于以下音频,请直接判断风格并返回风格标签,标签格式为<>。",
90
+ "Only Emotion Recognition(类别:sad,anger,neutral,happy,surprise,fear,disgust,和other)": "请鉴别音频中的发言者情感并标出,标签格式为<>。",
91
+ "Only Event Detection(类别:laugh,cough,cry,screaming,sigh,throat clearing,sneeze,other)": "对音频进行标签化,返回音频事件标签,标签格式为<>。",
92
+ "ASR+AGE+GENDER": '请将这段音频进行转录,并在转录完成的文本末尾附加<年龄> <性别>标签。',
93
+ "AGE+GENDER": "请识别以下音频发言者的年龄和性别.",
94
+ "ASR+STYLE+AGE+GENDER": "请对以下音频内容进行转录,并在文本结尾分别���加<风格>、<年龄>、<性别>标签。",
95
+ "STYLE+AGE+GENDER": "请对以下音频进行分析,识别说话风格、说话者年龄和性别。",
96
+ "ASR with punctuations": "需对提供的语音文件执行文本转换,同时为转换结果补充必要的标点。",
97
+ "ASR EVENT AGE GENDER": "请将以下音频内容进行转录,并在转录完成的文本末尾分别附加<音频事件>、<年龄>、<性别>标签。",
98
+ "ASR EMOTION AGE GENDER": "请将下列音频内容进行转录,并在转录文本的末尾分别添加<情感>、<年龄>、<性别>标签。",
99
+ }
100
+ prompt_path = hf_hub_download(repo_id="ASLP-lab/OSUM-EChat", filename="prompt.wav")
101
+ prompt_audio_choices = [
102
+ {"name": "拟人",
103
+ "value": prompt_path},
104
+ ]
105
+
106
+ prompt_audio_cache = {}
107
+ for item in prompt_audio_choices:
108
+ prompt_audio_cache[item["value"]] = load_wav(item["value"], 22050)
109
+
110
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
111
+
112
+
113
+
114
+ def do_s2t(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
115
+ model.eval()
116
+ feat, feat_lens = get_feat_from_wav_path(input_wav_path)
117
+ print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
118
+ if is_npu: torch_npu.npu.synchronize()
119
+ start_time = time.time()
120
+ res_text = model.generate(wavs=feat, wavs_len=feat_lens, prompt=input_prompt, cache_implementation="static")[0]
121
+ if is_npu: torch_npu.npu.synchronize()
122
+ end_time = time.time()
123
+ print(f"S2T 推理消耗时间: {end_time - start_time:.2f} 秒")
124
+ return res_text
125
+
126
+
127
+ def do_s2t4chat(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
128
+ model.eval()
129
+ feat, feat_lens = get_feat_from_wav_path(input_wav_path)
130
+ print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
131
+ if is_npu: torch_npu.npu.synchronize()
132
+ start_time = time.time()
133
+ res_text = model.generate4chat(wavs=feat, wavs_len=feat_lens, cache_implementation="static")[0]
134
+ if is_npu: torch_npu.npu.synchronize()
135
+ end_time = time.time()
136
+ print(f"S2T4Chat 推理消耗时间: {end_time - start_time:.2f} 秒")
137
+ return res_text
138
+
139
+ def do_s2t4chat_think(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
140
+ model.eval()
141
+ feat, feat_lens = get_feat_from_wav_path(input_wav_path)
142
+ print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
143
+ if is_npu: torch_npu.npu.synchronize()
144
+ start_time = time.time()
145
+ res_text = model.generate4chat_think(wavs=feat, wavs_len=feat_lens, cache_implementation="static")[0]
146
+ if is_npu: torch_npu.npu.synchronize()
147
+ end_time = time.time()
148
+ print(f"S2T4Chat 推理消耗时间: {end_time - start_time:.2f} 秒")
149
+ return res_text
150
+
151
+
152
+ def do_t2s(model, input_prompt, text_for_tts, profile=False): # 增加 model 参数
153
+ model.eval()
154
+ if is_npu: torch_npu.npu.synchronize()
155
+ start_time = time.time()
156
+ res_tensor = model.generate_tts(device=device, text=text_for_tts, )[0]
157
+ res_token_list = res_tensor.tolist()
158
+ res_text = res_token_list[:-1]
159
+ if is_npu: torch_npu.npu.synchronize()
160
+ end_time = time.time()
161
+ print(f"T2S 推理消耗时间: {end_time - start_time:.2f} 秒")
162
+ return res_text
163
+
164
+
165
+ def do_t2t(model, question_txt, profile=False): # 增加 model 参数
166
+ model.eval()
167
+ if is_npu: torch_npu.npu.synchronize()
168
+ start_time = time.time()
169
+ print(f'开始t2t推理, question_txt: {question_txt}')
170
+ res_text = model.generate_text2text(device=device, text=question_txt)[0]
171
+ if is_npu: torch_npu.npu.synchronize()
172
+ end_time = time.time()
173
+ print(f"T2T 推理消耗时间: {end_time - start_time:.2f} 秒")
174
+ return res_text
175
+
176
+
177
+ def do_s2s(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
178
+ model.eval()
179
+ feat, feat_lens = get_feat_from_wav_path(input_wav_path)
180
+ print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
181
+ if is_npu: torch_npu.npu.synchronize()
182
+ start_time = time.time()
183
+ output_text, text_res, speech_res = model.generate_s2s_no_stream_with_repetition_penalty(wavs=feat, wavs_len=feat_lens,)
184
+ if is_npu: torch_npu.npu.synchronize()
185
+ end_time = time.time()
186
+ print(f"S2S 推理消耗时间: {end_time - start_time:.2f} 秒")
187
+ return f'{output_text[0]}|{str(speech_res[0].tolist()[1:])}'
188
+
189
+ def do_s2s_think(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
190
+ model.eval()
191
+ feat, feat_lens = get_feat_from_wav_path(input_wav_path)
192
+ print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
193
+ if is_npu: torch_npu.npu.synchronize()
194
+ start_time = time.time()
195
+ output_text, text_res, speech_res = model.generate_s2s_no_stream_think_with_repetition_penalty(wavs=feat, wavs_len=feat_lens,)
196
+ if is_npu: torch_npu.npu.synchronize()
197
+ end_time = time.time()
198
+ print(f"S2S 推理消耗时间: {end_time - start_time:.2f} 秒")
199
+ return f'{output_text[0]}|{str(speech_res[0].tolist()[1:])}'
200
+
201
+ @spaces.GPU
202
+ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt): # 增加 model 和 tokenizer 参数
203
+ print(f"wav_path: {input_wav_path}, prompt:{input_prompt}")
204
+ if input_wav_path is None and not input_prompt.endswith(("_TTS", "_T2T")):
205
+ print("音频信息未输入,且不是T2S或T2T任务")
206
+ return "错误:需要音频输入"
207
+
208
+ if input_prompt.endswith("_TTS"):
209
+ text_for_tts = input_prompt.replace("_TTS", "")
210
+ prompt = "恳请将如下文本转换为其对应的语音token,力求生成最为流畅、自然的语音。"
211
+ res_text = do_t2s(model, prompt, text_for_tts)
212
+ elif input_prompt.endswith("_self_prompt"):
213
+ prompt = input_prompt.replace("_self_prompt", "")
214
+ res_text = do_s2t(model, input_wav_path, prompt)
215
+ elif input_prompt.endswith("_T2T"):
216
+ question_txt = input_prompt.replace("_T2T", "")
217
+ res_text = do_t2t(model, question_txt)
218
+ elif input_prompt in ["识别语音内容,并以文字方式作出回答。",
219
+ "请推断对这段语音回答时的情感,标注情感类型,撰写流畅自然的聊天回复,并生成情感语音token。",
220
+ "s2s_no_think"]:
221
+ res_text = do_s2s(model, input_wav_path, input_prompt)
222
+ elif input_prompt == "THINK":
223
+ res_text = do_s2s_think(model, input_wav_path, input_prompt)
224
+ elif input_prompt == "s2t_no_think":
225
+ res_text = do_s2t4chat(model, input_wav_path, input_prompt)
226
+ elif input_prompt == "s2t_think":
227
+ res_text = do_s2t4chat_think(model, input_wav_path, input_prompt)
228
+ else:
229
+ res_text = do_s2t(model, input_wav_path, input_prompt)
230
+ res_text = res_text.replace("<youth>", "<adult>").replace("<middle_age>", "<adult>").replace("<middle>",
231
+ "<adult>")
232
+
233
+ print("识别结果为:", res_text)
234
+ return res_text
235
+
236
+
237
+ def do_decode(model, tokenizer, input_wav_path, input_prompt): # 增加 model 和 tokenizer 参数
238
+ print(f'使用模型进行推理: input_wav_path={input_wav_path}, input_prompt={input_prompt}')
239
+ output_res = true_decode_fuc(model, tokenizer, input_wav_path, input_prompt)
240
+ return output_res
241
+
242
+
243
+ def save_to_jsonl(if_correct, wav, prompt, res):
244
+ data = {
245
+ "if_correct": if_correct,
246
+ "wav": wav,
247
+ "task": prompt,
248
+ "res": res
249
+ }
250
+ with open("results.jsonl", "a", encoding="utf-8") as f:
251
+ f.write(json.dumps(data, ensure_ascii=False) + "\n")
252
+
253
+
254
+ def download_audio(input_wav_path):
255
+ return input_wav_path if input_wav_path else None
256
+
257
+
258
+ def get_wav_from_token_list(input_list, prompt_speech):
259
+ time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
260
+ wav_path = f"./tmp/{time_str}.wav"
261
+ return token_list2wav(input_list, prompt_speech, wav_path, cosyvoice)
262
+
263
+
264
+ # --- Gradio 界面 ---
265
+ with gr.Blocks() as demo:
266
+ gr.Markdown(
267
+ f"""
268
+ <div style="display: flex; align-items: center; justify-content: center; text-align: center;">
269
+ <h1 style="font-family: 'Arial', sans-serif; color: #014377; font-size: 32px; margin-bottom: 0; display: inline-block; vertical-align: middle;">
270
+ OSUM Speech Understanding Model Test
271
+ </h1>
272
+ </div>
273
+ """
274
+ )
275
+
276
+ # ### --- 关键修改:添加模型选择器 --- ###
277
+ with gr.Row():
278
+ model_selector = gr.Radio(
279
+ choices=list(loaded_models.keys()), # 从加载的模型字典中获取选项
280
+ value=NAME_A, # 默认值
281
+ label="选择推理模型",
282
+ interactive=True
283
+ )
284
+
285
+ with gr.Row():
286
+ with gr.Column(scale=1, min_width=300):
287
+ audio_input = gr.Audio(label="录音", sources=["microphone", "upload"], type="filepath", visible=True)
288
+ with gr.Column(scale=1, min_width=300):
289
+ output_text = gr.Textbox(label="输出结果", lines=6, placeholder="生成的结果将显示在这里...",
290
+ interactive=False)
291
+
292
+ with gr.Row():
293
+ task_dropdown = gr.Dropdown(label="任务",
294
+ choices=list(TASK_PROMPT_MAPPING.keys()) + ["自主输入文本", "TTS任务", "T2T任务"],
295
+ value="empathetic_s2s_dialogue with think")
296
+ prompt_speech_dropdown = gr.Dropdown(label="参考音频(prompt_speech)",
297
+ choices=[(item["name"], item["value"]) for item in prompt_audio_choices],
298
+ value=prompt_audio_choices[0]["value"], visible=True)
299
+ custom_prompt_input = gr.Textbox(label="自定义任务提示", placeholder="请输入自定义任务提示...", visible=False)
300
+ tts_input = gr.Textbox(label="TTS输入文本", placeholder="请输入TTS任务的文本...", visible=False)
301
+ t2t_input = gr.Textbox(label="T2T输入文本", placeholder="请输入T2T任务的文本...", visible=False)
302
+
303
+ audio_player = gr.Audio(label="播放音频", type="filepath", interactive=False)
304
+
305
+ with gr.Row():
306
+ download_button = gr.DownloadButton("下载音频", variant="secondary",
307
+ elem_classes=["button-height", "download-button"])
308
+ submit_button = gr.Button("开始处理", variant="primary", elem_classes=["button-height", "submit-button"])
309
+
310
+ with gr.Row(visible=False) as confirmation_row:
311
+ # ... (确认组件不变)
312
+ gr.Markdown("请判断结果是否正确:")
313
+ confirmation_buttons = gr.Radio(choices=["正确", "错误"], label="", interactive=True, container=False,
314
+ elem_classes="confirmation-buttons")
315
+ save_button = gr.Button("提交反馈", variant="secondary")
316
+
317
+ # ... (底部内容不变)
318
+ with gr.Row():
319
+ with gr.Column(scale=1, min_width=800):
320
+ gr.Markdown(f"""...""") # 省略底部HTML
321
+
322
+
323
+ def show_confirmation(output_res, input_wav_path, input_prompt):
324
+ return gr.update(visible=True), output_res, input_wav_path, input_prompt
325
+
326
+
327
+ def save_result(if_correct, wav, prompt, res):
328
+ save_to_jsonl(if_correct, wav, prompt, res)
329
+ return gr.update(visible=False)
330
+
331
+
332
+ # handle_submit 函数现在接收 `selected_model_name` 参数
333
+ def handle_submit(selected_model_name, input_wav_path, task_choice, custom_prompt, tts_text, t2t_text,
334
+ prompt_speech):
335
+ # 1. 根据选择的模型名称,从字典中获取对应的模型和分词器
336
+ print(f"用户选择了: {selected_model_name}")
337
+ model_info = loaded_models[selected_model_name]
338
+ model_to_use = model_info["model"]
339
+ tokenizer_to_use = model_info["tokenizer"]
340
+
341
+ # 2. 准备 prompt
342
+ prompt_speech_data = prompt_audio_cache.get(prompt_speech)
343
+ if task_choice == "自主输入文本":
344
+ input_prompt = custom_prompt + "_self_prompt"
345
+ elif task_choice == "TTS任务":
346
+ input_prompt = tts_text + "_TTS"
347
+ elif task_choice == "T2T任务":
348
+ input_prompt = t2t_text + "_T2T"
349
+ else:
350
+ input_prompt = TASK_PROMPT_MAPPING.get(task_choice, "未知任务类型")
351
+
352
+ # 3. 调用重构后的推理函数,传入选择的模型
353
+ output_res = do_decode(model_to_use, tokenizer_to_use, input_wav_path, input_prompt)
354
+
355
+ # 4. 处理输出 (逻辑不变)
356
+ wav_path_output = input_wav_path
357
+ if task_choice == "TTS任务" or "empathetic_s2s_dialogue" in task_choice:
358
+ if isinstance(output_res, list): # TTS case
359
+ wav_path_output = get_wav_from_token_list(output_res, prompt_speech_data)
360
+ output_res = "生成的token: " + str(output_res)
361
+ elif isinstance(output_res, str) and "|" in output_res: # S2S case
362
+ try:
363
+ text_res, token_list_str = output_res.split("|")
364
+ token_list = json.loads(token_list_str)
365
+ wav_path_output = get_wav_from_token_list(token_list, prompt_speech_data)
366
+ output_res = text_res
367
+ except (ValueError, json.JSONDecodeError) as e:
368
+ print(f"处理S2S输出时出错: {e}")
369
+ output_res = f"错误:无法解析模型输出 - {output_res}"
370
+
371
+ return output_res, wav_path_output
372
+
373
+
374
+ # --- 绑定事件 (下拉框逻辑不变) ---
375
+ task_dropdown.change(fn=lambda choice: gr.update(visible=choice == "自主输入文本"), inputs=task_dropdown,
376
+ outputs=custom_prompt_input)
377
+ task_dropdown.change(fn=lambda choice: gr.update(visible=choice == "TTS任务"), inputs=task_dropdown,
378
+ outputs=tts_input)
379
+ task_dropdown.change(fn=lambda choice: gr.update(visible=choice == "T2T任务"), inputs=task_dropdown,
380
+ outputs=t2t_input)
381
+
382
+ submit_button.click(
383
+ fn=handle_submit,
384
+ # 在 inputs 列表中添加模型选择器 `model_selector`
385
+ inputs=[model_selector, audio_input, task_dropdown, custom_prompt_input, tts_input, t2t_input,
386
+ prompt_speech_dropdown],
387
+ outputs=[output_text, audio_player]
388
+ ).then(
389
+ fn=show_confirmation,
390
+ inputs=[output_text, audio_input, task_dropdown],
391
+ outputs=[confirmation_row, output_text, audio_input, task_dropdown]
392
+ )
393
+
394
+ download_button.click(fn=download_audio, inputs=[audio_input], outputs=[download_button])
395
+ save_button.click(fn=save_result, inputs=[confirmation_buttons, audio_input, task_dropdown, output_text],
396
+ outputs=confirmation_row)
397
+
398
+ # --- 关键修改:为两个模型分别进行预热 ---
399
+ print("开始预热模型...")
400
+ warmup_wav_path = "./tts/assert/hq_1.wav"
401
+ warmup_prompt = "将这段音频的语音内容详细记录为文字稿。"
402
+
403
+ for model_name, model_info in loaded_models.items():
404
+ print(f"正在预热 {model_name}...")
405
+ try:
406
+ # 使用重构后的 do_s2t 函数进行预热,传入对应的模型
407
+ res_text = do_s2t(model_info["model"], warmup_wav_path, warmup_prompt, profile=False)
408
+ print(f'{model_name} 预热完成。ASR推理结果: {res_text}')
409
+ except Exception as e:
410
+ print(f"预热 {model_name} 时发生错误: {e}")
411
+
412
+ # 启动Gradio界面
413
+ print("\nGradio 界面启动中...")
414
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
common_utils/__init__.py ADDED
File without changes
common_utils/convert_ckpt_dir_to_pt.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gxl_ai_utils.utils import utils_file
2
+ import torch
3
+ try:
4
+ import torch_npu
5
+ except:
6
+ pass
7
+ import os
8
+
9
+
10
+
11
+ def convert_ckpt_to_pt(pt_dir_path):
12
+ exp_dir = os.path.dirname(pt_dir_path)
13
+ pt_name = os.path.basename(pt_dir_path)
14
+ weight_dict = torch.load(f"{exp_dir}/{pt_name}/mp_rank_00_model_states.pt", map_location=torch.device('cpu'))[
15
+ 'module']
16
+ print(weight_dict.keys())
17
+ torch.save(weight_dict, f"{exp_dir}/{pt_name}.pt")
18
+
19
+ if __name__ == '__main__':
20
+ pt_dir_path, = utils_file.do_get_commandline_param(1, ["pt_dir_path"])
21
+ exp_dir = os.path.dirname(pt_dir_path)
22
+ pt_name = os.path.basename(pt_dir_path)
23
+ weight_dict = torch.load(f"{exp_dir}/{pt_name}/mp_rank_00_model_states.pt", map_location=torch.device('cpu'))[
24
+ 'module']
25
+ print(weight_dict.keys())
26
+ torch.save(weight_dict, f"{exp_dir}/{pt_name}.pt")
27
+ # weigth_dict = torch.load("/mnt/sfs/asr/code/wenet_undersdand_and_speech_xlgeng/examples/wenetspeech/whisper/exp/epoch24_cosyvoice1_new-set_token_1w_plus-multi_task_new/step_24999.pt")
common_utils/load_combine_type_yaml.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+
5
+ from gxl_ai_utils.utils import utils_file
6
+
7
+ data_config_path, tmp_file_path = utils_file.do_get_commandline_param(2)
8
+ # random.seed(10086)# 老的
9
+ # 把当前时间戳作为随机种子
10
+ random.seed(int(time.time()))
11
+ # random.seed(7891)# 尝试一下新的顺序 #7890
12
+ data_info_dict = utils_file.load_dict_from_yaml(data_config_path)
13
+ if data_info_dict is None:
14
+ data_info_dict = {}
15
+ total_list = []
16
+ for data_info in data_info_dict.values():
17
+ if "path" not in data_info:
18
+ print(f"path or weight not in data_info:{data_info}")
19
+ continue
20
+ if "weight" not in data_info:
21
+ data_weight = 1
22
+ else:
23
+ data_weight = int(float(data_info['weight']))
24
+ data_path_i = data_info['path']
25
+ utils_file.logging_info(f'path:{data_path_i} ')
26
+
27
+ if data_weight == 0:
28
+ data_weight = float(data_info['weight'])
29
+ if data_weight >= 0:
30
+ utils_file.logging_info(f'data {data_path_i} weight is {data_weight}, will be used as a list')
31
+ final_data_list_i_tmp = utils_file.load_list_file_clean(data_path_i)
32
+ true_num = int(len(final_data_list_i_tmp)*data_weight)
33
+ final_data_list_i = utils_file.do_get_random_sublist(final_data_list_i_tmp, true_num)
34
+ else:
35
+ final_data_list_i = utils_file.load_list_file_clean(data_path_i) * data_weight
36
+ # 判断数据类型
37
+ if "combines_list.txt" in data_path_i:
38
+ print(f'是 combine类型的数据')
39
+ tar_root_path = data_path_i.replace('combines_list.txt', 'combines_tar_root.txt')
40
+ if not os.path.exists(tar_root_path):
41
+ utils_file.logging_info(f'combine_list.txt:{data_path_i} 对应的 combines_tar_root.txt:{tar_root_path} 不存在')
42
+ continue
43
+ tar_root = utils_file.load_first_row_clean(tar_root_path)
44
+ if tar_root.endswith('/'):
45
+ tar_root = tar_root[:-1]
46
+ utils_file.logging_info(f' tar_root:{tar_root}')
47
+ new_final_data_list_i = []
48
+ for data_path_j in final_data_list_i:
49
+ # "combine_path|shard_path"
50
+ tmp_lines = f'{data_path_j}|{tar_root}/{utils_file.do_get_file_pure_name_from_path(data_path_j)}.tar'
51
+ new_final_data_list_i.append(tmp_lines)
52
+ else:
53
+ print(f'不是 combine类型的数据,是传统shard类型的数据')
54
+ new_final_data_list_i = [f'-|{data_path_j}' for data_path_j in final_data_list_i]
55
+
56
+ utils_file.logging_info(f'true load num is : {len(new_final_data_list_i)}')
57
+ total_list.extend(new_final_data_list_i)
58
+ random.shuffle(total_list)
59
+ utils_file.write_list_to_file(total_list, tmp_file_path)
common_utils/utils4infer.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import random
4
+ import re
5
+
6
+ import yaml
7
+ from cn2an import an2cn
8
+ from gxl_ai_utils.utils import utils_file
9
+ from wenet.utils.init_tokenizer import init_tokenizer
10
+ from gxl_ai_utils.config.gxl_config import GxlNode
11
+ from wenet.utils.init_model import init_model
12
+ import logging
13
+ import librosa
14
+ import torch
15
+ import torchaudio
16
+
17
+
18
+
19
+ def load_model_and_tokenizer(checkpoint_path, config_path, device:torch.device=torch.device('cuda')):
20
+ """
21
+ 封装了加载模型和分词器的逻辑
22
+ Args:
23
+ checkpoint_path (str): 模型权重文件路径
24
+ config_path (str): 模型配置文件路径
25
+ device (torch.device): 加载模型的设备
26
+ Returns:
27
+ model: 加载好的模型
28
+ tokenizer: 加载好的分词器
29
+ """
30
+ print(f"正在从以下路径加载模型: {checkpoint_path}")
31
+ args = GxlNode({"checkpoint": checkpoint_path})
32
+ configs = utils_file.load_dict_from_yaml(config_path)
33
+ model, configs = init_model(args, configs)
34
+ model = model.to(device).to(torch.bfloat16)
35
+ model.eval() # 设置为评估模式
36
+ tokenizer = init_tokenizer(configs)
37
+ print(f"模型 {checkpoint_path} 加载完成并移动到 {device}")
38
+ return model, tokenizer
39
+
40
+ def token_list2wav(token_list, prompt_speech, wav_path, cosyvoice):
41
+ token_list = [int(i) for i in token_list]
42
+ j = cosyvoice.inference_zero_shot_gz_22k(
43
+ '收到好友从远方寄来的生日礼物。',
44
+ '希望你以后能够做的比我还好呦。', prompt_speech, stream=False, token_list=token_list)
45
+ utils_file.makedir_for_file(wav_path)
46
+ torchaudio.save(wav_path, j['tts_speech'],cosyvoice.sample_rate)
47
+ print(f'语音合成完成,保存到 {wav_path}')
48
+ return wav_path
49
+
50
+ def do_resample(input_wav_path):
51
+ """..."""
52
+ waveform, sample_rate = torchaudio.load(input_wav_path)
53
+ if waveform.shape[0] > 1:
54
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
55
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
56
+ waveform = resampler(waveform)
57
+ return waveform, 16000
58
+
59
+
60
+ def get_feat_from_wav_path(input_wav_path, device:torch.device=torch.device('cuda')):
61
+ """..."""
62
+ waveform, sample_rate = do_resample(input_wav_path)
63
+ waveform = waveform.squeeze(0)
64
+ window = torch.hann_window(400)
65
+ stft = torch.stft(waveform, 400, 160, window=window, return_complex=True)
66
+ magnitudes = stft[..., :-1].abs() ** 2
67
+ filters = torch.from_numpy(librosa.filters.mel(sr=sample_rate, n_fft=400, n_mels=80))
68
+ mel_spec = filters @ magnitudes
69
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
70
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
71
+ log_spec = (log_spec + 4.0) / 4.0
72
+ feat = log_spec.transpose(0, 1)
73
+ feat_lens = torch.tensor([feat.shape[0]], dtype=torch.int64).to(device)
74
+ feat = feat.unsqueeze(0).to(device)
75
+ feat = feat.to(torch.bfloat16)
76
+ return feat, feat_lens
77
+
78
+
79
+
80
+ def do_format_shard_manifest4one(input_shards_path, tmp_file_path=None):
81
+ if tmp_file_path is None:
82
+ tmp_file_path = f'~/.cache/.temp/{random.randint(10000, 99999)}.txt'
83
+ data_path_i = input_shards_path
84
+ utils_file.logging_info(f'path:{data_path_i} ')
85
+ final_data_list_i = utils_file.load_list_file_clean(data_path_i)
86
+ # 判断数据类型
87
+ if "combines_list.txt" in data_path_i:
88
+ print(f'是 combine类型的数据')
89
+ tar_root_path = data_path_i.replace('combines_list.txt', 'combines_tar_root.txt')
90
+ if not os.path.exists(tar_root_path):
91
+ utils_file.logging_error(
92
+ f'combine_list.txt:{data_path_i} 对应的 combines_tar_root.txt:{tar_root_path} 不存在')
93
+ return
94
+ tar_root = utils_file.load_first_row_clean(tar_root_path)
95
+ if tar_root.endswith('/'):
96
+ tar_root = tar_root[:-1]
97
+ utils_file.logging_info(f' tar_root:{tar_root}')
98
+ new_final_data_list_i = []
99
+ for data_path_j in final_data_list_i:
100
+ # "combine_path|shard_path"
101
+ tmp_lines = f'{data_path_j}|{tar_root}/{utils_file.do_get_file_pure_name_from_path(data_path_j)}.tar'
102
+ new_final_data_list_i.append(tmp_lines)
103
+ else:
104
+ print(f'不是 combine类型的数据,是传统shard类型的数据')
105
+ new_final_data_list_i = [f'-|{data_path_j}' for data_path_j in final_data_list_i]
106
+
107
+ utils_file.logging_info(f'true load num is : {len(new_final_data_list_i)}')
108
+ utils_file.write_list_to_file(new_final_data_list_i, tmp_file_path)
109
+ return tmp_file_path
110
+
111
+
112
+
113
+ def convert_numbers_in_string(s):
114
+ # 正则表达式匹配数字(支持整数、小数、负数)
115
+ pattern = r'-?\d+\.?\d*'
116
+
117
+ def replace_func(match):
118
+ num_str = match.group()
119
+ try:
120
+ # 尝试转换数字
121
+ return an2cn(num_str)
122
+ except ValueError:
123
+ # 若转换失败(如非有效数字),返回原内容
124
+ return num_str
125
+ # 替换字符串中所有匹配的数字
126
+ return re.sub(pattern, replace_func, s)
127
+
128
+ def get_test_conf(config_path):
129
+ with open(config_path, 'r', encoding='utf-8') as fin:
130
+ print(f"加载配置文件 {config_path}")
131
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
132
+ configs['dataset_conf']['filter_conf']['filter_no_extra_info'] = False
133
+ test_conf = copy.deepcopy(configs['dataset_conf'])
134
+
135
+ # test_conf['filter_conf']['max_length'] = 3000 # whisper最长处理30s 102400
136
+ test_conf['filter_conf']['min_length'] = 10
137
+ test_conf['filter_conf']['token_max_length'] = 102400
138
+ test_conf['filter_conf']['token_min_length'] = 1
139
+ test_conf['filter_conf']['max_output_input_ratio'] = 102400
140
+ test_conf['filter_conf']['min_output_input_ratio'] = 0
141
+ test_conf['filter_conf']['filter_no_extra_info'] = False
142
+ test_conf['filter_conf']['max_seq_len'] = 102400
143
+ test_conf['speed_perturb'] = False
144
+ test_conf['spec_aug'] = False
145
+ test_conf['spec_sub'] = False
146
+ test_conf['spec_trim'] = False
147
+ test_conf['shuffle'] = False
148
+ test_conf['sort'] = False
149
+ test_conf['cycle'] = 1
150
+ test_conf['list_shuffle'] = True
151
+ if 'fbank_conf' in test_conf:
152
+ test_conf['fbank_conf']['dither'] = 0.0
153
+ elif 'mfcc_conf' in test_conf:
154
+ test_conf['mfcc_conf']['dither'] = 0.0
155
+ test_conf['batch_conf']['batch_type'] = "static"
156
+ test_conf['batch_conf']['batch_size'] = 1
157
+ test_conf['split_num'] = 1
158
+ test_conf['multi_num'] = 1
159
+ test_conf['other_filter_conf'] = {}
160
+ test_conf['data_recover'] = False
161
+ return configs, test_conf
162
+
163
+
conf/ct_config.yaml ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: osum_echat
2
+
3
+ # llm_path
4
+ llm_path: &llm_path "Qwen/Qwen2.5-3B-Instruct"
5
+
6
+ #
7
+ # model config
8
+ downsample_rate: 4 # 1 2 4 8
9
+ adapter_type: osum_echat
10
+ if_instruct: true
11
+ input_dim: 80
12
+
13
+ # tokenizer ,gxl
14
+ tokenizer: huggingface
15
+ tokenizer_conf:
16
+ llm_path: *llm_path
17
+
18
+ # lora config
19
+ use_lora: false
20
+ lora_alpha: 32
21
+ lora_rank: 64 # 3B -> 85M
22
+ lora_dropout: 0.1
23
+
24
+ # speech generate config
25
+ speech_token_num: &token_num 4097 #4097
26
+
27
+
28
+ # Configuration of parameters for training
29
+ fire_module: link_and_encoder_and_lora # link encoder llm link_and_encoder link_and_encoder_and_lora, llm需要配合use_lora为true
30
+
31
+ # other config
32
+ grad_clip: 5
33
+ accum_grad: 8
34
+ log_interval: 10
35
+ save_interval: 1250 #1250 #2500
36
+ max_epoch: 1
37
+ init_step: true
38
+
39
+ # training config
40
+ optim: adamw
41
+ optim_conf:
42
+ betas:
43
+ - 0.9
44
+ - 0.99
45
+ eps: 1.0e-06
46
+ lr: 1.0e-06
47
+ weight_decay: 0.01
48
+ scheduler: warmuplr
49
+ scheduler_conf:
50
+ warmup_steps: 2000
51
+
52
+
53
+ dataset: asr
54
+ dataset_conf:
55
+ speech_token_num: *token_num
56
+ batch_conf:
57
+ batch_size: 26
58
+ batch_type: dynamic
59
+ max_frames_in_batch: 28000000 #3000 #9000 #3000 #3300 # 3900
60
+ max_seq_in_batch: 3700 #1500 #4000 #1100 #1600 # 1900
61
+ feats_type: log_mel_spectrogram
62
+ filter_conf:
63
+ max_length: 20000
64
+ min_length: 20
65
+ token_max_length: 1200
66
+ token_min_length: 1
67
+ filter_no_extra_info: true # 如果没有task lang 等信息,直接过滤掉, 适用于通用多任务训练, 推理时应该关掉
68
+ max_seq_len: 2000 #、1100 #1000
69
+ other_filter_conf:
70
+ only_s2s: false # 只针对与s2s dataloader的过滤
71
+ only_s2t: false # 只针对与s2t dataloader的过滤
72
+ only_t2t: false # 只针对与t2t dataloader的过滤
73
+ only_t2s: false # 只针对与t2s dataloader的过滤
74
+ language_conf:
75
+ limited_langs:
76
+ - zh
77
+ log_mel_spectrogram_conf:
78
+ hop_length: 160
79
+ n_fft: 400
80
+ num_mel_bins: 80
81
+ padding: 0
82
+ resample_conf:
83
+ resample_rate: 16000
84
+ shuffle: true
85
+ shuffle_conf:
86
+ shuffle_size: 1500
87
+ sort: true
88
+ sort_conf:
89
+ sort_size: 500
90
+ spec_aug: true
91
+ spec_aug_conf:
92
+ max_f: 10
93
+ max_t: 50
94
+ num_f_mask: 2
95
+ num_t_mask: 2
96
+ spec_sub: true
97
+ spec_sub_conf:
98
+ max_t: 30
99
+ num_t_sub: 3
100
+ spec_trim: false
101
+ speed_perturb: false
102
+ eod_id: 151645
103
+ split_num: 1
104
+ multi_num: 2
105
+ prompt_conf_path: conf/prompt_config.yaml
106
+ data_recover: false
107
+ data_recover_conf:
108
+ start_idx: 0 # 删除前面start_idx个item(tar包)
109
+ other_tokenze_conf: # 一些对数据额外操作的可控按钮,这些操作一般来说再test时都得为false
110
+ only_info:
111
+ only_s2s: false # 只针对与s2s dataloader的过滤
112
+ only_s2t: false # 只针对与s2t dataloader的过滤
113
+ only_t2t: false # 只针对与t2t dataloader的过滤
114
+ only_t2s: false # 只针对与t2s dataloader的过滤
115
+ use_50_per_change_if_only_X: true # 50%的句子随机替换为其only X
116
+ use_s2s_streaming_random:
117
+ enable: false
118
+ rate: 0.5 # 1.0 表示100%的句子随机替换为其only X
119
+ natural_language_convert:
120
+ enable: false
121
+ rate: 0.00 # 1.0 表示100%的转换成自然语言模式
122
+ use_s2s_convert_s2t:
123
+ enable: false # 单独为s2t dataloader 开启s2s convert
124
+ rate: 1.0 # 1.0 表示100%的句子随机替换为其only X
125
+ use_streaming_tts:
126
+ enable: false
127
+ rate: 0.5 # 1.0 表示100%的句子随机替换为其only X
128
+ use_think_mode:
129
+ enable: false # 开启think 模式, 即随机替换为think模式的句子
130
+ rate: 0.8
131
+ other_filter_conf:
132
+ fiter_txt_is_None: true # 过滤掉text is "<NONE>"的语音数据,适配由于gender数据部分含有<NONE>标签而设计。但仅train起作用
133
+
134
+ # model config for encoder
135
+ encoder: transformer
136
+ encoder_conf:
137
+ activation_type: gelu
138
+ attention_dropout_rate: 0.0
139
+ attention_heads: 16
140
+ dropout_rate: 0.1
141
+ gradient_checkpointing: true
142
+ input_layer: conv1d2
143
+ key_bias: false
144
+ linear_units: 4096
145
+ normalize_before: true
146
+ num_blocks: 24
147
+ output_size: 1024
148
+ pos_enc_layer_type: abs_pos_whisper
149
+ positional_dropout_rate: 0.1
150
+ static_chunk_size: -1
151
+ use_dynamic_chunk: false
152
+ use_dynamic_left_chunk: false
153
+
conf/ct_config_sft.yaml ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: llmasr
2
+
3
+ # llm_path
4
+ llm_path: &llm_path "/home/A02_tmpdata3/ckpt/Qwen2.5-3B-Instruct"
5
+ #
6
+ # model config
7
+ downsample_rate: 4 # 1 2 4 8
8
+ adapter_type: osum_echat
9
+ if_instruct: true
10
+ input_dim: 80
11
+
12
+ # tokenizer ,gxl
13
+ tokenizer: huggingface
14
+ tokenizer_conf:
15
+ llm_path: *llm_path
16
+
17
+ # lora config
18
+ use_lora: false
19
+ lora_alpha: 32
20
+ lora_rank: 64 # 3B -> 85M
21
+ lora_dropout: 0.1
22
+
23
+ # speech generate config
24
+ speech_token_num: &token_num 4097 #4097
25
+
26
+
27
+ # Configuration of parameters for training
28
+ fire_module: link_and_encoder_and_lora # link encoder llm link_and_encoder link_and_encoder_and_lora, llm需要配合use_lora为true
29
+
30
+ # other config
31
+ grad_clip: 5
32
+ accum_grad: 8
33
+ log_interval: 10
34
+ save_interval: 125 #1250 #2500
35
+ max_epoch: 1
36
+ init_step: true
37
+
38
+ # training config
39
+ optim: adamw
40
+ optim_conf:
41
+ betas:
42
+ - 0.9
43
+ - 0.99
44
+ eps: 1.0e-06
45
+ lr: 1.0e-06
46
+ weight_decay: 0.01
47
+ scheduler: warmuplr
48
+ scheduler_conf:
49
+ warmup_steps: 400
50
+
51
+
52
+ dataset: asr
53
+ dataset_conf:
54
+ speech_token_num: *token_num
55
+ batch_conf:
56
+ batch_size: 26
57
+ batch_type: dynamic
58
+ max_frames_in_batch: 28000000 #3000 #9000 #3000 #3300 # 3900
59
+ max_seq_in_batch: 3700 #1500 #4000 #1100 #1600 # 1900
60
+ feats_type: log_mel_spectrogram
61
+ filter_conf:
62
+ max_length: 20000
63
+ min_length: 20
64
+ token_max_length: 1200
65
+ token_min_length: 1
66
+ filter_no_extra_info: true # 如果没有task lang 等信息,直接过滤掉, 适用于通用多任务训练, 推理时应该关掉
67
+ max_seq_len: 2000 #、1100 #1000
68
+ other_filter_conf:
69
+ only_s2s: false # 只针对与s2s dataloader的过滤
70
+ only_s2t: false # 只针对与s2t dataloader的过滤
71
+ only_t2t: false # 只针对与t2t dataloader的过滤
72
+ only_t2s: false # 只针对与t2s dataloader的过滤
73
+ language_conf:
74
+ limited_langs:
75
+ - zh
76
+ log_mel_spectrogram_conf:
77
+ hop_length: 160
78
+ n_fft: 400
79
+ num_mel_bins: 80
80
+ padding: 0
81
+ resample_conf:
82
+ resample_rate: 16000
83
+ shuffle: true
84
+ shuffle_conf:
85
+ shuffle_size: 1500
86
+ sort: true
87
+ sort_conf:
88
+ sort_size: 500
89
+ spec_aug: true
90
+ spec_aug_conf:
91
+ max_f: 10
92
+ max_t: 50
93
+ num_f_mask: 2
94
+ num_t_mask: 2
95
+ spec_sub: true
96
+ spec_sub_conf:
97
+ max_t: 30
98
+ num_t_sub: 3
99
+ spec_trim: false
100
+ speed_perturb: false
101
+ eod_id: 151645
102
+ split_num: 1
103
+ multi_num: 2
104
+ prompt_conf_path: conf/prompt_config.yaml
105
+ data_recover: false
106
+ data_recover_conf:
107
+ start_idx: 0 # 删除前面start_idx个item(tar包)
108
+ other_tokenze_conf: # 一些对数据额外操作的可控按钮,这些操作一般来说再test时都得为false
109
+ only_info:
110
+ only_s2s: false # 只针对与s2s dataloader的过滤
111
+ only_s2t: false # 只针对与s2t dataloader的过滤
112
+ only_t2t: false # 只针对与t2t dataloader的过滤
113
+ only_t2s: false # 只针对与t2s dataloader的过滤
114
+ use_50_per_change_if_only_X: true # 50%的句子随机替换为其only X
115
+ use_s2s_streaming_random:
116
+ enable: false
117
+ rate: 0.5 # 1.0 表示100%的句子随机替换为其only X
118
+ natural_language_convert:
119
+ enable: false
120
+ rate: 0.00 # 1.0 表示100%的转换成自然语言模式
121
+ use_s2s_convert_s2t:
122
+ enable: false # 单独为s2t dataloader 开启s2s convert
123
+ rate: 1.0 # 1.0 表示100%的句子随机替换为其only X
124
+ use_streaming_tts:
125
+ enable: false
126
+ rate: 0.5 # 1.0 表示100%的句子随机替换为其only X
127
+ use_think_mode:
128
+ enable: false # 开启think 模式, 即随机替换为think模式的句子
129
+ rate: 0.8
130
+ other_filter_conf:
131
+ fiter_txt_is_None: true # 过滤掉text is "<NONE>"的语音数据,适配由于gender数据部分含有<NONE>标签而设计。但仅train起作用
132
+
133
+ # model config for encoder
134
+ encoder: transformer
135
+ encoder_conf:
136
+ activation_type: gelu
137
+ attention_dropout_rate: 0.0
138
+ attention_heads: 16
139
+ dropout_rate: 0.1
140
+ gradient_checkpointing: true
141
+ input_layer: conv1d2
142
+ key_bias: false
143
+ linear_units: 4096
144
+ normalize_before: true
145
+ num_blocks: 24
146
+ output_size: 1024
147
+ pos_enc_layer_type: abs_pos_whisper
148
+ positional_dropout_rate: 0.1
149
+ static_chunk_size: -1
150
+ use_dynamic_chunk: false
151
+ use_dynamic_left_chunk: false
152
+
conf/data_s2s.yaml ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ===========================副语言 s2s thinking ===================================
2
+
3
+ # age gender,
4
+ age_gender_common:
5
+ path: /home/A02_tmpdata3/osum_s2s/gender/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
6
+ tar_num: 1511
7
+ weight: 2
8
+
9
+ gender_xianshi:
10
+ path: /home/A02_tmpdata3/osum_s2s/sex_xianshi_cosyvoice2_by_cywang_added_by_20250625/raw_data/s2s_handle/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
11
+ tar_num: 30
12
+ weight: 2
13
+
14
+
15
+ gender_yinshi_3k:
16
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_sex_yinshi_7_4_osum_by_cywang_added_by_20250708/raw_data/s2s_handle/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
17
+ tar_num: 3
18
+ weight: 2
19
+
20
+ gender_yinshi_5k:
21
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_sex_yinshi_5000_6_13_data_by_gjli_added_by_20250622/raw_data/s2s_handle/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
22
+ tar_num: 6
23
+ weight: 2
24
+
25
+ age_xianshi:
26
+ path: /home/A02_tmpdata3/osum_s2s/age_xianshi_cosyvoice2_by_cywang_added_by_20250625/raw_data/s2s_handle/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
27
+ tar_num: 25
28
+ weight: 2
29
+
30
+
31
+ # caption
32
+ caption_common_7label:
33
+ path: /home/A02_tmpdata3/osum_s2s/caption/raw_data/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
34
+ tar_num: 162
35
+ weight: 2
36
+
37
+ caption_common_50_label:
38
+ path: /home/A02_tmpdata3/osum_s2s/caption_add_2025_1_6/raw_data/s2s_data_with_gender/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
39
+ tar_num: 395 # 实际是196k
40
+ weight: 2
41
+
42
+ caption_xianshi:
43
+ path: /home/A02_tmpdata3/osum_s2s/caption_s2s_xianshi_20250806/raw_data/s2s_data/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
44
+ tar_num: 6
45
+ weight: 10
46
+
47
+
48
+
49
+ # emotion
50
+ emotion_100K_sensevoice:
51
+ path: /home/A02_tmpdata3/osum_s2s/emotion_yinshi_zxzhao_with_q_emo_by_cywang_added_by_20250701/handle_data/s2s_handle/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
52
+ tar_num: 107
53
+ weight: 10
54
+
55
+ emotion_30K_sensevoice:
56
+ path: /home/A02_tmpdata3/emotion/中英混多音色情感数据库/s2s_handle/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
57
+ tar_num: 33
58
+ weight: 10
59
+
60
+ S2SChat_osum_setting_qa_527_updated_by_cywang_added_by_20250616_think:
61
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_osum_setting_qa_527_updated_by_cywang_added_by_20250616/raw_data/s2s_handle/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
62
+ shard_num: 8
63
+ weight: 10
64
+
65
+ # ======================================s2s 副语言 thinking end=====================================
66
+
67
+
68
+
69
+
70
+ # ===========================副语言 s2s no thinking ===================================
71
+
72
+ # age gender,
73
+ age_gender_common_no_thinking:
74
+ path: /home/A02_tmpdata3/osum_s2s/gender/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
75
+ tar_num: 1511
76
+ weight: 2
77
+
78
+ gender_xianshi_no_thinking:
79
+ path: /home/A02_tmpdata3/osum_s2s/sex_xianshi_cosyvoice2_by_cywang_added_by_20250625/raw_data/s2s_handle/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
80
+ tar_num: 30
81
+ weight: 2
82
+
83
+ gender_yinshi_3k_no_thinking:
84
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_sex_yinshi_7_4_osum_by_cywang_added_by_20250708/raw_data/s2s_handle/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
85
+ tar_num: 3
86
+ weight: 2
87
+
88
+ gender_yinshi_5k_no_thinking:
89
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_sex_yinshi_5000_6_13_data_by_gjli_added_by_20250622/raw_data/s2s_handle/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
90
+ tar_num: 6
91
+ weight: 2
92
+
93
+ age_xianshi_no_thinking:
94
+ path: /home/A02_tmpdata3/osum_s2s/age_xianshi_cosyvoice2_by_cywang_added_by_20250625/raw_data/s2s_handle/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
95
+ tar_num: 25
96
+ weight: 2
97
+
98
+
99
+ # caption
100
+ caption_common_7label_no_thinking:
101
+ path: /home/A02_tmpdata3/osum_s2s/caption/raw_data/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
102
+ tar_num: 162
103
+ weight: 2
104
+
105
+ caption_common_50_label_no_thinking:
106
+ path: /home/A02_tmpdata3/osum_s2s/caption_add_2025_1_6/raw_data/s2s_data_with_gender/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
107
+ tar_num: 395 # 实际是196k
108
+ weight: 2
109
+
110
+ caption_xianshi_no_thinking:
111
+ path: /home/A02_tmpdata3/osum_s2s/caption_s2s_xianshi_20250806/raw_data/s2s_data/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
112
+ tar_num: 6
113
+ weight: 2
114
+
115
+
116
+ # emotion
117
+ emotion_100K_sensevoice_no_thinking:
118
+ path: /home/A02_tmpdata3/osum_s2s/emotion_yinshi_zxzhao_with_q_emo_by_cywang_added_by_20250701/handle_data/s2s_handle/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
119
+ tar_num: 107
120
+ weight: 10
121
+
122
+ emotion_30K_sensevoice_no_thinking:
123
+ path: /home/A02_tmpdata3/emotion/中英混多音色情感数据库/s2s_handle/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
124
+ tar_num: 33
125
+ weight: 10
126
+
127
+ S2SChat_osum_setting_qa_527_updated_by_cywang_added_by_20250616:
128
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_osum_setting_qa_527_updated_by_cywang_added_by_20250616/raw_data/s2s_handle/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
129
+ shard_num: 8
130
+ weight: 10
131
+
132
+ # -------------------------------------------s2s 副语言 no thinking end-------------------------------------------
133
+
134
+ S2SChat_syndata_merged_by_300W_zhguo_added_by_20250616:
135
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_syndata_merged_by_300W_zhguo_added_by_20250616/combines_data_s2s/combines_list.txt
136
+ tar_num: 3000
137
+ weight: 1
138
+ S2SChat_osum_total_data_lst_check_final_100W_by_zhguo_added_by_20250616:
139
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_osum_total_data_lst_check_final_100W_by_zhguo_added_by_20250616/combines_data_s2s/combines_list.txt
140
+ tar_num: 1000
141
+ weight: 1
142
+
143
+ gaozhiliang_gbma:
144
+ path: /home/A02_tmpdata3/osum_s2s/gaozhiliang_gbma/shards_list.txt
145
+ new_data_list: /home/node44_tmpdata3/netease/gbma/workspace/osum/data/process/0803/all_data_info.jsonl
146
+ new_lab_path: /home/work_nfs23/asr_data/data/osum_chat/s2s/gaozhiliang_gbma/shards_list.txt
147
+ shard_num: 24
148
+ weight: 1
149
+
150
+ # ======================================s2s no thinking end=====================================
151
+
152
+ # emotion explicit;
153
+ S2SChat_0628_E1_shard_by_kxxia_added_by_20250630:
154
+ huawei_path: /mnt/sfs/asr/update_data/S2SChat_0628_E1_shard_by_kxxia_added_by_20250630/shards_list.txt
155
+ new_data_list: /home/work_nfs16/cywang/workspace/OSUM/E1/0628_E1_shard.jsonl
156
+ new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/0628_E1_shard/shards_list.txt
157
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_0628_E1_shard_by_kxxia_added_by_20250630/shards_list.txt
158
+ shard_num: 147
159
+ description: "E1 shard, 情感显示数据"
160
+ weight: 1
161
+ S2SChat_eng_e1_by_cywang_added_by_20250711:
162
+ huawei_path: /mnt/sfs/asr/update_data/S2SChat_eng_e1_by_cywang_added_by_20250711/shards_list.txt
163
+ new_data_list: /home/work_nfs16/kxxia/work/common/eng_e1.jsonl1752154262.3374825
164
+ new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/eng_e1/shards_list.txt
165
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_eng_e1_by_cywang_added_by_20250711/shards_list.txt
166
+ shard_num: 50
167
+ weight: 2
168
+
169
+ # 下面一共才200多个
170
+ S2SChat_0630_trans_en2zh_by_cywang_added_by_20250704:
171
+ huawei_path: /mnt/sfs/asr/update_data/S2SChat_0630_trans_en2zh_by_cywang_added_by_20250704/shards_list.txt
172
+ new_data_list: /home/work_nfs16/cywang/workspace/OSUM/trans_emotion/0630_trans_en2zh.jsonl
173
+ new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/0630_trans_en2zh/shards_list.txt
174
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_0630_trans_en2zh_by_cywang_added_by_20250704/shards_list.txt
175
+ shard_num: 128
176
+ weight: 0.5
177
+ S2SChat_0630_trans_zh2en_by_cywang_added_by_20250704:
178
+ huawei_path: /mnt/sfs/asr/update_data/S2SChat_0630_trans_zh2en_by_cywang_added_by_20250704/shards_list.txt
179
+ new_data_list: /home/work_nfs16/cywang/workspace/OSUM/trans_emotion/0630_trans_zh2en.jsonl
180
+ new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/0630_trans_zh2en/shards_list.txt
181
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_0630_trans_zh2en_by_cywang_added_by_20250704/shards_list.txt
182
+ shard_num: 128
183
+ weight: 0.5
184
+ S2SChat_pachong_part1_filter_author_data_by_gjli_added_by_20250622:
185
+ shard_num: 28
186
+ huawei_path: /mnt/sfs/asr/update_data/S2SChat_pachong_part1_filter_author_data_by_gjli_added_by_20250622/shards_list.txt
187
+ new_data_list: /home/work_nfs16/gjli/workspaces/poem/6-16_shigepachong/pachong_part1_filter_content_data.list
188
+ new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/pachong_part1_filter_author_data/shards_list.txt
189
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_pachong_part1_filter_author_data_by_gjli_added_by_20250622/shards_list.txt
190
+ weight: 1
191
+ S2SChat_pachong_part1_filter_content_data_by_gjli_added_by_20250622:
192
+ huawei_path: /mnt/sfs/asr/update_data/S2SChat_pachong_part1_filter_content_data_by_gjli_added_by_20250622/shards_list.txt
193
+ new_data_list: /home/work_nfs16/gjli/workspaces/poem/6-16_shigepachong/pachong_part1_filter_author_data.list
194
+ new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/pachong_part1_filter_content_data/shards_list.txt
195
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_pachong_part1_filter_content_data_by_gjli_added_by_20250622/shards_list.txt
196
+ shard_num: 68
197
+ weight: 1
198
+ S2SChat_poem_1_2_6_3_author_data_150num_by_gjli_added_by_20250622:
199
+ huawei_path: /mnt/sfs/asr/update_data/S2SChat_poem_1_2_6_3_author_data_150num_by_gjli_added_by_20250622/shards_list.txt
200
+ new_data_list: /home/work_nfs16/gjli/workspaces/poem/6.3/poem_1_2_6-3_author_data.list
201
+ new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/poem_1_2_6-3_author_data/shards_list.txt
202
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_poem_1_2_6_3_author_data_150num_by_gjli_added_by_20250622/shards_list.txt
203
+ shard_num: 2
204
+ weight: 1
205
+ S2SChat_poem_1_2_6_3_content_data_150num_by_gjli_added_by_20250622:
206
+ huawei_path: /mnt/sfs/asr/update_data/S2SChat_poem_1_2_6_3_content_data_150num_by_gjli_added_by_20250622/shards_list.txt
207
+ new_data_list: /home/work_nfs16/gjli/workspaces/poem/6.3/poem_1_2_6-3_content_data.list
208
+ new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/poem_1_2_6-3_content_data/shards_list.txt
209
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_poem_1_2_6_3_content_data_150num_by_gjli_added_by_20250622/shards_list.txt
210
+ shard_num: 9
211
+ weight: 1
212
+ S2SChat_poem_500_author_data_new_by_gjli_added_by_20250622:
213
+ huawei_path: /mnt/sfs/asr/update_data/S2SChat_poem_500_author_data_new_by_gjli_added_by_20250622/shards_list.txt
214
+ new_data_list: /home/work_nfs16/gjli/workspaces/poem/poem_500_author_data_new.list
215
+ new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/poem_500_author_data_new/shards_list.txt
216
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_poem_500_author_data_new_by_gjli_added_by_20250622/shards_list.txt
217
+ shard_num: 4
218
+ weight: 1
219
+ S2SChat_poem_500_content_data_new_by_gjli_added_by_20250622:
220
+ huawei_path: /mnt/sfs/asr/update_data/S2SChat_poem_500_content_data_new_by_gjli_added_by_20250622/shards_list.txt
221
+ new_data_list: /home/work_nfs16/gjli/workspaces/poem/poem_500_content_data_new.list
222
+ new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/poem_500_content_data_new/shards_list.txt
223
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_poem_500_content_data_new_by_gjli_added_by_20250622/shards_list.txt
224
+ shard_num: 4
225
+ weight: 1
226
+
conf/data_s2t.yaml ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # age gender,
3
+ age_gender_common:
4
+ path: /home/A02_tmpdata3/osum_s2s/gender/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
5
+ tar_num: 1511
6
+
7
+ gender_xianshi:
8
+ path: /home/A02_tmpdata3/osum_s2s/sex_xianshi_cosyvoice2_by_cywang_added_by_20250625/raw_data/s2s_handle/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
9
+ tar_num: 30
10
+
11
+ gender_yinshi_3k:
12
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_sex_yinshi_7_4_osum_by_cywang_added_by_20250708/raw_data/s2s_handle/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
13
+ tar_num: 3
14
+
15
+ gender_yinshi_5k:
16
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_sex_yinshi_5000_6_13_data_by_gjli_added_by_20250622/raw_data/s2s_handle/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
17
+ tar_num: 6
18
+
19
+ age_xianshi:
20
+ path: /home/A02_tmpdata3/osum_s2s/age_xianshi_cosyvoice2_by_cywang_added_by_20250625/raw_data/s2s_handle/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
21
+ tar_num: 25
22
+
23
+
24
+ # caption
25
+ caption_common_7label:
26
+ path: /home/A02_tmpdata3/osum_s2s/caption/raw_data/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
27
+ tar_num: 162
28
+
29
+ caption_common_50_label:
30
+ path: /home/A02_tmpdata3/osum_s2s/caption_add_2025_1_6/raw_data/s2s_data_with_gender/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
31
+ tar_num: 395 # 实际是196k
32
+
33
+ caption_xianshi:
34
+ path: /home/A02_tmpdata3/osum_s2s/caption_s2s_xianshi_20250806/raw_data/s2s_data/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
35
+ tar_num: 6
36
+ weight: 10
37
+
38
+
39
+ # emotion
40
+ emotion_100K_sensevoice:
41
+ path: /home/A02_tmpdata3/osum_s2s/emotion_yinshi_zxzhao_with_q_emo_by_cywang_added_by_20250701/handle_data/s2s_handle/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
42
+ tar_num: 107
43
+ weight: 10
44
+
45
+ emotion_30K_sensevoice:
46
+ path: /home/A02_tmpdata3/emotion/中英混多音色情感数据库/s2s_handle/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
47
+ tar_num: 33
48
+ weight: 10
49
+
50
+ S2SChat_osum_setting_qa_527_updated_by_cywang_added_by_20250616_think:
51
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_osum_setting_qa_527_updated_by_cywang_added_by_20250616/raw_data/s2s_handle/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
52
+ shard_num: 8
53
+ weight: 10
54
+
55
+ # ======================================s2s 副语言 thinking end=====================================
56
+
57
+
58
+
59
+
60
+ # ===========================副语言 s2s no thinking ===================================
61
+
62
+ # age gender,
63
+ age_gender_common_no_thinking:
64
+ path: /home/A02_tmpdata3/osum_s2s/gender/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
65
+ tar_num: 1511
66
+
67
+ gender_xianshi_no_thinking:
68
+ path: /home/A02_tmpdata3/osum_s2s/sex_xianshi_cosyvoice2_by_cywang_added_by_20250625/raw_data/s2s_handle/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
69
+ tar_num: 30
70
+
71
+ gender_yinshi_3k_no_thinking:
72
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_sex_yinshi_7_4_osum_by_cywang_added_by_20250708/raw_data/s2s_handle/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
73
+ tar_num: 3
74
+
75
+ gender_yinshi_5k_no_thinking:
76
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_sex_yinshi_5000_6_13_data_by_gjli_added_by_20250622/raw_data/s2s_handle/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
77
+ tar_num: 6
78
+
79
+ age_xianshi_no_thinking:
80
+ path: /home/A02_tmpdata3/osum_s2s/age_xianshi_cosyvoice2_by_cywang_added_by_20250625/raw_data/s2s_handle/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
81
+ tar_num: 25
82
+
83
+
84
+ # caption
85
+ caption_common_7label_no_thinking:
86
+ path: /home/A02_tmpdata3/osum_s2s/caption/raw_data/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
87
+ tar_num: 162
88
+
89
+ caption_common_50_label_no_thinking:
90
+ path: /home/A02_tmpdata3/osum_s2s/caption_add_2025_1_6/raw_data/s2s_data_with_gender/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
91
+ tar_num: 395 # 实际是196k
92
+
93
+ caption_xianshi_no_thinking:
94
+ path: /home/A02_tmpdata3/osum_s2s/caption_s2s_xianshi_20250806/raw_data/s2s_data/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
95
+ tar_num: 6
96
+
97
+
98
+ # emotion
99
+ emotion_100K_sensevoice_no_thinking:
100
+ path: /home/A02_tmpdata3/osum_s2s/emotion_yinshi_zxzhao_with_q_emo_by_cywang_added_by_20250701/handle_data/s2s_handle/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
101
+ tar_num: 107
102
+ weight: 10
103
+
104
+ emotion_30K_sensevoice_no_thinking:
105
+ path: /home/A02_tmpdata3/emotion/中英混多音色情感数据库/s2s_handle/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
106
+ tar_num: 33
107
+ weight: 10
108
+
109
+ S2SChat_osum_setting_qa_527_updated_by_cywang_added_by_20250616:
110
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_osum_setting_qa_527_updated_by_cywang_added_by_20250616/raw_data/s2s_handle/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
111
+ shard_num: 8
112
+ weight: 10
113
+
114
+ # -------------------------------------------s2s 副语言 no thinking end-------------------------------------------
115
+
116
+
117
+
118
+ # 知识问答
119
+ S2SChat_syndata_merged_by_300W_zhguo_added_by_20250616:
120
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_syndata_merged_by_300W_zhguo_added_by_20250616/combines_data_s2t/combines_list.txt
121
+ tar_num: 3000
122
+ S2SChat_osum_total_data_lst_check_final_100W_by_zhguo_added_by_20250616:
123
+ path: /home/A02_tmpdata3/osum_s2s/S2SChat_osum_total_data_lst_check_final_100W_by_zhguo_added_by_20250616/combines_data_s2t/combines_list.txt
124
+ tar_num: 1000
125
+ # ======================================s2t 副语言 no thinking end==========================
126
+
127
+
128
+
129
+ # 语音理解==========================================
130
+ asr:
131
+ huawei_path: "/mnt/sfs/asr/asr/shards_list.txt" # 2.4
132
+ lab_path: "/home/node54_tmpdata/xlgeng/asr_data_2w/shards_list.txt"
133
+ path: "/home/A03_tmpdata1/s2s/asr_data_2.4w/asr_data_2w/shards_list.txt"
134
+ shard_num: 15477
135
+ weight: 0.1 # ~10000h
136
+
137
+
138
+ # ===========理解任务 ==============================================
139
+ librispeech:
140
+ huawei_path: "/mnt/sfs/asr/update_data/LibriSpeech_shard_common/shards_list.txt" #1000h
141
+ lab_path: "/home/work_nfs15/asr_data/data/LibriSpeech/LibriSpeech_shard_common/shards_list.txt"
142
+ path: "/home/A03_tmpdata3/asr_data/librispeech/shards_list.txt"
143
+ shard_num: 282
144
+ weight: 1
145
+ mix_asru200_add_2025_2_14:
146
+ huawei_path: "/mnt/sfs/asr/update_data/mix_asru200_add_2025_2_14/shards_list.txt" # 200
147
+ path: "/home/A03_tmpdata1/s2s/asru700/train/shards_list.txt"
148
+ lab_path: "/home/work_nfs15/asr_data/data/ASRU700/train/shards_list.txt" # 中英混单词之间是有空格的
149
+ shard_num: 187
150
+ weight: 1
151
+
152
+
153
+
154
+
155
+
156
+ caption:
157
+ path: "/home/A02_tmpdata3/osum_s2s/caption/shards_list.txt"
158
+ huawei_path: "/mnt/sfs/asr/update_data/caption/shards_list.txt" # 319h
159
+ lab_path: "/home/node54_tmpdata2/data4understand/update_data/caption/shards_list.txt"# 是cap audio set+aishell2的拼接
160
+ shard_num: 319
161
+ weight: 0.5
162
+ caption_add_2025_1_6:
163
+ path: "/home/A02_tmpdata3/osum_s2s/caption_add_2025_1_6/shards_list.txt"
164
+ lab_path: "/home/work_nfs7/yacao/0106_twj_shard/shards_0306/add_label/shards_list.txt"
165
+ huawei_path: "/mnt/sfs/asr/update_data/caption_2025_1_6_newadd/shards_list.txt" # 130h
166
+ shard_num: 392
167
+ weight: 0.5
168
+ caption_aslp_add_2025_1_15:
169
+ path: "/home/A02_tmpdata3/osum_s2s/caption_aslp_add_2025_1_15/shards_list.txt"
170
+ huawei_path: "/mnt/sfs/asr/update_data/caption_aslp_add_2025_1_15/shards_list.txt" # 5h
171
+ shard_num: 5
172
+ lab_path: "/home/work_nfs9/yacao/nfs7_copy/yacao/shard/0114_wjtian_simu2/aslp_caption_train/shards_list.txt"
173
+ weight: 5
174
+
175
+
176
+
177
+ # 50类别的caption
178
+ s2t_caption_50label:
179
+ shard_num: 392
180
+ path: "/home/A02_tmpdata3/osum_s2s/s2t_caption_50label/shards_list.txt"
181
+ lab_path: "/home/work_nfs7/yacao/0106_twj_shard/shards_0306/add_label/shards_list.txt"
182
+ huawei_path: "/mnt/sfs/asr/update_data/0106_twj_shard_caption_50label_add_by_2025_3_10/shards_list.txt" # 392tar
183
+ weight: 0.5 # 10
184
+
185
+
186
+
187
+
188
+ emotion: # 不全, 312tar
189
+ path: "/home/A02_tmpdata3/osum_s2s/emotion/shards_list.txt"
190
+ lab_path: "/home/xlgeng/sdb2/emotion/shards_list.txt"
191
+ huawei_path: "/mnt/sfs/asr/emotion/shards_list.txt"
192
+ shard_num: 370
193
+ weight: 0.5 # 538h
194
+ emotion_stage2_add:
195
+ path: "/home/A02_tmpdata3/osum_s2s/emotion_stage2_add/shards_list.txt"
196
+ lab_path: "/home/xlgeng/sdb2/emotion_stage2_add/shards_list.txt"
197
+ huawei_path: "/mnt/sfs/asr/emotion_stage2_add/shards_list.txt"
198
+ shard_num: 44
199
+ weight: 0.1 # 150h
200
+ emotion_stage3_add:
201
+ path: "/home/A02_tmpdata3/osum_s2s/emotion_stage3_add/shards_list.txt"
202
+ lab_path: "/home/xlgeng/sdb2/emotion_stage3_add/shards_list.txt"
203
+ huawei_path: "/mnt/sfs/asr/emotion_stage3_add/shards_list.txt"
204
+ shard_num: 53
205
+ weight: 0.1 # 138h
206
+ emotion_stage4_add:
207
+ path: "/home/A02_tmpdata3/osum_s2s/emotion_stage4_add/shards_list.txt"
208
+ lab_path: "/home/xlgeng/sdb2/emotion_stage4_add/shards_list.txt"
209
+ huawei_path: "/mnt/sfs/asr/emotion_stage4_add/shards_list.txt"
210
+ shard_num: 54
211
+ weight: 0.1 #100h
212
+ emotion_stage5_add:
213
+ path: "/home/A02_tmpdata3/osum_s2s/emotion_stage5_add/shards_list.txt"
214
+ lab_path: "/home/xlgeng/sdb2/emotion_stage5_add/shards_list.txt"
215
+ shard_num: 53
216
+ huawei_path: "/mnt/sfs/asr/emotion_stage5_add/shards_list.txt"
217
+ weight: 0.1
218
+
219
+ emotion_meld:
220
+ path: "/home/A02_tmpdata3/osum_s2s/emotion_meld/shards_list.txt"
221
+ lab_path: "/home/xlgeng/sdb2/emotion_meld/shards_list.txt"
222
+ huawei_path: "/mnt/sfs/asr/update_data/emotion_meld/shards_list.txt" # 8h
223
+ shard_num: 9
224
+ weight: 1
225
+ #emotion_dis_fear_add_2025_1_15:
226
+ # huawei_path: "/mnt/sfs/asr/update_data/emotion_dis_fear_add_2025_1_15/shards_list.txt"
227
+ # weight: 0
228
+
229
+ emotion_lucy_Q_added_2025_4_9:
230
+ path: "/home/A02_tmpdata3/osum_s2s/s2s_lucy_Q_emotion/shards_list.txt"
231
+ shard_num: 121
232
+ lab_path: "/home/work_nfs11/cywang/data/shard/emotion/QEmo_Q_train/shards_list.txt"
233
+ huawei_path: "/mnt/sfs/asr/update_data/emotion_lucy_Q_added_2025_4_9/shards_list.txt"
234
+ weight: 0.5
235
+
236
+
237
+
238
+ Age_with_noize_add_2025_2_4: # 不全,才245个
239
+ path: "/home/A02_tmpdata3/osum_s2s/age_3000_noize/shards_list.txt"
240
+ lab_path: "/home/work_nfs6/syliu/for_gxl/Age/simu_age/shards_list.txt"
241
+ shard_num: 2720
242
+ huawei_path: "/mnt/sfs/asr/update_data/Age_with_noize_add_2025_2_4/shards_list.txt"
243
+ weight: 0.1
244
+ age:
245
+ path: "/home/A02_tmpdata3/osum_s2s/age_3000/shards_list.txt"
246
+ lab_path: "/home/work_nfs3/syliu/for_gxl/Age/age/shards_list.txt"
247
+ huawei_path: "/mnt/sfs/asr/update_data/age/shards_list.txt"
248
+ shard_num: 2820
249
+ weight: 0.1 #1.5 # 3000h
250
+
251
+
252
+ gender: # 不全,目前310个
253
+ shard_num: 1738
254
+ lab_path: "/home/xlgeng/sdb2/gender/shards_list.txt"
255
+ huawei_path: "/mnt/sfs/asr/update_data/sex/shards_list.txt" # 3000
256
+ path: "/home/A02_tmpdata3/osum_s2s/gender/shards_list.txt"
257
+ weight: 0.1 #1.5
258
+ gender_add_2025_1_6_kaggle: # 全了
259
+ shard_num: 116
260
+ path: "/home/A02_tmpdata3/osum_s2s/gender_kaggle/shards_list.txt"
261
+ lab_path: "/home/work_nfs3/syliu/for_gxl/new_gender/Sex/sex/shards_list.txt"
262
+ huawei_path: "/mnt/sfs/asr/update_data/sex_2025_1_6_newadd/shards_list.txt" # 107h, kaggle
263
+ weight: 0.1 #3
264
+ gender_add_2025_2_4_fix: # 2100tar # 不全,365个
265
+ path: "/home/A02_tmpdata3/osum_s2s/gender_add_2025_2_4_fix/shards_list.txt"
266
+ shard_num: 2140
267
+ lab_path: "/home/work_nfs6/xlgeng/for_gxl/gender_add_2025_2_4_fix/shards_list.txt"
268
+ huawei_path: "/mnt/sfs/asr/update_data/gender_add_2025_2_4_fix/shards_list.txt"
269
+ weight: 0.1
270
+ gender_with_noize_add_2025_2_4: # 1500h ,780tar # 不全,266个
271
+ path: "/home/A02_tmpdata3/osum_s2s/gender_with_noize_add_2025_2_4/shards_list.txt"
272
+ lab_path: "/home/work_nfs6/xlgeng/for_gxl/gender_with_noize_add_2025_2_4/shards_list.txt"
273
+ huawei_path: "/mnt/sfs/asr/update_data/gender_with_noize_add_2025_2_4/shards_list.txt"
274
+ shard_num: 780
275
+ weight: 0.1
276
+
277
+
278
+
279
+
280
+ age_gender_stage2_add:
281
+ path: "/home/A02_tmpdata3/osum_s2s/age_gender_stage2_add/shards_list.txt"
282
+ lab_path: "/home/xlgeng/sdb2/age_gender_stage2_add/shards_list.txt"
283
+ huawei_path: "/mnt/sfs/asr/update_data/Speech_Age_Sex/shards_list.txt"
284
+ weight: 0.1 # 174h
285
+
286
+ age_gender_add_2025_1_13:
287
+ path: "/home/A02_tmpdata3/osum_s2s/age_gender_add_2025_1_13/shards_list.txt"
288
+ lab_path: "/home/work_nfs3/syliu/for_gxl/Age_Sex/age_sex/shards_list.txt"
289
+ huawei_path: "/mnt/sfs/asr/update_data/Speech_Age_Sex_add_2025_1_13/shards_list.txt"
290
+ weight: 0.1 #2571h
291
+
292
+ style_age_gender_stage3_add:
293
+ path: "/home/A02_tmpdata3/osum_s2s/style_age_gender_stage3_add/shards_list.txt"
294
+ lab_path: "/home/xlgeng/sdb2/style_age_gender_stage3_add/shards_list.txt"
295
+ huawei_path: "/mnt/sfs/asr/update_data/Speech_Style_Age_Sex/shards_list.txt"
296
+ weight: 0.1 # 85h
297
+
298
+
299
+ age_gender_pure_stage3_add:
300
+ path: "/home/A02_tmpdata3/osum_s2s/age_gender_pure_stage3_add/shards_list.txt"
301
+ lab_path: "/home/xlgeng/sdb2/age_gender_pure_stage3_add/shards_list.txt"
302
+ huawei_path: "/mnt/sfs/asr/update_data/Age_Sex/shards_list.txt"
303
+ weight: 0.1 # 174h
304
+
305
+
306
+ style_age_gender_pure_stage3_add:
307
+ path: "/home/A02_tmpdata3/osum_s2s/style_age_gender_pure_stage3_add/shards_list.txt"
308
+ lab_path: "/home/xlgeng/sdb2/style_age_gender_pure_stage3_add/shards_list.txt"
309
+ huawei_path: "/mnt/sfs/asr/update_data/Style_Age_Sex/shards_list.txt"
310
+ weight: 0.1 # 85h
311
+
312
+
313
+
314
+
315
+ # 多任务, caption
316
+ merged_output_caption_age_gender_add_2025_2_26:
317
+ path: "/home/A02_tmpdata3/osum_s2s/merged_output_caption_age_gender_add_2025_2_26/shards_list.txt"
318
+ lab_path: "/home/work_nfs7/yacao/0106_twj_shard/shards_0226/merged_output/shards_list.txt"
319
+ huawei_path: "/mnt/sfs/asr/update_data/multi_task/caption_new/merged_output/shards_list.txt"
320
+ weight: 0.1
321
+ nfs10_time1_output_caption_age_gender_add_2025_2_26:
322
+ path: "/home/A02_tmpdata3/osum_s2s/nfs10_time1_output_caption_age_gender_add_2025_2_26/shards_list.txt"
323
+ lab_path: "/home/work_nfs7/yacao/0106_twj_shard/shards_0226/nfs10_time1/shards_list.txt"
324
+ huawei_path: "/mnt/sfs/asr/update_data/multi_task/caption_new/nfs10_time1/shards_list.txt"
325
+ weight: 0.1
326
+ other_20000_caption_age_gender_add_2025_2_26:
327
+ path: "/home/A02_tmpdata3/osum_s2s/other_20000_caption_age_gender_add_2025_2_26/shards_list.txt"
328
+ lab_path: "/home/work_nfs7/yacao/0106_twj_shard/shards_0226/other_20000/shards_list.txt"
329
+ huawei_path: "/mnt/sfs/asr/update_data/multi_task/caption_new/other_20000/shards_list.txt"
330
+ weight: 0.1
331
+ simu9_1227_caption_age_gender_add_2025_2_26:
332
+ path: "/home/A02_tmpdata3/osum_s2s/simu9_1227_caption_age_gender_add_2025_2_26/shards_list.txt"
333
+ lab_path: "/home/work_nfs7/yacao/0106_twj_shard/shards_0226/simu9_1227/shards_list.txt"
334
+ huawei_path: "/mnt/sfs/asr/update_data/multi_task/caption_new/simu9_1227/shards_list.txt"
335
+ weight: 0.1
336
+
337
+
338
+ # 多任务, emotion
339
+ merged_output_emotion_age_gender_add_2025_3_2:
340
+ path: "/home/A02_tmpdata3/osum_s2s/merged_output_emotion_age_gender_add_2025_3_2/shards_list.txt"
341
+ lab_path: "/home/work_nfs16/emotion_data/OSUM_age_gender/emotion_age_gender1/shards_list.txt"
342
+ huawei_path: "/mnt/sfs/asr/update_data/multi_task/emotion_age_gender1/shards_list.txt"
343
+ weight: 0.1
344
+
345
+ merged_output_emotion_age_gender_add_2025_3_2_di2pi:
346
+ path: "/home/A02_tmpdata3/osum_s2s/merged_output_emotion_age_gender_add_2025_3_2_di2pi/shards_list.txt"
347
+ shard_num: 181
348
+ lab_path: "/home/work_nfs16/emotion_data/OSUM_age_gender/emotion_age_gender2/shards_list.txt"
349
+ huawei_path: ""
350
+ weight: 0.1
351
+
352
+
353
+ # 多任务, style
354
+ merged_output_style_age_gender_add_2025_3_2:
355
+ path: "/home/A02_tmpdata3/osum_s2s/merged_output_style_age_gender_add_2025_3_2/shards_list.txt"
356
+ lab_path: "/home/node54_tmpdata2/gjli/style_age_gender_data/style_labeling_100wto200w_part1_age_gender/shards_list.txt"
357
+ shard_num: 107
358
+ huawei_path: "/mnt/sfs/asr/update_data/multi_task/style_labeling_100wto200w_part1_age_gender/shards_list.txt"
359
+ weight: 0.1
360
+ merged_output_style_origin_tts_age_gender_add_2025_3_2:
361
+ path: "/home/A02_tmpdata3/osum_s2s/merged_output_style_origin_tts_age_gender_add_2025_3_2/shards_list.txt"
362
+ lab_path: "/home/node54_tmpdata2/gjli/style_age_gender_data/style_origin_tts_age_gender/shards_list.txt"
363
+ huawei_path: "/mnt/sfs/asr/update_data/multi_task/style_origin_tts_age_gender/shards_list.txt"
364
+ weight: 0.1
365
+ style_labeling_100wto200w_part1_age_gender_emotion_gjli:
366
+ path: "/home/A02_tmpdata3/osum_s2s/style_labeling_100wto200w_part1_age_gender_emotion_gjli/shards_list.txt"
367
+ lab_path: "/home/node54_tmpdata2/gjli/style_labeling_100wto200w_part1_age_gender_emotion/shards_list.txt" # 107
368
+ huawei_path: "/mnt/sfs/asr/update_data/style_labeling_100wto200w_part1_age_gender_emotion/shards_list.txt" #107tar
369
+ weight: 0.5
370
+ style_labeling_200wto300w_part1_age_gender_emotion_gjli:
371
+ path: "/home/A02_tmpdata3/osum_s2s/style_labeling_200wto300w_part1_age_gender_emotion_gjli/shards_list.txt"
372
+ lab_path: "/home/node54_tmpdata2/gjli/style_labeling_200wto300w_part2/shards_list.txt"
373
+ shard_num: 236
374
+ huawei_path: "_"
375
+
376
+ age_gender_style_emotion1_add_2025_3_29_zxzhao:
377
+ path: "/home/A02_tmpdata3/osum_s2s/age_gender_style_emotion1_add_2025_3_29_zxzhao/shards_list.txt"
378
+ lab_path: "/home/work_nfs16/emotion_data/OSUM_age_gender/age_gender_style_emotion1/shards_list.txt"
379
+ huawei_path: "/mnt/sfs/asr/update_data/age_gender_style_emotion1_add_2025_3_29_zxzhao/shards_list.txt" # 256tar
380
+ weight: 0.5
381
+
382
+
383
+ 5_label_caption_age_gender_style_emotion_added_2025_3_29_yacao:
384
+ path: "/home/A02_tmpdata3/osum_s2s/5_label_caption_age_gender_style_emotion_added_2025_3_29_yacao/shards_list.txt"
385
+ huawei_path: "/mnt/sfs/asr/update_data/5_label_caption_age_gender_style_emotion_added_2025_3_29_yacao/shards_list.txt" #270tar
386
+ lab_path: "/home/work_nfs7/yacao/0320_multilabel_2/shard/5_label/shards_list.txt"
387
+ weight: 0.5
388
+
389
+ # audio description 数据
390
+ audio_caption_by_wjtian_added_by_20250414: # 其实是 20250411 ,写日期的时候由于自动补全写错了
391
+ path: "/home/A02_tmpdata3/osum_s2s/audio_caption_by_wjtian_added_by_20250414/shards_list.txt"
392
+ lab_path: "/home/work_nfs7/cywang/OSUM/OSUM_data/shard/audio_caption/audio_caption/shards_list.txt"
393
+ huawei_path: "/mnt/sfs/asr/update_data/audio_caption_by_wjtian_added_by_20250414/shards_list.txt" # 2155 tar
394
+ weight: 0.15 # 开始上传天文杰准备的audio_caption数据,音频描述数据
395
+
396
+
397
+ S2SChat_MMAU_training_all_by_wjtian_added_by_20250708:
398
+ path: "/home/A02_tmpdata3/osum_s2s/S2SChat_MMAU_training_all_by_wjtian_added_by_20250708/shards_list.txt"
399
+ lab_path: "/home/work_nfs11/cywang/data/shard/S2Chat/MMAU-training-all/shards_list.txt"
400
+ huawei_path: "/mnt/sfs/asr/update_data/S2SChat_MMAU_training_all_by_wjtian_added_by_20250708/shards_list.txt" # 1000 tar
401
+ shard_num: 22
402
+ weight: 5
conf/data_t2s.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TEXT2TOKEN_hq_add_2025_3_17:
2
+ lab_path: "/home/node48_tmpdata/hkxie/4O/speech_data_final/10wh_token_data/TEXT2TOKEN_hq/shards_list.txt"
3
+ path_huawei: "/mnt/sfs/asr/update_data/TEXT2TOKEN_hq_add_2025_3_17/shards_list.txt"
4
+ path: /home/A03_tmpdata2/s2s/2000_hq_S2Chat_by_hkxie_added_by_20250411/combine_tts/combines_list.txt
5
+ weight: 5
6
+
7
+ #english_text_token_add_2025_3_26:
8
+ # path_huawei: "/mnt/sfs/asr/update_data/english_speech_data_final_TEXT2TOKEN_part_1_added_2025_3_26/shards_list.txt" # 2000
9
+ # data_list_path: "/home/work_nfs14/code/hkxie/ASR/understanding_LLM_task/english/speech_data_final/data_libriheavy_part_1.list"
10
+ # lab_path: "?"
11
+ # weight: 0.5 #1 #1 #10
12
+ #english_TEXT2TOKEN_part_2_added_by_20250402:
13
+ # path_huawei: "/mnt/sfs/asr/update_data/english_TEXT2TOKEN_part_2_added_by_20250402/shards_list.txt" #8050
14
+ # data_list_path: "/home/work_nfs14/code/hkxie/ASR/understanding_LLM_task/english/speech_data_final/data_libriheavy_part_2.list"
15
+ # lab_path: "?"
16
+ # weight: 0.5 #1 #1 #10
17
+ #
18
+ #zh_en_mix_tts_added_by_20250402: # tts
19
+ # path_huawei: "/mnt/sfs/asr/update_data/zh_en_mix_s2s_added_by_20250402/shards_list.txt" # 7
20
+ # weight: 1
21
+ #poly_tts_added_by_20250402: # tts
22
+ # path: "/mnt/sfs/asr/update_data/poly_s2s_added_by_20250402/shards_list.txt" # 295
23
+ # weight: 0.5 #10
24
+ #
25
+ #text2token_itn_by_cywang_added_by_20250428: # zyzhang 负责,cywang打包
26
+ # lab_path: "/home/work_nfs7/cywang/OSUM/OSUM_data/shard/text2token/tn/shards_list.txt"
27
+ # path_huawei: "/mnt/sfs/asr/update_data/text2token_itn_by_cywang_added_by_20250428/shards_list.txt" # 1100
28
+ # weight: 0.5
conf/data_t2t.yaml ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 文本到文本
2
+ #text2text_added_2025_4_4:
3
+ # path_huawei: "/mnt/sfs/asr/update_data/text2text_added_by_20250404/shards_list.txt" # 1850
4
+ # weight: 1
5
+ #text2text_2_added_by_20250409:
6
+ # path_huawei: "/mnt/sfs/asr/update_data/text2text_2_added_by_20250409/shards_list.txt" # 2000
7
+ # weight: 1
8
+ #
9
+ #text2text_3_en_added_by_20250411:
10
+ # path_huawei: "/mnt/sfs/asr/update_data/text2text_3_en_added_by_20250411/shards_list.txt" # 185
11
+ # weight: 1
12
+ #
13
+ #text2text_4_en_added_by_20250416:
14
+ # path_huawei: "/mnt/sfs/asr/update_data/text2text_4_en_added_by_20250416/shards_list.txt"
15
+ # weight: 1
16
+
17
+ #text2text_5_lucy_audioQA_1M_by_cywang_added_by_20250426:
18
+ # shard_num: 10000
19
+ # path_huawei: "/mnt/sfs/asr/update_data/text2text_5_lucy_audioQA_1M_by_cywang_added_by_20250426/shards_list.txt"
20
+ # weight: 0.1
21
+
22
+ t2t_8772K_by_xlgeng_added_by_20250513:
23
+ path: "/home/A03_tmpdata1/text2text_data_xlgeng/t2t_8772K/shards_list.txt"
24
+ path_huawei: "/mnt/sfs/asr/update_data/t2t_8772K_by_xlgeng_added_by_20250513/shards_list.txt"
25
+ weight: 0.1
26
+
27
+ #t2t_math_poetry_self_by_xlgeng_added_by_20250513:
28
+ # path: "/home/A03_tmpdata1/text2text_data_xlgeng/t2t_math_poetry_self/shards_list.txt"
29
+ # path_huawei: "/mnt/sfs/asr/update_data/t2t_math_poetry_self_by_xlgeng_added_by_20250513/shards_list.txt" # 75
30
+ # weight: 1
31
+
32
+ Alpaca_CoT_3000W_by_wjt_added_by_20250605:
33
+ lab_path_huawei: ""
34
+ shard_num: 30000
35
+ path_huawei: "/mnt/sfs/asr/update_data/Alpaca_CoT_3000W_by_wjt_added_by_20250605/shards_list.txt"
36
+ path: "/home/A03_tmpdata1/text2text_data_xlgeng/Alpaca-CoT_3000W/shards_list.txt"
37
+ weight: 0.15
38
+
39
+
40
+ qwenomni_bench_data:
41
+ path: "/home/A02_tmpdata3/osum_t2t/qwenomni_bench_data/shards_list.txt"
42
+ weight: 3
43
+
44
+ three_kingdoms:
45
+ path: "/home/A02_tmpdata3/osum_t2t/three_kingdoms/shards_list.txt"
46
+ weight: 3
47
+
48
+ voicebench_data:
49
+ path: "/home/A02_tmpdata3/osum_t2t/voicebench_data/shards_list.txt"
50
+ weight: 3
51
+
52
+ t2t_osum_self_instruction_8K:
53
+ path: "/home/A02_tmpdata3/osum_t2t/t2t_osum_self_instruction_8K/shards_list.txt"
54
+ weight: 3
55
+
56
+
57
+ #t2t_osum_self_instruction_8K_by_xlgeng_added_by_20250529:
58
+ # path: "/home/A02_tmpdata3/t2t_osum_self_instruction_8K_by_xlgeng_added_by_20250529/shards_list.txt"
59
+ # path_huawei: "/mnt/sfs/asr/update_data/t2t_osum_self_instruction_8K_by_xlgeng_added_by_20250529/shards_list.txt"
60
+ # weight: 5
61
+
62
+ #kouyu_t2t_data_by_xlgeng_added_by_20250622:
63
+ # path: ""
64
+ # shard_num: 1758
65
+ # path_huawei: "/mnt/sfs/asr/update_data/kouyu_t2t_data_by_xlgeng_added_by_20250622s/shards_list.txt"
66
+ # weight: 1
67
+ # 4653
68
+
69
+ #text2text_data_xlgeng_three_kingdoms_by_xlgeng_added_by_20250701:
70
+ # path: "/home/A02_tmpdata3/text2text_data_xlgeng_three_kingdoms_by_xlgeng_added_by_20250701/shards_list.txt"
71
+ # path_huawei: "/mnt/sfs/asr/update_data/text2text_data_xlgeng_three_kingdoms_by_xlgeng_added_by_20250701/shards_list.txt"
72
+ # weight: 1
73
+ # lab_path_huawei: "/home/work_nfs11/asr_data/data/text2text_data_xlgeng/shard/benchdata/three_kingdoms/shard/shards_list.txt"
74
+ # shard_num: 24
75
+ #
76
+ #text2text_data_xlgeng_qwenomni_bench_data_by_xlgeng_added_by_20250701:
77
+ # path: "/home/A02_tmpdata3/text2text_data_xlgeng_qwenomni_bench_data_by_xlgeng_added_by_20250701/shards_list.txt"
78
+ # path_huawei: "/mnt/sfs/asr/update_data/text2text_data_xlgeng_qwenomni_bench_data_by_xlgeng_added_by_20250701/shards_list.txt"
79
+ # weight: 1
80
+ # lab_path_huawei: "/home/work_nfs11/asr_data/data/text2text_data_xlgeng/shard/benchdata/qwenomni_bench_data/shard/shards_list.txt"
81
+ # shard_num: 113
82
+
83
+
84
+ #text2text_data_xlgeng_voicebench_data_by_xlgeng_added_by_20250701:
85
+ # path: "/home/A02_tmpdata3/text2text_data_xlgeng_voicebench_data_by_xlgeng_added_by_20250701/shards_list.txt"
86
+ # path_huawei: "/mnt/sfs/asr/update_data/text2text_data_xlgeng_voicebench_data_by_xlgeng_added_by_20250701/shards_list.txt"
87
+ # weight: 1
88
+ # lab_path_huawei: "/home/work_nfs11/asr_data/data/text2text_data_xlgeng/shard/benchdata/voicebench_data/shard/shards_list.txt"
89
+ # shard_num: 65
90
+ #
91
+ #t2t_age_chat_by_cywang_added_by_20250708: # have
92
+ # path: "/home/A02_tmpdata3/osum_s2s/t2t_age_chat_by_cywang_added_by_20250708/shards_list.txt"
93
+ # path_huawei: "/mnt/sfs/asr/update_data/t2t_age_chat_by_cywang_added_by_20250708/shards_list.txt"
94
+ # lab_path_huawei: "/home/work_nfs11/asr_data/data/osum_data/t2t_paralanguage_chat/age_chat/shard_dir/shards_list.txt"
95
+ # shard_num: 50
96
+ # weight: 1
97
+ #
98
+ #t2t_caption_chat_by_cywang_added_by_20250708: # have
99
+ # path: "/home/A02_tmpdata3/osum_s2s/t2t_caption_chat_by_cywang_added_by_20250708/shards_list.txt"
100
+ # lab_path_huawei: "/home/work_nfs11/asr_data/data/osum_data/t2t_paralanguage_chat/caption_chat/shard_dir/shards_list.txt"
101
+ # path_huawei: "/mnt/sfs/asr/update_data/t2t_caption_chat_by_cywang_added_by_20250708/shards_list.txt"
102
+ # shard_num: 100
103
+ # weight: 1
104
+ #
105
+ #t2t_emotion_chat_by_cywang_added_by_20250708: # have
106
+ # path: "/home/A02_tmpdata3/osum_s2s/t2t_emotion_chat_by_cywang_added_by_20250708/shards_list.txt"
107
+ # lab_path_huawei: "/home/work_nfs11/asr_data/data/osum_data/t2t_paralanguage_chat/emotion_chat/shard_dir/shards_list.txt"
108
+ # path_huawei: "/mnt/sfs/asr/update_data/t2t_emotion_chat_by_cywang_added_by_20250708/shards_list.txt"
109
+ # shard_num: 50
110
+ # weight: 1
111
+ #
112
+ #t2t_sex_chat_by_cywang_added_by_20250708: # have
113
+ # path: "/home/A02_tmpdata3/osum_s2s/t2t_sex_chat_by_cywang_added_by_20250708/shards_list.txt"
114
+ # lab_path_huawei: "/home/work_nfs11/asr_data/data/osum_data/t2t_paralanguage_chat/sex_chat/shard_dir/shards_list.txt"
115
+ # path_huawei: "/mnt/sfs/asr/update_data/t2t_sex_chat_by_cywang_added_by_20250708/shards_list.txt"
116
+ # shard_num: 50
117
+ # weight: 1
118
+ #
119
+ #t2t_xianshi_emotion_chat_by_cywang_added_by_20250711: # no
120
+ # path: "/home/A02_tmpdata3/osum_t2t/t2t_xianshi_emotion_chat/shards_list.txt"
121
+ # lab_path_huawei: "/home/work_nfs11/asr_data/data/osum_data/t2t_paralanguage_chat/xianshi_emotion_chat/shard_dir/shards_list.txt|/home/work_nfs23/asr_data/data/osum_chat/t2t_data/t2t_paralanguage_chat/xianshi_emotion_chat/shards_list.txt"
122
+ # path_huawei: "/mnt/sfs/asr/update_data/t2t_xianshi_emotion_chat_by_cywang_added_by_20250711/shards_list.txt"
123
+ # shard_num: 50
124
+ # weight: 1
125
+ #
126
+ #t2t_sex_chat_2_by_cywang_added_by_20250711: # no
127
+ # path: "/home/A02_tmpdata3/osum_t2t/t2t_sex_chat_2_by_cywang_added_by_20250711/shards_list.txt"
128
+ # lab_path_huawei: "/home/work_nfs11/asr_data/data/osum_data/t2t_paralanguage_chat/sex_chat_2/shard_dir/shards_list.txt|/home/work_nfs23/asr_data/data/osum_chat/t2t_data/t2t_paralanguage_chat/t2t_sex_chat_2_by_cywang_added_by_20250711/shards_list.txt"
129
+ # path_huawei: "/mnt/sfs/asr/update_data/t2t_sex_chat_2_by_cywang_added_by_20250711/shards_list.txt"
130
+ # shard_num: 27
131
+ # weight: 1
132
+ #
133
+ #t2t_age_chat_2_by_cywang_added_by_20250711: # no
134
+ # path: "/home/A02_tmpdata3/osum_t2t/t2t_age_chat_2_by_cywang_added_by_20250711/shards_list.txt"
135
+ # path_huawei: "/mnt/sfs/asr/update_data/t2t_age_chat_2_by_cywang_added_by_20250711/shards_list.txt"
136
+ # lab_path_huawei: "/home/work_nfs11/asr_data/data/osum_data/t2t_paralanguage_chat/age_chat_2/shard_dir/shards_list.txt|/home/work_nfs23/asr_data/data/osum_chat/t2t_data/t2t_paralanguage_chat/t2t_age_chat_2_by_cywang_added_by_20250711/shards_list.txt"
137
+ # shard_num: 27
138
+ # weight: 1
139
+ #
140
+ #t2t_sex_chat_2_by_cywang_added_by_20250715: # no
141
+ # path: "/home/A02_tmpdata3/osum_t2t/t2t_sex_chat_2_by_cywang_added_by_20250715/shards_list.txt"
142
+ # path_huawei: "/mnt/sfs/asr/update_data/t2t_sex_chat_2_by_cywang_added_by_20250715/shards_list.txt"
143
+ # lab_path_huawei: "/home/work_nfs14/asr_data/data/osum_data/t2t_paralanguage_chat/sex_chat_2/shard_dir/shards_list.txt"
144
+ # shard_num: 10
145
+ # weight: 1
146
+ #
147
+ #t2t_age_chat_3_by_cywang_added_by_20250716: # no
148
+ # path: "/home/A02_tmpdata3/osum_t2t/t2t_age_chat_3_by_cywang_added_by_20250716/shards_list.txt"
149
+ # path_huawei: "/mnt/sfs/asr/update_data/t2t_age_chat_3_by_cywang_added_by_20250716/shards_list.txt"
150
+ # lab_path_huawei: "/home/work_nfs14/asr_data/data/osum_data/t2t_paralanguage_chat/age_chat_3/shard_dir/shards_list.txt"
151
+ # shard_num: 10
152
+ # weight: 1
153
+ #
154
+ #t2t_caption_chat_3_by_cywang_added_by_20250716: # have
155
+ # path: "/home/A02_tmpdata3/osum_t2t/t2t_caption_chat_3_by_cywang_added_by_20250716/shards_list.txt"
156
+ # path_huawei: "/mnt/sfs/asr/update_data/t2t_caption_chat_3_by_cywang_added_by_20250716/shards_list.txt"
157
+ # lab_path_huawei: "/home/work_nfs14/asr_data/data/osum_data/t2t_paralanguage_chat/caption_chat_3/shard_dir/shards_list.txt"
158
+ # shard_num: 10
159
+ # weight: 1
conf/data_tmp.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gaozhiliang_gbma:
2
+ path: /home/A02_tmpdata3/osum_s2s/gaozhiliang_gbma/shards_list.txt
3
+ new_data_list: /home/node44_tmpdata3/netease/gbma/workspace/osum/data/process/0803/all_data_info.jsonl
4
+ new_lab_path: /home/work_nfs23/asr_data/data/osum_chat/s2s/gaozhiliang_gbma/shards_list.txt
5
+ shard_num: 24
6
+ weight: 10
conf/ds_stage2.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train_micro_batch_size_per_gpu": 1,
3
+ "gradient_accumulation_steps": 8,
4
+ "steps_per_print": 10,
5
+ "gradient_clipping": 5,
6
+ "fp16": {
7
+ "enabled": false,
8
+ "auto_cast": true,
9
+ "loss_scale": 0,
10
+ "initial_scale_power": 16,
11
+ "loss_scale_window": 1000,
12
+ "hysteresis": 2,
13
+ "consecutive_hysteresis": false,
14
+ "min_loss_scale": 1
15
+ },
16
+ "bf16": {
17
+ "enabled": true
18
+ },
19
+ "zero_force_ds_cpu_optimizer": false,
20
+ "zero_optimization": {
21
+ "stage": 2,
22
+ "offload_optimizer": {
23
+ "device": "none",
24
+ "pin_memory": true
25
+ },
26
+ "allgather_partitions": true,
27
+ "allgather_bucket_size": 2e8,
28
+ "reduce_scatter": true,
29
+ "reduce_bucket_size": 2e8,
30
+ "contiguous_gradients": false,
31
+ "overlap_comm": false
32
+ },
33
+ "find_unused_parameters": true
34
+ }
conf/empty.yaml ADDED
File without changes
conf/prompt_config.yaml ADDED
The diff for this file is too large to render. See raw diff
 
conf/system_prompt.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # qwen_instruct_prompt_pattern_chat_s2t = "<|im_start|>system\nYou are OSUM-chat, a speech-to-text dialogue. You understand both the meaning and paralinguistic cues in speech then respond exclusively with appropriate text.<|im_end|>\n"
2
+ # qwen_instruct_prompt_pattern__chat_t2t = "<|im_start|>system\n<|im_end|>\n"
3
+ # qwen_instruct_prompt_pattern_chat_s2s = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech then respond with appropriate text and emotionally matching synthetic speech.<|im_end|>\n"
4
+ # qwen_instruct_prompt_pattern_chat_s2s_think = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech. Before responding, first output your reasoning inside <think>...</think end>, analyzing the user’s words and vocal cues. Then generate a reply with appropriate text and emotionally matched synthetic speech.<|im_end|>\n"
5
+ # qwen_instruct_prompt_pattern_chat_s2s_streaming = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You analyze speech (content + paralinguistic cues) and respond with interleaved text and emotionally-matched synthetic speech.<|im_end|>\n"
6
+ # qwen_instruct_prompt_pattern_chat_s2s_streaming_think = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You analyze speech (both content and paralinguistic cues). Before responding, output your reasoning in <think>...</think end>. Then reply with interleaved text and emotionally matched synthetic speech.<|im_end|>\n"
7
+ # qwen_instruct_prompt_pattern__chat_t2t = "<|im_start|>system\n
8
+
9
+ # qwen_instruct_prompt_pattern_1_understand = "<|im_start|>system\nYou are OSUM-chat, an audio understanding assistant by ASLP Lab. You can transcribe speech accurately and analyze paralinguistic cues to provide precise text responses.<|im_end|>\n"
10
+ # qwen_instruct_prompt_pattern_1_tts = "<|im_start|>system\nYou are OSUM-chat, a speech synthesis assistant by ASLP Lab. You generate natural and fluent speech from text input.<|im_end|>\n"
11
+ # qwen_instruct_prompt_pattern_1_tts_streaming = "<|im_start|>system\nYou are OSUM-chat, a speech synthesis assistant by ASLP Lab. You generate natural speech from text input and output both audio and the original text in interleaved format.<|im_end|>\n"
12
+ # qwen_instruct_prompt_pattern_1_old = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
13
+ # # user_start = "<|im_start|>user\n"
14
+ t2t_chat: # <TEXT2TEXT>
15
+ prompt: You are OSUM-chat, a text-to-text dialogue assistant by ASLP Lab. You understand user input in text then respond exclusively with appropriate text.
16
+
17
+ s2t_chat: # <S2TCHAT>
18
+ prompt: You are OSUM-chat, a speech-to-text dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech then respond exclusively with appropriate text.
19
+
20
+ s2t_chat_thinker: # <S2TCHAT> <THINKER>
21
+ prompt: You are OSUM-chat, a thinking-enabled speech-to-text dialogue assistant by ASLP Lab. You not only comprehend the semantic meaning and paralinguistic cues in speech but also engage in deliberate reasoning to process such information. Based on this thinking process, you then respond exclusively with appropriate text.
22
+
23
+ t2s: # <TEXT2TOKEN>
24
+ prompt: You are OSUM-chat, a speech synthesis assistant by ASLP Lab. You generate natural and fluent speech from text input.
25
+
26
+ speech_understanding: # <TRANSCRIBE> <CAPTION> 。。
27
+ prompt: You are OSUM-chat, an audio understanding assistant by ASLP Lab. You can transcribe speech accurately and analyze paralinguistic cues to provide precise text responses.
patches/cumstom_stop_criteria.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers.generation.logits_process import LogitsProcessor
3
+ from transformers.generation.stopping_criteria import StoppingCriteria
4
+
5
+ class ASRLogitsProcessor(LogitsProcessor):
6
+ def __init__(self, text_token_num: int):
7
+ self.text_token_num = text_token_num
8
+
9
+ def __call__(self, input_ids, scores):
10
+ scores[..., self.text_token_num:] = torch.finfo(scores.dtype).min
11
+ return scores
12
+
13
+ class TTSLogitsProcessor(LogitsProcessor):
14
+ """
15
+ TTS 任务使用的LogitsProcessor,把所有text位置的logits设置为负无穷
16
+ """
17
+ def __init__(self, text_token_num: int):
18
+ self.text_token_num = text_token_num
19
+
20
+ def __call__(self, input_ids, scores):
21
+ scores[..., :self.text_token_num] = torch.finfo(scores.dtype).min
22
+ return scores
23
+
24
+ class S2SLogitsProcessor(LogitsProcessor):
25
+ """Speech 2 Speech 任务使用的 LogitsProcessor,当前只适用于batch_size=1
26
+
27
+ Args:
28
+ LogitsProcessor (_type_): _description_
29
+ """
30
+ def __init__(self, text_token_num: int, text_eos_id: int):
31
+ self.text_token_num = text_token_num
32
+ self.text_eos_id = text_eos_id
33
+ self.text_phase = True
34
+ def __call__(self, input_ids, scores):
35
+ print(input_ids.shape)
36
+ assert input_ids.size(0) == 1, "ERROR: S2SSpeechLogitsProcessor only support bs=1 now"
37
+ if self.text_phase:
38
+ scores[..., self.text_token_num:] = torch.finfo(scores.dtype).min
39
+ else:
40
+ scores[..., :self.text_token_num] = torch.finfo(scores.dtype).min
41
+
42
+ if self.text_phase and torch.isin(input_ids, self.text_eos_id):
43
+ self.text_phase = False
44
+
45
+ return scores
46
+
47
+ class S2SStopCriteria(StoppingCriteria):
48
+ """Speech 2 Speech 任务使用的 停止条件,当前只适用于batch_size=1
49
+
50
+ Args:
51
+ LogitsProcessor (_type_): _description_
52
+ """
53
+ def __init__(self, text_eos_id: int, speech_eos_id: int):
54
+ self.text_eos_id = text_eos_id
55
+ self.speech_eos_id = speech_eos_id
56
+
57
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
58
+ _input_ids = input_ids.flatten().view(-1)
59
+ if torch.isin(_input_ids, self.text_eos_id).any():
60
+ text_eos_idx = (_input_ids == self.text_eos_id).nonzero(as_tuple=True)[0][0].item()
61
+ if torch.sum(_input_ids[text_eos_idx:] == self.speech_eos_id) > 1:
62
+ return True
63
+ return False
64
+
65
+ class MaxTokenStopper(StoppingCriteria):
66
+ def __init__(self, max_tokens):
67
+ self.max_tokens = max_tokens
68
+
69
+ # TODO@wsy:期望能够修改max_tokens,但好像没用,后续注意
70
+ def change_max_tokens(self, max_tokens):
71
+ self.max_tokens = max_tokens
72
+
73
+ def __call__(self, input_ids, scores, **kwargs):
74
+ return input_ids.shape[1] >= self.max_tokens # 检查当前序列长度
75
+
76
+ class InterruptStopper(StoppingCriteria):
77
+ def __init__(self):
78
+ self.stop = False
79
+
80
+ def __call__(self, input_ids, scores, **kwargs):
81
+ if self.stop == True:
82
+ # self.stop == False # reset
83
+ return True
84
+ else:
85
+ return False
patches/custom_speech_ngram_blocking.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.generation.logits_process import LogitsProcessor
2
+ import torch
3
+
4
+ class SpeechOnlyNGramBlockingLogitsProcessor(LogitsProcessor):
5
+ def __init__(
6
+ self,
7
+ speech_token_num,
8
+ repeat_times=5,
9
+ special_token_repeat_times_dict=None,
10
+ window_size=8,
11
+ window_repeat=5,
12
+ special_token_window_dict=None
13
+ ):
14
+ """
15
+ speech_token_num: int, speech token 的数量(token_id in [0, speech_token_num) 视为 speech token)
16
+ repeat_times: int, 普通 speech token 的最大允许连续重复次数
17
+ special_token_repeat_times_dict: dict, {token_id: repeat_times},为特殊 speech token 单独指定最大连续重复次数
18
+ window_size: int, 默认滑动窗口大小
19
+ window_repeat: int, 默认窗口内最大允许出现次数
20
+ special_token_window_dict: dict, {token_id: (window_size, window_repeat)},为特殊 token 单独指定窗口参数
21
+ """
22
+ self.speech_token_num = speech_token_num
23
+ self.repeat_times = repeat_times
24
+ self.special_token_repeat_times_dict = special_token_repeat_times_dict or {}
25
+ self.speech_phase = False # 你需要在外部控制这个变量
26
+ self.window_size = window_size
27
+ self.window_repeat = window_repeat
28
+ self.special_token_window_dict = special_token_window_dict or {1446: (13, 10)}
29
+
30
+ def set_phase(self, speech_phase: bool):
31
+ self.speech_phase = speech_phase
32
+
33
+ def __call__(self, input_ids, scores):
34
+ if not self.speech_phase:
35
+ # text 阶段,什么都不做
36
+ return scores
37
+ batch_size, seq_len = input_ids.size()
38
+ for batch_idx in range(batch_size):
39
+ generated = input_ids[batch_idx].tolist()
40
+ if seq_len == 0:
41
+ continue
42
+ last_token = generated[-1]
43
+ if last_token >= self.speech_token_num:
44
+ continue # 不是 speech token
45
+
46
+ # 统计最近的 token 连续重复了多少次
47
+ repeat_count = 1
48
+ for i in range(seq_len-2, -1, -1):
49
+ if generated[i] == last_token:
50
+ repeat_count += 1
51
+ else:
52
+ break
53
+ # 获取该 token 的最大允许重复次数
54
+ max_repeat = self.special_token_repeat_times_dict.get(last_token, self.repeat_times)
55
+ if repeat_count >= max_repeat:
56
+ scores[batch_idx, last_token] = -float('inf') # 阻止生成
57
+
58
+ # ====== 滑动窗口内频率抑制 ======
59
+ # 对窗口内所有 speech token 检查
60
+ window_tokens = set(generated[-max(self.window_size, max([v[0] for v in self.special_token_window_dict.values()], default=0)):])
61
+ for token in window_tokens:
62
+ if token >= self.speech_token_num:
63
+ continue
64
+ # 获取该 token 的窗口参数
65
+ window_size, window_repeat = self.special_token_window_dict.get(
66
+ token, (self.window_size, self.window_repeat)
67
+ )
68
+ window = generated[-window_size:]
69
+ if window.count(token) >= window_repeat:
70
+ scores[batch_idx, token] = -float('inf')
71
+ # ====== 滑动窗口内频率抑制结束 ======
72
+ return scores
73
+
74
+
75
+
76
+
77
+ class OSUM_chat_LogitsProcessor(LogitsProcessor):
78
+ def __init__(self, allowed_tokens, sequence_to_match):
79
+ """
80
+ 初始化OSUM_chat_LogitsProcessor。
81
+
82
+ 参数:
83
+ allowed_tokens (list): 允许出现在当前时间步的token的ID列表
84
+ sequence_to_match (list): 用来判断当前时间步允许token的前置序列
85
+ """
86
+ self.allowed_tokens = allowed_tokens
87
+ self.sequence_to_match = sequence_to_match
88
+ self.match_found = False # 添加一个标志,表示是否已经找到匹配的序列
89
+
90
+ def init_match_found(self):
91
+ """
92
+ 初始化match_found标志。
93
+ """
94
+ self.match_found = False
95
+
96
+ def __call__(self, input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
97
+ """
98
+ 在每个时间步处理logits,对不符合条件的token设置极小的概率。
99
+
100
+ 参数:
101
+ input_ids (torch.Tensor): 当前输入的token ID序列
102
+ logits (torch.Tensor): 当前时间步的logits (shape: [batch_size, vocab_size])
103
+
104
+ 返回:
105
+ torch.Tensor: 被处理过的logits
106
+ """
107
+ # 如果已经匹配过一次,就跳过匹配检测,直接返回logits
108
+ # print("recent_tokens:!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") # 打印当前生成的序列
109
+ if self.match_found:
110
+ return logits
111
+
112
+ # 获取当前生成的序列的最后几个token(假设生成的长度大于等于序列长度)
113
+ sequence_length = len(self.sequence_to_match)
114
+ if input_ids.shape[-1] >= sequence_length:
115
+ recent_tokens = input_ids[:, -sequence_length:].tolist()
116
+ # print("recent_tokens:", recent_tokens) # 打印当前生成的序列
117
+
118
+ # 检查前面生成的token是否匹配我们需要的序列
119
+ if all(recent_tokens[0][i] == self.sequence_to_match[i] for i in range(sequence_length)):
120
+ # Create a mask for allowed tokens while preserving original logits
121
+ mask = torch.zeros_like(logits, dtype=torch.bool) # Initialize mask as False
122
+ mask[:, self.allowed_tokens] = True # Mark allowed tokens as True
123
+ # Apply mask: keep original logits for allowed tokens, set others to -inf
124
+ logits = torch.where(mask, logits, -float('inf'))
125
+ # 设置标志,表示匹配已成功
126
+ self.match_found = True
127
+ print("match found!!!!!!!!!!!!!!!!!!!!!!!")
128
+
129
+ return logits
patches/custom_speech_repetition_penalty.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.generation.logits_process import LogitsProcessor
2
+
3
+ class SpeechOnlyRepetitionPenaltyLogitsProcessor(LogitsProcessor):
4
+ def __init__(self, speech_token_num, penalty=1.2):
5
+ self.speech_token_num = speech_token_num
6
+ self.penalty = penalty
7
+ self.speech_phase = False # 你需要在外部控制这个变量
8
+
9
+ def set_phase(self, speech_phase: bool):
10
+ self.speech_phase = speech_phase
11
+
12
+ def __call__(self, input_ids, scores):
13
+ if not self.speech_phase:
14
+ # text阶段,什么都不做
15
+ return scores
16
+ # speech阶段,只对speech token做重复抑制
17
+ for batch_idx in range(input_ids.size(0)):
18
+ generated = input_ids[batch_idx].tolist()
19
+ for token_id in set(generated):
20
+ if 0 <= token_id < self.speech_token_num:
21
+ scores[batch_idx, token_id] /= self.penalty
22
+ return scores
patches/modelling_fm_infer_gpu.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import List, Optional, Tuple, Union
4
+ import f5_tts
5
+ from f5_tts.model.backbones.dit_mask import DiT as DiT_
6
+
7
+ _GPU_FM_TORCH_COMPILE = True
8
+
9
+ class GPUDiT(DiT_):
10
+ def __init__(self, *args, **kwargs):
11
+ super().__init__(*args, **kwargs)
12
+ self.fast_forward = torch.compile(self.fast_forward, dynamic=False, fullgraph=True) \
13
+ if _GPU_FM_TORCH_COMPILE else self.fast_forward
14
+
15
+ # ===================================================================
16
+ print("========================= DO FM PATCH ============================")
17
+ # ===================================================================
18
+ f5_tts.model.backbones.dit_mask.DiT = GPUDiT
patches/modelling_qwen2_infer_gpu.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import List, Optional, Tuple, Union
5
+ import transformers.models
6
+ from transformers.models.qwen2.modeling_qwen2 import (
7
+ Qwen2RotaryEmbedding,
8
+ Qwen2ForCausalLM,
9
+ Qwen2MLP,
10
+ Qwen2RMSNorm,
11
+ apply_rotary_pos_emb,
12
+ repeat_kv,
13
+ _prepare_4d_causal_attention_mask_with_cache_position,
14
+ )
15
+ from transformers.utils import logging
16
+ from transformers.modeling_outputs import CausalLMOutputWithPast
17
+ from transformers.cache_utils import Cache, StaticCache, SlidingWindowCache
18
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
19
+ from .utils import InferTaskCode
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ _GPU_QWEN_TORCH_COMPILE = True
24
+
25
+ # ===================================================================
26
+ # =============================Attention=============================
27
+ # ===================================================================
28
+ class GPUQwen2Attention(nn.Module):
29
+ """
30
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
31
+ and "Generating Long Sequences with Sparse Transformers".
32
+ """
33
+ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
34
+ super().__init__()
35
+ self.config = config
36
+ self.layer_idx = layer_idx
37
+ if layer_idx is None:
38
+ logger.warning_once(
39
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
40
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
41
+ "when creating this class."
42
+ )
43
+
44
+ self.hidden_size = config.hidden_size
45
+ self.num_heads = config.num_attention_heads
46
+ self.head_dim = self.hidden_size // self.num_heads
47
+ self.num_key_value_heads = config.num_key_value_heads
48
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
49
+ self.max_position_embeddings = config.max_position_embeddings
50
+ self.rope_theta = config.rope_theta
51
+ self.is_causal = True
52
+ self.attention_dropout = config.attention_dropout
53
+
54
+ if (self.head_dim * self.num_heads) != self.hidden_size:
55
+ raise ValueError(
56
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
57
+ f" and `num_heads`: {self.num_heads})."
58
+ )
59
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
60
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
61
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
62
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
63
+
64
+ self.rotary_emb = Qwen2RotaryEmbedding(
65
+ self.head_dim,
66
+ max_position_embeddings=self.max_position_embeddings,
67
+ base=self.rope_theta,
68
+ )
69
+
70
+ # Adapted from Qwen2Attention.forward
71
+ def forward(
72
+ self,
73
+ hidden_states: torch.Tensor,
74
+ attention_mask: Optional[torch.Tensor] = None,
75
+ position_ids: Optional[torch.LongTensor] = None,
76
+ past_key_value: Optional[Cache] = None,
77
+ output_attentions: bool = False,
78
+ use_cache: bool = False,
79
+ cache_position: Optional[torch.LongTensor] = None,
80
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
81
+ bsz, q_len, _ = hidden_states.size()
82
+
83
+ query_states = self.q_proj(hidden_states)
84
+ key_states = self.k_proj(hidden_states)
85
+ value_states = self.v_proj(hidden_states)
86
+
87
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
88
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
89
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
90
+
91
+ # NOTE: RoPE return all embedding (to satisfy torch compile)
92
+ cos, sin = self.rotary_emb(value_states, seq_len=past_key_value.get_max_length())
93
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
94
+
95
+ if past_key_value is not None:
96
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
97
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
98
+
99
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
100
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
101
+
102
+ causal_mask = attention_mask
103
+ if attention_mask is not None: # no matter the length, we just slice it
104
+ causal_mask = attention_mask[:, :, :, : past_key_value.get_max_length()]
105
+
106
+ query_states = query_states.contiguous()
107
+ key_states = key_states.contiguous()
108
+ value_states = value_states.contiguous()
109
+
110
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
111
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
112
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
113
+ is_causal = True if causal_mask is None and q_len > 1 else False
114
+
115
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
116
+ query_states,
117
+ key_states,
118
+ value_states,
119
+ attn_mask=causal_mask,
120
+ dropout_p=0.0,
121
+ is_causal=is_causal,
122
+ )
123
+
124
+ attn_output = attn_output.transpose(1, 2).contiguous()
125
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
126
+
127
+ attn_output = self.o_proj(attn_output)
128
+
129
+ return attn_output, None, past_key_value
130
+
131
+
132
+ # ===================================================================
133
+ # =============================Layer=================================
134
+ # ===================================================================
135
+ class GPUQwen2DecoderLayer(nn.Module):
136
+ def __init__(self, config: Qwen2Config, layer_idx: int):
137
+ super().__init__()
138
+ self.hidden_size = config.hidden_size
139
+
140
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
141
+ logger.warning_once(
142
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
143
+ "unexpected results may be encountered."
144
+ )
145
+ self.self_attn = GPUQwen2Attention(config, layer_idx)
146
+
147
+ self.mlp = Qwen2MLP(config)
148
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
149
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
150
+
151
+ def forward(
152
+ self,
153
+ hidden_states: torch.Tensor,
154
+ attention_mask: Optional[torch.Tensor] = None,
155
+ position_ids: Optional[torch.LongTensor] = None,
156
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
157
+ output_attentions: Optional[bool] = False,
158
+ use_cache: Optional[bool] = False,
159
+ cache_position: Optional[torch.LongTensor] = None,
160
+ **kwargs,
161
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
162
+ """
163
+ Args:
164
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
165
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
166
+ `(batch, sequence_length)` where padding elements are indicated by 0.
167
+ output_attentions (`bool`, *optional*):
168
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
169
+ returned tensors for more detail.
170
+ use_cache (`bool`, *optional*):
171
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
172
+ (see `past_key_values`).
173
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
174
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
175
+ Indices depicting the position of the input sequence tokens in the sequence.
176
+ kwargs (`dict`, *optional*):
177
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
178
+ into the model
179
+ """
180
+
181
+ residual = hidden_states
182
+
183
+ hidden_states = self.input_layernorm(hidden_states)
184
+
185
+ # Self Attention
186
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
187
+ hidden_states=hidden_states,
188
+ attention_mask=attention_mask,
189
+ position_ids=position_ids,
190
+ past_key_value=past_key_value,
191
+ output_attentions=output_attentions,
192
+ use_cache=use_cache,
193
+ cache_position=cache_position,
194
+ )
195
+ hidden_states = residual + hidden_states
196
+
197
+ # Fully Connected
198
+ residual = hidden_states
199
+ hidden_states = self.post_attention_layernorm(hidden_states)
200
+ hidden_states = self.mlp(hidden_states)
201
+ hidden_states = residual + hidden_states
202
+
203
+ outputs = (hidden_states,)
204
+
205
+ if output_attentions:
206
+ outputs += (self_attn_weights,)
207
+
208
+ if use_cache:
209
+ outputs += (present_key_value,)
210
+
211
+ return outputs
212
+
213
+ # ===================================================================
214
+ # ========================Qwen2ForCausalLM===========================
215
+ # ===================================================================
216
+ class InferQwen2ForCausalLM(Qwen2ForCausalLM):
217
+ def __init__(self, config):
218
+ super().__init__(config)
219
+ self.compile_forward = torch.compile(self.simplify_forward, dynamic=False, fullgraph=True) \
220
+ if _GPU_QWEN_TORCH_COMPILE else self.simplify_forward
221
+ self.text_phase = True
222
+ '''
223
+ NOTE: 重写原Qwen2ForCausalLM forward函数,torchair直接编译原函数在返回CausalLMOutputWithPast时会出现编译错误
224
+ '''
225
+ def simplify_forward(self,
226
+ input_ids: torch.LongTensor = None,
227
+ attention_mask: Optional[torch.Tensor] = None,
228
+ position_ids: Optional[torch.LongTensor] = None,
229
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
230
+ inputs_embeds: Optional[torch.FloatTensor] = None,
231
+ labels: Optional[torch.LongTensor] = None,
232
+ use_cache: Optional[bool] = None,
233
+ output_attentions: Optional[bool] = None,
234
+ output_hidden_states: Optional[bool] = None,
235
+ return_dict: Optional[bool] = None,
236
+ cache_position: Optional[torch.LongTensor] = None,
237
+ ):
238
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
239
+ output_hidden_states = (
240
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
241
+ )
242
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
243
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
244
+ outputs = self.model(
245
+ input_ids=input_ids,
246
+ attention_mask=attention_mask,
247
+ position_ids=position_ids,
248
+ past_key_values=past_key_values,
249
+ inputs_embeds=inputs_embeds,
250
+ use_cache=use_cache,
251
+ output_attentions=output_attentions,
252
+ output_hidden_states=output_hidden_states,
253
+ return_dict=return_dict,
254
+ cache_position=cache_position,
255
+ )
256
+
257
+ return outputs
258
+
259
+ def forward(self,
260
+ input_ids: torch.LongTensor = None,
261
+ attention_mask: Optional[torch.Tensor] = None,
262
+ position_ids: Optional[torch.LongTensor] = None,
263
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
264
+ inputs_embeds: Optional[torch.FloatTensor] = None,
265
+ labels: Optional[torch.LongTensor] = None,
266
+ use_cache: Optional[bool] = None,
267
+ output_attentions: Optional[bool] = None,
268
+ output_hidden_states: Optional[bool] = None,
269
+ return_dict: Optional[bool] = None,
270
+ cache_position: Optional[torch.LongTensor] = None,
271
+ do_compile = True
272
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
273
+ if past_key_values is not None:
274
+ past_key_values.training = False
275
+ # print(self.text_phase)
276
+ if input_ids is not None:
277
+ if self.text_phase:
278
+ inputs_embeds = self.model.embed_tokens(input_ids)
279
+ else:
280
+ inputs_embeds = self.speech_token_emded(input_ids)
281
+ if torch.isin(input_ids, 151645).any():
282
+ self.text_phase = False
283
+ input_ids = None
284
+
285
+ if (inputs_embeds is not None and cache_position[0] == 0) or do_compile==False :
286
+ # prefill branch
287
+ outputs = self.simplify_forward(input_ids,
288
+ attention_mask,
289
+ position_ids,
290
+ past_key_values,
291
+ inputs_embeds,
292
+ labels,
293
+ use_cache,
294
+ output_attentions,
295
+ output_hidden_states,
296
+ return_dict,
297
+ cache_position)
298
+ else:
299
+ # decoding
300
+ outputs = self.compile_forward(input_ids,
301
+ attention_mask,
302
+ position_ids,
303
+ past_key_values,
304
+ inputs_embeds,
305
+ labels,
306
+ use_cache,
307
+ output_attentions,
308
+ output_hidden_states,
309
+ return_dict,
310
+ cache_position)
311
+
312
+ last_hidden_states = outputs.last_hidden_state
313
+
314
+ if self.text_phase:
315
+ logits = self.lm_head(last_hidden_states)
316
+ else:
317
+ logits = self.speech_head(last_hidden_states)
318
+
319
+ logits = logits.float()
320
+
321
+ return CausalLMOutputWithPast(
322
+ loss=None,
323
+ logits=logits,
324
+ past_key_values=outputs.past_key_values,
325
+ hidden_states=outputs.hidden_states,
326
+ attentions=outputs.attentions,
327
+ )
328
+
329
+
330
+ def prepare_inputs_for_generation(
331
+ self,
332
+ input_ids,
333
+ past_key_values=None,
334
+ attention_mask=None,
335
+ inputs_embeds=None,
336
+ cache_position=None,
337
+ position_ids=None,
338
+ use_cache=True,
339
+ **kwargs,
340
+ ):
341
+ """
342
+ Mainly add static cache support
343
+ """
344
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
345
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
346
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
347
+ if past_key_values is not None:
348
+ if inputs_embeds is not None: # Exception 1
349
+ input_ids = input_ids[:, -cache_position.shape[0] :]
350
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
351
+ input_ids = input_ids[:, cache_position]
352
+
353
+ if attention_mask is not None and position_ids is None:
354
+ # create position_ids on the fly for batch generation
355
+ position_ids = attention_mask.long().cumsum(-1) - 1
356
+ position_ids.masked_fill_(attention_mask == 0, 1)
357
+ if past_key_values:
358
+ position_ids = position_ids[:, -input_ids.shape[1] :]
359
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`,
360
+ # as otherwise the input `position_ids` would have various stride during the decoding.
361
+ # Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case,
362
+ # `position_ids` is already contiguous but with varying stride which retriggers a capture.
363
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
364
+
365
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
366
+ if inputs_embeds is not None and cache_position[0] == 0:
367
+ model_inputs = {"inputs_embeds": inputs_embeds}
368
+ else:
369
+ # NOTE: 与上述的position_ids相同,same as position_ids, for torch.compile and cuda graph
370
+ input_ids = input_ids.clone(memory_format=torch.contiguous_format)
371
+ model_inputs = {"input_ids": input_ids}
372
+
373
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
374
+ if inputs_embeds is not None and cache_position[0] == 0:
375
+ # prefill phase, inputs_embeds has shape (B,S,H)
376
+ batch_size, sequence_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
377
+ device = inputs_embeds.device
378
+ else:
379
+ # decdoing phase, input_ids has shape (B,S)
380
+ batch_size, sequence_length = input_ids.shape
381
+ device = input_ids.device
382
+
383
+ dtype = self.lm_head.weight.dtype
384
+ min_dtype = torch.finfo(dtype).min
385
+
386
+ if inputs_embeds is not None and inputs_embeds.ndim == 2 or input_ids is not None and input_ids.size(-1) == 1:
387
+ # we only expand attention mask in docoding mode
388
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
389
+ attention_mask,
390
+ sequence_length=sequence_length,
391
+ target_length=past_key_values.get_max_length(),
392
+ dtype=dtype,
393
+ device=device,
394
+ min_dtype=min_dtype,
395
+ cache_position=cache_position,
396
+ batch_size=batch_size,
397
+ )
398
+
399
+ model_inputs.update(
400
+ {
401
+ "position_ids": position_ids,
402
+ "cache_position": cache_position,
403
+ "past_key_values": past_key_values,
404
+ "use_cache": use_cache,
405
+ "attention_mask": attention_mask,
406
+ "do_compile": kwargs['do_compile'],
407
+ }
408
+ )
409
+ return model_inputs
410
+
411
+ # ===================================================================
412
+ print("========================= DO Qwen2 PATCH ===========================")
413
+ # ===================================================================
414
+ transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel._supports_static_cache = True # enable static cache
415
+ transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer = GPUQwen2DecoderLayer
416
+ transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM = InferQwen2ForCausalLM
patches/utils.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ class InferTaskCode:
2
+ _ASR = 0
3
+ _TTS = 1
4
+ _S2S = 2
requirements.txt ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.24
2
+ jsonlines==4.0.0
3
+ torch==2.1.0
4
+ transformers==4.44.0
5
+ torchaudio==2.1.0
6
+ accelerate==1.7.0
7
+ peft==0.17.0
8
+ librosa
9
+ tensorboardX>=2.5
10
+ # torch_npu==2.1.0.post8
11
+ tqdm
12
+ absl-py
13
+ psutil
14
+ cloudpickle
15
+ ml-dtypes
16
+ tornado
17
+ openai-whisper
18
+ colorama
19
+ sox
20
+ deepspeed
21
+ librosa
22
+ gxl_ai_utils
23
+
24
+
25
+ hyperpyyaml
26
+ modelscope
27
+ onnxruntime
28
+ inflect
29
+ omegaconf
30
+ conformer
31
+ diffusers
32
+ hydra-core
33
+ lightning
34
+
35
+ gradio
36
+ cn2an
37
+ gdown
38
+ matplotlib
39
+ wget
40
+ pyarrow
41
+ pyworld
tts/__init__.py ADDED
File without changes
tts/assert//345/256/236/351/252/214/345/256/244.png ADDED
tts/cosyvoice/__init__.py ADDED
File without changes
tts/cosyvoice/bin/average_model.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc (Di Wu)
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import argparse
18
+ import glob
19
+
20
+ import yaml
21
+ import torch
22
+
23
+
24
+ def get_args():
25
+ parser = argparse.ArgumentParser(description='average model')
26
+ parser.add_argument('--dst_model', required=True, help='averaged model')
27
+ parser.add_argument('--src_path',
28
+ required=True,
29
+ help='src model path for average')
30
+ parser.add_argument('--val_best',
31
+ action="store_true",
32
+ help='averaged model')
33
+ parser.add_argument('--num',
34
+ default=5,
35
+ type=int,
36
+ help='nums for averaged model')
37
+
38
+ args = parser.parse_args()
39
+ print(args)
40
+ return args
41
+
42
+
43
+ def main():
44
+ args = get_args()
45
+ val_scores = []
46
+ if args.val_best:
47
+ yamls = glob.glob('{}/*.yaml'.format(args.src_path))
48
+ yamls = [
49
+ f for f in yamls
50
+ if not (os.path.basename(f).startswith('train')
51
+ or os.path.basename(f).startswith('init'))
52
+ ]
53
+ for y in yamls:
54
+ with open(y, 'r') as f:
55
+ dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
56
+ loss = float(dic_yaml['loss_dict']['loss'])
57
+ epoch = int(dic_yaml['epoch'])
58
+ step = int(dic_yaml['step'])
59
+ tag = dic_yaml['tag']
60
+ val_scores += [[epoch, step, loss, tag]]
61
+ sorted_val_scores = sorted(val_scores,
62
+ key=lambda x: x[2],
63
+ reverse=False)
64
+ print("best val (epoch, step, loss, tag) = " +
65
+ str(sorted_val_scores[:args.num]))
66
+ path_list = [
67
+ args.src_path + '/epoch_{}_whole.pt'.format(score[0])
68
+ for score in sorted_val_scores[:args.num]
69
+ ]
70
+ print(path_list)
71
+ avg = {}
72
+ num = args.num
73
+ assert num == len(path_list)
74
+ for path in path_list:
75
+ print('Processing {}'.format(path))
76
+ states = torch.load(path, map_location=torch.device('cpu'))
77
+ for k in states.keys():
78
+ if k not in avg.keys():
79
+ avg[k] = states[k].clone()
80
+ else:
81
+ avg[k] += states[k]
82
+ # average
83
+ for k in avg.keys():
84
+ if avg[k] is not None:
85
+ # pytorch 1.6 use true_divide instead of /=
86
+ avg[k] = torch.true_divide(avg[k], num)
87
+ print('Saving to {}'.format(args.dst_model))
88
+ torch.save(avg, args.dst_model)
89
+
90
+
91
+ if __name__ == '__main__':
92
+ main()
tts/cosyvoice/bin/export_jit.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+ import sys
22
+ import torch
23
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
24
+ sys.path.append('{}/../..'.format(ROOT_DIR))
25
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
26
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
27
+
28
+
29
+ def get_args():
30
+ parser = argparse.ArgumentParser(description='export your model for deployment')
31
+ parser.add_argument('--model_dir',
32
+ type=str,
33
+ default='pretrained_models/CosyVoice-300M',
34
+ help='local path')
35
+ args = parser.parse_args()
36
+ print(args)
37
+ return args
38
+
39
+
40
+ def get_optimized_script(model, preserved_attrs=[]):
41
+ script = torch.jit.script(model)
42
+ if preserved_attrs != []:
43
+ script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
44
+ else:
45
+ script = torch.jit.freeze(script)
46
+ script = torch.jit.optimize_for_inference(script)
47
+ return script
48
+
49
+
50
+ def main():
51
+ args = get_args()
52
+ logging.basicConfig(level=logging.DEBUG,
53
+ format='%(asctime)s %(levelname)s %(message)s')
54
+
55
+ torch._C._jit_set_fusion_strategy([('STATIC', 1)])
56
+ torch._C._jit_set_profiling_mode(False)
57
+ torch._C._jit_set_profiling_executor(False)
58
+
59
+ try:
60
+ model = CosyVoice(args.model_dir)
61
+ except Exception:
62
+ try:
63
+ model = CosyVoice2(args.model_dir)
64
+ except Exception:
65
+ raise TypeError('no valid model_type!')
66
+
67
+ if not isinstance(model, CosyVoice2):
68
+ # 1. export llm text_encoder
69
+ llm_text_encoder = model.model.llm.text_encoder
70
+ script = get_optimized_script(llm_text_encoder)
71
+ script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
72
+ script = get_optimized_script(llm_text_encoder.half())
73
+ script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
74
+
75
+ # 2. export llm llm
76
+ llm_llm = model.model.llm.llm
77
+ script = get_optimized_script(llm_llm, ['forward_chunk'])
78
+ script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
79
+ script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
80
+ script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
81
+
82
+ # 3. export flow encoder
83
+ flow_encoder = model.model.flow.encoder
84
+ script = get_optimized_script(flow_encoder)
85
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
86
+ script = get_optimized_script(flow_encoder.half())
87
+ script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
88
+
89
+
90
+ if __name__ == '__main__':
91
+ main()
tts/cosyvoice/bin/export_onnx.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, [email protected])
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import print_function
17
+
18
+ import argparse
19
+ import logging
20
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
21
+ import os
22
+ import sys
23
+ import onnxruntime
24
+ import random
25
+ import torch
26
+ from tqdm import tqdm
27
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
28
+ sys.path.append('{}/../..'.format(ROOT_DIR))
29
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
30
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
31
+
32
+
33
+ def get_dummy_input(batch_size, seq_len, out_channels, device):
34
+ x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
35
+ mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
36
+ mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
37
+ t = torch.rand((batch_size), dtype=torch.float32, device=device)
38
+ spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
39
+ cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
40
+ return x, mask, mu, t, spks, cond
41
+
42
+
43
+ def get_args():
44
+ parser = argparse.ArgumentParser(description='export your model for deployment')
45
+ parser.add_argument('--model_dir',
46
+ type=str,
47
+ default='pretrained_models/CosyVoice-300M',
48
+ help='local path')
49
+ args = parser.parse_args()
50
+ print(args)
51
+ return args
52
+
53
+
54
+ def main():
55
+ args = get_args()
56
+ logging.basicConfig(level=logging.DEBUG,
57
+ format='%(asctime)s %(levelname)s %(message)s')
58
+
59
+ try:
60
+ model = CosyVoice(args.model_dir)
61
+ except Exception:
62
+ try:
63
+ model = CosyVoice2(args.model_dir)
64
+ except Exception:
65
+ raise TypeError('no valid model_type!')
66
+
67
+ # 1. export flow decoder estimator
68
+ estimator = model.model.flow.decoder.estimator
69
+
70
+ device = model.model.device
71
+ batch_size, seq_len = 2, 256
72
+ out_channels = model.model.flow.decoder.estimator.out_channels
73
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
74
+ torch.onnx.export(
75
+ estimator,
76
+ (x, mask, mu, t, spks, cond),
77
+ '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
78
+ export_params=True,
79
+ opset_version=18,
80
+ do_constant_folding=True,
81
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
82
+ output_names=['estimator_out'],
83
+ dynamic_axes={
84
+ 'x': {2: 'seq_len'},
85
+ 'mask': {2: 'seq_len'},
86
+ 'mu': {2: 'seq_len'},
87
+ 'cond': {2: 'seq_len'},
88
+ 'estimator_out': {2: 'seq_len'},
89
+ }
90
+ )
91
+
92
+ # 2. test computation consistency
93
+ option = onnxruntime.SessionOptions()
94
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
95
+ option.intra_op_num_threads = 1
96
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
97
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
98
+ sess_options=option, providers=providers)
99
+
100
+ for _ in tqdm(range(10)):
101
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
102
+ output_pytorch = estimator(x, mask, mu, t, spks, cond)
103
+ ort_inputs = {
104
+ 'x': x.cpu().numpy(),
105
+ 'mask': mask.cpu().numpy(),
106
+ 'mu': mu.cpu().numpy(),
107
+ 't': t.cpu().numpy(),
108
+ 'spks': spks.cpu().numpy(),
109
+ 'cond': cond.cpu().numpy()
110
+ }
111
+ output_onnx = estimator_onnx.run(None, ort_inputs)[0]
112
+ torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
113
+
114
+
115
+ if __name__ == "__main__":
116
+ main()
tts/cosyvoice/bin/export_trt.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Copyright 2024 Alibaba Inc. All Rights Reserved.
3
+ # download tensorrt from https://developer.nvidia.com/tensorrt/download/10x, check your system and cuda for compatibability
4
+ # for example for linux + cuda12.4, you can download https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.Linux.x86_64-gnu.cuda-12.4.tar.gz
5
+ TRT_DIR=<YOUR_TRT_DIR>
6
+ MODEL_DIR=<COSYVOICE2_MODEL_DIR>
7
+
8
+ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64
9
+ $TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp32.mygpu.plan --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw --outputIOFormats=fp32:chw
10
+ $TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw
tts/cosyvoice/bin/inference.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+ import torch
22
+ from torch.utils.data import DataLoader
23
+ import torchaudio
24
+ from hyperpyyaml import load_hyperpyyaml
25
+ from tqdm import tqdm
26
+ from cosyvoice.cli.model import CosyVoiceModel
27
+ from cosyvoice.dataset.dataset import Dataset
28
+
29
+
30
+ def get_args():
31
+ parser = argparse.ArgumentParser(description='inference with your model')
32
+ parser.add_argument('--config', required=True, help='config file')
33
+ parser.add_argument('--prompt_data', required=True, help='prompt data file')
34
+ parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
35
+ parser.add_argument('--tts_text', required=True, help='tts input file')
36
+ parser.add_argument('--llm_model', required=True, help='llm model file')
37
+ parser.add_argument('--flow_model', required=True, help='flow model file')
38
+ parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
39
+ parser.add_argument('--gpu',
40
+ type=int,
41
+ default=-1,
42
+ help='gpu id for this rank, -1 for cpu')
43
+ parser.add_argument('--mode',
44
+ default='sft',
45
+ choices=['sft', 'zero_shot'],
46
+ help='inference mode')
47
+ parser.add_argument('--result_dir', required=True, help='asr result file')
48
+ args = parser.parse_args()
49
+ print(args)
50
+ return args
51
+
52
+
53
+ def main():
54
+ args = get_args()
55
+ logging.basicConfig(level=logging.DEBUG,
56
+ format='%(asctime)s %(levelname)s %(message)s')
57
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
58
+
59
+ # Init cosyvoice models from configs
60
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
61
+ device = torch.device('cuda' if use_cuda else 'cpu')
62
+ with open(args.config, 'r') as f:
63
+ configs = load_hyperpyyaml(f)
64
+
65
+ model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
66
+ model.load(args.llm_model, args.flow_model, args.hifigan_model)
67
+
68
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
69
+ tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
70
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
71
+
72
+ del configs
73
+ os.makedirs(args.result_dir, exist_ok=True)
74
+ fn = os.path.join(args.result_dir, 'wav.scp')
75
+ f = open(fn, 'w')
76
+ with torch.no_grad():
77
+ for _, batch in tqdm(enumerate(test_data_loader)):
78
+ utts = batch["utts"]
79
+ assert len(utts) == 1, "inference mode only support batchsize 1"
80
+ text_token = batch["text_token"].to(device)
81
+ text_token_len = batch["text_token_len"].to(device)
82
+ tts_index = batch["tts_index"]
83
+ tts_text_token = batch["tts_text_token"].to(device)
84
+ tts_text_token_len = batch["tts_text_token_len"].to(device)
85
+ speech_token = batch["speech_token"].to(device)
86
+ speech_token_len = batch["speech_token_len"].to(device)
87
+ speech_feat = batch["speech_feat"].to(device)
88
+ speech_feat_len = batch["speech_feat_len"].to(device)
89
+ utt_embedding = batch["utt_embedding"].to(device)
90
+ spk_embedding = batch["spk_embedding"].to(device)
91
+ if args.mode == 'sft':
92
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
93
+ 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
94
+ else:
95
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
96
+ 'prompt_text': text_token, 'prompt_text_len': text_token_len,
97
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
98
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
99
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
100
+ 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
101
+ tts_speeches = []
102
+ for model_output in model.tts(**model_input):
103
+ tts_speeches.append(model_output['tts_speech'])
104
+ tts_speeches = torch.concat(tts_speeches, dim=1)
105
+ tts_key = '{}_{}'.format(utts[0], tts_index[0])
106
+ tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
107
+ torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
108
+ f.write('{} {}\n'.format(tts_key, tts_fn))
109
+ f.flush()
110
+ f.close()
111
+ logging.info('Result wav.scp saved in {}'.format(fn))
112
+
113
+
114
+ if __name__ == '__main__':
115
+ main()
tts/cosyvoice/bin/train.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+ import argparse
17
+ import datetime
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ from copy import deepcopy
21
+ import os
22
+ import torch
23
+ import torch.distributed as dist
24
+ import deepspeed
25
+
26
+ from hyperpyyaml import load_hyperpyyaml
27
+
28
+ from torch.distributed.elastic.multiprocessing.errors import record
29
+
30
+ from cosyvoice.utils.executor import Executor
31
+ from cosyvoice.utils.train_utils import (
32
+ init_distributed,
33
+ init_dataset_and_dataloader,
34
+ init_optimizer_and_scheduler,
35
+ init_summarywriter, save_model,
36
+ wrap_cuda_model, check_modify_and_save_config)
37
+
38
+
39
+ def get_args():
40
+ parser = argparse.ArgumentParser(description='training your network')
41
+ parser.add_argument('--train_engine',
42
+ default='torch_ddp',
43
+ choices=['torch_ddp', 'deepspeed'],
44
+ help='Engine for paralleled training')
45
+ parser.add_argument('--model', required=True, help='model which will be trained')
46
+ parser.add_argument('--config', required=True, help='config file')
47
+ parser.add_argument('--train_data', required=True, help='train data file')
48
+ parser.add_argument('--cv_data', required=True, help='cv data file')
49
+ parser.add_argument('--checkpoint', help='checkpoint model')
50
+ parser.add_argument('--model_dir', required=True, help='save model dir')
51
+ parser.add_argument('--tensorboard_dir',
52
+ default='tensorboard',
53
+ help='tensorboard log dir')
54
+ parser.add_argument('--ddp.dist_backend',
55
+ dest='dist_backend',
56
+ default='nccl',
57
+ choices=['nccl', 'gloo'],
58
+ help='distributed backend')
59
+ parser.add_argument('--num_workers',
60
+ default=0,
61
+ type=int,
62
+ help='num of subprocess workers for reading')
63
+ parser.add_argument('--prefetch',
64
+ default=100,
65
+ type=int,
66
+ help='prefetch number')
67
+ parser.add_argument('--pin_memory',
68
+ action='store_true',
69
+ default=False,
70
+ help='Use pinned memory buffers used for reading')
71
+ parser.add_argument('--use_amp',
72
+ action='store_true',
73
+ default=False,
74
+ help='Use automatic mixed precision training')
75
+ parser.add_argument('--deepspeed.save_states',
76
+ dest='save_states',
77
+ default='model_only',
78
+ choices=['model_only', 'model+optimizer'],
79
+ help='save model/optimizer states')
80
+ parser.add_argument('--timeout',
81
+ default=60,
82
+ type=int,
83
+ help='timeout (in seconds) of cosyvoice_join.')
84
+ parser = deepspeed.add_config_arguments(parser)
85
+ args = parser.parse_args()
86
+ return args
87
+
88
+
89
+ @record
90
+ def main():
91
+ args = get_args()
92
+ logging.basicConfig(level=logging.DEBUG,
93
+ format='%(asctime)s %(levelname)s %(message)s')
94
+ # gan train has some special initialization logic
95
+ gan = True if args.model == 'hifigan' else False
96
+
97
+ override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
98
+ if gan is True:
99
+ override_dict.pop('hift')
100
+ with open(args.config, 'r') as f:
101
+ configs = load_hyperpyyaml(f, overrides=override_dict)
102
+ if gan is True:
103
+ configs['train_conf'] = configs['train_conf_gan']
104
+ configs['train_conf'].update(vars(args))
105
+
106
+ # Init env for ddp
107
+ init_distributed(args)
108
+
109
+ # Get dataset & dataloader
110
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
111
+ init_dataset_and_dataloader(args, configs, gan)
112
+
113
+ # Do some sanity checks and save config to arsg.model_dir
114
+ configs = check_modify_and_save_config(args, configs)
115
+
116
+ # Tensorboard summary
117
+ writer = init_summarywriter(args)
118
+
119
+ # load checkpoint
120
+ model = configs[args.model]
121
+ start_step, start_epoch = 0, -1
122
+ if args.checkpoint is not None:
123
+ if os.path.exists(args.checkpoint):
124
+ state_dict = torch.load(args.checkpoint, map_location='cpu')
125
+ model.load_state_dict(state_dict, strict=False)
126
+ if 'step' in state_dict:
127
+ start_step = state_dict['step']
128
+ if 'epoch' in state_dict:
129
+ start_epoch = state_dict['epoch']
130
+ else:
131
+ logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
132
+
133
+ # Dispatch model from cpu to gpu
134
+ model = wrap_cuda_model(args, model)
135
+
136
+ # Get optimizer & scheduler
137
+ model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
138
+ scheduler.set_step(start_step)
139
+ if scheduler_d is not None:
140
+ scheduler_d.set_step(start_step)
141
+
142
+ # Save init checkpoints
143
+ info_dict = deepcopy(configs['train_conf'])
144
+ info_dict['step'] = start_step
145
+ info_dict['epoch'] = start_epoch
146
+ save_model(model, 'init', info_dict)
147
+
148
+ # Get executor
149
+ executor = Executor(gan=gan)
150
+ executor.step = start_step
151
+
152
+ # Init scaler, used for pytorch amp mixed precision training
153
+ scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
154
+ print('start step {} start epoch {}'.format(start_step, start_epoch))
155
+ # Start training loop
156
+ for epoch in range(start_epoch + 1, info_dict['max_epoch']):
157
+ executor.epoch = epoch
158
+ train_dataset.set_epoch(epoch)
159
+ dist.barrier()
160
+ group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
161
+ if gan is True:
162
+ executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
163
+ writer, info_dict, scaler, group_join)
164
+ else:
165
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
166
+ dist.destroy_process_group(group_join)
167
+
168
+
169
+ if __name__ == '__main__':
170
+ main()
tts/cosyvoice/cli/__init__.py ADDED
File without changes
tts/cosyvoice/cli/cosyvoice.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import time
16
+ from typing import Generator
17
+ from tqdm import tqdm
18
+ from hyperpyyaml import load_hyperpyyaml
19
+ from modelscope import snapshot_download
20
+ import torch
21
+ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
22
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
23
+ from cosyvoice.utils.file_utils import logging
24
+ from cosyvoice.utils.class_utils import get_model_type
25
+
26
+
27
+ class CosyVoice:
28
+
29
+ def __init__(self, model_dir,gpu_id=0, load_jit=False, load_trt=False, fp16=False):
30
+ self.instruct = True if '-Instruct' in model_dir else False
31
+ self.model_dir = model_dir
32
+ self.fp16 = fp16
33
+ if not os.path.exists(model_dir):
34
+ model_dir = snapshot_download(model_dir)
35
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
36
+ configs = load_hyperpyyaml(f)
37
+ assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
38
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
39
+ configs['feat_extractor'],
40
+ '{}/campplus.onnx'.format(model_dir),
41
+ '{}/speech_tokenizer_v1.onnx'.format(model_dir),
42
+ '{}/spk2info.pt'.format(model_dir),
43
+ configs['allowed_special'],
44
+ gpu_id=gpu_id)
45
+ self.sample_rate = configs['sample_rate']
46
+ if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
47
+ load_jit, load_trt, fp16 = False, False, False
48
+ logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
49
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16, gpu_id=gpu_id)
50
+ self.model.load('{}/llm.pt'.format(model_dir),
51
+ '{}/flow.pt'.format(model_dir),
52
+ '{}/hift.pt'.format(model_dir))
53
+ if load_jit:
54
+ self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
55
+ '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
56
+ '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
57
+ if load_trt:
58
+ self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
59
+ '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
60
+ self.fp16)
61
+ del configs
62
+
63
+ def list_available_spks(self):
64
+ spks = list(self.frontend.spk2info.keys())
65
+ return spks
66
+
67
+ def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
68
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
69
+ model_input = self.frontend.frontend_sft(i, spk_id)
70
+ start_time = time.time()
71
+ logging.info('synthesis text {}'.format(i))
72
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
73
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
74
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
75
+ yield model_output
76
+ start_time = time.time()
77
+
78
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True, token_list=None):
79
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
80
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
81
+ if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
82
+ logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
83
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
84
+ start_time = time.time()
85
+ logging.info('synthesis text {}'.format(i))
86
+ # import pdb;pdb.set_trace()
87
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed, token_list=token_list):
88
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
89
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
90
+ return model_output
91
+
92
+
93
+ def inference_zero_shot_gxl(self,tts_text, prompt_text,prompt_speech_16k, stream=False, speed=1.0, text_frontend=True, token_list=None):
94
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
95
+ input_text = self.frontend.text_normalize(tts_text, split=False, text_frontend=text_frontend)
96
+ model_input = self.frontend.frontend_zero_shot(input_text, prompt_text, prompt_speech_16k, self.sample_rate)
97
+ start_time = time.time()
98
+ model_output = self.model.tts_gxl(**model_input, stream=stream, speed=speed, token_list=token_list)
99
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
100
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
101
+ return model_output
102
+
103
+
104
+ def inference_zero_shot_gz_22k(self,tts_text, prompt_text,prompt_speech_22k, stream=False, speed=1.0, text_frontend=True, token_list=None):
105
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
106
+ input_text = self.frontend.text_normalize(tts_text, split=False, text_frontend=text_frontend)
107
+ model_input = self.frontend.frontend_zero_shot_22k(input_text, prompt_text, prompt_speech_22k, self.sample_rate)
108
+ start_time = time.time()
109
+ model_output = self.model.tts_gxl(**model_input, stream=stream, speed=speed, token_list=token_list)
110
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
111
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
112
+ return model_output
113
+
114
+
115
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
116
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
117
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
118
+ start_time = time.time()
119
+ logging.info('synthesis text {}'.format(i))
120
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
121
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
122
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
123
+ yield model_output
124
+ start_time = time.time()
125
+
126
+ def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
127
+ assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
128
+ if self.instruct is False:
129
+ raise ValueError('{} do not support instruct inference'.format(self.model_dir))
130
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
131
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
132
+ model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
133
+ start_time = time.time()
134
+ logging.info('synthesis text {}'.format(i))
135
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
136
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
137
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
138
+ yield model_output
139
+ start_time = time.time()
140
+
141
+ def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
142
+ model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
143
+ start_time = time.time()
144
+ for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
145
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
146
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
147
+ yield model_output
148
+ start_time = time.time()
149
+
150
+
151
+ class CosyVoice2(CosyVoice):
152
+
153
+ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
154
+ self.instruct = True if '-Instruct' in model_dir else False
155
+ self.model_dir = model_dir
156
+ self.fp16 = fp16
157
+ if not os.path.exists(model_dir):
158
+ model_dir = snapshot_download(model_dir)
159
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
160
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
161
+ assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
162
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
163
+ configs['feat_extractor'],
164
+ '{}/campplus.onnx'.format(model_dir),
165
+ '{}/speech_tokenizer_v2.onnx'.format(model_dir),
166
+ '{}/spk2info.pt'.format(model_dir),
167
+ configs['allowed_special'])
168
+ self.sample_rate = configs['sample_rate']
169
+ if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
170
+ load_jit, load_trt, fp16 = False, False, False
171
+ logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
172
+ self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
173
+ self.model.load('{}/llm.pt'.format(model_dir),
174
+ '{}/flow.pt'.format(model_dir),
175
+ '{}/hift.pt'.format(model_dir))
176
+ if load_jit:
177
+ self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
178
+ if load_trt:
179
+ self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
180
+ '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
181
+ self.fp16)
182
+ del configs
183
+
184
+ def inference_instruct(self, *args, **kwargs):
185
+ raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
186
+
187
+ def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
188
+ assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
189
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
190
+ model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
191
+ start_time = time.time()
192
+ logging.info('synthesis text {}'.format(i))
193
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
194
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
195
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
196
+ yield model_output
197
+ start_time = time.time()
tts/cosyvoice/cli/frontend.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import partial
15
+ from typing import Generator
16
+ import json
17
+ import onnxruntime
18
+ import torch
19
+ import numpy as np
20
+ import whisper
21
+ from typing import Callable
22
+ import torchaudio.compliance.kaldi as kaldi
23
+ import torchaudio
24
+ import os
25
+ import re
26
+ import inflect
27
+ try:
28
+ import ttsfrd
29
+ use_ttsfrd = True
30
+ except ImportError:
31
+ print("failed to import ttsfrd, use WeTextProcessing instead")
32
+ # from tn.chinese.normalizer import Normalizer as ZhNormalizer
33
+ # from tn.english.normalizer import Normalizer as EnNormalizer
34
+ use_ttsfrd = False
35
+ from cosyvoice.utils.file_utils import logging
36
+ from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
37
+ is_npu = True
38
+ try:
39
+ import torch_npu
40
+ except ImportError:
41
+ is_npu = False
42
+ print("failed to import torch_npu")
43
+
44
+ class CosyVoiceFrontEnd:
45
+
46
+ def __init__(self,
47
+ get_tokenizer: Callable,
48
+ feat_extractor: Callable,
49
+ campplus_model: str,
50
+ speech_tokenizer_model: str,
51
+ spk2info: str = '',
52
+ allowed_special: str = 'all',
53
+ gpu_id: int = 0):
54
+ self.tokenizer = get_tokenizer()
55
+ self.feat_extractor = feat_extractor
56
+ if is_npu:
57
+ self.device = torch.device(f'npu:{gpu_id}')
58
+ else:
59
+ self.device = torch.device(f'cuda:{gpu_id}')
60
+ option = onnxruntime.SessionOptions()
61
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
62
+ option.intra_op_num_threads = 1
63
+ self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
64
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
65
+ providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
66
+ "CPUExecutionProvider"])
67
+ if os.path.exists(spk2info):
68
+ self.spk2info = torch.load(spk2info, map_location=self.device)
69
+ else:
70
+ self.spk2info = {}
71
+ self.allowed_special = allowed_special
72
+ self.use_ttsfrd = use_ttsfrd
73
+ if self.use_ttsfrd:
74
+ self.frd = ttsfrd.TtsFrontendEngine()
75
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
76
+ assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
77
+ 'failed to initialize ttsfrd resource'
78
+ self.frd.set_lang_type('pinyinvg')
79
+ else:
80
+ self.zh_tn_model = lambda x: x #ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
81
+ self.en_tn_model = lambda x: x #EnNormalizer()
82
+ self.inflect_parser = inflect.engine()
83
+
84
+ def _extract_text_token(self, text):
85
+ if isinstance(text, Generator):
86
+ logging.info('get tts_text generator, will return _extract_text_token_generator!')
87
+ # NOTE add a dummy text_token_len for compatibility
88
+ return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
89
+ else:
90
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
91
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
92
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
93
+ return text_token, text_token_len
94
+
95
+ def _extract_text_token_generator(self, text_generator):
96
+ for text in text_generator:
97
+ text_token, _ = self._extract_text_token(text)
98
+ for i in range(text_token.shape[1]):
99
+ yield text_token[:, i: i + 1]
100
+
101
+ def _extract_speech_token(self, speech):
102
+ assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
103
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
104
+ speech_token = self.speech_tokenizer_session.run(None,
105
+ {self.speech_tokenizer_session.get_inputs()[0].name:
106
+ feat.detach().cpu().numpy(),
107
+ self.speech_tokenizer_session.get_inputs()[1].name:
108
+ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
109
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
110
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
111
+ return speech_token, speech_token_len
112
+
113
+ def _extract_spk_embedding(self, speech):
114
+ feat = kaldi.fbank(speech,
115
+ num_mel_bins=80,
116
+ dither=0,
117
+ sample_frequency=16000)
118
+ feat = feat - feat.mean(dim=0, keepdim=True)
119
+ embedding = self.campplus_session.run(None,
120
+ {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
121
+ embedding = torch.tensor([embedding]).to(self.device)
122
+ return embedding
123
+
124
+ def _extract_speech_feat(self, speech):
125
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
126
+ speech_feat = speech_feat.unsqueeze(dim=0)
127
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
128
+ return speech_feat, speech_feat_len
129
+
130
+ def text_normalize(self, text, split=True, text_frontend=True):
131
+ if isinstance(text, Generator):
132
+ logging.info('get tts_text generator, will skip text_normalize!')
133
+ return [text]
134
+ if text_frontend is False:
135
+ return [text] if split is True else text
136
+ text = text.strip()
137
+ if self.use_ttsfrd:
138
+ texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
139
+ text = ''.join(texts)
140
+ else:
141
+ if contains_chinese(text):
142
+ # text = self.zh_tn_model.normalize(text)
143
+ text = text.replace("\n", "")
144
+ text = replace_blank(text)
145
+ text = replace_corner_mark(text)
146
+ text = text.replace(".", "。")
147
+ text = text.replace(" - ", ",")
148
+ text = remove_bracket(text)
149
+ text = re.sub(r'[,,、]+$', '。', text)
150
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
151
+ token_min_n=60, merge_len=20, comma_split=False))
152
+ else:
153
+ # text = self.en_tn_model.normalize(text)
154
+ text = spell_out_number(text, self.inflect_parser)
155
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
156
+ token_min_n=60, merge_len=20, comma_split=False))
157
+ texts = [i for i in texts if not is_only_punctuation(i)]
158
+ return texts if split is True else text
159
+
160
+ def frontend_sft(self, tts_text, spk_id):
161
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
162
+ embedding = self.spk2info[spk_id]['embedding']
163
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
164
+ return model_input
165
+
166
+ def frontend_zero_shot_22k(self, tts_text, prompt_text, prompt_speech_22k, resample_rate=16000):
167
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
168
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
169
+ prompt_speech_16k = torchaudio.transforms.Resample(orig_freq=22050, new_freq=16000)(prompt_speech_22k)
170
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22k)
171
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
172
+ # if resample_rate == 16000:
173
+ # # cosyvoice2, force speech_feat % speech_token = 2
174
+ # token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
175
+ # speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
176
+ # speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
177
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
178
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
179
+ 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
180
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
181
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
182
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
183
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
184
+ return model_input
185
+
186
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
187
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
188
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
189
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
190
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
191
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
192
+ if resample_rate == 24000:
193
+ # cosyvoice2, force speech_feat % speech_token = 2
194
+ token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
195
+ speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
196
+ speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
197
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
198
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
199
+ 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
200
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
201
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
202
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
203
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
204
+ return model_input
205
+
206
+ def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
207
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
208
+ # in cross lingual mode, we remove prompt in llm
209
+ del model_input['prompt_text']
210
+ del model_input['prompt_text_len']
211
+ del model_input['llm_prompt_speech_token']
212
+ del model_input['llm_prompt_speech_token_len']
213
+ return model_input
214
+
215
+ def frontend_instruct(self, tts_text, spk_id, instruct_text):
216
+ model_input = self.frontend_sft(tts_text, spk_id)
217
+ # in instruct mode, we remove spk_embedding in llm due to information leakage
218
+ del model_input['llm_embedding']
219
+ instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
220
+ model_input['prompt_text'] = instruct_text_token
221
+ model_input['prompt_text_len'] = instruct_text_token_len
222
+ return model_input
223
+
224
+ def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
225
+ model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate)
226
+ del model_input['llm_prompt_speech_token']
227
+ del model_input['llm_prompt_speech_token_len']
228
+ return model_input
229
+
230
+ def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
231
+ prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
232
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
233
+ prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
234
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
235
+ source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
236
+ model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
237
+ 'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
238
+ 'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
239
+ 'flow_embedding': embedding}
240
+ return model_input
tts/cosyvoice/cli/model.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from typing import Generator
16
+ import torch
17
+ import numpy as np
18
+ import threading
19
+ import time
20
+ from torch.nn import functional as F
21
+ from contextlib import nullcontext
22
+ import uuid
23
+ from cosyvoice.utils.common import fade_in_out
24
+ from cosyvoice.utils.file_utils import convert_onnx_to_trt
25
+ is_npu = True
26
+ try:
27
+ import torch_npu
28
+ except ImportError:
29
+ is_npu = False
30
+ print(f'torch_npu not found, set is_npu to False')
31
+
32
+ class CosyVoiceModel:
33
+
34
+ def __init__(self,
35
+ llm: torch.nn.Module,
36
+ flow: torch.nn.Module,
37
+ hift: torch.nn.Module,
38
+ fp16: bool,
39
+ gpu_id: int = 0):
40
+ if is_npu:
41
+ self.device = torch.device(f'npu:{gpu_id}')
42
+ else:
43
+ self.device = torch.device(f'cuda:{gpu_id}')
44
+ self.llm = llm
45
+ self.flow = flow
46
+ self.hift = hift
47
+ self.fp16 = fp16
48
+ self.llm.fp16 = fp16
49
+ self.flow.fp16 = fp16
50
+ if self.fp16 is True:
51
+ self.llm.half()
52
+ self.flow.half()
53
+ self.token_min_hop_len = 2 * self.flow.input_frame_rate
54
+ self.token_max_hop_len = 4 * self.flow.input_frame_rate
55
+ self.token_overlap_len = 20
56
+ # here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability
57
+ self.flow.decoder.estimator.static_chunk_size = 0
58
+ # mel fade in out
59
+ self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
60
+ self.mel_window = np.hamming(2 * self.mel_overlap_len)
61
+ # hift cache
62
+ self.mel_cache_len = 20
63
+ self.source_cache_len = int(self.mel_cache_len * 256)
64
+ # speech fade in out
65
+ self.speech_window = np.hamming(2 * self.source_cache_len)
66
+ # rtf and decoding related
67
+ self.stream_scale_factor = 1
68
+ assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
69
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
70
+ self.lock = threading.Lock()
71
+ # dict used to store session related variable
72
+ self.tts_speech_token_dict = {}
73
+ self.llm_end_dict = {}
74
+ self.mel_overlap_dict = {}
75
+ self.flow_cache_dict = {}
76
+ self.hift_cache_dict = {}
77
+
78
+ def load(self, llm_model, flow_model, hift_model):
79
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
80
+ self.llm.to(self.device).eval()
81
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
82
+ self.flow.to(self.device).eval()
83
+ # in case hift_model is a hifigan model
84
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
85
+ self.hift.load_state_dict(hift_state_dict, strict=True)
86
+ self.hift.to(self.device).eval()
87
+
88
+ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
89
+ llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
90
+ self.llm.text_encoder = llm_text_encoder
91
+ llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
92
+ self.llm.llm = llm_llm
93
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
94
+ self.flow.encoder = flow_encoder
95
+
96
+ def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
97
+ assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
98
+ if not os.path.exists(flow_decoder_estimator_model):
99
+ convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
100
+ if os.path.getsize(flow_decoder_estimator_model) == 0:
101
+ raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
102
+ del self.flow.decoder.estimator
103
+ import tensorrt as trt
104
+ with open(flow_decoder_estimator_model, 'rb') as f:
105
+ self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
106
+ if self.flow.decoder.estimator_engine is None:
107
+ raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
108
+ self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
109
+
110
+ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
111
+ with self.llm_context:
112
+ if isinstance(text, Generator):
113
+ assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
114
+ for i in self.llm.inference_bistream(text=text,
115
+ prompt_text=prompt_text.to(self.device),
116
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
117
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
118
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
119
+ embedding=llm_embedding.to(self.device)):
120
+ self.tts_speech_token_dict[uuid].append(i)
121
+ else:
122
+ for i in self.llm.inference(text=text.to(self.device),
123
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
124
+ prompt_text=prompt_text.to(self.device),
125
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
126
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
127
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
128
+ embedding=llm_embedding.to(self.device)):
129
+ self.tts_speech_token_dict[uuid].append(i)
130
+ self.llm_end_dict[uuid] = True
131
+
132
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
133
+ tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
134
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
135
+ prompt_token=prompt_token.to(self.device),
136
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
137
+ prompt_feat=prompt_feat.to(self.device),
138
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
139
+ embedding=embedding.to(self.device),
140
+ flow_cache=self.flow_cache_dict[uuid])
141
+ self.flow_cache_dict[uuid] = flow_cache
142
+
143
+ # mel overlap fade in out
144
+ if self.mel_overlap_dict[uuid].shape[2] != 0:
145
+ tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
146
+ # append hift cache
147
+ if self.hift_cache_dict[uuid] is not None:
148
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
149
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
150
+ else:
151
+ hift_cache_source = torch.zeros(1, 1, 0)
152
+ # keep overlap mel and hift cache
153
+ if finalize is False:
154
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
155
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
156
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
157
+ if self.hift_cache_dict[uuid] is not None:
158
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
159
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
160
+ 'source': tts_source[:, :, -self.source_cache_len:],
161
+ 'speech': tts_speech[:, -self.source_cache_len:]}
162
+ tts_speech = tts_speech[:, :-self.source_cache_len]
163
+ else:
164
+ if speed != 1.0:
165
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
166
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
167
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
168
+ if self.hift_cache_dict[uuid] is not None:
169
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
170
+ return tts_speech
171
+
172
+ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
173
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
174
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
175
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
176
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0,
177
+ token_list=None, **kwargs):
178
+ # this_uuid is used to track variables related to this inference thread
179
+ this_uuid = str(uuid.uuid1())
180
+ with self.lock:
181
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
182
+ self.hift_cache_dict[this_uuid] = None
183
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
184
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
185
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
186
+ p.start()
187
+ # import pdb;pdb.set_trace()
188
+ if stream is True:
189
+ token_hop_len = self.token_min_hop_len
190
+ while True:
191
+ time.sleep(0.1)
192
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
193
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
194
+ .unsqueeze(dim=0)
195
+ # import pdb;pdb.set_trace()
196
+ gen_token = [1650, 2163, 3062, 41, 347, 754, 1705, 73, 38, 2583, 59, 1660, 1716, 28, 324, 1260, 1018, 254, 1650, 3552, 1804, 2515, 2368, 38, 1660, 3106, 848, 3250, 1611, 511, 1037, 2964, 2255, 1509, 890, 1494, 2250, 1349, 2621, 3420, 46, 2646, 2646, 3025, 2579, 393, 824, 1609, 2089, 2162, 24, 2, 3768, 1155, 343, 325, 2764, 814, 426, 1243, 2579, 3916, 20, 1611, 349, 701, 1346, 3768, 927, 3305, 8, 2099, 511, 3582, 8, 421, 1494, 2323, 2253, 3607, 692, 3929, 511, 3710, 3662, 3179, 1204, 7, 2579, 2579, 3025, 3025, 571, 540, 1509, 2786, 2548, 1404, 699, 1260, 2250, 202, 202, 84, 3458, 73, 3458, 1716, 302, 2105, 193, 974, 3761, 2893, 2250, 193, 754, 69, 69, 599, 2554, 890, 1608, 148, 1243, 480, 1, 489, 271, 1038, 1736, 1865, 3337, 569, 28, 2246, 2426, 2250, 3768, 569, 1027, 3305, 3106, 8, 3635, 269, 1854, 70, 1385, 1584, 1385, 2187, 3064, 3064, 2579, 3025, 3337, 2579, 3768]
197
+ token_list = [66, 2307, 599, 1602, 714, 1100, 1243, 2657, 349, 535, 3662, 1403, 2610, 669, 569, 49, 48, 1027, 2684, 373, 728, 728, 186, 186, 7, 2250, 754, 1346, 1289, 2691, 3740, 3082, 629, 2841, 432, 1513, 1716, 302, 3607, 3607, 692, 1609, 2579, 3025, 2513, 2513, 1043, 1043, 2704, 53, 2893, 1043, 2704, 1043, 2513, 2513, 1043, 1083, 3600, 421, 8, 8, 1256, 1243, 3278, 2932, 510, 2515, 2582, 1906, 4056, 1346, 1241, 2253, 1346, 1698, 962, 409, 1507, 1377, 2162, 10, 21, 396, 3649, 373, 728, 2513, 2513, 2513, 2513, 1865, 1893, 1712, 375, 4064, 3062, 41, 569, 3887, 1716, 472, 3830, 186, 408, 203, 3478, 3340, 800, 1243, 480, 271, 2162, 3240, 3238, 3193, 599, 2391, 1317, 1346, 269, 2253, 2209, 8, 1974, 2764, 1579, 421, 1073, 3929, 590, 31, 3898, 53, 53, 1043, 1957]
198
+ this_tts_speech_token = np.array(token_list)
199
+ this_tts_speech_token = torch.tensor(this_tts_speech_token)
200
+ # this_tts_speech_token = np.load("/home/node57_data/hkxie/4O/streaming_fm/data/s3token1/05343304771_EIjYa_VAD27_3.hubert_code.npy")
201
+ # this_tts_speech_token = torch.tensor(this_tts_speech_token)
202
+
203
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
204
+ prompt_token=flow_prompt_speech_token,
205
+ prompt_feat=prompt_speech_feat,
206
+ embedding=flow_embedding,
207
+ uuid=this_uuid,
208
+ finalize=False)
209
+ yield {'tts_speech': this_tts_speech.cpu()}
210
+ with self.lock:
211
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
212
+ # increase token_hop_len for better speech quality
213
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
214
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
215
+ break
216
+ p.join()
217
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
218
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
219
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
220
+ prompt_token=flow_prompt_speech_token,
221
+ prompt_feat=prompt_speech_feat,
222
+ embedding=flow_embedding,
223
+ uuid=this_uuid,
224
+ finalize=True)
225
+ yield {'tts_speech': this_tts_speech.cpu()}
226
+ else:
227
+ # deal with all tokens
228
+ p.join()
229
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
230
+ this_tts_speech_token = np.array(token_list)
231
+ this_tts_speech_token = torch.tensor(this_tts_speech_token)
232
+ this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0)
233
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
234
+ prompt_token=flow_prompt_speech_token,
235
+ prompt_feat=prompt_speech_feat,
236
+ embedding=flow_embedding,
237
+ uuid=this_uuid,
238
+ finalize=True,
239
+ speed=speed)
240
+ yield {'tts_speech': this_tts_speech.cpu()}
241
+ with self.lock:
242
+ self.tts_speech_token_dict.pop(this_uuid)
243
+ self.llm_end_dict.pop(this_uuid)
244
+ self.mel_overlap_dict.pop(this_uuid)
245
+ self.hift_cache_dict.pop(this_uuid)
246
+ self.flow_cache_dict.pop(this_uuid)
247
+ torch.cuda.empty_cache()
248
+
249
+ def tts_gxl(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
250
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
251
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
252
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
253
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0,
254
+ token_list=None, **kwargs):
255
+ # this_uuid is used to track variables related to this inference thread
256
+ this_uuid = str(uuid.uuid1())
257
+ with self.lock:
258
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
259
+ self.hift_cache_dict[this_uuid] = None
260
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
261
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
262
+ # p = threading.Thread(target=self.llm_job,
263
+ # args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
264
+ # p.start()
265
+ # p.join()
266
+ # this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
267
+ this_tts_speech_token = np.array(token_list)
268
+ this_tts_speech_token = torch.tensor(this_tts_speech_token)
269
+ this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0)
270
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
271
+ prompt_token=flow_prompt_speech_token,
272
+ prompt_feat=prompt_speech_feat,
273
+ embedding=flow_embedding,
274
+ uuid=this_uuid,
275
+ finalize=True,
276
+ speed=speed)
277
+ torch.cuda.empty_cache()
278
+ with self.lock:
279
+ self.tts_speech_token_dict.pop(this_uuid)
280
+ self.llm_end_dict.pop(this_uuid)
281
+ self.mel_overlap_dict.pop(this_uuid)
282
+ self.hift_cache_dict.pop(this_uuid)
283
+ self.flow_cache_dict.pop(this_uuid)
284
+ return {'tts_speech': this_tts_speech.cpu()}
285
+
286
+ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
287
+ # this_uuid is used to track variables related to this inference thread
288
+ this_uuid = str(uuid.uuid1())
289
+ with self.lock:
290
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
291
+ self.hift_cache_dict[this_uuid] = None
292
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
293
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
294
+ if stream is True:
295
+ token_hop_len = self.token_min_hop_len
296
+ while True:
297
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
298
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
299
+ .unsqueeze(dim=0)
300
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
301
+ prompt_token=flow_prompt_speech_token,
302
+ prompt_feat=prompt_speech_feat,
303
+ embedding=flow_embedding,
304
+ uuid=this_uuid,
305
+ finalize=False)
306
+ yield {'tts_speech': this_tts_speech.cpu()}
307
+ with self.lock:
308
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
309
+ # increase token_hop_len for better speech quality
310
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
311
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
312
+ break
313
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
314
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
315
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
316
+ prompt_token=flow_prompt_speech_token,
317
+ prompt_feat=prompt_speech_feat,
318
+ embedding=flow_embedding,
319
+ uuid=this_uuid,
320
+ finalize=True)
321
+ yield {'tts_speech': this_tts_speech.cpu()}
322
+ else:
323
+ # deal with all tokens
324
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
325
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
326
+ prompt_token=flow_prompt_speech_token,
327
+ prompt_feat=prompt_speech_feat,
328
+ embedding=flow_embedding,
329
+ uuid=this_uuid,
330
+ finalize=True,
331
+ speed=speed)
332
+ yield {'tts_speech': this_tts_speech.cpu()}
333
+ with self.lock:
334
+ self.tts_speech_token_dict.pop(this_uuid)
335
+ self.llm_end_dict.pop(this_uuid)
336
+ self.mel_overlap_dict.pop(this_uuid)
337
+ self.hift_cache_dict.pop(this_uuid)
338
+ torch.cuda.empty_cache()
339
+
340
+
341
+ class CosyVoice2Model(CosyVoiceModel):
342
+
343
+ def __init__(self,
344
+ llm: torch.nn.Module,
345
+ flow: torch.nn.Module,
346
+ hift: torch.nn.Module,
347
+ fp16: bool):
348
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
349
+ self.llm = llm
350
+ self.flow = flow
351
+ self.hift = hift
352
+ self.fp16 = fp16
353
+ self.llm.fp16 = fp16
354
+ self.flow.fp16 = fp16
355
+ if self.fp16 is True:
356
+ self.llm.half()
357
+ self.flow.half()
358
+ self.token_hop_len = 2 * self.flow.input_frame_rate
359
+ # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
360
+ self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
361
+ self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
362
+ # hift cache
363
+ self.mel_cache_len = 8
364
+ self.source_cache_len = int(self.mel_cache_len * 480)
365
+ # speech fade in out
366
+ self.speech_window = np.hamming(2 * self.source_cache_len)
367
+ # rtf and decoding related
368
+ self.stream_scale_factor = 1
369
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
370
+ self.lock = threading.Lock()
371
+ # dict used to store session related variable
372
+ self.tts_speech_token_dict = {}
373
+ self.llm_end_dict = {}
374
+ self.hift_cache_dict = {}
375
+
376
+ def load_jit(self, flow_encoder_model):
377
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
378
+ self.flow.encoder = flow_encoder
379
+
380
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
381
+ tts_mel, _ = self.flow.inference(token=token.to(self.device),
382
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
383
+ prompt_token=prompt_token.to(self.device),
384
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
385
+ prompt_feat=prompt_feat.to(self.device),
386
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
387
+ embedding=embedding.to(self.device),
388
+ finalize=finalize)
389
+ tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
390
+ # append hift cache
391
+ if self.hift_cache_dict[uuid] is not None:
392
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
393
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
394
+ else:
395
+ hift_cache_source = torch.zeros(1, 1, 0)
396
+ # keep overlap mel and hift cache
397
+ if finalize is False:
398
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
399
+ if self.hift_cache_dict[uuid] is not None:
400
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
401
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
402
+ 'source': tts_source[:, :, -self.source_cache_len:],
403
+ 'speech': tts_speech[:, -self.source_cache_len:]}
404
+ tts_speech = tts_speech[:, :-self.source_cache_len]
405
+ else:
406
+ if speed != 1.0:
407
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
408
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
409
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
410
+ if self.hift_cache_dict[uuid] is not None:
411
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
412
+ return tts_speech
413
+
414
+ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
415
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
416
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
417
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
418
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
419
+ # this_uuid is used to track variables related to this inference thread
420
+ this_uuid = str(uuid.uuid1())
421
+ with self.lock:
422
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
423
+ self.hift_cache_dict[this_uuid] = None
424
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
425
+ p.start()
426
+ if stream is True:
427
+ token_offset = 0
428
+ while True:
429
+ time.sleep(0.1)
430
+ if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
431
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
432
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
433
+ prompt_token=flow_prompt_speech_token,
434
+ prompt_feat=prompt_speech_feat,
435
+ embedding=flow_embedding,
436
+ uuid=this_uuid,
437
+ token_offset=token_offset,
438
+ finalize=False)
439
+ token_offset += self.token_hop_len
440
+ yield {'tts_speech': this_tts_speech.cpu()}
441
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
442
+ break
443
+ p.join()
444
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
445
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
446
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
447
+ prompt_token=flow_prompt_speech_token,
448
+ prompt_feat=prompt_speech_feat,
449
+ embedding=flow_embedding,
450
+ uuid=this_uuid,
451
+ token_offset=token_offset,
452
+ finalize=True)
453
+ yield {'tts_speech': this_tts_speech.cpu()}
454
+ else:
455
+ # deal with all tokens
456
+ p.join()
457
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
458
+ # import pdb;pdb.set_trace()
459
+ # this_tts_speech_token = np.load("/home/node57_data/hkxie/4O/streaming_fm/data/s3token2/05343304771_EIjYa_VAD27_3.hubert_code.npy")
460
+ # this_tts_speech_token = np.load("/home/node57_data/hkxie/4O/streaming_fm/data/s3token2/05343304771_EIjYa_VAD41_6.hubert_code.npy")
461
+ # token2 = [2745, 860, 393, 393, 2579, 2926, 1842, 2136, 480, 205, 3910, 3251, 73, 42, 38, 1346, 2554, 368, 40, 1660, 1660, 1055, 2597, 1712, 28, 2246, 386, 122, 38, 3607, 3818, 1098, 980, 38, 1353, 1660, 426, 1694, 1406, 511, 511, 396, 671, 2571, 2809, 2385, 3947, 229, 2000, 773, 2786, 858, 2554, 701, 46, 2646, 1608, 2890, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 3, 31, 758, 3438, 3438, 3438, 54, 269, 2246, 343, 1600, 1608, 3554, 3649, 60, 511, 701, 44, 3554, 3775, 20, 2099, 535, 2099, 3545, 3267, 1223, 1650, 3607, 3611, 2646, 3545, 3545, 802, 802, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 3, 26, 1734, 571, 1240, 1509, 2786, 1509, 740, 890, 2426, 1241, 1241, 2399, 2, 3458, 2285, 25, 2105, 4082, 3761, 3121, 3121, 269, 4082, 1353, 2285, 463, 758, 1193, 421, 3662, 148, 1516, 101, 32, 615, 1660, 1038, 2597, 3554, 28, 2246, 2426, 1241, 22, 1406, 70, 2230, 2230, 3635, 302, 2537, 1385, 1385, 1385, 69, 754, 3489, 1055, 393, 393, 393, 393, 393, 393, 393, 393]
462
+
463
+ # token_list3 = [2745, 599, 3238, 2554, 84, 73, 42, 2582, 2583, 4082, 1660, 1584, 1469, 1712, 2243, 1260, 1688, 269, 409, 3552, 1584, 2646, 38, 2385, 1660, 1038, 1516, 85, 3250, 1611, 109, 3611, 2255, 3947, 229, 451, 2786, 1044, 2621, 4056, 2646, 2646, 2890, 31, 3898, 3898, 2893, 2893, 2893, 2893, 1043, 52, 52, 52, 52, 1504, 2307, 202, 229, 358, 358, 266, 2907, 1516, 2246, 343, 1030, 122, 2409, 1694, 1406, 511, 2209, 51, 927, 1185, 1256, 1879, 2890, 2858, 203, 2426, 2253, 69, 3011, 3611, 2515, 2646, 492, 3662, 1608, 7, 31, 1406, 1406, 2893, 1043, 728, 380, 380, 571, 2385, 229, 740, 3193, 358, 202, 3331, 2, 1796, 35, 2285, 1893, 1516, 329, 3761, 2859, 122, 1241, 329, 1906, 59, 460, 463, 2554, 740, 1608, 60, 1516, 101, 1, 489, 1038, 1038, 3337, 3768, 569, 32, 1494, 2250, 3768, 3649, 20, 351, 1404, 1193, 44, 59, 3607, 2174, 1584, 1584, 1584, 1655, 1736, 1043, 1043, 1469, 569, 28, 2000, 2426, 2250, 3768, 927, 3250, 8, 2099, 1716, 59, 792, 3106, 1385, 1385, 1385, 1385, 1385, 3947, 1507, 864, 52, 52, 52]
464
+ token_list3 = [997, 966, 3554, 1854, 714, 3761, 3741, 2426, 103, 103, 1260, 1260, 2306, 2306, 2307, 824, 792, 193, 1879, 3478, 48, 511, 3420, 1317, 1761, 599, 1002, 980, 2646, 2646, 2646, 2646, 2646, 3366, 1949, 575, 575, 26, 26, 29, 3929, 229, 3910, 568, 3265, 3768, 28, 2004, 3910, 568, 3265, 3062, 41, 927, 699, 304, 2859, 2537, 28, 3741, 2841, 1688, 3768, 28, 1155, 855, 1570, 1570, 1570, 1570, 1570, 2876, 2680, 3, 3, 3636, 1555, 2844, 409, 1040, 2515, 1640, 3121, 3153, 882, 2385, 1796, 1796, 1796, 2368, 1785, 49, 671, 3830, 3025, 2844, 2105, 1037, 1729, 2105, 3265, 103, 1346, 580, 3922, 2876, 42, 271, 59, 3106, 2680, 3830, 2704, 2105, 2815, 59, 1698, 1223, 1342, 3267, 2786, 2250, 2250, 2208, 3, 1446, 1446, 1446, 1446, 1446, 1446, 1446, 1446, 1446, 1446, 1688, 1688, 1446, 1446, 1688, 1688, 1688, 1688, 1688]
465
+ this_tts_speech_token = np.array(token_list3)
466
+ this_tts_speech_token = torch.tensor(this_tts_speech_token)
467
+ this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0)
468
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
469
+ prompt_token=flow_prompt_speech_token,
470
+ prompt_feat=prompt_speech_feat,
471
+ embedding=flow_embedding,
472
+ uuid=this_uuid,
473
+ token_offset=0,
474
+ finalize=True,
475
+ speed=speed)
476
+ yield {'tts_speech': this_tts_speech.cpu()}
477
+ with self.lock:
478
+ self.tts_speech_token_dict.pop(this_uuid)
479
+ self.llm_end_dict.pop(this_uuid)
480
+ torch.cuda.empty_cache()
tts/cosyvoice/dataset/__init__.py ADDED
File without changes
tts/cosyvoice/dataset/dataset.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import random
17
+ import json
18
+ import math
19
+ from functools import partial
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ from torch.utils.data import IterableDataset
24
+ from cosyvoice.utils.file_utils import read_lists, read_json_lists
25
+
26
+
27
+ class Processor(IterableDataset):
28
+
29
+ def __init__(self, source, f, *args, **kw):
30
+ assert callable(f)
31
+ self.source = source
32
+ self.f = f
33
+ self.args = args
34
+ self.kw = kw
35
+
36
+ def set_epoch(self, epoch):
37
+ self.source.set_epoch(epoch)
38
+
39
+ def __iter__(self):
40
+ """ Return an iterator over the source dataset processed by the
41
+ given processor.
42
+ """
43
+ assert self.source is not None
44
+ assert callable(self.f)
45
+ return self.f(iter(self.source), *self.args, **self.kw)
46
+
47
+ def apply(self, f):
48
+ assert callable(f)
49
+ return Processor(self, f, *self.args, **self.kw)
50
+
51
+
52
+ class DistributedSampler:
53
+
54
+ def __init__(self, shuffle=True, partition=True):
55
+ self.epoch = -1
56
+ self.update()
57
+ self.shuffle = shuffle
58
+ self.partition = partition
59
+
60
+ def update(self):
61
+ assert dist.is_available()
62
+ if dist.is_initialized():
63
+ self.rank = dist.get_rank()
64
+ self.world_size = dist.get_world_size()
65
+ else:
66
+ self.rank = 0
67
+ self.world_size = 1
68
+ worker_info = torch.utils.data.get_worker_info()
69
+ if worker_info is None:
70
+ self.worker_id = 0
71
+ self.num_workers = 1
72
+ else:
73
+ self.worker_id = worker_info.id
74
+ self.num_workers = worker_info.num_workers
75
+ return dict(rank=self.rank,
76
+ world_size=self.world_size,
77
+ worker_id=self.worker_id,
78
+ num_workers=self.num_workers)
79
+
80
+ def set_epoch(self, epoch):
81
+ self.epoch = epoch
82
+
83
+ def sample(self, data):
84
+ """ Sample data according to rank/world_size/num_workers
85
+
86
+ Args:
87
+ data(List): input data list
88
+
89
+ Returns:
90
+ List: data list after sample
91
+ """
92
+ data = list(range(len(data)))
93
+ # force datalist even
94
+ if self.partition:
95
+ if self.shuffle:
96
+ random.Random(self.epoch).shuffle(data)
97
+ if len(data) < self.world_size:
98
+ data = data * math.ceil(self.world_size / len(data))
99
+ data = data[:self.world_size]
100
+ data = data[self.rank::self.world_size]
101
+ if len(data) < self.num_workers:
102
+ data = data * math.ceil(self.num_workers / len(data))
103
+ data = data[:self.num_workers]
104
+ data = data[self.worker_id::self.num_workers]
105
+ return data
106
+
107
+
108
+ class DataList(IterableDataset):
109
+
110
+ def __init__(self, lists, shuffle=True, partition=True):
111
+ self.lists = lists
112
+ self.sampler = DistributedSampler(shuffle, partition)
113
+
114
+ def set_epoch(self, epoch):
115
+ self.sampler.set_epoch(epoch)
116
+
117
+ def __iter__(self):
118
+ sampler_info = self.sampler.update()
119
+ indexes = self.sampler.sample(self.lists)
120
+ for index in indexes:
121
+ data = dict(src=self.lists[index])
122
+ data.update(sampler_info)
123
+ yield data
124
+
125
+
126
+ def Dataset(data_list_file,
127
+ data_pipeline,
128
+ mode='train',
129
+ gan=False,
130
+ shuffle=True,
131
+ partition=True,
132
+ tts_file='',
133
+ prompt_utt2data=''):
134
+ """ Construct dataset from arguments
135
+
136
+ We have two shuffle stage in the Dataset. The first is global
137
+ shuffle at shard tar/raw file level. The second is global shuffle
138
+ at training samples level.
139
+
140
+ Args:
141
+ data_type(str): raw/shard
142
+ tokenizer (BaseTokenizer): tokenizer to tokenize
143
+ partition(bool): whether to do data partition in terms of rank
144
+ """
145
+ assert mode in ['train', 'inference']
146
+ lists = read_lists(data_list_file)
147
+ if mode == 'inference':
148
+ with open(tts_file) as f:
149
+ tts_data = json.load(f)
150
+ utt2lists = read_json_lists(prompt_utt2data)
151
+ # filter unnecessary file in inference mode
152
+ lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
153
+ dataset = DataList(lists,
154
+ shuffle=shuffle,
155
+ partition=partition)
156
+ if mode == 'inference':
157
+ # map partial arg to parquet_opener func in inference mode
158
+ data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
159
+ if gan is True:
160
+ # map partial arg to padding func in gan mode
161
+ data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
162
+ for func in data_pipeline:
163
+ dataset = Processor(dataset, func, mode=mode)
164
+ return dataset
tts/cosyvoice/dataset/processor.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+
17
+ import pyarrow.parquet as pq
18
+ from io import BytesIO
19
+ import torch
20
+ import torchaudio
21
+ from torch.nn.utils.rnn import pad_sequence
22
+ import torch.nn.functional as F
23
+ import pyworld as pw
24
+
25
+
26
+ AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
27
+
28
+
29
+ def parquet_opener(data, mode='train', tts_data={}):
30
+ """ Give url or local file, return file descriptor
31
+ Inplace operation.
32
+
33
+ Args:
34
+ data(Iterable[str]): url or local file list
35
+
36
+ Returns:
37
+ Iterable[{src, stream}]
38
+ """
39
+ for sample in data:
40
+ assert 'src' in sample
41
+ url = sample['src']
42
+ try:
43
+ for df in pq.ParquetFile(url).iter_batches(batch_size=64):
44
+ df = df.to_pandas()
45
+ for i in range(len(df)):
46
+ if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
47
+ continue
48
+ sample.update(dict(df.loc[i]))
49
+ if mode == 'train':
50
+ # NOTE do not return sample directly, must initialize a new dict
51
+ yield {**sample}
52
+ else:
53
+ for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
54
+ yield {**sample, 'tts_index': index, 'tts_text': text}
55
+ except Exception as ex:
56
+ logging.warning('Failed to open {}, ex info {}'.format(url, ex))
57
+
58
+
59
+ def filter(data,
60
+ max_length=10240,
61
+ min_length=10,
62
+ token_max_length=200,
63
+ token_min_length=1,
64
+ min_output_input_ratio=0.0005,
65
+ max_output_input_ratio=1,
66
+ mode='train'):
67
+ """ Filter sample according to feature and label length
68
+ Inplace operation.
69
+
70
+ Args::
71
+ data: Iterable[{key, wav, label, sample_rate}]
72
+ max_length: drop utterance which is greater than max_length(10ms)
73
+ min_length: drop utterance which is less than min_length(10ms)
74
+ token_max_length: drop utterance which is greater than
75
+ token_max_length, especially when use char unit for
76
+ english modeling
77
+ token_min_length: drop utterance which is
78
+ less than token_max_length
79
+ min_output_input_ratio: minimal ration of
80
+ token_length / feats_length(10ms)
81
+ max_output_input_ratio: maximum ration of
82
+ token_length / feats_length(10ms)
83
+
84
+ Returns:
85
+ Iterable[{key, wav, label, sample_rate}]
86
+ """
87
+ for sample in data:
88
+ sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
89
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
90
+ del sample['audio_data']
91
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
92
+ num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
93
+ if num_frames < min_length:
94
+ continue
95
+ if num_frames > max_length:
96
+ continue
97
+ if len(sample['text_token']) < token_min_length:
98
+ continue
99
+ if len(sample['text_token']) > token_max_length:
100
+ continue
101
+ if len(sample['speech_token']) == 0:
102
+ continue
103
+ if num_frames != 0:
104
+ if len(sample['text_token']) / num_frames < min_output_input_ratio:
105
+ continue
106
+ if len(sample['text_token']) / num_frames > max_output_input_ratio:
107
+ continue
108
+ yield sample
109
+
110
+
111
+ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
112
+ """ Resample data.
113
+ Inplace operation.
114
+
115
+ Args:
116
+ data: Iterable[{key, wav, label, sample_rate}]
117
+ resample_rate: target resample rate
118
+
119
+ Returns:
120
+ Iterable[{key, wav, label, sample_rate}]
121
+ """
122
+ for sample in data:
123
+ assert 'sample_rate' in sample
124
+ assert 'speech' in sample
125
+ sample_rate = sample['sample_rate']
126
+ waveform = sample['speech']
127
+ if sample_rate != resample_rate:
128
+ if sample_rate < min_sample_rate:
129
+ continue
130
+ sample['sample_rate'] = resample_rate
131
+ sample['speech'] = torchaudio.transforms.Resample(
132
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
133
+ max_val = sample['speech'].abs().max()
134
+ if max_val > 1:
135
+ sample['speech'] /= max_val
136
+ yield sample
137
+
138
+
139
+ def truncate(data, truncate_length=24576, mode='train'):
140
+ """ Truncate data.
141
+
142
+ Args:
143
+ data: Iterable[{key, wav, label, sample_rate}]
144
+ truncate_length: truncate length
145
+
146
+ Returns:
147
+ Iterable[{key, wav, label, sample_rate}]
148
+ """
149
+ for sample in data:
150
+ waveform = sample['speech']
151
+ if waveform.shape[1] > truncate_length:
152
+ start = random.randint(0, waveform.shape[1] - truncate_length)
153
+ waveform = waveform[:, start: start + truncate_length]
154
+ else:
155
+ waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
156
+ sample['speech'] = waveform
157
+ yield sample
158
+
159
+
160
+ def compute_fbank(data,
161
+ feat_extractor,
162
+ mode='train'):
163
+ """ Extract fbank
164
+
165
+ Args:
166
+ data: Iterable[{key, wav, label, sample_rate}]
167
+
168
+ Returns:
169
+ Iterable[{key, feat, label}]
170
+ """
171
+ for sample in data:
172
+ assert 'sample_rate' in sample
173
+ assert 'speech' in sample
174
+ assert 'utt' in sample
175
+ assert 'text_token' in sample
176
+ waveform = sample['speech']
177
+ mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
178
+ sample['speech_feat'] = mat
179
+ yield sample
180
+
181
+
182
+ def compute_f0(data, sample_rate, hop_size, mode='train'):
183
+ """ Extract f0
184
+
185
+ Args:
186
+ data: Iterable[{key, wav, label, sample_rate}]
187
+
188
+ Returns:
189
+ Iterable[{key, feat, label}]
190
+ """
191
+ frame_period = hop_size * 1000 / sample_rate
192
+ for sample in data:
193
+ assert 'sample_rate' in sample
194
+ assert 'speech' in sample
195
+ assert 'utt' in sample
196
+ assert 'text_token' in sample
197
+ waveform = sample['speech']
198
+ _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
199
+ if sum(_f0 != 0) < 5: # this happens when the algorithm fails
200
+ _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
201
+ f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
202
+ f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
203
+ sample['pitch_feat'] = f0
204
+ yield sample
205
+
206
+
207
+ def parse_embedding(data, normalize, mode='train'):
208
+ """ Parse utt_embedding/spk_embedding
209
+
210
+ Args:
211
+ data: Iterable[{key, wav, label, sample_rate}]
212
+
213
+ Returns:
214
+ Iterable[{key, feat, label}]
215
+ """
216
+ for sample in data:
217
+ sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
218
+ sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
219
+ if normalize:
220
+ sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
221
+ sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
222
+ yield sample
223
+
224
+
225
+ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
226
+ """ Decode text to chars or BPE
227
+ Inplace operation
228
+
229
+ Args:
230
+ data: Iterable[{key, wav, txt, sample_rate}]
231
+
232
+ Returns:
233
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
234
+ """
235
+ tokenizer = get_tokenizer()
236
+ for sample in data:
237
+ assert 'text' in sample
238
+ sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
239
+ if mode == 'inference':
240
+ sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
241
+ yield sample
242
+
243
+
244
+ def shuffle(data, shuffle_size=10000, mode='train'):
245
+ """ Local shuffle the data
246
+
247
+ Args:
248
+ data: Iterable[{key, feat, label}]
249
+ shuffle_size: buffer size for shuffle
250
+
251
+ Returns:
252
+ Iterable[{key, feat, label}]
253
+ """
254
+ buf = []
255
+ for sample in data:
256
+ buf.append(sample)
257
+ if len(buf) >= shuffle_size:
258
+ random.shuffle(buf)
259
+ for x in buf:
260
+ yield x
261
+ buf = []
262
+ # The sample left over
263
+ random.shuffle(buf)
264
+ for x in buf:
265
+ yield x
266
+
267
+
268
+ def sort(data, sort_size=500, mode='train'):
269
+ """ Sort the data by feature length.
270
+ Sort is used after shuffle and before batch, so we can group
271
+ utts with similar lengths into a batch, and `sort_size` should
272
+ be less than `shuffle_size`
273
+
274
+ Args:
275
+ data: Iterable[{key, feat, label}]
276
+ sort_size: buffer size for sort
277
+
278
+ Returns:
279
+ Iterable[{key, feat, label}]
280
+ """
281
+
282
+ buf = []
283
+ for sample in data:
284
+ buf.append(sample)
285
+ if len(buf) >= sort_size:
286
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
287
+ for x in buf:
288
+ yield x
289
+ buf = []
290
+ # The sample left over
291
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
292
+ for x in buf:
293
+ yield x
294
+
295
+
296
+ def static_batch(data, batch_size=16):
297
+ """ Static batch the data by `batch_size`
298
+
299
+ Args:
300
+ data: Iterable[{key, feat, label}]
301
+ batch_size: batch size
302
+
303
+ Returns:
304
+ Iterable[List[{key, feat, label}]]
305
+ """
306
+ buf = []
307
+ for sample in data:
308
+ buf.append(sample)
309
+ if len(buf) >= batch_size:
310
+ yield buf
311
+ buf = []
312
+ if len(buf) > 0:
313
+ yield buf
314
+
315
+
316
+ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
317
+ """ Dynamic batch the data until the total frames in batch
318
+ reach `max_frames_in_batch`
319
+
320
+ Args:
321
+ data: Iterable[{key, feat, label}]
322
+ max_frames_in_batch: max_frames in one batch
323
+
324
+ Returns:
325
+ Iterable[List[{key, feat, label}]]
326
+ """
327
+ buf = []
328
+ longest_frames = 0
329
+ for sample in data:
330
+ assert 'speech_feat' in sample
331
+ assert isinstance(sample['speech_feat'], torch.Tensor)
332
+ new_sample_frames = sample['speech_feat'].size(0)
333
+ longest_frames = max(longest_frames, new_sample_frames)
334
+ frames_after_padding = longest_frames * (len(buf) + 1)
335
+ if frames_after_padding > max_frames_in_batch:
336
+ yield buf
337
+ buf = [sample]
338
+ longest_frames = new_sample_frames
339
+ else:
340
+ buf.append(sample)
341
+ if len(buf) > 0:
342
+ yield buf
343
+
344
+
345
+ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
346
+ """ Wrapper for static/dynamic batch
347
+ """
348
+ if mode == 'inference':
349
+ return static_batch(data, 1)
350
+ else:
351
+ if batch_type == 'static':
352
+ return static_batch(data, batch_size)
353
+ elif batch_type == 'dynamic':
354
+ return dynamic_batch(data, max_frames_in_batch)
355
+ else:
356
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
357
+
358
+
359
+ def padding(data, use_spk_embedding, mode='train', gan=False):
360
+ """ Padding the data into training data
361
+
362
+ Args:
363
+ data: Iterable[List[{key, feat, label}]]
364
+
365
+ Returns:
366
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
367
+ """
368
+ for sample in data:
369
+ assert isinstance(sample, list)
370
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
371
+ dtype=torch.int32)
372
+ order = torch.argsort(speech_feat_len, descending=True)
373
+
374
+ utts = [sample[i]['utt'] for i in order]
375
+ speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
376
+ speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
377
+ speech = pad_sequence(speech, batch_first=True, padding_value=0)
378
+ speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
379
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
380
+ speech_token = pad_sequence(speech_token,
381
+ batch_first=True,
382
+ padding_value=0)
383
+ speech_feat = [sample[i]['speech_feat'] for i in order]
384
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
385
+ speech_feat = pad_sequence(speech_feat,
386
+ batch_first=True,
387
+ padding_value=0)
388
+ text = [sample[i]['text'] for i in order]
389
+ text_token = [torch.tensor(sample[i]['text_token']) for i in order]
390
+ text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
391
+ text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
392
+ utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
393
+ spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
394
+ batch = {
395
+ "utts": utts,
396
+ "speech": speech,
397
+ "speech_len": speech_len,
398
+ "speech_token": speech_token,
399
+ "speech_token_len": speech_token_len,
400
+ "speech_feat": speech_feat,
401
+ "speech_feat_len": speech_feat_len,
402
+ "text": text,
403
+ "text_token": text_token,
404
+ "text_token_len": text_token_len,
405
+ "utt_embedding": utt_embedding,
406
+ "spk_embedding": spk_embedding,
407
+ }
408
+ if gan is True:
409
+ # in gan train, we need pitch_feat
410
+ pitch_feat = [sample[i]['pitch_feat'] for i in order]
411
+ pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
412
+ pitch_feat = pad_sequence(pitch_feat,
413
+ batch_first=True,
414
+ padding_value=0)
415
+ batch["pitch_feat"] = pitch_feat
416
+ batch["pitch_feat_len"] = pitch_feat_len
417
+ else:
418
+ # only gan train needs speech, delete it to save memory
419
+ del batch["speech"]
420
+ del batch["speech_len"]
421
+ if mode == 'inference':
422
+ tts_text = [sample[i]['tts_text'] for i in order]
423
+ tts_index = [sample[i]['tts_index'] for i in order]
424
+ tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
425
+ tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
426
+ tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
427
+ batch.update({'tts_text': tts_text,
428
+ 'tts_index': tts_index,
429
+ 'tts_text_token': tts_text_token,
430
+ 'tts_text_token_len': tts_text_token_len})
431
+ if use_spk_embedding is True:
432
+ batch["embedding"] = batch["spk_embedding"]
433
+ else:
434
+ batch["embedding"] = batch["utt_embedding"]
435
+ yield batch
tts/cosyvoice/flow/decoder.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from einops import pack, rearrange, repeat
18
+ from cosyvoice.utils.common import mask_to_bias
19
+ from cosyvoice.utils.mask import add_optional_chunk_mask
20
+ from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
21
+ from matcha.models.components.transformer import BasicTransformerBlock
22
+
23
+
24
+ class Transpose(torch.nn.Module):
25
+ def __init__(self, dim0: int, dim1: int):
26
+ super().__init__()
27
+ self.dim0 = dim0
28
+ self.dim1 = dim1
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ x = torch.transpose(x, self.dim0, self.dim1)
32
+ return x
33
+
34
+
35
+ class CausalBlock1D(Block1D):
36
+ def __init__(self, dim: int, dim_out: int):
37
+ super(CausalBlock1D, self).__init__(dim, dim_out)
38
+ self.block = torch.nn.Sequential(
39
+ CausalConv1d(dim, dim_out, 3),
40
+ Transpose(1, 2),
41
+ nn.LayerNorm(dim_out),
42
+ Transpose(1, 2),
43
+ nn.Mish(),
44
+ )
45
+
46
+ def forward(self, x: torch.Tensor, mask: torch.Tensor):
47
+ output = self.block(x * mask)
48
+ return output * mask
49
+
50
+
51
+ class CausalResnetBlock1D(ResnetBlock1D):
52
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
53
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
54
+ self.block1 = CausalBlock1D(dim, dim_out)
55
+ self.block2 = CausalBlock1D(dim_out, dim_out)
56
+
57
+
58
+ class CausalConv1d(torch.nn.Conv1d):
59
+ def __init__(
60
+ self,
61
+ in_channels: int,
62
+ out_channels: int,
63
+ kernel_size: int,
64
+ stride: int = 1,
65
+ dilation: int = 1,
66
+ groups: int = 1,
67
+ bias: bool = True,
68
+ padding_mode: str = 'zeros',
69
+ device=None,
70
+ dtype=None
71
+ ) -> None:
72
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
73
+ kernel_size, stride,
74
+ padding=0, dilation=dilation,
75
+ groups=groups, bias=bias,
76
+ padding_mode=padding_mode,
77
+ device=device, dtype=dtype)
78
+ assert stride == 1
79
+ self.causal_padding = (kernel_size - 1, 0)
80
+
81
+ def forward(self, x: torch.Tensor):
82
+ x = F.pad(x, self.causal_padding)
83
+ x = super(CausalConv1d, self).forward(x)
84
+ return x
85
+
86
+
87
+ class ConditionalDecoder(nn.Module):
88
+ def __init__(
89
+ self,
90
+ in_channels,
91
+ out_channels,
92
+ causal=False,
93
+ channels=(256, 256),
94
+ dropout=0.05,
95
+ attention_head_dim=64,
96
+ n_blocks=1,
97
+ num_mid_blocks=2,
98
+ num_heads=4,
99
+ act_fn="snake",
100
+ ):
101
+ """
102
+ This decoder requires an input with the same shape of the target. So, if your text content
103
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
104
+ """
105
+ super().__init__()
106
+ channels = tuple(channels)
107
+ self.in_channels = in_channels
108
+ self.out_channels = out_channels
109
+ self.causal = causal
110
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
111
+ time_embed_dim = channels[0] * 4
112
+ self.time_mlp = TimestepEmbedding(
113
+ in_channels=in_channels,
114
+ time_embed_dim=time_embed_dim,
115
+ act_fn="silu",
116
+ )
117
+ self.down_blocks = nn.ModuleList([])
118
+ self.mid_blocks = nn.ModuleList([])
119
+ self.up_blocks = nn.ModuleList([])
120
+
121
+ output_channel = in_channels
122
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
123
+ input_channel = output_channel
124
+ output_channel = channels[i]
125
+ is_last = i == len(channels) - 1
126
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
127
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
128
+ transformer_blocks = nn.ModuleList(
129
+ [
130
+ BasicTransformerBlock(
131
+ dim=output_channel,
132
+ num_attention_heads=num_heads,
133
+ attention_head_dim=attention_head_dim,
134
+ dropout=dropout,
135
+ activation_fn=act_fn,
136
+ )
137
+ for _ in range(n_blocks)
138
+ ]
139
+ )
140
+ downsample = (
141
+ Downsample1D(output_channel) if not is_last else
142
+ CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
143
+ )
144
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
145
+
146
+ for _ in range(num_mid_blocks):
147
+ input_channel = channels[-1]
148
+ out_channels = channels[-1]
149
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
150
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
151
+
152
+ transformer_blocks = nn.ModuleList(
153
+ [
154
+ BasicTransformerBlock(
155
+ dim=output_channel,
156
+ num_attention_heads=num_heads,
157
+ attention_head_dim=attention_head_dim,
158
+ dropout=dropout,
159
+ activation_fn=act_fn,
160
+ )
161
+ for _ in range(n_blocks)
162
+ ]
163
+ )
164
+
165
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
166
+
167
+ channels = channels[::-1] + (channels[0],)
168
+ for i in range(len(channels) - 1):
169
+ input_channel = channels[i] * 2
170
+ output_channel = channels[i + 1]
171
+ is_last = i == len(channels) - 2
172
+ resnet = CausalResnetBlock1D(
173
+ dim=input_channel,
174
+ dim_out=output_channel,
175
+ time_emb_dim=time_embed_dim,
176
+ ) if self.causal else ResnetBlock1D(
177
+ dim=input_channel,
178
+ dim_out=output_channel,
179
+ time_emb_dim=time_embed_dim,
180
+ )
181
+ transformer_blocks = nn.ModuleList(
182
+ [
183
+ BasicTransformerBlock(
184
+ dim=output_channel,
185
+ num_attention_heads=num_heads,
186
+ attention_head_dim=attention_head_dim,
187
+ dropout=dropout,
188
+ activation_fn=act_fn,
189
+ )
190
+ for _ in range(n_blocks)
191
+ ]
192
+ )
193
+ upsample = (
194
+ Upsample1D(output_channel, use_conv_transpose=True)
195
+ if not is_last
196
+ else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
197
+ )
198
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
199
+ self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
200
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
201
+ self.initialize_weights()
202
+
203
+ def initialize_weights(self):
204
+ for m in self.modules():
205
+ if isinstance(m, nn.Conv1d):
206
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
207
+ if m.bias is not None:
208
+ nn.init.constant_(m.bias, 0)
209
+ elif isinstance(m, nn.GroupNorm):
210
+ nn.init.constant_(m.weight, 1)
211
+ nn.init.constant_(m.bias, 0)
212
+ elif isinstance(m, nn.Linear):
213
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
214
+ if m.bias is not None:
215
+ nn.init.constant_(m.bias, 0)
216
+
217
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
218
+ """Forward pass of the UNet1DConditional model.
219
+
220
+ Args:
221
+ x (torch.Tensor): shape (batch_size, in_channels, time)
222
+ mask (_type_): shape (batch_size, 1, time)
223
+ t (_type_): shape (batch_size)
224
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
225
+ cond (_type_, optional): placeholder for future use. Defaults to None.
226
+
227
+ Raises:
228
+ ValueError: _description_
229
+ ValueError: _description_
230
+
231
+ Returns:
232
+ _type_: _description_
233
+ """
234
+
235
+ t = self.time_embeddings(t).to(t.dtype)
236
+ t = self.time_mlp(t)
237
+
238
+ x = pack([x, mu], "b * t")[0]
239
+
240
+ if spks is not None:
241
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
242
+ x = pack([x, spks], "b * t")[0]
243
+ if cond is not None:
244
+ x = pack([x, cond], "b * t")[0]
245
+
246
+ hiddens = []
247
+ masks = [mask]
248
+ for resnet, transformer_blocks, downsample in self.down_blocks:
249
+ mask_down = masks[-1]
250
+ x = resnet(x, mask_down, t)
251
+ x = rearrange(x, "b c t -> b t c").contiguous()
252
+ # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
253
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
254
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
255
+ for transformer_block in transformer_blocks:
256
+ x = transformer_block(
257
+ hidden_states=x,
258
+ attention_mask=attn_mask,
259
+ timestep=t,
260
+ )
261
+ x = rearrange(x, "b t c -> b c t").contiguous()
262
+ hiddens.append(x) # Save hidden states for skip connections
263
+ x = downsample(x * mask_down)
264
+ masks.append(mask_down[:, :, ::2])
265
+ masks = masks[:-1]
266
+ mask_mid = masks[-1]
267
+
268
+ for resnet, transformer_blocks in self.mid_blocks:
269
+ x = resnet(x, mask_mid, t)
270
+ x = rearrange(x, "b c t -> b t c").contiguous()
271
+ # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
272
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
273
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
274
+ for transformer_block in transformer_blocks:
275
+ x = transformer_block(
276
+ hidden_states=x,
277
+ attention_mask=attn_mask,
278
+ timestep=t,
279
+ )
280
+ x = rearrange(x, "b t c -> b c t").contiguous()
281
+
282
+ for resnet, transformer_blocks, upsample in self.up_blocks:
283
+ mask_up = masks.pop()
284
+ skip = hiddens.pop()
285
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
286
+ x = resnet(x, mask_up, t)
287
+ x = rearrange(x, "b c t -> b t c").contiguous()
288
+ # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
289
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
290
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
291
+ for transformer_block in transformer_blocks:
292
+ x = transformer_block(
293
+ hidden_states=x,
294
+ attention_mask=attn_mask,
295
+ timestep=t,
296
+ )
297
+ x = rearrange(x, "b t c -> b c t").contiguous()
298
+ x = upsample(x * mask_up)
299
+ x = self.final_block(x, mask_up)
300
+ output = self.final_proj(x * mask_up)
301
+ return output * mask
tts/cosyvoice/flow/flow.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+ from typing import Dict, Optional
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+ from omegaconf import DictConfig
21
+ from cosyvoice.utils.mask import make_pad_mask
22
+
23
+
24
+ class MaskedDiffWithXvec(torch.nn.Module):
25
+ def __init__(self,
26
+ input_size: int = 512,
27
+ output_size: int = 80,
28
+ spk_embed_dim: int = 192,
29
+ output_type: str = "mel",
30
+ vocab_size: int = 4096,
31
+ input_frame_rate: int = 50,
32
+ only_mask_loss: bool = True,
33
+ encoder: torch.nn.Module = None,
34
+ length_regulator: torch.nn.Module = None,
35
+ decoder: torch.nn.Module = None,
36
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
37
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
38
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
39
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
40
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
41
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
42
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
43
+ super().__init__()
44
+ self.input_size = input_size
45
+ self.output_size = output_size
46
+ self.decoder_conf = decoder_conf
47
+ self.mel_feat_conf = mel_feat_conf
48
+ self.vocab_size = vocab_size
49
+ self.output_type = output_type
50
+ self.input_frame_rate = input_frame_rate
51
+ logging.info(f"input frame rate={self.input_frame_rate}")
52
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
53
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
54
+ self.encoder = encoder
55
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
56
+ self.decoder = decoder
57
+ self.length_regulator = length_regulator
58
+ self.only_mask_loss = only_mask_loss
59
+
60
+ def forward(
61
+ self,
62
+ batch: dict,
63
+ device: torch.device,
64
+ ) -> Dict[str, Optional[torch.Tensor]]:
65
+ token = batch['speech_token'].to(device)
66
+ token_len = batch['speech_token_len'].to(device)
67
+ feat = batch['speech_feat'].to(device)
68
+ feat_len = batch['speech_feat_len'].to(device)
69
+ embedding = batch['embedding'].to(device)
70
+
71
+ # xvec projection
72
+ embedding = F.normalize(embedding, dim=1)
73
+ embedding = self.spk_embed_affine_layer(embedding)
74
+
75
+ # concat text and prompt_text
76
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
77
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
78
+
79
+ # text encode
80
+ h, h_lengths = self.encoder(token, token_len)
81
+ h = self.encoder_proj(h)
82
+ h, h_lengths = self.length_regulator(h, feat_len)
83
+
84
+ # get conditions
85
+ conds = torch.zeros(feat.shape, device=token.device)
86
+ for i, j in enumerate(feat_len):
87
+ if random.random() < 0.5:
88
+ continue
89
+ index = random.randint(0, int(0.3 * j))
90
+ conds[i, :index] = feat[i, :index]
91
+ conds = conds.transpose(1, 2)
92
+
93
+ mask = (~make_pad_mask(feat_len)).to(h)
94
+ feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
95
+ loss, _ = self.decoder.compute_loss(
96
+ feat.transpose(1, 2).contiguous(),
97
+ mask.unsqueeze(1),
98
+ h.transpose(1, 2).contiguous(),
99
+ embedding,
100
+ cond=conds
101
+ )
102
+ return {'loss': loss}
103
+
104
+ @torch.inference_mode()
105
+ def inference(self,
106
+ token,
107
+ token_len,
108
+ prompt_token,
109
+ prompt_token_len,
110
+ prompt_feat,
111
+ prompt_feat_len,
112
+ embedding,
113
+ flow_cache):
114
+ if self.fp16 is True:
115
+ prompt_feat = prompt_feat.half()
116
+ embedding = embedding.half()
117
+
118
+ assert token.shape[0] == 1
119
+ # xvec projection
120
+ embedding = F.normalize(embedding, dim=1)
121
+ embedding = self.spk_embed_affine_layer(embedding)
122
+
123
+ # concat text and prompt_text
124
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
125
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
126
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
127
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
128
+
129
+ # text encode
130
+ h, h_lengths = self.encoder(token, token_len)
131
+ h = self.encoder_proj(h)
132
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
133
+ h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
134
+
135
+ # get conditions
136
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
137
+ conds[:, :mel_len1] = prompt_feat
138
+ conds = conds.transpose(1, 2)
139
+
140
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
141
+ feat, flow_cache = self.decoder(
142
+ mu=h.transpose(1, 2).contiguous(),
143
+ mask=mask.unsqueeze(1),
144
+ spks=embedding,
145
+ cond=conds,
146
+ n_timesteps=10,
147
+ prompt_len=mel_len1,
148
+ flow_cache=flow_cache
149
+ )
150
+ feat = feat[:, :, mel_len1:]
151
+ assert feat.shape[2] == mel_len2
152
+ return feat.float(), flow_cache
153
+
154
+
155
+ class CausalMaskedDiffWithXvec(torch.nn.Module):
156
+ def __init__(self,
157
+ input_size: int = 512,
158
+ output_size: int = 80,
159
+ spk_embed_dim: int = 192,
160
+ output_type: str = "mel",
161
+ vocab_size: int = 4096,
162
+ input_frame_rate: int = 50,
163
+ only_mask_loss: bool = True,
164
+ token_mel_ratio: int = 2,
165
+ pre_lookahead_len: int = 3,
166
+ encoder: torch.nn.Module = None,
167
+ decoder: torch.nn.Module = None,
168
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
169
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
170
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
171
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
172
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
173
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
174
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
175
+ super().__init__()
176
+ self.input_size = input_size
177
+ self.output_size = output_size
178
+ self.decoder_conf = decoder_conf
179
+ self.mel_feat_conf = mel_feat_conf
180
+ self.vocab_size = vocab_size
181
+ self.output_type = output_type
182
+ self.input_frame_rate = input_frame_rate
183
+ logging.info(f"input frame rate={self.input_frame_rate}")
184
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
185
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
186
+ self.encoder = encoder
187
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
188
+ self.decoder = decoder
189
+ self.only_mask_loss = only_mask_loss
190
+ self.token_mel_ratio = token_mel_ratio
191
+ self.pre_lookahead_len = pre_lookahead_len
192
+
193
+ @torch.inference_mode()
194
+ def inference(self,
195
+ token,
196
+ token_len,
197
+ prompt_token,
198
+ prompt_token_len,
199
+ prompt_feat,
200
+ prompt_feat_len,
201
+ embedding,
202
+ finalize):
203
+ if self.fp16 is True:
204
+ prompt_feat = prompt_feat.half()
205
+ embedding = embedding.half()
206
+
207
+ assert token.shape[0] == 1
208
+ # xvec projection
209
+ embedding = F.normalize(embedding, dim=1)
210
+ embedding = self.spk_embed_affine_layer(embedding)
211
+
212
+ # concat text and prompt_text
213
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
214
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
215
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
216
+
217
+ # text encode
218
+ h, h_lengths = self.encoder(token, token_len)
219
+ if finalize is False:
220
+ h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
221
+ mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
222
+ h = self.encoder_proj(h)
223
+
224
+ # get conditions
225
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
226
+ conds[:, :mel_len1] = prompt_feat
227
+ conds = conds.transpose(1, 2)
228
+
229
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
230
+ feat, _ = self.decoder(
231
+ mu=h.transpose(1, 2).contiguous(),
232
+ mask=mask.unsqueeze(1),
233
+ spks=embedding,
234
+ cond=conds,
235
+ n_timesteps=10
236
+ )
237
+ feat = feat[:, :, mel_len1:]
238
+ assert feat.shape[2] == mel_len2
239
+ return feat.float(), None
tts/cosyvoice/flow/flow_matching.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import threading
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from matcha.models.components.flow_matching import BASECFM
18
+
19
+
20
+ class ConditionalCFM(BASECFM):
21
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
22
+ super().__init__(
23
+ n_feats=in_channels,
24
+ cfm_params=cfm_params,
25
+ n_spks=n_spks,
26
+ spk_emb_dim=spk_emb_dim,
27
+ )
28
+ self.t_scheduler = cfm_params.t_scheduler
29
+ self.training_cfg_rate = cfm_params.training_cfg_rate
30
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
31
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
32
+ # Just change the architecture of the estimator here
33
+ self.estimator = estimator
34
+ self.lock = threading.Lock()
35
+
36
+ @torch.inference_mode()
37
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
38
+ """Forward diffusion
39
+
40
+ Args:
41
+ mu (torch.Tensor): output of encoder
42
+ shape: (batch_size, n_feats, mel_timesteps)
43
+ mask (torch.Tensor): output_mask
44
+ shape: (batch_size, 1, mel_timesteps)
45
+ n_timesteps (int): number of diffusion steps
46
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
47
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
48
+ shape: (batch_size, spk_emb_dim)
49
+ cond: Not used but kept for future purposes
50
+
51
+ Returns:
52
+ sample: generated mel-spectrogram
53
+ shape: (batch_size, n_feats, mel_timesteps)
54
+ """
55
+
56
+ z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
57
+ cache_size = flow_cache.shape[2]
58
+ # fix prompt and overlap part mu and z
59
+ if cache_size != 0:
60
+ z[:, :, :cache_size] = flow_cache[:, :, :, 0]
61
+ mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
62
+ z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
63
+ mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
64
+ flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
65
+
66
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
67
+ if self.t_scheduler == 'cosine':
68
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
69
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
70
+
71
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
72
+ """
73
+ Fixed euler solver for ODEs.
74
+ Args:
75
+ x (torch.Tensor): random noise
76
+ t_span (torch.Tensor): n_timesteps interpolated
77
+ shape: (n_timesteps + 1,)
78
+ mu (torch.Tensor): output of encoder
79
+ shape: (batch_size, n_feats, mel_timesteps)
80
+ mask (torch.Tensor): output_mask
81
+ shape: (batch_size, 1, mel_timesteps)
82
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
83
+ shape: (batch_size, spk_emb_dim)
84
+ cond: Not used but kept for future purposes
85
+ """
86
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
87
+ t = t.unsqueeze(dim=0)
88
+
89
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
90
+ # Or in future might add like a return_all_steps flag
91
+ sol = []
92
+
93
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
94
+ x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
95
+ mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
96
+ mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
97
+ t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
98
+ spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
99
+ cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
100
+ for step in range(1, len(t_span)):
101
+ # Classifier-Free Guidance inference introduced in VoiceBox
102
+ x_in[:] = x
103
+ mask_in[:] = mask
104
+ mu_in[0] = mu
105
+ t_in[:] = t.unsqueeze(0)
106
+ spks_in[0] = spks
107
+ cond_in[0] = cond
108
+ dphi_dt = self.forward_estimator(
109
+ x_in, mask_in,
110
+ mu_in, t_in,
111
+ spks_in,
112
+ cond_in
113
+ )
114
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
115
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
116
+ x = x + dt * dphi_dt
117
+ t = t + dt
118
+ sol.append(x)
119
+ if step < len(t_span) - 1:
120
+ dt = t_span[step + 1] - t
121
+
122
+ return sol[-1].float()
123
+
124
+ def forward_estimator(self, x, mask, mu, t, spks, cond):
125
+ if isinstance(self.estimator, torch.nn.Module):
126
+ return self.estimator.forward(x, mask, mu, t, spks, cond)
127
+ else:
128
+ with self.lock:
129
+ self.estimator.set_input_shape('x', (2, 80, x.size(2)))
130
+ self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
131
+ self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
132
+ self.estimator.set_input_shape('t', (2,))
133
+ self.estimator.set_input_shape('spks', (2, 80))
134
+ self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
135
+ # run trt engine
136
+ self.estimator.execute_v2([x.contiguous().data_ptr(),
137
+ mask.contiguous().data_ptr(),
138
+ mu.contiguous().data_ptr(),
139
+ t.contiguous().data_ptr(),
140
+ spks.contiguous().data_ptr(),
141
+ cond.contiguous().data_ptr(),
142
+ x.data_ptr()])
143
+ return x
144
+
145
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
146
+ """Computes diffusion loss
147
+
148
+ Args:
149
+ x1 (torch.Tensor): Target
150
+ shape: (batch_size, n_feats, mel_timesteps)
151
+ mask (torch.Tensor): target mask
152
+ shape: (batch_size, 1, mel_timesteps)
153
+ mu (torch.Tensor): output of encoder
154
+ shape: (batch_size, n_feats, mel_timesteps)
155
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
156
+ shape: (batch_size, spk_emb_dim)
157
+
158
+ Returns:
159
+ loss: conditional flow matching loss
160
+ y: conditional flow
161
+ shape: (batch_size, n_feats, mel_timesteps)
162
+ """
163
+ b, _, t = mu.shape
164
+
165
+ # random timestep
166
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
167
+ if self.t_scheduler == 'cosine':
168
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
169
+ # sample noise p(x_0)
170
+ z = torch.randn_like(x1)
171
+
172
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
173
+ u = x1 - (1 - self.sigma_min) * z
174
+
175
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
176
+ if self.training_cfg_rate > 0:
177
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
178
+ mu = mu * cfg_mask.view(-1, 1, 1)
179
+ spks = spks * cfg_mask.view(-1, 1)
180
+ cond = cond * cfg_mask.view(-1, 1, 1)
181
+
182
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
183
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
184
+ return loss, y
185
+
186
+
187
+ class CausalConditionalCFM(ConditionalCFM):
188
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
189
+ super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
190
+ self.rand_noise = torch.randn([1, 80, 50 * 300])
191
+
192
+ @torch.inference_mode()
193
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
194
+ """Forward diffusion
195
+
196
+ Args:
197
+ mu (torch.Tensor): output of encoder
198
+ shape: (batch_size, n_feats, mel_timesteps)
199
+ mask (torch.Tensor): output_mask
200
+ shape: (batch_size, 1, mel_timesteps)
201
+ n_timesteps (int): number of diffusion steps
202
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
203
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
204
+ shape: (batch_size, spk_emb_dim)
205
+ cond: Not used but kept for future purposes
206
+
207
+ Returns:
208
+ sample: generated mel-spectrogram
209
+ shape: (batch_size, n_feats, mel_timesteps)
210
+ """
211
+
212
+ z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
213
+ # fix prompt and overlap part mu and z
214
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
215
+ if self.t_scheduler == 'cosine':
216
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
217
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
tts/cosyvoice/flow/length_regulator.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Tuple
15
+ import torch.nn as nn
16
+ import torch
17
+ from torch.nn import functional as F
18
+ from cosyvoice.utils.mask import make_pad_mask
19
+
20
+
21
+ class InterpolateRegulator(nn.Module):
22
+ def __init__(
23
+ self,
24
+ channels: int,
25
+ sampling_ratios: Tuple,
26
+ out_channels: int = None,
27
+ groups: int = 1,
28
+ ):
29
+ super().__init__()
30
+ self.sampling_ratios = sampling_ratios
31
+ out_channels = out_channels or channels
32
+ model = nn.ModuleList([])
33
+ if len(sampling_ratios) > 0:
34
+ for _ in sampling_ratios:
35
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
36
+ norm = nn.GroupNorm(groups, channels)
37
+ act = nn.Mish()
38
+ model.extend([module, norm, act])
39
+ model.append(
40
+ nn.Conv1d(channels, out_channels, 1, 1)
41
+ )
42
+ self.model = nn.Sequential(*model)
43
+
44
+ def forward(self, x, ylens=None):
45
+ # x in (B, T, D)
46
+ mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
47
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
48
+ out = self.model(x).transpose(1, 2).contiguous()
49
+ olens = ylens
50
+ return out * mask, olens
51
+
52
+ def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
53
+ # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54
+ # x in (B, T, D)
55
+ if x2.shape[1] > 40:
56
+ x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
57
+ x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
58
+ mode='linear')
59
+ x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
60
+ x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
61
+ else:
62
+ x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
63
+ if x1.shape[1] != 0:
64
+ x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
65
+ x = torch.concat([x1, x2], dim=2)
66
+ else:
67
+ x = x2
68
+ out = self.model(x).transpose(1, 2).contiguous()
69
+ return out, mel_len1 + mel_len2
tts/cosyvoice/hifigan/discriminator.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.utils.parametrizations import weight_norm
4
+ from typing import List, Optional, Tuple
5
+ from einops import rearrange
6
+ from torchaudio.transforms import Spectrogram
7
+
8
+
9
+ class MultipleDiscriminator(nn.Module):
10
+ def __init__(
11
+ self, mpd: nn.Module, mrd: nn.Module
12
+ ):
13
+ super().__init__()
14
+ self.mpd = mpd
15
+ self.mrd = mrd
16
+
17
+ def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
18
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
19
+ this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
20
+ y_d_rs += this_y_d_rs
21
+ y_d_gs += this_y_d_gs
22
+ fmap_rs += this_fmap_rs
23
+ fmap_gs += this_fmap_gs
24
+ this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
25
+ y_d_rs += this_y_d_rs
26
+ y_d_gs += this_y_d_gs
27
+ fmap_rs += this_fmap_rs
28
+ fmap_gs += this_fmap_gs
29
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
30
+
31
+
32
+ class MultiResolutionDiscriminator(nn.Module):
33
+ def __init__(
34
+ self,
35
+ fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
36
+ num_embeddings: Optional[int] = None,
37
+ ):
38
+ """
39
+ Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
40
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
41
+
42
+ Args:
43
+ fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
44
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
45
+ Defaults to None.
46
+ """
47
+
48
+ super().__init__()
49
+ self.discriminators = nn.ModuleList(
50
+ [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
51
+ )
52
+
53
+ def forward(
54
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
55
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
56
+ y_d_rs = []
57
+ y_d_gs = []
58
+ fmap_rs = []
59
+ fmap_gs = []
60
+
61
+ for d in self.discriminators:
62
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
63
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
64
+ y_d_rs.append(y_d_r)
65
+ fmap_rs.append(fmap_r)
66
+ y_d_gs.append(y_d_g)
67
+ fmap_gs.append(fmap_g)
68
+
69
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
70
+
71
+
72
+ class DiscriminatorR(nn.Module):
73
+ def __init__(
74
+ self,
75
+ window_length: int,
76
+ num_embeddings: Optional[int] = None,
77
+ channels: int = 32,
78
+ hop_factor: float = 0.25,
79
+ bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
80
+ ):
81
+ super().__init__()
82
+ self.window_length = window_length
83
+ self.hop_factor = hop_factor
84
+ self.spec_fn = Spectrogram(
85
+ n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
86
+ )
87
+ n_fft = window_length // 2 + 1
88
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
89
+ self.bands = bands
90
+ convs = lambda: nn.ModuleList(
91
+ [
92
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
93
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
94
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
95
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
96
+ weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
97
+ ]
98
+ )
99
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
100
+
101
+ if num_embeddings is not None:
102
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
103
+ torch.nn.init.zeros_(self.emb.weight)
104
+
105
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
106
+
107
+ def spectrogram(self, x):
108
+ # Remove DC offset
109
+ x = x - x.mean(dim=-1, keepdims=True)
110
+ # Peak normalize the volume of input audio
111
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
112
+ x = self.spec_fn(x)
113
+ x = torch.view_as_real(x)
114
+ x = rearrange(x, "b f t c -> b c t f")
115
+ # Split into bands
116
+ x_bands = [x[..., b[0]: b[1]] for b in self.bands]
117
+ return x_bands
118
+
119
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
120
+ x_bands = self.spectrogram(x)
121
+ fmap = []
122
+ x = []
123
+ for band, stack in zip(x_bands, self.band_convs):
124
+ for i, layer in enumerate(stack):
125
+ band = layer(band)
126
+ band = torch.nn.functional.leaky_relu(band, 0.1)
127
+ if i > 0:
128
+ fmap.append(band)
129
+ x.append(band)
130
+ x = torch.cat(x, dim=-1)
131
+ if cond_embedding_id is not None:
132
+ emb = self.emb(cond_embedding_id)
133
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
134
+ else:
135
+ h = 0
136
+ x = self.conv_post(x)
137
+ fmap.append(x)
138
+ x += h
139
+
140
+ return x, fmap
tts/cosyvoice/hifigan/f0_predictor.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ # from torch.nn.utils import weight_norm
17
+ from torch.nn.utils.parametrizations import weight_norm
18
+
19
+
20
+ class ConvRNNF0Predictor(nn.Module):
21
+ def __init__(self,
22
+ num_class: int = 1,
23
+ in_channels: int = 80,
24
+ cond_channels: int = 512
25
+ ):
26
+ super().__init__()
27
+
28
+ self.num_class = num_class
29
+ self.condnet = nn.Sequential(
30
+ weight_norm(
31
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
32
+ ),
33
+ nn.ELU(),
34
+ weight_norm(
35
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
36
+ ),
37
+ nn.ELU(),
38
+ weight_norm(
39
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
40
+ ),
41
+ nn.ELU(),
42
+ weight_norm(
43
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
44
+ ),
45
+ nn.ELU(),
46
+ weight_norm(
47
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
48
+ ),
49
+ nn.ELU(),
50
+ )
51
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ x = self.condnet(x)
55
+ x = x.transpose(1, 2)
56
+ return torch.abs(self.classifier(x).squeeze(-1))
tts/cosyvoice/hifigan/generator.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HIFI-GAN"""
16
+
17
+ from typing import Dict, Optional, List
18
+ import numpy as np
19
+ from scipy.signal import get_window
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.nn import Conv1d
24
+ from torch.nn import ConvTranspose1d
25
+ from torch.nn.utils import remove_weight_norm
26
+ # from torch.nn.utils import weight_norm
27
+ from torch.nn.utils.parametrizations import weight_norm
28
+ from torch.distributions.uniform import Uniform
29
+
30
+ from cosyvoice.transformer.activation import Snake
31
+ from cosyvoice.utils.common import get_padding
32
+ from cosyvoice.utils.common import init_weights
33
+
34
+
35
+ """hifigan based generator implementation.
36
+
37
+ This code is modified from https://github.com/jik876/hifi-gan
38
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
39
+ https://github.com/NVIDIA/BigVGAN
40
+
41
+ """
42
+
43
+
44
+ class ResBlock(torch.nn.Module):
45
+ """Residual block module in HiFiGAN/BigVGAN."""
46
+ def __init__(
47
+ self,
48
+ channels: int = 512,
49
+ kernel_size: int = 3,
50
+ dilations: List[int] = [1, 3, 5],
51
+ ):
52
+ super(ResBlock, self).__init__()
53
+ self.convs1 = nn.ModuleList()
54
+ self.convs2 = nn.ModuleList()
55
+
56
+ for dilation in dilations:
57
+ self.convs1.append(
58
+ weight_norm(
59
+ Conv1d(
60
+ channels,
61
+ channels,
62
+ kernel_size,
63
+ 1,
64
+ dilation=dilation,
65
+ padding=get_padding(kernel_size, dilation)
66
+ )
67
+ )
68
+ )
69
+ self.convs2.append(
70
+ weight_norm(
71
+ Conv1d(
72
+ channels,
73
+ channels,
74
+ kernel_size,
75
+ 1,
76
+ dilation=1,
77
+ padding=get_padding(kernel_size, 1)
78
+ )
79
+ )
80
+ )
81
+ self.convs1.apply(init_weights)
82
+ self.convs2.apply(init_weights)
83
+ self.activations1 = nn.ModuleList([
84
+ Snake(channels, alpha_logscale=False)
85
+ for _ in range(len(self.convs1))
86
+ ])
87
+ self.activations2 = nn.ModuleList([
88
+ Snake(channels, alpha_logscale=False)
89
+ for _ in range(len(self.convs2))
90
+ ])
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ for idx in range(len(self.convs1)):
94
+ xt = self.activations1[idx](x)
95
+ xt = self.convs1[idx](xt)
96
+ xt = self.activations2[idx](xt)
97
+ xt = self.convs2[idx](xt)
98
+ x = xt + x
99
+ return x
100
+
101
+ def remove_weight_norm(self):
102
+ for idx in range(len(self.convs1)):
103
+ remove_weight_norm(self.convs1[idx])
104
+ remove_weight_norm(self.convs2[idx])
105
+
106
+
107
+ class SineGen(torch.nn.Module):
108
+ """ Definition of sine generator
109
+ SineGen(samp_rate, harmonic_num = 0,
110
+ sine_amp = 0.1, noise_std = 0.003,
111
+ voiced_threshold = 0,
112
+ flag_for_pulse=False)
113
+ samp_rate: sampling rate in Hz
114
+ harmonic_num: number of harmonic overtones (default 0)
115
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
116
+ noise_std: std of Gaussian noise (default 0.003)
117
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
118
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
119
+ Note: when flag_for_pulse is True, the first time step of a voiced
120
+ segment is always sin(np.pi) or cos(0)
121
+ """
122
+
123
+ def __init__(self, samp_rate, harmonic_num=0,
124
+ sine_amp=0.1, noise_std=0.003,
125
+ voiced_threshold=0):
126
+ super(SineGen, self).__init__()
127
+ self.sine_amp = sine_amp
128
+ self.noise_std = noise_std
129
+ self.harmonic_num = harmonic_num
130
+ self.sampling_rate = samp_rate
131
+ self.voiced_threshold = voiced_threshold
132
+
133
+ def _f02uv(self, f0):
134
+ # generate uv signal
135
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
136
+ return uv
137
+
138
+ @torch.no_grad()
139
+ def forward(self, f0):
140
+ """
141
+ :param f0: [B, 1, sample_len], Hz
142
+ :return: [B, 1, sample_len]
143
+ """
144
+
145
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
146
+ for i in range(self.harmonic_num + 1):
147
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
148
+
149
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
150
+ u_dist = Uniform(low=-np.pi, high=np.pi)
151
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
152
+ phase_vec[:, 0, :] = 0
153
+
154
+ # generate sine waveforms
155
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
156
+
157
+ # generate uv signal
158
+ uv = self._f02uv(f0)
159
+
160
+ # noise: for unvoiced should be similar to sine_amp
161
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
162
+ # . for voiced regions is self.noise_std
163
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
164
+ noise = noise_amp * torch.randn_like(sine_waves)
165
+
166
+ # first: set the unvoiced part to 0 by uv
167
+ # then: additive noise
168
+ sine_waves = sine_waves * uv + noise
169
+ return sine_waves, uv, noise
170
+
171
+
172
+ class SourceModuleHnNSF(torch.nn.Module):
173
+ """ SourceModule for hn-nsf
174
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
175
+ add_noise_std=0.003, voiced_threshod=0)
176
+ sampling_rate: sampling_rate in Hz
177
+ harmonic_num: number of harmonic above F0 (default: 0)
178
+ sine_amp: amplitude of sine source signal (default: 0.1)
179
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
180
+ note that amplitude of noise in unvoiced is decided
181
+ by sine_amp
182
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
183
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
184
+ F0_sampled (batchsize, length, 1)
185
+ Sine_source (batchsize, length, 1)
186
+ noise_source (batchsize, length 1)
187
+ uv (batchsize, length, 1)
188
+ """
189
+
190
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
191
+ add_noise_std=0.003, voiced_threshod=0):
192
+ super(SourceModuleHnNSF, self).__init__()
193
+
194
+ self.sine_amp = sine_amp
195
+ self.noise_std = add_noise_std
196
+
197
+ # to produce sine waveforms
198
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
199
+ sine_amp, add_noise_std, voiced_threshod)
200
+
201
+ # to merge source harmonics into a single excitation
202
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
203
+ self.l_tanh = torch.nn.Tanh()
204
+
205
+ def forward(self, x):
206
+ """
207
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
208
+ F0_sampled (batchsize, length, 1)
209
+ Sine_source (batchsize, length, 1)
210
+ noise_source (batchsize, length 1)
211
+ """
212
+ # source for harmonic branch
213
+ with torch.no_grad():
214
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
215
+ sine_wavs = sine_wavs.transpose(1, 2)
216
+ uv = uv.transpose(1, 2)
217
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
218
+
219
+ # source for noise branch, in the same shape as uv
220
+ noise = torch.randn_like(uv) * self.sine_amp / 3
221
+ return sine_merge, noise, uv
222
+
223
+
224
+ class HiFTGenerator(nn.Module):
225
+ """
226
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
227
+ https://arxiv.org/abs/2309.09493
228
+ """
229
+ def __init__(
230
+ self,
231
+ in_channels: int = 80,
232
+ base_channels: int = 512,
233
+ nb_harmonics: int = 8,
234
+ sampling_rate: int = 22050,
235
+ nsf_alpha: float = 0.1,
236
+ nsf_sigma: float = 0.003,
237
+ nsf_voiced_threshold: float = 10,
238
+ upsample_rates: List[int] = [8, 8],
239
+ upsample_kernel_sizes: List[int] = [16, 16],
240
+ istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
241
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
242
+ resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
243
+ source_resblock_kernel_sizes: List[int] = [7, 11],
244
+ source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
245
+ lrelu_slope: float = 0.1,
246
+ audio_limit: float = 0.99,
247
+ f0_predictor: torch.nn.Module = None,
248
+ ):
249
+ super(HiFTGenerator, self).__init__()
250
+
251
+ self.out_channels = 1
252
+ self.nb_harmonics = nb_harmonics
253
+ self.sampling_rate = sampling_rate
254
+ self.istft_params = istft_params
255
+ self.lrelu_slope = lrelu_slope
256
+ self.audio_limit = audio_limit
257
+
258
+ self.num_kernels = len(resblock_kernel_sizes)
259
+ self.num_upsamples = len(upsample_rates)
260
+ self.m_source = SourceModuleHnNSF(
261
+ sampling_rate=sampling_rate,
262
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
263
+ harmonic_num=nb_harmonics,
264
+ sine_amp=nsf_alpha,
265
+ add_noise_std=nsf_sigma,
266
+ voiced_threshod=nsf_voiced_threshold)
267
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
268
+
269
+ self.conv_pre = weight_norm(
270
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
271
+ )
272
+
273
+ # Up
274
+ self.ups = nn.ModuleList()
275
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
276
+ self.ups.append(
277
+ weight_norm(
278
+ ConvTranspose1d(
279
+ base_channels // (2**i),
280
+ base_channels // (2**(i + 1)),
281
+ k,
282
+ u,
283
+ padding=(k - u) // 2,
284
+ )
285
+ )
286
+ )
287
+
288
+ # Down
289
+ self.source_downs = nn.ModuleList()
290
+ self.source_resblocks = nn.ModuleList()
291
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
292
+ downsample_cum_rates = np.cumprod(downsample_rates)
293
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
294
+ if u == 1:
295
+ self.source_downs.append(
296
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
297
+ )
298
+ else:
299
+ self.source_downs.append(
300
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
301
+ )
302
+
303
+ self.source_resblocks.append(
304
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
305
+ )
306
+
307
+ self.resblocks = nn.ModuleList()
308
+ for i in range(len(self.ups)):
309
+ ch = base_channels // (2**(i + 1))
310
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
311
+ self.resblocks.append(ResBlock(ch, k, d))
312
+
313
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
314
+ self.ups.apply(init_weights)
315
+ self.conv_post.apply(init_weights)
316
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
317
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
318
+ self.f0_predictor = f0_predictor
319
+
320
+ def remove_weight_norm(self):
321
+ print('Removing weight norm...')
322
+ for l in self.ups:
323
+ remove_weight_norm(l)
324
+ for l in self.resblocks:
325
+ l.remove_weight_norm()
326
+ remove_weight_norm(self.conv_pre)
327
+ remove_weight_norm(self.conv_post)
328
+ self.m_source.remove_weight_norm()
329
+ for l in self.source_downs:
330
+ remove_weight_norm(l)
331
+ for l in self.source_resblocks:
332
+ l.remove_weight_norm()
333
+
334
+ def _stft(self, x):
335
+ spec = torch.stft(
336
+ x,
337
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
338
+ return_complex=True)
339
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
340
+ return spec[..., 0], spec[..., 1]
341
+
342
+ def _istft(self, magnitude, phase):
343
+ magnitude = torch.clip(magnitude, max=1e2)
344
+ real = magnitude * torch.cos(phase)
345
+ img = magnitude * torch.sin(phase)
346
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
347
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
348
+ return inverse_transform
349
+
350
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
351
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
352
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
353
+
354
+ x = self.conv_pre(x)
355
+ for i in range(self.num_upsamples):
356
+ x = F.leaky_relu(x, self.lrelu_slope)
357
+ x = self.ups[i](x)
358
+
359
+ if i == self.num_upsamples - 1:
360
+ x = self.reflection_pad(x)
361
+
362
+ # fusion
363
+ si = self.source_downs[i](s_stft)
364
+ si = self.source_resblocks[i](si)
365
+ x = x + si
366
+
367
+ xs = None
368
+ for j in range(self.num_kernels):
369
+ if xs is None:
370
+ xs = self.resblocks[i * self.num_kernels + j](x)
371
+ else:
372
+ xs += self.resblocks[i * self.num_kernels + j](x)
373
+ x = xs / self.num_kernels
374
+
375
+ x = F.leaky_relu(x)
376
+ x = self.conv_post(x)
377
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
378
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
379
+
380
+ x = self._istft(magnitude, phase)
381
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
382
+ return x
383
+
384
+ def forward(
385
+ self,
386
+ batch: dict,
387
+ device: torch.device,
388
+ ) -> Dict[str, Optional[torch.Tensor]]:
389
+ speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
390
+ # mel->f0
391
+ f0 = self.f0_predictor(speech_feat)
392
+ # f0->source
393
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
394
+ s, _, _ = self.m_source(s)
395
+ s = s.transpose(1, 2)
396
+ # mel+source->speech
397
+ generated_speech = self.decode(x=speech_feat, s=s)
398
+ return generated_speech, f0
399
+
400
+ @torch.inference_mode()
401
+ def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
402
+ # mel->f0
403
+ f0 = self.f0_predictor(speech_feat)
404
+ # f0->source
405
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
406
+ s, _, _ = self.m_source(s)
407
+ s = s.transpose(1, 2)
408
+ # use cache_source to avoid glitch
409
+ if cache_source.shape[2] != 0:
410
+ s[:, :, :cache_source.shape[2]] = cache_source
411
+ generated_speech = self.decode(x=speech_feat, s=s)
412
+ return generated_speech, s
tts/cosyvoice/hifigan/hifigan.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
6
+ from cosyvoice.utils.losses import tpr_loss, mel_loss
7
+
8
+
9
+ class HiFiGan(nn.Module):
10
+ def __init__(self, generator, discriminator, mel_spec_transform,
11
+ multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
12
+ tpr_loss_weight=1.0, tpr_loss_tau=0.04):
13
+ super(HiFiGan, self).__init__()
14
+ self.generator = generator
15
+ self.discriminator = discriminator
16
+ self.mel_spec_transform = mel_spec_transform
17
+ self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
18
+ self.feat_match_loss_weight = feat_match_loss_weight
19
+ self.tpr_loss_weight = tpr_loss_weight
20
+ self.tpr_loss_tau = tpr_loss_tau
21
+
22
+ def forward(
23
+ self,
24
+ batch: dict,
25
+ device: torch.device,
26
+ ) -> Dict[str, Optional[torch.Tensor]]:
27
+ if batch['turn'] == 'generator':
28
+ return self.forward_generator(batch, device)
29
+ else:
30
+ return self.forward_discriminator(batch, device)
31
+
32
+ def forward_generator(self, batch, device):
33
+ real_speech = batch['speech'].to(device)
34
+ pitch_feat = batch['pitch_feat'].to(device)
35
+ # 1. calculate generator outputs
36
+ generated_speech, generated_f0 = self.generator(batch, device)
37
+ # 2. calculate discriminator outputs
38
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
39
+ # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
40
+ loss_gen, _ = generator_loss(y_d_gs)
41
+ loss_fm = feature_loss(fmap_rs, fmap_gs)
42
+ loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
43
+ if self.tpr_loss_weight != 0:
44
+ loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
45
+ else:
46
+ loss_tpr = torch.zeros(1).to(device)
47
+ loss_f0 = F.l1_loss(generated_f0, pitch_feat)
48
+ loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
49
+ self.multi_mel_spectral_recon_loss_weight * loss_mel + \
50
+ self.tpr_loss_weight * loss_tpr + loss_f0
51
+ return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
52
+
53
+ def forward_discriminator(self, batch, device):
54
+ real_speech = batch['speech'].to(device)
55
+ # 1. calculate generator outputs
56
+ with torch.no_grad():
57
+ generated_speech, generated_f0 = self.generator(batch, device)
58
+ # 2. calculate discriminator outputs
59
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
60
+ # 3. calculate discriminator losses, tpr losses [Optional]
61
+ loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
62
+ if self.tpr_loss_weight != 0:
63
+ loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
64
+ else:
65
+ loss_tpr = torch.zeros(1).to(device)
66
+ loss = loss_disc + self.tpr_loss_weight * loss_tpr
67
+ return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr}
tts/cosyvoice/llm/llm.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Callable, List, Generator
15
+ import torch
16
+ from torch import nn
17
+ import torch.nn.functional as F
18
+ from transformers import Qwen2ForCausalLM
19
+ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
20
+ from cosyvoice.utils.common import IGNORE_ID
21
+ from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
22
+ from cosyvoice.utils.common import th_accuracy
23
+ from cosyvoice.utils.file_utils import logging
24
+
25
+
26
+ class TransformerLM(torch.nn.Module):
27
+ def __init__(
28
+ self,
29
+ text_encoder_input_size: int,
30
+ llm_input_size: int,
31
+ llm_output_size: int,
32
+ text_token_size: int,
33
+ speech_token_size: int,
34
+ text_encoder: torch.nn.Module,
35
+ llm: torch.nn.Module,
36
+ sampling: Callable,
37
+ length_normalized_loss: bool = True,
38
+ lsm_weight: float = 0.0,
39
+ spk_embed_dim: int = 192,
40
+ ):
41
+ super().__init__()
42
+ self.llm_input_size = llm_input_size
43
+ self.speech_token_size = speech_token_size
44
+ # 1. build text token inputs related modules
45
+ self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
46
+ self.text_encoder = text_encoder
47
+ self.text_encoder_affine_layer = nn.Linear(
48
+ self.text_encoder.output_size(),
49
+ llm_input_size
50
+ )
51
+
52
+ # 2. build speech token language model related modules
53
+ self.sos_eos = 0
54
+ self.task_id = 1
55
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
56
+ self.llm = llm
57
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
58
+ self.criterion_ce = LabelSmoothingLoss(
59
+ size=speech_token_size + 1,
60
+ padding_idx=IGNORE_ID,
61
+ smoothing=lsm_weight,
62
+ normalize_length=length_normalized_loss,
63
+ )
64
+
65
+ # 3. [Optional] build speech token related modules
66
+ self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
67
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
68
+
69
+ # 4. sampling method
70
+ self.sampling = sampling
71
+
72
+ def encode(
73
+ self,
74
+ text: torch.Tensor,
75
+ text_lengths: torch.Tensor,
76
+ ):
77
+ encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
78
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
79
+ encoder_out = self.text_encoder_affine_layer(encoder_out)
80
+ return encoder_out, encoder_out_lens
81
+
82
+ def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
83
+ text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
84
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
85
+ lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
86
+ for i in range(len(text_token))]
87
+ lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
88
+ lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
89
+ return lm_input, lm_input_len
90
+
91
+ def forward(
92
+ self,
93
+ batch: dict,
94
+ device: torch.device,
95
+ ) -> Dict[str, Optional[torch.Tensor]]:
96
+ """
97
+ Args:
98
+ text: (B, L, D)
99
+ text_lengths: (B,)
100
+ audio: (B, T, N) or (B, T)
101
+ audio_lengths: (B,)
102
+ """
103
+ text_token = batch['text_token'].to(device)
104
+ text_token_len = batch['text_token_len'].to(device)
105
+ speech_token = batch['speech_token'].to(device)
106
+ speech_token_len = batch['speech_token_len'].to(device)
107
+ embedding = batch['embedding'].to(device)
108
+
109
+ # 1. prepare llm_target
110
+ lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
111
+ [self.speech_token_size]) for i in range(text_token.size(0))]
112
+ lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
113
+
114
+ # 1. encode text_token
115
+ text_token = self.text_embedding(text_token)
116
+ text_token, text_token_len = self.encode(text_token, text_token_len)
117
+
118
+ # 2. embedding projection
119
+ embedding = F.normalize(embedding, dim=1)
120
+ embedding = self.spk_embed_affine_layer(embedding)
121
+ embedding = embedding.unsqueeze(1)
122
+
123
+ # 3. eos and task_id
124
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
125
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
126
+
127
+ # 4. encode speech_token
128
+ speech_token = self.speech_embedding(speech_token)
129
+
130
+ # 5. unpad and pad
131
+ lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
132
+ task_id_emb, speech_token, speech_token_len)
133
+
134
+ # 6. run lm forward
135
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
136
+ logits = self.llm_decoder(lm_output)
137
+ loss = self.criterion_ce(logits, lm_target)
138
+ acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
139
+ return {'loss': loss, 'acc': acc}
140
+
141
+ def sampling_ids(
142
+ self,
143
+ weighted_scores: torch.Tensor,
144
+ decoded_tokens: List,
145
+ sampling: int,
146
+ ignore_eos: bool = True,
147
+ ):
148
+ num_trials, max_trials = 0, 100
149
+ while True:
150
+ top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
151
+ if (not ignore_eos) or (self.speech_token_size not in top_ids):
152
+ break
153
+ num_trials += 1
154
+ if num_trials > max_trials:
155
+ raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
156
+ return top_ids
157
+
158
+ @torch.inference_mode()
159
+ def inference(
160
+ self,
161
+ text: torch.Tensor,
162
+ text_len: torch.Tensor,
163
+ prompt_text: torch.Tensor,
164
+ prompt_text_len: torch.Tensor,
165
+ prompt_speech_token: torch.Tensor,
166
+ prompt_speech_token_len: torch.Tensor,
167
+ embedding: torch.Tensor,
168
+ sampling: int = 25,
169
+ max_token_text_ratio: float = 20,
170
+ min_token_text_ratio: float = 2,
171
+ ) -> Generator[torch.Tensor, None, None]:
172
+ if self.fp16 is True:
173
+ embedding = embedding.half()
174
+
175
+ device = text.device
176
+ text = torch.concat([prompt_text, text], dim=1)
177
+ text_len += prompt_text_len
178
+ text = self.text_embedding(text)
179
+
180
+ # 1. encode text
181
+ text, text_len = self.encode(text, text_len)
182
+
183
+ # 2. encode embedding
184
+ if embedding.shape[0] != 0:
185
+ embedding = F.normalize(embedding, dim=1)
186
+ embedding = self.spk_embed_affine_layer(embedding)
187
+ embedding = embedding.unsqueeze(dim=1)
188
+ else:
189
+ embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
190
+
191
+ # 3. concat llm_input
192
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
193
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
194
+ if prompt_speech_token_len != 0:
195
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
196
+ else:
197
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
198
+ lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
199
+
200
+ # 4. cal min/max_length
201
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
202
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
203
+
204
+ # 5. step by step decode
205
+ out_tokens = []
206
+ offset = 0
207
+ att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
208
+ for i in range(max_len):
209
+ y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
210
+ att_cache=att_cache, cnn_cache=cnn_cache,
211
+ att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
212
+ device=lm_input.device)).to(torch.bool))
213
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
214
+ # force continue decode first token
215
+ if i == 0:
216
+ logp[:, self.speech_token_size] = -float('inf')
217
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
218
+ if top_ids == self.speech_token_size:
219
+ break
220
+ # in stream mode, yield token one by one
221
+ yield top_ids
222
+ out_tokens.append(top_ids)
223
+ offset += lm_input.size(1)
224
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
225
+
226
+
227
+ class Qwen2Encoder(torch.nn.Module):
228
+ def __init__(self, pretrain_path):
229
+ super().__init__()
230
+ self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
231
+
232
+ def forward_one_step(self, xs, masks, cache=None):
233
+ input_masks = masks[:, -1, :]
234
+ outs = self.model(
235
+ inputs_embeds=xs,
236
+ attention_mask=input_masks,
237
+ output_hidden_states=True,
238
+ return_dict=True,
239
+ use_cache=True,
240
+ past_key_values=cache,
241
+ )
242
+ xs = outs.hidden_states[-1]
243
+ new_cache = outs.past_key_values
244
+ return xs, new_cache
245
+
246
+
247
+ class Qwen2LM(TransformerLM):
248
+ def __init__(
249
+ self,
250
+ llm_input_size: int,
251
+ llm_output_size: int,
252
+ speech_token_size: int,
253
+ llm: torch.nn.Module,
254
+ sampling: Callable,
255
+ length_normalized_loss: bool = True,
256
+ lsm_weight: float = 0.0,
257
+ mix_ratio: List[int] = [5, 15],
258
+ ):
259
+ torch.nn.Module.__init__(self)
260
+ self.llm_input_size = llm_input_size
261
+ self.llm_output_size = llm_output_size
262
+ self.speech_token_size = speech_token_size
263
+
264
+ # 2. build speech token language model related modules
265
+ self.sos_eos = 0
266
+ self.task_id = 1
267
+ self.fill_token = 2
268
+
269
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
270
+ self.llm = llm
271
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
272
+ self.criterion_ce = LabelSmoothingLoss(
273
+ size=speech_token_size + 3,
274
+ padding_idx=IGNORE_ID,
275
+ smoothing=lsm_weight,
276
+ normalize_length=length_normalized_loss,
277
+ )
278
+
279
+ # 3. [Optional] build speech token related modules
280
+ self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
281
+
282
+ # 4. sampling method
283
+ self.sampling = sampling
284
+ self.mix_ratio = mix_ratio
285
+
286
+ @torch.inference_mode()
287
+ def inference(
288
+ self,
289
+ text: torch.Tensor,
290
+ text_len: torch.Tensor,
291
+ prompt_text: torch.Tensor,
292
+ prompt_text_len: torch.Tensor,
293
+ prompt_speech_token: torch.Tensor,
294
+ prompt_speech_token_len: torch.Tensor,
295
+ embedding: torch.Tensor,
296
+ sampling: int = 25,
297
+ max_token_text_ratio: float = 20,
298
+ min_token_text_ratio: float = 2,
299
+ ) -> Generator[torch.Tensor, None, None]:
300
+ device = text.device
301
+ text = torch.concat([prompt_text, text], dim=1)
302
+ text_len += prompt_text_len
303
+ text = self.llm.model.model.embed_tokens(text)
304
+
305
+ # 3. concat llm_input
306
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
307
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
308
+ if prompt_speech_token_len != 0:
309
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
310
+ else:
311
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
312
+ lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
313
+
314
+ # 4. cal min/max_length
315
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
316
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
317
+
318
+ # 5. step by step decode
319
+ out_tokens = []
320
+ cache = None
321
+ for i in range(max_len):
322
+ y_pred, cache = self.llm.forward_one_step(lm_input,
323
+ masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
324
+ cache=cache)
325
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
326
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
327
+ if top_ids == self.speech_token_size:
328
+ break
329
+ if top_ids > self.speech_token_size:
330
+ continue
331
+ # in stream mode, yield token one by one
332
+ yield top_ids
333
+ out_tokens.append(top_ids)
334
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
335
+
336
+ @torch.inference_mode()
337
+ def inference_bistream(
338
+ self,
339
+ text: Generator,
340
+ prompt_text: torch.Tensor,
341
+ prompt_text_len: torch.Tensor,
342
+ prompt_speech_token: torch.Tensor,
343
+ prompt_speech_token_len: torch.Tensor,
344
+ embedding: torch.Tensor,
345
+ sampling: int = 25,
346
+ max_token_text_ratio: float = 20,
347
+ min_token_text_ratio: float = 2,
348
+ ) -> Generator[torch.Tensor, None, None]:
349
+
350
+ device = prompt_text.device
351
+ # 1. prepare input
352
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
353
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
354
+ if prompt_speech_token_len != 0:
355
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
356
+ else:
357
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
358
+ lm_input = torch.concat([sos_eos_emb], dim=1)
359
+
360
+ # 2. iterate text
361
+ out_tokens = []
362
+ cache = None
363
+ # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
364
+ text_cache = self.llm.model.model.embed_tokens(prompt_text)
365
+ next_fill_index = -1
366
+ for this_text in text:
367
+ text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
368
+ # prompt_speech_token_emb not empty, try append to lm_input
369
+ while prompt_speech_token_emb.size(1) != 0:
370
+ if text_cache.size(1) >= self.mix_ratio[0]:
371
+ lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
372
+ logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
373
+ lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
374
+ text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
375
+ else:
376
+ logging.info('not enough text token to decode, wait for more')
377
+ break
378
+ # no prompt_speech_token_emb remain, can decode some speech token
379
+ if prompt_speech_token_emb.size(1) == 0:
380
+ if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
381
+ logging.info('get fill token, need to append more text token')
382
+ if text_cache.size(1) >= self.mix_ratio[0]:
383
+ lm_input_text = text_cache[:, :self.mix_ratio[0]]
384
+ logging.info('append {} text token'.format(lm_input_text.size(1)))
385
+ if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
386
+ lm_input = lm_input_text
387
+ else:
388
+ lm_input = torch.concat([lm_input, lm_input_text], dim=1)
389
+ text_cache = text_cache[:, self.mix_ratio[0]:]
390
+ else:
391
+ logging.info('not enough text token to decode, wait for more')
392
+ continue
393
+ while True:
394
+ seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
395
+ y_pred, cache = self.llm.forward_one_step(lm_input,
396
+ masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
397
+ cache=cache)
398
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
399
+ if next_fill_index != -1 and len(out_tokens) == next_fill_index:
400
+ top_ids = self.speech_token_size + 2
401
+ next_fill_index += (self.mix_ratio[1] + 1)
402
+ else:
403
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
404
+ if top_ids == self.speech_token_size + 2:
405
+ next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
406
+ logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
407
+ out_tokens.append(top_ids)
408
+ if top_ids >= self.speech_token_size:
409
+ if top_ids == self.speech_token_size + 2:
410
+ break
411
+ else:
412
+ raise ValueError('should not get token {}'.format(top_ids))
413
+ yield top_ids
414
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
415
+
416
+ # 3. final decode
417
+ lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
418
+ logging.info('no more text token, decode until met eos')
419
+ while True:
420
+ seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
421
+ y_pred, cache = self.llm.forward_one_step(lm_input,
422
+ masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
423
+ cache=cache)
424
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
425
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
426
+ out_tokens.append(top_ids)
427
+ if top_ids >= self.speech_token_size:
428
+ if top_ids == self.speech_token_size:
429
+ break
430
+ else:
431
+ raise ValueError('should not get token {}'.format(top_ids))
432
+ # in stream mode, yield token one by one
433
+ yield top_ids
434
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
tts/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
tts/cosyvoice/tokenizer/tokenizer.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ from functools import lru_cache
4
+ from typing import Optional
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+ from whisper.tokenizer import Tokenizer
8
+
9
+ import tiktoken
10
+
11
+ LANGUAGES = {
12
+ "en": "english",
13
+ "zh": "chinese",
14
+ "de": "german",
15
+ "es": "spanish",
16
+ "ru": "russian",
17
+ "ko": "korean",
18
+ "fr": "french",
19
+ "ja": "japanese",
20
+ "pt": "portuguese",
21
+ "tr": "turkish",
22
+ "pl": "polish",
23
+ "ca": "catalan",
24
+ "nl": "dutch",
25
+ "ar": "arabic",
26
+ "sv": "swedish",
27
+ "it": "italian",
28
+ "id": "indonesian",
29
+ "hi": "hindi",
30
+ "fi": "finnish",
31
+ "vi": "vietnamese",
32
+ "he": "hebrew",
33
+ "uk": "ukrainian",
34
+ "el": "greek",
35
+ "ms": "malay",
36
+ "cs": "czech",
37
+ "ro": "romanian",
38
+ "da": "danish",
39
+ "hu": "hungarian",
40
+ "ta": "tamil",
41
+ "no": "norwegian",
42
+ "th": "thai",
43
+ "ur": "urdu",
44
+ "hr": "croatian",
45
+ "bg": "bulgarian",
46
+ "lt": "lithuanian",
47
+ "la": "latin",
48
+ "mi": "maori",
49
+ "ml": "malayalam",
50
+ "cy": "welsh",
51
+ "sk": "slovak",
52
+ "te": "telugu",
53
+ "fa": "persian",
54
+ "lv": "latvian",
55
+ "bn": "bengali",
56
+ "sr": "serbian",
57
+ "az": "azerbaijani",
58
+ "sl": "slovenian",
59
+ "kn": "kannada",
60
+ "et": "estonian",
61
+ "mk": "macedonian",
62
+ "br": "breton",
63
+ "eu": "basque",
64
+ "is": "icelandic",
65
+ "hy": "armenian",
66
+ "ne": "nepali",
67
+ "mn": "mongolian",
68
+ "bs": "bosnian",
69
+ "kk": "kazakh",
70
+ "sq": "albanian",
71
+ "sw": "swahili",
72
+ "gl": "galician",
73
+ "mr": "marathi",
74
+ "pa": "punjabi",
75
+ "si": "sinhala",
76
+ "km": "khmer",
77
+ "sn": "shona",
78
+ "yo": "yoruba",
79
+ "so": "somali",
80
+ "af": "afrikaans",
81
+ "oc": "occitan",
82
+ "ka": "georgian",
83
+ "be": "belarusian",
84
+ "tg": "tajik",
85
+ "sd": "sindhi",
86
+ "gu": "gujarati",
87
+ "am": "amharic",
88
+ "yi": "yiddish",
89
+ "lo": "lao",
90
+ "uz": "uzbek",
91
+ "fo": "faroese",
92
+ "ht": "haitian creole",
93
+ "ps": "pashto",
94
+ "tk": "turkmen",
95
+ "nn": "nynorsk",
96
+ "mt": "maltese",
97
+ "sa": "sanskrit",
98
+ "lb": "luxembourgish",
99
+ "my": "myanmar",
100
+ "bo": "tibetan",
101
+ "tl": "tagalog",
102
+ "mg": "malagasy",
103
+ "as": "assamese",
104
+ "tt": "tatar",
105
+ "haw": "hawaiian",
106
+ "ln": "lingala",
107
+ "ha": "hausa",
108
+ "ba": "bashkir",
109
+ "jw": "javanese",
110
+ "su": "sundanese",
111
+ "yue": "cantonese",
112
+ "minnan": "minnan",
113
+ "wuyu": "wuyu",
114
+ "dialect": "dialect",
115
+ "zh/en": "zh/en",
116
+ "en/zh": "en/zh",
117
+ }
118
+
119
+ # language code lookup by name, with a few language aliases
120
+ TO_LANGUAGE_CODE = {
121
+ **{language: code for code, language in LANGUAGES.items()},
122
+ "burmese": "my",
123
+ "valencian": "ca",
124
+ "flemish": "nl",
125
+ "haitian": "ht",
126
+ "letzeburgesch": "lb",
127
+ "pushto": "ps",
128
+ "panjabi": "pa",
129
+ "moldavian": "ro",
130
+ "moldovan": "ro",
131
+ "sinhalese": "si",
132
+ "castilian": "es",
133
+ "mandarin": "zh",
134
+ }
135
+
136
+ AUDIO_EVENT = {
137
+ "ASR": "ASR",
138
+ "AED": "AED",
139
+ "SER": "SER",
140
+ "Speech": "Speech",
141
+ "/Speech": "/Speech",
142
+ "BGM": "BGM",
143
+ "/BGM": "/BGM",
144
+ "Laughter": "Laughter",
145
+ "/Laughter": "/Laughter",
146
+ "Applause": "Applause",
147
+ "/Applause": "/Applause",
148
+ }
149
+
150
+ EMOTION = {
151
+ "HAPPY": "HAPPY",
152
+ "SAD": "SAD",
153
+ "ANGRY": "ANGRY",
154
+ "NEUTRAL": "NEUTRAL",
155
+ }
156
+
157
+ TTS_Vocal_Token = {
158
+ "TTS/B": "TTS/B",
159
+ "TTS/O": "TTS/O",
160
+ "TTS/Q": "TTS/Q",
161
+ "TTS/A": "TTS/A",
162
+ "TTS/CO": "TTS/CO",
163
+ "TTS/CL": "TTS/CL",
164
+ "TTS/H": "TTS/H",
165
+ **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
166
+ }
167
+
168
+
169
+ @lru_cache(maxsize=None)
170
+ def get_encoding(name: str = "gpt2", num_languages: int = 99):
171
+ vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
172
+ ranks = {
173
+ base64.b64decode(token): int(rank)
174
+ for token, rank in (line.split() for line in open(vocab_path) if line)
175
+ }
176
+ n_vocab = len(ranks)
177
+ special_tokens = {}
178
+
179
+ specials = [
180
+ "<|endoftext|>",
181
+ "<|startoftranscript|>",
182
+ *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
183
+ *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
184
+ *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
185
+ "<|translate|>",
186
+ "<|transcribe|>",
187
+ "<|startoflm|>",
188
+ "<|startofprev|>",
189
+ "<|nospeech|>",
190
+ "<|notimestamps|>",
191
+ *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
192
+ *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
193
+ *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
194
+ ]
195
+
196
+ for token in specials:
197
+ special_tokens[token] = n_vocab
198
+ n_vocab += 1
199
+
200
+ return tiktoken.Encoding(
201
+ name=os.path.basename(vocab_path),
202
+ explicit_n_vocab=n_vocab,
203
+ pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
204
+ mergeable_ranks=ranks,
205
+ special_tokens=special_tokens,
206
+ )
207
+
208
+
209
+ @lru_cache(maxsize=None)
210
+ def get_tokenizer(
211
+ multilingual: bool,
212
+ *,
213
+ num_languages: int = 99,
214
+ language: Optional[str] = None,
215
+ task: Optional[str] = None, # Literal["transcribe", "translate", None]
216
+ ) -> Tokenizer:
217
+ if language is not None:
218
+ language = language.lower()
219
+ if language not in LANGUAGES:
220
+ if language in TO_LANGUAGE_CODE:
221
+ language = TO_LANGUAGE_CODE[language]
222
+ else:
223
+ raise ValueError(f"Unsupported language: {language}")
224
+
225
+ if multilingual:
226
+ encoding_name = "multilingual_zh_ja_yue_char_del"
227
+ language = language or "en"
228
+ task = task or "transcribe"
229
+ else:
230
+ encoding_name = "gpt2"
231
+ language = None
232
+ task = None
233
+
234
+ encoding = get_encoding(name=encoding_name, num_languages=num_languages)
235
+
236
+ return Tokenizer(
237
+ encoding=encoding, num_languages=num_languages, language=language, task=task
238
+ )
239
+
240
+
241
+ class QwenTokenizer():
242
+ def __init__(self, token_path, skip_special_tokens=True):
243
+ super().__init__()
244
+ # NOTE: non-chat model, all these special tokens keep randomly initialized.
245
+ special_tokens = {
246
+ 'eos_token': '<|endoftext|>',
247
+ 'pad_token': '<|endoftext|>',
248
+ 'additional_special_tokens': [
249
+ '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
250
+ '[breath]', '<strong>', '</strong>', '[noise]',
251
+ '[laughter]', '[cough]', '[clucking]', '[accent]',
252
+ '[quick_breath]',
253
+ "<laughter>", "</laughter>",
254
+ "[hissing]", "[sigh]", "[vocalized-noise]",
255
+ "[lipsmack]", "[mn]"
256
+ ]
257
+ }
258
+ self.special_tokens = special_tokens
259
+ self.tokenizer = AutoTokenizer.from_pretrained(token_path)
260
+ self.tokenizer.add_special_tokens(special_tokens)
261
+ self.skip_special_tokens = skip_special_tokens
262
+
263
+ def encode(self, text, **kwargs):
264
+ tokens = self.tokenizer([text], return_tensors="pt")
265
+ tokens = tokens["input_ids"][0].cpu().tolist()
266
+ return tokens
267
+
268
+ def decode(self, tokens):
269
+ tokens = torch.tensor(tokens, dtype=torch.int64)
270
+ text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
271
+ return text
272
+
273
+
274
+ @lru_cache(maxsize=None)
275
+ def get_qwen_tokenizer(
276
+ token_path: str,
277
+ skip_special_tokens: bool
278
+ ) -> QwenTokenizer:
279
+ return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)