R1ckShi commited on
Commit
5b4c852
1 Parent(s): d254666
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +231 -14
  2. app.py +211 -0
  3. cert.pem +32 -0
  4. cosyvoice/__init__.py +0 -0
  5. cosyvoice/bin/average_model.py +92 -0
  6. cosyvoice/bin/convert.py +168 -0
  7. cosyvoice/bin/export_jit.py +74 -0
  8. cosyvoice/bin/export_jit_cosyvoice2.py +60 -0
  9. cosyvoice/bin/export_onnx.py +112 -0
  10. cosyvoice/bin/export_onnx_cosyvoice2.py +110 -0
  11. cosyvoice/bin/export_trt_cosyvoce2.sh +3 -0
  12. cosyvoice/bin/inference.py +115 -0
  13. cosyvoice/bin/train.py +170 -0
  14. cosyvoice/cli/__init__.py +0 -0
  15. cosyvoice/cli/cosyvoice.py +167 -0
  16. cosyvoice/cli/frontend.py +213 -0
  17. cosyvoice/cli/model.py +421 -0
  18. cosyvoice/dataset/__init__.py +0 -0
  19. cosyvoice/dataset/dataset.py +164 -0
  20. cosyvoice/dataset/processor.py +431 -0
  21. cosyvoice/flow/decoder.py +299 -0
  22. cosyvoice/flow/flow.py +232 -0
  23. cosyvoice/flow/flow_matching.py +235 -0
  24. cosyvoice/flow/length_regulator.py +69 -0
  25. cosyvoice/hifigan/discriminator.py +140 -0
  26. cosyvoice/hifigan/f0_predictor.py +55 -0
  27. cosyvoice/hifigan/generator.py +411 -0
  28. cosyvoice/hifigan/hifigan.py +67 -0
  29. cosyvoice/llm/llm.py +340 -0
  30. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +0 -0
  31. cosyvoice/tokenizer/tokenizer.py +277 -0
  32. cosyvoice/transformer/__init__.py +0 -0
  33. cosyvoice/transformer/activation.py +84 -0
  34. cosyvoice/transformer/attention.py +330 -0
  35. cosyvoice/transformer/convolution.py +145 -0
  36. cosyvoice/transformer/decoder.py +396 -0
  37. cosyvoice/transformer/decoder_layer.py +132 -0
  38. cosyvoice/transformer/embedding.py +294 -0
  39. cosyvoice/transformer/encoder.py +474 -0
  40. cosyvoice/transformer/encoder_layer.py +236 -0
  41. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  42. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  43. cosyvoice/transformer/subsampling.py +383 -0
  44. cosyvoice/transformer/upsample_encoder.py +322 -0
  45. cosyvoice/utils/__init__.py +0 -0
  46. cosyvoice/utils/class_utils.py +70 -0
  47. cosyvoice/utils/common.py +166 -0
  48. cosyvoice/utils/executor.py +172 -0
  49. cosyvoice/utils/file_utils.py +51 -0
  50. cosyvoice/utils/frontend_utils.py +129 -0
README.md CHANGED
@@ -1,14 +1,231 @@
1
- ---
2
- title: CosyVoice2.0
3
- emoji: 👀
4
- colorFrom: indigo
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.9.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Experience the open-source LLM based speech synthesis model!
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=CosyVoice🤠&text2=Text-to-Speech%20💖%20Large%20Language%20Model&width=800&height=210)](https://github.com/Akshay090/svg-banners)
2
+
3
+ ## 👉🏻 CosyVoice 👈🏻
4
+ **CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_2.pdf); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B)
5
+
6
+ **CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice-300M)
7
+
8
+ ## Highlight🔥
9
+
10
+ **CosyVoice 2.0** has been released! Compared to version 1.0, the new version offers more accurate, more stable, faster, and better speech generation capabilities.
11
+ ### Multilingual
12
+ - **Support Language**: Chinese, English, Japanese, Korean, Chinese dialects (Cantonese, Sichuanese, Shanghainese, Tianjinese, Wuhanese, etc.)
13
+ - **Crosslingual & Mixlingual**:Support zero-shot voice cloning for cross-lingual and code-switching scenarios.
14
+ ### Ultra-Low Latency
15
+ - **Bidirectional Streaming Support**: CosyVoice 2.0 integrates offline and streaming modeling technologies.
16
+ - **Rapid First Packet Synthesis**: Achieves latency as low as 150ms while maintaining high-quality audio output.
17
+ ### High Accuracy
18
+ - **Improved Pronunciation**: Reduces pronunciation errors by 30% to 50% compared to CosyVoice 1.0.
19
+ - **Benchmark Achievements**: Attains the lowest character error rate on the hard test set of the Seed-TTS evaluation set.
20
+ ### Strong Stability
21
+ - **Consistency in Timbre**: Ensures reliable voice consistency for zero-shot and cross-language speech synthesis.
22
+ - **Cross-language Synthesis**: Marked improvements compared to version 1.0.
23
+ ### Natural Experience
24
+ - **Enhanced Prosody and Sound Quality**: Improved alignment of synthesized audio, raising MOS evaluation scores from 5.4 to 5.53.
25
+ - **Emotional and Dialectal Flexibility**: Now supports more granular emotional controls and accent adjustments.
26
+
27
+ ## Roadmap
28
+
29
+ - [x] 2024/12
30
+
31
+ - [x] 25hz cosyvoice 2.0 released
32
+
33
+ - [x] 2024/09
34
+
35
+ - [x] 25hz cosyvoice base model
36
+ - [x] 25hz cosyvoice voice conversion model
37
+
38
+ - [x] 2024/08
39
+
40
+ - [x] Repetition Aware Sampling(RAS) inference for llm stability
41
+ - [x] Streaming inference mode support, including kv cache and sdpa for rtf optimization
42
+
43
+ - [x] 2024/07
44
+
45
+ - [x] Flow matching training support
46
+ - [x] WeTextProcessing support when ttsfrd is not avaliable
47
+ - [x] Fastapi server and client
48
+
49
+
50
+ ## Install
51
+
52
+ **Clone and install**
53
+
54
+ - Clone the repo
55
+ ``` sh
56
+ git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
57
+ # If you failed to clone submodule due to network failures, please run following command until success
58
+ cd CosyVoice
59
+ git submodule update --init --recursive
60
+ ```
61
+
62
+ - Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
63
+ - Create Conda env:
64
+
65
+ ``` sh
66
+ conda create -n cosyvoice python=3.8
67
+ conda activate cosyvoice
68
+ # pynini is required by WeTextProcessing, use conda to install it as it can be executed on all platform.
69
+ conda install -y -c conda-forge pynini==2.1.5
70
+ pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
71
+
72
+ # If you encounter sox compatibility issues
73
+ # ubuntu
74
+ sudo apt-get install sox libsox-dev
75
+ # centos
76
+ sudo yum install sox sox-devel
77
+ ```
78
+
79
+ **Model download**
80
+
81
+ We strongly recommend that you download our pretrained `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
82
+
83
+ If you are expert in this field, and you are only interested in training your own CosyVoice model from scratch, you can skip this step.
84
+
85
+ ``` python
86
+ # SDK模型下载
87
+ from modelscope import snapshot_download
88
+ snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
89
+ snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
90
+ snapshot_download('iic/CosyVoice-300M-25Hz', local_dir='pretrained_models/CosyVoice-300M-25Hz')
91
+ snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
92
+ snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
93
+ snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
94
+ ```
95
+
96
+ ``` sh
97
+ # git模型下载,请确保已安装git lfs
98
+ mkdir -p pretrained_models
99
+ git clone https://www.modelscope.cn/iic/CosyVoice2-0.5B.git pretrained_models/CosyVoice2-0.5B
100
+ git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M
101
+ git clone https://www.modelscope.cn/iic/CosyVoice-300M-25Hz.git pretrained_models/CosyVoice-300M-25Hz
102
+ git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
103
+ git clone https://www.modelscope.cn/iic/CosyVoice-300M-Instruct.git pretrained_models/CosyVoice-300M-Instruct
104
+ git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
105
+ ```
106
+
107
+ Optionaly, you can unzip `ttsfrd` resouce and install `ttsfrd` package for better text normalization performance.
108
+
109
+ Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use WeTextProcessing by default.
110
+
111
+ ``` sh
112
+ cd pretrained_models/CosyVoice-ttsfrd/
113
+ unzip resource.zip -d .
114
+ pip install ttsfrd_dependency-0.1-py3-none-any.whl
115
+ pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
116
+ ```
117
+
118
+ **Basic Usage**
119
+
120
+ We strongly recommend using `CosyVoice2-0.5B` for better performance.
121
+ For zero_shot/cross_lingual inference, please use `CosyVoice-300M` model.
122
+ For sft inference, please use `CosyVoice-300M-SFT` model.
123
+ For instruct inference, please use `CosyVoice-300M-Instruct` model.
124
+ First, add `third_party/Matcha-TTS` to your `PYTHONPATH`.
125
+
126
+ ``` sh
127
+ export PYTHONPATH=third_party/Matcha-TTS
128
+ ```
129
+
130
+ ``` python
131
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
132
+ from cosyvoice.utils.file_utils import load_wav
133
+ import torchaudio
134
+ ```
135
+
136
+ **CosyVoice2 Usage**
137
+ ```python
138
+ cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False)
139
+
140
+ # zero_shot usage
141
+ prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
142
+ for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
143
+ torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
144
+
145
+ # instruct usage
146
+ for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
147
+ torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
148
+ ```
149
+
150
+ **CosyVoice Usage**
151
+ ```python
152
+ cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True)
153
+ # sft usage
154
+ print(cosyvoice.list_avaliable_spks())
155
+ # change stream=True for chunk stream inference
156
+ for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
157
+ torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
158
+
159
+ cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-25Hz') # or change to pretrained_models/CosyVoice-300M for 50Hz inference
160
+ # zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
161
+ prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
162
+ for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
163
+ torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
164
+ # cross_lingual usage
165
+ prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
166
+ for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
167
+ torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
168
+ # vc usage
169
+ prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
170
+ source_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
171
+ for i, j in enumerate(cosyvoice.inference_vc(source_speech_16k, prompt_speech_16k, stream=False)):
172
+ torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
173
+
174
+ cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
175
+ # instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
176
+ for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
177
+ torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
178
+ ```
179
+
180
+ **Start web demo**
181
+
182
+ You can use our web demo page to get familiar with CosyVoice quickly.
183
+ We support sft/zero_shot/cross_lingual/instruct inference in web demo.
184
+
185
+ Please see the demo website for details.
186
+
187
+ ``` python
188
+ # change iic/CosyVoice-300M-SFT for sft inference, or iic/CosyVoice-300M-Instruct for instruct inference
189
+ python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M
190
+ ```
191
+
192
+ **Advanced Usage**
193
+
194
+ For advanced user, we have provided train and inference scripts in `examples/libritts/cosyvoice/run.sh`.
195
+ You can get familiar with CosyVoice following this recipie.
196
+
197
+ **Build for deployment**
198
+
199
+ Optionally, if you want to use grpc for service deployment,
200
+ you can run following steps. Otherwise, you can just ignore this step.
201
+
202
+ ``` sh
203
+ cd runtime/python
204
+ docker build -t cosyvoice:v1.0 .
205
+ # change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
206
+ # for grpc usage
207
+ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
208
+ cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
209
+ # for fastapi usage
210
+ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && python3 server.py --port 50000 --model_dir iic/CosyVoice-300M && sleep infinity"
211
+ cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
212
+ ```
213
+
214
+ ## Discussion & Communication
215
+
216
+ You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
217
+
218
+ You can also scan the QR code to join our official Dingding chat group.
219
+
220
+ <img src="./asset/dingding.png" width="250px">
221
+
222
+ ## Acknowledge
223
+
224
+ 1. We borrowed a lot of code from [FunASR](https://github.com/modelscope/FunASR).
225
+ 2. We borrowed a lot of code from [FunCodec](https://github.com/modelscope/FunCodec).
226
+ 3. We borrowed a lot of code from [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS).
227
+ 4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
228
+ 5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
229
+
230
+ ## Disclaimer
231
+ The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Liu Yue)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ import argparse
17
+ import gradio as gr
18
+ import numpy as np
19
+ import torch
20
+ import torchaudio
21
+ import random
22
+ import librosa
23
+ from funasr import AutoModel
24
+ from funasr.utils.postprocess_utils import rich_transcription_postprocess
25
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
26
+ sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
27
+
28
+ from modelscope import snapshot_download
29
+ snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
30
+ snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
31
+ os.system('cd pretrained_models/CosyVoice-ttsfrd/ && pip install ttsfrd_dependency-0.1-py3-none-any.whl && pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl && apt install -y unzip && unzip resource.zip -d .')
32
+
33
+ from cosyvoice.cli.cosyvoice import CosyVoice2
34
+ from cosyvoice.utils.file_utils import load_wav, logging
35
+ from cosyvoice.utils.common import set_all_random_seed
36
+
37
+ inference_mode_list = ['3s极速复刻', '自然语言控制']
38
+ instruct_dict = {'3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
39
+ '自然语言控制': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入instruct文本\n3. 点击生成音频按钮'}
40
+ stream_mode_list = [('否', False), ('是', True)]
41
+ max_val = 0.8
42
+
43
+
44
+ def generate_seed():
45
+ seed = random.randint(1, 100000000)
46
+ return {
47
+ "__type__": "update",
48
+ "value": seed
49
+ }
50
+
51
+
52
+ def postprocess(speech, top_db=60, hop_length=220, win_length=440):
53
+ speech, _ = librosa.effects.trim(
54
+ speech, top_db=top_db,
55
+ frame_length=win_length,
56
+ hop_length=hop_length
57
+ )
58
+ if speech.abs().max() > max_val:
59
+ speech = speech / speech.abs().max() * max_val
60
+ speech = torch.concat([speech, torch.zeros(1, int(target_sr * 0.2))], dim=1)
61
+ return speech
62
+
63
+
64
+ def change_instruction(mode_checkbox_group):
65
+ return instruct_dict[mode_checkbox_group]
66
+
67
+ def prompt_wav_recognition(prompt_wav):
68
+ res = asr_model.generate(input=prompt_wav,
69
+ language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
70
+ use_itn=True,
71
+ )
72
+ text = res[0]["text"].split('|>')[-1]
73
+ return text
74
+
75
+ def generate_audio(tts_text, mode_checkbox_group, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
76
+ seed, stream):
77
+ sft_dropdown, speed = '', 1.0
78
+ if prompt_wav_upload is not None:
79
+ prompt_wav = prompt_wav_upload
80
+ elif prompt_wav_record is not None:
81
+ prompt_wav = prompt_wav_record
82
+ else:
83
+ prompt_wav = None
84
+ # if instruct mode, please make sure that model is iic/CosyVoice-300M-Instruct and not cross_lingual mode
85
+ if mode_checkbox_group in ['自然语言控制']:
86
+ if instruct_text == '':
87
+ gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本')
88
+ yield (target_sr, default_data)
89
+ if prompt_wav is None:
90
+ gr.Info('您正在使用自然语言控制模式, 请输入prompt音频')
91
+ # if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language
92
+ if mode_checkbox_group in ['跨语种复刻']:
93
+ if cosyvoice.frontend.instruct is True:
94
+ gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M模型'.format(args.model_dir))
95
+ yield (target_sr, default_data)
96
+ if instruct_text != '':
97
+ gr.Info('您正在使用跨语种复刻模式, instruct文本会被忽略')
98
+ if prompt_wav is None:
99
+ gr.Warning('您正在使用跨语种复刻模式, 请提供prompt音频')
100
+ yield (target_sr, default_data)
101
+ gr.Info('您正在使用跨语种复刻模式, 请确保合成文本和prompt文本为不同语言')
102
+ # if in zero_shot cross_lingual, please make sure that prompt_text and prompt_wav meets requirements
103
+ if mode_checkbox_group in ['3s极速复刻', '跨语种复刻']:
104
+ if prompt_wav is None:
105
+ gr.Warning('prompt音频为空,您是否忘记输入prompt音频?')
106
+ yield (target_sr, default_data)
107
+ if torchaudio.info(prompt_wav).sample_rate < prompt_sr:
108
+ gr.Warning('prompt音频采样率{}低于{}'.format(torchaudio.info(prompt_wav).sample_rate, prompt_sr))
109
+ yield (target_sr, default_data)
110
+ # sft mode only use sft_dropdown
111
+ if mode_checkbox_group in ['预训练音色']:
112
+ if instruct_text != '' or prompt_wav is not None or prompt_text != '':
113
+ gr.Info('您正在使用预训练音色模式,prompt文本/prompt音频/instruct文本会被忽略!')
114
+ # zero_shot mode only use prompt_wav prompt text
115
+ if mode_checkbox_group in ['3s极速复刻']:
116
+ if prompt_text == '':
117
+ gr.Warning('prompt文本为空,您是否忘记输入prompt文本?')
118
+ yield (target_sr, default_data)
119
+ if instruct_text != '':
120
+ gr.Info('您正在使用3s极速复刻模式,预训练音色/instruct文本会被忽略!')
121
+ info = torchaudio.info(prompt_wav)
122
+ if info.num_frames / info.sample_rate > 10:
123
+ gr.Warning('请限制输入音频在10s内,避免推理效果过低')
124
+ yield (target_sr, default_data)
125
+
126
+ if mode_checkbox_group == '预训练音色':
127
+ logging.info('get sft inference request')
128
+ set_all_random_seed(seed)
129
+ for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream, speed=speed):
130
+ yield (target_sr, i['tts_speech'].numpy().flatten())
131
+ elif mode_checkbox_group == '3s极速复刻':
132
+ logging.info('get zero_shot inference request')
133
+ prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
134
+ set_all_random_seed(seed)
135
+ for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed):
136
+ yield (target_sr, i['tts_speech'].numpy().flatten())
137
+ elif mode_checkbox_group == '跨语种复刻':
138
+ logging.info('get cross_lingual inference request')
139
+ prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
140
+ set_all_random_seed(seed)
141
+ for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed):
142
+ yield (target_sr, i['tts_speech'].numpy().flatten())
143
+ else:
144
+ logging.info('get instruct inference request')
145
+ logging.info('get instruct inference request')
146
+ prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
147
+ set_all_random_seed(seed)
148
+ for i in cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k, stream=stream, speed=speed):
149
+ yield (target_sr, i['tts_speech'].numpy().flatten())
150
+
151
+
152
+ def main():
153
+ with gr.Blocks() as demo:
154
+ gr.Markdown("### 代码库 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) \
155
+ 预训练模型 [CosyVoice2-0.5B](https://www.modelscope.cn/models/iic/CosyVoice2-0.5B) \
156
+ [CosyVoice-300M](https://www.modelscope.cn/models/iic/CosyVoice-300M) \
157
+ [CosyVoice-300M-Instruct](https://www.modelscope.cn/models/iic/CosyVoice-300M-Instruct) \
158
+ [CosyVoice-300M-SFT](https://www.modelscope.cn/models/iic/CosyVoice-300M-SFT)")
159
+ gr.Markdown("#### 请输入需要合成的文本,选择推理模式,并按照提示步骤进行操作")
160
+
161
+ tts_text = gr.Textbox(label="输入合成文本", lines=1, value="CosyVoice迎来全面升级,提供更准、更稳、更快、 更好的语音生成能力。CosyVoice is undergoing a comprehensive upgrade, providing more accurate, stable, faster, and better voice generation capabilities.")
162
+ with gr.Row():
163
+ mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
164
+ instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
165
+ stream = gr.Radio(choices=stream_mode_list, label='是否流式推理', value=stream_mode_list[0][1])
166
+ with gr.Column(scale=0.25):
167
+ seed_button = gr.Button(value="\U0001F3B2")
168
+ seed = gr.Number(value=0, label="随机推理种子")
169
+
170
+ with gr.Row():
171
+ prompt_wav_upload = gr.Audio(sources='upload', type='filepath', label='选择prompt音频文件,注意采样率不低于16khz')
172
+ prompt_wav_record = gr.Audio(sources='microphone', type='filepath', label='录制prompt音频文件')
173
+ prompt_text = gr.Textbox(label="prompt文本", lines=1, placeholder="请输入prompt文本,支持自动识别,您可以自行修正识别结果...", value='')
174
+ instruct_text = gr.Textbox(label="输入instruct文本", lines=1, placeholder="请输入instruct文本.例如:用四川话说这句话。", value='')
175
+
176
+ generate_button = gr.Button("生成音频")
177
+
178
+ audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True)
179
+
180
+ seed_button.click(generate_seed, inputs=[], outputs=seed)
181
+ generate_button.click(generate_audio,
182
+ inputs=[tts_text, mode_checkbox_group, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
183
+ seed, stream],
184
+ outputs=[audio_output])
185
+ mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
186
+ prompt_wav_upload.change(fn=prompt_wav_recognition, inputs=[prompt_wav_upload], outputs=[prompt_text])
187
+ prompt_wav_record.change(fn=prompt_wav_recognition, inputs=[prompt_wav_record], outputs=[prompt_text])
188
+ demo.queue(max_size=4, default_concurrency_limit=2).launch(server_port=50000)
189
+
190
+
191
+ if __name__ == '__main__':
192
+ load_jit = True if os.environ.get('jit') == '1' else False
193
+ load_onnx = True if os.environ.get('onnx') == '1' else False
194
+ load_trt = True if os.environ.get('trt') == '1' else False
195
+ logging.info('cosyvoice args load_jit {} load_onnx {} load_trt {}'.format(load_jit, load_onnx, load_trt))
196
+ cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=load_jit, load_onnx=load_onnx, load_trt=load_trt)
197
+ sft_spk = cosyvoice.list_avaliable_spks()
198
+ prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
199
+ for stream in [True, False]:
200
+ for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=stream)):
201
+ continue
202
+ prompt_sr, target_sr = 16000, 24000
203
+ default_data = np.zeros(target_sr)
204
+
205
+ model_dir = "iic/SenseVoiceSmall"
206
+ asr_model = AutoModel(
207
+ model=model_dir,
208
+ disable_update=True,
209
+ log_level='DEBUG',
210
+ device="cuda:0")
211
+ main()
cert.pem ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFkTCCA3mgAwIBAgIUEO2zq0OQeuRFIFH4lfHLgcR5hTUwDQYJKoZIhvcNAQEL
3
+ BQAwWDELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAlpKMQswCQYDVQQHDAJIWjEQMA4G
4
+ A1UECgwHY29tcGFueTEMMAoGA1UECwwDYWxpMQ8wDQYDVQQDDAZncmFkaW8wHhcN
5
+ MjQwNDI5MTIxNDQxWhcNMjUwNDI5MTIxNDQxWjBYMQswCQYDVQQGEwJDTjELMAkG
6
+ A1UECAwCWkoxCzAJBgNVBAcMAkhaMRAwDgYDVQQKDAdjb21wYW55MQwwCgYDVQQL
7
+ DANhbGkxDzANBgNVBAMMBmdyYWRpbzCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCC
8
+ AgoCggIBAKohzP3V7VdDyMgfRO4+xzh/mWFPapQWJIIrhnHj8GRJ9tgFVXf71vcU
9
+ PMo+/t+y0rjupw3WwWIj6kJP15t46xxmLzoJHZKHV7d1Y7XJTyN1hvRCzeGz6w/E
10
+ VX0y6U+0y9m1HG0kvsfLwCKZPxEN21RfPukGN3qOIpjaRvE6fxg8DCUQN8qEpjQ9
11
+ DQehq/g0B/wZFwIB2089+BeqesjaOinY2+z4YiMreIj2dy8XM6G59quS21oe0u5n
12
+ 6SW80ayf/yA6CHqblCHNfdi3vrzxMalNjT5EHKxQsLEDd2nWSndoPeXClXdSoIpE
13
+ 1+H86dWHZpzPLd6rOfa+FCZ3TQsZbL+p3ree2AIMIB7zWw59oKGE8UuZbtyCVWK6
14
+ hufMOs703ZT97WeBEoOA72itUwCBqAakYNoULvYSOuXZT0LvJN1Z4YLNTkJXDA0u
15
+ vMABPbRFXfFK67F/fLm/vges4dhhpQNeSxSuXEC7rMA5hCQRk3BccdEgxoBfNZcM
16
+ HKo8CaB3wxbK7inXZb3JD4sFK64H5VjfJE8ibFzoIhiPICuC+0bzSKfc0+dcUNMb
17
+ KsE5M3etmS1TcPKuebk9OTu8YUJiNMYgEInw7vCq004v4IOqQr0aX/LGRm21RB/i
18
+ M3qFKCSHSw5/Z+o9sZ/kw3AeNnx5r5dq4OAswx3RhScPJtd6qesZAgMBAAGjUzBR
19
+ MB0GA1UdDgQWBBSNZx2v1BNAGL4gGM4TUXIvn1OyFTAfBgNVHSMEGDAWgBSNZx2v
20
+ 1BNAGL4gGM4TUXIvn1OyFTAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUA
21
+ A4ICAQA/khg91VtI/tDLCLyQ6ZMulfOzJHuGmIs4cvG5fIOvzYjQpvAGSgNeivKp
22
+ +5RIkpUabcwdUCq6VXeXheo+SaGgVdwpxQy/p/E+i+AengRB5Qm/hJ5lLU6CdNBq
23
+ WCN/0Aa1GL/pM4HAzVQY81HeB46UaHWtW6J9hnBbVg2MF2GanAqfeODpZqIHEggt
24
+ Vw2ivElV47JTFZsNU+JYG5ECsfTjNQYpoA6Hyb/d5ZW8YsfOjr8oIBM4QyZWq1Ke
25
+ eAlytVwl9lj4AkAQIAgkrJHkLjj5yjZ7Hir5NjBuBx06FDAIFb2XWgNnq4ua/pSq
26
+ 9fL4cxx4cEJku1X/FYtUBbWsXe8uFGwTEGHuEZR3pj5VSFbuNlARLIsq8/gh8MRQ
27
+ NjKQIlTVINkuOFuVmSrLC5nIwTPhlpEFwIQPGzFD2DbVNor9EXQ2b89WtHqZAZik
28
+ qFDb76JM9jctf9n8l96oSKrwEaCoFmRojnyyYl9UByJxPRCeTJ//i2vxeTvLC3FT
29
+ Rw2jFi/pwoqSVmJtuAFLT96/x2qKpgk+M1zG3oFiDV1lxY8sw1RA3Mm4s3Cm8H5A
30
+ 3E+6R34XZLifqhxLVcyDsRWPcqte3Pt6v/xXWN+EuOigK4tr69p8aU7WR5mskmzO
31
+ tZFeEb0OxL1WjF/rmwCkd/SvSuWSiszMoX5hcOA7/GGw3pl3YQ==
32
+ -----END CERTIFICATE-----
cosyvoice/__init__.py ADDED
File without changes
cosyvoice/bin/average_model.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc (Di Wu)
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import argparse
18
+ import glob
19
+
20
+ import yaml
21
+ import torch
22
+
23
+
24
+ def get_args():
25
+ parser = argparse.ArgumentParser(description='average model')
26
+ parser.add_argument('--dst_model', required=True, help='averaged model')
27
+ parser.add_argument('--src_path',
28
+ required=True,
29
+ help='src model path for average')
30
+ parser.add_argument('--val_best',
31
+ action="store_true",
32
+ help='averaged model')
33
+ parser.add_argument('--num',
34
+ default=5,
35
+ type=int,
36
+ help='nums for averaged model')
37
+
38
+ args = parser.parse_args()
39
+ print(args)
40
+ return args
41
+
42
+
43
+ def main():
44
+ args = get_args()
45
+ val_scores = []
46
+ if args.val_best:
47
+ yamls = glob.glob('{}/*.yaml'.format(args.src_path))
48
+ yamls = [
49
+ f for f in yamls
50
+ if not (os.path.basename(f).startswith('train')
51
+ or os.path.basename(f).startswith('init'))
52
+ ]
53
+ for y in yamls:
54
+ with open(y, 'r') as f:
55
+ dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
56
+ loss = float(dic_yaml['loss_dict']['loss'])
57
+ epoch = int(dic_yaml['epoch'])
58
+ step = int(dic_yaml['step'])
59
+ tag = dic_yaml['tag']
60
+ val_scores += [[epoch, step, loss, tag]]
61
+ sorted_val_scores = sorted(val_scores,
62
+ key=lambda x: x[2],
63
+ reverse=False)
64
+ print("best val (epoch, step, loss, tag) = " +
65
+ str(sorted_val_scores[:args.num]))
66
+ path_list = [
67
+ args.src_path + '/epoch_{}_whole.pt'.format(score[0])
68
+ for score in sorted_val_scores[:args.num]
69
+ ]
70
+ print(path_list)
71
+ avg = {}
72
+ num = args.num
73
+ assert num == len(path_list)
74
+ for path in path_list:
75
+ print('Processing {}'.format(path))
76
+ states = torch.load(path, map_location=torch.device('cpu'))
77
+ for k in states.keys():
78
+ if k not in avg.keys():
79
+ avg[k] = states[k].clone()
80
+ else:
81
+ avg[k] += states[k]
82
+ # average
83
+ for k in avg.keys():
84
+ if avg[k] is not None:
85
+ # pytorch 1.6 use true_divide instead of /=
86
+ avg[k] = torch.true_divide(avg[k], num)
87
+ print('Saving to {}'.format(args.dst_model))
88
+ torch.save(avg, args.dst_model)
89
+
90
+
91
+ if __name__ == '__main__':
92
+ main()
cosyvoice/bin/convert.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+
4
+ def convert_llm(state_dict):
5
+ # 调整了lm的结构,把codec_lm.encoder作为llm,codec_lm.decoder作为decoder
6
+ keys = list(state_dict.keys())
7
+ for k in keys:
8
+ if k.startswith('codec_lm.encoder.'):
9
+ v = state_dict.pop(k)
10
+ k = k.replace('codec_lm.encoder.', 'llm.')
11
+ state_dict[k] = v
12
+ if k.startswith('codec_lm.decoder.'):
13
+ v = state_dict.pop(k)
14
+ k = k.replace('codec_lm.decoder.', 'llm_decoder.')
15
+ state_dict[k] = v
16
+ # espnet和wenet具体实现上的差异
17
+ keys = list(state_dict.keys())
18
+ for k in keys:
19
+ if k.startswith('text_encoder.embed.'):
20
+ v = state_dict.pop(k)
21
+ k = k.replace('text_encoder.embed.', 'text_encoder.embed.out.')
22
+ state_dict[k] = v
23
+ if k.startswith('llm.embed.'):
24
+ v = state_dict.pop(k)
25
+ k = k.replace('llm.embed.', 'llm.embed.out.')
26
+ state_dict[k] = v
27
+ keys = list(state_dict.keys())
28
+ for k in keys:
29
+ if k.startswith('text_enc_out_layer.'):
30
+ v = state_dict.pop(k)
31
+ k = k.replace('text_enc_out_layer.', 'text_encoder_affine_layer.')
32
+ state_dict[k] = v
33
+ if k.startswith('token_embedding.'):
34
+ v = state_dict.pop(k)
35
+ k = k.replace('token_embedding.', 'text_embedding.')
36
+ state_dict[k] = v
37
+ if k.startswith('xvec_proj.'):
38
+ v = state_dict.pop(k)
39
+ k = k.replace('xvec_proj.', 'spk_embed_affine_layer.')
40
+ state_dict[k] = v
41
+ if k.startswith('lm_embedding.'):
42
+ v = state_dict.pop(k)
43
+ k = k.replace('lm_embedding.', 'llm_embedding.')
44
+ state_dict[k] = v
45
+ if k.startswith('codec_embedder.'):
46
+ v = state_dict.pop(k)
47
+ k = k.replace('codec_embedder.', 'speech_embedding.')
48
+ state_dict[k] = v
49
+ # instruct少了spk embedding参数,加个全0上去
50
+ keys = list(state_dict.keys())
51
+ if 'spk_embed_affine_layer.weight' not in keys:
52
+ print('no spk_embed_affine_layer.weight, should be instruct model')
53
+ state_dict['spk_embed_affine_layer.weight'] = torch.zeros(1024, 192)
54
+ if 'spk_embed_affine_layer.bias' not in keys:
55
+ print('no spk_embed_affine_layer.bias, should be instruct model')
56
+ state_dict['spk_embed_affine_layer.bias'] = torch.zeros(1024)
57
+ return state_dict
58
+
59
+ def convert_hift(state_dict):
60
+ # 调整了cosyvoice中hifigan的结构,把f0_predictor放到generator里
61
+ keys = list(state_dict.keys())
62
+ for k in keys:
63
+ if k.startswith('decoder.'):
64
+ v = state_dict.pop(k)
65
+ k = k.replace('decoder.', '')
66
+ state_dict[k] = v
67
+ if k.startswith('generator.'):
68
+ v = state_dict.pop(k)
69
+ k = k.replace('generator.', '')
70
+ state_dict[k] = v
71
+ return state_dict
72
+
73
+ def convert_flow(state_dict):
74
+ keys = list(state_dict.keys())
75
+ for k in keys:
76
+ if k.startswith('encoder.embed.'):
77
+ v = state_dict.pop(k)
78
+ k = k.replace('encoder.embed.', 'encoder.embed.out.')
79
+ state_dict[k] = v
80
+ for k in keys:
81
+ if k.startswith('xvec_proj.'):
82
+ v = state_dict.pop(k)
83
+ k = k.replace('xvec_proj.', 'spk_embed_affine_layer.')
84
+ state_dict[k] = v
85
+ return state_dict
86
+
87
+ def convert_llm2(state_dict):
88
+ # 调整了lm的结构,把codec_lm.encoder作为llm,codec_lm.decoder作为decoder
89
+ keys = list(state_dict.keys())
90
+ for k in keys:
91
+ if k.startswith('codec_lm.encoder.'):
92
+ v = state_dict.pop(k)
93
+ k = k.replace('codec_lm.encoder.', 'llm.')
94
+ state_dict[k] = v
95
+ if k.startswith('codec_lm.decoder.'):
96
+ v = state_dict.pop(k)
97
+ k = k.replace('codec_lm.decoder.', 'llm_decoder.')
98
+ state_dict[k] = v
99
+ if k.startswith('lm_embedding.'):
100
+ v = state_dict.pop(k)
101
+ k = k.replace('lm_embedding.', 'llm_embedding.')
102
+ state_dict[k] = v
103
+ if k.startswith('codec_embedder.'):
104
+ v = state_dict.pop(k)
105
+ k = k.replace('codec_embedder.', 'speech_embedding.')
106
+ state_dict[k] = v
107
+ if k.startswith('text_enc_out_layer.'):
108
+ state_dict.pop(k)
109
+ if k.startswith('token_embedding.weight'):
110
+ state_dict.pop(k)
111
+ return state_dict
112
+
113
+ def convert_flow2(state_dict):
114
+ keys = list(state_dict.keys())
115
+ for k in keys:
116
+ if k.startswith('encoder.embed.'):
117
+ v = state_dict.pop(k)
118
+ k = k.replace('encoder.embed.', 'encoder.embed.out.')
119
+ state_dict[k] = v
120
+ for k in keys:
121
+ if k.startswith('xvec_proj.'):
122
+ v = state_dict.pop(k)
123
+ k = k.replace('xvec_proj.', 'spk_embed_affine_layer.')
124
+ state_dict[k] = v
125
+ for k in keys:
126
+ if k.startswith('mel_extractor.'):
127
+ state_dict.pop(k)
128
+ for k in keys:
129
+ if k.startswith('encoder.upsample_blocks.0.0.'):
130
+ v = state_dict.pop(k)
131
+ k = k.replace('encoder.upsample_blocks.0.0.', 'encoder.up_layer.')
132
+ state_dict[k] = v
133
+ if k.startswith('encoder.upsample_blocks.0.1.'):
134
+ v = state_dict.pop(k)
135
+ k = k.replace('encoder.upsample_blocks.0.1.', 'encoder.up_embed.out.')
136
+ state_dict[k] = v
137
+ if k.startswith('encoder.upsample_blocks.0.2.'):
138
+ v = state_dict.pop(k)
139
+ k = k.replace('encoder.upsample_blocks.0.2.', 'encoder.up_encoders.')
140
+ state_dict[k] = v
141
+ # CausalBlock1D中sequantial 1->2
142
+ if k.startswith('decoder.estimator.') and k.endswith('block.1.weight'):
143
+ v = state_dict.pop(k)
144
+ k = k.replace('block.1.weight', 'block.2.weight')
145
+ state_dict[k] = v
146
+ if k.startswith('decoder.estimator.') and k.endswith('block.1.bias'):
147
+ v = state_dict.pop(k)
148
+ k = k.replace('block.1.bias', 'block.2.bias')
149
+ state_dict[k] = v
150
+ return state_dict
151
+
152
+ if __name__ == '__main__':
153
+ # 使用方法 python3 convert.py 原格式llm.pt llm normalize 新格式llm.pt
154
+ # 或者 python3 convert.py 新格式llm.pt llm inverse_normalize 原格式llm.pt
155
+ state_dict = torch.load(sys.argv[1], map_location='cpu')
156
+ if sys.argv[2] == 'llm':
157
+ state_dict = convert_llm(state_dict)
158
+ elif sys.argv[2] == 'flow':
159
+ state_dict = convert_flow(state_dict)
160
+ elif sys.argv[2] == 'hift':
161
+ state_dict = convert_hift(state_dict)
162
+ elif sys.argv[2] == 'llm2':
163
+ state_dict = convert_llm2(state_dict)
164
+ elif sys.argv[2] == 'flow2':
165
+ state_dict = convert_flow2(state_dict)
166
+ else:
167
+ raise ValueError
168
+ torch.save(state_dict, sys.argv[4])
cosyvoice/bin/export_jit.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+ import sys
22
+ import torch
23
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
24
+ sys.path.append('{}/../..'.format(ROOT_DIR))
25
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
26
+ from cosyvoice.cli.cosyvoice import CosyVoice
27
+
28
+
29
+ def get_args():
30
+ parser = argparse.ArgumentParser(description='export your model for deployment')
31
+ parser.add_argument('--model_dir',
32
+ type=str,
33
+ default='pretrained_models/CosyVoice-300M',
34
+ help='local path')
35
+ args = parser.parse_args()
36
+ print(args)
37
+ return args
38
+
39
+
40
+ def main():
41
+ args = get_args()
42
+ logging.basicConfig(level=logging.DEBUG,
43
+ format='%(asctime)s %(levelname)s %(message)s')
44
+
45
+ torch._C._jit_set_fusion_strategy([('STATIC', 1)])
46
+ torch._C._jit_set_profiling_mode(False)
47
+ torch._C._jit_set_profiling_executor(False)
48
+
49
+ cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
50
+
51
+ # 1. export llm text_encoder
52
+ llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
53
+ script = torch.jit.script(llm_text_encoder)
54
+ script = torch.jit.freeze(script)
55
+ script = torch.jit.optimize_for_inference(script)
56
+ script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
57
+
58
+ # 2. export llm llm
59
+ llm_llm = cosyvoice.model.llm.llm.half()
60
+ script = torch.jit.script(llm_llm)
61
+ script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
62
+ script = torch.jit.optimize_for_inference(script)
63
+ script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
64
+
65
+ # 3. export flow encoder
66
+ flow_encoder = cosyvoice.model.flow.encoder
67
+ script = torch.jit.script(flow_encoder)
68
+ script = torch.jit.freeze(script)
69
+ script = torch.jit.optimize_for_inference(script)
70
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
71
+
72
+
73
+ if __name__ == '__main__':
74
+ main()
cosyvoice/bin/export_jit_cosyvoice2.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+ import sys
22
+ import torch
23
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
24
+ sys.path.append('{}/../..'.format(ROOT_DIR))
25
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
26
+ from cosyvoice.cli.cosyvoice import CosyVoice2
27
+
28
+
29
+ def get_args():
30
+ parser = argparse.ArgumentParser(description='export your model for deployment')
31
+ parser.add_argument('--model_dir',
32
+ type=str,
33
+ default='pretrained_models/CosyVoice-300M',
34
+ help='local path')
35
+ args = parser.parse_args()
36
+ print(args)
37
+ return args
38
+
39
+
40
+ def main():
41
+ args = get_args()
42
+ logging.basicConfig(level=logging.DEBUG,
43
+ format='%(asctime)s %(levelname)s %(message)s')
44
+
45
+ torch._C._jit_set_fusion_strategy([('STATIC', 1)])
46
+ torch._C._jit_set_profiling_mode(False)
47
+ torch._C._jit_set_profiling_executor(False)
48
+
49
+ cosyvoice = CosyVoice2(args.model_dir, load_jit=False, load_onnx=False)
50
+
51
+ # 3. export flow encoder
52
+ flow_encoder = cosyvoice.model.flow.encoder
53
+ script = torch.jit.script(flow_encoder)
54
+ script = torch.jit.freeze(script)
55
+ script = torch.jit.optimize_for_inference(script)
56
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
57
+
58
+
59
+ if __name__ == '__main__':
60
+ main()
cosyvoice/bin/export_onnx.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, [email protected])
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import print_function
17
+
18
+ import argparse
19
+ import logging
20
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
21
+ import os
22
+ import sys
23
+ import onnxruntime
24
+ import random
25
+ import torch
26
+ from tqdm import tqdm
27
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
28
+ sys.path.append('{}/../..'.format(ROOT_DIR))
29
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
30
+ from cosyvoice.cli.cosyvoice import CosyVoice
31
+
32
+
33
+ def get_dummy_input(batch_size, seq_len, out_channels, device):
34
+ x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
35
+ mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
36
+ mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
37
+ t = torch.rand((batch_size), dtype=torch.float32, device=device)
38
+ spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
39
+ cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
40
+ return x, mask, mu, t, spks, cond
41
+
42
+
43
+ def get_args():
44
+ parser = argparse.ArgumentParser(description='export your model for deployment')
45
+ parser.add_argument('--model_dir',
46
+ type=str,
47
+ default='pretrained_models/CosyVoice-300M',
48
+ help='local path')
49
+ args = parser.parse_args()
50
+ print(args)
51
+ return args
52
+
53
+
54
+ def main():
55
+ args = get_args()
56
+ logging.basicConfig(level=logging.DEBUG,
57
+ format='%(asctime)s %(levelname)s %(message)s')
58
+
59
+ cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
60
+
61
+ # 1. export flow decoder estimator
62
+ estimator = cosyvoice.model.flow.decoder.estimator
63
+
64
+ device = cosyvoice.model.device
65
+ batch_size, seq_len = 1, 256
66
+ out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
67
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
68
+ torch.onnx.export(
69
+ estimator,
70
+ (x, mask, mu, t, spks, cond),
71
+ '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
72
+ export_params=True,
73
+ opset_version=18,
74
+ do_constant_folding=True,
75
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
76
+ output_names=['estimator_out'],
77
+ dynamic_axes={
78
+ 'x': {0: 'batch_size', 2: 'seq_len'},
79
+ 'mask': {0: 'batch_size', 2: 'seq_len'},
80
+ 'mu': {0: 'batch_size', 2: 'seq_len'},
81
+ 'cond': {0: 'batch_size', 2: 'seq_len'},
82
+ 't': {0: 'batch_size'},
83
+ 'spks': {0: 'batch_size'},
84
+ 'estimator_out': {0: 'batch_size', 2: 'seq_len'},
85
+ }
86
+ )
87
+
88
+ # 2. test computation consistency
89
+ option = onnxruntime.SessionOptions()
90
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
91
+ option.intra_op_num_threads = 1
92
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
93
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
94
+ sess_options=option, providers=providers)
95
+
96
+ for _ in tqdm(range(10)):
97
+ x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
98
+ output_pytorch = estimator(x, mask, mu, t, spks, cond)
99
+ ort_inputs = {
100
+ 'x': x.cpu().numpy(),
101
+ 'mask': mask.cpu().numpy(),
102
+ 'mu': mu.cpu().numpy(),
103
+ 't': t.cpu().numpy(),
104
+ 'spks': spks.cpu().numpy(),
105
+ 'cond': cond.cpu().numpy()
106
+ }
107
+ output_onnx = estimator_onnx.run(None, ort_inputs)[0]
108
+ torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
109
+
110
+
111
+ if __name__ == "__main__":
112
+ main()
cosyvoice/bin/export_onnx_cosyvoice2.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, [email protected])
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import print_function
17
+
18
+ import argparse
19
+ import logging
20
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
21
+ import os
22
+ import sys
23
+ import onnxruntime
24
+ import random
25
+ import torch
26
+ from tqdm import tqdm
27
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
28
+ sys.path.append('{}/../..'.format(ROOT_DIR))
29
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
30
+ from cosyvoice.cli.cosyvoice import CosyVoice2
31
+
32
+
33
+ def get_dummy_input(batch_size, seq_len, out_channels, device):
34
+ x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
35
+ mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
36
+ mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
37
+ t = torch.rand((batch_size), dtype=torch.float32, device=device)
38
+ spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
39
+ cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
40
+ return x, mask, mu, t, spks, cond
41
+
42
+
43
+ def get_args():
44
+ parser = argparse.ArgumentParser(description='export your model for deployment')
45
+ parser.add_argument('--model_dir',
46
+ type=str,
47
+ default='pretrained_models/CosyVoice-300M',
48
+ help='local path')
49
+ args = parser.parse_args()
50
+ print(args)
51
+ return args
52
+
53
+
54
+ def main():
55
+ args = get_args()
56
+ logging.basicConfig(level=logging.DEBUG,
57
+ format='%(asctime)s %(levelname)s %(message)s')
58
+
59
+ cosyvoice = CosyVoice2(args.model_dir, load_jit=False, load_onnx=False)
60
+
61
+ # 1. export flow decoder estimator
62
+ estimator = cosyvoice.model.flow.decoder.estimator
63
+
64
+ device = cosyvoice.model.device
65
+ batch_size, seq_len = 2, 320
66
+ out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
67
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
68
+ torch.onnx.export(
69
+ estimator,
70
+ (x, mask, mu, t, spks, cond),
71
+ '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
72
+ export_params=True,
73
+ opset_version=18,
74
+ do_constant_folding=True,
75
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
76
+ output_names=['estimator_out'],
77
+ dynamic_axes={
78
+ 'x': {2: 'seq_len'},
79
+ 'mask': {2: 'seq_len'},
80
+ 'mu': {2: 'seq_len'},
81
+ 'cond': {2: 'seq_len'},
82
+ 'estimator_out': {2: 'seq_len'},
83
+ }
84
+ )
85
+
86
+ # 2. test computation consistency
87
+ option = onnxruntime.SessionOptions()
88
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
89
+ option.intra_op_num_threads = 1
90
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
91
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
92
+ sess_options=option, providers=providers)
93
+
94
+ for _ in tqdm(range(10)):
95
+ x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
96
+ output_pytorch = estimator(x, mask, mu, t, spks, cond)
97
+ ort_inputs = {
98
+ 'x': x.cpu().numpy(),
99
+ 'mask': mask.cpu().numpy(),
100
+ 'mu': mu.cpu().numpy(),
101
+ 't': t.cpu().numpy(),
102
+ 'spks': spks.cpu().numpy(),
103
+ 'cond': cond.cpu().numpy()
104
+ }
105
+ output_onnx = estimator_onnx.run(None, ort_inputs)[0]
106
+ torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
107
+
108
+
109
+ if __name__ == "__main__":
110
+ main()
cosyvoice/bin/export_trt_cosyvoce2.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/bash
2
+ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/mnt/lyuxiang.lx/data/TensorRT-10.0.1.6-cu124/TensorRT-10.0.1.6/lib:/usr/local/cuda-12.4/lib64
3
+ /mnt/lyuxiang.lx/data/TensorRT-10.0.1.6-cu124/TensorRT-10.0.1.6/bin/trtexec --onnx=/mnt/lyuxiang.lx/CosyVoice_github/pretrained_models/CosyVoice2-0.5B/flow.decoder.estimator.fp32.onnx --saveEngine=/mnt/lyuxiang.lx/CosyVoice_github/pretrained_models/CosyVoice2-0.5B/flow.decoder.estimator.fp16.Volta.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw
cosyvoice/bin/inference.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+ import torch
22
+ from torch.utils.data import DataLoader
23
+ import torchaudio
24
+ from hyperpyyaml import load_hyperpyyaml
25
+ from tqdm import tqdm
26
+ from cosyvoice.cli.model import CosyVoiceModel
27
+ from cosyvoice.dataset.dataset import Dataset
28
+
29
+
30
+ def get_args():
31
+ parser = argparse.ArgumentParser(description='inference with your model')
32
+ parser.add_argument('--config', required=True, help='config file')
33
+ parser.add_argument('--prompt_data', required=True, help='prompt data file')
34
+ parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
35
+ parser.add_argument('--tts_text', required=True, help='tts input file')
36
+ parser.add_argument('--llm_model', required=True, help='llm model file')
37
+ parser.add_argument('--flow_model', required=True, help='flow model file')
38
+ parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
39
+ parser.add_argument('--gpu',
40
+ type=int,
41
+ default=-1,
42
+ help='gpu id for this rank, -1 for cpu')
43
+ parser.add_argument('--mode',
44
+ default='sft',
45
+ choices=['sft', 'zero_shot'],
46
+ help='inference mode')
47
+ parser.add_argument('--result_dir', required=True, help='asr result file')
48
+ args = parser.parse_args()
49
+ print(args)
50
+ return args
51
+
52
+
53
+ def main():
54
+ args = get_args()
55
+ logging.basicConfig(level=logging.DEBUG,
56
+ format='%(asctime)s %(levelname)s %(message)s')
57
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
58
+
59
+ # Init cosyvoice models from configs
60
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
61
+ device = torch.device('cuda' if use_cuda else 'cpu')
62
+ with open(args.config, 'r') as f:
63
+ configs = load_hyperpyyaml(f)
64
+
65
+ model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
66
+ model.load(args.llm_model, args.flow_model, args.hifigan_model)
67
+
68
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
69
+ tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
70
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
71
+
72
+ del configs
73
+ os.makedirs(args.result_dir, exist_ok=True)
74
+ fn = os.path.join(args.result_dir, 'wav.scp')
75
+ f = open(fn, 'w')
76
+ with torch.no_grad():
77
+ for _, batch in tqdm(enumerate(test_data_loader)):
78
+ utts = batch["utts"]
79
+ assert len(utts) == 1, "inference mode only support batchsize 1"
80
+ text_token = batch["text_token"].to(device)
81
+ text_token_len = batch["text_token_len"].to(device)
82
+ tts_index = batch["tts_index"]
83
+ tts_text_token = batch["tts_text_token"].to(device)
84
+ tts_text_token_len = batch["tts_text_token_len"].to(device)
85
+ speech_token = batch["speech_token"].to(device)
86
+ speech_token_len = batch["speech_token_len"].to(device)
87
+ speech_feat = batch["speech_feat"].to(device)
88
+ speech_feat_len = batch["speech_feat_len"].to(device)
89
+ utt_embedding = batch["utt_embedding"].to(device)
90
+ spk_embedding = batch["spk_embedding"].to(device)
91
+ if args.mode == 'sft':
92
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
93
+ 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
94
+ else:
95
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
96
+ 'prompt_text': text_token, 'prompt_text_len': text_token_len,
97
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
98
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
99
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
100
+ 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
101
+ tts_speeches = []
102
+ for model_output in model.tts(**model_input):
103
+ tts_speeches.append(model_output['tts_speech'])
104
+ tts_speeches = torch.concat(tts_speeches, dim=1)
105
+ tts_key = '{}_{}'.format(utts[0], tts_index[0])
106
+ tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
107
+ torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
108
+ f.write('{} {}\n'.format(tts_key, tts_fn))
109
+ f.flush()
110
+ f.close()
111
+ logging.info('Result wav.scp saved in {}'.format(fn))
112
+
113
+
114
+ if __name__ == '__main__':
115
+ main()
cosyvoice/bin/train.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+ import argparse
17
+ import datetime
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ from copy import deepcopy
21
+ import os
22
+ import torch
23
+ import torch.distributed as dist
24
+ import deepspeed
25
+
26
+ from hyperpyyaml import load_hyperpyyaml
27
+
28
+ from torch.distributed.elastic.multiprocessing.errors import record
29
+
30
+ from cosyvoice.utils.executor import Executor
31
+ from cosyvoice.utils.train_utils import (
32
+ init_distributed,
33
+ init_dataset_and_dataloader,
34
+ init_optimizer_and_scheduler,
35
+ init_summarywriter, save_model,
36
+ wrap_cuda_model, check_modify_and_save_config)
37
+
38
+
39
+ def get_args():
40
+ parser = argparse.ArgumentParser(description='training your network')
41
+ parser.add_argument('--train_engine',
42
+ default='torch_ddp',
43
+ choices=['torch_ddp', 'deepspeed'],
44
+ help='Engine for paralleled training')
45
+ parser.add_argument('--model', required=True, help='model which will be trained')
46
+ parser.add_argument('--config', required=True, help='config file')
47
+ parser.add_argument('--train_data', required=True, help='train data file')
48
+ parser.add_argument('--cv_data', required=True, help='cv data file')
49
+ parser.add_argument('--checkpoint', help='checkpoint model')
50
+ parser.add_argument('--model_dir', required=True, help='save model dir')
51
+ parser.add_argument('--tensorboard_dir',
52
+ default='tensorboard',
53
+ help='tensorboard log dir')
54
+ parser.add_argument('--ddp.dist_backend',
55
+ dest='dist_backend',
56
+ default='nccl',
57
+ choices=['nccl', 'gloo'],
58
+ help='distributed backend')
59
+ parser.add_argument('--num_workers',
60
+ default=0,
61
+ type=int,
62
+ help='num of subprocess workers for reading')
63
+ parser.add_argument('--prefetch',
64
+ default=100,
65
+ type=int,
66
+ help='prefetch number')
67
+ parser.add_argument('--pin_memory',
68
+ action='store_true',
69
+ default=False,
70
+ help='Use pinned memory buffers used for reading')
71
+ parser.add_argument('--use_amp',
72
+ action='store_true',
73
+ default=False,
74
+ help='Use automatic mixed precision training')
75
+ parser.add_argument('--deepspeed.save_states',
76
+ dest='save_states',
77
+ default='model_only',
78
+ choices=['model_only', 'model+optimizer'],
79
+ help='save model/optimizer states')
80
+ parser.add_argument('--timeout',
81
+ default=60,
82
+ type=int,
83
+ help='timeout (in seconds) of cosyvoice_join.')
84
+ parser = deepspeed.add_config_arguments(parser)
85
+ args = parser.parse_args()
86
+ return args
87
+
88
+
89
+ @record
90
+ def main():
91
+ args = get_args()
92
+ logging.basicConfig(level=logging.DEBUG,
93
+ format='%(asctime)s %(levelname)s %(message)s')
94
+ # gan train has some special initialization logic
95
+ gan = True if args.model == 'hifigan' else False
96
+
97
+ override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
98
+ if gan is True:
99
+ override_dict.pop('hift')
100
+ with open(args.config, 'r') as f:
101
+ configs = load_hyperpyyaml(f, overrides=override_dict)
102
+ if gan is True:
103
+ configs['train_conf'] = configs['train_conf_gan']
104
+ configs['train_conf'].update(vars(args))
105
+
106
+ # Init env for ddp
107
+ init_distributed(args)
108
+
109
+ # Get dataset & dataloader
110
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
111
+ init_dataset_and_dataloader(args, configs, gan)
112
+
113
+ # Do some sanity checks and save config to arsg.model_dir
114
+ configs = check_modify_and_save_config(args, configs)
115
+
116
+ # Tensorboard summary
117
+ writer = init_summarywriter(args)
118
+
119
+ # load checkpoint
120
+ model = configs[args.model]
121
+ start_step, start_epoch = 0, -1
122
+ if args.checkpoint is not None:
123
+ if os.path.exists(args.checkpoint):
124
+ state_dict = torch.load(args.checkpoint, map_location='cpu')
125
+ model.load_state_dict(state_dict, strict=False)
126
+ if 'step' in state_dict:
127
+ start_step = state_dict['step']
128
+ if 'epoch' in state_dict:
129
+ start_epoch = state_dict['epoch']
130
+ else:
131
+ logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
132
+
133
+ # Dispatch model from cpu to gpu
134
+ model = wrap_cuda_model(args, model)
135
+
136
+ # Get optimizer & scheduler
137
+ model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
138
+ scheduler.set_step(start_step)
139
+ if scheduler_d is not None:
140
+ scheduler_d.set_step(start_step)
141
+
142
+ # Save init checkpoints
143
+ info_dict = deepcopy(configs['train_conf'])
144
+ info_dict['step'] = start_step
145
+ info_dict['epoch'] = start_epoch
146
+ save_model(model, 'init', info_dict)
147
+
148
+ # Get executor
149
+ executor = Executor(gan=gan)
150
+ executor.step = start_step
151
+
152
+ # Init scaler, used for pytorch amp mixed precision training
153
+ scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
154
+ print('start step {} start epoch {}'.format(start_step, start_epoch))
155
+ # Start training loop
156
+ for epoch in range(start_epoch + 1, info_dict['max_epoch']):
157
+ executor.epoch = epoch
158
+ train_dataset.set_epoch(epoch)
159
+ dist.barrier()
160
+ group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
161
+ if gan is True:
162
+ executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
163
+ writer, info_dict, scaler, group_join)
164
+ else:
165
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
166
+ dist.destroy_process_group(group_join)
167
+
168
+
169
+ if __name__ == '__main__':
170
+ main()
cosyvoice/cli/__init__.py ADDED
File without changes
cosyvoice/cli/cosyvoice.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import time
16
+ from tqdm import tqdm
17
+ from hyperpyyaml import load_hyperpyyaml
18
+ from modelscope import snapshot_download
19
+ import torch
20
+ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
21
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
22
+ from cosyvoice.utils.file_utils import logging
23
+
24
+
25
+ class CosyVoice:
26
+
27
+ def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
28
+ instruct = True if '-Instruct' in model_dir else False
29
+ self.model_dir = model_dir
30
+ if not os.path.exists(model_dir):
31
+ model_dir = snapshot_download(model_dir)
32
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
33
+ configs = load_hyperpyyaml(f)
34
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
35
+ configs['feat_extractor'],
36
+ '{}/campplus.onnx'.format(model_dir),
37
+ '{}/speech_tokenizer_v1.onnx'.format(model_dir),
38
+ '{}/spk2info.pt'.format(model_dir),
39
+ instruct,
40
+ configs['allowed_special'])
41
+ self.sample_rate = configs['sample_rate']
42
+ if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
43
+ load_jit = False
44
+ fp16 = False
45
+ logging.warning('cpu do not support fp16 and jit, force set to False')
46
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
47
+ self.model.load('{}/llm.pt'.format(model_dir),
48
+ '{}/flow.pt'.format(model_dir),
49
+ '{}/hift.pt'.format(model_dir))
50
+ if load_jit:
51
+ self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
52
+ '{}/llm.llm.fp16.zip'.format(model_dir),
53
+ '{}/flow.encoder.fp32.zip'.format(model_dir))
54
+ if load_onnx:
55
+ self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
56
+ del configs
57
+
58
+ def list_avaliable_spks(self):
59
+ spks = list(self.frontend.spk2info.keys())
60
+ return spks
61
+
62
+ def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0):
63
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
64
+ model_input = self.frontend.frontend_sft(i, spk_id)
65
+ start_time = time.time()
66
+ logging.info('synthesis text {}'.format(i))
67
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
68
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
69
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
70
+ yield model_output
71
+ start_time = time.time()
72
+
73
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0):
74
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False)
75
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
76
+ if len(i) < 0.5 * len(prompt_text):
77
+ logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
78
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
79
+ start_time = time.time()
80
+ logging.info('synthesis text {}'.format(i))
81
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
82
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
83
+ logging.info('yield speech len {}, rtf {}, abs mean {}, std {}'.format(speech_len, (time.time() - start_time) / speech_len, model_output['tts_speech'].abs().mean(), model_output['tts_speech'].std()))
84
+ yield model_output
85
+ start_time = time.time()
86
+
87
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0):
88
+ if self.frontend.instruct is True:
89
+ raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
90
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
91
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
92
+ start_time = time.time()
93
+ logging.info('synthesis text {}'.format(i))
94
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
95
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
96
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
97
+ yield model_output
98
+ start_time = time.time()
99
+
100
+ def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0):
101
+ if self.frontend.instruct is False:
102
+ raise ValueError('{} do not support instruct inference'.format(self.model_dir))
103
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False)
104
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
105
+ model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
106
+ start_time = time.time()
107
+ logging.info('synthesis text {}'.format(i))
108
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
109
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
110
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
111
+ yield model_output
112
+ start_time = time.time()
113
+
114
+ def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0):
115
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
116
+ model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
117
+ start_time = time.time()
118
+ logging.info('synthesis text {}'.format(i))
119
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
120
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
121
+ logging.info('yield speech len {}, rtf {}, abs mean {}, std {}'.format(speech_len, (time.time() - start_time) / speech_len, model_output['tts_speech'].abs().mean(), model_output['tts_speech'].std()))
122
+ yield model_output
123
+ start_time = time.time()
124
+
125
+ def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
126
+ model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
127
+ start_time = time.time()
128
+ for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
129
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
130
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
131
+ yield model_output
132
+ start_time = time.time()
133
+
134
+ class CosyVoice2(CosyVoice):
135
+
136
+ def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
137
+ instruct = True if '-Instruct' in model_dir else False
138
+ self.model_dir = model_dir
139
+ if not os.path.exists(model_dir):
140
+ model_dir = snapshot_download(model_dir)
141
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
142
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
143
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
144
+ configs['feat_extractor'],
145
+ '{}/campplus.onnx'.format(model_dir),
146
+ '{}/speech_tokenizer_v2.onnx'.format(model_dir),
147
+ '{}/spk2info.pt'.format(model_dir),
148
+ instruct,
149
+ configs['allowed_special'])
150
+ self.sample_rate = configs['sample_rate']
151
+ if torch.cuda.is_available() is False and load_jit is True:
152
+ load_jit = False
153
+ logging.warning('cpu do not support jit, force set to False')
154
+ self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
155
+ self.model.load('{}/llm.pt'.format(model_dir),
156
+ '{}/flow.pt'.format(model_dir),
157
+ '{}/hift.pt'.format(model_dir))
158
+ if load_jit:
159
+ self.model.load_jit('{}/flow.encoder.fp32.zip'.format(model_dir))
160
+ if load_trt is True and load_onnx is True:
161
+ load_onnx = False
162
+ logging.warning('can not set both load_trt and load_onnx to True, force set load_onnx to False')
163
+ if load_onnx:
164
+ self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
165
+ if load_trt:
166
+ self.model.load_trt('{}/flow.decoder.estimator.fp16.A10.plan'.format(model_dir))
167
+ del configs
cosyvoice/cli/frontend.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import partial
15
+ import json
16
+ import onnxruntime
17
+ import torch
18
+ import numpy as np
19
+ import whisper
20
+ from typing import Callable
21
+ import torchaudio.compliance.kaldi as kaldi
22
+ import torchaudio
23
+ import os
24
+ import re
25
+ import inflect
26
+ try:
27
+ import ttsfrd
28
+ use_ttsfrd = True
29
+ except ImportError:
30
+ print("failed to import ttsfrd, use WeTextProcessing instead")
31
+ from tn.chinese.normalizer import Normalizer as ZhNormalizer
32
+ from tn.english.normalizer import Normalizer as EnNormalizer
33
+ use_ttsfrd = False
34
+ from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
35
+
36
+
37
+ class CosyVoiceFrontEnd:
38
+
39
+ def __init__(self,
40
+ get_tokenizer: Callable,
41
+ feat_extractor: Callable,
42
+ campplus_model: str,
43
+ speech_tokenizer_model: str,
44
+ spk2info: str = '',
45
+ instruct: bool = False,
46
+ allowed_special: str = 'all'):
47
+ self.tokenizer = get_tokenizer()
48
+ self.feat_extractor = feat_extractor
49
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
50
+ option = onnxruntime.SessionOptions()
51
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
52
+ option.intra_op_num_threads = 1
53
+ self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
54
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
55
+ providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
56
+ "CPUExecutionProvider"])
57
+ if os.path.exists(spk2info):
58
+ self.spk2info = torch.load(spk2info, map_location=self.device)
59
+ else:
60
+ self.spk2info = {}
61
+ self.instruct = instruct
62
+ self.allowed_special = allowed_special
63
+ self.inflect_parser = inflect.engine()
64
+ self.use_ttsfrd = use_ttsfrd
65
+ if self.use_ttsfrd:
66
+ self.frd = ttsfrd.TtsFrontendEngine()
67
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
68
+ assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
69
+ 'failed to initialize ttsfrd resource'
70
+ self.frd.set_lang_type('pinyinvg')
71
+ else:
72
+ self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
73
+ self.en_tn_model = EnNormalizer()
74
+
75
+ def _extract_text_token(self, text):
76
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
77
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
78
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
79
+ return text_token, text_token_len
80
+
81
+ def _extract_speech_token(self, speech):
82
+ assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
83
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
84
+ speech_token = self.speech_tokenizer_session.run(None,
85
+ {self.speech_tokenizer_session.get_inputs()[0].name:
86
+ feat.detach().cpu().numpy(),
87
+ self.speech_tokenizer_session.get_inputs()[1].name:
88
+ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
89
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
90
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
91
+ return speech_token, speech_token_len
92
+
93
+ def _extract_spk_embedding(self, speech):
94
+ feat = kaldi.fbank(speech,
95
+ num_mel_bins=80,
96
+ dither=0,
97
+ sample_frequency=16000)
98
+ feat = feat - feat.mean(dim=0, keepdim=True)
99
+ embedding = self.campplus_session.run(None,
100
+ {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
101
+ embedding = torch.tensor([embedding]).to(self.device)
102
+ return embedding
103
+
104
+ def _extract_speech_feat(self, speech):
105
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
106
+ speech_feat = speech_feat.unsqueeze(dim=0)
107
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
108
+ return speech_feat, speech_feat_len
109
+
110
+ def text_normalize(self, text, split=True):
111
+ text = text.strip()
112
+ if contains_chinese(text):
113
+ if self.use_ttsfrd:
114
+ texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
115
+ text = ''.join(texts)
116
+ else:
117
+ text = self.zh_tn_model.normalize(text)
118
+ text = text.replace("\n", "")
119
+ text = replace_blank(text)
120
+ text = replace_corner_mark(text)
121
+ text = text.replace(".", "。")
122
+ text = text.replace(" - ", ",")
123
+ text = remove_bracket(text)
124
+ text = re.sub(r'[,,、]+$', '。', text)
125
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
126
+ token_min_n=60, merge_len=20, comma_split=False))
127
+ else:
128
+ if self.use_ttsfrd:
129
+ texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
130
+ text = ''.join(texts)
131
+ else:
132
+ text = self.en_tn_model.normalize(text)
133
+ text = spell_out_number(text, self.inflect_parser)
134
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
135
+ token_min_n=60, merge_len=20, comma_split=False))
136
+ if split is False:
137
+ return text
138
+ return texts
139
+
140
+ def frontend_sft(self, tts_text, spk_id):
141
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
142
+ embedding = self.spk2info[spk_id]['embedding']
143
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
144
+ return model_input
145
+
146
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
147
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
148
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
149
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
150
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
151
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
152
+ if resample_rate == 24000:
153
+ # cosyvoice2, force speech_feat % speech_token = 2
154
+ token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
155
+ speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2* token_len
156
+ speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
157
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
158
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
159
+ 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
160
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
161
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
162
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
163
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
164
+ return model_input
165
+
166
+ def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
167
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
168
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(instruct_text + '<|endofprompt|>')
169
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
170
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
171
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
172
+ if resample_rate == 24000:
173
+ # cosyvoice2, force speech_feat % speech_token = 2
174
+ token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
175
+ speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2* token_len
176
+ speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
177
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
178
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
179
+ 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
180
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
181
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
182
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
183
+ return model_input
184
+
185
+ def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
186
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
187
+ # in cross lingual mode, we remove prompt in llm
188
+ del model_input['prompt_text']
189
+ del model_input['prompt_text_len']
190
+ del model_input['llm_prompt_speech_token']
191
+ del model_input['llm_prompt_speech_token_len']
192
+ return model_input
193
+
194
+ def frontend_instruct(self, tts_text, spk_id, instruct_text):
195
+ model_input = self.frontend_sft(tts_text, spk_id)
196
+ # in instruct mode, we remove spk_embedding in llm due to information leakage
197
+ del model_input['llm_embedding']
198
+ instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
199
+ model_input['prompt_text'] = instruct_text_token
200
+ model_input['prompt_text_len'] = instruct_text_token_len
201
+ return model_input
202
+
203
+ def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
204
+ prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
205
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
206
+ prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
207
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
208
+ source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
209
+ model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
210
+ 'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
211
+ 'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
212
+ 'flow_embedding': embedding}
213
+ return model_input
cosyvoice/cli/model.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import numpy as np
16
+ import threading
17
+ import time
18
+ from torch.nn import functional as F
19
+ from contextlib import nullcontext
20
+ import uuid
21
+ from cosyvoice.utils.common import fade_in_out
22
+
23
+
24
+ class CosyVoiceModel:
25
+
26
+ def __init__(self,
27
+ llm: torch.nn.Module,
28
+ flow: torch.nn.Module,
29
+ hift: torch.nn.Module,
30
+ fp16: bool):
31
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+ self.llm = llm
33
+ self.flow = flow
34
+ self.hift = hift
35
+ self.fp16 = fp16
36
+ self.token_min_hop_len = 2 * self.flow.input_frame_rate
37
+ self.token_max_hop_len = 4 * self.flow.input_frame_rate
38
+ self.token_overlap_len = 20
39
+ # mel fade in out
40
+ self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
41
+ self.mel_window = np.hamming(2 * self.mel_overlap_len)
42
+ # hift cache
43
+ self.mel_cache_len = 20
44
+ self.source_cache_len = int(self.mel_cache_len * 256)
45
+ # speech fade in out
46
+ self.speech_window = np.hamming(2 * self.source_cache_len)
47
+ # rtf and decoding related
48
+ self.stream_scale_factor = 1
49
+ assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
50
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
51
+ self.lock = threading.Lock()
52
+ # dict used to store session related variable
53
+ self.tts_speech_token_dict = {}
54
+ self.llm_end_dict = {}
55
+ self.mel_overlap_dict = {}
56
+ self.flow_cache_dict = {}
57
+ self.hift_cache_dict = {}
58
+
59
+ def load(self, llm_model, flow_model, hift_model):
60
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
61
+ self.llm.to(self.device).eval()
62
+ if self.fp16 is True:
63
+ self.llm.half()
64
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
65
+ self.flow.to(self.device).eval()
66
+ # in case hift_model is a hifigan model
67
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
68
+ self.hift.load_state_dict(hift_state_dict, strict=True)
69
+ self.hift.to(self.device).eval()
70
+
71
+ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
72
+ assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
73
+ llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
74
+ self.llm.text_encoder = llm_text_encoder
75
+ llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
76
+ self.llm.llm = llm_llm
77
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
78
+ self.flow.encoder = flow_encoder
79
+
80
+ def load_onnx(self, flow_decoder_estimator_model):
81
+ import onnxruntime
82
+ option = onnxruntime.SessionOptions()
83
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
84
+ option.intra_op_num_threads = 1
85
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
86
+ del self.flow.decoder.estimator
87
+ self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
88
+
89
+ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
90
+ if self.fp16 is True:
91
+ llm_embedding = llm_embedding.half()
92
+ with self.llm_context:
93
+ for i in self.llm.inference(text=text.to(self.device),
94
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
95
+ prompt_text=prompt_text.to(self.device),
96
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
97
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
98
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
99
+ embedding=llm_embedding.to(self.device)):
100
+ self.tts_speech_token_dict[uuid].append(i)
101
+ self.llm_end_dict[uuid] = True
102
+
103
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
104
+ tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
105
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
106
+ prompt_token=prompt_token.to(self.device),
107
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
108
+ prompt_feat=prompt_feat.to(self.device),
109
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
110
+ embedding=embedding.to(self.device),
111
+ flow_cache=self.flow_cache_dict[uuid])
112
+ self.flow_cache_dict[uuid] = flow_cache
113
+
114
+ # mel overlap fade in out
115
+ if self.mel_overlap_dict[uuid].shape[2] != 0:
116
+ tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
117
+ # append hift cache
118
+ if self.hift_cache_dict[uuid] is not None:
119
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
120
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
121
+ else:
122
+ hift_cache_source = torch.zeros(1, 1, 0)
123
+ # keep overlap mel and hift cache
124
+ if finalize is False:
125
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
126
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
127
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
128
+ if self.hift_cache_dict[uuid] is not None:
129
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
130
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
131
+ 'source': tts_source[:, :, -self.source_cache_len:],
132
+ 'speech': tts_speech[:, -self.source_cache_len:]}
133
+ tts_speech = tts_speech[:, :-self.source_cache_len]
134
+ else:
135
+ if speed != 1.0:
136
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
137
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
138
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
139
+ if self.hift_cache_dict[uuid] is not None:
140
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
141
+ return tts_speech
142
+
143
+ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
144
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
145
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
146
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
147
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
148
+ # this_uuid is used to track variables related to this inference thread
149
+ this_uuid = str(uuid.uuid1())
150
+ with self.lock:
151
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
152
+ self.hift_cache_dict[this_uuid] = None
153
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
154
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
155
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
156
+ p.start()
157
+ if stream is True:
158
+ token_hop_len = self.token_min_hop_len
159
+ while True:
160
+ time.sleep(0.1)
161
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
162
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
163
+ .unsqueeze(dim=0)
164
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
165
+ prompt_token=flow_prompt_speech_token,
166
+ prompt_feat=prompt_speech_feat,
167
+ embedding=flow_embedding,
168
+ uuid=this_uuid,
169
+ finalize=False)
170
+ yield {'tts_speech': this_tts_speech.cpu()}
171
+ with self.lock:
172
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
173
+ # increase token_hop_len for better speech quality
174
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
175
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
176
+ break
177
+ p.join()
178
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
179
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
180
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
181
+ prompt_token=flow_prompt_speech_token,
182
+ prompt_feat=prompt_speech_feat,
183
+ embedding=flow_embedding,
184
+ uuid=this_uuid,
185
+ finalize=True)
186
+ yield {'tts_speech': this_tts_speech.cpu()}
187
+ else:
188
+ # deal with all tokens
189
+ p.join()
190
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
191
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
192
+ prompt_token=flow_prompt_speech_token,
193
+ prompt_feat=prompt_speech_feat,
194
+ embedding=flow_embedding,
195
+ uuid=this_uuid,
196
+ finalize=True,
197
+ speed=speed)
198
+ yield {'tts_speech': this_tts_speech.cpu()}
199
+ with self.lock:
200
+ self.tts_speech_token_dict.pop(this_uuid)
201
+ self.llm_end_dict.pop(this_uuid)
202
+ self.mel_overlap_dict.pop(this_uuid)
203
+ self.hift_cache_dict.pop(this_uuid)
204
+
205
+ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
206
+ # this_uuid is used to track variables related to this inference thread
207
+ this_uuid = str(uuid.uuid1())
208
+ with self.lock:
209
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
210
+ self.hift_cache_dict[this_uuid] = None
211
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
212
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
213
+ if stream is True:
214
+ token_hop_len = self.token_min_hop_len
215
+ while True:
216
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
217
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
218
+ .unsqueeze(dim=0)
219
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
220
+ prompt_token=flow_prompt_speech_token,
221
+ prompt_feat=prompt_speech_feat,
222
+ embedding=flow_embedding,
223
+ uuid=this_uuid,
224
+ finalize=False)
225
+ yield {'tts_speech': this_tts_speech.cpu()}
226
+ with self.lock:
227
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
228
+ # increase token_hop_len for better speech quality
229
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
230
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
231
+ break
232
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
233
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid], dim=1).unsqueeze(dim=0)
234
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
235
+ prompt_token=flow_prompt_speech_token,
236
+ prompt_feat=prompt_speech_feat,
237
+ embedding=flow_embedding,
238
+ uuid=this_uuid,
239
+ finalize=True)
240
+ yield {'tts_speech': this_tts_speech.cpu()}
241
+ else:
242
+ # deal with all tokens
243
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
244
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
245
+ prompt_token=flow_prompt_speech_token,
246
+ prompt_feat=prompt_speech_feat,
247
+ embedding=flow_embedding,
248
+ uuid=this_uuid,
249
+ finalize=True,
250
+ speed=speed)
251
+ yield {'tts_speech': this_tts_speech.cpu()}
252
+ with self.lock:
253
+ self.tts_speech_token_dict.pop(this_uuid)
254
+ self.llm_end_dict.pop(this_uuid)
255
+ self.mel_overlap_dict.pop(this_uuid)
256
+ self.hift_cache_dict.pop(this_uuid)
257
+
258
+
259
+ class CosyVoice2Model:
260
+
261
+ def __init__(self,
262
+ llm: torch.nn.Module,
263
+ flow: torch.nn.Module,
264
+ hift: torch.nn.Module):
265
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
266
+ self.llm = llm
267
+ self.flow = flow
268
+ self.hift = hift
269
+ self.token_hop_len = 2 * self.flow.input_frame_rate
270
+ # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
271
+ self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
272
+ self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
273
+ # hift cache
274
+ self.mel_cache_len = 8
275
+ self.source_cache_len = int(self.mel_cache_len * 480)
276
+ # speech fade in out
277
+ self.speech_window = np.hamming(2 * self.source_cache_len)
278
+ # rtf and decoding related
279
+ self.stream_scale_factor = 1
280
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
281
+ self.lock = threading.Lock()
282
+ # dict used to store session related variable
283
+ self.tts_speech_token_dict = {}
284
+ self.llm_end_dict = {}
285
+ self.hift_cache_dict = {}
286
+
287
+ def load(self, llm_model, flow_model, hift_model):
288
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
289
+ self.llm.to(self.device).eval()
290
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
291
+ self.flow.to(self.device).eval()
292
+ self.flow.decoder.fp16 = False
293
+ # in case hift_model is a hifigan model
294
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
295
+ self.hift.load_state_dict(hift_state_dict, strict=True)
296
+ self.hift.to(self.device).eval()
297
+
298
+ def load_jit(self, flow_encoder_model):
299
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
300
+ self.flow.encoder = flow_encoder
301
+
302
+ def load_onnx(self, flow_decoder_estimator_model):
303
+ import onnxruntime
304
+ option = onnxruntime.SessionOptions()
305
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
306
+ option.intra_op_num_threads = 1
307
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
308
+ del self.flow.decoder.estimator
309
+ self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
310
+
311
+ def load_trt(self, flow_decoder_estimator_model):
312
+ del self.flow.decoder.estimator
313
+ import tensorrt as trt
314
+ with open(flow_decoder_estimator_model, 'rb') as f:
315
+ self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
316
+ self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
317
+ self.flow.decoder.fp16 = True
318
+
319
+ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
320
+ with self.llm_context:
321
+ for i in self.llm.inference(text=text.to(self.device),
322
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
323
+ prompt_text=prompt_text.to(self.device),
324
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
325
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
326
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
327
+ embedding=llm_embedding.to(self.device)):
328
+ self.tts_speech_token_dict[uuid].append(i)
329
+ self.llm_end_dict[uuid] = True
330
+
331
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
332
+ tts_mel, _ = self.flow.inference(token=token.to(self.device),
333
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
334
+ prompt_token=prompt_token.to(self.device),
335
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
336
+ prompt_feat=prompt_feat.to(self.device),
337
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
338
+ embedding=embedding.to(self.device),
339
+ finalize=finalize)
340
+ tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
341
+ # append hift cache
342
+ if self.hift_cache_dict[uuid] is not None:
343
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
344
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
345
+ else:
346
+ hift_cache_source = torch.zeros(1, 1, 0)
347
+ # keep overlap mel and hift cache
348
+ if finalize is False:
349
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
350
+ if self.hift_cache_dict[uuid] is not None:
351
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
352
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
353
+ 'source': tts_source[:, :, -self.source_cache_len:],
354
+ 'speech': tts_speech[:, -self.source_cache_len:]}
355
+ tts_speech = tts_speech[:, :-self.source_cache_len]
356
+ else:
357
+ if speed != 1.0:
358
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
359
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
360
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
361
+ if self.hift_cache_dict[uuid] is not None:
362
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
363
+ return tts_speech
364
+
365
+ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
366
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
367
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
368
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
369
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
370
+ # this_uuid is used to track variables related to this inference thread
371
+ this_uuid = str(uuid.uuid1())
372
+ with self.lock:
373
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
374
+ self.hift_cache_dict[this_uuid] = None
375
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
376
+ p.start()
377
+ if stream is True:
378
+ token_offset = 0
379
+ while True:
380
+ time.sleep(0.1)
381
+ if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
382
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]) \
383
+ .unsqueeze(dim=0)
384
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
385
+ prompt_token=flow_prompt_speech_token,
386
+ prompt_feat=prompt_speech_feat,
387
+ embedding=flow_embedding,
388
+ uuid=this_uuid,
389
+ token_offset=token_offset,
390
+ finalize=False)
391
+ token_offset += self.token_hop_len
392
+ yield {'tts_speech': this_tts_speech.cpu()}
393
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
394
+ break
395
+ p.join()
396
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
397
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
398
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
399
+ prompt_token=flow_prompt_speech_token,
400
+ prompt_feat=prompt_speech_feat,
401
+ embedding=flow_embedding,
402
+ uuid=this_uuid,
403
+ token_offset=token_offset,
404
+ finalize=True)
405
+ yield {'tts_speech': this_tts_speech.cpu()}
406
+ else:
407
+ # deal with all tokens
408
+ p.join()
409
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
410
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
411
+ prompt_token=flow_prompt_speech_token,
412
+ prompt_feat=prompt_speech_feat,
413
+ embedding=flow_embedding,
414
+ uuid=this_uuid,
415
+ token_offset=0,
416
+ finalize=True,
417
+ speed=speed)
418
+ yield {'tts_speech': this_tts_speech.cpu()}
419
+ with self.lock:
420
+ self.tts_speech_token_dict.pop(this_uuid)
421
+ self.llm_end_dict.pop(this_uuid)
cosyvoice/dataset/__init__.py ADDED
File without changes
cosyvoice/dataset/dataset.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import random
17
+ import json
18
+ import math
19
+ from functools import partial
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ from torch.utils.data import IterableDataset
24
+ from cosyvoice.utils.file_utils import read_lists, read_json_lists
25
+
26
+
27
+ class Processor(IterableDataset):
28
+
29
+ def __init__(self, source, f, *args, **kw):
30
+ assert callable(f)
31
+ self.source = source
32
+ self.f = f
33
+ self.args = args
34
+ self.kw = kw
35
+
36
+ def set_epoch(self, epoch):
37
+ self.source.set_epoch(epoch)
38
+
39
+ def __iter__(self):
40
+ """ Return an iterator over the source dataset processed by the
41
+ given processor.
42
+ """
43
+ assert self.source is not None
44
+ assert callable(self.f)
45
+ return self.f(iter(self.source), *self.args, **self.kw)
46
+
47
+ def apply(self, f):
48
+ assert callable(f)
49
+ return Processor(self, f, *self.args, **self.kw)
50
+
51
+
52
+ class DistributedSampler:
53
+
54
+ def __init__(self, shuffle=True, partition=True):
55
+ self.epoch = -1
56
+ self.update()
57
+ self.shuffle = shuffle
58
+ self.partition = partition
59
+
60
+ def update(self):
61
+ assert dist.is_available()
62
+ if dist.is_initialized():
63
+ self.rank = dist.get_rank()
64
+ self.world_size = dist.get_world_size()
65
+ else:
66
+ self.rank = 0
67
+ self.world_size = 1
68
+ worker_info = torch.utils.data.get_worker_info()
69
+ if worker_info is None:
70
+ self.worker_id = 0
71
+ self.num_workers = 1
72
+ else:
73
+ self.worker_id = worker_info.id
74
+ self.num_workers = worker_info.num_workers
75
+ return dict(rank=self.rank,
76
+ world_size=self.world_size,
77
+ worker_id=self.worker_id,
78
+ num_workers=self.num_workers)
79
+
80
+ def set_epoch(self, epoch):
81
+ self.epoch = epoch
82
+
83
+ def sample(self, data):
84
+ """ Sample data according to rank/world_size/num_workers
85
+
86
+ Args:
87
+ data(List): input data list
88
+
89
+ Returns:
90
+ List: data list after sample
91
+ """
92
+ data = list(range(len(data)))
93
+ # force datalist even
94
+ if self.partition:
95
+ if self.shuffle:
96
+ random.Random(self.epoch).shuffle(data)
97
+ if len(data) < self.world_size:
98
+ data = data * math.ceil(self.world_size / len(data))
99
+ data = data[:self.world_size]
100
+ data = data[self.rank::self.world_size]
101
+ if len(data) < self.num_workers:
102
+ data = data * math.ceil(self.num_workers / len(data))
103
+ data = data[:self.num_workers]
104
+ data = data[self.worker_id::self.num_workers]
105
+ return data
106
+
107
+
108
+ class DataList(IterableDataset):
109
+
110
+ def __init__(self, lists, shuffle=True, partition=True):
111
+ self.lists = lists
112
+ self.sampler = DistributedSampler(shuffle, partition)
113
+
114
+ def set_epoch(self, epoch):
115
+ self.sampler.set_epoch(epoch)
116
+
117
+ def __iter__(self):
118
+ sampler_info = self.sampler.update()
119
+ indexes = self.sampler.sample(self.lists)
120
+ for index in indexes:
121
+ data = dict(src=self.lists[index])
122
+ data.update(sampler_info)
123
+ yield data
124
+
125
+
126
+ def Dataset(data_list_file,
127
+ data_pipeline,
128
+ mode='train',
129
+ gan=False,
130
+ shuffle=True,
131
+ partition=True,
132
+ tts_file='',
133
+ prompt_utt2data=''):
134
+ """ Construct dataset from arguments
135
+
136
+ We have two shuffle stage in the Dataset. The first is global
137
+ shuffle at shards tar/raw file level. The second is global shuffle
138
+ at training samples level.
139
+
140
+ Args:
141
+ data_type(str): raw/shard
142
+ tokenizer (BaseTokenizer): tokenizer to tokenize
143
+ partition(bool): whether to do data partition in terms of rank
144
+ """
145
+ assert mode in ['train', 'inference']
146
+ lists = read_lists(data_list_file)
147
+ if mode == 'inference':
148
+ with open(tts_file) as f:
149
+ tts_data = json.load(f)
150
+ utt2lists = read_json_lists(prompt_utt2data)
151
+ # filter unnecessary file in inference mode
152
+ lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
153
+ dataset = DataList(lists,
154
+ shuffle=shuffle,
155
+ partition=partition)
156
+ if mode == 'inference':
157
+ # map partial arg to parquet_opener func in inference mode
158
+ data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
159
+ if gan is True:
160
+ # map partial arg to padding func in gan mode
161
+ data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
162
+ for func in data_pipeline:
163
+ dataset = Processor(dataset, func, mode=mode)
164
+ return dataset
cosyvoice/dataset/processor.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+
17
+ import pyarrow.parquet as pq
18
+ from io import BytesIO
19
+ import torch
20
+ import torchaudio
21
+ from torch.nn.utils.rnn import pad_sequence
22
+ import torch.nn.functional as F
23
+
24
+ torchaudio.set_audio_backend('soundfile')
25
+
26
+ AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
27
+
28
+
29
+ def parquet_opener(data, mode='train', tts_data={}):
30
+ """ Give url or local file, return file descriptor
31
+ Inplace operation.
32
+
33
+ Args:
34
+ data(Iterable[str]): url or local file list
35
+
36
+ Returns:
37
+ Iterable[{src, stream}]
38
+ """
39
+ for sample in data:
40
+ assert 'src' in sample
41
+ url = sample['src']
42
+ try:
43
+ for df in pq.ParquetFile(url).iter_batches(batch_size=64):
44
+ df = df.to_pandas()
45
+ for i in range(len(df)):
46
+ if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
47
+ continue
48
+ sample.update(dict(df.loc[i]))
49
+ if mode == 'train':
50
+ # NOTE do not return sample directly, must initialize a new dict
51
+ yield {**sample}
52
+ else:
53
+ for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
54
+ yield {**sample, 'tts_index': index, 'tts_text': text}
55
+ except Exception as ex:
56
+ logging.warning('Failed to open {}, ex info {}'.format(url, ex))
57
+
58
+
59
+ def filter(data,
60
+ max_length=10240,
61
+ min_length=10,
62
+ token_max_length=200,
63
+ token_min_length=1,
64
+ min_output_input_ratio=0.0005,
65
+ max_output_input_ratio=1,
66
+ mode='train'):
67
+ """ Filter sample according to feature and label length
68
+ Inplace operation.
69
+
70
+ Args::
71
+ data: Iterable[{key, wav, label, sample_rate}]
72
+ max_length: drop utterance which is greater than max_length(10ms)
73
+ min_length: drop utterance which is less than min_length(10ms)
74
+ token_max_length: drop utterance which is greater than
75
+ token_max_length, especially when use char unit for
76
+ english modeling
77
+ token_min_length: drop utterance which is
78
+ less than token_max_length
79
+ min_output_input_ratio: minimal ration of
80
+ token_length / feats_length(10ms)
81
+ max_output_input_ratio: maximum ration of
82
+ token_length / feats_length(10ms)
83
+
84
+ Returns:
85
+ Iterable[{key, wav, label, sample_rate}]
86
+ """
87
+ for sample in data:
88
+ sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
89
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
90
+ del sample['audio_data']
91
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
92
+ num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
93
+ if num_frames < min_length:
94
+ continue
95
+ if num_frames > max_length:
96
+ continue
97
+ if len(sample['text_token']) < token_min_length:
98
+ continue
99
+ if len(sample['text_token']) > token_max_length:
100
+ continue
101
+ if len(sample['speech_token']) == 0:
102
+ continue
103
+ if num_frames != 0:
104
+ if len(sample['text_token']) / num_frames < min_output_input_ratio:
105
+ continue
106
+ if len(sample['text_token']) / num_frames > max_output_input_ratio:
107
+ continue
108
+ yield sample
109
+
110
+
111
+ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
112
+ """ Resample data.
113
+ Inplace operation.
114
+
115
+ Args:
116
+ data: Iterable[{key, wav, label, sample_rate}]
117
+ resample_rate: target resample rate
118
+
119
+ Returns:
120
+ Iterable[{key, wav, label, sample_rate}]
121
+ """
122
+ for sample in data:
123
+ assert 'sample_rate' in sample
124
+ assert 'speech' in sample
125
+ sample_rate = sample['sample_rate']
126
+ waveform = sample['speech']
127
+ if sample_rate != resample_rate:
128
+ if sample_rate < min_sample_rate:
129
+ continue
130
+ sample['sample_rate'] = resample_rate
131
+ sample['speech'] = torchaudio.transforms.Resample(
132
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
133
+ max_val = sample['speech'].abs().max()
134
+ if max_val > 1:
135
+ sample['speech'] /= max_val
136
+ yield sample
137
+
138
+
139
+ def truncate(data, truncate_length=24576, mode='train'):
140
+ """ Truncate data.
141
+
142
+ Args:
143
+ data: Iterable[{key, wav, label, sample_rate}]
144
+ truncate_length: truncate length
145
+
146
+ Returns:
147
+ Iterable[{key, wav, label, sample_rate}]
148
+ """
149
+ for sample in data:
150
+ waveform = sample['speech']
151
+ if waveform.shape[1] > truncate_length:
152
+ start = random.randint(0, waveform.shape[1] - truncate_length)
153
+ waveform = waveform[:, start: start + truncate_length]
154
+ else:
155
+ waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
156
+ sample['speech'] = waveform
157
+ yield sample
158
+
159
+
160
+ def compute_fbank(data,
161
+ feat_extractor,
162
+ mode='train'):
163
+ """ Extract fbank
164
+
165
+ Args:
166
+ data: Iterable[{key, wav, label, sample_rate}]
167
+
168
+ Returns:
169
+ Iterable[{key, feat, label}]
170
+ """
171
+ for sample in data:
172
+ assert 'sample_rate' in sample
173
+ assert 'speech' in sample
174
+ assert 'utt' in sample
175
+ assert 'text_token' in sample
176
+ waveform = sample['speech']
177
+ mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
178
+ sample['speech_feat'] = mat
179
+ yield sample
180
+
181
+
182
+ def compute_f0(data, pitch_extractor, mode='train'):
183
+ """ Extract f0
184
+
185
+ Args:
186
+ data: Iterable[{key, wav, label, sample_rate}]
187
+
188
+ Returns:
189
+ Iterable[{key, feat, label}]
190
+ """
191
+ for sample in data:
192
+ assert 'sample_rate' in sample
193
+ assert 'speech' in sample
194
+ assert 'utt' in sample
195
+ assert 'text_token' in sample
196
+ waveform = sample['speech']
197
+ mat = pitch_extractor(waveform).transpose(1, 2)
198
+ mat = F.interpolate(mat, size=sample['speech_feat'].shape[0], mode='linear')
199
+ sample['pitch_feat'] = mat[0, 0]
200
+ yield sample
201
+
202
+
203
+ def parse_embedding(data, normalize, mode='train'):
204
+ """ Parse utt_embedding/spk_embedding
205
+
206
+ Args:
207
+ data: Iterable[{key, wav, label, sample_rate}]
208
+
209
+ Returns:
210
+ Iterable[{key, feat, label}]
211
+ """
212
+ for sample in data:
213
+ sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
214
+ sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
215
+ if normalize:
216
+ sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
217
+ sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
218
+ yield sample
219
+
220
+
221
+ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
222
+ """ Decode text to chars or BPE
223
+ Inplace operation
224
+
225
+ Args:
226
+ data: Iterable[{key, wav, txt, sample_rate}]
227
+
228
+ Returns:
229
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
230
+ """
231
+ tokenizer = get_tokenizer()
232
+ for sample in data:
233
+ assert 'text' in sample
234
+ sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
235
+ if mode == 'inference':
236
+ sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
237
+ yield sample
238
+
239
+
240
+ def shuffle(data, shuffle_size=10000, mode='train'):
241
+ """ Local shuffle the data
242
+
243
+ Args:
244
+ data: Iterable[{key, feat, label}]
245
+ shuffle_size: buffer size for shuffle
246
+
247
+ Returns:
248
+ Iterable[{key, feat, label}]
249
+ """
250
+ buf = []
251
+ for sample in data:
252
+ buf.append(sample)
253
+ if len(buf) >= shuffle_size:
254
+ random.shuffle(buf)
255
+ for x in buf:
256
+ yield x
257
+ buf = []
258
+ # The sample left over
259
+ random.shuffle(buf)
260
+ for x in buf:
261
+ yield x
262
+
263
+
264
+ def sort(data, sort_size=500, mode='train'):
265
+ """ Sort the data by feature length.
266
+ Sort is used after shuffle and before batch, so we can group
267
+ utts with similar lengths into a batch, and `sort_size` should
268
+ be less than `shuffle_size`
269
+
270
+ Args:
271
+ data: Iterable[{key, feat, label}]
272
+ sort_size: buffer size for sort
273
+
274
+ Returns:
275
+ Iterable[{key, feat, label}]
276
+ """
277
+
278
+ buf = []
279
+ for sample in data:
280
+ buf.append(sample)
281
+ if len(buf) >= sort_size:
282
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
283
+ for x in buf:
284
+ yield x
285
+ buf = []
286
+ # The sample left over
287
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
288
+ for x in buf:
289
+ yield x
290
+
291
+
292
+ def static_batch(data, batch_size=16):
293
+ """ Static batch the data by `batch_size`
294
+
295
+ Args:
296
+ data: Iterable[{key, feat, label}]
297
+ batch_size: batch size
298
+
299
+ Returns:
300
+ Iterable[List[{key, feat, label}]]
301
+ """
302
+ buf = []
303
+ for sample in data:
304
+ buf.append(sample)
305
+ if len(buf) >= batch_size:
306
+ yield buf
307
+ buf = []
308
+ if len(buf) > 0:
309
+ yield buf
310
+
311
+
312
+ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
313
+ """ Dynamic batch the data until the total frames in batch
314
+ reach `max_frames_in_batch`
315
+
316
+ Args:
317
+ data: Iterable[{key, feat, label}]
318
+ max_frames_in_batch: max_frames in one batch
319
+
320
+ Returns:
321
+ Iterable[List[{key, feat, label}]]
322
+ """
323
+ buf = []
324
+ longest_frames = 0
325
+ for sample in data:
326
+ assert 'speech_feat' in sample
327
+ assert isinstance(sample['speech_feat'], torch.Tensor)
328
+ new_sample_frames = sample['speech_feat'].size(0)
329
+ longest_frames = max(longest_frames, new_sample_frames)
330
+ frames_after_padding = longest_frames * (len(buf) + 1)
331
+ if frames_after_padding > max_frames_in_batch:
332
+ yield buf
333
+ buf = [sample]
334
+ longest_frames = new_sample_frames
335
+ else:
336
+ buf.append(sample)
337
+ if len(buf) > 0:
338
+ yield buf
339
+
340
+
341
+ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
342
+ """ Wrapper for static/dynamic batch
343
+ """
344
+ if mode == 'inference':
345
+ return static_batch(data, 1)
346
+ else:
347
+ if batch_type == 'static':
348
+ return static_batch(data, batch_size)
349
+ elif batch_type == 'dynamic':
350
+ return dynamic_batch(data, max_frames_in_batch)
351
+ else:
352
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
353
+
354
+
355
+ def padding(data, use_spk_embedding, mode='train', gan=False):
356
+ """ Padding the data into training data
357
+
358
+ Args:
359
+ data: Iterable[List[{key, feat, label}]]
360
+
361
+ Returns:
362
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
363
+ """
364
+ for sample in data:
365
+ assert isinstance(sample, list)
366
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
367
+ dtype=torch.int32)
368
+ order = torch.argsort(speech_feat_len, descending=True)
369
+
370
+ utts = [sample[i]['utt'] for i in order]
371
+ speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
372
+ speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
373
+ speech = pad_sequence(speech, batch_first=True, padding_value=0)
374
+ speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
375
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
376
+ speech_token = pad_sequence(speech_token,
377
+ batch_first=True,
378
+ padding_value=0)
379
+ speech_feat = [sample[i]['speech_feat'] for i in order]
380
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
381
+ speech_feat = pad_sequence(speech_feat,
382
+ batch_first=True,
383
+ padding_value=0)
384
+ text = [sample[i]['text'] for i in order]
385
+ text_token = [torch.tensor(sample[i]['text_token']) for i in order]
386
+ text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
387
+ text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
388
+ utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
389
+ spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
390
+ batch = {
391
+ "utts": utts,
392
+ "speech": speech,
393
+ "speech_len": speech_len,
394
+ "speech_token": speech_token,
395
+ "speech_token_len": speech_token_len,
396
+ "speech_feat": speech_feat,
397
+ "speech_feat_len": speech_feat_len,
398
+ "text": text,
399
+ "text_token": text_token,
400
+ "text_token_len": text_token_len,
401
+ "utt_embedding": utt_embedding,
402
+ "spk_embedding": spk_embedding,
403
+ }
404
+ if gan is True:
405
+ # in gan train, we need pitch_feat
406
+ pitch_feat = [sample[i]['pitch_feat'] for i in order]
407
+ pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
408
+ pitch_feat = pad_sequence(pitch_feat,
409
+ batch_first=True,
410
+ padding_value=0)
411
+ batch["pitch_feat"] = pitch_feat
412
+ batch["pitch_feat_len"] = pitch_feat_len
413
+ else:
414
+ # only gan train needs speech, delete it to save memory
415
+ del batch["speech"]
416
+ del batch["speech_len"]
417
+ if mode == 'inference':
418
+ tts_text = [sample[i]['tts_text'] for i in order]
419
+ tts_index = [sample[i]['tts_index'] for i in order]
420
+ tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
421
+ tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
422
+ tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
423
+ batch.update({'tts_text': tts_text,
424
+ 'tts_index': tts_index,
425
+ 'tts_text_token': tts_text_token,
426
+ 'tts_text_token_len': tts_text_token_len})
427
+ if use_spk_embedding is True:
428
+ batch["embedding"] = batch["spk_embedding"]
429
+ else:
430
+ batch["embedding"] = batch["utt_embedding"]
431
+ yield batch
cosyvoice/flow/decoder.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from einops import pack, rearrange, repeat
18
+ from cosyvoice.utils.common import mask_to_bias
19
+ from cosyvoice.utils.mask import add_optional_chunk_mask
20
+ from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
21
+ from matcha.models.components.transformer import BasicTransformerBlock
22
+
23
+
24
+ class Transpose(torch.nn.Module):
25
+ def __init__(self, dim0: int, dim1: int):
26
+ super().__init__()
27
+ self.dim0 = dim0
28
+ self.dim1 = dim1
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ x = torch.transpose(x, self.dim0, self.dim1)
32
+ return x
33
+
34
+
35
+ class CausalBlock1D(Block1D):
36
+ def __init__(self, dim: int, dim_out: int):
37
+ super(CausalBlock1D, self).__init__(dim, dim_out)
38
+ self.block = torch.nn.Sequential(
39
+ CausalConv1d(dim, dim_out, 3),
40
+ Transpose(1, 2),
41
+ nn.LayerNorm(dim_out),
42
+ Transpose(1, 2),
43
+ nn.Mish(),
44
+ )
45
+
46
+ def forward(self, x: torch.Tensor, mask: torch.Tensor):
47
+ output = self.block(x * mask)
48
+ return output * mask
49
+
50
+
51
+ class CausalResnetBlock1D(ResnetBlock1D):
52
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int=8):
53
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
54
+ self.block1 = CausalBlock1D(dim, dim_out)
55
+ self.block2 = CausalBlock1D(dim_out, dim_out)
56
+
57
+
58
+ class CausalConv1d(torch.nn.Conv1d):
59
+ def __init__(
60
+ self,
61
+ in_channels: int,
62
+ out_channels: int,
63
+ kernel_size: int,
64
+ stride: int = 1,
65
+ dilation: int = 1,
66
+ groups: int = 1,
67
+ bias: bool = True,
68
+ padding_mode: str = 'zeros',
69
+ device=None,
70
+ dtype=None
71
+ ) -> None:
72
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
73
+ kernel_size, stride,
74
+ padding=0, dilation=dilation,
75
+ groups=groups, bias=bias,
76
+ padding_mode=padding_mode,
77
+ device=device, dtype=dtype
78
+ )
79
+ assert stride == 1
80
+ self.causal_padding = (kernel_size - 1, 0)
81
+
82
+ def forward(self, x: torch.Tensor):
83
+ x = F.pad(x, self.causal_padding)
84
+ x = super(CausalConv1d, self).forward(x)
85
+ return x
86
+
87
+
88
+ class ConditionalDecoder(nn.Module):
89
+ def __init__(
90
+ self,
91
+ in_channels,
92
+ out_channels,
93
+ causal=False,
94
+ channels=(256, 256),
95
+ dropout=0.05,
96
+ attention_head_dim=64,
97
+ n_blocks=1,
98
+ num_mid_blocks=2,
99
+ num_heads=4,
100
+ act_fn="snake",
101
+ ):
102
+ """
103
+ This decoder requires an input with the same shape of the target. So, if your text content
104
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
105
+ """
106
+ super().__init__()
107
+ channels = tuple(channels)
108
+ self.in_channels = in_channels
109
+ self.out_channels = out_channels
110
+ self.causal = causal
111
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
112
+ time_embed_dim = channels[0] * 4
113
+ self.time_mlp = TimestepEmbedding(
114
+ in_channels=in_channels,
115
+ time_embed_dim=time_embed_dim,
116
+ act_fn="silu",
117
+ )
118
+ self.down_blocks = nn.ModuleList([])
119
+ self.mid_blocks = nn.ModuleList([])
120
+ self.up_blocks = nn.ModuleList([])
121
+
122
+ output_channel = in_channels
123
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
124
+ input_channel = output_channel
125
+ output_channel = channels[i]
126
+ is_last = i == len(channels) - 1
127
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
128
+ transformer_blocks = nn.ModuleList(
129
+ [
130
+ BasicTransformerBlock(
131
+ dim=output_channel,
132
+ num_attention_heads=num_heads,
133
+ attention_head_dim=attention_head_dim,
134
+ dropout=dropout,
135
+ activation_fn=act_fn,
136
+ )
137
+ for _ in range(n_blocks)
138
+ ]
139
+ )
140
+ downsample = (
141
+ Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
142
+ )
143
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
144
+
145
+ for _ in range(num_mid_blocks):
146
+ input_channel = channels[-1]
147
+ out_channels = channels[-1]
148
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
149
+
150
+ transformer_blocks = nn.ModuleList(
151
+ [
152
+ BasicTransformerBlock(
153
+ dim=output_channel,
154
+ num_attention_heads=num_heads,
155
+ attention_head_dim=attention_head_dim,
156
+ dropout=dropout,
157
+ activation_fn=act_fn,
158
+ )
159
+ for _ in range(n_blocks)
160
+ ]
161
+ )
162
+
163
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
164
+
165
+ channels = channels[::-1] + (channels[0],)
166
+ for i in range(len(channels) - 1):
167
+ input_channel = channels[i] * 2
168
+ output_channel = channels[i + 1]
169
+ is_last = i == len(channels) - 2
170
+ resnet = CausalResnetBlock1D(
171
+ dim=input_channel,
172
+ dim_out=output_channel,
173
+ time_emb_dim=time_embed_dim,
174
+ ) if self.causal else ResnetBlock1D(
175
+ dim=input_channel,
176
+ dim_out=output_channel,
177
+ time_emb_dim=time_embed_dim,
178
+ )
179
+ transformer_blocks = nn.ModuleList(
180
+ [
181
+ BasicTransformerBlock(
182
+ dim=output_channel,
183
+ num_attention_heads=num_heads,
184
+ attention_head_dim=attention_head_dim,
185
+ dropout=dropout,
186
+ activation_fn=act_fn,
187
+ )
188
+ for _ in range(n_blocks)
189
+ ]
190
+ )
191
+ upsample = (
192
+ Upsample1D(output_channel, use_conv_transpose=True)
193
+ if not is_last
194
+ else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
195
+ )
196
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
197
+ self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
198
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
199
+ self.initialize_weights()
200
+
201
+ def initialize_weights(self):
202
+ for m in self.modules():
203
+ if isinstance(m, nn.Conv1d):
204
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
205
+ if m.bias is not None:
206
+ nn.init.constant_(m.bias, 0)
207
+ elif isinstance(m, nn.GroupNorm):
208
+ nn.init.constant_(m.weight, 1)
209
+ nn.init.constant_(m.bias, 0)
210
+ elif isinstance(m, nn.Linear):
211
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
212
+ if m.bias is not None:
213
+ nn.init.constant_(m.bias, 0)
214
+
215
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
216
+ """Forward pass of the UNet1DConditional model.
217
+
218
+ Args:
219
+ x (torch.Tensor): shape (batch_size, in_channels, time)
220
+ mask (_type_): shape (batch_size, 1, time)
221
+ t (_type_): shape (batch_size)
222
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
223
+ cond (_type_, optional): placeholder for future use. Defaults to None.
224
+
225
+ Raises:
226
+ ValueError: _description_
227
+ ValueError: _description_
228
+
229
+ Returns:
230
+ _type_: _description_
231
+ """
232
+
233
+ t = self.time_embeddings(t).to(t.dtype)
234
+ t = self.time_mlp(t)
235
+
236
+ x = pack([x, mu], "b * t")[0]
237
+
238
+ if spks is not None:
239
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
240
+ x = pack([x, spks], "b * t")[0]
241
+ if cond is not None:
242
+ x = pack([x, cond], "b * t")[0]
243
+
244
+ hiddens = []
245
+ masks = [mask]
246
+ for resnet, transformer_blocks, downsample in self.down_blocks:
247
+ mask_down = masks[-1]
248
+ x = resnet(x, mask_down, t)
249
+ x = rearrange(x, "b c t -> b t c").contiguous()
250
+ # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
251
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
252
+ attn_mask = mask_to_bias(attn_mask==1, x.dtype)
253
+ for transformer_block in transformer_blocks:
254
+ x = transformer_block(
255
+ hidden_states=x,
256
+ attention_mask=attn_mask,
257
+ timestep=t,
258
+ )
259
+ x = rearrange(x, "b t c -> b c t").contiguous()
260
+ hiddens.append(x) # Save hidden states for skip connections
261
+ x = downsample(x * mask_down)
262
+ masks.append(mask_down[:, :, ::2])
263
+ masks = masks[:-1]
264
+ mask_mid = masks[-1]
265
+
266
+ for resnet, transformer_blocks in self.mid_blocks:
267
+ x = resnet(x, mask_mid, t)
268
+ x = rearrange(x, "b c t -> b t c").contiguous()
269
+ # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
270
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
271
+ attn_mask = mask_to_bias(attn_mask==1, x.dtype)
272
+ for transformer_block in transformer_blocks:
273
+ x = transformer_block(
274
+ hidden_states=x,
275
+ attention_mask=attn_mask,
276
+ timestep=t,
277
+ )
278
+ x = rearrange(x, "b t c -> b c t").contiguous()
279
+
280
+ for resnet, transformer_blocks, upsample in self.up_blocks:
281
+ mask_up = masks.pop()
282
+ skip = hiddens.pop()
283
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
284
+ x = resnet(x, mask_up, t)
285
+ x = rearrange(x, "b c t -> b t c").contiguous()
286
+ # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
287
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
288
+ attn_mask = mask_to_bias(attn_mask==1, x.dtype)
289
+ for transformer_block in transformer_blocks:
290
+ x = transformer_block(
291
+ hidden_states=x,
292
+ attention_mask=attn_mask,
293
+ timestep=t,
294
+ )
295
+ x = rearrange(x, "b t c -> b c t").contiguous()
296
+ x = upsample(x * mask_up)
297
+ x = self.final_block(x, mask_up)
298
+ output = self.final_proj(x * mask_up)
299
+ return output * mask
cosyvoice/flow/flow.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+ from typing import Dict, Optional
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+ from omegaconf import DictConfig
21
+ from cosyvoice.utils.mask import make_pad_mask
22
+
23
+
24
+ class MaskedDiffWithXvec(torch.nn.Module):
25
+ def __init__(self,
26
+ input_size: int = 512,
27
+ output_size: int = 80,
28
+ spk_embed_dim: int = 192,
29
+ output_type: str = "mel",
30
+ vocab_size: int = 4096,
31
+ input_frame_rate: int = 50,
32
+ only_mask_loss: bool = True,
33
+ encoder: torch.nn.Module = None,
34
+ length_regulator: torch.nn.Module = None,
35
+ decoder: torch.nn.Module = None,
36
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
37
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
38
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
39
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
40
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
41
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
42
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
43
+ super().__init__()
44
+ self.input_size = input_size
45
+ self.output_size = output_size
46
+ self.decoder_conf = decoder_conf
47
+ self.mel_feat_conf = mel_feat_conf
48
+ self.vocab_size = vocab_size
49
+ self.output_type = output_type
50
+ self.input_frame_rate = input_frame_rate
51
+ logging.info(f"input frame rate={self.input_frame_rate}")
52
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
53
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
54
+ self.encoder = encoder
55
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
56
+ self.decoder = decoder
57
+ self.length_regulator = length_regulator
58
+ self.only_mask_loss = only_mask_loss
59
+
60
+ def forward(
61
+ self,
62
+ batch: dict,
63
+ device: torch.device,
64
+ ) -> Dict[str, Optional[torch.Tensor]]:
65
+ token = batch['speech_token'].to(device)
66
+ token_len = batch['speech_token_len'].to(device)
67
+ feat = batch['speech_feat'].to(device)
68
+ feat_len = batch['speech_feat_len'].to(device)
69
+ embedding = batch['embedding'].to(device)
70
+
71
+ # xvec projection
72
+ embedding = F.normalize(embedding, dim=1)
73
+ embedding = self.spk_embed_affine_layer(embedding)
74
+
75
+ # concat text and prompt_text
76
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
77
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
78
+
79
+ # text encode
80
+ h, h_lengths = self.encoder(token, token_len)
81
+ h = self.encoder_proj(h)
82
+ h, h_lengths = self.length_regulator(h, feat_len)
83
+
84
+ # get conditions
85
+ conds = torch.zeros(feat.shape, device=token.device)
86
+ for i, j in enumerate(feat_len):
87
+ if random.random() < 0.5:
88
+ continue
89
+ index = random.randint(0, int(0.3 * j))
90
+ conds[i, :index] = feat[i, :index]
91
+ conds = conds.transpose(1, 2)
92
+
93
+ mask = (~make_pad_mask(feat_len)).to(h)
94
+ feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
95
+ loss, _ = self.decoder.compute_loss(
96
+ feat.transpose(1, 2).contiguous(),
97
+ mask.unsqueeze(1),
98
+ h.transpose(1, 2).contiguous(),
99
+ embedding,
100
+ cond=conds
101
+ )
102
+ return {'loss': loss}
103
+
104
+ @torch.inference_mode()
105
+ def inference(self,
106
+ token,
107
+ token_len,
108
+ prompt_token,
109
+ prompt_token_len,
110
+ prompt_feat,
111
+ prompt_feat_len,
112
+ embedding,
113
+ flow_cache):
114
+ assert token.shape[0] == 1
115
+ # xvec projection
116
+ embedding = F.normalize(embedding, dim=1)
117
+ embedding = self.spk_embed_affine_layer(embedding)
118
+
119
+ # concat text and prompt_text
120
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
121
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
122
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
123
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
124
+
125
+ # text encode
126
+ h, h_lengths = self.encoder(token, token_len)
127
+ h = self.encoder_proj(h)
128
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
129
+ h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
130
+
131
+ # get conditions
132
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
133
+ conds[:, :mel_len1] = prompt_feat
134
+ conds = conds.transpose(1, 2)
135
+
136
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
137
+ feat, flow_cache = self.decoder(
138
+ mu=h.transpose(1, 2).contiguous(),
139
+ mask=mask.unsqueeze(1),
140
+ spks=embedding,
141
+ cond=conds,
142
+ n_timesteps=10,
143
+ prompt_len=mel_len1,
144
+ flow_cache=flow_cache
145
+ )
146
+ feat = feat[:, :, mel_len1:]
147
+ assert feat.shape[2] == mel_len2
148
+ return feat, flow_cache
149
+
150
+
151
+ class CausalMaskedDiffWithXvec(torch.nn.Module):
152
+ def __init__(self,
153
+ input_size: int = 512,
154
+ output_size: int = 80,
155
+ spk_embed_dim: int = 192,
156
+ output_type: str = "mel",
157
+ vocab_size: int = 4096,
158
+ input_frame_rate: int = 50,
159
+ only_mask_loss: bool = True,
160
+ token_mel_ratio: int = 2,
161
+ pre_lookahead_len: int = 3,
162
+ encoder: torch.nn.Module = None,
163
+ decoder: torch.nn.Module = None,
164
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
165
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
166
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
167
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
168
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
169
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
170
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
171
+ super().__init__()
172
+ self.input_size = input_size
173
+ self.output_size = output_size
174
+ self.decoder_conf = decoder_conf
175
+ self.mel_feat_conf = mel_feat_conf
176
+ self.vocab_size = vocab_size
177
+ self.output_type = output_type
178
+ self.input_frame_rate = input_frame_rate
179
+ logging.info(f"input frame rate={self.input_frame_rate}")
180
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
181
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
182
+ self.encoder = encoder
183
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
184
+ self.decoder = decoder
185
+ self.only_mask_loss = only_mask_loss
186
+ self.token_mel_ratio = token_mel_ratio
187
+ self.pre_lookahead_len = pre_lookahead_len
188
+
189
+ @torch.inference_mode()
190
+ def inference(self,
191
+ token,
192
+ token_len,
193
+ prompt_token,
194
+ prompt_token_len,
195
+ prompt_feat,
196
+ prompt_feat_len,
197
+ embedding,
198
+ finalize):
199
+ assert token.shape[0] == 1
200
+ # xvec projection
201
+ embedding = F.normalize(embedding, dim=1)
202
+ embedding = self.spk_embed_affine_layer(embedding)
203
+
204
+ # concat text and prompt_text
205
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
206
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
207
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
208
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
209
+
210
+ # text encode
211
+ h, h_lengths = self.encoder(token, token_len)
212
+ if finalize is False:
213
+ h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
214
+ mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
215
+ h = self.encoder_proj(h)
216
+
217
+ # get conditions
218
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
219
+ conds[:, :mel_len1] = prompt_feat
220
+ conds = conds.transpose(1, 2)
221
+
222
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
223
+ feat, _ = self.decoder(
224
+ mu=h.transpose(1, 2).contiguous(),
225
+ mask=mask.unsqueeze(1),
226
+ spks=embedding,
227
+ cond=conds,
228
+ n_timesteps=10
229
+ )
230
+ feat = feat[:, :, mel_len1:]
231
+ assert feat.shape[2] == mel_len2
232
+ return feat, None
cosyvoice/flow/flow_matching.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import onnxruntime
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from matcha.models.components.flow_matching import BASECFM
18
+
19
+
20
+ class ConditionalCFM(BASECFM):
21
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
22
+ super().__init__(
23
+ n_feats=in_channels,
24
+ cfm_params=cfm_params,
25
+ n_spks=n_spks,
26
+ spk_emb_dim=spk_emb_dim,
27
+ )
28
+ self.t_scheduler = cfm_params.t_scheduler
29
+ self.training_cfg_rate = cfm_params.training_cfg_rate
30
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
31
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
32
+ # Just change the architecture of the estimator here
33
+ self.estimator = estimator
34
+
35
+ @torch.inference_mode()
36
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
37
+ """Forward diffusion
38
+
39
+ Args:
40
+ mu (torch.Tensor): output of encoder
41
+ shape: (batch_size, n_feats, mel_timesteps)
42
+ mask (torch.Tensor): output_mask
43
+ shape: (batch_size, 1, mel_timesteps)
44
+ n_timesteps (int): number of diffusion steps
45
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
46
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
47
+ shape: (batch_size, spk_emb_dim)
48
+ cond: Not used but kept for future purposes
49
+
50
+ Returns:
51
+ sample: generated mel-spectrogram
52
+ shape: (batch_size, n_feats, mel_timesteps)
53
+ """
54
+
55
+ z = torch.randn_like(mu) * temperature
56
+ cache_size = flow_cache.shape[2]
57
+ # fix prompt and overlap part mu and z
58
+ if cache_size != 0:
59
+ z[:, :, :cache_size] = flow_cache[:, :, :, 0]
60
+ mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
61
+ z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
62
+ mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
63
+ flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
64
+
65
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
66
+ if self.t_scheduler == 'cosine':
67
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
68
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
69
+
70
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
71
+ """
72
+ Fixed euler solver for ODEs.
73
+ Args:
74
+ x (torch.Tensor): random noise
75
+ t_span (torch.Tensor): n_timesteps interpolated
76
+ shape: (n_timesteps + 1,)
77
+ mu (torch.Tensor): output of encoder
78
+ shape: (batch_size, n_feats, mel_timesteps)
79
+ mask (torch.Tensor): output_mask
80
+ shape: (batch_size, 1, mel_timesteps)
81
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
82
+ shape: (batch_size, spk_emb_dim)
83
+ cond: Not used but kept for future purposes
84
+ """
85
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
86
+ t = t.unsqueeze(dim=0)
87
+
88
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
89
+ # Or in future might add like a return_all_steps flag
90
+ sol = []
91
+
92
+ if self.inference_cfg_rate > 0:
93
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
94
+ x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
95
+ mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
96
+ mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
97
+ t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
98
+ spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
99
+ cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
100
+ else:
101
+ x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
102
+ for step in range(1, len(t_span)):
103
+ # Classifier-Free Guidance inference introduced in VoiceBox
104
+ if self.inference_cfg_rate > 0:
105
+ x_in[:] = x
106
+ mask_in[:] = mask
107
+ mu_in[0] = mu
108
+ t_in[:] = t.unsqueeze(0)
109
+ spks_in[0] = spks
110
+ cond_in[0] = cond
111
+ else:
112
+ x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
113
+ dphi_dt = self.forward_estimator(
114
+ x_in, mask_in,
115
+ mu_in, t_in,
116
+ spks_in,
117
+ cond_in
118
+ )
119
+ if self.inference_cfg_rate > 0:
120
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
121
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
122
+ x = x + dt * dphi_dt
123
+ t = t + dt
124
+ sol.append(x)
125
+ if step < len(t_span) - 1:
126
+ dt = t_span[step + 1] - t
127
+
128
+ return sol[-1].float()
129
+
130
+ def forward_estimator(self, x, mask, mu, t, spks, cond):
131
+ if isinstance(self.estimator, torch.nn.Module):
132
+ return self.estimator.forward(x, mask, mu, t, spks, cond)
133
+ elif isinstance(self.estimator, onnxruntime.InferenceSession):
134
+ ort_inputs = {
135
+ 'x': x.cpu().numpy(),
136
+ 'mask': mask.cpu().numpy(),
137
+ 'mu': mu.cpu().numpy(),
138
+ 't': t.cpu().numpy(),
139
+ 'spks': spks.cpu().numpy(),
140
+ 'cond': cond.cpu().numpy()
141
+ }
142
+ output = self.estimator.run(None, ort_inputs)[0]
143
+ return torch.tensor(output, dtype=x.dtype, device=x.device)
144
+ else:
145
+ self.estimator.set_input_shape('x', (2, 80, x.size(2)))
146
+ self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
147
+ self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
148
+ self.estimator.set_input_shape('t', (2,))
149
+ self.estimator.set_input_shape('spks', (2, 80))
150
+ self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
151
+ # run trt engine
152
+ self.estimator.execute_v2([x.contiguous().data_ptr(),
153
+ mask.contiguous().data_ptr(),
154
+ mu.contiguous().data_ptr(),
155
+ t.contiguous().data_ptr(),
156
+ spks.contiguous().data_ptr(),
157
+ cond.contiguous().data_ptr(),
158
+ x.data_ptr()])
159
+ return x
160
+
161
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
162
+ """Computes diffusion loss
163
+
164
+ Args:
165
+ x1 (torch.Tensor): Target
166
+ shape: (batch_size, n_feats, mel_timesteps)
167
+ mask (torch.Tensor): target mask
168
+ shape: (batch_size, 1, mel_timesteps)
169
+ mu (torch.Tensor): output of encoder
170
+ shape: (batch_size, n_feats, mel_timesteps)
171
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
172
+ shape: (batch_size, spk_emb_dim)
173
+
174
+ Returns:
175
+ loss: conditional flow matching loss
176
+ y: conditional flow
177
+ shape: (batch_size, n_feats, mel_timesteps)
178
+ """
179
+ b, _, t = mu.shape
180
+
181
+ # random timestep
182
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
183
+ if self.t_scheduler == 'cosine':
184
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
185
+ # sample noise p(x_0)
186
+ z = torch.randn_like(x1)
187
+
188
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
189
+ u = x1 - (1 - self.sigma_min) * z
190
+
191
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
192
+ if self.training_cfg_rate > 0:
193
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
194
+ mu = mu * cfg_mask.view(-1, 1, 1)
195
+ spks = spks * cfg_mask.view(-1, 1)
196
+ cond = cond * cfg_mask.view(-1, 1, 1)
197
+
198
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
199
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
200
+ return loss, y
201
+
202
+
203
+ class CausalConditionalCFM(ConditionalCFM):
204
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
205
+ super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
206
+ self.rand_noise = torch.randn([1, 80, 50 * 300])
207
+
208
+ @torch.inference_mode()
209
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
210
+ """Forward diffusion
211
+
212
+ Args:
213
+ mu (torch.Tensor): output of encoder
214
+ shape: (batch_size, n_feats, mel_timesteps)
215
+ mask (torch.Tensor): output_mask
216
+ shape: (batch_size, 1, mel_timesteps)
217
+ n_timesteps (int): number of diffusion steps
218
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
219
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
220
+ shape: (batch_size, spk_emb_dim)
221
+ cond: Not used but kept for future purposes
222
+
223
+ Returns:
224
+ sample: generated mel-spectrogram
225
+ shape: (batch_size, n_feats, mel_timesteps)
226
+ """
227
+
228
+ z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
229
+ if self.fp16 is True:
230
+ z = z.half()
231
+ # fix prompt and overlap part mu and z
232
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
233
+ if self.t_scheduler == 'cosine':
234
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
235
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
cosyvoice/flow/length_regulator.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Tuple
15
+ import torch.nn as nn
16
+ import torch
17
+ from torch.nn import functional as F
18
+ from cosyvoice.utils.mask import make_pad_mask
19
+
20
+
21
+ class InterpolateRegulator(nn.Module):
22
+ def __init__(
23
+ self,
24
+ channels: int,
25
+ sampling_ratios: Tuple,
26
+ out_channels: int = None,
27
+ groups: int = 1,
28
+ ):
29
+ super().__init__()
30
+ self.sampling_ratios = sampling_ratios
31
+ out_channels = out_channels or channels
32
+ model = nn.ModuleList([])
33
+ if len(sampling_ratios) > 0:
34
+ for _ in sampling_ratios:
35
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
36
+ norm = nn.GroupNorm(groups, channels)
37
+ act = nn.Mish()
38
+ model.extend([module, norm, act])
39
+ model.append(
40
+ nn.Conv1d(channels, out_channels, 1, 1)
41
+ )
42
+ self.model = nn.Sequential(*model)
43
+
44
+ def forward(self, x, ylens=None):
45
+ # x in (B, T, D)
46
+ mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
47
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
48
+ out = self.model(x).transpose(1, 2).contiguous()
49
+ olens = ylens
50
+ return out * mask, olens
51
+
52
+ def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
53
+ # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54
+ # x in (B, T, D)
55
+ if x2.shape[1] > 40:
56
+ x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
57
+ x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
58
+ mode='linear')
59
+ x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
60
+ x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
61
+ else:
62
+ x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
63
+ if x1.shape[1] != 0:
64
+ x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
65
+ x = torch.concat([x1, x2], dim=2)
66
+ else:
67
+ x = x2
68
+ out = self.model(x).transpose(1, 2).contiguous()
69
+ return out, mel_len1 + mel_len2
cosyvoice/hifigan/discriminator.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.utils import weight_norm
4
+ from typing import List, Optional, Tuple
5
+ from einops import rearrange
6
+ from torchaudio.transforms import Spectrogram
7
+
8
+
9
+ class MultipleDiscriminator(nn.Module):
10
+ def __init__(
11
+ self, mpd: nn.Module, mrd: nn.Module
12
+ ):
13
+ super().__init__()
14
+ self.mpd = mpd
15
+ self.mrd = mrd
16
+
17
+ def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
18
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
19
+ this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
20
+ y_d_rs += this_y_d_rs
21
+ y_d_gs += this_y_d_gs
22
+ fmap_rs += this_fmap_rs
23
+ fmap_gs += this_fmap_gs
24
+ this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
25
+ y_d_rs += this_y_d_rs
26
+ y_d_gs += this_y_d_gs
27
+ fmap_rs += this_fmap_rs
28
+ fmap_gs += this_fmap_gs
29
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
30
+
31
+
32
+ class MultiResolutionDiscriminator(nn.Module):
33
+ def __init__(
34
+ self,
35
+ fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
36
+ num_embeddings: Optional[int] = None,
37
+ ):
38
+ """
39
+ Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
40
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
41
+
42
+ Args:
43
+ fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
44
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
45
+ Defaults to None.
46
+ """
47
+
48
+ super().__init__()
49
+ self.discriminators = nn.ModuleList(
50
+ [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
51
+ )
52
+
53
+ def forward(
54
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
55
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
56
+ y_d_rs = []
57
+ y_d_gs = []
58
+ fmap_rs = []
59
+ fmap_gs = []
60
+
61
+ for d in self.discriminators:
62
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
63
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
64
+ y_d_rs.append(y_d_r)
65
+ fmap_rs.append(fmap_r)
66
+ y_d_gs.append(y_d_g)
67
+ fmap_gs.append(fmap_g)
68
+
69
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
70
+
71
+
72
+ class DiscriminatorR(nn.Module):
73
+ def __init__(
74
+ self,
75
+ window_length: int,
76
+ num_embeddings: Optional[int] = None,
77
+ channels: int = 32,
78
+ hop_factor: float = 0.25,
79
+ bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
80
+ ):
81
+ super().__init__()
82
+ self.window_length = window_length
83
+ self.hop_factor = hop_factor
84
+ self.spec_fn = Spectrogram(
85
+ n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
86
+ )
87
+ n_fft = window_length // 2 + 1
88
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
89
+ self.bands = bands
90
+ convs = lambda: nn.ModuleList(
91
+ [
92
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
93
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
94
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
95
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
96
+ weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
97
+ ]
98
+ )
99
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
100
+
101
+ if num_embeddings is not None:
102
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
103
+ torch.nn.init.zeros_(self.emb.weight)
104
+
105
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
106
+
107
+ def spectrogram(self, x):
108
+ # Remove DC offset
109
+ x = x - x.mean(dim=-1, keepdims=True)
110
+ # Peak normalize the volume of input audio
111
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
112
+ x = self.spec_fn(x)
113
+ x = torch.view_as_real(x)
114
+ x = rearrange(x, "b f t c -> b c t f")
115
+ # Split into bands
116
+ x_bands = [x[..., b[0]: b[1]] for b in self.bands]
117
+ return x_bands
118
+
119
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
120
+ x_bands = self.spectrogram(x)
121
+ fmap = []
122
+ x = []
123
+ for band, stack in zip(x_bands, self.band_convs):
124
+ for i, layer in enumerate(stack):
125
+ band = layer(band)
126
+ band = torch.nn.functional.leaky_relu(band, 0.1)
127
+ if i > 0:
128
+ fmap.append(band)
129
+ x.append(band)
130
+ x = torch.cat(x, dim=-1)
131
+ if cond_embedding_id is not None:
132
+ emb = self.emb(cond_embedding_id)
133
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
134
+ else:
135
+ h = 0
136
+ x = self.conv_post(x)
137
+ fmap.append(x)
138
+ x += h
139
+
140
+ return x, fmap
cosyvoice/hifigan/f0_predictor.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn.utils import weight_norm
17
+
18
+
19
+ class ConvRNNF0Predictor(nn.Module):
20
+ def __init__(self,
21
+ num_class: int = 1,
22
+ in_channels: int = 80,
23
+ cond_channels: int = 512
24
+ ):
25
+ super().__init__()
26
+
27
+ self.num_class = num_class
28
+ self.condnet = nn.Sequential(
29
+ weight_norm(
30
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
31
+ ),
32
+ nn.ELU(),
33
+ weight_norm(
34
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
35
+ ),
36
+ nn.ELU(),
37
+ weight_norm(
38
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
39
+ ),
40
+ nn.ELU(),
41
+ weight_norm(
42
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
43
+ ),
44
+ nn.ELU(),
45
+ weight_norm(
46
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
47
+ ),
48
+ nn.ELU(),
49
+ )
50
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ x = self.condnet(x)
54
+ x = x.transpose(1, 2)
55
+ return torch.abs(self.classifier(x).squeeze(-1))
cosyvoice/hifigan/generator.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HIFI-GAN"""
16
+
17
+ from typing import Dict, Optional, List
18
+ import numpy as np
19
+ from scipy.signal import get_window
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.nn import Conv1d
24
+ from torch.nn import ConvTranspose1d
25
+ from torch.nn.utils import remove_weight_norm
26
+ from torch.nn.utils import weight_norm
27
+ from torch.distributions.uniform import Uniform
28
+
29
+ from cosyvoice.transformer.activation import Snake
30
+ from cosyvoice.utils.common import get_padding
31
+ from cosyvoice.utils.common import init_weights
32
+
33
+
34
+ """hifigan based generator implementation.
35
+
36
+ This code is modified from https://github.com/jik876/hifi-gan
37
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
38
+ https://github.com/NVIDIA/BigVGAN
39
+
40
+ """
41
+
42
+
43
+ class ResBlock(torch.nn.Module):
44
+ """Residual block module in HiFiGAN/BigVGAN."""
45
+ def __init__(
46
+ self,
47
+ channels: int = 512,
48
+ kernel_size: int = 3,
49
+ dilations: List[int] = [1, 3, 5],
50
+ ):
51
+ super(ResBlock, self).__init__()
52
+ self.convs1 = nn.ModuleList()
53
+ self.convs2 = nn.ModuleList()
54
+
55
+ for dilation in dilations:
56
+ self.convs1.append(
57
+ weight_norm(
58
+ Conv1d(
59
+ channels,
60
+ channels,
61
+ kernel_size,
62
+ 1,
63
+ dilation=dilation,
64
+ padding=get_padding(kernel_size, dilation)
65
+ )
66
+ )
67
+ )
68
+ self.convs2.append(
69
+ weight_norm(
70
+ Conv1d(
71
+ channels,
72
+ channels,
73
+ kernel_size,
74
+ 1,
75
+ dilation=1,
76
+ padding=get_padding(kernel_size, 1)
77
+ )
78
+ )
79
+ )
80
+ self.convs1.apply(init_weights)
81
+ self.convs2.apply(init_weights)
82
+ self.activations1 = nn.ModuleList([
83
+ Snake(channels, alpha_logscale=False)
84
+ for _ in range(len(self.convs1))
85
+ ])
86
+ self.activations2 = nn.ModuleList([
87
+ Snake(channels, alpha_logscale=False)
88
+ for _ in range(len(self.convs2))
89
+ ])
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ for idx in range(len(self.convs1)):
93
+ xt = self.activations1[idx](x)
94
+ xt = self.convs1[idx](xt)
95
+ xt = self.activations2[idx](xt)
96
+ xt = self.convs2[idx](xt)
97
+ x = xt + x
98
+ return x
99
+
100
+ def remove_weight_norm(self):
101
+ for idx in range(len(self.convs1)):
102
+ remove_weight_norm(self.convs1[idx])
103
+ remove_weight_norm(self.convs2[idx])
104
+
105
+
106
+ class SineGen(torch.nn.Module):
107
+ """ Definition of sine generator
108
+ SineGen(samp_rate, harmonic_num = 0,
109
+ sine_amp = 0.1, noise_std = 0.003,
110
+ voiced_threshold = 0,
111
+ flag_for_pulse=False)
112
+ samp_rate: sampling rate in Hz
113
+ harmonic_num: number of harmonic overtones (default 0)
114
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
115
+ noise_std: std of Gaussian noise (default 0.003)
116
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
117
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
118
+ Note: when flag_for_pulse is True, the first time step of a voiced
119
+ segment is always sin(np.pi) or cos(0)
120
+ """
121
+
122
+ def __init__(self, samp_rate, harmonic_num=0,
123
+ sine_amp=0.1, noise_std=0.003,
124
+ voiced_threshold=0):
125
+ super(SineGen, self).__init__()
126
+ self.sine_amp = sine_amp
127
+ self.noise_std = noise_std
128
+ self.harmonic_num = harmonic_num
129
+ self.sampling_rate = samp_rate
130
+ self.voiced_threshold = voiced_threshold
131
+
132
+ def _f02uv(self, f0):
133
+ # generate uv signal
134
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
135
+ return uv
136
+
137
+ @torch.no_grad()
138
+ def forward(self, f0):
139
+ """
140
+ :param f0: [B, 1, sample_len], Hz
141
+ :return: [B, 1, sample_len]
142
+ """
143
+
144
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
145
+ for i in range(self.harmonic_num + 1):
146
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
147
+
148
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
149
+ u_dist = Uniform(low=-np.pi, high=np.pi)
150
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
151
+ phase_vec[:, 0, :] = 0
152
+
153
+ # generate sine waveforms
154
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
155
+
156
+ # generate uv signal
157
+ uv = self._f02uv(f0)
158
+
159
+ # noise: for unvoiced should be similar to sine_amp
160
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
161
+ # . for voiced regions is self.noise_std
162
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
163
+ noise = noise_amp * torch.randn_like(sine_waves)
164
+
165
+ # first: set the unvoiced part to 0 by uv
166
+ # then: additive noise
167
+ sine_waves = sine_waves * uv + noise
168
+ return sine_waves, uv, noise
169
+
170
+
171
+ class SourceModuleHnNSF(torch.nn.Module):
172
+ """ SourceModule for hn-nsf
173
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
174
+ add_noise_std=0.003, voiced_threshod=0)
175
+ sampling_rate: sampling_rate in Hz
176
+ harmonic_num: number of harmonic above F0 (default: 0)
177
+ sine_amp: amplitude of sine source signal (default: 0.1)
178
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
179
+ note that amplitude of noise in unvoiced is decided
180
+ by sine_amp
181
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
182
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
183
+ F0_sampled (batchsize, length, 1)
184
+ Sine_source (batchsize, length, 1)
185
+ noise_source (batchsize, length 1)
186
+ uv (batchsize, length, 1)
187
+ """
188
+
189
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
190
+ add_noise_std=0.003, voiced_threshod=0):
191
+ super(SourceModuleHnNSF, self).__init__()
192
+
193
+ self.sine_amp = sine_amp
194
+ self.noise_std = add_noise_std
195
+
196
+ # to produce sine waveforms
197
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
198
+ sine_amp, add_noise_std, voiced_threshod)
199
+
200
+ # to merge source harmonics into a single excitation
201
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
202
+ self.l_tanh = torch.nn.Tanh()
203
+
204
+ def forward(self, x):
205
+ """
206
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
207
+ F0_sampled (batchsize, length, 1)
208
+ Sine_source (batchsize, length, 1)
209
+ noise_source (batchsize, length 1)
210
+ """
211
+ # source for harmonic branch
212
+ with torch.no_grad():
213
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
214
+ sine_wavs = sine_wavs.transpose(1, 2)
215
+ uv = uv.transpose(1, 2)
216
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
217
+
218
+ # source for noise branch, in the same shape as uv
219
+ noise = torch.randn_like(uv) * self.sine_amp / 3
220
+ return sine_merge, noise, uv
221
+
222
+
223
+ class HiFTGenerator(nn.Module):
224
+ """
225
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
226
+ https://arxiv.org/abs/2309.09493
227
+ """
228
+ def __init__(
229
+ self,
230
+ in_channels: int = 80,
231
+ base_channels: int = 512,
232
+ nb_harmonics: int = 8,
233
+ sampling_rate: int = 22050,
234
+ nsf_alpha: float = 0.1,
235
+ nsf_sigma: float = 0.003,
236
+ nsf_voiced_threshold: float = 10,
237
+ upsample_rates: List[int] = [8, 8],
238
+ upsample_kernel_sizes: List[int] = [16, 16],
239
+ istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
240
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
241
+ resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
242
+ source_resblock_kernel_sizes: List[int] = [7, 11],
243
+ source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
244
+ lrelu_slope: float = 0.1,
245
+ audio_limit: float = 0.99,
246
+ f0_predictor: torch.nn.Module = None,
247
+ ):
248
+ super(HiFTGenerator, self).__init__()
249
+
250
+ self.out_channels = 1
251
+ self.nb_harmonics = nb_harmonics
252
+ self.sampling_rate = sampling_rate
253
+ self.istft_params = istft_params
254
+ self.lrelu_slope = lrelu_slope
255
+ self.audio_limit = audio_limit
256
+
257
+ self.num_kernels = len(resblock_kernel_sizes)
258
+ self.num_upsamples = len(upsample_rates)
259
+ self.m_source = SourceModuleHnNSF(
260
+ sampling_rate=sampling_rate,
261
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
262
+ harmonic_num=nb_harmonics,
263
+ sine_amp=nsf_alpha,
264
+ add_noise_std=nsf_sigma,
265
+ voiced_threshod=nsf_voiced_threshold)
266
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
267
+
268
+ self.conv_pre = weight_norm(
269
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
270
+ )
271
+
272
+ # Up
273
+ self.ups = nn.ModuleList()
274
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
275
+ self.ups.append(
276
+ weight_norm(
277
+ ConvTranspose1d(
278
+ base_channels // (2**i),
279
+ base_channels // (2**(i + 1)),
280
+ k,
281
+ u,
282
+ padding=(k - u) // 2,
283
+ )
284
+ )
285
+ )
286
+
287
+ # Down
288
+ self.source_downs = nn.ModuleList()
289
+ self.source_resblocks = nn.ModuleList()
290
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
291
+ downsample_cum_rates = np.cumprod(downsample_rates)
292
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
293
+ if u == 1:
294
+ self.source_downs.append(
295
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
296
+ )
297
+ else:
298
+ self.source_downs.append(
299
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
300
+ )
301
+
302
+ self.source_resblocks.append(
303
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
304
+ )
305
+
306
+ self.resblocks = nn.ModuleList()
307
+ for i in range(len(self.ups)):
308
+ ch = base_channels // (2**(i + 1))
309
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
310
+ self.resblocks.append(ResBlock(ch, k, d))
311
+
312
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
313
+ self.ups.apply(init_weights)
314
+ self.conv_post.apply(init_weights)
315
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
316
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
317
+ self.f0_predictor = f0_predictor
318
+
319
+ def remove_weight_norm(self):
320
+ print('Removing weight norm...')
321
+ for l in self.ups:
322
+ remove_weight_norm(l)
323
+ for l in self.resblocks:
324
+ l.remove_weight_norm()
325
+ remove_weight_norm(self.conv_pre)
326
+ remove_weight_norm(self.conv_post)
327
+ self.m_source.remove_weight_norm()
328
+ for l in self.source_downs:
329
+ remove_weight_norm(l)
330
+ for l in self.source_resblocks:
331
+ l.remove_weight_norm()
332
+
333
+ def _stft(self, x):
334
+ spec = torch.stft(
335
+ x,
336
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
337
+ return_complex=True)
338
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
339
+ return spec[..., 0], spec[..., 1]
340
+
341
+ def _istft(self, magnitude, phase):
342
+ magnitude = torch.clip(magnitude, max=1e2)
343
+ real = magnitude * torch.cos(phase)
344
+ img = magnitude * torch.sin(phase)
345
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
346
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
347
+ return inverse_transform
348
+
349
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
350
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
351
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
352
+
353
+ x = self.conv_pre(x)
354
+ for i in range(self.num_upsamples):
355
+ x = F.leaky_relu(x, self.lrelu_slope)
356
+ x = self.ups[i](x)
357
+
358
+ if i == self.num_upsamples - 1:
359
+ x = self.reflection_pad(x)
360
+
361
+ # fusion
362
+ si = self.source_downs[i](s_stft)
363
+ si = self.source_resblocks[i](si)
364
+ x = x + si
365
+
366
+ xs = None
367
+ for j in range(self.num_kernels):
368
+ if xs is None:
369
+ xs = self.resblocks[i * self.num_kernels + j](x)
370
+ else:
371
+ xs += self.resblocks[i * self.num_kernels + j](x)
372
+ x = xs / self.num_kernels
373
+
374
+ x = F.leaky_relu(x)
375
+ x = self.conv_post(x)
376
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
377
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
378
+
379
+ x = self._istft(magnitude, phase)
380
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
381
+ return x
382
+
383
+ def forward(
384
+ self,
385
+ batch: dict,
386
+ device: torch.device,
387
+ ) -> Dict[str, Optional[torch.Tensor]]:
388
+ speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
389
+ # mel->f0
390
+ f0 = self.f0_predictor(speech_feat)
391
+ # f0->source
392
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
393
+ s, _, _ = self.m_source(s)
394
+ s = s.transpose(1, 2)
395
+ # mel+source->speech
396
+ generated_speech = self.decode(x=speech_feat, s=s)
397
+ return generated_speech, f0
398
+
399
+ @torch.inference_mode()
400
+ def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
401
+ # mel->f0
402
+ f0 = self.f0_predictor(speech_feat)
403
+ # f0->source
404
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
405
+ s, _, _ = self.m_source(s)
406
+ s = s.transpose(1, 2)
407
+ # use cache_source to avoid glitch
408
+ if cache_source.shape[2] != 0:
409
+ s[:, :, :cache_source.shape[2]] = cache_source
410
+ generated_speech = self.decode(x=speech_feat, s=s)
411
+ return generated_speech, s
cosyvoice/hifigan/hifigan.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
6
+ from cosyvoice.utils.losses import tpr_loss, mel_loss
7
+
8
+
9
+ class HiFiGan(nn.Module):
10
+ def __init__(self, generator, discriminator, mel_spec_transform,
11
+ multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
12
+ tpr_loss_weight=1.0, tpr_loss_tau=0.04):
13
+ super(HiFiGan, self).__init__()
14
+ self.generator = generator
15
+ self.discriminator = discriminator
16
+ self.mel_spec_transform = mel_spec_transform
17
+ self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
18
+ self.feat_match_loss_weight = feat_match_loss_weight
19
+ self.tpr_loss_weight = tpr_loss_weight
20
+ self.tpr_loss_tau = tpr_loss_tau
21
+
22
+ def forward(
23
+ self,
24
+ batch: dict,
25
+ device: torch.device,
26
+ ) -> Dict[str, Optional[torch.Tensor]]:
27
+ if batch['turn'] == 'generator':
28
+ return self.forward_generator(batch, device)
29
+ else:
30
+ return self.forward_discriminator(batch, device)
31
+
32
+ def forward_generator(self, batch, device):
33
+ real_speech = batch['speech'].to(device)
34
+ pitch_feat = batch['pitch_feat'].to(device)
35
+ # 1. calculate generator outputs
36
+ generated_speech, generated_f0 = self.generator(batch, device)
37
+ # 2. calculate discriminator outputs
38
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
39
+ # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
40
+ loss_gen, _ = generator_loss(y_d_gs)
41
+ loss_fm = feature_loss(fmap_rs, fmap_gs)
42
+ loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
43
+ if self.tpr_loss_weight != 0:
44
+ loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
45
+ else:
46
+ loss_tpr = torch.zeros(1).to(device)
47
+ loss_f0 = F.l1_loss(generated_f0, pitch_feat)
48
+ loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
49
+ self.multi_mel_spectral_recon_loss_weight * loss_mel + \
50
+ self.tpr_loss_weight * loss_tpr + loss_f0
51
+ return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
52
+
53
+ def forward_discriminator(self, batch, device):
54
+ real_speech = batch['speech'].to(device)
55
+ # 1. calculate generator outputs
56
+ with torch.no_grad():
57
+ generated_speech, generated_f0 = self.generator(batch, device)
58
+ # 2. calculate discriminator outputs
59
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
60
+ # 3. calculate discriminator losses, tpr losses [Optional]
61
+ loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
62
+ if self.tpr_loss_weight != 0:
63
+ loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
64
+ else:
65
+ loss_tpr = torch.zeros(1).to(device)
66
+ loss = loss_disc + self.tpr_loss_weight * loss_tpr
67
+ return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr}
cosyvoice/llm/llm.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Callable, List, Generator
15
+ import torch
16
+ from torch import nn
17
+ import torch.nn.functional as F
18
+ from transformers import Qwen2ForCausalLM
19
+ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
20
+ from cosyvoice.utils.common import IGNORE_ID
21
+ from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
22
+ from cosyvoice.utils.common import th_accuracy
23
+
24
+
25
+ class TransformerLM(torch.nn.Module):
26
+ def __init__(
27
+ self,
28
+ text_encoder_input_size: int,
29
+ llm_input_size: int,
30
+ llm_output_size: int,
31
+ text_token_size: int,
32
+ speech_token_size: int,
33
+ text_encoder: torch.nn.Module,
34
+ llm: torch.nn.Module,
35
+ sampling: Callable,
36
+ length_normalized_loss: bool = True,
37
+ lsm_weight: float = 0.0,
38
+ spk_embed_dim: int = 192,
39
+ ):
40
+ super().__init__()
41
+ self.llm_input_size = llm_input_size
42
+ self.speech_token_size = speech_token_size
43
+ # 1. build text token inputs related modules
44
+ self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
45
+ self.text_encoder = text_encoder
46
+ self.text_encoder_affine_layer = nn.Linear(
47
+ self.text_encoder.output_size(),
48
+ llm_input_size
49
+ )
50
+
51
+ # 2. build speech token language model related modules
52
+ self.sos_eos = 0
53
+ self.task_id = 1
54
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
55
+ self.llm = llm
56
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
57
+ self.criterion_ce = LabelSmoothingLoss(
58
+ size=speech_token_size + 1,
59
+ padding_idx=IGNORE_ID,
60
+ smoothing=lsm_weight,
61
+ normalize_length=length_normalized_loss,
62
+ )
63
+
64
+ # 3. [Optional] build speech token related modules
65
+ self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
66
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
67
+
68
+ # 4. sampling method
69
+ self.sampling = sampling
70
+
71
+ def encode(
72
+ self,
73
+ text: torch.Tensor,
74
+ text_lengths: torch.Tensor,
75
+ ):
76
+ encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
77
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
78
+ encoder_out = self.text_encoder_affine_layer(encoder_out)
79
+ return encoder_out, encoder_out_lens
80
+
81
+ def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
82
+ text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
83
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
84
+ lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
85
+ for i in range(len(text_token))]
86
+ lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
87
+ lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
88
+ return lm_input, lm_input_len
89
+
90
+ def forward(
91
+ self,
92
+ batch: dict,
93
+ device: torch.device,
94
+ ) -> Dict[str, Optional[torch.Tensor]]:
95
+ """
96
+ Args:
97
+ text: (B, L, D)
98
+ text_lengths: (B,)
99
+ audio: (B, T, N) or (B, T)
100
+ audio_lengths: (B,)
101
+ """
102
+ text_token = batch['text_token'].to(device)
103
+ text_token_len = batch['text_token_len'].to(device)
104
+ speech_token = batch['speech_token'].to(device)
105
+ speech_token_len = batch['speech_token_len'].to(device)
106
+ embedding = batch['embedding'].to(device)
107
+
108
+ # 1. prepare llm_target
109
+ lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
110
+ [self.speech_token_size]) for i in range(text_token.size(0))]
111
+ lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
112
+
113
+ # 1. encode text_token
114
+ text_token = self.text_embedding(text_token)
115
+ text_token, text_token_len = self.encode(text_token, text_token_len)
116
+
117
+ # 2. embedding projection
118
+ embedding = F.normalize(embedding, dim=1)
119
+ embedding = self.spk_embed_affine_layer(embedding)
120
+ embedding = embedding.unsqueeze(1)
121
+
122
+ # 3. eos and task_id
123
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
124
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
125
+
126
+ # 4. encode speech_token
127
+ speech_token = self.speech_embedding(speech_token)
128
+
129
+ # 5. unpad and pad
130
+ lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
131
+ task_id_emb, speech_token, speech_token_len)
132
+
133
+ # 6. run lm forward
134
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
135
+ logits = self.llm_decoder(lm_output)
136
+ loss = self.criterion_ce(logits, lm_target)
137
+ acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
138
+ return {'loss': loss, 'acc': acc}
139
+
140
+ def sampling_ids(
141
+ self,
142
+ weighted_scores: torch.Tensor,
143
+ decoded_tokens: List,
144
+ sampling: int,
145
+ ignore_eos: bool = True,
146
+ ):
147
+ while True:
148
+ top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
149
+ if (not ignore_eos) or (self.speech_token_size not in top_ids):
150
+ break
151
+ return top_ids
152
+
153
+ @torch.inference_mode()
154
+ def inference(
155
+ self,
156
+ text: torch.Tensor,
157
+ text_len: torch.Tensor,
158
+ prompt_text: torch.Tensor,
159
+ prompt_text_len: torch.Tensor,
160
+ prompt_speech_token: torch.Tensor,
161
+ prompt_speech_token_len: torch.Tensor,
162
+ embedding: torch.Tensor,
163
+ sampling: int = 25,
164
+ max_token_text_ratio: float = 20,
165
+ min_token_text_ratio: float = 2,
166
+ ) -> Generator[torch.Tensor, None, None]:
167
+ device = text.device
168
+ text = torch.concat([prompt_text, text], dim=1)
169
+ text_len += prompt_text_len
170
+ text = self.text_embedding(text)
171
+
172
+ # 1. encode text
173
+ text, text_len = self.encode(text, text_len)
174
+
175
+ # 2. encode embedding
176
+ if embedding.shape[0] != 0:
177
+ embedding = F.normalize(embedding, dim=1)
178
+ embedding = self.spk_embed_affine_layer(embedding)
179
+ embedding = embedding.unsqueeze(dim=1)
180
+ else:
181
+ embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
182
+
183
+ # 3. concat llm_input
184
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
185
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
186
+ if prompt_speech_token_len != 0:
187
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
188
+ else:
189
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
190
+ lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
191
+
192
+ # 4. cal min/max_length
193
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
194
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
195
+
196
+ # 5. step by step decode
197
+ out_tokens = []
198
+ offset = 0
199
+ att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
200
+ for i in range(max_len):
201
+ y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
202
+ att_cache=att_cache, cnn_cache=cnn_cache,
203
+ att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
204
+ device=lm_input.device)).to(torch.bool))
205
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
206
+ # force continue decode first token
207
+ if i == 0:
208
+ logp[:, self.speech_token_size] = -float('inf')
209
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
210
+ if top_ids == self.speech_token_size:
211
+ break
212
+ # in stream mode, yield token one by one
213
+ yield top_ids
214
+ out_tokens.append(top_ids)
215
+ offset += lm_input.size(1)
216
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
217
+
218
+
219
+ class Qwen2Encoder(torch.nn.Module):
220
+ def __init__(self, pretrain_path):
221
+ super().__init__()
222
+ self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
223
+
224
+ def forward_one_step(self, xs, masks, cache=None):
225
+ input_masks = masks[:, -1, :]
226
+ outs = self.model(
227
+ inputs_embeds=xs,
228
+ attention_mask=input_masks,
229
+ output_hidden_states=True,
230
+ return_dict=True,
231
+ use_cache=True,
232
+ past_key_values=cache,
233
+ )
234
+ xs = outs.hidden_states[-1]
235
+ new_cache = outs.past_key_values
236
+ return xs, new_cache
237
+
238
+
239
+ class Qwen2LM(torch.nn.Module):
240
+ def __init__(
241
+ self,
242
+ llm_input_size: int,
243
+ llm_output_size: int,
244
+ speech_token_size: int,
245
+ llm: torch.nn.Module,
246
+ sampling: Callable,
247
+ length_normalized_loss: bool = True,
248
+ lsm_weight: float = 0.0,
249
+ ):
250
+ super().__init__()
251
+ self.llm_input_size = llm_input_size
252
+ self.llm_output_size = llm_output_size
253
+ self.speech_token_size = speech_token_size
254
+
255
+ # 2. build speech token language model related modules
256
+ self.sos_eos = 0
257
+ self.task_id = 1
258
+ self.fill_token = 2
259
+
260
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
261
+ self.llm = llm
262
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
263
+ self.criterion_ce = LabelSmoothingLoss(
264
+ size=speech_token_size + 3,
265
+ padding_idx=IGNORE_ID,
266
+ smoothing=lsm_weight,
267
+ normalize_length=length_normalized_loss,
268
+ )
269
+
270
+ # 3. [Optional] build speech token related modules
271
+ self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
272
+
273
+ # 4. sampling method
274
+ self.sampling = sampling
275
+
276
+ def sampling_ids(
277
+ self,
278
+ weighted_scores: torch.Tensor,
279
+ decoded_tokens: List,
280
+ sampling: int,
281
+ ignore_eos: bool = True,
282
+ ):
283
+ while True:
284
+ top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
285
+ if (not ignore_eos) or (self.speech_token_size not in top_ids):
286
+ break
287
+ return top_ids
288
+
289
+ @torch.inference_mode()
290
+ def inference(
291
+ self,
292
+ text: torch.Tensor,
293
+ text_len: torch.Tensor,
294
+ prompt_text: torch.Tensor,
295
+ prompt_text_len: torch.Tensor,
296
+ prompt_speech_token: torch.Tensor,
297
+ prompt_speech_token_len: torch.Tensor,
298
+ embedding: torch.Tensor,
299
+ sampling: int = 25,
300
+ max_token_text_ratio: float = 20,
301
+ min_token_text_ratio: float = 2,
302
+ ) -> Generator[torch.Tensor, None, None]:
303
+ device = text.device
304
+ text = torch.concat([prompt_text, text], dim=1)
305
+ text_len += prompt_text_len
306
+ text = self.llm.model.model.embed_tokens(text)
307
+
308
+ # 2. encode embedding
309
+ embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
310
+
311
+ # 3. concat llm_input
312
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
313
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
314
+ if prompt_speech_token_len != 0:
315
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
316
+ else:
317
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
318
+ lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
319
+
320
+ # 4. cal min/max_length
321
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
322
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
323
+
324
+ # 5. step by step decode
325
+ out_tokens = []
326
+ cache = None
327
+ for i in range(max_len):
328
+ y_pred, cache = self.llm.forward_one_step(lm_input,
329
+ masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
330
+ cache=cache)
331
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
332
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
333
+ if top_ids == self.speech_token_size:
334
+ break
335
+ if top_ids > self.speech_token_size:
336
+ continue
337
+ # in stream mode, yield token one by one
338
+ yield top_ids
339
+ out_tokens.append(top_ids)
340
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
cosyvoice/tokenizer/tokenizer.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ from functools import lru_cache
4
+ from typing import Optional
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+ from whisper.tokenizer import Tokenizer
8
+
9
+ import tiktoken
10
+
11
+ LANGUAGES = {
12
+ "en": "english",
13
+ "zh": "chinese",
14
+ "de": "german",
15
+ "es": "spanish",
16
+ "ru": "russian",
17
+ "ko": "korean",
18
+ "fr": "french",
19
+ "ja": "japanese",
20
+ "pt": "portuguese",
21
+ "tr": "turkish",
22
+ "pl": "polish",
23
+ "ca": "catalan",
24
+ "nl": "dutch",
25
+ "ar": "arabic",
26
+ "sv": "swedish",
27
+ "it": "italian",
28
+ "id": "indonesian",
29
+ "hi": "hindi",
30
+ "fi": "finnish",
31
+ "vi": "vietnamese",
32
+ "he": "hebrew",
33
+ "uk": "ukrainian",
34
+ "el": "greek",
35
+ "ms": "malay",
36
+ "cs": "czech",
37
+ "ro": "romanian",
38
+ "da": "danish",
39
+ "hu": "hungarian",
40
+ "ta": "tamil",
41
+ "no": "norwegian",
42
+ "th": "thai",
43
+ "ur": "urdu",
44
+ "hr": "croatian",
45
+ "bg": "bulgarian",
46
+ "lt": "lithuanian",
47
+ "la": "latin",
48
+ "mi": "maori",
49
+ "ml": "malayalam",
50
+ "cy": "welsh",
51
+ "sk": "slovak",
52
+ "te": "telugu",
53
+ "fa": "persian",
54
+ "lv": "latvian",
55
+ "bn": "bengali",
56
+ "sr": "serbian",
57
+ "az": "azerbaijani",
58
+ "sl": "slovenian",
59
+ "kn": "kannada",
60
+ "et": "estonian",
61
+ "mk": "macedonian",
62
+ "br": "breton",
63
+ "eu": "basque",
64
+ "is": "icelandic",
65
+ "hy": "armenian",
66
+ "ne": "nepali",
67
+ "mn": "mongolian",
68
+ "bs": "bosnian",
69
+ "kk": "kazakh",
70
+ "sq": "albanian",
71
+ "sw": "swahili",
72
+ "gl": "galician",
73
+ "mr": "marathi",
74
+ "pa": "punjabi",
75
+ "si": "sinhala",
76
+ "km": "khmer",
77
+ "sn": "shona",
78
+ "yo": "yoruba",
79
+ "so": "somali",
80
+ "af": "afrikaans",
81
+ "oc": "occitan",
82
+ "ka": "georgian",
83
+ "be": "belarusian",
84
+ "tg": "tajik",
85
+ "sd": "sindhi",
86
+ "gu": "gujarati",
87
+ "am": "amharic",
88
+ "yi": "yiddish",
89
+ "lo": "lao",
90
+ "uz": "uzbek",
91
+ "fo": "faroese",
92
+ "ht": "haitian creole",
93
+ "ps": "pashto",
94
+ "tk": "turkmen",
95
+ "nn": "nynorsk",
96
+ "mt": "maltese",
97
+ "sa": "sanskrit",
98
+ "lb": "luxembourgish",
99
+ "my": "myanmar",
100
+ "bo": "tibetan",
101
+ "tl": "tagalog",
102
+ "mg": "malagasy",
103
+ "as": "assamese",
104
+ "tt": "tatar",
105
+ "haw": "hawaiian",
106
+ "ln": "lingala",
107
+ "ha": "hausa",
108
+ "ba": "bashkir",
109
+ "jw": "javanese",
110
+ "su": "sundanese",
111
+ "yue": "cantonese",
112
+ "minnan": "minnan",
113
+ "wuyu": "wuyu",
114
+ "dialect": "dialect",
115
+ "zh/en": "zh/en",
116
+ "en/zh": "en/zh",
117
+ }
118
+
119
+ # language code lookup by name, with a few language aliases
120
+ TO_LANGUAGE_CODE = {
121
+ **{language: code for code, language in LANGUAGES.items()},
122
+ "burmese": "my",
123
+ "valencian": "ca",
124
+ "flemish": "nl",
125
+ "haitian": "ht",
126
+ "letzeburgesch": "lb",
127
+ "pushto": "ps",
128
+ "panjabi": "pa",
129
+ "moldavian": "ro",
130
+ "moldovan": "ro",
131
+ "sinhalese": "si",
132
+ "castilian": "es",
133
+ "mandarin": "zh",
134
+ }
135
+
136
+ AUDIO_EVENT = {
137
+ "ASR": "ASR",
138
+ "AED": "AED",
139
+ "SER": "SER",
140
+ "Speech": "Speech",
141
+ "/Speech": "/Speech",
142
+ "BGM": "BGM",
143
+ "/BGM": "/BGM",
144
+ "Laughter": "Laughter",
145
+ "/Laughter": "/Laughter",
146
+ "Applause": "Applause",
147
+ "/Applause": "/Applause",
148
+ }
149
+
150
+ EMOTION = {
151
+ "HAPPY": "HAPPY",
152
+ "SAD": "SAD",
153
+ "ANGRY": "ANGRY",
154
+ "NEUTRAL": "NEUTRAL",
155
+ }
156
+
157
+ TTS_Vocal_Token = {
158
+ "TTS/B": "TTS/B",
159
+ "TTS/O": "TTS/O",
160
+ "TTS/Q": "TTS/Q",
161
+ "TTS/A": "TTS/A",
162
+ "TTS/CO": "TTS/CO",
163
+ "TTS/CL": "TTS/CL",
164
+ "TTS/H": "TTS/H",
165
+ **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
166
+ }
167
+
168
+
169
+ @lru_cache(maxsize=None)
170
+ def get_encoding(name: str = "gpt2", num_languages: int = 99):
171
+ vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
172
+ ranks = {
173
+ base64.b64decode(token): int(rank)
174
+ for token, rank in (line.split() for line in open(vocab_path) if line)
175
+ }
176
+ n_vocab = len(ranks)
177
+ special_tokens = {}
178
+
179
+ specials = [
180
+ "<|endoftext|>",
181
+ "<|startoftranscript|>",
182
+ *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
183
+ *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
184
+ *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
185
+ "<|translate|>",
186
+ "<|transcribe|>",
187
+ "<|startoflm|>",
188
+ "<|startofprev|>",
189
+ "<|nospeech|>",
190
+ "<|notimestamps|>",
191
+ *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
192
+ *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
193
+ *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
194
+ ]
195
+
196
+ for token in specials:
197
+ special_tokens[token] = n_vocab
198
+ n_vocab += 1
199
+
200
+ return tiktoken.Encoding(
201
+ name=os.path.basename(vocab_path),
202
+ explicit_n_vocab=n_vocab,
203
+ pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
204
+ mergeable_ranks=ranks,
205
+ special_tokens=special_tokens,
206
+ )
207
+
208
+
209
+ @lru_cache(maxsize=None)
210
+ def get_tokenizer(
211
+ multilingual: bool,
212
+ *,
213
+ num_languages: int = 99,
214
+ language: Optional[str] = None,
215
+ task: Optional[str] = None, # Literal["transcribe", "translate", None]
216
+ ) -> Tokenizer:
217
+ if language is not None:
218
+ language = language.lower()
219
+ if language not in LANGUAGES:
220
+ if language in TO_LANGUAGE_CODE:
221
+ language = TO_LANGUAGE_CODE[language]
222
+ else:
223
+ raise ValueError(f"Unsupported language: {language}")
224
+
225
+ if multilingual:
226
+ encoding_name = "multilingual_zh_ja_yue_char_del"
227
+ language = language or "en"
228
+ task = task or "transcribe"
229
+ else:
230
+ encoding_name = "gpt2"
231
+ language = None
232
+ task = None
233
+
234
+ encoding = get_encoding(name=encoding_name, num_languages=num_languages)
235
+
236
+ return Tokenizer(
237
+ encoding=encoding, num_languages=num_languages, language=language, task=task
238
+ )
239
+
240
+
241
+ class QwenTokenizer():
242
+ def __init__(self, token_path, skip_special_tokens=True):
243
+ super().__init__()
244
+ # NOTE: non-chat model, all these special tokens keep randomly initialized.
245
+ special_tokens = {
246
+ 'eos_token': '<|endoftext|>',
247
+ 'pad_token': '<|endoftext|>',
248
+ 'additional_special_tokens': [
249
+ '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
250
+ '[breath]', '<strong>', '</strong>', '[noise]',
251
+ '[laughter]', '[cough]', '[clucking]', '[accent]',
252
+ '[quick_breath]',
253
+ "<laughter>", "</laughter>",
254
+ "[hissing]", "[sigh]", "[vocalized-noise]",
255
+ "[lipsmack]", "[mn]"
256
+ ]
257
+ }
258
+ self.tokenizer = AutoTokenizer.from_pretrained(token_path)
259
+ self.tokenizer.add_special_tokens(special_tokens)
260
+ self.skip_special_tokens = skip_special_tokens
261
+
262
+ def encode(self, text, **kwargs):
263
+ tokens = self.tokenizer([text], return_tensors="pt")
264
+ tokens = tokens["input_ids"][0].cpu().tolist()
265
+ return tokens
266
+
267
+ def decode(self, tokens):
268
+ tokens = torch.tensor(tokens, dtype=torch.int64)
269
+ text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
270
+ return text
271
+
272
+ @lru_cache(maxsize=None)
273
+ def get_qwen_tokenizer(
274
+ token_path: str,
275
+ skip_special_tokens: bool
276
+ ) -> QwenTokenizer:
277
+ return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
cosyvoice/transformer/__init__.py ADDED
File without changes
cosyvoice/transformer/activation.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # 2020 Northwestern Polytechnical University (Pengcheng Guo)
3
+ # 2020 Mobvoi Inc (Binbin Zhang)
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Swish() activation function for Conformer."""
18
+
19
+ import torch
20
+ from torch import nn, sin, pow
21
+ from torch.nn import Parameter
22
+
23
+
24
+ class Swish(torch.nn.Module):
25
+ """Construct an Swish object."""
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ """Return Swish activation function."""
29
+ return x * torch.sigmoid(x)
30
+
31
+
32
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
33
+ # LICENSE is in incl_licenses directory.
34
+ class Snake(nn.Module):
35
+ '''
36
+ Implementation of a sine-based periodic activation function
37
+ Shape:
38
+ - Input: (B, C, T)
39
+ - Output: (B, C, T), same shape as the input
40
+ Parameters:
41
+ - alpha - trainable parameter
42
+ References:
43
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
44
+ https://arxiv.org/abs/2006.08195
45
+ Examples:
46
+ >>> a1 = snake(256)
47
+ >>> x = torch.randn(256)
48
+ >>> x = a1(x)
49
+ '''
50
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
51
+ '''
52
+ Initialization.
53
+ INPUT:
54
+ - in_features: shape of the input
55
+ - alpha: trainable parameter
56
+ alpha is initialized to 1 by default, higher values = higher-frequency.
57
+ alpha will be trained along with the rest of your model.
58
+ '''
59
+ super(Snake, self).__init__()
60
+ self.in_features = in_features
61
+
62
+ # initialize alpha
63
+ self.alpha_logscale = alpha_logscale
64
+ if self.alpha_logscale: # log scale alphas initialized to zeros
65
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
66
+ else: # linear scale alphas initialized to ones
67
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
68
+
69
+ self.alpha.requires_grad = alpha_trainable
70
+
71
+ self.no_div_by_zero = 0.000000001
72
+
73
+ def forward(self, x):
74
+ '''
75
+ Forward pass of the function.
76
+ Applies the function to the input elementwise.
77
+ Snake ∶= x + 1/a * sin^2 (xa)
78
+ '''
79
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
80
+ if self.alpha_logscale:
81
+ alpha = torch.exp(alpha)
82
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
83
+
84
+ return x
cosyvoice/transformer/attention.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2022 Xingchen Song ([email protected])
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Multi-Head Attention layer definition."""
18
+
19
+ import math
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+
26
+ class MultiHeadedAttention(nn.Module):
27
+ """Multi-Head Attention layer.
28
+
29
+ Args:
30
+ n_head (int): The number of heads.
31
+ n_feat (int): The number of features.
32
+ dropout_rate (float): Dropout rate.
33
+
34
+ """
35
+
36
+ def __init__(self,
37
+ n_head: int,
38
+ n_feat: int,
39
+ dropout_rate: float,
40
+ key_bias: bool = True):
41
+ """Construct an MultiHeadedAttention object."""
42
+ super().__init__()
43
+ assert n_feat % n_head == 0
44
+ # We assume d_v always equals d_k
45
+ self.d_k = n_feat // n_head
46
+ self.h = n_head
47
+ self.linear_q = nn.Linear(n_feat, n_feat)
48
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
49
+ self.linear_v = nn.Linear(n_feat, n_feat)
50
+ self.linear_out = nn.Linear(n_feat, n_feat)
51
+ self.dropout = nn.Dropout(p=dropout_rate)
52
+
53
+ def forward_qkv(
54
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
55
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
56
+ """Transform query, key and value.
57
+
58
+ Args:
59
+ query (torch.Tensor): Query tensor (#batch, time1, size).
60
+ key (torch.Tensor): Key tensor (#batch, time2, size).
61
+ value (torch.Tensor): Value tensor (#batch, time2, size).
62
+
63
+ Returns:
64
+ torch.Tensor: Transformed query tensor, size
65
+ (#batch, n_head, time1, d_k).
66
+ torch.Tensor: Transformed key tensor, size
67
+ (#batch, n_head, time2, d_k).
68
+ torch.Tensor: Transformed value tensor, size
69
+ (#batch, n_head, time2, d_k).
70
+
71
+ """
72
+ n_batch = query.size(0)
73
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
74
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
75
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
76
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
77
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
78
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
79
+
80
+ return q, k, v
81
+
82
+ def forward_attention(
83
+ self,
84
+ value: torch.Tensor,
85
+ scores: torch.Tensor,
86
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
87
+ ) -> torch.Tensor:
88
+ """Compute attention context vector.
89
+
90
+ Args:
91
+ value (torch.Tensor): Transformed value, size
92
+ (#batch, n_head, time2, d_k).
93
+ scores (torch.Tensor): Attention score, size
94
+ (#batch, n_head, time1, time2).
95
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
96
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
97
+
98
+ Returns:
99
+ torch.Tensor: Transformed value (#batch, time1, d_model)
100
+ weighted by the attention score (#batch, time1, time2).
101
+
102
+ """
103
+ n_batch = value.size(0)
104
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
105
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
106
+ # 1st chunk to ease the onnx export.]
107
+ # 2. pytorch training
108
+ if mask.size(2) > 0: # time2 > 0
109
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
110
+ # For last chunk, time2 might be larger than scores.size(-1)
111
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
112
+ scores = scores.masked_fill(mask, -float('inf'))
113
+ attn = torch.softmax(scores, dim=-1).masked_fill(
114
+ mask, 0.0) # (batch, head, time1, time2)
115
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
116
+ # 1. onnx(16/-1, -1/-1, 16/0)
117
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
118
+ else:
119
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
120
+
121
+ p_attn = self.dropout(attn)
122
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
123
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
124
+ self.h * self.d_k)
125
+ ) # (batch, time1, d_model)
126
+
127
+ return self.linear_out(x) # (batch, time1, d_model)
128
+
129
+ def forward(
130
+ self,
131
+ query: torch.Tensor,
132
+ key: torch.Tensor,
133
+ value: torch.Tensor,
134
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
135
+ pos_emb: torch.Tensor = torch.empty(0),
136
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
137
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
138
+ """Compute scaled dot product attention.
139
+
140
+ Args:
141
+ query (torch.Tensor): Query tensor (#batch, time1, size).
142
+ key (torch.Tensor): Key tensor (#batch, time2, size).
143
+ value (torch.Tensor): Value tensor (#batch, time2, size).
144
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
145
+ (#batch, time1, time2).
146
+ 1.When applying cross attention between decoder and encoder,
147
+ the batch padding mask for input is in (#batch, 1, T) shape.
148
+ 2.When applying self attention of encoder,
149
+ the mask is in (#batch, T, T) shape.
150
+ 3.When applying self attention of decoder,
151
+ the mask is in (#batch, L, L) shape.
152
+ 4.If the different position in decoder see different block
153
+ of the encoder, such as Mocha, the passed in mask could be
154
+ in (#batch, L, T) shape. But there is no such case in current
155
+ CosyVoice.
156
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
157
+ where `cache_t == chunk_size * num_decoding_left_chunks`
158
+ and `head * d_k == size`
159
+
160
+
161
+ Returns:
162
+ torch.Tensor: Output tensor (#batch, time1, d_model).
163
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
164
+ where `cache_t == chunk_size * num_decoding_left_chunks`
165
+ and `head * d_k == size`
166
+
167
+ """
168
+ q, k, v = self.forward_qkv(query, key, value)
169
+
170
+ # NOTE(xcsong):
171
+ # when export onnx model, for 1st chunk, we feed
172
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
173
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
174
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
175
+ # and we will always do splitting and
176
+ # concatnation(this will simplify onnx export). Note that
177
+ # it's OK to concat & split zero-shaped tensors(see code below).
178
+ # when export jit model, for 1st chunk, we always feed
179
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
180
+ # >>> a = torch.ones((1, 2, 0, 4))
181
+ # >>> b = torch.ones((1, 2, 3, 4))
182
+ # >>> c = torch.cat((a, b), dim=2)
183
+ # >>> torch.equal(b, c) # True
184
+ # >>> d = torch.split(a, 2, dim=-1)
185
+ # >>> torch.equal(d[0], d[1]) # True
186
+ if cache.size(0) > 0:
187
+ key_cache, value_cache = torch.split(cache,
188
+ cache.size(-1) // 2,
189
+ dim=-1)
190
+ k = torch.cat([key_cache, k], dim=2)
191
+ v = torch.cat([value_cache, v], dim=2)
192
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
193
+ # non-trivial to calculate `next_cache_start` here.
194
+ new_cache = torch.cat((k, v), dim=-1)
195
+
196
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
197
+ return self.forward_attention(v, scores, mask), new_cache
198
+
199
+
200
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
201
+ """Multi-Head Attention layer with relative position encoding.
202
+ Paper: https://arxiv.org/abs/1901.02860
203
+ Args:
204
+ n_head (int): The number of heads.
205
+ n_feat (int): The number of features.
206
+ dropout_rate (float): Dropout rate.
207
+ """
208
+
209
+ def __init__(self,
210
+ n_head: int,
211
+ n_feat: int,
212
+ dropout_rate: float,
213
+ key_bias: bool = True):
214
+ """Construct an RelPositionMultiHeadedAttention object."""
215
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
216
+ # linear transformation for positional encoding
217
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
218
+ # these two learnable bias are used in matrix c and matrix d
219
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
220
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
221
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
222
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
223
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
224
+
225
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
226
+ """Compute relative positional encoding.
227
+
228
+ Args:
229
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
230
+ time1 means the length of query vector.
231
+
232
+ Returns:
233
+ torch.Tensor: Output tensor.
234
+
235
+ """
236
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
237
+ device=x.device,
238
+ dtype=x.dtype)
239
+ x_padded = torch.cat([zero_pad, x], dim=-1)
240
+
241
+ x_padded = x_padded.view(x.size()[0],
242
+ x.size()[1],
243
+ x.size(3) + 1, x.size(2))
244
+ x = x_padded[:, :, 1:].view_as(x)[
245
+ :, :, :, : x.size(-1) // 2 + 1
246
+ ] # only keep the positions from 0 to time2
247
+ return x
248
+
249
+ def forward(
250
+ self,
251
+ query: torch.Tensor,
252
+ key: torch.Tensor,
253
+ value: torch.Tensor,
254
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
255
+ pos_emb: torch.Tensor = torch.empty(0),
256
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
257
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
258
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
259
+ Args:
260
+ query (torch.Tensor): Query tensor (#batch, time1, size).
261
+ key (torch.Tensor): Key tensor (#batch, time2, size).
262
+ value (torch.Tensor): Value tensor (#batch, time2, size).
263
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
264
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
265
+ pos_emb (torch.Tensor): Positional embedding tensor
266
+ (#batch, time2, size).
267
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
268
+ where `cache_t == chunk_size * num_decoding_left_chunks`
269
+ and `head * d_k == size`
270
+ Returns:
271
+ torch.Tensor: Output tensor (#batch, time1, d_model).
272
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
273
+ where `cache_t == chunk_size * num_decoding_left_chunks`
274
+ and `head * d_k == size`
275
+ """
276
+ q, k, v = self.forward_qkv(query, key, value)
277
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
278
+
279
+ # NOTE(xcsong):
280
+ # when export onnx model, for 1st chunk, we feed
281
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
282
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
283
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
284
+ # and we will always do splitting and
285
+ # concatnation(this will simplify onnx export). Note that
286
+ # it's OK to concat & split zero-shaped tensors(see code below).
287
+ # when export jit model, for 1st chunk, we always feed
288
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
289
+ # >>> a = torch.ones((1, 2, 0, 4))
290
+ # >>> b = torch.ones((1, 2, 3, 4))
291
+ # >>> c = torch.cat((a, b), dim=2)
292
+ # >>> torch.equal(b, c) # True
293
+ # >>> d = torch.split(a, 2, dim=-1)
294
+ # >>> torch.equal(d[0], d[1]) # True
295
+ if cache.size(0) > 0:
296
+ key_cache, value_cache = torch.split(cache,
297
+ cache.size(-1) // 2,
298
+ dim=-1)
299
+ k = torch.cat([key_cache, k], dim=2)
300
+ v = torch.cat([value_cache, v], dim=2)
301
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
302
+ # non-trivial to calculate `next_cache_start` here.
303
+ new_cache = torch.cat((k, v), dim=-1)
304
+
305
+ n_batch_pos = pos_emb.size(0)
306
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
307
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
308
+
309
+ # (batch, head, time1, d_k)
310
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
311
+ # (batch, head, time1, d_k)
312
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
313
+
314
+ # compute attention score
315
+ # first compute matrix a and matrix c
316
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
317
+ # (batch, head, time1, time2)
318
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
319
+
320
+ # compute matrix b and matrix d
321
+ # (batch, head, time1, time2)
322
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
323
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
324
+ if matrix_ac.shape != matrix_bd.shape:
325
+ matrix_bd = self.rel_shift(matrix_bd)
326
+
327
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
328
+ self.d_k) # (batch, head, time1, time2)
329
+
330
+ return self.forward_attention(v, scores, mask), new_cache
cosyvoice/transformer/convolution.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """ConvolutionModule definition."""
17
+
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class ConvolutionModule(nn.Module):
25
+ """ConvolutionModule in Conformer model."""
26
+
27
+ def __init__(self,
28
+ channels: int,
29
+ kernel_size: int = 15,
30
+ activation: nn.Module = nn.ReLU(),
31
+ norm: str = "batch_norm",
32
+ causal: bool = False,
33
+ bias: bool = True):
34
+ """Construct an ConvolutionModule object.
35
+ Args:
36
+ channels (int): The number of channels of conv layers.
37
+ kernel_size (int): Kernel size of conv layers.
38
+ causal (int): Whether use causal convolution or not
39
+ """
40
+ super().__init__()
41
+
42
+ self.pointwise_conv1 = nn.Conv1d(
43
+ channels,
44
+ 2 * channels,
45
+ kernel_size=1,
46
+ stride=1,
47
+ padding=0,
48
+ bias=bias,
49
+ )
50
+ # self.lorder is used to distinguish if it's a causal convolution,
51
+ # if self.lorder > 0: it's a causal convolution, the input will be
52
+ # padded with self.lorder frames on the left in forward.
53
+ # else: it's a symmetrical convolution
54
+ if causal:
55
+ padding = 0
56
+ self.lorder = kernel_size - 1
57
+ else:
58
+ # kernel_size should be an odd number for none causal convolution
59
+ assert (kernel_size - 1) % 2 == 0
60
+ padding = (kernel_size - 1) // 2
61
+ self.lorder = 0
62
+ self.depthwise_conv = nn.Conv1d(
63
+ channels,
64
+ channels,
65
+ kernel_size,
66
+ stride=1,
67
+ padding=padding,
68
+ groups=channels,
69
+ bias=bias,
70
+ )
71
+
72
+ assert norm in ['batch_norm', 'layer_norm']
73
+ if norm == "batch_norm":
74
+ self.use_layer_norm = False
75
+ self.norm = nn.BatchNorm1d(channels)
76
+ else:
77
+ self.use_layer_norm = True
78
+ self.norm = nn.LayerNorm(channels)
79
+
80
+ self.pointwise_conv2 = nn.Conv1d(
81
+ channels,
82
+ channels,
83
+ kernel_size=1,
84
+ stride=1,
85
+ padding=0,
86
+ bias=bias,
87
+ )
88
+ self.activation = activation
89
+
90
+ def forward(
91
+ self,
92
+ x: torch.Tensor,
93
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
94
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
95
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
96
+ """Compute convolution module.
97
+ Args:
98
+ x (torch.Tensor): Input tensor (#batch, time, channels).
99
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
100
+ (0, 0, 0) means fake mask.
101
+ cache (torch.Tensor): left context cache, it is only
102
+ used in causal convolution (#batch, channels, cache_t),
103
+ (0, 0, 0) meas fake cache.
104
+ Returns:
105
+ torch.Tensor: Output tensor (#batch, time, channels).
106
+ """
107
+ # exchange the temporal dimension and the feature dimension
108
+ x = x.transpose(1, 2) # (#batch, channels, time)
109
+
110
+ # mask batch padding
111
+ if mask_pad.size(2) > 0: # time > 0
112
+ x.masked_fill_(~mask_pad, 0.0)
113
+
114
+ if self.lorder > 0:
115
+ if cache.size(2) == 0: # cache_t == 0
116
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
117
+ else:
118
+ assert cache.size(0) == x.size(0) # equal batch
119
+ assert cache.size(1) == x.size(1) # equal channel
120
+ x = torch.cat((cache, x), dim=2)
121
+ assert (x.size(2) > self.lorder)
122
+ new_cache = x[:, :, -self.lorder:]
123
+ else:
124
+ # It's better we just return None if no cache is required,
125
+ # However, for JIT export, here we just fake one tensor instead of
126
+ # None.
127
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
128
+
129
+ # GLU mechanism
130
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
131
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
132
+
133
+ # 1D Depthwise Conv
134
+ x = self.depthwise_conv(x)
135
+ if self.use_layer_norm:
136
+ x = x.transpose(1, 2)
137
+ x = self.activation(self.norm(x))
138
+ if self.use_layer_norm:
139
+ x = x.transpose(1, 2)
140
+ x = self.pointwise_conv2(x)
141
+ # mask batch padding
142
+ if mask_pad.size(2) > 0: # time > 0
143
+ x.masked_fill_(~mask_pad, 0.0)
144
+
145
+ return x.transpose(1, 2), new_cache
cosyvoice/transformer/decoder.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Decoder definition."""
17
+ from typing import Tuple, List, Optional
18
+
19
+ import torch
20
+ import torch.utils.checkpoint as ckpt
21
+ import logging
22
+
23
+ from cosyvoice.transformer.decoder_layer import DecoderLayer
24
+ from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
25
+ from cosyvoice.utils.class_utils import (
26
+ COSYVOICE_EMB_CLASSES,
27
+ COSYVOICE_ATTENTION_CLASSES,
28
+ COSYVOICE_ACTIVATION_CLASSES,
29
+ )
30
+ from cosyvoice.utils.mask import (subsequent_mask, make_pad_mask)
31
+
32
+
33
+ class TransformerDecoder(torch.nn.Module):
34
+ """Base class of Transfomer decoder module.
35
+ Args:
36
+ vocab_size: output dim
37
+ encoder_output_size: dimension of attention
38
+ attention_heads: the number of heads of multi head attention
39
+ linear_units: the hidden units number of position-wise feedforward
40
+ num_blocks: the number of decoder blocks
41
+ dropout_rate: dropout rate
42
+ self_attention_dropout_rate: dropout rate for attention
43
+ input_layer: input layer type
44
+ use_output_layer: whether to use output layer
45
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
46
+ normalize_before:
47
+ True: use layer_norm before each sub-block of a layer.
48
+ False: use layer_norm after each sub-block of a layer.
49
+ src_attention: if false, encoder-decoder cross attention is not
50
+ applied, such as CIF model
51
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
52
+ gradient_checkpointing: rerunning a forward-pass segment for each
53
+ checkpointed segment during backward.
54
+ tie_word_embedding: Tie or clone module weights depending of whether we are
55
+ using TorchScript or not
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ vocab_size: int,
61
+ encoder_output_size: int,
62
+ attention_heads: int = 4,
63
+ linear_units: int = 2048,
64
+ num_blocks: int = 6,
65
+ dropout_rate: float = 0.1,
66
+ positional_dropout_rate: float = 0.1,
67
+ self_attention_dropout_rate: float = 0.0,
68
+ src_attention_dropout_rate: float = 0.0,
69
+ input_layer: str = "embed",
70
+ use_output_layer: bool = True,
71
+ normalize_before: bool = True,
72
+ src_attention: bool = True,
73
+ key_bias: bool = True,
74
+ activation_type: str = "relu",
75
+ gradient_checkpointing: bool = False,
76
+ tie_word_embedding: bool = False,
77
+ ):
78
+ super().__init__()
79
+ attention_dim = encoder_output_size
80
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
81
+
82
+ self.embed = torch.nn.Sequential(
83
+ torch.nn.Identity() if input_layer == "no_pos" else
84
+ torch.nn.Embedding(vocab_size, attention_dim),
85
+ COSYVOICE_EMB_CLASSES[input_layer](attention_dim,
86
+ positional_dropout_rate),
87
+ )
88
+
89
+ self.normalize_before = normalize_before
90
+ self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
91
+ self.use_output_layer = use_output_layer
92
+ if use_output_layer:
93
+ self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
94
+ else:
95
+ self.output_layer = torch.nn.Identity()
96
+ self.num_blocks = num_blocks
97
+ self.decoders = torch.nn.ModuleList([
98
+ DecoderLayer(
99
+ attention_dim,
100
+ COSYVOICE_ATTENTION_CLASSES["selfattn"](
101
+ attention_heads, attention_dim,
102
+ self_attention_dropout_rate, key_bias),
103
+ COSYVOICE_ATTENTION_CLASSES["selfattn"](
104
+ attention_heads, attention_dim, src_attention_dropout_rate,
105
+ key_bias) if src_attention else None,
106
+ PositionwiseFeedForward(attention_dim, linear_units,
107
+ dropout_rate, activation),
108
+ dropout_rate,
109
+ normalize_before,
110
+ ) for _ in range(self.num_blocks)
111
+ ])
112
+
113
+ self.gradient_checkpointing = gradient_checkpointing
114
+ self.tie_word_embedding = tie_word_embedding
115
+
116
+ def forward(
117
+ self,
118
+ memory: torch.Tensor,
119
+ memory_mask: torch.Tensor,
120
+ ys_in_pad: torch.Tensor,
121
+ ys_in_lens: torch.Tensor,
122
+ r_ys_in_pad: torch.Tensor = torch.empty(0),
123
+ reverse_weight: float = 0.0,
124
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
125
+ """Forward decoder.
126
+ Args:
127
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
128
+ memory_mask: encoder memory mask, (batch, 1, maxlen_in)
129
+ ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
130
+ ys_in_lens: input lengths of this batch (batch)
131
+ r_ys_in_pad: not used in transformer decoder, in order to unify api
132
+ with bidirectional decoder
133
+ reverse_weight: not used in transformer decoder, in order to unify
134
+ api with bidirectional decode
135
+ Returns:
136
+ (tuple): tuple containing:
137
+ x: decoded token score before softmax (batch, maxlen_out,
138
+ vocab_size) if use_output_layer is True,
139
+ torch.tensor(0.0), in order to unify api with bidirectional decoder
140
+ olens: (batch, )
141
+ NOTE(xcsong):
142
+ We pass the `__call__` method of the modules instead of `forward` to the
143
+ checkpointing API because `__call__` attaches all the hooks of the module.
144
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
145
+ """
146
+ tgt = ys_in_pad
147
+ maxlen = tgt.size(1)
148
+ # tgt_mask: (B, 1, L)
149
+ tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
150
+ tgt_mask = tgt_mask.to(tgt.device)
151
+ # m: (1, L, L)
152
+ m = subsequent_mask(tgt_mask.size(-1),
153
+ device=tgt_mask.device).unsqueeze(0)
154
+ # tgt_mask: (B, L, L)
155
+ tgt_mask = tgt_mask & m
156
+ x, _ = self.embed(tgt)
157
+ if self.gradient_checkpointing and self.training:
158
+ x = self.forward_layers_checkpointed(x, tgt_mask, memory,
159
+ memory_mask)
160
+ else:
161
+ x = self.forward_layers(x, tgt_mask, memory, memory_mask)
162
+ if self.normalize_before:
163
+ x = self.after_norm(x)
164
+ if self.use_output_layer:
165
+ x = self.output_layer(x)
166
+ olens = tgt_mask.sum(1)
167
+ return x, torch.tensor(0.0), olens
168
+
169
+ def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
170
+ memory: torch.Tensor,
171
+ memory_mask: torch.Tensor) -> torch.Tensor:
172
+ for layer in self.decoders:
173
+ x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
174
+ memory_mask)
175
+ return x
176
+
177
+ @torch.jit.unused
178
+ def forward_layers_checkpointed(self, x: torch.Tensor,
179
+ tgt_mask: torch.Tensor,
180
+ memory: torch.Tensor,
181
+ memory_mask: torch.Tensor) -> torch.Tensor:
182
+ for layer in self.decoders:
183
+ x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
184
+ layer.__call__, x, tgt_mask, memory, memory_mask)
185
+ return x
186
+
187
+ def forward_one_step(
188
+ self,
189
+ memory: torch.Tensor,
190
+ memory_mask: torch.Tensor,
191
+ tgt: torch.Tensor,
192
+ tgt_mask: torch.Tensor,
193
+ cache: Optional[List[torch.Tensor]] = None,
194
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
195
+ """Forward one step.
196
+ This is only used for decoding.
197
+ Args:
198
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
199
+ memory_mask: encoded memory mask, (batch, 1, maxlen_in)
200
+ tgt: input token ids, int64 (batch, maxlen_out)
201
+ tgt_mask: input token mask, (batch, maxlen_out)
202
+ dtype=torch.uint8 in PyTorch 1.2-
203
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
204
+ cache: cached output list of (batch, max_time_out-1, size)
205
+ Returns:
206
+ y, cache: NN output value and cache per `self.decoders`.
207
+ y.shape` is (batch, maxlen_out, token)
208
+ """
209
+ x, _ = self.embed(tgt)
210
+ new_cache = []
211
+ for i, decoder in enumerate(self.decoders):
212
+ if cache is None:
213
+ c = None
214
+ else:
215
+ c = cache[i]
216
+ x, tgt_mask, memory, memory_mask = decoder(x,
217
+ tgt_mask,
218
+ memory,
219
+ memory_mask,
220
+ cache=c)
221
+ new_cache.append(x)
222
+ if self.normalize_before:
223
+ y = self.after_norm(x[:, -1])
224
+ else:
225
+ y = x[:, -1]
226
+ if self.use_output_layer:
227
+ y = torch.log_softmax(self.output_layer(y), dim=-1)
228
+ return y, new_cache
229
+
230
+ def tie_or_clone_weights(self, jit_mode: bool = True):
231
+ """Tie or clone module weights (between word_emb and output_layer)
232
+ depending of whether we are using TorchScript or not"""
233
+ if not self.use_output_layer:
234
+ return
235
+ if jit_mode:
236
+ logging.info("clone emb.weight to output.weight")
237
+ self.output_layer.weight = torch.nn.Parameter(
238
+ self.embed[0].weight.clone())
239
+ else:
240
+ logging.info("tie emb.weight with output.weight")
241
+ self.output_layer.weight = self.embed[0].weight
242
+
243
+ if getattr(self.output_layer, "bias", None) is not None:
244
+ self.output_layer.bias.data = torch.nn.functional.pad(
245
+ self.output_layer.bias.data,
246
+ (
247
+ 0,
248
+ self.output_layer.weight.shape[0] -
249
+ self.output_layer.bias.shape[0],
250
+ ),
251
+ "constant",
252
+ 0,
253
+ )
254
+
255
+
256
+ class BiTransformerDecoder(torch.nn.Module):
257
+ """Base class of Transfomer decoder module.
258
+ Args:
259
+ vocab_size: output dim
260
+ encoder_output_size: dimension of attention
261
+ attention_heads: the number of heads of multi head attention
262
+ linear_units: the hidden units number of position-wise feedforward
263
+ num_blocks: the number of decoder blocks
264
+ r_num_blocks: the number of right to left decoder blocks
265
+ dropout_rate: dropout rate
266
+ self_attention_dropout_rate: dropout rate for attention
267
+ input_layer: input layer type
268
+ use_output_layer: whether to use output layer
269
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
270
+ normalize_before:
271
+ True: use layer_norm before each sub-block of a layer.
272
+ False: use layer_norm after each sub-block of a layer.
273
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
274
+ """
275
+
276
+ def __init__(
277
+ self,
278
+ vocab_size: int,
279
+ encoder_output_size: int,
280
+ attention_heads: int = 4,
281
+ linear_units: int = 2048,
282
+ num_blocks: int = 6,
283
+ r_num_blocks: int = 0,
284
+ dropout_rate: float = 0.1,
285
+ positional_dropout_rate: float = 0.1,
286
+ self_attention_dropout_rate: float = 0.0,
287
+ src_attention_dropout_rate: float = 0.0,
288
+ input_layer: str = "embed",
289
+ use_output_layer: bool = True,
290
+ normalize_before: bool = True,
291
+ key_bias: bool = True,
292
+ gradient_checkpointing: bool = False,
293
+ tie_word_embedding: bool = False,
294
+ ):
295
+
296
+ super().__init__()
297
+ self.tie_word_embedding = tie_word_embedding
298
+ self.left_decoder = TransformerDecoder(
299
+ vocab_size,
300
+ encoder_output_size,
301
+ attention_heads,
302
+ linear_units,
303
+ num_blocks,
304
+ dropout_rate,
305
+ positional_dropout_rate,
306
+ self_attention_dropout_rate,
307
+ src_attention_dropout_rate,
308
+ input_layer,
309
+ use_output_layer,
310
+ normalize_before,
311
+ key_bias=key_bias,
312
+ gradient_checkpointing=gradient_checkpointing,
313
+ tie_word_embedding=tie_word_embedding)
314
+
315
+ self.right_decoder = TransformerDecoder(
316
+ vocab_size,
317
+ encoder_output_size,
318
+ attention_heads,
319
+ linear_units,
320
+ r_num_blocks,
321
+ dropout_rate,
322
+ positional_dropout_rate,
323
+ self_attention_dropout_rate,
324
+ src_attention_dropout_rate,
325
+ input_layer,
326
+ use_output_layer,
327
+ normalize_before,
328
+ key_bias=key_bias,
329
+ gradient_checkpointing=gradient_checkpointing,
330
+ tie_word_embedding=tie_word_embedding)
331
+
332
+ def forward(
333
+ self,
334
+ memory: torch.Tensor,
335
+ memory_mask: torch.Tensor,
336
+ ys_in_pad: torch.Tensor,
337
+ ys_in_lens: torch.Tensor,
338
+ r_ys_in_pad: torch.Tensor,
339
+ reverse_weight: float = 0.0,
340
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
341
+ """Forward decoder.
342
+ Args:
343
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
344
+ memory_mask: encoder memory mask, (batch, 1, maxlen_in)
345
+ ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
346
+ ys_in_lens: input lengths of this batch (batch)
347
+ r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
348
+ used for right to left decoder
349
+ reverse_weight: used for right to left decoder
350
+ Returns:
351
+ (tuple): tuple containing:
352
+ x: decoded token score before softmax (batch, maxlen_out,
353
+ vocab_size) if use_output_layer is True,
354
+ r_x: x: decoded token score (right to left decoder)
355
+ before softmax (batch, maxlen_out, vocab_size)
356
+ if use_output_layer is True,
357
+ olens: (batch, )
358
+ """
359
+ l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
360
+ ys_in_lens)
361
+ r_x = torch.tensor(0.0)
362
+ if reverse_weight > 0.0:
363
+ r_x, _, olens = self.right_decoder(memory, memory_mask,
364
+ r_ys_in_pad, ys_in_lens)
365
+ return l_x, r_x, olens
366
+
367
+ def forward_one_step(
368
+ self,
369
+ memory: torch.Tensor,
370
+ memory_mask: torch.Tensor,
371
+ tgt: torch.Tensor,
372
+ tgt_mask: torch.Tensor,
373
+ cache: Optional[List[torch.Tensor]] = None,
374
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
375
+ """Forward one step.
376
+ This is only used for decoding.
377
+ Args:
378
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
379
+ memory_mask: encoded memory mask, (batch, 1, maxlen_in)
380
+ tgt: input token ids, int64 (batch, maxlen_out)
381
+ tgt_mask: input token mask, (batch, maxlen_out)
382
+ dtype=torch.uint8 in PyTorch 1.2-
383
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
384
+ cache: cached output list of (batch, max_time_out-1, size)
385
+ Returns:
386
+ y, cache: NN output value and cache per `self.decoders`.
387
+ y.shape` is (batch, maxlen_out, token)
388
+ """
389
+ return self.left_decoder.forward_one_step(memory, memory_mask, tgt,
390
+ tgt_mask, cache)
391
+
392
+ def tie_or_clone_weights(self, jit_mode: bool = True):
393
+ """Tie or clone module weights (between word_emb and output_layer)
394
+ depending of whether we are using TorchScript or not"""
395
+ self.left_decoder.tie_or_clone_weights(jit_mode)
396
+ self.right_decoder.tie_or_clone_weights(jit_mode)
cosyvoice/transformer/decoder_layer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Decoder self-attention layer definition."""
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+
22
+ class DecoderLayer(nn.Module):
23
+ """Single decoder layer module.
24
+
25
+ Args:
26
+ size (int): Input dimension.
27
+ self_attn (torch.nn.Module): Self-attention module instance.
28
+ `MultiHeadedAttention` instance can be used as the argument.
29
+ src_attn (torch.nn.Module): Inter-attention module instance.
30
+ `MultiHeadedAttention` instance can be used as the argument.
31
+ If `None` is passed, Inter-attention is not used, such as
32
+ CIF, GPT, and other decoder only model.
33
+ feed_forward (torch.nn.Module): Feed-forward module instance.
34
+ `PositionwiseFeedForward` instance can be used as the argument.
35
+ dropout_rate (float): Dropout rate.
36
+ normalize_before (bool):
37
+ True: use layer_norm before each sub-block.
38
+ False: to use layer_norm after each sub-block.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ size: int,
44
+ self_attn: nn.Module,
45
+ src_attn: Optional[nn.Module],
46
+ feed_forward: nn.Module,
47
+ dropout_rate: float,
48
+ normalize_before: bool = True,
49
+ ):
50
+ """Construct an DecoderLayer object."""
51
+ super().__init__()
52
+ self.size = size
53
+ self.self_attn = self_attn
54
+ self.src_attn = src_attn
55
+ self.feed_forward = feed_forward
56
+ self.norm1 = nn.LayerNorm(size, eps=1e-5)
57
+ self.norm2 = nn.LayerNorm(size, eps=1e-5)
58
+ self.norm3 = nn.LayerNorm(size, eps=1e-5)
59
+ self.dropout = nn.Dropout(dropout_rate)
60
+ self.normalize_before = normalize_before
61
+
62
+ def forward(
63
+ self,
64
+ tgt: torch.Tensor,
65
+ tgt_mask: torch.Tensor,
66
+ memory: torch.Tensor,
67
+ memory_mask: torch.Tensor,
68
+ cache: Optional[torch.Tensor] = None
69
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
70
+ """Compute decoded features.
71
+
72
+ Args:
73
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
74
+ tgt_mask (torch.Tensor): Mask for input tensor
75
+ (#batch, maxlen_out).
76
+ memory (torch.Tensor): Encoded memory
77
+ (#batch, maxlen_in, size).
78
+ memory_mask (torch.Tensor): Encoded memory mask
79
+ (#batch, maxlen_in).
80
+ cache (torch.Tensor): cached tensors.
81
+ (#batch, maxlen_out - 1, size).
82
+
83
+ Returns:
84
+ torch.Tensor: Output tensor (#batch, maxlen_out, size).
85
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
86
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
87
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
88
+
89
+ """
90
+ residual = tgt
91
+ if self.normalize_before:
92
+ tgt = self.norm1(tgt)
93
+
94
+ if cache is None:
95
+ tgt_q = tgt
96
+ tgt_q_mask = tgt_mask
97
+ else:
98
+ # compute only the last frame query keeping dim: max_time_out -> 1
99
+ assert cache.shape == (
100
+ tgt.shape[0],
101
+ tgt.shape[1] - 1,
102
+ self.size,
103
+ ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
104
+ tgt_q = tgt[:, -1:, :]
105
+ residual = residual[:, -1:, :]
106
+ tgt_q_mask = tgt_mask[:, -1:, :]
107
+
108
+ x = residual + self.dropout(
109
+ self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
110
+ if not self.normalize_before:
111
+ x = self.norm1(x)
112
+
113
+ if self.src_attn is not None:
114
+ residual = x
115
+ if self.normalize_before:
116
+ x = self.norm2(x)
117
+ x = residual + self.dropout(
118
+ self.src_attn(x, memory, memory, memory_mask)[0])
119
+ if not self.normalize_before:
120
+ x = self.norm2(x)
121
+
122
+ residual = x
123
+ if self.normalize_before:
124
+ x = self.norm3(x)
125
+ x = residual + self.dropout(self.feed_forward(x))
126
+ if not self.normalize_before:
127
+ x = self.norm3(x)
128
+
129
+ if cache is not None:
130
+ x = torch.cat([cache, x], dim=1)
131
+
132
+ return x, tgt_mask, memory, memory_mask
cosyvoice/transformer/embedding.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Positonal Encoding Module."""
17
+
18
+ import math
19
+ from typing import Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import numpy as np
24
+
25
+
26
+ class PositionalEncoding(torch.nn.Module):
27
+ """Positional encoding.
28
+
29
+ :param int d_model: embedding dim
30
+ :param float dropout_rate: dropout rate
31
+ :param int max_len: maximum input length
32
+
33
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
34
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
35
+ """
36
+
37
+ def __init__(self,
38
+ d_model: int,
39
+ dropout_rate: float,
40
+ max_len: int = 5000,
41
+ reverse: bool = False):
42
+ """Construct an PositionalEncoding object."""
43
+ super().__init__()
44
+ self.d_model = d_model
45
+ self.xscale = math.sqrt(self.d_model)
46
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
47
+ self.max_len = max_len
48
+
49
+ self.pe = torch.zeros(self.max_len, self.d_model)
50
+ position = torch.arange(0, self.max_len,
51
+ dtype=torch.float32).unsqueeze(1)
52
+ div_term = torch.exp(
53
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) *
54
+ -(math.log(10000.0) / self.d_model))
55
+ self.pe[:, 0::2] = torch.sin(position * div_term)
56
+ self.pe[:, 1::2] = torch.cos(position * div_term)
57
+ self.pe = self.pe.unsqueeze(0)
58
+
59
+ def forward(self,
60
+ x: torch.Tensor,
61
+ offset: Union[int, torch.Tensor] = 0) \
62
+ -> Tuple[torch.Tensor, torch.Tensor]:
63
+ """Add positional encoding.
64
+
65
+ Args:
66
+ x (torch.Tensor): Input. Its shape is (batch, time, ...)
67
+ offset (int, torch.tensor): position offset
68
+
69
+ Returns:
70
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
71
+ torch.Tensor: for compatibility to RelPositionalEncoding
72
+ """
73
+
74
+ self.pe = self.pe.to(x.device)
75
+ pos_emb = self.position_encoding(offset, x.size(1), False)
76
+ x = x * self.xscale + pos_emb
77
+ return self.dropout(x), self.dropout(pos_emb)
78
+
79
+ def position_encoding(self,
80
+ offset: Union[int, torch.Tensor],
81
+ size: int,
82
+ apply_dropout: bool = True) -> torch.Tensor:
83
+ """ For getting encoding in a streaming fashion
84
+
85
+ Attention!!!!!
86
+ we apply dropout only once at the whole utterance level in a none
87
+ streaming way, but will call this function several times with
88
+ increasing input size in a streaming scenario, so the dropout will
89
+ be applied several times.
90
+
91
+ Args:
92
+ offset (int or torch.tensor): start offset
93
+ size (int): required size of position encoding
94
+
95
+ Returns:
96
+ torch.Tensor: Corresponding encoding
97
+ """
98
+ # How to subscript a Union type:
99
+ # https://github.com/pytorch/pytorch/issues/69434
100
+ if isinstance(offset, int):
101
+ assert offset + size <= self.max_len
102
+ pos_emb = self.pe[:, offset:offset + size]
103
+ elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
104
+ assert offset + size <= self.max_len
105
+ pos_emb = self.pe[:, offset:offset + size]
106
+ else: # for batched streaming decoding on GPU
107
+ assert torch.max(offset) + size <= self.max_len
108
+ index = offset.unsqueeze(1) + \
109
+ torch.arange(0, size).to(offset.device) # B X T
110
+ flag = index > 0
111
+ # remove negative offset
112
+ index = index * flag
113
+ pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
114
+
115
+ if apply_dropout:
116
+ pos_emb = self.dropout(pos_emb)
117
+ return pos_emb
118
+
119
+
120
+ class RelPositionalEncoding(PositionalEncoding):
121
+ """Relative positional encoding module.
122
+ See : Appendix B in https://arxiv.org/abs/1901.02860
123
+ Args:
124
+ d_model (int): Embedding dimension.
125
+ dropout_rate (float): Dropout rate.
126
+ max_len (int): Maximum input length.
127
+ """
128
+
129
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
130
+ """Initialize class."""
131
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
132
+
133
+ def forward(self,
134
+ x: torch.Tensor,
135
+ offset: Union[int, torch.Tensor] = 0) \
136
+ -> Tuple[torch.Tensor, torch.Tensor]:
137
+ """Compute positional encoding.
138
+ Args:
139
+ x (torch.Tensor): Input tensor (batch, time, `*`).
140
+ Returns:
141
+ torch.Tensor: Encoded tensor (batch, time, `*`).
142
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
143
+ """
144
+ self.pe = self.pe.to(x.device)
145
+ x = x * self.xscale
146
+ pos_emb = self.position_encoding(offset, x.size(1), False)
147
+ return self.dropout(x), self.dropout(pos_emb)
148
+
149
+
150
+ class WhisperPositionalEncoding(PositionalEncoding):
151
+ """ Sinusoids position encoding used in openai-whisper.encoder
152
+ """
153
+
154
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
155
+ super().__init__(d_model, dropout_rate, max_len)
156
+ self.xscale = 1.0
157
+ log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
158
+ inv_timescales = torch.exp(-log_timescale_increment *
159
+ torch.arange(d_model // 2))
160
+ scaled_time = torch.arange(max_len)[:, np.newaxis] * \
161
+ inv_timescales[np.newaxis, :]
162
+ pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
163
+ delattr(self, "pe")
164
+ self.register_buffer("pe", pe.unsqueeze(0))
165
+
166
+
167
+ class LearnablePositionalEncoding(PositionalEncoding):
168
+ """ Learnable position encoding used in openai-whisper.decoder
169
+ """
170
+
171
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
172
+ super().__init__(d_model, dropout_rate, max_len)
173
+ # NOTE(xcsong): overwrite self.pe & self.xscale
174
+ self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
175
+ self.xscale = 1.0
176
+
177
+
178
+ class NoPositionalEncoding(torch.nn.Module):
179
+ """ No position encoding
180
+ """
181
+
182
+ def __init__(self, d_model: int, dropout_rate: float):
183
+ super().__init__()
184
+ self.d_model = d_model
185
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
186
+
187
+ def forward(self,
188
+ x: torch.Tensor,
189
+ offset: Union[int, torch.Tensor] = 0) \
190
+ -> Tuple[torch.Tensor, torch.Tensor]:
191
+ """ Just return zero vector for interface compatibility
192
+ """
193
+ pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
194
+ return self.dropout(x), pos_emb
195
+
196
+ def position_encoding(self, offset: Union[int, torch.Tensor],
197
+ size: int) -> torch.Tensor:
198
+ return torch.zeros(1, size, self.d_model)
199
+
200
+
201
+ class EspnetRelPositionalEncoding(torch.nn.Module):
202
+ """Relative positional encoding module (new implementation).
203
+
204
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
205
+
206
+ See : Appendix B in https://arxiv.org/abs/1901.02860
207
+
208
+ Args:
209
+ d_model (int): Embedding dimension.
210
+ dropout_rate (float): Dropout rate.
211
+ max_len (int): Maximum input length.
212
+
213
+ """
214
+
215
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
216
+ """Construct an PositionalEncoding object."""
217
+ super(EspnetRelPositionalEncoding, self).__init__()
218
+ self.d_model = d_model
219
+ self.xscale = math.sqrt(self.d_model)
220
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
221
+ self.pe = None
222
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
223
+
224
+ def extend_pe(self, x: torch.Tensor):
225
+ """Reset the positional encodings."""
226
+ if self.pe is not None:
227
+ # self.pe contains both positive and negative parts
228
+ # the length of self.pe is 2 * input_len - 1
229
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
230
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
231
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
232
+ return
233
+ # Suppose `i` means to the position of query vecotr and `j` means the
234
+ # position of key vector. We use position relative positions when keys
235
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
236
+ pe_positive = torch.zeros(x.size(1), self.d_model)
237
+ pe_negative = torch.zeros(x.size(1), self.d_model)
238
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
239
+ div_term = torch.exp(
240
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
241
+ * -(math.log(10000.0) / self.d_model)
242
+ )
243
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
244
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
245
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
246
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
247
+
248
+ # Reserve the order of positive indices and concat both positive and
249
+ # negative indices. This is used to support the shifting trick
250
+ # as in https://arxiv.org/abs/1901.02860
251
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
252
+ pe_negative = pe_negative[1:].unsqueeze(0)
253
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
254
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
255
+
256
+ def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
257
+ -> Tuple[torch.Tensor, torch.Tensor]:
258
+ """Add positional encoding.
259
+
260
+ Args:
261
+ x (torch.Tensor): Input tensor (batch, time, `*`).
262
+
263
+ Returns:
264
+ torch.Tensor: Encoded tensor (batch, time, `*`).
265
+
266
+ """
267
+ self.extend_pe(x)
268
+ x = x * self.xscale
269
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
270
+ return self.dropout(x), self.dropout(pos_emb)
271
+
272
+ def position_encoding(self,
273
+ offset: Union[int, torch.Tensor],
274
+ size: int) -> torch.Tensor:
275
+ """ For getting encoding in a streaming fashion
276
+
277
+ Attention!!!!!
278
+ we apply dropout only once at the whole utterance level in a none
279
+ streaming way, but will call this function several times with
280
+ increasing input size in a streaming scenario, so the dropout will
281
+ be applied several times.
282
+
283
+ Args:
284
+ offset (int or torch.tensor): start offset
285
+ size (int): required size of position encoding
286
+
287
+ Returns:
288
+ torch.Tensor: Corresponding encoding
289
+ """
290
+ pos_emb = self.pe[
291
+ :,
292
+ self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
293
+ ]
294
+ return pos_emb
cosyvoice/transformer/encoder.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song ([email protected])
3
+ # 2024 Alibaba Inc (Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ """Encoder definition."""
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ import torch.utils.checkpoint as ckpt
22
+
23
+ from cosyvoice.transformer.convolution import ConvolutionModule
24
+ from cosyvoice.transformer.encoder_layer import TransformerEncoderLayer
25
+ from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
26
+ from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
27
+ from cosyvoice.utils.class_utils import (
28
+ COSYVOICE_EMB_CLASSES,
29
+ COSYVOICE_SUBSAMPLE_CLASSES,
30
+ COSYVOICE_ATTENTION_CLASSES,
31
+ COSYVOICE_ACTIVATION_CLASSES,
32
+ )
33
+ from cosyvoice.utils.mask import make_pad_mask
34
+ from cosyvoice.utils.mask import add_optional_chunk_mask
35
+
36
+
37
+ class BaseEncoder(torch.nn.Module):
38
+
39
+ def __init__(
40
+ self,
41
+ input_size: int,
42
+ output_size: int = 256,
43
+ attention_heads: int = 4,
44
+ linear_units: int = 2048,
45
+ num_blocks: int = 6,
46
+ dropout_rate: float = 0.1,
47
+ positional_dropout_rate: float = 0.1,
48
+ attention_dropout_rate: float = 0.0,
49
+ input_layer: str = "conv2d",
50
+ pos_enc_layer_type: str = "abs_pos",
51
+ normalize_before: bool = True,
52
+ static_chunk_size: int = 0,
53
+ use_dynamic_chunk: bool = False,
54
+ global_cmvn: torch.nn.Module = None,
55
+ use_dynamic_left_chunk: bool = False,
56
+ gradient_checkpointing: bool = False,
57
+ ):
58
+ """
59
+ Args:
60
+ input_size (int): input dim
61
+ output_size (int): dimension of attention
62
+ attention_heads (int): the number of heads of multi head attention
63
+ linear_units (int): the hidden units number of position-wise feed
64
+ forward
65
+ num_blocks (int): the number of decoder blocks
66
+ dropout_rate (float): dropout rate
67
+ attention_dropout_rate (float): dropout rate in attention
68
+ positional_dropout_rate (float): dropout rate after adding
69
+ positional encoding
70
+ input_layer (str): input layer type.
71
+ optional [linear, conv2d, conv2d6, conv2d8]
72
+ pos_enc_layer_type (str): Encoder positional encoding layer type.
73
+ opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
74
+ normalize_before (bool):
75
+ True: use layer_norm before each sub-block of a layer.
76
+ False: use layer_norm after each sub-block of a layer.
77
+ static_chunk_size (int): chunk size for static chunk training and
78
+ decoding
79
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
80
+ training or not, You can only use fixed chunk(chunk_size > 0)
81
+ or dyanmic chunk size(use_dynamic_chunk = True)
82
+ global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
83
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
84
+ dynamic chunk training
85
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
86
+ gradient_checkpointing: rerunning a forward-pass segment for each
87
+ checkpointed segment during backward.
88
+ """
89
+ super().__init__()
90
+ self._output_size = output_size
91
+
92
+ self.global_cmvn = global_cmvn
93
+ self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
94
+ input_size,
95
+ output_size,
96
+ dropout_rate,
97
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
98
+ positional_dropout_rate),
99
+ )
100
+
101
+ self.normalize_before = normalize_before
102
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
103
+ self.static_chunk_size = static_chunk_size
104
+ self.use_dynamic_chunk = use_dynamic_chunk
105
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
106
+ self.gradient_checkpointing = gradient_checkpointing
107
+
108
+ def output_size(self) -> int:
109
+ return self._output_size
110
+
111
+ def forward(
112
+ self,
113
+ xs: torch.Tensor,
114
+ xs_lens: torch.Tensor,
115
+ decoding_chunk_size: int = 0,
116
+ num_decoding_left_chunks: int = -1,
117
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
118
+ """Embed positions in tensor.
119
+
120
+ Args:
121
+ xs: padded input tensor (B, T, D)
122
+ xs_lens: input length (B)
123
+ decoding_chunk_size: decoding chunk size for dynamic chunk
124
+ 0: default for training, use random dynamic chunk.
125
+ <0: for decoding, use full chunk.
126
+ >0: for decoding, use fixed chunk size as set.
127
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
128
+ the chunk size is decoding_chunk_size.
129
+ >=0: use num_decoding_left_chunks
130
+ <0: use all left chunks
131
+ Returns:
132
+ encoder output tensor xs, and subsampled masks
133
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
134
+ masks: torch.Tensor batch padding mask after subsample
135
+ (B, 1, T' ~= T/subsample_rate)
136
+ NOTE(xcsong):
137
+ We pass the `__call__` method of the modules instead of `forward` to the
138
+ checkpointing API because `__call__` attaches all the hooks of the module.
139
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
140
+ """
141
+ T = xs.size(1)
142
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
143
+ if self.global_cmvn is not None:
144
+ xs = self.global_cmvn(xs)
145
+ xs, pos_emb, masks = self.embed(xs, masks)
146
+ mask_pad = masks # (B, 1, T/subsample_rate)
147
+ chunk_masks = add_optional_chunk_mask(xs, masks,
148
+ self.use_dynamic_chunk,
149
+ self.use_dynamic_left_chunk,
150
+ decoding_chunk_size,
151
+ self.static_chunk_size,
152
+ num_decoding_left_chunks)
153
+ if self.gradient_checkpointing and self.training:
154
+ xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb,
155
+ mask_pad)
156
+ else:
157
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
158
+ if self.normalize_before:
159
+ xs = self.after_norm(xs)
160
+ # Here we assume the mask is not changed in encoder layers, so just
161
+ # return the masks before encoder layers, and the masks will be used
162
+ # for cross attention with decoder later
163
+ return xs, masks
164
+
165
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
166
+ pos_emb: torch.Tensor,
167
+ mask_pad: torch.Tensor) -> torch.Tensor:
168
+ for layer in self.encoders:
169
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
170
+ return xs
171
+
172
+ @torch.jit.unused
173
+ def forward_layers_checkpointed(self, xs: torch.Tensor,
174
+ chunk_masks: torch.Tensor,
175
+ pos_emb: torch.Tensor,
176
+ mask_pad: torch.Tensor) -> torch.Tensor:
177
+ for layer in self.encoders:
178
+ xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
179
+ chunk_masks, pos_emb,
180
+ mask_pad)
181
+ return xs
182
+
183
+ @torch.jit.export
184
+ def forward_chunk(
185
+ self,
186
+ xs: torch.Tensor,
187
+ offset: int,
188
+ required_cache_size: int,
189
+ att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
190
+ cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
191
+ att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
192
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
193
+ """ Forward just one chunk
194
+
195
+ Args:
196
+ xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
197
+ where `time == (chunk_size - 1) * subsample_rate + \
198
+ subsample.right_context + 1`
199
+ offset (int): current offset in encoder output time stamp
200
+ required_cache_size (int): cache size required for next chunk
201
+ compuation
202
+ >=0: actual cache size
203
+ <0: means all history cache is required
204
+ att_cache (torch.Tensor): cache tensor for KEY & VALUE in
205
+ transformer/conformer attention, with shape
206
+ (elayers, head, cache_t1, d_k * 2), where
207
+ `head * d_k == hidden-dim` and
208
+ `cache_t1 == chunk_size * num_decoding_left_chunks`.
209
+ cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
210
+ (elayers, b=1, hidden-dim, cache_t2), where
211
+ `cache_t2 == cnn.lorder - 1`
212
+
213
+ Returns:
214
+ torch.Tensor: output of current input xs,
215
+ with shape (b=1, chunk_size, hidden-dim).
216
+ torch.Tensor: new attention cache required for next chunk, with
217
+ dynamic shape (elayers, head, ?, d_k * 2)
218
+ depending on required_cache_size.
219
+ torch.Tensor: new conformer cnn cache required for next chunk, with
220
+ same shape as the original cnn_cache.
221
+
222
+ """
223
+ assert xs.size(0) == 1
224
+ # tmp_masks is just for interface compatibility
225
+ tmp_masks = torch.ones(1,
226
+ xs.size(1),
227
+ device=xs.device,
228
+ dtype=torch.bool)
229
+ tmp_masks = tmp_masks.unsqueeze(1)
230
+ if self.global_cmvn is not None:
231
+ xs = self.global_cmvn(xs)
232
+ # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
233
+ xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
234
+ # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
235
+ elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
236
+ chunk_size = xs.size(1)
237
+ attention_key_size = cache_t1 + chunk_size
238
+ pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
239
+ size=attention_key_size)
240
+ if required_cache_size < 0:
241
+ next_cache_start = 0
242
+ elif required_cache_size == 0:
243
+ next_cache_start = attention_key_size
244
+ else:
245
+ next_cache_start = max(attention_key_size - required_cache_size, 0)
246
+ r_att_cache = []
247
+ r_cnn_cache = []
248
+ for i, layer in enumerate(self.encoders):
249
+ # NOTE(xcsong): Before layer.forward
250
+ # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
251
+ # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
252
+ xs, _, new_att_cache, new_cnn_cache = layer(
253
+ xs,
254
+ att_mask,
255
+ pos_emb,
256
+ att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
257
+ cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache)
258
+ # NOTE(xcsong): After layer.forward
259
+ # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
260
+ # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
261
+ r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
262
+ r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
263
+ if self.normalize_before:
264
+ xs = self.after_norm(xs)
265
+
266
+ # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
267
+ # ? may be larger than cache_t1, it depends on required_cache_size
268
+ r_att_cache = torch.cat(r_att_cache, dim=0)
269
+ # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
270
+ r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
271
+
272
+ return (xs, r_att_cache, r_cnn_cache)
273
+
274
+ @torch.jit.unused
275
+ def forward_chunk_by_chunk(
276
+ self,
277
+ xs: torch.Tensor,
278
+ decoding_chunk_size: int,
279
+ num_decoding_left_chunks: int = -1,
280
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
281
+ """ Forward input chunk by chunk with chunk_size like a streaming
282
+ fashion
283
+
284
+ Here we should pay special attention to computation cache in the
285
+ streaming style forward chunk by chunk. Three things should be taken
286
+ into account for computation in the current network:
287
+ 1. transformer/conformer encoder layers output cache
288
+ 2. convolution in conformer
289
+ 3. convolution in subsampling
290
+
291
+ However, we don't implement subsampling cache for:
292
+ 1. We can control subsampling module to output the right result by
293
+ overlapping input instead of cache left context, even though it
294
+ wastes some computation, but subsampling only takes a very
295
+ small fraction of computation in the whole model.
296
+ 2. Typically, there are several covolution layers with subsampling
297
+ in subsampling module, it is tricky and complicated to do cache
298
+ with different convolution layers with different subsampling
299
+ rate.
300
+ 3. Currently, nn.Sequential is used to stack all the convolution
301
+ layers in subsampling, we need to rewrite it to make it work
302
+ with cache, which is not preferred.
303
+ Args:
304
+ xs (torch.Tensor): (1, max_len, dim)
305
+ chunk_size (int): decoding chunk size
306
+ """
307
+ assert decoding_chunk_size > 0
308
+ # The model is trained by static or dynamic chunk
309
+ assert self.static_chunk_size > 0 or self.use_dynamic_chunk
310
+ subsampling = self.embed.subsampling_rate
311
+ context = self.embed.right_context + 1 # Add current frame
312
+ stride = subsampling * decoding_chunk_size
313
+ decoding_window = (decoding_chunk_size - 1) * subsampling + context
314
+ num_frames = xs.size(1)
315
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
316
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
317
+ outputs = []
318
+ offset = 0
319
+ required_cache_size = decoding_chunk_size * num_decoding_left_chunks
320
+
321
+ # Feed forward overlap input step by step
322
+ for cur in range(0, num_frames - context + 1, stride):
323
+ end = min(cur + decoding_window, num_frames)
324
+ chunk_xs = xs[:, cur:end, :]
325
+ (y, att_cache,
326
+ cnn_cache) = self.forward_chunk(chunk_xs, offset,
327
+ required_cache_size, att_cache,
328
+ cnn_cache)
329
+ outputs.append(y)
330
+ offset += y.size(1)
331
+ ys = torch.cat(outputs, 1)
332
+ masks = torch.ones((1, 1, ys.size(1)),
333
+ device=ys.device,
334
+ dtype=torch.bool)
335
+ return ys, masks
336
+
337
+
338
+ class TransformerEncoder(BaseEncoder):
339
+ """Transformer encoder module."""
340
+
341
+ def __init__(
342
+ self,
343
+ input_size: int,
344
+ output_size: int = 256,
345
+ attention_heads: int = 4,
346
+ linear_units: int = 2048,
347
+ num_blocks: int = 6,
348
+ dropout_rate: float = 0.1,
349
+ positional_dropout_rate: float = 0.1,
350
+ attention_dropout_rate: float = 0.0,
351
+ input_layer: str = "conv2d",
352
+ pos_enc_layer_type: str = "abs_pos",
353
+ normalize_before: bool = True,
354
+ static_chunk_size: int = 0,
355
+ use_dynamic_chunk: bool = False,
356
+ global_cmvn: torch.nn.Module = None,
357
+ use_dynamic_left_chunk: bool = False,
358
+ key_bias: bool = True,
359
+ selfattention_layer_type: str = "selfattn",
360
+ activation_type: str = "relu",
361
+ gradient_checkpointing: bool = False,
362
+ ):
363
+ """ Construct TransformerEncoder
364
+
365
+ See Encoder for the meaning of each parameter.
366
+ """
367
+ super().__init__(input_size, output_size, attention_heads,
368
+ linear_units, num_blocks, dropout_rate,
369
+ positional_dropout_rate, attention_dropout_rate,
370
+ input_layer, pos_enc_layer_type, normalize_before,
371
+ static_chunk_size, use_dynamic_chunk, global_cmvn,
372
+ use_dynamic_left_chunk, gradient_checkpointing)
373
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
374
+ self.encoders = torch.nn.ModuleList([
375
+ TransformerEncoderLayer(
376
+ output_size,
377
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](attention_heads,
378
+ output_size,
379
+ attention_dropout_rate,
380
+ key_bias),
381
+ PositionwiseFeedForward(output_size, linear_units,
382
+ dropout_rate, activation),
383
+ dropout_rate, normalize_before) for _ in range(num_blocks)
384
+ ])
385
+
386
+
387
+ class ConformerEncoder(BaseEncoder):
388
+ """Conformer encoder module."""
389
+
390
+ def __init__(
391
+ self,
392
+ input_size: int,
393
+ output_size: int = 256,
394
+ attention_heads: int = 4,
395
+ linear_units: int = 2048,
396
+ num_blocks: int = 6,
397
+ dropout_rate: float = 0.1,
398
+ positional_dropout_rate: float = 0.1,
399
+ attention_dropout_rate: float = 0.0,
400
+ input_layer: str = "conv2d",
401
+ pos_enc_layer_type: str = "rel_pos",
402
+ normalize_before: bool = True,
403
+ static_chunk_size: int = 0,
404
+ use_dynamic_chunk: bool = False,
405
+ global_cmvn: torch.nn.Module = None,
406
+ use_dynamic_left_chunk: bool = False,
407
+ positionwise_conv_kernel_size: int = 1,
408
+ macaron_style: bool = True,
409
+ selfattention_layer_type: str = "rel_selfattn",
410
+ activation_type: str = "swish",
411
+ use_cnn_module: bool = True,
412
+ cnn_module_kernel: int = 15,
413
+ causal: bool = False,
414
+ cnn_module_norm: str = "batch_norm",
415
+ key_bias: bool = True,
416
+ gradient_checkpointing: bool = False,
417
+ ):
418
+ """Construct ConformerEncoder
419
+
420
+ Args:
421
+ input_size to use_dynamic_chunk, see in BaseEncoder
422
+ positionwise_conv_kernel_size (int): Kernel size of positionwise
423
+ conv1d layer.
424
+ macaron_style (bool): Whether to use macaron style for
425
+ positionwise layer.
426
+ selfattention_layer_type (str): Encoder attention layer type,
427
+ the parameter has no effect now, it's just for configure
428
+ compatibility.
429
+ activation_type (str): Encoder activation function type.
430
+ use_cnn_module (bool): Whether to use convolution module.
431
+ cnn_module_kernel (int): Kernel size of convolution module.
432
+ causal (bool): whether to use causal convolution or not.
433
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
434
+ """
435
+ super().__init__(input_size, output_size, attention_heads,
436
+ linear_units, num_blocks, dropout_rate,
437
+ positional_dropout_rate, attention_dropout_rate,
438
+ input_layer, pos_enc_layer_type, normalize_before,
439
+ static_chunk_size, use_dynamic_chunk, global_cmvn,
440
+ use_dynamic_left_chunk, gradient_checkpointing)
441
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
442
+
443
+ # self-attention module definition
444
+ encoder_selfattn_layer_args = (
445
+ attention_heads,
446
+ output_size,
447
+ attention_dropout_rate,
448
+ key_bias,
449
+ )
450
+ # feed-forward module definition
451
+ positionwise_layer_args = (
452
+ output_size,
453
+ linear_units,
454
+ dropout_rate,
455
+ activation,
456
+ )
457
+ # convolution module definition
458
+ convolution_layer_args = (output_size, cnn_module_kernel, activation,
459
+ cnn_module_norm, causal)
460
+
461
+ self.encoders = torch.nn.ModuleList([
462
+ ConformerEncoderLayer(
463
+ output_size,
464
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
465
+ *encoder_selfattn_layer_args),
466
+ PositionwiseFeedForward(*positionwise_layer_args),
467
+ PositionwiseFeedForward(
468
+ *positionwise_layer_args) if macaron_style else None,
469
+ ConvolutionModule(
470
+ *convolution_layer_args) if use_cnn_module else None,
471
+ dropout_rate,
472
+ normalize_before,
473
+ ) for _ in range(num_blocks)
474
+ ])
cosyvoice/transformer/encoder_layer.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Encoder self-attention layer definition."""
17
+
18
+ from typing import Optional, Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class TransformerEncoderLayer(nn.Module):
25
+ """Encoder layer module.
26
+
27
+ Args:
28
+ size (int): Input dimension.
29
+ self_attn (torch.nn.Module): Self-attention module instance.
30
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
31
+ instance can be used as the argument.
32
+ feed_forward (torch.nn.Module): Feed-forward module instance.
33
+ `PositionwiseFeedForward`, instance can be used as the argument.
34
+ dropout_rate (float): Dropout rate.
35
+ normalize_before (bool):
36
+ True: use layer_norm before each sub-block.
37
+ False: to use layer_norm after each sub-block.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ size: int,
43
+ self_attn: torch.nn.Module,
44
+ feed_forward: torch.nn.Module,
45
+ dropout_rate: float,
46
+ normalize_before: bool = True,
47
+ ):
48
+ """Construct an EncoderLayer object."""
49
+ super().__init__()
50
+ self.self_attn = self_attn
51
+ self.feed_forward = feed_forward
52
+ self.norm1 = nn.LayerNorm(size, eps=1e-12)
53
+ self.norm2 = nn.LayerNorm(size, eps=1e-12)
54
+ self.dropout = nn.Dropout(dropout_rate)
55
+ self.size = size
56
+ self.normalize_before = normalize_before
57
+
58
+ def forward(
59
+ self,
60
+ x: torch.Tensor,
61
+ mask: torch.Tensor,
62
+ pos_emb: torch.Tensor,
63
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
64
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
65
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
66
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
67
+ """Compute encoded features.
68
+
69
+ Args:
70
+ x (torch.Tensor): (#batch, time, size)
71
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
72
+ (0, 0, 0) means fake mask.
73
+ pos_emb (torch.Tensor): just for interface compatibility
74
+ to ConformerEncoderLayer
75
+ mask_pad (torch.Tensor): does not used in transformer layer,
76
+ just for unified api with conformer.
77
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
78
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
79
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
80
+ (#batch=1, size, cache_t2), not used here, it's for interface
81
+ compatibility to ConformerEncoderLayer.
82
+ Returns:
83
+ torch.Tensor: Output tensor (#batch, time, size).
84
+ torch.Tensor: Mask tensor (#batch, time, time).
85
+ torch.Tensor: att_cache tensor,
86
+ (#batch=1, head, cache_t1 + time, d_k * 2).
87
+ torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
88
+
89
+ """
90
+ residual = x
91
+ if self.normalize_before:
92
+ x = self.norm1(x)
93
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
94
+ x = residual + self.dropout(x_att)
95
+ if not self.normalize_before:
96
+ x = self.norm1(x)
97
+
98
+ residual = x
99
+ if self.normalize_before:
100
+ x = self.norm2(x)
101
+ x = residual + self.dropout(self.feed_forward(x))
102
+ if not self.normalize_before:
103
+ x = self.norm2(x)
104
+
105
+ fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
106
+ return x, mask, new_att_cache, fake_cnn_cache
107
+
108
+
109
+ class ConformerEncoderLayer(nn.Module):
110
+ """Encoder layer module.
111
+ Args:
112
+ size (int): Input dimension.
113
+ self_attn (torch.nn.Module): Self-attention module instance.
114
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
115
+ instance can be used as the argument.
116
+ feed_forward (torch.nn.Module): Feed-forward module instance.
117
+ `PositionwiseFeedForward` instance can be used as the argument.
118
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
119
+ instance.
120
+ `PositionwiseFeedForward` instance can be used as the argument.
121
+ conv_module (torch.nn.Module): Convolution module instance.
122
+ `ConvlutionModule` instance can be used as the argument.
123
+ dropout_rate (float): Dropout rate.
124
+ normalize_before (bool):
125
+ True: use layer_norm before each sub-block.
126
+ False: use layer_norm after each sub-block.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ size: int,
132
+ self_attn: torch.nn.Module,
133
+ feed_forward: Optional[nn.Module] = None,
134
+ feed_forward_macaron: Optional[nn.Module] = None,
135
+ conv_module: Optional[nn.Module] = None,
136
+ dropout_rate: float = 0.1,
137
+ normalize_before: bool = True,
138
+ ):
139
+ """Construct an EncoderLayer object."""
140
+ super().__init__()
141
+ self.self_attn = self_attn
142
+ self.feed_forward = feed_forward
143
+ self.feed_forward_macaron = feed_forward_macaron
144
+ self.conv_module = conv_module
145
+ self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
146
+ self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
147
+ if feed_forward_macaron is not None:
148
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
149
+ self.ff_scale = 0.5
150
+ else:
151
+ self.ff_scale = 1.0
152
+ if self.conv_module is not None:
153
+ self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
154
+ self.norm_final = nn.LayerNorm(
155
+ size, eps=1e-12) # for the final output of the block
156
+ self.dropout = nn.Dropout(dropout_rate)
157
+ self.size = size
158
+ self.normalize_before = normalize_before
159
+
160
+ def forward(
161
+ self,
162
+ x: torch.Tensor,
163
+ mask: torch.Tensor,
164
+ pos_emb: torch.Tensor,
165
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
166
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
167
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
168
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
169
+ """Compute encoded features.
170
+
171
+ Args:
172
+ x (torch.Tensor): (#batch, time, size)
173
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
174
+ (0, 0, 0) means fake mask.
175
+ pos_emb (torch.Tensor): positional encoding, must not be None
176
+ for ConformerEncoderLayer.
177
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
178
+ (#batch, 1,time), (0, 0, 0) means fake mask.
179
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
180
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
181
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
182
+ (#batch=1, size, cache_t2)
183
+ Returns:
184
+ torch.Tensor: Output tensor (#batch, time, size).
185
+ torch.Tensor: Mask tensor (#batch, time, time).
186
+ torch.Tensor: att_cache tensor,
187
+ (#batch=1, head, cache_t1 + time, d_k * 2).
188
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
189
+ """
190
+
191
+ # whether to use macaron style
192
+ if self.feed_forward_macaron is not None:
193
+ residual = x
194
+ if self.normalize_before:
195
+ x = self.norm_ff_macaron(x)
196
+ x = residual + self.ff_scale * self.dropout(
197
+ self.feed_forward_macaron(x))
198
+ if not self.normalize_before:
199
+ x = self.norm_ff_macaron(x)
200
+
201
+ # multi-headed self-attention module
202
+ residual = x
203
+ if self.normalize_before:
204
+ x = self.norm_mha(x)
205
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
206
+ att_cache)
207
+ x = residual + self.dropout(x_att)
208
+ if not self.normalize_before:
209
+ x = self.norm_mha(x)
210
+
211
+ # convolution module
212
+ # Fake new cnn cache here, and then change it in conv_module
213
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
214
+ if self.conv_module is not None:
215
+ residual = x
216
+ if self.normalize_before:
217
+ x = self.norm_conv(x)
218
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
219
+ x = residual + self.dropout(x)
220
+
221
+ if not self.normalize_before:
222
+ x = self.norm_conv(x)
223
+
224
+ # feed forward module
225
+ residual = x
226
+ if self.normalize_before:
227
+ x = self.norm_ff(x)
228
+
229
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
230
+ if not self.normalize_before:
231
+ x = self.norm_ff(x)
232
+
233
+ if self.conv_module is not None:
234
+ x = self.norm_final(x)
235
+
236
+ return x, mask, new_att_cache, new_cnn_cache
cosyvoice/transformer/label_smoothing_loss.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Label smoothing module."""
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+
21
+ class LabelSmoothingLoss(nn.Module):
22
+ """Label-smoothing loss.
23
+
24
+ In a standard CE loss, the label's data distribution is:
25
+ [0,1,2] ->
26
+ [
27
+ [1.0, 0.0, 0.0],
28
+ [0.0, 1.0, 0.0],
29
+ [0.0, 0.0, 1.0],
30
+ ]
31
+
32
+ In the smoothing version CE Loss,some probabilities
33
+ are taken from the true label prob (1.0) and are divided
34
+ among other labels.
35
+
36
+ e.g.
37
+ smoothing=0.1
38
+ [0,1,2] ->
39
+ [
40
+ [0.9, 0.05, 0.05],
41
+ [0.05, 0.9, 0.05],
42
+ [0.05, 0.05, 0.9],
43
+ ]
44
+
45
+ Args:
46
+ size (int): the number of class
47
+ padding_idx (int): padding class id which will be ignored for loss
48
+ smoothing (float): smoothing rate (0.0 means the conventional CE)
49
+ normalize_length (bool):
50
+ normalize loss by sequence length if True
51
+ normalize loss by batch size if False
52
+ """
53
+
54
+ def __init__(self,
55
+ size: int,
56
+ padding_idx: int,
57
+ smoothing: float,
58
+ normalize_length: bool = False):
59
+ """Construct an LabelSmoothingLoss object."""
60
+ super(LabelSmoothingLoss, self).__init__()
61
+ self.criterion = nn.KLDivLoss(reduction="none")
62
+ self.padding_idx = padding_idx
63
+ self.confidence = 1.0 - smoothing
64
+ self.smoothing = smoothing
65
+ self.size = size
66
+ self.normalize_length = normalize_length
67
+
68
+ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
69
+ """Compute loss between x and target.
70
+
71
+ The model outputs and data labels tensors are flatten to
72
+ (batch*seqlen, class) shape and a mask is applied to the
73
+ padding part which should not be calculated for loss.
74
+
75
+ Args:
76
+ x (torch.Tensor): prediction (batch, seqlen, class)
77
+ target (torch.Tensor):
78
+ target signal masked with self.padding_id (batch, seqlen)
79
+ Returns:
80
+ loss (torch.Tensor) : The KL loss, scalar float value
81
+ """
82
+ assert x.size(2) == self.size
83
+ batch_size = x.size(0)
84
+ x = x.view(-1, self.size)
85
+ target = target.view(-1)
86
+ # use zeros_like instead of torch.no_grad() for true_dist,
87
+ # since no_grad() can not be exported by JIT
88
+ true_dist = torch.zeros_like(x)
89
+ true_dist.fill_(self.smoothing / (self.size - 1))
90
+ ignore = target == self.padding_idx # (B,)
91
+ total = len(target) - ignore.sum().item()
92
+ target = target.masked_fill(ignore, 0) # avoid -1 index
93
+ true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
94
+ kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
95
+ denom = total if self.normalize_length else batch_size
96
+ return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
cosyvoice/transformer/positionwise_feed_forward.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Positionwise feed forward layer definition."""
16
+
17
+ import torch
18
+
19
+
20
+ class PositionwiseFeedForward(torch.nn.Module):
21
+ """Positionwise feed forward layer.
22
+
23
+ FeedForward are appied on each position of the sequence.
24
+ The output dim is same with the input dim.
25
+
26
+ Args:
27
+ idim (int): Input dimenstion.
28
+ hidden_units (int): The number of hidden units.
29
+ dropout_rate (float): Dropout rate.
30
+ activation (torch.nn.Module): Activation function
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ idim: int,
36
+ hidden_units: int,
37
+ dropout_rate: float,
38
+ activation: torch.nn.Module = torch.nn.ReLU(),
39
+ ):
40
+ """Construct a PositionwiseFeedForward object."""
41
+ super(PositionwiseFeedForward, self).__init__()
42
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
43
+ self.activation = activation
44
+ self.dropout = torch.nn.Dropout(dropout_rate)
45
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
46
+
47
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
48
+ """Forward function.
49
+
50
+ Args:
51
+ xs: input tensor (B, L, D)
52
+ Returns:
53
+ output tensor, (B, L, D)
54
+ """
55
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
56
+
57
+
58
+ class MoEFFNLayer(torch.nn.Module):
59
+ """
60
+ Mixture of expert with Positionwise feed forward layer
61
+ See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
62
+ The output dim is same with the input dim.
63
+
64
+ Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
65
+ https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
66
+ Args:
67
+ n_expert: number of expert.
68
+ n_expert_per_token: The actual number of experts used for each frame
69
+ idim (int): Input dimenstion.
70
+ hidden_units (int): The number of hidden units.
71
+ dropout_rate (float): Dropout rate.
72
+ activation (torch.nn.Module): Activation function
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ n_expert: int,
78
+ n_expert_per_token: int,
79
+ idim: int,
80
+ hidden_units: int,
81
+ dropout_rate: float,
82
+ activation: torch.nn.Module = torch.nn.ReLU(),
83
+ ):
84
+ super(MoEFFNLayer, self).__init__()
85
+ self.gate = torch.nn.Linear(idim, n_expert, bias=False)
86
+ self.experts = torch.nn.ModuleList(
87
+ PositionwiseFeedForward(idim, hidden_units, dropout_rate,
88
+ activation) for _ in range(n_expert))
89
+ self.n_expert_per_token = n_expert_per_token
90
+
91
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
92
+ """Foward function.
93
+ Args:
94
+ xs: input tensor (B, L, D)
95
+ Returns:
96
+ output tensor, (B, L, D)
97
+
98
+ """
99
+ B, L, D = xs.size(
100
+ ) # batch size, sequence length, embedding dimension (idim)
101
+ xs = xs.view(-1, D) # (B*L, D)
102
+ router = self.gate(xs) # (B*L, n_expert)
103
+ logits, indices = torch.topk(
104
+ router, self.n_expert_per_token
105
+ ) # probs:(B*L, n_expert), indices: (B*L, n_expert)
106
+ weights = torch.nn.functional.softmax(
107
+ logits, dim=1,
108
+ dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
109
+ output = torch.zeros_like(xs) # (B*L, D)
110
+ for i, expert in enumerate(self.experts):
111
+ mask = indices == i
112
+ batch_idx, ith_expert = torch.where(mask)
113
+ output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
114
+ xs[batch_idx])
115
+ return output.view(B, L, D)
cosyvoice/transformer/subsampling.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Subsampling layer definition."""
17
+
18
+ from typing import Tuple, Union
19
+
20
+ import torch
21
+
22
+
23
+ class BaseSubsampling(torch.nn.Module):
24
+
25
+ def __init__(self):
26
+ super().__init__()
27
+ self.right_context = 0
28
+ self.subsampling_rate = 1
29
+
30
+ def position_encoding(self, offset: Union[int, torch.Tensor],
31
+ size: int) -> torch.Tensor:
32
+ return self.pos_enc.position_encoding(offset, size)
33
+
34
+
35
+ class EmbedinigNoSubsampling(BaseSubsampling):
36
+ """Embedding input without subsampling
37
+ """
38
+
39
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
40
+ pos_enc_class: torch.nn.Module):
41
+ super().__init__()
42
+ self.embed = torch.nn.Embedding(idim, odim)
43
+ self.pos_enc = pos_enc_class
44
+
45
+ def forward(
46
+ self,
47
+ x: torch.Tensor,
48
+ x_mask: torch.Tensor,
49
+ offset: Union[int, torch.Tensor] = 0
50
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51
+ """Input x.
52
+
53
+ Args:
54
+ x (torch.Tensor): Input tensor (#batch, time, idim).
55
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
56
+
57
+ Returns:
58
+ torch.Tensor: linear input tensor (#batch, time', odim),
59
+ where time' = time .
60
+ torch.Tensor: linear input mask (#batch, 1, time'),
61
+ where time' = time .
62
+
63
+ """
64
+ x = self.embed(x)
65
+ x, pos_emb = self.pos_enc(x, offset)
66
+ return x, pos_emb, x_mask
67
+
68
+
69
+ class LinearNoSubsampling(BaseSubsampling):
70
+ """Linear transform the input without subsampling
71
+
72
+ Args:
73
+ idim (int): Input dimension.
74
+ odim (int): Output dimension.
75
+ dropout_rate (float): Dropout rate.
76
+
77
+ """
78
+
79
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
80
+ pos_enc_class: torch.nn.Module):
81
+ """Construct an linear object."""
82
+ super().__init__()
83
+ self.out = torch.nn.Sequential(
84
+ torch.nn.Linear(idim, odim),
85
+ torch.nn.LayerNorm(odim, eps=1e-5),
86
+ torch.nn.Dropout(dropout_rate),
87
+ )
88
+ self.pos_enc = pos_enc_class
89
+ self.right_context = 0
90
+ self.subsampling_rate = 1
91
+
92
+ def forward(
93
+ self,
94
+ x: torch.Tensor,
95
+ x_mask: torch.Tensor,
96
+ offset: Union[int, torch.Tensor] = 0
97
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
98
+ """Input x.
99
+
100
+ Args:
101
+ x (torch.Tensor): Input tensor (#batch, time, idim).
102
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
103
+
104
+ Returns:
105
+ torch.Tensor: linear input tensor (#batch, time', odim),
106
+ where time' = time .
107
+ torch.Tensor: linear input mask (#batch, 1, time'),
108
+ where time' = time .
109
+
110
+ """
111
+ x = self.out(x)
112
+ x, pos_emb = self.pos_enc(x, offset)
113
+ return x, pos_emb, x_mask
114
+
115
+
116
+ class Conv1dSubsampling2(BaseSubsampling):
117
+ """Convolutional 1D subsampling (to 1/2 length).
118
+ It is designed for Whisper, ref:
119
+ https://github.com/openai/whisper/blob/main/whisper/model.py
120
+
121
+ Args:
122
+ idim (int): Input dimension.
123
+ odim (int): Output dimension.
124
+ dropout_rate (float): Dropout rate.
125
+
126
+ """
127
+
128
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
129
+ pos_enc_class: torch.nn.Module):
130
+ """Construct an Conv1dSubsampling2 object."""
131
+ super().__init__()
132
+ self.conv = torch.nn.Sequential(
133
+ torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
134
+ torch.nn.GELU(),
135
+ torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
136
+ torch.nn.GELU(),
137
+ )
138
+ self.pos_enc = pos_enc_class
139
+ # The right context for every conv layer is computed by:
140
+ # (kernel_size - 1) * frame_rate_of_this_layer
141
+ self.subsampling_rate = 2
142
+ # 4 = (3 - 1) * 1 + (3 - 1) * 1
143
+ self.right_context = 4
144
+
145
+ def forward(
146
+ self,
147
+ x: torch.Tensor,
148
+ x_mask: torch.Tensor,
149
+ offset: Union[int, torch.Tensor] = 0
150
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
151
+ """Subsample x.
152
+
153
+ Args:
154
+ x (torch.Tensor): Input tensor (#batch, time, idim).
155
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
156
+
157
+ Returns:
158
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
159
+ where time' = time // 2.
160
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
161
+ where time' = time // 2.
162
+ torch.Tensor: positional encoding
163
+
164
+ """
165
+ time = x.size(1)
166
+ x = x.transpose(1, 2) # (b, f, t)
167
+ x = self.conv(x)
168
+ x = x.transpose(1, 2) # (b, t, f)
169
+ x, pos_emb = self.pos_enc(x, offset)
170
+ return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
171
+
172
+
173
+ class Conv2dSubsampling4(BaseSubsampling):
174
+ """Convolutional 2D subsampling (to 1/4 length).
175
+
176
+ Args:
177
+ idim (int): Input dimension.
178
+ odim (int): Output dimension.
179
+ dropout_rate (float): Dropout rate.
180
+
181
+ """
182
+
183
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
184
+ pos_enc_class: torch.nn.Module):
185
+ """Construct an Conv2dSubsampling4 object."""
186
+ super().__init__()
187
+ self.conv = torch.nn.Sequential(
188
+ torch.nn.Conv2d(1, odim, 3, 2),
189
+ torch.nn.ReLU(),
190
+ torch.nn.Conv2d(odim, odim, 3, 2),
191
+ torch.nn.ReLU(),
192
+ )
193
+ self.out = torch.nn.Sequential(
194
+ torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
195
+ self.pos_enc = pos_enc_class
196
+ # The right context for every conv layer is computed by:
197
+ # (kernel_size - 1) * frame_rate_of_this_layer
198
+ self.subsampling_rate = 4
199
+ # 6 = (3 - 1) * 1 + (3 - 1) * 2
200
+ self.right_context = 6
201
+
202
+ def forward(
203
+ self,
204
+ x: torch.Tensor,
205
+ x_mask: torch.Tensor,
206
+ offset: Union[int, torch.Tensor] = 0
207
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
208
+ """Subsample x.
209
+
210
+ Args:
211
+ x (torch.Tensor): Input tensor (#batch, time, idim).
212
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
213
+
214
+ Returns:
215
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
216
+ where time' = time // 4.
217
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
218
+ where time' = time // 4.
219
+ torch.Tensor: positional encoding
220
+
221
+ """
222
+ x = x.unsqueeze(1) # (b, c=1, t, f)
223
+ x = self.conv(x)
224
+ b, c, t, f = x.size()
225
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
226
+ x, pos_emb = self.pos_enc(x, offset)
227
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
228
+
229
+
230
+ class Conv2dSubsampling6(BaseSubsampling):
231
+ """Convolutional 2D subsampling (to 1/6 length).
232
+ Args:
233
+ idim (int): Input dimension.
234
+ odim (int): Output dimension.
235
+ dropout_rate (float): Dropout rate.
236
+ pos_enc (torch.nn.Module): Custom position encoding layer.
237
+ """
238
+
239
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
240
+ pos_enc_class: torch.nn.Module):
241
+ """Construct an Conv2dSubsampling6 object."""
242
+ super().__init__()
243
+ self.conv = torch.nn.Sequential(
244
+ torch.nn.Conv2d(1, odim, 3, 2),
245
+ torch.nn.ReLU(),
246
+ torch.nn.Conv2d(odim, odim, 5, 3),
247
+ torch.nn.ReLU(),
248
+ )
249
+ self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
250
+ odim)
251
+ self.pos_enc = pos_enc_class
252
+ # 10 = (3 - 1) * 1 + (5 - 1) * 2
253
+ self.subsampling_rate = 6
254
+ self.right_context = 10
255
+
256
+ def forward(
257
+ self,
258
+ x: torch.Tensor,
259
+ x_mask: torch.Tensor,
260
+ offset: Union[int, torch.Tensor] = 0
261
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
262
+ """Subsample x.
263
+ Args:
264
+ x (torch.Tensor): Input tensor (#batch, time, idim).
265
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
266
+
267
+ Returns:
268
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
269
+ where time' = time // 6.
270
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
271
+ where time' = time // 6.
272
+ torch.Tensor: positional encoding
273
+ """
274
+ x = x.unsqueeze(1) # (b, c, t, f)
275
+ x = self.conv(x)
276
+ b, c, t, f = x.size()
277
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
278
+ x, pos_emb = self.pos_enc(x, offset)
279
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
280
+
281
+
282
+ class Conv2dSubsampling8(BaseSubsampling):
283
+ """Convolutional 2D subsampling (to 1/8 length).
284
+
285
+ Args:
286
+ idim (int): Input dimension.
287
+ odim (int): Output dimension.
288
+ dropout_rate (float): Dropout rate.
289
+
290
+ """
291
+
292
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
293
+ pos_enc_class: torch.nn.Module):
294
+ """Construct an Conv2dSubsampling8 object."""
295
+ super().__init__()
296
+ self.conv = torch.nn.Sequential(
297
+ torch.nn.Conv2d(1, odim, 3, 2),
298
+ torch.nn.ReLU(),
299
+ torch.nn.Conv2d(odim, odim, 3, 2),
300
+ torch.nn.ReLU(),
301
+ torch.nn.Conv2d(odim, odim, 3, 2),
302
+ torch.nn.ReLU(),
303
+ )
304
+ self.linear = torch.nn.Linear(
305
+ odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
306
+ self.pos_enc = pos_enc_class
307
+ self.subsampling_rate = 8
308
+ # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
309
+ self.right_context = 14
310
+
311
+ def forward(
312
+ self,
313
+ x: torch.Tensor,
314
+ x_mask: torch.Tensor,
315
+ offset: Union[int, torch.Tensor] = 0
316
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
317
+ """Subsample x.
318
+
319
+ Args:
320
+ x (torch.Tensor): Input tensor (#batch, time, idim).
321
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
322
+
323
+ Returns:
324
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
325
+ where time' = time // 8.
326
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
327
+ where time' = time // 8.
328
+ torch.Tensor: positional encoding
329
+ """
330
+ x = x.unsqueeze(1) # (b, c, t, f)
331
+ x = self.conv(x)
332
+ b, c, t, f = x.size()
333
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
334
+ x, pos_emb = self.pos_enc(x, offset)
335
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
336
+
337
+
338
+ class LegacyLinearNoSubsampling(BaseSubsampling):
339
+ """Linear transform the input without subsampling
340
+
341
+ Args:
342
+ idim (int): Input dimension.
343
+ odim (int): Output dimension.
344
+ dropout_rate (float): Dropout rate.
345
+
346
+ """
347
+
348
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
349
+ pos_enc_class: torch.nn.Module):
350
+ """Construct an linear object."""
351
+ super().__init__()
352
+ self.out = torch.nn.Sequential(
353
+ torch.nn.Linear(idim, odim),
354
+ torch.nn.LayerNorm(odim, eps=1e-5),
355
+ torch.nn.Dropout(dropout_rate),
356
+ torch.nn.ReLU(),
357
+ )
358
+ self.pos_enc = pos_enc_class
359
+ self.right_context = 0
360
+ self.subsampling_rate = 1
361
+
362
+ def forward(
363
+ self,
364
+ x: torch.Tensor,
365
+ x_mask: torch.Tensor,
366
+ offset: Union[int, torch.Tensor] = 0
367
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
368
+ """Input x.
369
+
370
+ Args:
371
+ x (torch.Tensor): Input tensor (#batch, time, idim).
372
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
373
+
374
+ Returns:
375
+ torch.Tensor: linear input tensor (#batch, time', odim),
376
+ where time' = time .
377
+ torch.Tensor: linear input mask (#batch, 1, time'),
378
+ where time' = time .
379
+
380
+ """
381
+ x = self.out(x)
382
+ x, pos_emb = self.pos_enc(x, offset)
383
+ return x, pos_emb, x_mask
cosyvoice/transformer/upsample_encoder.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song ([email protected])
3
+ # 2024 Alibaba Inc (Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ """Encoder definition."""
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+ import torch.utils.checkpoint as ckpt
23
+ from torch.nn import functional as F
24
+
25
+ from cosyvoice.transformer.convolution import ConvolutionModule
26
+ from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
27
+ from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
28
+ from cosyvoice.utils.class_utils import (
29
+ COSYVOICE_EMB_CLASSES,
30
+ COSYVOICE_SUBSAMPLE_CLASSES,
31
+ COSYVOICE_ATTENTION_CLASSES,
32
+ COSYVOICE_ACTIVATION_CLASSES,
33
+ )
34
+ from cosyvoice.utils.mask import make_pad_mask
35
+ from cosyvoice.utils.mask import add_optional_chunk_mask
36
+
37
+
38
+ class Upsample1D(nn.Module):
39
+ """A 1D upsampling layer with an optional convolution.
40
+
41
+ Parameters:
42
+ channels (`int`):
43
+ number of channels in the inputs and outputs.
44
+ use_conv (`bool`, default `False`):
45
+ option to use a convolution.
46
+ use_conv_transpose (`bool`, default `False`):
47
+ option to use a convolution transpose.
48
+ out_channels (`int`, optional):
49
+ number of output channels. Defaults to `channels`.
50
+ """
51
+
52
+ def __init__(self, channels: int, out_channels: int, stride: int=2):
53
+ super().__init__()
54
+ self.channels = channels
55
+ self.out_channels = out_channels
56
+ self.stride = stride
57
+ # In this mode, first repeat interpolate, than conv with stride=1
58
+ self.conv = nn.Conv1d(
59
+ self.channels, self.out_channels, stride*2+1, stride=1,
60
+ padding=0,
61
+ )
62
+
63
+ def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
64
+ outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
65
+ outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
66
+ outputs = self.conv(outputs)
67
+ return outputs, input_lengths * self.stride
68
+
69
+
70
+ class PreLookaheadLayer(nn.Module):
71
+ def __init__(self, channels: int, pre_lookahead_len: int = 1):
72
+ super().__init__()
73
+ self.channels = channels
74
+ self.pre_lookahead_len = pre_lookahead_len
75
+ self.conv1 = nn.Conv1d(
76
+ channels, channels,
77
+ kernel_size=pre_lookahead_len+1,
78
+ stride=1, padding=0,
79
+ )
80
+ self.conv2 = nn.Conv1d(
81
+ channels, channels,
82
+ kernel_size=3, stride=1, padding=0,
83
+ )
84
+
85
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
86
+ """
87
+ inputs: (batch_size, seq_len, channels)
88
+ """
89
+ outputs = inputs.transpose(1, 2).contiguous()
90
+ # look ahead
91
+ outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
92
+ outputs = F.leaky_relu(self.conv1(outputs))
93
+ # outputs
94
+ outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
95
+ outputs = self.conv2(outputs)
96
+ outputs = outputs.transpose(1, 2).contiguous()
97
+
98
+ # residual connection
99
+ outputs = outputs + inputs
100
+ return outputs
101
+
102
+
103
+ class UpsampleConformerEncoder(torch.nn.Module):
104
+
105
+ def __init__(
106
+ self,
107
+ input_size: int,
108
+ output_size: int = 256,
109
+ attention_heads: int = 4,
110
+ linear_units: int = 2048,
111
+ num_blocks: int = 6,
112
+ dropout_rate: float = 0.1,
113
+ positional_dropout_rate: float = 0.1,
114
+ attention_dropout_rate: float = 0.0,
115
+ input_layer: str = "conv2d",
116
+ pos_enc_layer_type: str = "rel_pos",
117
+ normalize_before: bool = True,
118
+ static_chunk_size: int = 0,
119
+ use_dynamic_chunk: bool = False,
120
+ global_cmvn: torch.nn.Module = None,
121
+ use_dynamic_left_chunk: bool = False,
122
+ positionwise_conv_kernel_size: int = 1,
123
+ macaron_style: bool = True,
124
+ selfattention_layer_type: str = "rel_selfattn",
125
+ activation_type: str = "swish",
126
+ use_cnn_module: bool = True,
127
+ cnn_module_kernel: int = 15,
128
+ causal: bool = False,
129
+ cnn_module_norm: str = "batch_norm",
130
+ key_bias: bool = True,
131
+ gradient_checkpointing: bool = False,
132
+ ):
133
+ """
134
+ Args:
135
+ input_size (int): input dim
136
+ output_size (int): dimension of attention
137
+ attention_heads (int): the number of heads of multi head attention
138
+ linear_units (int): the hidden units number of position-wise feed
139
+ forward
140
+ num_blocks (int): the number of decoder blocks
141
+ dropout_rate (float): dropout rate
142
+ attention_dropout_rate (float): dropout rate in attention
143
+ positional_dropout_rate (float): dropout rate after adding
144
+ positional encoding
145
+ input_layer (str): input layer type.
146
+ optional [linear, conv2d, conv2d6, conv2d8]
147
+ pos_enc_layer_type (str): Encoder positional encoding layer type.
148
+ opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
149
+ normalize_before (bool):
150
+ True: use layer_norm before each sub-block of a layer.
151
+ False: use layer_norm after each sub-block of a layer.
152
+ static_chunk_size (int): chunk size for static chunk training and
153
+ decoding
154
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
155
+ training or not, You can only use fixed chunk(chunk_size > 0)
156
+ or dyanmic chunk size(use_dynamic_chunk = True)
157
+ global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
158
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
159
+ dynamic chunk training
160
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
161
+ gradient_checkpointing: rerunning a forward-pass segment for each
162
+ checkpointed segment during backward.
163
+ """
164
+ super().__init__()
165
+ self._output_size = output_size
166
+
167
+ self.global_cmvn = global_cmvn
168
+ self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
169
+ input_size,
170
+ output_size,
171
+ dropout_rate,
172
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
173
+ positional_dropout_rate),
174
+ )
175
+
176
+ self.normalize_before = normalize_before
177
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
178
+ self.static_chunk_size = static_chunk_size
179
+ self.use_dynamic_chunk = use_dynamic_chunk
180
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
181
+ self.gradient_checkpointing = gradient_checkpointing
182
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
183
+ # self-attention module definition
184
+ encoder_selfattn_layer_args = (
185
+ attention_heads,
186
+ output_size,
187
+ attention_dropout_rate,
188
+ key_bias,
189
+ )
190
+ # feed-forward module definition
191
+ positionwise_layer_args = (
192
+ output_size,
193
+ linear_units,
194
+ dropout_rate,
195
+ activation,
196
+ )
197
+ # convolution module definition
198
+ convolution_layer_args = (output_size, cnn_module_kernel, activation,
199
+ cnn_module_norm, causal)
200
+ self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
201
+ self.encoders = torch.nn.ModuleList([
202
+ ConformerEncoderLayer(
203
+ output_size,
204
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
205
+ *encoder_selfattn_layer_args),
206
+ PositionwiseFeedForward(*positionwise_layer_args),
207
+ PositionwiseFeedForward(
208
+ *positionwise_layer_args) if macaron_style else None,
209
+ ConvolutionModule(
210
+ *convolution_layer_args) if use_cnn_module else None,
211
+ dropout_rate,
212
+ normalize_before,
213
+ ) for _ in range(num_blocks)
214
+ ])
215
+ self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
216
+ self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
217
+ input_size,
218
+ output_size,
219
+ dropout_rate,
220
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
221
+ positional_dropout_rate),
222
+ )
223
+ self.up_encoders = torch.nn.ModuleList([
224
+ ConformerEncoderLayer(
225
+ output_size,
226
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
227
+ *encoder_selfattn_layer_args),
228
+ PositionwiseFeedForward(*positionwise_layer_args),
229
+ PositionwiseFeedForward(
230
+ *positionwise_layer_args) if macaron_style else None,
231
+ ConvolutionModule(
232
+ *convolution_layer_args) if use_cnn_module else None,
233
+ dropout_rate,
234
+ normalize_before,
235
+ ) for _ in range(4)
236
+ ])
237
+
238
+ def output_size(self) -> int:
239
+ return self._output_size
240
+
241
+ def forward(
242
+ self,
243
+ xs: torch.Tensor,
244
+ xs_lens: torch.Tensor,
245
+ decoding_chunk_size: int = 0,
246
+ num_decoding_left_chunks: int = -1,
247
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
248
+ """Embed positions in tensor.
249
+
250
+ Args:
251
+ xs: padded input tensor (B, T, D)
252
+ xs_lens: input length (B)
253
+ decoding_chunk_size: decoding chunk size for dynamic chunk
254
+ 0: default for training, use random dynamic chunk.
255
+ <0: for decoding, use full chunk.
256
+ >0: for decoding, use fixed chunk size as set.
257
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
258
+ the chunk size is decoding_chunk_size.
259
+ >=0: use num_decoding_left_chunks
260
+ <0: use all left chunks
261
+ Returns:
262
+ encoder output tensor xs, and subsampled masks
263
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
264
+ masks: torch.Tensor batch padding mask after subsample
265
+ (B, 1, T' ~= T/subsample_rate)
266
+ NOTE(xcsong):
267
+ We pass the `__call__` method of the modules instead of `forward` to the
268
+ checkpointing API because `__call__` attaches all the hooks of the module.
269
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
270
+ """
271
+ T = xs.size(1)
272
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
273
+ if self.global_cmvn is not None:
274
+ xs = self.global_cmvn(xs)
275
+ xs, pos_emb, masks = self.embed(xs, masks)
276
+ mask_pad = masks # (B, 1, T/subsample_rate)
277
+ chunk_masks = add_optional_chunk_mask(xs, masks,
278
+ self.use_dynamic_chunk,
279
+ self.use_dynamic_left_chunk,
280
+ decoding_chunk_size,
281
+ self.static_chunk_size,
282
+ num_decoding_left_chunks)
283
+ # lookahead + conformer encoder
284
+ xs = self.pre_lookahead_layer(xs)
285
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
286
+
287
+ # upsample + conformer encoder
288
+ xs = xs.transpose(1, 2).contiguous()
289
+ xs, xs_lens = self.up_layer(xs, xs_lens)
290
+ xs = xs.transpose(1, 2).contiguous()
291
+ T = xs.size(1)
292
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
293
+ xs, pos_emb, masks = self.up_embed(xs, masks)
294
+ mask_pad = masks # (B, 1, T/subsample_rate)
295
+ chunk_masks = add_optional_chunk_mask(xs, masks,
296
+ self.use_dynamic_chunk,
297
+ self.use_dynamic_left_chunk,
298
+ decoding_chunk_size,
299
+ self.static_chunk_size * self.up_layer.stride,
300
+ num_decoding_left_chunks)
301
+ xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
302
+
303
+ if self.normalize_before:
304
+ xs = self.after_norm(xs)
305
+ # Here we assume the mask is not changed in encoder layers, so just
306
+ # return the masks before encoder layers, and the masks will be used
307
+ # for cross attention with decoder later
308
+ return xs, masks
309
+
310
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
311
+ pos_emb: torch.Tensor,
312
+ mask_pad: torch.Tensor) -> torch.Tensor:
313
+ for layer in self.encoders:
314
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
315
+ return xs
316
+
317
+ def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
318
+ pos_emb: torch.Tensor,
319
+ mask_pad: torch.Tensor) -> torch.Tensor:
320
+ for layer in self.up_encoders:
321
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
322
+ return xs
cosyvoice/utils/__init__.py ADDED
File without changes
cosyvoice/utils/class_utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright [2023-11-28] <[email protected], Xingchen Song>
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import torch
16
+
17
+ from cosyvoice.transformer.activation import Swish
18
+ from cosyvoice.transformer.subsampling import (
19
+ LinearNoSubsampling,
20
+ EmbedinigNoSubsampling,
21
+ Conv1dSubsampling2,
22
+ Conv2dSubsampling4,
23
+ Conv2dSubsampling6,
24
+ Conv2dSubsampling8,
25
+ )
26
+ from cosyvoice.transformer.embedding import (PositionalEncoding,
27
+ RelPositionalEncoding,
28
+ WhisperPositionalEncoding,
29
+ LearnablePositionalEncoding,
30
+ NoPositionalEncoding)
31
+ from cosyvoice.transformer.attention import (MultiHeadedAttention,
32
+ RelPositionMultiHeadedAttention)
33
+ from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
34
+ from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
35
+
36
+
37
+ COSYVOICE_ACTIVATION_CLASSES = {
38
+ "hardtanh": torch.nn.Hardtanh,
39
+ "tanh": torch.nn.Tanh,
40
+ "relu": torch.nn.ReLU,
41
+ "selu": torch.nn.SELU,
42
+ "swish": getattr(torch.nn, "SiLU", Swish),
43
+ "gelu": torch.nn.GELU,
44
+ }
45
+
46
+ COSYVOICE_SUBSAMPLE_CLASSES = {
47
+ "linear": LinearNoSubsampling,
48
+ "linear_legacy": LegacyLinearNoSubsampling,
49
+ "embed": EmbedinigNoSubsampling,
50
+ "conv1d2": Conv1dSubsampling2,
51
+ "conv2d": Conv2dSubsampling4,
52
+ "conv2d6": Conv2dSubsampling6,
53
+ "conv2d8": Conv2dSubsampling8,
54
+ 'paraformer_dummy': torch.nn.Identity
55
+ }
56
+
57
+ COSYVOICE_EMB_CLASSES = {
58
+ "embed": PositionalEncoding,
59
+ "abs_pos": PositionalEncoding,
60
+ "rel_pos": RelPositionalEncoding,
61
+ "rel_pos_espnet": EspnetRelPositionalEncoding,
62
+ "no_pos": NoPositionalEncoding,
63
+ "abs_pos_whisper": WhisperPositionalEncoding,
64
+ "embed_learnable_pe": LearnablePositionalEncoding,
65
+ }
66
+
67
+ COSYVOICE_ATTENTION_CLASSES = {
68
+ "selfattn": MultiHeadedAttention,
69
+ "rel_selfattn": RelPositionMultiHeadedAttention,
70
+ }
cosyvoice/utils/common.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Unility functions for Transformer."""
17
+
18
+ import random
19
+ from typing import List
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ IGNORE_ID = -1
25
+
26
+
27
+ def pad_list(xs: List[torch.Tensor], pad_value: int):
28
+ """Perform padding for the list of tensors.
29
+
30
+ Args:
31
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
32
+ pad_value (float): Value for padding.
33
+
34
+ Returns:
35
+ Tensor: Padded tensor (B, Tmax, `*`).
36
+
37
+ Examples:
38
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
39
+ >>> x
40
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
41
+ >>> pad_list(x, 0)
42
+ tensor([[1., 1., 1., 1.],
43
+ [1., 1., 0., 0.],
44
+ [1., 0., 0., 0.]])
45
+
46
+ """
47
+ max_len = max([len(item) for item in xs])
48
+ batchs = len(xs)
49
+ ndim = xs[0].ndim
50
+ if ndim == 1:
51
+ pad_res = torch.zeros(batchs,
52
+ max_len,
53
+ dtype=xs[0].dtype,
54
+ device=xs[0].device)
55
+ elif ndim == 2:
56
+ pad_res = torch.zeros(batchs,
57
+ max_len,
58
+ xs[0].shape[1],
59
+ dtype=xs[0].dtype,
60
+ device=xs[0].device)
61
+ elif ndim == 3:
62
+ pad_res = torch.zeros(batchs,
63
+ max_len,
64
+ xs[0].shape[1],
65
+ xs[0].shape[2],
66
+ dtype=xs[0].dtype,
67
+ device=xs[0].device)
68
+ else:
69
+ raise ValueError(f"Unsupported ndim: {ndim}")
70
+ pad_res.fill_(pad_value)
71
+ for i in range(batchs):
72
+ pad_res[i, :len(xs[i])] = xs[i]
73
+ return pad_res
74
+
75
+
76
+ def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
77
+ ignore_label: int) -> torch.Tensor:
78
+ """Calculate accuracy.
79
+
80
+ Args:
81
+ pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
82
+ pad_targets (LongTensor): Target label tensors (B, Lmax).
83
+ ignore_label (int): Ignore label id.
84
+
85
+ Returns:
86
+ torch.Tensor: Accuracy value (0.0 - 1.0).
87
+
88
+ """
89
+ pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
90
+ pad_outputs.size(1)).argmax(2)
91
+ mask = pad_targets != ignore_label
92
+ numerator = torch.sum(
93
+ pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
94
+ denominator = torch.sum(mask)
95
+ return (numerator / denominator).detach()
96
+
97
+
98
+ def get_padding(kernel_size, dilation=1):
99
+ return int((kernel_size * dilation - dilation) / 2)
100
+
101
+
102
+ def init_weights(m, mean=0.0, std=0.01):
103
+ classname = m.__class__.__name__
104
+ if classname.find("Conv") != -1:
105
+ m.weight.data.normal_(mean, std)
106
+
107
+
108
+ # Repetition Aware Sampling in VALL-E 2
109
+ def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
110
+ top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
111
+ rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
112
+ if rep_num >= win_size * tau_r:
113
+ top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
114
+ return top_ids
115
+
116
+
117
+ def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
118
+ prob, indices = [], []
119
+ cum_prob = 0.0
120
+ sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
121
+ for i in range(len(sorted_idx)):
122
+ # sampling both top-p and numbers.
123
+ if cum_prob < top_p and len(prob) < top_k:
124
+ cum_prob += sorted_value[i]
125
+ prob.append(sorted_value[i])
126
+ indices.append(sorted_idx[i])
127
+ else:
128
+ break
129
+ prob = torch.tensor(prob).to(weighted_scores)
130
+ indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
131
+ top_ids = indices[prob.multinomial(1, replacement=True)]
132
+ return top_ids
133
+
134
+
135
+ def random_sampling(weighted_scores, decoded_tokens, sampling):
136
+ top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
137
+ return top_ids
138
+
139
+
140
+ def fade_in_out(fade_in_mel, fade_out_mel, window):
141
+ device = fade_in_mel.device
142
+ fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
143
+ mel_overlap_len = int(window.shape[0] / 2)
144
+ if fade_in_mel.device == torch.device('cpu'):
145
+ fade_in_mel = fade_in_mel.clone()
146
+ fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
147
+ fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
148
+ return fade_in_mel.to(device)
149
+
150
+
151
+ def set_all_random_seed(seed):
152
+ random.seed(seed)
153
+ np.random.seed(seed)
154
+ torch.manual_seed(seed)
155
+ torch.cuda.manual_seed_all(seed)
156
+
157
+
158
+ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
159
+ assert mask.dtype == torch.bool
160
+ assert dtype in [torch.float32, torch.bfloat16, torch.float16]
161
+ mask = mask.to(dtype)
162
+ # attention mask bias
163
+ # NOTE(Mddct): torch.finfo jit issues
164
+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
165
+ mask = (1.0 - mask) * torch.finfo(dtype).min
166
+ return mask
cosyvoice/utils/executor.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ from contextlib import nullcontext
18
+ import os
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+
23
+ from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join
24
+
25
+
26
+ class Executor:
27
+
28
+ def __init__(self, gan: bool = False):
29
+ self.gan = gan
30
+ self.step = 0
31
+ self.epoch = 0
32
+ self.rank = int(os.environ.get('RANK', 0))
33
+ self.device = torch.device('cuda:{}'.format(self.rank))
34
+
35
+ def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join):
36
+ ''' Train one epoch
37
+ '''
38
+
39
+ lr = optimizer.param_groups[0]['lr']
40
+ logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
41
+ logging.info('using accumulate grad, new batch size is {} times'
42
+ ' larger than before'.format(info_dict['accum_grad']))
43
+ # A context manager to be used in conjunction with an instance of
44
+ # torch.nn.parallel.DistributedDataParallel to be able to train
45
+ # with uneven inputs across participating processes.
46
+ model.train()
47
+ model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
48
+ with model_context():
49
+ for batch_idx, batch_dict in enumerate(train_data_loader):
50
+ info_dict["tag"] = "TRAIN"
51
+ info_dict["step"] = self.step
52
+ info_dict["epoch"] = self.epoch
53
+ info_dict["batch_idx"] = batch_idx
54
+ if cosyvoice_join(group_join, info_dict):
55
+ break
56
+
57
+ # Disable gradient synchronizations across DDP processes.
58
+ # Within this context, gradients will be accumulated on module
59
+ # variables, which will later be synchronized.
60
+ if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
61
+ context = model.no_sync
62
+ # Used for single gpu training and DDP gradient synchronization
63
+ # processes.
64
+ else:
65
+ context = nullcontext
66
+
67
+ with context():
68
+ info_dict = batch_forward(model, batch_dict, scaler, info_dict)
69
+ info_dict = batch_backward(model, scaler, info_dict)
70
+
71
+ info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
72
+ log_per_step(writer, info_dict)
73
+ # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
74
+ if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
75
+ (batch_idx + 1) % info_dict["accum_grad"] == 0:
76
+ dist.barrier()
77
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
78
+ model.train()
79
+ if (batch_idx + 1) % info_dict["accum_grad"] == 0:
80
+ self.step += 1
81
+ dist.barrier()
82
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
83
+
84
+ def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
85
+ writer, info_dict, scaler, group_join):
86
+ ''' Train one epoch
87
+ '''
88
+
89
+ lr = optimizer.param_groups[0]['lr']
90
+ logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
91
+ logging.info('using accumulate grad, new batch size is {} times'
92
+ ' larger than before'.format(info_dict['accum_grad']))
93
+ # A context manager to be used in conjunction with an instance of
94
+ # torch.nn.parallel.DistributedDataParallel to be able to train
95
+ # with uneven inputs across participating processes.
96
+ model.train()
97
+ model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
98
+ with model_context():
99
+ for batch_idx, batch_dict in enumerate(train_data_loader):
100
+ info_dict["tag"] = "TRAIN"
101
+ info_dict["step"] = self.step
102
+ info_dict["epoch"] = self.epoch
103
+ info_dict["batch_idx"] = batch_idx
104
+ if cosyvoice_join(group_join, info_dict):
105
+ break
106
+
107
+ # Disable gradient synchronizations across DDP processes.
108
+ # Within this context, gradients will be accumulated on module
109
+ # variables, which will later be synchronized.
110
+ if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
111
+ context = model.no_sync
112
+ # Used for single gpu training and DDP gradient synchronization
113
+ # processes.
114
+ else:
115
+ context = nullcontext
116
+
117
+ with context():
118
+ batch_dict['turn'] = 'discriminator'
119
+ info_dict = batch_forward(model, batch_dict, scaler, info_dict)
120
+ info_dict = batch_backward(model, scaler, info_dict)
121
+ info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict)
122
+ optimizer.zero_grad()
123
+ log_per_step(writer, info_dict)
124
+ with context():
125
+ batch_dict['turn'] = 'generator'
126
+ info_dict = batch_forward(model, batch_dict, scaler, info_dict)
127
+ info_dict = batch_backward(model, scaler, info_dict)
128
+ info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
129
+ optimizer_d.zero_grad()
130
+ log_per_step(writer, info_dict)
131
+ # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
132
+ if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
133
+ (batch_idx + 1) % info_dict["accum_grad"] == 0:
134
+ dist.barrier()
135
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
136
+ model.train()
137
+ if (batch_idx + 1) % info_dict["accum_grad"] == 0:
138
+ self.step += 1
139
+ dist.barrier()
140
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
141
+
142
+ @torch.inference_mode()
143
+ def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
144
+ ''' Cross validation on
145
+ '''
146
+ logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
147
+ model.eval()
148
+ total_num_utts, total_loss_dict = 0, {} # avoid division by 0
149
+ for batch_idx, batch_dict in enumerate(cv_data_loader):
150
+ info_dict["tag"] = "CV"
151
+ info_dict["step"] = self.step
152
+ info_dict["epoch"] = self.epoch
153
+ info_dict["batch_idx"] = batch_idx
154
+
155
+ num_utts = len(batch_dict["utts"])
156
+ total_num_utts += num_utts
157
+
158
+ if self.gan is True:
159
+ batch_dict['turn'] = 'generator'
160
+ info_dict = batch_forward(model, batch_dict, None, info_dict)
161
+
162
+ for k, v in info_dict['loss_dict'].items():
163
+ if k not in total_loss_dict:
164
+ total_loss_dict[k] = []
165
+ total_loss_dict[k].append(v.item() * num_utts)
166
+ log_per_step(None, info_dict)
167
+ for k, v in total_loss_dict.items():
168
+ total_loss_dict[k] = sum(v) / total_num_utts
169
+ info_dict['loss_dict'] = total_loss_dict
170
+ log_per_save(writer, info_dict)
171
+ model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
172
+ save_model(model, model_name, info_dict)
cosyvoice/utils/file_utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import torchaudio
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ logging.basicConfig(level=logging.DEBUG,
21
+ format='%(asctime)s %(levelname)s %(message)s')
22
+
23
+
24
+ def read_lists(list_file):
25
+ lists = []
26
+ with open(list_file, 'r', encoding='utf8') as fin:
27
+ for line in fin:
28
+ lists.append(line.strip())
29
+ return lists
30
+
31
+
32
+ def read_json_lists(list_file):
33
+ lists = read_lists(list_file)
34
+ results = {}
35
+ for fn in lists:
36
+ with open(fn, 'r', encoding='utf8') as fin:
37
+ results.update(json.load(fin))
38
+ return results
39
+
40
+
41
+ def load_wav(wav, target_sr):
42
+ # speech, sample_rate = torchaudio.load(wav)
43
+ # speech = speech.mean(dim=0, keepdim=True)
44
+ # if sample_rate != target_sr:
45
+ # assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
46
+ # speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
47
+
48
+ import librosa, torch
49
+ speech, _ = librosa.load(path=wav, sr=target_sr)
50
+ speech = torch.from_numpy(speech).unsqueeze(dim=0)
51
+ return speech
cosyvoice/utils/frontend_utils.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
17
+
18
+
19
+ # whether contain chinese character
20
+ def contains_chinese(text):
21
+ return bool(chinese_char_pattern.search(text))
22
+
23
+
24
+ # replace special symbol
25
+ def replace_corner_mark(text):
26
+ text = text.replace('²', '平方')
27
+ text = text.replace('³', '立方')
28
+ return text
29
+
30
+
31
+ # remove meaningless symbol
32
+ def remove_bracket(text):
33
+ text = text.replace('(', '').replace(')', '')
34
+ text = text.replace('【', '').replace('】', '')
35
+ text = text.replace('`', '').replace('`', '')
36
+ text = text.replace("——", " ")
37
+ return text
38
+
39
+
40
+ # spell Arabic numerals
41
+ def spell_out_number(text: str, inflect_parser):
42
+ new_text = []
43
+ st = None
44
+ for i, c in enumerate(text):
45
+ if not c.isdigit():
46
+ if st is not None:
47
+ num_str = inflect_parser.number_to_words(text[st: i])
48
+ new_text.append(num_str)
49
+ st = None
50
+ new_text.append(c)
51
+ else:
52
+ if st is None:
53
+ st = i
54
+ if st is not None and st < len(text):
55
+ num_str = inflect_parser.number_to_words(text[st:])
56
+ new_text.append(num_str)
57
+ return ''.join(new_text)
58
+
59
+
60
+ # split paragrah logic:
61
+ # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
62
+ # 2. cal sentence len according to lang
63
+ # 3. split sentence according to puncatation
64
+ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
65
+ def calc_utt_length(_text: str):
66
+ if lang == "zh":
67
+ return len(_text)
68
+ else:
69
+ return len(tokenize(_text))
70
+
71
+ def should_merge(_text: str):
72
+ if lang == "zh":
73
+ return len(_text) < merge_len
74
+ else:
75
+ return len(tokenize(_text)) < merge_len
76
+
77
+ if lang == "zh":
78
+ pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
79
+ else:
80
+ pounc = ['.', '?', '!', ';', ':']
81
+ if comma_split:
82
+ pounc.extend([',', ','])
83
+
84
+ if text[-1] not in pounc:
85
+ if lang == "zh":
86
+ text += "。"
87
+ else:
88
+ text += "."
89
+
90
+ st = 0
91
+ utts = []
92
+ for i, c in enumerate(text):
93
+ if c in pounc:
94
+ if len(text[st: i]) > 0:
95
+ utts.append(text[st: i] + c)
96
+ if i + 1 < len(text) and text[i + 1] in ['"', '”']:
97
+ tmp = utts.pop(-1)
98
+ utts.append(tmp + text[i + 1])
99
+ st = i + 2
100
+ else:
101
+ st = i + 1
102
+
103
+ final_utts = []
104
+ cur_utt = ""
105
+ for utt in utts:
106
+ if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
107
+ final_utts.append(cur_utt)
108
+ cur_utt = ""
109
+ cur_utt = cur_utt + utt
110
+ if len(cur_utt) > 0:
111
+ if should_merge(cur_utt) and len(final_utts) != 0:
112
+ final_utts[-1] = final_utts[-1] + cur_utt
113
+ else:
114
+ final_utts.append(cur_utt)
115
+
116
+ return final_utts
117
+
118
+
119
+ # remove blank between chinese character
120
+ def replace_blank(text: str):
121
+ out_str = []
122
+ for i, c in enumerate(text):
123
+ if c == " ":
124
+ if ((text[i + 1].isascii() and text[i + 1] != " ") and
125
+ (text[i - 1].isascii() and text[i - 1] != " ")):
126
+ out_str.append(c)
127
+ else:
128
+ out_str.append(c)
129
+ return "".join(out_str)