Spaces:
Running
on
L4
Running
on
L4
init push
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +231 -14
- app.py +211 -0
- cert.pem +32 -0
- cosyvoice/__init__.py +0 -0
- cosyvoice/bin/average_model.py +92 -0
- cosyvoice/bin/convert.py +168 -0
- cosyvoice/bin/export_jit.py +74 -0
- cosyvoice/bin/export_jit_cosyvoice2.py +60 -0
- cosyvoice/bin/export_onnx.py +112 -0
- cosyvoice/bin/export_onnx_cosyvoice2.py +110 -0
- cosyvoice/bin/export_trt_cosyvoce2.sh +3 -0
- cosyvoice/bin/inference.py +115 -0
- cosyvoice/bin/train.py +170 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +167 -0
- cosyvoice/cli/frontend.py +213 -0
- cosyvoice/cli/model.py +421 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +164 -0
- cosyvoice/dataset/processor.py +431 -0
- cosyvoice/flow/decoder.py +299 -0
- cosyvoice/flow/flow.py +232 -0
- cosyvoice/flow/flow_matching.py +235 -0
- cosyvoice/flow/length_regulator.py +69 -0
- cosyvoice/hifigan/discriminator.py +140 -0
- cosyvoice/hifigan/f0_predictor.py +55 -0
- cosyvoice/hifigan/generator.py +411 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +340 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +0 -0
- cosyvoice/tokenizer/tokenizer.py +277 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +294 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +322 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +70 -0
- cosyvoice/utils/common.py +166 -0
- cosyvoice/utils/executor.py +172 -0
- cosyvoice/utils/file_utils.py +51 -0
- cosyvoice/utils/frontend_utils.py +129 -0
README.md
CHANGED
@@ -1,14 +1,231 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=CosyVoice🤠&text2=Text-to-Speech%20💖%20Large%20Language%20Model&width=800&height=210)](https://github.com/Akshay090/svg-banners)
|
2 |
+
|
3 |
+
## 👉🏻 CosyVoice 👈🏻
|
4 |
+
**CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_2.pdf); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B)
|
5 |
+
|
6 |
+
**CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice-300M)
|
7 |
+
|
8 |
+
## Highlight🔥
|
9 |
+
|
10 |
+
**CosyVoice 2.0** has been released! Compared to version 1.0, the new version offers more accurate, more stable, faster, and better speech generation capabilities.
|
11 |
+
### Multilingual
|
12 |
+
- **Support Language**: Chinese, English, Japanese, Korean, Chinese dialects (Cantonese, Sichuanese, Shanghainese, Tianjinese, Wuhanese, etc.)
|
13 |
+
- **Crosslingual & Mixlingual**:Support zero-shot voice cloning for cross-lingual and code-switching scenarios.
|
14 |
+
### Ultra-Low Latency
|
15 |
+
- **Bidirectional Streaming Support**: CosyVoice 2.0 integrates offline and streaming modeling technologies.
|
16 |
+
- **Rapid First Packet Synthesis**: Achieves latency as low as 150ms while maintaining high-quality audio output.
|
17 |
+
### High Accuracy
|
18 |
+
- **Improved Pronunciation**: Reduces pronunciation errors by 30% to 50% compared to CosyVoice 1.0.
|
19 |
+
- **Benchmark Achievements**: Attains the lowest character error rate on the hard test set of the Seed-TTS evaluation set.
|
20 |
+
### Strong Stability
|
21 |
+
- **Consistency in Timbre**: Ensures reliable voice consistency for zero-shot and cross-language speech synthesis.
|
22 |
+
- **Cross-language Synthesis**: Marked improvements compared to version 1.0.
|
23 |
+
### Natural Experience
|
24 |
+
- **Enhanced Prosody and Sound Quality**: Improved alignment of synthesized audio, raising MOS evaluation scores from 5.4 to 5.53.
|
25 |
+
- **Emotional and Dialectal Flexibility**: Now supports more granular emotional controls and accent adjustments.
|
26 |
+
|
27 |
+
## Roadmap
|
28 |
+
|
29 |
+
- [x] 2024/12
|
30 |
+
|
31 |
+
- [x] 25hz cosyvoice 2.0 released
|
32 |
+
|
33 |
+
- [x] 2024/09
|
34 |
+
|
35 |
+
- [x] 25hz cosyvoice base model
|
36 |
+
- [x] 25hz cosyvoice voice conversion model
|
37 |
+
|
38 |
+
- [x] 2024/08
|
39 |
+
|
40 |
+
- [x] Repetition Aware Sampling(RAS) inference for llm stability
|
41 |
+
- [x] Streaming inference mode support, including kv cache and sdpa for rtf optimization
|
42 |
+
|
43 |
+
- [x] 2024/07
|
44 |
+
|
45 |
+
- [x] Flow matching training support
|
46 |
+
- [x] WeTextProcessing support when ttsfrd is not avaliable
|
47 |
+
- [x] Fastapi server and client
|
48 |
+
|
49 |
+
|
50 |
+
## Install
|
51 |
+
|
52 |
+
**Clone and install**
|
53 |
+
|
54 |
+
- Clone the repo
|
55 |
+
``` sh
|
56 |
+
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
57 |
+
# If you failed to clone submodule due to network failures, please run following command until success
|
58 |
+
cd CosyVoice
|
59 |
+
git submodule update --init --recursive
|
60 |
+
```
|
61 |
+
|
62 |
+
- Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
|
63 |
+
- Create Conda env:
|
64 |
+
|
65 |
+
``` sh
|
66 |
+
conda create -n cosyvoice python=3.8
|
67 |
+
conda activate cosyvoice
|
68 |
+
# pynini is required by WeTextProcessing, use conda to install it as it can be executed on all platform.
|
69 |
+
conda install -y -c conda-forge pynini==2.1.5
|
70 |
+
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
71 |
+
|
72 |
+
# If you encounter sox compatibility issues
|
73 |
+
# ubuntu
|
74 |
+
sudo apt-get install sox libsox-dev
|
75 |
+
# centos
|
76 |
+
sudo yum install sox sox-devel
|
77 |
+
```
|
78 |
+
|
79 |
+
**Model download**
|
80 |
+
|
81 |
+
We strongly recommend that you download our pretrained `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
|
82 |
+
|
83 |
+
If you are expert in this field, and you are only interested in training your own CosyVoice model from scratch, you can skip this step.
|
84 |
+
|
85 |
+
``` python
|
86 |
+
# SDK模型下载
|
87 |
+
from modelscope import snapshot_download
|
88 |
+
snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
|
89 |
+
snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
|
90 |
+
snapshot_download('iic/CosyVoice-300M-25Hz', local_dir='pretrained_models/CosyVoice-300M-25Hz')
|
91 |
+
snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
|
92 |
+
snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
|
93 |
+
snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
|
94 |
+
```
|
95 |
+
|
96 |
+
``` sh
|
97 |
+
# git模型下载,请确保已安装git lfs
|
98 |
+
mkdir -p pretrained_models
|
99 |
+
git clone https://www.modelscope.cn/iic/CosyVoice2-0.5B.git pretrained_models/CosyVoice2-0.5B
|
100 |
+
git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M
|
101 |
+
git clone https://www.modelscope.cn/iic/CosyVoice-300M-25Hz.git pretrained_models/CosyVoice-300M-25Hz
|
102 |
+
git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
|
103 |
+
git clone https://www.modelscope.cn/iic/CosyVoice-300M-Instruct.git pretrained_models/CosyVoice-300M-Instruct
|
104 |
+
git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
|
105 |
+
```
|
106 |
+
|
107 |
+
Optionaly, you can unzip `ttsfrd` resouce and install `ttsfrd` package for better text normalization performance.
|
108 |
+
|
109 |
+
Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use WeTextProcessing by default.
|
110 |
+
|
111 |
+
``` sh
|
112 |
+
cd pretrained_models/CosyVoice-ttsfrd/
|
113 |
+
unzip resource.zip -d .
|
114 |
+
pip install ttsfrd_dependency-0.1-py3-none-any.whl
|
115 |
+
pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
|
116 |
+
```
|
117 |
+
|
118 |
+
**Basic Usage**
|
119 |
+
|
120 |
+
We strongly recommend using `CosyVoice2-0.5B` for better performance.
|
121 |
+
For zero_shot/cross_lingual inference, please use `CosyVoice-300M` model.
|
122 |
+
For sft inference, please use `CosyVoice-300M-SFT` model.
|
123 |
+
For instruct inference, please use `CosyVoice-300M-Instruct` model.
|
124 |
+
First, add `third_party/Matcha-TTS` to your `PYTHONPATH`.
|
125 |
+
|
126 |
+
``` sh
|
127 |
+
export PYTHONPATH=third_party/Matcha-TTS
|
128 |
+
```
|
129 |
+
|
130 |
+
``` python
|
131 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
132 |
+
from cosyvoice.utils.file_utils import load_wav
|
133 |
+
import torchaudio
|
134 |
+
```
|
135 |
+
|
136 |
+
**CosyVoice2 Usage**
|
137 |
+
```python
|
138 |
+
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False)
|
139 |
+
|
140 |
+
# zero_shot usage
|
141 |
+
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
|
142 |
+
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
143 |
+
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
144 |
+
|
145 |
+
# instruct usage
|
146 |
+
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
|
147 |
+
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
148 |
+
```
|
149 |
+
|
150 |
+
**CosyVoice Usage**
|
151 |
+
```python
|
152 |
+
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True)
|
153 |
+
# sft usage
|
154 |
+
print(cosyvoice.list_avaliable_spks())
|
155 |
+
# change stream=True for chunk stream inference
|
156 |
+
for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
|
157 |
+
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
158 |
+
|
159 |
+
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-25Hz') # or change to pretrained_models/CosyVoice-300M for 50Hz inference
|
160 |
+
# zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
|
161 |
+
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
|
162 |
+
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
163 |
+
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
164 |
+
# cross_lingual usage
|
165 |
+
prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
|
166 |
+
for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
|
167 |
+
torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
168 |
+
# vc usage
|
169 |
+
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
|
170 |
+
source_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
|
171 |
+
for i, j in enumerate(cosyvoice.inference_vc(source_speech_16k, prompt_speech_16k, stream=False)):
|
172 |
+
torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
173 |
+
|
174 |
+
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
|
175 |
+
# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
|
176 |
+
for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
|
177 |
+
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
178 |
+
```
|
179 |
+
|
180 |
+
**Start web demo**
|
181 |
+
|
182 |
+
You can use our web demo page to get familiar with CosyVoice quickly.
|
183 |
+
We support sft/zero_shot/cross_lingual/instruct inference in web demo.
|
184 |
+
|
185 |
+
Please see the demo website for details.
|
186 |
+
|
187 |
+
``` python
|
188 |
+
# change iic/CosyVoice-300M-SFT for sft inference, or iic/CosyVoice-300M-Instruct for instruct inference
|
189 |
+
python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M
|
190 |
+
```
|
191 |
+
|
192 |
+
**Advanced Usage**
|
193 |
+
|
194 |
+
For advanced user, we have provided train and inference scripts in `examples/libritts/cosyvoice/run.sh`.
|
195 |
+
You can get familiar with CosyVoice following this recipie.
|
196 |
+
|
197 |
+
**Build for deployment**
|
198 |
+
|
199 |
+
Optionally, if you want to use grpc for service deployment,
|
200 |
+
you can run following steps. Otherwise, you can just ignore this step.
|
201 |
+
|
202 |
+
``` sh
|
203 |
+
cd runtime/python
|
204 |
+
docker build -t cosyvoice:v1.0 .
|
205 |
+
# change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
|
206 |
+
# for grpc usage
|
207 |
+
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
|
208 |
+
cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
209 |
+
# for fastapi usage
|
210 |
+
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && python3 server.py --port 50000 --model_dir iic/CosyVoice-300M && sleep infinity"
|
211 |
+
cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
212 |
+
```
|
213 |
+
|
214 |
+
## Discussion & Communication
|
215 |
+
|
216 |
+
You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
|
217 |
+
|
218 |
+
You can also scan the QR code to join our official Dingding chat group.
|
219 |
+
|
220 |
+
<img src="./asset/dingding.png" width="250px">
|
221 |
+
|
222 |
+
## Acknowledge
|
223 |
+
|
224 |
+
1. We borrowed a lot of code from [FunASR](https://github.com/modelscope/FunASR).
|
225 |
+
2. We borrowed a lot of code from [FunCodec](https://github.com/modelscope/FunCodec).
|
226 |
+
3. We borrowed a lot of code from [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS).
|
227 |
+
4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
|
228 |
+
5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
|
229 |
+
|
230 |
+
## Disclaimer
|
231 |
+
The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
|
app.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Liu Yue)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import argparse
|
17 |
+
import gradio as gr
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import torchaudio
|
21 |
+
import random
|
22 |
+
import librosa
|
23 |
+
from funasr import AutoModel
|
24 |
+
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
25 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
26 |
+
sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
|
27 |
+
|
28 |
+
from modelscope import snapshot_download
|
29 |
+
snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
|
30 |
+
snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
|
31 |
+
os.system('cd pretrained_models/CosyVoice-ttsfrd/ && pip install ttsfrd_dependency-0.1-py3-none-any.whl && pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl && apt install -y unzip && unzip resource.zip -d .')
|
32 |
+
|
33 |
+
from cosyvoice.cli.cosyvoice import CosyVoice2
|
34 |
+
from cosyvoice.utils.file_utils import load_wav, logging
|
35 |
+
from cosyvoice.utils.common import set_all_random_seed
|
36 |
+
|
37 |
+
inference_mode_list = ['3s极速复刻', '自然语言控制']
|
38 |
+
instruct_dict = {'3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
|
39 |
+
'自然语言控制': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入instruct文本\n3. 点击生成音频按钮'}
|
40 |
+
stream_mode_list = [('否', False), ('是', True)]
|
41 |
+
max_val = 0.8
|
42 |
+
|
43 |
+
|
44 |
+
def generate_seed():
|
45 |
+
seed = random.randint(1, 100000000)
|
46 |
+
return {
|
47 |
+
"__type__": "update",
|
48 |
+
"value": seed
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
def postprocess(speech, top_db=60, hop_length=220, win_length=440):
|
53 |
+
speech, _ = librosa.effects.trim(
|
54 |
+
speech, top_db=top_db,
|
55 |
+
frame_length=win_length,
|
56 |
+
hop_length=hop_length
|
57 |
+
)
|
58 |
+
if speech.abs().max() > max_val:
|
59 |
+
speech = speech / speech.abs().max() * max_val
|
60 |
+
speech = torch.concat([speech, torch.zeros(1, int(target_sr * 0.2))], dim=1)
|
61 |
+
return speech
|
62 |
+
|
63 |
+
|
64 |
+
def change_instruction(mode_checkbox_group):
|
65 |
+
return instruct_dict[mode_checkbox_group]
|
66 |
+
|
67 |
+
def prompt_wav_recognition(prompt_wav):
|
68 |
+
res = asr_model.generate(input=prompt_wav,
|
69 |
+
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
70 |
+
use_itn=True,
|
71 |
+
)
|
72 |
+
text = res[0]["text"].split('|>')[-1]
|
73 |
+
return text
|
74 |
+
|
75 |
+
def generate_audio(tts_text, mode_checkbox_group, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
|
76 |
+
seed, stream):
|
77 |
+
sft_dropdown, speed = '', 1.0
|
78 |
+
if prompt_wav_upload is not None:
|
79 |
+
prompt_wav = prompt_wav_upload
|
80 |
+
elif prompt_wav_record is not None:
|
81 |
+
prompt_wav = prompt_wav_record
|
82 |
+
else:
|
83 |
+
prompt_wav = None
|
84 |
+
# if instruct mode, please make sure that model is iic/CosyVoice-300M-Instruct and not cross_lingual mode
|
85 |
+
if mode_checkbox_group in ['自然语言控制']:
|
86 |
+
if instruct_text == '':
|
87 |
+
gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本')
|
88 |
+
yield (target_sr, default_data)
|
89 |
+
if prompt_wav is None:
|
90 |
+
gr.Info('您正在使用自然语言控制模式, 请输入prompt音频')
|
91 |
+
# if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language
|
92 |
+
if mode_checkbox_group in ['跨语种复刻']:
|
93 |
+
if cosyvoice.frontend.instruct is True:
|
94 |
+
gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M模型'.format(args.model_dir))
|
95 |
+
yield (target_sr, default_data)
|
96 |
+
if instruct_text != '':
|
97 |
+
gr.Info('您正在使用跨语种复刻模式, instruct文本会被忽略')
|
98 |
+
if prompt_wav is None:
|
99 |
+
gr.Warning('您正在使用跨语种复刻模式, 请提供prompt音频')
|
100 |
+
yield (target_sr, default_data)
|
101 |
+
gr.Info('您正在使用跨语种复刻模式, 请确保合成文本和prompt文本为不同语言')
|
102 |
+
# if in zero_shot cross_lingual, please make sure that prompt_text and prompt_wav meets requirements
|
103 |
+
if mode_checkbox_group in ['3s极速复刻', '跨语种复刻']:
|
104 |
+
if prompt_wav is None:
|
105 |
+
gr.Warning('prompt音频为空,您是否忘记输入prompt音频?')
|
106 |
+
yield (target_sr, default_data)
|
107 |
+
if torchaudio.info(prompt_wav).sample_rate < prompt_sr:
|
108 |
+
gr.Warning('prompt音频采样率{}低于{}'.format(torchaudio.info(prompt_wav).sample_rate, prompt_sr))
|
109 |
+
yield (target_sr, default_data)
|
110 |
+
# sft mode only use sft_dropdown
|
111 |
+
if mode_checkbox_group in ['预训练音色']:
|
112 |
+
if instruct_text != '' or prompt_wav is not None or prompt_text != '':
|
113 |
+
gr.Info('您正在使用预训练音色模式,prompt文本/prompt音频/instruct文本会被忽略!')
|
114 |
+
# zero_shot mode only use prompt_wav prompt text
|
115 |
+
if mode_checkbox_group in ['3s极速复刻']:
|
116 |
+
if prompt_text == '':
|
117 |
+
gr.Warning('prompt文本为空,您是否忘记输入prompt文本?')
|
118 |
+
yield (target_sr, default_data)
|
119 |
+
if instruct_text != '':
|
120 |
+
gr.Info('您正在使用3s极速复刻模式,预训练音色/instruct文本会被忽略!')
|
121 |
+
info = torchaudio.info(prompt_wav)
|
122 |
+
if info.num_frames / info.sample_rate > 10:
|
123 |
+
gr.Warning('请限制输入音频在10s内,避免推理效果过低')
|
124 |
+
yield (target_sr, default_data)
|
125 |
+
|
126 |
+
if mode_checkbox_group == '预训练音色':
|
127 |
+
logging.info('get sft inference request')
|
128 |
+
set_all_random_seed(seed)
|
129 |
+
for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream, speed=speed):
|
130 |
+
yield (target_sr, i['tts_speech'].numpy().flatten())
|
131 |
+
elif mode_checkbox_group == '3s极速复刻':
|
132 |
+
logging.info('get zero_shot inference request')
|
133 |
+
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
|
134 |
+
set_all_random_seed(seed)
|
135 |
+
for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed):
|
136 |
+
yield (target_sr, i['tts_speech'].numpy().flatten())
|
137 |
+
elif mode_checkbox_group == '跨语种复刻':
|
138 |
+
logging.info('get cross_lingual inference request')
|
139 |
+
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
|
140 |
+
set_all_random_seed(seed)
|
141 |
+
for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed):
|
142 |
+
yield (target_sr, i['tts_speech'].numpy().flatten())
|
143 |
+
else:
|
144 |
+
logging.info('get instruct inference request')
|
145 |
+
logging.info('get instruct inference request')
|
146 |
+
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
|
147 |
+
set_all_random_seed(seed)
|
148 |
+
for i in cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k, stream=stream, speed=speed):
|
149 |
+
yield (target_sr, i['tts_speech'].numpy().flatten())
|
150 |
+
|
151 |
+
|
152 |
+
def main():
|
153 |
+
with gr.Blocks() as demo:
|
154 |
+
gr.Markdown("### 代码库 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) \
|
155 |
+
预训练模型 [CosyVoice2-0.5B](https://www.modelscope.cn/models/iic/CosyVoice2-0.5B) \
|
156 |
+
[CosyVoice-300M](https://www.modelscope.cn/models/iic/CosyVoice-300M) \
|
157 |
+
[CosyVoice-300M-Instruct](https://www.modelscope.cn/models/iic/CosyVoice-300M-Instruct) \
|
158 |
+
[CosyVoice-300M-SFT](https://www.modelscope.cn/models/iic/CosyVoice-300M-SFT)")
|
159 |
+
gr.Markdown("#### 请输入需要合成的文本,选择推理模式,并按照提示步骤进行操作")
|
160 |
+
|
161 |
+
tts_text = gr.Textbox(label="输入合成文本", lines=1, value="CosyVoice迎来全面升级,提供更准、更稳、更快、 更好的语音生成能力。CosyVoice is undergoing a comprehensive upgrade, providing more accurate, stable, faster, and better voice generation capabilities.")
|
162 |
+
with gr.Row():
|
163 |
+
mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
|
164 |
+
instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
|
165 |
+
stream = gr.Radio(choices=stream_mode_list, label='是否流式推理', value=stream_mode_list[0][1])
|
166 |
+
with gr.Column(scale=0.25):
|
167 |
+
seed_button = gr.Button(value="\U0001F3B2")
|
168 |
+
seed = gr.Number(value=0, label="随机推理种子")
|
169 |
+
|
170 |
+
with gr.Row():
|
171 |
+
prompt_wav_upload = gr.Audio(sources='upload', type='filepath', label='选择prompt音频文件,注意采样率不低于16khz')
|
172 |
+
prompt_wav_record = gr.Audio(sources='microphone', type='filepath', label='录制prompt音频文件')
|
173 |
+
prompt_text = gr.Textbox(label="prompt文本", lines=1, placeholder="请输入prompt文本,支持自动识别,您可以自行修正识别结果...", value='')
|
174 |
+
instruct_text = gr.Textbox(label="输入instruct文本", lines=1, placeholder="请输入instruct文本.例如:用四川话说这句话。", value='')
|
175 |
+
|
176 |
+
generate_button = gr.Button("生成音频")
|
177 |
+
|
178 |
+
audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True)
|
179 |
+
|
180 |
+
seed_button.click(generate_seed, inputs=[], outputs=seed)
|
181 |
+
generate_button.click(generate_audio,
|
182 |
+
inputs=[tts_text, mode_checkbox_group, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
|
183 |
+
seed, stream],
|
184 |
+
outputs=[audio_output])
|
185 |
+
mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
|
186 |
+
prompt_wav_upload.change(fn=prompt_wav_recognition, inputs=[prompt_wav_upload], outputs=[prompt_text])
|
187 |
+
prompt_wav_record.change(fn=prompt_wav_recognition, inputs=[prompt_wav_record], outputs=[prompt_text])
|
188 |
+
demo.queue(max_size=4, default_concurrency_limit=2).launch(server_port=50000)
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == '__main__':
|
192 |
+
load_jit = True if os.environ.get('jit') == '1' else False
|
193 |
+
load_onnx = True if os.environ.get('onnx') == '1' else False
|
194 |
+
load_trt = True if os.environ.get('trt') == '1' else False
|
195 |
+
logging.info('cosyvoice args load_jit {} load_onnx {} load_trt {}'.format(load_jit, load_onnx, load_trt))
|
196 |
+
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=load_jit, load_onnx=load_onnx, load_trt=load_trt)
|
197 |
+
sft_spk = cosyvoice.list_avaliable_spks()
|
198 |
+
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
|
199 |
+
for stream in [True, False]:
|
200 |
+
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=stream)):
|
201 |
+
continue
|
202 |
+
prompt_sr, target_sr = 16000, 24000
|
203 |
+
default_data = np.zeros(target_sr)
|
204 |
+
|
205 |
+
model_dir = "iic/SenseVoiceSmall"
|
206 |
+
asr_model = AutoModel(
|
207 |
+
model=model_dir,
|
208 |
+
disable_update=True,
|
209 |
+
log_level='DEBUG',
|
210 |
+
device="cuda:0")
|
211 |
+
main()
|
cert.pem
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFkTCCA3mgAwIBAgIUEO2zq0OQeuRFIFH4lfHLgcR5hTUwDQYJKoZIhvcNAQEL
|
3 |
+
BQAwWDELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAlpKMQswCQYDVQQHDAJIWjEQMA4G
|
4 |
+
A1UECgwHY29tcGFueTEMMAoGA1UECwwDYWxpMQ8wDQYDVQQDDAZncmFkaW8wHhcN
|
5 |
+
MjQwNDI5MTIxNDQxWhcNMjUwNDI5MTIxNDQxWjBYMQswCQYDVQQGEwJDTjELMAkG
|
6 |
+
A1UECAwCWkoxCzAJBgNVBAcMAkhaMRAwDgYDVQQKDAdjb21wYW55MQwwCgYDVQQL
|
7 |
+
DANhbGkxDzANBgNVBAMMBmdyYWRpbzCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCC
|
8 |
+
AgoCggIBAKohzP3V7VdDyMgfRO4+xzh/mWFPapQWJIIrhnHj8GRJ9tgFVXf71vcU
|
9 |
+
PMo+/t+y0rjupw3WwWIj6kJP15t46xxmLzoJHZKHV7d1Y7XJTyN1hvRCzeGz6w/E
|
10 |
+
VX0y6U+0y9m1HG0kvsfLwCKZPxEN21RfPukGN3qOIpjaRvE6fxg8DCUQN8qEpjQ9
|
11 |
+
DQehq/g0B/wZFwIB2089+BeqesjaOinY2+z4YiMreIj2dy8XM6G59quS21oe0u5n
|
12 |
+
6SW80ayf/yA6CHqblCHNfdi3vrzxMalNjT5EHKxQsLEDd2nWSndoPeXClXdSoIpE
|
13 |
+
1+H86dWHZpzPLd6rOfa+FCZ3TQsZbL+p3ree2AIMIB7zWw59oKGE8UuZbtyCVWK6
|
14 |
+
hufMOs703ZT97WeBEoOA72itUwCBqAakYNoULvYSOuXZT0LvJN1Z4YLNTkJXDA0u
|
15 |
+
vMABPbRFXfFK67F/fLm/vges4dhhpQNeSxSuXEC7rMA5hCQRk3BccdEgxoBfNZcM
|
16 |
+
HKo8CaB3wxbK7inXZb3JD4sFK64H5VjfJE8ibFzoIhiPICuC+0bzSKfc0+dcUNMb
|
17 |
+
KsE5M3etmS1TcPKuebk9OTu8YUJiNMYgEInw7vCq004v4IOqQr0aX/LGRm21RB/i
|
18 |
+
M3qFKCSHSw5/Z+o9sZ/kw3AeNnx5r5dq4OAswx3RhScPJtd6qesZAgMBAAGjUzBR
|
19 |
+
MB0GA1UdDgQWBBSNZx2v1BNAGL4gGM4TUXIvn1OyFTAfBgNVHSMEGDAWgBSNZx2v
|
20 |
+
1BNAGL4gGM4TUXIvn1OyFTAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUA
|
21 |
+
A4ICAQA/khg91VtI/tDLCLyQ6ZMulfOzJHuGmIs4cvG5fIOvzYjQpvAGSgNeivKp
|
22 |
+
+5RIkpUabcwdUCq6VXeXheo+SaGgVdwpxQy/p/E+i+AengRB5Qm/hJ5lLU6CdNBq
|
23 |
+
WCN/0Aa1GL/pM4HAzVQY81HeB46UaHWtW6J9hnBbVg2MF2GanAqfeODpZqIHEggt
|
24 |
+
Vw2ivElV47JTFZsNU+JYG5ECsfTjNQYpoA6Hyb/d5ZW8YsfOjr8oIBM4QyZWq1Ke
|
25 |
+
eAlytVwl9lj4AkAQIAgkrJHkLjj5yjZ7Hir5NjBuBx06FDAIFb2XWgNnq4ua/pSq
|
26 |
+
9fL4cxx4cEJku1X/FYtUBbWsXe8uFGwTEGHuEZR3pj5VSFbuNlARLIsq8/gh8MRQ
|
27 |
+
NjKQIlTVINkuOFuVmSrLC5nIwTPhlpEFwIQPGzFD2DbVNor9EXQ2b89WtHqZAZik
|
28 |
+
qFDb76JM9jctf9n8l96oSKrwEaCoFmRojnyyYl9UByJxPRCeTJ//i2vxeTvLC3FT
|
29 |
+
Rw2jFi/pwoqSVmJtuAFLT96/x2qKpgk+M1zG3oFiDV1lxY8sw1RA3Mm4s3Cm8H5A
|
30 |
+
3E+6R34XZLifqhxLVcyDsRWPcqte3Pt6v/xXWN+EuOigK4tr69p8aU7WR5mskmzO
|
31 |
+
tZFeEb0OxL1WjF/rmwCkd/SvSuWSiszMoX5hcOA7/GGw3pl3YQ==
|
32 |
+
-----END CERTIFICATE-----
|
cosyvoice/__init__.py
ADDED
File without changes
|
cosyvoice/bin/average_model.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
|
2 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
import argparse
|
18 |
+
import glob
|
19 |
+
|
20 |
+
import yaml
|
21 |
+
import torch
|
22 |
+
|
23 |
+
|
24 |
+
def get_args():
|
25 |
+
parser = argparse.ArgumentParser(description='average model')
|
26 |
+
parser.add_argument('--dst_model', required=True, help='averaged model')
|
27 |
+
parser.add_argument('--src_path',
|
28 |
+
required=True,
|
29 |
+
help='src model path for average')
|
30 |
+
parser.add_argument('--val_best',
|
31 |
+
action="store_true",
|
32 |
+
help='averaged model')
|
33 |
+
parser.add_argument('--num',
|
34 |
+
default=5,
|
35 |
+
type=int,
|
36 |
+
help='nums for averaged model')
|
37 |
+
|
38 |
+
args = parser.parse_args()
|
39 |
+
print(args)
|
40 |
+
return args
|
41 |
+
|
42 |
+
|
43 |
+
def main():
|
44 |
+
args = get_args()
|
45 |
+
val_scores = []
|
46 |
+
if args.val_best:
|
47 |
+
yamls = glob.glob('{}/*.yaml'.format(args.src_path))
|
48 |
+
yamls = [
|
49 |
+
f for f in yamls
|
50 |
+
if not (os.path.basename(f).startswith('train')
|
51 |
+
or os.path.basename(f).startswith('init'))
|
52 |
+
]
|
53 |
+
for y in yamls:
|
54 |
+
with open(y, 'r') as f:
|
55 |
+
dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
|
56 |
+
loss = float(dic_yaml['loss_dict']['loss'])
|
57 |
+
epoch = int(dic_yaml['epoch'])
|
58 |
+
step = int(dic_yaml['step'])
|
59 |
+
tag = dic_yaml['tag']
|
60 |
+
val_scores += [[epoch, step, loss, tag]]
|
61 |
+
sorted_val_scores = sorted(val_scores,
|
62 |
+
key=lambda x: x[2],
|
63 |
+
reverse=False)
|
64 |
+
print("best val (epoch, step, loss, tag) = " +
|
65 |
+
str(sorted_val_scores[:args.num]))
|
66 |
+
path_list = [
|
67 |
+
args.src_path + '/epoch_{}_whole.pt'.format(score[0])
|
68 |
+
for score in sorted_val_scores[:args.num]
|
69 |
+
]
|
70 |
+
print(path_list)
|
71 |
+
avg = {}
|
72 |
+
num = args.num
|
73 |
+
assert num == len(path_list)
|
74 |
+
for path in path_list:
|
75 |
+
print('Processing {}'.format(path))
|
76 |
+
states = torch.load(path, map_location=torch.device('cpu'))
|
77 |
+
for k in states.keys():
|
78 |
+
if k not in avg.keys():
|
79 |
+
avg[k] = states[k].clone()
|
80 |
+
else:
|
81 |
+
avg[k] += states[k]
|
82 |
+
# average
|
83 |
+
for k in avg.keys():
|
84 |
+
if avg[k] is not None:
|
85 |
+
# pytorch 1.6 use true_divide instead of /=
|
86 |
+
avg[k] = torch.true_divide(avg[k], num)
|
87 |
+
print('Saving to {}'.format(args.dst_model))
|
88 |
+
torch.save(avg, args.dst_model)
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == '__main__':
|
92 |
+
main()
|
cosyvoice/bin/convert.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
|
4 |
+
def convert_llm(state_dict):
|
5 |
+
# 调整了lm的结构,把codec_lm.encoder作为llm,codec_lm.decoder作为decoder
|
6 |
+
keys = list(state_dict.keys())
|
7 |
+
for k in keys:
|
8 |
+
if k.startswith('codec_lm.encoder.'):
|
9 |
+
v = state_dict.pop(k)
|
10 |
+
k = k.replace('codec_lm.encoder.', 'llm.')
|
11 |
+
state_dict[k] = v
|
12 |
+
if k.startswith('codec_lm.decoder.'):
|
13 |
+
v = state_dict.pop(k)
|
14 |
+
k = k.replace('codec_lm.decoder.', 'llm_decoder.')
|
15 |
+
state_dict[k] = v
|
16 |
+
# espnet和wenet具体实现上的差异
|
17 |
+
keys = list(state_dict.keys())
|
18 |
+
for k in keys:
|
19 |
+
if k.startswith('text_encoder.embed.'):
|
20 |
+
v = state_dict.pop(k)
|
21 |
+
k = k.replace('text_encoder.embed.', 'text_encoder.embed.out.')
|
22 |
+
state_dict[k] = v
|
23 |
+
if k.startswith('llm.embed.'):
|
24 |
+
v = state_dict.pop(k)
|
25 |
+
k = k.replace('llm.embed.', 'llm.embed.out.')
|
26 |
+
state_dict[k] = v
|
27 |
+
keys = list(state_dict.keys())
|
28 |
+
for k in keys:
|
29 |
+
if k.startswith('text_enc_out_layer.'):
|
30 |
+
v = state_dict.pop(k)
|
31 |
+
k = k.replace('text_enc_out_layer.', 'text_encoder_affine_layer.')
|
32 |
+
state_dict[k] = v
|
33 |
+
if k.startswith('token_embedding.'):
|
34 |
+
v = state_dict.pop(k)
|
35 |
+
k = k.replace('token_embedding.', 'text_embedding.')
|
36 |
+
state_dict[k] = v
|
37 |
+
if k.startswith('xvec_proj.'):
|
38 |
+
v = state_dict.pop(k)
|
39 |
+
k = k.replace('xvec_proj.', 'spk_embed_affine_layer.')
|
40 |
+
state_dict[k] = v
|
41 |
+
if k.startswith('lm_embedding.'):
|
42 |
+
v = state_dict.pop(k)
|
43 |
+
k = k.replace('lm_embedding.', 'llm_embedding.')
|
44 |
+
state_dict[k] = v
|
45 |
+
if k.startswith('codec_embedder.'):
|
46 |
+
v = state_dict.pop(k)
|
47 |
+
k = k.replace('codec_embedder.', 'speech_embedding.')
|
48 |
+
state_dict[k] = v
|
49 |
+
# instruct少了spk embedding参数,加个全0上去
|
50 |
+
keys = list(state_dict.keys())
|
51 |
+
if 'spk_embed_affine_layer.weight' not in keys:
|
52 |
+
print('no spk_embed_affine_layer.weight, should be instruct model')
|
53 |
+
state_dict['spk_embed_affine_layer.weight'] = torch.zeros(1024, 192)
|
54 |
+
if 'spk_embed_affine_layer.bias' not in keys:
|
55 |
+
print('no spk_embed_affine_layer.bias, should be instruct model')
|
56 |
+
state_dict['spk_embed_affine_layer.bias'] = torch.zeros(1024)
|
57 |
+
return state_dict
|
58 |
+
|
59 |
+
def convert_hift(state_dict):
|
60 |
+
# 调整了cosyvoice中hifigan的结构,把f0_predictor放到generator里
|
61 |
+
keys = list(state_dict.keys())
|
62 |
+
for k in keys:
|
63 |
+
if k.startswith('decoder.'):
|
64 |
+
v = state_dict.pop(k)
|
65 |
+
k = k.replace('decoder.', '')
|
66 |
+
state_dict[k] = v
|
67 |
+
if k.startswith('generator.'):
|
68 |
+
v = state_dict.pop(k)
|
69 |
+
k = k.replace('generator.', '')
|
70 |
+
state_dict[k] = v
|
71 |
+
return state_dict
|
72 |
+
|
73 |
+
def convert_flow(state_dict):
|
74 |
+
keys = list(state_dict.keys())
|
75 |
+
for k in keys:
|
76 |
+
if k.startswith('encoder.embed.'):
|
77 |
+
v = state_dict.pop(k)
|
78 |
+
k = k.replace('encoder.embed.', 'encoder.embed.out.')
|
79 |
+
state_dict[k] = v
|
80 |
+
for k in keys:
|
81 |
+
if k.startswith('xvec_proj.'):
|
82 |
+
v = state_dict.pop(k)
|
83 |
+
k = k.replace('xvec_proj.', 'spk_embed_affine_layer.')
|
84 |
+
state_dict[k] = v
|
85 |
+
return state_dict
|
86 |
+
|
87 |
+
def convert_llm2(state_dict):
|
88 |
+
# 调整了lm的结构,把codec_lm.encoder作为llm,codec_lm.decoder作为decoder
|
89 |
+
keys = list(state_dict.keys())
|
90 |
+
for k in keys:
|
91 |
+
if k.startswith('codec_lm.encoder.'):
|
92 |
+
v = state_dict.pop(k)
|
93 |
+
k = k.replace('codec_lm.encoder.', 'llm.')
|
94 |
+
state_dict[k] = v
|
95 |
+
if k.startswith('codec_lm.decoder.'):
|
96 |
+
v = state_dict.pop(k)
|
97 |
+
k = k.replace('codec_lm.decoder.', 'llm_decoder.')
|
98 |
+
state_dict[k] = v
|
99 |
+
if k.startswith('lm_embedding.'):
|
100 |
+
v = state_dict.pop(k)
|
101 |
+
k = k.replace('lm_embedding.', 'llm_embedding.')
|
102 |
+
state_dict[k] = v
|
103 |
+
if k.startswith('codec_embedder.'):
|
104 |
+
v = state_dict.pop(k)
|
105 |
+
k = k.replace('codec_embedder.', 'speech_embedding.')
|
106 |
+
state_dict[k] = v
|
107 |
+
if k.startswith('text_enc_out_layer.'):
|
108 |
+
state_dict.pop(k)
|
109 |
+
if k.startswith('token_embedding.weight'):
|
110 |
+
state_dict.pop(k)
|
111 |
+
return state_dict
|
112 |
+
|
113 |
+
def convert_flow2(state_dict):
|
114 |
+
keys = list(state_dict.keys())
|
115 |
+
for k in keys:
|
116 |
+
if k.startswith('encoder.embed.'):
|
117 |
+
v = state_dict.pop(k)
|
118 |
+
k = k.replace('encoder.embed.', 'encoder.embed.out.')
|
119 |
+
state_dict[k] = v
|
120 |
+
for k in keys:
|
121 |
+
if k.startswith('xvec_proj.'):
|
122 |
+
v = state_dict.pop(k)
|
123 |
+
k = k.replace('xvec_proj.', 'spk_embed_affine_layer.')
|
124 |
+
state_dict[k] = v
|
125 |
+
for k in keys:
|
126 |
+
if k.startswith('mel_extractor.'):
|
127 |
+
state_dict.pop(k)
|
128 |
+
for k in keys:
|
129 |
+
if k.startswith('encoder.upsample_blocks.0.0.'):
|
130 |
+
v = state_dict.pop(k)
|
131 |
+
k = k.replace('encoder.upsample_blocks.0.0.', 'encoder.up_layer.')
|
132 |
+
state_dict[k] = v
|
133 |
+
if k.startswith('encoder.upsample_blocks.0.1.'):
|
134 |
+
v = state_dict.pop(k)
|
135 |
+
k = k.replace('encoder.upsample_blocks.0.1.', 'encoder.up_embed.out.')
|
136 |
+
state_dict[k] = v
|
137 |
+
if k.startswith('encoder.upsample_blocks.0.2.'):
|
138 |
+
v = state_dict.pop(k)
|
139 |
+
k = k.replace('encoder.upsample_blocks.0.2.', 'encoder.up_encoders.')
|
140 |
+
state_dict[k] = v
|
141 |
+
# CausalBlock1D中sequantial 1->2
|
142 |
+
if k.startswith('decoder.estimator.') and k.endswith('block.1.weight'):
|
143 |
+
v = state_dict.pop(k)
|
144 |
+
k = k.replace('block.1.weight', 'block.2.weight')
|
145 |
+
state_dict[k] = v
|
146 |
+
if k.startswith('decoder.estimator.') and k.endswith('block.1.bias'):
|
147 |
+
v = state_dict.pop(k)
|
148 |
+
k = k.replace('block.1.bias', 'block.2.bias')
|
149 |
+
state_dict[k] = v
|
150 |
+
return state_dict
|
151 |
+
|
152 |
+
if __name__ == '__main__':
|
153 |
+
# 使用方法 python3 convert.py 原格式llm.pt llm normalize 新格式llm.pt
|
154 |
+
# 或者 python3 convert.py 新格式llm.pt llm inverse_normalize 原格式llm.pt
|
155 |
+
state_dict = torch.load(sys.argv[1], map_location='cpu')
|
156 |
+
if sys.argv[2] == 'llm':
|
157 |
+
state_dict = convert_llm(state_dict)
|
158 |
+
elif sys.argv[2] == 'flow':
|
159 |
+
state_dict = convert_flow(state_dict)
|
160 |
+
elif sys.argv[2] == 'hift':
|
161 |
+
state_dict = convert_hift(state_dict)
|
162 |
+
elif sys.argv[2] == 'llm2':
|
163 |
+
state_dict = convert_llm2(state_dict)
|
164 |
+
elif sys.argv[2] == 'flow2':
|
165 |
+
state_dict = convert_flow2(state_dict)
|
166 |
+
else:
|
167 |
+
raise ValueError
|
168 |
+
torch.save(state_dict, sys.argv[4])
|
cosyvoice/bin/export_jit.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import print_function
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import logging
|
19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
20 |
+
import os
|
21 |
+
import sys
|
22 |
+
import torch
|
23 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
24 |
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
25 |
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
26 |
+
from cosyvoice.cli.cosyvoice import CosyVoice
|
27 |
+
|
28 |
+
|
29 |
+
def get_args():
|
30 |
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
31 |
+
parser.add_argument('--model_dir',
|
32 |
+
type=str,
|
33 |
+
default='pretrained_models/CosyVoice-300M',
|
34 |
+
help='local path')
|
35 |
+
args = parser.parse_args()
|
36 |
+
print(args)
|
37 |
+
return args
|
38 |
+
|
39 |
+
|
40 |
+
def main():
|
41 |
+
args = get_args()
|
42 |
+
logging.basicConfig(level=logging.DEBUG,
|
43 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
44 |
+
|
45 |
+
torch._C._jit_set_fusion_strategy([('STATIC', 1)])
|
46 |
+
torch._C._jit_set_profiling_mode(False)
|
47 |
+
torch._C._jit_set_profiling_executor(False)
|
48 |
+
|
49 |
+
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
|
50 |
+
|
51 |
+
# 1. export llm text_encoder
|
52 |
+
llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
|
53 |
+
script = torch.jit.script(llm_text_encoder)
|
54 |
+
script = torch.jit.freeze(script)
|
55 |
+
script = torch.jit.optimize_for_inference(script)
|
56 |
+
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
|
57 |
+
|
58 |
+
# 2. export llm llm
|
59 |
+
llm_llm = cosyvoice.model.llm.llm.half()
|
60 |
+
script = torch.jit.script(llm_llm)
|
61 |
+
script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
|
62 |
+
script = torch.jit.optimize_for_inference(script)
|
63 |
+
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
64 |
+
|
65 |
+
# 3. export flow encoder
|
66 |
+
flow_encoder = cosyvoice.model.flow.encoder
|
67 |
+
script = torch.jit.script(flow_encoder)
|
68 |
+
script = torch.jit.freeze(script)
|
69 |
+
script = torch.jit.optimize_for_inference(script)
|
70 |
+
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == '__main__':
|
74 |
+
main()
|
cosyvoice/bin/export_jit_cosyvoice2.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import print_function
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import logging
|
19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
20 |
+
import os
|
21 |
+
import sys
|
22 |
+
import torch
|
23 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
24 |
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
25 |
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
26 |
+
from cosyvoice.cli.cosyvoice import CosyVoice2
|
27 |
+
|
28 |
+
|
29 |
+
def get_args():
|
30 |
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
31 |
+
parser.add_argument('--model_dir',
|
32 |
+
type=str,
|
33 |
+
default='pretrained_models/CosyVoice-300M',
|
34 |
+
help='local path')
|
35 |
+
args = parser.parse_args()
|
36 |
+
print(args)
|
37 |
+
return args
|
38 |
+
|
39 |
+
|
40 |
+
def main():
|
41 |
+
args = get_args()
|
42 |
+
logging.basicConfig(level=logging.DEBUG,
|
43 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
44 |
+
|
45 |
+
torch._C._jit_set_fusion_strategy([('STATIC', 1)])
|
46 |
+
torch._C._jit_set_profiling_mode(False)
|
47 |
+
torch._C._jit_set_profiling_executor(False)
|
48 |
+
|
49 |
+
cosyvoice = CosyVoice2(args.model_dir, load_jit=False, load_onnx=False)
|
50 |
+
|
51 |
+
# 3. export flow encoder
|
52 |
+
flow_encoder = cosyvoice.model.flow.encoder
|
53 |
+
script = torch.jit.script(flow_encoder)
|
54 |
+
script = torch.jit.freeze(script)
|
55 |
+
script = torch.jit.optimize_for_inference(script)
|
56 |
+
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == '__main__':
|
60 |
+
main()
|
cosyvoice/bin/export_onnx.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, [email protected])
|
2 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from __future__ import print_function
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import logging
|
20 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
21 |
+
import os
|
22 |
+
import sys
|
23 |
+
import onnxruntime
|
24 |
+
import random
|
25 |
+
import torch
|
26 |
+
from tqdm import tqdm
|
27 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
28 |
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
29 |
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
30 |
+
from cosyvoice.cli.cosyvoice import CosyVoice
|
31 |
+
|
32 |
+
|
33 |
+
def get_dummy_input(batch_size, seq_len, out_channels, device):
|
34 |
+
x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
35 |
+
mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
|
36 |
+
mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
37 |
+
t = torch.rand((batch_size), dtype=torch.float32, device=device)
|
38 |
+
spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
|
39 |
+
cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
40 |
+
return x, mask, mu, t, spks, cond
|
41 |
+
|
42 |
+
|
43 |
+
def get_args():
|
44 |
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
45 |
+
parser.add_argument('--model_dir',
|
46 |
+
type=str,
|
47 |
+
default='pretrained_models/CosyVoice-300M',
|
48 |
+
help='local path')
|
49 |
+
args = parser.parse_args()
|
50 |
+
print(args)
|
51 |
+
return args
|
52 |
+
|
53 |
+
|
54 |
+
def main():
|
55 |
+
args = get_args()
|
56 |
+
logging.basicConfig(level=logging.DEBUG,
|
57 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
58 |
+
|
59 |
+
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
|
60 |
+
|
61 |
+
# 1. export flow decoder estimator
|
62 |
+
estimator = cosyvoice.model.flow.decoder.estimator
|
63 |
+
|
64 |
+
device = cosyvoice.model.device
|
65 |
+
batch_size, seq_len = 1, 256
|
66 |
+
out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
|
67 |
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
68 |
+
torch.onnx.export(
|
69 |
+
estimator,
|
70 |
+
(x, mask, mu, t, spks, cond),
|
71 |
+
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
72 |
+
export_params=True,
|
73 |
+
opset_version=18,
|
74 |
+
do_constant_folding=True,
|
75 |
+
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
76 |
+
output_names=['estimator_out'],
|
77 |
+
dynamic_axes={
|
78 |
+
'x': {0: 'batch_size', 2: 'seq_len'},
|
79 |
+
'mask': {0: 'batch_size', 2: 'seq_len'},
|
80 |
+
'mu': {0: 'batch_size', 2: 'seq_len'},
|
81 |
+
'cond': {0: 'batch_size', 2: 'seq_len'},
|
82 |
+
't': {0: 'batch_size'},
|
83 |
+
'spks': {0: 'batch_size'},
|
84 |
+
'estimator_out': {0: 'batch_size', 2: 'seq_len'},
|
85 |
+
}
|
86 |
+
)
|
87 |
+
|
88 |
+
# 2. test computation consistency
|
89 |
+
option = onnxruntime.SessionOptions()
|
90 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
91 |
+
option.intra_op_num_threads = 1
|
92 |
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
93 |
+
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
94 |
+
sess_options=option, providers=providers)
|
95 |
+
|
96 |
+
for _ in tqdm(range(10)):
|
97 |
+
x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
|
98 |
+
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
99 |
+
ort_inputs = {
|
100 |
+
'x': x.cpu().numpy(),
|
101 |
+
'mask': mask.cpu().numpy(),
|
102 |
+
'mu': mu.cpu().numpy(),
|
103 |
+
't': t.cpu().numpy(),
|
104 |
+
'spks': spks.cpu().numpy(),
|
105 |
+
'cond': cond.cpu().numpy()
|
106 |
+
}
|
107 |
+
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
108 |
+
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
109 |
+
|
110 |
+
|
111 |
+
if __name__ == "__main__":
|
112 |
+
main()
|
cosyvoice/bin/export_onnx_cosyvoice2.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, [email protected])
|
2 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from __future__ import print_function
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import logging
|
20 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
21 |
+
import os
|
22 |
+
import sys
|
23 |
+
import onnxruntime
|
24 |
+
import random
|
25 |
+
import torch
|
26 |
+
from tqdm import tqdm
|
27 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
28 |
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
29 |
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
30 |
+
from cosyvoice.cli.cosyvoice import CosyVoice2
|
31 |
+
|
32 |
+
|
33 |
+
def get_dummy_input(batch_size, seq_len, out_channels, device):
|
34 |
+
x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
35 |
+
mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
|
36 |
+
mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
37 |
+
t = torch.rand((batch_size), dtype=torch.float32, device=device)
|
38 |
+
spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
|
39 |
+
cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
40 |
+
return x, mask, mu, t, spks, cond
|
41 |
+
|
42 |
+
|
43 |
+
def get_args():
|
44 |
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
45 |
+
parser.add_argument('--model_dir',
|
46 |
+
type=str,
|
47 |
+
default='pretrained_models/CosyVoice-300M',
|
48 |
+
help='local path')
|
49 |
+
args = parser.parse_args()
|
50 |
+
print(args)
|
51 |
+
return args
|
52 |
+
|
53 |
+
|
54 |
+
def main():
|
55 |
+
args = get_args()
|
56 |
+
logging.basicConfig(level=logging.DEBUG,
|
57 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
58 |
+
|
59 |
+
cosyvoice = CosyVoice2(args.model_dir, load_jit=False, load_onnx=False)
|
60 |
+
|
61 |
+
# 1. export flow decoder estimator
|
62 |
+
estimator = cosyvoice.model.flow.decoder.estimator
|
63 |
+
|
64 |
+
device = cosyvoice.model.device
|
65 |
+
batch_size, seq_len = 2, 320
|
66 |
+
out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
|
67 |
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
68 |
+
torch.onnx.export(
|
69 |
+
estimator,
|
70 |
+
(x, mask, mu, t, spks, cond),
|
71 |
+
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
72 |
+
export_params=True,
|
73 |
+
opset_version=18,
|
74 |
+
do_constant_folding=True,
|
75 |
+
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
76 |
+
output_names=['estimator_out'],
|
77 |
+
dynamic_axes={
|
78 |
+
'x': {2: 'seq_len'},
|
79 |
+
'mask': {2: 'seq_len'},
|
80 |
+
'mu': {2: 'seq_len'},
|
81 |
+
'cond': {2: 'seq_len'},
|
82 |
+
'estimator_out': {2: 'seq_len'},
|
83 |
+
}
|
84 |
+
)
|
85 |
+
|
86 |
+
# 2. test computation consistency
|
87 |
+
option = onnxruntime.SessionOptions()
|
88 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
89 |
+
option.intra_op_num_threads = 1
|
90 |
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
91 |
+
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
92 |
+
sess_options=option, providers=providers)
|
93 |
+
|
94 |
+
for _ in tqdm(range(10)):
|
95 |
+
x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
|
96 |
+
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
97 |
+
ort_inputs = {
|
98 |
+
'x': x.cpu().numpy(),
|
99 |
+
'mask': mask.cpu().numpy(),
|
100 |
+
'mu': mu.cpu().numpy(),
|
101 |
+
't': t.cpu().numpy(),
|
102 |
+
'spks': spks.cpu().numpy(),
|
103 |
+
'cond': cond.cpu().numpy()
|
104 |
+
}
|
105 |
+
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
106 |
+
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
main()
|
cosyvoice/bin/export_trt_cosyvoce2.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/mnt/lyuxiang.lx/data/TensorRT-10.0.1.6-cu124/TensorRT-10.0.1.6/lib:/usr/local/cuda-12.4/lib64
|
3 |
+
/mnt/lyuxiang.lx/data/TensorRT-10.0.1.6-cu124/TensorRT-10.0.1.6/bin/trtexec --onnx=/mnt/lyuxiang.lx/CosyVoice_github/pretrained_models/CosyVoice2-0.5B/flow.decoder.estimator.fp32.onnx --saveEngine=/mnt/lyuxiang.lx/CosyVoice_github/pretrained_models/CosyVoice2-0.5B/flow.decoder.estimator.fp16.Volta.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw
|
cosyvoice/bin/inference.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import print_function
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import logging
|
19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
20 |
+
import os
|
21 |
+
import torch
|
22 |
+
from torch.utils.data import DataLoader
|
23 |
+
import torchaudio
|
24 |
+
from hyperpyyaml import load_hyperpyyaml
|
25 |
+
from tqdm import tqdm
|
26 |
+
from cosyvoice.cli.model import CosyVoiceModel
|
27 |
+
from cosyvoice.dataset.dataset import Dataset
|
28 |
+
|
29 |
+
|
30 |
+
def get_args():
|
31 |
+
parser = argparse.ArgumentParser(description='inference with your model')
|
32 |
+
parser.add_argument('--config', required=True, help='config file')
|
33 |
+
parser.add_argument('--prompt_data', required=True, help='prompt data file')
|
34 |
+
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
|
35 |
+
parser.add_argument('--tts_text', required=True, help='tts input file')
|
36 |
+
parser.add_argument('--llm_model', required=True, help='llm model file')
|
37 |
+
parser.add_argument('--flow_model', required=True, help='flow model file')
|
38 |
+
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
|
39 |
+
parser.add_argument('--gpu',
|
40 |
+
type=int,
|
41 |
+
default=-1,
|
42 |
+
help='gpu id for this rank, -1 for cpu')
|
43 |
+
parser.add_argument('--mode',
|
44 |
+
default='sft',
|
45 |
+
choices=['sft', 'zero_shot'],
|
46 |
+
help='inference mode')
|
47 |
+
parser.add_argument('--result_dir', required=True, help='asr result file')
|
48 |
+
args = parser.parse_args()
|
49 |
+
print(args)
|
50 |
+
return args
|
51 |
+
|
52 |
+
|
53 |
+
def main():
|
54 |
+
args = get_args()
|
55 |
+
logging.basicConfig(level=logging.DEBUG,
|
56 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
57 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
58 |
+
|
59 |
+
# Init cosyvoice models from configs
|
60 |
+
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
61 |
+
device = torch.device('cuda' if use_cuda else 'cpu')
|
62 |
+
with open(args.config, 'r') as f:
|
63 |
+
configs = load_hyperpyyaml(f)
|
64 |
+
|
65 |
+
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
66 |
+
model.load(args.llm_model, args.flow_model, args.hifigan_model)
|
67 |
+
|
68 |
+
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
|
69 |
+
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
|
70 |
+
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
71 |
+
|
72 |
+
del configs
|
73 |
+
os.makedirs(args.result_dir, exist_ok=True)
|
74 |
+
fn = os.path.join(args.result_dir, 'wav.scp')
|
75 |
+
f = open(fn, 'w')
|
76 |
+
with torch.no_grad():
|
77 |
+
for _, batch in tqdm(enumerate(test_data_loader)):
|
78 |
+
utts = batch["utts"]
|
79 |
+
assert len(utts) == 1, "inference mode only support batchsize 1"
|
80 |
+
text_token = batch["text_token"].to(device)
|
81 |
+
text_token_len = batch["text_token_len"].to(device)
|
82 |
+
tts_index = batch["tts_index"]
|
83 |
+
tts_text_token = batch["tts_text_token"].to(device)
|
84 |
+
tts_text_token_len = batch["tts_text_token_len"].to(device)
|
85 |
+
speech_token = batch["speech_token"].to(device)
|
86 |
+
speech_token_len = batch["speech_token_len"].to(device)
|
87 |
+
speech_feat = batch["speech_feat"].to(device)
|
88 |
+
speech_feat_len = batch["speech_feat_len"].to(device)
|
89 |
+
utt_embedding = batch["utt_embedding"].to(device)
|
90 |
+
spk_embedding = batch["spk_embedding"].to(device)
|
91 |
+
if args.mode == 'sft':
|
92 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
93 |
+
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
|
94 |
+
else:
|
95 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
96 |
+
'prompt_text': text_token, 'prompt_text_len': text_token_len,
|
97 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
98 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
99 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
100 |
+
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
|
101 |
+
tts_speeches = []
|
102 |
+
for model_output in model.tts(**model_input):
|
103 |
+
tts_speeches.append(model_output['tts_speech'])
|
104 |
+
tts_speeches = torch.concat(tts_speeches, dim=1)
|
105 |
+
tts_key = '{}_{}'.format(utts[0], tts_index[0])
|
106 |
+
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
|
107 |
+
torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
|
108 |
+
f.write('{} {}\n'.format(tts_key, tts_fn))
|
109 |
+
f.flush()
|
110 |
+
f.close()
|
111 |
+
logging.info('Result wav.scp saved in {}'.format(fn))
|
112 |
+
|
113 |
+
|
114 |
+
if __name__ == '__main__':
|
115 |
+
main()
|
cosyvoice/bin/train.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import print_function
|
16 |
+
import argparse
|
17 |
+
import datetime
|
18 |
+
import logging
|
19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
20 |
+
from copy import deepcopy
|
21 |
+
import os
|
22 |
+
import torch
|
23 |
+
import torch.distributed as dist
|
24 |
+
import deepspeed
|
25 |
+
|
26 |
+
from hyperpyyaml import load_hyperpyyaml
|
27 |
+
|
28 |
+
from torch.distributed.elastic.multiprocessing.errors import record
|
29 |
+
|
30 |
+
from cosyvoice.utils.executor import Executor
|
31 |
+
from cosyvoice.utils.train_utils import (
|
32 |
+
init_distributed,
|
33 |
+
init_dataset_and_dataloader,
|
34 |
+
init_optimizer_and_scheduler,
|
35 |
+
init_summarywriter, save_model,
|
36 |
+
wrap_cuda_model, check_modify_and_save_config)
|
37 |
+
|
38 |
+
|
39 |
+
def get_args():
|
40 |
+
parser = argparse.ArgumentParser(description='training your network')
|
41 |
+
parser.add_argument('--train_engine',
|
42 |
+
default='torch_ddp',
|
43 |
+
choices=['torch_ddp', 'deepspeed'],
|
44 |
+
help='Engine for paralleled training')
|
45 |
+
parser.add_argument('--model', required=True, help='model which will be trained')
|
46 |
+
parser.add_argument('--config', required=True, help='config file')
|
47 |
+
parser.add_argument('--train_data', required=True, help='train data file')
|
48 |
+
parser.add_argument('--cv_data', required=True, help='cv data file')
|
49 |
+
parser.add_argument('--checkpoint', help='checkpoint model')
|
50 |
+
parser.add_argument('--model_dir', required=True, help='save model dir')
|
51 |
+
parser.add_argument('--tensorboard_dir',
|
52 |
+
default='tensorboard',
|
53 |
+
help='tensorboard log dir')
|
54 |
+
parser.add_argument('--ddp.dist_backend',
|
55 |
+
dest='dist_backend',
|
56 |
+
default='nccl',
|
57 |
+
choices=['nccl', 'gloo'],
|
58 |
+
help='distributed backend')
|
59 |
+
parser.add_argument('--num_workers',
|
60 |
+
default=0,
|
61 |
+
type=int,
|
62 |
+
help='num of subprocess workers for reading')
|
63 |
+
parser.add_argument('--prefetch',
|
64 |
+
default=100,
|
65 |
+
type=int,
|
66 |
+
help='prefetch number')
|
67 |
+
parser.add_argument('--pin_memory',
|
68 |
+
action='store_true',
|
69 |
+
default=False,
|
70 |
+
help='Use pinned memory buffers used for reading')
|
71 |
+
parser.add_argument('--use_amp',
|
72 |
+
action='store_true',
|
73 |
+
default=False,
|
74 |
+
help='Use automatic mixed precision training')
|
75 |
+
parser.add_argument('--deepspeed.save_states',
|
76 |
+
dest='save_states',
|
77 |
+
default='model_only',
|
78 |
+
choices=['model_only', 'model+optimizer'],
|
79 |
+
help='save model/optimizer states')
|
80 |
+
parser.add_argument('--timeout',
|
81 |
+
default=60,
|
82 |
+
type=int,
|
83 |
+
help='timeout (in seconds) of cosyvoice_join.')
|
84 |
+
parser = deepspeed.add_config_arguments(parser)
|
85 |
+
args = parser.parse_args()
|
86 |
+
return args
|
87 |
+
|
88 |
+
|
89 |
+
@record
|
90 |
+
def main():
|
91 |
+
args = get_args()
|
92 |
+
logging.basicConfig(level=logging.DEBUG,
|
93 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
94 |
+
# gan train has some special initialization logic
|
95 |
+
gan = True if args.model == 'hifigan' else False
|
96 |
+
|
97 |
+
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
|
98 |
+
if gan is True:
|
99 |
+
override_dict.pop('hift')
|
100 |
+
with open(args.config, 'r') as f:
|
101 |
+
configs = load_hyperpyyaml(f, overrides=override_dict)
|
102 |
+
if gan is True:
|
103 |
+
configs['train_conf'] = configs['train_conf_gan']
|
104 |
+
configs['train_conf'].update(vars(args))
|
105 |
+
|
106 |
+
# Init env for ddp
|
107 |
+
init_distributed(args)
|
108 |
+
|
109 |
+
# Get dataset & dataloader
|
110 |
+
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
|
111 |
+
init_dataset_and_dataloader(args, configs, gan)
|
112 |
+
|
113 |
+
# Do some sanity checks and save config to arsg.model_dir
|
114 |
+
configs = check_modify_and_save_config(args, configs)
|
115 |
+
|
116 |
+
# Tensorboard summary
|
117 |
+
writer = init_summarywriter(args)
|
118 |
+
|
119 |
+
# load checkpoint
|
120 |
+
model = configs[args.model]
|
121 |
+
start_step, start_epoch = 0, -1
|
122 |
+
if args.checkpoint is not None:
|
123 |
+
if os.path.exists(args.checkpoint):
|
124 |
+
state_dict = torch.load(args.checkpoint, map_location='cpu')
|
125 |
+
model.load_state_dict(state_dict, strict=False)
|
126 |
+
if 'step' in state_dict:
|
127 |
+
start_step = state_dict['step']
|
128 |
+
if 'epoch' in state_dict:
|
129 |
+
start_epoch = state_dict['epoch']
|
130 |
+
else:
|
131 |
+
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
|
132 |
+
|
133 |
+
# Dispatch model from cpu to gpu
|
134 |
+
model = wrap_cuda_model(args, model)
|
135 |
+
|
136 |
+
# Get optimizer & scheduler
|
137 |
+
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
|
138 |
+
scheduler.set_step(start_step)
|
139 |
+
if scheduler_d is not None:
|
140 |
+
scheduler_d.set_step(start_step)
|
141 |
+
|
142 |
+
# Save init checkpoints
|
143 |
+
info_dict = deepcopy(configs['train_conf'])
|
144 |
+
info_dict['step'] = start_step
|
145 |
+
info_dict['epoch'] = start_epoch
|
146 |
+
save_model(model, 'init', info_dict)
|
147 |
+
|
148 |
+
# Get executor
|
149 |
+
executor = Executor(gan=gan)
|
150 |
+
executor.step = start_step
|
151 |
+
|
152 |
+
# Init scaler, used for pytorch amp mixed precision training
|
153 |
+
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
|
154 |
+
print('start step {} start epoch {}'.format(start_step, start_epoch))
|
155 |
+
# Start training loop
|
156 |
+
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
|
157 |
+
executor.epoch = epoch
|
158 |
+
train_dataset.set_epoch(epoch)
|
159 |
+
dist.barrier()
|
160 |
+
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
|
161 |
+
if gan is True:
|
162 |
+
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
163 |
+
writer, info_dict, scaler, group_join)
|
164 |
+
else:
|
165 |
+
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
|
166 |
+
dist.destroy_process_group(group_join)
|
167 |
+
|
168 |
+
|
169 |
+
if __name__ == '__main__':
|
170 |
+
main()
|
cosyvoice/cli/__init__.py
ADDED
File without changes
|
cosyvoice/cli/cosyvoice.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import time
|
16 |
+
from tqdm import tqdm
|
17 |
+
from hyperpyyaml import load_hyperpyyaml
|
18 |
+
from modelscope import snapshot_download
|
19 |
+
import torch
|
20 |
+
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
21 |
+
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
22 |
+
from cosyvoice.utils.file_utils import logging
|
23 |
+
|
24 |
+
|
25 |
+
class CosyVoice:
|
26 |
+
|
27 |
+
def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
|
28 |
+
instruct = True if '-Instruct' in model_dir else False
|
29 |
+
self.model_dir = model_dir
|
30 |
+
if not os.path.exists(model_dir):
|
31 |
+
model_dir = snapshot_download(model_dir)
|
32 |
+
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
33 |
+
configs = load_hyperpyyaml(f)
|
34 |
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
35 |
+
configs['feat_extractor'],
|
36 |
+
'{}/campplus.onnx'.format(model_dir),
|
37 |
+
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
|
38 |
+
'{}/spk2info.pt'.format(model_dir),
|
39 |
+
instruct,
|
40 |
+
configs['allowed_special'])
|
41 |
+
self.sample_rate = configs['sample_rate']
|
42 |
+
if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
|
43 |
+
load_jit = False
|
44 |
+
fp16 = False
|
45 |
+
logging.warning('cpu do not support fp16 and jit, force set to False')
|
46 |
+
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
|
47 |
+
self.model.load('{}/llm.pt'.format(model_dir),
|
48 |
+
'{}/flow.pt'.format(model_dir),
|
49 |
+
'{}/hift.pt'.format(model_dir))
|
50 |
+
if load_jit:
|
51 |
+
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
|
52 |
+
'{}/llm.llm.fp16.zip'.format(model_dir),
|
53 |
+
'{}/flow.encoder.fp32.zip'.format(model_dir))
|
54 |
+
if load_onnx:
|
55 |
+
self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
|
56 |
+
del configs
|
57 |
+
|
58 |
+
def list_avaliable_spks(self):
|
59 |
+
spks = list(self.frontend.spk2info.keys())
|
60 |
+
return spks
|
61 |
+
|
62 |
+
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0):
|
63 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
64 |
+
model_input = self.frontend.frontend_sft(i, spk_id)
|
65 |
+
start_time = time.time()
|
66 |
+
logging.info('synthesis text {}'.format(i))
|
67 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
68 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
69 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
70 |
+
yield model_output
|
71 |
+
start_time = time.time()
|
72 |
+
|
73 |
+
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0):
|
74 |
+
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
|
75 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
76 |
+
if len(i) < 0.5 * len(prompt_text):
|
77 |
+
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
|
78 |
+
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
|
79 |
+
start_time = time.time()
|
80 |
+
logging.info('synthesis text {}'.format(i))
|
81 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
82 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
83 |
+
logging.info('yield speech len {}, rtf {}, abs mean {}, std {}'.format(speech_len, (time.time() - start_time) / speech_len, model_output['tts_speech'].abs().mean(), model_output['tts_speech'].std()))
|
84 |
+
yield model_output
|
85 |
+
start_time = time.time()
|
86 |
+
|
87 |
+
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0):
|
88 |
+
if self.frontend.instruct is True:
|
89 |
+
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
|
90 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
91 |
+
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
|
92 |
+
start_time = time.time()
|
93 |
+
logging.info('synthesis text {}'.format(i))
|
94 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
95 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
96 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
97 |
+
yield model_output
|
98 |
+
start_time = time.time()
|
99 |
+
|
100 |
+
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0):
|
101 |
+
if self.frontend.instruct is False:
|
102 |
+
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
103 |
+
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
|
104 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
105 |
+
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
106 |
+
start_time = time.time()
|
107 |
+
logging.info('synthesis text {}'.format(i))
|
108 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
109 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
110 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
111 |
+
yield model_output
|
112 |
+
start_time = time.time()
|
113 |
+
|
114 |
+
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0):
|
115 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
116 |
+
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
|
117 |
+
start_time = time.time()
|
118 |
+
logging.info('synthesis text {}'.format(i))
|
119 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
120 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
121 |
+
logging.info('yield speech len {}, rtf {}, abs mean {}, std {}'.format(speech_len, (time.time() - start_time) / speech_len, model_output['tts_speech'].abs().mean(), model_output['tts_speech'].std()))
|
122 |
+
yield model_output
|
123 |
+
start_time = time.time()
|
124 |
+
|
125 |
+
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
|
126 |
+
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
|
127 |
+
start_time = time.time()
|
128 |
+
for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
|
129 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
130 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
131 |
+
yield model_output
|
132 |
+
start_time = time.time()
|
133 |
+
|
134 |
+
class CosyVoice2(CosyVoice):
|
135 |
+
|
136 |
+
def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
|
137 |
+
instruct = True if '-Instruct' in model_dir else False
|
138 |
+
self.model_dir = model_dir
|
139 |
+
if not os.path.exists(model_dir):
|
140 |
+
model_dir = snapshot_download(model_dir)
|
141 |
+
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
142 |
+
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
143 |
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
144 |
+
configs['feat_extractor'],
|
145 |
+
'{}/campplus.onnx'.format(model_dir),
|
146 |
+
'{}/speech_tokenizer_v2.onnx'.format(model_dir),
|
147 |
+
'{}/spk2info.pt'.format(model_dir),
|
148 |
+
instruct,
|
149 |
+
configs['allowed_special'])
|
150 |
+
self.sample_rate = configs['sample_rate']
|
151 |
+
if torch.cuda.is_available() is False and load_jit is True:
|
152 |
+
load_jit = False
|
153 |
+
logging.warning('cpu do not support jit, force set to False')
|
154 |
+
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
|
155 |
+
self.model.load('{}/llm.pt'.format(model_dir),
|
156 |
+
'{}/flow.pt'.format(model_dir),
|
157 |
+
'{}/hift.pt'.format(model_dir))
|
158 |
+
if load_jit:
|
159 |
+
self.model.load_jit('{}/flow.encoder.fp32.zip'.format(model_dir))
|
160 |
+
if load_trt is True and load_onnx is True:
|
161 |
+
load_onnx = False
|
162 |
+
logging.warning('can not set both load_trt and load_onnx to True, force set load_onnx to False')
|
163 |
+
if load_onnx:
|
164 |
+
self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
|
165 |
+
if load_trt:
|
166 |
+
self.model.load_trt('{}/flow.decoder.estimator.fp16.A10.plan'.format(model_dir))
|
167 |
+
del configs
|
cosyvoice/cli/frontend.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from functools import partial
|
15 |
+
import json
|
16 |
+
import onnxruntime
|
17 |
+
import torch
|
18 |
+
import numpy as np
|
19 |
+
import whisper
|
20 |
+
from typing import Callable
|
21 |
+
import torchaudio.compliance.kaldi as kaldi
|
22 |
+
import torchaudio
|
23 |
+
import os
|
24 |
+
import re
|
25 |
+
import inflect
|
26 |
+
try:
|
27 |
+
import ttsfrd
|
28 |
+
use_ttsfrd = True
|
29 |
+
except ImportError:
|
30 |
+
print("failed to import ttsfrd, use WeTextProcessing instead")
|
31 |
+
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
32 |
+
from tn.english.normalizer import Normalizer as EnNormalizer
|
33 |
+
use_ttsfrd = False
|
34 |
+
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
|
35 |
+
|
36 |
+
|
37 |
+
class CosyVoiceFrontEnd:
|
38 |
+
|
39 |
+
def __init__(self,
|
40 |
+
get_tokenizer: Callable,
|
41 |
+
feat_extractor: Callable,
|
42 |
+
campplus_model: str,
|
43 |
+
speech_tokenizer_model: str,
|
44 |
+
spk2info: str = '',
|
45 |
+
instruct: bool = False,
|
46 |
+
allowed_special: str = 'all'):
|
47 |
+
self.tokenizer = get_tokenizer()
|
48 |
+
self.feat_extractor = feat_extractor
|
49 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
50 |
+
option = onnxruntime.SessionOptions()
|
51 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
52 |
+
option.intra_op_num_threads = 1
|
53 |
+
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
54 |
+
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
55 |
+
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
56 |
+
"CPUExecutionProvider"])
|
57 |
+
if os.path.exists(spk2info):
|
58 |
+
self.spk2info = torch.load(spk2info, map_location=self.device)
|
59 |
+
else:
|
60 |
+
self.spk2info = {}
|
61 |
+
self.instruct = instruct
|
62 |
+
self.allowed_special = allowed_special
|
63 |
+
self.inflect_parser = inflect.engine()
|
64 |
+
self.use_ttsfrd = use_ttsfrd
|
65 |
+
if self.use_ttsfrd:
|
66 |
+
self.frd = ttsfrd.TtsFrontendEngine()
|
67 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
68 |
+
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
69 |
+
'failed to initialize ttsfrd resource'
|
70 |
+
self.frd.set_lang_type('pinyinvg')
|
71 |
+
else:
|
72 |
+
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
|
73 |
+
self.en_tn_model = EnNormalizer()
|
74 |
+
|
75 |
+
def _extract_text_token(self, text):
|
76 |
+
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
77 |
+
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
78 |
+
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
79 |
+
return text_token, text_token_len
|
80 |
+
|
81 |
+
def _extract_speech_token(self, speech):
|
82 |
+
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
83 |
+
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
84 |
+
speech_token = self.speech_tokenizer_session.run(None,
|
85 |
+
{self.speech_tokenizer_session.get_inputs()[0].name:
|
86 |
+
feat.detach().cpu().numpy(),
|
87 |
+
self.speech_tokenizer_session.get_inputs()[1].name:
|
88 |
+
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
89 |
+
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
90 |
+
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
91 |
+
return speech_token, speech_token_len
|
92 |
+
|
93 |
+
def _extract_spk_embedding(self, speech):
|
94 |
+
feat = kaldi.fbank(speech,
|
95 |
+
num_mel_bins=80,
|
96 |
+
dither=0,
|
97 |
+
sample_frequency=16000)
|
98 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
99 |
+
embedding = self.campplus_session.run(None,
|
100 |
+
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
101 |
+
embedding = torch.tensor([embedding]).to(self.device)
|
102 |
+
return embedding
|
103 |
+
|
104 |
+
def _extract_speech_feat(self, speech):
|
105 |
+
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
106 |
+
speech_feat = speech_feat.unsqueeze(dim=0)
|
107 |
+
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
108 |
+
return speech_feat, speech_feat_len
|
109 |
+
|
110 |
+
def text_normalize(self, text, split=True):
|
111 |
+
text = text.strip()
|
112 |
+
if contains_chinese(text):
|
113 |
+
if self.use_ttsfrd:
|
114 |
+
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
115 |
+
text = ''.join(texts)
|
116 |
+
else:
|
117 |
+
text = self.zh_tn_model.normalize(text)
|
118 |
+
text = text.replace("\n", "")
|
119 |
+
text = replace_blank(text)
|
120 |
+
text = replace_corner_mark(text)
|
121 |
+
text = text.replace(".", "。")
|
122 |
+
text = text.replace(" - ", ",")
|
123 |
+
text = remove_bracket(text)
|
124 |
+
text = re.sub(r'[,,、]+$', '。', text)
|
125 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
126 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
127 |
+
else:
|
128 |
+
if self.use_ttsfrd:
|
129 |
+
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
130 |
+
text = ''.join(texts)
|
131 |
+
else:
|
132 |
+
text = self.en_tn_model.normalize(text)
|
133 |
+
text = spell_out_number(text, self.inflect_parser)
|
134 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
135 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
136 |
+
if split is False:
|
137 |
+
return text
|
138 |
+
return texts
|
139 |
+
|
140 |
+
def frontend_sft(self, tts_text, spk_id):
|
141 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
142 |
+
embedding = self.spk2info[spk_id]['embedding']
|
143 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
144 |
+
return model_input
|
145 |
+
|
146 |
+
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
|
147 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
148 |
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
149 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
150 |
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
151 |
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
152 |
+
if resample_rate == 24000:
|
153 |
+
# cosyvoice2, force speech_feat % speech_token = 2
|
154 |
+
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
155 |
+
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2* token_len
|
156 |
+
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
157 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
158 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
159 |
+
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
160 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
161 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
162 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
163 |
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
164 |
+
return model_input
|
165 |
+
|
166 |
+
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
|
167 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
168 |
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(instruct_text + '<|endofprompt|>')
|
169 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
170 |
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
171 |
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
172 |
+
if resample_rate == 24000:
|
173 |
+
# cosyvoice2, force speech_feat % speech_token = 2
|
174 |
+
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
175 |
+
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2* token_len
|
176 |
+
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
177 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
178 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
179 |
+
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
180 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
181 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
182 |
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
183 |
+
return model_input
|
184 |
+
|
185 |
+
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
|
186 |
+
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
|
187 |
+
# in cross lingual mode, we remove prompt in llm
|
188 |
+
del model_input['prompt_text']
|
189 |
+
del model_input['prompt_text_len']
|
190 |
+
del model_input['llm_prompt_speech_token']
|
191 |
+
del model_input['llm_prompt_speech_token_len']
|
192 |
+
return model_input
|
193 |
+
|
194 |
+
def frontend_instruct(self, tts_text, spk_id, instruct_text):
|
195 |
+
model_input = self.frontend_sft(tts_text, spk_id)
|
196 |
+
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
197 |
+
del model_input['llm_embedding']
|
198 |
+
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
|
199 |
+
model_input['prompt_text'] = instruct_text_token
|
200 |
+
model_input['prompt_text_len'] = instruct_text_token_len
|
201 |
+
return model_input
|
202 |
+
|
203 |
+
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
|
204 |
+
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
205 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
206 |
+
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
207 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
208 |
+
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
209 |
+
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
210 |
+
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
|
211 |
+
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
|
212 |
+
'flow_embedding': embedding}
|
213 |
+
return model_input
|
cosyvoice/cli/model.py
ADDED
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import numpy as np
|
16 |
+
import threading
|
17 |
+
import time
|
18 |
+
from torch.nn import functional as F
|
19 |
+
from contextlib import nullcontext
|
20 |
+
import uuid
|
21 |
+
from cosyvoice.utils.common import fade_in_out
|
22 |
+
|
23 |
+
|
24 |
+
class CosyVoiceModel:
|
25 |
+
|
26 |
+
def __init__(self,
|
27 |
+
llm: torch.nn.Module,
|
28 |
+
flow: torch.nn.Module,
|
29 |
+
hift: torch.nn.Module,
|
30 |
+
fp16: bool):
|
31 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
32 |
+
self.llm = llm
|
33 |
+
self.flow = flow
|
34 |
+
self.hift = hift
|
35 |
+
self.fp16 = fp16
|
36 |
+
self.token_min_hop_len = 2 * self.flow.input_frame_rate
|
37 |
+
self.token_max_hop_len = 4 * self.flow.input_frame_rate
|
38 |
+
self.token_overlap_len = 20
|
39 |
+
# mel fade in out
|
40 |
+
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
|
41 |
+
self.mel_window = np.hamming(2 * self.mel_overlap_len)
|
42 |
+
# hift cache
|
43 |
+
self.mel_cache_len = 20
|
44 |
+
self.source_cache_len = int(self.mel_cache_len * 256)
|
45 |
+
# speech fade in out
|
46 |
+
self.speech_window = np.hamming(2 * self.source_cache_len)
|
47 |
+
# rtf and decoding related
|
48 |
+
self.stream_scale_factor = 1
|
49 |
+
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
50 |
+
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
51 |
+
self.lock = threading.Lock()
|
52 |
+
# dict used to store session related variable
|
53 |
+
self.tts_speech_token_dict = {}
|
54 |
+
self.llm_end_dict = {}
|
55 |
+
self.mel_overlap_dict = {}
|
56 |
+
self.flow_cache_dict = {}
|
57 |
+
self.hift_cache_dict = {}
|
58 |
+
|
59 |
+
def load(self, llm_model, flow_model, hift_model):
|
60 |
+
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
61 |
+
self.llm.to(self.device).eval()
|
62 |
+
if self.fp16 is True:
|
63 |
+
self.llm.half()
|
64 |
+
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
65 |
+
self.flow.to(self.device).eval()
|
66 |
+
# in case hift_model is a hifigan model
|
67 |
+
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
68 |
+
self.hift.load_state_dict(hift_state_dict, strict=True)
|
69 |
+
self.hift.to(self.device).eval()
|
70 |
+
|
71 |
+
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
72 |
+
assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
|
73 |
+
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
|
74 |
+
self.llm.text_encoder = llm_text_encoder
|
75 |
+
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
|
76 |
+
self.llm.llm = llm_llm
|
77 |
+
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
78 |
+
self.flow.encoder = flow_encoder
|
79 |
+
|
80 |
+
def load_onnx(self, flow_decoder_estimator_model):
|
81 |
+
import onnxruntime
|
82 |
+
option = onnxruntime.SessionOptions()
|
83 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
84 |
+
option.intra_op_num_threads = 1
|
85 |
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
86 |
+
del self.flow.decoder.estimator
|
87 |
+
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
|
88 |
+
|
89 |
+
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
90 |
+
if self.fp16 is True:
|
91 |
+
llm_embedding = llm_embedding.half()
|
92 |
+
with self.llm_context:
|
93 |
+
for i in self.llm.inference(text=text.to(self.device),
|
94 |
+
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
95 |
+
prompt_text=prompt_text.to(self.device),
|
96 |
+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
97 |
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
98 |
+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
99 |
+
embedding=llm_embedding.to(self.device)):
|
100 |
+
self.tts_speech_token_dict[uuid].append(i)
|
101 |
+
self.llm_end_dict[uuid] = True
|
102 |
+
|
103 |
+
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
104 |
+
tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
|
105 |
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
106 |
+
prompt_token=prompt_token.to(self.device),
|
107 |
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
108 |
+
prompt_feat=prompt_feat.to(self.device),
|
109 |
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
110 |
+
embedding=embedding.to(self.device),
|
111 |
+
flow_cache=self.flow_cache_dict[uuid])
|
112 |
+
self.flow_cache_dict[uuid] = flow_cache
|
113 |
+
|
114 |
+
# mel overlap fade in out
|
115 |
+
if self.mel_overlap_dict[uuid].shape[2] != 0:
|
116 |
+
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
117 |
+
# append hift cache
|
118 |
+
if self.hift_cache_dict[uuid] is not None:
|
119 |
+
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
120 |
+
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
121 |
+
else:
|
122 |
+
hift_cache_source = torch.zeros(1, 1, 0)
|
123 |
+
# keep overlap mel and hift cache
|
124 |
+
if finalize is False:
|
125 |
+
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
|
126 |
+
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
|
127 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
128 |
+
if self.hift_cache_dict[uuid] is not None:
|
129 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
130 |
+
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
131 |
+
'source': tts_source[:, :, -self.source_cache_len:],
|
132 |
+
'speech': tts_speech[:, -self.source_cache_len:]}
|
133 |
+
tts_speech = tts_speech[:, :-self.source_cache_len]
|
134 |
+
else:
|
135 |
+
if speed != 1.0:
|
136 |
+
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
137 |
+
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
138 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
139 |
+
if self.hift_cache_dict[uuid] is not None:
|
140 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
141 |
+
return tts_speech
|
142 |
+
|
143 |
+
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
144 |
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
145 |
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
146 |
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
147 |
+
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
148 |
+
# this_uuid is used to track variables related to this inference thread
|
149 |
+
this_uuid = str(uuid.uuid1())
|
150 |
+
with self.lock:
|
151 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
152 |
+
self.hift_cache_dict[this_uuid] = None
|
153 |
+
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
154 |
+
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
155 |
+
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
156 |
+
p.start()
|
157 |
+
if stream is True:
|
158 |
+
token_hop_len = self.token_min_hop_len
|
159 |
+
while True:
|
160 |
+
time.sleep(0.1)
|
161 |
+
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
162 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
163 |
+
.unsqueeze(dim=0)
|
164 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
165 |
+
prompt_token=flow_prompt_speech_token,
|
166 |
+
prompt_feat=prompt_speech_feat,
|
167 |
+
embedding=flow_embedding,
|
168 |
+
uuid=this_uuid,
|
169 |
+
finalize=False)
|
170 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
171 |
+
with self.lock:
|
172 |
+
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
173 |
+
# increase token_hop_len for better speech quality
|
174 |
+
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
175 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
176 |
+
break
|
177 |
+
p.join()
|
178 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
179 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
180 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
181 |
+
prompt_token=flow_prompt_speech_token,
|
182 |
+
prompt_feat=prompt_speech_feat,
|
183 |
+
embedding=flow_embedding,
|
184 |
+
uuid=this_uuid,
|
185 |
+
finalize=True)
|
186 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
187 |
+
else:
|
188 |
+
# deal with all tokens
|
189 |
+
p.join()
|
190 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
191 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
192 |
+
prompt_token=flow_prompt_speech_token,
|
193 |
+
prompt_feat=prompt_speech_feat,
|
194 |
+
embedding=flow_embedding,
|
195 |
+
uuid=this_uuid,
|
196 |
+
finalize=True,
|
197 |
+
speed=speed)
|
198 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
199 |
+
with self.lock:
|
200 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
201 |
+
self.llm_end_dict.pop(this_uuid)
|
202 |
+
self.mel_overlap_dict.pop(this_uuid)
|
203 |
+
self.hift_cache_dict.pop(this_uuid)
|
204 |
+
|
205 |
+
def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
|
206 |
+
# this_uuid is used to track variables related to this inference thread
|
207 |
+
this_uuid = str(uuid.uuid1())
|
208 |
+
with self.lock:
|
209 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
|
210 |
+
self.hift_cache_dict[this_uuid] = None
|
211 |
+
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
212 |
+
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
213 |
+
if stream is True:
|
214 |
+
token_hop_len = self.token_min_hop_len
|
215 |
+
while True:
|
216 |
+
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
217 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
218 |
+
.unsqueeze(dim=0)
|
219 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
220 |
+
prompt_token=flow_prompt_speech_token,
|
221 |
+
prompt_feat=prompt_speech_feat,
|
222 |
+
embedding=flow_embedding,
|
223 |
+
uuid=this_uuid,
|
224 |
+
finalize=False)
|
225 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
226 |
+
with self.lock:
|
227 |
+
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
228 |
+
# increase token_hop_len for better speech quality
|
229 |
+
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
230 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
231 |
+
break
|
232 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
233 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid], dim=1).unsqueeze(dim=0)
|
234 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
235 |
+
prompt_token=flow_prompt_speech_token,
|
236 |
+
prompt_feat=prompt_speech_feat,
|
237 |
+
embedding=flow_embedding,
|
238 |
+
uuid=this_uuid,
|
239 |
+
finalize=True)
|
240 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
241 |
+
else:
|
242 |
+
# deal with all tokens
|
243 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
244 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
245 |
+
prompt_token=flow_prompt_speech_token,
|
246 |
+
prompt_feat=prompt_speech_feat,
|
247 |
+
embedding=flow_embedding,
|
248 |
+
uuid=this_uuid,
|
249 |
+
finalize=True,
|
250 |
+
speed=speed)
|
251 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
252 |
+
with self.lock:
|
253 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
254 |
+
self.llm_end_dict.pop(this_uuid)
|
255 |
+
self.mel_overlap_dict.pop(this_uuid)
|
256 |
+
self.hift_cache_dict.pop(this_uuid)
|
257 |
+
|
258 |
+
|
259 |
+
class CosyVoice2Model:
|
260 |
+
|
261 |
+
def __init__(self,
|
262 |
+
llm: torch.nn.Module,
|
263 |
+
flow: torch.nn.Module,
|
264 |
+
hift: torch.nn.Module):
|
265 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
266 |
+
self.llm = llm
|
267 |
+
self.flow = flow
|
268 |
+
self.hift = hift
|
269 |
+
self.token_hop_len = 2 * self.flow.input_frame_rate
|
270 |
+
# here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
|
271 |
+
self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
|
272 |
+
self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
|
273 |
+
# hift cache
|
274 |
+
self.mel_cache_len = 8
|
275 |
+
self.source_cache_len = int(self.mel_cache_len * 480)
|
276 |
+
# speech fade in out
|
277 |
+
self.speech_window = np.hamming(2 * self.source_cache_len)
|
278 |
+
# rtf and decoding related
|
279 |
+
self.stream_scale_factor = 1
|
280 |
+
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
281 |
+
self.lock = threading.Lock()
|
282 |
+
# dict used to store session related variable
|
283 |
+
self.tts_speech_token_dict = {}
|
284 |
+
self.llm_end_dict = {}
|
285 |
+
self.hift_cache_dict = {}
|
286 |
+
|
287 |
+
def load(self, llm_model, flow_model, hift_model):
|
288 |
+
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
289 |
+
self.llm.to(self.device).eval()
|
290 |
+
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
291 |
+
self.flow.to(self.device).eval()
|
292 |
+
self.flow.decoder.fp16 = False
|
293 |
+
# in case hift_model is a hifigan model
|
294 |
+
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
295 |
+
self.hift.load_state_dict(hift_state_dict, strict=True)
|
296 |
+
self.hift.to(self.device).eval()
|
297 |
+
|
298 |
+
def load_jit(self, flow_encoder_model):
|
299 |
+
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
300 |
+
self.flow.encoder = flow_encoder
|
301 |
+
|
302 |
+
def load_onnx(self, flow_decoder_estimator_model):
|
303 |
+
import onnxruntime
|
304 |
+
option = onnxruntime.SessionOptions()
|
305 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
306 |
+
option.intra_op_num_threads = 1
|
307 |
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
308 |
+
del self.flow.decoder.estimator
|
309 |
+
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
|
310 |
+
|
311 |
+
def load_trt(self, flow_decoder_estimator_model):
|
312 |
+
del self.flow.decoder.estimator
|
313 |
+
import tensorrt as trt
|
314 |
+
with open(flow_decoder_estimator_model, 'rb') as f:
|
315 |
+
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
316 |
+
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
|
317 |
+
self.flow.decoder.fp16 = True
|
318 |
+
|
319 |
+
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
320 |
+
with self.llm_context:
|
321 |
+
for i in self.llm.inference(text=text.to(self.device),
|
322 |
+
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
323 |
+
prompt_text=prompt_text.to(self.device),
|
324 |
+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
325 |
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
326 |
+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
327 |
+
embedding=llm_embedding.to(self.device)):
|
328 |
+
self.tts_speech_token_dict[uuid].append(i)
|
329 |
+
self.llm_end_dict[uuid] = True
|
330 |
+
|
331 |
+
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
332 |
+
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
333 |
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
334 |
+
prompt_token=prompt_token.to(self.device),
|
335 |
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
336 |
+
prompt_feat=prompt_feat.to(self.device),
|
337 |
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
338 |
+
embedding=embedding.to(self.device),
|
339 |
+
finalize=finalize)
|
340 |
+
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
341 |
+
# append hift cache
|
342 |
+
if self.hift_cache_dict[uuid] is not None:
|
343 |
+
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
344 |
+
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
345 |
+
else:
|
346 |
+
hift_cache_source = torch.zeros(1, 1, 0)
|
347 |
+
# keep overlap mel and hift cache
|
348 |
+
if finalize is False:
|
349 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
350 |
+
if self.hift_cache_dict[uuid] is not None:
|
351 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
352 |
+
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
353 |
+
'source': tts_source[:, :, -self.source_cache_len:],
|
354 |
+
'speech': tts_speech[:, -self.source_cache_len:]}
|
355 |
+
tts_speech = tts_speech[:, :-self.source_cache_len]
|
356 |
+
else:
|
357 |
+
if speed != 1.0:
|
358 |
+
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
359 |
+
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
360 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
361 |
+
if self.hift_cache_dict[uuid] is not None:
|
362 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
363 |
+
return tts_speech
|
364 |
+
|
365 |
+
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
366 |
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
367 |
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
368 |
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
369 |
+
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
370 |
+
# this_uuid is used to track variables related to this inference thread
|
371 |
+
this_uuid = str(uuid.uuid1())
|
372 |
+
with self.lock:
|
373 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
374 |
+
self.hift_cache_dict[this_uuid] = None
|
375 |
+
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
376 |
+
p.start()
|
377 |
+
if stream is True:
|
378 |
+
token_offset = 0
|
379 |
+
while True:
|
380 |
+
time.sleep(0.1)
|
381 |
+
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
|
382 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]) \
|
383 |
+
.unsqueeze(dim=0)
|
384 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
385 |
+
prompt_token=flow_prompt_speech_token,
|
386 |
+
prompt_feat=prompt_speech_feat,
|
387 |
+
embedding=flow_embedding,
|
388 |
+
uuid=this_uuid,
|
389 |
+
token_offset=token_offset,
|
390 |
+
finalize=False)
|
391 |
+
token_offset += self.token_hop_len
|
392 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
393 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
|
394 |
+
break
|
395 |
+
p.join()
|
396 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
397 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
398 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
399 |
+
prompt_token=flow_prompt_speech_token,
|
400 |
+
prompt_feat=prompt_speech_feat,
|
401 |
+
embedding=flow_embedding,
|
402 |
+
uuid=this_uuid,
|
403 |
+
token_offset=token_offset,
|
404 |
+
finalize=True)
|
405 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
406 |
+
else:
|
407 |
+
# deal with all tokens
|
408 |
+
p.join()
|
409 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
410 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
411 |
+
prompt_token=flow_prompt_speech_token,
|
412 |
+
prompt_feat=prompt_speech_feat,
|
413 |
+
embedding=flow_embedding,
|
414 |
+
uuid=this_uuid,
|
415 |
+
token_offset=0,
|
416 |
+
finalize=True,
|
417 |
+
speed=speed)
|
418 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
419 |
+
with self.lock:
|
420 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
421 |
+
self.llm_end_dict.pop(this_uuid)
|
cosyvoice/dataset/__init__.py
ADDED
File without changes
|
cosyvoice/dataset/dataset.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
2 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import random
|
17 |
+
import json
|
18 |
+
import math
|
19 |
+
from functools import partial
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.distributed as dist
|
23 |
+
from torch.utils.data import IterableDataset
|
24 |
+
from cosyvoice.utils.file_utils import read_lists, read_json_lists
|
25 |
+
|
26 |
+
|
27 |
+
class Processor(IterableDataset):
|
28 |
+
|
29 |
+
def __init__(self, source, f, *args, **kw):
|
30 |
+
assert callable(f)
|
31 |
+
self.source = source
|
32 |
+
self.f = f
|
33 |
+
self.args = args
|
34 |
+
self.kw = kw
|
35 |
+
|
36 |
+
def set_epoch(self, epoch):
|
37 |
+
self.source.set_epoch(epoch)
|
38 |
+
|
39 |
+
def __iter__(self):
|
40 |
+
""" Return an iterator over the source dataset processed by the
|
41 |
+
given processor.
|
42 |
+
"""
|
43 |
+
assert self.source is not None
|
44 |
+
assert callable(self.f)
|
45 |
+
return self.f(iter(self.source), *self.args, **self.kw)
|
46 |
+
|
47 |
+
def apply(self, f):
|
48 |
+
assert callable(f)
|
49 |
+
return Processor(self, f, *self.args, **self.kw)
|
50 |
+
|
51 |
+
|
52 |
+
class DistributedSampler:
|
53 |
+
|
54 |
+
def __init__(self, shuffle=True, partition=True):
|
55 |
+
self.epoch = -1
|
56 |
+
self.update()
|
57 |
+
self.shuffle = shuffle
|
58 |
+
self.partition = partition
|
59 |
+
|
60 |
+
def update(self):
|
61 |
+
assert dist.is_available()
|
62 |
+
if dist.is_initialized():
|
63 |
+
self.rank = dist.get_rank()
|
64 |
+
self.world_size = dist.get_world_size()
|
65 |
+
else:
|
66 |
+
self.rank = 0
|
67 |
+
self.world_size = 1
|
68 |
+
worker_info = torch.utils.data.get_worker_info()
|
69 |
+
if worker_info is None:
|
70 |
+
self.worker_id = 0
|
71 |
+
self.num_workers = 1
|
72 |
+
else:
|
73 |
+
self.worker_id = worker_info.id
|
74 |
+
self.num_workers = worker_info.num_workers
|
75 |
+
return dict(rank=self.rank,
|
76 |
+
world_size=self.world_size,
|
77 |
+
worker_id=self.worker_id,
|
78 |
+
num_workers=self.num_workers)
|
79 |
+
|
80 |
+
def set_epoch(self, epoch):
|
81 |
+
self.epoch = epoch
|
82 |
+
|
83 |
+
def sample(self, data):
|
84 |
+
""" Sample data according to rank/world_size/num_workers
|
85 |
+
|
86 |
+
Args:
|
87 |
+
data(List): input data list
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
List: data list after sample
|
91 |
+
"""
|
92 |
+
data = list(range(len(data)))
|
93 |
+
# force datalist even
|
94 |
+
if self.partition:
|
95 |
+
if self.shuffle:
|
96 |
+
random.Random(self.epoch).shuffle(data)
|
97 |
+
if len(data) < self.world_size:
|
98 |
+
data = data * math.ceil(self.world_size / len(data))
|
99 |
+
data = data[:self.world_size]
|
100 |
+
data = data[self.rank::self.world_size]
|
101 |
+
if len(data) < self.num_workers:
|
102 |
+
data = data * math.ceil(self.num_workers / len(data))
|
103 |
+
data = data[:self.num_workers]
|
104 |
+
data = data[self.worker_id::self.num_workers]
|
105 |
+
return data
|
106 |
+
|
107 |
+
|
108 |
+
class DataList(IterableDataset):
|
109 |
+
|
110 |
+
def __init__(self, lists, shuffle=True, partition=True):
|
111 |
+
self.lists = lists
|
112 |
+
self.sampler = DistributedSampler(shuffle, partition)
|
113 |
+
|
114 |
+
def set_epoch(self, epoch):
|
115 |
+
self.sampler.set_epoch(epoch)
|
116 |
+
|
117 |
+
def __iter__(self):
|
118 |
+
sampler_info = self.sampler.update()
|
119 |
+
indexes = self.sampler.sample(self.lists)
|
120 |
+
for index in indexes:
|
121 |
+
data = dict(src=self.lists[index])
|
122 |
+
data.update(sampler_info)
|
123 |
+
yield data
|
124 |
+
|
125 |
+
|
126 |
+
def Dataset(data_list_file,
|
127 |
+
data_pipeline,
|
128 |
+
mode='train',
|
129 |
+
gan=False,
|
130 |
+
shuffle=True,
|
131 |
+
partition=True,
|
132 |
+
tts_file='',
|
133 |
+
prompt_utt2data=''):
|
134 |
+
""" Construct dataset from arguments
|
135 |
+
|
136 |
+
We have two shuffle stage in the Dataset. The first is global
|
137 |
+
shuffle at shards tar/raw file level. The second is global shuffle
|
138 |
+
at training samples level.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
data_type(str): raw/shard
|
142 |
+
tokenizer (BaseTokenizer): tokenizer to tokenize
|
143 |
+
partition(bool): whether to do data partition in terms of rank
|
144 |
+
"""
|
145 |
+
assert mode in ['train', 'inference']
|
146 |
+
lists = read_lists(data_list_file)
|
147 |
+
if mode == 'inference':
|
148 |
+
with open(tts_file) as f:
|
149 |
+
tts_data = json.load(f)
|
150 |
+
utt2lists = read_json_lists(prompt_utt2data)
|
151 |
+
# filter unnecessary file in inference mode
|
152 |
+
lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
|
153 |
+
dataset = DataList(lists,
|
154 |
+
shuffle=shuffle,
|
155 |
+
partition=partition)
|
156 |
+
if mode == 'inference':
|
157 |
+
# map partial arg to parquet_opener func in inference mode
|
158 |
+
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
|
159 |
+
if gan is True:
|
160 |
+
# map partial arg to padding func in gan mode
|
161 |
+
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
|
162 |
+
for func in data_pipeline:
|
163 |
+
dataset = Processor(dataset, func, mode=mode)
|
164 |
+
return dataset
|
cosyvoice/dataset/processor.py
ADDED
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import logging
|
15 |
+
import random
|
16 |
+
|
17 |
+
import pyarrow.parquet as pq
|
18 |
+
from io import BytesIO
|
19 |
+
import torch
|
20 |
+
import torchaudio
|
21 |
+
from torch.nn.utils.rnn import pad_sequence
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
torchaudio.set_audio_backend('soundfile')
|
25 |
+
|
26 |
+
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
|
27 |
+
|
28 |
+
|
29 |
+
def parquet_opener(data, mode='train', tts_data={}):
|
30 |
+
""" Give url or local file, return file descriptor
|
31 |
+
Inplace operation.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
data(Iterable[str]): url or local file list
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Iterable[{src, stream}]
|
38 |
+
"""
|
39 |
+
for sample in data:
|
40 |
+
assert 'src' in sample
|
41 |
+
url = sample['src']
|
42 |
+
try:
|
43 |
+
for df in pq.ParquetFile(url).iter_batches(batch_size=64):
|
44 |
+
df = df.to_pandas()
|
45 |
+
for i in range(len(df)):
|
46 |
+
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
|
47 |
+
continue
|
48 |
+
sample.update(dict(df.loc[i]))
|
49 |
+
if mode == 'train':
|
50 |
+
# NOTE do not return sample directly, must initialize a new dict
|
51 |
+
yield {**sample}
|
52 |
+
else:
|
53 |
+
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
|
54 |
+
yield {**sample, 'tts_index': index, 'tts_text': text}
|
55 |
+
except Exception as ex:
|
56 |
+
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
|
57 |
+
|
58 |
+
|
59 |
+
def filter(data,
|
60 |
+
max_length=10240,
|
61 |
+
min_length=10,
|
62 |
+
token_max_length=200,
|
63 |
+
token_min_length=1,
|
64 |
+
min_output_input_ratio=0.0005,
|
65 |
+
max_output_input_ratio=1,
|
66 |
+
mode='train'):
|
67 |
+
""" Filter sample according to feature and label length
|
68 |
+
Inplace operation.
|
69 |
+
|
70 |
+
Args::
|
71 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
72 |
+
max_length: drop utterance which is greater than max_length(10ms)
|
73 |
+
min_length: drop utterance which is less than min_length(10ms)
|
74 |
+
token_max_length: drop utterance which is greater than
|
75 |
+
token_max_length, especially when use char unit for
|
76 |
+
english modeling
|
77 |
+
token_min_length: drop utterance which is
|
78 |
+
less than token_max_length
|
79 |
+
min_output_input_ratio: minimal ration of
|
80 |
+
token_length / feats_length(10ms)
|
81 |
+
max_output_input_ratio: maximum ration of
|
82 |
+
token_length / feats_length(10ms)
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
Iterable[{key, wav, label, sample_rate}]
|
86 |
+
"""
|
87 |
+
for sample in data:
|
88 |
+
sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
|
89 |
+
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
|
90 |
+
del sample['audio_data']
|
91 |
+
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
92 |
+
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
|
93 |
+
if num_frames < min_length:
|
94 |
+
continue
|
95 |
+
if num_frames > max_length:
|
96 |
+
continue
|
97 |
+
if len(sample['text_token']) < token_min_length:
|
98 |
+
continue
|
99 |
+
if len(sample['text_token']) > token_max_length:
|
100 |
+
continue
|
101 |
+
if len(sample['speech_token']) == 0:
|
102 |
+
continue
|
103 |
+
if num_frames != 0:
|
104 |
+
if len(sample['text_token']) / num_frames < min_output_input_ratio:
|
105 |
+
continue
|
106 |
+
if len(sample['text_token']) / num_frames > max_output_input_ratio:
|
107 |
+
continue
|
108 |
+
yield sample
|
109 |
+
|
110 |
+
|
111 |
+
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
|
112 |
+
""" Resample data.
|
113 |
+
Inplace operation.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
117 |
+
resample_rate: target resample rate
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
Iterable[{key, wav, label, sample_rate}]
|
121 |
+
"""
|
122 |
+
for sample in data:
|
123 |
+
assert 'sample_rate' in sample
|
124 |
+
assert 'speech' in sample
|
125 |
+
sample_rate = sample['sample_rate']
|
126 |
+
waveform = sample['speech']
|
127 |
+
if sample_rate != resample_rate:
|
128 |
+
if sample_rate < min_sample_rate:
|
129 |
+
continue
|
130 |
+
sample['sample_rate'] = resample_rate
|
131 |
+
sample['speech'] = torchaudio.transforms.Resample(
|
132 |
+
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
133 |
+
max_val = sample['speech'].abs().max()
|
134 |
+
if max_val > 1:
|
135 |
+
sample['speech'] /= max_val
|
136 |
+
yield sample
|
137 |
+
|
138 |
+
|
139 |
+
def truncate(data, truncate_length=24576, mode='train'):
|
140 |
+
""" Truncate data.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
144 |
+
truncate_length: truncate length
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
Iterable[{key, wav, label, sample_rate}]
|
148 |
+
"""
|
149 |
+
for sample in data:
|
150 |
+
waveform = sample['speech']
|
151 |
+
if waveform.shape[1] > truncate_length:
|
152 |
+
start = random.randint(0, waveform.shape[1] - truncate_length)
|
153 |
+
waveform = waveform[:, start: start + truncate_length]
|
154 |
+
else:
|
155 |
+
waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
|
156 |
+
sample['speech'] = waveform
|
157 |
+
yield sample
|
158 |
+
|
159 |
+
|
160 |
+
def compute_fbank(data,
|
161 |
+
feat_extractor,
|
162 |
+
mode='train'):
|
163 |
+
""" Extract fbank
|
164 |
+
|
165 |
+
Args:
|
166 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
Iterable[{key, feat, label}]
|
170 |
+
"""
|
171 |
+
for sample in data:
|
172 |
+
assert 'sample_rate' in sample
|
173 |
+
assert 'speech' in sample
|
174 |
+
assert 'utt' in sample
|
175 |
+
assert 'text_token' in sample
|
176 |
+
waveform = sample['speech']
|
177 |
+
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
178 |
+
sample['speech_feat'] = mat
|
179 |
+
yield sample
|
180 |
+
|
181 |
+
|
182 |
+
def compute_f0(data, pitch_extractor, mode='train'):
|
183 |
+
""" Extract f0
|
184 |
+
|
185 |
+
Args:
|
186 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
Iterable[{key, feat, label}]
|
190 |
+
"""
|
191 |
+
for sample in data:
|
192 |
+
assert 'sample_rate' in sample
|
193 |
+
assert 'speech' in sample
|
194 |
+
assert 'utt' in sample
|
195 |
+
assert 'text_token' in sample
|
196 |
+
waveform = sample['speech']
|
197 |
+
mat = pitch_extractor(waveform).transpose(1, 2)
|
198 |
+
mat = F.interpolate(mat, size=sample['speech_feat'].shape[0], mode='linear')
|
199 |
+
sample['pitch_feat'] = mat[0, 0]
|
200 |
+
yield sample
|
201 |
+
|
202 |
+
|
203 |
+
def parse_embedding(data, normalize, mode='train'):
|
204 |
+
""" Parse utt_embedding/spk_embedding
|
205 |
+
|
206 |
+
Args:
|
207 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
Iterable[{key, feat, label}]
|
211 |
+
"""
|
212 |
+
for sample in data:
|
213 |
+
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
|
214 |
+
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
|
215 |
+
if normalize:
|
216 |
+
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
|
217 |
+
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
|
218 |
+
yield sample
|
219 |
+
|
220 |
+
|
221 |
+
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
|
222 |
+
""" Decode text to chars or BPE
|
223 |
+
Inplace operation
|
224 |
+
|
225 |
+
Args:
|
226 |
+
data: Iterable[{key, wav, txt, sample_rate}]
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
Iterable[{key, wav, txt, tokens, label, sample_rate}]
|
230 |
+
"""
|
231 |
+
tokenizer = get_tokenizer()
|
232 |
+
for sample in data:
|
233 |
+
assert 'text' in sample
|
234 |
+
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
|
235 |
+
if mode == 'inference':
|
236 |
+
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
|
237 |
+
yield sample
|
238 |
+
|
239 |
+
|
240 |
+
def shuffle(data, shuffle_size=10000, mode='train'):
|
241 |
+
""" Local shuffle the data
|
242 |
+
|
243 |
+
Args:
|
244 |
+
data: Iterable[{key, feat, label}]
|
245 |
+
shuffle_size: buffer size for shuffle
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
Iterable[{key, feat, label}]
|
249 |
+
"""
|
250 |
+
buf = []
|
251 |
+
for sample in data:
|
252 |
+
buf.append(sample)
|
253 |
+
if len(buf) >= shuffle_size:
|
254 |
+
random.shuffle(buf)
|
255 |
+
for x in buf:
|
256 |
+
yield x
|
257 |
+
buf = []
|
258 |
+
# The sample left over
|
259 |
+
random.shuffle(buf)
|
260 |
+
for x in buf:
|
261 |
+
yield x
|
262 |
+
|
263 |
+
|
264 |
+
def sort(data, sort_size=500, mode='train'):
|
265 |
+
""" Sort the data by feature length.
|
266 |
+
Sort is used after shuffle and before batch, so we can group
|
267 |
+
utts with similar lengths into a batch, and `sort_size` should
|
268 |
+
be less than `shuffle_size`
|
269 |
+
|
270 |
+
Args:
|
271 |
+
data: Iterable[{key, feat, label}]
|
272 |
+
sort_size: buffer size for sort
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
Iterable[{key, feat, label}]
|
276 |
+
"""
|
277 |
+
|
278 |
+
buf = []
|
279 |
+
for sample in data:
|
280 |
+
buf.append(sample)
|
281 |
+
if len(buf) >= sort_size:
|
282 |
+
buf.sort(key=lambda x: x['speech_feat'].size(0))
|
283 |
+
for x in buf:
|
284 |
+
yield x
|
285 |
+
buf = []
|
286 |
+
# The sample left over
|
287 |
+
buf.sort(key=lambda x: x['speech_feat'].size(0))
|
288 |
+
for x in buf:
|
289 |
+
yield x
|
290 |
+
|
291 |
+
|
292 |
+
def static_batch(data, batch_size=16):
|
293 |
+
""" Static batch the data by `batch_size`
|
294 |
+
|
295 |
+
Args:
|
296 |
+
data: Iterable[{key, feat, label}]
|
297 |
+
batch_size: batch size
|
298 |
+
|
299 |
+
Returns:
|
300 |
+
Iterable[List[{key, feat, label}]]
|
301 |
+
"""
|
302 |
+
buf = []
|
303 |
+
for sample in data:
|
304 |
+
buf.append(sample)
|
305 |
+
if len(buf) >= batch_size:
|
306 |
+
yield buf
|
307 |
+
buf = []
|
308 |
+
if len(buf) > 0:
|
309 |
+
yield buf
|
310 |
+
|
311 |
+
|
312 |
+
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
|
313 |
+
""" Dynamic batch the data until the total frames in batch
|
314 |
+
reach `max_frames_in_batch`
|
315 |
+
|
316 |
+
Args:
|
317 |
+
data: Iterable[{key, feat, label}]
|
318 |
+
max_frames_in_batch: max_frames in one batch
|
319 |
+
|
320 |
+
Returns:
|
321 |
+
Iterable[List[{key, feat, label}]]
|
322 |
+
"""
|
323 |
+
buf = []
|
324 |
+
longest_frames = 0
|
325 |
+
for sample in data:
|
326 |
+
assert 'speech_feat' in sample
|
327 |
+
assert isinstance(sample['speech_feat'], torch.Tensor)
|
328 |
+
new_sample_frames = sample['speech_feat'].size(0)
|
329 |
+
longest_frames = max(longest_frames, new_sample_frames)
|
330 |
+
frames_after_padding = longest_frames * (len(buf) + 1)
|
331 |
+
if frames_after_padding > max_frames_in_batch:
|
332 |
+
yield buf
|
333 |
+
buf = [sample]
|
334 |
+
longest_frames = new_sample_frames
|
335 |
+
else:
|
336 |
+
buf.append(sample)
|
337 |
+
if len(buf) > 0:
|
338 |
+
yield buf
|
339 |
+
|
340 |
+
|
341 |
+
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
|
342 |
+
""" Wrapper for static/dynamic batch
|
343 |
+
"""
|
344 |
+
if mode == 'inference':
|
345 |
+
return static_batch(data, 1)
|
346 |
+
else:
|
347 |
+
if batch_type == 'static':
|
348 |
+
return static_batch(data, batch_size)
|
349 |
+
elif batch_type == 'dynamic':
|
350 |
+
return dynamic_batch(data, max_frames_in_batch)
|
351 |
+
else:
|
352 |
+
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
353 |
+
|
354 |
+
|
355 |
+
def padding(data, use_spk_embedding, mode='train', gan=False):
|
356 |
+
""" Padding the data into training data
|
357 |
+
|
358 |
+
Args:
|
359 |
+
data: Iterable[List[{key, feat, label}]]
|
360 |
+
|
361 |
+
Returns:
|
362 |
+
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
|
363 |
+
"""
|
364 |
+
for sample in data:
|
365 |
+
assert isinstance(sample, list)
|
366 |
+
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
|
367 |
+
dtype=torch.int32)
|
368 |
+
order = torch.argsort(speech_feat_len, descending=True)
|
369 |
+
|
370 |
+
utts = [sample[i]['utt'] for i in order]
|
371 |
+
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
|
372 |
+
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
|
373 |
+
speech = pad_sequence(speech, batch_first=True, padding_value=0)
|
374 |
+
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
|
375 |
+
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
376 |
+
speech_token = pad_sequence(speech_token,
|
377 |
+
batch_first=True,
|
378 |
+
padding_value=0)
|
379 |
+
speech_feat = [sample[i]['speech_feat'] for i in order]
|
380 |
+
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
|
381 |
+
speech_feat = pad_sequence(speech_feat,
|
382 |
+
batch_first=True,
|
383 |
+
padding_value=0)
|
384 |
+
text = [sample[i]['text'] for i in order]
|
385 |
+
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
|
386 |
+
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
387 |
+
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
|
388 |
+
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
|
389 |
+
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
390 |
+
batch = {
|
391 |
+
"utts": utts,
|
392 |
+
"speech": speech,
|
393 |
+
"speech_len": speech_len,
|
394 |
+
"speech_token": speech_token,
|
395 |
+
"speech_token_len": speech_token_len,
|
396 |
+
"speech_feat": speech_feat,
|
397 |
+
"speech_feat_len": speech_feat_len,
|
398 |
+
"text": text,
|
399 |
+
"text_token": text_token,
|
400 |
+
"text_token_len": text_token_len,
|
401 |
+
"utt_embedding": utt_embedding,
|
402 |
+
"spk_embedding": spk_embedding,
|
403 |
+
}
|
404 |
+
if gan is True:
|
405 |
+
# in gan train, we need pitch_feat
|
406 |
+
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
407 |
+
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
|
408 |
+
pitch_feat = pad_sequence(pitch_feat,
|
409 |
+
batch_first=True,
|
410 |
+
padding_value=0)
|
411 |
+
batch["pitch_feat"] = pitch_feat
|
412 |
+
batch["pitch_feat_len"] = pitch_feat_len
|
413 |
+
else:
|
414 |
+
# only gan train needs speech, delete it to save memory
|
415 |
+
del batch["speech"]
|
416 |
+
del batch["speech_len"]
|
417 |
+
if mode == 'inference':
|
418 |
+
tts_text = [sample[i]['tts_text'] for i in order]
|
419 |
+
tts_index = [sample[i]['tts_index'] for i in order]
|
420 |
+
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
|
421 |
+
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
|
422 |
+
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
|
423 |
+
batch.update({'tts_text': tts_text,
|
424 |
+
'tts_index': tts_index,
|
425 |
+
'tts_text_token': tts_text_token,
|
426 |
+
'tts_text_token_len': tts_text_token_len})
|
427 |
+
if use_spk_embedding is True:
|
428 |
+
batch["embedding"] = batch["spk_embedding"]
|
429 |
+
else:
|
430 |
+
batch["embedding"] = batch["utt_embedding"]
|
431 |
+
yield batch
|
cosyvoice/flow/decoder.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from einops import pack, rearrange, repeat
|
18 |
+
from cosyvoice.utils.common import mask_to_bias
|
19 |
+
from cosyvoice.utils.mask import add_optional_chunk_mask
|
20 |
+
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
21 |
+
from matcha.models.components.transformer import BasicTransformerBlock
|
22 |
+
|
23 |
+
|
24 |
+
class Transpose(torch.nn.Module):
|
25 |
+
def __init__(self, dim0: int, dim1: int):
|
26 |
+
super().__init__()
|
27 |
+
self.dim0 = dim0
|
28 |
+
self.dim1 = dim1
|
29 |
+
|
30 |
+
def forward(self, x: torch.Tensor):
|
31 |
+
x = torch.transpose(x, self.dim0, self.dim1)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class CausalBlock1D(Block1D):
|
36 |
+
def __init__(self, dim: int, dim_out: int):
|
37 |
+
super(CausalBlock1D, self).__init__(dim, dim_out)
|
38 |
+
self.block = torch.nn.Sequential(
|
39 |
+
CausalConv1d(dim, dim_out, 3),
|
40 |
+
Transpose(1, 2),
|
41 |
+
nn.LayerNorm(dim_out),
|
42 |
+
Transpose(1, 2),
|
43 |
+
nn.Mish(),
|
44 |
+
)
|
45 |
+
|
46 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
47 |
+
output = self.block(x * mask)
|
48 |
+
return output * mask
|
49 |
+
|
50 |
+
|
51 |
+
class CausalResnetBlock1D(ResnetBlock1D):
|
52 |
+
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int=8):
|
53 |
+
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
54 |
+
self.block1 = CausalBlock1D(dim, dim_out)
|
55 |
+
self.block2 = CausalBlock1D(dim_out, dim_out)
|
56 |
+
|
57 |
+
|
58 |
+
class CausalConv1d(torch.nn.Conv1d):
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
in_channels: int,
|
62 |
+
out_channels: int,
|
63 |
+
kernel_size: int,
|
64 |
+
stride: int = 1,
|
65 |
+
dilation: int = 1,
|
66 |
+
groups: int = 1,
|
67 |
+
bias: bool = True,
|
68 |
+
padding_mode: str = 'zeros',
|
69 |
+
device=None,
|
70 |
+
dtype=None
|
71 |
+
) -> None:
|
72 |
+
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
73 |
+
kernel_size, stride,
|
74 |
+
padding=0, dilation=dilation,
|
75 |
+
groups=groups, bias=bias,
|
76 |
+
padding_mode=padding_mode,
|
77 |
+
device=device, dtype=dtype
|
78 |
+
)
|
79 |
+
assert stride == 1
|
80 |
+
self.causal_padding = (kernel_size - 1, 0)
|
81 |
+
|
82 |
+
def forward(self, x: torch.Tensor):
|
83 |
+
x = F.pad(x, self.causal_padding)
|
84 |
+
x = super(CausalConv1d, self).forward(x)
|
85 |
+
return x
|
86 |
+
|
87 |
+
|
88 |
+
class ConditionalDecoder(nn.Module):
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
in_channels,
|
92 |
+
out_channels,
|
93 |
+
causal=False,
|
94 |
+
channels=(256, 256),
|
95 |
+
dropout=0.05,
|
96 |
+
attention_head_dim=64,
|
97 |
+
n_blocks=1,
|
98 |
+
num_mid_blocks=2,
|
99 |
+
num_heads=4,
|
100 |
+
act_fn="snake",
|
101 |
+
):
|
102 |
+
"""
|
103 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
104 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
105 |
+
"""
|
106 |
+
super().__init__()
|
107 |
+
channels = tuple(channels)
|
108 |
+
self.in_channels = in_channels
|
109 |
+
self.out_channels = out_channels
|
110 |
+
self.causal = causal
|
111 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
112 |
+
time_embed_dim = channels[0] * 4
|
113 |
+
self.time_mlp = TimestepEmbedding(
|
114 |
+
in_channels=in_channels,
|
115 |
+
time_embed_dim=time_embed_dim,
|
116 |
+
act_fn="silu",
|
117 |
+
)
|
118 |
+
self.down_blocks = nn.ModuleList([])
|
119 |
+
self.mid_blocks = nn.ModuleList([])
|
120 |
+
self.up_blocks = nn.ModuleList([])
|
121 |
+
|
122 |
+
output_channel = in_channels
|
123 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
124 |
+
input_channel = output_channel
|
125 |
+
output_channel = channels[i]
|
126 |
+
is_last = i == len(channels) - 1
|
127 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
128 |
+
transformer_blocks = nn.ModuleList(
|
129 |
+
[
|
130 |
+
BasicTransformerBlock(
|
131 |
+
dim=output_channel,
|
132 |
+
num_attention_heads=num_heads,
|
133 |
+
attention_head_dim=attention_head_dim,
|
134 |
+
dropout=dropout,
|
135 |
+
activation_fn=act_fn,
|
136 |
+
)
|
137 |
+
for _ in range(n_blocks)
|
138 |
+
]
|
139 |
+
)
|
140 |
+
downsample = (
|
141 |
+
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
142 |
+
)
|
143 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
144 |
+
|
145 |
+
for _ in range(num_mid_blocks):
|
146 |
+
input_channel = channels[-1]
|
147 |
+
out_channels = channels[-1]
|
148 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
149 |
+
|
150 |
+
transformer_blocks = nn.ModuleList(
|
151 |
+
[
|
152 |
+
BasicTransformerBlock(
|
153 |
+
dim=output_channel,
|
154 |
+
num_attention_heads=num_heads,
|
155 |
+
attention_head_dim=attention_head_dim,
|
156 |
+
dropout=dropout,
|
157 |
+
activation_fn=act_fn,
|
158 |
+
)
|
159 |
+
for _ in range(n_blocks)
|
160 |
+
]
|
161 |
+
)
|
162 |
+
|
163 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
164 |
+
|
165 |
+
channels = channels[::-1] + (channels[0],)
|
166 |
+
for i in range(len(channels) - 1):
|
167 |
+
input_channel = channels[i] * 2
|
168 |
+
output_channel = channels[i + 1]
|
169 |
+
is_last = i == len(channels) - 2
|
170 |
+
resnet = CausalResnetBlock1D(
|
171 |
+
dim=input_channel,
|
172 |
+
dim_out=output_channel,
|
173 |
+
time_emb_dim=time_embed_dim,
|
174 |
+
) if self.causal else ResnetBlock1D(
|
175 |
+
dim=input_channel,
|
176 |
+
dim_out=output_channel,
|
177 |
+
time_emb_dim=time_embed_dim,
|
178 |
+
)
|
179 |
+
transformer_blocks = nn.ModuleList(
|
180 |
+
[
|
181 |
+
BasicTransformerBlock(
|
182 |
+
dim=output_channel,
|
183 |
+
num_attention_heads=num_heads,
|
184 |
+
attention_head_dim=attention_head_dim,
|
185 |
+
dropout=dropout,
|
186 |
+
activation_fn=act_fn,
|
187 |
+
)
|
188 |
+
for _ in range(n_blocks)
|
189 |
+
]
|
190 |
+
)
|
191 |
+
upsample = (
|
192 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
193 |
+
if not is_last
|
194 |
+
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
195 |
+
)
|
196 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
197 |
+
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
|
198 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
199 |
+
self.initialize_weights()
|
200 |
+
|
201 |
+
def initialize_weights(self):
|
202 |
+
for m in self.modules():
|
203 |
+
if isinstance(m, nn.Conv1d):
|
204 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
205 |
+
if m.bias is not None:
|
206 |
+
nn.init.constant_(m.bias, 0)
|
207 |
+
elif isinstance(m, nn.GroupNorm):
|
208 |
+
nn.init.constant_(m.weight, 1)
|
209 |
+
nn.init.constant_(m.bias, 0)
|
210 |
+
elif isinstance(m, nn.Linear):
|
211 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
212 |
+
if m.bias is not None:
|
213 |
+
nn.init.constant_(m.bias, 0)
|
214 |
+
|
215 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
216 |
+
"""Forward pass of the UNet1DConditional model.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
220 |
+
mask (_type_): shape (batch_size, 1, time)
|
221 |
+
t (_type_): shape (batch_size)
|
222 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
223 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
224 |
+
|
225 |
+
Raises:
|
226 |
+
ValueError: _description_
|
227 |
+
ValueError: _description_
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
_type_: _description_
|
231 |
+
"""
|
232 |
+
|
233 |
+
t = self.time_embeddings(t).to(t.dtype)
|
234 |
+
t = self.time_mlp(t)
|
235 |
+
|
236 |
+
x = pack([x, mu], "b * t")[0]
|
237 |
+
|
238 |
+
if spks is not None:
|
239 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
240 |
+
x = pack([x, spks], "b * t")[0]
|
241 |
+
if cond is not None:
|
242 |
+
x = pack([x, cond], "b * t")[0]
|
243 |
+
|
244 |
+
hiddens = []
|
245 |
+
masks = [mask]
|
246 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
247 |
+
mask_down = masks[-1]
|
248 |
+
x = resnet(x, mask_down, t)
|
249 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
250 |
+
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
251 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
252 |
+
attn_mask = mask_to_bias(attn_mask==1, x.dtype)
|
253 |
+
for transformer_block in transformer_blocks:
|
254 |
+
x = transformer_block(
|
255 |
+
hidden_states=x,
|
256 |
+
attention_mask=attn_mask,
|
257 |
+
timestep=t,
|
258 |
+
)
|
259 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
260 |
+
hiddens.append(x) # Save hidden states for skip connections
|
261 |
+
x = downsample(x * mask_down)
|
262 |
+
masks.append(mask_down[:, :, ::2])
|
263 |
+
masks = masks[:-1]
|
264 |
+
mask_mid = masks[-1]
|
265 |
+
|
266 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
267 |
+
x = resnet(x, mask_mid, t)
|
268 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
269 |
+
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
270 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
271 |
+
attn_mask = mask_to_bias(attn_mask==1, x.dtype)
|
272 |
+
for transformer_block in transformer_blocks:
|
273 |
+
x = transformer_block(
|
274 |
+
hidden_states=x,
|
275 |
+
attention_mask=attn_mask,
|
276 |
+
timestep=t,
|
277 |
+
)
|
278 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
279 |
+
|
280 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
281 |
+
mask_up = masks.pop()
|
282 |
+
skip = hiddens.pop()
|
283 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
284 |
+
x = resnet(x, mask_up, t)
|
285 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
286 |
+
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
287 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
288 |
+
attn_mask = mask_to_bias(attn_mask==1, x.dtype)
|
289 |
+
for transformer_block in transformer_blocks:
|
290 |
+
x = transformer_block(
|
291 |
+
hidden_states=x,
|
292 |
+
attention_mask=attn_mask,
|
293 |
+
timestep=t,
|
294 |
+
)
|
295 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
296 |
+
x = upsample(x * mask_up)
|
297 |
+
x = self.final_block(x, mask_up)
|
298 |
+
output = self.final_proj(x * mask_up)
|
299 |
+
return output * mask
|
cosyvoice/flow/flow.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import logging
|
15 |
+
import random
|
16 |
+
from typing import Dict, Optional
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
from torch.nn import functional as F
|
20 |
+
from omegaconf import DictConfig
|
21 |
+
from cosyvoice.utils.mask import make_pad_mask
|
22 |
+
|
23 |
+
|
24 |
+
class MaskedDiffWithXvec(torch.nn.Module):
|
25 |
+
def __init__(self,
|
26 |
+
input_size: int = 512,
|
27 |
+
output_size: int = 80,
|
28 |
+
spk_embed_dim: int = 192,
|
29 |
+
output_type: str = "mel",
|
30 |
+
vocab_size: int = 4096,
|
31 |
+
input_frame_rate: int = 50,
|
32 |
+
only_mask_loss: bool = True,
|
33 |
+
encoder: torch.nn.Module = None,
|
34 |
+
length_regulator: torch.nn.Module = None,
|
35 |
+
decoder: torch.nn.Module = None,
|
36 |
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
37 |
+
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
38 |
+
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
39 |
+
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
40 |
+
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
41 |
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
42 |
+
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
43 |
+
super().__init__()
|
44 |
+
self.input_size = input_size
|
45 |
+
self.output_size = output_size
|
46 |
+
self.decoder_conf = decoder_conf
|
47 |
+
self.mel_feat_conf = mel_feat_conf
|
48 |
+
self.vocab_size = vocab_size
|
49 |
+
self.output_type = output_type
|
50 |
+
self.input_frame_rate = input_frame_rate
|
51 |
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
52 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
53 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
54 |
+
self.encoder = encoder
|
55 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
56 |
+
self.decoder = decoder
|
57 |
+
self.length_regulator = length_regulator
|
58 |
+
self.only_mask_loss = only_mask_loss
|
59 |
+
|
60 |
+
def forward(
|
61 |
+
self,
|
62 |
+
batch: dict,
|
63 |
+
device: torch.device,
|
64 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
65 |
+
token = batch['speech_token'].to(device)
|
66 |
+
token_len = batch['speech_token_len'].to(device)
|
67 |
+
feat = batch['speech_feat'].to(device)
|
68 |
+
feat_len = batch['speech_feat_len'].to(device)
|
69 |
+
embedding = batch['embedding'].to(device)
|
70 |
+
|
71 |
+
# xvec projection
|
72 |
+
embedding = F.normalize(embedding, dim=1)
|
73 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
74 |
+
|
75 |
+
# concat text and prompt_text
|
76 |
+
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
77 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
78 |
+
|
79 |
+
# text encode
|
80 |
+
h, h_lengths = self.encoder(token, token_len)
|
81 |
+
h = self.encoder_proj(h)
|
82 |
+
h, h_lengths = self.length_regulator(h, feat_len)
|
83 |
+
|
84 |
+
# get conditions
|
85 |
+
conds = torch.zeros(feat.shape, device=token.device)
|
86 |
+
for i, j in enumerate(feat_len):
|
87 |
+
if random.random() < 0.5:
|
88 |
+
continue
|
89 |
+
index = random.randint(0, int(0.3 * j))
|
90 |
+
conds[i, :index] = feat[i, :index]
|
91 |
+
conds = conds.transpose(1, 2)
|
92 |
+
|
93 |
+
mask = (~make_pad_mask(feat_len)).to(h)
|
94 |
+
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
95 |
+
loss, _ = self.decoder.compute_loss(
|
96 |
+
feat.transpose(1, 2).contiguous(),
|
97 |
+
mask.unsqueeze(1),
|
98 |
+
h.transpose(1, 2).contiguous(),
|
99 |
+
embedding,
|
100 |
+
cond=conds
|
101 |
+
)
|
102 |
+
return {'loss': loss}
|
103 |
+
|
104 |
+
@torch.inference_mode()
|
105 |
+
def inference(self,
|
106 |
+
token,
|
107 |
+
token_len,
|
108 |
+
prompt_token,
|
109 |
+
prompt_token_len,
|
110 |
+
prompt_feat,
|
111 |
+
prompt_feat_len,
|
112 |
+
embedding,
|
113 |
+
flow_cache):
|
114 |
+
assert token.shape[0] == 1
|
115 |
+
# xvec projection
|
116 |
+
embedding = F.normalize(embedding, dim=1)
|
117 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
118 |
+
|
119 |
+
# concat text and prompt_text
|
120 |
+
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
121 |
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
122 |
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
123 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
124 |
+
|
125 |
+
# text encode
|
126 |
+
h, h_lengths = self.encoder(token, token_len)
|
127 |
+
h = self.encoder_proj(h)
|
128 |
+
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
|
129 |
+
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
130 |
+
|
131 |
+
# get conditions
|
132 |
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
133 |
+
conds[:, :mel_len1] = prompt_feat
|
134 |
+
conds = conds.transpose(1, 2)
|
135 |
+
|
136 |
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
137 |
+
feat, flow_cache = self.decoder(
|
138 |
+
mu=h.transpose(1, 2).contiguous(),
|
139 |
+
mask=mask.unsqueeze(1),
|
140 |
+
spks=embedding,
|
141 |
+
cond=conds,
|
142 |
+
n_timesteps=10,
|
143 |
+
prompt_len=mel_len1,
|
144 |
+
flow_cache=flow_cache
|
145 |
+
)
|
146 |
+
feat = feat[:, :, mel_len1:]
|
147 |
+
assert feat.shape[2] == mel_len2
|
148 |
+
return feat, flow_cache
|
149 |
+
|
150 |
+
|
151 |
+
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
152 |
+
def __init__(self,
|
153 |
+
input_size: int = 512,
|
154 |
+
output_size: int = 80,
|
155 |
+
spk_embed_dim: int = 192,
|
156 |
+
output_type: str = "mel",
|
157 |
+
vocab_size: int = 4096,
|
158 |
+
input_frame_rate: int = 50,
|
159 |
+
only_mask_loss: bool = True,
|
160 |
+
token_mel_ratio: int = 2,
|
161 |
+
pre_lookahead_len: int = 3,
|
162 |
+
encoder: torch.nn.Module = None,
|
163 |
+
decoder: torch.nn.Module = None,
|
164 |
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
165 |
+
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
166 |
+
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
167 |
+
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
168 |
+
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
169 |
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
170 |
+
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
171 |
+
super().__init__()
|
172 |
+
self.input_size = input_size
|
173 |
+
self.output_size = output_size
|
174 |
+
self.decoder_conf = decoder_conf
|
175 |
+
self.mel_feat_conf = mel_feat_conf
|
176 |
+
self.vocab_size = vocab_size
|
177 |
+
self.output_type = output_type
|
178 |
+
self.input_frame_rate = input_frame_rate
|
179 |
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
180 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
181 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
182 |
+
self.encoder = encoder
|
183 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
184 |
+
self.decoder = decoder
|
185 |
+
self.only_mask_loss = only_mask_loss
|
186 |
+
self.token_mel_ratio = token_mel_ratio
|
187 |
+
self.pre_lookahead_len = pre_lookahead_len
|
188 |
+
|
189 |
+
@torch.inference_mode()
|
190 |
+
def inference(self,
|
191 |
+
token,
|
192 |
+
token_len,
|
193 |
+
prompt_token,
|
194 |
+
prompt_token_len,
|
195 |
+
prompt_feat,
|
196 |
+
prompt_feat_len,
|
197 |
+
embedding,
|
198 |
+
finalize):
|
199 |
+
assert token.shape[0] == 1
|
200 |
+
# xvec projection
|
201 |
+
embedding = F.normalize(embedding, dim=1)
|
202 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
203 |
+
|
204 |
+
# concat text and prompt_text
|
205 |
+
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
206 |
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
207 |
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
208 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
209 |
+
|
210 |
+
# text encode
|
211 |
+
h, h_lengths = self.encoder(token, token_len)
|
212 |
+
if finalize is False:
|
213 |
+
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
|
214 |
+
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
215 |
+
h = self.encoder_proj(h)
|
216 |
+
|
217 |
+
# get conditions
|
218 |
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
219 |
+
conds[:, :mel_len1] = prompt_feat
|
220 |
+
conds = conds.transpose(1, 2)
|
221 |
+
|
222 |
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
223 |
+
feat, _ = self.decoder(
|
224 |
+
mu=h.transpose(1, 2).contiguous(),
|
225 |
+
mask=mask.unsqueeze(1),
|
226 |
+
spks=embedding,
|
227 |
+
cond=conds,
|
228 |
+
n_timesteps=10
|
229 |
+
)
|
230 |
+
feat = feat[:, :, mel_len1:]
|
231 |
+
assert feat.shape[2] == mel_len2
|
232 |
+
return feat, None
|
cosyvoice/flow/flow_matching.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import onnxruntime
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from matcha.models.components.flow_matching import BASECFM
|
18 |
+
|
19 |
+
|
20 |
+
class ConditionalCFM(BASECFM):
|
21 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
22 |
+
super().__init__(
|
23 |
+
n_feats=in_channels,
|
24 |
+
cfm_params=cfm_params,
|
25 |
+
n_spks=n_spks,
|
26 |
+
spk_emb_dim=spk_emb_dim,
|
27 |
+
)
|
28 |
+
self.t_scheduler = cfm_params.t_scheduler
|
29 |
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
30 |
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
31 |
+
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
32 |
+
# Just change the architecture of the estimator here
|
33 |
+
self.estimator = estimator
|
34 |
+
|
35 |
+
@torch.inference_mode()
|
36 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
|
37 |
+
"""Forward diffusion
|
38 |
+
|
39 |
+
Args:
|
40 |
+
mu (torch.Tensor): output of encoder
|
41 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
42 |
+
mask (torch.Tensor): output_mask
|
43 |
+
shape: (batch_size, 1, mel_timesteps)
|
44 |
+
n_timesteps (int): number of diffusion steps
|
45 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
46 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
47 |
+
shape: (batch_size, spk_emb_dim)
|
48 |
+
cond: Not used but kept for future purposes
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
sample: generated mel-spectrogram
|
52 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
53 |
+
"""
|
54 |
+
|
55 |
+
z = torch.randn_like(mu) * temperature
|
56 |
+
cache_size = flow_cache.shape[2]
|
57 |
+
# fix prompt and overlap part mu and z
|
58 |
+
if cache_size != 0:
|
59 |
+
z[:, :, :cache_size] = flow_cache[:, :, :, 0]
|
60 |
+
mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
|
61 |
+
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
|
62 |
+
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
|
63 |
+
flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
|
64 |
+
|
65 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
66 |
+
if self.t_scheduler == 'cosine':
|
67 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
68 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
|
69 |
+
|
70 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
71 |
+
"""
|
72 |
+
Fixed euler solver for ODEs.
|
73 |
+
Args:
|
74 |
+
x (torch.Tensor): random noise
|
75 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
76 |
+
shape: (n_timesteps + 1,)
|
77 |
+
mu (torch.Tensor): output of encoder
|
78 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
79 |
+
mask (torch.Tensor): output_mask
|
80 |
+
shape: (batch_size, 1, mel_timesteps)
|
81 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
82 |
+
shape: (batch_size, spk_emb_dim)
|
83 |
+
cond: Not used but kept for future purposes
|
84 |
+
"""
|
85 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
86 |
+
t = t.unsqueeze(dim=0)
|
87 |
+
|
88 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
89 |
+
# Or in future might add like a return_all_steps flag
|
90 |
+
sol = []
|
91 |
+
|
92 |
+
if self.inference_cfg_rate > 0:
|
93 |
+
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
94 |
+
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
95 |
+
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
96 |
+
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
97 |
+
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
98 |
+
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
99 |
+
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
100 |
+
else:
|
101 |
+
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
|
102 |
+
for step in range(1, len(t_span)):
|
103 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
104 |
+
if self.inference_cfg_rate > 0:
|
105 |
+
x_in[:] = x
|
106 |
+
mask_in[:] = mask
|
107 |
+
mu_in[0] = mu
|
108 |
+
t_in[:] = t.unsqueeze(0)
|
109 |
+
spks_in[0] = spks
|
110 |
+
cond_in[0] = cond
|
111 |
+
else:
|
112 |
+
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
|
113 |
+
dphi_dt = self.forward_estimator(
|
114 |
+
x_in, mask_in,
|
115 |
+
mu_in, t_in,
|
116 |
+
spks_in,
|
117 |
+
cond_in
|
118 |
+
)
|
119 |
+
if self.inference_cfg_rate > 0:
|
120 |
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
121 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
122 |
+
x = x + dt * dphi_dt
|
123 |
+
t = t + dt
|
124 |
+
sol.append(x)
|
125 |
+
if step < len(t_span) - 1:
|
126 |
+
dt = t_span[step + 1] - t
|
127 |
+
|
128 |
+
return sol[-1].float()
|
129 |
+
|
130 |
+
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
131 |
+
if isinstance(self.estimator, torch.nn.Module):
|
132 |
+
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
133 |
+
elif isinstance(self.estimator, onnxruntime.InferenceSession):
|
134 |
+
ort_inputs = {
|
135 |
+
'x': x.cpu().numpy(),
|
136 |
+
'mask': mask.cpu().numpy(),
|
137 |
+
'mu': mu.cpu().numpy(),
|
138 |
+
't': t.cpu().numpy(),
|
139 |
+
'spks': spks.cpu().numpy(),
|
140 |
+
'cond': cond.cpu().numpy()
|
141 |
+
}
|
142 |
+
output = self.estimator.run(None, ort_inputs)[0]
|
143 |
+
return torch.tensor(output, dtype=x.dtype, device=x.device)
|
144 |
+
else:
|
145 |
+
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
146 |
+
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
147 |
+
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
148 |
+
self.estimator.set_input_shape('t', (2,))
|
149 |
+
self.estimator.set_input_shape('spks', (2, 80))
|
150 |
+
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
151 |
+
# run trt engine
|
152 |
+
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
153 |
+
mask.contiguous().data_ptr(),
|
154 |
+
mu.contiguous().data_ptr(),
|
155 |
+
t.contiguous().data_ptr(),
|
156 |
+
spks.contiguous().data_ptr(),
|
157 |
+
cond.contiguous().data_ptr(),
|
158 |
+
x.data_ptr()])
|
159 |
+
return x
|
160 |
+
|
161 |
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
162 |
+
"""Computes diffusion loss
|
163 |
+
|
164 |
+
Args:
|
165 |
+
x1 (torch.Tensor): Target
|
166 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
167 |
+
mask (torch.Tensor): target mask
|
168 |
+
shape: (batch_size, 1, mel_timesteps)
|
169 |
+
mu (torch.Tensor): output of encoder
|
170 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
171 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
172 |
+
shape: (batch_size, spk_emb_dim)
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
loss: conditional flow matching loss
|
176 |
+
y: conditional flow
|
177 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
178 |
+
"""
|
179 |
+
b, _, t = mu.shape
|
180 |
+
|
181 |
+
# random timestep
|
182 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
183 |
+
if self.t_scheduler == 'cosine':
|
184 |
+
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
185 |
+
# sample noise p(x_0)
|
186 |
+
z = torch.randn_like(x1)
|
187 |
+
|
188 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
189 |
+
u = x1 - (1 - self.sigma_min) * z
|
190 |
+
|
191 |
+
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
192 |
+
if self.training_cfg_rate > 0:
|
193 |
+
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
194 |
+
mu = mu * cfg_mask.view(-1, 1, 1)
|
195 |
+
spks = spks * cfg_mask.view(-1, 1)
|
196 |
+
cond = cond * cfg_mask.view(-1, 1, 1)
|
197 |
+
|
198 |
+
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
199 |
+
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
200 |
+
return loss, y
|
201 |
+
|
202 |
+
|
203 |
+
class CausalConditionalCFM(ConditionalCFM):
|
204 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
205 |
+
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
206 |
+
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
207 |
+
|
208 |
+
@torch.inference_mode()
|
209 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
210 |
+
"""Forward diffusion
|
211 |
+
|
212 |
+
Args:
|
213 |
+
mu (torch.Tensor): output of encoder
|
214 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
215 |
+
mask (torch.Tensor): output_mask
|
216 |
+
shape: (batch_size, 1, mel_timesteps)
|
217 |
+
n_timesteps (int): number of diffusion steps
|
218 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
219 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
220 |
+
shape: (batch_size, spk_emb_dim)
|
221 |
+
cond: Not used but kept for future purposes
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
sample: generated mel-spectrogram
|
225 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
226 |
+
"""
|
227 |
+
|
228 |
+
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
|
229 |
+
if self.fp16 is True:
|
230 |
+
z = z.half()
|
231 |
+
# fix prompt and overlap part mu and z
|
232 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
233 |
+
if self.t_scheduler == 'cosine':
|
234 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
235 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
|
cosyvoice/flow/length_regulator.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Tuple
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch
|
17 |
+
from torch.nn import functional as F
|
18 |
+
from cosyvoice.utils.mask import make_pad_mask
|
19 |
+
|
20 |
+
|
21 |
+
class InterpolateRegulator(nn.Module):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
channels: int,
|
25 |
+
sampling_ratios: Tuple,
|
26 |
+
out_channels: int = None,
|
27 |
+
groups: int = 1,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.sampling_ratios = sampling_ratios
|
31 |
+
out_channels = out_channels or channels
|
32 |
+
model = nn.ModuleList([])
|
33 |
+
if len(sampling_ratios) > 0:
|
34 |
+
for _ in sampling_ratios:
|
35 |
+
module = nn.Conv1d(channels, channels, 3, 1, 1)
|
36 |
+
norm = nn.GroupNorm(groups, channels)
|
37 |
+
act = nn.Mish()
|
38 |
+
model.extend([module, norm, act])
|
39 |
+
model.append(
|
40 |
+
nn.Conv1d(channels, out_channels, 1, 1)
|
41 |
+
)
|
42 |
+
self.model = nn.Sequential(*model)
|
43 |
+
|
44 |
+
def forward(self, x, ylens=None):
|
45 |
+
# x in (B, T, D)
|
46 |
+
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
47 |
+
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
|
48 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
49 |
+
olens = ylens
|
50 |
+
return out * mask, olens
|
51 |
+
|
52 |
+
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
|
53 |
+
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
|
54 |
+
# x in (B, T, D)
|
55 |
+
if x2.shape[1] > 40:
|
56 |
+
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
57 |
+
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
|
58 |
+
mode='linear')
|
59 |
+
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
60 |
+
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
|
61 |
+
else:
|
62 |
+
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
|
63 |
+
if x1.shape[1] != 0:
|
64 |
+
x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
|
65 |
+
x = torch.concat([x1, x2], dim=2)
|
66 |
+
else:
|
67 |
+
x = x2
|
68 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
69 |
+
return out, mel_len1 + mel_len2
|
cosyvoice/hifigan/discriminator.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn.utils import weight_norm
|
4 |
+
from typing import List, Optional, Tuple
|
5 |
+
from einops import rearrange
|
6 |
+
from torchaudio.transforms import Spectrogram
|
7 |
+
|
8 |
+
|
9 |
+
class MultipleDiscriminator(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self, mpd: nn.Module, mrd: nn.Module
|
12 |
+
):
|
13 |
+
super().__init__()
|
14 |
+
self.mpd = mpd
|
15 |
+
self.mrd = mrd
|
16 |
+
|
17 |
+
def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
|
18 |
+
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
|
19 |
+
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
|
20 |
+
y_d_rs += this_y_d_rs
|
21 |
+
y_d_gs += this_y_d_gs
|
22 |
+
fmap_rs += this_fmap_rs
|
23 |
+
fmap_gs += this_fmap_gs
|
24 |
+
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
|
25 |
+
y_d_rs += this_y_d_rs
|
26 |
+
y_d_gs += this_y_d_gs
|
27 |
+
fmap_rs += this_fmap_rs
|
28 |
+
fmap_gs += this_fmap_gs
|
29 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
30 |
+
|
31 |
+
|
32 |
+
class MultiResolutionDiscriminator(nn.Module):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
|
36 |
+
num_embeddings: Optional[int] = None,
|
37 |
+
):
|
38 |
+
"""
|
39 |
+
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
|
40 |
+
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
|
44 |
+
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
45 |
+
Defaults to None.
|
46 |
+
"""
|
47 |
+
|
48 |
+
super().__init__()
|
49 |
+
self.discriminators = nn.ModuleList(
|
50 |
+
[DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
|
51 |
+
)
|
52 |
+
|
53 |
+
def forward(
|
54 |
+
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
|
55 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
56 |
+
y_d_rs = []
|
57 |
+
y_d_gs = []
|
58 |
+
fmap_rs = []
|
59 |
+
fmap_gs = []
|
60 |
+
|
61 |
+
for d in self.discriminators:
|
62 |
+
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
63 |
+
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
64 |
+
y_d_rs.append(y_d_r)
|
65 |
+
fmap_rs.append(fmap_r)
|
66 |
+
y_d_gs.append(y_d_g)
|
67 |
+
fmap_gs.append(fmap_g)
|
68 |
+
|
69 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
70 |
+
|
71 |
+
|
72 |
+
class DiscriminatorR(nn.Module):
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
window_length: int,
|
76 |
+
num_embeddings: Optional[int] = None,
|
77 |
+
channels: int = 32,
|
78 |
+
hop_factor: float = 0.25,
|
79 |
+
bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
|
80 |
+
):
|
81 |
+
super().__init__()
|
82 |
+
self.window_length = window_length
|
83 |
+
self.hop_factor = hop_factor
|
84 |
+
self.spec_fn = Spectrogram(
|
85 |
+
n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
|
86 |
+
)
|
87 |
+
n_fft = window_length // 2 + 1
|
88 |
+
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
89 |
+
self.bands = bands
|
90 |
+
convs = lambda: nn.ModuleList(
|
91 |
+
[
|
92 |
+
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
93 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
94 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
95 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
96 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
97 |
+
]
|
98 |
+
)
|
99 |
+
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
100 |
+
|
101 |
+
if num_embeddings is not None:
|
102 |
+
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
|
103 |
+
torch.nn.init.zeros_(self.emb.weight)
|
104 |
+
|
105 |
+
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
106 |
+
|
107 |
+
def spectrogram(self, x):
|
108 |
+
# Remove DC offset
|
109 |
+
x = x - x.mean(dim=-1, keepdims=True)
|
110 |
+
# Peak normalize the volume of input audio
|
111 |
+
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
112 |
+
x = self.spec_fn(x)
|
113 |
+
x = torch.view_as_real(x)
|
114 |
+
x = rearrange(x, "b f t c -> b c t f")
|
115 |
+
# Split into bands
|
116 |
+
x_bands = [x[..., b[0]: b[1]] for b in self.bands]
|
117 |
+
return x_bands
|
118 |
+
|
119 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
|
120 |
+
x_bands = self.spectrogram(x)
|
121 |
+
fmap = []
|
122 |
+
x = []
|
123 |
+
for band, stack in zip(x_bands, self.band_convs):
|
124 |
+
for i, layer in enumerate(stack):
|
125 |
+
band = layer(band)
|
126 |
+
band = torch.nn.functional.leaky_relu(band, 0.1)
|
127 |
+
if i > 0:
|
128 |
+
fmap.append(band)
|
129 |
+
x.append(band)
|
130 |
+
x = torch.cat(x, dim=-1)
|
131 |
+
if cond_embedding_id is not None:
|
132 |
+
emb = self.emb(cond_embedding_id)
|
133 |
+
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
134 |
+
else:
|
135 |
+
h = 0
|
136 |
+
x = self.conv_post(x)
|
137 |
+
fmap.append(x)
|
138 |
+
x += h
|
139 |
+
|
140 |
+
return x, fmap
|
cosyvoice/hifigan/f0_predictor.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.nn.utils import weight_norm
|
17 |
+
|
18 |
+
|
19 |
+
class ConvRNNF0Predictor(nn.Module):
|
20 |
+
def __init__(self,
|
21 |
+
num_class: int = 1,
|
22 |
+
in_channels: int = 80,
|
23 |
+
cond_channels: int = 512
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
self.num_class = num_class
|
28 |
+
self.condnet = nn.Sequential(
|
29 |
+
weight_norm(
|
30 |
+
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
|
31 |
+
),
|
32 |
+
nn.ELU(),
|
33 |
+
weight_norm(
|
34 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
35 |
+
),
|
36 |
+
nn.ELU(),
|
37 |
+
weight_norm(
|
38 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
39 |
+
),
|
40 |
+
nn.ELU(),
|
41 |
+
weight_norm(
|
42 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
43 |
+
),
|
44 |
+
nn.ELU(),
|
45 |
+
weight_norm(
|
46 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
47 |
+
),
|
48 |
+
nn.ELU(),
|
49 |
+
)
|
50 |
+
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
51 |
+
|
52 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
53 |
+
x = self.condnet(x)
|
54 |
+
x = x.transpose(1, 2)
|
55 |
+
return torch.abs(self.classifier(x).squeeze(-1))
|
cosyvoice/hifigan/generator.py
ADDED
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""HIFI-GAN"""
|
16 |
+
|
17 |
+
from typing import Dict, Optional, List
|
18 |
+
import numpy as np
|
19 |
+
from scipy.signal import get_window
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from torch.nn import Conv1d
|
24 |
+
from torch.nn import ConvTranspose1d
|
25 |
+
from torch.nn.utils import remove_weight_norm
|
26 |
+
from torch.nn.utils import weight_norm
|
27 |
+
from torch.distributions.uniform import Uniform
|
28 |
+
|
29 |
+
from cosyvoice.transformer.activation import Snake
|
30 |
+
from cosyvoice.utils.common import get_padding
|
31 |
+
from cosyvoice.utils.common import init_weights
|
32 |
+
|
33 |
+
|
34 |
+
"""hifigan based generator implementation.
|
35 |
+
|
36 |
+
This code is modified from https://github.com/jik876/hifi-gan
|
37 |
+
,https://github.com/kan-bayashi/ParallelWaveGAN and
|
38 |
+
https://github.com/NVIDIA/BigVGAN
|
39 |
+
|
40 |
+
"""
|
41 |
+
|
42 |
+
|
43 |
+
class ResBlock(torch.nn.Module):
|
44 |
+
"""Residual block module in HiFiGAN/BigVGAN."""
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
channels: int = 512,
|
48 |
+
kernel_size: int = 3,
|
49 |
+
dilations: List[int] = [1, 3, 5],
|
50 |
+
):
|
51 |
+
super(ResBlock, self).__init__()
|
52 |
+
self.convs1 = nn.ModuleList()
|
53 |
+
self.convs2 = nn.ModuleList()
|
54 |
+
|
55 |
+
for dilation in dilations:
|
56 |
+
self.convs1.append(
|
57 |
+
weight_norm(
|
58 |
+
Conv1d(
|
59 |
+
channels,
|
60 |
+
channels,
|
61 |
+
kernel_size,
|
62 |
+
1,
|
63 |
+
dilation=dilation,
|
64 |
+
padding=get_padding(kernel_size, dilation)
|
65 |
+
)
|
66 |
+
)
|
67 |
+
)
|
68 |
+
self.convs2.append(
|
69 |
+
weight_norm(
|
70 |
+
Conv1d(
|
71 |
+
channels,
|
72 |
+
channels,
|
73 |
+
kernel_size,
|
74 |
+
1,
|
75 |
+
dilation=1,
|
76 |
+
padding=get_padding(kernel_size, 1)
|
77 |
+
)
|
78 |
+
)
|
79 |
+
)
|
80 |
+
self.convs1.apply(init_weights)
|
81 |
+
self.convs2.apply(init_weights)
|
82 |
+
self.activations1 = nn.ModuleList([
|
83 |
+
Snake(channels, alpha_logscale=False)
|
84 |
+
for _ in range(len(self.convs1))
|
85 |
+
])
|
86 |
+
self.activations2 = nn.ModuleList([
|
87 |
+
Snake(channels, alpha_logscale=False)
|
88 |
+
for _ in range(len(self.convs2))
|
89 |
+
])
|
90 |
+
|
91 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
92 |
+
for idx in range(len(self.convs1)):
|
93 |
+
xt = self.activations1[idx](x)
|
94 |
+
xt = self.convs1[idx](xt)
|
95 |
+
xt = self.activations2[idx](xt)
|
96 |
+
xt = self.convs2[idx](xt)
|
97 |
+
x = xt + x
|
98 |
+
return x
|
99 |
+
|
100 |
+
def remove_weight_norm(self):
|
101 |
+
for idx in range(len(self.convs1)):
|
102 |
+
remove_weight_norm(self.convs1[idx])
|
103 |
+
remove_weight_norm(self.convs2[idx])
|
104 |
+
|
105 |
+
|
106 |
+
class SineGen(torch.nn.Module):
|
107 |
+
""" Definition of sine generator
|
108 |
+
SineGen(samp_rate, harmonic_num = 0,
|
109 |
+
sine_amp = 0.1, noise_std = 0.003,
|
110 |
+
voiced_threshold = 0,
|
111 |
+
flag_for_pulse=False)
|
112 |
+
samp_rate: sampling rate in Hz
|
113 |
+
harmonic_num: number of harmonic overtones (default 0)
|
114 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
115 |
+
noise_std: std of Gaussian noise (default 0.003)
|
116 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
117 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
118 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
119 |
+
segment is always sin(np.pi) or cos(0)
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(self, samp_rate, harmonic_num=0,
|
123 |
+
sine_amp=0.1, noise_std=0.003,
|
124 |
+
voiced_threshold=0):
|
125 |
+
super(SineGen, self).__init__()
|
126 |
+
self.sine_amp = sine_amp
|
127 |
+
self.noise_std = noise_std
|
128 |
+
self.harmonic_num = harmonic_num
|
129 |
+
self.sampling_rate = samp_rate
|
130 |
+
self.voiced_threshold = voiced_threshold
|
131 |
+
|
132 |
+
def _f02uv(self, f0):
|
133 |
+
# generate uv signal
|
134 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
135 |
+
return uv
|
136 |
+
|
137 |
+
@torch.no_grad()
|
138 |
+
def forward(self, f0):
|
139 |
+
"""
|
140 |
+
:param f0: [B, 1, sample_len], Hz
|
141 |
+
:return: [B, 1, sample_len]
|
142 |
+
"""
|
143 |
+
|
144 |
+
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
145 |
+
for i in range(self.harmonic_num + 1):
|
146 |
+
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
147 |
+
|
148 |
+
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
149 |
+
u_dist = Uniform(low=-np.pi, high=np.pi)
|
150 |
+
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
|
151 |
+
phase_vec[:, 0, :] = 0
|
152 |
+
|
153 |
+
# generate sine waveforms
|
154 |
+
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
|
155 |
+
|
156 |
+
# generate uv signal
|
157 |
+
uv = self._f02uv(f0)
|
158 |
+
|
159 |
+
# noise: for unvoiced should be similar to sine_amp
|
160 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
161 |
+
# . for voiced regions is self.noise_std
|
162 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
163 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
164 |
+
|
165 |
+
# first: set the unvoiced part to 0 by uv
|
166 |
+
# then: additive noise
|
167 |
+
sine_waves = sine_waves * uv + noise
|
168 |
+
return sine_waves, uv, noise
|
169 |
+
|
170 |
+
|
171 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
172 |
+
""" SourceModule for hn-nsf
|
173 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
174 |
+
add_noise_std=0.003, voiced_threshod=0)
|
175 |
+
sampling_rate: sampling_rate in Hz
|
176 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
177 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
178 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
179 |
+
note that amplitude of noise in unvoiced is decided
|
180 |
+
by sine_amp
|
181 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
182 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
183 |
+
F0_sampled (batchsize, length, 1)
|
184 |
+
Sine_source (batchsize, length, 1)
|
185 |
+
noise_source (batchsize, length 1)
|
186 |
+
uv (batchsize, length, 1)
|
187 |
+
"""
|
188 |
+
|
189 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
190 |
+
add_noise_std=0.003, voiced_threshod=0):
|
191 |
+
super(SourceModuleHnNSF, self).__init__()
|
192 |
+
|
193 |
+
self.sine_amp = sine_amp
|
194 |
+
self.noise_std = add_noise_std
|
195 |
+
|
196 |
+
# to produce sine waveforms
|
197 |
+
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
198 |
+
sine_amp, add_noise_std, voiced_threshod)
|
199 |
+
|
200 |
+
# to merge source harmonics into a single excitation
|
201 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
202 |
+
self.l_tanh = torch.nn.Tanh()
|
203 |
+
|
204 |
+
def forward(self, x):
|
205 |
+
"""
|
206 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
207 |
+
F0_sampled (batchsize, length, 1)
|
208 |
+
Sine_source (batchsize, length, 1)
|
209 |
+
noise_source (batchsize, length 1)
|
210 |
+
"""
|
211 |
+
# source for harmonic branch
|
212 |
+
with torch.no_grad():
|
213 |
+
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
|
214 |
+
sine_wavs = sine_wavs.transpose(1, 2)
|
215 |
+
uv = uv.transpose(1, 2)
|
216 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
217 |
+
|
218 |
+
# source for noise branch, in the same shape as uv
|
219 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
220 |
+
return sine_merge, noise, uv
|
221 |
+
|
222 |
+
|
223 |
+
class HiFTGenerator(nn.Module):
|
224 |
+
"""
|
225 |
+
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
226 |
+
https://arxiv.org/abs/2309.09493
|
227 |
+
"""
|
228 |
+
def __init__(
|
229 |
+
self,
|
230 |
+
in_channels: int = 80,
|
231 |
+
base_channels: int = 512,
|
232 |
+
nb_harmonics: int = 8,
|
233 |
+
sampling_rate: int = 22050,
|
234 |
+
nsf_alpha: float = 0.1,
|
235 |
+
nsf_sigma: float = 0.003,
|
236 |
+
nsf_voiced_threshold: float = 10,
|
237 |
+
upsample_rates: List[int] = [8, 8],
|
238 |
+
upsample_kernel_sizes: List[int] = [16, 16],
|
239 |
+
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
240 |
+
resblock_kernel_sizes: List[int] = [3, 7, 11],
|
241 |
+
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
242 |
+
source_resblock_kernel_sizes: List[int] = [7, 11],
|
243 |
+
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
|
244 |
+
lrelu_slope: float = 0.1,
|
245 |
+
audio_limit: float = 0.99,
|
246 |
+
f0_predictor: torch.nn.Module = None,
|
247 |
+
):
|
248 |
+
super(HiFTGenerator, self).__init__()
|
249 |
+
|
250 |
+
self.out_channels = 1
|
251 |
+
self.nb_harmonics = nb_harmonics
|
252 |
+
self.sampling_rate = sampling_rate
|
253 |
+
self.istft_params = istft_params
|
254 |
+
self.lrelu_slope = lrelu_slope
|
255 |
+
self.audio_limit = audio_limit
|
256 |
+
|
257 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
258 |
+
self.num_upsamples = len(upsample_rates)
|
259 |
+
self.m_source = SourceModuleHnNSF(
|
260 |
+
sampling_rate=sampling_rate,
|
261 |
+
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
262 |
+
harmonic_num=nb_harmonics,
|
263 |
+
sine_amp=nsf_alpha,
|
264 |
+
add_noise_std=nsf_sigma,
|
265 |
+
voiced_threshod=nsf_voiced_threshold)
|
266 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
267 |
+
|
268 |
+
self.conv_pre = weight_norm(
|
269 |
+
Conv1d(in_channels, base_channels, 7, 1, padding=3)
|
270 |
+
)
|
271 |
+
|
272 |
+
# Up
|
273 |
+
self.ups = nn.ModuleList()
|
274 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
275 |
+
self.ups.append(
|
276 |
+
weight_norm(
|
277 |
+
ConvTranspose1d(
|
278 |
+
base_channels // (2**i),
|
279 |
+
base_channels // (2**(i + 1)),
|
280 |
+
k,
|
281 |
+
u,
|
282 |
+
padding=(k - u) // 2,
|
283 |
+
)
|
284 |
+
)
|
285 |
+
)
|
286 |
+
|
287 |
+
# Down
|
288 |
+
self.source_downs = nn.ModuleList()
|
289 |
+
self.source_resblocks = nn.ModuleList()
|
290 |
+
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
291 |
+
downsample_cum_rates = np.cumprod(downsample_rates)
|
292 |
+
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
|
293 |
+
if u == 1:
|
294 |
+
self.source_downs.append(
|
295 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
296 |
+
)
|
297 |
+
else:
|
298 |
+
self.source_downs.append(
|
299 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
|
300 |
+
)
|
301 |
+
|
302 |
+
self.source_resblocks.append(
|
303 |
+
ResBlock(base_channels // (2 ** (i + 1)), k, d)
|
304 |
+
)
|
305 |
+
|
306 |
+
self.resblocks = nn.ModuleList()
|
307 |
+
for i in range(len(self.ups)):
|
308 |
+
ch = base_channels // (2**(i + 1))
|
309 |
+
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
310 |
+
self.resblocks.append(ResBlock(ch, k, d))
|
311 |
+
|
312 |
+
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
|
313 |
+
self.ups.apply(init_weights)
|
314 |
+
self.conv_post.apply(init_weights)
|
315 |
+
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
316 |
+
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
317 |
+
self.f0_predictor = f0_predictor
|
318 |
+
|
319 |
+
def remove_weight_norm(self):
|
320 |
+
print('Removing weight norm...')
|
321 |
+
for l in self.ups:
|
322 |
+
remove_weight_norm(l)
|
323 |
+
for l in self.resblocks:
|
324 |
+
l.remove_weight_norm()
|
325 |
+
remove_weight_norm(self.conv_pre)
|
326 |
+
remove_weight_norm(self.conv_post)
|
327 |
+
self.m_source.remove_weight_norm()
|
328 |
+
for l in self.source_downs:
|
329 |
+
remove_weight_norm(l)
|
330 |
+
for l in self.source_resblocks:
|
331 |
+
l.remove_weight_norm()
|
332 |
+
|
333 |
+
def _stft(self, x):
|
334 |
+
spec = torch.stft(
|
335 |
+
x,
|
336 |
+
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
|
337 |
+
return_complex=True)
|
338 |
+
spec = torch.view_as_real(spec) # [B, F, TT, 2]
|
339 |
+
return spec[..., 0], spec[..., 1]
|
340 |
+
|
341 |
+
def _istft(self, magnitude, phase):
|
342 |
+
magnitude = torch.clip(magnitude, max=1e2)
|
343 |
+
real = magnitude * torch.cos(phase)
|
344 |
+
img = magnitude * torch.sin(phase)
|
345 |
+
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
|
346 |
+
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
347 |
+
return inverse_transform
|
348 |
+
|
349 |
+
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
350 |
+
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
351 |
+
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
352 |
+
|
353 |
+
x = self.conv_pre(x)
|
354 |
+
for i in range(self.num_upsamples):
|
355 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
356 |
+
x = self.ups[i](x)
|
357 |
+
|
358 |
+
if i == self.num_upsamples - 1:
|
359 |
+
x = self.reflection_pad(x)
|
360 |
+
|
361 |
+
# fusion
|
362 |
+
si = self.source_downs[i](s_stft)
|
363 |
+
si = self.source_resblocks[i](si)
|
364 |
+
x = x + si
|
365 |
+
|
366 |
+
xs = None
|
367 |
+
for j in range(self.num_kernels):
|
368 |
+
if xs is None:
|
369 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
370 |
+
else:
|
371 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
372 |
+
x = xs / self.num_kernels
|
373 |
+
|
374 |
+
x = F.leaky_relu(x)
|
375 |
+
x = self.conv_post(x)
|
376 |
+
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
|
377 |
+
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
|
378 |
+
|
379 |
+
x = self._istft(magnitude, phase)
|
380 |
+
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
381 |
+
return x
|
382 |
+
|
383 |
+
def forward(
|
384 |
+
self,
|
385 |
+
batch: dict,
|
386 |
+
device: torch.device,
|
387 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
388 |
+
speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
|
389 |
+
# mel->f0
|
390 |
+
f0 = self.f0_predictor(speech_feat)
|
391 |
+
# f0->source
|
392 |
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
393 |
+
s, _, _ = self.m_source(s)
|
394 |
+
s = s.transpose(1, 2)
|
395 |
+
# mel+source->speech
|
396 |
+
generated_speech = self.decode(x=speech_feat, s=s)
|
397 |
+
return generated_speech, f0
|
398 |
+
|
399 |
+
@torch.inference_mode()
|
400 |
+
def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
401 |
+
# mel->f0
|
402 |
+
f0 = self.f0_predictor(speech_feat)
|
403 |
+
# f0->source
|
404 |
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
405 |
+
s, _, _ = self.m_source(s)
|
406 |
+
s = s.transpose(1, 2)
|
407 |
+
# use cache_source to avoid glitch
|
408 |
+
if cache_source.shape[2] != 0:
|
409 |
+
s[:, :, :cache_source.shape[2]] = cache_source
|
410 |
+
generated_speech = self.decode(x=speech_feat, s=s)
|
411 |
+
return generated_speech, s
|
cosyvoice/hifigan/hifigan.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
|
6 |
+
from cosyvoice.utils.losses import tpr_loss, mel_loss
|
7 |
+
|
8 |
+
|
9 |
+
class HiFiGan(nn.Module):
|
10 |
+
def __init__(self, generator, discriminator, mel_spec_transform,
|
11 |
+
multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
|
12 |
+
tpr_loss_weight=1.0, tpr_loss_tau=0.04):
|
13 |
+
super(HiFiGan, self).__init__()
|
14 |
+
self.generator = generator
|
15 |
+
self.discriminator = discriminator
|
16 |
+
self.mel_spec_transform = mel_spec_transform
|
17 |
+
self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
|
18 |
+
self.feat_match_loss_weight = feat_match_loss_weight
|
19 |
+
self.tpr_loss_weight = tpr_loss_weight
|
20 |
+
self.tpr_loss_tau = tpr_loss_tau
|
21 |
+
|
22 |
+
def forward(
|
23 |
+
self,
|
24 |
+
batch: dict,
|
25 |
+
device: torch.device,
|
26 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
27 |
+
if batch['turn'] == 'generator':
|
28 |
+
return self.forward_generator(batch, device)
|
29 |
+
else:
|
30 |
+
return self.forward_discriminator(batch, device)
|
31 |
+
|
32 |
+
def forward_generator(self, batch, device):
|
33 |
+
real_speech = batch['speech'].to(device)
|
34 |
+
pitch_feat = batch['pitch_feat'].to(device)
|
35 |
+
# 1. calculate generator outputs
|
36 |
+
generated_speech, generated_f0 = self.generator(batch, device)
|
37 |
+
# 2. calculate discriminator outputs
|
38 |
+
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
|
39 |
+
# 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
|
40 |
+
loss_gen, _ = generator_loss(y_d_gs)
|
41 |
+
loss_fm = feature_loss(fmap_rs, fmap_gs)
|
42 |
+
loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
|
43 |
+
if self.tpr_loss_weight != 0:
|
44 |
+
loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
|
45 |
+
else:
|
46 |
+
loss_tpr = torch.zeros(1).to(device)
|
47 |
+
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
|
48 |
+
loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
|
49 |
+
self.multi_mel_spectral_recon_loss_weight * loss_mel + \
|
50 |
+
self.tpr_loss_weight * loss_tpr + loss_f0
|
51 |
+
return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
|
52 |
+
|
53 |
+
def forward_discriminator(self, batch, device):
|
54 |
+
real_speech = batch['speech'].to(device)
|
55 |
+
# 1. calculate generator outputs
|
56 |
+
with torch.no_grad():
|
57 |
+
generated_speech, generated_f0 = self.generator(batch, device)
|
58 |
+
# 2. calculate discriminator outputs
|
59 |
+
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
|
60 |
+
# 3. calculate discriminator losses, tpr losses [Optional]
|
61 |
+
loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
|
62 |
+
if self.tpr_loss_weight != 0:
|
63 |
+
loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
|
64 |
+
else:
|
65 |
+
loss_tpr = torch.zeros(1).to(device)
|
66 |
+
loss = loss_disc + self.tpr_loss_weight * loss_tpr
|
67 |
+
return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr}
|
cosyvoice/llm/llm.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Dict, Optional, Callable, List, Generator
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from transformers import Qwen2ForCausalLM
|
19 |
+
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
|
20 |
+
from cosyvoice.utils.common import IGNORE_ID
|
21 |
+
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
|
22 |
+
from cosyvoice.utils.common import th_accuracy
|
23 |
+
|
24 |
+
|
25 |
+
class TransformerLM(torch.nn.Module):
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
text_encoder_input_size: int,
|
29 |
+
llm_input_size: int,
|
30 |
+
llm_output_size: int,
|
31 |
+
text_token_size: int,
|
32 |
+
speech_token_size: int,
|
33 |
+
text_encoder: torch.nn.Module,
|
34 |
+
llm: torch.nn.Module,
|
35 |
+
sampling: Callable,
|
36 |
+
length_normalized_loss: bool = True,
|
37 |
+
lsm_weight: float = 0.0,
|
38 |
+
spk_embed_dim: int = 192,
|
39 |
+
):
|
40 |
+
super().__init__()
|
41 |
+
self.llm_input_size = llm_input_size
|
42 |
+
self.speech_token_size = speech_token_size
|
43 |
+
# 1. build text token inputs related modules
|
44 |
+
self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
|
45 |
+
self.text_encoder = text_encoder
|
46 |
+
self.text_encoder_affine_layer = nn.Linear(
|
47 |
+
self.text_encoder.output_size(),
|
48 |
+
llm_input_size
|
49 |
+
)
|
50 |
+
|
51 |
+
# 2. build speech token language model related modules
|
52 |
+
self.sos_eos = 0
|
53 |
+
self.task_id = 1
|
54 |
+
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
55 |
+
self.llm = llm
|
56 |
+
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
|
57 |
+
self.criterion_ce = LabelSmoothingLoss(
|
58 |
+
size=speech_token_size + 1,
|
59 |
+
padding_idx=IGNORE_ID,
|
60 |
+
smoothing=lsm_weight,
|
61 |
+
normalize_length=length_normalized_loss,
|
62 |
+
)
|
63 |
+
|
64 |
+
# 3. [Optional] build speech token related modules
|
65 |
+
self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
|
66 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
|
67 |
+
|
68 |
+
# 4. sampling method
|
69 |
+
self.sampling = sampling
|
70 |
+
|
71 |
+
def encode(
|
72 |
+
self,
|
73 |
+
text: torch.Tensor,
|
74 |
+
text_lengths: torch.Tensor,
|
75 |
+
):
|
76 |
+
encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
|
77 |
+
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
|
78 |
+
encoder_out = self.text_encoder_affine_layer(encoder_out)
|
79 |
+
return encoder_out, encoder_out_lens
|
80 |
+
|
81 |
+
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
|
82 |
+
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
83 |
+
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
84 |
+
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
|
85 |
+
for i in range(len(text_token))]
|
86 |
+
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
87 |
+
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
|
88 |
+
return lm_input, lm_input_len
|
89 |
+
|
90 |
+
def forward(
|
91 |
+
self,
|
92 |
+
batch: dict,
|
93 |
+
device: torch.device,
|
94 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
95 |
+
"""
|
96 |
+
Args:
|
97 |
+
text: (B, L, D)
|
98 |
+
text_lengths: (B,)
|
99 |
+
audio: (B, T, N) or (B, T)
|
100 |
+
audio_lengths: (B,)
|
101 |
+
"""
|
102 |
+
text_token = batch['text_token'].to(device)
|
103 |
+
text_token_len = batch['text_token_len'].to(device)
|
104 |
+
speech_token = batch['speech_token'].to(device)
|
105 |
+
speech_token_len = batch['speech_token_len'].to(device)
|
106 |
+
embedding = batch['embedding'].to(device)
|
107 |
+
|
108 |
+
# 1. prepare llm_target
|
109 |
+
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
|
110 |
+
[self.speech_token_size]) for i in range(text_token.size(0))]
|
111 |
+
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
|
112 |
+
|
113 |
+
# 1. encode text_token
|
114 |
+
text_token = self.text_embedding(text_token)
|
115 |
+
text_token, text_token_len = self.encode(text_token, text_token_len)
|
116 |
+
|
117 |
+
# 2. embedding projection
|
118 |
+
embedding = F.normalize(embedding, dim=1)
|
119 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
120 |
+
embedding = embedding.unsqueeze(1)
|
121 |
+
|
122 |
+
# 3. eos and task_id
|
123 |
+
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
124 |
+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
125 |
+
|
126 |
+
# 4. encode speech_token
|
127 |
+
speech_token = self.speech_embedding(speech_token)
|
128 |
+
|
129 |
+
# 5. unpad and pad
|
130 |
+
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
|
131 |
+
task_id_emb, speech_token, speech_token_len)
|
132 |
+
|
133 |
+
# 6. run lm forward
|
134 |
+
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
135 |
+
logits = self.llm_decoder(lm_output)
|
136 |
+
loss = self.criterion_ce(logits, lm_target)
|
137 |
+
acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
|
138 |
+
return {'loss': loss, 'acc': acc}
|
139 |
+
|
140 |
+
def sampling_ids(
|
141 |
+
self,
|
142 |
+
weighted_scores: torch.Tensor,
|
143 |
+
decoded_tokens: List,
|
144 |
+
sampling: int,
|
145 |
+
ignore_eos: bool = True,
|
146 |
+
):
|
147 |
+
while True:
|
148 |
+
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
149 |
+
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
150 |
+
break
|
151 |
+
return top_ids
|
152 |
+
|
153 |
+
@torch.inference_mode()
|
154 |
+
def inference(
|
155 |
+
self,
|
156 |
+
text: torch.Tensor,
|
157 |
+
text_len: torch.Tensor,
|
158 |
+
prompt_text: torch.Tensor,
|
159 |
+
prompt_text_len: torch.Tensor,
|
160 |
+
prompt_speech_token: torch.Tensor,
|
161 |
+
prompt_speech_token_len: torch.Tensor,
|
162 |
+
embedding: torch.Tensor,
|
163 |
+
sampling: int = 25,
|
164 |
+
max_token_text_ratio: float = 20,
|
165 |
+
min_token_text_ratio: float = 2,
|
166 |
+
) -> Generator[torch.Tensor, None, None]:
|
167 |
+
device = text.device
|
168 |
+
text = torch.concat([prompt_text, text], dim=1)
|
169 |
+
text_len += prompt_text_len
|
170 |
+
text = self.text_embedding(text)
|
171 |
+
|
172 |
+
# 1. encode text
|
173 |
+
text, text_len = self.encode(text, text_len)
|
174 |
+
|
175 |
+
# 2. encode embedding
|
176 |
+
if embedding.shape[0] != 0:
|
177 |
+
embedding = F.normalize(embedding, dim=1)
|
178 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
179 |
+
embedding = embedding.unsqueeze(dim=1)
|
180 |
+
else:
|
181 |
+
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
182 |
+
|
183 |
+
# 3. concat llm_input
|
184 |
+
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
185 |
+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
186 |
+
if prompt_speech_token_len != 0:
|
187 |
+
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
188 |
+
else:
|
189 |
+
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
190 |
+
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
191 |
+
|
192 |
+
# 4. cal min/max_length
|
193 |
+
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
194 |
+
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
195 |
+
|
196 |
+
# 5. step by step decode
|
197 |
+
out_tokens = []
|
198 |
+
offset = 0
|
199 |
+
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
|
200 |
+
for i in range(max_len):
|
201 |
+
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
|
202 |
+
att_cache=att_cache, cnn_cache=cnn_cache,
|
203 |
+
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
|
204 |
+
device=lm_input.device)).to(torch.bool))
|
205 |
+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
206 |
+
# force continue decode first token
|
207 |
+
if i == 0:
|
208 |
+
logp[:, self.speech_token_size] = -float('inf')
|
209 |
+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
210 |
+
if top_ids == self.speech_token_size:
|
211 |
+
break
|
212 |
+
# in stream mode, yield token one by one
|
213 |
+
yield top_ids
|
214 |
+
out_tokens.append(top_ids)
|
215 |
+
offset += lm_input.size(1)
|
216 |
+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
217 |
+
|
218 |
+
|
219 |
+
class Qwen2Encoder(torch.nn.Module):
|
220 |
+
def __init__(self, pretrain_path):
|
221 |
+
super().__init__()
|
222 |
+
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
|
223 |
+
|
224 |
+
def forward_one_step(self, xs, masks, cache=None):
|
225 |
+
input_masks = masks[:, -1, :]
|
226 |
+
outs = self.model(
|
227 |
+
inputs_embeds=xs,
|
228 |
+
attention_mask=input_masks,
|
229 |
+
output_hidden_states=True,
|
230 |
+
return_dict=True,
|
231 |
+
use_cache=True,
|
232 |
+
past_key_values=cache,
|
233 |
+
)
|
234 |
+
xs = outs.hidden_states[-1]
|
235 |
+
new_cache = outs.past_key_values
|
236 |
+
return xs, new_cache
|
237 |
+
|
238 |
+
|
239 |
+
class Qwen2LM(torch.nn.Module):
|
240 |
+
def __init__(
|
241 |
+
self,
|
242 |
+
llm_input_size: int,
|
243 |
+
llm_output_size: int,
|
244 |
+
speech_token_size: int,
|
245 |
+
llm: torch.nn.Module,
|
246 |
+
sampling: Callable,
|
247 |
+
length_normalized_loss: bool = True,
|
248 |
+
lsm_weight: float = 0.0,
|
249 |
+
):
|
250 |
+
super().__init__()
|
251 |
+
self.llm_input_size = llm_input_size
|
252 |
+
self.llm_output_size = llm_output_size
|
253 |
+
self.speech_token_size = speech_token_size
|
254 |
+
|
255 |
+
# 2. build speech token language model related modules
|
256 |
+
self.sos_eos = 0
|
257 |
+
self.task_id = 1
|
258 |
+
self.fill_token = 2
|
259 |
+
|
260 |
+
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
261 |
+
self.llm = llm
|
262 |
+
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
|
263 |
+
self.criterion_ce = LabelSmoothingLoss(
|
264 |
+
size=speech_token_size + 3,
|
265 |
+
padding_idx=IGNORE_ID,
|
266 |
+
smoothing=lsm_weight,
|
267 |
+
normalize_length=length_normalized_loss,
|
268 |
+
)
|
269 |
+
|
270 |
+
# 3. [Optional] build speech token related modules
|
271 |
+
self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
|
272 |
+
|
273 |
+
# 4. sampling method
|
274 |
+
self.sampling = sampling
|
275 |
+
|
276 |
+
def sampling_ids(
|
277 |
+
self,
|
278 |
+
weighted_scores: torch.Tensor,
|
279 |
+
decoded_tokens: List,
|
280 |
+
sampling: int,
|
281 |
+
ignore_eos: bool = True,
|
282 |
+
):
|
283 |
+
while True:
|
284 |
+
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
285 |
+
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
286 |
+
break
|
287 |
+
return top_ids
|
288 |
+
|
289 |
+
@torch.inference_mode()
|
290 |
+
def inference(
|
291 |
+
self,
|
292 |
+
text: torch.Tensor,
|
293 |
+
text_len: torch.Tensor,
|
294 |
+
prompt_text: torch.Tensor,
|
295 |
+
prompt_text_len: torch.Tensor,
|
296 |
+
prompt_speech_token: torch.Tensor,
|
297 |
+
prompt_speech_token_len: torch.Tensor,
|
298 |
+
embedding: torch.Tensor,
|
299 |
+
sampling: int = 25,
|
300 |
+
max_token_text_ratio: float = 20,
|
301 |
+
min_token_text_ratio: float = 2,
|
302 |
+
) -> Generator[torch.Tensor, None, None]:
|
303 |
+
device = text.device
|
304 |
+
text = torch.concat([prompt_text, text], dim=1)
|
305 |
+
text_len += prompt_text_len
|
306 |
+
text = self.llm.model.model.embed_tokens(text)
|
307 |
+
|
308 |
+
# 2. encode embedding
|
309 |
+
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
310 |
+
|
311 |
+
# 3. concat llm_input
|
312 |
+
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
313 |
+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
314 |
+
if prompt_speech_token_len != 0:
|
315 |
+
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
316 |
+
else:
|
317 |
+
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
318 |
+
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
319 |
+
|
320 |
+
# 4. cal min/max_length
|
321 |
+
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
322 |
+
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
323 |
+
|
324 |
+
# 5. step by step decode
|
325 |
+
out_tokens = []
|
326 |
+
cache = None
|
327 |
+
for i in range(max_len):
|
328 |
+
y_pred, cache = self.llm.forward_one_step(lm_input,
|
329 |
+
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
|
330 |
+
cache=cache)
|
331 |
+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
332 |
+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
333 |
+
if top_ids == self.speech_token_size:
|
334 |
+
break
|
335 |
+
if top_ids > self.speech_token_size:
|
336 |
+
continue
|
337 |
+
# in stream mode, yield token one by one
|
338 |
+
yield top_ids
|
339 |
+
out_tokens.append(top_ids)
|
340 |
+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken
ADDED
The diff for this file is too large to render.
See raw diff
|
|
cosyvoice/tokenizer/tokenizer.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import os
|
3 |
+
from functools import lru_cache
|
4 |
+
from typing import Optional
|
5 |
+
import torch
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
from whisper.tokenizer import Tokenizer
|
8 |
+
|
9 |
+
import tiktoken
|
10 |
+
|
11 |
+
LANGUAGES = {
|
12 |
+
"en": "english",
|
13 |
+
"zh": "chinese",
|
14 |
+
"de": "german",
|
15 |
+
"es": "spanish",
|
16 |
+
"ru": "russian",
|
17 |
+
"ko": "korean",
|
18 |
+
"fr": "french",
|
19 |
+
"ja": "japanese",
|
20 |
+
"pt": "portuguese",
|
21 |
+
"tr": "turkish",
|
22 |
+
"pl": "polish",
|
23 |
+
"ca": "catalan",
|
24 |
+
"nl": "dutch",
|
25 |
+
"ar": "arabic",
|
26 |
+
"sv": "swedish",
|
27 |
+
"it": "italian",
|
28 |
+
"id": "indonesian",
|
29 |
+
"hi": "hindi",
|
30 |
+
"fi": "finnish",
|
31 |
+
"vi": "vietnamese",
|
32 |
+
"he": "hebrew",
|
33 |
+
"uk": "ukrainian",
|
34 |
+
"el": "greek",
|
35 |
+
"ms": "malay",
|
36 |
+
"cs": "czech",
|
37 |
+
"ro": "romanian",
|
38 |
+
"da": "danish",
|
39 |
+
"hu": "hungarian",
|
40 |
+
"ta": "tamil",
|
41 |
+
"no": "norwegian",
|
42 |
+
"th": "thai",
|
43 |
+
"ur": "urdu",
|
44 |
+
"hr": "croatian",
|
45 |
+
"bg": "bulgarian",
|
46 |
+
"lt": "lithuanian",
|
47 |
+
"la": "latin",
|
48 |
+
"mi": "maori",
|
49 |
+
"ml": "malayalam",
|
50 |
+
"cy": "welsh",
|
51 |
+
"sk": "slovak",
|
52 |
+
"te": "telugu",
|
53 |
+
"fa": "persian",
|
54 |
+
"lv": "latvian",
|
55 |
+
"bn": "bengali",
|
56 |
+
"sr": "serbian",
|
57 |
+
"az": "azerbaijani",
|
58 |
+
"sl": "slovenian",
|
59 |
+
"kn": "kannada",
|
60 |
+
"et": "estonian",
|
61 |
+
"mk": "macedonian",
|
62 |
+
"br": "breton",
|
63 |
+
"eu": "basque",
|
64 |
+
"is": "icelandic",
|
65 |
+
"hy": "armenian",
|
66 |
+
"ne": "nepali",
|
67 |
+
"mn": "mongolian",
|
68 |
+
"bs": "bosnian",
|
69 |
+
"kk": "kazakh",
|
70 |
+
"sq": "albanian",
|
71 |
+
"sw": "swahili",
|
72 |
+
"gl": "galician",
|
73 |
+
"mr": "marathi",
|
74 |
+
"pa": "punjabi",
|
75 |
+
"si": "sinhala",
|
76 |
+
"km": "khmer",
|
77 |
+
"sn": "shona",
|
78 |
+
"yo": "yoruba",
|
79 |
+
"so": "somali",
|
80 |
+
"af": "afrikaans",
|
81 |
+
"oc": "occitan",
|
82 |
+
"ka": "georgian",
|
83 |
+
"be": "belarusian",
|
84 |
+
"tg": "tajik",
|
85 |
+
"sd": "sindhi",
|
86 |
+
"gu": "gujarati",
|
87 |
+
"am": "amharic",
|
88 |
+
"yi": "yiddish",
|
89 |
+
"lo": "lao",
|
90 |
+
"uz": "uzbek",
|
91 |
+
"fo": "faroese",
|
92 |
+
"ht": "haitian creole",
|
93 |
+
"ps": "pashto",
|
94 |
+
"tk": "turkmen",
|
95 |
+
"nn": "nynorsk",
|
96 |
+
"mt": "maltese",
|
97 |
+
"sa": "sanskrit",
|
98 |
+
"lb": "luxembourgish",
|
99 |
+
"my": "myanmar",
|
100 |
+
"bo": "tibetan",
|
101 |
+
"tl": "tagalog",
|
102 |
+
"mg": "malagasy",
|
103 |
+
"as": "assamese",
|
104 |
+
"tt": "tatar",
|
105 |
+
"haw": "hawaiian",
|
106 |
+
"ln": "lingala",
|
107 |
+
"ha": "hausa",
|
108 |
+
"ba": "bashkir",
|
109 |
+
"jw": "javanese",
|
110 |
+
"su": "sundanese",
|
111 |
+
"yue": "cantonese",
|
112 |
+
"minnan": "minnan",
|
113 |
+
"wuyu": "wuyu",
|
114 |
+
"dialect": "dialect",
|
115 |
+
"zh/en": "zh/en",
|
116 |
+
"en/zh": "en/zh",
|
117 |
+
}
|
118 |
+
|
119 |
+
# language code lookup by name, with a few language aliases
|
120 |
+
TO_LANGUAGE_CODE = {
|
121 |
+
**{language: code for code, language in LANGUAGES.items()},
|
122 |
+
"burmese": "my",
|
123 |
+
"valencian": "ca",
|
124 |
+
"flemish": "nl",
|
125 |
+
"haitian": "ht",
|
126 |
+
"letzeburgesch": "lb",
|
127 |
+
"pushto": "ps",
|
128 |
+
"panjabi": "pa",
|
129 |
+
"moldavian": "ro",
|
130 |
+
"moldovan": "ro",
|
131 |
+
"sinhalese": "si",
|
132 |
+
"castilian": "es",
|
133 |
+
"mandarin": "zh",
|
134 |
+
}
|
135 |
+
|
136 |
+
AUDIO_EVENT = {
|
137 |
+
"ASR": "ASR",
|
138 |
+
"AED": "AED",
|
139 |
+
"SER": "SER",
|
140 |
+
"Speech": "Speech",
|
141 |
+
"/Speech": "/Speech",
|
142 |
+
"BGM": "BGM",
|
143 |
+
"/BGM": "/BGM",
|
144 |
+
"Laughter": "Laughter",
|
145 |
+
"/Laughter": "/Laughter",
|
146 |
+
"Applause": "Applause",
|
147 |
+
"/Applause": "/Applause",
|
148 |
+
}
|
149 |
+
|
150 |
+
EMOTION = {
|
151 |
+
"HAPPY": "HAPPY",
|
152 |
+
"SAD": "SAD",
|
153 |
+
"ANGRY": "ANGRY",
|
154 |
+
"NEUTRAL": "NEUTRAL",
|
155 |
+
}
|
156 |
+
|
157 |
+
TTS_Vocal_Token = {
|
158 |
+
"TTS/B": "TTS/B",
|
159 |
+
"TTS/O": "TTS/O",
|
160 |
+
"TTS/Q": "TTS/Q",
|
161 |
+
"TTS/A": "TTS/A",
|
162 |
+
"TTS/CO": "TTS/CO",
|
163 |
+
"TTS/CL": "TTS/CL",
|
164 |
+
"TTS/H": "TTS/H",
|
165 |
+
**{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
|
166 |
+
}
|
167 |
+
|
168 |
+
|
169 |
+
@lru_cache(maxsize=None)
|
170 |
+
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
171 |
+
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
172 |
+
ranks = {
|
173 |
+
base64.b64decode(token): int(rank)
|
174 |
+
for token, rank in (line.split() for line in open(vocab_path) if line)
|
175 |
+
}
|
176 |
+
n_vocab = len(ranks)
|
177 |
+
special_tokens = {}
|
178 |
+
|
179 |
+
specials = [
|
180 |
+
"<|endoftext|>",
|
181 |
+
"<|startoftranscript|>",
|
182 |
+
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
183 |
+
*[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
|
184 |
+
*[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
|
185 |
+
"<|translate|>",
|
186 |
+
"<|transcribe|>",
|
187 |
+
"<|startoflm|>",
|
188 |
+
"<|startofprev|>",
|
189 |
+
"<|nospeech|>",
|
190 |
+
"<|notimestamps|>",
|
191 |
+
*[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
|
192 |
+
*[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
|
193 |
+
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
194 |
+
]
|
195 |
+
|
196 |
+
for token in specials:
|
197 |
+
special_tokens[token] = n_vocab
|
198 |
+
n_vocab += 1
|
199 |
+
|
200 |
+
return tiktoken.Encoding(
|
201 |
+
name=os.path.basename(vocab_path),
|
202 |
+
explicit_n_vocab=n_vocab,
|
203 |
+
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
204 |
+
mergeable_ranks=ranks,
|
205 |
+
special_tokens=special_tokens,
|
206 |
+
)
|
207 |
+
|
208 |
+
|
209 |
+
@lru_cache(maxsize=None)
|
210 |
+
def get_tokenizer(
|
211 |
+
multilingual: bool,
|
212 |
+
*,
|
213 |
+
num_languages: int = 99,
|
214 |
+
language: Optional[str] = None,
|
215 |
+
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
216 |
+
) -> Tokenizer:
|
217 |
+
if language is not None:
|
218 |
+
language = language.lower()
|
219 |
+
if language not in LANGUAGES:
|
220 |
+
if language in TO_LANGUAGE_CODE:
|
221 |
+
language = TO_LANGUAGE_CODE[language]
|
222 |
+
else:
|
223 |
+
raise ValueError(f"Unsupported language: {language}")
|
224 |
+
|
225 |
+
if multilingual:
|
226 |
+
encoding_name = "multilingual_zh_ja_yue_char_del"
|
227 |
+
language = language or "en"
|
228 |
+
task = task or "transcribe"
|
229 |
+
else:
|
230 |
+
encoding_name = "gpt2"
|
231 |
+
language = None
|
232 |
+
task = None
|
233 |
+
|
234 |
+
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
235 |
+
|
236 |
+
return Tokenizer(
|
237 |
+
encoding=encoding, num_languages=num_languages, language=language, task=task
|
238 |
+
)
|
239 |
+
|
240 |
+
|
241 |
+
class QwenTokenizer():
|
242 |
+
def __init__(self, token_path, skip_special_tokens=True):
|
243 |
+
super().__init__()
|
244 |
+
# NOTE: non-chat model, all these special tokens keep randomly initialized.
|
245 |
+
special_tokens = {
|
246 |
+
'eos_token': '<|endoftext|>',
|
247 |
+
'pad_token': '<|endoftext|>',
|
248 |
+
'additional_special_tokens': [
|
249 |
+
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
250 |
+
'[breath]', '<strong>', '</strong>', '[noise]',
|
251 |
+
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
252 |
+
'[quick_breath]',
|
253 |
+
"<laughter>", "</laughter>",
|
254 |
+
"[hissing]", "[sigh]", "[vocalized-noise]",
|
255 |
+
"[lipsmack]", "[mn]"
|
256 |
+
]
|
257 |
+
}
|
258 |
+
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
|
259 |
+
self.tokenizer.add_special_tokens(special_tokens)
|
260 |
+
self.skip_special_tokens = skip_special_tokens
|
261 |
+
|
262 |
+
def encode(self, text, **kwargs):
|
263 |
+
tokens = self.tokenizer([text], return_tensors="pt")
|
264 |
+
tokens = tokens["input_ids"][0].cpu().tolist()
|
265 |
+
return tokens
|
266 |
+
|
267 |
+
def decode(self, tokens):
|
268 |
+
tokens = torch.tensor(tokens, dtype=torch.int64)
|
269 |
+
text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
|
270 |
+
return text
|
271 |
+
|
272 |
+
@lru_cache(maxsize=None)
|
273 |
+
def get_qwen_tokenizer(
|
274 |
+
token_path: str,
|
275 |
+
skip_special_tokens: bool
|
276 |
+
) -> QwenTokenizer:
|
277 |
+
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
|
cosyvoice/transformer/__init__.py
ADDED
File without changes
|
cosyvoice/transformer/activation.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
|
2 |
+
# 2020 Northwestern Polytechnical University (Pengcheng Guo)
|
3 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
4 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
"""Swish() activation function for Conformer."""
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn, sin, pow
|
21 |
+
from torch.nn import Parameter
|
22 |
+
|
23 |
+
|
24 |
+
class Swish(torch.nn.Module):
|
25 |
+
"""Construct an Swish object."""
|
26 |
+
|
27 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
28 |
+
"""Return Swish activation function."""
|
29 |
+
return x * torch.sigmoid(x)
|
30 |
+
|
31 |
+
|
32 |
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
33 |
+
# LICENSE is in incl_licenses directory.
|
34 |
+
class Snake(nn.Module):
|
35 |
+
'''
|
36 |
+
Implementation of a sine-based periodic activation function
|
37 |
+
Shape:
|
38 |
+
- Input: (B, C, T)
|
39 |
+
- Output: (B, C, T), same shape as the input
|
40 |
+
Parameters:
|
41 |
+
- alpha - trainable parameter
|
42 |
+
References:
|
43 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
44 |
+
https://arxiv.org/abs/2006.08195
|
45 |
+
Examples:
|
46 |
+
>>> a1 = snake(256)
|
47 |
+
>>> x = torch.randn(256)
|
48 |
+
>>> x = a1(x)
|
49 |
+
'''
|
50 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
51 |
+
'''
|
52 |
+
Initialization.
|
53 |
+
INPUT:
|
54 |
+
- in_features: shape of the input
|
55 |
+
- alpha: trainable parameter
|
56 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
57 |
+
alpha will be trained along with the rest of your model.
|
58 |
+
'''
|
59 |
+
super(Snake, self).__init__()
|
60 |
+
self.in_features = in_features
|
61 |
+
|
62 |
+
# initialize alpha
|
63 |
+
self.alpha_logscale = alpha_logscale
|
64 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
65 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
66 |
+
else: # linear scale alphas initialized to ones
|
67 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
68 |
+
|
69 |
+
self.alpha.requires_grad = alpha_trainable
|
70 |
+
|
71 |
+
self.no_div_by_zero = 0.000000001
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
'''
|
75 |
+
Forward pass of the function.
|
76 |
+
Applies the function to the input elementwise.
|
77 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
78 |
+
'''
|
79 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
80 |
+
if self.alpha_logscale:
|
81 |
+
alpha = torch.exp(alpha)
|
82 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
83 |
+
|
84 |
+
return x
|
cosyvoice/transformer/attention.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019 Shigeki Karita
|
2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
3 |
+
# 2022 Xingchen Song ([email protected])
|
4 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
"""Multi-Head Attention layer definition."""
|
18 |
+
|
19 |
+
import math
|
20 |
+
from typing import Tuple
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from torch import nn
|
24 |
+
|
25 |
+
|
26 |
+
class MultiHeadedAttention(nn.Module):
|
27 |
+
"""Multi-Head Attention layer.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
n_head (int): The number of heads.
|
31 |
+
n_feat (int): The number of features.
|
32 |
+
dropout_rate (float): Dropout rate.
|
33 |
+
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
n_head: int,
|
38 |
+
n_feat: int,
|
39 |
+
dropout_rate: float,
|
40 |
+
key_bias: bool = True):
|
41 |
+
"""Construct an MultiHeadedAttention object."""
|
42 |
+
super().__init__()
|
43 |
+
assert n_feat % n_head == 0
|
44 |
+
# We assume d_v always equals d_k
|
45 |
+
self.d_k = n_feat // n_head
|
46 |
+
self.h = n_head
|
47 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
48 |
+
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
|
49 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
50 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
51 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
52 |
+
|
53 |
+
def forward_qkv(
|
54 |
+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
55 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
56 |
+
"""Transform query, key and value.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
60 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
61 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
torch.Tensor: Transformed query tensor, size
|
65 |
+
(#batch, n_head, time1, d_k).
|
66 |
+
torch.Tensor: Transformed key tensor, size
|
67 |
+
(#batch, n_head, time2, d_k).
|
68 |
+
torch.Tensor: Transformed value tensor, size
|
69 |
+
(#batch, n_head, time2, d_k).
|
70 |
+
|
71 |
+
"""
|
72 |
+
n_batch = query.size(0)
|
73 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
74 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
75 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
76 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
77 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
78 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
79 |
+
|
80 |
+
return q, k, v
|
81 |
+
|
82 |
+
def forward_attention(
|
83 |
+
self,
|
84 |
+
value: torch.Tensor,
|
85 |
+
scores: torch.Tensor,
|
86 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
87 |
+
) -> torch.Tensor:
|
88 |
+
"""Compute attention context vector.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
value (torch.Tensor): Transformed value, size
|
92 |
+
(#batch, n_head, time2, d_k).
|
93 |
+
scores (torch.Tensor): Attention score, size
|
94 |
+
(#batch, n_head, time1, time2).
|
95 |
+
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
|
96 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
100 |
+
weighted by the attention score (#batch, time1, time2).
|
101 |
+
|
102 |
+
"""
|
103 |
+
n_batch = value.size(0)
|
104 |
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
|
105 |
+
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
106 |
+
# 1st chunk to ease the onnx export.]
|
107 |
+
# 2. pytorch training
|
108 |
+
if mask.size(2) > 0: # time2 > 0
|
109 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
110 |
+
# For last chunk, time2 might be larger than scores.size(-1)
|
111 |
+
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
112 |
+
scores = scores.masked_fill(mask, -float('inf'))
|
113 |
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
114 |
+
mask, 0.0) # (batch, head, time1, time2)
|
115 |
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
|
116 |
+
# 1. onnx(16/-1, -1/-1, 16/0)
|
117 |
+
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
118 |
+
else:
|
119 |
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
120 |
+
|
121 |
+
p_attn = self.dropout(attn)
|
122 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
123 |
+
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
|
124 |
+
self.h * self.d_k)
|
125 |
+
) # (batch, time1, d_model)
|
126 |
+
|
127 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
128 |
+
|
129 |
+
def forward(
|
130 |
+
self,
|
131 |
+
query: torch.Tensor,
|
132 |
+
key: torch.Tensor,
|
133 |
+
value: torch.Tensor,
|
134 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
135 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
136 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
137 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
138 |
+
"""Compute scaled dot product attention.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
142 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
143 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
144 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
145 |
+
(#batch, time1, time2).
|
146 |
+
1.When applying cross attention between decoder and encoder,
|
147 |
+
the batch padding mask for input is in (#batch, 1, T) shape.
|
148 |
+
2.When applying self attention of encoder,
|
149 |
+
the mask is in (#batch, T, T) shape.
|
150 |
+
3.When applying self attention of decoder,
|
151 |
+
the mask is in (#batch, L, L) shape.
|
152 |
+
4.If the different position in decoder see different block
|
153 |
+
of the encoder, such as Mocha, the passed in mask could be
|
154 |
+
in (#batch, L, T) shape. But there is no such case in current
|
155 |
+
CosyVoice.
|
156 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
157 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
158 |
+
and `head * d_k == size`
|
159 |
+
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
163 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
164 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
165 |
+
and `head * d_k == size`
|
166 |
+
|
167 |
+
"""
|
168 |
+
q, k, v = self.forward_qkv(query, key, value)
|
169 |
+
|
170 |
+
# NOTE(xcsong):
|
171 |
+
# when export onnx model, for 1st chunk, we feed
|
172 |
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
173 |
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
174 |
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
175 |
+
# and we will always do splitting and
|
176 |
+
# concatnation(this will simplify onnx export). Note that
|
177 |
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
178 |
+
# when export jit model, for 1st chunk, we always feed
|
179 |
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
180 |
+
# >>> a = torch.ones((1, 2, 0, 4))
|
181 |
+
# >>> b = torch.ones((1, 2, 3, 4))
|
182 |
+
# >>> c = torch.cat((a, b), dim=2)
|
183 |
+
# >>> torch.equal(b, c) # True
|
184 |
+
# >>> d = torch.split(a, 2, dim=-1)
|
185 |
+
# >>> torch.equal(d[0], d[1]) # True
|
186 |
+
if cache.size(0) > 0:
|
187 |
+
key_cache, value_cache = torch.split(cache,
|
188 |
+
cache.size(-1) // 2,
|
189 |
+
dim=-1)
|
190 |
+
k = torch.cat([key_cache, k], dim=2)
|
191 |
+
v = torch.cat([value_cache, v], dim=2)
|
192 |
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
193 |
+
# non-trivial to calculate `next_cache_start` here.
|
194 |
+
new_cache = torch.cat((k, v), dim=-1)
|
195 |
+
|
196 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
197 |
+
return self.forward_attention(v, scores, mask), new_cache
|
198 |
+
|
199 |
+
|
200 |
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
201 |
+
"""Multi-Head Attention layer with relative position encoding.
|
202 |
+
Paper: https://arxiv.org/abs/1901.02860
|
203 |
+
Args:
|
204 |
+
n_head (int): The number of heads.
|
205 |
+
n_feat (int): The number of features.
|
206 |
+
dropout_rate (float): Dropout rate.
|
207 |
+
"""
|
208 |
+
|
209 |
+
def __init__(self,
|
210 |
+
n_head: int,
|
211 |
+
n_feat: int,
|
212 |
+
dropout_rate: float,
|
213 |
+
key_bias: bool = True):
|
214 |
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
215 |
+
super().__init__(n_head, n_feat, dropout_rate, key_bias)
|
216 |
+
# linear transformation for positional encoding
|
217 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
218 |
+
# these two learnable bias are used in matrix c and matrix d
|
219 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
220 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
221 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
222 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
223 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
224 |
+
|
225 |
+
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
226 |
+
"""Compute relative positional encoding.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
230 |
+
time1 means the length of query vector.
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
torch.Tensor: Output tensor.
|
234 |
+
|
235 |
+
"""
|
236 |
+
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
237 |
+
device=x.device,
|
238 |
+
dtype=x.dtype)
|
239 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
240 |
+
|
241 |
+
x_padded = x_padded.view(x.size()[0],
|
242 |
+
x.size()[1],
|
243 |
+
x.size(3) + 1, x.size(2))
|
244 |
+
x = x_padded[:, :, 1:].view_as(x)[
|
245 |
+
:, :, :, : x.size(-1) // 2 + 1
|
246 |
+
] # only keep the positions from 0 to time2
|
247 |
+
return x
|
248 |
+
|
249 |
+
def forward(
|
250 |
+
self,
|
251 |
+
query: torch.Tensor,
|
252 |
+
key: torch.Tensor,
|
253 |
+
value: torch.Tensor,
|
254 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
255 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
256 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
257 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
258 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
259 |
+
Args:
|
260 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
261 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
262 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
263 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
264 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
265 |
+
pos_emb (torch.Tensor): Positional embedding tensor
|
266 |
+
(#batch, time2, size).
|
267 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
268 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
269 |
+
and `head * d_k == size`
|
270 |
+
Returns:
|
271 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
272 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
273 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
274 |
+
and `head * d_k == size`
|
275 |
+
"""
|
276 |
+
q, k, v = self.forward_qkv(query, key, value)
|
277 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
278 |
+
|
279 |
+
# NOTE(xcsong):
|
280 |
+
# when export onnx model, for 1st chunk, we feed
|
281 |
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
282 |
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
283 |
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
284 |
+
# and we will always do splitting and
|
285 |
+
# concatnation(this will simplify onnx export). Note that
|
286 |
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
287 |
+
# when export jit model, for 1st chunk, we always feed
|
288 |
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
289 |
+
# >>> a = torch.ones((1, 2, 0, 4))
|
290 |
+
# >>> b = torch.ones((1, 2, 3, 4))
|
291 |
+
# >>> c = torch.cat((a, b), dim=2)
|
292 |
+
# >>> torch.equal(b, c) # True
|
293 |
+
# >>> d = torch.split(a, 2, dim=-1)
|
294 |
+
# >>> torch.equal(d[0], d[1]) # True
|
295 |
+
if cache.size(0) > 0:
|
296 |
+
key_cache, value_cache = torch.split(cache,
|
297 |
+
cache.size(-1) // 2,
|
298 |
+
dim=-1)
|
299 |
+
k = torch.cat([key_cache, k], dim=2)
|
300 |
+
v = torch.cat([value_cache, v], dim=2)
|
301 |
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
302 |
+
# non-trivial to calculate `next_cache_start` here.
|
303 |
+
new_cache = torch.cat((k, v), dim=-1)
|
304 |
+
|
305 |
+
n_batch_pos = pos_emb.size(0)
|
306 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
307 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
308 |
+
|
309 |
+
# (batch, head, time1, d_k)
|
310 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
311 |
+
# (batch, head, time1, d_k)
|
312 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
313 |
+
|
314 |
+
# compute attention score
|
315 |
+
# first compute matrix a and matrix c
|
316 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
317 |
+
# (batch, head, time1, time2)
|
318 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
319 |
+
|
320 |
+
# compute matrix b and matrix d
|
321 |
+
# (batch, head, time1, time2)
|
322 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
323 |
+
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
|
324 |
+
if matrix_ac.shape != matrix_bd.shape:
|
325 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
326 |
+
|
327 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
328 |
+
self.d_k) # (batch, head, time1, time2)
|
329 |
+
|
330 |
+
return self.forward_attention(v, scores, mask), new_cache
|
cosyvoice/transformer/convolution.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
2 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""ConvolutionModule definition."""
|
17 |
+
|
18 |
+
from typing import Tuple
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
|
24 |
+
class ConvolutionModule(nn.Module):
|
25 |
+
"""ConvolutionModule in Conformer model."""
|
26 |
+
|
27 |
+
def __init__(self,
|
28 |
+
channels: int,
|
29 |
+
kernel_size: int = 15,
|
30 |
+
activation: nn.Module = nn.ReLU(),
|
31 |
+
norm: str = "batch_norm",
|
32 |
+
causal: bool = False,
|
33 |
+
bias: bool = True):
|
34 |
+
"""Construct an ConvolutionModule object.
|
35 |
+
Args:
|
36 |
+
channels (int): The number of channels of conv layers.
|
37 |
+
kernel_size (int): Kernel size of conv layers.
|
38 |
+
causal (int): Whether use causal convolution or not
|
39 |
+
"""
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
self.pointwise_conv1 = nn.Conv1d(
|
43 |
+
channels,
|
44 |
+
2 * channels,
|
45 |
+
kernel_size=1,
|
46 |
+
stride=1,
|
47 |
+
padding=0,
|
48 |
+
bias=bias,
|
49 |
+
)
|
50 |
+
# self.lorder is used to distinguish if it's a causal convolution,
|
51 |
+
# if self.lorder > 0: it's a causal convolution, the input will be
|
52 |
+
# padded with self.lorder frames on the left in forward.
|
53 |
+
# else: it's a symmetrical convolution
|
54 |
+
if causal:
|
55 |
+
padding = 0
|
56 |
+
self.lorder = kernel_size - 1
|
57 |
+
else:
|
58 |
+
# kernel_size should be an odd number for none causal convolution
|
59 |
+
assert (kernel_size - 1) % 2 == 0
|
60 |
+
padding = (kernel_size - 1) // 2
|
61 |
+
self.lorder = 0
|
62 |
+
self.depthwise_conv = nn.Conv1d(
|
63 |
+
channels,
|
64 |
+
channels,
|
65 |
+
kernel_size,
|
66 |
+
stride=1,
|
67 |
+
padding=padding,
|
68 |
+
groups=channels,
|
69 |
+
bias=bias,
|
70 |
+
)
|
71 |
+
|
72 |
+
assert norm in ['batch_norm', 'layer_norm']
|
73 |
+
if norm == "batch_norm":
|
74 |
+
self.use_layer_norm = False
|
75 |
+
self.norm = nn.BatchNorm1d(channels)
|
76 |
+
else:
|
77 |
+
self.use_layer_norm = True
|
78 |
+
self.norm = nn.LayerNorm(channels)
|
79 |
+
|
80 |
+
self.pointwise_conv2 = nn.Conv1d(
|
81 |
+
channels,
|
82 |
+
channels,
|
83 |
+
kernel_size=1,
|
84 |
+
stride=1,
|
85 |
+
padding=0,
|
86 |
+
bias=bias,
|
87 |
+
)
|
88 |
+
self.activation = activation
|
89 |
+
|
90 |
+
def forward(
|
91 |
+
self,
|
92 |
+
x: torch.Tensor,
|
93 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
94 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0)),
|
95 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
96 |
+
"""Compute convolution module.
|
97 |
+
Args:
|
98 |
+
x (torch.Tensor): Input tensor (#batch, time, channels).
|
99 |
+
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
|
100 |
+
(0, 0, 0) means fake mask.
|
101 |
+
cache (torch.Tensor): left context cache, it is only
|
102 |
+
used in causal convolution (#batch, channels, cache_t),
|
103 |
+
(0, 0, 0) meas fake cache.
|
104 |
+
Returns:
|
105 |
+
torch.Tensor: Output tensor (#batch, time, channels).
|
106 |
+
"""
|
107 |
+
# exchange the temporal dimension and the feature dimension
|
108 |
+
x = x.transpose(1, 2) # (#batch, channels, time)
|
109 |
+
|
110 |
+
# mask batch padding
|
111 |
+
if mask_pad.size(2) > 0: # time > 0
|
112 |
+
x.masked_fill_(~mask_pad, 0.0)
|
113 |
+
|
114 |
+
if self.lorder > 0:
|
115 |
+
if cache.size(2) == 0: # cache_t == 0
|
116 |
+
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
|
117 |
+
else:
|
118 |
+
assert cache.size(0) == x.size(0) # equal batch
|
119 |
+
assert cache.size(1) == x.size(1) # equal channel
|
120 |
+
x = torch.cat((cache, x), dim=2)
|
121 |
+
assert (x.size(2) > self.lorder)
|
122 |
+
new_cache = x[:, :, -self.lorder:]
|
123 |
+
else:
|
124 |
+
# It's better we just return None if no cache is required,
|
125 |
+
# However, for JIT export, here we just fake one tensor instead of
|
126 |
+
# None.
|
127 |
+
new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
128 |
+
|
129 |
+
# GLU mechanism
|
130 |
+
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
131 |
+
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
132 |
+
|
133 |
+
# 1D Depthwise Conv
|
134 |
+
x = self.depthwise_conv(x)
|
135 |
+
if self.use_layer_norm:
|
136 |
+
x = x.transpose(1, 2)
|
137 |
+
x = self.activation(self.norm(x))
|
138 |
+
if self.use_layer_norm:
|
139 |
+
x = x.transpose(1, 2)
|
140 |
+
x = self.pointwise_conv2(x)
|
141 |
+
# mask batch padding
|
142 |
+
if mask_pad.size(2) > 0: # time > 0
|
143 |
+
x.masked_fill_(~mask_pad, 0.0)
|
144 |
+
|
145 |
+
return x.transpose(1, 2), new_cache
|
cosyvoice/transformer/decoder.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
2 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""Decoder definition."""
|
17 |
+
from typing import Tuple, List, Optional
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.utils.checkpoint as ckpt
|
21 |
+
import logging
|
22 |
+
|
23 |
+
from cosyvoice.transformer.decoder_layer import DecoderLayer
|
24 |
+
from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
25 |
+
from cosyvoice.utils.class_utils import (
|
26 |
+
COSYVOICE_EMB_CLASSES,
|
27 |
+
COSYVOICE_ATTENTION_CLASSES,
|
28 |
+
COSYVOICE_ACTIVATION_CLASSES,
|
29 |
+
)
|
30 |
+
from cosyvoice.utils.mask import (subsequent_mask, make_pad_mask)
|
31 |
+
|
32 |
+
|
33 |
+
class TransformerDecoder(torch.nn.Module):
|
34 |
+
"""Base class of Transfomer decoder module.
|
35 |
+
Args:
|
36 |
+
vocab_size: output dim
|
37 |
+
encoder_output_size: dimension of attention
|
38 |
+
attention_heads: the number of heads of multi head attention
|
39 |
+
linear_units: the hidden units number of position-wise feedforward
|
40 |
+
num_blocks: the number of decoder blocks
|
41 |
+
dropout_rate: dropout rate
|
42 |
+
self_attention_dropout_rate: dropout rate for attention
|
43 |
+
input_layer: input layer type
|
44 |
+
use_output_layer: whether to use output layer
|
45 |
+
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
46 |
+
normalize_before:
|
47 |
+
True: use layer_norm before each sub-block of a layer.
|
48 |
+
False: use layer_norm after each sub-block of a layer.
|
49 |
+
src_attention: if false, encoder-decoder cross attention is not
|
50 |
+
applied, such as CIF model
|
51 |
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
52 |
+
gradient_checkpointing: rerunning a forward-pass segment for each
|
53 |
+
checkpointed segment during backward.
|
54 |
+
tie_word_embedding: Tie or clone module weights depending of whether we are
|
55 |
+
using TorchScript or not
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
vocab_size: int,
|
61 |
+
encoder_output_size: int,
|
62 |
+
attention_heads: int = 4,
|
63 |
+
linear_units: int = 2048,
|
64 |
+
num_blocks: int = 6,
|
65 |
+
dropout_rate: float = 0.1,
|
66 |
+
positional_dropout_rate: float = 0.1,
|
67 |
+
self_attention_dropout_rate: float = 0.0,
|
68 |
+
src_attention_dropout_rate: float = 0.0,
|
69 |
+
input_layer: str = "embed",
|
70 |
+
use_output_layer: bool = True,
|
71 |
+
normalize_before: bool = True,
|
72 |
+
src_attention: bool = True,
|
73 |
+
key_bias: bool = True,
|
74 |
+
activation_type: str = "relu",
|
75 |
+
gradient_checkpointing: bool = False,
|
76 |
+
tie_word_embedding: bool = False,
|
77 |
+
):
|
78 |
+
super().__init__()
|
79 |
+
attention_dim = encoder_output_size
|
80 |
+
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
81 |
+
|
82 |
+
self.embed = torch.nn.Sequential(
|
83 |
+
torch.nn.Identity() if input_layer == "no_pos" else
|
84 |
+
torch.nn.Embedding(vocab_size, attention_dim),
|
85 |
+
COSYVOICE_EMB_CLASSES[input_layer](attention_dim,
|
86 |
+
positional_dropout_rate),
|
87 |
+
)
|
88 |
+
|
89 |
+
self.normalize_before = normalize_before
|
90 |
+
self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
|
91 |
+
self.use_output_layer = use_output_layer
|
92 |
+
if use_output_layer:
|
93 |
+
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
|
94 |
+
else:
|
95 |
+
self.output_layer = torch.nn.Identity()
|
96 |
+
self.num_blocks = num_blocks
|
97 |
+
self.decoders = torch.nn.ModuleList([
|
98 |
+
DecoderLayer(
|
99 |
+
attention_dim,
|
100 |
+
COSYVOICE_ATTENTION_CLASSES["selfattn"](
|
101 |
+
attention_heads, attention_dim,
|
102 |
+
self_attention_dropout_rate, key_bias),
|
103 |
+
COSYVOICE_ATTENTION_CLASSES["selfattn"](
|
104 |
+
attention_heads, attention_dim, src_attention_dropout_rate,
|
105 |
+
key_bias) if src_attention else None,
|
106 |
+
PositionwiseFeedForward(attention_dim, linear_units,
|
107 |
+
dropout_rate, activation),
|
108 |
+
dropout_rate,
|
109 |
+
normalize_before,
|
110 |
+
) for _ in range(self.num_blocks)
|
111 |
+
])
|
112 |
+
|
113 |
+
self.gradient_checkpointing = gradient_checkpointing
|
114 |
+
self.tie_word_embedding = tie_word_embedding
|
115 |
+
|
116 |
+
def forward(
|
117 |
+
self,
|
118 |
+
memory: torch.Tensor,
|
119 |
+
memory_mask: torch.Tensor,
|
120 |
+
ys_in_pad: torch.Tensor,
|
121 |
+
ys_in_lens: torch.Tensor,
|
122 |
+
r_ys_in_pad: torch.Tensor = torch.empty(0),
|
123 |
+
reverse_weight: float = 0.0,
|
124 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
125 |
+
"""Forward decoder.
|
126 |
+
Args:
|
127 |
+
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
128 |
+
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
|
129 |
+
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
|
130 |
+
ys_in_lens: input lengths of this batch (batch)
|
131 |
+
r_ys_in_pad: not used in transformer decoder, in order to unify api
|
132 |
+
with bidirectional decoder
|
133 |
+
reverse_weight: not used in transformer decoder, in order to unify
|
134 |
+
api with bidirectional decode
|
135 |
+
Returns:
|
136 |
+
(tuple): tuple containing:
|
137 |
+
x: decoded token score before softmax (batch, maxlen_out,
|
138 |
+
vocab_size) if use_output_layer is True,
|
139 |
+
torch.tensor(0.0), in order to unify api with bidirectional decoder
|
140 |
+
olens: (batch, )
|
141 |
+
NOTE(xcsong):
|
142 |
+
We pass the `__call__` method of the modules instead of `forward` to the
|
143 |
+
checkpointing API because `__call__` attaches all the hooks of the module.
|
144 |
+
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
145 |
+
"""
|
146 |
+
tgt = ys_in_pad
|
147 |
+
maxlen = tgt.size(1)
|
148 |
+
# tgt_mask: (B, 1, L)
|
149 |
+
tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
|
150 |
+
tgt_mask = tgt_mask.to(tgt.device)
|
151 |
+
# m: (1, L, L)
|
152 |
+
m = subsequent_mask(tgt_mask.size(-1),
|
153 |
+
device=tgt_mask.device).unsqueeze(0)
|
154 |
+
# tgt_mask: (B, L, L)
|
155 |
+
tgt_mask = tgt_mask & m
|
156 |
+
x, _ = self.embed(tgt)
|
157 |
+
if self.gradient_checkpointing and self.training:
|
158 |
+
x = self.forward_layers_checkpointed(x, tgt_mask, memory,
|
159 |
+
memory_mask)
|
160 |
+
else:
|
161 |
+
x = self.forward_layers(x, tgt_mask, memory, memory_mask)
|
162 |
+
if self.normalize_before:
|
163 |
+
x = self.after_norm(x)
|
164 |
+
if self.use_output_layer:
|
165 |
+
x = self.output_layer(x)
|
166 |
+
olens = tgt_mask.sum(1)
|
167 |
+
return x, torch.tensor(0.0), olens
|
168 |
+
|
169 |
+
def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
|
170 |
+
memory: torch.Tensor,
|
171 |
+
memory_mask: torch.Tensor) -> torch.Tensor:
|
172 |
+
for layer in self.decoders:
|
173 |
+
x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
|
174 |
+
memory_mask)
|
175 |
+
return x
|
176 |
+
|
177 |
+
@torch.jit.unused
|
178 |
+
def forward_layers_checkpointed(self, x: torch.Tensor,
|
179 |
+
tgt_mask: torch.Tensor,
|
180 |
+
memory: torch.Tensor,
|
181 |
+
memory_mask: torch.Tensor) -> torch.Tensor:
|
182 |
+
for layer in self.decoders:
|
183 |
+
x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
|
184 |
+
layer.__call__, x, tgt_mask, memory, memory_mask)
|
185 |
+
return x
|
186 |
+
|
187 |
+
def forward_one_step(
|
188 |
+
self,
|
189 |
+
memory: torch.Tensor,
|
190 |
+
memory_mask: torch.Tensor,
|
191 |
+
tgt: torch.Tensor,
|
192 |
+
tgt_mask: torch.Tensor,
|
193 |
+
cache: Optional[List[torch.Tensor]] = None,
|
194 |
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
195 |
+
"""Forward one step.
|
196 |
+
This is only used for decoding.
|
197 |
+
Args:
|
198 |
+
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
199 |
+
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
|
200 |
+
tgt: input token ids, int64 (batch, maxlen_out)
|
201 |
+
tgt_mask: input token mask, (batch, maxlen_out)
|
202 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
203 |
+
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
204 |
+
cache: cached output list of (batch, max_time_out-1, size)
|
205 |
+
Returns:
|
206 |
+
y, cache: NN output value and cache per `self.decoders`.
|
207 |
+
y.shape` is (batch, maxlen_out, token)
|
208 |
+
"""
|
209 |
+
x, _ = self.embed(tgt)
|
210 |
+
new_cache = []
|
211 |
+
for i, decoder in enumerate(self.decoders):
|
212 |
+
if cache is None:
|
213 |
+
c = None
|
214 |
+
else:
|
215 |
+
c = cache[i]
|
216 |
+
x, tgt_mask, memory, memory_mask = decoder(x,
|
217 |
+
tgt_mask,
|
218 |
+
memory,
|
219 |
+
memory_mask,
|
220 |
+
cache=c)
|
221 |
+
new_cache.append(x)
|
222 |
+
if self.normalize_before:
|
223 |
+
y = self.after_norm(x[:, -1])
|
224 |
+
else:
|
225 |
+
y = x[:, -1]
|
226 |
+
if self.use_output_layer:
|
227 |
+
y = torch.log_softmax(self.output_layer(y), dim=-1)
|
228 |
+
return y, new_cache
|
229 |
+
|
230 |
+
def tie_or_clone_weights(self, jit_mode: bool = True):
|
231 |
+
"""Tie or clone module weights (between word_emb and output_layer)
|
232 |
+
depending of whether we are using TorchScript or not"""
|
233 |
+
if not self.use_output_layer:
|
234 |
+
return
|
235 |
+
if jit_mode:
|
236 |
+
logging.info("clone emb.weight to output.weight")
|
237 |
+
self.output_layer.weight = torch.nn.Parameter(
|
238 |
+
self.embed[0].weight.clone())
|
239 |
+
else:
|
240 |
+
logging.info("tie emb.weight with output.weight")
|
241 |
+
self.output_layer.weight = self.embed[0].weight
|
242 |
+
|
243 |
+
if getattr(self.output_layer, "bias", None) is not None:
|
244 |
+
self.output_layer.bias.data = torch.nn.functional.pad(
|
245 |
+
self.output_layer.bias.data,
|
246 |
+
(
|
247 |
+
0,
|
248 |
+
self.output_layer.weight.shape[0] -
|
249 |
+
self.output_layer.bias.shape[0],
|
250 |
+
),
|
251 |
+
"constant",
|
252 |
+
0,
|
253 |
+
)
|
254 |
+
|
255 |
+
|
256 |
+
class BiTransformerDecoder(torch.nn.Module):
|
257 |
+
"""Base class of Transfomer decoder module.
|
258 |
+
Args:
|
259 |
+
vocab_size: output dim
|
260 |
+
encoder_output_size: dimension of attention
|
261 |
+
attention_heads: the number of heads of multi head attention
|
262 |
+
linear_units: the hidden units number of position-wise feedforward
|
263 |
+
num_blocks: the number of decoder blocks
|
264 |
+
r_num_blocks: the number of right to left decoder blocks
|
265 |
+
dropout_rate: dropout rate
|
266 |
+
self_attention_dropout_rate: dropout rate for attention
|
267 |
+
input_layer: input layer type
|
268 |
+
use_output_layer: whether to use output layer
|
269 |
+
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
270 |
+
normalize_before:
|
271 |
+
True: use layer_norm before each sub-block of a layer.
|
272 |
+
False: use layer_norm after each sub-block of a layer.
|
273 |
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
274 |
+
"""
|
275 |
+
|
276 |
+
def __init__(
|
277 |
+
self,
|
278 |
+
vocab_size: int,
|
279 |
+
encoder_output_size: int,
|
280 |
+
attention_heads: int = 4,
|
281 |
+
linear_units: int = 2048,
|
282 |
+
num_blocks: int = 6,
|
283 |
+
r_num_blocks: int = 0,
|
284 |
+
dropout_rate: float = 0.1,
|
285 |
+
positional_dropout_rate: float = 0.1,
|
286 |
+
self_attention_dropout_rate: float = 0.0,
|
287 |
+
src_attention_dropout_rate: float = 0.0,
|
288 |
+
input_layer: str = "embed",
|
289 |
+
use_output_layer: bool = True,
|
290 |
+
normalize_before: bool = True,
|
291 |
+
key_bias: bool = True,
|
292 |
+
gradient_checkpointing: bool = False,
|
293 |
+
tie_word_embedding: bool = False,
|
294 |
+
):
|
295 |
+
|
296 |
+
super().__init__()
|
297 |
+
self.tie_word_embedding = tie_word_embedding
|
298 |
+
self.left_decoder = TransformerDecoder(
|
299 |
+
vocab_size,
|
300 |
+
encoder_output_size,
|
301 |
+
attention_heads,
|
302 |
+
linear_units,
|
303 |
+
num_blocks,
|
304 |
+
dropout_rate,
|
305 |
+
positional_dropout_rate,
|
306 |
+
self_attention_dropout_rate,
|
307 |
+
src_attention_dropout_rate,
|
308 |
+
input_layer,
|
309 |
+
use_output_layer,
|
310 |
+
normalize_before,
|
311 |
+
key_bias=key_bias,
|
312 |
+
gradient_checkpointing=gradient_checkpointing,
|
313 |
+
tie_word_embedding=tie_word_embedding)
|
314 |
+
|
315 |
+
self.right_decoder = TransformerDecoder(
|
316 |
+
vocab_size,
|
317 |
+
encoder_output_size,
|
318 |
+
attention_heads,
|
319 |
+
linear_units,
|
320 |
+
r_num_blocks,
|
321 |
+
dropout_rate,
|
322 |
+
positional_dropout_rate,
|
323 |
+
self_attention_dropout_rate,
|
324 |
+
src_attention_dropout_rate,
|
325 |
+
input_layer,
|
326 |
+
use_output_layer,
|
327 |
+
normalize_before,
|
328 |
+
key_bias=key_bias,
|
329 |
+
gradient_checkpointing=gradient_checkpointing,
|
330 |
+
tie_word_embedding=tie_word_embedding)
|
331 |
+
|
332 |
+
def forward(
|
333 |
+
self,
|
334 |
+
memory: torch.Tensor,
|
335 |
+
memory_mask: torch.Tensor,
|
336 |
+
ys_in_pad: torch.Tensor,
|
337 |
+
ys_in_lens: torch.Tensor,
|
338 |
+
r_ys_in_pad: torch.Tensor,
|
339 |
+
reverse_weight: float = 0.0,
|
340 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
341 |
+
"""Forward decoder.
|
342 |
+
Args:
|
343 |
+
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
344 |
+
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
|
345 |
+
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
|
346 |
+
ys_in_lens: input lengths of this batch (batch)
|
347 |
+
r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
|
348 |
+
used for right to left decoder
|
349 |
+
reverse_weight: used for right to left decoder
|
350 |
+
Returns:
|
351 |
+
(tuple): tuple containing:
|
352 |
+
x: decoded token score before softmax (batch, maxlen_out,
|
353 |
+
vocab_size) if use_output_layer is True,
|
354 |
+
r_x: x: decoded token score (right to left decoder)
|
355 |
+
before softmax (batch, maxlen_out, vocab_size)
|
356 |
+
if use_output_layer is True,
|
357 |
+
olens: (batch, )
|
358 |
+
"""
|
359 |
+
l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
|
360 |
+
ys_in_lens)
|
361 |
+
r_x = torch.tensor(0.0)
|
362 |
+
if reverse_weight > 0.0:
|
363 |
+
r_x, _, olens = self.right_decoder(memory, memory_mask,
|
364 |
+
r_ys_in_pad, ys_in_lens)
|
365 |
+
return l_x, r_x, olens
|
366 |
+
|
367 |
+
def forward_one_step(
|
368 |
+
self,
|
369 |
+
memory: torch.Tensor,
|
370 |
+
memory_mask: torch.Tensor,
|
371 |
+
tgt: torch.Tensor,
|
372 |
+
tgt_mask: torch.Tensor,
|
373 |
+
cache: Optional[List[torch.Tensor]] = None,
|
374 |
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
375 |
+
"""Forward one step.
|
376 |
+
This is only used for decoding.
|
377 |
+
Args:
|
378 |
+
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
379 |
+
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
|
380 |
+
tgt: input token ids, int64 (batch, maxlen_out)
|
381 |
+
tgt_mask: input token mask, (batch, maxlen_out)
|
382 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
383 |
+
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
384 |
+
cache: cached output list of (batch, max_time_out-1, size)
|
385 |
+
Returns:
|
386 |
+
y, cache: NN output value and cache per `self.decoders`.
|
387 |
+
y.shape` is (batch, maxlen_out, token)
|
388 |
+
"""
|
389 |
+
return self.left_decoder.forward_one_step(memory, memory_mask, tgt,
|
390 |
+
tgt_mask, cache)
|
391 |
+
|
392 |
+
def tie_or_clone_weights(self, jit_mode: bool = True):
|
393 |
+
"""Tie or clone module weights (between word_emb and output_layer)
|
394 |
+
depending of whether we are using TorchScript or not"""
|
395 |
+
self.left_decoder.tie_or_clone_weights(jit_mode)
|
396 |
+
self.right_decoder.tie_or_clone_weights(jit_mode)
|
cosyvoice/transformer/decoder_layer.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019 Shigeki Karita
|
2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Decoder self-attention layer definition."""
|
16 |
+
from typing import Optional, Tuple
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
|
22 |
+
class DecoderLayer(nn.Module):
|
23 |
+
"""Single decoder layer module.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
size (int): Input dimension.
|
27 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
28 |
+
`MultiHeadedAttention` instance can be used as the argument.
|
29 |
+
src_attn (torch.nn.Module): Inter-attention module instance.
|
30 |
+
`MultiHeadedAttention` instance can be used as the argument.
|
31 |
+
If `None` is passed, Inter-attention is not used, such as
|
32 |
+
CIF, GPT, and other decoder only model.
|
33 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
34 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
35 |
+
dropout_rate (float): Dropout rate.
|
36 |
+
normalize_before (bool):
|
37 |
+
True: use layer_norm before each sub-block.
|
38 |
+
False: to use layer_norm after each sub-block.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
size: int,
|
44 |
+
self_attn: nn.Module,
|
45 |
+
src_attn: Optional[nn.Module],
|
46 |
+
feed_forward: nn.Module,
|
47 |
+
dropout_rate: float,
|
48 |
+
normalize_before: bool = True,
|
49 |
+
):
|
50 |
+
"""Construct an DecoderLayer object."""
|
51 |
+
super().__init__()
|
52 |
+
self.size = size
|
53 |
+
self.self_attn = self_attn
|
54 |
+
self.src_attn = src_attn
|
55 |
+
self.feed_forward = feed_forward
|
56 |
+
self.norm1 = nn.LayerNorm(size, eps=1e-5)
|
57 |
+
self.norm2 = nn.LayerNorm(size, eps=1e-5)
|
58 |
+
self.norm3 = nn.LayerNorm(size, eps=1e-5)
|
59 |
+
self.dropout = nn.Dropout(dropout_rate)
|
60 |
+
self.normalize_before = normalize_before
|
61 |
+
|
62 |
+
def forward(
|
63 |
+
self,
|
64 |
+
tgt: torch.Tensor,
|
65 |
+
tgt_mask: torch.Tensor,
|
66 |
+
memory: torch.Tensor,
|
67 |
+
memory_mask: torch.Tensor,
|
68 |
+
cache: Optional[torch.Tensor] = None
|
69 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
70 |
+
"""Compute decoded features.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
|
74 |
+
tgt_mask (torch.Tensor): Mask for input tensor
|
75 |
+
(#batch, maxlen_out).
|
76 |
+
memory (torch.Tensor): Encoded memory
|
77 |
+
(#batch, maxlen_in, size).
|
78 |
+
memory_mask (torch.Tensor): Encoded memory mask
|
79 |
+
(#batch, maxlen_in).
|
80 |
+
cache (torch.Tensor): cached tensors.
|
81 |
+
(#batch, maxlen_out - 1, size).
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
torch.Tensor: Output tensor (#batch, maxlen_out, size).
|
85 |
+
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
|
86 |
+
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
|
87 |
+
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
|
88 |
+
|
89 |
+
"""
|
90 |
+
residual = tgt
|
91 |
+
if self.normalize_before:
|
92 |
+
tgt = self.norm1(tgt)
|
93 |
+
|
94 |
+
if cache is None:
|
95 |
+
tgt_q = tgt
|
96 |
+
tgt_q_mask = tgt_mask
|
97 |
+
else:
|
98 |
+
# compute only the last frame query keeping dim: max_time_out -> 1
|
99 |
+
assert cache.shape == (
|
100 |
+
tgt.shape[0],
|
101 |
+
tgt.shape[1] - 1,
|
102 |
+
self.size,
|
103 |
+
), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
|
104 |
+
tgt_q = tgt[:, -1:, :]
|
105 |
+
residual = residual[:, -1:, :]
|
106 |
+
tgt_q_mask = tgt_mask[:, -1:, :]
|
107 |
+
|
108 |
+
x = residual + self.dropout(
|
109 |
+
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
|
110 |
+
if not self.normalize_before:
|
111 |
+
x = self.norm1(x)
|
112 |
+
|
113 |
+
if self.src_attn is not None:
|
114 |
+
residual = x
|
115 |
+
if self.normalize_before:
|
116 |
+
x = self.norm2(x)
|
117 |
+
x = residual + self.dropout(
|
118 |
+
self.src_attn(x, memory, memory, memory_mask)[0])
|
119 |
+
if not self.normalize_before:
|
120 |
+
x = self.norm2(x)
|
121 |
+
|
122 |
+
residual = x
|
123 |
+
if self.normalize_before:
|
124 |
+
x = self.norm3(x)
|
125 |
+
x = residual + self.dropout(self.feed_forward(x))
|
126 |
+
if not self.normalize_before:
|
127 |
+
x = self.norm3(x)
|
128 |
+
|
129 |
+
if cache is not None:
|
130 |
+
x = torch.cat([cache, x], dim=1)
|
131 |
+
|
132 |
+
return x, tgt_mask, memory, memory_mask
|
cosyvoice/transformer/embedding.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
2 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""Positonal Encoding Module."""
|
17 |
+
|
18 |
+
import math
|
19 |
+
from typing import Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
|
26 |
+
class PositionalEncoding(torch.nn.Module):
|
27 |
+
"""Positional encoding.
|
28 |
+
|
29 |
+
:param int d_model: embedding dim
|
30 |
+
:param float dropout_rate: dropout rate
|
31 |
+
:param int max_len: maximum input length
|
32 |
+
|
33 |
+
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
|
34 |
+
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self,
|
38 |
+
d_model: int,
|
39 |
+
dropout_rate: float,
|
40 |
+
max_len: int = 5000,
|
41 |
+
reverse: bool = False):
|
42 |
+
"""Construct an PositionalEncoding object."""
|
43 |
+
super().__init__()
|
44 |
+
self.d_model = d_model
|
45 |
+
self.xscale = math.sqrt(self.d_model)
|
46 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
47 |
+
self.max_len = max_len
|
48 |
+
|
49 |
+
self.pe = torch.zeros(self.max_len, self.d_model)
|
50 |
+
position = torch.arange(0, self.max_len,
|
51 |
+
dtype=torch.float32).unsqueeze(1)
|
52 |
+
div_term = torch.exp(
|
53 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32) *
|
54 |
+
-(math.log(10000.0) / self.d_model))
|
55 |
+
self.pe[:, 0::2] = torch.sin(position * div_term)
|
56 |
+
self.pe[:, 1::2] = torch.cos(position * div_term)
|
57 |
+
self.pe = self.pe.unsqueeze(0)
|
58 |
+
|
59 |
+
def forward(self,
|
60 |
+
x: torch.Tensor,
|
61 |
+
offset: Union[int, torch.Tensor] = 0) \
|
62 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
63 |
+
"""Add positional encoding.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
67 |
+
offset (int, torch.tensor): position offset
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
71 |
+
torch.Tensor: for compatibility to RelPositionalEncoding
|
72 |
+
"""
|
73 |
+
|
74 |
+
self.pe = self.pe.to(x.device)
|
75 |
+
pos_emb = self.position_encoding(offset, x.size(1), False)
|
76 |
+
x = x * self.xscale + pos_emb
|
77 |
+
return self.dropout(x), self.dropout(pos_emb)
|
78 |
+
|
79 |
+
def position_encoding(self,
|
80 |
+
offset: Union[int, torch.Tensor],
|
81 |
+
size: int,
|
82 |
+
apply_dropout: bool = True) -> torch.Tensor:
|
83 |
+
""" For getting encoding in a streaming fashion
|
84 |
+
|
85 |
+
Attention!!!!!
|
86 |
+
we apply dropout only once at the whole utterance level in a none
|
87 |
+
streaming way, but will call this function several times with
|
88 |
+
increasing input size in a streaming scenario, so the dropout will
|
89 |
+
be applied several times.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
offset (int or torch.tensor): start offset
|
93 |
+
size (int): required size of position encoding
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
torch.Tensor: Corresponding encoding
|
97 |
+
"""
|
98 |
+
# How to subscript a Union type:
|
99 |
+
# https://github.com/pytorch/pytorch/issues/69434
|
100 |
+
if isinstance(offset, int):
|
101 |
+
assert offset + size <= self.max_len
|
102 |
+
pos_emb = self.pe[:, offset:offset + size]
|
103 |
+
elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
|
104 |
+
assert offset + size <= self.max_len
|
105 |
+
pos_emb = self.pe[:, offset:offset + size]
|
106 |
+
else: # for batched streaming decoding on GPU
|
107 |
+
assert torch.max(offset) + size <= self.max_len
|
108 |
+
index = offset.unsqueeze(1) + \
|
109 |
+
torch.arange(0, size).to(offset.device) # B X T
|
110 |
+
flag = index > 0
|
111 |
+
# remove negative offset
|
112 |
+
index = index * flag
|
113 |
+
pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
|
114 |
+
|
115 |
+
if apply_dropout:
|
116 |
+
pos_emb = self.dropout(pos_emb)
|
117 |
+
return pos_emb
|
118 |
+
|
119 |
+
|
120 |
+
class RelPositionalEncoding(PositionalEncoding):
|
121 |
+
"""Relative positional encoding module.
|
122 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
123 |
+
Args:
|
124 |
+
d_model (int): Embedding dimension.
|
125 |
+
dropout_rate (float): Dropout rate.
|
126 |
+
max_len (int): Maximum input length.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
130 |
+
"""Initialize class."""
|
131 |
+
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
132 |
+
|
133 |
+
def forward(self,
|
134 |
+
x: torch.Tensor,
|
135 |
+
offset: Union[int, torch.Tensor] = 0) \
|
136 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
137 |
+
"""Compute positional encoding.
|
138 |
+
Args:
|
139 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
140 |
+
Returns:
|
141 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
142 |
+
torch.Tensor: Positional embedding tensor (1, time, `*`).
|
143 |
+
"""
|
144 |
+
self.pe = self.pe.to(x.device)
|
145 |
+
x = x * self.xscale
|
146 |
+
pos_emb = self.position_encoding(offset, x.size(1), False)
|
147 |
+
return self.dropout(x), self.dropout(pos_emb)
|
148 |
+
|
149 |
+
|
150 |
+
class WhisperPositionalEncoding(PositionalEncoding):
|
151 |
+
""" Sinusoids position encoding used in openai-whisper.encoder
|
152 |
+
"""
|
153 |
+
|
154 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
|
155 |
+
super().__init__(d_model, dropout_rate, max_len)
|
156 |
+
self.xscale = 1.0
|
157 |
+
log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
|
158 |
+
inv_timescales = torch.exp(-log_timescale_increment *
|
159 |
+
torch.arange(d_model // 2))
|
160 |
+
scaled_time = torch.arange(max_len)[:, np.newaxis] * \
|
161 |
+
inv_timescales[np.newaxis, :]
|
162 |
+
pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
163 |
+
delattr(self, "pe")
|
164 |
+
self.register_buffer("pe", pe.unsqueeze(0))
|
165 |
+
|
166 |
+
|
167 |
+
class LearnablePositionalEncoding(PositionalEncoding):
|
168 |
+
""" Learnable position encoding used in openai-whisper.decoder
|
169 |
+
"""
|
170 |
+
|
171 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
|
172 |
+
super().__init__(d_model, dropout_rate, max_len)
|
173 |
+
# NOTE(xcsong): overwrite self.pe & self.xscale
|
174 |
+
self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
|
175 |
+
self.xscale = 1.0
|
176 |
+
|
177 |
+
|
178 |
+
class NoPositionalEncoding(torch.nn.Module):
|
179 |
+
""" No position encoding
|
180 |
+
"""
|
181 |
+
|
182 |
+
def __init__(self, d_model: int, dropout_rate: float):
|
183 |
+
super().__init__()
|
184 |
+
self.d_model = d_model
|
185 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
186 |
+
|
187 |
+
def forward(self,
|
188 |
+
x: torch.Tensor,
|
189 |
+
offset: Union[int, torch.Tensor] = 0) \
|
190 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
191 |
+
""" Just return zero vector for interface compatibility
|
192 |
+
"""
|
193 |
+
pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
|
194 |
+
return self.dropout(x), pos_emb
|
195 |
+
|
196 |
+
def position_encoding(self, offset: Union[int, torch.Tensor],
|
197 |
+
size: int) -> torch.Tensor:
|
198 |
+
return torch.zeros(1, size, self.d_model)
|
199 |
+
|
200 |
+
|
201 |
+
class EspnetRelPositionalEncoding(torch.nn.Module):
|
202 |
+
"""Relative positional encoding module (new implementation).
|
203 |
+
|
204 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
205 |
+
|
206 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
207 |
+
|
208 |
+
Args:
|
209 |
+
d_model (int): Embedding dimension.
|
210 |
+
dropout_rate (float): Dropout rate.
|
211 |
+
max_len (int): Maximum input length.
|
212 |
+
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
216 |
+
"""Construct an PositionalEncoding object."""
|
217 |
+
super(EspnetRelPositionalEncoding, self).__init__()
|
218 |
+
self.d_model = d_model
|
219 |
+
self.xscale = math.sqrt(self.d_model)
|
220 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
221 |
+
self.pe = None
|
222 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
223 |
+
|
224 |
+
def extend_pe(self, x: torch.Tensor):
|
225 |
+
"""Reset the positional encodings."""
|
226 |
+
if self.pe is not None:
|
227 |
+
# self.pe contains both positive and negative parts
|
228 |
+
# the length of self.pe is 2 * input_len - 1
|
229 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
230 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
231 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
232 |
+
return
|
233 |
+
# Suppose `i` means to the position of query vecotr and `j` means the
|
234 |
+
# position of key vector. We use position relative positions when keys
|
235 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
236 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
237 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
238 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
239 |
+
div_term = torch.exp(
|
240 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
241 |
+
* -(math.log(10000.0) / self.d_model)
|
242 |
+
)
|
243 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
244 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
245 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
246 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
247 |
+
|
248 |
+
# Reserve the order of positive indices and concat both positive and
|
249 |
+
# negative indices. This is used to support the shifting trick
|
250 |
+
# as in https://arxiv.org/abs/1901.02860
|
251 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
252 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
253 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
254 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
255 |
+
|
256 |
+
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
|
257 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
258 |
+
"""Add positional encoding.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
265 |
+
|
266 |
+
"""
|
267 |
+
self.extend_pe(x)
|
268 |
+
x = x * self.xscale
|
269 |
+
pos_emb = self.position_encoding(size=x.size(1), offset=offset)
|
270 |
+
return self.dropout(x), self.dropout(pos_emb)
|
271 |
+
|
272 |
+
def position_encoding(self,
|
273 |
+
offset: Union[int, torch.Tensor],
|
274 |
+
size: int) -> torch.Tensor:
|
275 |
+
""" For getting encoding in a streaming fashion
|
276 |
+
|
277 |
+
Attention!!!!!
|
278 |
+
we apply dropout only once at the whole utterance level in a none
|
279 |
+
streaming way, but will call this function several times with
|
280 |
+
increasing input size in a streaming scenario, so the dropout will
|
281 |
+
be applied several times.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
offset (int or torch.tensor): start offset
|
285 |
+
size (int): required size of position encoding
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
torch.Tensor: Corresponding encoding
|
289 |
+
"""
|
290 |
+
pos_emb = self.pe[
|
291 |
+
:,
|
292 |
+
self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
|
293 |
+
]
|
294 |
+
return pos_emb
|
cosyvoice/transformer/encoder.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
2 |
+
# 2022 Xingchen Song ([email protected])
|
3 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
17 |
+
"""Encoder definition."""
|
18 |
+
from typing import Tuple
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.utils.checkpoint as ckpt
|
22 |
+
|
23 |
+
from cosyvoice.transformer.convolution import ConvolutionModule
|
24 |
+
from cosyvoice.transformer.encoder_layer import TransformerEncoderLayer
|
25 |
+
from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
|
26 |
+
from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
27 |
+
from cosyvoice.utils.class_utils import (
|
28 |
+
COSYVOICE_EMB_CLASSES,
|
29 |
+
COSYVOICE_SUBSAMPLE_CLASSES,
|
30 |
+
COSYVOICE_ATTENTION_CLASSES,
|
31 |
+
COSYVOICE_ACTIVATION_CLASSES,
|
32 |
+
)
|
33 |
+
from cosyvoice.utils.mask import make_pad_mask
|
34 |
+
from cosyvoice.utils.mask import add_optional_chunk_mask
|
35 |
+
|
36 |
+
|
37 |
+
class BaseEncoder(torch.nn.Module):
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
input_size: int,
|
42 |
+
output_size: int = 256,
|
43 |
+
attention_heads: int = 4,
|
44 |
+
linear_units: int = 2048,
|
45 |
+
num_blocks: int = 6,
|
46 |
+
dropout_rate: float = 0.1,
|
47 |
+
positional_dropout_rate: float = 0.1,
|
48 |
+
attention_dropout_rate: float = 0.0,
|
49 |
+
input_layer: str = "conv2d",
|
50 |
+
pos_enc_layer_type: str = "abs_pos",
|
51 |
+
normalize_before: bool = True,
|
52 |
+
static_chunk_size: int = 0,
|
53 |
+
use_dynamic_chunk: bool = False,
|
54 |
+
global_cmvn: torch.nn.Module = None,
|
55 |
+
use_dynamic_left_chunk: bool = False,
|
56 |
+
gradient_checkpointing: bool = False,
|
57 |
+
):
|
58 |
+
"""
|
59 |
+
Args:
|
60 |
+
input_size (int): input dim
|
61 |
+
output_size (int): dimension of attention
|
62 |
+
attention_heads (int): the number of heads of multi head attention
|
63 |
+
linear_units (int): the hidden units number of position-wise feed
|
64 |
+
forward
|
65 |
+
num_blocks (int): the number of decoder blocks
|
66 |
+
dropout_rate (float): dropout rate
|
67 |
+
attention_dropout_rate (float): dropout rate in attention
|
68 |
+
positional_dropout_rate (float): dropout rate after adding
|
69 |
+
positional encoding
|
70 |
+
input_layer (str): input layer type.
|
71 |
+
optional [linear, conv2d, conv2d6, conv2d8]
|
72 |
+
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
73 |
+
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
|
74 |
+
normalize_before (bool):
|
75 |
+
True: use layer_norm before each sub-block of a layer.
|
76 |
+
False: use layer_norm after each sub-block of a layer.
|
77 |
+
static_chunk_size (int): chunk size for static chunk training and
|
78 |
+
decoding
|
79 |
+
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
80 |
+
training or not, You can only use fixed chunk(chunk_size > 0)
|
81 |
+
or dyanmic chunk size(use_dynamic_chunk = True)
|
82 |
+
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
|
83 |
+
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
84 |
+
dynamic chunk training
|
85 |
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
86 |
+
gradient_checkpointing: rerunning a forward-pass segment for each
|
87 |
+
checkpointed segment during backward.
|
88 |
+
"""
|
89 |
+
super().__init__()
|
90 |
+
self._output_size = output_size
|
91 |
+
|
92 |
+
self.global_cmvn = global_cmvn
|
93 |
+
self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
94 |
+
input_size,
|
95 |
+
output_size,
|
96 |
+
dropout_rate,
|
97 |
+
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
98 |
+
positional_dropout_rate),
|
99 |
+
)
|
100 |
+
|
101 |
+
self.normalize_before = normalize_before
|
102 |
+
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
103 |
+
self.static_chunk_size = static_chunk_size
|
104 |
+
self.use_dynamic_chunk = use_dynamic_chunk
|
105 |
+
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
106 |
+
self.gradient_checkpointing = gradient_checkpointing
|
107 |
+
|
108 |
+
def output_size(self) -> int:
|
109 |
+
return self._output_size
|
110 |
+
|
111 |
+
def forward(
|
112 |
+
self,
|
113 |
+
xs: torch.Tensor,
|
114 |
+
xs_lens: torch.Tensor,
|
115 |
+
decoding_chunk_size: int = 0,
|
116 |
+
num_decoding_left_chunks: int = -1,
|
117 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
118 |
+
"""Embed positions in tensor.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
xs: padded input tensor (B, T, D)
|
122 |
+
xs_lens: input length (B)
|
123 |
+
decoding_chunk_size: decoding chunk size for dynamic chunk
|
124 |
+
0: default for training, use random dynamic chunk.
|
125 |
+
<0: for decoding, use full chunk.
|
126 |
+
>0: for decoding, use fixed chunk size as set.
|
127 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
128 |
+
the chunk size is decoding_chunk_size.
|
129 |
+
>=0: use num_decoding_left_chunks
|
130 |
+
<0: use all left chunks
|
131 |
+
Returns:
|
132 |
+
encoder output tensor xs, and subsampled masks
|
133 |
+
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
134 |
+
masks: torch.Tensor batch padding mask after subsample
|
135 |
+
(B, 1, T' ~= T/subsample_rate)
|
136 |
+
NOTE(xcsong):
|
137 |
+
We pass the `__call__` method of the modules instead of `forward` to the
|
138 |
+
checkpointing API because `__call__` attaches all the hooks of the module.
|
139 |
+
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
140 |
+
"""
|
141 |
+
T = xs.size(1)
|
142 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
143 |
+
if self.global_cmvn is not None:
|
144 |
+
xs = self.global_cmvn(xs)
|
145 |
+
xs, pos_emb, masks = self.embed(xs, masks)
|
146 |
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
147 |
+
chunk_masks = add_optional_chunk_mask(xs, masks,
|
148 |
+
self.use_dynamic_chunk,
|
149 |
+
self.use_dynamic_left_chunk,
|
150 |
+
decoding_chunk_size,
|
151 |
+
self.static_chunk_size,
|
152 |
+
num_decoding_left_chunks)
|
153 |
+
if self.gradient_checkpointing and self.training:
|
154 |
+
xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb,
|
155 |
+
mask_pad)
|
156 |
+
else:
|
157 |
+
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
158 |
+
if self.normalize_before:
|
159 |
+
xs = self.after_norm(xs)
|
160 |
+
# Here we assume the mask is not changed in encoder layers, so just
|
161 |
+
# return the masks before encoder layers, and the masks will be used
|
162 |
+
# for cross attention with decoder later
|
163 |
+
return xs, masks
|
164 |
+
|
165 |
+
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
166 |
+
pos_emb: torch.Tensor,
|
167 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
168 |
+
for layer in self.encoders:
|
169 |
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
170 |
+
return xs
|
171 |
+
|
172 |
+
@torch.jit.unused
|
173 |
+
def forward_layers_checkpointed(self, xs: torch.Tensor,
|
174 |
+
chunk_masks: torch.Tensor,
|
175 |
+
pos_emb: torch.Tensor,
|
176 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
177 |
+
for layer in self.encoders:
|
178 |
+
xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
|
179 |
+
chunk_masks, pos_emb,
|
180 |
+
mask_pad)
|
181 |
+
return xs
|
182 |
+
|
183 |
+
@torch.jit.export
|
184 |
+
def forward_chunk(
|
185 |
+
self,
|
186 |
+
xs: torch.Tensor,
|
187 |
+
offset: int,
|
188 |
+
required_cache_size: int,
|
189 |
+
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
190 |
+
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
191 |
+
att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
192 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
193 |
+
""" Forward just one chunk
|
194 |
+
|
195 |
+
Args:
|
196 |
+
xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
|
197 |
+
where `time == (chunk_size - 1) * subsample_rate + \
|
198 |
+
subsample.right_context + 1`
|
199 |
+
offset (int): current offset in encoder output time stamp
|
200 |
+
required_cache_size (int): cache size required for next chunk
|
201 |
+
compuation
|
202 |
+
>=0: actual cache size
|
203 |
+
<0: means all history cache is required
|
204 |
+
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
|
205 |
+
transformer/conformer attention, with shape
|
206 |
+
(elayers, head, cache_t1, d_k * 2), where
|
207 |
+
`head * d_k == hidden-dim` and
|
208 |
+
`cache_t1 == chunk_size * num_decoding_left_chunks`.
|
209 |
+
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
|
210 |
+
(elayers, b=1, hidden-dim, cache_t2), where
|
211 |
+
`cache_t2 == cnn.lorder - 1`
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
torch.Tensor: output of current input xs,
|
215 |
+
with shape (b=1, chunk_size, hidden-dim).
|
216 |
+
torch.Tensor: new attention cache required for next chunk, with
|
217 |
+
dynamic shape (elayers, head, ?, d_k * 2)
|
218 |
+
depending on required_cache_size.
|
219 |
+
torch.Tensor: new conformer cnn cache required for next chunk, with
|
220 |
+
same shape as the original cnn_cache.
|
221 |
+
|
222 |
+
"""
|
223 |
+
assert xs.size(0) == 1
|
224 |
+
# tmp_masks is just for interface compatibility
|
225 |
+
tmp_masks = torch.ones(1,
|
226 |
+
xs.size(1),
|
227 |
+
device=xs.device,
|
228 |
+
dtype=torch.bool)
|
229 |
+
tmp_masks = tmp_masks.unsqueeze(1)
|
230 |
+
if self.global_cmvn is not None:
|
231 |
+
xs = self.global_cmvn(xs)
|
232 |
+
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
|
233 |
+
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
|
234 |
+
# NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
|
235 |
+
elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
|
236 |
+
chunk_size = xs.size(1)
|
237 |
+
attention_key_size = cache_t1 + chunk_size
|
238 |
+
pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
|
239 |
+
size=attention_key_size)
|
240 |
+
if required_cache_size < 0:
|
241 |
+
next_cache_start = 0
|
242 |
+
elif required_cache_size == 0:
|
243 |
+
next_cache_start = attention_key_size
|
244 |
+
else:
|
245 |
+
next_cache_start = max(attention_key_size - required_cache_size, 0)
|
246 |
+
r_att_cache = []
|
247 |
+
r_cnn_cache = []
|
248 |
+
for i, layer in enumerate(self.encoders):
|
249 |
+
# NOTE(xcsong): Before layer.forward
|
250 |
+
# shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
|
251 |
+
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
|
252 |
+
xs, _, new_att_cache, new_cnn_cache = layer(
|
253 |
+
xs,
|
254 |
+
att_mask,
|
255 |
+
pos_emb,
|
256 |
+
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
|
257 |
+
cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache)
|
258 |
+
# NOTE(xcsong): After layer.forward
|
259 |
+
# shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
|
260 |
+
# shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
|
261 |
+
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
|
262 |
+
r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
|
263 |
+
if self.normalize_before:
|
264 |
+
xs = self.after_norm(xs)
|
265 |
+
|
266 |
+
# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
|
267 |
+
# ? may be larger than cache_t1, it depends on required_cache_size
|
268 |
+
r_att_cache = torch.cat(r_att_cache, dim=0)
|
269 |
+
# NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
|
270 |
+
r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
|
271 |
+
|
272 |
+
return (xs, r_att_cache, r_cnn_cache)
|
273 |
+
|
274 |
+
@torch.jit.unused
|
275 |
+
def forward_chunk_by_chunk(
|
276 |
+
self,
|
277 |
+
xs: torch.Tensor,
|
278 |
+
decoding_chunk_size: int,
|
279 |
+
num_decoding_left_chunks: int = -1,
|
280 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
281 |
+
""" Forward input chunk by chunk with chunk_size like a streaming
|
282 |
+
fashion
|
283 |
+
|
284 |
+
Here we should pay special attention to computation cache in the
|
285 |
+
streaming style forward chunk by chunk. Three things should be taken
|
286 |
+
into account for computation in the current network:
|
287 |
+
1. transformer/conformer encoder layers output cache
|
288 |
+
2. convolution in conformer
|
289 |
+
3. convolution in subsampling
|
290 |
+
|
291 |
+
However, we don't implement subsampling cache for:
|
292 |
+
1. We can control subsampling module to output the right result by
|
293 |
+
overlapping input instead of cache left context, even though it
|
294 |
+
wastes some computation, but subsampling only takes a very
|
295 |
+
small fraction of computation in the whole model.
|
296 |
+
2. Typically, there are several covolution layers with subsampling
|
297 |
+
in subsampling module, it is tricky and complicated to do cache
|
298 |
+
with different convolution layers with different subsampling
|
299 |
+
rate.
|
300 |
+
3. Currently, nn.Sequential is used to stack all the convolution
|
301 |
+
layers in subsampling, we need to rewrite it to make it work
|
302 |
+
with cache, which is not preferred.
|
303 |
+
Args:
|
304 |
+
xs (torch.Tensor): (1, max_len, dim)
|
305 |
+
chunk_size (int): decoding chunk size
|
306 |
+
"""
|
307 |
+
assert decoding_chunk_size > 0
|
308 |
+
# The model is trained by static or dynamic chunk
|
309 |
+
assert self.static_chunk_size > 0 or self.use_dynamic_chunk
|
310 |
+
subsampling = self.embed.subsampling_rate
|
311 |
+
context = self.embed.right_context + 1 # Add current frame
|
312 |
+
stride = subsampling * decoding_chunk_size
|
313 |
+
decoding_window = (decoding_chunk_size - 1) * subsampling + context
|
314 |
+
num_frames = xs.size(1)
|
315 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
|
316 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
|
317 |
+
outputs = []
|
318 |
+
offset = 0
|
319 |
+
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
|
320 |
+
|
321 |
+
# Feed forward overlap input step by step
|
322 |
+
for cur in range(0, num_frames - context + 1, stride):
|
323 |
+
end = min(cur + decoding_window, num_frames)
|
324 |
+
chunk_xs = xs[:, cur:end, :]
|
325 |
+
(y, att_cache,
|
326 |
+
cnn_cache) = self.forward_chunk(chunk_xs, offset,
|
327 |
+
required_cache_size, att_cache,
|
328 |
+
cnn_cache)
|
329 |
+
outputs.append(y)
|
330 |
+
offset += y.size(1)
|
331 |
+
ys = torch.cat(outputs, 1)
|
332 |
+
masks = torch.ones((1, 1, ys.size(1)),
|
333 |
+
device=ys.device,
|
334 |
+
dtype=torch.bool)
|
335 |
+
return ys, masks
|
336 |
+
|
337 |
+
|
338 |
+
class TransformerEncoder(BaseEncoder):
|
339 |
+
"""Transformer encoder module."""
|
340 |
+
|
341 |
+
def __init__(
|
342 |
+
self,
|
343 |
+
input_size: int,
|
344 |
+
output_size: int = 256,
|
345 |
+
attention_heads: int = 4,
|
346 |
+
linear_units: int = 2048,
|
347 |
+
num_blocks: int = 6,
|
348 |
+
dropout_rate: float = 0.1,
|
349 |
+
positional_dropout_rate: float = 0.1,
|
350 |
+
attention_dropout_rate: float = 0.0,
|
351 |
+
input_layer: str = "conv2d",
|
352 |
+
pos_enc_layer_type: str = "abs_pos",
|
353 |
+
normalize_before: bool = True,
|
354 |
+
static_chunk_size: int = 0,
|
355 |
+
use_dynamic_chunk: bool = False,
|
356 |
+
global_cmvn: torch.nn.Module = None,
|
357 |
+
use_dynamic_left_chunk: bool = False,
|
358 |
+
key_bias: bool = True,
|
359 |
+
selfattention_layer_type: str = "selfattn",
|
360 |
+
activation_type: str = "relu",
|
361 |
+
gradient_checkpointing: bool = False,
|
362 |
+
):
|
363 |
+
""" Construct TransformerEncoder
|
364 |
+
|
365 |
+
See Encoder for the meaning of each parameter.
|
366 |
+
"""
|
367 |
+
super().__init__(input_size, output_size, attention_heads,
|
368 |
+
linear_units, num_blocks, dropout_rate,
|
369 |
+
positional_dropout_rate, attention_dropout_rate,
|
370 |
+
input_layer, pos_enc_layer_type, normalize_before,
|
371 |
+
static_chunk_size, use_dynamic_chunk, global_cmvn,
|
372 |
+
use_dynamic_left_chunk, gradient_checkpointing)
|
373 |
+
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
374 |
+
self.encoders = torch.nn.ModuleList([
|
375 |
+
TransformerEncoderLayer(
|
376 |
+
output_size,
|
377 |
+
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](attention_heads,
|
378 |
+
output_size,
|
379 |
+
attention_dropout_rate,
|
380 |
+
key_bias),
|
381 |
+
PositionwiseFeedForward(output_size, linear_units,
|
382 |
+
dropout_rate, activation),
|
383 |
+
dropout_rate, normalize_before) for _ in range(num_blocks)
|
384 |
+
])
|
385 |
+
|
386 |
+
|
387 |
+
class ConformerEncoder(BaseEncoder):
|
388 |
+
"""Conformer encoder module."""
|
389 |
+
|
390 |
+
def __init__(
|
391 |
+
self,
|
392 |
+
input_size: int,
|
393 |
+
output_size: int = 256,
|
394 |
+
attention_heads: int = 4,
|
395 |
+
linear_units: int = 2048,
|
396 |
+
num_blocks: int = 6,
|
397 |
+
dropout_rate: float = 0.1,
|
398 |
+
positional_dropout_rate: float = 0.1,
|
399 |
+
attention_dropout_rate: float = 0.0,
|
400 |
+
input_layer: str = "conv2d",
|
401 |
+
pos_enc_layer_type: str = "rel_pos",
|
402 |
+
normalize_before: bool = True,
|
403 |
+
static_chunk_size: int = 0,
|
404 |
+
use_dynamic_chunk: bool = False,
|
405 |
+
global_cmvn: torch.nn.Module = None,
|
406 |
+
use_dynamic_left_chunk: bool = False,
|
407 |
+
positionwise_conv_kernel_size: int = 1,
|
408 |
+
macaron_style: bool = True,
|
409 |
+
selfattention_layer_type: str = "rel_selfattn",
|
410 |
+
activation_type: str = "swish",
|
411 |
+
use_cnn_module: bool = True,
|
412 |
+
cnn_module_kernel: int = 15,
|
413 |
+
causal: bool = False,
|
414 |
+
cnn_module_norm: str = "batch_norm",
|
415 |
+
key_bias: bool = True,
|
416 |
+
gradient_checkpointing: bool = False,
|
417 |
+
):
|
418 |
+
"""Construct ConformerEncoder
|
419 |
+
|
420 |
+
Args:
|
421 |
+
input_size to use_dynamic_chunk, see in BaseEncoder
|
422 |
+
positionwise_conv_kernel_size (int): Kernel size of positionwise
|
423 |
+
conv1d layer.
|
424 |
+
macaron_style (bool): Whether to use macaron style for
|
425 |
+
positionwise layer.
|
426 |
+
selfattention_layer_type (str): Encoder attention layer type,
|
427 |
+
the parameter has no effect now, it's just for configure
|
428 |
+
compatibility.
|
429 |
+
activation_type (str): Encoder activation function type.
|
430 |
+
use_cnn_module (bool): Whether to use convolution module.
|
431 |
+
cnn_module_kernel (int): Kernel size of convolution module.
|
432 |
+
causal (bool): whether to use causal convolution or not.
|
433 |
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
434 |
+
"""
|
435 |
+
super().__init__(input_size, output_size, attention_heads,
|
436 |
+
linear_units, num_blocks, dropout_rate,
|
437 |
+
positional_dropout_rate, attention_dropout_rate,
|
438 |
+
input_layer, pos_enc_layer_type, normalize_before,
|
439 |
+
static_chunk_size, use_dynamic_chunk, global_cmvn,
|
440 |
+
use_dynamic_left_chunk, gradient_checkpointing)
|
441 |
+
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
442 |
+
|
443 |
+
# self-attention module definition
|
444 |
+
encoder_selfattn_layer_args = (
|
445 |
+
attention_heads,
|
446 |
+
output_size,
|
447 |
+
attention_dropout_rate,
|
448 |
+
key_bias,
|
449 |
+
)
|
450 |
+
# feed-forward module definition
|
451 |
+
positionwise_layer_args = (
|
452 |
+
output_size,
|
453 |
+
linear_units,
|
454 |
+
dropout_rate,
|
455 |
+
activation,
|
456 |
+
)
|
457 |
+
# convolution module definition
|
458 |
+
convolution_layer_args = (output_size, cnn_module_kernel, activation,
|
459 |
+
cnn_module_norm, causal)
|
460 |
+
|
461 |
+
self.encoders = torch.nn.ModuleList([
|
462 |
+
ConformerEncoderLayer(
|
463 |
+
output_size,
|
464 |
+
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
465 |
+
*encoder_selfattn_layer_args),
|
466 |
+
PositionwiseFeedForward(*positionwise_layer_args),
|
467 |
+
PositionwiseFeedForward(
|
468 |
+
*positionwise_layer_args) if macaron_style else None,
|
469 |
+
ConvolutionModule(
|
470 |
+
*convolution_layer_args) if use_cnn_module else None,
|
471 |
+
dropout_rate,
|
472 |
+
normalize_before,
|
473 |
+
) for _ in range(num_blocks)
|
474 |
+
])
|
cosyvoice/transformer/encoder_layer.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
2 |
+
# 2022 Xingchen Song ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""Encoder self-attention layer definition."""
|
17 |
+
|
18 |
+
from typing import Optional, Tuple
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
|
24 |
+
class TransformerEncoderLayer(nn.Module):
|
25 |
+
"""Encoder layer module.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
size (int): Input dimension.
|
29 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
30 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
31 |
+
instance can be used as the argument.
|
32 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
33 |
+
`PositionwiseFeedForward`, instance can be used as the argument.
|
34 |
+
dropout_rate (float): Dropout rate.
|
35 |
+
normalize_before (bool):
|
36 |
+
True: use layer_norm before each sub-block.
|
37 |
+
False: to use layer_norm after each sub-block.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
size: int,
|
43 |
+
self_attn: torch.nn.Module,
|
44 |
+
feed_forward: torch.nn.Module,
|
45 |
+
dropout_rate: float,
|
46 |
+
normalize_before: bool = True,
|
47 |
+
):
|
48 |
+
"""Construct an EncoderLayer object."""
|
49 |
+
super().__init__()
|
50 |
+
self.self_attn = self_attn
|
51 |
+
self.feed_forward = feed_forward
|
52 |
+
self.norm1 = nn.LayerNorm(size, eps=1e-12)
|
53 |
+
self.norm2 = nn.LayerNorm(size, eps=1e-12)
|
54 |
+
self.dropout = nn.Dropout(dropout_rate)
|
55 |
+
self.size = size
|
56 |
+
self.normalize_before = normalize_before
|
57 |
+
|
58 |
+
def forward(
|
59 |
+
self,
|
60 |
+
x: torch.Tensor,
|
61 |
+
mask: torch.Tensor,
|
62 |
+
pos_emb: torch.Tensor,
|
63 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
64 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
65 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
66 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
67 |
+
"""Compute encoded features.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
x (torch.Tensor): (#batch, time, size)
|
71 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
72 |
+
(0, 0, 0) means fake mask.
|
73 |
+
pos_emb (torch.Tensor): just for interface compatibility
|
74 |
+
to ConformerEncoderLayer
|
75 |
+
mask_pad (torch.Tensor): does not used in transformer layer,
|
76 |
+
just for unified api with conformer.
|
77 |
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
78 |
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
79 |
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
80 |
+
(#batch=1, size, cache_t2), not used here, it's for interface
|
81 |
+
compatibility to ConformerEncoderLayer.
|
82 |
+
Returns:
|
83 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
84 |
+
torch.Tensor: Mask tensor (#batch, time, time).
|
85 |
+
torch.Tensor: att_cache tensor,
|
86 |
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
87 |
+
torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
|
88 |
+
|
89 |
+
"""
|
90 |
+
residual = x
|
91 |
+
if self.normalize_before:
|
92 |
+
x = self.norm1(x)
|
93 |
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
|
94 |
+
x = residual + self.dropout(x_att)
|
95 |
+
if not self.normalize_before:
|
96 |
+
x = self.norm1(x)
|
97 |
+
|
98 |
+
residual = x
|
99 |
+
if self.normalize_before:
|
100 |
+
x = self.norm2(x)
|
101 |
+
x = residual + self.dropout(self.feed_forward(x))
|
102 |
+
if not self.normalize_before:
|
103 |
+
x = self.norm2(x)
|
104 |
+
|
105 |
+
fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
106 |
+
return x, mask, new_att_cache, fake_cnn_cache
|
107 |
+
|
108 |
+
|
109 |
+
class ConformerEncoderLayer(nn.Module):
|
110 |
+
"""Encoder layer module.
|
111 |
+
Args:
|
112 |
+
size (int): Input dimension.
|
113 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
114 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
115 |
+
instance can be used as the argument.
|
116 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
117 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
118 |
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
|
119 |
+
instance.
|
120 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
121 |
+
conv_module (torch.nn.Module): Convolution module instance.
|
122 |
+
`ConvlutionModule` instance can be used as the argument.
|
123 |
+
dropout_rate (float): Dropout rate.
|
124 |
+
normalize_before (bool):
|
125 |
+
True: use layer_norm before each sub-block.
|
126 |
+
False: use layer_norm after each sub-block.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
size: int,
|
132 |
+
self_attn: torch.nn.Module,
|
133 |
+
feed_forward: Optional[nn.Module] = None,
|
134 |
+
feed_forward_macaron: Optional[nn.Module] = None,
|
135 |
+
conv_module: Optional[nn.Module] = None,
|
136 |
+
dropout_rate: float = 0.1,
|
137 |
+
normalize_before: bool = True,
|
138 |
+
):
|
139 |
+
"""Construct an EncoderLayer object."""
|
140 |
+
super().__init__()
|
141 |
+
self.self_attn = self_attn
|
142 |
+
self.feed_forward = feed_forward
|
143 |
+
self.feed_forward_macaron = feed_forward_macaron
|
144 |
+
self.conv_module = conv_module
|
145 |
+
self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
|
146 |
+
self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
|
147 |
+
if feed_forward_macaron is not None:
|
148 |
+
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
|
149 |
+
self.ff_scale = 0.5
|
150 |
+
else:
|
151 |
+
self.ff_scale = 1.0
|
152 |
+
if self.conv_module is not None:
|
153 |
+
self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
|
154 |
+
self.norm_final = nn.LayerNorm(
|
155 |
+
size, eps=1e-12) # for the final output of the block
|
156 |
+
self.dropout = nn.Dropout(dropout_rate)
|
157 |
+
self.size = size
|
158 |
+
self.normalize_before = normalize_before
|
159 |
+
|
160 |
+
def forward(
|
161 |
+
self,
|
162 |
+
x: torch.Tensor,
|
163 |
+
mask: torch.Tensor,
|
164 |
+
pos_emb: torch.Tensor,
|
165 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
166 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
167 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
168 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
169 |
+
"""Compute encoded features.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
x (torch.Tensor): (#batch, time, size)
|
173 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
174 |
+
(0, 0, 0) means fake mask.
|
175 |
+
pos_emb (torch.Tensor): positional encoding, must not be None
|
176 |
+
for ConformerEncoderLayer.
|
177 |
+
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
178 |
+
(#batch, 1,time), (0, 0, 0) means fake mask.
|
179 |
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
180 |
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
181 |
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
182 |
+
(#batch=1, size, cache_t2)
|
183 |
+
Returns:
|
184 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
185 |
+
torch.Tensor: Mask tensor (#batch, time, time).
|
186 |
+
torch.Tensor: att_cache tensor,
|
187 |
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
188 |
+
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
189 |
+
"""
|
190 |
+
|
191 |
+
# whether to use macaron style
|
192 |
+
if self.feed_forward_macaron is not None:
|
193 |
+
residual = x
|
194 |
+
if self.normalize_before:
|
195 |
+
x = self.norm_ff_macaron(x)
|
196 |
+
x = residual + self.ff_scale * self.dropout(
|
197 |
+
self.feed_forward_macaron(x))
|
198 |
+
if not self.normalize_before:
|
199 |
+
x = self.norm_ff_macaron(x)
|
200 |
+
|
201 |
+
# multi-headed self-attention module
|
202 |
+
residual = x
|
203 |
+
if self.normalize_before:
|
204 |
+
x = self.norm_mha(x)
|
205 |
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
|
206 |
+
att_cache)
|
207 |
+
x = residual + self.dropout(x_att)
|
208 |
+
if not self.normalize_before:
|
209 |
+
x = self.norm_mha(x)
|
210 |
+
|
211 |
+
# convolution module
|
212 |
+
# Fake new cnn cache here, and then change it in conv_module
|
213 |
+
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
214 |
+
if self.conv_module is not None:
|
215 |
+
residual = x
|
216 |
+
if self.normalize_before:
|
217 |
+
x = self.norm_conv(x)
|
218 |
+
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
|
219 |
+
x = residual + self.dropout(x)
|
220 |
+
|
221 |
+
if not self.normalize_before:
|
222 |
+
x = self.norm_conv(x)
|
223 |
+
|
224 |
+
# feed forward module
|
225 |
+
residual = x
|
226 |
+
if self.normalize_before:
|
227 |
+
x = self.norm_ff(x)
|
228 |
+
|
229 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
230 |
+
if not self.normalize_before:
|
231 |
+
x = self.norm_ff(x)
|
232 |
+
|
233 |
+
if self.conv_module is not None:
|
234 |
+
x = self.norm_final(x)
|
235 |
+
|
236 |
+
return x, mask, new_att_cache, new_cnn_cache
|
cosyvoice/transformer/label_smoothing_loss.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019 Shigeki Karita
|
2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Label smoothing module."""
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
|
21 |
+
class LabelSmoothingLoss(nn.Module):
|
22 |
+
"""Label-smoothing loss.
|
23 |
+
|
24 |
+
In a standard CE loss, the label's data distribution is:
|
25 |
+
[0,1,2] ->
|
26 |
+
[
|
27 |
+
[1.0, 0.0, 0.0],
|
28 |
+
[0.0, 1.0, 0.0],
|
29 |
+
[0.0, 0.0, 1.0],
|
30 |
+
]
|
31 |
+
|
32 |
+
In the smoothing version CE Loss,some probabilities
|
33 |
+
are taken from the true label prob (1.0) and are divided
|
34 |
+
among other labels.
|
35 |
+
|
36 |
+
e.g.
|
37 |
+
smoothing=0.1
|
38 |
+
[0,1,2] ->
|
39 |
+
[
|
40 |
+
[0.9, 0.05, 0.05],
|
41 |
+
[0.05, 0.9, 0.05],
|
42 |
+
[0.05, 0.05, 0.9],
|
43 |
+
]
|
44 |
+
|
45 |
+
Args:
|
46 |
+
size (int): the number of class
|
47 |
+
padding_idx (int): padding class id which will be ignored for loss
|
48 |
+
smoothing (float): smoothing rate (0.0 means the conventional CE)
|
49 |
+
normalize_length (bool):
|
50 |
+
normalize loss by sequence length if True
|
51 |
+
normalize loss by batch size if False
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(self,
|
55 |
+
size: int,
|
56 |
+
padding_idx: int,
|
57 |
+
smoothing: float,
|
58 |
+
normalize_length: bool = False):
|
59 |
+
"""Construct an LabelSmoothingLoss object."""
|
60 |
+
super(LabelSmoothingLoss, self).__init__()
|
61 |
+
self.criterion = nn.KLDivLoss(reduction="none")
|
62 |
+
self.padding_idx = padding_idx
|
63 |
+
self.confidence = 1.0 - smoothing
|
64 |
+
self.smoothing = smoothing
|
65 |
+
self.size = size
|
66 |
+
self.normalize_length = normalize_length
|
67 |
+
|
68 |
+
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
69 |
+
"""Compute loss between x and target.
|
70 |
+
|
71 |
+
The model outputs and data labels tensors are flatten to
|
72 |
+
(batch*seqlen, class) shape and a mask is applied to the
|
73 |
+
padding part which should not be calculated for loss.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
x (torch.Tensor): prediction (batch, seqlen, class)
|
77 |
+
target (torch.Tensor):
|
78 |
+
target signal masked with self.padding_id (batch, seqlen)
|
79 |
+
Returns:
|
80 |
+
loss (torch.Tensor) : The KL loss, scalar float value
|
81 |
+
"""
|
82 |
+
assert x.size(2) == self.size
|
83 |
+
batch_size = x.size(0)
|
84 |
+
x = x.view(-1, self.size)
|
85 |
+
target = target.view(-1)
|
86 |
+
# use zeros_like instead of torch.no_grad() for true_dist,
|
87 |
+
# since no_grad() can not be exported by JIT
|
88 |
+
true_dist = torch.zeros_like(x)
|
89 |
+
true_dist.fill_(self.smoothing / (self.size - 1))
|
90 |
+
ignore = target == self.padding_idx # (B,)
|
91 |
+
total = len(target) - ignore.sum().item()
|
92 |
+
target = target.masked_fill(ignore, 0) # avoid -1 index
|
93 |
+
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
|
94 |
+
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
|
95 |
+
denom = total if self.normalize_length else batch_size
|
96 |
+
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
|
cosyvoice/transformer/positionwise_feed_forward.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019 Shigeki Karita
|
2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Positionwise feed forward layer definition."""
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
|
20 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
21 |
+
"""Positionwise feed forward layer.
|
22 |
+
|
23 |
+
FeedForward are appied on each position of the sequence.
|
24 |
+
The output dim is same with the input dim.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
idim (int): Input dimenstion.
|
28 |
+
hidden_units (int): The number of hidden units.
|
29 |
+
dropout_rate (float): Dropout rate.
|
30 |
+
activation (torch.nn.Module): Activation function
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
idim: int,
|
36 |
+
hidden_units: int,
|
37 |
+
dropout_rate: float,
|
38 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
39 |
+
):
|
40 |
+
"""Construct a PositionwiseFeedForward object."""
|
41 |
+
super(PositionwiseFeedForward, self).__init__()
|
42 |
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
43 |
+
self.activation = activation
|
44 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
45 |
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
46 |
+
|
47 |
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
48 |
+
"""Forward function.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
xs: input tensor (B, L, D)
|
52 |
+
Returns:
|
53 |
+
output tensor, (B, L, D)
|
54 |
+
"""
|
55 |
+
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
56 |
+
|
57 |
+
|
58 |
+
class MoEFFNLayer(torch.nn.Module):
|
59 |
+
"""
|
60 |
+
Mixture of expert with Positionwise feed forward layer
|
61 |
+
See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
|
62 |
+
The output dim is same with the input dim.
|
63 |
+
|
64 |
+
Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
|
65 |
+
https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
|
66 |
+
Args:
|
67 |
+
n_expert: number of expert.
|
68 |
+
n_expert_per_token: The actual number of experts used for each frame
|
69 |
+
idim (int): Input dimenstion.
|
70 |
+
hidden_units (int): The number of hidden units.
|
71 |
+
dropout_rate (float): Dropout rate.
|
72 |
+
activation (torch.nn.Module): Activation function
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
n_expert: int,
|
78 |
+
n_expert_per_token: int,
|
79 |
+
idim: int,
|
80 |
+
hidden_units: int,
|
81 |
+
dropout_rate: float,
|
82 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
83 |
+
):
|
84 |
+
super(MoEFFNLayer, self).__init__()
|
85 |
+
self.gate = torch.nn.Linear(idim, n_expert, bias=False)
|
86 |
+
self.experts = torch.nn.ModuleList(
|
87 |
+
PositionwiseFeedForward(idim, hidden_units, dropout_rate,
|
88 |
+
activation) for _ in range(n_expert))
|
89 |
+
self.n_expert_per_token = n_expert_per_token
|
90 |
+
|
91 |
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
92 |
+
"""Foward function.
|
93 |
+
Args:
|
94 |
+
xs: input tensor (B, L, D)
|
95 |
+
Returns:
|
96 |
+
output tensor, (B, L, D)
|
97 |
+
|
98 |
+
"""
|
99 |
+
B, L, D = xs.size(
|
100 |
+
) # batch size, sequence length, embedding dimension (idim)
|
101 |
+
xs = xs.view(-1, D) # (B*L, D)
|
102 |
+
router = self.gate(xs) # (B*L, n_expert)
|
103 |
+
logits, indices = torch.topk(
|
104 |
+
router, self.n_expert_per_token
|
105 |
+
) # probs:(B*L, n_expert), indices: (B*L, n_expert)
|
106 |
+
weights = torch.nn.functional.softmax(
|
107 |
+
logits, dim=1,
|
108 |
+
dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
|
109 |
+
output = torch.zeros_like(xs) # (B*L, D)
|
110 |
+
for i, expert in enumerate(self.experts):
|
111 |
+
mask = indices == i
|
112 |
+
batch_idx, ith_expert = torch.where(mask)
|
113 |
+
output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
|
114 |
+
xs[batch_idx])
|
115 |
+
return output.view(B, L, D)
|
cosyvoice/transformer/subsampling.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
2 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""Subsampling layer definition."""
|
17 |
+
|
18 |
+
from typing import Tuple, Union
|
19 |
+
|
20 |
+
import torch
|
21 |
+
|
22 |
+
|
23 |
+
class BaseSubsampling(torch.nn.Module):
|
24 |
+
|
25 |
+
def __init__(self):
|
26 |
+
super().__init__()
|
27 |
+
self.right_context = 0
|
28 |
+
self.subsampling_rate = 1
|
29 |
+
|
30 |
+
def position_encoding(self, offset: Union[int, torch.Tensor],
|
31 |
+
size: int) -> torch.Tensor:
|
32 |
+
return self.pos_enc.position_encoding(offset, size)
|
33 |
+
|
34 |
+
|
35 |
+
class EmbedinigNoSubsampling(BaseSubsampling):
|
36 |
+
"""Embedding input without subsampling
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
40 |
+
pos_enc_class: torch.nn.Module):
|
41 |
+
super().__init__()
|
42 |
+
self.embed = torch.nn.Embedding(idim, odim)
|
43 |
+
self.pos_enc = pos_enc_class
|
44 |
+
|
45 |
+
def forward(
|
46 |
+
self,
|
47 |
+
x: torch.Tensor,
|
48 |
+
x_mask: torch.Tensor,
|
49 |
+
offset: Union[int, torch.Tensor] = 0
|
50 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
51 |
+
"""Input x.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
55 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
59 |
+
where time' = time .
|
60 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
61 |
+
where time' = time .
|
62 |
+
|
63 |
+
"""
|
64 |
+
x = self.embed(x)
|
65 |
+
x, pos_emb = self.pos_enc(x, offset)
|
66 |
+
return x, pos_emb, x_mask
|
67 |
+
|
68 |
+
|
69 |
+
class LinearNoSubsampling(BaseSubsampling):
|
70 |
+
"""Linear transform the input without subsampling
|
71 |
+
|
72 |
+
Args:
|
73 |
+
idim (int): Input dimension.
|
74 |
+
odim (int): Output dimension.
|
75 |
+
dropout_rate (float): Dropout rate.
|
76 |
+
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
80 |
+
pos_enc_class: torch.nn.Module):
|
81 |
+
"""Construct an linear object."""
|
82 |
+
super().__init__()
|
83 |
+
self.out = torch.nn.Sequential(
|
84 |
+
torch.nn.Linear(idim, odim),
|
85 |
+
torch.nn.LayerNorm(odim, eps=1e-5),
|
86 |
+
torch.nn.Dropout(dropout_rate),
|
87 |
+
)
|
88 |
+
self.pos_enc = pos_enc_class
|
89 |
+
self.right_context = 0
|
90 |
+
self.subsampling_rate = 1
|
91 |
+
|
92 |
+
def forward(
|
93 |
+
self,
|
94 |
+
x: torch.Tensor,
|
95 |
+
x_mask: torch.Tensor,
|
96 |
+
offset: Union[int, torch.Tensor] = 0
|
97 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
98 |
+
"""Input x.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
102 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
106 |
+
where time' = time .
|
107 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
108 |
+
where time' = time .
|
109 |
+
|
110 |
+
"""
|
111 |
+
x = self.out(x)
|
112 |
+
x, pos_emb = self.pos_enc(x, offset)
|
113 |
+
return x, pos_emb, x_mask
|
114 |
+
|
115 |
+
|
116 |
+
class Conv1dSubsampling2(BaseSubsampling):
|
117 |
+
"""Convolutional 1D subsampling (to 1/2 length).
|
118 |
+
It is designed for Whisper, ref:
|
119 |
+
https://github.com/openai/whisper/blob/main/whisper/model.py
|
120 |
+
|
121 |
+
Args:
|
122 |
+
idim (int): Input dimension.
|
123 |
+
odim (int): Output dimension.
|
124 |
+
dropout_rate (float): Dropout rate.
|
125 |
+
|
126 |
+
"""
|
127 |
+
|
128 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
129 |
+
pos_enc_class: torch.nn.Module):
|
130 |
+
"""Construct an Conv1dSubsampling2 object."""
|
131 |
+
super().__init__()
|
132 |
+
self.conv = torch.nn.Sequential(
|
133 |
+
torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
|
134 |
+
torch.nn.GELU(),
|
135 |
+
torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
|
136 |
+
torch.nn.GELU(),
|
137 |
+
)
|
138 |
+
self.pos_enc = pos_enc_class
|
139 |
+
# The right context for every conv layer is computed by:
|
140 |
+
# (kernel_size - 1) * frame_rate_of_this_layer
|
141 |
+
self.subsampling_rate = 2
|
142 |
+
# 4 = (3 - 1) * 1 + (3 - 1) * 1
|
143 |
+
self.right_context = 4
|
144 |
+
|
145 |
+
def forward(
|
146 |
+
self,
|
147 |
+
x: torch.Tensor,
|
148 |
+
x_mask: torch.Tensor,
|
149 |
+
offset: Union[int, torch.Tensor] = 0
|
150 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
151 |
+
"""Subsample x.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
155 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
159 |
+
where time' = time // 2.
|
160 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
161 |
+
where time' = time // 2.
|
162 |
+
torch.Tensor: positional encoding
|
163 |
+
|
164 |
+
"""
|
165 |
+
time = x.size(1)
|
166 |
+
x = x.transpose(1, 2) # (b, f, t)
|
167 |
+
x = self.conv(x)
|
168 |
+
x = x.transpose(1, 2) # (b, t, f)
|
169 |
+
x, pos_emb = self.pos_enc(x, offset)
|
170 |
+
return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
|
171 |
+
|
172 |
+
|
173 |
+
class Conv2dSubsampling4(BaseSubsampling):
|
174 |
+
"""Convolutional 2D subsampling (to 1/4 length).
|
175 |
+
|
176 |
+
Args:
|
177 |
+
idim (int): Input dimension.
|
178 |
+
odim (int): Output dimension.
|
179 |
+
dropout_rate (float): Dropout rate.
|
180 |
+
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
184 |
+
pos_enc_class: torch.nn.Module):
|
185 |
+
"""Construct an Conv2dSubsampling4 object."""
|
186 |
+
super().__init__()
|
187 |
+
self.conv = torch.nn.Sequential(
|
188 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
189 |
+
torch.nn.ReLU(),
|
190 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
191 |
+
torch.nn.ReLU(),
|
192 |
+
)
|
193 |
+
self.out = torch.nn.Sequential(
|
194 |
+
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
|
195 |
+
self.pos_enc = pos_enc_class
|
196 |
+
# The right context for every conv layer is computed by:
|
197 |
+
# (kernel_size - 1) * frame_rate_of_this_layer
|
198 |
+
self.subsampling_rate = 4
|
199 |
+
# 6 = (3 - 1) * 1 + (3 - 1) * 2
|
200 |
+
self.right_context = 6
|
201 |
+
|
202 |
+
def forward(
|
203 |
+
self,
|
204 |
+
x: torch.Tensor,
|
205 |
+
x_mask: torch.Tensor,
|
206 |
+
offset: Union[int, torch.Tensor] = 0
|
207 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
208 |
+
"""Subsample x.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
212 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
216 |
+
where time' = time // 4.
|
217 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
218 |
+
where time' = time // 4.
|
219 |
+
torch.Tensor: positional encoding
|
220 |
+
|
221 |
+
"""
|
222 |
+
x = x.unsqueeze(1) # (b, c=1, t, f)
|
223 |
+
x = self.conv(x)
|
224 |
+
b, c, t, f = x.size()
|
225 |
+
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
226 |
+
x, pos_emb = self.pos_enc(x, offset)
|
227 |
+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
|
228 |
+
|
229 |
+
|
230 |
+
class Conv2dSubsampling6(BaseSubsampling):
|
231 |
+
"""Convolutional 2D subsampling (to 1/6 length).
|
232 |
+
Args:
|
233 |
+
idim (int): Input dimension.
|
234 |
+
odim (int): Output dimension.
|
235 |
+
dropout_rate (float): Dropout rate.
|
236 |
+
pos_enc (torch.nn.Module): Custom position encoding layer.
|
237 |
+
"""
|
238 |
+
|
239 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
240 |
+
pos_enc_class: torch.nn.Module):
|
241 |
+
"""Construct an Conv2dSubsampling6 object."""
|
242 |
+
super().__init__()
|
243 |
+
self.conv = torch.nn.Sequential(
|
244 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
245 |
+
torch.nn.ReLU(),
|
246 |
+
torch.nn.Conv2d(odim, odim, 5, 3),
|
247 |
+
torch.nn.ReLU(),
|
248 |
+
)
|
249 |
+
self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
|
250 |
+
odim)
|
251 |
+
self.pos_enc = pos_enc_class
|
252 |
+
# 10 = (3 - 1) * 1 + (5 - 1) * 2
|
253 |
+
self.subsampling_rate = 6
|
254 |
+
self.right_context = 10
|
255 |
+
|
256 |
+
def forward(
|
257 |
+
self,
|
258 |
+
x: torch.Tensor,
|
259 |
+
x_mask: torch.Tensor,
|
260 |
+
offset: Union[int, torch.Tensor] = 0
|
261 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
262 |
+
"""Subsample x.
|
263 |
+
Args:
|
264 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
265 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
269 |
+
where time' = time // 6.
|
270 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
271 |
+
where time' = time // 6.
|
272 |
+
torch.Tensor: positional encoding
|
273 |
+
"""
|
274 |
+
x = x.unsqueeze(1) # (b, c, t, f)
|
275 |
+
x = self.conv(x)
|
276 |
+
b, c, t, f = x.size()
|
277 |
+
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
278 |
+
x, pos_emb = self.pos_enc(x, offset)
|
279 |
+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
|
280 |
+
|
281 |
+
|
282 |
+
class Conv2dSubsampling8(BaseSubsampling):
|
283 |
+
"""Convolutional 2D subsampling (to 1/8 length).
|
284 |
+
|
285 |
+
Args:
|
286 |
+
idim (int): Input dimension.
|
287 |
+
odim (int): Output dimension.
|
288 |
+
dropout_rate (float): Dropout rate.
|
289 |
+
|
290 |
+
"""
|
291 |
+
|
292 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
293 |
+
pos_enc_class: torch.nn.Module):
|
294 |
+
"""Construct an Conv2dSubsampling8 object."""
|
295 |
+
super().__init__()
|
296 |
+
self.conv = torch.nn.Sequential(
|
297 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
298 |
+
torch.nn.ReLU(),
|
299 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
300 |
+
torch.nn.ReLU(),
|
301 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
302 |
+
torch.nn.ReLU(),
|
303 |
+
)
|
304 |
+
self.linear = torch.nn.Linear(
|
305 |
+
odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
|
306 |
+
self.pos_enc = pos_enc_class
|
307 |
+
self.subsampling_rate = 8
|
308 |
+
# 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
|
309 |
+
self.right_context = 14
|
310 |
+
|
311 |
+
def forward(
|
312 |
+
self,
|
313 |
+
x: torch.Tensor,
|
314 |
+
x_mask: torch.Tensor,
|
315 |
+
offset: Union[int, torch.Tensor] = 0
|
316 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
317 |
+
"""Subsample x.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
321 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
325 |
+
where time' = time // 8.
|
326 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
327 |
+
where time' = time // 8.
|
328 |
+
torch.Tensor: positional encoding
|
329 |
+
"""
|
330 |
+
x = x.unsqueeze(1) # (b, c, t, f)
|
331 |
+
x = self.conv(x)
|
332 |
+
b, c, t, f = x.size()
|
333 |
+
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
334 |
+
x, pos_emb = self.pos_enc(x, offset)
|
335 |
+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
|
336 |
+
|
337 |
+
|
338 |
+
class LegacyLinearNoSubsampling(BaseSubsampling):
|
339 |
+
"""Linear transform the input without subsampling
|
340 |
+
|
341 |
+
Args:
|
342 |
+
idim (int): Input dimension.
|
343 |
+
odim (int): Output dimension.
|
344 |
+
dropout_rate (float): Dropout rate.
|
345 |
+
|
346 |
+
"""
|
347 |
+
|
348 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
349 |
+
pos_enc_class: torch.nn.Module):
|
350 |
+
"""Construct an linear object."""
|
351 |
+
super().__init__()
|
352 |
+
self.out = torch.nn.Sequential(
|
353 |
+
torch.nn.Linear(idim, odim),
|
354 |
+
torch.nn.LayerNorm(odim, eps=1e-5),
|
355 |
+
torch.nn.Dropout(dropout_rate),
|
356 |
+
torch.nn.ReLU(),
|
357 |
+
)
|
358 |
+
self.pos_enc = pos_enc_class
|
359 |
+
self.right_context = 0
|
360 |
+
self.subsampling_rate = 1
|
361 |
+
|
362 |
+
def forward(
|
363 |
+
self,
|
364 |
+
x: torch.Tensor,
|
365 |
+
x_mask: torch.Tensor,
|
366 |
+
offset: Union[int, torch.Tensor] = 0
|
367 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
368 |
+
"""Input x.
|
369 |
+
|
370 |
+
Args:
|
371 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
372 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
373 |
+
|
374 |
+
Returns:
|
375 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
376 |
+
where time' = time .
|
377 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
378 |
+
where time' = time .
|
379 |
+
|
380 |
+
"""
|
381 |
+
x = self.out(x)
|
382 |
+
x, pos_emb = self.pos_enc(x, offset)
|
383 |
+
return x, pos_emb, x_mask
|
cosyvoice/transformer/upsample_encoder.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
2 |
+
# 2022 Xingchen Song ([email protected])
|
3 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
17 |
+
"""Encoder definition."""
|
18 |
+
from typing import Tuple
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch import nn
|
22 |
+
import torch.utils.checkpoint as ckpt
|
23 |
+
from torch.nn import functional as F
|
24 |
+
|
25 |
+
from cosyvoice.transformer.convolution import ConvolutionModule
|
26 |
+
from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
|
27 |
+
from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
28 |
+
from cosyvoice.utils.class_utils import (
|
29 |
+
COSYVOICE_EMB_CLASSES,
|
30 |
+
COSYVOICE_SUBSAMPLE_CLASSES,
|
31 |
+
COSYVOICE_ATTENTION_CLASSES,
|
32 |
+
COSYVOICE_ACTIVATION_CLASSES,
|
33 |
+
)
|
34 |
+
from cosyvoice.utils.mask import make_pad_mask
|
35 |
+
from cosyvoice.utils.mask import add_optional_chunk_mask
|
36 |
+
|
37 |
+
|
38 |
+
class Upsample1D(nn.Module):
|
39 |
+
"""A 1D upsampling layer with an optional convolution.
|
40 |
+
|
41 |
+
Parameters:
|
42 |
+
channels (`int`):
|
43 |
+
number of channels in the inputs and outputs.
|
44 |
+
use_conv (`bool`, default `False`):
|
45 |
+
option to use a convolution.
|
46 |
+
use_conv_transpose (`bool`, default `False`):
|
47 |
+
option to use a convolution transpose.
|
48 |
+
out_channels (`int`, optional):
|
49 |
+
number of output channels. Defaults to `channels`.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, channels: int, out_channels: int, stride: int=2):
|
53 |
+
super().__init__()
|
54 |
+
self.channels = channels
|
55 |
+
self.out_channels = out_channels
|
56 |
+
self.stride = stride
|
57 |
+
# In this mode, first repeat interpolate, than conv with stride=1
|
58 |
+
self.conv = nn.Conv1d(
|
59 |
+
self.channels, self.out_channels, stride*2+1, stride=1,
|
60 |
+
padding=0,
|
61 |
+
)
|
62 |
+
|
63 |
+
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
|
64 |
+
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
|
65 |
+
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
|
66 |
+
outputs = self.conv(outputs)
|
67 |
+
return outputs, input_lengths * self.stride
|
68 |
+
|
69 |
+
|
70 |
+
class PreLookaheadLayer(nn.Module):
|
71 |
+
def __init__(self, channels: int, pre_lookahead_len: int = 1):
|
72 |
+
super().__init__()
|
73 |
+
self.channels = channels
|
74 |
+
self.pre_lookahead_len = pre_lookahead_len
|
75 |
+
self.conv1 = nn.Conv1d(
|
76 |
+
channels, channels,
|
77 |
+
kernel_size=pre_lookahead_len+1,
|
78 |
+
stride=1, padding=0,
|
79 |
+
)
|
80 |
+
self.conv2 = nn.Conv1d(
|
81 |
+
channels, channels,
|
82 |
+
kernel_size=3, stride=1, padding=0,
|
83 |
+
)
|
84 |
+
|
85 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
86 |
+
"""
|
87 |
+
inputs: (batch_size, seq_len, channels)
|
88 |
+
"""
|
89 |
+
outputs = inputs.transpose(1, 2).contiguous()
|
90 |
+
# look ahead
|
91 |
+
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
|
92 |
+
outputs = F.leaky_relu(self.conv1(outputs))
|
93 |
+
# outputs
|
94 |
+
outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
|
95 |
+
outputs = self.conv2(outputs)
|
96 |
+
outputs = outputs.transpose(1, 2).contiguous()
|
97 |
+
|
98 |
+
# residual connection
|
99 |
+
outputs = outputs + inputs
|
100 |
+
return outputs
|
101 |
+
|
102 |
+
|
103 |
+
class UpsampleConformerEncoder(torch.nn.Module):
|
104 |
+
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
input_size: int,
|
108 |
+
output_size: int = 256,
|
109 |
+
attention_heads: int = 4,
|
110 |
+
linear_units: int = 2048,
|
111 |
+
num_blocks: int = 6,
|
112 |
+
dropout_rate: float = 0.1,
|
113 |
+
positional_dropout_rate: float = 0.1,
|
114 |
+
attention_dropout_rate: float = 0.0,
|
115 |
+
input_layer: str = "conv2d",
|
116 |
+
pos_enc_layer_type: str = "rel_pos",
|
117 |
+
normalize_before: bool = True,
|
118 |
+
static_chunk_size: int = 0,
|
119 |
+
use_dynamic_chunk: bool = False,
|
120 |
+
global_cmvn: torch.nn.Module = None,
|
121 |
+
use_dynamic_left_chunk: bool = False,
|
122 |
+
positionwise_conv_kernel_size: int = 1,
|
123 |
+
macaron_style: bool = True,
|
124 |
+
selfattention_layer_type: str = "rel_selfattn",
|
125 |
+
activation_type: str = "swish",
|
126 |
+
use_cnn_module: bool = True,
|
127 |
+
cnn_module_kernel: int = 15,
|
128 |
+
causal: bool = False,
|
129 |
+
cnn_module_norm: str = "batch_norm",
|
130 |
+
key_bias: bool = True,
|
131 |
+
gradient_checkpointing: bool = False,
|
132 |
+
):
|
133 |
+
"""
|
134 |
+
Args:
|
135 |
+
input_size (int): input dim
|
136 |
+
output_size (int): dimension of attention
|
137 |
+
attention_heads (int): the number of heads of multi head attention
|
138 |
+
linear_units (int): the hidden units number of position-wise feed
|
139 |
+
forward
|
140 |
+
num_blocks (int): the number of decoder blocks
|
141 |
+
dropout_rate (float): dropout rate
|
142 |
+
attention_dropout_rate (float): dropout rate in attention
|
143 |
+
positional_dropout_rate (float): dropout rate after adding
|
144 |
+
positional encoding
|
145 |
+
input_layer (str): input layer type.
|
146 |
+
optional [linear, conv2d, conv2d6, conv2d8]
|
147 |
+
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
148 |
+
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
|
149 |
+
normalize_before (bool):
|
150 |
+
True: use layer_norm before each sub-block of a layer.
|
151 |
+
False: use layer_norm after each sub-block of a layer.
|
152 |
+
static_chunk_size (int): chunk size for static chunk training and
|
153 |
+
decoding
|
154 |
+
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
155 |
+
training or not, You can only use fixed chunk(chunk_size > 0)
|
156 |
+
or dyanmic chunk size(use_dynamic_chunk = True)
|
157 |
+
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
|
158 |
+
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
159 |
+
dynamic chunk training
|
160 |
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
161 |
+
gradient_checkpointing: rerunning a forward-pass segment for each
|
162 |
+
checkpointed segment during backward.
|
163 |
+
"""
|
164 |
+
super().__init__()
|
165 |
+
self._output_size = output_size
|
166 |
+
|
167 |
+
self.global_cmvn = global_cmvn
|
168 |
+
self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
169 |
+
input_size,
|
170 |
+
output_size,
|
171 |
+
dropout_rate,
|
172 |
+
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
173 |
+
positional_dropout_rate),
|
174 |
+
)
|
175 |
+
|
176 |
+
self.normalize_before = normalize_before
|
177 |
+
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
178 |
+
self.static_chunk_size = static_chunk_size
|
179 |
+
self.use_dynamic_chunk = use_dynamic_chunk
|
180 |
+
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
181 |
+
self.gradient_checkpointing = gradient_checkpointing
|
182 |
+
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
183 |
+
# self-attention module definition
|
184 |
+
encoder_selfattn_layer_args = (
|
185 |
+
attention_heads,
|
186 |
+
output_size,
|
187 |
+
attention_dropout_rate,
|
188 |
+
key_bias,
|
189 |
+
)
|
190 |
+
# feed-forward module definition
|
191 |
+
positionwise_layer_args = (
|
192 |
+
output_size,
|
193 |
+
linear_units,
|
194 |
+
dropout_rate,
|
195 |
+
activation,
|
196 |
+
)
|
197 |
+
# convolution module definition
|
198 |
+
convolution_layer_args = (output_size, cnn_module_kernel, activation,
|
199 |
+
cnn_module_norm, causal)
|
200 |
+
self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
|
201 |
+
self.encoders = torch.nn.ModuleList([
|
202 |
+
ConformerEncoderLayer(
|
203 |
+
output_size,
|
204 |
+
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
205 |
+
*encoder_selfattn_layer_args),
|
206 |
+
PositionwiseFeedForward(*positionwise_layer_args),
|
207 |
+
PositionwiseFeedForward(
|
208 |
+
*positionwise_layer_args) if macaron_style else None,
|
209 |
+
ConvolutionModule(
|
210 |
+
*convolution_layer_args) if use_cnn_module else None,
|
211 |
+
dropout_rate,
|
212 |
+
normalize_before,
|
213 |
+
) for _ in range(num_blocks)
|
214 |
+
])
|
215 |
+
self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
|
216 |
+
self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
217 |
+
input_size,
|
218 |
+
output_size,
|
219 |
+
dropout_rate,
|
220 |
+
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
221 |
+
positional_dropout_rate),
|
222 |
+
)
|
223 |
+
self.up_encoders = torch.nn.ModuleList([
|
224 |
+
ConformerEncoderLayer(
|
225 |
+
output_size,
|
226 |
+
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
227 |
+
*encoder_selfattn_layer_args),
|
228 |
+
PositionwiseFeedForward(*positionwise_layer_args),
|
229 |
+
PositionwiseFeedForward(
|
230 |
+
*positionwise_layer_args) if macaron_style else None,
|
231 |
+
ConvolutionModule(
|
232 |
+
*convolution_layer_args) if use_cnn_module else None,
|
233 |
+
dropout_rate,
|
234 |
+
normalize_before,
|
235 |
+
) for _ in range(4)
|
236 |
+
])
|
237 |
+
|
238 |
+
def output_size(self) -> int:
|
239 |
+
return self._output_size
|
240 |
+
|
241 |
+
def forward(
|
242 |
+
self,
|
243 |
+
xs: torch.Tensor,
|
244 |
+
xs_lens: torch.Tensor,
|
245 |
+
decoding_chunk_size: int = 0,
|
246 |
+
num_decoding_left_chunks: int = -1,
|
247 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
248 |
+
"""Embed positions in tensor.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
xs: padded input tensor (B, T, D)
|
252 |
+
xs_lens: input length (B)
|
253 |
+
decoding_chunk_size: decoding chunk size for dynamic chunk
|
254 |
+
0: default for training, use random dynamic chunk.
|
255 |
+
<0: for decoding, use full chunk.
|
256 |
+
>0: for decoding, use fixed chunk size as set.
|
257 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
258 |
+
the chunk size is decoding_chunk_size.
|
259 |
+
>=0: use num_decoding_left_chunks
|
260 |
+
<0: use all left chunks
|
261 |
+
Returns:
|
262 |
+
encoder output tensor xs, and subsampled masks
|
263 |
+
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
264 |
+
masks: torch.Tensor batch padding mask after subsample
|
265 |
+
(B, 1, T' ~= T/subsample_rate)
|
266 |
+
NOTE(xcsong):
|
267 |
+
We pass the `__call__` method of the modules instead of `forward` to the
|
268 |
+
checkpointing API because `__call__` attaches all the hooks of the module.
|
269 |
+
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
270 |
+
"""
|
271 |
+
T = xs.size(1)
|
272 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
273 |
+
if self.global_cmvn is not None:
|
274 |
+
xs = self.global_cmvn(xs)
|
275 |
+
xs, pos_emb, masks = self.embed(xs, masks)
|
276 |
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
277 |
+
chunk_masks = add_optional_chunk_mask(xs, masks,
|
278 |
+
self.use_dynamic_chunk,
|
279 |
+
self.use_dynamic_left_chunk,
|
280 |
+
decoding_chunk_size,
|
281 |
+
self.static_chunk_size,
|
282 |
+
num_decoding_left_chunks)
|
283 |
+
# lookahead + conformer encoder
|
284 |
+
xs = self.pre_lookahead_layer(xs)
|
285 |
+
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
286 |
+
|
287 |
+
# upsample + conformer encoder
|
288 |
+
xs = xs.transpose(1, 2).contiguous()
|
289 |
+
xs, xs_lens = self.up_layer(xs, xs_lens)
|
290 |
+
xs = xs.transpose(1, 2).contiguous()
|
291 |
+
T = xs.size(1)
|
292 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
293 |
+
xs, pos_emb, masks = self.up_embed(xs, masks)
|
294 |
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
295 |
+
chunk_masks = add_optional_chunk_mask(xs, masks,
|
296 |
+
self.use_dynamic_chunk,
|
297 |
+
self.use_dynamic_left_chunk,
|
298 |
+
decoding_chunk_size,
|
299 |
+
self.static_chunk_size * self.up_layer.stride,
|
300 |
+
num_decoding_left_chunks)
|
301 |
+
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
|
302 |
+
|
303 |
+
if self.normalize_before:
|
304 |
+
xs = self.after_norm(xs)
|
305 |
+
# Here we assume the mask is not changed in encoder layers, so just
|
306 |
+
# return the masks before encoder layers, and the masks will be used
|
307 |
+
# for cross attention with decoder later
|
308 |
+
return xs, masks
|
309 |
+
|
310 |
+
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
311 |
+
pos_emb: torch.Tensor,
|
312 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
313 |
+
for layer in self.encoders:
|
314 |
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
315 |
+
return xs
|
316 |
+
|
317 |
+
def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
318 |
+
pos_emb: torch.Tensor,
|
319 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
320 |
+
for layer in self.up_encoders:
|
321 |
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
322 |
+
return xs
|
cosyvoice/utils/__init__.py
ADDED
File without changes
|
cosyvoice/utils/class_utils.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright [2023-11-28] <[email protected], Xingchen Song>
|
2 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from cosyvoice.transformer.activation import Swish
|
18 |
+
from cosyvoice.transformer.subsampling import (
|
19 |
+
LinearNoSubsampling,
|
20 |
+
EmbedinigNoSubsampling,
|
21 |
+
Conv1dSubsampling2,
|
22 |
+
Conv2dSubsampling4,
|
23 |
+
Conv2dSubsampling6,
|
24 |
+
Conv2dSubsampling8,
|
25 |
+
)
|
26 |
+
from cosyvoice.transformer.embedding import (PositionalEncoding,
|
27 |
+
RelPositionalEncoding,
|
28 |
+
WhisperPositionalEncoding,
|
29 |
+
LearnablePositionalEncoding,
|
30 |
+
NoPositionalEncoding)
|
31 |
+
from cosyvoice.transformer.attention import (MultiHeadedAttention,
|
32 |
+
RelPositionMultiHeadedAttention)
|
33 |
+
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
|
34 |
+
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
|
35 |
+
|
36 |
+
|
37 |
+
COSYVOICE_ACTIVATION_CLASSES = {
|
38 |
+
"hardtanh": torch.nn.Hardtanh,
|
39 |
+
"tanh": torch.nn.Tanh,
|
40 |
+
"relu": torch.nn.ReLU,
|
41 |
+
"selu": torch.nn.SELU,
|
42 |
+
"swish": getattr(torch.nn, "SiLU", Swish),
|
43 |
+
"gelu": torch.nn.GELU,
|
44 |
+
}
|
45 |
+
|
46 |
+
COSYVOICE_SUBSAMPLE_CLASSES = {
|
47 |
+
"linear": LinearNoSubsampling,
|
48 |
+
"linear_legacy": LegacyLinearNoSubsampling,
|
49 |
+
"embed": EmbedinigNoSubsampling,
|
50 |
+
"conv1d2": Conv1dSubsampling2,
|
51 |
+
"conv2d": Conv2dSubsampling4,
|
52 |
+
"conv2d6": Conv2dSubsampling6,
|
53 |
+
"conv2d8": Conv2dSubsampling8,
|
54 |
+
'paraformer_dummy': torch.nn.Identity
|
55 |
+
}
|
56 |
+
|
57 |
+
COSYVOICE_EMB_CLASSES = {
|
58 |
+
"embed": PositionalEncoding,
|
59 |
+
"abs_pos": PositionalEncoding,
|
60 |
+
"rel_pos": RelPositionalEncoding,
|
61 |
+
"rel_pos_espnet": EspnetRelPositionalEncoding,
|
62 |
+
"no_pos": NoPositionalEncoding,
|
63 |
+
"abs_pos_whisper": WhisperPositionalEncoding,
|
64 |
+
"embed_learnable_pe": LearnablePositionalEncoding,
|
65 |
+
}
|
66 |
+
|
67 |
+
COSYVOICE_ATTENTION_CLASSES = {
|
68 |
+
"selfattn": MultiHeadedAttention,
|
69 |
+
"rel_selfattn": RelPositionMultiHeadedAttention,
|
70 |
+
}
|
cosyvoice/utils/common.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
2 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""Unility functions for Transformer."""
|
17 |
+
|
18 |
+
import random
|
19 |
+
from typing import List
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
|
24 |
+
IGNORE_ID = -1
|
25 |
+
|
26 |
+
|
27 |
+
def pad_list(xs: List[torch.Tensor], pad_value: int):
|
28 |
+
"""Perform padding for the list of tensors.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
32 |
+
pad_value (float): Value for padding.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
Tensor: Padded tensor (B, Tmax, `*`).
|
36 |
+
|
37 |
+
Examples:
|
38 |
+
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
39 |
+
>>> x
|
40 |
+
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
41 |
+
>>> pad_list(x, 0)
|
42 |
+
tensor([[1., 1., 1., 1.],
|
43 |
+
[1., 1., 0., 0.],
|
44 |
+
[1., 0., 0., 0.]])
|
45 |
+
|
46 |
+
"""
|
47 |
+
max_len = max([len(item) for item in xs])
|
48 |
+
batchs = len(xs)
|
49 |
+
ndim = xs[0].ndim
|
50 |
+
if ndim == 1:
|
51 |
+
pad_res = torch.zeros(batchs,
|
52 |
+
max_len,
|
53 |
+
dtype=xs[0].dtype,
|
54 |
+
device=xs[0].device)
|
55 |
+
elif ndim == 2:
|
56 |
+
pad_res = torch.zeros(batchs,
|
57 |
+
max_len,
|
58 |
+
xs[0].shape[1],
|
59 |
+
dtype=xs[0].dtype,
|
60 |
+
device=xs[0].device)
|
61 |
+
elif ndim == 3:
|
62 |
+
pad_res = torch.zeros(batchs,
|
63 |
+
max_len,
|
64 |
+
xs[0].shape[1],
|
65 |
+
xs[0].shape[2],
|
66 |
+
dtype=xs[0].dtype,
|
67 |
+
device=xs[0].device)
|
68 |
+
else:
|
69 |
+
raise ValueError(f"Unsupported ndim: {ndim}")
|
70 |
+
pad_res.fill_(pad_value)
|
71 |
+
for i in range(batchs):
|
72 |
+
pad_res[i, :len(xs[i])] = xs[i]
|
73 |
+
return pad_res
|
74 |
+
|
75 |
+
|
76 |
+
def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
|
77 |
+
ignore_label: int) -> torch.Tensor:
|
78 |
+
"""Calculate accuracy.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
82 |
+
pad_targets (LongTensor): Target label tensors (B, Lmax).
|
83 |
+
ignore_label (int): Ignore label id.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
torch.Tensor: Accuracy value (0.0 - 1.0).
|
87 |
+
|
88 |
+
"""
|
89 |
+
pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
|
90 |
+
pad_outputs.size(1)).argmax(2)
|
91 |
+
mask = pad_targets != ignore_label
|
92 |
+
numerator = torch.sum(
|
93 |
+
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
94 |
+
denominator = torch.sum(mask)
|
95 |
+
return (numerator / denominator).detach()
|
96 |
+
|
97 |
+
|
98 |
+
def get_padding(kernel_size, dilation=1):
|
99 |
+
return int((kernel_size * dilation - dilation) / 2)
|
100 |
+
|
101 |
+
|
102 |
+
def init_weights(m, mean=0.0, std=0.01):
|
103 |
+
classname = m.__class__.__name__
|
104 |
+
if classname.find("Conv") != -1:
|
105 |
+
m.weight.data.normal_(mean, std)
|
106 |
+
|
107 |
+
|
108 |
+
# Repetition Aware Sampling in VALL-E 2
|
109 |
+
def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
|
110 |
+
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
|
111 |
+
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
|
112 |
+
if rep_num >= win_size * tau_r:
|
113 |
+
top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
|
114 |
+
return top_ids
|
115 |
+
|
116 |
+
|
117 |
+
def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
|
118 |
+
prob, indices = [], []
|
119 |
+
cum_prob = 0.0
|
120 |
+
sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
|
121 |
+
for i in range(len(sorted_idx)):
|
122 |
+
# sampling both top-p and numbers.
|
123 |
+
if cum_prob < top_p and len(prob) < top_k:
|
124 |
+
cum_prob += sorted_value[i]
|
125 |
+
prob.append(sorted_value[i])
|
126 |
+
indices.append(sorted_idx[i])
|
127 |
+
else:
|
128 |
+
break
|
129 |
+
prob = torch.tensor(prob).to(weighted_scores)
|
130 |
+
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
|
131 |
+
top_ids = indices[prob.multinomial(1, replacement=True)]
|
132 |
+
return top_ids
|
133 |
+
|
134 |
+
|
135 |
+
def random_sampling(weighted_scores, decoded_tokens, sampling):
|
136 |
+
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
|
137 |
+
return top_ids
|
138 |
+
|
139 |
+
|
140 |
+
def fade_in_out(fade_in_mel, fade_out_mel, window):
|
141 |
+
device = fade_in_mel.device
|
142 |
+
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
|
143 |
+
mel_overlap_len = int(window.shape[0] / 2)
|
144 |
+
if fade_in_mel.device == torch.device('cpu'):
|
145 |
+
fade_in_mel = fade_in_mel.clone()
|
146 |
+
fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
|
147 |
+
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
|
148 |
+
return fade_in_mel.to(device)
|
149 |
+
|
150 |
+
|
151 |
+
def set_all_random_seed(seed):
|
152 |
+
random.seed(seed)
|
153 |
+
np.random.seed(seed)
|
154 |
+
torch.manual_seed(seed)
|
155 |
+
torch.cuda.manual_seed_all(seed)
|
156 |
+
|
157 |
+
|
158 |
+
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
159 |
+
assert mask.dtype == torch.bool
|
160 |
+
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
|
161 |
+
mask = mask.to(dtype)
|
162 |
+
# attention mask bias
|
163 |
+
# NOTE(Mddct): torch.finfo jit issues
|
164 |
+
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
165 |
+
mask = (1.0 - mask) * torch.finfo(dtype).min
|
166 |
+
return mask
|
cosyvoice/utils/executor.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
2 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import logging
|
17 |
+
from contextlib import nullcontext
|
18 |
+
import os
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.distributed as dist
|
22 |
+
|
23 |
+
from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join
|
24 |
+
|
25 |
+
|
26 |
+
class Executor:
|
27 |
+
|
28 |
+
def __init__(self, gan: bool = False):
|
29 |
+
self.gan = gan
|
30 |
+
self.step = 0
|
31 |
+
self.epoch = 0
|
32 |
+
self.rank = int(os.environ.get('RANK', 0))
|
33 |
+
self.device = torch.device('cuda:{}'.format(self.rank))
|
34 |
+
|
35 |
+
def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join):
|
36 |
+
''' Train one epoch
|
37 |
+
'''
|
38 |
+
|
39 |
+
lr = optimizer.param_groups[0]['lr']
|
40 |
+
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
|
41 |
+
logging.info('using accumulate grad, new batch size is {} times'
|
42 |
+
' larger than before'.format(info_dict['accum_grad']))
|
43 |
+
# A context manager to be used in conjunction with an instance of
|
44 |
+
# torch.nn.parallel.DistributedDataParallel to be able to train
|
45 |
+
# with uneven inputs across participating processes.
|
46 |
+
model.train()
|
47 |
+
model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
|
48 |
+
with model_context():
|
49 |
+
for batch_idx, batch_dict in enumerate(train_data_loader):
|
50 |
+
info_dict["tag"] = "TRAIN"
|
51 |
+
info_dict["step"] = self.step
|
52 |
+
info_dict["epoch"] = self.epoch
|
53 |
+
info_dict["batch_idx"] = batch_idx
|
54 |
+
if cosyvoice_join(group_join, info_dict):
|
55 |
+
break
|
56 |
+
|
57 |
+
# Disable gradient synchronizations across DDP processes.
|
58 |
+
# Within this context, gradients will be accumulated on module
|
59 |
+
# variables, which will later be synchronized.
|
60 |
+
if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
|
61 |
+
context = model.no_sync
|
62 |
+
# Used for single gpu training and DDP gradient synchronization
|
63 |
+
# processes.
|
64 |
+
else:
|
65 |
+
context = nullcontext
|
66 |
+
|
67 |
+
with context():
|
68 |
+
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
|
69 |
+
info_dict = batch_backward(model, scaler, info_dict)
|
70 |
+
|
71 |
+
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
|
72 |
+
log_per_step(writer, info_dict)
|
73 |
+
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
|
74 |
+
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
|
75 |
+
(batch_idx + 1) % info_dict["accum_grad"] == 0:
|
76 |
+
dist.barrier()
|
77 |
+
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
|
78 |
+
model.train()
|
79 |
+
if (batch_idx + 1) % info_dict["accum_grad"] == 0:
|
80 |
+
self.step += 1
|
81 |
+
dist.barrier()
|
82 |
+
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
|
83 |
+
|
84 |
+
def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
85 |
+
writer, info_dict, scaler, group_join):
|
86 |
+
''' Train one epoch
|
87 |
+
'''
|
88 |
+
|
89 |
+
lr = optimizer.param_groups[0]['lr']
|
90 |
+
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
|
91 |
+
logging.info('using accumulate grad, new batch size is {} times'
|
92 |
+
' larger than before'.format(info_dict['accum_grad']))
|
93 |
+
# A context manager to be used in conjunction with an instance of
|
94 |
+
# torch.nn.parallel.DistributedDataParallel to be able to train
|
95 |
+
# with uneven inputs across participating processes.
|
96 |
+
model.train()
|
97 |
+
model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
|
98 |
+
with model_context():
|
99 |
+
for batch_idx, batch_dict in enumerate(train_data_loader):
|
100 |
+
info_dict["tag"] = "TRAIN"
|
101 |
+
info_dict["step"] = self.step
|
102 |
+
info_dict["epoch"] = self.epoch
|
103 |
+
info_dict["batch_idx"] = batch_idx
|
104 |
+
if cosyvoice_join(group_join, info_dict):
|
105 |
+
break
|
106 |
+
|
107 |
+
# Disable gradient synchronizations across DDP processes.
|
108 |
+
# Within this context, gradients will be accumulated on module
|
109 |
+
# variables, which will later be synchronized.
|
110 |
+
if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
|
111 |
+
context = model.no_sync
|
112 |
+
# Used for single gpu training and DDP gradient synchronization
|
113 |
+
# processes.
|
114 |
+
else:
|
115 |
+
context = nullcontext
|
116 |
+
|
117 |
+
with context():
|
118 |
+
batch_dict['turn'] = 'discriminator'
|
119 |
+
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
|
120 |
+
info_dict = batch_backward(model, scaler, info_dict)
|
121 |
+
info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict)
|
122 |
+
optimizer.zero_grad()
|
123 |
+
log_per_step(writer, info_dict)
|
124 |
+
with context():
|
125 |
+
batch_dict['turn'] = 'generator'
|
126 |
+
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
|
127 |
+
info_dict = batch_backward(model, scaler, info_dict)
|
128 |
+
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
|
129 |
+
optimizer_d.zero_grad()
|
130 |
+
log_per_step(writer, info_dict)
|
131 |
+
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
|
132 |
+
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
|
133 |
+
(batch_idx + 1) % info_dict["accum_grad"] == 0:
|
134 |
+
dist.barrier()
|
135 |
+
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
|
136 |
+
model.train()
|
137 |
+
if (batch_idx + 1) % info_dict["accum_grad"] == 0:
|
138 |
+
self.step += 1
|
139 |
+
dist.barrier()
|
140 |
+
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
|
141 |
+
|
142 |
+
@torch.inference_mode()
|
143 |
+
def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
|
144 |
+
''' Cross validation on
|
145 |
+
'''
|
146 |
+
logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
|
147 |
+
model.eval()
|
148 |
+
total_num_utts, total_loss_dict = 0, {} # avoid division by 0
|
149 |
+
for batch_idx, batch_dict in enumerate(cv_data_loader):
|
150 |
+
info_dict["tag"] = "CV"
|
151 |
+
info_dict["step"] = self.step
|
152 |
+
info_dict["epoch"] = self.epoch
|
153 |
+
info_dict["batch_idx"] = batch_idx
|
154 |
+
|
155 |
+
num_utts = len(batch_dict["utts"])
|
156 |
+
total_num_utts += num_utts
|
157 |
+
|
158 |
+
if self.gan is True:
|
159 |
+
batch_dict['turn'] = 'generator'
|
160 |
+
info_dict = batch_forward(model, batch_dict, None, info_dict)
|
161 |
+
|
162 |
+
for k, v in info_dict['loss_dict'].items():
|
163 |
+
if k not in total_loss_dict:
|
164 |
+
total_loss_dict[k] = []
|
165 |
+
total_loss_dict[k].append(v.item() * num_utts)
|
166 |
+
log_per_step(None, info_dict)
|
167 |
+
for k, v in total_loss_dict.items():
|
168 |
+
total_loss_dict[k] = sum(v) / total_num_utts
|
169 |
+
info_dict['loss_dict'] = total_loss_dict
|
170 |
+
log_per_save(writer, info_dict)
|
171 |
+
model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
|
172 |
+
save_model(model, model_name, info_dict)
|
cosyvoice/utils/file_utils.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
2 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import json
|
17 |
+
import torchaudio
|
18 |
+
import logging
|
19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
20 |
+
logging.basicConfig(level=logging.DEBUG,
|
21 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
22 |
+
|
23 |
+
|
24 |
+
def read_lists(list_file):
|
25 |
+
lists = []
|
26 |
+
with open(list_file, 'r', encoding='utf8') as fin:
|
27 |
+
for line in fin:
|
28 |
+
lists.append(line.strip())
|
29 |
+
return lists
|
30 |
+
|
31 |
+
|
32 |
+
def read_json_lists(list_file):
|
33 |
+
lists = read_lists(list_file)
|
34 |
+
results = {}
|
35 |
+
for fn in lists:
|
36 |
+
with open(fn, 'r', encoding='utf8') as fin:
|
37 |
+
results.update(json.load(fin))
|
38 |
+
return results
|
39 |
+
|
40 |
+
|
41 |
+
def load_wav(wav, target_sr):
|
42 |
+
# speech, sample_rate = torchaudio.load(wav)
|
43 |
+
# speech = speech.mean(dim=0, keepdim=True)
|
44 |
+
# if sample_rate != target_sr:
|
45 |
+
# assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
|
46 |
+
# speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
47 |
+
|
48 |
+
import librosa, torch
|
49 |
+
speech, _ = librosa.load(path=wav, sr=target_sr)
|
50 |
+
speech = torch.from_numpy(speech).unsqueeze(dim=0)
|
51 |
+
return speech
|
cosyvoice/utils/frontend_utils.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import re
|
16 |
+
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
17 |
+
|
18 |
+
|
19 |
+
# whether contain chinese character
|
20 |
+
def contains_chinese(text):
|
21 |
+
return bool(chinese_char_pattern.search(text))
|
22 |
+
|
23 |
+
|
24 |
+
# replace special symbol
|
25 |
+
def replace_corner_mark(text):
|
26 |
+
text = text.replace('²', '平方')
|
27 |
+
text = text.replace('³', '立方')
|
28 |
+
return text
|
29 |
+
|
30 |
+
|
31 |
+
# remove meaningless symbol
|
32 |
+
def remove_bracket(text):
|
33 |
+
text = text.replace('(', '').replace(')', '')
|
34 |
+
text = text.replace('【', '').replace('】', '')
|
35 |
+
text = text.replace('`', '').replace('`', '')
|
36 |
+
text = text.replace("——", " ")
|
37 |
+
return text
|
38 |
+
|
39 |
+
|
40 |
+
# spell Arabic numerals
|
41 |
+
def spell_out_number(text: str, inflect_parser):
|
42 |
+
new_text = []
|
43 |
+
st = None
|
44 |
+
for i, c in enumerate(text):
|
45 |
+
if not c.isdigit():
|
46 |
+
if st is not None:
|
47 |
+
num_str = inflect_parser.number_to_words(text[st: i])
|
48 |
+
new_text.append(num_str)
|
49 |
+
st = None
|
50 |
+
new_text.append(c)
|
51 |
+
else:
|
52 |
+
if st is None:
|
53 |
+
st = i
|
54 |
+
if st is not None and st < len(text):
|
55 |
+
num_str = inflect_parser.number_to_words(text[st:])
|
56 |
+
new_text.append(num_str)
|
57 |
+
return ''.join(new_text)
|
58 |
+
|
59 |
+
|
60 |
+
# split paragrah logic:
|
61 |
+
# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
|
62 |
+
# 2. cal sentence len according to lang
|
63 |
+
# 3. split sentence according to puncatation
|
64 |
+
def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
|
65 |
+
def calc_utt_length(_text: str):
|
66 |
+
if lang == "zh":
|
67 |
+
return len(_text)
|
68 |
+
else:
|
69 |
+
return len(tokenize(_text))
|
70 |
+
|
71 |
+
def should_merge(_text: str):
|
72 |
+
if lang == "zh":
|
73 |
+
return len(_text) < merge_len
|
74 |
+
else:
|
75 |
+
return len(tokenize(_text)) < merge_len
|
76 |
+
|
77 |
+
if lang == "zh":
|
78 |
+
pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
|
79 |
+
else:
|
80 |
+
pounc = ['.', '?', '!', ';', ':']
|
81 |
+
if comma_split:
|
82 |
+
pounc.extend([',', ','])
|
83 |
+
|
84 |
+
if text[-1] not in pounc:
|
85 |
+
if lang == "zh":
|
86 |
+
text += "。"
|
87 |
+
else:
|
88 |
+
text += "."
|
89 |
+
|
90 |
+
st = 0
|
91 |
+
utts = []
|
92 |
+
for i, c in enumerate(text):
|
93 |
+
if c in pounc:
|
94 |
+
if len(text[st: i]) > 0:
|
95 |
+
utts.append(text[st: i] + c)
|
96 |
+
if i + 1 < len(text) and text[i + 1] in ['"', '”']:
|
97 |
+
tmp = utts.pop(-1)
|
98 |
+
utts.append(tmp + text[i + 1])
|
99 |
+
st = i + 2
|
100 |
+
else:
|
101 |
+
st = i + 1
|
102 |
+
|
103 |
+
final_utts = []
|
104 |
+
cur_utt = ""
|
105 |
+
for utt in utts:
|
106 |
+
if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
|
107 |
+
final_utts.append(cur_utt)
|
108 |
+
cur_utt = ""
|
109 |
+
cur_utt = cur_utt + utt
|
110 |
+
if len(cur_utt) > 0:
|
111 |
+
if should_merge(cur_utt) and len(final_utts) != 0:
|
112 |
+
final_utts[-1] = final_utts[-1] + cur_utt
|
113 |
+
else:
|
114 |
+
final_utts.append(cur_utt)
|
115 |
+
|
116 |
+
return final_utts
|
117 |
+
|
118 |
+
|
119 |
+
# remove blank between chinese character
|
120 |
+
def replace_blank(text: str):
|
121 |
+
out_str = []
|
122 |
+
for i, c in enumerate(text):
|
123 |
+
if c == " ":
|
124 |
+
if ((text[i + 1].isascii() and text[i + 1] != " ") and
|
125 |
+
(text[i - 1].isascii() and text[i - 1] != " ")):
|
126 |
+
out_str.append(c)
|
127 |
+
else:
|
128 |
+
out_str.append(c)
|
129 |
+
return "".join(out_str)
|