Spaces:
Running
on
Zero
Running
on
Zero
开始部署
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +414 -0
- common_utils/__init__.py +0 -0
- common_utils/convert_ckpt_dir_to_pt.py +27 -0
- common_utils/load_combine_type_yaml.py +59 -0
- common_utils/utils4infer.py +163 -0
- conf/ct_config.yaml +153 -0
- conf/ct_config_sft.yaml +152 -0
- conf/data_s2s.yaml +226 -0
- conf/data_s2t.yaml +402 -0
- conf/data_t2s.yaml +28 -0
- conf/data_t2t.yaml +159 -0
- conf/data_tmp.yaml +6 -0
- conf/ds_stage2.json +34 -0
- conf/empty.yaml +0 -0
- conf/prompt_config.yaml +0 -0
- conf/system_prompt.yaml +27 -0
- patches/cumstom_stop_criteria.py +85 -0
- patches/custom_speech_ngram_blocking.py +129 -0
- patches/custom_speech_repetition_penalty.py +22 -0
- patches/modelling_fm_infer_gpu.py +18 -0
- patches/modelling_qwen2_infer_gpu.py +416 -0
- patches/utils.py +4 -0
- requirements.txt +41 -0
- tts/__init__.py +0 -0
- tts/assert//345/256/236/351/252/214/345/256/244.png +0 -0
- tts/cosyvoice/__init__.py +0 -0
- tts/cosyvoice/bin/average_model.py +92 -0
- tts/cosyvoice/bin/export_jit.py +91 -0
- tts/cosyvoice/bin/export_onnx.py +116 -0
- tts/cosyvoice/bin/export_trt.sh +10 -0
- tts/cosyvoice/bin/inference.py +115 -0
- tts/cosyvoice/bin/train.py +170 -0
- tts/cosyvoice/cli/__init__.py +0 -0
- tts/cosyvoice/cli/cosyvoice.py +197 -0
- tts/cosyvoice/cli/frontend.py +240 -0
- tts/cosyvoice/cli/model.py +480 -0
- tts/cosyvoice/dataset/__init__.py +0 -0
- tts/cosyvoice/dataset/dataset.py +164 -0
- tts/cosyvoice/dataset/processor.py +435 -0
- tts/cosyvoice/flow/decoder.py +301 -0
- tts/cosyvoice/flow/flow.py +239 -0
- tts/cosyvoice/flow/flow_matching.py +217 -0
- tts/cosyvoice/flow/length_regulator.py +69 -0
- tts/cosyvoice/hifigan/discriminator.py +140 -0
- tts/cosyvoice/hifigan/f0_predictor.py +56 -0
- tts/cosyvoice/hifigan/generator.py +412 -0
- tts/cosyvoice/hifigan/hifigan.py +67 -0
- tts/cosyvoice/llm/llm.py +434 -0
- tts/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +0 -0
- tts/cosyvoice/tokenizer/tokenizer.py +279 -0
app.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import base64
|
3 |
+
import datetime
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import spaces
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import sys
|
11 |
+
import time
|
12 |
+
import traceback
|
13 |
+
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from common_utils.utils4infer import get_feat_from_wav_path, load_model_and_tokenizer, token_list2wav
|
17 |
+
|
18 |
+
sys.path.insert(0, '.')
|
19 |
+
sys.path.insert(0, './tts')
|
20 |
+
sys.path.insert(0, './tts/third_party/Matcha-TTS')
|
21 |
+
from patches import modelling_qwen2_infer_gpu # 打patch
|
22 |
+
from tts.cosyvoice.cli.cosyvoice import CosyVoice
|
23 |
+
from tts.cosyvoice.utils.file_utils import load_wav
|
24 |
+
|
25 |
+
is_npu = False
|
26 |
+
try:
|
27 |
+
import torch_npu
|
28 |
+
except ImportError:
|
29 |
+
is_npu = False
|
30 |
+
print("torch_npu is not available. if you want to use npu, please install it.")
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
from huggingface_hub import hf_hub_download
|
36 |
+
# 从Hugging Face下载.pt文件
|
37 |
+
CHECKPOINT_PATH_A = hf_hub_download(repo_id="ASLP-lab/OSUM-EChat", filename="language_think_final.pt")
|
38 |
+
CHECKPOINT_PATH_B= hf_hub_download(repo_id="ASLP-lab/OSUM-EChat", filename="tag_think_final.pt")
|
39 |
+
CONFIG_PATH = "./conf/ct_config.yaml"
|
40 |
+
NAME_A="language_think"
|
41 |
+
NAME_B="tag_think"
|
42 |
+
cosyvoice_model_path = hf_hub_download(repo_id="ASLP-lab/OSUM-EChat", filename="CosyVoice-300M-25Hz.tar")
|
43 |
+
# 将tar包解压到当前目录
|
44 |
+
os.system(f"tar -xvf {cosyvoice_model_path}")
|
45 |
+
print("解压cosyvoice模型pt文件完成")
|
46 |
+
cosyvoice_model_path="./CosyVoice-300M-25Hz"
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
device = torch.device("cuda")
|
51 |
+
print("开始加载模型 A...")
|
52 |
+
model_a, tokenizer_a = load_model_and_tokenizer(CHECKPOINT_PATH_A, CONFIG_PATH)
|
53 |
+
|
54 |
+
print("\n开始加载模型 B...")
|
55 |
+
if CHECKPOINT_PATH_B is not None:
|
56 |
+
model_b, tokenizer_b = load_model_and_tokenizer(CHECKPOINT_PATH_B, CONFIG_PATH)
|
57 |
+
else:
|
58 |
+
model_b, tokenizer_b = None, None
|
59 |
+
loaded_models = {
|
60 |
+
NAME_A: {"model": model_a, "tokenizer": tokenizer_a},
|
61 |
+
NAME_B: {"model": model_b, "tokenizer": tokenizer_b},
|
62 |
+
} if model_b is not None else {
|
63 |
+
NAME_A: {"model": model_a, "tokenizer": tokenizer_a},
|
64 |
+
}
|
65 |
+
print("\n所有模型已加载完毕。")
|
66 |
+
|
67 |
+
cosyvoice = CosyVoice(cosyvoice_model_path, gpu_id=0)
|
68 |
+
|
69 |
+
# 将图片转换为 Base64
|
70 |
+
with open("./tts/assert/实验室.png", "rb") as image_file:
|
71 |
+
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
|
72 |
+
|
73 |
+
# 任务映射
|
74 |
+
TASK_PROMPT_MAPPING = {
|
75 |
+
"empathetic_s2s_dialogue with think": "THINK",
|
76 |
+
"empathetic_s2s_dialogue no think": "s2s_no_think",
|
77 |
+
"empathetic_s2t_dialogue with think": "s2t_think",
|
78 |
+
"empathetic_s2t_dialogue no think": "s2t_no_think",
|
79 |
+
"ASR (Automatic Speech Recognition)": "转录这段音频中的语音内容为文字。",
|
80 |
+
"SRWT (Speech Recognition with Timestamps)": "请识别音频内容,并对所有英文词和中文字进行时间对齐,标注格式为<>,时间精度0.1秒。",
|
81 |
+
"VED (Vocal Event Detection)(类别:laugh,cough,cry,screaming,sigh,throat clearing,sneeze,other)": "请将音频转化为文字,并在末尾添加相关音频事件标签,标签格式为<>。",
|
82 |
+
"SER (Speech Emotion Recognition)(类别:sad,anger,neutral,happy,surprise,fear,disgust,和other)": "请将音频内容转录成文字记录,并在记录末尾标注情感标签,以<>表示。",
|
83 |
+
"SSR (Speaking Style Recognition)(类别:新闻科普,恐怖故事,童话故事,客服,诗歌散文,有声书,日常口语,其他)": "请将音频中的讲话内容转化为文字,并在结尾处注明风格标签,用<>表示。",
|
84 |
+
"SGC (Speaker Gender Classification)(类别:female,male)": "请将音频转录为文字,并在文本末尾标注性别标签,标签格式为<>。",
|
85 |
+
"SAP (Speaker Age Prediction)(类别:child、adult和old)": "请将这段音频转录成文字,并在末尾加上年龄标签,格式为<>。",
|
86 |
+
"STTC (Speech to Text Chat)": "首先将语音转录为文字,然后对语音内容进行回复,转录和文字之间使用<开始回答>分割。",
|
87 |
+
"Only Age Prediction(类别:child、adult和old)": "请根据音频分析发言者的年龄并输出年龄标签,标签格式为<>。",
|
88 |
+
"Only Gender Classification(类别:female,male)": "根据下述音频内容判断说话者性别,返回性别标签,格式为<>.",
|
89 |
+
"Only Style Recognition(类别:新闻科普,恐怖故事,童话故事,客服,诗歌散文,有声书,日常口语,其他)": "对于以下音频,请直接判断风格并返回风格标签,标签格式为<>。",
|
90 |
+
"Only Emotion Recognition(类别:sad,anger,neutral,happy,surprise,fear,disgust,和other)": "请鉴别音频中的发言者情感并标出,标签格式为<>。",
|
91 |
+
"Only Event Detection(类别:laugh,cough,cry,screaming,sigh,throat clearing,sneeze,other)": "对音频进行标签化,返回音频事件标签,标签格式为<>。",
|
92 |
+
"ASR+AGE+GENDER": '请将这段音频进行转录,并在转录完成的文本末尾附加<年龄> <性别>标签。',
|
93 |
+
"AGE+GENDER": "请识别以下音频发言者的年龄和性别.",
|
94 |
+
"ASR+STYLE+AGE+GENDER": "请对以下音频内容进行转录,并在文本结尾分别���加<风格>、<年龄>、<性别>标签。",
|
95 |
+
"STYLE+AGE+GENDER": "请对以下音频进行分析,识别说话风格、说话者年龄和性别。",
|
96 |
+
"ASR with punctuations": "需对提供的语音文件执行文本转换,同时为转换结果补充必要的标点。",
|
97 |
+
"ASR EVENT AGE GENDER": "请将以下音频内容进行转录,并在转录完成的文本末尾分别附加<音频事件>、<年龄>、<性别>标签。",
|
98 |
+
"ASR EMOTION AGE GENDER": "请将下列音频内容进行转录,并在转录文本的末尾分别添加<情感>、<年龄>、<性别>标签。",
|
99 |
+
}
|
100 |
+
prompt_path = hf_hub_download(repo_id="ASLP-lab/OSUM-EChat", filename="prompt.wav")
|
101 |
+
prompt_audio_choices = [
|
102 |
+
{"name": "拟人",
|
103 |
+
"value": prompt_path},
|
104 |
+
]
|
105 |
+
|
106 |
+
prompt_audio_cache = {}
|
107 |
+
for item in prompt_audio_choices:
|
108 |
+
prompt_audio_cache[item["value"]] = load_wav(item["value"], 22050)
|
109 |
+
|
110 |
+
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
def do_s2t(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
|
115 |
+
model.eval()
|
116 |
+
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
117 |
+
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
118 |
+
if is_npu: torch_npu.npu.synchronize()
|
119 |
+
start_time = time.time()
|
120 |
+
res_text = model.generate(wavs=feat, wavs_len=feat_lens, prompt=input_prompt, cache_implementation="static")[0]
|
121 |
+
if is_npu: torch_npu.npu.synchronize()
|
122 |
+
end_time = time.time()
|
123 |
+
print(f"S2T 推理消耗时间: {end_time - start_time:.2f} 秒")
|
124 |
+
return res_text
|
125 |
+
|
126 |
+
|
127 |
+
def do_s2t4chat(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
|
128 |
+
model.eval()
|
129 |
+
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
130 |
+
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
131 |
+
if is_npu: torch_npu.npu.synchronize()
|
132 |
+
start_time = time.time()
|
133 |
+
res_text = model.generate4chat(wavs=feat, wavs_len=feat_lens, cache_implementation="static")[0]
|
134 |
+
if is_npu: torch_npu.npu.synchronize()
|
135 |
+
end_time = time.time()
|
136 |
+
print(f"S2T4Chat 推理消耗时间: {end_time - start_time:.2f} 秒")
|
137 |
+
return res_text
|
138 |
+
|
139 |
+
def do_s2t4chat_think(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
|
140 |
+
model.eval()
|
141 |
+
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
142 |
+
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
143 |
+
if is_npu: torch_npu.npu.synchronize()
|
144 |
+
start_time = time.time()
|
145 |
+
res_text = model.generate4chat_think(wavs=feat, wavs_len=feat_lens, cache_implementation="static")[0]
|
146 |
+
if is_npu: torch_npu.npu.synchronize()
|
147 |
+
end_time = time.time()
|
148 |
+
print(f"S2T4Chat 推理消耗时间: {end_time - start_time:.2f} 秒")
|
149 |
+
return res_text
|
150 |
+
|
151 |
+
|
152 |
+
def do_t2s(model, input_prompt, text_for_tts, profile=False): # 增加 model 参数
|
153 |
+
model.eval()
|
154 |
+
if is_npu: torch_npu.npu.synchronize()
|
155 |
+
start_time = time.time()
|
156 |
+
res_tensor = model.generate_tts(device=device, text=text_for_tts, )[0]
|
157 |
+
res_token_list = res_tensor.tolist()
|
158 |
+
res_text = res_token_list[:-1]
|
159 |
+
if is_npu: torch_npu.npu.synchronize()
|
160 |
+
end_time = time.time()
|
161 |
+
print(f"T2S 推理消耗时间: {end_time - start_time:.2f} 秒")
|
162 |
+
return res_text
|
163 |
+
|
164 |
+
|
165 |
+
def do_t2t(model, question_txt, profile=False): # 增加 model 参数
|
166 |
+
model.eval()
|
167 |
+
if is_npu: torch_npu.npu.synchronize()
|
168 |
+
start_time = time.time()
|
169 |
+
print(f'开始t2t推理, question_txt: {question_txt}')
|
170 |
+
res_text = model.generate_text2text(device=device, text=question_txt)[0]
|
171 |
+
if is_npu: torch_npu.npu.synchronize()
|
172 |
+
end_time = time.time()
|
173 |
+
print(f"T2T 推理消耗时间: {end_time - start_time:.2f} 秒")
|
174 |
+
return res_text
|
175 |
+
|
176 |
+
|
177 |
+
def do_s2s(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
|
178 |
+
model.eval()
|
179 |
+
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
180 |
+
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
181 |
+
if is_npu: torch_npu.npu.synchronize()
|
182 |
+
start_time = time.time()
|
183 |
+
output_text, text_res, speech_res = model.generate_s2s_no_stream_with_repetition_penalty(wavs=feat, wavs_len=feat_lens,)
|
184 |
+
if is_npu: torch_npu.npu.synchronize()
|
185 |
+
end_time = time.time()
|
186 |
+
print(f"S2S 推理消耗时间: {end_time - start_time:.2f} 秒")
|
187 |
+
return f'{output_text[0]}|{str(speech_res[0].tolist()[1:])}'
|
188 |
+
|
189 |
+
def do_s2s_think(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
|
190 |
+
model.eval()
|
191 |
+
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
192 |
+
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
193 |
+
if is_npu: torch_npu.npu.synchronize()
|
194 |
+
start_time = time.time()
|
195 |
+
output_text, text_res, speech_res = model.generate_s2s_no_stream_think_with_repetition_penalty(wavs=feat, wavs_len=feat_lens,)
|
196 |
+
if is_npu: torch_npu.npu.synchronize()
|
197 |
+
end_time = time.time()
|
198 |
+
print(f"S2S 推理消耗时间: {end_time - start_time:.2f} 秒")
|
199 |
+
return f'{output_text[0]}|{str(speech_res[0].tolist()[1:])}'
|
200 |
+
|
201 |
+
@spaces.GPU
|
202 |
+
def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt): # 增加 model 和 tokenizer 参数
|
203 |
+
print(f"wav_path: {input_wav_path}, prompt:{input_prompt}")
|
204 |
+
if input_wav_path is None and not input_prompt.endswith(("_TTS", "_T2T")):
|
205 |
+
print("音频信息未输入,且不是T2S或T2T任务")
|
206 |
+
return "错误:需要音频输入"
|
207 |
+
|
208 |
+
if input_prompt.endswith("_TTS"):
|
209 |
+
text_for_tts = input_prompt.replace("_TTS", "")
|
210 |
+
prompt = "恳请将如下文本转换为其对应的语音token,力求生成最为流畅、自然的语音。"
|
211 |
+
res_text = do_t2s(model, prompt, text_for_tts)
|
212 |
+
elif input_prompt.endswith("_self_prompt"):
|
213 |
+
prompt = input_prompt.replace("_self_prompt", "")
|
214 |
+
res_text = do_s2t(model, input_wav_path, prompt)
|
215 |
+
elif input_prompt.endswith("_T2T"):
|
216 |
+
question_txt = input_prompt.replace("_T2T", "")
|
217 |
+
res_text = do_t2t(model, question_txt)
|
218 |
+
elif input_prompt in ["识别语音内容,并以文字方式作出回答。",
|
219 |
+
"请推断对这段语音回答时的情感,标注情感类型,撰写流畅自然的聊天回复,并生成情感语音token。",
|
220 |
+
"s2s_no_think"]:
|
221 |
+
res_text = do_s2s(model, input_wav_path, input_prompt)
|
222 |
+
elif input_prompt == "THINK":
|
223 |
+
res_text = do_s2s_think(model, input_wav_path, input_prompt)
|
224 |
+
elif input_prompt == "s2t_no_think":
|
225 |
+
res_text = do_s2t4chat(model, input_wav_path, input_prompt)
|
226 |
+
elif input_prompt == "s2t_think":
|
227 |
+
res_text = do_s2t4chat_think(model, input_wav_path, input_prompt)
|
228 |
+
else:
|
229 |
+
res_text = do_s2t(model, input_wav_path, input_prompt)
|
230 |
+
res_text = res_text.replace("<youth>", "<adult>").replace("<middle_age>", "<adult>").replace("<middle>",
|
231 |
+
"<adult>")
|
232 |
+
|
233 |
+
print("识别结果为:", res_text)
|
234 |
+
return res_text
|
235 |
+
|
236 |
+
|
237 |
+
def do_decode(model, tokenizer, input_wav_path, input_prompt): # 增加 model 和 tokenizer 参数
|
238 |
+
print(f'使用模型进行推理: input_wav_path={input_wav_path}, input_prompt={input_prompt}')
|
239 |
+
output_res = true_decode_fuc(model, tokenizer, input_wav_path, input_prompt)
|
240 |
+
return output_res
|
241 |
+
|
242 |
+
|
243 |
+
def save_to_jsonl(if_correct, wav, prompt, res):
|
244 |
+
data = {
|
245 |
+
"if_correct": if_correct,
|
246 |
+
"wav": wav,
|
247 |
+
"task": prompt,
|
248 |
+
"res": res
|
249 |
+
}
|
250 |
+
with open("results.jsonl", "a", encoding="utf-8") as f:
|
251 |
+
f.write(json.dumps(data, ensure_ascii=False) + "\n")
|
252 |
+
|
253 |
+
|
254 |
+
def download_audio(input_wav_path):
|
255 |
+
return input_wav_path if input_wav_path else None
|
256 |
+
|
257 |
+
|
258 |
+
def get_wav_from_token_list(input_list, prompt_speech):
|
259 |
+
time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
260 |
+
wav_path = f"./tmp/{time_str}.wav"
|
261 |
+
return token_list2wav(input_list, prompt_speech, wav_path, cosyvoice)
|
262 |
+
|
263 |
+
|
264 |
+
# --- Gradio 界面 ---
|
265 |
+
with gr.Blocks() as demo:
|
266 |
+
gr.Markdown(
|
267 |
+
f"""
|
268 |
+
<div style="display: flex; align-items: center; justify-content: center; text-align: center;">
|
269 |
+
<h1 style="font-family: 'Arial', sans-serif; color: #014377; font-size: 32px; margin-bottom: 0; display: inline-block; vertical-align: middle;">
|
270 |
+
OSUM Speech Understanding Model Test
|
271 |
+
</h1>
|
272 |
+
</div>
|
273 |
+
"""
|
274 |
+
)
|
275 |
+
|
276 |
+
# ### --- 关键修改:添加模型选择器 --- ###
|
277 |
+
with gr.Row():
|
278 |
+
model_selector = gr.Radio(
|
279 |
+
choices=list(loaded_models.keys()), # 从加载的模型字典中获取选项
|
280 |
+
value=NAME_A, # 默认值
|
281 |
+
label="选择推理模型",
|
282 |
+
interactive=True
|
283 |
+
)
|
284 |
+
|
285 |
+
with gr.Row():
|
286 |
+
with gr.Column(scale=1, min_width=300):
|
287 |
+
audio_input = gr.Audio(label="录音", sources=["microphone", "upload"], type="filepath", visible=True)
|
288 |
+
with gr.Column(scale=1, min_width=300):
|
289 |
+
output_text = gr.Textbox(label="输出结果", lines=6, placeholder="生成的结果将显示在这里...",
|
290 |
+
interactive=False)
|
291 |
+
|
292 |
+
with gr.Row():
|
293 |
+
task_dropdown = gr.Dropdown(label="任务",
|
294 |
+
choices=list(TASK_PROMPT_MAPPING.keys()) + ["自主输入文本", "TTS任务", "T2T任务"],
|
295 |
+
value="empathetic_s2s_dialogue with think")
|
296 |
+
prompt_speech_dropdown = gr.Dropdown(label="参考音频(prompt_speech)",
|
297 |
+
choices=[(item["name"], item["value"]) for item in prompt_audio_choices],
|
298 |
+
value=prompt_audio_choices[0]["value"], visible=True)
|
299 |
+
custom_prompt_input = gr.Textbox(label="自定义任务提示", placeholder="请输入自定义任务提示...", visible=False)
|
300 |
+
tts_input = gr.Textbox(label="TTS输入文本", placeholder="请输入TTS任务的文本...", visible=False)
|
301 |
+
t2t_input = gr.Textbox(label="T2T输入文本", placeholder="请输入T2T任务的文本...", visible=False)
|
302 |
+
|
303 |
+
audio_player = gr.Audio(label="播放音频", type="filepath", interactive=False)
|
304 |
+
|
305 |
+
with gr.Row():
|
306 |
+
download_button = gr.DownloadButton("下载音频", variant="secondary",
|
307 |
+
elem_classes=["button-height", "download-button"])
|
308 |
+
submit_button = gr.Button("开始处理", variant="primary", elem_classes=["button-height", "submit-button"])
|
309 |
+
|
310 |
+
with gr.Row(visible=False) as confirmation_row:
|
311 |
+
# ... (确认组件不变)
|
312 |
+
gr.Markdown("请判断结果是否正确:")
|
313 |
+
confirmation_buttons = gr.Radio(choices=["正确", "错误"], label="", interactive=True, container=False,
|
314 |
+
elem_classes="confirmation-buttons")
|
315 |
+
save_button = gr.Button("提交反馈", variant="secondary")
|
316 |
+
|
317 |
+
# ... (底部内容不变)
|
318 |
+
with gr.Row():
|
319 |
+
with gr.Column(scale=1, min_width=800):
|
320 |
+
gr.Markdown(f"""...""") # 省略底部HTML
|
321 |
+
|
322 |
+
|
323 |
+
def show_confirmation(output_res, input_wav_path, input_prompt):
|
324 |
+
return gr.update(visible=True), output_res, input_wav_path, input_prompt
|
325 |
+
|
326 |
+
|
327 |
+
def save_result(if_correct, wav, prompt, res):
|
328 |
+
save_to_jsonl(if_correct, wav, prompt, res)
|
329 |
+
return gr.update(visible=False)
|
330 |
+
|
331 |
+
|
332 |
+
# handle_submit 函数现在接收 `selected_model_name` 参数
|
333 |
+
def handle_submit(selected_model_name, input_wav_path, task_choice, custom_prompt, tts_text, t2t_text,
|
334 |
+
prompt_speech):
|
335 |
+
# 1. 根据选择的模型名称,从字典中获取对应的模型和分词器
|
336 |
+
print(f"用户选择了: {selected_model_name}")
|
337 |
+
model_info = loaded_models[selected_model_name]
|
338 |
+
model_to_use = model_info["model"]
|
339 |
+
tokenizer_to_use = model_info["tokenizer"]
|
340 |
+
|
341 |
+
# 2. 准备 prompt
|
342 |
+
prompt_speech_data = prompt_audio_cache.get(prompt_speech)
|
343 |
+
if task_choice == "自主输入文本":
|
344 |
+
input_prompt = custom_prompt + "_self_prompt"
|
345 |
+
elif task_choice == "TTS任务":
|
346 |
+
input_prompt = tts_text + "_TTS"
|
347 |
+
elif task_choice == "T2T任务":
|
348 |
+
input_prompt = t2t_text + "_T2T"
|
349 |
+
else:
|
350 |
+
input_prompt = TASK_PROMPT_MAPPING.get(task_choice, "未知任务类型")
|
351 |
+
|
352 |
+
# 3. 调用重构后的推理函数,传入选择的模型
|
353 |
+
output_res = do_decode(model_to_use, tokenizer_to_use, input_wav_path, input_prompt)
|
354 |
+
|
355 |
+
# 4. 处理输出 (逻辑不变)
|
356 |
+
wav_path_output = input_wav_path
|
357 |
+
if task_choice == "TTS任务" or "empathetic_s2s_dialogue" in task_choice:
|
358 |
+
if isinstance(output_res, list): # TTS case
|
359 |
+
wav_path_output = get_wav_from_token_list(output_res, prompt_speech_data)
|
360 |
+
output_res = "生成的token: " + str(output_res)
|
361 |
+
elif isinstance(output_res, str) and "|" in output_res: # S2S case
|
362 |
+
try:
|
363 |
+
text_res, token_list_str = output_res.split("|")
|
364 |
+
token_list = json.loads(token_list_str)
|
365 |
+
wav_path_output = get_wav_from_token_list(token_list, prompt_speech_data)
|
366 |
+
output_res = text_res
|
367 |
+
except (ValueError, json.JSONDecodeError) as e:
|
368 |
+
print(f"处理S2S输出时出错: {e}")
|
369 |
+
output_res = f"错误:无法解析模型输出 - {output_res}"
|
370 |
+
|
371 |
+
return output_res, wav_path_output
|
372 |
+
|
373 |
+
|
374 |
+
# --- 绑定事件 (下拉框逻辑不变) ---
|
375 |
+
task_dropdown.change(fn=lambda choice: gr.update(visible=choice == "自主输入文本"), inputs=task_dropdown,
|
376 |
+
outputs=custom_prompt_input)
|
377 |
+
task_dropdown.change(fn=lambda choice: gr.update(visible=choice == "TTS任务"), inputs=task_dropdown,
|
378 |
+
outputs=tts_input)
|
379 |
+
task_dropdown.change(fn=lambda choice: gr.update(visible=choice == "T2T任务"), inputs=task_dropdown,
|
380 |
+
outputs=t2t_input)
|
381 |
+
|
382 |
+
submit_button.click(
|
383 |
+
fn=handle_submit,
|
384 |
+
# 在 inputs 列表中添加模型选择器 `model_selector`
|
385 |
+
inputs=[model_selector, audio_input, task_dropdown, custom_prompt_input, tts_input, t2t_input,
|
386 |
+
prompt_speech_dropdown],
|
387 |
+
outputs=[output_text, audio_player]
|
388 |
+
).then(
|
389 |
+
fn=show_confirmation,
|
390 |
+
inputs=[output_text, audio_input, task_dropdown],
|
391 |
+
outputs=[confirmation_row, output_text, audio_input, task_dropdown]
|
392 |
+
)
|
393 |
+
|
394 |
+
download_button.click(fn=download_audio, inputs=[audio_input], outputs=[download_button])
|
395 |
+
save_button.click(fn=save_result, inputs=[confirmation_buttons, audio_input, task_dropdown, output_text],
|
396 |
+
outputs=confirmation_row)
|
397 |
+
|
398 |
+
# --- 关键修改:为两个模型分别进行预热 ---
|
399 |
+
print("开始预热模型...")
|
400 |
+
warmup_wav_path = "./tts/assert/hq_1.wav"
|
401 |
+
warmup_prompt = "将这段音频的语音内容详细记录为文字稿。"
|
402 |
+
|
403 |
+
for model_name, model_info in loaded_models.items():
|
404 |
+
print(f"正在预热 {model_name}...")
|
405 |
+
try:
|
406 |
+
# 使用重构后的 do_s2t 函数进行预热,传入对应的模型
|
407 |
+
res_text = do_s2t(model_info["model"], warmup_wav_path, warmup_prompt, profile=False)
|
408 |
+
print(f'{model_name} 预热完成。ASR推理结果: {res_text}')
|
409 |
+
except Exception as e:
|
410 |
+
print(f"预热 {model_name} 时发生错误: {e}")
|
411 |
+
|
412 |
+
# 启动Gradio界面
|
413 |
+
print("\nGradio 界面启动中...")
|
414 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
common_utils/__init__.py
ADDED
File without changes
|
common_utils/convert_ckpt_dir_to_pt.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gxl_ai_utils.utils import utils_file
|
2 |
+
import torch
|
3 |
+
try:
|
4 |
+
import torch_npu
|
5 |
+
except:
|
6 |
+
pass
|
7 |
+
import os
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
def convert_ckpt_to_pt(pt_dir_path):
|
12 |
+
exp_dir = os.path.dirname(pt_dir_path)
|
13 |
+
pt_name = os.path.basename(pt_dir_path)
|
14 |
+
weight_dict = torch.load(f"{exp_dir}/{pt_name}/mp_rank_00_model_states.pt", map_location=torch.device('cpu'))[
|
15 |
+
'module']
|
16 |
+
print(weight_dict.keys())
|
17 |
+
torch.save(weight_dict, f"{exp_dir}/{pt_name}.pt")
|
18 |
+
|
19 |
+
if __name__ == '__main__':
|
20 |
+
pt_dir_path, = utils_file.do_get_commandline_param(1, ["pt_dir_path"])
|
21 |
+
exp_dir = os.path.dirname(pt_dir_path)
|
22 |
+
pt_name = os.path.basename(pt_dir_path)
|
23 |
+
weight_dict = torch.load(f"{exp_dir}/{pt_name}/mp_rank_00_model_states.pt", map_location=torch.device('cpu'))[
|
24 |
+
'module']
|
25 |
+
print(weight_dict.keys())
|
26 |
+
torch.save(weight_dict, f"{exp_dir}/{pt_name}.pt")
|
27 |
+
# weigth_dict = torch.load("/mnt/sfs/asr/code/wenet_undersdand_and_speech_xlgeng/examples/wenetspeech/whisper/exp/epoch24_cosyvoice1_new-set_token_1w_plus-multi_task_new/step_24999.pt")
|
common_utils/load_combine_type_yaml.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
|
5 |
+
from gxl_ai_utils.utils import utils_file
|
6 |
+
|
7 |
+
data_config_path, tmp_file_path = utils_file.do_get_commandline_param(2)
|
8 |
+
# random.seed(10086)# 老的
|
9 |
+
# 把当前时间戳作为随机种子
|
10 |
+
random.seed(int(time.time()))
|
11 |
+
# random.seed(7891)# 尝试一下新的顺序 #7890
|
12 |
+
data_info_dict = utils_file.load_dict_from_yaml(data_config_path)
|
13 |
+
if data_info_dict is None:
|
14 |
+
data_info_dict = {}
|
15 |
+
total_list = []
|
16 |
+
for data_info in data_info_dict.values():
|
17 |
+
if "path" not in data_info:
|
18 |
+
print(f"path or weight not in data_info:{data_info}")
|
19 |
+
continue
|
20 |
+
if "weight" not in data_info:
|
21 |
+
data_weight = 1
|
22 |
+
else:
|
23 |
+
data_weight = int(float(data_info['weight']))
|
24 |
+
data_path_i = data_info['path']
|
25 |
+
utils_file.logging_info(f'path:{data_path_i} ')
|
26 |
+
|
27 |
+
if data_weight == 0:
|
28 |
+
data_weight = float(data_info['weight'])
|
29 |
+
if data_weight >= 0:
|
30 |
+
utils_file.logging_info(f'data {data_path_i} weight is {data_weight}, will be used as a list')
|
31 |
+
final_data_list_i_tmp = utils_file.load_list_file_clean(data_path_i)
|
32 |
+
true_num = int(len(final_data_list_i_tmp)*data_weight)
|
33 |
+
final_data_list_i = utils_file.do_get_random_sublist(final_data_list_i_tmp, true_num)
|
34 |
+
else:
|
35 |
+
final_data_list_i = utils_file.load_list_file_clean(data_path_i) * data_weight
|
36 |
+
# 判断数据类型
|
37 |
+
if "combines_list.txt" in data_path_i:
|
38 |
+
print(f'是 combine类型的数据')
|
39 |
+
tar_root_path = data_path_i.replace('combines_list.txt', 'combines_tar_root.txt')
|
40 |
+
if not os.path.exists(tar_root_path):
|
41 |
+
utils_file.logging_info(f'combine_list.txt:{data_path_i} 对应的 combines_tar_root.txt:{tar_root_path} 不存在')
|
42 |
+
continue
|
43 |
+
tar_root = utils_file.load_first_row_clean(tar_root_path)
|
44 |
+
if tar_root.endswith('/'):
|
45 |
+
tar_root = tar_root[:-1]
|
46 |
+
utils_file.logging_info(f' tar_root:{tar_root}')
|
47 |
+
new_final_data_list_i = []
|
48 |
+
for data_path_j in final_data_list_i:
|
49 |
+
# "combine_path|shard_path"
|
50 |
+
tmp_lines = f'{data_path_j}|{tar_root}/{utils_file.do_get_file_pure_name_from_path(data_path_j)}.tar'
|
51 |
+
new_final_data_list_i.append(tmp_lines)
|
52 |
+
else:
|
53 |
+
print(f'不是 combine类型的数据,是传统shard类型的数据')
|
54 |
+
new_final_data_list_i = [f'-|{data_path_j}' for data_path_j in final_data_list_i]
|
55 |
+
|
56 |
+
utils_file.logging_info(f'true load num is : {len(new_final_data_list_i)}')
|
57 |
+
total_list.extend(new_final_data_list_i)
|
58 |
+
random.shuffle(total_list)
|
59 |
+
utils_file.write_list_to_file(total_list, tmp_file_path)
|
common_utils/utils4infer.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import re
|
5 |
+
|
6 |
+
import yaml
|
7 |
+
from cn2an import an2cn
|
8 |
+
from gxl_ai_utils.utils import utils_file
|
9 |
+
from wenet.utils.init_tokenizer import init_tokenizer
|
10 |
+
from gxl_ai_utils.config.gxl_config import GxlNode
|
11 |
+
from wenet.utils.init_model import init_model
|
12 |
+
import logging
|
13 |
+
import librosa
|
14 |
+
import torch
|
15 |
+
import torchaudio
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
def load_model_and_tokenizer(checkpoint_path, config_path, device:torch.device=torch.device('cuda')):
|
20 |
+
"""
|
21 |
+
封装了加载模型和分词器的逻辑
|
22 |
+
Args:
|
23 |
+
checkpoint_path (str): 模型权重文件路径
|
24 |
+
config_path (str): 模型配置文件路径
|
25 |
+
device (torch.device): 加载模型的设备
|
26 |
+
Returns:
|
27 |
+
model: 加载好的模型
|
28 |
+
tokenizer: 加载好的分词器
|
29 |
+
"""
|
30 |
+
print(f"正在从以下路径加载模型: {checkpoint_path}")
|
31 |
+
args = GxlNode({"checkpoint": checkpoint_path})
|
32 |
+
configs = utils_file.load_dict_from_yaml(config_path)
|
33 |
+
model, configs = init_model(args, configs)
|
34 |
+
model = model.to(device).to(torch.bfloat16)
|
35 |
+
model.eval() # 设置为评估模式
|
36 |
+
tokenizer = init_tokenizer(configs)
|
37 |
+
print(f"模型 {checkpoint_path} 加载完成并移动到 {device}")
|
38 |
+
return model, tokenizer
|
39 |
+
|
40 |
+
def token_list2wav(token_list, prompt_speech, wav_path, cosyvoice):
|
41 |
+
token_list = [int(i) for i in token_list]
|
42 |
+
j = cosyvoice.inference_zero_shot_gz_22k(
|
43 |
+
'收到好友从远方寄来的生日礼物。',
|
44 |
+
'希望你以后能够做的比我还好呦。', prompt_speech, stream=False, token_list=token_list)
|
45 |
+
utils_file.makedir_for_file(wav_path)
|
46 |
+
torchaudio.save(wav_path, j['tts_speech'],cosyvoice.sample_rate)
|
47 |
+
print(f'语音合成完成,保存到 {wav_path}')
|
48 |
+
return wav_path
|
49 |
+
|
50 |
+
def do_resample(input_wav_path):
|
51 |
+
"""..."""
|
52 |
+
waveform, sample_rate = torchaudio.load(input_wav_path)
|
53 |
+
if waveform.shape[0] > 1:
|
54 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
55 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
|
56 |
+
waveform = resampler(waveform)
|
57 |
+
return waveform, 16000
|
58 |
+
|
59 |
+
|
60 |
+
def get_feat_from_wav_path(input_wav_path, device:torch.device=torch.device('cuda')):
|
61 |
+
"""..."""
|
62 |
+
waveform, sample_rate = do_resample(input_wav_path)
|
63 |
+
waveform = waveform.squeeze(0)
|
64 |
+
window = torch.hann_window(400)
|
65 |
+
stft = torch.stft(waveform, 400, 160, window=window, return_complex=True)
|
66 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
67 |
+
filters = torch.from_numpy(librosa.filters.mel(sr=sample_rate, n_fft=400, n_mels=80))
|
68 |
+
mel_spec = filters @ magnitudes
|
69 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
70 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
71 |
+
log_spec = (log_spec + 4.0) / 4.0
|
72 |
+
feat = log_spec.transpose(0, 1)
|
73 |
+
feat_lens = torch.tensor([feat.shape[0]], dtype=torch.int64).to(device)
|
74 |
+
feat = feat.unsqueeze(0).to(device)
|
75 |
+
feat = feat.to(torch.bfloat16)
|
76 |
+
return feat, feat_lens
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
def do_format_shard_manifest4one(input_shards_path, tmp_file_path=None):
|
81 |
+
if tmp_file_path is None:
|
82 |
+
tmp_file_path = f'~/.cache/.temp/{random.randint(10000, 99999)}.txt'
|
83 |
+
data_path_i = input_shards_path
|
84 |
+
utils_file.logging_info(f'path:{data_path_i} ')
|
85 |
+
final_data_list_i = utils_file.load_list_file_clean(data_path_i)
|
86 |
+
# 判断数据类型
|
87 |
+
if "combines_list.txt" in data_path_i:
|
88 |
+
print(f'是 combine类型的数据')
|
89 |
+
tar_root_path = data_path_i.replace('combines_list.txt', 'combines_tar_root.txt')
|
90 |
+
if not os.path.exists(tar_root_path):
|
91 |
+
utils_file.logging_error(
|
92 |
+
f'combine_list.txt:{data_path_i} 对应的 combines_tar_root.txt:{tar_root_path} 不存在')
|
93 |
+
return
|
94 |
+
tar_root = utils_file.load_first_row_clean(tar_root_path)
|
95 |
+
if tar_root.endswith('/'):
|
96 |
+
tar_root = tar_root[:-1]
|
97 |
+
utils_file.logging_info(f' tar_root:{tar_root}')
|
98 |
+
new_final_data_list_i = []
|
99 |
+
for data_path_j in final_data_list_i:
|
100 |
+
# "combine_path|shard_path"
|
101 |
+
tmp_lines = f'{data_path_j}|{tar_root}/{utils_file.do_get_file_pure_name_from_path(data_path_j)}.tar'
|
102 |
+
new_final_data_list_i.append(tmp_lines)
|
103 |
+
else:
|
104 |
+
print(f'不是 combine类型的数据,是传统shard类型的数据')
|
105 |
+
new_final_data_list_i = [f'-|{data_path_j}' for data_path_j in final_data_list_i]
|
106 |
+
|
107 |
+
utils_file.logging_info(f'true load num is : {len(new_final_data_list_i)}')
|
108 |
+
utils_file.write_list_to_file(new_final_data_list_i, tmp_file_path)
|
109 |
+
return tmp_file_path
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
def convert_numbers_in_string(s):
|
114 |
+
# 正则表达式匹配数字(支持整数、小数、负数)
|
115 |
+
pattern = r'-?\d+\.?\d*'
|
116 |
+
|
117 |
+
def replace_func(match):
|
118 |
+
num_str = match.group()
|
119 |
+
try:
|
120 |
+
# 尝试转换数字
|
121 |
+
return an2cn(num_str)
|
122 |
+
except ValueError:
|
123 |
+
# 若转换失败(如非有效数字),返回原内容
|
124 |
+
return num_str
|
125 |
+
# 替换字符串中所有匹配的数字
|
126 |
+
return re.sub(pattern, replace_func, s)
|
127 |
+
|
128 |
+
def get_test_conf(config_path):
|
129 |
+
with open(config_path, 'r', encoding='utf-8') as fin:
|
130 |
+
print(f"加载配置文件 {config_path}")
|
131 |
+
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
132 |
+
configs['dataset_conf']['filter_conf']['filter_no_extra_info'] = False
|
133 |
+
test_conf = copy.deepcopy(configs['dataset_conf'])
|
134 |
+
|
135 |
+
# test_conf['filter_conf']['max_length'] = 3000 # whisper最长处理30s 102400
|
136 |
+
test_conf['filter_conf']['min_length'] = 10
|
137 |
+
test_conf['filter_conf']['token_max_length'] = 102400
|
138 |
+
test_conf['filter_conf']['token_min_length'] = 1
|
139 |
+
test_conf['filter_conf']['max_output_input_ratio'] = 102400
|
140 |
+
test_conf['filter_conf']['min_output_input_ratio'] = 0
|
141 |
+
test_conf['filter_conf']['filter_no_extra_info'] = False
|
142 |
+
test_conf['filter_conf']['max_seq_len'] = 102400
|
143 |
+
test_conf['speed_perturb'] = False
|
144 |
+
test_conf['spec_aug'] = False
|
145 |
+
test_conf['spec_sub'] = False
|
146 |
+
test_conf['spec_trim'] = False
|
147 |
+
test_conf['shuffle'] = False
|
148 |
+
test_conf['sort'] = False
|
149 |
+
test_conf['cycle'] = 1
|
150 |
+
test_conf['list_shuffle'] = True
|
151 |
+
if 'fbank_conf' in test_conf:
|
152 |
+
test_conf['fbank_conf']['dither'] = 0.0
|
153 |
+
elif 'mfcc_conf' in test_conf:
|
154 |
+
test_conf['mfcc_conf']['dither'] = 0.0
|
155 |
+
test_conf['batch_conf']['batch_type'] = "static"
|
156 |
+
test_conf['batch_conf']['batch_size'] = 1
|
157 |
+
test_conf['split_num'] = 1
|
158 |
+
test_conf['multi_num'] = 1
|
159 |
+
test_conf['other_filter_conf'] = {}
|
160 |
+
test_conf['data_recover'] = False
|
161 |
+
return configs, test_conf
|
162 |
+
|
163 |
+
|
conf/ct_config.yaml
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model: osum_echat
|
2 |
+
|
3 |
+
# llm_path
|
4 |
+
llm_path: &llm_path "Qwen/Qwen2.5-3B-Instruct"
|
5 |
+
|
6 |
+
#
|
7 |
+
# model config
|
8 |
+
downsample_rate: 4 # 1 2 4 8
|
9 |
+
adapter_type: osum_echat
|
10 |
+
if_instruct: true
|
11 |
+
input_dim: 80
|
12 |
+
|
13 |
+
# tokenizer ,gxl
|
14 |
+
tokenizer: huggingface
|
15 |
+
tokenizer_conf:
|
16 |
+
llm_path: *llm_path
|
17 |
+
|
18 |
+
# lora config
|
19 |
+
use_lora: false
|
20 |
+
lora_alpha: 32
|
21 |
+
lora_rank: 64 # 3B -> 85M
|
22 |
+
lora_dropout: 0.1
|
23 |
+
|
24 |
+
# speech generate config
|
25 |
+
speech_token_num: &token_num 4097 #4097
|
26 |
+
|
27 |
+
|
28 |
+
# Configuration of parameters for training
|
29 |
+
fire_module: link_and_encoder_and_lora # link encoder llm link_and_encoder link_and_encoder_and_lora, llm需要配合use_lora为true
|
30 |
+
|
31 |
+
# other config
|
32 |
+
grad_clip: 5
|
33 |
+
accum_grad: 8
|
34 |
+
log_interval: 10
|
35 |
+
save_interval: 1250 #1250 #2500
|
36 |
+
max_epoch: 1
|
37 |
+
init_step: true
|
38 |
+
|
39 |
+
# training config
|
40 |
+
optim: adamw
|
41 |
+
optim_conf:
|
42 |
+
betas:
|
43 |
+
- 0.9
|
44 |
+
- 0.99
|
45 |
+
eps: 1.0e-06
|
46 |
+
lr: 1.0e-06
|
47 |
+
weight_decay: 0.01
|
48 |
+
scheduler: warmuplr
|
49 |
+
scheduler_conf:
|
50 |
+
warmup_steps: 2000
|
51 |
+
|
52 |
+
|
53 |
+
dataset: asr
|
54 |
+
dataset_conf:
|
55 |
+
speech_token_num: *token_num
|
56 |
+
batch_conf:
|
57 |
+
batch_size: 26
|
58 |
+
batch_type: dynamic
|
59 |
+
max_frames_in_batch: 28000000 #3000 #9000 #3000 #3300 # 3900
|
60 |
+
max_seq_in_batch: 3700 #1500 #4000 #1100 #1600 # 1900
|
61 |
+
feats_type: log_mel_spectrogram
|
62 |
+
filter_conf:
|
63 |
+
max_length: 20000
|
64 |
+
min_length: 20
|
65 |
+
token_max_length: 1200
|
66 |
+
token_min_length: 1
|
67 |
+
filter_no_extra_info: true # 如果没有task lang 等信息,直接过滤掉, 适用于通用多任务训练, 推理时应该关掉
|
68 |
+
max_seq_len: 2000 #、1100 #1000
|
69 |
+
other_filter_conf:
|
70 |
+
only_s2s: false # 只针对与s2s dataloader的过滤
|
71 |
+
only_s2t: false # 只针对与s2t dataloader的过滤
|
72 |
+
only_t2t: false # 只针对与t2t dataloader的过滤
|
73 |
+
only_t2s: false # 只针对与t2s dataloader的过滤
|
74 |
+
language_conf:
|
75 |
+
limited_langs:
|
76 |
+
- zh
|
77 |
+
log_mel_spectrogram_conf:
|
78 |
+
hop_length: 160
|
79 |
+
n_fft: 400
|
80 |
+
num_mel_bins: 80
|
81 |
+
padding: 0
|
82 |
+
resample_conf:
|
83 |
+
resample_rate: 16000
|
84 |
+
shuffle: true
|
85 |
+
shuffle_conf:
|
86 |
+
shuffle_size: 1500
|
87 |
+
sort: true
|
88 |
+
sort_conf:
|
89 |
+
sort_size: 500
|
90 |
+
spec_aug: true
|
91 |
+
spec_aug_conf:
|
92 |
+
max_f: 10
|
93 |
+
max_t: 50
|
94 |
+
num_f_mask: 2
|
95 |
+
num_t_mask: 2
|
96 |
+
spec_sub: true
|
97 |
+
spec_sub_conf:
|
98 |
+
max_t: 30
|
99 |
+
num_t_sub: 3
|
100 |
+
spec_trim: false
|
101 |
+
speed_perturb: false
|
102 |
+
eod_id: 151645
|
103 |
+
split_num: 1
|
104 |
+
multi_num: 2
|
105 |
+
prompt_conf_path: conf/prompt_config.yaml
|
106 |
+
data_recover: false
|
107 |
+
data_recover_conf:
|
108 |
+
start_idx: 0 # 删除前面start_idx个item(tar包)
|
109 |
+
other_tokenze_conf: # 一些对数据额外操作的可控按钮,这些操作一般来说再test时都得为false
|
110 |
+
only_info:
|
111 |
+
only_s2s: false # 只针对与s2s dataloader的过滤
|
112 |
+
only_s2t: false # 只针对与s2t dataloader的过滤
|
113 |
+
only_t2t: false # 只针对与t2t dataloader的过滤
|
114 |
+
only_t2s: false # 只针对与t2s dataloader的过滤
|
115 |
+
use_50_per_change_if_only_X: true # 50%的句子随机替换为其only X
|
116 |
+
use_s2s_streaming_random:
|
117 |
+
enable: false
|
118 |
+
rate: 0.5 # 1.0 表示100%的句子随机替换为其only X
|
119 |
+
natural_language_convert:
|
120 |
+
enable: false
|
121 |
+
rate: 0.00 # 1.0 表示100%的转换成自然语言模式
|
122 |
+
use_s2s_convert_s2t:
|
123 |
+
enable: false # 单独为s2t dataloader 开启s2s convert
|
124 |
+
rate: 1.0 # 1.0 表示100%的句子随机替换为其only X
|
125 |
+
use_streaming_tts:
|
126 |
+
enable: false
|
127 |
+
rate: 0.5 # 1.0 表示100%的句子随机替换为其only X
|
128 |
+
use_think_mode:
|
129 |
+
enable: false # 开启think 模式, 即随机替换为think模式的句子
|
130 |
+
rate: 0.8
|
131 |
+
other_filter_conf:
|
132 |
+
fiter_txt_is_None: true # 过滤掉text is "<NONE>"的语音数据,适配由于gender数据部分含有<NONE>标签而设计。但仅train起作用
|
133 |
+
|
134 |
+
# model config for encoder
|
135 |
+
encoder: transformer
|
136 |
+
encoder_conf:
|
137 |
+
activation_type: gelu
|
138 |
+
attention_dropout_rate: 0.0
|
139 |
+
attention_heads: 16
|
140 |
+
dropout_rate: 0.1
|
141 |
+
gradient_checkpointing: true
|
142 |
+
input_layer: conv1d2
|
143 |
+
key_bias: false
|
144 |
+
linear_units: 4096
|
145 |
+
normalize_before: true
|
146 |
+
num_blocks: 24
|
147 |
+
output_size: 1024
|
148 |
+
pos_enc_layer_type: abs_pos_whisper
|
149 |
+
positional_dropout_rate: 0.1
|
150 |
+
static_chunk_size: -1
|
151 |
+
use_dynamic_chunk: false
|
152 |
+
use_dynamic_left_chunk: false
|
153 |
+
|
conf/ct_config_sft.yaml
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model: llmasr
|
2 |
+
|
3 |
+
# llm_path
|
4 |
+
llm_path: &llm_path "/home/A02_tmpdata3/ckpt/Qwen2.5-3B-Instruct"
|
5 |
+
#
|
6 |
+
# model config
|
7 |
+
downsample_rate: 4 # 1 2 4 8
|
8 |
+
adapter_type: osum_echat
|
9 |
+
if_instruct: true
|
10 |
+
input_dim: 80
|
11 |
+
|
12 |
+
# tokenizer ,gxl
|
13 |
+
tokenizer: huggingface
|
14 |
+
tokenizer_conf:
|
15 |
+
llm_path: *llm_path
|
16 |
+
|
17 |
+
# lora config
|
18 |
+
use_lora: false
|
19 |
+
lora_alpha: 32
|
20 |
+
lora_rank: 64 # 3B -> 85M
|
21 |
+
lora_dropout: 0.1
|
22 |
+
|
23 |
+
# speech generate config
|
24 |
+
speech_token_num: &token_num 4097 #4097
|
25 |
+
|
26 |
+
|
27 |
+
# Configuration of parameters for training
|
28 |
+
fire_module: link_and_encoder_and_lora # link encoder llm link_and_encoder link_and_encoder_and_lora, llm需要配合use_lora为true
|
29 |
+
|
30 |
+
# other config
|
31 |
+
grad_clip: 5
|
32 |
+
accum_grad: 8
|
33 |
+
log_interval: 10
|
34 |
+
save_interval: 125 #1250 #2500
|
35 |
+
max_epoch: 1
|
36 |
+
init_step: true
|
37 |
+
|
38 |
+
# training config
|
39 |
+
optim: adamw
|
40 |
+
optim_conf:
|
41 |
+
betas:
|
42 |
+
- 0.9
|
43 |
+
- 0.99
|
44 |
+
eps: 1.0e-06
|
45 |
+
lr: 1.0e-06
|
46 |
+
weight_decay: 0.01
|
47 |
+
scheduler: warmuplr
|
48 |
+
scheduler_conf:
|
49 |
+
warmup_steps: 400
|
50 |
+
|
51 |
+
|
52 |
+
dataset: asr
|
53 |
+
dataset_conf:
|
54 |
+
speech_token_num: *token_num
|
55 |
+
batch_conf:
|
56 |
+
batch_size: 26
|
57 |
+
batch_type: dynamic
|
58 |
+
max_frames_in_batch: 28000000 #3000 #9000 #3000 #3300 # 3900
|
59 |
+
max_seq_in_batch: 3700 #1500 #4000 #1100 #1600 # 1900
|
60 |
+
feats_type: log_mel_spectrogram
|
61 |
+
filter_conf:
|
62 |
+
max_length: 20000
|
63 |
+
min_length: 20
|
64 |
+
token_max_length: 1200
|
65 |
+
token_min_length: 1
|
66 |
+
filter_no_extra_info: true # 如果没有task lang 等信息,直接过滤掉, 适用于通用多任务训练, 推理时应该关掉
|
67 |
+
max_seq_len: 2000 #、1100 #1000
|
68 |
+
other_filter_conf:
|
69 |
+
only_s2s: false # 只针对与s2s dataloader的过滤
|
70 |
+
only_s2t: false # 只针对与s2t dataloader的过滤
|
71 |
+
only_t2t: false # 只针对与t2t dataloader的过滤
|
72 |
+
only_t2s: false # 只针对与t2s dataloader的过滤
|
73 |
+
language_conf:
|
74 |
+
limited_langs:
|
75 |
+
- zh
|
76 |
+
log_mel_spectrogram_conf:
|
77 |
+
hop_length: 160
|
78 |
+
n_fft: 400
|
79 |
+
num_mel_bins: 80
|
80 |
+
padding: 0
|
81 |
+
resample_conf:
|
82 |
+
resample_rate: 16000
|
83 |
+
shuffle: true
|
84 |
+
shuffle_conf:
|
85 |
+
shuffle_size: 1500
|
86 |
+
sort: true
|
87 |
+
sort_conf:
|
88 |
+
sort_size: 500
|
89 |
+
spec_aug: true
|
90 |
+
spec_aug_conf:
|
91 |
+
max_f: 10
|
92 |
+
max_t: 50
|
93 |
+
num_f_mask: 2
|
94 |
+
num_t_mask: 2
|
95 |
+
spec_sub: true
|
96 |
+
spec_sub_conf:
|
97 |
+
max_t: 30
|
98 |
+
num_t_sub: 3
|
99 |
+
spec_trim: false
|
100 |
+
speed_perturb: false
|
101 |
+
eod_id: 151645
|
102 |
+
split_num: 1
|
103 |
+
multi_num: 2
|
104 |
+
prompt_conf_path: conf/prompt_config.yaml
|
105 |
+
data_recover: false
|
106 |
+
data_recover_conf:
|
107 |
+
start_idx: 0 # 删除前面start_idx个item(tar包)
|
108 |
+
other_tokenze_conf: # 一些对数据额外操作的可控按钮,这些操作一般来说再test时都得为false
|
109 |
+
only_info:
|
110 |
+
only_s2s: false # 只针对与s2s dataloader的过滤
|
111 |
+
only_s2t: false # 只针对与s2t dataloader的过滤
|
112 |
+
only_t2t: false # 只针对与t2t dataloader的过滤
|
113 |
+
only_t2s: false # 只针对与t2s dataloader的过滤
|
114 |
+
use_50_per_change_if_only_X: true # 50%的句子随机替换为其only X
|
115 |
+
use_s2s_streaming_random:
|
116 |
+
enable: false
|
117 |
+
rate: 0.5 # 1.0 表示100%的句子随机替换为其only X
|
118 |
+
natural_language_convert:
|
119 |
+
enable: false
|
120 |
+
rate: 0.00 # 1.0 表示100%的转换成自然语言模式
|
121 |
+
use_s2s_convert_s2t:
|
122 |
+
enable: false # 单独为s2t dataloader 开启s2s convert
|
123 |
+
rate: 1.0 # 1.0 表示100%的句子随机替换为其only X
|
124 |
+
use_streaming_tts:
|
125 |
+
enable: false
|
126 |
+
rate: 0.5 # 1.0 表示100%的句子随机替换为其only X
|
127 |
+
use_think_mode:
|
128 |
+
enable: false # 开启think 模式, 即随机替换为think模式的句子
|
129 |
+
rate: 0.8
|
130 |
+
other_filter_conf:
|
131 |
+
fiter_txt_is_None: true # 过滤掉text is "<NONE>"的语音数据,适配由于gender数据部分含有<NONE>标签而设计。但仅train起作用
|
132 |
+
|
133 |
+
# model config for encoder
|
134 |
+
encoder: transformer
|
135 |
+
encoder_conf:
|
136 |
+
activation_type: gelu
|
137 |
+
attention_dropout_rate: 0.0
|
138 |
+
attention_heads: 16
|
139 |
+
dropout_rate: 0.1
|
140 |
+
gradient_checkpointing: true
|
141 |
+
input_layer: conv1d2
|
142 |
+
key_bias: false
|
143 |
+
linear_units: 4096
|
144 |
+
normalize_before: true
|
145 |
+
num_blocks: 24
|
146 |
+
output_size: 1024
|
147 |
+
pos_enc_layer_type: abs_pos_whisper
|
148 |
+
positional_dropout_rate: 0.1
|
149 |
+
static_chunk_size: -1
|
150 |
+
use_dynamic_chunk: false
|
151 |
+
use_dynamic_left_chunk: false
|
152 |
+
|
conf/data_s2s.yaml
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ===========================副语言 s2s thinking ===================================
|
2 |
+
|
3 |
+
# age gender,
|
4 |
+
age_gender_common:
|
5 |
+
path: /home/A02_tmpdata3/osum_s2s/gender/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
|
6 |
+
tar_num: 1511
|
7 |
+
weight: 2
|
8 |
+
|
9 |
+
gender_xianshi:
|
10 |
+
path: /home/A02_tmpdata3/osum_s2s/sex_xianshi_cosyvoice2_by_cywang_added_by_20250625/raw_data/s2s_handle/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
|
11 |
+
tar_num: 30
|
12 |
+
weight: 2
|
13 |
+
|
14 |
+
|
15 |
+
gender_yinshi_3k:
|
16 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_sex_yinshi_7_4_osum_by_cywang_added_by_20250708/raw_data/s2s_handle/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
|
17 |
+
tar_num: 3
|
18 |
+
weight: 2
|
19 |
+
|
20 |
+
gender_yinshi_5k:
|
21 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_sex_yinshi_5000_6_13_data_by_gjli_added_by_20250622/raw_data/s2s_handle/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
|
22 |
+
tar_num: 6
|
23 |
+
weight: 2
|
24 |
+
|
25 |
+
age_xianshi:
|
26 |
+
path: /home/A02_tmpdata3/osum_s2s/age_xianshi_cosyvoice2_by_cywang_added_by_20250625/raw_data/s2s_handle/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
|
27 |
+
tar_num: 25
|
28 |
+
weight: 2
|
29 |
+
|
30 |
+
|
31 |
+
# caption
|
32 |
+
caption_common_7label:
|
33 |
+
path: /home/A02_tmpdata3/osum_s2s/caption/raw_data/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
|
34 |
+
tar_num: 162
|
35 |
+
weight: 2
|
36 |
+
|
37 |
+
caption_common_50_label:
|
38 |
+
path: /home/A02_tmpdata3/osum_s2s/caption_add_2025_1_6/raw_data/s2s_data_with_gender/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
|
39 |
+
tar_num: 395 # 实际是196k
|
40 |
+
weight: 2
|
41 |
+
|
42 |
+
caption_xianshi:
|
43 |
+
path: /home/A02_tmpdata3/osum_s2s/caption_s2s_xianshi_20250806/raw_data/s2s_data/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
|
44 |
+
tar_num: 6
|
45 |
+
weight: 10
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
# emotion
|
50 |
+
emotion_100K_sensevoice:
|
51 |
+
path: /home/A02_tmpdata3/osum_s2s/emotion_yinshi_zxzhao_with_q_emo_by_cywang_added_by_20250701/handle_data/s2s_handle/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
|
52 |
+
tar_num: 107
|
53 |
+
weight: 10
|
54 |
+
|
55 |
+
emotion_30K_sensevoice:
|
56 |
+
path: /home/A02_tmpdata3/emotion/中英混多音色情感数据库/s2s_handle/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
|
57 |
+
tar_num: 33
|
58 |
+
weight: 10
|
59 |
+
|
60 |
+
S2SChat_osum_setting_qa_527_updated_by_cywang_added_by_20250616_think:
|
61 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_osum_setting_qa_527_updated_by_cywang_added_by_20250616/raw_data/s2s_handle/xlgeng_new_data/s2s_thinking/doubao/combines_list.txt
|
62 |
+
shard_num: 8
|
63 |
+
weight: 10
|
64 |
+
|
65 |
+
# ======================================s2s 副语言 thinking end=====================================
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
# ===========================副语言 s2s no thinking ===================================
|
71 |
+
|
72 |
+
# age gender,
|
73 |
+
age_gender_common_no_thinking:
|
74 |
+
path: /home/A02_tmpdata3/osum_s2s/gender/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
|
75 |
+
tar_num: 1511
|
76 |
+
weight: 2
|
77 |
+
|
78 |
+
gender_xianshi_no_thinking:
|
79 |
+
path: /home/A02_tmpdata3/osum_s2s/sex_xianshi_cosyvoice2_by_cywang_added_by_20250625/raw_data/s2s_handle/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
|
80 |
+
tar_num: 30
|
81 |
+
weight: 2
|
82 |
+
|
83 |
+
gender_yinshi_3k_no_thinking:
|
84 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_sex_yinshi_7_4_osum_by_cywang_added_by_20250708/raw_data/s2s_handle/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
|
85 |
+
tar_num: 3
|
86 |
+
weight: 2
|
87 |
+
|
88 |
+
gender_yinshi_5k_no_thinking:
|
89 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_sex_yinshi_5000_6_13_data_by_gjli_added_by_20250622/raw_data/s2s_handle/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
|
90 |
+
tar_num: 6
|
91 |
+
weight: 2
|
92 |
+
|
93 |
+
age_xianshi_no_thinking:
|
94 |
+
path: /home/A02_tmpdata3/osum_s2s/age_xianshi_cosyvoice2_by_cywang_added_by_20250625/raw_data/s2s_handle/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
|
95 |
+
tar_num: 25
|
96 |
+
weight: 2
|
97 |
+
|
98 |
+
|
99 |
+
# caption
|
100 |
+
caption_common_7label_no_thinking:
|
101 |
+
path: /home/A02_tmpdata3/osum_s2s/caption/raw_data/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
|
102 |
+
tar_num: 162
|
103 |
+
weight: 2
|
104 |
+
|
105 |
+
caption_common_50_label_no_thinking:
|
106 |
+
path: /home/A02_tmpdata3/osum_s2s/caption_add_2025_1_6/raw_data/s2s_data_with_gender/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
|
107 |
+
tar_num: 395 # 实际是196k
|
108 |
+
weight: 2
|
109 |
+
|
110 |
+
caption_xianshi_no_thinking:
|
111 |
+
path: /home/A02_tmpdata3/osum_s2s/caption_s2s_xianshi_20250806/raw_data/s2s_data/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
|
112 |
+
tar_num: 6
|
113 |
+
weight: 2
|
114 |
+
|
115 |
+
|
116 |
+
# emotion
|
117 |
+
emotion_100K_sensevoice_no_thinking:
|
118 |
+
path: /home/A02_tmpdata3/osum_s2s/emotion_yinshi_zxzhao_with_q_emo_by_cywang_added_by_20250701/handle_data/s2s_handle/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
|
119 |
+
tar_num: 107
|
120 |
+
weight: 10
|
121 |
+
|
122 |
+
emotion_30K_sensevoice_no_thinking:
|
123 |
+
path: /home/A02_tmpdata3/emotion/中英混多音色情感数据库/s2s_handle/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
|
124 |
+
tar_num: 33
|
125 |
+
weight: 10
|
126 |
+
|
127 |
+
S2SChat_osum_setting_qa_527_updated_by_cywang_added_by_20250616:
|
128 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_osum_setting_qa_527_updated_by_cywang_added_by_20250616/raw_data/s2s_handle/xlgeng_new_data/s2s_no_thinking/doubao/combines_list.txt
|
129 |
+
shard_num: 8
|
130 |
+
weight: 10
|
131 |
+
|
132 |
+
# -------------------------------------------s2s 副语言 no thinking end-------------------------------------------
|
133 |
+
|
134 |
+
S2SChat_syndata_merged_by_300W_zhguo_added_by_20250616:
|
135 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_syndata_merged_by_300W_zhguo_added_by_20250616/combines_data_s2s/combines_list.txt
|
136 |
+
tar_num: 3000
|
137 |
+
weight: 1
|
138 |
+
S2SChat_osum_total_data_lst_check_final_100W_by_zhguo_added_by_20250616:
|
139 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_osum_total_data_lst_check_final_100W_by_zhguo_added_by_20250616/combines_data_s2s/combines_list.txt
|
140 |
+
tar_num: 1000
|
141 |
+
weight: 1
|
142 |
+
|
143 |
+
gaozhiliang_gbma:
|
144 |
+
path: /home/A02_tmpdata3/osum_s2s/gaozhiliang_gbma/shards_list.txt
|
145 |
+
new_data_list: /home/node44_tmpdata3/netease/gbma/workspace/osum/data/process/0803/all_data_info.jsonl
|
146 |
+
new_lab_path: /home/work_nfs23/asr_data/data/osum_chat/s2s/gaozhiliang_gbma/shards_list.txt
|
147 |
+
shard_num: 24
|
148 |
+
weight: 1
|
149 |
+
|
150 |
+
# ======================================s2s no thinking end=====================================
|
151 |
+
|
152 |
+
# emotion explicit;
|
153 |
+
S2SChat_0628_E1_shard_by_kxxia_added_by_20250630:
|
154 |
+
huawei_path: /mnt/sfs/asr/update_data/S2SChat_0628_E1_shard_by_kxxia_added_by_20250630/shards_list.txt
|
155 |
+
new_data_list: /home/work_nfs16/cywang/workspace/OSUM/E1/0628_E1_shard.jsonl
|
156 |
+
new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/0628_E1_shard/shards_list.txt
|
157 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_0628_E1_shard_by_kxxia_added_by_20250630/shards_list.txt
|
158 |
+
shard_num: 147
|
159 |
+
description: "E1 shard, 情感显示数据"
|
160 |
+
weight: 1
|
161 |
+
S2SChat_eng_e1_by_cywang_added_by_20250711:
|
162 |
+
huawei_path: /mnt/sfs/asr/update_data/S2SChat_eng_e1_by_cywang_added_by_20250711/shards_list.txt
|
163 |
+
new_data_list: /home/work_nfs16/kxxia/work/common/eng_e1.jsonl1752154262.3374825
|
164 |
+
new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/eng_e1/shards_list.txt
|
165 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_eng_e1_by_cywang_added_by_20250711/shards_list.txt
|
166 |
+
shard_num: 50
|
167 |
+
weight: 2
|
168 |
+
|
169 |
+
# 下面一共才200多个
|
170 |
+
S2SChat_0630_trans_en2zh_by_cywang_added_by_20250704:
|
171 |
+
huawei_path: /mnt/sfs/asr/update_data/S2SChat_0630_trans_en2zh_by_cywang_added_by_20250704/shards_list.txt
|
172 |
+
new_data_list: /home/work_nfs16/cywang/workspace/OSUM/trans_emotion/0630_trans_en2zh.jsonl
|
173 |
+
new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/0630_trans_en2zh/shards_list.txt
|
174 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_0630_trans_en2zh_by_cywang_added_by_20250704/shards_list.txt
|
175 |
+
shard_num: 128
|
176 |
+
weight: 0.5
|
177 |
+
S2SChat_0630_trans_zh2en_by_cywang_added_by_20250704:
|
178 |
+
huawei_path: /mnt/sfs/asr/update_data/S2SChat_0630_trans_zh2en_by_cywang_added_by_20250704/shards_list.txt
|
179 |
+
new_data_list: /home/work_nfs16/cywang/workspace/OSUM/trans_emotion/0630_trans_zh2en.jsonl
|
180 |
+
new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/0630_trans_zh2en/shards_list.txt
|
181 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_0630_trans_zh2en_by_cywang_added_by_20250704/shards_list.txt
|
182 |
+
shard_num: 128
|
183 |
+
weight: 0.5
|
184 |
+
S2SChat_pachong_part1_filter_author_data_by_gjli_added_by_20250622:
|
185 |
+
shard_num: 28
|
186 |
+
huawei_path: /mnt/sfs/asr/update_data/S2SChat_pachong_part1_filter_author_data_by_gjli_added_by_20250622/shards_list.txt
|
187 |
+
new_data_list: /home/work_nfs16/gjli/workspaces/poem/6-16_shigepachong/pachong_part1_filter_content_data.list
|
188 |
+
new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/pachong_part1_filter_author_data/shards_list.txt
|
189 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_pachong_part1_filter_author_data_by_gjli_added_by_20250622/shards_list.txt
|
190 |
+
weight: 1
|
191 |
+
S2SChat_pachong_part1_filter_content_data_by_gjli_added_by_20250622:
|
192 |
+
huawei_path: /mnt/sfs/asr/update_data/S2SChat_pachong_part1_filter_content_data_by_gjli_added_by_20250622/shards_list.txt
|
193 |
+
new_data_list: /home/work_nfs16/gjli/workspaces/poem/6-16_shigepachong/pachong_part1_filter_author_data.list
|
194 |
+
new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/pachong_part1_filter_content_data/shards_list.txt
|
195 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_pachong_part1_filter_content_data_by_gjli_added_by_20250622/shards_list.txt
|
196 |
+
shard_num: 68
|
197 |
+
weight: 1
|
198 |
+
S2SChat_poem_1_2_6_3_author_data_150num_by_gjli_added_by_20250622:
|
199 |
+
huawei_path: /mnt/sfs/asr/update_data/S2SChat_poem_1_2_6_3_author_data_150num_by_gjli_added_by_20250622/shards_list.txt
|
200 |
+
new_data_list: /home/work_nfs16/gjli/workspaces/poem/6.3/poem_1_2_6-3_author_data.list
|
201 |
+
new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/poem_1_2_6-3_author_data/shards_list.txt
|
202 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_poem_1_2_6_3_author_data_150num_by_gjli_added_by_20250622/shards_list.txt
|
203 |
+
shard_num: 2
|
204 |
+
weight: 1
|
205 |
+
S2SChat_poem_1_2_6_3_content_data_150num_by_gjli_added_by_20250622:
|
206 |
+
huawei_path: /mnt/sfs/asr/update_data/S2SChat_poem_1_2_6_3_content_data_150num_by_gjli_added_by_20250622/shards_list.txt
|
207 |
+
new_data_list: /home/work_nfs16/gjli/workspaces/poem/6.3/poem_1_2_6-3_content_data.list
|
208 |
+
new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/poem_1_2_6-3_content_data/shards_list.txt
|
209 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_poem_1_2_6_3_content_data_150num_by_gjli_added_by_20250622/shards_list.txt
|
210 |
+
shard_num: 9
|
211 |
+
weight: 1
|
212 |
+
S2SChat_poem_500_author_data_new_by_gjli_added_by_20250622:
|
213 |
+
huawei_path: /mnt/sfs/asr/update_data/S2SChat_poem_500_author_data_new_by_gjli_added_by_20250622/shards_list.txt
|
214 |
+
new_data_list: /home/work_nfs16/gjli/workspaces/poem/poem_500_author_data_new.list
|
215 |
+
new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/poem_500_author_data_new/shards_list.txt
|
216 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_poem_500_author_data_new_by_gjli_added_by_20250622/shards_list.txt
|
217 |
+
shard_num: 4
|
218 |
+
weight: 1
|
219 |
+
S2SChat_poem_500_content_data_new_by_gjli_added_by_20250622:
|
220 |
+
huawei_path: /mnt/sfs/asr/update_data/S2SChat_poem_500_content_data_new_by_gjli_added_by_20250622/shards_list.txt
|
221 |
+
new_data_list: /home/work_nfs16/gjli/workspaces/poem/poem_500_content_data_new.list
|
222 |
+
new_lab_path: /home/work_nfs11/cywang/data/shard/S2Chat/poem_500_content_data_new/shards_list.txt
|
223 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_poem_500_content_data_new_by_gjli_added_by_20250622/shards_list.txt
|
224 |
+
shard_num: 4
|
225 |
+
weight: 1
|
226 |
+
|
conf/data_s2t.yaml
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# age gender,
|
3 |
+
age_gender_common:
|
4 |
+
path: /home/A02_tmpdata3/osum_s2s/gender/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
|
5 |
+
tar_num: 1511
|
6 |
+
|
7 |
+
gender_xianshi:
|
8 |
+
path: /home/A02_tmpdata3/osum_s2s/sex_xianshi_cosyvoice2_by_cywang_added_by_20250625/raw_data/s2s_handle/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
|
9 |
+
tar_num: 30
|
10 |
+
|
11 |
+
gender_yinshi_3k:
|
12 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_sex_yinshi_7_4_osum_by_cywang_added_by_20250708/raw_data/s2s_handle/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
|
13 |
+
tar_num: 3
|
14 |
+
|
15 |
+
gender_yinshi_5k:
|
16 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_sex_yinshi_5000_6_13_data_by_gjli_added_by_20250622/raw_data/s2s_handle/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
|
17 |
+
tar_num: 6
|
18 |
+
|
19 |
+
age_xianshi:
|
20 |
+
path: /home/A02_tmpdata3/osum_s2s/age_xianshi_cosyvoice2_by_cywang_added_by_20250625/raw_data/s2s_handle/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
|
21 |
+
tar_num: 25
|
22 |
+
|
23 |
+
|
24 |
+
# caption
|
25 |
+
caption_common_7label:
|
26 |
+
path: /home/A02_tmpdata3/osum_s2s/caption/raw_data/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
|
27 |
+
tar_num: 162
|
28 |
+
|
29 |
+
caption_common_50_label:
|
30 |
+
path: /home/A02_tmpdata3/osum_s2s/caption_add_2025_1_6/raw_data/s2s_data_with_gender/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
|
31 |
+
tar_num: 395 # 实际是196k
|
32 |
+
|
33 |
+
caption_xianshi:
|
34 |
+
path: /home/A02_tmpdata3/osum_s2s/caption_s2s_xianshi_20250806/raw_data/s2s_data/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
|
35 |
+
tar_num: 6
|
36 |
+
weight: 10
|
37 |
+
|
38 |
+
|
39 |
+
# emotion
|
40 |
+
emotion_100K_sensevoice:
|
41 |
+
path: /home/A02_tmpdata3/osum_s2s/emotion_yinshi_zxzhao_with_q_emo_by_cywang_added_by_20250701/handle_data/s2s_handle/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
|
42 |
+
tar_num: 107
|
43 |
+
weight: 10
|
44 |
+
|
45 |
+
emotion_30K_sensevoice:
|
46 |
+
path: /home/A02_tmpdata3/emotion/中英混多音色情感数据库/s2s_handle/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
|
47 |
+
tar_num: 33
|
48 |
+
weight: 10
|
49 |
+
|
50 |
+
S2SChat_osum_setting_qa_527_updated_by_cywang_added_by_20250616_think:
|
51 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_osum_setting_qa_527_updated_by_cywang_added_by_20250616/raw_data/s2s_handle/xlgeng_new_data/s2t_thinking/doubao/combines_list.txt
|
52 |
+
shard_num: 8
|
53 |
+
weight: 10
|
54 |
+
|
55 |
+
# ======================================s2s 副语言 thinking end=====================================
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
# ===========================副语言 s2s no thinking ===================================
|
61 |
+
|
62 |
+
# age gender,
|
63 |
+
age_gender_common_no_thinking:
|
64 |
+
path: /home/A02_tmpdata3/osum_s2s/gender/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
|
65 |
+
tar_num: 1511
|
66 |
+
|
67 |
+
gender_xianshi_no_thinking:
|
68 |
+
path: /home/A02_tmpdata3/osum_s2s/sex_xianshi_cosyvoice2_by_cywang_added_by_20250625/raw_data/s2s_handle/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
|
69 |
+
tar_num: 30
|
70 |
+
|
71 |
+
gender_yinshi_3k_no_thinking:
|
72 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_sex_yinshi_7_4_osum_by_cywang_added_by_20250708/raw_data/s2s_handle/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
|
73 |
+
tar_num: 3
|
74 |
+
|
75 |
+
gender_yinshi_5k_no_thinking:
|
76 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_sex_yinshi_5000_6_13_data_by_gjli_added_by_20250622/raw_data/s2s_handle/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
|
77 |
+
tar_num: 6
|
78 |
+
|
79 |
+
age_xianshi_no_thinking:
|
80 |
+
path: /home/A02_tmpdata3/osum_s2s/age_xianshi_cosyvoice2_by_cywang_added_by_20250625/raw_data/s2s_handle/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
|
81 |
+
tar_num: 25
|
82 |
+
|
83 |
+
|
84 |
+
# caption
|
85 |
+
caption_common_7label_no_thinking:
|
86 |
+
path: /home/A02_tmpdata3/osum_s2s/caption/raw_data/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
|
87 |
+
tar_num: 162
|
88 |
+
|
89 |
+
caption_common_50_label_no_thinking:
|
90 |
+
path: /home/A02_tmpdata3/osum_s2s/caption_add_2025_1_6/raw_data/s2s_data_with_gender/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
|
91 |
+
tar_num: 395 # 实际是196k
|
92 |
+
|
93 |
+
caption_xianshi_no_thinking:
|
94 |
+
path: /home/A02_tmpdata3/osum_s2s/caption_s2s_xianshi_20250806/raw_data/s2s_data/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
|
95 |
+
tar_num: 6
|
96 |
+
|
97 |
+
|
98 |
+
# emotion
|
99 |
+
emotion_100K_sensevoice_no_thinking:
|
100 |
+
path: /home/A02_tmpdata3/osum_s2s/emotion_yinshi_zxzhao_with_q_emo_by_cywang_added_by_20250701/handle_data/s2s_handle/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
|
101 |
+
tar_num: 107
|
102 |
+
weight: 10
|
103 |
+
|
104 |
+
emotion_30K_sensevoice_no_thinking:
|
105 |
+
path: /home/A02_tmpdata3/emotion/中英混多音色情感数据库/s2s_handle/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
|
106 |
+
tar_num: 33
|
107 |
+
weight: 10
|
108 |
+
|
109 |
+
S2SChat_osum_setting_qa_527_updated_by_cywang_added_by_20250616:
|
110 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_osum_setting_qa_527_updated_by_cywang_added_by_20250616/raw_data/s2s_handle/xlgeng_new_data/s2t_no_thinking/doubao/combines_list.txt
|
111 |
+
shard_num: 8
|
112 |
+
weight: 10
|
113 |
+
|
114 |
+
# -------------------------------------------s2s 副语言 no thinking end-------------------------------------------
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
# 知识问答
|
119 |
+
S2SChat_syndata_merged_by_300W_zhguo_added_by_20250616:
|
120 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_syndata_merged_by_300W_zhguo_added_by_20250616/combines_data_s2t/combines_list.txt
|
121 |
+
tar_num: 3000
|
122 |
+
S2SChat_osum_total_data_lst_check_final_100W_by_zhguo_added_by_20250616:
|
123 |
+
path: /home/A02_tmpdata3/osum_s2s/S2SChat_osum_total_data_lst_check_final_100W_by_zhguo_added_by_20250616/combines_data_s2t/combines_list.txt
|
124 |
+
tar_num: 1000
|
125 |
+
# ======================================s2t 副语言 no thinking end==========================
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
# 语音理解==========================================
|
130 |
+
asr:
|
131 |
+
huawei_path: "/mnt/sfs/asr/asr/shards_list.txt" # 2.4
|
132 |
+
lab_path: "/home/node54_tmpdata/xlgeng/asr_data_2w/shards_list.txt"
|
133 |
+
path: "/home/A03_tmpdata1/s2s/asr_data_2.4w/asr_data_2w/shards_list.txt"
|
134 |
+
shard_num: 15477
|
135 |
+
weight: 0.1 # ~10000h
|
136 |
+
|
137 |
+
|
138 |
+
# ===========理解任务 ==============================================
|
139 |
+
librispeech:
|
140 |
+
huawei_path: "/mnt/sfs/asr/update_data/LibriSpeech_shard_common/shards_list.txt" #1000h
|
141 |
+
lab_path: "/home/work_nfs15/asr_data/data/LibriSpeech/LibriSpeech_shard_common/shards_list.txt"
|
142 |
+
path: "/home/A03_tmpdata3/asr_data/librispeech/shards_list.txt"
|
143 |
+
shard_num: 282
|
144 |
+
weight: 1
|
145 |
+
mix_asru200_add_2025_2_14:
|
146 |
+
huawei_path: "/mnt/sfs/asr/update_data/mix_asru200_add_2025_2_14/shards_list.txt" # 200
|
147 |
+
path: "/home/A03_tmpdata1/s2s/asru700/train/shards_list.txt"
|
148 |
+
lab_path: "/home/work_nfs15/asr_data/data/ASRU700/train/shards_list.txt" # 中英混单词之间是有空格的
|
149 |
+
shard_num: 187
|
150 |
+
weight: 1
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
caption:
|
157 |
+
path: "/home/A02_tmpdata3/osum_s2s/caption/shards_list.txt"
|
158 |
+
huawei_path: "/mnt/sfs/asr/update_data/caption/shards_list.txt" # 319h
|
159 |
+
lab_path: "/home/node54_tmpdata2/data4understand/update_data/caption/shards_list.txt"# 是cap audio set+aishell2的拼接
|
160 |
+
shard_num: 319
|
161 |
+
weight: 0.5
|
162 |
+
caption_add_2025_1_6:
|
163 |
+
path: "/home/A02_tmpdata3/osum_s2s/caption_add_2025_1_6/shards_list.txt"
|
164 |
+
lab_path: "/home/work_nfs7/yacao/0106_twj_shard/shards_0306/add_label/shards_list.txt"
|
165 |
+
huawei_path: "/mnt/sfs/asr/update_data/caption_2025_1_6_newadd/shards_list.txt" # 130h
|
166 |
+
shard_num: 392
|
167 |
+
weight: 0.5
|
168 |
+
caption_aslp_add_2025_1_15:
|
169 |
+
path: "/home/A02_tmpdata3/osum_s2s/caption_aslp_add_2025_1_15/shards_list.txt"
|
170 |
+
huawei_path: "/mnt/sfs/asr/update_data/caption_aslp_add_2025_1_15/shards_list.txt" # 5h
|
171 |
+
shard_num: 5
|
172 |
+
lab_path: "/home/work_nfs9/yacao/nfs7_copy/yacao/shard/0114_wjtian_simu2/aslp_caption_train/shards_list.txt"
|
173 |
+
weight: 5
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
# 50类别的caption
|
178 |
+
s2t_caption_50label:
|
179 |
+
shard_num: 392
|
180 |
+
path: "/home/A02_tmpdata3/osum_s2s/s2t_caption_50label/shards_list.txt"
|
181 |
+
lab_path: "/home/work_nfs7/yacao/0106_twj_shard/shards_0306/add_label/shards_list.txt"
|
182 |
+
huawei_path: "/mnt/sfs/asr/update_data/0106_twj_shard_caption_50label_add_by_2025_3_10/shards_list.txt" # 392tar
|
183 |
+
weight: 0.5 # 10
|
184 |
+
|
185 |
+
|
186 |
+
|
187 |
+
|
188 |
+
emotion: # 不全, 312tar
|
189 |
+
path: "/home/A02_tmpdata3/osum_s2s/emotion/shards_list.txt"
|
190 |
+
lab_path: "/home/xlgeng/sdb2/emotion/shards_list.txt"
|
191 |
+
huawei_path: "/mnt/sfs/asr/emotion/shards_list.txt"
|
192 |
+
shard_num: 370
|
193 |
+
weight: 0.5 # 538h
|
194 |
+
emotion_stage2_add:
|
195 |
+
path: "/home/A02_tmpdata3/osum_s2s/emotion_stage2_add/shards_list.txt"
|
196 |
+
lab_path: "/home/xlgeng/sdb2/emotion_stage2_add/shards_list.txt"
|
197 |
+
huawei_path: "/mnt/sfs/asr/emotion_stage2_add/shards_list.txt"
|
198 |
+
shard_num: 44
|
199 |
+
weight: 0.1 # 150h
|
200 |
+
emotion_stage3_add:
|
201 |
+
path: "/home/A02_tmpdata3/osum_s2s/emotion_stage3_add/shards_list.txt"
|
202 |
+
lab_path: "/home/xlgeng/sdb2/emotion_stage3_add/shards_list.txt"
|
203 |
+
huawei_path: "/mnt/sfs/asr/emotion_stage3_add/shards_list.txt"
|
204 |
+
shard_num: 53
|
205 |
+
weight: 0.1 # 138h
|
206 |
+
emotion_stage4_add:
|
207 |
+
path: "/home/A02_tmpdata3/osum_s2s/emotion_stage4_add/shards_list.txt"
|
208 |
+
lab_path: "/home/xlgeng/sdb2/emotion_stage4_add/shards_list.txt"
|
209 |
+
huawei_path: "/mnt/sfs/asr/emotion_stage4_add/shards_list.txt"
|
210 |
+
shard_num: 54
|
211 |
+
weight: 0.1 #100h
|
212 |
+
emotion_stage5_add:
|
213 |
+
path: "/home/A02_tmpdata3/osum_s2s/emotion_stage5_add/shards_list.txt"
|
214 |
+
lab_path: "/home/xlgeng/sdb2/emotion_stage5_add/shards_list.txt"
|
215 |
+
shard_num: 53
|
216 |
+
huawei_path: "/mnt/sfs/asr/emotion_stage5_add/shards_list.txt"
|
217 |
+
weight: 0.1
|
218 |
+
|
219 |
+
emotion_meld:
|
220 |
+
path: "/home/A02_tmpdata3/osum_s2s/emotion_meld/shards_list.txt"
|
221 |
+
lab_path: "/home/xlgeng/sdb2/emotion_meld/shards_list.txt"
|
222 |
+
huawei_path: "/mnt/sfs/asr/update_data/emotion_meld/shards_list.txt" # 8h
|
223 |
+
shard_num: 9
|
224 |
+
weight: 1
|
225 |
+
#emotion_dis_fear_add_2025_1_15:
|
226 |
+
# huawei_path: "/mnt/sfs/asr/update_data/emotion_dis_fear_add_2025_1_15/shards_list.txt"
|
227 |
+
# weight: 0
|
228 |
+
|
229 |
+
emotion_lucy_Q_added_2025_4_9:
|
230 |
+
path: "/home/A02_tmpdata3/osum_s2s/s2s_lucy_Q_emotion/shards_list.txt"
|
231 |
+
shard_num: 121
|
232 |
+
lab_path: "/home/work_nfs11/cywang/data/shard/emotion/QEmo_Q_train/shards_list.txt"
|
233 |
+
huawei_path: "/mnt/sfs/asr/update_data/emotion_lucy_Q_added_2025_4_9/shards_list.txt"
|
234 |
+
weight: 0.5
|
235 |
+
|
236 |
+
|
237 |
+
|
238 |
+
Age_with_noize_add_2025_2_4: # 不全,才245个
|
239 |
+
path: "/home/A02_tmpdata3/osum_s2s/age_3000_noize/shards_list.txt"
|
240 |
+
lab_path: "/home/work_nfs6/syliu/for_gxl/Age/simu_age/shards_list.txt"
|
241 |
+
shard_num: 2720
|
242 |
+
huawei_path: "/mnt/sfs/asr/update_data/Age_with_noize_add_2025_2_4/shards_list.txt"
|
243 |
+
weight: 0.1
|
244 |
+
age:
|
245 |
+
path: "/home/A02_tmpdata3/osum_s2s/age_3000/shards_list.txt"
|
246 |
+
lab_path: "/home/work_nfs3/syliu/for_gxl/Age/age/shards_list.txt"
|
247 |
+
huawei_path: "/mnt/sfs/asr/update_data/age/shards_list.txt"
|
248 |
+
shard_num: 2820
|
249 |
+
weight: 0.1 #1.5 # 3000h
|
250 |
+
|
251 |
+
|
252 |
+
gender: # 不全,目前310个
|
253 |
+
shard_num: 1738
|
254 |
+
lab_path: "/home/xlgeng/sdb2/gender/shards_list.txt"
|
255 |
+
huawei_path: "/mnt/sfs/asr/update_data/sex/shards_list.txt" # 3000
|
256 |
+
path: "/home/A02_tmpdata3/osum_s2s/gender/shards_list.txt"
|
257 |
+
weight: 0.1 #1.5
|
258 |
+
gender_add_2025_1_6_kaggle: # 全了
|
259 |
+
shard_num: 116
|
260 |
+
path: "/home/A02_tmpdata3/osum_s2s/gender_kaggle/shards_list.txt"
|
261 |
+
lab_path: "/home/work_nfs3/syliu/for_gxl/new_gender/Sex/sex/shards_list.txt"
|
262 |
+
huawei_path: "/mnt/sfs/asr/update_data/sex_2025_1_6_newadd/shards_list.txt" # 107h, kaggle
|
263 |
+
weight: 0.1 #3
|
264 |
+
gender_add_2025_2_4_fix: # 2100tar # 不全,365个
|
265 |
+
path: "/home/A02_tmpdata3/osum_s2s/gender_add_2025_2_4_fix/shards_list.txt"
|
266 |
+
shard_num: 2140
|
267 |
+
lab_path: "/home/work_nfs6/xlgeng/for_gxl/gender_add_2025_2_4_fix/shards_list.txt"
|
268 |
+
huawei_path: "/mnt/sfs/asr/update_data/gender_add_2025_2_4_fix/shards_list.txt"
|
269 |
+
weight: 0.1
|
270 |
+
gender_with_noize_add_2025_2_4: # 1500h ,780tar # 不全,266个
|
271 |
+
path: "/home/A02_tmpdata3/osum_s2s/gender_with_noize_add_2025_2_4/shards_list.txt"
|
272 |
+
lab_path: "/home/work_nfs6/xlgeng/for_gxl/gender_with_noize_add_2025_2_4/shards_list.txt"
|
273 |
+
huawei_path: "/mnt/sfs/asr/update_data/gender_with_noize_add_2025_2_4/shards_list.txt"
|
274 |
+
shard_num: 780
|
275 |
+
weight: 0.1
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
age_gender_stage2_add:
|
281 |
+
path: "/home/A02_tmpdata3/osum_s2s/age_gender_stage2_add/shards_list.txt"
|
282 |
+
lab_path: "/home/xlgeng/sdb2/age_gender_stage2_add/shards_list.txt"
|
283 |
+
huawei_path: "/mnt/sfs/asr/update_data/Speech_Age_Sex/shards_list.txt"
|
284 |
+
weight: 0.1 # 174h
|
285 |
+
|
286 |
+
age_gender_add_2025_1_13:
|
287 |
+
path: "/home/A02_tmpdata3/osum_s2s/age_gender_add_2025_1_13/shards_list.txt"
|
288 |
+
lab_path: "/home/work_nfs3/syliu/for_gxl/Age_Sex/age_sex/shards_list.txt"
|
289 |
+
huawei_path: "/mnt/sfs/asr/update_data/Speech_Age_Sex_add_2025_1_13/shards_list.txt"
|
290 |
+
weight: 0.1 #2571h
|
291 |
+
|
292 |
+
style_age_gender_stage3_add:
|
293 |
+
path: "/home/A02_tmpdata3/osum_s2s/style_age_gender_stage3_add/shards_list.txt"
|
294 |
+
lab_path: "/home/xlgeng/sdb2/style_age_gender_stage3_add/shards_list.txt"
|
295 |
+
huawei_path: "/mnt/sfs/asr/update_data/Speech_Style_Age_Sex/shards_list.txt"
|
296 |
+
weight: 0.1 # 85h
|
297 |
+
|
298 |
+
|
299 |
+
age_gender_pure_stage3_add:
|
300 |
+
path: "/home/A02_tmpdata3/osum_s2s/age_gender_pure_stage3_add/shards_list.txt"
|
301 |
+
lab_path: "/home/xlgeng/sdb2/age_gender_pure_stage3_add/shards_list.txt"
|
302 |
+
huawei_path: "/mnt/sfs/asr/update_data/Age_Sex/shards_list.txt"
|
303 |
+
weight: 0.1 # 174h
|
304 |
+
|
305 |
+
|
306 |
+
style_age_gender_pure_stage3_add:
|
307 |
+
path: "/home/A02_tmpdata3/osum_s2s/style_age_gender_pure_stage3_add/shards_list.txt"
|
308 |
+
lab_path: "/home/xlgeng/sdb2/style_age_gender_pure_stage3_add/shards_list.txt"
|
309 |
+
huawei_path: "/mnt/sfs/asr/update_data/Style_Age_Sex/shards_list.txt"
|
310 |
+
weight: 0.1 # 85h
|
311 |
+
|
312 |
+
|
313 |
+
|
314 |
+
|
315 |
+
# 多任务, caption
|
316 |
+
merged_output_caption_age_gender_add_2025_2_26:
|
317 |
+
path: "/home/A02_tmpdata3/osum_s2s/merged_output_caption_age_gender_add_2025_2_26/shards_list.txt"
|
318 |
+
lab_path: "/home/work_nfs7/yacao/0106_twj_shard/shards_0226/merged_output/shards_list.txt"
|
319 |
+
huawei_path: "/mnt/sfs/asr/update_data/multi_task/caption_new/merged_output/shards_list.txt"
|
320 |
+
weight: 0.1
|
321 |
+
nfs10_time1_output_caption_age_gender_add_2025_2_26:
|
322 |
+
path: "/home/A02_tmpdata3/osum_s2s/nfs10_time1_output_caption_age_gender_add_2025_2_26/shards_list.txt"
|
323 |
+
lab_path: "/home/work_nfs7/yacao/0106_twj_shard/shards_0226/nfs10_time1/shards_list.txt"
|
324 |
+
huawei_path: "/mnt/sfs/asr/update_data/multi_task/caption_new/nfs10_time1/shards_list.txt"
|
325 |
+
weight: 0.1
|
326 |
+
other_20000_caption_age_gender_add_2025_2_26:
|
327 |
+
path: "/home/A02_tmpdata3/osum_s2s/other_20000_caption_age_gender_add_2025_2_26/shards_list.txt"
|
328 |
+
lab_path: "/home/work_nfs7/yacao/0106_twj_shard/shards_0226/other_20000/shards_list.txt"
|
329 |
+
huawei_path: "/mnt/sfs/asr/update_data/multi_task/caption_new/other_20000/shards_list.txt"
|
330 |
+
weight: 0.1
|
331 |
+
simu9_1227_caption_age_gender_add_2025_2_26:
|
332 |
+
path: "/home/A02_tmpdata3/osum_s2s/simu9_1227_caption_age_gender_add_2025_2_26/shards_list.txt"
|
333 |
+
lab_path: "/home/work_nfs7/yacao/0106_twj_shard/shards_0226/simu9_1227/shards_list.txt"
|
334 |
+
huawei_path: "/mnt/sfs/asr/update_data/multi_task/caption_new/simu9_1227/shards_list.txt"
|
335 |
+
weight: 0.1
|
336 |
+
|
337 |
+
|
338 |
+
# 多任务, emotion
|
339 |
+
merged_output_emotion_age_gender_add_2025_3_2:
|
340 |
+
path: "/home/A02_tmpdata3/osum_s2s/merged_output_emotion_age_gender_add_2025_3_2/shards_list.txt"
|
341 |
+
lab_path: "/home/work_nfs16/emotion_data/OSUM_age_gender/emotion_age_gender1/shards_list.txt"
|
342 |
+
huawei_path: "/mnt/sfs/asr/update_data/multi_task/emotion_age_gender1/shards_list.txt"
|
343 |
+
weight: 0.1
|
344 |
+
|
345 |
+
merged_output_emotion_age_gender_add_2025_3_2_di2pi:
|
346 |
+
path: "/home/A02_tmpdata3/osum_s2s/merged_output_emotion_age_gender_add_2025_3_2_di2pi/shards_list.txt"
|
347 |
+
shard_num: 181
|
348 |
+
lab_path: "/home/work_nfs16/emotion_data/OSUM_age_gender/emotion_age_gender2/shards_list.txt"
|
349 |
+
huawei_path: ""
|
350 |
+
weight: 0.1
|
351 |
+
|
352 |
+
|
353 |
+
# 多任务, style
|
354 |
+
merged_output_style_age_gender_add_2025_3_2:
|
355 |
+
path: "/home/A02_tmpdata3/osum_s2s/merged_output_style_age_gender_add_2025_3_2/shards_list.txt"
|
356 |
+
lab_path: "/home/node54_tmpdata2/gjli/style_age_gender_data/style_labeling_100wto200w_part1_age_gender/shards_list.txt"
|
357 |
+
shard_num: 107
|
358 |
+
huawei_path: "/mnt/sfs/asr/update_data/multi_task/style_labeling_100wto200w_part1_age_gender/shards_list.txt"
|
359 |
+
weight: 0.1
|
360 |
+
merged_output_style_origin_tts_age_gender_add_2025_3_2:
|
361 |
+
path: "/home/A02_tmpdata3/osum_s2s/merged_output_style_origin_tts_age_gender_add_2025_3_2/shards_list.txt"
|
362 |
+
lab_path: "/home/node54_tmpdata2/gjli/style_age_gender_data/style_origin_tts_age_gender/shards_list.txt"
|
363 |
+
huawei_path: "/mnt/sfs/asr/update_data/multi_task/style_origin_tts_age_gender/shards_list.txt"
|
364 |
+
weight: 0.1
|
365 |
+
style_labeling_100wto200w_part1_age_gender_emotion_gjli:
|
366 |
+
path: "/home/A02_tmpdata3/osum_s2s/style_labeling_100wto200w_part1_age_gender_emotion_gjli/shards_list.txt"
|
367 |
+
lab_path: "/home/node54_tmpdata2/gjli/style_labeling_100wto200w_part1_age_gender_emotion/shards_list.txt" # 107
|
368 |
+
huawei_path: "/mnt/sfs/asr/update_data/style_labeling_100wto200w_part1_age_gender_emotion/shards_list.txt" #107tar
|
369 |
+
weight: 0.5
|
370 |
+
style_labeling_200wto300w_part1_age_gender_emotion_gjli:
|
371 |
+
path: "/home/A02_tmpdata3/osum_s2s/style_labeling_200wto300w_part1_age_gender_emotion_gjli/shards_list.txt"
|
372 |
+
lab_path: "/home/node54_tmpdata2/gjli/style_labeling_200wto300w_part2/shards_list.txt"
|
373 |
+
shard_num: 236
|
374 |
+
huawei_path: "_"
|
375 |
+
|
376 |
+
age_gender_style_emotion1_add_2025_3_29_zxzhao:
|
377 |
+
path: "/home/A02_tmpdata3/osum_s2s/age_gender_style_emotion1_add_2025_3_29_zxzhao/shards_list.txt"
|
378 |
+
lab_path: "/home/work_nfs16/emotion_data/OSUM_age_gender/age_gender_style_emotion1/shards_list.txt"
|
379 |
+
huawei_path: "/mnt/sfs/asr/update_data/age_gender_style_emotion1_add_2025_3_29_zxzhao/shards_list.txt" # 256tar
|
380 |
+
weight: 0.5
|
381 |
+
|
382 |
+
|
383 |
+
5_label_caption_age_gender_style_emotion_added_2025_3_29_yacao:
|
384 |
+
path: "/home/A02_tmpdata3/osum_s2s/5_label_caption_age_gender_style_emotion_added_2025_3_29_yacao/shards_list.txt"
|
385 |
+
huawei_path: "/mnt/sfs/asr/update_data/5_label_caption_age_gender_style_emotion_added_2025_3_29_yacao/shards_list.txt" #270tar
|
386 |
+
lab_path: "/home/work_nfs7/yacao/0320_multilabel_2/shard/5_label/shards_list.txt"
|
387 |
+
weight: 0.5
|
388 |
+
|
389 |
+
# audio description 数据
|
390 |
+
audio_caption_by_wjtian_added_by_20250414: # 其实是 20250411 ,写日期的时候由于自动补全写错了
|
391 |
+
path: "/home/A02_tmpdata3/osum_s2s/audio_caption_by_wjtian_added_by_20250414/shards_list.txt"
|
392 |
+
lab_path: "/home/work_nfs7/cywang/OSUM/OSUM_data/shard/audio_caption/audio_caption/shards_list.txt"
|
393 |
+
huawei_path: "/mnt/sfs/asr/update_data/audio_caption_by_wjtian_added_by_20250414/shards_list.txt" # 2155 tar
|
394 |
+
weight: 0.15 # 开始上传天文杰准备的audio_caption数据,音频描述数据
|
395 |
+
|
396 |
+
|
397 |
+
S2SChat_MMAU_training_all_by_wjtian_added_by_20250708:
|
398 |
+
path: "/home/A02_tmpdata3/osum_s2s/S2SChat_MMAU_training_all_by_wjtian_added_by_20250708/shards_list.txt"
|
399 |
+
lab_path: "/home/work_nfs11/cywang/data/shard/S2Chat/MMAU-training-all/shards_list.txt"
|
400 |
+
huawei_path: "/mnt/sfs/asr/update_data/S2SChat_MMAU_training_all_by_wjtian_added_by_20250708/shards_list.txt" # 1000 tar
|
401 |
+
shard_num: 22
|
402 |
+
weight: 5
|
conf/data_t2s.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TEXT2TOKEN_hq_add_2025_3_17:
|
2 |
+
lab_path: "/home/node48_tmpdata/hkxie/4O/speech_data_final/10wh_token_data/TEXT2TOKEN_hq/shards_list.txt"
|
3 |
+
path_huawei: "/mnt/sfs/asr/update_data/TEXT2TOKEN_hq_add_2025_3_17/shards_list.txt"
|
4 |
+
path: /home/A03_tmpdata2/s2s/2000_hq_S2Chat_by_hkxie_added_by_20250411/combine_tts/combines_list.txt
|
5 |
+
weight: 5
|
6 |
+
|
7 |
+
#english_text_token_add_2025_3_26:
|
8 |
+
# path_huawei: "/mnt/sfs/asr/update_data/english_speech_data_final_TEXT2TOKEN_part_1_added_2025_3_26/shards_list.txt" # 2000
|
9 |
+
# data_list_path: "/home/work_nfs14/code/hkxie/ASR/understanding_LLM_task/english/speech_data_final/data_libriheavy_part_1.list"
|
10 |
+
# lab_path: "?"
|
11 |
+
# weight: 0.5 #1 #1 #10
|
12 |
+
#english_TEXT2TOKEN_part_2_added_by_20250402:
|
13 |
+
# path_huawei: "/mnt/sfs/asr/update_data/english_TEXT2TOKEN_part_2_added_by_20250402/shards_list.txt" #8050
|
14 |
+
# data_list_path: "/home/work_nfs14/code/hkxie/ASR/understanding_LLM_task/english/speech_data_final/data_libriheavy_part_2.list"
|
15 |
+
# lab_path: "?"
|
16 |
+
# weight: 0.5 #1 #1 #10
|
17 |
+
#
|
18 |
+
#zh_en_mix_tts_added_by_20250402: # tts
|
19 |
+
# path_huawei: "/mnt/sfs/asr/update_data/zh_en_mix_s2s_added_by_20250402/shards_list.txt" # 7
|
20 |
+
# weight: 1
|
21 |
+
#poly_tts_added_by_20250402: # tts
|
22 |
+
# path: "/mnt/sfs/asr/update_data/poly_s2s_added_by_20250402/shards_list.txt" # 295
|
23 |
+
# weight: 0.5 #10
|
24 |
+
#
|
25 |
+
#text2token_itn_by_cywang_added_by_20250428: # zyzhang 负责,cywang打包
|
26 |
+
# lab_path: "/home/work_nfs7/cywang/OSUM/OSUM_data/shard/text2token/tn/shards_list.txt"
|
27 |
+
# path_huawei: "/mnt/sfs/asr/update_data/text2token_itn_by_cywang_added_by_20250428/shards_list.txt" # 1100
|
28 |
+
# weight: 0.5
|
conf/data_t2t.yaml
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 文本到文本
|
2 |
+
#text2text_added_2025_4_4:
|
3 |
+
# path_huawei: "/mnt/sfs/asr/update_data/text2text_added_by_20250404/shards_list.txt" # 1850
|
4 |
+
# weight: 1
|
5 |
+
#text2text_2_added_by_20250409:
|
6 |
+
# path_huawei: "/mnt/sfs/asr/update_data/text2text_2_added_by_20250409/shards_list.txt" # 2000
|
7 |
+
# weight: 1
|
8 |
+
#
|
9 |
+
#text2text_3_en_added_by_20250411:
|
10 |
+
# path_huawei: "/mnt/sfs/asr/update_data/text2text_3_en_added_by_20250411/shards_list.txt" # 185
|
11 |
+
# weight: 1
|
12 |
+
#
|
13 |
+
#text2text_4_en_added_by_20250416:
|
14 |
+
# path_huawei: "/mnt/sfs/asr/update_data/text2text_4_en_added_by_20250416/shards_list.txt"
|
15 |
+
# weight: 1
|
16 |
+
|
17 |
+
#text2text_5_lucy_audioQA_1M_by_cywang_added_by_20250426:
|
18 |
+
# shard_num: 10000
|
19 |
+
# path_huawei: "/mnt/sfs/asr/update_data/text2text_5_lucy_audioQA_1M_by_cywang_added_by_20250426/shards_list.txt"
|
20 |
+
# weight: 0.1
|
21 |
+
|
22 |
+
t2t_8772K_by_xlgeng_added_by_20250513:
|
23 |
+
path: "/home/A03_tmpdata1/text2text_data_xlgeng/t2t_8772K/shards_list.txt"
|
24 |
+
path_huawei: "/mnt/sfs/asr/update_data/t2t_8772K_by_xlgeng_added_by_20250513/shards_list.txt"
|
25 |
+
weight: 0.1
|
26 |
+
|
27 |
+
#t2t_math_poetry_self_by_xlgeng_added_by_20250513:
|
28 |
+
# path: "/home/A03_tmpdata1/text2text_data_xlgeng/t2t_math_poetry_self/shards_list.txt"
|
29 |
+
# path_huawei: "/mnt/sfs/asr/update_data/t2t_math_poetry_self_by_xlgeng_added_by_20250513/shards_list.txt" # 75
|
30 |
+
# weight: 1
|
31 |
+
|
32 |
+
Alpaca_CoT_3000W_by_wjt_added_by_20250605:
|
33 |
+
lab_path_huawei: ""
|
34 |
+
shard_num: 30000
|
35 |
+
path_huawei: "/mnt/sfs/asr/update_data/Alpaca_CoT_3000W_by_wjt_added_by_20250605/shards_list.txt"
|
36 |
+
path: "/home/A03_tmpdata1/text2text_data_xlgeng/Alpaca-CoT_3000W/shards_list.txt"
|
37 |
+
weight: 0.15
|
38 |
+
|
39 |
+
|
40 |
+
qwenomni_bench_data:
|
41 |
+
path: "/home/A02_tmpdata3/osum_t2t/qwenomni_bench_data/shards_list.txt"
|
42 |
+
weight: 3
|
43 |
+
|
44 |
+
three_kingdoms:
|
45 |
+
path: "/home/A02_tmpdata3/osum_t2t/three_kingdoms/shards_list.txt"
|
46 |
+
weight: 3
|
47 |
+
|
48 |
+
voicebench_data:
|
49 |
+
path: "/home/A02_tmpdata3/osum_t2t/voicebench_data/shards_list.txt"
|
50 |
+
weight: 3
|
51 |
+
|
52 |
+
t2t_osum_self_instruction_8K:
|
53 |
+
path: "/home/A02_tmpdata3/osum_t2t/t2t_osum_self_instruction_8K/shards_list.txt"
|
54 |
+
weight: 3
|
55 |
+
|
56 |
+
|
57 |
+
#t2t_osum_self_instruction_8K_by_xlgeng_added_by_20250529:
|
58 |
+
# path: "/home/A02_tmpdata3/t2t_osum_self_instruction_8K_by_xlgeng_added_by_20250529/shards_list.txt"
|
59 |
+
# path_huawei: "/mnt/sfs/asr/update_data/t2t_osum_self_instruction_8K_by_xlgeng_added_by_20250529/shards_list.txt"
|
60 |
+
# weight: 5
|
61 |
+
|
62 |
+
#kouyu_t2t_data_by_xlgeng_added_by_20250622:
|
63 |
+
# path: ""
|
64 |
+
# shard_num: 1758
|
65 |
+
# path_huawei: "/mnt/sfs/asr/update_data/kouyu_t2t_data_by_xlgeng_added_by_20250622s/shards_list.txt"
|
66 |
+
# weight: 1
|
67 |
+
# 4653
|
68 |
+
|
69 |
+
#text2text_data_xlgeng_three_kingdoms_by_xlgeng_added_by_20250701:
|
70 |
+
# path: "/home/A02_tmpdata3/text2text_data_xlgeng_three_kingdoms_by_xlgeng_added_by_20250701/shards_list.txt"
|
71 |
+
# path_huawei: "/mnt/sfs/asr/update_data/text2text_data_xlgeng_three_kingdoms_by_xlgeng_added_by_20250701/shards_list.txt"
|
72 |
+
# weight: 1
|
73 |
+
# lab_path_huawei: "/home/work_nfs11/asr_data/data/text2text_data_xlgeng/shard/benchdata/three_kingdoms/shard/shards_list.txt"
|
74 |
+
# shard_num: 24
|
75 |
+
#
|
76 |
+
#text2text_data_xlgeng_qwenomni_bench_data_by_xlgeng_added_by_20250701:
|
77 |
+
# path: "/home/A02_tmpdata3/text2text_data_xlgeng_qwenomni_bench_data_by_xlgeng_added_by_20250701/shards_list.txt"
|
78 |
+
# path_huawei: "/mnt/sfs/asr/update_data/text2text_data_xlgeng_qwenomni_bench_data_by_xlgeng_added_by_20250701/shards_list.txt"
|
79 |
+
# weight: 1
|
80 |
+
# lab_path_huawei: "/home/work_nfs11/asr_data/data/text2text_data_xlgeng/shard/benchdata/qwenomni_bench_data/shard/shards_list.txt"
|
81 |
+
# shard_num: 113
|
82 |
+
|
83 |
+
|
84 |
+
#text2text_data_xlgeng_voicebench_data_by_xlgeng_added_by_20250701:
|
85 |
+
# path: "/home/A02_tmpdata3/text2text_data_xlgeng_voicebench_data_by_xlgeng_added_by_20250701/shards_list.txt"
|
86 |
+
# path_huawei: "/mnt/sfs/asr/update_data/text2text_data_xlgeng_voicebench_data_by_xlgeng_added_by_20250701/shards_list.txt"
|
87 |
+
# weight: 1
|
88 |
+
# lab_path_huawei: "/home/work_nfs11/asr_data/data/text2text_data_xlgeng/shard/benchdata/voicebench_data/shard/shards_list.txt"
|
89 |
+
# shard_num: 65
|
90 |
+
#
|
91 |
+
#t2t_age_chat_by_cywang_added_by_20250708: # have
|
92 |
+
# path: "/home/A02_tmpdata3/osum_s2s/t2t_age_chat_by_cywang_added_by_20250708/shards_list.txt"
|
93 |
+
# path_huawei: "/mnt/sfs/asr/update_data/t2t_age_chat_by_cywang_added_by_20250708/shards_list.txt"
|
94 |
+
# lab_path_huawei: "/home/work_nfs11/asr_data/data/osum_data/t2t_paralanguage_chat/age_chat/shard_dir/shards_list.txt"
|
95 |
+
# shard_num: 50
|
96 |
+
# weight: 1
|
97 |
+
#
|
98 |
+
#t2t_caption_chat_by_cywang_added_by_20250708: # have
|
99 |
+
# path: "/home/A02_tmpdata3/osum_s2s/t2t_caption_chat_by_cywang_added_by_20250708/shards_list.txt"
|
100 |
+
# lab_path_huawei: "/home/work_nfs11/asr_data/data/osum_data/t2t_paralanguage_chat/caption_chat/shard_dir/shards_list.txt"
|
101 |
+
# path_huawei: "/mnt/sfs/asr/update_data/t2t_caption_chat_by_cywang_added_by_20250708/shards_list.txt"
|
102 |
+
# shard_num: 100
|
103 |
+
# weight: 1
|
104 |
+
#
|
105 |
+
#t2t_emotion_chat_by_cywang_added_by_20250708: # have
|
106 |
+
# path: "/home/A02_tmpdata3/osum_s2s/t2t_emotion_chat_by_cywang_added_by_20250708/shards_list.txt"
|
107 |
+
# lab_path_huawei: "/home/work_nfs11/asr_data/data/osum_data/t2t_paralanguage_chat/emotion_chat/shard_dir/shards_list.txt"
|
108 |
+
# path_huawei: "/mnt/sfs/asr/update_data/t2t_emotion_chat_by_cywang_added_by_20250708/shards_list.txt"
|
109 |
+
# shard_num: 50
|
110 |
+
# weight: 1
|
111 |
+
#
|
112 |
+
#t2t_sex_chat_by_cywang_added_by_20250708: # have
|
113 |
+
# path: "/home/A02_tmpdata3/osum_s2s/t2t_sex_chat_by_cywang_added_by_20250708/shards_list.txt"
|
114 |
+
# lab_path_huawei: "/home/work_nfs11/asr_data/data/osum_data/t2t_paralanguage_chat/sex_chat/shard_dir/shards_list.txt"
|
115 |
+
# path_huawei: "/mnt/sfs/asr/update_data/t2t_sex_chat_by_cywang_added_by_20250708/shards_list.txt"
|
116 |
+
# shard_num: 50
|
117 |
+
# weight: 1
|
118 |
+
#
|
119 |
+
#t2t_xianshi_emotion_chat_by_cywang_added_by_20250711: # no
|
120 |
+
# path: "/home/A02_tmpdata3/osum_t2t/t2t_xianshi_emotion_chat/shards_list.txt"
|
121 |
+
# lab_path_huawei: "/home/work_nfs11/asr_data/data/osum_data/t2t_paralanguage_chat/xianshi_emotion_chat/shard_dir/shards_list.txt|/home/work_nfs23/asr_data/data/osum_chat/t2t_data/t2t_paralanguage_chat/xianshi_emotion_chat/shards_list.txt"
|
122 |
+
# path_huawei: "/mnt/sfs/asr/update_data/t2t_xianshi_emotion_chat_by_cywang_added_by_20250711/shards_list.txt"
|
123 |
+
# shard_num: 50
|
124 |
+
# weight: 1
|
125 |
+
#
|
126 |
+
#t2t_sex_chat_2_by_cywang_added_by_20250711: # no
|
127 |
+
# path: "/home/A02_tmpdata3/osum_t2t/t2t_sex_chat_2_by_cywang_added_by_20250711/shards_list.txt"
|
128 |
+
# lab_path_huawei: "/home/work_nfs11/asr_data/data/osum_data/t2t_paralanguage_chat/sex_chat_2/shard_dir/shards_list.txt|/home/work_nfs23/asr_data/data/osum_chat/t2t_data/t2t_paralanguage_chat/t2t_sex_chat_2_by_cywang_added_by_20250711/shards_list.txt"
|
129 |
+
# path_huawei: "/mnt/sfs/asr/update_data/t2t_sex_chat_2_by_cywang_added_by_20250711/shards_list.txt"
|
130 |
+
# shard_num: 27
|
131 |
+
# weight: 1
|
132 |
+
#
|
133 |
+
#t2t_age_chat_2_by_cywang_added_by_20250711: # no
|
134 |
+
# path: "/home/A02_tmpdata3/osum_t2t/t2t_age_chat_2_by_cywang_added_by_20250711/shards_list.txt"
|
135 |
+
# path_huawei: "/mnt/sfs/asr/update_data/t2t_age_chat_2_by_cywang_added_by_20250711/shards_list.txt"
|
136 |
+
# lab_path_huawei: "/home/work_nfs11/asr_data/data/osum_data/t2t_paralanguage_chat/age_chat_2/shard_dir/shards_list.txt|/home/work_nfs23/asr_data/data/osum_chat/t2t_data/t2t_paralanguage_chat/t2t_age_chat_2_by_cywang_added_by_20250711/shards_list.txt"
|
137 |
+
# shard_num: 27
|
138 |
+
# weight: 1
|
139 |
+
#
|
140 |
+
#t2t_sex_chat_2_by_cywang_added_by_20250715: # no
|
141 |
+
# path: "/home/A02_tmpdata3/osum_t2t/t2t_sex_chat_2_by_cywang_added_by_20250715/shards_list.txt"
|
142 |
+
# path_huawei: "/mnt/sfs/asr/update_data/t2t_sex_chat_2_by_cywang_added_by_20250715/shards_list.txt"
|
143 |
+
# lab_path_huawei: "/home/work_nfs14/asr_data/data/osum_data/t2t_paralanguage_chat/sex_chat_2/shard_dir/shards_list.txt"
|
144 |
+
# shard_num: 10
|
145 |
+
# weight: 1
|
146 |
+
#
|
147 |
+
#t2t_age_chat_3_by_cywang_added_by_20250716: # no
|
148 |
+
# path: "/home/A02_tmpdata3/osum_t2t/t2t_age_chat_3_by_cywang_added_by_20250716/shards_list.txt"
|
149 |
+
# path_huawei: "/mnt/sfs/asr/update_data/t2t_age_chat_3_by_cywang_added_by_20250716/shards_list.txt"
|
150 |
+
# lab_path_huawei: "/home/work_nfs14/asr_data/data/osum_data/t2t_paralanguage_chat/age_chat_3/shard_dir/shards_list.txt"
|
151 |
+
# shard_num: 10
|
152 |
+
# weight: 1
|
153 |
+
#
|
154 |
+
#t2t_caption_chat_3_by_cywang_added_by_20250716: # have
|
155 |
+
# path: "/home/A02_tmpdata3/osum_t2t/t2t_caption_chat_3_by_cywang_added_by_20250716/shards_list.txt"
|
156 |
+
# path_huawei: "/mnt/sfs/asr/update_data/t2t_caption_chat_3_by_cywang_added_by_20250716/shards_list.txt"
|
157 |
+
# lab_path_huawei: "/home/work_nfs14/asr_data/data/osum_data/t2t_paralanguage_chat/caption_chat_3/shard_dir/shards_list.txt"
|
158 |
+
# shard_num: 10
|
159 |
+
# weight: 1
|
conf/data_tmp.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gaozhiliang_gbma:
|
2 |
+
path: /home/A02_tmpdata3/osum_s2s/gaozhiliang_gbma/shards_list.txt
|
3 |
+
new_data_list: /home/node44_tmpdata3/netease/gbma/workspace/osum/data/process/0803/all_data_info.jsonl
|
4 |
+
new_lab_path: /home/work_nfs23/asr_data/data/osum_chat/s2s/gaozhiliang_gbma/shards_list.txt
|
5 |
+
shard_num: 24
|
6 |
+
weight: 10
|
conf/ds_stage2.json
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train_micro_batch_size_per_gpu": 1,
|
3 |
+
"gradient_accumulation_steps": 8,
|
4 |
+
"steps_per_print": 10,
|
5 |
+
"gradient_clipping": 5,
|
6 |
+
"fp16": {
|
7 |
+
"enabled": false,
|
8 |
+
"auto_cast": true,
|
9 |
+
"loss_scale": 0,
|
10 |
+
"initial_scale_power": 16,
|
11 |
+
"loss_scale_window": 1000,
|
12 |
+
"hysteresis": 2,
|
13 |
+
"consecutive_hysteresis": false,
|
14 |
+
"min_loss_scale": 1
|
15 |
+
},
|
16 |
+
"bf16": {
|
17 |
+
"enabled": true
|
18 |
+
},
|
19 |
+
"zero_force_ds_cpu_optimizer": false,
|
20 |
+
"zero_optimization": {
|
21 |
+
"stage": 2,
|
22 |
+
"offload_optimizer": {
|
23 |
+
"device": "none",
|
24 |
+
"pin_memory": true
|
25 |
+
},
|
26 |
+
"allgather_partitions": true,
|
27 |
+
"allgather_bucket_size": 2e8,
|
28 |
+
"reduce_scatter": true,
|
29 |
+
"reduce_bucket_size": 2e8,
|
30 |
+
"contiguous_gradients": false,
|
31 |
+
"overlap_comm": false
|
32 |
+
},
|
33 |
+
"find_unused_parameters": true
|
34 |
+
}
|
conf/empty.yaml
ADDED
File without changes
|
conf/prompt_config.yaml
ADDED
The diff for this file is too large to render.
See raw diff
|
|
conf/system_prompt.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# qwen_instruct_prompt_pattern_chat_s2t = "<|im_start|>system\nYou are OSUM-chat, a speech-to-text dialogue. You understand both the meaning and paralinguistic cues in speech then respond exclusively with appropriate text.<|im_end|>\n"
|
2 |
+
# qwen_instruct_prompt_pattern__chat_t2t = "<|im_start|>system\n<|im_end|>\n"
|
3 |
+
# qwen_instruct_prompt_pattern_chat_s2s = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech then respond with appropriate text and emotionally matching synthetic speech.<|im_end|>\n"
|
4 |
+
# qwen_instruct_prompt_pattern_chat_s2s_think = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech. Before responding, first output your reasoning inside <think>...</think end>, analyzing the user’s words and vocal cues. Then generate a reply with appropriate text and emotionally matched synthetic speech.<|im_end|>\n"
|
5 |
+
# qwen_instruct_prompt_pattern_chat_s2s_streaming = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You analyze speech (content + paralinguistic cues) and respond with interleaved text and emotionally-matched synthetic speech.<|im_end|>\n"
|
6 |
+
# qwen_instruct_prompt_pattern_chat_s2s_streaming_think = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You analyze speech (both content and paralinguistic cues). Before responding, output your reasoning in <think>...</think end>. Then reply with interleaved text and emotionally matched synthetic speech.<|im_end|>\n"
|
7 |
+
# qwen_instruct_prompt_pattern__chat_t2t = "<|im_start|>system\n
|
8 |
+
|
9 |
+
# qwen_instruct_prompt_pattern_1_understand = "<|im_start|>system\nYou are OSUM-chat, an audio understanding assistant by ASLP Lab. You can transcribe speech accurately and analyze paralinguistic cues to provide precise text responses.<|im_end|>\n"
|
10 |
+
# qwen_instruct_prompt_pattern_1_tts = "<|im_start|>system\nYou are OSUM-chat, a speech synthesis assistant by ASLP Lab. You generate natural and fluent speech from text input.<|im_end|>\n"
|
11 |
+
# qwen_instruct_prompt_pattern_1_tts_streaming = "<|im_start|>system\nYou are OSUM-chat, a speech synthesis assistant by ASLP Lab. You generate natural speech from text input and output both audio and the original text in interleaved format.<|im_end|>\n"
|
12 |
+
# qwen_instruct_prompt_pattern_1_old = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
|
13 |
+
# # user_start = "<|im_start|>user\n"
|
14 |
+
t2t_chat: # <TEXT2TEXT>
|
15 |
+
prompt: You are OSUM-chat, a text-to-text dialogue assistant by ASLP Lab. You understand user input in text then respond exclusively with appropriate text.
|
16 |
+
|
17 |
+
s2t_chat: # <S2TCHAT>
|
18 |
+
prompt: You are OSUM-chat, a speech-to-text dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech then respond exclusively with appropriate text.
|
19 |
+
|
20 |
+
s2t_chat_thinker: # <S2TCHAT> <THINKER>
|
21 |
+
prompt: You are OSUM-chat, a thinking-enabled speech-to-text dialogue assistant by ASLP Lab. You not only comprehend the semantic meaning and paralinguistic cues in speech but also engage in deliberate reasoning to process such information. Based on this thinking process, you then respond exclusively with appropriate text.
|
22 |
+
|
23 |
+
t2s: # <TEXT2TOKEN>
|
24 |
+
prompt: You are OSUM-chat, a speech synthesis assistant by ASLP Lab. You generate natural and fluent speech from text input.
|
25 |
+
|
26 |
+
speech_understanding: # <TRANSCRIBE> <CAPTION> 。。
|
27 |
+
prompt: You are OSUM-chat, an audio understanding assistant by ASLP Lab. You can transcribe speech accurately and analyze paralinguistic cues to provide precise text responses.
|
patches/cumstom_stop_criteria.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers.generation.logits_process import LogitsProcessor
|
3 |
+
from transformers.generation.stopping_criteria import StoppingCriteria
|
4 |
+
|
5 |
+
class ASRLogitsProcessor(LogitsProcessor):
|
6 |
+
def __init__(self, text_token_num: int):
|
7 |
+
self.text_token_num = text_token_num
|
8 |
+
|
9 |
+
def __call__(self, input_ids, scores):
|
10 |
+
scores[..., self.text_token_num:] = torch.finfo(scores.dtype).min
|
11 |
+
return scores
|
12 |
+
|
13 |
+
class TTSLogitsProcessor(LogitsProcessor):
|
14 |
+
"""
|
15 |
+
TTS 任务使用的LogitsProcessor,把所有text位置的logits设置为负无穷
|
16 |
+
"""
|
17 |
+
def __init__(self, text_token_num: int):
|
18 |
+
self.text_token_num = text_token_num
|
19 |
+
|
20 |
+
def __call__(self, input_ids, scores):
|
21 |
+
scores[..., :self.text_token_num] = torch.finfo(scores.dtype).min
|
22 |
+
return scores
|
23 |
+
|
24 |
+
class S2SLogitsProcessor(LogitsProcessor):
|
25 |
+
"""Speech 2 Speech 任务使用的 LogitsProcessor,当前只适用于batch_size=1
|
26 |
+
|
27 |
+
Args:
|
28 |
+
LogitsProcessor (_type_): _description_
|
29 |
+
"""
|
30 |
+
def __init__(self, text_token_num: int, text_eos_id: int):
|
31 |
+
self.text_token_num = text_token_num
|
32 |
+
self.text_eos_id = text_eos_id
|
33 |
+
self.text_phase = True
|
34 |
+
def __call__(self, input_ids, scores):
|
35 |
+
print(input_ids.shape)
|
36 |
+
assert input_ids.size(0) == 1, "ERROR: S2SSpeechLogitsProcessor only support bs=1 now"
|
37 |
+
if self.text_phase:
|
38 |
+
scores[..., self.text_token_num:] = torch.finfo(scores.dtype).min
|
39 |
+
else:
|
40 |
+
scores[..., :self.text_token_num] = torch.finfo(scores.dtype).min
|
41 |
+
|
42 |
+
if self.text_phase and torch.isin(input_ids, self.text_eos_id):
|
43 |
+
self.text_phase = False
|
44 |
+
|
45 |
+
return scores
|
46 |
+
|
47 |
+
class S2SStopCriteria(StoppingCriteria):
|
48 |
+
"""Speech 2 Speech 任务使用的 停止条件,当前只适用于batch_size=1
|
49 |
+
|
50 |
+
Args:
|
51 |
+
LogitsProcessor (_type_): _description_
|
52 |
+
"""
|
53 |
+
def __init__(self, text_eos_id: int, speech_eos_id: int):
|
54 |
+
self.text_eos_id = text_eos_id
|
55 |
+
self.speech_eos_id = speech_eos_id
|
56 |
+
|
57 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
|
58 |
+
_input_ids = input_ids.flatten().view(-1)
|
59 |
+
if torch.isin(_input_ids, self.text_eos_id).any():
|
60 |
+
text_eos_idx = (_input_ids == self.text_eos_id).nonzero(as_tuple=True)[0][0].item()
|
61 |
+
if torch.sum(_input_ids[text_eos_idx:] == self.speech_eos_id) > 1:
|
62 |
+
return True
|
63 |
+
return False
|
64 |
+
|
65 |
+
class MaxTokenStopper(StoppingCriteria):
|
66 |
+
def __init__(self, max_tokens):
|
67 |
+
self.max_tokens = max_tokens
|
68 |
+
|
69 |
+
# TODO@wsy:期望能够修改max_tokens,但好像没用,后续注意
|
70 |
+
def change_max_tokens(self, max_tokens):
|
71 |
+
self.max_tokens = max_tokens
|
72 |
+
|
73 |
+
def __call__(self, input_ids, scores, **kwargs):
|
74 |
+
return input_ids.shape[1] >= self.max_tokens # 检查当前序列长度
|
75 |
+
|
76 |
+
class InterruptStopper(StoppingCriteria):
|
77 |
+
def __init__(self):
|
78 |
+
self.stop = False
|
79 |
+
|
80 |
+
def __call__(self, input_ids, scores, **kwargs):
|
81 |
+
if self.stop == True:
|
82 |
+
# self.stop == False # reset
|
83 |
+
return True
|
84 |
+
else:
|
85 |
+
return False
|
patches/custom_speech_ngram_blocking.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.generation.logits_process import LogitsProcessor
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class SpeechOnlyNGramBlockingLogitsProcessor(LogitsProcessor):
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
speech_token_num,
|
8 |
+
repeat_times=5,
|
9 |
+
special_token_repeat_times_dict=None,
|
10 |
+
window_size=8,
|
11 |
+
window_repeat=5,
|
12 |
+
special_token_window_dict=None
|
13 |
+
):
|
14 |
+
"""
|
15 |
+
speech_token_num: int, speech token 的数量(token_id in [0, speech_token_num) 视为 speech token)
|
16 |
+
repeat_times: int, 普通 speech token 的最大允许连续重复次数
|
17 |
+
special_token_repeat_times_dict: dict, {token_id: repeat_times},为特殊 speech token 单独指定最大连续重复次数
|
18 |
+
window_size: int, 默认滑动窗口大小
|
19 |
+
window_repeat: int, 默认窗口内最大允许出现次数
|
20 |
+
special_token_window_dict: dict, {token_id: (window_size, window_repeat)},为特殊 token 单独指定窗口参数
|
21 |
+
"""
|
22 |
+
self.speech_token_num = speech_token_num
|
23 |
+
self.repeat_times = repeat_times
|
24 |
+
self.special_token_repeat_times_dict = special_token_repeat_times_dict or {}
|
25 |
+
self.speech_phase = False # 你需要在外部控制这个变量
|
26 |
+
self.window_size = window_size
|
27 |
+
self.window_repeat = window_repeat
|
28 |
+
self.special_token_window_dict = special_token_window_dict or {1446: (13, 10)}
|
29 |
+
|
30 |
+
def set_phase(self, speech_phase: bool):
|
31 |
+
self.speech_phase = speech_phase
|
32 |
+
|
33 |
+
def __call__(self, input_ids, scores):
|
34 |
+
if not self.speech_phase:
|
35 |
+
# text 阶段,什么都不做
|
36 |
+
return scores
|
37 |
+
batch_size, seq_len = input_ids.size()
|
38 |
+
for batch_idx in range(batch_size):
|
39 |
+
generated = input_ids[batch_idx].tolist()
|
40 |
+
if seq_len == 0:
|
41 |
+
continue
|
42 |
+
last_token = generated[-1]
|
43 |
+
if last_token >= self.speech_token_num:
|
44 |
+
continue # 不是 speech token
|
45 |
+
|
46 |
+
# 统计最近的 token 连续重复了多少次
|
47 |
+
repeat_count = 1
|
48 |
+
for i in range(seq_len-2, -1, -1):
|
49 |
+
if generated[i] == last_token:
|
50 |
+
repeat_count += 1
|
51 |
+
else:
|
52 |
+
break
|
53 |
+
# 获取该 token 的最大允许重复次数
|
54 |
+
max_repeat = self.special_token_repeat_times_dict.get(last_token, self.repeat_times)
|
55 |
+
if repeat_count >= max_repeat:
|
56 |
+
scores[batch_idx, last_token] = -float('inf') # 阻止生成
|
57 |
+
|
58 |
+
# ====== 滑动窗口内频率抑制 ======
|
59 |
+
# 对窗口内所有 speech token 检查
|
60 |
+
window_tokens = set(generated[-max(self.window_size, max([v[0] for v in self.special_token_window_dict.values()], default=0)):])
|
61 |
+
for token in window_tokens:
|
62 |
+
if token >= self.speech_token_num:
|
63 |
+
continue
|
64 |
+
# 获取该 token 的窗口参数
|
65 |
+
window_size, window_repeat = self.special_token_window_dict.get(
|
66 |
+
token, (self.window_size, self.window_repeat)
|
67 |
+
)
|
68 |
+
window = generated[-window_size:]
|
69 |
+
if window.count(token) >= window_repeat:
|
70 |
+
scores[batch_idx, token] = -float('inf')
|
71 |
+
# ====== 滑动窗口内频率抑制结束 ======
|
72 |
+
return scores
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
class OSUM_chat_LogitsProcessor(LogitsProcessor):
|
78 |
+
def __init__(self, allowed_tokens, sequence_to_match):
|
79 |
+
"""
|
80 |
+
初始化OSUM_chat_LogitsProcessor。
|
81 |
+
|
82 |
+
参数:
|
83 |
+
allowed_tokens (list): 允许出现在当前时间步的token的ID列表
|
84 |
+
sequence_to_match (list): 用来判断当前时间步允许token的前置序列
|
85 |
+
"""
|
86 |
+
self.allowed_tokens = allowed_tokens
|
87 |
+
self.sequence_to_match = sequence_to_match
|
88 |
+
self.match_found = False # 添加一个标志,表示是否已经找到匹配的序列
|
89 |
+
|
90 |
+
def init_match_found(self):
|
91 |
+
"""
|
92 |
+
初始化match_found标志。
|
93 |
+
"""
|
94 |
+
self.match_found = False
|
95 |
+
|
96 |
+
def __call__(self, input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
|
97 |
+
"""
|
98 |
+
在每个时间步处理logits,对不符合条件的token设置极小的概率。
|
99 |
+
|
100 |
+
参数:
|
101 |
+
input_ids (torch.Tensor): 当前输入的token ID序列
|
102 |
+
logits (torch.Tensor): 当前时间步的logits (shape: [batch_size, vocab_size])
|
103 |
+
|
104 |
+
返回:
|
105 |
+
torch.Tensor: 被处理过的logits
|
106 |
+
"""
|
107 |
+
# 如果已经匹配过一次,就跳过匹配检测,直接返回logits
|
108 |
+
# print("recent_tokens:!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") # 打印当前生成的序列
|
109 |
+
if self.match_found:
|
110 |
+
return logits
|
111 |
+
|
112 |
+
# 获取当前生成的序列的最后几个token(假设生成的长度大于等于序列长度)
|
113 |
+
sequence_length = len(self.sequence_to_match)
|
114 |
+
if input_ids.shape[-1] >= sequence_length:
|
115 |
+
recent_tokens = input_ids[:, -sequence_length:].tolist()
|
116 |
+
# print("recent_tokens:", recent_tokens) # 打印当前生成的序列
|
117 |
+
|
118 |
+
# 检查前面生成的token是否匹配我们需要的序列
|
119 |
+
if all(recent_tokens[0][i] == self.sequence_to_match[i] for i in range(sequence_length)):
|
120 |
+
# Create a mask for allowed tokens while preserving original logits
|
121 |
+
mask = torch.zeros_like(logits, dtype=torch.bool) # Initialize mask as False
|
122 |
+
mask[:, self.allowed_tokens] = True # Mark allowed tokens as True
|
123 |
+
# Apply mask: keep original logits for allowed tokens, set others to -inf
|
124 |
+
logits = torch.where(mask, logits, -float('inf'))
|
125 |
+
# 设置标志,表示匹配已成功
|
126 |
+
self.match_found = True
|
127 |
+
print("match found!!!!!!!!!!!!!!!!!!!!!!!")
|
128 |
+
|
129 |
+
return logits
|
patches/custom_speech_repetition_penalty.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.generation.logits_process import LogitsProcessor
|
2 |
+
|
3 |
+
class SpeechOnlyRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
4 |
+
def __init__(self, speech_token_num, penalty=1.2):
|
5 |
+
self.speech_token_num = speech_token_num
|
6 |
+
self.penalty = penalty
|
7 |
+
self.speech_phase = False # 你需要在外部控制这个变量
|
8 |
+
|
9 |
+
def set_phase(self, speech_phase: bool):
|
10 |
+
self.speech_phase = speech_phase
|
11 |
+
|
12 |
+
def __call__(self, input_ids, scores):
|
13 |
+
if not self.speech_phase:
|
14 |
+
# text阶段,什么都不做
|
15 |
+
return scores
|
16 |
+
# speech阶段,只对speech token做重复抑制
|
17 |
+
for batch_idx in range(input_ids.size(0)):
|
18 |
+
generated = input_ids[batch_idx].tolist()
|
19 |
+
for token_id in set(generated):
|
20 |
+
if 0 <= token_id < self.speech_token_num:
|
21 |
+
scores[batch_idx, token_id] /= self.penalty
|
22 |
+
return scores
|
patches/modelling_fm_infer_gpu.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import List, Optional, Tuple, Union
|
4 |
+
import f5_tts
|
5 |
+
from f5_tts.model.backbones.dit_mask import DiT as DiT_
|
6 |
+
|
7 |
+
_GPU_FM_TORCH_COMPILE = True
|
8 |
+
|
9 |
+
class GPUDiT(DiT_):
|
10 |
+
def __init__(self, *args, **kwargs):
|
11 |
+
super().__init__(*args, **kwargs)
|
12 |
+
self.fast_forward = torch.compile(self.fast_forward, dynamic=False, fullgraph=True) \
|
13 |
+
if _GPU_FM_TORCH_COMPILE else self.fast_forward
|
14 |
+
|
15 |
+
# ===================================================================
|
16 |
+
print("========================= DO FM PATCH ============================")
|
17 |
+
# ===================================================================
|
18 |
+
f5_tts.model.backbones.dit_mask.DiT = GPUDiT
|
patches/modelling_qwen2_infer_gpu.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
import transformers.models
|
6 |
+
from transformers.models.qwen2.modeling_qwen2 import (
|
7 |
+
Qwen2RotaryEmbedding,
|
8 |
+
Qwen2ForCausalLM,
|
9 |
+
Qwen2MLP,
|
10 |
+
Qwen2RMSNorm,
|
11 |
+
apply_rotary_pos_emb,
|
12 |
+
repeat_kv,
|
13 |
+
_prepare_4d_causal_attention_mask_with_cache_position,
|
14 |
+
)
|
15 |
+
from transformers.utils import logging
|
16 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
17 |
+
from transformers.cache_utils import Cache, StaticCache, SlidingWindowCache
|
18 |
+
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
19 |
+
from .utils import InferTaskCode
|
20 |
+
|
21 |
+
logger = logging.get_logger(__name__)
|
22 |
+
|
23 |
+
_GPU_QWEN_TORCH_COMPILE = True
|
24 |
+
|
25 |
+
# ===================================================================
|
26 |
+
# =============================Attention=============================
|
27 |
+
# ===================================================================
|
28 |
+
class GPUQwen2Attention(nn.Module):
|
29 |
+
"""
|
30 |
+
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
31 |
+
and "Generating Long Sequences with Sparse Transformers".
|
32 |
+
"""
|
33 |
+
def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
|
34 |
+
super().__init__()
|
35 |
+
self.config = config
|
36 |
+
self.layer_idx = layer_idx
|
37 |
+
if layer_idx is None:
|
38 |
+
logger.warning_once(
|
39 |
+
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
40 |
+
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
41 |
+
"when creating this class."
|
42 |
+
)
|
43 |
+
|
44 |
+
self.hidden_size = config.hidden_size
|
45 |
+
self.num_heads = config.num_attention_heads
|
46 |
+
self.head_dim = self.hidden_size // self.num_heads
|
47 |
+
self.num_key_value_heads = config.num_key_value_heads
|
48 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
49 |
+
self.max_position_embeddings = config.max_position_embeddings
|
50 |
+
self.rope_theta = config.rope_theta
|
51 |
+
self.is_causal = True
|
52 |
+
self.attention_dropout = config.attention_dropout
|
53 |
+
|
54 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
55 |
+
raise ValueError(
|
56 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
57 |
+
f" and `num_heads`: {self.num_heads})."
|
58 |
+
)
|
59 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
60 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
61 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
62 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
63 |
+
|
64 |
+
self.rotary_emb = Qwen2RotaryEmbedding(
|
65 |
+
self.head_dim,
|
66 |
+
max_position_embeddings=self.max_position_embeddings,
|
67 |
+
base=self.rope_theta,
|
68 |
+
)
|
69 |
+
|
70 |
+
# Adapted from Qwen2Attention.forward
|
71 |
+
def forward(
|
72 |
+
self,
|
73 |
+
hidden_states: torch.Tensor,
|
74 |
+
attention_mask: Optional[torch.Tensor] = None,
|
75 |
+
position_ids: Optional[torch.LongTensor] = None,
|
76 |
+
past_key_value: Optional[Cache] = None,
|
77 |
+
output_attentions: bool = False,
|
78 |
+
use_cache: bool = False,
|
79 |
+
cache_position: Optional[torch.LongTensor] = None,
|
80 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
81 |
+
bsz, q_len, _ = hidden_states.size()
|
82 |
+
|
83 |
+
query_states = self.q_proj(hidden_states)
|
84 |
+
key_states = self.k_proj(hidden_states)
|
85 |
+
value_states = self.v_proj(hidden_states)
|
86 |
+
|
87 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
88 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
89 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
90 |
+
|
91 |
+
# NOTE: RoPE return all embedding (to satisfy torch compile)
|
92 |
+
cos, sin = self.rotary_emb(value_states, seq_len=past_key_value.get_max_length())
|
93 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
94 |
+
|
95 |
+
if past_key_value is not None:
|
96 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
97 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
98 |
+
|
99 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
100 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
101 |
+
|
102 |
+
causal_mask = attention_mask
|
103 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
104 |
+
causal_mask = attention_mask[:, :, :, : past_key_value.get_max_length()]
|
105 |
+
|
106 |
+
query_states = query_states.contiguous()
|
107 |
+
key_states = key_states.contiguous()
|
108 |
+
value_states = value_states.contiguous()
|
109 |
+
|
110 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
111 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
112 |
+
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
113 |
+
is_causal = True if causal_mask is None and q_len > 1 else False
|
114 |
+
|
115 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
116 |
+
query_states,
|
117 |
+
key_states,
|
118 |
+
value_states,
|
119 |
+
attn_mask=causal_mask,
|
120 |
+
dropout_p=0.0,
|
121 |
+
is_causal=is_causal,
|
122 |
+
)
|
123 |
+
|
124 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
125 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
126 |
+
|
127 |
+
attn_output = self.o_proj(attn_output)
|
128 |
+
|
129 |
+
return attn_output, None, past_key_value
|
130 |
+
|
131 |
+
|
132 |
+
# ===================================================================
|
133 |
+
# =============================Layer=================================
|
134 |
+
# ===================================================================
|
135 |
+
class GPUQwen2DecoderLayer(nn.Module):
|
136 |
+
def __init__(self, config: Qwen2Config, layer_idx: int):
|
137 |
+
super().__init__()
|
138 |
+
self.hidden_size = config.hidden_size
|
139 |
+
|
140 |
+
if config.sliding_window and config._attn_implementation != "flash_attention_2":
|
141 |
+
logger.warning_once(
|
142 |
+
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
143 |
+
"unexpected results may be encountered."
|
144 |
+
)
|
145 |
+
self.self_attn = GPUQwen2Attention(config, layer_idx)
|
146 |
+
|
147 |
+
self.mlp = Qwen2MLP(config)
|
148 |
+
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
149 |
+
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
150 |
+
|
151 |
+
def forward(
|
152 |
+
self,
|
153 |
+
hidden_states: torch.Tensor,
|
154 |
+
attention_mask: Optional[torch.Tensor] = None,
|
155 |
+
position_ids: Optional[torch.LongTensor] = None,
|
156 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
157 |
+
output_attentions: Optional[bool] = False,
|
158 |
+
use_cache: Optional[bool] = False,
|
159 |
+
cache_position: Optional[torch.LongTensor] = None,
|
160 |
+
**kwargs,
|
161 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
162 |
+
"""
|
163 |
+
Args:
|
164 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
165 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
166 |
+
`(batch, sequence_length)` where padding elements are indicated by 0.
|
167 |
+
output_attentions (`bool`, *optional*):
|
168 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
169 |
+
returned tensors for more detail.
|
170 |
+
use_cache (`bool`, *optional*):
|
171 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
172 |
+
(see `past_key_values`).
|
173 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
174 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
175 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
176 |
+
kwargs (`dict`, *optional*):
|
177 |
+
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
178 |
+
into the model
|
179 |
+
"""
|
180 |
+
|
181 |
+
residual = hidden_states
|
182 |
+
|
183 |
+
hidden_states = self.input_layernorm(hidden_states)
|
184 |
+
|
185 |
+
# Self Attention
|
186 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
187 |
+
hidden_states=hidden_states,
|
188 |
+
attention_mask=attention_mask,
|
189 |
+
position_ids=position_ids,
|
190 |
+
past_key_value=past_key_value,
|
191 |
+
output_attentions=output_attentions,
|
192 |
+
use_cache=use_cache,
|
193 |
+
cache_position=cache_position,
|
194 |
+
)
|
195 |
+
hidden_states = residual + hidden_states
|
196 |
+
|
197 |
+
# Fully Connected
|
198 |
+
residual = hidden_states
|
199 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
200 |
+
hidden_states = self.mlp(hidden_states)
|
201 |
+
hidden_states = residual + hidden_states
|
202 |
+
|
203 |
+
outputs = (hidden_states,)
|
204 |
+
|
205 |
+
if output_attentions:
|
206 |
+
outputs += (self_attn_weights,)
|
207 |
+
|
208 |
+
if use_cache:
|
209 |
+
outputs += (present_key_value,)
|
210 |
+
|
211 |
+
return outputs
|
212 |
+
|
213 |
+
# ===================================================================
|
214 |
+
# ========================Qwen2ForCausalLM===========================
|
215 |
+
# ===================================================================
|
216 |
+
class InferQwen2ForCausalLM(Qwen2ForCausalLM):
|
217 |
+
def __init__(self, config):
|
218 |
+
super().__init__(config)
|
219 |
+
self.compile_forward = torch.compile(self.simplify_forward, dynamic=False, fullgraph=True) \
|
220 |
+
if _GPU_QWEN_TORCH_COMPILE else self.simplify_forward
|
221 |
+
self.text_phase = True
|
222 |
+
'''
|
223 |
+
NOTE: 重写原Qwen2ForCausalLM forward函数,torchair直接编译原函数在返回CausalLMOutputWithPast时会出现编译错误
|
224 |
+
'''
|
225 |
+
def simplify_forward(self,
|
226 |
+
input_ids: torch.LongTensor = None,
|
227 |
+
attention_mask: Optional[torch.Tensor] = None,
|
228 |
+
position_ids: Optional[torch.LongTensor] = None,
|
229 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
230 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
231 |
+
labels: Optional[torch.LongTensor] = None,
|
232 |
+
use_cache: Optional[bool] = None,
|
233 |
+
output_attentions: Optional[bool] = None,
|
234 |
+
output_hidden_states: Optional[bool] = None,
|
235 |
+
return_dict: Optional[bool] = None,
|
236 |
+
cache_position: Optional[torch.LongTensor] = None,
|
237 |
+
):
|
238 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
239 |
+
output_hidden_states = (
|
240 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
241 |
+
)
|
242 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
243 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
244 |
+
outputs = self.model(
|
245 |
+
input_ids=input_ids,
|
246 |
+
attention_mask=attention_mask,
|
247 |
+
position_ids=position_ids,
|
248 |
+
past_key_values=past_key_values,
|
249 |
+
inputs_embeds=inputs_embeds,
|
250 |
+
use_cache=use_cache,
|
251 |
+
output_attentions=output_attentions,
|
252 |
+
output_hidden_states=output_hidden_states,
|
253 |
+
return_dict=return_dict,
|
254 |
+
cache_position=cache_position,
|
255 |
+
)
|
256 |
+
|
257 |
+
return outputs
|
258 |
+
|
259 |
+
def forward(self,
|
260 |
+
input_ids: torch.LongTensor = None,
|
261 |
+
attention_mask: Optional[torch.Tensor] = None,
|
262 |
+
position_ids: Optional[torch.LongTensor] = None,
|
263 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
264 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
265 |
+
labels: Optional[torch.LongTensor] = None,
|
266 |
+
use_cache: Optional[bool] = None,
|
267 |
+
output_attentions: Optional[bool] = None,
|
268 |
+
output_hidden_states: Optional[bool] = None,
|
269 |
+
return_dict: Optional[bool] = None,
|
270 |
+
cache_position: Optional[torch.LongTensor] = None,
|
271 |
+
do_compile = True
|
272 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
273 |
+
if past_key_values is not None:
|
274 |
+
past_key_values.training = False
|
275 |
+
# print(self.text_phase)
|
276 |
+
if input_ids is not None:
|
277 |
+
if self.text_phase:
|
278 |
+
inputs_embeds = self.model.embed_tokens(input_ids)
|
279 |
+
else:
|
280 |
+
inputs_embeds = self.speech_token_emded(input_ids)
|
281 |
+
if torch.isin(input_ids, 151645).any():
|
282 |
+
self.text_phase = False
|
283 |
+
input_ids = None
|
284 |
+
|
285 |
+
if (inputs_embeds is not None and cache_position[0] == 0) or do_compile==False :
|
286 |
+
# prefill branch
|
287 |
+
outputs = self.simplify_forward(input_ids,
|
288 |
+
attention_mask,
|
289 |
+
position_ids,
|
290 |
+
past_key_values,
|
291 |
+
inputs_embeds,
|
292 |
+
labels,
|
293 |
+
use_cache,
|
294 |
+
output_attentions,
|
295 |
+
output_hidden_states,
|
296 |
+
return_dict,
|
297 |
+
cache_position)
|
298 |
+
else:
|
299 |
+
# decoding
|
300 |
+
outputs = self.compile_forward(input_ids,
|
301 |
+
attention_mask,
|
302 |
+
position_ids,
|
303 |
+
past_key_values,
|
304 |
+
inputs_embeds,
|
305 |
+
labels,
|
306 |
+
use_cache,
|
307 |
+
output_attentions,
|
308 |
+
output_hidden_states,
|
309 |
+
return_dict,
|
310 |
+
cache_position)
|
311 |
+
|
312 |
+
last_hidden_states = outputs.last_hidden_state
|
313 |
+
|
314 |
+
if self.text_phase:
|
315 |
+
logits = self.lm_head(last_hidden_states)
|
316 |
+
else:
|
317 |
+
logits = self.speech_head(last_hidden_states)
|
318 |
+
|
319 |
+
logits = logits.float()
|
320 |
+
|
321 |
+
return CausalLMOutputWithPast(
|
322 |
+
loss=None,
|
323 |
+
logits=logits,
|
324 |
+
past_key_values=outputs.past_key_values,
|
325 |
+
hidden_states=outputs.hidden_states,
|
326 |
+
attentions=outputs.attentions,
|
327 |
+
)
|
328 |
+
|
329 |
+
|
330 |
+
def prepare_inputs_for_generation(
|
331 |
+
self,
|
332 |
+
input_ids,
|
333 |
+
past_key_values=None,
|
334 |
+
attention_mask=None,
|
335 |
+
inputs_embeds=None,
|
336 |
+
cache_position=None,
|
337 |
+
position_ids=None,
|
338 |
+
use_cache=True,
|
339 |
+
**kwargs,
|
340 |
+
):
|
341 |
+
"""
|
342 |
+
Mainly add static cache support
|
343 |
+
"""
|
344 |
+
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
345 |
+
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
346 |
+
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
347 |
+
if past_key_values is not None:
|
348 |
+
if inputs_embeds is not None: # Exception 1
|
349 |
+
input_ids = input_ids[:, -cache_position.shape[0] :]
|
350 |
+
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
351 |
+
input_ids = input_ids[:, cache_position]
|
352 |
+
|
353 |
+
if attention_mask is not None and position_ids is None:
|
354 |
+
# create position_ids on the fly for batch generation
|
355 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
356 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
357 |
+
if past_key_values:
|
358 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
359 |
+
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`,
|
360 |
+
# as otherwise the input `position_ids` would have various stride during the decoding.
|
361 |
+
# Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case,
|
362 |
+
# `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
363 |
+
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
364 |
+
|
365 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
366 |
+
if inputs_embeds is not None and cache_position[0] == 0:
|
367 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
368 |
+
else:
|
369 |
+
# NOTE: 与上述的position_ids相同,same as position_ids, for torch.compile and cuda graph
|
370 |
+
input_ids = input_ids.clone(memory_format=torch.contiguous_format)
|
371 |
+
model_inputs = {"input_ids": input_ids}
|
372 |
+
|
373 |
+
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
374 |
+
if inputs_embeds is not None and cache_position[0] == 0:
|
375 |
+
# prefill phase, inputs_embeds has shape (B,S,H)
|
376 |
+
batch_size, sequence_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
|
377 |
+
device = inputs_embeds.device
|
378 |
+
else:
|
379 |
+
# decdoing phase, input_ids has shape (B,S)
|
380 |
+
batch_size, sequence_length = input_ids.shape
|
381 |
+
device = input_ids.device
|
382 |
+
|
383 |
+
dtype = self.lm_head.weight.dtype
|
384 |
+
min_dtype = torch.finfo(dtype).min
|
385 |
+
|
386 |
+
if inputs_embeds is not None and inputs_embeds.ndim == 2 or input_ids is not None and input_ids.size(-1) == 1:
|
387 |
+
# we only expand attention mask in docoding mode
|
388 |
+
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
389 |
+
attention_mask,
|
390 |
+
sequence_length=sequence_length,
|
391 |
+
target_length=past_key_values.get_max_length(),
|
392 |
+
dtype=dtype,
|
393 |
+
device=device,
|
394 |
+
min_dtype=min_dtype,
|
395 |
+
cache_position=cache_position,
|
396 |
+
batch_size=batch_size,
|
397 |
+
)
|
398 |
+
|
399 |
+
model_inputs.update(
|
400 |
+
{
|
401 |
+
"position_ids": position_ids,
|
402 |
+
"cache_position": cache_position,
|
403 |
+
"past_key_values": past_key_values,
|
404 |
+
"use_cache": use_cache,
|
405 |
+
"attention_mask": attention_mask,
|
406 |
+
"do_compile": kwargs['do_compile'],
|
407 |
+
}
|
408 |
+
)
|
409 |
+
return model_inputs
|
410 |
+
|
411 |
+
# ===================================================================
|
412 |
+
print("========================= DO Qwen2 PATCH ===========================")
|
413 |
+
# ===================================================================
|
414 |
+
transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel._supports_static_cache = True # enable static cache
|
415 |
+
transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer = GPUQwen2DecoderLayer
|
416 |
+
transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM = InferQwen2ForCausalLM
|
patches/utils.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class InferTaskCode:
|
2 |
+
_ASR = 0
|
3 |
+
_TTS = 1
|
4 |
+
_S2S = 2
|
requirements.txt
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.24
|
2 |
+
jsonlines==4.0.0
|
3 |
+
torch==2.1.0
|
4 |
+
transformers==4.44.0
|
5 |
+
torchaudio==2.1.0
|
6 |
+
accelerate==1.7.0
|
7 |
+
peft==0.17.0
|
8 |
+
librosa
|
9 |
+
tensorboardX>=2.5
|
10 |
+
# torch_npu==2.1.0.post8
|
11 |
+
tqdm
|
12 |
+
absl-py
|
13 |
+
psutil
|
14 |
+
cloudpickle
|
15 |
+
ml-dtypes
|
16 |
+
tornado
|
17 |
+
openai-whisper
|
18 |
+
colorama
|
19 |
+
sox
|
20 |
+
deepspeed
|
21 |
+
librosa
|
22 |
+
gxl_ai_utils
|
23 |
+
|
24 |
+
|
25 |
+
hyperpyyaml
|
26 |
+
modelscope
|
27 |
+
onnxruntime
|
28 |
+
inflect
|
29 |
+
omegaconf
|
30 |
+
conformer
|
31 |
+
diffusers
|
32 |
+
hydra-core
|
33 |
+
lightning
|
34 |
+
|
35 |
+
gradio
|
36 |
+
cn2an
|
37 |
+
gdown
|
38 |
+
matplotlib
|
39 |
+
wget
|
40 |
+
pyarrow
|
41 |
+
pyworld
|
tts/__init__.py
ADDED
File without changes
|
tts/assert//345/256/236/351/252/214/345/256/244.png
ADDED
![]() |
tts/cosyvoice/__init__.py
ADDED
File without changes
|
tts/cosyvoice/bin/average_model.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
|
2 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
import argparse
|
18 |
+
import glob
|
19 |
+
|
20 |
+
import yaml
|
21 |
+
import torch
|
22 |
+
|
23 |
+
|
24 |
+
def get_args():
|
25 |
+
parser = argparse.ArgumentParser(description='average model')
|
26 |
+
parser.add_argument('--dst_model', required=True, help='averaged model')
|
27 |
+
parser.add_argument('--src_path',
|
28 |
+
required=True,
|
29 |
+
help='src model path for average')
|
30 |
+
parser.add_argument('--val_best',
|
31 |
+
action="store_true",
|
32 |
+
help='averaged model')
|
33 |
+
parser.add_argument('--num',
|
34 |
+
default=5,
|
35 |
+
type=int,
|
36 |
+
help='nums for averaged model')
|
37 |
+
|
38 |
+
args = parser.parse_args()
|
39 |
+
print(args)
|
40 |
+
return args
|
41 |
+
|
42 |
+
|
43 |
+
def main():
|
44 |
+
args = get_args()
|
45 |
+
val_scores = []
|
46 |
+
if args.val_best:
|
47 |
+
yamls = glob.glob('{}/*.yaml'.format(args.src_path))
|
48 |
+
yamls = [
|
49 |
+
f for f in yamls
|
50 |
+
if not (os.path.basename(f).startswith('train')
|
51 |
+
or os.path.basename(f).startswith('init'))
|
52 |
+
]
|
53 |
+
for y in yamls:
|
54 |
+
with open(y, 'r') as f:
|
55 |
+
dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
|
56 |
+
loss = float(dic_yaml['loss_dict']['loss'])
|
57 |
+
epoch = int(dic_yaml['epoch'])
|
58 |
+
step = int(dic_yaml['step'])
|
59 |
+
tag = dic_yaml['tag']
|
60 |
+
val_scores += [[epoch, step, loss, tag]]
|
61 |
+
sorted_val_scores = sorted(val_scores,
|
62 |
+
key=lambda x: x[2],
|
63 |
+
reverse=False)
|
64 |
+
print("best val (epoch, step, loss, tag) = " +
|
65 |
+
str(sorted_val_scores[:args.num]))
|
66 |
+
path_list = [
|
67 |
+
args.src_path + '/epoch_{}_whole.pt'.format(score[0])
|
68 |
+
for score in sorted_val_scores[:args.num]
|
69 |
+
]
|
70 |
+
print(path_list)
|
71 |
+
avg = {}
|
72 |
+
num = args.num
|
73 |
+
assert num == len(path_list)
|
74 |
+
for path in path_list:
|
75 |
+
print('Processing {}'.format(path))
|
76 |
+
states = torch.load(path, map_location=torch.device('cpu'))
|
77 |
+
for k in states.keys():
|
78 |
+
if k not in avg.keys():
|
79 |
+
avg[k] = states[k].clone()
|
80 |
+
else:
|
81 |
+
avg[k] += states[k]
|
82 |
+
# average
|
83 |
+
for k in avg.keys():
|
84 |
+
if avg[k] is not None:
|
85 |
+
# pytorch 1.6 use true_divide instead of /=
|
86 |
+
avg[k] = torch.true_divide(avg[k], num)
|
87 |
+
print('Saving to {}'.format(args.dst_model))
|
88 |
+
torch.save(avg, args.dst_model)
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == '__main__':
|
92 |
+
main()
|
tts/cosyvoice/bin/export_jit.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import print_function
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import logging
|
19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
20 |
+
import os
|
21 |
+
import sys
|
22 |
+
import torch
|
23 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
24 |
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
25 |
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
26 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
27 |
+
|
28 |
+
|
29 |
+
def get_args():
|
30 |
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
31 |
+
parser.add_argument('--model_dir',
|
32 |
+
type=str,
|
33 |
+
default='pretrained_models/CosyVoice-300M',
|
34 |
+
help='local path')
|
35 |
+
args = parser.parse_args()
|
36 |
+
print(args)
|
37 |
+
return args
|
38 |
+
|
39 |
+
|
40 |
+
def get_optimized_script(model, preserved_attrs=[]):
|
41 |
+
script = torch.jit.script(model)
|
42 |
+
if preserved_attrs != []:
|
43 |
+
script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
|
44 |
+
else:
|
45 |
+
script = torch.jit.freeze(script)
|
46 |
+
script = torch.jit.optimize_for_inference(script)
|
47 |
+
return script
|
48 |
+
|
49 |
+
|
50 |
+
def main():
|
51 |
+
args = get_args()
|
52 |
+
logging.basicConfig(level=logging.DEBUG,
|
53 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
54 |
+
|
55 |
+
torch._C._jit_set_fusion_strategy([('STATIC', 1)])
|
56 |
+
torch._C._jit_set_profiling_mode(False)
|
57 |
+
torch._C._jit_set_profiling_executor(False)
|
58 |
+
|
59 |
+
try:
|
60 |
+
model = CosyVoice(args.model_dir)
|
61 |
+
except Exception:
|
62 |
+
try:
|
63 |
+
model = CosyVoice2(args.model_dir)
|
64 |
+
except Exception:
|
65 |
+
raise TypeError('no valid model_type!')
|
66 |
+
|
67 |
+
if not isinstance(model, CosyVoice2):
|
68 |
+
# 1. export llm text_encoder
|
69 |
+
llm_text_encoder = model.model.llm.text_encoder
|
70 |
+
script = get_optimized_script(llm_text_encoder)
|
71 |
+
script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
|
72 |
+
script = get_optimized_script(llm_text_encoder.half())
|
73 |
+
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
|
74 |
+
|
75 |
+
# 2. export llm llm
|
76 |
+
llm_llm = model.model.llm.llm
|
77 |
+
script = get_optimized_script(llm_llm, ['forward_chunk'])
|
78 |
+
script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
|
79 |
+
script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
|
80 |
+
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
81 |
+
|
82 |
+
# 3. export flow encoder
|
83 |
+
flow_encoder = model.model.flow.encoder
|
84 |
+
script = get_optimized_script(flow_encoder)
|
85 |
+
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
86 |
+
script = get_optimized_script(flow_encoder.half())
|
87 |
+
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == '__main__':
|
91 |
+
main()
|
tts/cosyvoice/bin/export_onnx.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, [email protected])
|
2 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from __future__ import print_function
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import logging
|
20 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
21 |
+
import os
|
22 |
+
import sys
|
23 |
+
import onnxruntime
|
24 |
+
import random
|
25 |
+
import torch
|
26 |
+
from tqdm import tqdm
|
27 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
28 |
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
29 |
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
30 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
31 |
+
|
32 |
+
|
33 |
+
def get_dummy_input(batch_size, seq_len, out_channels, device):
|
34 |
+
x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
35 |
+
mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
|
36 |
+
mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
37 |
+
t = torch.rand((batch_size), dtype=torch.float32, device=device)
|
38 |
+
spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
|
39 |
+
cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
40 |
+
return x, mask, mu, t, spks, cond
|
41 |
+
|
42 |
+
|
43 |
+
def get_args():
|
44 |
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
45 |
+
parser.add_argument('--model_dir',
|
46 |
+
type=str,
|
47 |
+
default='pretrained_models/CosyVoice-300M',
|
48 |
+
help='local path')
|
49 |
+
args = parser.parse_args()
|
50 |
+
print(args)
|
51 |
+
return args
|
52 |
+
|
53 |
+
|
54 |
+
def main():
|
55 |
+
args = get_args()
|
56 |
+
logging.basicConfig(level=logging.DEBUG,
|
57 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
58 |
+
|
59 |
+
try:
|
60 |
+
model = CosyVoice(args.model_dir)
|
61 |
+
except Exception:
|
62 |
+
try:
|
63 |
+
model = CosyVoice2(args.model_dir)
|
64 |
+
except Exception:
|
65 |
+
raise TypeError('no valid model_type!')
|
66 |
+
|
67 |
+
# 1. export flow decoder estimator
|
68 |
+
estimator = model.model.flow.decoder.estimator
|
69 |
+
|
70 |
+
device = model.model.device
|
71 |
+
batch_size, seq_len = 2, 256
|
72 |
+
out_channels = model.model.flow.decoder.estimator.out_channels
|
73 |
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
74 |
+
torch.onnx.export(
|
75 |
+
estimator,
|
76 |
+
(x, mask, mu, t, spks, cond),
|
77 |
+
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
78 |
+
export_params=True,
|
79 |
+
opset_version=18,
|
80 |
+
do_constant_folding=True,
|
81 |
+
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
82 |
+
output_names=['estimator_out'],
|
83 |
+
dynamic_axes={
|
84 |
+
'x': {2: 'seq_len'},
|
85 |
+
'mask': {2: 'seq_len'},
|
86 |
+
'mu': {2: 'seq_len'},
|
87 |
+
'cond': {2: 'seq_len'},
|
88 |
+
'estimator_out': {2: 'seq_len'},
|
89 |
+
}
|
90 |
+
)
|
91 |
+
|
92 |
+
# 2. test computation consistency
|
93 |
+
option = onnxruntime.SessionOptions()
|
94 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
95 |
+
option.intra_op_num_threads = 1
|
96 |
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
97 |
+
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
98 |
+
sess_options=option, providers=providers)
|
99 |
+
|
100 |
+
for _ in tqdm(range(10)):
|
101 |
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
102 |
+
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
103 |
+
ort_inputs = {
|
104 |
+
'x': x.cpu().numpy(),
|
105 |
+
'mask': mask.cpu().numpy(),
|
106 |
+
'mu': mu.cpu().numpy(),
|
107 |
+
't': t.cpu().numpy(),
|
108 |
+
'spks': spks.cpu().numpy(),
|
109 |
+
'cond': cond.cpu().numpy()
|
110 |
+
}
|
111 |
+
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
112 |
+
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
main()
|
tts/cosyvoice/bin/export_trt.sh
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Copyright 2024 Alibaba Inc. All Rights Reserved.
|
3 |
+
# download tensorrt from https://developer.nvidia.com/tensorrt/download/10x, check your system and cuda for compatibability
|
4 |
+
# for example for linux + cuda12.4, you can download https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.Linux.x86_64-gnu.cuda-12.4.tar.gz
|
5 |
+
TRT_DIR=<YOUR_TRT_DIR>
|
6 |
+
MODEL_DIR=<COSYVOICE2_MODEL_DIR>
|
7 |
+
|
8 |
+
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64
|
9 |
+
$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp32.mygpu.plan --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw --outputIOFormats=fp32:chw
|
10 |
+
$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw
|
tts/cosyvoice/bin/inference.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import print_function
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import logging
|
19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
20 |
+
import os
|
21 |
+
import torch
|
22 |
+
from torch.utils.data import DataLoader
|
23 |
+
import torchaudio
|
24 |
+
from hyperpyyaml import load_hyperpyyaml
|
25 |
+
from tqdm import tqdm
|
26 |
+
from cosyvoice.cli.model import CosyVoiceModel
|
27 |
+
from cosyvoice.dataset.dataset import Dataset
|
28 |
+
|
29 |
+
|
30 |
+
def get_args():
|
31 |
+
parser = argparse.ArgumentParser(description='inference with your model')
|
32 |
+
parser.add_argument('--config', required=True, help='config file')
|
33 |
+
parser.add_argument('--prompt_data', required=True, help='prompt data file')
|
34 |
+
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
|
35 |
+
parser.add_argument('--tts_text', required=True, help='tts input file')
|
36 |
+
parser.add_argument('--llm_model', required=True, help='llm model file')
|
37 |
+
parser.add_argument('--flow_model', required=True, help='flow model file')
|
38 |
+
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
|
39 |
+
parser.add_argument('--gpu',
|
40 |
+
type=int,
|
41 |
+
default=-1,
|
42 |
+
help='gpu id for this rank, -1 for cpu')
|
43 |
+
parser.add_argument('--mode',
|
44 |
+
default='sft',
|
45 |
+
choices=['sft', 'zero_shot'],
|
46 |
+
help='inference mode')
|
47 |
+
parser.add_argument('--result_dir', required=True, help='asr result file')
|
48 |
+
args = parser.parse_args()
|
49 |
+
print(args)
|
50 |
+
return args
|
51 |
+
|
52 |
+
|
53 |
+
def main():
|
54 |
+
args = get_args()
|
55 |
+
logging.basicConfig(level=logging.DEBUG,
|
56 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
57 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
58 |
+
|
59 |
+
# Init cosyvoice models from configs
|
60 |
+
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
61 |
+
device = torch.device('cuda' if use_cuda else 'cpu')
|
62 |
+
with open(args.config, 'r') as f:
|
63 |
+
configs = load_hyperpyyaml(f)
|
64 |
+
|
65 |
+
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
66 |
+
model.load(args.llm_model, args.flow_model, args.hifigan_model)
|
67 |
+
|
68 |
+
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
|
69 |
+
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
|
70 |
+
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
71 |
+
|
72 |
+
del configs
|
73 |
+
os.makedirs(args.result_dir, exist_ok=True)
|
74 |
+
fn = os.path.join(args.result_dir, 'wav.scp')
|
75 |
+
f = open(fn, 'w')
|
76 |
+
with torch.no_grad():
|
77 |
+
for _, batch in tqdm(enumerate(test_data_loader)):
|
78 |
+
utts = batch["utts"]
|
79 |
+
assert len(utts) == 1, "inference mode only support batchsize 1"
|
80 |
+
text_token = batch["text_token"].to(device)
|
81 |
+
text_token_len = batch["text_token_len"].to(device)
|
82 |
+
tts_index = batch["tts_index"]
|
83 |
+
tts_text_token = batch["tts_text_token"].to(device)
|
84 |
+
tts_text_token_len = batch["tts_text_token_len"].to(device)
|
85 |
+
speech_token = batch["speech_token"].to(device)
|
86 |
+
speech_token_len = batch["speech_token_len"].to(device)
|
87 |
+
speech_feat = batch["speech_feat"].to(device)
|
88 |
+
speech_feat_len = batch["speech_feat_len"].to(device)
|
89 |
+
utt_embedding = batch["utt_embedding"].to(device)
|
90 |
+
spk_embedding = batch["spk_embedding"].to(device)
|
91 |
+
if args.mode == 'sft':
|
92 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
93 |
+
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
|
94 |
+
else:
|
95 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
96 |
+
'prompt_text': text_token, 'prompt_text_len': text_token_len,
|
97 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
98 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
99 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
100 |
+
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
|
101 |
+
tts_speeches = []
|
102 |
+
for model_output in model.tts(**model_input):
|
103 |
+
tts_speeches.append(model_output['tts_speech'])
|
104 |
+
tts_speeches = torch.concat(tts_speeches, dim=1)
|
105 |
+
tts_key = '{}_{}'.format(utts[0], tts_index[0])
|
106 |
+
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
|
107 |
+
torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
|
108 |
+
f.write('{} {}\n'.format(tts_key, tts_fn))
|
109 |
+
f.flush()
|
110 |
+
f.close()
|
111 |
+
logging.info('Result wav.scp saved in {}'.format(fn))
|
112 |
+
|
113 |
+
|
114 |
+
if __name__ == '__main__':
|
115 |
+
main()
|
tts/cosyvoice/bin/train.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import print_function
|
16 |
+
import argparse
|
17 |
+
import datetime
|
18 |
+
import logging
|
19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
20 |
+
from copy import deepcopy
|
21 |
+
import os
|
22 |
+
import torch
|
23 |
+
import torch.distributed as dist
|
24 |
+
import deepspeed
|
25 |
+
|
26 |
+
from hyperpyyaml import load_hyperpyyaml
|
27 |
+
|
28 |
+
from torch.distributed.elastic.multiprocessing.errors import record
|
29 |
+
|
30 |
+
from cosyvoice.utils.executor import Executor
|
31 |
+
from cosyvoice.utils.train_utils import (
|
32 |
+
init_distributed,
|
33 |
+
init_dataset_and_dataloader,
|
34 |
+
init_optimizer_and_scheduler,
|
35 |
+
init_summarywriter, save_model,
|
36 |
+
wrap_cuda_model, check_modify_and_save_config)
|
37 |
+
|
38 |
+
|
39 |
+
def get_args():
|
40 |
+
parser = argparse.ArgumentParser(description='training your network')
|
41 |
+
parser.add_argument('--train_engine',
|
42 |
+
default='torch_ddp',
|
43 |
+
choices=['torch_ddp', 'deepspeed'],
|
44 |
+
help='Engine for paralleled training')
|
45 |
+
parser.add_argument('--model', required=True, help='model which will be trained')
|
46 |
+
parser.add_argument('--config', required=True, help='config file')
|
47 |
+
parser.add_argument('--train_data', required=True, help='train data file')
|
48 |
+
parser.add_argument('--cv_data', required=True, help='cv data file')
|
49 |
+
parser.add_argument('--checkpoint', help='checkpoint model')
|
50 |
+
parser.add_argument('--model_dir', required=True, help='save model dir')
|
51 |
+
parser.add_argument('--tensorboard_dir',
|
52 |
+
default='tensorboard',
|
53 |
+
help='tensorboard log dir')
|
54 |
+
parser.add_argument('--ddp.dist_backend',
|
55 |
+
dest='dist_backend',
|
56 |
+
default='nccl',
|
57 |
+
choices=['nccl', 'gloo'],
|
58 |
+
help='distributed backend')
|
59 |
+
parser.add_argument('--num_workers',
|
60 |
+
default=0,
|
61 |
+
type=int,
|
62 |
+
help='num of subprocess workers for reading')
|
63 |
+
parser.add_argument('--prefetch',
|
64 |
+
default=100,
|
65 |
+
type=int,
|
66 |
+
help='prefetch number')
|
67 |
+
parser.add_argument('--pin_memory',
|
68 |
+
action='store_true',
|
69 |
+
default=False,
|
70 |
+
help='Use pinned memory buffers used for reading')
|
71 |
+
parser.add_argument('--use_amp',
|
72 |
+
action='store_true',
|
73 |
+
default=False,
|
74 |
+
help='Use automatic mixed precision training')
|
75 |
+
parser.add_argument('--deepspeed.save_states',
|
76 |
+
dest='save_states',
|
77 |
+
default='model_only',
|
78 |
+
choices=['model_only', 'model+optimizer'],
|
79 |
+
help='save model/optimizer states')
|
80 |
+
parser.add_argument('--timeout',
|
81 |
+
default=60,
|
82 |
+
type=int,
|
83 |
+
help='timeout (in seconds) of cosyvoice_join.')
|
84 |
+
parser = deepspeed.add_config_arguments(parser)
|
85 |
+
args = parser.parse_args()
|
86 |
+
return args
|
87 |
+
|
88 |
+
|
89 |
+
@record
|
90 |
+
def main():
|
91 |
+
args = get_args()
|
92 |
+
logging.basicConfig(level=logging.DEBUG,
|
93 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
94 |
+
# gan train has some special initialization logic
|
95 |
+
gan = True if args.model == 'hifigan' else False
|
96 |
+
|
97 |
+
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
|
98 |
+
if gan is True:
|
99 |
+
override_dict.pop('hift')
|
100 |
+
with open(args.config, 'r') as f:
|
101 |
+
configs = load_hyperpyyaml(f, overrides=override_dict)
|
102 |
+
if gan is True:
|
103 |
+
configs['train_conf'] = configs['train_conf_gan']
|
104 |
+
configs['train_conf'].update(vars(args))
|
105 |
+
|
106 |
+
# Init env for ddp
|
107 |
+
init_distributed(args)
|
108 |
+
|
109 |
+
# Get dataset & dataloader
|
110 |
+
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
|
111 |
+
init_dataset_and_dataloader(args, configs, gan)
|
112 |
+
|
113 |
+
# Do some sanity checks and save config to arsg.model_dir
|
114 |
+
configs = check_modify_and_save_config(args, configs)
|
115 |
+
|
116 |
+
# Tensorboard summary
|
117 |
+
writer = init_summarywriter(args)
|
118 |
+
|
119 |
+
# load checkpoint
|
120 |
+
model = configs[args.model]
|
121 |
+
start_step, start_epoch = 0, -1
|
122 |
+
if args.checkpoint is not None:
|
123 |
+
if os.path.exists(args.checkpoint):
|
124 |
+
state_dict = torch.load(args.checkpoint, map_location='cpu')
|
125 |
+
model.load_state_dict(state_dict, strict=False)
|
126 |
+
if 'step' in state_dict:
|
127 |
+
start_step = state_dict['step']
|
128 |
+
if 'epoch' in state_dict:
|
129 |
+
start_epoch = state_dict['epoch']
|
130 |
+
else:
|
131 |
+
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
|
132 |
+
|
133 |
+
# Dispatch model from cpu to gpu
|
134 |
+
model = wrap_cuda_model(args, model)
|
135 |
+
|
136 |
+
# Get optimizer & scheduler
|
137 |
+
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
|
138 |
+
scheduler.set_step(start_step)
|
139 |
+
if scheduler_d is not None:
|
140 |
+
scheduler_d.set_step(start_step)
|
141 |
+
|
142 |
+
# Save init checkpoints
|
143 |
+
info_dict = deepcopy(configs['train_conf'])
|
144 |
+
info_dict['step'] = start_step
|
145 |
+
info_dict['epoch'] = start_epoch
|
146 |
+
save_model(model, 'init', info_dict)
|
147 |
+
|
148 |
+
# Get executor
|
149 |
+
executor = Executor(gan=gan)
|
150 |
+
executor.step = start_step
|
151 |
+
|
152 |
+
# Init scaler, used for pytorch amp mixed precision training
|
153 |
+
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
|
154 |
+
print('start step {} start epoch {}'.format(start_step, start_epoch))
|
155 |
+
# Start training loop
|
156 |
+
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
|
157 |
+
executor.epoch = epoch
|
158 |
+
train_dataset.set_epoch(epoch)
|
159 |
+
dist.barrier()
|
160 |
+
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
|
161 |
+
if gan is True:
|
162 |
+
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
163 |
+
writer, info_dict, scaler, group_join)
|
164 |
+
else:
|
165 |
+
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
|
166 |
+
dist.destroy_process_group(group_join)
|
167 |
+
|
168 |
+
|
169 |
+
if __name__ == '__main__':
|
170 |
+
main()
|
tts/cosyvoice/cli/__init__.py
ADDED
File without changes
|
tts/cosyvoice/cli/cosyvoice.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import time
|
16 |
+
from typing import Generator
|
17 |
+
from tqdm import tqdm
|
18 |
+
from hyperpyyaml import load_hyperpyyaml
|
19 |
+
from modelscope import snapshot_download
|
20 |
+
import torch
|
21 |
+
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
22 |
+
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
23 |
+
from cosyvoice.utils.file_utils import logging
|
24 |
+
from cosyvoice.utils.class_utils import get_model_type
|
25 |
+
|
26 |
+
|
27 |
+
class CosyVoice:
|
28 |
+
|
29 |
+
def __init__(self, model_dir,gpu_id=0, load_jit=False, load_trt=False, fp16=False):
|
30 |
+
self.instruct = True if '-Instruct' in model_dir else False
|
31 |
+
self.model_dir = model_dir
|
32 |
+
self.fp16 = fp16
|
33 |
+
if not os.path.exists(model_dir):
|
34 |
+
model_dir = snapshot_download(model_dir)
|
35 |
+
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
36 |
+
configs = load_hyperpyyaml(f)
|
37 |
+
assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
|
38 |
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
39 |
+
configs['feat_extractor'],
|
40 |
+
'{}/campplus.onnx'.format(model_dir),
|
41 |
+
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
|
42 |
+
'{}/spk2info.pt'.format(model_dir),
|
43 |
+
configs['allowed_special'],
|
44 |
+
gpu_id=gpu_id)
|
45 |
+
self.sample_rate = configs['sample_rate']
|
46 |
+
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
47 |
+
load_jit, load_trt, fp16 = False, False, False
|
48 |
+
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
49 |
+
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16, gpu_id=gpu_id)
|
50 |
+
self.model.load('{}/llm.pt'.format(model_dir),
|
51 |
+
'{}/flow.pt'.format(model_dir),
|
52 |
+
'{}/hift.pt'.format(model_dir))
|
53 |
+
if load_jit:
|
54 |
+
self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
55 |
+
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
56 |
+
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
57 |
+
if load_trt:
|
58 |
+
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
59 |
+
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
60 |
+
self.fp16)
|
61 |
+
del configs
|
62 |
+
|
63 |
+
def list_available_spks(self):
|
64 |
+
spks = list(self.frontend.spk2info.keys())
|
65 |
+
return spks
|
66 |
+
|
67 |
+
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
68 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
69 |
+
model_input = self.frontend.frontend_sft(i, spk_id)
|
70 |
+
start_time = time.time()
|
71 |
+
logging.info('synthesis text {}'.format(i))
|
72 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
73 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
74 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
75 |
+
yield model_output
|
76 |
+
start_time = time.time()
|
77 |
+
|
78 |
+
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True, token_list=None):
|
79 |
+
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
|
80 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
81 |
+
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
|
82 |
+
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
|
83 |
+
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
|
84 |
+
start_time = time.time()
|
85 |
+
logging.info('synthesis text {}'.format(i))
|
86 |
+
# import pdb;pdb.set_trace()
|
87 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed, token_list=token_list):
|
88 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
89 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
90 |
+
return model_output
|
91 |
+
|
92 |
+
|
93 |
+
def inference_zero_shot_gxl(self,tts_text, prompt_text,prompt_speech_16k, stream=False, speed=1.0, text_frontend=True, token_list=None):
|
94 |
+
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
|
95 |
+
input_text = self.frontend.text_normalize(tts_text, split=False, text_frontend=text_frontend)
|
96 |
+
model_input = self.frontend.frontend_zero_shot(input_text, prompt_text, prompt_speech_16k, self.sample_rate)
|
97 |
+
start_time = time.time()
|
98 |
+
model_output = self.model.tts_gxl(**model_input, stream=stream, speed=speed, token_list=token_list)
|
99 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
100 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
101 |
+
return model_output
|
102 |
+
|
103 |
+
|
104 |
+
def inference_zero_shot_gz_22k(self,tts_text, prompt_text,prompt_speech_22k, stream=False, speed=1.0, text_frontend=True, token_list=None):
|
105 |
+
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
|
106 |
+
input_text = self.frontend.text_normalize(tts_text, split=False, text_frontend=text_frontend)
|
107 |
+
model_input = self.frontend.frontend_zero_shot_22k(input_text, prompt_text, prompt_speech_22k, self.sample_rate)
|
108 |
+
start_time = time.time()
|
109 |
+
model_output = self.model.tts_gxl(**model_input, stream=stream, speed=speed, token_list=token_list)
|
110 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
111 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
112 |
+
return model_output
|
113 |
+
|
114 |
+
|
115 |
+
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
116 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
117 |
+
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
|
118 |
+
start_time = time.time()
|
119 |
+
logging.info('synthesis text {}'.format(i))
|
120 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
121 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
122 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
123 |
+
yield model_output
|
124 |
+
start_time = time.time()
|
125 |
+
|
126 |
+
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
|
127 |
+
assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
|
128 |
+
if self.instruct is False:
|
129 |
+
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
130 |
+
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
|
131 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
132 |
+
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
133 |
+
start_time = time.time()
|
134 |
+
logging.info('synthesis text {}'.format(i))
|
135 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
136 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
137 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
138 |
+
yield model_output
|
139 |
+
start_time = time.time()
|
140 |
+
|
141 |
+
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
|
142 |
+
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
|
143 |
+
start_time = time.time()
|
144 |
+
for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
|
145 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
146 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
147 |
+
yield model_output
|
148 |
+
start_time = time.time()
|
149 |
+
|
150 |
+
|
151 |
+
class CosyVoice2(CosyVoice):
|
152 |
+
|
153 |
+
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
|
154 |
+
self.instruct = True if '-Instruct' in model_dir else False
|
155 |
+
self.model_dir = model_dir
|
156 |
+
self.fp16 = fp16
|
157 |
+
if not os.path.exists(model_dir):
|
158 |
+
model_dir = snapshot_download(model_dir)
|
159 |
+
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
160 |
+
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
161 |
+
assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
|
162 |
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
163 |
+
configs['feat_extractor'],
|
164 |
+
'{}/campplus.onnx'.format(model_dir),
|
165 |
+
'{}/speech_tokenizer_v2.onnx'.format(model_dir),
|
166 |
+
'{}/spk2info.pt'.format(model_dir),
|
167 |
+
configs['allowed_special'])
|
168 |
+
self.sample_rate = configs['sample_rate']
|
169 |
+
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
170 |
+
load_jit, load_trt, fp16 = False, False, False
|
171 |
+
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
172 |
+
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
173 |
+
self.model.load('{}/llm.pt'.format(model_dir),
|
174 |
+
'{}/flow.pt'.format(model_dir),
|
175 |
+
'{}/hift.pt'.format(model_dir))
|
176 |
+
if load_jit:
|
177 |
+
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
178 |
+
if load_trt:
|
179 |
+
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
180 |
+
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
181 |
+
self.fp16)
|
182 |
+
del configs
|
183 |
+
|
184 |
+
def inference_instruct(self, *args, **kwargs):
|
185 |
+
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
|
186 |
+
|
187 |
+
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
188 |
+
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
|
189 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
190 |
+
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
|
191 |
+
start_time = time.time()
|
192 |
+
logging.info('synthesis text {}'.format(i))
|
193 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
194 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
195 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
196 |
+
yield model_output
|
197 |
+
start_time = time.time()
|
tts/cosyvoice/cli/frontend.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from functools import partial
|
15 |
+
from typing import Generator
|
16 |
+
import json
|
17 |
+
import onnxruntime
|
18 |
+
import torch
|
19 |
+
import numpy as np
|
20 |
+
import whisper
|
21 |
+
from typing import Callable
|
22 |
+
import torchaudio.compliance.kaldi as kaldi
|
23 |
+
import torchaudio
|
24 |
+
import os
|
25 |
+
import re
|
26 |
+
import inflect
|
27 |
+
try:
|
28 |
+
import ttsfrd
|
29 |
+
use_ttsfrd = True
|
30 |
+
except ImportError:
|
31 |
+
print("failed to import ttsfrd, use WeTextProcessing instead")
|
32 |
+
# from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
33 |
+
# from tn.english.normalizer import Normalizer as EnNormalizer
|
34 |
+
use_ttsfrd = False
|
35 |
+
from cosyvoice.utils.file_utils import logging
|
36 |
+
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
37 |
+
is_npu = True
|
38 |
+
try:
|
39 |
+
import torch_npu
|
40 |
+
except ImportError:
|
41 |
+
is_npu = False
|
42 |
+
print("failed to import torch_npu")
|
43 |
+
|
44 |
+
class CosyVoiceFrontEnd:
|
45 |
+
|
46 |
+
def __init__(self,
|
47 |
+
get_tokenizer: Callable,
|
48 |
+
feat_extractor: Callable,
|
49 |
+
campplus_model: str,
|
50 |
+
speech_tokenizer_model: str,
|
51 |
+
spk2info: str = '',
|
52 |
+
allowed_special: str = 'all',
|
53 |
+
gpu_id: int = 0):
|
54 |
+
self.tokenizer = get_tokenizer()
|
55 |
+
self.feat_extractor = feat_extractor
|
56 |
+
if is_npu:
|
57 |
+
self.device = torch.device(f'npu:{gpu_id}')
|
58 |
+
else:
|
59 |
+
self.device = torch.device(f'cuda:{gpu_id}')
|
60 |
+
option = onnxruntime.SessionOptions()
|
61 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
62 |
+
option.intra_op_num_threads = 1
|
63 |
+
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
64 |
+
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
65 |
+
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
66 |
+
"CPUExecutionProvider"])
|
67 |
+
if os.path.exists(spk2info):
|
68 |
+
self.spk2info = torch.load(spk2info, map_location=self.device)
|
69 |
+
else:
|
70 |
+
self.spk2info = {}
|
71 |
+
self.allowed_special = allowed_special
|
72 |
+
self.use_ttsfrd = use_ttsfrd
|
73 |
+
if self.use_ttsfrd:
|
74 |
+
self.frd = ttsfrd.TtsFrontendEngine()
|
75 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
76 |
+
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
77 |
+
'failed to initialize ttsfrd resource'
|
78 |
+
self.frd.set_lang_type('pinyinvg')
|
79 |
+
else:
|
80 |
+
self.zh_tn_model = lambda x: x #ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
|
81 |
+
self.en_tn_model = lambda x: x #EnNormalizer()
|
82 |
+
self.inflect_parser = inflect.engine()
|
83 |
+
|
84 |
+
def _extract_text_token(self, text):
|
85 |
+
if isinstance(text, Generator):
|
86 |
+
logging.info('get tts_text generator, will return _extract_text_token_generator!')
|
87 |
+
# NOTE add a dummy text_token_len for compatibility
|
88 |
+
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
|
89 |
+
else:
|
90 |
+
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
91 |
+
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
92 |
+
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
93 |
+
return text_token, text_token_len
|
94 |
+
|
95 |
+
def _extract_text_token_generator(self, text_generator):
|
96 |
+
for text in text_generator:
|
97 |
+
text_token, _ = self._extract_text_token(text)
|
98 |
+
for i in range(text_token.shape[1]):
|
99 |
+
yield text_token[:, i: i + 1]
|
100 |
+
|
101 |
+
def _extract_speech_token(self, speech):
|
102 |
+
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
103 |
+
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
104 |
+
speech_token = self.speech_tokenizer_session.run(None,
|
105 |
+
{self.speech_tokenizer_session.get_inputs()[0].name:
|
106 |
+
feat.detach().cpu().numpy(),
|
107 |
+
self.speech_tokenizer_session.get_inputs()[1].name:
|
108 |
+
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
109 |
+
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
110 |
+
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
111 |
+
return speech_token, speech_token_len
|
112 |
+
|
113 |
+
def _extract_spk_embedding(self, speech):
|
114 |
+
feat = kaldi.fbank(speech,
|
115 |
+
num_mel_bins=80,
|
116 |
+
dither=0,
|
117 |
+
sample_frequency=16000)
|
118 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
119 |
+
embedding = self.campplus_session.run(None,
|
120 |
+
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
121 |
+
embedding = torch.tensor([embedding]).to(self.device)
|
122 |
+
return embedding
|
123 |
+
|
124 |
+
def _extract_speech_feat(self, speech):
|
125 |
+
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
126 |
+
speech_feat = speech_feat.unsqueeze(dim=0)
|
127 |
+
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
128 |
+
return speech_feat, speech_feat_len
|
129 |
+
|
130 |
+
def text_normalize(self, text, split=True, text_frontend=True):
|
131 |
+
if isinstance(text, Generator):
|
132 |
+
logging.info('get tts_text generator, will skip text_normalize!')
|
133 |
+
return [text]
|
134 |
+
if text_frontend is False:
|
135 |
+
return [text] if split is True else text
|
136 |
+
text = text.strip()
|
137 |
+
if self.use_ttsfrd:
|
138 |
+
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
139 |
+
text = ''.join(texts)
|
140 |
+
else:
|
141 |
+
if contains_chinese(text):
|
142 |
+
# text = self.zh_tn_model.normalize(text)
|
143 |
+
text = text.replace("\n", "")
|
144 |
+
text = replace_blank(text)
|
145 |
+
text = replace_corner_mark(text)
|
146 |
+
text = text.replace(".", "。")
|
147 |
+
text = text.replace(" - ", ",")
|
148 |
+
text = remove_bracket(text)
|
149 |
+
text = re.sub(r'[,,、]+$', '。', text)
|
150 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
151 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
152 |
+
else:
|
153 |
+
# text = self.en_tn_model.normalize(text)
|
154 |
+
text = spell_out_number(text, self.inflect_parser)
|
155 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
156 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
157 |
+
texts = [i for i in texts if not is_only_punctuation(i)]
|
158 |
+
return texts if split is True else text
|
159 |
+
|
160 |
+
def frontend_sft(self, tts_text, spk_id):
|
161 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
162 |
+
embedding = self.spk2info[spk_id]['embedding']
|
163 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
164 |
+
return model_input
|
165 |
+
|
166 |
+
def frontend_zero_shot_22k(self, tts_text, prompt_text, prompt_speech_22k, resample_rate=16000):
|
167 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
168 |
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
169 |
+
prompt_speech_16k = torchaudio.transforms.Resample(orig_freq=22050, new_freq=16000)(prompt_speech_22k)
|
170 |
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22k)
|
171 |
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
172 |
+
# if resample_rate == 16000:
|
173 |
+
# # cosyvoice2, force speech_feat % speech_token = 2
|
174 |
+
# token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
175 |
+
# speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
176 |
+
# speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
177 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
178 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
179 |
+
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
180 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
181 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
182 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
183 |
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
184 |
+
return model_input
|
185 |
+
|
186 |
+
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
|
187 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
188 |
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
189 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
190 |
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
191 |
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
192 |
+
if resample_rate == 24000:
|
193 |
+
# cosyvoice2, force speech_feat % speech_token = 2
|
194 |
+
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
195 |
+
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
196 |
+
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
197 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
198 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
199 |
+
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
200 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
201 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
202 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
203 |
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
204 |
+
return model_input
|
205 |
+
|
206 |
+
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
|
207 |
+
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
|
208 |
+
# in cross lingual mode, we remove prompt in llm
|
209 |
+
del model_input['prompt_text']
|
210 |
+
del model_input['prompt_text_len']
|
211 |
+
del model_input['llm_prompt_speech_token']
|
212 |
+
del model_input['llm_prompt_speech_token_len']
|
213 |
+
return model_input
|
214 |
+
|
215 |
+
def frontend_instruct(self, tts_text, spk_id, instruct_text):
|
216 |
+
model_input = self.frontend_sft(tts_text, spk_id)
|
217 |
+
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
218 |
+
del model_input['llm_embedding']
|
219 |
+
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
|
220 |
+
model_input['prompt_text'] = instruct_text_token
|
221 |
+
model_input['prompt_text_len'] = instruct_text_token_len
|
222 |
+
return model_input
|
223 |
+
|
224 |
+
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
|
225 |
+
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate)
|
226 |
+
del model_input['llm_prompt_speech_token']
|
227 |
+
del model_input['llm_prompt_speech_token_len']
|
228 |
+
return model_input
|
229 |
+
|
230 |
+
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
|
231 |
+
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
232 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
233 |
+
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
234 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
235 |
+
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
236 |
+
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
237 |
+
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
|
238 |
+
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
|
239 |
+
'flow_embedding': embedding}
|
240 |
+
return model_input
|
tts/cosyvoice/cli/model.py
ADDED
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
from typing import Generator
|
16 |
+
import torch
|
17 |
+
import numpy as np
|
18 |
+
import threading
|
19 |
+
import time
|
20 |
+
from torch.nn import functional as F
|
21 |
+
from contextlib import nullcontext
|
22 |
+
import uuid
|
23 |
+
from cosyvoice.utils.common import fade_in_out
|
24 |
+
from cosyvoice.utils.file_utils import convert_onnx_to_trt
|
25 |
+
is_npu = True
|
26 |
+
try:
|
27 |
+
import torch_npu
|
28 |
+
except ImportError:
|
29 |
+
is_npu = False
|
30 |
+
print(f'torch_npu not found, set is_npu to False')
|
31 |
+
|
32 |
+
class CosyVoiceModel:
|
33 |
+
|
34 |
+
def __init__(self,
|
35 |
+
llm: torch.nn.Module,
|
36 |
+
flow: torch.nn.Module,
|
37 |
+
hift: torch.nn.Module,
|
38 |
+
fp16: bool,
|
39 |
+
gpu_id: int = 0):
|
40 |
+
if is_npu:
|
41 |
+
self.device = torch.device(f'npu:{gpu_id}')
|
42 |
+
else:
|
43 |
+
self.device = torch.device(f'cuda:{gpu_id}')
|
44 |
+
self.llm = llm
|
45 |
+
self.flow = flow
|
46 |
+
self.hift = hift
|
47 |
+
self.fp16 = fp16
|
48 |
+
self.llm.fp16 = fp16
|
49 |
+
self.flow.fp16 = fp16
|
50 |
+
if self.fp16 is True:
|
51 |
+
self.llm.half()
|
52 |
+
self.flow.half()
|
53 |
+
self.token_min_hop_len = 2 * self.flow.input_frame_rate
|
54 |
+
self.token_max_hop_len = 4 * self.flow.input_frame_rate
|
55 |
+
self.token_overlap_len = 20
|
56 |
+
# here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability
|
57 |
+
self.flow.decoder.estimator.static_chunk_size = 0
|
58 |
+
# mel fade in out
|
59 |
+
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
|
60 |
+
self.mel_window = np.hamming(2 * self.mel_overlap_len)
|
61 |
+
# hift cache
|
62 |
+
self.mel_cache_len = 20
|
63 |
+
self.source_cache_len = int(self.mel_cache_len * 256)
|
64 |
+
# speech fade in out
|
65 |
+
self.speech_window = np.hamming(2 * self.source_cache_len)
|
66 |
+
# rtf and decoding related
|
67 |
+
self.stream_scale_factor = 1
|
68 |
+
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
69 |
+
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
70 |
+
self.lock = threading.Lock()
|
71 |
+
# dict used to store session related variable
|
72 |
+
self.tts_speech_token_dict = {}
|
73 |
+
self.llm_end_dict = {}
|
74 |
+
self.mel_overlap_dict = {}
|
75 |
+
self.flow_cache_dict = {}
|
76 |
+
self.hift_cache_dict = {}
|
77 |
+
|
78 |
+
def load(self, llm_model, flow_model, hift_model):
|
79 |
+
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
80 |
+
self.llm.to(self.device).eval()
|
81 |
+
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
82 |
+
self.flow.to(self.device).eval()
|
83 |
+
# in case hift_model is a hifigan model
|
84 |
+
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
85 |
+
self.hift.load_state_dict(hift_state_dict, strict=True)
|
86 |
+
self.hift.to(self.device).eval()
|
87 |
+
|
88 |
+
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
89 |
+
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
|
90 |
+
self.llm.text_encoder = llm_text_encoder
|
91 |
+
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
|
92 |
+
self.llm.llm = llm_llm
|
93 |
+
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
94 |
+
self.flow.encoder = flow_encoder
|
95 |
+
|
96 |
+
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
|
97 |
+
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
98 |
+
if not os.path.exists(flow_decoder_estimator_model):
|
99 |
+
convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
|
100 |
+
if os.path.getsize(flow_decoder_estimator_model) == 0:
|
101 |
+
raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
|
102 |
+
del self.flow.decoder.estimator
|
103 |
+
import tensorrt as trt
|
104 |
+
with open(flow_decoder_estimator_model, 'rb') as f:
|
105 |
+
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
106 |
+
if self.flow.decoder.estimator_engine is None:
|
107 |
+
raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
|
108 |
+
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
|
109 |
+
|
110 |
+
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
111 |
+
with self.llm_context:
|
112 |
+
if isinstance(text, Generator):
|
113 |
+
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
|
114 |
+
for i in self.llm.inference_bistream(text=text,
|
115 |
+
prompt_text=prompt_text.to(self.device),
|
116 |
+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
117 |
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
118 |
+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
119 |
+
embedding=llm_embedding.to(self.device)):
|
120 |
+
self.tts_speech_token_dict[uuid].append(i)
|
121 |
+
else:
|
122 |
+
for i in self.llm.inference(text=text.to(self.device),
|
123 |
+
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
124 |
+
prompt_text=prompt_text.to(self.device),
|
125 |
+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
126 |
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
127 |
+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
128 |
+
embedding=llm_embedding.to(self.device)):
|
129 |
+
self.tts_speech_token_dict[uuid].append(i)
|
130 |
+
self.llm_end_dict[uuid] = True
|
131 |
+
|
132 |
+
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
133 |
+
tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
|
134 |
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
135 |
+
prompt_token=prompt_token.to(self.device),
|
136 |
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
137 |
+
prompt_feat=prompt_feat.to(self.device),
|
138 |
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
139 |
+
embedding=embedding.to(self.device),
|
140 |
+
flow_cache=self.flow_cache_dict[uuid])
|
141 |
+
self.flow_cache_dict[uuid] = flow_cache
|
142 |
+
|
143 |
+
# mel overlap fade in out
|
144 |
+
if self.mel_overlap_dict[uuid].shape[2] != 0:
|
145 |
+
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
146 |
+
# append hift cache
|
147 |
+
if self.hift_cache_dict[uuid] is not None:
|
148 |
+
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
149 |
+
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
150 |
+
else:
|
151 |
+
hift_cache_source = torch.zeros(1, 1, 0)
|
152 |
+
# keep overlap mel and hift cache
|
153 |
+
if finalize is False:
|
154 |
+
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
|
155 |
+
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
|
156 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
157 |
+
if self.hift_cache_dict[uuid] is not None:
|
158 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
159 |
+
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
160 |
+
'source': tts_source[:, :, -self.source_cache_len:],
|
161 |
+
'speech': tts_speech[:, -self.source_cache_len:]}
|
162 |
+
tts_speech = tts_speech[:, :-self.source_cache_len]
|
163 |
+
else:
|
164 |
+
if speed != 1.0:
|
165 |
+
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
166 |
+
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
167 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
168 |
+
if self.hift_cache_dict[uuid] is not None:
|
169 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
170 |
+
return tts_speech
|
171 |
+
|
172 |
+
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
173 |
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
174 |
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
175 |
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
176 |
+
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0,
|
177 |
+
token_list=None, **kwargs):
|
178 |
+
# this_uuid is used to track variables related to this inference thread
|
179 |
+
this_uuid = str(uuid.uuid1())
|
180 |
+
with self.lock:
|
181 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
182 |
+
self.hift_cache_dict[this_uuid] = None
|
183 |
+
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
184 |
+
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
185 |
+
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
186 |
+
p.start()
|
187 |
+
# import pdb;pdb.set_trace()
|
188 |
+
if stream is True:
|
189 |
+
token_hop_len = self.token_min_hop_len
|
190 |
+
while True:
|
191 |
+
time.sleep(0.1)
|
192 |
+
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
193 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
194 |
+
.unsqueeze(dim=0)
|
195 |
+
# import pdb;pdb.set_trace()
|
196 |
+
gen_token = [1650, 2163, 3062, 41, 347, 754, 1705, 73, 38, 2583, 59, 1660, 1716, 28, 324, 1260, 1018, 254, 1650, 3552, 1804, 2515, 2368, 38, 1660, 3106, 848, 3250, 1611, 511, 1037, 2964, 2255, 1509, 890, 1494, 2250, 1349, 2621, 3420, 46, 2646, 2646, 3025, 2579, 393, 824, 1609, 2089, 2162, 24, 2, 3768, 1155, 343, 325, 2764, 814, 426, 1243, 2579, 3916, 20, 1611, 349, 701, 1346, 3768, 927, 3305, 8, 2099, 511, 3582, 8, 421, 1494, 2323, 2253, 3607, 692, 3929, 511, 3710, 3662, 3179, 1204, 7, 2579, 2579, 3025, 3025, 571, 540, 1509, 2786, 2548, 1404, 699, 1260, 2250, 202, 202, 84, 3458, 73, 3458, 1716, 302, 2105, 193, 974, 3761, 2893, 2250, 193, 754, 69, 69, 599, 2554, 890, 1608, 148, 1243, 480, 1, 489, 271, 1038, 1736, 1865, 3337, 569, 28, 2246, 2426, 2250, 3768, 569, 1027, 3305, 3106, 8, 3635, 269, 1854, 70, 1385, 1584, 1385, 2187, 3064, 3064, 2579, 3025, 3337, 2579, 3768]
|
197 |
+
token_list = [66, 2307, 599, 1602, 714, 1100, 1243, 2657, 349, 535, 3662, 1403, 2610, 669, 569, 49, 48, 1027, 2684, 373, 728, 728, 186, 186, 7, 2250, 754, 1346, 1289, 2691, 3740, 3082, 629, 2841, 432, 1513, 1716, 302, 3607, 3607, 692, 1609, 2579, 3025, 2513, 2513, 1043, 1043, 2704, 53, 2893, 1043, 2704, 1043, 2513, 2513, 1043, 1083, 3600, 421, 8, 8, 1256, 1243, 3278, 2932, 510, 2515, 2582, 1906, 4056, 1346, 1241, 2253, 1346, 1698, 962, 409, 1507, 1377, 2162, 10, 21, 396, 3649, 373, 728, 2513, 2513, 2513, 2513, 1865, 1893, 1712, 375, 4064, 3062, 41, 569, 3887, 1716, 472, 3830, 186, 408, 203, 3478, 3340, 800, 1243, 480, 271, 2162, 3240, 3238, 3193, 599, 2391, 1317, 1346, 269, 2253, 2209, 8, 1974, 2764, 1579, 421, 1073, 3929, 590, 31, 3898, 53, 53, 1043, 1957]
|
198 |
+
this_tts_speech_token = np.array(token_list)
|
199 |
+
this_tts_speech_token = torch.tensor(this_tts_speech_token)
|
200 |
+
# this_tts_speech_token = np.load("/home/node57_data/hkxie/4O/streaming_fm/data/s3token1/05343304771_EIjYa_VAD27_3.hubert_code.npy")
|
201 |
+
# this_tts_speech_token = torch.tensor(this_tts_speech_token)
|
202 |
+
|
203 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
204 |
+
prompt_token=flow_prompt_speech_token,
|
205 |
+
prompt_feat=prompt_speech_feat,
|
206 |
+
embedding=flow_embedding,
|
207 |
+
uuid=this_uuid,
|
208 |
+
finalize=False)
|
209 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
210 |
+
with self.lock:
|
211 |
+
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
212 |
+
# increase token_hop_len for better speech quality
|
213 |
+
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
214 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
215 |
+
break
|
216 |
+
p.join()
|
217 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
218 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
219 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
220 |
+
prompt_token=flow_prompt_speech_token,
|
221 |
+
prompt_feat=prompt_speech_feat,
|
222 |
+
embedding=flow_embedding,
|
223 |
+
uuid=this_uuid,
|
224 |
+
finalize=True)
|
225 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
226 |
+
else:
|
227 |
+
# deal with all tokens
|
228 |
+
p.join()
|
229 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
230 |
+
this_tts_speech_token = np.array(token_list)
|
231 |
+
this_tts_speech_token = torch.tensor(this_tts_speech_token)
|
232 |
+
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0)
|
233 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
234 |
+
prompt_token=flow_prompt_speech_token,
|
235 |
+
prompt_feat=prompt_speech_feat,
|
236 |
+
embedding=flow_embedding,
|
237 |
+
uuid=this_uuid,
|
238 |
+
finalize=True,
|
239 |
+
speed=speed)
|
240 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
241 |
+
with self.lock:
|
242 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
243 |
+
self.llm_end_dict.pop(this_uuid)
|
244 |
+
self.mel_overlap_dict.pop(this_uuid)
|
245 |
+
self.hift_cache_dict.pop(this_uuid)
|
246 |
+
self.flow_cache_dict.pop(this_uuid)
|
247 |
+
torch.cuda.empty_cache()
|
248 |
+
|
249 |
+
def tts_gxl(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
250 |
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
251 |
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
252 |
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
253 |
+
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0,
|
254 |
+
token_list=None, **kwargs):
|
255 |
+
# this_uuid is used to track variables related to this inference thread
|
256 |
+
this_uuid = str(uuid.uuid1())
|
257 |
+
with self.lock:
|
258 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
259 |
+
self.hift_cache_dict[this_uuid] = None
|
260 |
+
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
261 |
+
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
262 |
+
# p = threading.Thread(target=self.llm_job,
|
263 |
+
# args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
264 |
+
# p.start()
|
265 |
+
# p.join()
|
266 |
+
# this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
267 |
+
this_tts_speech_token = np.array(token_list)
|
268 |
+
this_tts_speech_token = torch.tensor(this_tts_speech_token)
|
269 |
+
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0)
|
270 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
271 |
+
prompt_token=flow_prompt_speech_token,
|
272 |
+
prompt_feat=prompt_speech_feat,
|
273 |
+
embedding=flow_embedding,
|
274 |
+
uuid=this_uuid,
|
275 |
+
finalize=True,
|
276 |
+
speed=speed)
|
277 |
+
torch.cuda.empty_cache()
|
278 |
+
with self.lock:
|
279 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
280 |
+
self.llm_end_dict.pop(this_uuid)
|
281 |
+
self.mel_overlap_dict.pop(this_uuid)
|
282 |
+
self.hift_cache_dict.pop(this_uuid)
|
283 |
+
self.flow_cache_dict.pop(this_uuid)
|
284 |
+
return {'tts_speech': this_tts_speech.cpu()}
|
285 |
+
|
286 |
+
def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
|
287 |
+
# this_uuid is used to track variables related to this inference thread
|
288 |
+
this_uuid = str(uuid.uuid1())
|
289 |
+
with self.lock:
|
290 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
|
291 |
+
self.hift_cache_dict[this_uuid] = None
|
292 |
+
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
293 |
+
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
294 |
+
if stream is True:
|
295 |
+
token_hop_len = self.token_min_hop_len
|
296 |
+
while True:
|
297 |
+
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
298 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
299 |
+
.unsqueeze(dim=0)
|
300 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
301 |
+
prompt_token=flow_prompt_speech_token,
|
302 |
+
prompt_feat=prompt_speech_feat,
|
303 |
+
embedding=flow_embedding,
|
304 |
+
uuid=this_uuid,
|
305 |
+
finalize=False)
|
306 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
307 |
+
with self.lock:
|
308 |
+
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
309 |
+
# increase token_hop_len for better speech quality
|
310 |
+
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
311 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
312 |
+
break
|
313 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
314 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
315 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
316 |
+
prompt_token=flow_prompt_speech_token,
|
317 |
+
prompt_feat=prompt_speech_feat,
|
318 |
+
embedding=flow_embedding,
|
319 |
+
uuid=this_uuid,
|
320 |
+
finalize=True)
|
321 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
322 |
+
else:
|
323 |
+
# deal with all tokens
|
324 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
325 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
326 |
+
prompt_token=flow_prompt_speech_token,
|
327 |
+
prompt_feat=prompt_speech_feat,
|
328 |
+
embedding=flow_embedding,
|
329 |
+
uuid=this_uuid,
|
330 |
+
finalize=True,
|
331 |
+
speed=speed)
|
332 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
333 |
+
with self.lock:
|
334 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
335 |
+
self.llm_end_dict.pop(this_uuid)
|
336 |
+
self.mel_overlap_dict.pop(this_uuid)
|
337 |
+
self.hift_cache_dict.pop(this_uuid)
|
338 |
+
torch.cuda.empty_cache()
|
339 |
+
|
340 |
+
|
341 |
+
class CosyVoice2Model(CosyVoiceModel):
|
342 |
+
|
343 |
+
def __init__(self,
|
344 |
+
llm: torch.nn.Module,
|
345 |
+
flow: torch.nn.Module,
|
346 |
+
hift: torch.nn.Module,
|
347 |
+
fp16: bool):
|
348 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
349 |
+
self.llm = llm
|
350 |
+
self.flow = flow
|
351 |
+
self.hift = hift
|
352 |
+
self.fp16 = fp16
|
353 |
+
self.llm.fp16 = fp16
|
354 |
+
self.flow.fp16 = fp16
|
355 |
+
if self.fp16 is True:
|
356 |
+
self.llm.half()
|
357 |
+
self.flow.half()
|
358 |
+
self.token_hop_len = 2 * self.flow.input_frame_rate
|
359 |
+
# here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
|
360 |
+
self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
|
361 |
+
self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
|
362 |
+
# hift cache
|
363 |
+
self.mel_cache_len = 8
|
364 |
+
self.source_cache_len = int(self.mel_cache_len * 480)
|
365 |
+
# speech fade in out
|
366 |
+
self.speech_window = np.hamming(2 * self.source_cache_len)
|
367 |
+
# rtf and decoding related
|
368 |
+
self.stream_scale_factor = 1
|
369 |
+
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
370 |
+
self.lock = threading.Lock()
|
371 |
+
# dict used to store session related variable
|
372 |
+
self.tts_speech_token_dict = {}
|
373 |
+
self.llm_end_dict = {}
|
374 |
+
self.hift_cache_dict = {}
|
375 |
+
|
376 |
+
def load_jit(self, flow_encoder_model):
|
377 |
+
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
378 |
+
self.flow.encoder = flow_encoder
|
379 |
+
|
380 |
+
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
381 |
+
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
382 |
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
383 |
+
prompt_token=prompt_token.to(self.device),
|
384 |
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
385 |
+
prompt_feat=prompt_feat.to(self.device),
|
386 |
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
387 |
+
embedding=embedding.to(self.device),
|
388 |
+
finalize=finalize)
|
389 |
+
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
390 |
+
# append hift cache
|
391 |
+
if self.hift_cache_dict[uuid] is not None:
|
392 |
+
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
393 |
+
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
394 |
+
else:
|
395 |
+
hift_cache_source = torch.zeros(1, 1, 0)
|
396 |
+
# keep overlap mel and hift cache
|
397 |
+
if finalize is False:
|
398 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
399 |
+
if self.hift_cache_dict[uuid] is not None:
|
400 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
401 |
+
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
402 |
+
'source': tts_source[:, :, -self.source_cache_len:],
|
403 |
+
'speech': tts_speech[:, -self.source_cache_len:]}
|
404 |
+
tts_speech = tts_speech[:, :-self.source_cache_len]
|
405 |
+
else:
|
406 |
+
if speed != 1.0:
|
407 |
+
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
408 |
+
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
409 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
410 |
+
if self.hift_cache_dict[uuid] is not None:
|
411 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
412 |
+
return tts_speech
|
413 |
+
|
414 |
+
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
415 |
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
416 |
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
417 |
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
418 |
+
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
419 |
+
# this_uuid is used to track variables related to this inference thread
|
420 |
+
this_uuid = str(uuid.uuid1())
|
421 |
+
with self.lock:
|
422 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
423 |
+
self.hift_cache_dict[this_uuid] = None
|
424 |
+
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
425 |
+
p.start()
|
426 |
+
if stream is True:
|
427 |
+
token_offset = 0
|
428 |
+
while True:
|
429 |
+
time.sleep(0.1)
|
430 |
+
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
|
431 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
|
432 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
433 |
+
prompt_token=flow_prompt_speech_token,
|
434 |
+
prompt_feat=prompt_speech_feat,
|
435 |
+
embedding=flow_embedding,
|
436 |
+
uuid=this_uuid,
|
437 |
+
token_offset=token_offset,
|
438 |
+
finalize=False)
|
439 |
+
token_offset += self.token_hop_len
|
440 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
441 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
|
442 |
+
break
|
443 |
+
p.join()
|
444 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
445 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
446 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
447 |
+
prompt_token=flow_prompt_speech_token,
|
448 |
+
prompt_feat=prompt_speech_feat,
|
449 |
+
embedding=flow_embedding,
|
450 |
+
uuid=this_uuid,
|
451 |
+
token_offset=token_offset,
|
452 |
+
finalize=True)
|
453 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
454 |
+
else:
|
455 |
+
# deal with all tokens
|
456 |
+
p.join()
|
457 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
458 |
+
# import pdb;pdb.set_trace()
|
459 |
+
# this_tts_speech_token = np.load("/home/node57_data/hkxie/4O/streaming_fm/data/s3token2/05343304771_EIjYa_VAD27_3.hubert_code.npy")
|
460 |
+
# this_tts_speech_token = np.load("/home/node57_data/hkxie/4O/streaming_fm/data/s3token2/05343304771_EIjYa_VAD41_6.hubert_code.npy")
|
461 |
+
# token2 = [2745, 860, 393, 393, 2579, 2926, 1842, 2136, 480, 205, 3910, 3251, 73, 42, 38, 1346, 2554, 368, 40, 1660, 1660, 1055, 2597, 1712, 28, 2246, 386, 122, 38, 3607, 3818, 1098, 980, 38, 1353, 1660, 426, 1694, 1406, 511, 511, 396, 671, 2571, 2809, 2385, 3947, 229, 2000, 773, 2786, 858, 2554, 701, 46, 2646, 1608, 2890, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 3, 31, 758, 3438, 3438, 3438, 54, 269, 2246, 343, 1600, 1608, 3554, 3649, 60, 511, 701, 44, 3554, 3775, 20, 2099, 535, 2099, 3545, 3267, 1223, 1650, 3607, 3611, 2646, 3545, 3545, 802, 802, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 3, 26, 1734, 571, 1240, 1509, 2786, 1509, 740, 890, 2426, 1241, 1241, 2399, 2, 3458, 2285, 25, 2105, 4082, 3761, 3121, 3121, 269, 4082, 1353, 2285, 463, 758, 1193, 421, 3662, 148, 1516, 101, 32, 615, 1660, 1038, 2597, 3554, 28, 2246, 2426, 1241, 22, 1406, 70, 2230, 2230, 3635, 302, 2537, 1385, 1385, 1385, 69, 754, 3489, 1055, 393, 393, 393, 393, 393, 393, 393, 393]
|
462 |
+
|
463 |
+
# token_list3 = [2745, 599, 3238, 2554, 84, 73, 42, 2582, 2583, 4082, 1660, 1584, 1469, 1712, 2243, 1260, 1688, 269, 409, 3552, 1584, 2646, 38, 2385, 1660, 1038, 1516, 85, 3250, 1611, 109, 3611, 2255, 3947, 229, 451, 2786, 1044, 2621, 4056, 2646, 2646, 2890, 31, 3898, 3898, 2893, 2893, 2893, 2893, 1043, 52, 52, 52, 52, 1504, 2307, 202, 229, 358, 358, 266, 2907, 1516, 2246, 343, 1030, 122, 2409, 1694, 1406, 511, 2209, 51, 927, 1185, 1256, 1879, 2890, 2858, 203, 2426, 2253, 69, 3011, 3611, 2515, 2646, 492, 3662, 1608, 7, 31, 1406, 1406, 2893, 1043, 728, 380, 380, 571, 2385, 229, 740, 3193, 358, 202, 3331, 2, 1796, 35, 2285, 1893, 1516, 329, 3761, 2859, 122, 1241, 329, 1906, 59, 460, 463, 2554, 740, 1608, 60, 1516, 101, 1, 489, 1038, 1038, 3337, 3768, 569, 32, 1494, 2250, 3768, 3649, 20, 351, 1404, 1193, 44, 59, 3607, 2174, 1584, 1584, 1584, 1655, 1736, 1043, 1043, 1469, 569, 28, 2000, 2426, 2250, 3768, 927, 3250, 8, 2099, 1716, 59, 792, 3106, 1385, 1385, 1385, 1385, 1385, 3947, 1507, 864, 52, 52, 52]
|
464 |
+
token_list3 = [997, 966, 3554, 1854, 714, 3761, 3741, 2426, 103, 103, 1260, 1260, 2306, 2306, 2307, 824, 792, 193, 1879, 3478, 48, 511, 3420, 1317, 1761, 599, 1002, 980, 2646, 2646, 2646, 2646, 2646, 3366, 1949, 575, 575, 26, 26, 29, 3929, 229, 3910, 568, 3265, 3768, 28, 2004, 3910, 568, 3265, 3062, 41, 927, 699, 304, 2859, 2537, 28, 3741, 2841, 1688, 3768, 28, 1155, 855, 1570, 1570, 1570, 1570, 1570, 2876, 2680, 3, 3, 3636, 1555, 2844, 409, 1040, 2515, 1640, 3121, 3153, 882, 2385, 1796, 1796, 1796, 2368, 1785, 49, 671, 3830, 3025, 2844, 2105, 1037, 1729, 2105, 3265, 103, 1346, 580, 3922, 2876, 42, 271, 59, 3106, 2680, 3830, 2704, 2105, 2815, 59, 1698, 1223, 1342, 3267, 2786, 2250, 2250, 2208, 3, 1446, 1446, 1446, 1446, 1446, 1446, 1446, 1446, 1446, 1446, 1688, 1688, 1446, 1446, 1688, 1688, 1688, 1688, 1688]
|
465 |
+
this_tts_speech_token = np.array(token_list3)
|
466 |
+
this_tts_speech_token = torch.tensor(this_tts_speech_token)
|
467 |
+
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0)
|
468 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
469 |
+
prompt_token=flow_prompt_speech_token,
|
470 |
+
prompt_feat=prompt_speech_feat,
|
471 |
+
embedding=flow_embedding,
|
472 |
+
uuid=this_uuid,
|
473 |
+
token_offset=0,
|
474 |
+
finalize=True,
|
475 |
+
speed=speed)
|
476 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
477 |
+
with self.lock:
|
478 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
479 |
+
self.llm_end_dict.pop(this_uuid)
|
480 |
+
torch.cuda.empty_cache()
|
tts/cosyvoice/dataset/__init__.py
ADDED
File without changes
|
tts/cosyvoice/dataset/dataset.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
2 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import random
|
17 |
+
import json
|
18 |
+
import math
|
19 |
+
from functools import partial
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.distributed as dist
|
23 |
+
from torch.utils.data import IterableDataset
|
24 |
+
from cosyvoice.utils.file_utils import read_lists, read_json_lists
|
25 |
+
|
26 |
+
|
27 |
+
class Processor(IterableDataset):
|
28 |
+
|
29 |
+
def __init__(self, source, f, *args, **kw):
|
30 |
+
assert callable(f)
|
31 |
+
self.source = source
|
32 |
+
self.f = f
|
33 |
+
self.args = args
|
34 |
+
self.kw = kw
|
35 |
+
|
36 |
+
def set_epoch(self, epoch):
|
37 |
+
self.source.set_epoch(epoch)
|
38 |
+
|
39 |
+
def __iter__(self):
|
40 |
+
""" Return an iterator over the source dataset processed by the
|
41 |
+
given processor.
|
42 |
+
"""
|
43 |
+
assert self.source is not None
|
44 |
+
assert callable(self.f)
|
45 |
+
return self.f(iter(self.source), *self.args, **self.kw)
|
46 |
+
|
47 |
+
def apply(self, f):
|
48 |
+
assert callable(f)
|
49 |
+
return Processor(self, f, *self.args, **self.kw)
|
50 |
+
|
51 |
+
|
52 |
+
class DistributedSampler:
|
53 |
+
|
54 |
+
def __init__(self, shuffle=True, partition=True):
|
55 |
+
self.epoch = -1
|
56 |
+
self.update()
|
57 |
+
self.shuffle = shuffle
|
58 |
+
self.partition = partition
|
59 |
+
|
60 |
+
def update(self):
|
61 |
+
assert dist.is_available()
|
62 |
+
if dist.is_initialized():
|
63 |
+
self.rank = dist.get_rank()
|
64 |
+
self.world_size = dist.get_world_size()
|
65 |
+
else:
|
66 |
+
self.rank = 0
|
67 |
+
self.world_size = 1
|
68 |
+
worker_info = torch.utils.data.get_worker_info()
|
69 |
+
if worker_info is None:
|
70 |
+
self.worker_id = 0
|
71 |
+
self.num_workers = 1
|
72 |
+
else:
|
73 |
+
self.worker_id = worker_info.id
|
74 |
+
self.num_workers = worker_info.num_workers
|
75 |
+
return dict(rank=self.rank,
|
76 |
+
world_size=self.world_size,
|
77 |
+
worker_id=self.worker_id,
|
78 |
+
num_workers=self.num_workers)
|
79 |
+
|
80 |
+
def set_epoch(self, epoch):
|
81 |
+
self.epoch = epoch
|
82 |
+
|
83 |
+
def sample(self, data):
|
84 |
+
""" Sample data according to rank/world_size/num_workers
|
85 |
+
|
86 |
+
Args:
|
87 |
+
data(List): input data list
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
List: data list after sample
|
91 |
+
"""
|
92 |
+
data = list(range(len(data)))
|
93 |
+
# force datalist even
|
94 |
+
if self.partition:
|
95 |
+
if self.shuffle:
|
96 |
+
random.Random(self.epoch).shuffle(data)
|
97 |
+
if len(data) < self.world_size:
|
98 |
+
data = data * math.ceil(self.world_size / len(data))
|
99 |
+
data = data[:self.world_size]
|
100 |
+
data = data[self.rank::self.world_size]
|
101 |
+
if len(data) < self.num_workers:
|
102 |
+
data = data * math.ceil(self.num_workers / len(data))
|
103 |
+
data = data[:self.num_workers]
|
104 |
+
data = data[self.worker_id::self.num_workers]
|
105 |
+
return data
|
106 |
+
|
107 |
+
|
108 |
+
class DataList(IterableDataset):
|
109 |
+
|
110 |
+
def __init__(self, lists, shuffle=True, partition=True):
|
111 |
+
self.lists = lists
|
112 |
+
self.sampler = DistributedSampler(shuffle, partition)
|
113 |
+
|
114 |
+
def set_epoch(self, epoch):
|
115 |
+
self.sampler.set_epoch(epoch)
|
116 |
+
|
117 |
+
def __iter__(self):
|
118 |
+
sampler_info = self.sampler.update()
|
119 |
+
indexes = self.sampler.sample(self.lists)
|
120 |
+
for index in indexes:
|
121 |
+
data = dict(src=self.lists[index])
|
122 |
+
data.update(sampler_info)
|
123 |
+
yield data
|
124 |
+
|
125 |
+
|
126 |
+
def Dataset(data_list_file,
|
127 |
+
data_pipeline,
|
128 |
+
mode='train',
|
129 |
+
gan=False,
|
130 |
+
shuffle=True,
|
131 |
+
partition=True,
|
132 |
+
tts_file='',
|
133 |
+
prompt_utt2data=''):
|
134 |
+
""" Construct dataset from arguments
|
135 |
+
|
136 |
+
We have two shuffle stage in the Dataset. The first is global
|
137 |
+
shuffle at shard tar/raw file level. The second is global shuffle
|
138 |
+
at training samples level.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
data_type(str): raw/shard
|
142 |
+
tokenizer (BaseTokenizer): tokenizer to tokenize
|
143 |
+
partition(bool): whether to do data partition in terms of rank
|
144 |
+
"""
|
145 |
+
assert mode in ['train', 'inference']
|
146 |
+
lists = read_lists(data_list_file)
|
147 |
+
if mode == 'inference':
|
148 |
+
with open(tts_file) as f:
|
149 |
+
tts_data = json.load(f)
|
150 |
+
utt2lists = read_json_lists(prompt_utt2data)
|
151 |
+
# filter unnecessary file in inference mode
|
152 |
+
lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
|
153 |
+
dataset = DataList(lists,
|
154 |
+
shuffle=shuffle,
|
155 |
+
partition=partition)
|
156 |
+
if mode == 'inference':
|
157 |
+
# map partial arg to parquet_opener func in inference mode
|
158 |
+
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
|
159 |
+
if gan is True:
|
160 |
+
# map partial arg to padding func in gan mode
|
161 |
+
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
|
162 |
+
for func in data_pipeline:
|
163 |
+
dataset = Processor(dataset, func, mode=mode)
|
164 |
+
return dataset
|
tts/cosyvoice/dataset/processor.py
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import logging
|
15 |
+
import random
|
16 |
+
|
17 |
+
import pyarrow.parquet as pq
|
18 |
+
from io import BytesIO
|
19 |
+
import torch
|
20 |
+
import torchaudio
|
21 |
+
from torch.nn.utils.rnn import pad_sequence
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import pyworld as pw
|
24 |
+
|
25 |
+
|
26 |
+
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
|
27 |
+
|
28 |
+
|
29 |
+
def parquet_opener(data, mode='train', tts_data={}):
|
30 |
+
""" Give url or local file, return file descriptor
|
31 |
+
Inplace operation.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
data(Iterable[str]): url or local file list
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Iterable[{src, stream}]
|
38 |
+
"""
|
39 |
+
for sample in data:
|
40 |
+
assert 'src' in sample
|
41 |
+
url = sample['src']
|
42 |
+
try:
|
43 |
+
for df in pq.ParquetFile(url).iter_batches(batch_size=64):
|
44 |
+
df = df.to_pandas()
|
45 |
+
for i in range(len(df)):
|
46 |
+
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
|
47 |
+
continue
|
48 |
+
sample.update(dict(df.loc[i]))
|
49 |
+
if mode == 'train':
|
50 |
+
# NOTE do not return sample directly, must initialize a new dict
|
51 |
+
yield {**sample}
|
52 |
+
else:
|
53 |
+
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
|
54 |
+
yield {**sample, 'tts_index': index, 'tts_text': text}
|
55 |
+
except Exception as ex:
|
56 |
+
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
|
57 |
+
|
58 |
+
|
59 |
+
def filter(data,
|
60 |
+
max_length=10240,
|
61 |
+
min_length=10,
|
62 |
+
token_max_length=200,
|
63 |
+
token_min_length=1,
|
64 |
+
min_output_input_ratio=0.0005,
|
65 |
+
max_output_input_ratio=1,
|
66 |
+
mode='train'):
|
67 |
+
""" Filter sample according to feature and label length
|
68 |
+
Inplace operation.
|
69 |
+
|
70 |
+
Args::
|
71 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
72 |
+
max_length: drop utterance which is greater than max_length(10ms)
|
73 |
+
min_length: drop utterance which is less than min_length(10ms)
|
74 |
+
token_max_length: drop utterance which is greater than
|
75 |
+
token_max_length, especially when use char unit for
|
76 |
+
english modeling
|
77 |
+
token_min_length: drop utterance which is
|
78 |
+
less than token_max_length
|
79 |
+
min_output_input_ratio: minimal ration of
|
80 |
+
token_length / feats_length(10ms)
|
81 |
+
max_output_input_ratio: maximum ration of
|
82 |
+
token_length / feats_length(10ms)
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
Iterable[{key, wav, label, sample_rate}]
|
86 |
+
"""
|
87 |
+
for sample in data:
|
88 |
+
sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
|
89 |
+
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
|
90 |
+
del sample['audio_data']
|
91 |
+
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
92 |
+
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
|
93 |
+
if num_frames < min_length:
|
94 |
+
continue
|
95 |
+
if num_frames > max_length:
|
96 |
+
continue
|
97 |
+
if len(sample['text_token']) < token_min_length:
|
98 |
+
continue
|
99 |
+
if len(sample['text_token']) > token_max_length:
|
100 |
+
continue
|
101 |
+
if len(sample['speech_token']) == 0:
|
102 |
+
continue
|
103 |
+
if num_frames != 0:
|
104 |
+
if len(sample['text_token']) / num_frames < min_output_input_ratio:
|
105 |
+
continue
|
106 |
+
if len(sample['text_token']) / num_frames > max_output_input_ratio:
|
107 |
+
continue
|
108 |
+
yield sample
|
109 |
+
|
110 |
+
|
111 |
+
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
|
112 |
+
""" Resample data.
|
113 |
+
Inplace operation.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
117 |
+
resample_rate: target resample rate
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
Iterable[{key, wav, label, sample_rate}]
|
121 |
+
"""
|
122 |
+
for sample in data:
|
123 |
+
assert 'sample_rate' in sample
|
124 |
+
assert 'speech' in sample
|
125 |
+
sample_rate = sample['sample_rate']
|
126 |
+
waveform = sample['speech']
|
127 |
+
if sample_rate != resample_rate:
|
128 |
+
if sample_rate < min_sample_rate:
|
129 |
+
continue
|
130 |
+
sample['sample_rate'] = resample_rate
|
131 |
+
sample['speech'] = torchaudio.transforms.Resample(
|
132 |
+
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
133 |
+
max_val = sample['speech'].abs().max()
|
134 |
+
if max_val > 1:
|
135 |
+
sample['speech'] /= max_val
|
136 |
+
yield sample
|
137 |
+
|
138 |
+
|
139 |
+
def truncate(data, truncate_length=24576, mode='train'):
|
140 |
+
""" Truncate data.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
144 |
+
truncate_length: truncate length
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
Iterable[{key, wav, label, sample_rate}]
|
148 |
+
"""
|
149 |
+
for sample in data:
|
150 |
+
waveform = sample['speech']
|
151 |
+
if waveform.shape[1] > truncate_length:
|
152 |
+
start = random.randint(0, waveform.shape[1] - truncate_length)
|
153 |
+
waveform = waveform[:, start: start + truncate_length]
|
154 |
+
else:
|
155 |
+
waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
|
156 |
+
sample['speech'] = waveform
|
157 |
+
yield sample
|
158 |
+
|
159 |
+
|
160 |
+
def compute_fbank(data,
|
161 |
+
feat_extractor,
|
162 |
+
mode='train'):
|
163 |
+
""" Extract fbank
|
164 |
+
|
165 |
+
Args:
|
166 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
Iterable[{key, feat, label}]
|
170 |
+
"""
|
171 |
+
for sample in data:
|
172 |
+
assert 'sample_rate' in sample
|
173 |
+
assert 'speech' in sample
|
174 |
+
assert 'utt' in sample
|
175 |
+
assert 'text_token' in sample
|
176 |
+
waveform = sample['speech']
|
177 |
+
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
178 |
+
sample['speech_feat'] = mat
|
179 |
+
yield sample
|
180 |
+
|
181 |
+
|
182 |
+
def compute_f0(data, sample_rate, hop_size, mode='train'):
|
183 |
+
""" Extract f0
|
184 |
+
|
185 |
+
Args:
|
186 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
Iterable[{key, feat, label}]
|
190 |
+
"""
|
191 |
+
frame_period = hop_size * 1000 / sample_rate
|
192 |
+
for sample in data:
|
193 |
+
assert 'sample_rate' in sample
|
194 |
+
assert 'speech' in sample
|
195 |
+
assert 'utt' in sample
|
196 |
+
assert 'text_token' in sample
|
197 |
+
waveform = sample['speech']
|
198 |
+
_f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
|
199 |
+
if sum(_f0 != 0) < 5: # this happens when the algorithm fails
|
200 |
+
_f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
|
201 |
+
f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
|
202 |
+
f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
|
203 |
+
sample['pitch_feat'] = f0
|
204 |
+
yield sample
|
205 |
+
|
206 |
+
|
207 |
+
def parse_embedding(data, normalize, mode='train'):
|
208 |
+
""" Parse utt_embedding/spk_embedding
|
209 |
+
|
210 |
+
Args:
|
211 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
Iterable[{key, feat, label}]
|
215 |
+
"""
|
216 |
+
for sample in data:
|
217 |
+
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
|
218 |
+
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
|
219 |
+
if normalize:
|
220 |
+
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
|
221 |
+
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
|
222 |
+
yield sample
|
223 |
+
|
224 |
+
|
225 |
+
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
|
226 |
+
""" Decode text to chars or BPE
|
227 |
+
Inplace operation
|
228 |
+
|
229 |
+
Args:
|
230 |
+
data: Iterable[{key, wav, txt, sample_rate}]
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
Iterable[{key, wav, txt, tokens, label, sample_rate}]
|
234 |
+
"""
|
235 |
+
tokenizer = get_tokenizer()
|
236 |
+
for sample in data:
|
237 |
+
assert 'text' in sample
|
238 |
+
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
|
239 |
+
if mode == 'inference':
|
240 |
+
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
|
241 |
+
yield sample
|
242 |
+
|
243 |
+
|
244 |
+
def shuffle(data, shuffle_size=10000, mode='train'):
|
245 |
+
""" Local shuffle the data
|
246 |
+
|
247 |
+
Args:
|
248 |
+
data: Iterable[{key, feat, label}]
|
249 |
+
shuffle_size: buffer size for shuffle
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
Iterable[{key, feat, label}]
|
253 |
+
"""
|
254 |
+
buf = []
|
255 |
+
for sample in data:
|
256 |
+
buf.append(sample)
|
257 |
+
if len(buf) >= shuffle_size:
|
258 |
+
random.shuffle(buf)
|
259 |
+
for x in buf:
|
260 |
+
yield x
|
261 |
+
buf = []
|
262 |
+
# The sample left over
|
263 |
+
random.shuffle(buf)
|
264 |
+
for x in buf:
|
265 |
+
yield x
|
266 |
+
|
267 |
+
|
268 |
+
def sort(data, sort_size=500, mode='train'):
|
269 |
+
""" Sort the data by feature length.
|
270 |
+
Sort is used after shuffle and before batch, so we can group
|
271 |
+
utts with similar lengths into a batch, and `sort_size` should
|
272 |
+
be less than `shuffle_size`
|
273 |
+
|
274 |
+
Args:
|
275 |
+
data: Iterable[{key, feat, label}]
|
276 |
+
sort_size: buffer size for sort
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
Iterable[{key, feat, label}]
|
280 |
+
"""
|
281 |
+
|
282 |
+
buf = []
|
283 |
+
for sample in data:
|
284 |
+
buf.append(sample)
|
285 |
+
if len(buf) >= sort_size:
|
286 |
+
buf.sort(key=lambda x: x['speech_feat'].size(0))
|
287 |
+
for x in buf:
|
288 |
+
yield x
|
289 |
+
buf = []
|
290 |
+
# The sample left over
|
291 |
+
buf.sort(key=lambda x: x['speech_feat'].size(0))
|
292 |
+
for x in buf:
|
293 |
+
yield x
|
294 |
+
|
295 |
+
|
296 |
+
def static_batch(data, batch_size=16):
|
297 |
+
""" Static batch the data by `batch_size`
|
298 |
+
|
299 |
+
Args:
|
300 |
+
data: Iterable[{key, feat, label}]
|
301 |
+
batch_size: batch size
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
Iterable[List[{key, feat, label}]]
|
305 |
+
"""
|
306 |
+
buf = []
|
307 |
+
for sample in data:
|
308 |
+
buf.append(sample)
|
309 |
+
if len(buf) >= batch_size:
|
310 |
+
yield buf
|
311 |
+
buf = []
|
312 |
+
if len(buf) > 0:
|
313 |
+
yield buf
|
314 |
+
|
315 |
+
|
316 |
+
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
|
317 |
+
""" Dynamic batch the data until the total frames in batch
|
318 |
+
reach `max_frames_in_batch`
|
319 |
+
|
320 |
+
Args:
|
321 |
+
data: Iterable[{key, feat, label}]
|
322 |
+
max_frames_in_batch: max_frames in one batch
|
323 |
+
|
324 |
+
Returns:
|
325 |
+
Iterable[List[{key, feat, label}]]
|
326 |
+
"""
|
327 |
+
buf = []
|
328 |
+
longest_frames = 0
|
329 |
+
for sample in data:
|
330 |
+
assert 'speech_feat' in sample
|
331 |
+
assert isinstance(sample['speech_feat'], torch.Tensor)
|
332 |
+
new_sample_frames = sample['speech_feat'].size(0)
|
333 |
+
longest_frames = max(longest_frames, new_sample_frames)
|
334 |
+
frames_after_padding = longest_frames * (len(buf) + 1)
|
335 |
+
if frames_after_padding > max_frames_in_batch:
|
336 |
+
yield buf
|
337 |
+
buf = [sample]
|
338 |
+
longest_frames = new_sample_frames
|
339 |
+
else:
|
340 |
+
buf.append(sample)
|
341 |
+
if len(buf) > 0:
|
342 |
+
yield buf
|
343 |
+
|
344 |
+
|
345 |
+
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
|
346 |
+
""" Wrapper for static/dynamic batch
|
347 |
+
"""
|
348 |
+
if mode == 'inference':
|
349 |
+
return static_batch(data, 1)
|
350 |
+
else:
|
351 |
+
if batch_type == 'static':
|
352 |
+
return static_batch(data, batch_size)
|
353 |
+
elif batch_type == 'dynamic':
|
354 |
+
return dynamic_batch(data, max_frames_in_batch)
|
355 |
+
else:
|
356 |
+
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
357 |
+
|
358 |
+
|
359 |
+
def padding(data, use_spk_embedding, mode='train', gan=False):
|
360 |
+
""" Padding the data into training data
|
361 |
+
|
362 |
+
Args:
|
363 |
+
data: Iterable[List[{key, feat, label}]]
|
364 |
+
|
365 |
+
Returns:
|
366 |
+
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
|
367 |
+
"""
|
368 |
+
for sample in data:
|
369 |
+
assert isinstance(sample, list)
|
370 |
+
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
|
371 |
+
dtype=torch.int32)
|
372 |
+
order = torch.argsort(speech_feat_len, descending=True)
|
373 |
+
|
374 |
+
utts = [sample[i]['utt'] for i in order]
|
375 |
+
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
|
376 |
+
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
|
377 |
+
speech = pad_sequence(speech, batch_first=True, padding_value=0)
|
378 |
+
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
|
379 |
+
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
380 |
+
speech_token = pad_sequence(speech_token,
|
381 |
+
batch_first=True,
|
382 |
+
padding_value=0)
|
383 |
+
speech_feat = [sample[i]['speech_feat'] for i in order]
|
384 |
+
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
|
385 |
+
speech_feat = pad_sequence(speech_feat,
|
386 |
+
batch_first=True,
|
387 |
+
padding_value=0)
|
388 |
+
text = [sample[i]['text'] for i in order]
|
389 |
+
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
|
390 |
+
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
391 |
+
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
|
392 |
+
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
|
393 |
+
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
394 |
+
batch = {
|
395 |
+
"utts": utts,
|
396 |
+
"speech": speech,
|
397 |
+
"speech_len": speech_len,
|
398 |
+
"speech_token": speech_token,
|
399 |
+
"speech_token_len": speech_token_len,
|
400 |
+
"speech_feat": speech_feat,
|
401 |
+
"speech_feat_len": speech_feat_len,
|
402 |
+
"text": text,
|
403 |
+
"text_token": text_token,
|
404 |
+
"text_token_len": text_token_len,
|
405 |
+
"utt_embedding": utt_embedding,
|
406 |
+
"spk_embedding": spk_embedding,
|
407 |
+
}
|
408 |
+
if gan is True:
|
409 |
+
# in gan train, we need pitch_feat
|
410 |
+
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
411 |
+
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
|
412 |
+
pitch_feat = pad_sequence(pitch_feat,
|
413 |
+
batch_first=True,
|
414 |
+
padding_value=0)
|
415 |
+
batch["pitch_feat"] = pitch_feat
|
416 |
+
batch["pitch_feat_len"] = pitch_feat_len
|
417 |
+
else:
|
418 |
+
# only gan train needs speech, delete it to save memory
|
419 |
+
del batch["speech"]
|
420 |
+
del batch["speech_len"]
|
421 |
+
if mode == 'inference':
|
422 |
+
tts_text = [sample[i]['tts_text'] for i in order]
|
423 |
+
tts_index = [sample[i]['tts_index'] for i in order]
|
424 |
+
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
|
425 |
+
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
|
426 |
+
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
|
427 |
+
batch.update({'tts_text': tts_text,
|
428 |
+
'tts_index': tts_index,
|
429 |
+
'tts_text_token': tts_text_token,
|
430 |
+
'tts_text_token_len': tts_text_token_len})
|
431 |
+
if use_spk_embedding is True:
|
432 |
+
batch["embedding"] = batch["spk_embedding"]
|
433 |
+
else:
|
434 |
+
batch["embedding"] = batch["utt_embedding"]
|
435 |
+
yield batch
|
tts/cosyvoice/flow/decoder.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from einops import pack, rearrange, repeat
|
18 |
+
from cosyvoice.utils.common import mask_to_bias
|
19 |
+
from cosyvoice.utils.mask import add_optional_chunk_mask
|
20 |
+
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
21 |
+
from matcha.models.components.transformer import BasicTransformerBlock
|
22 |
+
|
23 |
+
|
24 |
+
class Transpose(torch.nn.Module):
|
25 |
+
def __init__(self, dim0: int, dim1: int):
|
26 |
+
super().__init__()
|
27 |
+
self.dim0 = dim0
|
28 |
+
self.dim1 = dim1
|
29 |
+
|
30 |
+
def forward(self, x: torch.Tensor):
|
31 |
+
x = torch.transpose(x, self.dim0, self.dim1)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class CausalBlock1D(Block1D):
|
36 |
+
def __init__(self, dim: int, dim_out: int):
|
37 |
+
super(CausalBlock1D, self).__init__(dim, dim_out)
|
38 |
+
self.block = torch.nn.Sequential(
|
39 |
+
CausalConv1d(dim, dim_out, 3),
|
40 |
+
Transpose(1, 2),
|
41 |
+
nn.LayerNorm(dim_out),
|
42 |
+
Transpose(1, 2),
|
43 |
+
nn.Mish(),
|
44 |
+
)
|
45 |
+
|
46 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
47 |
+
output = self.block(x * mask)
|
48 |
+
return output * mask
|
49 |
+
|
50 |
+
|
51 |
+
class CausalResnetBlock1D(ResnetBlock1D):
|
52 |
+
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
53 |
+
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
54 |
+
self.block1 = CausalBlock1D(dim, dim_out)
|
55 |
+
self.block2 = CausalBlock1D(dim_out, dim_out)
|
56 |
+
|
57 |
+
|
58 |
+
class CausalConv1d(torch.nn.Conv1d):
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
in_channels: int,
|
62 |
+
out_channels: int,
|
63 |
+
kernel_size: int,
|
64 |
+
stride: int = 1,
|
65 |
+
dilation: int = 1,
|
66 |
+
groups: int = 1,
|
67 |
+
bias: bool = True,
|
68 |
+
padding_mode: str = 'zeros',
|
69 |
+
device=None,
|
70 |
+
dtype=None
|
71 |
+
) -> None:
|
72 |
+
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
73 |
+
kernel_size, stride,
|
74 |
+
padding=0, dilation=dilation,
|
75 |
+
groups=groups, bias=bias,
|
76 |
+
padding_mode=padding_mode,
|
77 |
+
device=device, dtype=dtype)
|
78 |
+
assert stride == 1
|
79 |
+
self.causal_padding = (kernel_size - 1, 0)
|
80 |
+
|
81 |
+
def forward(self, x: torch.Tensor):
|
82 |
+
x = F.pad(x, self.causal_padding)
|
83 |
+
x = super(CausalConv1d, self).forward(x)
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
class ConditionalDecoder(nn.Module):
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
in_channels,
|
91 |
+
out_channels,
|
92 |
+
causal=False,
|
93 |
+
channels=(256, 256),
|
94 |
+
dropout=0.05,
|
95 |
+
attention_head_dim=64,
|
96 |
+
n_blocks=1,
|
97 |
+
num_mid_blocks=2,
|
98 |
+
num_heads=4,
|
99 |
+
act_fn="snake",
|
100 |
+
):
|
101 |
+
"""
|
102 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
103 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
104 |
+
"""
|
105 |
+
super().__init__()
|
106 |
+
channels = tuple(channels)
|
107 |
+
self.in_channels = in_channels
|
108 |
+
self.out_channels = out_channels
|
109 |
+
self.causal = causal
|
110 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
111 |
+
time_embed_dim = channels[0] * 4
|
112 |
+
self.time_mlp = TimestepEmbedding(
|
113 |
+
in_channels=in_channels,
|
114 |
+
time_embed_dim=time_embed_dim,
|
115 |
+
act_fn="silu",
|
116 |
+
)
|
117 |
+
self.down_blocks = nn.ModuleList([])
|
118 |
+
self.mid_blocks = nn.ModuleList([])
|
119 |
+
self.up_blocks = nn.ModuleList([])
|
120 |
+
|
121 |
+
output_channel = in_channels
|
122 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
123 |
+
input_channel = output_channel
|
124 |
+
output_channel = channels[i]
|
125 |
+
is_last = i == len(channels) - 1
|
126 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
127 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
128 |
+
transformer_blocks = nn.ModuleList(
|
129 |
+
[
|
130 |
+
BasicTransformerBlock(
|
131 |
+
dim=output_channel,
|
132 |
+
num_attention_heads=num_heads,
|
133 |
+
attention_head_dim=attention_head_dim,
|
134 |
+
dropout=dropout,
|
135 |
+
activation_fn=act_fn,
|
136 |
+
)
|
137 |
+
for _ in range(n_blocks)
|
138 |
+
]
|
139 |
+
)
|
140 |
+
downsample = (
|
141 |
+
Downsample1D(output_channel) if not is_last else
|
142 |
+
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
143 |
+
)
|
144 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
145 |
+
|
146 |
+
for _ in range(num_mid_blocks):
|
147 |
+
input_channel = channels[-1]
|
148 |
+
out_channels = channels[-1]
|
149 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
150 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
151 |
+
|
152 |
+
transformer_blocks = nn.ModuleList(
|
153 |
+
[
|
154 |
+
BasicTransformerBlock(
|
155 |
+
dim=output_channel,
|
156 |
+
num_attention_heads=num_heads,
|
157 |
+
attention_head_dim=attention_head_dim,
|
158 |
+
dropout=dropout,
|
159 |
+
activation_fn=act_fn,
|
160 |
+
)
|
161 |
+
for _ in range(n_blocks)
|
162 |
+
]
|
163 |
+
)
|
164 |
+
|
165 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
166 |
+
|
167 |
+
channels = channels[::-1] + (channels[0],)
|
168 |
+
for i in range(len(channels) - 1):
|
169 |
+
input_channel = channels[i] * 2
|
170 |
+
output_channel = channels[i + 1]
|
171 |
+
is_last = i == len(channels) - 2
|
172 |
+
resnet = CausalResnetBlock1D(
|
173 |
+
dim=input_channel,
|
174 |
+
dim_out=output_channel,
|
175 |
+
time_emb_dim=time_embed_dim,
|
176 |
+
) if self.causal else ResnetBlock1D(
|
177 |
+
dim=input_channel,
|
178 |
+
dim_out=output_channel,
|
179 |
+
time_emb_dim=time_embed_dim,
|
180 |
+
)
|
181 |
+
transformer_blocks = nn.ModuleList(
|
182 |
+
[
|
183 |
+
BasicTransformerBlock(
|
184 |
+
dim=output_channel,
|
185 |
+
num_attention_heads=num_heads,
|
186 |
+
attention_head_dim=attention_head_dim,
|
187 |
+
dropout=dropout,
|
188 |
+
activation_fn=act_fn,
|
189 |
+
)
|
190 |
+
for _ in range(n_blocks)
|
191 |
+
]
|
192 |
+
)
|
193 |
+
upsample = (
|
194 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
195 |
+
if not is_last
|
196 |
+
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
197 |
+
)
|
198 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
199 |
+
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
|
200 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
201 |
+
self.initialize_weights()
|
202 |
+
|
203 |
+
def initialize_weights(self):
|
204 |
+
for m in self.modules():
|
205 |
+
if isinstance(m, nn.Conv1d):
|
206 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
207 |
+
if m.bias is not None:
|
208 |
+
nn.init.constant_(m.bias, 0)
|
209 |
+
elif isinstance(m, nn.GroupNorm):
|
210 |
+
nn.init.constant_(m.weight, 1)
|
211 |
+
nn.init.constant_(m.bias, 0)
|
212 |
+
elif isinstance(m, nn.Linear):
|
213 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
214 |
+
if m.bias is not None:
|
215 |
+
nn.init.constant_(m.bias, 0)
|
216 |
+
|
217 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
218 |
+
"""Forward pass of the UNet1DConditional model.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
222 |
+
mask (_type_): shape (batch_size, 1, time)
|
223 |
+
t (_type_): shape (batch_size)
|
224 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
225 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
226 |
+
|
227 |
+
Raises:
|
228 |
+
ValueError: _description_
|
229 |
+
ValueError: _description_
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
_type_: _description_
|
233 |
+
"""
|
234 |
+
|
235 |
+
t = self.time_embeddings(t).to(t.dtype)
|
236 |
+
t = self.time_mlp(t)
|
237 |
+
|
238 |
+
x = pack([x, mu], "b * t")[0]
|
239 |
+
|
240 |
+
if spks is not None:
|
241 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
242 |
+
x = pack([x, spks], "b * t")[0]
|
243 |
+
if cond is not None:
|
244 |
+
x = pack([x, cond], "b * t")[0]
|
245 |
+
|
246 |
+
hiddens = []
|
247 |
+
masks = [mask]
|
248 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
249 |
+
mask_down = masks[-1]
|
250 |
+
x = resnet(x, mask_down, t)
|
251 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
252 |
+
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
253 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
254 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
255 |
+
for transformer_block in transformer_blocks:
|
256 |
+
x = transformer_block(
|
257 |
+
hidden_states=x,
|
258 |
+
attention_mask=attn_mask,
|
259 |
+
timestep=t,
|
260 |
+
)
|
261 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
262 |
+
hiddens.append(x) # Save hidden states for skip connections
|
263 |
+
x = downsample(x * mask_down)
|
264 |
+
masks.append(mask_down[:, :, ::2])
|
265 |
+
masks = masks[:-1]
|
266 |
+
mask_mid = masks[-1]
|
267 |
+
|
268 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
269 |
+
x = resnet(x, mask_mid, t)
|
270 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
271 |
+
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
272 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
273 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
274 |
+
for transformer_block in transformer_blocks:
|
275 |
+
x = transformer_block(
|
276 |
+
hidden_states=x,
|
277 |
+
attention_mask=attn_mask,
|
278 |
+
timestep=t,
|
279 |
+
)
|
280 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
281 |
+
|
282 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
283 |
+
mask_up = masks.pop()
|
284 |
+
skip = hiddens.pop()
|
285 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
286 |
+
x = resnet(x, mask_up, t)
|
287 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
288 |
+
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
289 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
290 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
291 |
+
for transformer_block in transformer_blocks:
|
292 |
+
x = transformer_block(
|
293 |
+
hidden_states=x,
|
294 |
+
attention_mask=attn_mask,
|
295 |
+
timestep=t,
|
296 |
+
)
|
297 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
298 |
+
x = upsample(x * mask_up)
|
299 |
+
x = self.final_block(x, mask_up)
|
300 |
+
output = self.final_proj(x * mask_up)
|
301 |
+
return output * mask
|
tts/cosyvoice/flow/flow.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import logging
|
15 |
+
import random
|
16 |
+
from typing import Dict, Optional
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
from torch.nn import functional as F
|
20 |
+
from omegaconf import DictConfig
|
21 |
+
from cosyvoice.utils.mask import make_pad_mask
|
22 |
+
|
23 |
+
|
24 |
+
class MaskedDiffWithXvec(torch.nn.Module):
|
25 |
+
def __init__(self,
|
26 |
+
input_size: int = 512,
|
27 |
+
output_size: int = 80,
|
28 |
+
spk_embed_dim: int = 192,
|
29 |
+
output_type: str = "mel",
|
30 |
+
vocab_size: int = 4096,
|
31 |
+
input_frame_rate: int = 50,
|
32 |
+
only_mask_loss: bool = True,
|
33 |
+
encoder: torch.nn.Module = None,
|
34 |
+
length_regulator: torch.nn.Module = None,
|
35 |
+
decoder: torch.nn.Module = None,
|
36 |
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
37 |
+
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
38 |
+
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
39 |
+
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
40 |
+
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
41 |
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
42 |
+
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
43 |
+
super().__init__()
|
44 |
+
self.input_size = input_size
|
45 |
+
self.output_size = output_size
|
46 |
+
self.decoder_conf = decoder_conf
|
47 |
+
self.mel_feat_conf = mel_feat_conf
|
48 |
+
self.vocab_size = vocab_size
|
49 |
+
self.output_type = output_type
|
50 |
+
self.input_frame_rate = input_frame_rate
|
51 |
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
52 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
53 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
54 |
+
self.encoder = encoder
|
55 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
56 |
+
self.decoder = decoder
|
57 |
+
self.length_regulator = length_regulator
|
58 |
+
self.only_mask_loss = only_mask_loss
|
59 |
+
|
60 |
+
def forward(
|
61 |
+
self,
|
62 |
+
batch: dict,
|
63 |
+
device: torch.device,
|
64 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
65 |
+
token = batch['speech_token'].to(device)
|
66 |
+
token_len = batch['speech_token_len'].to(device)
|
67 |
+
feat = batch['speech_feat'].to(device)
|
68 |
+
feat_len = batch['speech_feat_len'].to(device)
|
69 |
+
embedding = batch['embedding'].to(device)
|
70 |
+
|
71 |
+
# xvec projection
|
72 |
+
embedding = F.normalize(embedding, dim=1)
|
73 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
74 |
+
|
75 |
+
# concat text and prompt_text
|
76 |
+
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
77 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
78 |
+
|
79 |
+
# text encode
|
80 |
+
h, h_lengths = self.encoder(token, token_len)
|
81 |
+
h = self.encoder_proj(h)
|
82 |
+
h, h_lengths = self.length_regulator(h, feat_len)
|
83 |
+
|
84 |
+
# get conditions
|
85 |
+
conds = torch.zeros(feat.shape, device=token.device)
|
86 |
+
for i, j in enumerate(feat_len):
|
87 |
+
if random.random() < 0.5:
|
88 |
+
continue
|
89 |
+
index = random.randint(0, int(0.3 * j))
|
90 |
+
conds[i, :index] = feat[i, :index]
|
91 |
+
conds = conds.transpose(1, 2)
|
92 |
+
|
93 |
+
mask = (~make_pad_mask(feat_len)).to(h)
|
94 |
+
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
95 |
+
loss, _ = self.decoder.compute_loss(
|
96 |
+
feat.transpose(1, 2).contiguous(),
|
97 |
+
mask.unsqueeze(1),
|
98 |
+
h.transpose(1, 2).contiguous(),
|
99 |
+
embedding,
|
100 |
+
cond=conds
|
101 |
+
)
|
102 |
+
return {'loss': loss}
|
103 |
+
|
104 |
+
@torch.inference_mode()
|
105 |
+
def inference(self,
|
106 |
+
token,
|
107 |
+
token_len,
|
108 |
+
prompt_token,
|
109 |
+
prompt_token_len,
|
110 |
+
prompt_feat,
|
111 |
+
prompt_feat_len,
|
112 |
+
embedding,
|
113 |
+
flow_cache):
|
114 |
+
if self.fp16 is True:
|
115 |
+
prompt_feat = prompt_feat.half()
|
116 |
+
embedding = embedding.half()
|
117 |
+
|
118 |
+
assert token.shape[0] == 1
|
119 |
+
# xvec projection
|
120 |
+
embedding = F.normalize(embedding, dim=1)
|
121 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
122 |
+
|
123 |
+
# concat text and prompt_text
|
124 |
+
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
125 |
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
126 |
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
127 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
128 |
+
|
129 |
+
# text encode
|
130 |
+
h, h_lengths = self.encoder(token, token_len)
|
131 |
+
h = self.encoder_proj(h)
|
132 |
+
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
|
133 |
+
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
134 |
+
|
135 |
+
# get conditions
|
136 |
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
137 |
+
conds[:, :mel_len1] = prompt_feat
|
138 |
+
conds = conds.transpose(1, 2)
|
139 |
+
|
140 |
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
141 |
+
feat, flow_cache = self.decoder(
|
142 |
+
mu=h.transpose(1, 2).contiguous(),
|
143 |
+
mask=mask.unsqueeze(1),
|
144 |
+
spks=embedding,
|
145 |
+
cond=conds,
|
146 |
+
n_timesteps=10,
|
147 |
+
prompt_len=mel_len1,
|
148 |
+
flow_cache=flow_cache
|
149 |
+
)
|
150 |
+
feat = feat[:, :, mel_len1:]
|
151 |
+
assert feat.shape[2] == mel_len2
|
152 |
+
return feat.float(), flow_cache
|
153 |
+
|
154 |
+
|
155 |
+
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
156 |
+
def __init__(self,
|
157 |
+
input_size: int = 512,
|
158 |
+
output_size: int = 80,
|
159 |
+
spk_embed_dim: int = 192,
|
160 |
+
output_type: str = "mel",
|
161 |
+
vocab_size: int = 4096,
|
162 |
+
input_frame_rate: int = 50,
|
163 |
+
only_mask_loss: bool = True,
|
164 |
+
token_mel_ratio: int = 2,
|
165 |
+
pre_lookahead_len: int = 3,
|
166 |
+
encoder: torch.nn.Module = None,
|
167 |
+
decoder: torch.nn.Module = None,
|
168 |
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
169 |
+
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
170 |
+
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
171 |
+
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
172 |
+
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
173 |
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
174 |
+
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
175 |
+
super().__init__()
|
176 |
+
self.input_size = input_size
|
177 |
+
self.output_size = output_size
|
178 |
+
self.decoder_conf = decoder_conf
|
179 |
+
self.mel_feat_conf = mel_feat_conf
|
180 |
+
self.vocab_size = vocab_size
|
181 |
+
self.output_type = output_type
|
182 |
+
self.input_frame_rate = input_frame_rate
|
183 |
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
184 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
185 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
186 |
+
self.encoder = encoder
|
187 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
188 |
+
self.decoder = decoder
|
189 |
+
self.only_mask_loss = only_mask_loss
|
190 |
+
self.token_mel_ratio = token_mel_ratio
|
191 |
+
self.pre_lookahead_len = pre_lookahead_len
|
192 |
+
|
193 |
+
@torch.inference_mode()
|
194 |
+
def inference(self,
|
195 |
+
token,
|
196 |
+
token_len,
|
197 |
+
prompt_token,
|
198 |
+
prompt_token_len,
|
199 |
+
prompt_feat,
|
200 |
+
prompt_feat_len,
|
201 |
+
embedding,
|
202 |
+
finalize):
|
203 |
+
if self.fp16 is True:
|
204 |
+
prompt_feat = prompt_feat.half()
|
205 |
+
embedding = embedding.half()
|
206 |
+
|
207 |
+
assert token.shape[0] == 1
|
208 |
+
# xvec projection
|
209 |
+
embedding = F.normalize(embedding, dim=1)
|
210 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
211 |
+
|
212 |
+
# concat text and prompt_text
|
213 |
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
214 |
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
215 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
216 |
+
|
217 |
+
# text encode
|
218 |
+
h, h_lengths = self.encoder(token, token_len)
|
219 |
+
if finalize is False:
|
220 |
+
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
|
221 |
+
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
222 |
+
h = self.encoder_proj(h)
|
223 |
+
|
224 |
+
# get conditions
|
225 |
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
226 |
+
conds[:, :mel_len1] = prompt_feat
|
227 |
+
conds = conds.transpose(1, 2)
|
228 |
+
|
229 |
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
230 |
+
feat, _ = self.decoder(
|
231 |
+
mu=h.transpose(1, 2).contiguous(),
|
232 |
+
mask=mask.unsqueeze(1),
|
233 |
+
spks=embedding,
|
234 |
+
cond=conds,
|
235 |
+
n_timesteps=10
|
236 |
+
)
|
237 |
+
feat = feat[:, :, mel_len1:]
|
238 |
+
assert feat.shape[2] == mel_len2
|
239 |
+
return feat.float(), None
|
tts/cosyvoice/flow/flow_matching.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import threading
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from matcha.models.components.flow_matching import BASECFM
|
18 |
+
|
19 |
+
|
20 |
+
class ConditionalCFM(BASECFM):
|
21 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
22 |
+
super().__init__(
|
23 |
+
n_feats=in_channels,
|
24 |
+
cfm_params=cfm_params,
|
25 |
+
n_spks=n_spks,
|
26 |
+
spk_emb_dim=spk_emb_dim,
|
27 |
+
)
|
28 |
+
self.t_scheduler = cfm_params.t_scheduler
|
29 |
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
30 |
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
31 |
+
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
32 |
+
# Just change the architecture of the estimator here
|
33 |
+
self.estimator = estimator
|
34 |
+
self.lock = threading.Lock()
|
35 |
+
|
36 |
+
@torch.inference_mode()
|
37 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
|
38 |
+
"""Forward diffusion
|
39 |
+
|
40 |
+
Args:
|
41 |
+
mu (torch.Tensor): output of encoder
|
42 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
43 |
+
mask (torch.Tensor): output_mask
|
44 |
+
shape: (batch_size, 1, mel_timesteps)
|
45 |
+
n_timesteps (int): number of diffusion steps
|
46 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
47 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
48 |
+
shape: (batch_size, spk_emb_dim)
|
49 |
+
cond: Not used but kept for future purposes
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
sample: generated mel-spectrogram
|
53 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
54 |
+
"""
|
55 |
+
|
56 |
+
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
57 |
+
cache_size = flow_cache.shape[2]
|
58 |
+
# fix prompt and overlap part mu and z
|
59 |
+
if cache_size != 0:
|
60 |
+
z[:, :, :cache_size] = flow_cache[:, :, :, 0]
|
61 |
+
mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
|
62 |
+
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
|
63 |
+
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
|
64 |
+
flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
|
65 |
+
|
66 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
67 |
+
if self.t_scheduler == 'cosine':
|
68 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
69 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
|
70 |
+
|
71 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
72 |
+
"""
|
73 |
+
Fixed euler solver for ODEs.
|
74 |
+
Args:
|
75 |
+
x (torch.Tensor): random noise
|
76 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
77 |
+
shape: (n_timesteps + 1,)
|
78 |
+
mu (torch.Tensor): output of encoder
|
79 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
80 |
+
mask (torch.Tensor): output_mask
|
81 |
+
shape: (batch_size, 1, mel_timesteps)
|
82 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
83 |
+
shape: (batch_size, spk_emb_dim)
|
84 |
+
cond: Not used but kept for future purposes
|
85 |
+
"""
|
86 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
87 |
+
t = t.unsqueeze(dim=0)
|
88 |
+
|
89 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
90 |
+
# Or in future might add like a return_all_steps flag
|
91 |
+
sol = []
|
92 |
+
|
93 |
+
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
94 |
+
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
95 |
+
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
96 |
+
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
97 |
+
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
98 |
+
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
99 |
+
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
100 |
+
for step in range(1, len(t_span)):
|
101 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
102 |
+
x_in[:] = x
|
103 |
+
mask_in[:] = mask
|
104 |
+
mu_in[0] = mu
|
105 |
+
t_in[:] = t.unsqueeze(0)
|
106 |
+
spks_in[0] = spks
|
107 |
+
cond_in[0] = cond
|
108 |
+
dphi_dt = self.forward_estimator(
|
109 |
+
x_in, mask_in,
|
110 |
+
mu_in, t_in,
|
111 |
+
spks_in,
|
112 |
+
cond_in
|
113 |
+
)
|
114 |
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
115 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
116 |
+
x = x + dt * dphi_dt
|
117 |
+
t = t + dt
|
118 |
+
sol.append(x)
|
119 |
+
if step < len(t_span) - 1:
|
120 |
+
dt = t_span[step + 1] - t
|
121 |
+
|
122 |
+
return sol[-1].float()
|
123 |
+
|
124 |
+
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
125 |
+
if isinstance(self.estimator, torch.nn.Module):
|
126 |
+
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
127 |
+
else:
|
128 |
+
with self.lock:
|
129 |
+
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
130 |
+
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
131 |
+
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
132 |
+
self.estimator.set_input_shape('t', (2,))
|
133 |
+
self.estimator.set_input_shape('spks', (2, 80))
|
134 |
+
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
135 |
+
# run trt engine
|
136 |
+
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
137 |
+
mask.contiguous().data_ptr(),
|
138 |
+
mu.contiguous().data_ptr(),
|
139 |
+
t.contiguous().data_ptr(),
|
140 |
+
spks.contiguous().data_ptr(),
|
141 |
+
cond.contiguous().data_ptr(),
|
142 |
+
x.data_ptr()])
|
143 |
+
return x
|
144 |
+
|
145 |
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
146 |
+
"""Computes diffusion loss
|
147 |
+
|
148 |
+
Args:
|
149 |
+
x1 (torch.Tensor): Target
|
150 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
151 |
+
mask (torch.Tensor): target mask
|
152 |
+
shape: (batch_size, 1, mel_timesteps)
|
153 |
+
mu (torch.Tensor): output of encoder
|
154 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
155 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
156 |
+
shape: (batch_size, spk_emb_dim)
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
loss: conditional flow matching loss
|
160 |
+
y: conditional flow
|
161 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
162 |
+
"""
|
163 |
+
b, _, t = mu.shape
|
164 |
+
|
165 |
+
# random timestep
|
166 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
167 |
+
if self.t_scheduler == 'cosine':
|
168 |
+
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
169 |
+
# sample noise p(x_0)
|
170 |
+
z = torch.randn_like(x1)
|
171 |
+
|
172 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
173 |
+
u = x1 - (1 - self.sigma_min) * z
|
174 |
+
|
175 |
+
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
176 |
+
if self.training_cfg_rate > 0:
|
177 |
+
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
178 |
+
mu = mu * cfg_mask.view(-1, 1, 1)
|
179 |
+
spks = spks * cfg_mask.view(-1, 1)
|
180 |
+
cond = cond * cfg_mask.view(-1, 1, 1)
|
181 |
+
|
182 |
+
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
183 |
+
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
184 |
+
return loss, y
|
185 |
+
|
186 |
+
|
187 |
+
class CausalConditionalCFM(ConditionalCFM):
|
188 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
189 |
+
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
190 |
+
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
191 |
+
|
192 |
+
@torch.inference_mode()
|
193 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
194 |
+
"""Forward diffusion
|
195 |
+
|
196 |
+
Args:
|
197 |
+
mu (torch.Tensor): output of encoder
|
198 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
199 |
+
mask (torch.Tensor): output_mask
|
200 |
+
shape: (batch_size, 1, mel_timesteps)
|
201 |
+
n_timesteps (int): number of diffusion steps
|
202 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
203 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
204 |
+
shape: (batch_size, spk_emb_dim)
|
205 |
+
cond: Not used but kept for future purposes
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
sample: generated mel-spectrogram
|
209 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
210 |
+
"""
|
211 |
+
|
212 |
+
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
|
213 |
+
# fix prompt and overlap part mu and z
|
214 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
215 |
+
if self.t_scheduler == 'cosine':
|
216 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
217 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
|
tts/cosyvoice/flow/length_regulator.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Tuple
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch
|
17 |
+
from torch.nn import functional as F
|
18 |
+
from cosyvoice.utils.mask import make_pad_mask
|
19 |
+
|
20 |
+
|
21 |
+
class InterpolateRegulator(nn.Module):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
channels: int,
|
25 |
+
sampling_ratios: Tuple,
|
26 |
+
out_channels: int = None,
|
27 |
+
groups: int = 1,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.sampling_ratios = sampling_ratios
|
31 |
+
out_channels = out_channels or channels
|
32 |
+
model = nn.ModuleList([])
|
33 |
+
if len(sampling_ratios) > 0:
|
34 |
+
for _ in sampling_ratios:
|
35 |
+
module = nn.Conv1d(channels, channels, 3, 1, 1)
|
36 |
+
norm = nn.GroupNorm(groups, channels)
|
37 |
+
act = nn.Mish()
|
38 |
+
model.extend([module, norm, act])
|
39 |
+
model.append(
|
40 |
+
nn.Conv1d(channels, out_channels, 1, 1)
|
41 |
+
)
|
42 |
+
self.model = nn.Sequential(*model)
|
43 |
+
|
44 |
+
def forward(self, x, ylens=None):
|
45 |
+
# x in (B, T, D)
|
46 |
+
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
47 |
+
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
|
48 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
49 |
+
olens = ylens
|
50 |
+
return out * mask, olens
|
51 |
+
|
52 |
+
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
|
53 |
+
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
|
54 |
+
# x in (B, T, D)
|
55 |
+
if x2.shape[1] > 40:
|
56 |
+
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
57 |
+
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
|
58 |
+
mode='linear')
|
59 |
+
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
60 |
+
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
|
61 |
+
else:
|
62 |
+
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
|
63 |
+
if x1.shape[1] != 0:
|
64 |
+
x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
|
65 |
+
x = torch.concat([x1, x2], dim=2)
|
66 |
+
else:
|
67 |
+
x = x2
|
68 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
69 |
+
return out, mel_len1 + mel_len2
|
tts/cosyvoice/hifigan/discriminator.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn.utils.parametrizations import weight_norm
|
4 |
+
from typing import List, Optional, Tuple
|
5 |
+
from einops import rearrange
|
6 |
+
from torchaudio.transforms import Spectrogram
|
7 |
+
|
8 |
+
|
9 |
+
class MultipleDiscriminator(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self, mpd: nn.Module, mrd: nn.Module
|
12 |
+
):
|
13 |
+
super().__init__()
|
14 |
+
self.mpd = mpd
|
15 |
+
self.mrd = mrd
|
16 |
+
|
17 |
+
def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
|
18 |
+
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
|
19 |
+
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
|
20 |
+
y_d_rs += this_y_d_rs
|
21 |
+
y_d_gs += this_y_d_gs
|
22 |
+
fmap_rs += this_fmap_rs
|
23 |
+
fmap_gs += this_fmap_gs
|
24 |
+
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
|
25 |
+
y_d_rs += this_y_d_rs
|
26 |
+
y_d_gs += this_y_d_gs
|
27 |
+
fmap_rs += this_fmap_rs
|
28 |
+
fmap_gs += this_fmap_gs
|
29 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
30 |
+
|
31 |
+
|
32 |
+
class MultiResolutionDiscriminator(nn.Module):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
|
36 |
+
num_embeddings: Optional[int] = None,
|
37 |
+
):
|
38 |
+
"""
|
39 |
+
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
|
40 |
+
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
|
44 |
+
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
45 |
+
Defaults to None.
|
46 |
+
"""
|
47 |
+
|
48 |
+
super().__init__()
|
49 |
+
self.discriminators = nn.ModuleList(
|
50 |
+
[DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
|
51 |
+
)
|
52 |
+
|
53 |
+
def forward(
|
54 |
+
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
|
55 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
56 |
+
y_d_rs = []
|
57 |
+
y_d_gs = []
|
58 |
+
fmap_rs = []
|
59 |
+
fmap_gs = []
|
60 |
+
|
61 |
+
for d in self.discriminators:
|
62 |
+
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
63 |
+
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
64 |
+
y_d_rs.append(y_d_r)
|
65 |
+
fmap_rs.append(fmap_r)
|
66 |
+
y_d_gs.append(y_d_g)
|
67 |
+
fmap_gs.append(fmap_g)
|
68 |
+
|
69 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
70 |
+
|
71 |
+
|
72 |
+
class DiscriminatorR(nn.Module):
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
window_length: int,
|
76 |
+
num_embeddings: Optional[int] = None,
|
77 |
+
channels: int = 32,
|
78 |
+
hop_factor: float = 0.25,
|
79 |
+
bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
|
80 |
+
):
|
81 |
+
super().__init__()
|
82 |
+
self.window_length = window_length
|
83 |
+
self.hop_factor = hop_factor
|
84 |
+
self.spec_fn = Spectrogram(
|
85 |
+
n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
|
86 |
+
)
|
87 |
+
n_fft = window_length // 2 + 1
|
88 |
+
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
89 |
+
self.bands = bands
|
90 |
+
convs = lambda: nn.ModuleList(
|
91 |
+
[
|
92 |
+
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
93 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
94 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
95 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
96 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
97 |
+
]
|
98 |
+
)
|
99 |
+
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
100 |
+
|
101 |
+
if num_embeddings is not None:
|
102 |
+
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
|
103 |
+
torch.nn.init.zeros_(self.emb.weight)
|
104 |
+
|
105 |
+
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
106 |
+
|
107 |
+
def spectrogram(self, x):
|
108 |
+
# Remove DC offset
|
109 |
+
x = x - x.mean(dim=-1, keepdims=True)
|
110 |
+
# Peak normalize the volume of input audio
|
111 |
+
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
112 |
+
x = self.spec_fn(x)
|
113 |
+
x = torch.view_as_real(x)
|
114 |
+
x = rearrange(x, "b f t c -> b c t f")
|
115 |
+
# Split into bands
|
116 |
+
x_bands = [x[..., b[0]: b[1]] for b in self.bands]
|
117 |
+
return x_bands
|
118 |
+
|
119 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
|
120 |
+
x_bands = self.spectrogram(x)
|
121 |
+
fmap = []
|
122 |
+
x = []
|
123 |
+
for band, stack in zip(x_bands, self.band_convs):
|
124 |
+
for i, layer in enumerate(stack):
|
125 |
+
band = layer(band)
|
126 |
+
band = torch.nn.functional.leaky_relu(band, 0.1)
|
127 |
+
if i > 0:
|
128 |
+
fmap.append(band)
|
129 |
+
x.append(band)
|
130 |
+
x = torch.cat(x, dim=-1)
|
131 |
+
if cond_embedding_id is not None:
|
132 |
+
emb = self.emb(cond_embedding_id)
|
133 |
+
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
134 |
+
else:
|
135 |
+
h = 0
|
136 |
+
x = self.conv_post(x)
|
137 |
+
fmap.append(x)
|
138 |
+
x += h
|
139 |
+
|
140 |
+
return x, fmap
|
tts/cosyvoice/hifigan/f0_predictor.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
# from torch.nn.utils import weight_norm
|
17 |
+
from torch.nn.utils.parametrizations import weight_norm
|
18 |
+
|
19 |
+
|
20 |
+
class ConvRNNF0Predictor(nn.Module):
|
21 |
+
def __init__(self,
|
22 |
+
num_class: int = 1,
|
23 |
+
in_channels: int = 80,
|
24 |
+
cond_channels: int = 512
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
self.num_class = num_class
|
29 |
+
self.condnet = nn.Sequential(
|
30 |
+
weight_norm(
|
31 |
+
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
|
32 |
+
),
|
33 |
+
nn.ELU(),
|
34 |
+
weight_norm(
|
35 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
36 |
+
),
|
37 |
+
nn.ELU(),
|
38 |
+
weight_norm(
|
39 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
40 |
+
),
|
41 |
+
nn.ELU(),
|
42 |
+
weight_norm(
|
43 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
44 |
+
),
|
45 |
+
nn.ELU(),
|
46 |
+
weight_norm(
|
47 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
48 |
+
),
|
49 |
+
nn.ELU(),
|
50 |
+
)
|
51 |
+
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
52 |
+
|
53 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
54 |
+
x = self.condnet(x)
|
55 |
+
x = x.transpose(1, 2)
|
56 |
+
return torch.abs(self.classifier(x).squeeze(-1))
|
tts/cosyvoice/hifigan/generator.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""HIFI-GAN"""
|
16 |
+
|
17 |
+
from typing import Dict, Optional, List
|
18 |
+
import numpy as np
|
19 |
+
from scipy.signal import get_window
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from torch.nn import Conv1d
|
24 |
+
from torch.nn import ConvTranspose1d
|
25 |
+
from torch.nn.utils import remove_weight_norm
|
26 |
+
# from torch.nn.utils import weight_norm
|
27 |
+
from torch.nn.utils.parametrizations import weight_norm
|
28 |
+
from torch.distributions.uniform import Uniform
|
29 |
+
|
30 |
+
from cosyvoice.transformer.activation import Snake
|
31 |
+
from cosyvoice.utils.common import get_padding
|
32 |
+
from cosyvoice.utils.common import init_weights
|
33 |
+
|
34 |
+
|
35 |
+
"""hifigan based generator implementation.
|
36 |
+
|
37 |
+
This code is modified from https://github.com/jik876/hifi-gan
|
38 |
+
,https://github.com/kan-bayashi/ParallelWaveGAN and
|
39 |
+
https://github.com/NVIDIA/BigVGAN
|
40 |
+
|
41 |
+
"""
|
42 |
+
|
43 |
+
|
44 |
+
class ResBlock(torch.nn.Module):
|
45 |
+
"""Residual block module in HiFiGAN/BigVGAN."""
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
channels: int = 512,
|
49 |
+
kernel_size: int = 3,
|
50 |
+
dilations: List[int] = [1, 3, 5],
|
51 |
+
):
|
52 |
+
super(ResBlock, self).__init__()
|
53 |
+
self.convs1 = nn.ModuleList()
|
54 |
+
self.convs2 = nn.ModuleList()
|
55 |
+
|
56 |
+
for dilation in dilations:
|
57 |
+
self.convs1.append(
|
58 |
+
weight_norm(
|
59 |
+
Conv1d(
|
60 |
+
channels,
|
61 |
+
channels,
|
62 |
+
kernel_size,
|
63 |
+
1,
|
64 |
+
dilation=dilation,
|
65 |
+
padding=get_padding(kernel_size, dilation)
|
66 |
+
)
|
67 |
+
)
|
68 |
+
)
|
69 |
+
self.convs2.append(
|
70 |
+
weight_norm(
|
71 |
+
Conv1d(
|
72 |
+
channels,
|
73 |
+
channels,
|
74 |
+
kernel_size,
|
75 |
+
1,
|
76 |
+
dilation=1,
|
77 |
+
padding=get_padding(kernel_size, 1)
|
78 |
+
)
|
79 |
+
)
|
80 |
+
)
|
81 |
+
self.convs1.apply(init_weights)
|
82 |
+
self.convs2.apply(init_weights)
|
83 |
+
self.activations1 = nn.ModuleList([
|
84 |
+
Snake(channels, alpha_logscale=False)
|
85 |
+
for _ in range(len(self.convs1))
|
86 |
+
])
|
87 |
+
self.activations2 = nn.ModuleList([
|
88 |
+
Snake(channels, alpha_logscale=False)
|
89 |
+
for _ in range(len(self.convs2))
|
90 |
+
])
|
91 |
+
|
92 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
93 |
+
for idx in range(len(self.convs1)):
|
94 |
+
xt = self.activations1[idx](x)
|
95 |
+
xt = self.convs1[idx](xt)
|
96 |
+
xt = self.activations2[idx](xt)
|
97 |
+
xt = self.convs2[idx](xt)
|
98 |
+
x = xt + x
|
99 |
+
return x
|
100 |
+
|
101 |
+
def remove_weight_norm(self):
|
102 |
+
for idx in range(len(self.convs1)):
|
103 |
+
remove_weight_norm(self.convs1[idx])
|
104 |
+
remove_weight_norm(self.convs2[idx])
|
105 |
+
|
106 |
+
|
107 |
+
class SineGen(torch.nn.Module):
|
108 |
+
""" Definition of sine generator
|
109 |
+
SineGen(samp_rate, harmonic_num = 0,
|
110 |
+
sine_amp = 0.1, noise_std = 0.003,
|
111 |
+
voiced_threshold = 0,
|
112 |
+
flag_for_pulse=False)
|
113 |
+
samp_rate: sampling rate in Hz
|
114 |
+
harmonic_num: number of harmonic overtones (default 0)
|
115 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
116 |
+
noise_std: std of Gaussian noise (default 0.003)
|
117 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
118 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
119 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
120 |
+
segment is always sin(np.pi) or cos(0)
|
121 |
+
"""
|
122 |
+
|
123 |
+
def __init__(self, samp_rate, harmonic_num=0,
|
124 |
+
sine_amp=0.1, noise_std=0.003,
|
125 |
+
voiced_threshold=0):
|
126 |
+
super(SineGen, self).__init__()
|
127 |
+
self.sine_amp = sine_amp
|
128 |
+
self.noise_std = noise_std
|
129 |
+
self.harmonic_num = harmonic_num
|
130 |
+
self.sampling_rate = samp_rate
|
131 |
+
self.voiced_threshold = voiced_threshold
|
132 |
+
|
133 |
+
def _f02uv(self, f0):
|
134 |
+
# generate uv signal
|
135 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
136 |
+
return uv
|
137 |
+
|
138 |
+
@torch.no_grad()
|
139 |
+
def forward(self, f0):
|
140 |
+
"""
|
141 |
+
:param f0: [B, 1, sample_len], Hz
|
142 |
+
:return: [B, 1, sample_len]
|
143 |
+
"""
|
144 |
+
|
145 |
+
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
146 |
+
for i in range(self.harmonic_num + 1):
|
147 |
+
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
148 |
+
|
149 |
+
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
150 |
+
u_dist = Uniform(low=-np.pi, high=np.pi)
|
151 |
+
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
|
152 |
+
phase_vec[:, 0, :] = 0
|
153 |
+
|
154 |
+
# generate sine waveforms
|
155 |
+
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
|
156 |
+
|
157 |
+
# generate uv signal
|
158 |
+
uv = self._f02uv(f0)
|
159 |
+
|
160 |
+
# noise: for unvoiced should be similar to sine_amp
|
161 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
162 |
+
# . for voiced regions is self.noise_std
|
163 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
164 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
165 |
+
|
166 |
+
# first: set the unvoiced part to 0 by uv
|
167 |
+
# then: additive noise
|
168 |
+
sine_waves = sine_waves * uv + noise
|
169 |
+
return sine_waves, uv, noise
|
170 |
+
|
171 |
+
|
172 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
173 |
+
""" SourceModule for hn-nsf
|
174 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
175 |
+
add_noise_std=0.003, voiced_threshod=0)
|
176 |
+
sampling_rate: sampling_rate in Hz
|
177 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
178 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
179 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
180 |
+
note that amplitude of noise in unvoiced is decided
|
181 |
+
by sine_amp
|
182 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
183 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
184 |
+
F0_sampled (batchsize, length, 1)
|
185 |
+
Sine_source (batchsize, length, 1)
|
186 |
+
noise_source (batchsize, length 1)
|
187 |
+
uv (batchsize, length, 1)
|
188 |
+
"""
|
189 |
+
|
190 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
191 |
+
add_noise_std=0.003, voiced_threshod=0):
|
192 |
+
super(SourceModuleHnNSF, self).__init__()
|
193 |
+
|
194 |
+
self.sine_amp = sine_amp
|
195 |
+
self.noise_std = add_noise_std
|
196 |
+
|
197 |
+
# to produce sine waveforms
|
198 |
+
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
199 |
+
sine_amp, add_noise_std, voiced_threshod)
|
200 |
+
|
201 |
+
# to merge source harmonics into a single excitation
|
202 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
203 |
+
self.l_tanh = torch.nn.Tanh()
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
"""
|
207 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
208 |
+
F0_sampled (batchsize, length, 1)
|
209 |
+
Sine_source (batchsize, length, 1)
|
210 |
+
noise_source (batchsize, length 1)
|
211 |
+
"""
|
212 |
+
# source for harmonic branch
|
213 |
+
with torch.no_grad():
|
214 |
+
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
|
215 |
+
sine_wavs = sine_wavs.transpose(1, 2)
|
216 |
+
uv = uv.transpose(1, 2)
|
217 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
218 |
+
|
219 |
+
# source for noise branch, in the same shape as uv
|
220 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
221 |
+
return sine_merge, noise, uv
|
222 |
+
|
223 |
+
|
224 |
+
class HiFTGenerator(nn.Module):
|
225 |
+
"""
|
226 |
+
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
227 |
+
https://arxiv.org/abs/2309.09493
|
228 |
+
"""
|
229 |
+
def __init__(
|
230 |
+
self,
|
231 |
+
in_channels: int = 80,
|
232 |
+
base_channels: int = 512,
|
233 |
+
nb_harmonics: int = 8,
|
234 |
+
sampling_rate: int = 22050,
|
235 |
+
nsf_alpha: float = 0.1,
|
236 |
+
nsf_sigma: float = 0.003,
|
237 |
+
nsf_voiced_threshold: float = 10,
|
238 |
+
upsample_rates: List[int] = [8, 8],
|
239 |
+
upsample_kernel_sizes: List[int] = [16, 16],
|
240 |
+
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
241 |
+
resblock_kernel_sizes: List[int] = [3, 7, 11],
|
242 |
+
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
243 |
+
source_resblock_kernel_sizes: List[int] = [7, 11],
|
244 |
+
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
|
245 |
+
lrelu_slope: float = 0.1,
|
246 |
+
audio_limit: float = 0.99,
|
247 |
+
f0_predictor: torch.nn.Module = None,
|
248 |
+
):
|
249 |
+
super(HiFTGenerator, self).__init__()
|
250 |
+
|
251 |
+
self.out_channels = 1
|
252 |
+
self.nb_harmonics = nb_harmonics
|
253 |
+
self.sampling_rate = sampling_rate
|
254 |
+
self.istft_params = istft_params
|
255 |
+
self.lrelu_slope = lrelu_slope
|
256 |
+
self.audio_limit = audio_limit
|
257 |
+
|
258 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
259 |
+
self.num_upsamples = len(upsample_rates)
|
260 |
+
self.m_source = SourceModuleHnNSF(
|
261 |
+
sampling_rate=sampling_rate,
|
262 |
+
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
263 |
+
harmonic_num=nb_harmonics,
|
264 |
+
sine_amp=nsf_alpha,
|
265 |
+
add_noise_std=nsf_sigma,
|
266 |
+
voiced_threshod=nsf_voiced_threshold)
|
267 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
268 |
+
|
269 |
+
self.conv_pre = weight_norm(
|
270 |
+
Conv1d(in_channels, base_channels, 7, 1, padding=3)
|
271 |
+
)
|
272 |
+
|
273 |
+
# Up
|
274 |
+
self.ups = nn.ModuleList()
|
275 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
276 |
+
self.ups.append(
|
277 |
+
weight_norm(
|
278 |
+
ConvTranspose1d(
|
279 |
+
base_channels // (2**i),
|
280 |
+
base_channels // (2**(i + 1)),
|
281 |
+
k,
|
282 |
+
u,
|
283 |
+
padding=(k - u) // 2,
|
284 |
+
)
|
285 |
+
)
|
286 |
+
)
|
287 |
+
|
288 |
+
# Down
|
289 |
+
self.source_downs = nn.ModuleList()
|
290 |
+
self.source_resblocks = nn.ModuleList()
|
291 |
+
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
292 |
+
downsample_cum_rates = np.cumprod(downsample_rates)
|
293 |
+
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
|
294 |
+
if u == 1:
|
295 |
+
self.source_downs.append(
|
296 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
297 |
+
)
|
298 |
+
else:
|
299 |
+
self.source_downs.append(
|
300 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
|
301 |
+
)
|
302 |
+
|
303 |
+
self.source_resblocks.append(
|
304 |
+
ResBlock(base_channels // (2 ** (i + 1)), k, d)
|
305 |
+
)
|
306 |
+
|
307 |
+
self.resblocks = nn.ModuleList()
|
308 |
+
for i in range(len(self.ups)):
|
309 |
+
ch = base_channels // (2**(i + 1))
|
310 |
+
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
311 |
+
self.resblocks.append(ResBlock(ch, k, d))
|
312 |
+
|
313 |
+
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
|
314 |
+
self.ups.apply(init_weights)
|
315 |
+
self.conv_post.apply(init_weights)
|
316 |
+
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
317 |
+
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
318 |
+
self.f0_predictor = f0_predictor
|
319 |
+
|
320 |
+
def remove_weight_norm(self):
|
321 |
+
print('Removing weight norm...')
|
322 |
+
for l in self.ups:
|
323 |
+
remove_weight_norm(l)
|
324 |
+
for l in self.resblocks:
|
325 |
+
l.remove_weight_norm()
|
326 |
+
remove_weight_norm(self.conv_pre)
|
327 |
+
remove_weight_norm(self.conv_post)
|
328 |
+
self.m_source.remove_weight_norm()
|
329 |
+
for l in self.source_downs:
|
330 |
+
remove_weight_norm(l)
|
331 |
+
for l in self.source_resblocks:
|
332 |
+
l.remove_weight_norm()
|
333 |
+
|
334 |
+
def _stft(self, x):
|
335 |
+
spec = torch.stft(
|
336 |
+
x,
|
337 |
+
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
|
338 |
+
return_complex=True)
|
339 |
+
spec = torch.view_as_real(spec) # [B, F, TT, 2]
|
340 |
+
return spec[..., 0], spec[..., 1]
|
341 |
+
|
342 |
+
def _istft(self, magnitude, phase):
|
343 |
+
magnitude = torch.clip(magnitude, max=1e2)
|
344 |
+
real = magnitude * torch.cos(phase)
|
345 |
+
img = magnitude * torch.sin(phase)
|
346 |
+
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
|
347 |
+
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
348 |
+
return inverse_transform
|
349 |
+
|
350 |
+
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
351 |
+
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
352 |
+
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
353 |
+
|
354 |
+
x = self.conv_pre(x)
|
355 |
+
for i in range(self.num_upsamples):
|
356 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
357 |
+
x = self.ups[i](x)
|
358 |
+
|
359 |
+
if i == self.num_upsamples - 1:
|
360 |
+
x = self.reflection_pad(x)
|
361 |
+
|
362 |
+
# fusion
|
363 |
+
si = self.source_downs[i](s_stft)
|
364 |
+
si = self.source_resblocks[i](si)
|
365 |
+
x = x + si
|
366 |
+
|
367 |
+
xs = None
|
368 |
+
for j in range(self.num_kernels):
|
369 |
+
if xs is None:
|
370 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
371 |
+
else:
|
372 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
373 |
+
x = xs / self.num_kernels
|
374 |
+
|
375 |
+
x = F.leaky_relu(x)
|
376 |
+
x = self.conv_post(x)
|
377 |
+
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
|
378 |
+
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
|
379 |
+
|
380 |
+
x = self._istft(magnitude, phase)
|
381 |
+
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
382 |
+
return x
|
383 |
+
|
384 |
+
def forward(
|
385 |
+
self,
|
386 |
+
batch: dict,
|
387 |
+
device: torch.device,
|
388 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
389 |
+
speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
|
390 |
+
# mel->f0
|
391 |
+
f0 = self.f0_predictor(speech_feat)
|
392 |
+
# f0->source
|
393 |
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
394 |
+
s, _, _ = self.m_source(s)
|
395 |
+
s = s.transpose(1, 2)
|
396 |
+
# mel+source->speech
|
397 |
+
generated_speech = self.decode(x=speech_feat, s=s)
|
398 |
+
return generated_speech, f0
|
399 |
+
|
400 |
+
@torch.inference_mode()
|
401 |
+
def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
402 |
+
# mel->f0
|
403 |
+
f0 = self.f0_predictor(speech_feat)
|
404 |
+
# f0->source
|
405 |
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
406 |
+
s, _, _ = self.m_source(s)
|
407 |
+
s = s.transpose(1, 2)
|
408 |
+
# use cache_source to avoid glitch
|
409 |
+
if cache_source.shape[2] != 0:
|
410 |
+
s[:, :, :cache_source.shape[2]] = cache_source
|
411 |
+
generated_speech = self.decode(x=speech_feat, s=s)
|
412 |
+
return generated_speech, s
|
tts/cosyvoice/hifigan/hifigan.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
|
6 |
+
from cosyvoice.utils.losses import tpr_loss, mel_loss
|
7 |
+
|
8 |
+
|
9 |
+
class HiFiGan(nn.Module):
|
10 |
+
def __init__(self, generator, discriminator, mel_spec_transform,
|
11 |
+
multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
|
12 |
+
tpr_loss_weight=1.0, tpr_loss_tau=0.04):
|
13 |
+
super(HiFiGan, self).__init__()
|
14 |
+
self.generator = generator
|
15 |
+
self.discriminator = discriminator
|
16 |
+
self.mel_spec_transform = mel_spec_transform
|
17 |
+
self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
|
18 |
+
self.feat_match_loss_weight = feat_match_loss_weight
|
19 |
+
self.tpr_loss_weight = tpr_loss_weight
|
20 |
+
self.tpr_loss_tau = tpr_loss_tau
|
21 |
+
|
22 |
+
def forward(
|
23 |
+
self,
|
24 |
+
batch: dict,
|
25 |
+
device: torch.device,
|
26 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
27 |
+
if batch['turn'] == 'generator':
|
28 |
+
return self.forward_generator(batch, device)
|
29 |
+
else:
|
30 |
+
return self.forward_discriminator(batch, device)
|
31 |
+
|
32 |
+
def forward_generator(self, batch, device):
|
33 |
+
real_speech = batch['speech'].to(device)
|
34 |
+
pitch_feat = batch['pitch_feat'].to(device)
|
35 |
+
# 1. calculate generator outputs
|
36 |
+
generated_speech, generated_f0 = self.generator(batch, device)
|
37 |
+
# 2. calculate discriminator outputs
|
38 |
+
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
|
39 |
+
# 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
|
40 |
+
loss_gen, _ = generator_loss(y_d_gs)
|
41 |
+
loss_fm = feature_loss(fmap_rs, fmap_gs)
|
42 |
+
loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
|
43 |
+
if self.tpr_loss_weight != 0:
|
44 |
+
loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
|
45 |
+
else:
|
46 |
+
loss_tpr = torch.zeros(1).to(device)
|
47 |
+
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
|
48 |
+
loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
|
49 |
+
self.multi_mel_spectral_recon_loss_weight * loss_mel + \
|
50 |
+
self.tpr_loss_weight * loss_tpr + loss_f0
|
51 |
+
return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
|
52 |
+
|
53 |
+
def forward_discriminator(self, batch, device):
|
54 |
+
real_speech = batch['speech'].to(device)
|
55 |
+
# 1. calculate generator outputs
|
56 |
+
with torch.no_grad():
|
57 |
+
generated_speech, generated_f0 = self.generator(batch, device)
|
58 |
+
# 2. calculate discriminator outputs
|
59 |
+
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
|
60 |
+
# 3. calculate discriminator losses, tpr losses [Optional]
|
61 |
+
loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
|
62 |
+
if self.tpr_loss_weight != 0:
|
63 |
+
loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
|
64 |
+
else:
|
65 |
+
loss_tpr = torch.zeros(1).to(device)
|
66 |
+
loss = loss_disc + self.tpr_loss_weight * loss_tpr
|
67 |
+
return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr}
|
tts/cosyvoice/llm/llm.py
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Dict, Optional, Callable, List, Generator
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from transformers import Qwen2ForCausalLM
|
19 |
+
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
|
20 |
+
from cosyvoice.utils.common import IGNORE_ID
|
21 |
+
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
|
22 |
+
from cosyvoice.utils.common import th_accuracy
|
23 |
+
from cosyvoice.utils.file_utils import logging
|
24 |
+
|
25 |
+
|
26 |
+
class TransformerLM(torch.nn.Module):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
text_encoder_input_size: int,
|
30 |
+
llm_input_size: int,
|
31 |
+
llm_output_size: int,
|
32 |
+
text_token_size: int,
|
33 |
+
speech_token_size: int,
|
34 |
+
text_encoder: torch.nn.Module,
|
35 |
+
llm: torch.nn.Module,
|
36 |
+
sampling: Callable,
|
37 |
+
length_normalized_loss: bool = True,
|
38 |
+
lsm_weight: float = 0.0,
|
39 |
+
spk_embed_dim: int = 192,
|
40 |
+
):
|
41 |
+
super().__init__()
|
42 |
+
self.llm_input_size = llm_input_size
|
43 |
+
self.speech_token_size = speech_token_size
|
44 |
+
# 1. build text token inputs related modules
|
45 |
+
self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
|
46 |
+
self.text_encoder = text_encoder
|
47 |
+
self.text_encoder_affine_layer = nn.Linear(
|
48 |
+
self.text_encoder.output_size(),
|
49 |
+
llm_input_size
|
50 |
+
)
|
51 |
+
|
52 |
+
# 2. build speech token language model related modules
|
53 |
+
self.sos_eos = 0
|
54 |
+
self.task_id = 1
|
55 |
+
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
56 |
+
self.llm = llm
|
57 |
+
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
|
58 |
+
self.criterion_ce = LabelSmoothingLoss(
|
59 |
+
size=speech_token_size + 1,
|
60 |
+
padding_idx=IGNORE_ID,
|
61 |
+
smoothing=lsm_weight,
|
62 |
+
normalize_length=length_normalized_loss,
|
63 |
+
)
|
64 |
+
|
65 |
+
# 3. [Optional] build speech token related modules
|
66 |
+
self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
|
67 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
|
68 |
+
|
69 |
+
# 4. sampling method
|
70 |
+
self.sampling = sampling
|
71 |
+
|
72 |
+
def encode(
|
73 |
+
self,
|
74 |
+
text: torch.Tensor,
|
75 |
+
text_lengths: torch.Tensor,
|
76 |
+
):
|
77 |
+
encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
|
78 |
+
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
|
79 |
+
encoder_out = self.text_encoder_affine_layer(encoder_out)
|
80 |
+
return encoder_out, encoder_out_lens
|
81 |
+
|
82 |
+
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
|
83 |
+
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
84 |
+
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
85 |
+
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
|
86 |
+
for i in range(len(text_token))]
|
87 |
+
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
88 |
+
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
|
89 |
+
return lm_input, lm_input_len
|
90 |
+
|
91 |
+
def forward(
|
92 |
+
self,
|
93 |
+
batch: dict,
|
94 |
+
device: torch.device,
|
95 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
96 |
+
"""
|
97 |
+
Args:
|
98 |
+
text: (B, L, D)
|
99 |
+
text_lengths: (B,)
|
100 |
+
audio: (B, T, N) or (B, T)
|
101 |
+
audio_lengths: (B,)
|
102 |
+
"""
|
103 |
+
text_token = batch['text_token'].to(device)
|
104 |
+
text_token_len = batch['text_token_len'].to(device)
|
105 |
+
speech_token = batch['speech_token'].to(device)
|
106 |
+
speech_token_len = batch['speech_token_len'].to(device)
|
107 |
+
embedding = batch['embedding'].to(device)
|
108 |
+
|
109 |
+
# 1. prepare llm_target
|
110 |
+
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
|
111 |
+
[self.speech_token_size]) for i in range(text_token.size(0))]
|
112 |
+
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
|
113 |
+
|
114 |
+
# 1. encode text_token
|
115 |
+
text_token = self.text_embedding(text_token)
|
116 |
+
text_token, text_token_len = self.encode(text_token, text_token_len)
|
117 |
+
|
118 |
+
# 2. embedding projection
|
119 |
+
embedding = F.normalize(embedding, dim=1)
|
120 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
121 |
+
embedding = embedding.unsqueeze(1)
|
122 |
+
|
123 |
+
# 3. eos and task_id
|
124 |
+
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
125 |
+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
126 |
+
|
127 |
+
# 4. encode speech_token
|
128 |
+
speech_token = self.speech_embedding(speech_token)
|
129 |
+
|
130 |
+
# 5. unpad and pad
|
131 |
+
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
|
132 |
+
task_id_emb, speech_token, speech_token_len)
|
133 |
+
|
134 |
+
# 6. run lm forward
|
135 |
+
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
136 |
+
logits = self.llm_decoder(lm_output)
|
137 |
+
loss = self.criterion_ce(logits, lm_target)
|
138 |
+
acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
|
139 |
+
return {'loss': loss, 'acc': acc}
|
140 |
+
|
141 |
+
def sampling_ids(
|
142 |
+
self,
|
143 |
+
weighted_scores: torch.Tensor,
|
144 |
+
decoded_tokens: List,
|
145 |
+
sampling: int,
|
146 |
+
ignore_eos: bool = True,
|
147 |
+
):
|
148 |
+
num_trials, max_trials = 0, 100
|
149 |
+
while True:
|
150 |
+
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
151 |
+
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
152 |
+
break
|
153 |
+
num_trials += 1
|
154 |
+
if num_trials > max_trials:
|
155 |
+
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
|
156 |
+
return top_ids
|
157 |
+
|
158 |
+
@torch.inference_mode()
|
159 |
+
def inference(
|
160 |
+
self,
|
161 |
+
text: torch.Tensor,
|
162 |
+
text_len: torch.Tensor,
|
163 |
+
prompt_text: torch.Tensor,
|
164 |
+
prompt_text_len: torch.Tensor,
|
165 |
+
prompt_speech_token: torch.Tensor,
|
166 |
+
prompt_speech_token_len: torch.Tensor,
|
167 |
+
embedding: torch.Tensor,
|
168 |
+
sampling: int = 25,
|
169 |
+
max_token_text_ratio: float = 20,
|
170 |
+
min_token_text_ratio: float = 2,
|
171 |
+
) -> Generator[torch.Tensor, None, None]:
|
172 |
+
if self.fp16 is True:
|
173 |
+
embedding = embedding.half()
|
174 |
+
|
175 |
+
device = text.device
|
176 |
+
text = torch.concat([prompt_text, text], dim=1)
|
177 |
+
text_len += prompt_text_len
|
178 |
+
text = self.text_embedding(text)
|
179 |
+
|
180 |
+
# 1. encode text
|
181 |
+
text, text_len = self.encode(text, text_len)
|
182 |
+
|
183 |
+
# 2. encode embedding
|
184 |
+
if embedding.shape[0] != 0:
|
185 |
+
embedding = F.normalize(embedding, dim=1)
|
186 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
187 |
+
embedding = embedding.unsqueeze(dim=1)
|
188 |
+
else:
|
189 |
+
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
|
190 |
+
|
191 |
+
# 3. concat llm_input
|
192 |
+
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
193 |
+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
194 |
+
if prompt_speech_token_len != 0:
|
195 |
+
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
196 |
+
else:
|
197 |
+
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
198 |
+
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
199 |
+
|
200 |
+
# 4. cal min/max_length
|
201 |
+
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
202 |
+
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
203 |
+
|
204 |
+
# 5. step by step decode
|
205 |
+
out_tokens = []
|
206 |
+
offset = 0
|
207 |
+
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
|
208 |
+
for i in range(max_len):
|
209 |
+
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
|
210 |
+
att_cache=att_cache, cnn_cache=cnn_cache,
|
211 |
+
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
|
212 |
+
device=lm_input.device)).to(torch.bool))
|
213 |
+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
214 |
+
# force continue decode first token
|
215 |
+
if i == 0:
|
216 |
+
logp[:, self.speech_token_size] = -float('inf')
|
217 |
+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
218 |
+
if top_ids == self.speech_token_size:
|
219 |
+
break
|
220 |
+
# in stream mode, yield token one by one
|
221 |
+
yield top_ids
|
222 |
+
out_tokens.append(top_ids)
|
223 |
+
offset += lm_input.size(1)
|
224 |
+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
225 |
+
|
226 |
+
|
227 |
+
class Qwen2Encoder(torch.nn.Module):
|
228 |
+
def __init__(self, pretrain_path):
|
229 |
+
super().__init__()
|
230 |
+
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
|
231 |
+
|
232 |
+
def forward_one_step(self, xs, masks, cache=None):
|
233 |
+
input_masks = masks[:, -1, :]
|
234 |
+
outs = self.model(
|
235 |
+
inputs_embeds=xs,
|
236 |
+
attention_mask=input_masks,
|
237 |
+
output_hidden_states=True,
|
238 |
+
return_dict=True,
|
239 |
+
use_cache=True,
|
240 |
+
past_key_values=cache,
|
241 |
+
)
|
242 |
+
xs = outs.hidden_states[-1]
|
243 |
+
new_cache = outs.past_key_values
|
244 |
+
return xs, new_cache
|
245 |
+
|
246 |
+
|
247 |
+
class Qwen2LM(TransformerLM):
|
248 |
+
def __init__(
|
249 |
+
self,
|
250 |
+
llm_input_size: int,
|
251 |
+
llm_output_size: int,
|
252 |
+
speech_token_size: int,
|
253 |
+
llm: torch.nn.Module,
|
254 |
+
sampling: Callable,
|
255 |
+
length_normalized_loss: bool = True,
|
256 |
+
lsm_weight: float = 0.0,
|
257 |
+
mix_ratio: List[int] = [5, 15],
|
258 |
+
):
|
259 |
+
torch.nn.Module.__init__(self)
|
260 |
+
self.llm_input_size = llm_input_size
|
261 |
+
self.llm_output_size = llm_output_size
|
262 |
+
self.speech_token_size = speech_token_size
|
263 |
+
|
264 |
+
# 2. build speech token language model related modules
|
265 |
+
self.sos_eos = 0
|
266 |
+
self.task_id = 1
|
267 |
+
self.fill_token = 2
|
268 |
+
|
269 |
+
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
270 |
+
self.llm = llm
|
271 |
+
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
|
272 |
+
self.criterion_ce = LabelSmoothingLoss(
|
273 |
+
size=speech_token_size + 3,
|
274 |
+
padding_idx=IGNORE_ID,
|
275 |
+
smoothing=lsm_weight,
|
276 |
+
normalize_length=length_normalized_loss,
|
277 |
+
)
|
278 |
+
|
279 |
+
# 3. [Optional] build speech token related modules
|
280 |
+
self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
|
281 |
+
|
282 |
+
# 4. sampling method
|
283 |
+
self.sampling = sampling
|
284 |
+
self.mix_ratio = mix_ratio
|
285 |
+
|
286 |
+
@torch.inference_mode()
|
287 |
+
def inference(
|
288 |
+
self,
|
289 |
+
text: torch.Tensor,
|
290 |
+
text_len: torch.Tensor,
|
291 |
+
prompt_text: torch.Tensor,
|
292 |
+
prompt_text_len: torch.Tensor,
|
293 |
+
prompt_speech_token: torch.Tensor,
|
294 |
+
prompt_speech_token_len: torch.Tensor,
|
295 |
+
embedding: torch.Tensor,
|
296 |
+
sampling: int = 25,
|
297 |
+
max_token_text_ratio: float = 20,
|
298 |
+
min_token_text_ratio: float = 2,
|
299 |
+
) -> Generator[torch.Tensor, None, None]:
|
300 |
+
device = text.device
|
301 |
+
text = torch.concat([prompt_text, text], dim=1)
|
302 |
+
text_len += prompt_text_len
|
303 |
+
text = self.llm.model.model.embed_tokens(text)
|
304 |
+
|
305 |
+
# 3. concat llm_input
|
306 |
+
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
307 |
+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
308 |
+
if prompt_speech_token_len != 0:
|
309 |
+
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
310 |
+
else:
|
311 |
+
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
312 |
+
lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
313 |
+
|
314 |
+
# 4. cal min/max_length
|
315 |
+
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
316 |
+
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
317 |
+
|
318 |
+
# 5. step by step decode
|
319 |
+
out_tokens = []
|
320 |
+
cache = None
|
321 |
+
for i in range(max_len):
|
322 |
+
y_pred, cache = self.llm.forward_one_step(lm_input,
|
323 |
+
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
|
324 |
+
cache=cache)
|
325 |
+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
326 |
+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
327 |
+
if top_ids == self.speech_token_size:
|
328 |
+
break
|
329 |
+
if top_ids > self.speech_token_size:
|
330 |
+
continue
|
331 |
+
# in stream mode, yield token one by one
|
332 |
+
yield top_ids
|
333 |
+
out_tokens.append(top_ids)
|
334 |
+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
335 |
+
|
336 |
+
@torch.inference_mode()
|
337 |
+
def inference_bistream(
|
338 |
+
self,
|
339 |
+
text: Generator,
|
340 |
+
prompt_text: torch.Tensor,
|
341 |
+
prompt_text_len: torch.Tensor,
|
342 |
+
prompt_speech_token: torch.Tensor,
|
343 |
+
prompt_speech_token_len: torch.Tensor,
|
344 |
+
embedding: torch.Tensor,
|
345 |
+
sampling: int = 25,
|
346 |
+
max_token_text_ratio: float = 20,
|
347 |
+
min_token_text_ratio: float = 2,
|
348 |
+
) -> Generator[torch.Tensor, None, None]:
|
349 |
+
|
350 |
+
device = prompt_text.device
|
351 |
+
# 1. prepare input
|
352 |
+
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
353 |
+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
354 |
+
if prompt_speech_token_len != 0:
|
355 |
+
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
356 |
+
else:
|
357 |
+
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
|
358 |
+
lm_input = torch.concat([sos_eos_emb], dim=1)
|
359 |
+
|
360 |
+
# 2. iterate text
|
361 |
+
out_tokens = []
|
362 |
+
cache = None
|
363 |
+
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
|
364 |
+
text_cache = self.llm.model.model.embed_tokens(prompt_text)
|
365 |
+
next_fill_index = -1
|
366 |
+
for this_text in text:
|
367 |
+
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
|
368 |
+
# prompt_speech_token_emb not empty, try append to lm_input
|
369 |
+
while prompt_speech_token_emb.size(1) != 0:
|
370 |
+
if text_cache.size(1) >= self.mix_ratio[0]:
|
371 |
+
lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
|
372 |
+
logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
|
373 |
+
lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
|
374 |
+
text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
|
375 |
+
else:
|
376 |
+
logging.info('not enough text token to decode, wait for more')
|
377 |
+
break
|
378 |
+
# no prompt_speech_token_emb remain, can decode some speech token
|
379 |
+
if prompt_speech_token_emb.size(1) == 0:
|
380 |
+
if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
|
381 |
+
logging.info('get fill token, need to append more text token')
|
382 |
+
if text_cache.size(1) >= self.mix_ratio[0]:
|
383 |
+
lm_input_text = text_cache[:, :self.mix_ratio[0]]
|
384 |
+
logging.info('append {} text token'.format(lm_input_text.size(1)))
|
385 |
+
if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
|
386 |
+
lm_input = lm_input_text
|
387 |
+
else:
|
388 |
+
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
|
389 |
+
text_cache = text_cache[:, self.mix_ratio[0]:]
|
390 |
+
else:
|
391 |
+
logging.info('not enough text token to decode, wait for more')
|
392 |
+
continue
|
393 |
+
while True:
|
394 |
+
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
|
395 |
+
y_pred, cache = self.llm.forward_one_step(lm_input,
|
396 |
+
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
|
397 |
+
cache=cache)
|
398 |
+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
399 |
+
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
|
400 |
+
top_ids = self.speech_token_size + 2
|
401 |
+
next_fill_index += (self.mix_ratio[1] + 1)
|
402 |
+
else:
|
403 |
+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
|
404 |
+
if top_ids == self.speech_token_size + 2:
|
405 |
+
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
|
406 |
+
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
|
407 |
+
out_tokens.append(top_ids)
|
408 |
+
if top_ids >= self.speech_token_size:
|
409 |
+
if top_ids == self.speech_token_size + 2:
|
410 |
+
break
|
411 |
+
else:
|
412 |
+
raise ValueError('should not get token {}'.format(top_ids))
|
413 |
+
yield top_ids
|
414 |
+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
415 |
+
|
416 |
+
# 3. final decode
|
417 |
+
lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
|
418 |
+
logging.info('no more text token, decode until met eos')
|
419 |
+
while True:
|
420 |
+
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
|
421 |
+
y_pred, cache = self.llm.forward_one_step(lm_input,
|
422 |
+
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
|
423 |
+
cache=cache)
|
424 |
+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
425 |
+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
|
426 |
+
out_tokens.append(top_ids)
|
427 |
+
if top_ids >= self.speech_token_size:
|
428 |
+
if top_ids == self.speech_token_size:
|
429 |
+
break
|
430 |
+
else:
|
431 |
+
raise ValueError('should not get token {}'.format(top_ids))
|
432 |
+
# in stream mode, yield token one by one
|
433 |
+
yield top_ids
|
434 |
+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
tts/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tts/cosyvoice/tokenizer/tokenizer.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import os
|
3 |
+
from functools import lru_cache
|
4 |
+
from typing import Optional
|
5 |
+
import torch
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
from whisper.tokenizer import Tokenizer
|
8 |
+
|
9 |
+
import tiktoken
|
10 |
+
|
11 |
+
LANGUAGES = {
|
12 |
+
"en": "english",
|
13 |
+
"zh": "chinese",
|
14 |
+
"de": "german",
|
15 |
+
"es": "spanish",
|
16 |
+
"ru": "russian",
|
17 |
+
"ko": "korean",
|
18 |
+
"fr": "french",
|
19 |
+
"ja": "japanese",
|
20 |
+
"pt": "portuguese",
|
21 |
+
"tr": "turkish",
|
22 |
+
"pl": "polish",
|
23 |
+
"ca": "catalan",
|
24 |
+
"nl": "dutch",
|
25 |
+
"ar": "arabic",
|
26 |
+
"sv": "swedish",
|
27 |
+
"it": "italian",
|
28 |
+
"id": "indonesian",
|
29 |
+
"hi": "hindi",
|
30 |
+
"fi": "finnish",
|
31 |
+
"vi": "vietnamese",
|
32 |
+
"he": "hebrew",
|
33 |
+
"uk": "ukrainian",
|
34 |
+
"el": "greek",
|
35 |
+
"ms": "malay",
|
36 |
+
"cs": "czech",
|
37 |
+
"ro": "romanian",
|
38 |
+
"da": "danish",
|
39 |
+
"hu": "hungarian",
|
40 |
+
"ta": "tamil",
|
41 |
+
"no": "norwegian",
|
42 |
+
"th": "thai",
|
43 |
+
"ur": "urdu",
|
44 |
+
"hr": "croatian",
|
45 |
+
"bg": "bulgarian",
|
46 |
+
"lt": "lithuanian",
|
47 |
+
"la": "latin",
|
48 |
+
"mi": "maori",
|
49 |
+
"ml": "malayalam",
|
50 |
+
"cy": "welsh",
|
51 |
+
"sk": "slovak",
|
52 |
+
"te": "telugu",
|
53 |
+
"fa": "persian",
|
54 |
+
"lv": "latvian",
|
55 |
+
"bn": "bengali",
|
56 |
+
"sr": "serbian",
|
57 |
+
"az": "azerbaijani",
|
58 |
+
"sl": "slovenian",
|
59 |
+
"kn": "kannada",
|
60 |
+
"et": "estonian",
|
61 |
+
"mk": "macedonian",
|
62 |
+
"br": "breton",
|
63 |
+
"eu": "basque",
|
64 |
+
"is": "icelandic",
|
65 |
+
"hy": "armenian",
|
66 |
+
"ne": "nepali",
|
67 |
+
"mn": "mongolian",
|
68 |
+
"bs": "bosnian",
|
69 |
+
"kk": "kazakh",
|
70 |
+
"sq": "albanian",
|
71 |
+
"sw": "swahili",
|
72 |
+
"gl": "galician",
|
73 |
+
"mr": "marathi",
|
74 |
+
"pa": "punjabi",
|
75 |
+
"si": "sinhala",
|
76 |
+
"km": "khmer",
|
77 |
+
"sn": "shona",
|
78 |
+
"yo": "yoruba",
|
79 |
+
"so": "somali",
|
80 |
+
"af": "afrikaans",
|
81 |
+
"oc": "occitan",
|
82 |
+
"ka": "georgian",
|
83 |
+
"be": "belarusian",
|
84 |
+
"tg": "tajik",
|
85 |
+
"sd": "sindhi",
|
86 |
+
"gu": "gujarati",
|
87 |
+
"am": "amharic",
|
88 |
+
"yi": "yiddish",
|
89 |
+
"lo": "lao",
|
90 |
+
"uz": "uzbek",
|
91 |
+
"fo": "faroese",
|
92 |
+
"ht": "haitian creole",
|
93 |
+
"ps": "pashto",
|
94 |
+
"tk": "turkmen",
|
95 |
+
"nn": "nynorsk",
|
96 |
+
"mt": "maltese",
|
97 |
+
"sa": "sanskrit",
|
98 |
+
"lb": "luxembourgish",
|
99 |
+
"my": "myanmar",
|
100 |
+
"bo": "tibetan",
|
101 |
+
"tl": "tagalog",
|
102 |
+
"mg": "malagasy",
|
103 |
+
"as": "assamese",
|
104 |
+
"tt": "tatar",
|
105 |
+
"haw": "hawaiian",
|
106 |
+
"ln": "lingala",
|
107 |
+
"ha": "hausa",
|
108 |
+
"ba": "bashkir",
|
109 |
+
"jw": "javanese",
|
110 |
+
"su": "sundanese",
|
111 |
+
"yue": "cantonese",
|
112 |
+
"minnan": "minnan",
|
113 |
+
"wuyu": "wuyu",
|
114 |
+
"dialect": "dialect",
|
115 |
+
"zh/en": "zh/en",
|
116 |
+
"en/zh": "en/zh",
|
117 |
+
}
|
118 |
+
|
119 |
+
# language code lookup by name, with a few language aliases
|
120 |
+
TO_LANGUAGE_CODE = {
|
121 |
+
**{language: code for code, language in LANGUAGES.items()},
|
122 |
+
"burmese": "my",
|
123 |
+
"valencian": "ca",
|
124 |
+
"flemish": "nl",
|
125 |
+
"haitian": "ht",
|
126 |
+
"letzeburgesch": "lb",
|
127 |
+
"pushto": "ps",
|
128 |
+
"panjabi": "pa",
|
129 |
+
"moldavian": "ro",
|
130 |
+
"moldovan": "ro",
|
131 |
+
"sinhalese": "si",
|
132 |
+
"castilian": "es",
|
133 |
+
"mandarin": "zh",
|
134 |
+
}
|
135 |
+
|
136 |
+
AUDIO_EVENT = {
|
137 |
+
"ASR": "ASR",
|
138 |
+
"AED": "AED",
|
139 |
+
"SER": "SER",
|
140 |
+
"Speech": "Speech",
|
141 |
+
"/Speech": "/Speech",
|
142 |
+
"BGM": "BGM",
|
143 |
+
"/BGM": "/BGM",
|
144 |
+
"Laughter": "Laughter",
|
145 |
+
"/Laughter": "/Laughter",
|
146 |
+
"Applause": "Applause",
|
147 |
+
"/Applause": "/Applause",
|
148 |
+
}
|
149 |
+
|
150 |
+
EMOTION = {
|
151 |
+
"HAPPY": "HAPPY",
|
152 |
+
"SAD": "SAD",
|
153 |
+
"ANGRY": "ANGRY",
|
154 |
+
"NEUTRAL": "NEUTRAL",
|
155 |
+
}
|
156 |
+
|
157 |
+
TTS_Vocal_Token = {
|
158 |
+
"TTS/B": "TTS/B",
|
159 |
+
"TTS/O": "TTS/O",
|
160 |
+
"TTS/Q": "TTS/Q",
|
161 |
+
"TTS/A": "TTS/A",
|
162 |
+
"TTS/CO": "TTS/CO",
|
163 |
+
"TTS/CL": "TTS/CL",
|
164 |
+
"TTS/H": "TTS/H",
|
165 |
+
**{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
|
166 |
+
}
|
167 |
+
|
168 |
+
|
169 |
+
@lru_cache(maxsize=None)
|
170 |
+
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
171 |
+
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
172 |
+
ranks = {
|
173 |
+
base64.b64decode(token): int(rank)
|
174 |
+
for token, rank in (line.split() for line in open(vocab_path) if line)
|
175 |
+
}
|
176 |
+
n_vocab = len(ranks)
|
177 |
+
special_tokens = {}
|
178 |
+
|
179 |
+
specials = [
|
180 |
+
"<|endoftext|>",
|
181 |
+
"<|startoftranscript|>",
|
182 |
+
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
183 |
+
*[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
|
184 |
+
*[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
|
185 |
+
"<|translate|>",
|
186 |
+
"<|transcribe|>",
|
187 |
+
"<|startoflm|>",
|
188 |
+
"<|startofprev|>",
|
189 |
+
"<|nospeech|>",
|
190 |
+
"<|notimestamps|>",
|
191 |
+
*[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
|
192 |
+
*[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
|
193 |
+
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
194 |
+
]
|
195 |
+
|
196 |
+
for token in specials:
|
197 |
+
special_tokens[token] = n_vocab
|
198 |
+
n_vocab += 1
|
199 |
+
|
200 |
+
return tiktoken.Encoding(
|
201 |
+
name=os.path.basename(vocab_path),
|
202 |
+
explicit_n_vocab=n_vocab,
|
203 |
+
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
204 |
+
mergeable_ranks=ranks,
|
205 |
+
special_tokens=special_tokens,
|
206 |
+
)
|
207 |
+
|
208 |
+
|
209 |
+
@lru_cache(maxsize=None)
|
210 |
+
def get_tokenizer(
|
211 |
+
multilingual: bool,
|
212 |
+
*,
|
213 |
+
num_languages: int = 99,
|
214 |
+
language: Optional[str] = None,
|
215 |
+
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
216 |
+
) -> Tokenizer:
|
217 |
+
if language is not None:
|
218 |
+
language = language.lower()
|
219 |
+
if language not in LANGUAGES:
|
220 |
+
if language in TO_LANGUAGE_CODE:
|
221 |
+
language = TO_LANGUAGE_CODE[language]
|
222 |
+
else:
|
223 |
+
raise ValueError(f"Unsupported language: {language}")
|
224 |
+
|
225 |
+
if multilingual:
|
226 |
+
encoding_name = "multilingual_zh_ja_yue_char_del"
|
227 |
+
language = language or "en"
|
228 |
+
task = task or "transcribe"
|
229 |
+
else:
|
230 |
+
encoding_name = "gpt2"
|
231 |
+
language = None
|
232 |
+
task = None
|
233 |
+
|
234 |
+
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
235 |
+
|
236 |
+
return Tokenizer(
|
237 |
+
encoding=encoding, num_languages=num_languages, language=language, task=task
|
238 |
+
)
|
239 |
+
|
240 |
+
|
241 |
+
class QwenTokenizer():
|
242 |
+
def __init__(self, token_path, skip_special_tokens=True):
|
243 |
+
super().__init__()
|
244 |
+
# NOTE: non-chat model, all these special tokens keep randomly initialized.
|
245 |
+
special_tokens = {
|
246 |
+
'eos_token': '<|endoftext|>',
|
247 |
+
'pad_token': '<|endoftext|>',
|
248 |
+
'additional_special_tokens': [
|
249 |
+
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
250 |
+
'[breath]', '<strong>', '</strong>', '[noise]',
|
251 |
+
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
252 |
+
'[quick_breath]',
|
253 |
+
"<laughter>", "</laughter>",
|
254 |
+
"[hissing]", "[sigh]", "[vocalized-noise]",
|
255 |
+
"[lipsmack]", "[mn]"
|
256 |
+
]
|
257 |
+
}
|
258 |
+
self.special_tokens = special_tokens
|
259 |
+
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
|
260 |
+
self.tokenizer.add_special_tokens(special_tokens)
|
261 |
+
self.skip_special_tokens = skip_special_tokens
|
262 |
+
|
263 |
+
def encode(self, text, **kwargs):
|
264 |
+
tokens = self.tokenizer([text], return_tensors="pt")
|
265 |
+
tokens = tokens["input_ids"][0].cpu().tolist()
|
266 |
+
return tokens
|
267 |
+
|
268 |
+
def decode(self, tokens):
|
269 |
+
tokens = torch.tensor(tokens, dtype=torch.int64)
|
270 |
+
text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
|
271 |
+
return text
|
272 |
+
|
273 |
+
|
274 |
+
@lru_cache(maxsize=None)
|
275 |
+
def get_qwen_tokenizer(
|
276 |
+
token_path: str,
|
277 |
+
skip_special_tokens: bool
|
278 |
+
) -> QwenTokenizer:
|
279 |
+
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
|