Update requirements.txt

#19
by SWivid - opened
This view is limited to 50 files because it contains too many changes.  See the raw diff here.
.DS_Store DELETED
Binary file (6.15 kB)
 
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/workflows/pre-commit.yaml DELETED
@@ -1,14 +0,0 @@
1
- name: pre-commit
2
-
3
- on:
4
- pull_request:
5
- push:
6
- branches: [main]
7
-
8
- jobs:
9
- pre-commit:
10
- runs-on: ubuntu-latest
11
- steps:
12
- - uses: actions/checkout@v3
13
- - uses: actions/setup-python@v3
14
- - uses: pre-commit/[email protected]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/workflows/sync-hf.yaml DELETED
@@ -1,18 +0,0 @@
1
- name: Sync to HF Space
2
-
3
- on:
4
- push:
5
- branches:
6
- - main
7
-
8
- jobs:
9
- trigger_curl:
10
- runs-on: ubuntu-latest
11
-
12
- steps:
13
- - name: Send cURL POST request
14
- run: |
15
- curl -X POST https://mrfakename-sync-f5.hf.space/gradio_api/call/refresh \
16
- -s \
17
- -H "Content-Type: application/json" \
18
- -d "{\"data\": [\"${{ secrets.REFRESH_PASSWORD }}\"]}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.pre-commit-config.yaml DELETED
@@ -1,14 +0,0 @@
1
- repos:
2
- - repo: https://github.com/astral-sh/ruff-pre-commit
3
- # Ruff version.
4
- rev: v0.7.0
5
- hooks:
6
- # Run the linter.
7
- - id: ruff
8
- args: [--fix]
9
- # Run the formatter.
10
- - id: ruff-format
11
- - repo: https://github.com/pre-commit/pre-commit-hooks
12
- rev: v2.3.0
13
- hooks:
14
- - id: check-yaml
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Dockerfile DELETED
@@ -1,25 +0,0 @@
1
- FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel
2
-
3
- USER root
4
-
5
- ARG DEBIAN_FRONTEND=noninteractive
6
-
7
- LABEL github_repo="https://github.com/SWivid/F5-TTS"
8
-
9
- RUN set -x \
10
- && apt-get update \
11
- && apt-get -y install wget curl man git less openssl libssl-dev unzip unar build-essential aria2 tmux vim \
12
- && apt-get install -y openssh-server sox libsox-fmt-all libsox-fmt-mp3 libsndfile1-dev ffmpeg \
13
- && rm -rf /var/lib/apt/lists/* \
14
- && apt-get clean
15
-
16
- WORKDIR /workspace
17
-
18
- RUN git clone https://github.com/SWivid/F5-TTS.git \
19
- && cd F5-TTS \
20
- && pip install --no-cache-dir -r requirements.txt \
21
- && pip install --no-cache-dir -r requirements_eval.txt
22
-
23
- ENV SHELL=/bin/bash
24
-
25
- WORKDIR /workspace/F5-TTS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: F5-TTS
3
  emoji: 🗣️
4
  colorFrom: green
5
  colorTo: green
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: true
9
- short_description: 'F5-TTS & E2-TTS: Zero-Shot Voice Cloning (Unofficial Demo)'
10
- sdk_version: 4.44.1
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: E2/F5 TTS
3
  emoji: 🗣️
4
  colorFrom: green
5
  colorTo: green
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: true
9
+ short_description: 'E2-TTS & F5-TTS: Zero-Shot Voice Cloning (Unofficial Demo)'
10
+ sdk_version: 5.0.2
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
README_REPO.md DELETED
@@ -1,269 +0,0 @@
1
- # F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching
2
-
3
- [![python](https://img.shields.io/badge/Python-3.10-brightgreen)](https://github.com/SWivid/F5-TTS)
4
- [![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885)
5
- [![demo](https://img.shields.io/badge/GitHub-Demo%20page-orange.svg)](https://swivid.github.io/F5-TTS/)
6
- [![hfspace](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
7
- [![msspace](https://img.shields.io/badge/🤖-Space%20demo-blue)](https://modelscope.cn/studios/modelscope/E2-F5-TTS)
8
- [![lab](https://img.shields.io/badge/X--LANCE-Lab-grey?labelColor=lightgrey)](https://x-lance.sjtu.edu.cn/)
9
- <img src="https://github.com/user-attachments/assets/12d7749c-071a-427c-81bf-b87b91def670" alt="Watermark" style="width: 40px; height: auto">
10
-
11
- **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
12
-
13
- **E2 TTS**: Flat-UNet Transformer, closest reproduction from [paper](https://arxiv.org/abs/2406.18009).
14
-
15
- **Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance
16
-
17
- ### Thanks to all the contributors !
18
-
19
- ## Installation
20
-
21
- Clone the repository:
22
-
23
- ```bash
24
- git clone https://github.com/SWivid/F5-TTS.git
25
- cd F5-TTS
26
- ```
27
-
28
- Install torch with your CUDA version, e.g. :
29
-
30
- ```bash
31
- pip install torch==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
32
- pip install torchaudio==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
33
- ```
34
-
35
- Install other packages:
36
-
37
- ```bash
38
- pip install -r requirements.txt
39
- ```
40
-
41
- **[Optional]**: We provide [Dockerfile](https://github.com/SWivid/F5-TTS/blob/main/Dockerfile) and you can use the following command to build it.
42
- ```bash
43
- docker build -t f5tts:v1 .
44
- ```
45
-
46
- ### Development
47
-
48
- When making a pull request, please use pre-commit to ensure code quality:
49
-
50
- ```bash
51
- pip install pre-commit
52
- pre-commit install
53
- ```
54
-
55
- This will run linters and formatters automatically before each commit.
56
-
57
- Manually run using:
58
-
59
- ```bash
60
- pre-commit run --all-files
61
- ```
62
-
63
- Note: Some model components have linting exceptions for E722 to accommodate tensor notation
64
-
65
-
66
- ## Prepare Dataset
67
-
68
- Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.
69
-
70
- ```bash
71
- # prepare custom dataset up to your need
72
- # download corresponding dataset first, and fill in the path in scripts
73
-
74
- # Prepare the Emilia dataset
75
- python scripts/prepare_emilia.py
76
-
77
- # Prepare the Wenetspeech4TTS dataset
78
- python scripts/prepare_wenetspeech4tts.py
79
- ```
80
-
81
- ## Training & Finetuning
82
-
83
- Once your datasets are prepared, you can start the training process.
84
-
85
- ```bash
86
- # setup accelerate config, e.g. use multi-gpu ddp, fp16
87
- # will be to: ~/.cache/huggingface/accelerate/default_config.yaml
88
- accelerate config
89
- accelerate launch train.py
90
- ```
91
- An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
92
-
93
- Gradio UI finetuning with `finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
94
-
95
- ### Wandb Logging
96
-
97
- By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`).
98
-
99
- To turn on wandb logging, you can either:
100
-
101
- 1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
102
- 2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows:
103
-
104
- On Mac & Linux:
105
-
106
- ```
107
- export WANDB_API_KEY=<YOUR WANDB API KEY>
108
- ```
109
-
110
- On Windows:
111
-
112
- ```
113
- set WANDB_API_KEY=<YOUR WANDB API KEY>
114
- ```
115
- Moreover, if you couldn't access Wandb and want to log metrics offline, you can the environment variable as follows:
116
-
117
- ```
118
- export WANDB_MODE=offline
119
- ```
120
-
121
- ## Inference
122
-
123
- The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or automatically downloaded with `inference-cli` and `gradio_app`.
124
-
125
- Currently support 30s for a single generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by `inference-cli` and `gradio_app`.
126
- - To avoid possible inference failures, make sure you have seen through the following instructions.
127
- - A longer prompt audio allows shorter generated output. The part longer than 30s cannot be generated properly. Consider using a prompt audio <15s.
128
- - Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
129
- - Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses. If first few words skipped in code-switched generation (cuz different speed with different languages), this might help.
130
-
131
- ### CLI Inference
132
-
133
- Either you can specify everything in `inference-cli.toml` or override with flags. Leave `--ref_text ""` will have ASR model transcribe the reference audio automatically (use extra GPU memory). If encounter network error, consider use local ckpt, just set `ckpt_file` in `inference-cli.py`
134
-
135
- for change model use `--ckpt_file` to specify the model you want to load,
136
- for change vocab.txt use `--vocab_file` to provide your vocab.txt file.
137
-
138
- ```bash
139
- python inference-cli.py \
140
- --model "F5-TTS" \
141
- --ref_audio "tests/ref_audio/test_en_1_ref_short.wav" \
142
- --ref_text "Some call me nature, others call me mother nature." \
143
- --gen_text "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
144
-
145
- python inference-cli.py \
146
- --model "E2-TTS" \
147
- --ref_audio "tests/ref_audio/test_zh_1_ref_short.wav" \
148
- --ref_text "对,这就是我,万人敬仰的太乙真人。" \
149
- --gen_text "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道,我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"
150
-
151
- # Multi voice
152
- python inference-cli.py -c samples/story.toml
153
- ```
154
-
155
- ### Gradio App
156
- Currently supported features:
157
- - Chunk inference
158
- - Podcast Generation
159
- - Multiple Speech-Type Generation
160
-
161
- You can launch a Gradio app (web interface) to launch a GUI for inference (will load ckpt from Huggingface, you may also use local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
162
-
163
- ```bash
164
- python gradio_app.py
165
- ```
166
-
167
- You can specify the port/host:
168
-
169
- ```bash
170
- python gradio_app.py --port 7860 --host 0.0.0.0
171
- ```
172
-
173
- Or launch a share link:
174
-
175
- ```bash
176
- python gradio_app.py --share
177
- ```
178
-
179
- ### Speech Editing
180
-
181
- To test speech editing capabilities, use the following command.
182
-
183
- ```bash
184
- python speech_edit.py
185
- ```
186
-
187
- ## Evaluation
188
-
189
- ### Prepare Test Datasets
190
-
191
- 1. Seed-TTS test set: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
192
- 2. LibriSpeech test-clean: Download from [OpenSLR](http://www.openslr.org/12/).
193
- 3. Unzip the downloaded datasets and place them in the data/ directory.
194
- 4. Update the path for the test-clean data in `scripts/eval_infer_batch.py`
195
- 5. Our filtered LibriSpeech-PC 4-10s subset is already under data/ in this repo
196
-
197
- ### Batch Inference for Test Set
198
-
199
- To run batch inference for evaluations, execute the following commands:
200
-
201
- ```bash
202
- # batch inference for evaluations
203
- accelerate config # if not set before
204
- bash scripts/eval_infer_batch.sh
205
- ```
206
-
207
- ### Download Evaluation Model Checkpoints
208
-
209
- 1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
210
- 2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
211
- 3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
212
-
213
- ### Objective Evaluation
214
-
215
- Install packages for evaluation:
216
-
217
- ```bash
218
- pip install -r requirements_eval.txt
219
- ```
220
-
221
- **Some Notes**
222
-
223
- For faster-whisper with CUDA 11:
224
-
225
- ```bash
226
- pip install --force-reinstall ctranslate2==3.24.0
227
- ```
228
-
229
- (Recommended) To avoid possible ASR failures, such as abnormal repetitions in output:
230
-
231
- ```bash
232
- pip install faster-whisper==0.10.1
233
- ```
234
-
235
- Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
236
- ```bash
237
- # Evaluation for Seed-TTS test set
238
- python scripts/eval_seedtts_testset.py
239
-
240
- # Evaluation for LibriSpeech-PC test-clean (cross-sentence)
241
- python scripts/eval_librispeech_test_clean.py
242
- ```
243
-
244
- ## Acknowledgements
245
-
246
- - [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
247
- - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763) valuable datasets
248
- - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
249
- - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
250
- - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
251
- - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
252
- - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
253
- - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
254
- - [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
255
- - [F5-TTS-ONNX](https://github.com/DakeQQ/F5-TTS-ONNX) ONNX Runtime version by [DakeQQ](https://github.com/DakeQQ)
256
-
257
- ## Citation
258
- If our work and codebase is useful for you, please cite as:
259
- ```
260
- @article{chen-etal-2024-f5tts,
261
- title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching},
262
- author={Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen},
263
- journal={arXiv preprint arXiv:2410.06885},
264
- year={2024},
265
- }
266
- ```
267
- ## License
268
-
269
- Our code is released under MIT License. The pre-trained models are licensed under the CC-BY-NC license due to the training data Emilia, which is an in-the-wild dataset. Sorry for any inconvenience this may cause.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api.py DELETED
@@ -1,132 +0,0 @@
1
- import soundfile as sf
2
- import torch
3
- import tqdm
4
- from cached_path import cached_path
5
-
6
- from model import DiT, UNetT
7
- from model.utils import save_spectrogram
8
-
9
- from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
10
- from model.utils import seed_everything
11
- import random
12
- import sys
13
-
14
-
15
- class F5TTS:
16
- def __init__(
17
- self,
18
- model_type="F5-TTS",
19
- ckpt_file="",
20
- vocab_file="",
21
- ode_method="euler",
22
- use_ema=True,
23
- local_path=None,
24
- device=None,
25
- ):
26
- # Initialize parameters
27
- self.final_wave = None
28
- self.target_sample_rate = 24000
29
- self.n_mel_channels = 100
30
- self.hop_length = 256
31
- self.target_rms = 0.1
32
- self.seed = -1
33
-
34
- # Set device
35
- self.device = device or (
36
- "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
37
- )
38
-
39
- # Load models
40
- self.load_vocoder_model(local_path)
41
- self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
42
-
43
- def load_vocoder_model(self, local_path):
44
- self.vocos = load_vocoder(local_path is not None, local_path, self.device)
45
-
46
- def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
47
- if model_type == "F5-TTS":
48
- if not ckpt_file:
49
- ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
50
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
51
- model_cls = DiT
52
- elif model_type == "E2-TTS":
53
- if not ckpt_file:
54
- ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
55
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
56
- model_cls = UNetT
57
- else:
58
- raise ValueError(f"Unknown model type: {model_type}")
59
-
60
- self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device)
61
-
62
- def export_wav(self, wav, file_wave, remove_silence=False):
63
- sf.write(file_wave, wav, self.target_sample_rate)
64
-
65
- if remove_silence:
66
- remove_silence_for_generated_wav(file_wave)
67
-
68
- def export_spectrogram(self, spect, file_spect):
69
- save_spectrogram(spect, file_spect)
70
-
71
- def infer(
72
- self,
73
- ref_file,
74
- ref_text,
75
- gen_text,
76
- show_info=print,
77
- progress=tqdm,
78
- target_rms=0.1,
79
- cross_fade_duration=0.15,
80
- sway_sampling_coef=-1,
81
- cfg_strength=2,
82
- nfe_step=32,
83
- speed=1.0,
84
- fix_duration=None,
85
- remove_silence=False,
86
- file_wave=None,
87
- file_spect=None,
88
- seed=-1,
89
- ):
90
- if seed == -1:
91
- seed = random.randint(0, sys.maxsize)
92
- seed_everything(seed)
93
- self.seed = seed
94
- wav, sr, spect = infer_process(
95
- ref_file,
96
- ref_text,
97
- gen_text,
98
- self.ema_model,
99
- show_info=show_info,
100
- progress=progress,
101
- target_rms=target_rms,
102
- cross_fade_duration=cross_fade_duration,
103
- nfe_step=nfe_step,
104
- cfg_strength=cfg_strength,
105
- sway_sampling_coef=sway_sampling_coef,
106
- speed=speed,
107
- fix_duration=fix_duration,
108
- device=self.device,
109
- )
110
-
111
- if file_wave is not None:
112
- self.export_wav(wav, file_wave, remove_silence)
113
-
114
- if file_spect is not None:
115
- self.export_spectrogram(spect, file_spect)
116
-
117
- return wav, sr, spect
118
-
119
-
120
- if __name__ == "__main__":
121
- f5tts = F5TTS()
122
-
123
- wav, sr, spect = f5tts.infer(
124
- ref_file="tests/ref_audio/test_en_1_ref_short.wav",
125
- ref_text="some call me nature, others call me mother nature.",
126
- gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
127
- file_wave="tests/out.wav",
128
- file_spect="tests/out.png",
129
- seed=-1, # random seed = -1
130
- )
131
-
132
- print("seed :", f5tts.seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,593 +1,242 @@
1
- # ruff: noqa: E402
2
- # Above allows ruff to ignore E402: module level import not at top of file
3
-
4
  import re
5
- import tempfile
6
-
7
- import click
8
  import gradio as gr
9
  import numpy as np
10
- import soundfile as sf
11
- import torchaudio
12
- from cached_path import cached_path
 
13
  from pydub import AudioSegment
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- try:
16
- import spaces
17
-
18
- USING_SPACES = True
19
- except ImportError:
20
- USING_SPACES = False
21
 
 
22
 
23
- def gpu_decorator(func):
24
- if USING_SPACES:
25
- return spaces.GPU(func)
26
- else:
27
- return func
28
 
29
 
30
- from model import DiT, UNetT
31
- from model.utils import (
32
- save_spectrogram,
 
 
33
  )
34
- from model.utils_infer import (
35
- load_vocoder,
36
- load_model,
37
- preprocess_ref_audio_text,
38
- infer_process,
39
- remove_silence_for_generated_wav,
40
- )
41
-
42
- vocos = load_vocoder()
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # load models
46
  F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
47
- F5TTS_ema_model = load_model(
48
- DiT, F5TTS_model_cfg, str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
49
- )
50
-
51
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
52
- E2TTS_ema_model = load_model(
53
- UNetT, E2TTS_model_cfg, str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
54
- )
55
-
56
-
57
- @gpu_decorator
58
- def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1):
59
- ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=gr.Info)
60
 
61
- if model == "F5-TTS":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  ema_model = F5TTS_ema_model
63
- elif model == "E2-TTS":
 
64
  ema_model = E2TTS_ema_model
65
-
66
- final_wave, final_sample_rate, combined_spectrogram = infer_process(
67
- ref_audio,
68
- ref_text,
69
- gen_text,
70
- ema_model,
71
- cross_fade_duration=cross_fade_duration,
72
- speed=speed,
73
- show_info=gr.Info,
74
- progress=gr.Progress(),
75
- )
76
-
77
- # Remove silence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  if remove_silence:
79
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
80
- sf.write(f.name, final_wave, final_sample_rate)
81
- remove_silence_for_generated_wav(f.name)
82
- final_wave, _ = torchaudio.load(f.name)
83
- final_wave = final_wave.squeeze().cpu().numpy()
84
-
85
- # Save the spectrogram
86
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
87
- spectrogram_path = tmp_spectrogram.name
88
- save_spectrogram(combined_spectrogram, spectrogram_path)
89
-
90
- return (final_sample_rate, final_wave), spectrogram_path
91
-
92
-
93
- @gpu_decorator
94
- def generate_podcast(
95
- script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, model, remove_silence
96
- ):
97
- # Split the script into speaker blocks
98
- speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
99
- speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
100
-
101
- generated_audio_segments = []
102
 
103
- for i in range(0, len(speaker_blocks), 2):
104
- speaker = speaker_blocks[i]
105
- text = speaker_blocks[i + 1].strip()
106
 
107
- # Determine which speaker is talking
108
- if speaker == speaker1_name:
109
- ref_audio = ref_audio1
110
- ref_text = ref_text1
111
- elif speaker == speaker2_name:
112
- ref_audio = ref_audio2
113
- ref_text = ref_text2
114
- else:
115
- continue # Skip if the speaker is neither speaker1 nor speaker2
116
 
117
- # Generate audio for this block
118
- audio, _ = infer(ref_audio, ref_text, text, model, remove_silence)
119
 
120
- # Convert the generated audio to a numpy array
121
- sr, audio_data = audio
122
-
123
- # Save the audio data as a WAV file
124
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
125
- sf.write(temp_file.name, audio_data, sr)
126
- audio_segment = AudioSegment.from_wav(temp_file.name)
127
-
128
- generated_audio_segments.append(audio_segment)
129
-
130
- # Add a short pause between speakers
131
- pause = AudioSegment.silent(duration=500) # 500ms pause
132
- generated_audio_segments.append(pause)
133
-
134
- # Concatenate all audio segments
135
- final_podcast = sum(generated_audio_segments)
136
-
137
- # Export the final podcast
138
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
139
- podcast_path = temp_file.name
140
- final_podcast.export(podcast_path, format="wav")
141
-
142
- return podcast_path
143
-
144
-
145
- def parse_speechtypes_text(gen_text):
146
- # Pattern to find (Emotion)
147
- pattern = r"\((.*?)\)"
148
-
149
- # Split the text by the pattern
150
- tokens = re.split(pattern, gen_text)
151
 
152
- segments = []
153
 
154
- current_emotion = "Regular"
 
155
 
156
- for i in range(len(tokens)):
157
- if i % 2 == 0:
158
- # This is text
159
- text = tokens[i].strip()
160
- if text:
161
- segments.append({"emotion": current_emotion, "text": text})
162
- else:
163
- # This is emotion
164
- emotion = tokens[i].strip()
165
- current_emotion = emotion
166
 
167
- return segments
168
 
 
169
 
170
- with gr.Blocks() as app_credits:
171
- gr.Markdown("""
172
- # Credits
173
 
174
- * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
175
- * [RootingInLoad](https://github.com/RootingInLoad) for the podcast generation
176
- * [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation
177
  """)
178
- with gr.Blocks() as app_tts:
179
- gr.Markdown("# Batched TTS")
180
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
181
- gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
182
  model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
183
  generate_btn = gr.Button("Synthesize", variant="primary")
184
  with gr.Accordion("Advanced Settings", open=False):
185
- ref_text_input = gr.Textbox(
186
- label="Reference Text",
187
- info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
188
- lines=2,
189
- )
190
- remove_silence = gr.Checkbox(
191
- label="Remove Silences",
192
- info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
193
- value=False,
194
- )
195
- speed_slider = gr.Slider(
196
- label="Speed",
197
- minimum=0.3,
198
- maximum=2.0,
199
- value=1.0,
200
- step=0.1,
201
- info="Adjust the speed of the audio.",
202
- )
203
- cross_fade_duration_slider = gr.Slider(
204
- label="Cross-Fade Duration (s)",
205
- minimum=0.0,
206
- maximum=1.0,
207
- value=0.15,
208
- step=0.01,
209
- info="Set the duration of the cross-fade between audio clips.",
210
- )
211
-
212
  audio_output = gr.Audio(label="Synthesized Audio")
213
- spectrogram_output = gr.Image(label="Spectrogram")
214
-
215
- generate_btn.click(
216
- infer,
217
- inputs=[
218
- ref_audio_input,
219
- ref_text_input,
220
- gen_text_input,
221
- model_choice,
222
- remove_silence,
223
- cross_fade_duration_slider,
224
- speed_slider,
225
- ],
226
- outputs=[audio_output, spectrogram_output],
227
- )
228
-
229
- with gr.Blocks() as app_podcast:
230
- gr.Markdown("# Podcast Generation")
231
- speaker1_name = gr.Textbox(label="Speaker 1 Name")
232
- ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
233
- ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2)
234
-
235
- speaker2_name = gr.Textbox(label="Speaker 2 Name")
236
- ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath")
237
- ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2)
238
-
239
- script_input = gr.Textbox(
240
- label="Podcast Script",
241
- lines=10,
242
- placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...",
243
- )
244
-
245
- podcast_model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
246
- podcast_remove_silence = gr.Checkbox(
247
- label="Remove Silences",
248
- value=True,
249
- )
250
- generate_podcast_btn = gr.Button("Generate Podcast", variant="primary")
251
- podcast_output = gr.Audio(label="Generated Podcast")
252
-
253
- def podcast_generation(
254
- script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence
255
- ):
256
- return generate_podcast(
257
- script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence
258
- )
259
-
260
- generate_podcast_btn.click(
261
- podcast_generation,
262
- inputs=[
263
- script_input,
264
- speaker1_name,
265
- ref_audio_input1,
266
- ref_text_input1,
267
- speaker2_name,
268
- ref_audio_input2,
269
- ref_text_input2,
270
- podcast_model_choice,
271
- podcast_remove_silence,
272
- ],
273
- outputs=podcast_output,
274
- )
275
-
276
-
277
- def parse_emotional_text(gen_text):
278
- # Pattern to find (Emotion)
279
- pattern = r"\((.*?)\)"
280
-
281
- # Split the text by the pattern
282
- tokens = re.split(pattern, gen_text)
283
-
284
- segments = []
285
-
286
- current_emotion = "Regular"
287
-
288
- for i in range(len(tokens)):
289
- if i % 2 == 0:
290
- # This is text
291
- text = tokens[i].strip()
292
- if text:
293
- segments.append({"emotion": current_emotion, "text": text})
294
- else:
295
- # This is emotion
296
- emotion = tokens[i].strip()
297
- current_emotion = emotion
298
-
299
- return segments
300
-
301
-
302
- with gr.Blocks() as app_emotional:
303
- # New section for emotional generation
304
- gr.Markdown(
305
- """
306
- # Multiple Speech-Type Generation
307
-
308
- This section allows you to upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the "Add Speech Type" button. Enter your text in the format shown below, and the system will generate speech using the appropriate emotions. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
309
-
310
- **Example Input:**
311
-
312
- (Regular) Hello, I'd like to order a sandwich please. (Surprised) What do you mean you're out of bread? (Sad) I really wanted a sandwich though... (Angry) You know what, darn you and your little shop, you suck! (Whisper) I'll just go back home and cry now. (Shouting) Why me?!
313
- """
314
- )
315
-
316
- gr.Markdown(
317
- "Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button."
318
- )
319
-
320
- # Regular speech type (mandatory)
321
- with gr.Row():
322
- regular_name = gr.Textbox(value="Regular", label="Speech Type Name", interactive=False)
323
- regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
324
- regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=2)
325
-
326
- # Additional speech types (up to 99 more)
327
- max_speech_types = 100
328
- speech_type_names = []
329
- speech_type_audios = []
330
- speech_type_ref_texts = []
331
- speech_type_delete_btns = []
332
-
333
- for i in range(max_speech_types - 1):
334
- with gr.Row():
335
- name_input = gr.Textbox(label="Speech Type Name", visible=False)
336
- audio_input = gr.Audio(label="Reference Audio", type="filepath", visible=False)
337
- ref_text_input = gr.Textbox(label="Reference Text", lines=2, visible=False)
338
- delete_btn = gr.Button("Delete", variant="secondary", visible=False)
339
- speech_type_names.append(name_input)
340
- speech_type_audios.append(audio_input)
341
- speech_type_ref_texts.append(ref_text_input)
342
- speech_type_delete_btns.append(delete_btn)
343
-
344
- # Button to add speech type
345
- add_speech_type_btn = gr.Button("Add Speech Type")
346
-
347
- # Keep track of current number of speech types
348
- speech_type_count = gr.State(value=0)
349
-
350
- # Function to add a speech type
351
- def add_speech_type_fn(speech_type_count):
352
- if speech_type_count < max_speech_types - 1:
353
- speech_type_count += 1
354
- # Prepare updates for the components
355
- name_updates = []
356
- audio_updates = []
357
- ref_text_updates = []
358
- delete_btn_updates = []
359
- for i in range(max_speech_types - 1):
360
- if i < speech_type_count:
361
- name_updates.append(gr.update(visible=True))
362
- audio_updates.append(gr.update(visible=True))
363
- ref_text_updates.append(gr.update(visible=True))
364
- delete_btn_updates.append(gr.update(visible=True))
365
- else:
366
- name_updates.append(gr.update())
367
- audio_updates.append(gr.update())
368
- ref_text_updates.append(gr.update())
369
- delete_btn_updates.append(gr.update())
370
- else:
371
- # Optionally, show a warning
372
- # gr.Warning("Maximum number of speech types reached.")
373
- name_updates = [gr.update() for _ in range(max_speech_types - 1)]
374
- audio_updates = [gr.update() for _ in range(max_speech_types - 1)]
375
- ref_text_updates = [gr.update() for _ in range(max_speech_types - 1)]
376
- delete_btn_updates = [gr.update() for _ in range(max_speech_types - 1)]
377
- return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
378
-
379
- add_speech_type_btn.click(
380
- add_speech_type_fn,
381
- inputs=speech_type_count,
382
- outputs=[speech_type_count]
383
- + speech_type_names
384
- + speech_type_audios
385
- + speech_type_ref_texts
386
- + speech_type_delete_btns,
387
- )
388
-
389
- # Function to delete a speech type
390
- def make_delete_speech_type_fn(index):
391
- def delete_speech_type_fn(speech_type_count):
392
- # Prepare updates
393
- name_updates = []
394
- audio_updates = []
395
- ref_text_updates = []
396
- delete_btn_updates = []
397
-
398
- for i in range(max_speech_types - 1):
399
- if i == index:
400
- name_updates.append(gr.update(visible=False, value=""))
401
- audio_updates.append(gr.update(visible=False, value=None))
402
- ref_text_updates.append(gr.update(visible=False, value=""))
403
- delete_btn_updates.append(gr.update(visible=False))
404
- else:
405
- name_updates.append(gr.update())
406
- audio_updates.append(gr.update())
407
- ref_text_updates.append(gr.update())
408
- delete_btn_updates.append(gr.update())
409
-
410
- speech_type_count = max(0, speech_type_count - 1)
411
-
412
- return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
413
-
414
- return delete_speech_type_fn
415
-
416
- for i, delete_btn in enumerate(speech_type_delete_btns):
417
- delete_fn = make_delete_speech_type_fn(i)
418
- delete_btn.click(
419
- delete_fn,
420
- inputs=speech_type_count,
421
- outputs=[speech_type_count]
422
- + speech_type_names
423
- + speech_type_audios
424
- + speech_type_ref_texts
425
- + speech_type_delete_btns,
426
- )
427
-
428
- # Text input for the prompt
429
- gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10)
430
-
431
- # Model choice
432
- model_choice_emotional = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
433
-
434
- with gr.Accordion("Advanced Settings", open=False):
435
- remove_silence_emotional = gr.Checkbox(
436
- label="Remove Silences",
437
- value=True,
438
- )
439
-
440
- # Generate button
441
- generate_emotional_btn = gr.Button("Generate Emotional Speech", variant="primary")
442
-
443
- # Output audio
444
- audio_output_emotional = gr.Audio(label="Synthesized Audio")
445
-
446
- @gpu_decorator
447
- def generate_emotional_speech(
448
- regular_audio,
449
- regular_ref_text,
450
- gen_text,
451
- *args,
452
- ):
453
- num_additional_speech_types = max_speech_types - 1
454
- speech_type_names_list = args[:num_additional_speech_types]
455
- speech_type_audios_list = args[num_additional_speech_types : 2 * num_additional_speech_types]
456
- speech_type_ref_texts_list = args[2 * num_additional_speech_types : 3 * num_additional_speech_types]
457
- model_choice = args[3 * num_additional_speech_types]
458
- remove_silence = args[3 * num_additional_speech_types + 1]
459
-
460
- # Collect the speech types and their audios into a dict
461
- speech_types = {"Regular": {"audio": regular_audio, "ref_text": regular_ref_text}}
462
-
463
- for name_input, audio_input, ref_text_input in zip(
464
- speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
465
- ):
466
- if name_input and audio_input:
467
- speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input}
468
-
469
- # Parse the gen_text into segments
470
- segments = parse_speechtypes_text(gen_text)
471
-
472
- # For each segment, generate speech
473
- generated_audio_segments = []
474
- current_emotion = "Regular"
475
-
476
- for segment in segments:
477
- emotion = segment["emotion"]
478
- text = segment["text"]
479
-
480
- if emotion in speech_types:
481
- current_emotion = emotion
482
- else:
483
- # If emotion not available, default to Regular
484
- current_emotion = "Regular"
485
-
486
- ref_audio = speech_types[current_emotion]["audio"]
487
- ref_text = speech_types[current_emotion].get("ref_text", "")
488
-
489
- # Generate speech for this segment
490
- audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, 0)
491
- sr, audio_data = audio
492
-
493
- generated_audio_segments.append(audio_data)
494
-
495
- # Concatenate all audio segments
496
- if generated_audio_segments:
497
- final_audio_data = np.concatenate(generated_audio_segments)
498
- return (sr, final_audio_data)
499
- else:
500
- gr.Warning("No audio generated.")
501
- return None
502
-
503
- generate_emotional_btn.click(
504
- generate_emotional_speech,
505
- inputs=[
506
- regular_audio,
507
- regular_ref_text,
508
- gen_text_input_emotional,
509
- ]
510
- + speech_type_names
511
- + speech_type_audios
512
- + speech_type_ref_texts
513
- + [
514
- model_choice_emotional,
515
- remove_silence_emotional,
516
- ],
517
- outputs=audio_output_emotional,
518
- )
519
-
520
- # Validation function to disable Generate button if speech types are missing
521
- def validate_speech_types(gen_text, regular_name, *args):
522
- num_additional_speech_types = max_speech_types - 1
523
- speech_type_names_list = args[:num_additional_speech_types]
524
-
525
- # Collect the speech types names
526
- speech_types_available = set()
527
- if regular_name:
528
- speech_types_available.add(regular_name)
529
- for name_input in speech_type_names_list:
530
- if name_input:
531
- speech_types_available.add(name_input)
532
-
533
- # Parse the gen_text to get the speech types used
534
- segments = parse_emotional_text(gen_text)
535
- speech_types_in_text = set(segment["emotion"] for segment in segments)
536
-
537
- # Check if all speech types in text are available
538
- missing_speech_types = speech_types_in_text - speech_types_available
539
-
540
- if missing_speech_types:
541
- # Disable the generate button
542
- return gr.update(interactive=False)
543
- else:
544
- # Enable the generate button
545
- return gr.update(interactive=True)
546
-
547
- gen_text_input_emotional.change(
548
- validate_speech_types,
549
- inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
550
- outputs=generate_emotional_btn,
551
- )
552
- with gr.Blocks() as app:
553
- gr.Markdown(
554
- """
555
- # E2/F5 TTS
556
 
557
- This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
 
 
558
 
559
- * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
560
- * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
561
 
562
- The checkpoints support English and Chinese.
563
 
564
- If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
 
 
 
 
 
565
 
566
- **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
567
- """
568
- )
569
- gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
570
-
571
-
572
- @click.command()
573
- @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
574
- @click.option("--host", "-H", default=None, help="Host to run the app on")
575
- @click.option(
576
- "--share",
577
- "-s",
578
- default=False,
579
- is_flag=True,
580
- help="Share the app via Gradio share link",
581
- )
582
- @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
583
- def main(port, host, share, api):
584
- global app
585
- print("Starting app...")
586
- app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api)
587
 
588
 
589
- if __name__ == "__main__":
590
- if not USING_SPACES:
591
- main()
592
- else:
593
- app.queue().launch()
 
1
+ import os
 
 
2
  import re
3
+ import torch
4
+ import torchaudio
 
5
  import gradio as gr
6
  import numpy as np
7
+ import tempfile
8
+ from einops import rearrange
9
+ from ema_pytorch import EMA
10
+ from vocos import Vocos
11
  from pydub import AudioSegment
12
+ from model import CFM, UNetT, DiT, MMDiT
13
+ from cached_path import cached_path
14
+ from model.utils import (
15
+ get_tokenizer,
16
+ convert_char_to_pinyin,
17
+ save_spectrogram,
18
+ )
19
+ from transformers import pipeline
20
+ import spaces
21
+ import librosa
22
+ from txtsplit import txtsplit
23
+ from detoxify import Detoxify
24
 
 
 
 
 
 
 
25
 
26
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
27
 
28
+ model = Detoxify('original', device=device)
 
 
 
 
29
 
30
 
31
+ pipe = pipeline(
32
+ "automatic-speech-recognition",
33
+ model="openai/whisper-large-v3-turbo",
34
+ torch_dtype=torch.float16,
35
+ device=device,
36
  )
 
 
 
 
 
 
 
 
 
37
 
38
+ # --------------------- Settings -------------------- #
39
+
40
+ target_sample_rate = 24000
41
+ n_mel_channels = 100
42
+ hop_length = 256
43
+ target_rms = 0.1
44
+ nfe_step = 32 # 16, 32
45
+ cfg_strength = 2.0
46
+ ode_method = 'euler'
47
+ sway_sampling_coef = -1.0
48
+ speed = 1.0
49
+ # fix_duration = 27 # None or float (duration in seconds)
50
+ fix_duration = None
51
+
52
+ def load_model(exp_name, model_cls, model_cfg, ckpt_step):
53
+ checkpoint = torch.load(str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
54
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
55
+ model = CFM(
56
+ transformer=model_cls(
57
+ **model_cfg,
58
+ text_num_embeds=vocab_size,
59
+ mel_dim=n_mel_channels
60
+ ),
61
+ mel_spec_kwargs=dict(
62
+ target_sample_rate=target_sample_rate,
63
+ n_mel_channels=n_mel_channels,
64
+ hop_length=hop_length,
65
+ ),
66
+ odeint_kwargs=dict(
67
+ method=ode_method,
68
+ ),
69
+ vocab_char_map=vocab_char_map,
70
+ ).to(device)
71
+
72
+ ema_model = EMA(model, include_online_model=False).to(device)
73
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
74
+ ema_model.copy_params_from_ema_to_model()
75
+
76
+ return ema_model, model
77
 
78
  # load models
79
  F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
 
 
 
 
80
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
 
 
 
 
 
 
 
 
81
 
82
+ F5TTS_ema_model, F5TTS_base_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
83
+ E2TTS_ema_model, E2TTS_base_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
84
+
85
+ @spaces.GPU
86
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
87
+ print(gen_text)
88
+ if model.predict(gen_text)['toxicity'] > 0.8:
89
+ print("Flagged for toxicity:", gen_text)
90
+ raise gr.Error("Your text was flagged for toxicity, please try again with a different text.")
91
+ gr.Info("Converting audio...")
92
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
93
+ aseg = AudioSegment.from_file(ref_audio_orig)
94
+ # Convert to mono
95
+ aseg = aseg.set_channels(1)
96
+ audio_duration = len(aseg)
97
+ if audio_duration > 15000:
98
+ gr.Warning("Audio is over 15s, clipping to only first 15s.")
99
+ aseg = aseg[:15000]
100
+ aseg.export(f.name, format="wav")
101
+ ref_audio = f.name
102
+ if exp_name == "F5-TTS":
103
  ema_model = F5TTS_ema_model
104
+ base_model = F5TTS_base_model
105
+ elif exp_name == "E2-TTS":
106
  ema_model = E2TTS_ema_model
107
+ base_model = E2TTS_base_model
108
+
109
+ if not ref_text.strip():
110
+ gr.Info("No reference text provided, transcribing reference audio...")
111
+ ref_text = outputs = pipe(
112
+ ref_audio,
113
+ chunk_length_s=30,
114
+ batch_size=128,
115
+ generate_kwargs={"task": "transcribe"},
116
+ return_timestamps=False,
117
+ )['text'].strip()
118
+ gr.Info("Finished transcription")
119
+ else:
120
+ gr.Info("Using custom reference text...")
121
+ audio, sr = torchaudio.load(ref_audio)
122
+ # Audio
123
+ if audio.shape[0] > 1:
124
+ audio = torch.mean(audio, dim=0, keepdim=True)
125
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
126
+ if rms < target_rms:
127
+ audio = audio * target_rms / rms
128
+ if sr != target_sample_rate:
129
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
130
+ audio = resampler(audio)
131
+ audio = audio.to(device)
132
+ # Chunk
133
+ chunks = txtsplit(gen_text, 100, 150) # 100 chars preferred, 150 max
134
+ results = []
135
+ generated_mel_specs = []
136
+ for chunk in progress.tqdm(chunks):
137
+ # Prepare the text
138
+ text_list = [ref_text + chunk]
139
+ final_text_list = convert_char_to_pinyin(text_list)
140
+
141
+ # Calculate duration
142
+ ref_audio_len = audio.shape[-1] // hop_length
143
+ # if fix_duration is not None:
144
+ # duration = int(fix_duration * target_sample_rate / hop_length)
145
+ # else:
146
+ zh_pause_punc = r"。,、;:?!"
147
+ ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
148
+ gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
149
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
150
+
151
+ # inference
152
+ gr.Info(f"Generating audio using {exp_name}")
153
+ with torch.inference_mode():
154
+ generated, _ = base_model.sample(
155
+ cond=audio,
156
+ text=final_text_list,
157
+ duration=duration,
158
+ steps=nfe_step,
159
+ cfg_strength=cfg_strength,
160
+ sway_sampling_coef=sway_sampling_coef,
161
+ )
162
+
163
+ generated = generated[:, ref_audio_len:, :]
164
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
165
+ gr.Info("Running vocoder")
166
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
167
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
168
+ if rms < target_rms:
169
+ generated_wave = generated_wave * rms / target_rms
170
+
171
+ # wav -> numpy
172
+ generated_wave = generated_wave.squeeze().cpu().numpy()
173
+ results.append(generated_wave)
174
+ generated_wave = np.concatenate(results)
175
  if remove_silence:
176
+ gr.Info("Removing audio silences... This may take a moment")
177
+ non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
178
+ non_silent_wave = np.array([])
179
+ for interval in non_silent_intervals:
180
+ start, end = interval
181
+ non_silent_wave = np.concatenate([non_silent_wave, generated_wave[start:end]])
182
+ generated_wave = non_silent_wave
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
 
 
 
184
 
185
+ # spectogram
186
+ # with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
187
+ # spectrogram_path = tmp_spectrogram.name
188
+ # save_spectrogram(generated_mel_spec[0].cpu().numpy(), spectrogram_path)
 
 
 
 
 
189
 
190
+ return (target_sample_rate, generated_wave)
 
191
 
192
+ with gr.Blocks() as app:
193
+ gr.Markdown("""
194
+ # E2/F5 TTS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
+ This is an unofficial E2/F5 TTS demo. This demo supports the following TTS models:
197
 
198
+ * [E2-TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
199
+ * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
200
 
201
+ This demo is based on the [F5-TTS](https://github.com/SWivid/F5-TTS) codebase, which is based on an [unofficial E2-TTS implementation](https://github.com/lucidrains/e2-tts-pytorch).
 
 
 
 
 
 
 
 
 
202
 
203
+ The checkpoints support English and Chinese.
204
 
205
+ If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt. If you're still running into issues, please open a [community Discussion](https://huggingface.co/spaces/mrfakename/E2-F5-TTS/discussions).
206
 
207
+ The model is licensed under the CC-BY-NC license, this demo cannot be used for commercial purposes.
 
 
208
 
209
+ **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
 
 
210
  """)
211
+
 
212
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
213
+ gen_text_input = gr.Textbox(label="Text to Generate (longer text will use chunking)", lines=4)
214
  model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
215
  generate_btn = gr.Button("Synthesize", variant="primary")
216
  with gr.Accordion("Advanced Settings", open=False):
217
+ ref_text_input = gr.Textbox(label="Reference Text", info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.", lines=2)
218
+ remove_silence = gr.Checkbox(label="Remove Silences", info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.", value=True)
219
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  audio_output = gr.Audio(label="Synthesized Audio")
221
+ # spectrogram_output = gr.Image(label="Spectrogram")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
+ generate_btn.click(infer, inputs=[ref_audio_input, ref_text_input, gen_text_input, model_choice, remove_silence], outputs=[audio_output])
224
+ gr.Markdown("""
225
+ ## Run Locally
226
 
227
+ Run this demo locally on CPU, CUDA, or MPS/Apple Silicon (requires macOS >= 14):
 
228
 
229
+ First, ensure `ffmpeg` is installed.
230
 
231
+ ```bash
232
+ git clone https://huggingface.co/spaces/mrfakename/E2-F5-TTS
233
+ cd E2-F5-TTS
234
+ python -m pip install -r requirements.txt
235
+ python app_local.py
236
+ ```
237
 
238
+ """)
239
+ gr.Markdown("Unofficial demo by [mrfakename](https://x.com/realmrfakename)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
 
242
+ app.queue().launch()
 
 
 
 
app_local.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print("WARNING: You are running this unofficial E2/F5 TTS demo locally, it may not be as up-to-date as the hosted version (https://huggingface.co/spaces/mrfakename/E2-F5-TTS)")
2
+
3
+ import os
4
+ import re
5
+ import torch
6
+ import torchaudio
7
+ import gradio as gr
8
+ import numpy as np
9
+ import tempfile
10
+ from einops import rearrange
11
+ from ema_pytorch import EMA
12
+ from vocos import Vocos
13
+ from pydub import AudioSegment
14
+ from model import CFM, UNetT, DiT, MMDiT
15
+ from cached_path import cached_path
16
+ from model.utils import (
17
+ get_tokenizer,
18
+ convert_char_to_pinyin,
19
+ save_spectrogram,
20
+ )
21
+ from transformers import pipeline
22
+ import librosa
23
+ from txtsplit import txtsplit
24
+
25
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
26
+
27
+ pipe = pipeline(
28
+ "automatic-speech-recognition",
29
+ model="openai/whisper-large-v3-turbo",
30
+ torch_dtype=torch.float16,
31
+ device=device,
32
+ )
33
+
34
+ # --------------------- Settings -------------------- #
35
+
36
+ target_sample_rate = 24000
37
+ n_mel_channels = 100
38
+ hop_length = 256
39
+ target_rms = 0.1
40
+ nfe_step = 32 # 16, 32
41
+ cfg_strength = 2.0
42
+ ode_method = 'euler'
43
+ sway_sampling_coef = -1.0
44
+ speed = 1.0
45
+ # fix_duration = 27 # None or float (duration in seconds)
46
+ fix_duration = None
47
+
48
+ def load_model(exp_name, model_cls, model_cfg, ckpt_step):
49
+ checkpoint = torch.load(str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
50
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
51
+ model = CFM(
52
+ transformer=model_cls(
53
+ **model_cfg,
54
+ text_num_embeds=vocab_size,
55
+ mel_dim=n_mel_channels
56
+ ),
57
+ mel_spec_kwargs=dict(
58
+ target_sample_rate=target_sample_rate,
59
+ n_mel_channels=n_mel_channels,
60
+ hop_length=hop_length,
61
+ ),
62
+ odeint_kwargs=dict(
63
+ method=ode_method,
64
+ ),
65
+ vocab_char_map=vocab_char_map,
66
+ ).to(device)
67
+
68
+ ema_model = EMA(model, include_online_model=False).to(device)
69
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
70
+ ema_model.copy_params_from_ema_to_model()
71
+
72
+ return ema_model, model
73
+
74
+ # load models
75
+ F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
76
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
77
+
78
+ F5TTS_ema_model, F5TTS_base_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
79
+ E2TTS_ema_model, E2TTS_base_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
80
+
81
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
82
+ print(gen_text)
83
+ gr.Info("Converting audio...")
84
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
85
+ aseg = AudioSegment.from_file(ref_audio_orig)
86
+ # Convert to mono
87
+ aseg = aseg.set_channels(1)
88
+ audio_duration = len(aseg)
89
+ if audio_duration > 15000:
90
+ gr.Warning("Audio is over 15s, clipping to only first 15s.")
91
+ aseg = aseg[:15000]
92
+ aseg.export(f.name, format="wav")
93
+ ref_audio = f.name
94
+ if exp_name == "F5-TTS":
95
+ ema_model = F5TTS_ema_model
96
+ base_model = F5TTS_base_model
97
+ elif exp_name == "E2-TTS":
98
+ ema_model = E2TTS_ema_model
99
+ base_model = E2TTS_base_model
100
+
101
+ if not ref_text.strip():
102
+ gr.Info("No reference text provided, transcribing reference audio...")
103
+ ref_text = outputs = pipe(
104
+ ref_audio,
105
+ chunk_length_s=30,
106
+ batch_size=128,
107
+ generate_kwargs={"task": "transcribe"},
108
+ return_timestamps=False,
109
+ )['text'].strip()
110
+ gr.Info("Finished transcription")
111
+ else:
112
+ gr.Info("Using custom reference text...")
113
+ audio, sr = torchaudio.load(ref_audio)
114
+ # Audio
115
+ if audio.shape[0] > 1:
116
+ audio = torch.mean(audio, dim=0, keepdim=True)
117
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
118
+ if rms < target_rms:
119
+ audio = audio * target_rms / rms
120
+ if sr != target_sample_rate:
121
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
122
+ audio = resampler(audio)
123
+ audio = audio.to(device)
124
+ # Chunk
125
+ chunks = txtsplit(gen_text, 100, 150) # 100 chars preferred, 150 max
126
+ results = []
127
+ generated_mel_specs = []
128
+ for chunk in progress.tqdm(chunks):
129
+ # Prepare the text
130
+ text_list = [ref_text + chunk]
131
+ final_text_list = convert_char_to_pinyin(text_list)
132
+
133
+ # Calculate duration
134
+ ref_audio_len = audio.shape[-1] // hop_length
135
+ # if fix_duration is not None:
136
+ # duration = int(fix_duration * target_sample_rate / hop_length)
137
+ # else:
138
+ zh_pause_punc = r"。,、;:?!"
139
+ ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
140
+ gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
141
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
142
+
143
+ # inference
144
+ gr.Info(f"Generating audio using {exp_name}")
145
+ with torch.inference_mode():
146
+ generated, _ = base_model.sample(
147
+ cond=audio,
148
+ text=final_text_list,
149
+ duration=duration,
150
+ steps=nfe_step,
151
+ cfg_strength=cfg_strength,
152
+ sway_sampling_coef=sway_sampling_coef,
153
+ )
154
+
155
+ generated = generated[:, ref_audio_len:, :]
156
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
157
+ gr.Info("Running vocoder")
158
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
159
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
160
+ if rms < target_rms:
161
+ generated_wave = generated_wave * rms / target_rms
162
+
163
+ # wav -> numpy
164
+ generated_wave = generated_wave.squeeze().cpu().numpy()
165
+ results.append(generated_wave)
166
+ generated_wave = np.concatenate(results)
167
+ if remove_silence:
168
+ gr.Info("Removing audio silences... This may take a moment")
169
+ non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
170
+ non_silent_wave = np.array([])
171
+ for interval in non_silent_intervals:
172
+ start, end = interval
173
+ non_silent_wave = np.concatenate([non_silent_wave, generated_wave[start:end]])
174
+ generated_wave = non_silent_wave
175
+
176
+
177
+ # spectogram
178
+ # with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
179
+ # spectrogram_path = tmp_spectrogram.name
180
+ # save_spectrogram(generated_mel_spec[0].cpu().numpy(), spectrogram_path)
181
+
182
+ return (target_sample_rate, generated_wave)
183
+
184
+ with gr.Blocks() as app:
185
+ gr.Markdown("""
186
+ # E2/F5 TTS
187
+
188
+ This is an unofficial E2/F5 TTS demo. This demo supports the following TTS models:
189
+
190
+ * [E2-TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
191
+ * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
192
+
193
+ This demo is based on the [F5-TTS](https://github.com/SWivid/F5-TTS) codebase, which is based on an [unofficial E2-TTS implementation](https://github.com/lucidrains/e2-tts-pytorch).
194
+
195
+ The checkpoints support English and Chinese.
196
+
197
+ If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt. If you're still running into issues, please open a [community Discussion](https://huggingface.co/spaces/mrfakename/E2-F5-TTS/discussions).
198
+
199
+ Long-form/batched inference + speech editing is coming soon!
200
+
201
+ **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
202
+ """)
203
+
204
+ ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
205
+ gen_text_input = gr.Textbox(label="Text to Generate (longer text will use chunking)", lines=4)
206
+ model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
207
+ generate_btn = gr.Button("Synthesize", variant="primary")
208
+ with gr.Accordion("Advanced Settings", open=False):
209
+ ref_text_input = gr.Textbox(label="Reference Text", info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.", lines=2)
210
+ remove_silence = gr.Checkbox(label="Remove Silences", info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.", value=True)
211
+
212
+ audio_output = gr.Audio(label="Synthesized Audio")
213
+ # spectrogram_output = gr.Image(label="Spectrogram")
214
+
215
+ generate_btn.click(infer, inputs=[ref_audio_input, ref_text_input, gen_text_input, model_choice, remove_silence], outputs=[audio_output])
216
+ gr.Markdown("Unofficial demo by [mrfakename](https://x.com/realmrfakename)")
217
+
218
+
219
+ app.queue().launch()
cog.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prediction interface for Cog ⚙️
2
+ # https://cog.run/python
3
+
4
+ from cog import BasePredictor, Input, Path
5
+
6
+ import os
7
+ import re
8
+ import torch
9
+ import torchaudio
10
+ import numpy as np
11
+ import tempfile
12
+ from einops import rearrange
13
+ from ema_pytorch import EMA
14
+ from vocos import Vocos
15
+ from pydub import AudioSegment
16
+ from model import CFM, UNetT, DiT, MMDiT
17
+ from cached_path import cached_path
18
+ from model.utils import (
19
+ get_tokenizer,
20
+ convert_char_to_pinyin,
21
+ save_spectrogram,
22
+ )
23
+ from transformers import pipeline
24
+ import librosa
25
+
26
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
27
+
28
+ target_sample_rate = 24000
29
+ n_mel_channels = 100
30
+ hop_length = 256
31
+ target_rms = 0.1
32
+ nfe_step = 32 # 16, 32
33
+ cfg_strength = 2.0
34
+ ode_method = 'euler'
35
+ sway_sampling_coef = -1.0
36
+ speed = 1.0
37
+ # fix_duration = 27 # None or float (duration in seconds)
38
+ fix_duration = None
39
+
40
+
41
+ class Predictor(BasePredictor):
42
+ def load_model(exp_name, model_cls, model_cfg, ckpt_step):
43
+ checkpoint = torch.load(str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
44
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
45
+ model = CFM(
46
+ transformer=model_cls(
47
+ **model_cfg,
48
+ text_num_embeds=vocab_size,
49
+ mel_dim=n_mel_channels
50
+ ),
51
+ mel_spec_kwargs=dict(
52
+ target_sample_rate=target_sample_rate,
53
+ n_mel_channels=n_mel_channels,
54
+ hop_length=hop_length,
55
+ ),
56
+ odeint_kwargs=dict(
57
+ method=ode_method,
58
+ ),
59
+ vocab_char_map=vocab_char_map,
60
+ ).to(device)
61
+
62
+ ema_model = EMA(model, include_online_model=False).to(device)
63
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
64
+ ema_model.copy_params_from_ema_to_model()
65
+
66
+ return ema_model, model
67
+ def setup(self) -> None:
68
+ """Load the model into memory to make running multiple predictions efficient"""
69
+ # self.model = torch.load("./weights.pth")
70
+ print("Loading Whisper model...")
71
+ self.pipe = pipeline(
72
+ "automatic-speech-recognition",
73
+ model="openai/whisper-large-v3-turbo",
74
+ torch_dtype=torch.float16,
75
+ device=device,
76
+ )
77
+ print("Loading F5-TTS model...")
78
+
79
+ F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
80
+ self.F5TTS_ema_model, self.F5TTS_base_model = self.load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
81
+
82
+
83
+ def predict(
84
+ self,
85
+ gen_text: str = Input(description="Text to generate"),
86
+ ref_audio_orig: Path = Input(description="Reference audio"),
87
+ remove_silence: bool = Input(description="Remove silences", default=True),
88
+ ) -> Path:
89
+ """Run a single prediction on the model"""
90
+ model_choice = "F5-TTS"
91
+ print(gen_text)
92
+ if len(gen_text) > 200:
93
+ raise gr.Error("Please keep your text under 200 chars.")
94
+ gr.Info("Converting audio...")
95
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
96
+ aseg = AudioSegment.from_file(ref_audio_orig)
97
+ audio_duration = len(aseg)
98
+ if audio_duration > 15000:
99
+ gr.Warning("Audio is over 15s, clipping to only first 15s.")
100
+ aseg = aseg[:15000]
101
+ aseg.export(f.name, format="wav")
102
+ ref_audio = f.name
103
+ ema_model = self.F5TTS_ema_model
104
+ base_model = self.F5TTS_base_model
105
+
106
+ if not ref_text.strip():
107
+ gr.Info("No reference text provided, transcribing reference audio...")
108
+ ref_text = outputs = self.pipe(
109
+ ref_audio,
110
+ chunk_length_s=30,
111
+ batch_size=128,
112
+ generate_kwargs={"task": "transcribe"},
113
+ return_timestamps=False,
114
+ )['text'].strip()
115
+ gr.Info("Finished transcription")
116
+ else:
117
+ gr.Info("Using custom reference text...")
118
+ audio, sr = torchaudio.load(ref_audio)
119
+
120
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
121
+ if rms < target_rms:
122
+ audio = audio * target_rms / rms
123
+ if sr != target_sample_rate:
124
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
125
+ audio = resampler(audio)
126
+ audio = audio.to(device)
127
+
128
+ # Prepare the text
129
+ text_list = [ref_text + gen_text]
130
+ final_text_list = convert_char_to_pinyin(text_list)
131
+
132
+ # Calculate duration
133
+ ref_audio_len = audio.shape[-1] // hop_length
134
+ # if fix_duration is not None:
135
+ # duration = int(fix_duration * target_sample_rate / hop_length)
136
+ # else:
137
+ zh_pause_punc = r"。,、;:?!"
138
+ ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
139
+ gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
140
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
141
+
142
+ # inference
143
+ gr.Info(f"Generating audio using F5-TTS")
144
+ with torch.inference_mode():
145
+ generated, _ = base_model.sample(
146
+ cond=audio,
147
+ text=final_text_list,
148
+ duration=duration,
149
+ steps=nfe_step,
150
+ cfg_strength=cfg_strength,
151
+ sway_sampling_coef=sway_sampling_coef,
152
+ )
153
+
154
+ generated = generated[:, ref_audio_len:, :]
155
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
156
+ gr.Info("Running vocoder")
157
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
158
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
159
+ if rms < target_rms:
160
+ generated_wave = generated_wave * rms / target_rms
161
+
162
+ # wav -> numpy
163
+ generated_wave = generated_wave.squeeze().cpu().numpy()
164
+
165
+ if remove_silence:
166
+ gr.Info("Removing audio silences... This may take a moment")
167
+ non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
168
+ non_silent_wave = np.array([])
169
+ for interval in non_silent_intervals:
170
+ start, end = interval
171
+ non_silent_wave = np.concatenate([non_silent_wave, generated_wave[start:end]])
172
+ generated_wave = non_silent_wave
173
+
174
+
175
+ # spectogram
176
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav:
177
+ wav_path = tmp_wav.name
178
+ torchaudio.save(wav_path, torch.tensor(generated_wave), target_sample_rate)
179
+
180
+ return wav_path
data/.DS_Store DELETED
Binary file (6.15 kB)
 
finetune-cli.py DELETED
@@ -1,127 +0,0 @@
1
- import argparse
2
- from model import CFM, UNetT, DiT, Trainer
3
- from model.utils import get_tokenizer
4
- from model.dataset import load_dataset
5
- from cached_path import cached_path
6
- import shutil
7
- import os
8
-
9
- # -------------------------- Dataset Settings --------------------------- #
10
- target_sample_rate = 24000
11
- n_mel_channels = 100
12
- hop_length = 256
13
-
14
-
15
- # -------------------------- Argument Parsing --------------------------- #
16
- def parse_args():
17
- parser = argparse.ArgumentParser(description="Train CFM Model")
18
-
19
- parser.add_argument(
20
- "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
21
- )
22
- parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
23
- parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for training")
24
- parser.add_argument("--batch_size_per_gpu", type=int, default=256, help="Batch size per GPU")
25
- parser.add_argument(
26
- "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
27
- )
28
- parser.add_argument("--max_samples", type=int, default=16, help="Max sequences per batch")
29
- parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
30
- parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
31
- parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
32
- parser.add_argument("--num_warmup_updates", type=int, default=5, help="Warmup steps")
33
- parser.add_argument("--save_per_updates", type=int, default=10, help="Save checkpoint every X steps")
34
- parser.add_argument("--last_per_steps", type=int, default=10, help="Save last checkpoint every X steps")
35
- parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
36
-
37
- parser.add_argument(
38
- "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
39
- )
40
- parser.add_argument(
41
- "--tokenizer_path",
42
- type=str,
43
- default=None,
44
- help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
45
- )
46
-
47
- return parser.parse_args()
48
-
49
-
50
- # -------------------------- Training Settings -------------------------- #
51
-
52
-
53
- def main():
54
- args = parse_args()
55
-
56
- # Model parameters based on experiment name
57
- if args.exp_name == "F5TTS_Base":
58
- wandb_resume_id = None
59
- model_cls = DiT
60
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
61
- if args.finetune:
62
- ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
63
- elif args.exp_name == "E2TTS_Base":
64
- wandb_resume_id = None
65
- model_cls = UNetT
66
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
67
- if args.finetune:
68
- ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
69
-
70
- if args.finetune:
71
- path_ckpt = os.path.join("ckpts", args.dataset_name)
72
- if not os.path.isdir(path_ckpt):
73
- os.makedirs(path_ckpt, exist_ok=True)
74
- shutil.copy2(ckpt_path, os.path.join(path_ckpt, os.path.basename(ckpt_path)))
75
-
76
- checkpoint_path = os.path.join("ckpts", args.dataset_name)
77
-
78
- # Use the tokenizer and tokenizer_path provided in the command line arguments
79
- tokenizer = args.tokenizer
80
- if tokenizer == "custom":
81
- if not args.tokenizer_path:
82
- raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.")
83
- tokenizer_path = args.tokenizer_path
84
- else:
85
- tokenizer_path = args.dataset_name
86
-
87
- vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
88
-
89
- mel_spec_kwargs = dict(
90
- target_sample_rate=target_sample_rate,
91
- n_mel_channels=n_mel_channels,
92
- hop_length=hop_length,
93
- )
94
-
95
- e2tts = CFM(
96
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
97
- mel_spec_kwargs=mel_spec_kwargs,
98
- vocab_char_map=vocab_char_map,
99
- )
100
-
101
- trainer = Trainer(
102
- e2tts,
103
- args.epochs,
104
- args.learning_rate,
105
- num_warmup_updates=args.num_warmup_updates,
106
- save_per_updates=args.save_per_updates,
107
- checkpoint_path=checkpoint_path,
108
- batch_size=args.batch_size_per_gpu,
109
- batch_size_type=args.batch_size_type,
110
- max_samples=args.max_samples,
111
- grad_accumulation_steps=args.grad_accumulation_steps,
112
- max_grad_norm=args.max_grad_norm,
113
- wandb_project="CFM-TTS",
114
- wandb_run_name=args.exp_name,
115
- wandb_resume_id=wandb_resume_id,
116
- last_per_steps=args.last_per_steps,
117
- )
118
-
119
- train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
120
- trainer.train(
121
- train_dataset,
122
- resumable_with_seed=666, # seed for shuffling dataset
123
- )
124
-
125
-
126
- if __name__ == "__main__":
127
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetune_gradio.py DELETED
@@ -1,944 +0,0 @@
1
- import os
2
- import sys
3
-
4
- import tempfile
5
- import random
6
- from transformers import pipeline
7
- import gradio as gr
8
- import torch
9
- import gc
10
- import click
11
- import torchaudio
12
- from glob import glob
13
- import librosa
14
- import numpy as np
15
- from scipy.io import wavfile
16
- import shutil
17
- import time
18
-
19
- import json
20
- from model.utils import convert_char_to_pinyin
21
- import signal
22
- import psutil
23
- import platform
24
- import subprocess
25
- from datasets.arrow_writer import ArrowWriter
26
- from datasets import Dataset as Dataset_
27
- from api import F5TTS
28
-
29
-
30
- training_process = None
31
- system = platform.system()
32
- python_executable = sys.executable or "python"
33
- tts_api = None
34
- last_checkpoint = ""
35
- last_device = ""
36
-
37
- path_data = "data"
38
-
39
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
40
-
41
- pipe = None
42
-
43
-
44
- # Load metadata
45
- def get_audio_duration(audio_path):
46
- """Calculate the duration of an audio file."""
47
- audio, sample_rate = torchaudio.load(audio_path)
48
- num_channels = audio.shape[0]
49
- return audio.shape[1] / (sample_rate * num_channels)
50
-
51
-
52
- def clear_text(text):
53
- """Clean and prepare text by lowering the case and stripping whitespace."""
54
- return text.lower().strip()
55
-
56
-
57
- def get_rms(
58
- y,
59
- frame_length=2048,
60
- hop_length=512,
61
- pad_mode="constant",
62
- ): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
63
- padding = (int(frame_length // 2), int(frame_length // 2))
64
- y = np.pad(y, padding, mode=pad_mode)
65
-
66
- axis = -1
67
- # put our new within-frame axis at the end for now
68
- out_strides = y.strides + tuple([y.strides[axis]])
69
- # Reduce the shape on the framing axis
70
- x_shape_trimmed = list(y.shape)
71
- x_shape_trimmed[axis] -= frame_length - 1
72
- out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
73
- xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
74
- if axis < 0:
75
- target_axis = axis - 1
76
- else:
77
- target_axis = axis + 1
78
- xw = np.moveaxis(xw, -1, target_axis)
79
- # Downsample along the target axis
80
- slices = [slice(None)] * xw.ndim
81
- slices[axis] = slice(0, None, hop_length)
82
- x = xw[tuple(slices)]
83
-
84
- # Calculate power
85
- power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
86
-
87
- return np.sqrt(power)
88
-
89
-
90
- class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
91
- def __init__(
92
- self,
93
- sr: int,
94
- threshold: float = -40.0,
95
- min_length: int = 2000,
96
- min_interval: int = 300,
97
- hop_size: int = 20,
98
- max_sil_kept: int = 2000,
99
- ):
100
- if not min_length >= min_interval >= hop_size:
101
- raise ValueError("The following condition must be satisfied: min_length >= min_interval >= hop_size")
102
- if not max_sil_kept >= hop_size:
103
- raise ValueError("The following condition must be satisfied: max_sil_kept >= hop_size")
104
- min_interval = sr * min_interval / 1000
105
- self.threshold = 10 ** (threshold / 20.0)
106
- self.hop_size = round(sr * hop_size / 1000)
107
- self.win_size = min(round(min_interval), 4 * self.hop_size)
108
- self.min_length = round(sr * min_length / 1000 / self.hop_size)
109
- self.min_interval = round(min_interval / self.hop_size)
110
- self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
111
-
112
- def _apply_slice(self, waveform, begin, end):
113
- if len(waveform.shape) > 1:
114
- return waveform[:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)]
115
- else:
116
- return waveform[begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)]
117
-
118
- # @timeit
119
- def slice(self, waveform):
120
- if len(waveform.shape) > 1:
121
- samples = waveform.mean(axis=0)
122
- else:
123
- samples = waveform
124
- if samples.shape[0] <= self.min_length:
125
- return [waveform]
126
- rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
127
- sil_tags = []
128
- silence_start = None
129
- clip_start = 0
130
- for i, rms in enumerate(rms_list):
131
- # Keep looping while frame is silent.
132
- if rms < self.threshold:
133
- # Record start of silent frames.
134
- if silence_start is None:
135
- silence_start = i
136
- continue
137
- # Keep looping while frame is not silent and silence start has not been recorded.
138
- if silence_start is None:
139
- continue
140
- # Clear recorded silence start if interval is not enough or clip is too short
141
- is_leading_silence = silence_start == 0 and i > self.max_sil_kept
142
- need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
143
- if not is_leading_silence and not need_slice_middle:
144
- silence_start = None
145
- continue
146
- # Need slicing. Record the range of silent frames to be removed.
147
- if i - silence_start <= self.max_sil_kept:
148
- pos = rms_list[silence_start : i + 1].argmin() + silence_start
149
- if silence_start == 0:
150
- sil_tags.append((0, pos))
151
- else:
152
- sil_tags.append((pos, pos))
153
- clip_start = pos
154
- elif i - silence_start <= self.max_sil_kept * 2:
155
- pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
156
- pos += i - self.max_sil_kept
157
- pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
158
- pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
159
- if silence_start == 0:
160
- sil_tags.append((0, pos_r))
161
- clip_start = pos_r
162
- else:
163
- sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
164
- clip_start = max(pos_r, pos)
165
- else:
166
- pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
167
- pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
168
- if silence_start == 0:
169
- sil_tags.append((0, pos_r))
170
- else:
171
- sil_tags.append((pos_l, pos_r))
172
- clip_start = pos_r
173
- silence_start = None
174
- # Deal with trailing silence.
175
- total_frames = rms_list.shape[0]
176
- if silence_start is not None and total_frames - silence_start >= self.min_interval:
177
- silence_end = min(total_frames, silence_start + self.max_sil_kept)
178
- pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
179
- sil_tags.append((pos, total_frames + 1))
180
- # Apply and return slices.
181
- ####音频+起始时间+终止时间
182
- if len(sil_tags) == 0:
183
- return [[waveform, 0, int(total_frames * self.hop_size)]]
184
- else:
185
- chunks = []
186
- if sil_tags[0][0] > 0:
187
- chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]), 0, int(sil_tags[0][0] * self.hop_size)])
188
- for i in range(len(sil_tags) - 1):
189
- chunks.append(
190
- [
191
- self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),
192
- int(sil_tags[i][1] * self.hop_size),
193
- int(sil_tags[i + 1][0] * self.hop_size),
194
- ]
195
- )
196
- if sil_tags[-1][1] < total_frames:
197
- chunks.append(
198
- [
199
- self._apply_slice(waveform, sil_tags[-1][1], total_frames),
200
- int(sil_tags[-1][1] * self.hop_size),
201
- int(total_frames * self.hop_size),
202
- ]
203
- )
204
- return chunks
205
-
206
-
207
- # terminal
208
- def terminate_process_tree(pid, including_parent=True):
209
- try:
210
- parent = psutil.Process(pid)
211
- except psutil.NoSuchProcess:
212
- # Process already terminated
213
- return
214
-
215
- children = parent.children(recursive=True)
216
- for child in children:
217
- try:
218
- os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
219
- except OSError:
220
- pass
221
- if including_parent:
222
- try:
223
- os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
224
- except OSError:
225
- pass
226
-
227
-
228
- def terminate_process(pid):
229
- if system == "Windows":
230
- cmd = f"taskkill /t /f /pid {pid}"
231
- os.system(cmd)
232
- else:
233
- terminate_process_tree(pid)
234
-
235
-
236
- def start_training(
237
- dataset_name="",
238
- exp_name="F5TTS_Base",
239
- learning_rate=1e-4,
240
- batch_size_per_gpu=400,
241
- batch_size_type="frame",
242
- max_samples=64,
243
- grad_accumulation_steps=1,
244
- max_grad_norm=1.0,
245
- epochs=11,
246
- num_warmup_updates=200,
247
- save_per_updates=400,
248
- last_per_steps=800,
249
- finetune=True,
250
- ):
251
- global training_process, tts_api
252
-
253
- if tts_api is not None:
254
- del tts_api
255
- gc.collect()
256
- torch.cuda.empty_cache()
257
- tts_api = None
258
-
259
- path_project = os.path.join(path_data, dataset_name + "_pinyin")
260
-
261
- if not os.path.isdir(path_project):
262
- yield (
263
- f"There is not project with name {dataset_name}",
264
- gr.update(interactive=True),
265
- gr.update(interactive=False),
266
- )
267
- return
268
-
269
- file_raw = os.path.join(path_project, "raw.arrow")
270
- if not os.path.isfile(file_raw):
271
- yield f"There is no file {file_raw}", gr.update(interactive=True), gr.update(interactive=False)
272
- return
273
-
274
- # Check if a training process is already running
275
- if training_process is not None:
276
- return "Train run already!", gr.update(interactive=False), gr.update(interactive=True)
277
-
278
- yield "start train", gr.update(interactive=False), gr.update(interactive=False)
279
-
280
- # Command to run the training script with the specified arguments
281
- cmd = (
282
- f"accelerate launch finetune-cli.py --exp_name {exp_name} "
283
- f"--learning_rate {learning_rate} "
284
- f"--batch_size_per_gpu {batch_size_per_gpu} "
285
- f"--batch_size_type {batch_size_type} "
286
- f"--max_samples {max_samples} "
287
- f"--grad_accumulation_steps {grad_accumulation_steps} "
288
- f"--max_grad_norm {max_grad_norm} "
289
- f"--epochs {epochs} "
290
- f"--num_warmup_updates {num_warmup_updates} "
291
- f"--save_per_updates {save_per_updates} "
292
- f"--last_per_steps {last_per_steps} "
293
- f"--dataset_name {dataset_name}"
294
- )
295
- if finetune:
296
- cmd += f" --finetune {finetune}"
297
-
298
- print(cmd)
299
-
300
- try:
301
- # Start the training process
302
- training_process = subprocess.Popen(cmd, shell=True)
303
-
304
- time.sleep(5)
305
- yield "train start", gr.update(interactive=False), gr.update(interactive=True)
306
-
307
- # Wait for the training process to finish
308
- training_process.wait()
309
- time.sleep(1)
310
-
311
- if training_process is None:
312
- text_info = "train stop"
313
- else:
314
- text_info = "train complete !"
315
-
316
- except Exception as e: # Catch all exceptions
317
- # Ensure that we reset the training process variable in case of an error
318
- text_info = f"An error occurred: {str(e)}"
319
-
320
- training_process = None
321
-
322
- yield text_info, gr.update(interactive=True), gr.update(interactive=False)
323
-
324
-
325
- def stop_training():
326
- global training_process
327
- if training_process is None:
328
- return "Train not run !", gr.update(interactive=True), gr.update(interactive=False)
329
- terminate_process_tree(training_process.pid)
330
- training_process = None
331
- return "train stop", gr.update(interactive=True), gr.update(interactive=False)
332
-
333
-
334
- def create_data_project(name):
335
- name += "_pinyin"
336
- os.makedirs(os.path.join(path_data, name), exist_ok=True)
337
- os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True)
338
-
339
-
340
- def transcribe(file_audio, language="english"):
341
- global pipe
342
-
343
- if pipe is None:
344
- pipe = pipeline(
345
- "automatic-speech-recognition",
346
- model="openai/whisper-large-v3-turbo",
347
- torch_dtype=torch.float16,
348
- device=device,
349
- )
350
-
351
- text_transcribe = pipe(
352
- file_audio,
353
- chunk_length_s=30,
354
- batch_size=128,
355
- generate_kwargs={"task": "transcribe", "language": language},
356
- return_timestamps=False,
357
- )["text"].strip()
358
- return text_transcribe
359
-
360
-
361
- def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
362
- name_project += "_pinyin"
363
- path_project = os.path.join(path_data, name_project)
364
- path_dataset = os.path.join(path_project, "dataset")
365
- path_project_wavs = os.path.join(path_project, "wavs")
366
- file_metadata = os.path.join(path_project, "metadata.csv")
367
-
368
- if audio_files is None:
369
- return "You need to load an audio file."
370
-
371
- if os.path.isdir(path_project_wavs):
372
- shutil.rmtree(path_project_wavs)
373
-
374
- if os.path.isfile(file_metadata):
375
- os.remove(file_metadata)
376
-
377
- os.makedirs(path_project_wavs, exist_ok=True)
378
-
379
- if user:
380
- file_audios = [
381
- file
382
- for format in ("*.wav", "*.ogg", "*.opus", "*.mp3", "*.flac")
383
- for file in glob(os.path.join(path_dataset, format))
384
- ]
385
- if file_audios == []:
386
- return "No audio file was found in the dataset."
387
- else:
388
- file_audios = audio_files
389
-
390
- alpha = 0.5
391
- _max = 1.0
392
- slicer = Slicer(24000)
393
-
394
- num = 0
395
- error_num = 0
396
- data = ""
397
- for file_audio in progress.tqdm(file_audios, desc="transcribe files", total=len((file_audios))):
398
- audio, _ = librosa.load(file_audio, sr=24000, mono=True)
399
-
400
- list_slicer = slicer.slice(audio)
401
- for chunk, start, end in progress.tqdm(list_slicer, total=len(list_slicer), desc="slicer files"):
402
- name_segment = os.path.join(f"segment_{num}")
403
- file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav")
404
-
405
- tmp_max = np.abs(chunk).max()
406
- if tmp_max > 1:
407
- chunk /= tmp_max
408
- chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
409
- wavfile.write(file_segment, 24000, (chunk * 32767).astype(np.int16))
410
-
411
- try:
412
- text = transcribe(file_segment, language)
413
- text = text.lower().strip().replace('"', "")
414
-
415
- data += f"{name_segment}|{text}\n"
416
-
417
- num += 1
418
- except: # noqa: E722
419
- error_num += 1
420
-
421
- with open(file_metadata, "w", encoding="utf-8") as f:
422
- f.write(data)
423
-
424
- if error_num != []:
425
- error_text = f"\nerror files : {error_num}"
426
- else:
427
- error_text = ""
428
-
429
- return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
430
-
431
-
432
- def format_seconds_to_hms(seconds):
433
- hours = int(seconds / 3600)
434
- minutes = int((seconds % 3600) / 60)
435
- seconds = seconds % 60
436
- return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
437
-
438
-
439
- def create_metadata(name_project, progress=gr.Progress()):
440
- name_project += "_pinyin"
441
- path_project = os.path.join(path_data, name_project)
442
- path_project_wavs = os.path.join(path_project, "wavs")
443
- file_metadata = os.path.join(path_project, "metadata.csv")
444
- file_raw = os.path.join(path_project, "raw.arrow")
445
- file_duration = os.path.join(path_project, "duration.json")
446
- file_vocab = os.path.join(path_project, "vocab.txt")
447
-
448
- if not os.path.isfile(file_metadata):
449
- return "The file was not found in " + file_metadata
450
-
451
- with open(file_metadata, "r", encoding="utf-8") as f:
452
- data = f.read()
453
-
454
- audio_path_list = []
455
- text_list = []
456
- duration_list = []
457
-
458
- count = data.split("\n")
459
- lenght = 0
460
- result = []
461
- error_files = []
462
- for line in progress.tqdm(data.split("\n"), total=count):
463
- sp_line = line.split("|")
464
- if len(sp_line) != 2:
465
- continue
466
- name_audio, text = sp_line[:2]
467
-
468
- file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
469
-
470
- if not os.path.isfile(file_audio):
471
- error_files.append(file_audio)
472
- continue
473
-
474
- duraction = get_audio_duration(file_audio)
475
- if duraction < 2 and duraction > 15:
476
- continue
477
- if len(text) < 4:
478
- continue
479
-
480
- text = clear_text(text)
481
- text = convert_char_to_pinyin([text], polyphone=True)[0]
482
-
483
- audio_path_list.append(file_audio)
484
- duration_list.append(duraction)
485
- text_list.append(text)
486
-
487
- result.append({"audio_path": file_audio, "text": text, "duration": duraction})
488
-
489
- lenght += duraction
490
-
491
- if duration_list == []:
492
- error_files_text = "\n".join(error_files)
493
- return f"Error: No audio files found in the specified path : \n{error_files_text}"
494
-
495
- min_second = round(min(duration_list), 2)
496
- max_second = round(max(duration_list), 2)
497
-
498
- with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
499
- for line in progress.tqdm(result, total=len(result), desc="prepare data"):
500
- writer.write(line)
501
-
502
- with open(file_duration, "w", encoding="utf-8") as f:
503
- json.dump({"duration": duration_list}, f, ensure_ascii=False)
504
-
505
- file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
506
- if not os.path.isfile(file_vocab_finetune):
507
- return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
508
- shutil.copy2(file_vocab_finetune, file_vocab)
509
-
510
- if error_files != []:
511
- error_text = "error files\n" + "\n".join(error_files)
512
- else:
513
- error_text = ""
514
-
515
- return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
516
-
517
-
518
- def check_user(value):
519
- return gr.update(visible=not value), gr.update(visible=value)
520
-
521
-
522
- def calculate_train(
523
- name_project,
524
- batch_size_type,
525
- max_samples,
526
- learning_rate,
527
- num_warmup_updates,
528
- save_per_updates,
529
- last_per_steps,
530
- finetune,
531
- ):
532
- name_project += "_pinyin"
533
- path_project = os.path.join(path_data, name_project)
534
- file_duraction = os.path.join(path_project, "duration.json")
535
-
536
- if not os.path.isfile(file_duraction):
537
- return (
538
- 1000,
539
- max_samples,
540
- num_warmup_updates,
541
- save_per_updates,
542
- last_per_steps,
543
- "project not found !",
544
- learning_rate,
545
- )
546
-
547
- with open(file_duraction, "r") as file:
548
- data = json.load(file)
549
-
550
- duration_list = data["duration"]
551
-
552
- samples = len(duration_list)
553
-
554
- if torch.cuda.is_available():
555
- gpu_properties = torch.cuda.get_device_properties(0)
556
- total_memory = gpu_properties.total_memory / (1024**3)
557
- elif torch.backends.mps.is_available():
558
- total_memory = psutil.virtual_memory().available / (1024**3)
559
-
560
- if batch_size_type == "frame":
561
- batch = int(total_memory * 0.5)
562
- batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
563
- batch_size_per_gpu = int(38400 / batch)
564
- else:
565
- batch_size_per_gpu = int(total_memory / 8)
566
- batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
567
- batch = batch_size_per_gpu
568
-
569
- if batch_size_per_gpu <= 0:
570
- batch_size_per_gpu = 1
571
-
572
- if samples < 64:
573
- max_samples = int(samples * 0.25)
574
- else:
575
- max_samples = 64
576
-
577
- num_warmup_updates = int(samples * 0.05)
578
- save_per_updates = int(samples * 0.10)
579
- last_per_steps = int(save_per_updates * 5)
580
-
581
- max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
582
- num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
583
- save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
584
- last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
585
-
586
- if finetune:
587
- learning_rate = 1e-5
588
- else:
589
- learning_rate = 7.5e-5
590
-
591
- return batch_size_per_gpu, max_samples, num_warmup_updates, save_per_updates, last_per_steps, samples, learning_rate
592
-
593
-
594
- def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None:
595
- try:
596
- checkpoint = torch.load(checkpoint_path)
597
- print("Original Checkpoint Keys:", checkpoint.keys())
598
-
599
- ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
600
-
601
- if ema_model_state_dict is not None:
602
- new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
603
- torch.save(new_checkpoint, new_checkpoint_path)
604
- return f"New checkpoint saved at: {new_checkpoint_path}"
605
- else:
606
- return "No 'ema_model_state_dict' found in the checkpoint."
607
-
608
- except Exception as e:
609
- return f"An error occurred: {e}"
610
-
611
-
612
- def vocab_check(project_name):
613
- name_project = project_name + "_pinyin"
614
- path_project = os.path.join(path_data, name_project)
615
-
616
- file_metadata = os.path.join(path_project, "metadata.csv")
617
-
618
- file_vocab = "data/Emilia_ZH_EN_pinyin/vocab.txt"
619
- if not os.path.isfile(file_vocab):
620
- return f"the file {file_vocab} not found !"
621
-
622
- with open(file_vocab, "r", encoding="utf-8") as f:
623
- data = f.read()
624
-
625
- vocab = data.split("\n")
626
-
627
- if not os.path.isfile(file_metadata):
628
- return f"the file {file_metadata} not found !"
629
-
630
- with open(file_metadata, "r", encoding="utf-8") as f:
631
- data = f.read()
632
-
633
- miss_symbols = []
634
- miss_symbols_keep = {}
635
- for item in data.split("\n"):
636
- sp = item.split("|")
637
- if len(sp) != 2:
638
- continue
639
-
640
- text = sp[1].lower().strip()
641
-
642
- for t in text:
643
- if t not in vocab and t not in miss_symbols_keep:
644
- miss_symbols.append(t)
645
- miss_symbols_keep[t] = t
646
- if miss_symbols == []:
647
- info = "You can train using your language !"
648
- else:
649
- info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
650
-
651
- return info
652
-
653
-
654
- def get_random_sample_prepare(project_name):
655
- name_project = project_name + "_pinyin"
656
- path_project = os.path.join(path_data, name_project)
657
- file_arrow = os.path.join(path_project, "raw.arrow")
658
- if not os.path.isfile(file_arrow):
659
- return "", None
660
- dataset = Dataset_.from_file(file_arrow)
661
- random_sample = dataset.shuffle(seed=random.randint(0, 1000)).select([0])
662
- text = "[" + " , ".join(["' " + t + " '" for t in random_sample["text"][0]]) + "]"
663
- audio_path = random_sample["audio_path"][0]
664
- return text, audio_path
665
-
666
-
667
- def get_random_sample_transcribe(project_name):
668
- name_project = project_name + "_pinyin"
669
- path_project = os.path.join(path_data, name_project)
670
- file_metadata = os.path.join(path_project, "metadata.csv")
671
- if not os.path.isfile(file_metadata):
672
- return "", None
673
-
674
- data = ""
675
- with open(file_metadata, "r", encoding="utf-8") as f:
676
- data = f.read()
677
-
678
- list_data = []
679
- for item in data.split("\n"):
680
- sp = item.split("|")
681
- if len(sp) != 2:
682
- continue
683
- list_data.append([os.path.join(path_project, "wavs", sp[0] + ".wav"), sp[1]])
684
-
685
- if list_data == []:
686
- return "", None
687
-
688
- random_item = random.choice(list_data)
689
-
690
- return random_item[1], random_item[0]
691
-
692
-
693
- def get_random_sample_infer(project_name):
694
- text, audio = get_random_sample_transcribe(project_name)
695
- return (
696
- text,
697
- text,
698
- audio,
699
- )
700
-
701
-
702
- def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
703
- global last_checkpoint, last_device, tts_api
704
-
705
- if not os.path.isfile(file_checkpoint):
706
- return None
707
-
708
- if training_process is not None:
709
- device_test = "cpu"
710
- else:
711
- device_test = None
712
-
713
- if last_checkpoint != file_checkpoint or last_device != device_test:
714
- if last_checkpoint != file_checkpoint:
715
- last_checkpoint = file_checkpoint
716
- if last_device != device_test:
717
- last_device = device_test
718
-
719
- tts_api = F5TTS(model_type=exp_name, ckpt_file=file_checkpoint, device=device_test)
720
-
721
- print("update", device_test, file_checkpoint)
722
-
723
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
724
- tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
725
- return f.name
726
-
727
-
728
- with gr.Blocks() as app:
729
- with gr.Row():
730
- project_name = gr.Textbox(label="project name", value="my_speak")
731
- bt_create = gr.Button("create new project")
732
-
733
- bt_create.click(fn=create_data_project, inputs=[project_name])
734
-
735
- with gr.Tabs():
736
- with gr.TabItem("transcribe Data"):
737
- ch_manual = gr.Checkbox(label="user", value=False)
738
-
739
- mark_info_transcribe = gr.Markdown(
740
- """```plaintext
741
- Place your 'wavs' folder and 'metadata.csv' file in the {your_project_name}' directory.
742
-
743
- my_speak/
744
-
745
- └── dataset/
746
- ├── audio1.wav
747
- └── audio2.wav
748
- ...
749
- ```""",
750
- visible=False,
751
- )
752
-
753
- audio_speaker = gr.File(label="voice", type="filepath", file_count="multiple")
754
- txt_lang = gr.Text(label="Language", value="english")
755
- bt_transcribe = bt_create = gr.Button("transcribe")
756
- txt_info_transcribe = gr.Text(label="info", value="")
757
- bt_transcribe.click(
758
- fn=transcribe_all,
759
- inputs=[project_name, audio_speaker, txt_lang, ch_manual],
760
- outputs=[txt_info_transcribe],
761
- )
762
- ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
763
-
764
- random_sample_transcribe = gr.Button("random sample")
765
-
766
- with gr.Row():
767
- random_text_transcribe = gr.Text(label="Text")
768
- random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
769
-
770
- random_sample_transcribe.click(
771
- fn=get_random_sample_transcribe,
772
- inputs=[project_name],
773
- outputs=[random_text_transcribe, random_audio_transcribe],
774
- )
775
-
776
- with gr.TabItem("prepare Data"):
777
- gr.Markdown(
778
- """```plaintext
779
- place all your wavs folder and your metadata.csv file in {your name project}
780
- my_speak/
781
-
782
- ├── wavs/
783
- │ ├── audio1.wav
784
- │ └── audio2.wav
785
- | ...
786
-
787
- └── metadata.csv
788
-
789
- file format metadata.csv
790
-
791
- audio1|text1
792
- audio2|text1
793
- ...
794
-
795
- ```"""
796
- )
797
-
798
- bt_prepare = bt_create = gr.Button("prepare")
799
- txt_info_prepare = gr.Text(label="info", value="")
800
- bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
801
-
802
- random_sample_prepare = gr.Button("random sample")
803
-
804
- with gr.Row():
805
- random_text_prepare = gr.Text(label="Pinyin")
806
- random_audio_prepare = gr.Audio(label="Audio", type="filepath")
807
-
808
- random_sample_prepare.click(
809
- fn=get_random_sample_prepare, inputs=[project_name], outputs=[random_text_prepare, random_audio_prepare]
810
- )
811
-
812
- with gr.TabItem("train Data"):
813
- with gr.Row():
814
- bt_calculate = bt_create = gr.Button("Auto Settings")
815
- ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
816
- lb_samples = gr.Label(label="samples")
817
- batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
818
-
819
- with gr.Row():
820
- exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
821
- learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
822
-
823
- with gr.Row():
824
- batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
825
- max_samples = gr.Number(label="Max Samples", value=64)
826
-
827
- with gr.Row():
828
- grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
829
- max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
830
-
831
- with gr.Row():
832
- epochs = gr.Number(label="Epochs", value=10)
833
- num_warmup_updates = gr.Number(label="Warmup Updates", value=5)
834
-
835
- with gr.Row():
836
- save_per_updates = gr.Number(label="Save per Updates", value=10)
837
- last_per_steps = gr.Number(label="Last per Steps", value=50)
838
-
839
- with gr.Row():
840
- start_button = gr.Button("Start Training")
841
- stop_button = gr.Button("Stop Training", interactive=False)
842
-
843
- txt_info_train = gr.Text(label="info", value="")
844
- start_button.click(
845
- fn=start_training,
846
- inputs=[
847
- project_name,
848
- exp_name,
849
- learning_rate,
850
- batch_size_per_gpu,
851
- batch_size_type,
852
- max_samples,
853
- grad_accumulation_steps,
854
- max_grad_norm,
855
- epochs,
856
- num_warmup_updates,
857
- save_per_updates,
858
- last_per_steps,
859
- ch_finetune,
860
- ],
861
- outputs=[txt_info_train, start_button, stop_button],
862
- )
863
- stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button])
864
- bt_calculate.click(
865
- fn=calculate_train,
866
- inputs=[
867
- project_name,
868
- batch_size_type,
869
- max_samples,
870
- learning_rate,
871
- num_warmup_updates,
872
- save_per_updates,
873
- last_per_steps,
874
- ch_finetune,
875
- ],
876
- outputs=[
877
- batch_size_per_gpu,
878
- max_samples,
879
- num_warmup_updates,
880
- save_per_updates,
881
- last_per_steps,
882
- lb_samples,
883
- learning_rate,
884
- ],
885
- )
886
-
887
- with gr.TabItem("reduse checkpoint"):
888
- txt_path_checkpoint = gr.Text(label="path checkpoint :")
889
- txt_path_checkpoint_small = gr.Text(label="path output :")
890
- txt_info_reduse = gr.Text(label="info", value="")
891
- reduse_button = gr.Button("reduse")
892
- reduse_button.click(
893
- fn=extract_and_save_ema_model,
894
- inputs=[txt_path_checkpoint, txt_path_checkpoint_small],
895
- outputs=[txt_info_reduse],
896
- )
897
-
898
- with gr.TabItem("vocab check experiment"):
899
- check_button = gr.Button("check vocab")
900
- txt_info_check = gr.Text(label="info", value="")
901
- check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
902
-
903
- with gr.TabItem("test model"):
904
- exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
905
- nfe_step = gr.Number(label="n_step", value=32)
906
- file_checkpoint_pt = gr.Textbox(label="Checkpoint", value="")
907
-
908
- random_sample_infer = gr.Button("random sample")
909
-
910
- ref_text = gr.Textbox(label="ref text")
911
- ref_audio = gr.Audio(label="audio ref", type="filepath")
912
- gen_text = gr.Textbox(label="gen text")
913
- random_sample_infer.click(
914
- fn=get_random_sample_infer, inputs=[project_name], outputs=[ref_text, gen_text, ref_audio]
915
- )
916
- check_button_infer = gr.Button("infer")
917
- gen_audio = gr.Audio(label="audio gen", type="filepath")
918
-
919
- check_button_infer.click(
920
- fn=infer,
921
- inputs=[file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step],
922
- outputs=[gen_audio],
923
- )
924
-
925
-
926
- @click.command()
927
- @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
928
- @click.option("--host", "-H", default=None, help="Host to run the app on")
929
- @click.option(
930
- "--share",
931
- "-s",
932
- default=False,
933
- is_flag=True,
934
- help="Share the app via Gradio share link",
935
- )
936
- @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
937
- def main(port, host, share, api):
938
- global app
939
- print("Starting app...")
940
- app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api)
941
-
942
-
943
- if __name__ == "__main__":
944
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_app.py DELETED
@@ -1,824 +0,0 @@
1
- import os
2
- import re
3
- import torch
4
- import torchaudio
5
- import gradio as gr
6
- import numpy as np
7
- import tempfile
8
- from einops import rearrange
9
- from vocos import Vocos
10
- from pydub import AudioSegment, silence
11
- from model import CFM, UNetT, DiT, MMDiT
12
- from cached_path import cached_path
13
- from model.utils import (
14
- load_checkpoint,
15
- get_tokenizer,
16
- convert_char_to_pinyin,
17
- save_spectrogram,
18
- )
19
- from transformers import pipeline
20
- import librosa
21
- import click
22
- import soundfile as sf
23
-
24
- try:
25
- import spaces
26
- USING_SPACES = True
27
- except ImportError:
28
- USING_SPACES = False
29
-
30
- def gpu_decorator(func):
31
- if USING_SPACES:
32
- return spaces.GPU(func)
33
- else:
34
- return func
35
-
36
-
37
-
38
- SPLIT_WORDS = [
39
- "but", "however", "nevertheless", "yet", "still",
40
- "therefore", "thus", "hence", "consequently",
41
- "moreover", "furthermore", "additionally",
42
- "meanwhile", "alternatively", "otherwise",
43
- "namely", "specifically", "for example", "such as",
44
- "in fact", "indeed", "notably",
45
- "in contrast", "on the other hand", "conversely",
46
- "in conclusion", "to summarize", "finally"
47
- ]
48
-
49
- device = (
50
- "cuda"
51
- if torch.cuda.is_available()
52
- else "mps" if torch.backends.mps.is_available() else "cpu"
53
- )
54
-
55
- print(f"Using {device} device")
56
-
57
- pipe = pipeline(
58
- "automatic-speech-recognition",
59
- model="openai/whisper-large-v3-turbo",
60
- torch_dtype=torch.float16,
61
- device=device,
62
- )
63
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
64
-
65
- # --------------------- Settings -------------------- #
66
-
67
- target_sample_rate = 24000
68
- n_mel_channels = 100
69
- hop_length = 256
70
- target_rms = 0.1
71
- nfe_step = 32 # 16, 32
72
- cfg_strength = 2.0
73
- ode_method = "euler"
74
- sway_sampling_coef = -1.0
75
- speed = 1.0
76
- # fix_duration = 27 # None or float (duration in seconds)
77
- fix_duration = None
78
-
79
-
80
- def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
81
- ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
82
- # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
83
- vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
84
- model = CFM(
85
- transformer=model_cls(
86
- **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
87
- ),
88
- mel_spec_kwargs=dict(
89
- target_sample_rate=target_sample_rate,
90
- n_mel_channels=n_mel_channels,
91
- hop_length=hop_length,
92
- ),
93
- odeint_kwargs=dict(
94
- method=ode_method,
95
- ),
96
- vocab_char_map=vocab_char_map,
97
- ).to(device)
98
-
99
- model = load_checkpoint(model, ckpt_path, device, use_ema = True)
100
-
101
- return model
102
-
103
-
104
- # load models
105
- F5TTS_model_cfg = dict(
106
- dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
107
- )
108
- E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
109
-
110
- F5TTS_ema_model = load_model(
111
- "F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
112
- )
113
- E2TTS_ema_model = load_model(
114
- "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
115
- )
116
-
117
- def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
118
- if len(text.encode('utf-8')) <= max_chars:
119
- return [text]
120
- if text[-1] not in ['。', '.', '!', '!', '?', '?']:
121
- text += '.'
122
-
123
- sentences = re.split('([。.!?!?])', text)
124
- sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
125
-
126
- batches = []
127
- current_batch = ""
128
-
129
- def split_by_words(text):
130
- words = text.split()
131
- current_word_part = ""
132
- word_batches = []
133
- for word in words:
134
- if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
135
- current_word_part += word + ' '
136
- else:
137
- if current_word_part:
138
- # Try to find a suitable split word
139
- for split_word in split_words:
140
- split_index = current_word_part.rfind(' ' + split_word + ' ')
141
- if split_index != -1:
142
- word_batches.append(current_word_part[:split_index].strip())
143
- current_word_part = current_word_part[split_index:].strip() + ' '
144
- break
145
- else:
146
- # If no suitable split word found, just append the current part
147
- word_batches.append(current_word_part.strip())
148
- current_word_part = ""
149
- current_word_part += word + ' '
150
- if current_word_part:
151
- word_batches.append(current_word_part.strip())
152
- return word_batches
153
-
154
- for sentence in sentences:
155
- if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
156
- current_batch += sentence
157
- else:
158
- # If adding this sentence would exceed the limit
159
- if current_batch:
160
- batches.append(current_batch)
161
- current_batch = ""
162
-
163
- # If the sentence itself is longer than max_chars, split it
164
- if len(sentence.encode('utf-8')) > max_chars:
165
- # First, try to split by colon
166
- colon_parts = sentence.split(':')
167
- if len(colon_parts) > 1:
168
- for part in colon_parts:
169
- if len(part.encode('utf-8')) <= max_chars:
170
- batches.append(part)
171
- else:
172
- # If colon part is still too long, split by comma
173
- comma_parts = re.split('[,,]', part)
174
- if len(comma_parts) > 1:
175
- current_comma_part = ""
176
- for comma_part in comma_parts:
177
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
178
- current_comma_part += comma_part + ','
179
- else:
180
- if current_comma_part:
181
- batches.append(current_comma_part.rstrip(','))
182
- current_comma_part = comma_part + ','
183
- if current_comma_part:
184
- batches.append(current_comma_part.rstrip(','))
185
- else:
186
- # If no comma, split by words
187
- batches.extend(split_by_words(part))
188
- else:
189
- # If no colon, split by comma
190
- comma_parts = re.split('[,,]', sentence)
191
- if len(comma_parts) > 1:
192
- current_comma_part = ""
193
- for comma_part in comma_parts:
194
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
195
- current_comma_part += comma_part + ','
196
- else:
197
- if current_comma_part:
198
- batches.append(current_comma_part.rstrip(','))
199
- current_comma_part = comma_part + ','
200
- if current_comma_part:
201
- batches.append(current_comma_part.rstrip(','))
202
- else:
203
- # If no comma, split by words
204
- batches.extend(split_by_words(sentence))
205
- else:
206
- current_batch = sentence
207
-
208
- if current_batch:
209
- batches.append(current_batch)
210
-
211
- return batches
212
-
213
- def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, progress=gr.Progress()):
214
- if exp_name == "F5-TTS":
215
- ema_model = F5TTS_ema_model
216
- elif exp_name == "E2-TTS":
217
- ema_model = E2TTS_ema_model
218
-
219
- audio, sr = ref_audio
220
- if audio.shape[0] > 1:
221
- audio = torch.mean(audio, dim=0, keepdim=True)
222
-
223
- rms = torch.sqrt(torch.mean(torch.square(audio)))
224
- if rms < target_rms:
225
- audio = audio * target_rms / rms
226
- if sr != target_sample_rate:
227
- resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
228
- audio = resampler(audio)
229
- audio = audio.to(device)
230
-
231
- generated_waves = []
232
- spectrograms = []
233
-
234
- for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
235
- # Prepare the text
236
- if len(ref_text[-1].encode('utf-8')) == 1:
237
- ref_text = ref_text + " "
238
- text_list = [ref_text + gen_text]
239
- final_text_list = convert_char_to_pinyin(text_list)
240
-
241
- # Calculate duration
242
- ref_audio_len = audio.shape[-1] // hop_length
243
- zh_pause_punc = r"。,、;:?!"
244
- ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
245
- gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
246
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
247
-
248
- # inference
249
- with torch.inference_mode():
250
- generated, _ = ema_model.sample(
251
- cond=audio,
252
- text=final_text_list,
253
- duration=duration,
254
- steps=nfe_step,
255
- cfg_strength=cfg_strength,
256
- sway_sampling_coef=sway_sampling_coef,
257
- )
258
-
259
- generated = generated[:, ref_audio_len:, :]
260
- generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
261
- generated_wave = vocos.decode(generated_mel_spec.cpu())
262
- if rms < target_rms:
263
- generated_wave = generated_wave * rms / target_rms
264
-
265
- # wav -> numpy
266
- generated_wave = generated_wave.squeeze().cpu().numpy()
267
-
268
- generated_waves.append(generated_wave)
269
- spectrograms.append(generated_mel_spec[0].cpu().numpy())
270
-
271
- # Combine all generated waves
272
- final_wave = np.concatenate(generated_waves)
273
-
274
- # Remove silence
275
- if remove_silence:
276
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
277
- sf.write(f.name, final_wave, target_sample_rate)
278
- aseg = AudioSegment.from_file(f.name)
279
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
280
- non_silent_wave = AudioSegment.silent(duration=0)
281
- for non_silent_seg in non_silent_segs:
282
- non_silent_wave += non_silent_seg
283
- aseg = non_silent_wave
284
- aseg.export(f.name, format="wav")
285
- final_wave, _ = torchaudio.load(f.name)
286
- final_wave = final_wave.squeeze().cpu().numpy()
287
-
288
- # Create a combined spectrogram
289
- combined_spectrogram = np.concatenate(spectrograms, axis=1)
290
-
291
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
292
- spectrogram_path = tmp_spectrogram.name
293
- save_spectrogram(combined_spectrogram, spectrogram_path)
294
-
295
- return (target_sample_rate, final_wave), spectrogram_path
296
-
297
- def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_split_words=''):
298
- if not custom_split_words.strip():
299
- custom_words = [word.strip() for word in custom_split_words.split(',')]
300
- global SPLIT_WORDS
301
- SPLIT_WORDS = custom_words
302
-
303
- print(gen_text)
304
-
305
- gr.Info("Converting audio...")
306
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
307
- aseg = AudioSegment.from_file(ref_audio_orig)
308
-
309
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
310
- non_silent_wave = AudioSegment.silent(duration=0)
311
- for non_silent_seg in non_silent_segs:
312
- non_silent_wave += non_silent_seg
313
- aseg = non_silent_wave
314
-
315
- audio_duration = len(aseg)
316
- if audio_duration > 15000:
317
- gr.Warning("Audio is over 15s, clipping to only first 15s.")
318
- aseg = aseg[:15000]
319
- aseg.export(f.name, format="wav")
320
- ref_audio = f.name
321
-
322
- if not ref_text.strip():
323
- gr.Info("No reference text provided, transcribing reference audio...")
324
- ref_text = pipe(
325
- ref_audio,
326
- chunk_length_s=30,
327
- batch_size=128,
328
- generate_kwargs={"task": "transcribe"},
329
- return_timestamps=False,
330
- )["text"].strip()
331
- gr.Info("Finished transcription")
332
- else:
333
- gr.Info("Using custom reference text...")
334
-
335
- # Split the input text into batches
336
- audio, sr = torchaudio.load(ref_audio)
337
- max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
338
- gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
339
- print('ref_text', ref_text)
340
- for i, gen_text in enumerate(gen_text_batches):
341
- print(f'gen_text {i}', gen_text)
342
-
343
- gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
344
- return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
345
-
346
- def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
347
- # Split the script into speaker blocks
348
- speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
349
- speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
350
-
351
- generated_audio_segments = []
352
-
353
- for i in range(0, len(speaker_blocks), 2):
354
- speaker = speaker_blocks[i]
355
- text = speaker_blocks[i+1].strip()
356
-
357
- # Determine which speaker is talking
358
- if speaker == speaker1_name:
359
- ref_audio = ref_audio1
360
- ref_text = ref_text1
361
- elif speaker == speaker2_name:
362
- ref_audio = ref_audio2
363
- ref_text = ref_text2
364
- else:
365
- continue # Skip if the speaker is neither speaker1 nor speaker2
366
-
367
- # Generate audio for this block
368
- audio, _ = infer(ref_audio, ref_text, text, exp_name, remove_silence)
369
-
370
- # Convert the generated audio to a numpy array
371
- sr, audio_data = audio
372
-
373
- # Save the audio data as a WAV file
374
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
375
- sf.write(temp_file.name, audio_data, sr)
376
- audio_segment = AudioSegment.from_wav(temp_file.name)
377
-
378
- generated_audio_segments.append(audio_segment)
379
-
380
- # Add a short pause between speakers
381
- pause = AudioSegment.silent(duration=500) # 500ms pause
382
- generated_audio_segments.append(pause)
383
-
384
- # Concatenate all audio segments
385
- final_podcast = sum(generated_audio_segments)
386
-
387
- # Export the final podcast
388
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
389
- podcast_path = temp_file.name
390
- final_podcast.export(podcast_path, format="wav")
391
-
392
- return podcast_path
393
-
394
- def parse_speechtypes_text(gen_text):
395
- # Pattern to find (Emotion)
396
- pattern = r'\((.*?)\)'
397
-
398
- # Split the text by the pattern
399
- tokens = re.split(pattern, gen_text)
400
-
401
- segments = []
402
-
403
- current_emotion = 'Regular'
404
-
405
- for i in range(len(tokens)):
406
- if i % 2 == 0:
407
- # This is text
408
- text = tokens[i].strip()
409
- if text:
410
- segments.append({'emotion': current_emotion, 'text': text})
411
- else:
412
- # This is emotion
413
- emotion = tokens[i].strip()
414
- current_emotion = emotion
415
-
416
- return segments
417
-
418
- def update_speed(new_speed):
419
- global speed
420
- speed = new_speed
421
- return f"Speed set to: {speed}"
422
-
423
- with gr.Blocks() as app_credits:
424
- gr.Markdown("""
425
- # Credits
426
-
427
- * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
428
- * [RootingInLoad](https://github.com/RootingInLoad) for the podcast generation
429
- """)
430
- with gr.Blocks() as app_tts:
431
- gr.Markdown("# Batched TTS")
432
- ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
433
- gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
434
- model_choice = gr.Radio(
435
- choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
436
- )
437
- generate_btn = gr.Button("Synthesize", variant="primary")
438
- with gr.Accordion("Advanced Settings", open=False):
439
- ref_text_input = gr.Textbox(
440
- label="Reference Text",
441
- info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
442
- lines=2,
443
- )
444
- remove_silence = gr.Checkbox(
445
- label="Remove Silences",
446
- info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
447
- value=True,
448
- )
449
- split_words_input = gr.Textbox(
450
- label="Custom Split Words",
451
- info="Enter custom words to split on, separated by commas. Leave blank to use default list.",
452
- lines=2,
453
- )
454
- speed_slider = gr.Slider(
455
- label="Speed",
456
- minimum=0.3,
457
- maximum=2.0,
458
- value=speed,
459
- step=0.1,
460
- info="Adjust the speed of the audio.",
461
- )
462
- speed_slider.change(update_speed, inputs=speed_slider)
463
-
464
- audio_output = gr.Audio(label="Synthesized Audio")
465
- spectrogram_output = gr.Image(label="Spectrogram")
466
-
467
- generate_btn.click(
468
- infer,
469
- inputs=[
470
- ref_audio_input,
471
- ref_text_input,
472
- gen_text_input,
473
- model_choice,
474
- remove_silence,
475
- split_words_input,
476
- ],
477
- outputs=[audio_output, spectrogram_output],
478
- )
479
-
480
- with gr.Blocks() as app_podcast:
481
- gr.Markdown("# Podcast Generation")
482
- speaker1_name = gr.Textbox(label="Speaker 1 Name")
483
- ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
484
- ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2)
485
-
486
- speaker2_name = gr.Textbox(label="Speaker 2 Name")
487
- ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath")
488
- ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2)
489
-
490
- script_input = gr.Textbox(label="Podcast Script", lines=10,
491
- placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...")
492
-
493
- podcast_model_choice = gr.Radio(
494
- choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
495
- )
496
- podcast_remove_silence = gr.Checkbox(
497
- label="Remove Silences",
498
- value=True,
499
- )
500
- generate_podcast_btn = gr.Button("Generate Podcast", variant="primary")
501
- podcast_output = gr.Audio(label="Generated Podcast")
502
-
503
- def podcast_generation(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence):
504
- return generate_podcast(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence)
505
-
506
- generate_podcast_btn.click(
507
- podcast_generation,
508
- inputs=[
509
- script_input,
510
- speaker1_name,
511
- ref_audio_input1,
512
- ref_text_input1,
513
- speaker2_name,
514
- ref_audio_input2,
515
- ref_text_input2,
516
- podcast_model_choice,
517
- podcast_remove_silence,
518
- ],
519
- outputs=podcast_output,
520
- )
521
-
522
- def parse_emotional_text(gen_text):
523
- # Pattern to find (Emotion)
524
- pattern = r'\((.*?)\)'
525
-
526
- # Split the text by the pattern
527
- tokens = re.split(pattern, gen_text)
528
-
529
- segments = []
530
-
531
- current_emotion = 'Regular'
532
-
533
- for i in range(len(tokens)):
534
- if i % 2 == 0:
535
- # This is text
536
- text = tokens[i].strip()
537
- if text:
538
- segments.append({'emotion': current_emotion, 'text': text})
539
- else:
540
- # This is emotion
541
- emotion = tokens[i].strip()
542
- current_emotion = emotion
543
-
544
- return segments
545
-
546
- with gr.Blocks() as app_emotional:
547
- # New section for emotional generation
548
- gr.Markdown(
549
- """
550
- # Multiple Speech-Type Generation
551
-
552
- This section allows you to upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the "Add Speech Type" button. Enter your text in the format shown below, and the system will generate speech using the appropriate emotions. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
553
-
554
- **Example Input:**
555
-
556
- (Regular) Hello, I'd like to order a sandwich please. (Surprised) What do you mean you're out of bread? (Sad) I really wanted a sandwich though... (Angry) You know what, darn you and your little shop, you suck! (Whisper) I'll just go back home and cry now. (Shouting) Why me?!
557
- """
558
- )
559
-
560
- gr.Markdown("Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button.")
561
-
562
- # Regular speech type (mandatory)
563
- with gr.Row():
564
- regular_name = gr.Textbox(value='Regular', label='Speech Type Name', interactive=False)
565
- regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath')
566
- regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2)
567
-
568
- # Additional speech types (up to 9 more)
569
- max_speech_types = 10
570
- speech_type_names = []
571
- speech_type_audios = []
572
- speech_type_ref_texts = []
573
- speech_type_delete_btns = []
574
-
575
- for i in range(max_speech_types - 1):
576
- with gr.Row():
577
- name_input = gr.Textbox(label='Speech Type Name', visible=False)
578
- audio_input = gr.Audio(label='Reference Audio', type='filepath', visible=False)
579
- ref_text_input = gr.Textbox(label='Reference Text', lines=2, visible=False)
580
- delete_btn = gr.Button("Delete", variant="secondary", visible=False)
581
- speech_type_names.append(name_input)
582
- speech_type_audios.append(audio_input)
583
- speech_type_ref_texts.append(ref_text_input)
584
- speech_type_delete_btns.append(delete_btn)
585
-
586
- # Button to add speech type
587
- add_speech_type_btn = gr.Button("Add Speech Type")
588
-
589
- # Keep track of current number of speech types
590
- speech_type_count = gr.State(value=0)
591
-
592
- # Function to add a speech type
593
- def add_speech_type_fn(speech_type_count):
594
- if speech_type_count < max_speech_types - 1:
595
- speech_type_count += 1
596
- # Prepare updates for the components
597
- name_updates = []
598
- audio_updates = []
599
- ref_text_updates = []
600
- delete_btn_updates = []
601
- for i in range(max_speech_types - 1):
602
- if i < speech_type_count:
603
- name_updates.append(gr.update(visible=True))
604
- audio_updates.append(gr.update(visible=True))
605
- ref_text_updates.append(gr.update(visible=True))
606
- delete_btn_updates.append(gr.update(visible=True))
607
- else:
608
- name_updates.append(gr.update())
609
- audio_updates.append(gr.update())
610
- ref_text_updates.append(gr.update())
611
- delete_btn_updates.append(gr.update())
612
- else:
613
- # Optionally, show a warning
614
- # gr.Warning("Maximum number of speech types reached.")
615
- name_updates = [gr.update() for _ in range(max_speech_types - 1)]
616
- audio_updates = [gr.update() for _ in range(max_speech_types - 1)]
617
- ref_text_updates = [gr.update() for _ in range(max_speech_types - 1)]
618
- delete_btn_updates = [gr.update() for _ in range(max_speech_types - 1)]
619
- return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
620
-
621
- add_speech_type_btn.click(
622
- add_speech_type_fn,
623
- inputs=speech_type_count,
624
- outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
625
- )
626
-
627
- # Function to delete a speech type
628
- def make_delete_speech_type_fn(index):
629
- def delete_speech_type_fn(speech_type_count):
630
- # Prepare updates
631
- name_updates = []
632
- audio_updates = []
633
- ref_text_updates = []
634
- delete_btn_updates = []
635
-
636
- for i in range(max_speech_types - 1):
637
- if i == index:
638
- name_updates.append(gr.update(visible=False, value=''))
639
- audio_updates.append(gr.update(visible=False, value=None))
640
- ref_text_updates.append(gr.update(visible=False, value=''))
641
- delete_btn_updates.append(gr.update(visible=False))
642
- else:
643
- name_updates.append(gr.update())
644
- audio_updates.append(gr.update())
645
- ref_text_updates.append(gr.update())
646
- delete_btn_updates.append(gr.update())
647
-
648
- speech_type_count = max(0, speech_type_count - 1)
649
-
650
- return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
651
-
652
- return delete_speech_type_fn
653
-
654
- for i, delete_btn in enumerate(speech_type_delete_btns):
655
- delete_fn = make_delete_speech_type_fn(i)
656
- delete_btn.click(
657
- delete_fn,
658
- inputs=speech_type_count,
659
- outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
660
- )
661
-
662
- # Text input for the prompt
663
- gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10)
664
-
665
- # Model choice
666
- model_choice_emotional = gr.Radio(
667
- choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
668
- )
669
-
670
- with gr.Accordion("Advanced Settings", open=False):
671
- remove_silence_emotional = gr.Checkbox(
672
- label="Remove Silences",
673
- value=True,
674
- )
675
-
676
- # Generate button
677
- generate_emotional_btn = gr.Button("Generate Emotional Speech", variant="primary")
678
-
679
- # Output audio
680
- audio_output_emotional = gr.Audio(label="Synthesized Audio")
681
-
682
- def generate_emotional_speech(
683
- regular_audio,
684
- regular_ref_text,
685
- gen_text,
686
- *args,
687
- ):
688
- num_additional_speech_types = max_speech_types - 1
689
- speech_type_names_list = args[:num_additional_speech_types]
690
- speech_type_audios_list = args[num_additional_speech_types:2 * num_additional_speech_types]
691
- speech_type_ref_texts_list = args[2 * num_additional_speech_types:3 * num_additional_speech_types]
692
- model_choice = args[3 * num_additional_speech_types]
693
- remove_silence = args[3 * num_additional_speech_types + 1]
694
-
695
- # Collect the speech types and their audios into a dict
696
- speech_types = {'Regular': {'audio': regular_audio, 'ref_text': regular_ref_text}}
697
-
698
- for name_input, audio_input, ref_text_input in zip(speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list):
699
- if name_input and audio_input:
700
- speech_types[name_input] = {'audio': audio_input, 'ref_text': ref_text_input}
701
-
702
- # Parse the gen_text into segments
703
- segments = parse_speechtypes_text(gen_text)
704
-
705
- # For each segment, generate speech
706
- generated_audio_segments = []
707
- current_emotion = 'Regular'
708
-
709
- for segment in segments:
710
- emotion = segment['emotion']
711
- text = segment['text']
712
-
713
- if emotion in speech_types:
714
- current_emotion = emotion
715
- else:
716
- # If emotion not available, default to Regular
717
- current_emotion = 'Regular'
718
-
719
- ref_audio = speech_types[current_emotion]['audio']
720
- ref_text = speech_types[current_emotion].get('ref_text', '')
721
-
722
- # Generate speech for this segment
723
- audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, "")
724
- sr, audio_data = audio
725
-
726
- generated_audio_segments.append(audio_data)
727
-
728
- # Concatenate all audio segments
729
- if generated_audio_segments:
730
- final_audio_data = np.concatenate(generated_audio_segments)
731
- return (sr, final_audio_data)
732
- else:
733
- gr.Warning("No audio generated.")
734
- return None
735
-
736
- generate_emotional_btn.click(
737
- generate_emotional_speech,
738
- inputs=[
739
- regular_audio,
740
- regular_ref_text,
741
- gen_text_input_emotional,
742
- ] + speech_type_names + speech_type_audios + speech_type_ref_texts + [
743
- model_choice_emotional,
744
- remove_silence_emotional,
745
- ],
746
- outputs=audio_output_emotional,
747
- )
748
-
749
- # Validation function to disable Generate button if speech types are missing
750
- def validate_speech_types(
751
- gen_text,
752
- regular_name,
753
- *args
754
- ):
755
- num_additional_speech_types = max_speech_types - 1
756
- speech_type_names_list = args[:num_additional_speech_types]
757
-
758
- # Collect the speech types names
759
- speech_types_available = set()
760
- if regular_name:
761
- speech_types_available.add(regular_name)
762
- for name_input in speech_type_names_list:
763
- if name_input:
764
- speech_types_available.add(name_input)
765
-
766
- # Parse the gen_text to get the speech types used
767
- segments = parse_emotional_text(gen_text)
768
- speech_types_in_text = set(segment['emotion'] for segment in segments)
769
-
770
- # Check if all speech types in text are available
771
- missing_speech_types = speech_types_in_text - speech_types_available
772
-
773
- if missing_speech_types:
774
- # Disable the generate button
775
- return gr.update(interactive=False)
776
- else:
777
- # Enable the generate button
778
- return gr.update(interactive=True)
779
-
780
- gen_text_input_emotional.change(
781
- validate_speech_types,
782
- inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
783
- outputs=generate_emotional_btn
784
- )
785
- with gr.Blocks() as app:
786
- gr.Markdown(
787
- """
788
- # E2/F5 TTS
789
-
790
- This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
791
-
792
- * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
793
- * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
794
-
795
- The checkpoints support English and Chinese.
796
-
797
- If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
798
-
799
- **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
800
- """
801
- )
802
- gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
803
-
804
- @click.command()
805
- @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
806
- @click.option("--host", "-H", default=None, help="Host to run the app on")
807
- @click.option(
808
- "--share",
809
- "-s",
810
- default=False,
811
- is_flag=True,
812
- help="Share the app via Gradio share link",
813
- )
814
- @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
815
- def main(port, host, share, api):
816
- global app
817
- print(f"Starting app...")
818
- app.queue(api_open=api).launch(
819
- server_name=host, server_port=port, share=share, show_api=api
820
- )
821
-
822
-
823
- if __name__ == "__main__":
824
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference-cli.py DELETED
@@ -1,170 +0,0 @@
1
- import argparse
2
- import codecs
3
- import re
4
- from pathlib import Path
5
-
6
- import numpy as np
7
- import soundfile as sf
8
- import tomli
9
- from cached_path import cached_path
10
-
11
- from model import DiT, UNetT
12
- from model.utils_infer import (
13
- load_vocoder,
14
- load_model,
15
- preprocess_ref_audio_text,
16
- infer_process,
17
- remove_silence_for_generated_wav,
18
- )
19
-
20
-
21
- parser = argparse.ArgumentParser(
22
- prog="python3 inference-cli.py",
23
- description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
24
- epilog="Specify options above to override one or more settings from config.",
25
- )
26
- parser.add_argument(
27
- "-c",
28
- "--config",
29
- help="Configuration file. Default=cli-config.toml",
30
- default="inference-cli.toml",
31
- )
32
- parser.add_argument(
33
- "-m",
34
- "--model",
35
- help="F5-TTS | E2-TTS",
36
- )
37
- parser.add_argument(
38
- "-p",
39
- "--ckpt_file",
40
- help="The Checkpoint .pt",
41
- )
42
- parser.add_argument(
43
- "-v",
44
- "--vocab_file",
45
- help="The vocab .txt",
46
- )
47
- parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
48
- parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
49
- parser.add_argument(
50
- "-t",
51
- "--gen_text",
52
- type=str,
53
- help="Text to generate.",
54
- )
55
- parser.add_argument(
56
- "-f",
57
- "--gen_file",
58
- type=str,
59
- help="File with text to generate. Ignores --text",
60
- )
61
- parser.add_argument(
62
- "-o",
63
- "--output_dir",
64
- type=str,
65
- help="Path to output folder..",
66
- )
67
- parser.add_argument(
68
- "--remove_silence",
69
- help="Remove silence.",
70
- )
71
- parser.add_argument(
72
- "--load_vocoder_from_local",
73
- action="store_true",
74
- help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
75
- )
76
- args = parser.parse_args()
77
-
78
- config = tomli.load(open(args.config, "rb"))
79
-
80
- ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
81
- ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
82
- gen_text = args.gen_text if args.gen_text else config["gen_text"]
83
- gen_file = args.gen_file if args.gen_file else config["gen_file"]
84
- if gen_file:
85
- gen_text = codecs.open(gen_file, "r", "utf-8").read()
86
- output_dir = args.output_dir if args.output_dir else config["output_dir"]
87
- model = args.model if args.model else config["model"]
88
- ckpt_file = args.ckpt_file if args.ckpt_file else ""
89
- vocab_file = args.vocab_file if args.vocab_file else ""
90
- remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
91
- wave_path = Path(output_dir) / "out.wav"
92
- spectrogram_path = Path(output_dir) / "out.png"
93
- vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
94
-
95
- vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
96
-
97
-
98
- # load models
99
- if model == "F5-TTS":
100
- model_cls = DiT
101
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
102
- if ckpt_file == "":
103
- repo_name = "F5-TTS"
104
- exp_name = "F5TTS_Base"
105
- ckpt_step = 1200000
106
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
107
- # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
108
-
109
- elif model == "E2-TTS":
110
- model_cls = UNetT
111
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
112
- if ckpt_file == "":
113
- repo_name = "E2-TTS"
114
- exp_name = "E2TTS_Base"
115
- ckpt_step = 1200000
116
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
117
- # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
118
-
119
- print(f"Using {model}...")
120
- ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
121
-
122
-
123
- def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
124
- main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
125
- if "voices" not in config:
126
- voices = {"main": main_voice}
127
- else:
128
- voices = config["voices"]
129
- voices["main"] = main_voice
130
- for voice in voices:
131
- voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
132
- voices[voice]["ref_audio"], voices[voice]["ref_text"]
133
- )
134
- print("Voice:", voice)
135
- print("Ref_audio:", voices[voice]["ref_audio"])
136
- print("Ref_text:", voices[voice]["ref_text"])
137
-
138
- generated_audio_segments = []
139
- reg1 = r"(?=\[\w+\])"
140
- chunks = re.split(reg1, text_gen)
141
- reg2 = r"\[(\w+)\]"
142
- for text in chunks:
143
- match = re.match(reg2, text)
144
- if match:
145
- voice = match[1]
146
- else:
147
- print("No voice tag found, using main.")
148
- voice = "main"
149
- if voice not in voices:
150
- print(f"Voice {voice} not found, using main.")
151
- voice = "main"
152
- text = re.sub(reg2, "", text)
153
- gen_text = text.strip()
154
- ref_audio = voices[voice]["ref_audio"]
155
- ref_text = voices[voice]["ref_text"]
156
- print(f"Voice: {voice}")
157
- audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj)
158
- generated_audio_segments.append(audio)
159
-
160
- if generated_audio_segments:
161
- final_wave = np.concatenate(generated_audio_segments)
162
- with open(wave_path, "wb") as f:
163
- sf.write(f.name, final_wave, final_sample_rate)
164
- # Remove silence
165
- if remove_silence:
166
- remove_silence_for_generated_wav(f.name)
167
- print(f.name)
168
-
169
-
170
- main_process(ref_audio, ref_text, gen_text, ema_model, remove_silence)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference-cli.toml DELETED
@@ -1,10 +0,0 @@
1
- # F5-TTS | E2-TTS
2
- model = "F5-TTS"
3
- ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
4
- # If an empty "", transcribes the reference audio automatically.
5
- ref_text = "Some call me nature, others call me mother nature."
6
- gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
7
- # File with text to generate. Ignores the text above.
8
- gen_file = ""
9
- remove_silence = false
10
- output_dir = "tests"
 
 
 
 
 
 
 
 
 
 
 
model/__init__.py CHANGED
@@ -5,6 +5,3 @@ from model.backbones.dit import DiT
5
  from model.backbones.mmdit import MMDiT
6
 
7
  from model.trainer import Trainer
8
-
9
-
10
- __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
 
5
  from model.backbones.mmdit import MMDiT
6
 
7
  from model.trainer import Trainer
 
 
 
model/backbones/dit.py CHANGED
@@ -13,6 +13,8 @@ import torch
13
  from torch import nn
14
  import torch.nn.functional as F
15
 
 
 
16
  from x_transformers.x_transformers import RotaryEmbedding
17
 
18
  from model.modules import (
@@ -21,16 +23,14 @@ from model.modules import (
21
  ConvPositionEmbedding,
22
  DiTBlock,
23
  AdaLayerNormZero_Final,
24
- precompute_freqs_cis,
25
- get_pos_embed_indices,
26
  )
27
 
28
 
29
  # Text embedding
30
 
31
-
32
  class TextEmbedding(nn.Module):
33
- def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
34
  super().__init__()
35
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
 
@@ -38,22 +38,20 @@ class TextEmbedding(nn.Module):
38
  self.extra_modeling = True
39
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
40
  self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
41
- self.text_blocks = nn.Sequential(
42
- *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
43
- )
44
  else:
45
  self.extra_modeling = False
46
 
47
- def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
 
48
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
49
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
50
- batch, text_len = text.shape[0], text.shape[1]
51
- text = F.pad(text, (0, seq_len - text_len), value=0)
52
 
53
  if drop_text: # cfg for text
54
  text = torch.zeros_like(text)
55
 
56
- text = self.text_embed(text) # b n -> b n d
57
 
58
  # possible extra modeling
59
  if self.extra_modeling:
@@ -71,91 +69,88 @@ class TextEmbedding(nn.Module):
71
 
72
  # noised input audio and context mixing embedding
73
 
74
-
75
  class InputEmbedding(nn.Module):
76
  def __init__(self, mel_dim, text_dim, out_dim):
77
  super().__init__()
78
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
79
- self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
80
 
81
- def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
82
  if drop_audio_cond: # cfg for cond audio
83
  cond = torch.zeros_like(cond)
84
 
85
- x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
86
  x = self.conv_pos_embed(x) + x
87
  return x
88
-
89
 
90
  # Transformer backbone using DiT blocks
91
 
92
-
93
  class DiT(nn.Module):
94
- def __init__(
95
- self,
96
- *,
97
- dim,
98
- depth=8,
99
- heads=8,
100
- dim_head=64,
101
- dropout=0.1,
102
- ff_mult=4,
103
- mel_dim=100,
104
- text_num_embeds=256,
105
- text_dim=None,
106
- conv_layers=0,
107
- long_skip_connection=False,
108
  ):
109
  super().__init__()
110
 
111
  self.time_embed = TimestepEmbedding(dim)
112
  if text_dim is None:
113
  text_dim = mel_dim
114
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
115
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
116
 
117
  self.rotary_embed = RotaryEmbedding(dim_head)
118
 
119
  self.dim = dim
120
  self.depth = depth
121
-
122
  self.transformer_blocks = nn.ModuleList(
123
- [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
 
 
 
 
 
 
 
 
 
124
  )
125
- self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
126
-
127
  self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
128
  self.proj_out = nn.Linear(dim, mel_dim)
129
 
130
  def forward(
131
  self,
132
- x: float["b n d"], # nosied input audio # noqa: F722
133
- cond: float["b n d"], # masked cond audio # noqa: F722
134
- text: int["b nt"], # text # noqa: F722
135
- time: float["b"] | float[""], # time step # noqa: F821 F722
136
  drop_audio_cond, # cfg for cond audio
137
- drop_text, # cfg for text
138
- mask: bool["b n"] | None = None, # noqa: F722
139
  ):
140
  batch, seq_len = x.shape[0], x.shape[1]
141
  if time.ndim == 0:
142
- time = time.repeat(batch)
143
-
144
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
145
  t = self.time_embed(time)
146
- text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
147
- x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
148
-
149
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
150
 
151
  if self.long_skip_connection is not None:
152
  residual = x
153
 
154
  for block in self.transformer_blocks:
155
- x = block(x, t, mask=mask, rope=rope)
156
 
157
  if self.long_skip_connection is not None:
158
- x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
159
 
160
  x = self.norm_out(x, t)
161
  output = self.proj_out(x)
 
13
  from torch import nn
14
  import torch.nn.functional as F
15
 
16
+ from einops import repeat
17
+
18
  from x_transformers.x_transformers import RotaryEmbedding
19
 
20
  from model.modules import (
 
23
  ConvPositionEmbedding,
24
  DiTBlock,
25
  AdaLayerNormZero_Final,
26
+ precompute_freqs_cis, get_pos_embed_indices,
 
27
  )
28
 
29
 
30
  # Text embedding
31
 
 
32
  class TextEmbedding(nn.Module):
33
+ def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
34
  super().__init__()
35
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
 
 
38
  self.extra_modeling = True
39
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
40
  self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
41
+ self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
 
 
42
  else:
43
  self.extra_modeling = False
44
 
45
+ def forward(self, text: int['b nt'], seq_len, drop_text = False):
46
+ batch, text_len = text.shape[0], text.shape[1]
47
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
48
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
49
+ text = F.pad(text, (0, seq_len - text_len), value = 0)
 
50
 
51
  if drop_text: # cfg for text
52
  text = torch.zeros_like(text)
53
 
54
+ text = self.text_embed(text) # b n -> b n d
55
 
56
  # possible extra modeling
57
  if self.extra_modeling:
 
69
 
70
  # noised input audio and context mixing embedding
71
 
 
72
  class InputEmbedding(nn.Module):
73
  def __init__(self, mel_dim, text_dim, out_dim):
74
  super().__init__()
75
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
76
+ self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
77
 
78
+ def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
79
  if drop_audio_cond: # cfg for cond audio
80
  cond = torch.zeros_like(cond)
81
 
82
+ x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
83
  x = self.conv_pos_embed(x) + x
84
  return x
85
+
86
 
87
  # Transformer backbone using DiT blocks
88
 
 
89
  class DiT(nn.Module):
90
+ def __init__(self, *,
91
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
92
+ mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
93
+ long_skip_connection = False,
 
 
 
 
 
 
 
 
 
 
94
  ):
95
  super().__init__()
96
 
97
  self.time_embed = TimestepEmbedding(dim)
98
  if text_dim is None:
99
  text_dim = mel_dim
100
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
101
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
102
 
103
  self.rotary_embed = RotaryEmbedding(dim_head)
104
 
105
  self.dim = dim
106
  self.depth = depth
107
+
108
  self.transformer_blocks = nn.ModuleList(
109
+ [
110
+ DiTBlock(
111
+ dim = dim,
112
+ heads = heads,
113
+ dim_head = dim_head,
114
+ ff_mult = ff_mult,
115
+ dropout = dropout
116
+ )
117
+ for _ in range(depth)
118
+ ]
119
  )
120
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias = False) if long_skip_connection else None
121
+
122
  self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
123
  self.proj_out = nn.Linear(dim, mel_dim)
124
 
125
  def forward(
126
  self,
127
+ x: float['b n d'], # nosied input audio
128
+ cond: float['b n d'], # masked cond audio
129
+ text: int['b nt'], # text
130
+ time: float['b'] | float[''], # time step
131
  drop_audio_cond, # cfg for cond audio
132
+ drop_text, # cfg for text
133
+ mask: bool['b n'] | None = None,
134
  ):
135
  batch, seq_len = x.shape[0], x.shape[1]
136
  if time.ndim == 0:
137
+ time = repeat(time, ' -> b', b = batch)
138
+
139
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
140
  t = self.time_embed(time)
141
+ text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
142
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
143
+
144
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
145
 
146
  if self.long_skip_connection is not None:
147
  residual = x
148
 
149
  for block in self.transformer_blocks:
150
+ x = block(x, t, mask = mask, rope = rope)
151
 
152
  if self.long_skip_connection is not None:
153
+ x = self.long_skip_connection(torch.cat((x, residual), dim = -1))
154
 
155
  x = self.norm_out(x, t)
156
  output = self.proj_out(x)
model/backbones/mmdit.py CHANGED
@@ -12,6 +12,8 @@ from __future__ import annotations
12
  import torch
13
  from torch import nn
14
 
 
 
15
  from x_transformers.x_transformers import RotaryEmbedding
16
 
17
  from model.modules import (
@@ -19,14 +21,12 @@ from model.modules import (
19
  ConvPositionEmbedding,
20
  MMDiTBlock,
21
  AdaLayerNormZero_Final,
22
- precompute_freqs_cis,
23
- get_pos_embed_indices,
24
  )
25
 
26
 
27
  # text embedding
28
 
29
-
30
  class TextEmbedding(nn.Module):
31
  def __init__(self, out_dim, text_num_embeds):
32
  super().__init__()
@@ -35,7 +35,7 @@ class TextEmbedding(nn.Module):
35
  self.precompute_max_pos = 1024
36
  self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
 
38
- def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
39
  text = text + 1
40
  if drop_text:
41
  text = torch.zeros_like(text)
@@ -54,37 +54,27 @@ class TextEmbedding(nn.Module):
54
 
55
  # noised input & masked cond audio embedding
56
 
57
-
58
  class AudioEmbedding(nn.Module):
59
  def __init__(self, in_dim, out_dim):
60
  super().__init__()
61
  self.linear = nn.Linear(2 * in_dim, out_dim)
62
  self.conv_pos_embed = ConvPositionEmbedding(out_dim)
63
 
64
- def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
65
  if drop_audio_cond:
66
  cond = torch.zeros_like(cond)
67
- x = torch.cat((x, cond), dim=-1)
68
  x = self.linear(x)
69
  x = self.conv_pos_embed(x) + x
70
  return x
71
-
72
 
73
  # Transformer backbone using MM-DiT blocks
74
 
75
-
76
  class MMDiT(nn.Module):
77
- def __init__(
78
- self,
79
- *,
80
- dim,
81
- depth=8,
82
- heads=8,
83
- dim_head=64,
84
- dropout=0.1,
85
- ff_mult=4,
86
- text_num_embeds=256,
87
- mel_dim=100,
88
  ):
89
  super().__init__()
90
 
@@ -96,16 +86,16 @@ class MMDiT(nn.Module):
96
 
97
  self.dim = dim
98
  self.depth = depth
99
-
100
  self.transformer_blocks = nn.ModuleList(
101
  [
102
  MMDiTBlock(
103
- dim=dim,
104
- heads=heads,
105
- dim_head=dim_head,
106
- dropout=dropout,
107
- ff_mult=ff_mult,
108
- context_pre_only=i == depth - 1,
109
  )
110
  for i in range(depth)
111
  ]
@@ -115,30 +105,30 @@ class MMDiT(nn.Module):
115
 
116
  def forward(
117
  self,
118
- x: float["b n d"], # nosied input audio # noqa: F722
119
- cond: float["b n d"], # masked cond audio # noqa: F722
120
- text: int["b nt"], # text # noqa: F722
121
- time: float["b"] | float[""], # time step # noqa: F821 F722
122
  drop_audio_cond, # cfg for cond audio
123
- drop_text, # cfg for text
124
- mask: bool["b n"] | None = None, # noqa: F722
125
  ):
126
  batch = x.shape[0]
127
  if time.ndim == 0:
128
- time = time.repeat(batch)
129
 
130
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
131
  t = self.time_embed(time)
132
- c = self.text_embed(text, drop_text=drop_text)
133
- x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
134
 
135
  seq_len = x.shape[1]
136
  text_len = text.shape[1]
137
  rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
138
  rope_text = self.rotary_embed.forward_from_seq_len(text_len)
139
-
140
  for block in self.transformer_blocks:
141
- c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
142
 
143
  x = self.norm_out(x, t)
144
  output = self.proj_out(x)
 
12
  import torch
13
  from torch import nn
14
 
15
+ from einops import repeat
16
+
17
  from x_transformers.x_transformers import RotaryEmbedding
18
 
19
  from model.modules import (
 
21
  ConvPositionEmbedding,
22
  MMDiTBlock,
23
  AdaLayerNormZero_Final,
24
+ precompute_freqs_cis, get_pos_embed_indices,
 
25
  )
26
 
27
 
28
  # text embedding
29
 
 
30
  class TextEmbedding(nn.Module):
31
  def __init__(self, out_dim, text_num_embeds):
32
  super().__init__()
 
35
  self.precompute_max_pos = 1024
36
  self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
 
38
+ def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']:
39
  text = text + 1
40
  if drop_text:
41
  text = torch.zeros_like(text)
 
54
 
55
  # noised input & masked cond audio embedding
56
 
 
57
  class AudioEmbedding(nn.Module):
58
  def __init__(self, in_dim, out_dim):
59
  super().__init__()
60
  self.linear = nn.Linear(2 * in_dim, out_dim)
61
  self.conv_pos_embed = ConvPositionEmbedding(out_dim)
62
 
63
+ def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False):
64
  if drop_audio_cond:
65
  cond = torch.zeros_like(cond)
66
+ x = torch.cat((x, cond), dim = -1)
67
  x = self.linear(x)
68
  x = self.conv_pos_embed(x) + x
69
  return x
70
+
71
 
72
  # Transformer backbone using MM-DiT blocks
73
 
 
74
  class MMDiT(nn.Module):
75
+ def __init__(self, *,
76
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
77
+ text_num_embeds = 256, mel_dim = 100,
 
 
 
 
 
 
 
 
78
  ):
79
  super().__init__()
80
 
 
86
 
87
  self.dim = dim
88
  self.depth = depth
89
+
90
  self.transformer_blocks = nn.ModuleList(
91
  [
92
  MMDiTBlock(
93
+ dim = dim,
94
+ heads = heads,
95
+ dim_head = dim_head,
96
+ dropout = dropout,
97
+ ff_mult = ff_mult,
98
+ context_pre_only = i == depth - 1,
99
  )
100
  for i in range(depth)
101
  ]
 
105
 
106
  def forward(
107
  self,
108
+ x: float['b n d'], # nosied input audio
109
+ cond: float['b n d'], # masked cond audio
110
+ text: int['b nt'], # text
111
+ time: float['b'] | float[''], # time step
112
  drop_audio_cond, # cfg for cond audio
113
+ drop_text, # cfg for text
114
+ mask: bool['b n'] | None = None,
115
  ):
116
  batch = x.shape[0]
117
  if time.ndim == 0:
118
+ time = repeat(time, ' -> b', b = batch)
119
 
120
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
121
  t = self.time_embed(time)
122
+ c = self.text_embed(text, drop_text = drop_text)
123
+ x = self.audio_embed(x, cond, drop_audio_cond = drop_audio_cond)
124
 
125
  seq_len = x.shape[1]
126
  text_len = text.shape[1]
127
  rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
128
  rope_text = self.rotary_embed.forward_from_seq_len(text_len)
129
+
130
  for block in self.transformer_blocks:
131
+ c, x = block(x, c, t, mask = mask, rope = rope_audio, c_rope = rope_text)
132
 
133
  x = self.norm_out(x, t)
134
  output = self.proj_out(x)
model/backbones/unett.py CHANGED
@@ -14,6 +14,8 @@ import torch
14
  from torch import nn
15
  import torch.nn.functional as F
16
 
 
 
17
  from x_transformers import RMSNorm
18
  from x_transformers.x_transformers import RotaryEmbedding
19
 
@@ -24,16 +26,14 @@ from model.modules import (
24
  Attention,
25
  AttnProcessor,
26
  FeedForward,
27
- precompute_freqs_cis,
28
- get_pos_embed_indices,
29
  )
30
 
31
 
32
  # Text embedding
33
 
34
-
35
  class TextEmbedding(nn.Module):
36
- def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
37
  super().__init__()
38
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
 
@@ -41,22 +41,20 @@ class TextEmbedding(nn.Module):
41
  self.extra_modeling = True
42
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
43
  self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
44
- self.text_blocks = nn.Sequential(
45
- *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
46
- )
47
  else:
48
  self.extra_modeling = False
49
 
50
- def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
 
51
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
52
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
53
- batch, text_len = text.shape[0], text.shape[1]
54
- text = F.pad(text, (0, seq_len - text_len), value=0)
55
 
56
  if drop_text: # cfg for text
57
  text = torch.zeros_like(text)
58
 
59
- text = self.text_embed(text) # b n -> b n d
60
 
61
  # possible extra modeling
62
  if self.extra_modeling:
@@ -74,40 +72,28 @@ class TextEmbedding(nn.Module):
74
 
75
  # noised input audio and context mixing embedding
76
 
77
-
78
  class InputEmbedding(nn.Module):
79
  def __init__(self, mel_dim, text_dim, out_dim):
80
  super().__init__()
81
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
82
- self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
83
 
84
- def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
85
  if drop_audio_cond: # cfg for cond audio
86
  cond = torch.zeros_like(cond)
87
 
88
- x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
89
  x = self.conv_pos_embed(x) + x
90
  return x
91
 
92
 
93
  # Flat UNet Transformer backbone
94
 
95
-
96
  class UNetT(nn.Module):
97
- def __init__(
98
- self,
99
- *,
100
- dim,
101
- depth=8,
102
- heads=8,
103
- dim_head=64,
104
- dropout=0.1,
105
- ff_mult=4,
106
- mel_dim=100,
107
- text_num_embeds=256,
108
- text_dim=None,
109
- conv_layers=0,
110
- skip_connect_type: Literal["add", "concat", "none"] = "concat",
111
  ):
112
  super().__init__()
113
  assert depth % 2 == 0, "UNet-Transformer's depth should be even."
@@ -115,7 +101,7 @@ class UNetT(nn.Module):
115
  self.time_embed = TimestepEmbedding(dim)
116
  if text_dim is None:
117
  text_dim = mel_dim
118
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
119
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
120
 
121
  self.rotary_embed = RotaryEmbedding(dim_head)
@@ -124,7 +110,7 @@ class UNetT(nn.Module):
124
 
125
  self.dim = dim
126
  self.skip_connect_type = skip_connect_type
127
- needs_skip_proj = skip_connect_type == "concat"
128
 
129
  self.depth = depth
130
  self.layers = nn.ModuleList([])
@@ -134,57 +120,53 @@ class UNetT(nn.Module):
134
 
135
  attn_norm = RMSNorm(dim)
136
  attn = Attention(
137
- processor=AttnProcessor(),
138
- dim=dim,
139
- heads=heads,
140
- dim_head=dim_head,
141
- dropout=dropout,
142
- )
143
 
144
  ff_norm = RMSNorm(dim)
145
- ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
146
-
147
- skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
148
-
149
- self.layers.append(
150
- nn.ModuleList(
151
- [
152
- skip_proj,
153
- attn_norm,
154
- attn,
155
- ff_norm,
156
- ff,
157
- ]
158
- )
159
- )
160
 
161
  self.norm_out = RMSNorm(dim)
162
  self.proj_out = nn.Linear(dim, mel_dim)
163
 
164
  def forward(
165
  self,
166
- x: float["b n d"], # nosied input audio # noqa: F722
167
- cond: float["b n d"], # masked cond audio # noqa: F722
168
- text: int["b nt"], # text # noqa: F722
169
- time: float["b"] | float[""], # time step # noqa: F821 F722
170
  drop_audio_cond, # cfg for cond audio
171
- drop_text, # cfg for text
172
- mask: bool["b n"] | None = None, # noqa: F722
173
  ):
174
  batch, seq_len = x.shape[0], x.shape[1]
175
  if time.ndim == 0:
176
- time = time.repeat(batch)
177
-
178
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
179
  t = self.time_embed(time)
180
- text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
181
- x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
182
 
183
  # postfix time t to input x, [b n d] -> [b n+1 d]
184
- x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
185
  if mask is not None:
186
  mask = F.pad(mask, (1, 0), value=1)
187
-
188
  rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
189
 
190
  # flat unet transformer
@@ -202,18 +184,18 @@ class UNetT(nn.Module):
202
 
203
  if is_later_half:
204
  skip = skips.pop()
205
- if skip_connect_type == "concat":
206
- x = torch.cat((x, skip), dim=-1)
207
  x = maybe_skip_proj(x)
208
- elif skip_connect_type == "add":
209
  x = x + skip
210
 
211
  # attention and feedforward blocks
212
- x = attn(attn_norm(x), rope=rope, mask=mask) + x
213
  x = ff(ff_norm(x)) + x
214
 
215
  assert len(skips) == 0
216
 
217
- x = self.norm_out(x)[:, 1:, :] # unpack t from x
218
 
219
  return self.proj_out(x)
 
14
  from torch import nn
15
  import torch.nn.functional as F
16
 
17
+ from einops import repeat, pack, unpack
18
+
19
  from x_transformers import RMSNorm
20
  from x_transformers.x_transformers import RotaryEmbedding
21
 
 
26
  Attention,
27
  AttnProcessor,
28
  FeedForward,
29
+ precompute_freqs_cis, get_pos_embed_indices,
 
30
  )
31
 
32
 
33
  # Text embedding
34
 
 
35
  class TextEmbedding(nn.Module):
36
+ def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
37
  super().__init__()
38
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
 
 
41
  self.extra_modeling = True
42
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
43
  self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
44
+ self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
 
 
45
  else:
46
  self.extra_modeling = False
47
 
48
+ def forward(self, text: int['b nt'], seq_len, drop_text = False):
49
+ batch, text_len = text.shape[0], text.shape[1]
50
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
51
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
52
+ text = F.pad(text, (0, seq_len - text_len), value = 0)
 
53
 
54
  if drop_text: # cfg for text
55
  text = torch.zeros_like(text)
56
 
57
+ text = self.text_embed(text) # b n -> b n d
58
 
59
  # possible extra modeling
60
  if self.extra_modeling:
 
72
 
73
  # noised input audio and context mixing embedding
74
 
 
75
  class InputEmbedding(nn.Module):
76
  def __init__(self, mel_dim, text_dim, out_dim):
77
  super().__init__()
78
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
79
+ self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
80
 
81
+ def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
82
  if drop_audio_cond: # cfg for cond audio
83
  cond = torch.zeros_like(cond)
84
 
85
+ x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
86
  x = self.conv_pos_embed(x) + x
87
  return x
88
 
89
 
90
  # Flat UNet Transformer backbone
91
 
 
92
  class UNetT(nn.Module):
93
+ def __init__(self, *,
94
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
95
+ mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
96
+ skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
 
 
 
 
 
 
 
 
 
 
97
  ):
98
  super().__init__()
99
  assert depth % 2 == 0, "UNet-Transformer's depth should be even."
 
101
  self.time_embed = TimestepEmbedding(dim)
102
  if text_dim is None:
103
  text_dim = mel_dim
104
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
105
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
106
 
107
  self.rotary_embed = RotaryEmbedding(dim_head)
 
110
 
111
  self.dim = dim
112
  self.skip_connect_type = skip_connect_type
113
+ needs_skip_proj = skip_connect_type == 'concat'
114
 
115
  self.depth = depth
116
  self.layers = nn.ModuleList([])
 
120
 
121
  attn_norm = RMSNorm(dim)
122
  attn = Attention(
123
+ processor = AttnProcessor(),
124
+ dim = dim,
125
+ heads = heads,
126
+ dim_head = dim_head,
127
+ dropout = dropout,
128
+ )
129
 
130
  ff_norm = RMSNorm(dim)
131
+ ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
132
+
133
+ skip_proj = nn.Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None
134
+
135
+ self.layers.append(nn.ModuleList([
136
+ skip_proj,
137
+ attn_norm,
138
+ attn,
139
+ ff_norm,
140
+ ff,
141
+ ]))
 
 
 
 
142
 
143
  self.norm_out = RMSNorm(dim)
144
  self.proj_out = nn.Linear(dim, mel_dim)
145
 
146
  def forward(
147
  self,
148
+ x: float['b n d'], # nosied input audio
149
+ cond: float['b n d'], # masked cond audio
150
+ text: int['b nt'], # text
151
+ time: float['b'] | float[''], # time step
152
  drop_audio_cond, # cfg for cond audio
153
+ drop_text, # cfg for text
154
+ mask: bool['b n'] | None = None,
155
  ):
156
  batch, seq_len = x.shape[0], x.shape[1]
157
  if time.ndim == 0:
158
+ time = repeat(time, ' -> b', b = batch)
159
+
160
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
161
  t = self.time_embed(time)
162
+ text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
163
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
164
 
165
  # postfix time t to input x, [b n d] -> [b n+1 d]
166
+ x, ps = pack((t, x), 'b * d')
167
  if mask is not None:
168
  mask = F.pad(mask, (1, 0), value=1)
169
+
170
  rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
171
 
172
  # flat unet transformer
 
184
 
185
  if is_later_half:
186
  skip = skips.pop()
187
+ if skip_connect_type == 'concat':
188
+ x = torch.cat((x, skip), dim = -1)
189
  x = maybe_skip_proj(x)
190
+ elif skip_connect_type == 'add':
191
  x = x + skip
192
 
193
  # attention and feedforward blocks
194
+ x = attn(attn_norm(x), rope = rope, mask = mask) + x
195
  x = ff(ff_norm(x)) + x
196
 
197
  assert len(skips) == 0
198
 
199
+ _, x = unpack(self.norm_out(x), ps, 'b * d')
200
 
201
  return self.proj_out(x)
model/cfm.py CHANGED
@@ -18,34 +18,34 @@ from torch.nn.utils.rnn import pad_sequence
18
 
19
  from torchdiffeq import odeint
20
 
 
 
21
  from model.modules import MelSpec
 
22
  from model.utils import (
23
- default,
24
- exists,
25
- list_str_to_idx,
26
- list_str_to_tensor,
27
- lens_to_mask,
28
- mask_from_frac_lengths,
29
- )
30
 
31
 
32
  class CFM(nn.Module):
33
  def __init__(
34
  self,
35
  transformer: nn.Module,
36
- sigma=0.0,
37
  odeint_kwargs: dict = dict(
38
  # atol = 1e-5,
39
  # rtol = 1e-5,
40
- method="euler" # 'midpoint'
41
  ),
42
- audio_drop_prob=0.3,
43
- cond_drop_prob=0.2,
44
- num_channels=None,
45
  mel_spec_module: nn.Module | None = None,
46
  mel_spec_kwargs: dict = dict(),
47
- frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
48
- vocab_char_map: dict[str:int] | None = None,
49
  ):
50
  super().__init__()
51
 
@@ -81,37 +81,33 @@ class CFM(nn.Module):
81
  @torch.no_grad()
82
  def sample(
83
  self,
84
- cond: float["b n d"] | float["b nw"], # noqa: F722
85
- text: int["b nt"] | list[str], # noqa: F722
86
- duration: int | int["b"], # noqa: F821
87
  *,
88
- lens: int["b"] | None = None, # noqa: F821
89
- steps=32,
90
- cfg_strength=1.0,
91
- sway_sampling_coef=None,
92
  seed: int | None = None,
93
- max_duration=4096,
94
- vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
95
- no_ref_audio=False,
96
- duplicate_test=False,
97
- t_inter=0.1,
98
- edit_mask=None,
99
  ):
100
  self.eval()
101
 
102
- if next(self.parameters()).dtype == torch.float16:
103
- cond = cond.half()
104
-
105
  # raw wave
106
 
107
  if cond.ndim == 2:
108
  cond = self.mel_spec(cond)
109
- cond = cond.permute(0, 2, 1)
110
  assert cond.shape[-1] == self.num_channels
111
 
112
  batch, cond_seq_len, device = *cond.shape[:2], cond.device
113
  if not exists(lens):
114
- lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
115
 
116
  # text
117
 
@@ -123,37 +119,30 @@ class CFM(nn.Module):
123
  assert text.shape[0] == batch
124
 
125
  if exists(text):
126
- text_lens = (text != -1).sum(dim=-1)
127
- lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
128
 
129
  # duration
130
 
131
  cond_mask = lens_to_mask(lens)
132
- if edit_mask is not None:
133
- cond_mask = cond_mask & edit_mask
134
 
135
  if isinstance(duration, int):
136
- duration = torch.full((batch,), duration, device=device, dtype=torch.long)
137
 
138
- duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
139
- duration = duration.clamp(max=max_duration)
140
  max_duration = duration.amax()
141
-
142
  # duplicate test corner for inner time step oberservation
143
  if duplicate_test:
144
- test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
145
-
146
- cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
147
- cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
148
- cond_mask = cond_mask.unsqueeze(-1)
149
- step_cond = torch.where(
150
- cond_mask, cond, torch.zeros_like(cond)
151
- ) # allow direct control (cut cond audio) with lens passed in
152
 
153
- if batch > 1:
154
- mask = lens_to_mask(duration)
155
- else: # save memory and speed up, as single inference need no mask currently
156
- mask = None
157
 
158
  # test for no ref audio
159
  if no_ref_audio:
@@ -166,15 +155,11 @@ class CFM(nn.Module):
166
  # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
167
 
168
  # predict flow
169
- pred = self.transformer(
170
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
171
- )
172
  if cfg_strength < 1e-5:
173
  return pred
174
-
175
- null_pred = self.transformer(
176
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
177
- )
178
  return pred + (pred - null_pred) * cfg_strength
179
 
180
  # noise input
@@ -184,8 +169,8 @@ class CFM(nn.Module):
184
  for dur in duration:
185
  if exists(seed):
186
  torch.manual_seed(seed)
187
- y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
188
- y0 = pad_sequence(y0, padding_value=0, batch_first=True)
189
 
190
  t_start = 0
191
 
@@ -195,37 +180,37 @@ class CFM(nn.Module):
195
  y0 = (1 - t_start) * y0 + t_start * test_cond
196
  steps = int(steps * (1 - t_start))
197
 
198
- t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
199
  if sway_sampling_coef is not None:
200
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
201
 
202
  trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
203
-
204
  sampled = trajectory[-1]
205
  out = sampled
206
  out = torch.where(cond_mask, cond, out)
207
 
208
  if exists(vocoder):
209
- out = out.permute(0, 2, 1)
210
  out = vocoder(out)
211
 
212
  return out, trajectory
213
 
214
  def forward(
215
  self,
216
- inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
217
- text: int["b nt"] | list[str], # noqa: F722
218
  *,
219
- lens: int["b"] | None = None, # noqa: F821
220
  noise_scheduler: str | None = None,
221
  ):
222
  # handle raw wave
223
  if inp.ndim == 2:
224
  inp = self.mel_spec(inp)
225
- inp = inp.permute(0, 2, 1)
226
  assert inp.shape[-1] == self.num_channels
227
 
228
- batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
229
 
230
  # handle text as string
231
  if isinstance(text, list):
@@ -237,12 +222,12 @@ class CFM(nn.Module):
237
 
238
  # lens and mask
239
  if not exists(lens):
240
- lens = torch.full((batch,), seq_len, device=device)
241
-
242
- mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
243
 
244
  # get a random span to mask out for training conditionally
245
- frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
246
  rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
247
 
248
  if exists(mask):
@@ -255,16 +240,19 @@ class CFM(nn.Module):
255
  x0 = torch.randn_like(x1)
256
 
257
  # time step
258
- time = torch.rand((batch,), dtype=dtype, device=self.device)
259
  # TODO. noise_scheduler
260
 
261
  # sample xt (φ_t(x) in the paper)
262
- t = time.unsqueeze(-1).unsqueeze(-1)
263
  φ = (1 - t) * x0 + t * x1
264
  flow = x1 - x0
265
 
266
  # only predict what is within the random mask span for infilling
267
- cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
 
 
 
268
 
269
  # transformer and cfg training with a drop rate
270
  drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
@@ -273,15 +261,13 @@ class CFM(nn.Module):
273
  drop_text = True
274
  else:
275
  drop_text = False
276
-
277
  # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
278
  # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
279
- pred = self.transformer(
280
- x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
281
- )
282
 
283
  # flow matching loss
284
- loss = F.mse_loss(pred, flow, reduction="none")
285
  loss = loss[rand_span_mask]
286
 
287
  return loss.mean(), cond, pred
 
18
 
19
  from torchdiffeq import odeint
20
 
21
+ from einops import rearrange
22
+
23
  from model.modules import MelSpec
24
+
25
  from model.utils import (
26
+ default, exists,
27
+ list_str_to_idx, list_str_to_tensor,
28
+ lens_to_mask, mask_from_frac_lengths,
29
+ )
 
 
 
30
 
31
 
32
  class CFM(nn.Module):
33
  def __init__(
34
  self,
35
  transformer: nn.Module,
36
+ sigma = 0.,
37
  odeint_kwargs: dict = dict(
38
  # atol = 1e-5,
39
  # rtol = 1e-5,
40
+ method = 'euler' # 'midpoint'
41
  ),
42
+ audio_drop_prob = 0.3,
43
+ cond_drop_prob = 0.2,
44
+ num_channels = None,
45
  mel_spec_module: nn.Module | None = None,
46
  mel_spec_kwargs: dict = dict(),
47
+ frac_lengths_mask: tuple[float, float] = (0.7, 1.),
48
+ vocab_char_map: dict[str: int] | None = None
49
  ):
50
  super().__init__()
51
 
 
81
  @torch.no_grad()
82
  def sample(
83
  self,
84
+ cond: float['b n d'] | float['b nw'],
85
+ text: int['b nt'] | list[str],
86
+ duration: int | int['b'],
87
  *,
88
+ lens: int['b'] | None = None,
89
+ steps = 32,
90
+ cfg_strength = 1.,
91
+ sway_sampling_coef = None,
92
  seed: int | None = None,
93
+ max_duration = 4096,
94
+ vocoder: Callable[[float['b d n']], float['b nw']] | None = None,
95
+ no_ref_audio = False,
96
+ duplicate_test = False,
97
+ t_inter = 0.1,
 
98
  ):
99
  self.eval()
100
 
 
 
 
101
  # raw wave
102
 
103
  if cond.ndim == 2:
104
  cond = self.mel_spec(cond)
105
+ cond = rearrange(cond, 'b d n -> b n d')
106
  assert cond.shape[-1] == self.num_channels
107
 
108
  batch, cond_seq_len, device = *cond.shape[:2], cond.device
109
  if not exists(lens):
110
+ lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
111
 
112
  # text
113
 
 
119
  assert text.shape[0] == batch
120
 
121
  if exists(text):
122
+ text_lens = (text != -1).sum(dim = -1)
123
+ lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
124
 
125
  # duration
126
 
127
  cond_mask = lens_to_mask(lens)
 
 
128
 
129
  if isinstance(duration, int):
130
+ duration = torch.full((batch,), duration, device = device, dtype = torch.long)
131
 
132
+ duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
133
+ duration = duration.clamp(max = max_duration)
134
  max_duration = duration.amax()
135
+
136
  # duplicate test corner for inner time step oberservation
137
  if duplicate_test:
138
+ test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.)
139
+
140
+ cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
141
+ cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
142
+ cond_mask = rearrange(cond_mask, '... -> ... 1')
143
+ step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
 
 
144
 
145
+ mask = lens_to_mask(duration)
 
 
 
146
 
147
  # test for no ref audio
148
  if no_ref_audio:
 
155
  # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
156
 
157
  # predict flow
158
+ pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False)
 
 
159
  if cfg_strength < 1e-5:
160
  return pred
161
+
162
+ null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True)
 
 
163
  return pred + (pred - null_pred) * cfg_strength
164
 
165
  # noise input
 
169
  for dur in duration:
170
  if exists(seed):
171
  torch.manual_seed(seed)
172
+ y0.append(torch.randn(dur, self.num_channels, device = self.device))
173
+ y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
174
 
175
  t_start = 0
176
 
 
180
  y0 = (1 - t_start) * y0 + t_start * test_cond
181
  steps = int(steps * (1 - t_start))
182
 
183
+ t = torch.linspace(t_start, 1, steps, device = self.device)
184
  if sway_sampling_coef is not None:
185
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
186
 
187
  trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
188
+
189
  sampled = trajectory[-1]
190
  out = sampled
191
  out = torch.where(cond_mask, cond, out)
192
 
193
  if exists(vocoder):
194
+ out = rearrange(out, 'b n d -> b d n')
195
  out = vocoder(out)
196
 
197
  return out, trajectory
198
 
199
  def forward(
200
  self,
201
+ inp: float['b n d'] | float['b nw'], # mel or raw wave
202
+ text: int['b nt'] | list[str],
203
  *,
204
+ lens: int['b'] | None = None,
205
  noise_scheduler: str | None = None,
206
  ):
207
  # handle raw wave
208
  if inp.ndim == 2:
209
  inp = self.mel_spec(inp)
210
+ inp = rearrange(inp, 'b d n -> b n d')
211
  assert inp.shape[-1] == self.num_channels
212
 
213
+ batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
214
 
215
  # handle text as string
216
  if isinstance(text, list):
 
222
 
223
  # lens and mask
224
  if not exists(lens):
225
+ lens = torch.full((batch,), seq_len, device = device)
226
+
227
+ mask = lens_to_mask(lens, length = seq_len) # useless here, as collate_fn will pad to max length in batch
228
 
229
  # get a random span to mask out for training conditionally
230
+ frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
231
  rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
232
 
233
  if exists(mask):
 
240
  x0 = torch.randn_like(x1)
241
 
242
  # time step
243
+ time = torch.rand((batch,), dtype = dtype, device = self.device)
244
  # TODO. noise_scheduler
245
 
246
  # sample xt (φ_t(x) in the paper)
247
+ t = rearrange(time, 'b -> b 1 1')
248
  φ = (1 - t) * x0 + t * x1
249
  flow = x1 - x0
250
 
251
  # only predict what is within the random mask span for infilling
252
+ cond = torch.where(
253
+ rand_span_mask[..., None],
254
+ torch.zeros_like(x1), x1
255
+ )
256
 
257
  # transformer and cfg training with a drop rate
258
  drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
 
261
  drop_text = True
262
  else:
263
  drop_text = False
264
+
265
  # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
266
  # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
267
+ pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text)
 
 
268
 
269
  # flow matching loss
270
+ loss = F.mse_loss(pred, flow, reduction = 'none')
271
  loss = loss[rand_span_mask]
272
 
273
  return loss.mean(), cond, pred
model/dataset.py CHANGED
@@ -6,67 +6,65 @@ import torch
6
  import torch.nn.functional as F
7
  from torch.utils.data import Dataset, Sampler
8
  import torchaudio
9
- from datasets import load_from_disk
10
  from datasets import Dataset as Dataset_
11
- from torch import nn
 
12
 
13
  from model.modules import MelSpec
14
- from model.utils import default
15
 
16
 
17
  class HFDataset(Dataset):
18
  def __init__(
19
  self,
20
  hf_dataset: Dataset,
21
- target_sample_rate=24_000,
22
- n_mel_channels=100,
23
- hop_length=256,
24
  ):
25
  self.data = hf_dataset
26
  self.target_sample_rate = target_sample_rate
27
  self.hop_length = hop_length
28
- self.mel_spectrogram = MelSpec(
29
- target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
30
- )
31
-
32
  def get_frame_len(self, index):
33
  row = self.data[index]
34
- audio = row["audio"]["array"]
35
- sample_rate = row["audio"]["sampling_rate"]
36
  return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
37
 
38
  def __len__(self):
39
  return len(self.data)
40
-
41
  def __getitem__(self, index):
42
  row = self.data[index]
43
- audio = row["audio"]["array"]
44
 
45
  # logger.info(f"Audio shape: {audio.shape}")
46
 
47
- sample_rate = row["audio"]["sampling_rate"]
48
  duration = audio.shape[-1] / sample_rate
49
 
50
  if duration > 30 or duration < 0.3:
51
  return self.__getitem__((index + 1) % len(self.data))
52
-
53
  audio_tensor = torch.from_numpy(audio).float()
54
-
55
  if sample_rate != self.target_sample_rate:
56
  resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
57
  audio_tensor = resampler(audio_tensor)
58
-
59
- audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
60
-
61
  mel_spec = self.mel_spectrogram(audio_tensor)
62
-
63
- mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
64
-
65
- text = row["text"]
66
-
67
  return dict(
68
- mel_spec=mel_spec,
69
- text=text,
70
  )
71
 
72
 
@@ -74,39 +72,28 @@ class CustomDataset(Dataset):
74
  def __init__(
75
  self,
76
  custom_dataset: Dataset,
77
- durations=None,
78
- target_sample_rate=24_000,
79
- hop_length=256,
80
- n_mel_channels=100,
81
- preprocessed_mel=False,
82
- mel_spec_module: nn.Module | None = None,
83
  ):
84
  self.data = custom_dataset
85
  self.durations = durations
86
  self.target_sample_rate = target_sample_rate
87
  self.hop_length = hop_length
88
  self.preprocessed_mel = preprocessed_mel
89
-
90
  if not preprocessed_mel:
91
- self.mel_spectrogram = default(
92
- mel_spec_module,
93
- MelSpec(
94
- target_sample_rate=target_sample_rate,
95
- hop_length=hop_length,
96
- n_mel_channels=n_mel_channels,
97
- ),
98
- )
99
 
100
  def get_frame_len(self, index):
101
- if (
102
- self.durations is not None
103
- ): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
104
  return self.durations[index] * self.target_sample_rate / self.hop_length
105
  return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
106
-
107
  def __len__(self):
108
  return len(self.data)
109
-
110
  def __getitem__(self, index):
111
  row = self.data[index]
112
  audio_path = row["audio_path"]
@@ -118,57 +105,48 @@ class CustomDataset(Dataset):
118
 
119
  else:
120
  audio, source_sample_rate = torchaudio.load(audio_path)
121
- if audio.shape[0] > 1:
122
- audio = torch.mean(audio, dim=0, keepdim=True)
123
 
124
  if duration > 30 or duration < 0.3:
125
  return self.__getitem__((index + 1) % len(self.data))
126
-
127
  if source_sample_rate != self.target_sample_rate:
128
  resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
129
  audio = resampler(audio)
130
-
131
  mel_spec = self.mel_spectrogram(audio)
132
- mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
133
-
134
  return dict(
135
- mel_spec=mel_spec,
136
- text=text,
137
  )
138
-
139
 
140
  # Dynamic Batch Sampler
141
 
142
-
143
  class DynamicBatchSampler(Sampler[list[int]]):
144
- """Extension of Sampler that will do the following:
145
- 1. Change the batch size (essentially number of sequences)
146
- in a batch to ensure that the total number of frames are less
147
- than a certain threshold.
148
- 2. Make sure the padding efficiency in the batch is high.
149
  """
150
 
151
- def __init__(
152
- self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
153
- ):
154
  self.sampler = sampler
155
  self.frames_threshold = frames_threshold
156
  self.max_samples = max_samples
157
 
158
  indices, batches = [], []
159
  data_source = self.sampler.data_source
160
-
161
- for idx in tqdm(
162
- self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration"
163
- ):
164
  indices.append((idx, data_source.get_frame_len(idx)))
165
- indices.sort(key=lambda elem: elem[1])
166
 
167
  batch = []
168
  batch_frames = 0
169
- for idx, frame_len in tqdm(
170
- indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
171
- ):
172
  if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
173
  batch.append(idx)
174
  batch_frames += frame_len
@@ -204,91 +182,61 @@ class DynamicBatchSampler(Sampler[list[int]]):
204
 
205
  # Load dataset
206
 
207
-
208
  def load_dataset(
209
- dataset_name: str,
210
- tokenizer: str = "pinyin",
211
- dataset_type: str = "CustomDataset",
212
- audio_type: str = "raw",
213
- mel_spec_module: nn.Module | None = None,
214
- mel_spec_kwargs: dict = dict(),
215
- ) -> CustomDataset | HFDataset:
216
- """
217
- dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
218
- - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
219
- """
220
-
221
  print("Loading dataset ...")
222
 
223
  if dataset_type == "CustomDataset":
224
  if audio_type == "raw":
225
  try:
226
  train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
227
- except: # noqa: E722
228
  train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
229
  preprocessed_mel = False
230
  elif audio_type == "mel":
231
  train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
232
  preprocessed_mel = True
233
- with open(f"data/{dataset_name}_{tokenizer}/duration.json", "r", encoding="utf-8") as f:
234
- data_dict = json.load(f)
235
- durations = data_dict["duration"]
236
- train_dataset = CustomDataset(
237
- train_dataset,
238
- durations=durations,
239
- preprocessed_mel=preprocessed_mel,
240
- mel_spec_module=mel_spec_module,
241
- **mel_spec_kwargs,
242
- )
243
-
244
- elif dataset_type == "CustomDatasetPath":
245
- try:
246
- train_dataset = load_from_disk(f"{dataset_name}/raw")
247
- except: # noqa: E722
248
- train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
249
-
250
- with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f:
251
  data_dict = json.load(f)
252
  durations = data_dict["duration"]
253
- train_dataset = CustomDataset(
254
- train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
255
- )
256
 
257
  elif dataset_type == "HFDataset":
258
- print(
259
- "Should manually modify the path of huggingface dataset to your need.\n"
260
- + "May also the corresponding script cuz different dataset may have different format."
261
- )
262
  pre, post = dataset_name.split("_")
263
- train_dataset = HFDataset(
264
- load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),
265
- )
266
 
267
  return train_dataset
268
 
269
 
270
  # collation
271
 
272
-
273
  def collate_fn(batch):
274
- mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
275
  mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
276
  max_mel_length = mel_lengths.amax()
277
 
278
  padded_mel_specs = []
279
  for spec in mel_specs: # TODO. maybe records mask for attention here
280
  padding = (0, max_mel_length - spec.size(-1))
281
- padded_spec = F.pad(spec, padding, value=0)
282
  padded_mel_specs.append(padded_spec)
283
-
284
  mel_specs = torch.stack(padded_mel_specs)
285
 
286
- text = [item["text"] for item in batch]
287
  text_lengths = torch.LongTensor([len(item) for item in text])
288
 
289
  return dict(
290
- mel=mel_specs,
291
- mel_lengths=mel_lengths,
292
- text=text,
293
- text_lengths=text_lengths,
294
  )
 
6
  import torch.nn.functional as F
7
  from torch.utils.data import Dataset, Sampler
8
  import torchaudio
9
+ from datasets import load_dataset, load_from_disk
10
  from datasets import Dataset as Dataset_
11
+
12
+ from einops import rearrange
13
 
14
  from model.modules import MelSpec
 
15
 
16
 
17
  class HFDataset(Dataset):
18
  def __init__(
19
  self,
20
  hf_dataset: Dataset,
21
+ target_sample_rate = 24_000,
22
+ n_mel_channels = 100,
23
+ hop_length = 256,
24
  ):
25
  self.data = hf_dataset
26
  self.target_sample_rate = target_sample_rate
27
  self.hop_length = hop_length
28
+ self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
29
+
 
 
30
  def get_frame_len(self, index):
31
  row = self.data[index]
32
+ audio = row['audio']['array']
33
+ sample_rate = row['audio']['sampling_rate']
34
  return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
35
 
36
  def __len__(self):
37
  return len(self.data)
38
+
39
  def __getitem__(self, index):
40
  row = self.data[index]
41
+ audio = row['audio']['array']
42
 
43
  # logger.info(f"Audio shape: {audio.shape}")
44
 
45
+ sample_rate = row['audio']['sampling_rate']
46
  duration = audio.shape[-1] / sample_rate
47
 
48
  if duration > 30 or duration < 0.3:
49
  return self.__getitem__((index + 1) % len(self.data))
50
+
51
  audio_tensor = torch.from_numpy(audio).float()
52
+
53
  if sample_rate != self.target_sample_rate:
54
  resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
55
  audio_tensor = resampler(audio_tensor)
56
+
57
+ audio_tensor = rearrange(audio_tensor, 't -> 1 t')
58
+
59
  mel_spec = self.mel_spectrogram(audio_tensor)
60
+
61
+ mel_spec = rearrange(mel_spec, '1 d t -> d t')
62
+
63
+ text = row['text']
64
+
65
  return dict(
66
+ mel_spec = mel_spec,
67
+ text = text,
68
  )
69
 
70
 
 
72
  def __init__(
73
  self,
74
  custom_dataset: Dataset,
75
+ durations = None,
76
+ target_sample_rate = 24_000,
77
+ hop_length = 256,
78
+ n_mel_channels = 100,
79
+ preprocessed_mel = False,
 
80
  ):
81
  self.data = custom_dataset
82
  self.durations = durations
83
  self.target_sample_rate = target_sample_rate
84
  self.hop_length = hop_length
85
  self.preprocessed_mel = preprocessed_mel
 
86
  if not preprocessed_mel:
87
+ self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels)
 
 
 
 
 
 
 
88
 
89
  def get_frame_len(self, index):
90
+ if self.durations is not None: # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
 
 
91
  return self.durations[index] * self.target_sample_rate / self.hop_length
92
  return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
93
+
94
  def __len__(self):
95
  return len(self.data)
96
+
97
  def __getitem__(self, index):
98
  row = self.data[index]
99
  audio_path = row["audio_path"]
 
105
 
106
  else:
107
  audio, source_sample_rate = torchaudio.load(audio_path)
 
 
108
 
109
  if duration > 30 or duration < 0.3:
110
  return self.__getitem__((index + 1) % len(self.data))
111
+
112
  if source_sample_rate != self.target_sample_rate:
113
  resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
114
  audio = resampler(audio)
115
+
116
  mel_spec = self.mel_spectrogram(audio)
117
+ mel_spec = rearrange(mel_spec, '1 d t -> d t')
118
+
119
  return dict(
120
+ mel_spec = mel_spec,
121
+ text = text,
122
  )
123
+
124
 
125
  # Dynamic Batch Sampler
126
 
 
127
  class DynamicBatchSampler(Sampler[list[int]]):
128
+ """ Extension of Sampler that will do the following:
129
+ 1. Change the batch size (essentially number of sequences)
130
+ in a batch to ensure that the total number of frames are less
131
+ than a certain threshold.
132
+ 2. Make sure the padding efficiency in the batch is high.
133
  """
134
 
135
+ def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False):
 
 
136
  self.sampler = sampler
137
  self.frames_threshold = frames_threshold
138
  self.max_samples = max_samples
139
 
140
  indices, batches = [], []
141
  data_source = self.sampler.data_source
142
+
143
+ for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"):
 
 
144
  indices.append((idx, data_source.get_frame_len(idx)))
145
+ indices.sort(key=lambda elem : elem[1])
146
 
147
  batch = []
148
  batch_frames = 0
149
+ for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"):
 
 
150
  if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
151
  batch.append(idx)
152
  batch_frames += frame_len
 
182
 
183
  # Load dataset
184
 
 
185
  def load_dataset(
186
+ dataset_name: str,
187
+ tokenizer: str,
188
+ dataset_type: str = "CustomDataset",
189
+ audio_type: str = "raw",
190
+ mel_spec_kwargs: dict = dict()
191
+ ) -> CustomDataset | HFDataset:
192
+
 
 
 
 
 
193
  print("Loading dataset ...")
194
 
195
  if dataset_type == "CustomDataset":
196
  if audio_type == "raw":
197
  try:
198
  train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
199
+ except:
200
  train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
201
  preprocessed_mel = False
202
  elif audio_type == "mel":
203
  train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
204
  preprocessed_mel = True
205
+ with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  data_dict = json.load(f)
207
  durations = data_dict["duration"]
208
+ train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
 
 
209
 
210
  elif dataset_type == "HFDataset":
211
+ print("Should manually modify the path of huggingface dataset to your need.\n" +
212
+ "May also the corresponding script cuz different dataset may have different format.")
 
 
213
  pre, post = dataset_name.split("_")
214
+ train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),)
 
 
215
 
216
  return train_dataset
217
 
218
 
219
  # collation
220
 
 
221
  def collate_fn(batch):
222
+ mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
223
  mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
224
  max_mel_length = mel_lengths.amax()
225
 
226
  padded_mel_specs = []
227
  for spec in mel_specs: # TODO. maybe records mask for attention here
228
  padding = (0, max_mel_length - spec.size(-1))
229
+ padded_spec = F.pad(spec, padding, value = 0)
230
  padded_mel_specs.append(padded_spec)
231
+
232
  mel_specs = torch.stack(padded_mel_specs)
233
 
234
+ text = [item['text'] for item in batch]
235
  text_lengths = torch.LongTensor([len(item) for item in text])
236
 
237
  return dict(
238
+ mel = mel_specs,
239
+ mel_lengths = mel_lengths,
240
+ text = text,
241
+ text_lengths = text_lengths,
242
  )
model/ecapa_tdnn.py CHANGED
@@ -9,14 +9,13 @@ import torch.nn as nn
9
  import torch.nn.functional as F
10
 
11
 
12
- """ Res2Conv1d + BatchNorm1d + ReLU
13
- """
14
-
15
 
16
  class Res2Conv1dReluBn(nn.Module):
17
- """
18
  in_channels == out_channels == channels
19
- """
20
 
21
  def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
22
  super().__init__()
@@ -52,9 +51,8 @@ class Res2Conv1dReluBn(nn.Module):
52
  return out
53
 
54
 
55
- """ Conv1d + BatchNorm1d + ReLU
56
- """
57
-
58
 
59
  class Conv1dReluBn(nn.Module):
60
  def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
@@ -66,9 +64,8 @@ class Conv1dReluBn(nn.Module):
66
  return self.bn(F.relu(self.conv(x)))
67
 
68
 
69
- """ The SE connection of 1D case.
70
- """
71
-
72
 
73
  class SE_Connect(nn.Module):
74
  def __init__(self, channels, se_bottleneck_dim=128):
@@ -85,8 +82,8 @@ class SE_Connect(nn.Module):
85
  return out
86
 
87
 
88
- """ SE-Res2Block of the ECAPA-TDNN architecture.
89
- """
90
 
91
  # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
92
  # return nn.Sequential(
@@ -96,7 +93,6 @@ class SE_Connect(nn.Module):
96
  # SE_Connect(channels)
97
  # )
98
 
99
-
100
  class SE_Res2Block(nn.Module):
101
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
102
  super().__init__()
@@ -126,9 +122,8 @@ class SE_Res2Block(nn.Module):
126
  return x + residual
127
 
128
 
129
- """ Attentive weighted mean and standard deviation pooling.
130
- """
131
-
132
 
133
  class AttentiveStatsPool(nn.Module):
134
  def __init__(self, in_dim, attention_channels=128, global_context_att=False):
@@ -143,6 +138,7 @@ class AttentiveStatsPool(nn.Module):
143
  self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
144
 
145
  def forward(self, x):
 
146
  if self.global_context_att:
147
  context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
148
  context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
@@ -155,52 +151,38 @@ class AttentiveStatsPool(nn.Module):
155
  # alpha = F.relu(self.linear1(x_in))
156
  alpha = torch.softmax(self.linear2(alpha), dim=2)
157
  mean = torch.sum(alpha * x, dim=2)
158
- residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
159
  std = torch.sqrt(residuals.clamp(min=1e-9))
160
  return torch.cat([mean, std], dim=1)
161
 
162
 
163
  class ECAPA_TDNN(nn.Module):
164
- def __init__(
165
- self,
166
- feat_dim=80,
167
- channels=512,
168
- emb_dim=192,
169
- global_context_att=False,
170
- feat_type="wavlm_large",
171
- sr=16000,
172
- feature_selection="hidden_states",
173
- update_extract=False,
174
- config_path=None,
175
- ):
176
  super().__init__()
177
 
178
  self.feat_type = feat_type
179
  self.feature_selection = feature_selection
180
  self.update_extract = update_extract
181
  self.sr = sr
182
-
183
- torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
184
  try:
185
  local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
186
- self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path)
187
- except: # noqa: E722
188
- self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
189
 
190
- if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
191
- self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
192
- ):
193
  self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
194
- if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
195
- self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
196
- ):
197
  self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
198
 
199
  self.feat_num = self.get_feat_num()
200
  self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
201
 
202
- if feat_type != "fbank" and feat_type != "mfcc":
203
- freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"]
204
  for name, param in self.feature_extract.named_parameters():
205
  for freeze_val in freeze_list:
206
  if freeze_val in name:
@@ -216,46 +198,18 @@ class ECAPA_TDNN(nn.Module):
216
  self.channels = [channels] * 4 + [1536]
217
 
218
  self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
219
- self.layer2 = SE_Res2Block(
220
- self.channels[0],
221
- self.channels[1],
222
- kernel_size=3,
223
- stride=1,
224
- padding=2,
225
- dilation=2,
226
- scale=8,
227
- se_bottleneck_dim=128,
228
- )
229
- self.layer3 = SE_Res2Block(
230
- self.channels[1],
231
- self.channels[2],
232
- kernel_size=3,
233
- stride=1,
234
- padding=3,
235
- dilation=3,
236
- scale=8,
237
- se_bottleneck_dim=128,
238
- )
239
- self.layer4 = SE_Res2Block(
240
- self.channels[2],
241
- self.channels[3],
242
- kernel_size=3,
243
- stride=1,
244
- padding=4,
245
- dilation=4,
246
- scale=8,
247
- se_bottleneck_dim=128,
248
- )
249
 
250
  # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
251
  cat_channels = channels * 3
252
  self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
253
- self.pooling = AttentiveStatsPool(
254
- self.channels[-1], attention_channels=128, global_context_att=global_context_att
255
- )
256
  self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
257
  self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
258
 
 
259
  def get_feat_num(self):
260
  self.feature_extract.eval()
261
  wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
@@ -272,12 +226,12 @@ class ECAPA_TDNN(nn.Module):
272
  x = self.feature_extract([sample for sample in x])
273
  else:
274
  with torch.no_grad():
275
- if self.feat_type == "fbank" or self.feat_type == "mfcc":
276
  x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
277
  else:
278
  x = self.feature_extract([sample for sample in x])
279
 
280
- if self.feat_type == "fbank":
281
  x = x.log()
282
 
283
  if self.feat_type != "fbank" and self.feat_type != "mfcc":
@@ -309,22 +263,6 @@ class ECAPA_TDNN(nn.Module):
309
  return out
310
 
311
 
312
- def ECAPA_TDNN_SMALL(
313
- feat_dim,
314
- emb_dim=256,
315
- feat_type="wavlm_large",
316
- sr=16000,
317
- feature_selection="hidden_states",
318
- update_extract=False,
319
- config_path=None,
320
- ):
321
- return ECAPA_TDNN(
322
- feat_dim=feat_dim,
323
- channels=512,
324
- emb_dim=emb_dim,
325
- feat_type=feat_type,
326
- sr=sr,
327
- feature_selection=feature_selection,
328
- update_extract=update_extract,
329
- config_path=config_path,
330
- )
 
9
  import torch.nn.functional as F
10
 
11
 
12
+ ''' Res2Conv1d + BatchNorm1d + ReLU
13
+ '''
 
14
 
15
  class Res2Conv1dReluBn(nn.Module):
16
+ '''
17
  in_channels == out_channels == channels
18
+ '''
19
 
20
  def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
21
  super().__init__()
 
51
  return out
52
 
53
 
54
+ ''' Conv1d + BatchNorm1d + ReLU
55
+ '''
 
56
 
57
  class Conv1dReluBn(nn.Module):
58
  def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
 
64
  return self.bn(F.relu(self.conv(x)))
65
 
66
 
67
+ ''' The SE connection of 1D case.
68
+ '''
 
69
 
70
  class SE_Connect(nn.Module):
71
  def __init__(self, channels, se_bottleneck_dim=128):
 
82
  return out
83
 
84
 
85
+ ''' SE-Res2Block of the ECAPA-TDNN architecture.
86
+ '''
87
 
88
  # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
89
  # return nn.Sequential(
 
93
  # SE_Connect(channels)
94
  # )
95
 
 
96
  class SE_Res2Block(nn.Module):
97
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
98
  super().__init__()
 
122
  return x + residual
123
 
124
 
125
+ ''' Attentive weighted mean and standard deviation pooling.
126
+ '''
 
127
 
128
  class AttentiveStatsPool(nn.Module):
129
  def __init__(self, in_dim, attention_channels=128, global_context_att=False):
 
138
  self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
139
 
140
  def forward(self, x):
141
+
142
  if self.global_context_att:
143
  context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
144
  context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
 
151
  # alpha = F.relu(self.linear1(x_in))
152
  alpha = torch.softmax(self.linear2(alpha), dim=2)
153
  mean = torch.sum(alpha * x, dim=2)
154
+ residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
155
  std = torch.sqrt(residuals.clamp(min=1e-9))
156
  return torch.cat([mean, std], dim=1)
157
 
158
 
159
  class ECAPA_TDNN(nn.Module):
160
+ def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
161
+ feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
 
 
 
 
 
 
 
 
 
 
162
  super().__init__()
163
 
164
  self.feat_type = feat_type
165
  self.feature_selection = feature_selection
166
  self.update_extract = update_extract
167
  self.sr = sr
168
+
169
+ torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
170
  try:
171
  local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
172
+ self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source='local', config_path=config_path)
173
+ except:
174
+ self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
175
 
176
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
 
 
177
  self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
178
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
 
 
179
  self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
180
 
181
  self.feat_num = self.get_feat_num()
182
  self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
183
 
184
+ if feat_type != 'fbank' and feat_type != 'mfcc':
185
+ freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
186
  for name, param in self.feature_extract.named_parameters():
187
  for freeze_val in freeze_list:
188
  if freeze_val in name:
 
198
  self.channels = [channels] * 4 + [1536]
199
 
200
  self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
201
+ self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
202
+ self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
203
+ self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
206
  cat_channels = channels * 3
207
  self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
208
+ self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
 
 
209
  self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
210
  self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
211
 
212
+
213
  def get_feat_num(self):
214
  self.feature_extract.eval()
215
  wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
 
226
  x = self.feature_extract([sample for sample in x])
227
  else:
228
  with torch.no_grad():
229
+ if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
230
  x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
231
  else:
232
  x = self.feature_extract([sample for sample in x])
233
 
234
+ if self.feat_type == 'fbank':
235
  x = x.log()
236
 
237
  if self.feat_type != "fbank" and self.feat_type != "mfcc":
 
263
  return out
264
 
265
 
266
+ def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
267
+ return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
268
+ feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/modules.py CHANGED
@@ -16,45 +16,45 @@ from torch import nn
16
  import torch.nn.functional as F
17
  import torchaudio
18
 
 
19
  from x_transformers.x_transformers import apply_rotary_pos_emb
20
 
21
 
22
  # raw wav to mel spec
23
 
24
-
25
  class MelSpec(nn.Module):
26
  def __init__(
27
  self,
28
- filter_length=1024,
29
- hop_length=256,
30
- win_length=1024,
31
- n_mel_channels=100,
32
- target_sample_rate=24_000,
33
- normalize=False,
34
- power=1,
35
- norm=None,
36
- center=True,
37
  ):
38
  super().__init__()
39
  self.n_mel_channels = n_mel_channels
40
 
41
  self.mel_stft = torchaudio.transforms.MelSpectrogram(
42
- sample_rate=target_sample_rate,
43
- n_fft=filter_length,
44
- win_length=win_length,
45
- hop_length=hop_length,
46
- n_mels=n_mel_channels,
47
- power=power,
48
- center=center,
49
- normalized=normalize,
50
- norm=norm,
51
  )
52
 
53
- self.register_buffer("dummy", torch.tensor(0), persistent=False)
54
 
55
  def forward(self, inp):
56
  if len(inp.shape) == 3:
57
- inp = inp.squeeze(1) # 'b 1 nw -> b nw'
58
 
59
  assert len(inp.shape) == 2
60
 
@@ -62,13 +62,12 @@ class MelSpec(nn.Module):
62
  self.to(inp.device)
63
 
64
  mel = self.mel_stft(inp)
65
- mel = mel.clamp(min=1e-5).log()
66
  return mel
67
-
68
 
69
  # sinusoidal position embedding
70
 
71
-
72
  class SinusPositionEmbedding(nn.Module):
73
  def __init__(self, dim):
74
  super().__init__()
@@ -86,37 +85,35 @@ class SinusPositionEmbedding(nn.Module):
86
 
87
  # convolutional position embedding
88
 
89
-
90
  class ConvPositionEmbedding(nn.Module):
91
- def __init__(self, dim, kernel_size=31, groups=16):
92
  super().__init__()
93
  assert kernel_size % 2 != 0
94
  self.conv1d = nn.Sequential(
95
- nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
96
  nn.Mish(),
97
- nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
98
  nn.Mish(),
99
  )
100
 
101
- def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
102
  if mask is not None:
103
  mask = mask[..., None]
104
- x = x.masked_fill(~mask, 0.0)
105
 
106
- x = x.permute(0, 2, 1)
107
  x = self.conv1d(x)
108
- out = x.permute(0, 2, 1)
109
 
110
  if mask is not None:
111
- out = out.masked_fill(~mask, 0.0)
112
 
113
  return out
114
 
115
 
116
  # rotary positional embedding related
117
 
118
-
119
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
120
  # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
121
  # has some connection to NTK literature
122
  # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
@@ -129,14 +126,12 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
129
  freqs_sin = torch.sin(freqs) # imaginary part
130
  return torch.cat([freqs_cos, freqs_sin], dim=-1)
131
 
132
-
133
- def get_pos_embed_indices(start, length, max_pos, scale=1.0):
134
  # length = length if isinstance(length, int) else length.max()
135
  scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
136
- pos = (
137
- start.unsqueeze(1)
138
- + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
139
- )
140
  # avoid extra long error.
141
  pos = torch.where(pos < max_pos, pos, max_pos - 1)
142
  return pos
@@ -144,7 +139,6 @@ def get_pos_embed_indices(start, length, max_pos, scale=1.0):
144
 
145
  # Global Response Normalization layer (Instance Normalization ?)
146
 
147
-
148
  class GRN(nn.Module):
149
  def __init__(self, dim):
150
  super().__init__()
@@ -160,7 +154,6 @@ class GRN(nn.Module):
160
  # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
161
  # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
162
 
163
-
164
  class ConvNeXtV2Block(nn.Module):
165
  def __init__(
166
  self,
@@ -170,9 +163,7 @@ class ConvNeXtV2Block(nn.Module):
170
  ):
171
  super().__init__()
172
  padding = (dilation * (7 - 1)) // 2
173
- self.dwconv = nn.Conv1d(
174
- dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
175
- ) # depthwise conv
176
  self.norm = nn.LayerNorm(dim, eps=1e-6)
177
  self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
178
  self.act = nn.GELU()
@@ -195,7 +186,6 @@ class ConvNeXtV2Block(nn.Module):
195
  # AdaLayerNormZero
196
  # return with modulated x for attn input, and params for later mlp modulation
197
 
198
-
199
  class AdaLayerNormZero(nn.Module):
200
  def __init__(self, dim):
201
  super().__init__()
@@ -205,7 +195,7 @@ class AdaLayerNormZero(nn.Module):
205
 
206
  self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
207
 
208
- def forward(self, x, emb=None):
209
  emb = self.linear(self.silu(emb))
210
  shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
211
 
@@ -216,7 +206,6 @@ class AdaLayerNormZero(nn.Module):
216
  # AdaLayerNormZero for final layer
217
  # return only with modulated x for attn input, cuz no more mlp modulation
218
 
219
-
220
  class AdaLayerNormZero_Final(nn.Module):
221
  def __init__(self, dim):
222
  super().__init__()
@@ -236,16 +225,22 @@ class AdaLayerNormZero_Final(nn.Module):
236
 
237
  # FeedForward
238
 
239
-
240
  class FeedForward(nn.Module):
241
- def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
242
  super().__init__()
243
  inner_dim = int(dim * mult)
244
  dim_out = dim_out if dim_out is not None else dim
245
 
246
  activation = nn.GELU(approximate=approximate)
247
- project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
248
- self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
 
 
 
 
 
 
 
249
 
250
  def forward(self, x):
251
  return self.ff(x)
@@ -254,7 +249,6 @@ class FeedForward(nn.Module):
254
  # Attention with possible joint part
255
  # modified from diffusers/src/diffusers/models/attention_processor.py
256
 
257
-
258
  class Attention(nn.Module):
259
  def __init__(
260
  self,
@@ -263,8 +257,8 @@ class Attention(nn.Module):
263
  heads: int = 8,
264
  dim_head: int = 64,
265
  dropout: float = 0.0,
266
- context_dim: Optional[int] = None, # if not None -> joint attention
267
- context_pre_only=None,
268
  ):
269
  super().__init__()
270
 
@@ -300,21 +294,20 @@ class Attention(nn.Module):
300
 
301
  def forward(
302
  self,
303
- x: float["b n d"], # noised input x # noqa: F722
304
- c: float["b n d"] = None, # context c # noqa: F722
305
- mask: bool["b n"] | None = None, # noqa: F722
306
- rope=None, # rotary position embedding for x
307
- c_rope=None, # rotary position embedding for c
308
  ) -> torch.Tensor:
309
  if c is not None:
310
- return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
311
  else:
312
- return self.processor(self, x, mask=mask, rope=rope)
313
 
314
 
315
  # Attention processor
316
 
317
-
318
  class AttnProcessor:
319
  def __init__(self):
320
  pass
@@ -322,10 +315,11 @@ class AttnProcessor:
322
  def __call__(
323
  self,
324
  attn: Attention,
325
- x: float["b n d"], # noised input x # noqa: F722
326
- mask: bool["b n"] | None = None, # noqa: F722
327
- rope=None, # rotary position embedding
328
  ) -> torch.FloatTensor:
 
329
  batch_size = x.shape[0]
330
 
331
  # `sample` projections.
@@ -336,7 +330,7 @@ class AttnProcessor:
336
  # apply rotary position embedding
337
  if rope is not None:
338
  freqs, xpos_scale = rope
339
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
340
 
341
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
342
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
@@ -351,7 +345,7 @@ class AttnProcessor:
351
  # mask. e.g. inference got a batch with different target durations, mask out the padding
352
  if mask is not None:
353
  attn_mask = mask
354
- attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
355
  attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
356
  else:
357
  attn_mask = None
@@ -366,16 +360,15 @@ class AttnProcessor:
366
  x = attn.to_out[1](x)
367
 
368
  if mask is not None:
369
- mask = mask.unsqueeze(-1)
370
- x = x.masked_fill(~mask, 0.0)
371
 
372
  return x
373
-
374
 
375
  # Joint Attention processor for MM-DiT
376
  # modified from diffusers/src/diffusers/models/attention_processor.py
377
 
378
-
379
  class JointAttnProcessor:
380
  def __init__(self):
381
  pass
@@ -383,11 +376,11 @@ class JointAttnProcessor:
383
  def __call__(
384
  self,
385
  attn: Attention,
386
- x: float["b n d"], # noised input x # noqa: F722
387
- c: float["b nt d"] = None, # context c, here text # noqa: F722
388
- mask: bool["b n"] | None = None, # noqa: F722
389
- rope=None, # rotary position embedding for x
390
- c_rope=None, # rotary position embedding for c
391
  ) -> torch.FloatTensor:
392
  residual = x
393
 
@@ -406,12 +399,12 @@ class JointAttnProcessor:
406
  # apply rope for context and noised input independently
407
  if rope is not None:
408
  freqs, xpos_scale = rope
409
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
410
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
411
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
412
  if c_rope is not None:
413
  freqs, xpos_scale = c_rope
414
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
415
  c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
416
  c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
417
 
@@ -428,8 +421,8 @@ class JointAttnProcessor:
428
 
429
  # mask. e.g. inference got a batch with different target durations, mask out the padding
430
  if mask is not None:
431
- attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
432
- attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
433
  attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
434
  else:
435
  attn_mask = None
@@ -440,8 +433,8 @@ class JointAttnProcessor:
440
 
441
  # Split the attention outputs.
442
  x, c = (
443
- x[:, : residual.shape[1]],
444
- x[:, residual.shape[1] :],
445
  )
446
 
447
  # linear proj
@@ -452,8 +445,8 @@ class JointAttnProcessor:
452
  c = attn.to_out_c(c)
453
 
454
  if mask is not None:
455
- mask = mask.unsqueeze(-1)
456
- x = x.masked_fill(~mask, 0.0)
457
  # c = c.masked_fill(~mask, 0.) # no mask for c (text)
458
 
459
  return x, c
@@ -461,24 +454,24 @@ class JointAttnProcessor:
461
 
462
  # DiT Block
463
 
464
-
465
  class DiTBlock(nn.Module):
466
- def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
467
- super().__init__()
468
 
 
 
 
469
  self.attn_norm = AdaLayerNormZero(dim)
470
  self.attn = Attention(
471
- processor=AttnProcessor(),
472
- dim=dim,
473
- heads=heads,
474
- dim_head=dim_head,
475
- dropout=dropout,
476
- )
477
-
478
  self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
479
- self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
480
 
481
- def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
482
  # pre-norm & modulation for attention input
483
  norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
484
 
@@ -487,7 +480,7 @@ class DiTBlock(nn.Module):
487
 
488
  # process attention output for input x
489
  x = x + gate_msa.unsqueeze(1) * attn_output
490
-
491
  norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
492
  ff_output = self.ff(norm)
493
  x = x + gate_mlp.unsqueeze(1) * ff_output
@@ -497,9 +490,8 @@ class DiTBlock(nn.Module):
497
 
498
  # MMDiT Block https://arxiv.org/abs/2403.03206
499
 
500
-
501
  class MMDiTBlock(nn.Module):
502
- r"""
503
  modified from diffusers/src/diffusers/models/attention.py
504
 
505
  notes.
@@ -508,33 +500,33 @@ class MMDiTBlock(nn.Module):
508
  context_pre_only: last layer only do prenorm + modulation cuz no more ffn
509
  """
510
 
511
- def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
512
  super().__init__()
513
 
514
  self.context_pre_only = context_pre_only
515
-
516
  self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
517
  self.attn_norm_x = AdaLayerNormZero(dim)
518
  self.attn = Attention(
519
- processor=JointAttnProcessor(),
520
- dim=dim,
521
- heads=heads,
522
- dim_head=dim_head,
523
- dropout=dropout,
524
- context_dim=dim,
525
- context_pre_only=context_pre_only,
526
- )
527
 
528
  if not context_pre_only:
529
  self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
530
- self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
531
  else:
532
  self.ff_norm_c = None
533
  self.ff_c = None
534
  self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
535
- self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
536
 
537
- def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
538
  # pre-norm & modulation for attention input
539
  if self.context_pre_only:
540
  norm_c = self.attn_norm_c(c, t)
@@ -548,7 +540,7 @@ class MMDiTBlock(nn.Module):
548
  # process attention output for context c
549
  if self.context_pre_only:
550
  c = None
551
- else: # if not last layer
552
  c = c + c_gate_msa.unsqueeze(1) * c_attn_output
553
 
554
  norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
@@ -557,7 +549,7 @@ class MMDiTBlock(nn.Module):
557
 
558
  # process attention output for input x
559
  x = x + x_gate_msa.unsqueeze(1) * x_attn_output
560
-
561
  norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
562
  x_ff_output = self.ff_x(norm_x)
563
  x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
@@ -567,15 +559,17 @@ class MMDiTBlock(nn.Module):
567
 
568
  # time step conditioning embedding
569
 
570
-
571
  class TimestepEmbedding(nn.Module):
572
  def __init__(self, dim, freq_embed_dim=256):
573
  super().__init__()
574
  self.time_embed = SinusPositionEmbedding(freq_embed_dim)
575
- self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
 
 
 
 
576
 
577
- def forward(self, timestep: float["b"]): # noqa: F821
578
  time_hidden = self.time_embed(timestep)
579
- time_hidden = time_hidden.to(timestep.dtype)
580
  time = self.time_mlp(time_hidden) # b d
581
  return time
 
16
  import torch.nn.functional as F
17
  import torchaudio
18
 
19
+ from einops import rearrange
20
  from x_transformers.x_transformers import apply_rotary_pos_emb
21
 
22
 
23
  # raw wav to mel spec
24
 
 
25
  class MelSpec(nn.Module):
26
  def __init__(
27
  self,
28
+ filter_length = 1024,
29
+ hop_length = 256,
30
+ win_length = 1024,
31
+ n_mel_channels = 100,
32
+ target_sample_rate = 24_000,
33
+ normalize = False,
34
+ power = 1,
35
+ norm = None,
36
+ center = True,
37
  ):
38
  super().__init__()
39
  self.n_mel_channels = n_mel_channels
40
 
41
  self.mel_stft = torchaudio.transforms.MelSpectrogram(
42
+ sample_rate = target_sample_rate,
43
+ n_fft = filter_length,
44
+ win_length = win_length,
45
+ hop_length = hop_length,
46
+ n_mels = n_mel_channels,
47
+ power = power,
48
+ center = center,
49
+ normalized = normalize,
50
+ norm = norm,
51
  )
52
 
53
+ self.register_buffer('dummy', torch.tensor(0), persistent = False)
54
 
55
  def forward(self, inp):
56
  if len(inp.shape) == 3:
57
+ inp = rearrange(inp, 'b 1 nw -> b nw')
58
 
59
  assert len(inp.shape) == 2
60
 
 
62
  self.to(inp.device)
63
 
64
  mel = self.mel_stft(inp)
65
+ mel = mel.clamp(min = 1e-5).log()
66
  return mel
67
+
68
 
69
  # sinusoidal position embedding
70
 
 
71
  class SinusPositionEmbedding(nn.Module):
72
  def __init__(self, dim):
73
  super().__init__()
 
85
 
86
  # convolutional position embedding
87
 
 
88
  class ConvPositionEmbedding(nn.Module):
89
+ def __init__(self, dim, kernel_size = 31, groups = 16):
90
  super().__init__()
91
  assert kernel_size % 2 != 0
92
  self.conv1d = nn.Sequential(
93
+ nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
94
  nn.Mish(),
95
+ nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
96
  nn.Mish(),
97
  )
98
 
99
+ def forward(self, x: float['b n d'], mask: bool['b n'] | None = None):
100
  if mask is not None:
101
  mask = mask[..., None]
102
+ x = x.masked_fill(~mask, 0.)
103
 
104
+ x = rearrange(x, 'b n d -> b d n')
105
  x = self.conv1d(x)
106
+ out = rearrange(x, 'b d n -> b n d')
107
 
108
  if mask is not None:
109
+ out = out.masked_fill(~mask, 0.)
110
 
111
  return out
112
 
113
 
114
  # rotary positional embedding related
115
 
116
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.):
 
117
  # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
118
  # has some connection to NTK literature
119
  # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
 
126
  freqs_sin = torch.sin(freqs) # imaginary part
127
  return torch.cat([freqs_cos, freqs_sin], dim=-1)
128
 
129
+ def get_pos_embed_indices(start, length, max_pos, scale=1.):
 
130
  # length = length if isinstance(length, int) else length.max()
131
  scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
132
+ pos = start.unsqueeze(1) + (
133
+ torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) *
134
+ scale.unsqueeze(1)).long()
 
135
  # avoid extra long error.
136
  pos = torch.where(pos < max_pos, pos, max_pos - 1)
137
  return pos
 
139
 
140
  # Global Response Normalization layer (Instance Normalization ?)
141
 
 
142
  class GRN(nn.Module):
143
  def __init__(self, dim):
144
  super().__init__()
 
154
  # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
155
  # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
156
 
 
157
  class ConvNeXtV2Block(nn.Module):
158
  def __init__(
159
  self,
 
163
  ):
164
  super().__init__()
165
  padding = (dilation * (7 - 1)) // 2
166
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv
 
 
167
  self.norm = nn.LayerNorm(dim, eps=1e-6)
168
  self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
169
  self.act = nn.GELU()
 
186
  # AdaLayerNormZero
187
  # return with modulated x for attn input, and params for later mlp modulation
188
 
 
189
  class AdaLayerNormZero(nn.Module):
190
  def __init__(self, dim):
191
  super().__init__()
 
195
 
196
  self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
197
 
198
+ def forward(self, x, emb = None):
199
  emb = self.linear(self.silu(emb))
200
  shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
201
 
 
206
  # AdaLayerNormZero for final layer
207
  # return only with modulated x for attn input, cuz no more mlp modulation
208
 
 
209
  class AdaLayerNormZero_Final(nn.Module):
210
  def __init__(self, dim):
211
  super().__init__()
 
225
 
226
  # FeedForward
227
 
 
228
  class FeedForward(nn.Module):
229
+ def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'):
230
  super().__init__()
231
  inner_dim = int(dim * mult)
232
  dim_out = dim_out if dim_out is not None else dim
233
 
234
  activation = nn.GELU(approximate=approximate)
235
+ project_in = nn.Sequential(
236
+ nn.Linear(dim, inner_dim),
237
+ activation
238
+ )
239
+ self.ff = nn.Sequential(
240
+ project_in,
241
+ nn.Dropout(dropout),
242
+ nn.Linear(inner_dim, dim_out)
243
+ )
244
 
245
  def forward(self, x):
246
  return self.ff(x)
 
249
  # Attention with possible joint part
250
  # modified from diffusers/src/diffusers/models/attention_processor.py
251
 
 
252
  class Attention(nn.Module):
253
  def __init__(
254
  self,
 
257
  heads: int = 8,
258
  dim_head: int = 64,
259
  dropout: float = 0.0,
260
+ context_dim: Optional[int] = None, # if not None -> joint attention
261
+ context_pre_only = None,
262
  ):
263
  super().__init__()
264
 
 
294
 
295
  def forward(
296
  self,
297
+ x: float['b n d'], # noised input x
298
+ c: float['b n d'] = None, # context c
299
+ mask: bool['b n'] | None = None,
300
+ rope = None, # rotary position embedding for x
301
+ c_rope = None, # rotary position embedding for c
302
  ) -> torch.Tensor:
303
  if c is not None:
304
+ return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope)
305
  else:
306
+ return self.processor(self, x, mask = mask, rope = rope)
307
 
308
 
309
  # Attention processor
310
 
 
311
  class AttnProcessor:
312
  def __init__(self):
313
  pass
 
315
  def __call__(
316
  self,
317
  attn: Attention,
318
+ x: float['b n d'], # noised input x
319
+ mask: bool['b n'] | None = None,
320
+ rope = None, # rotary position embedding
321
  ) -> torch.FloatTensor:
322
+
323
  batch_size = x.shape[0]
324
 
325
  # `sample` projections.
 
330
  # apply rotary position embedding
331
  if rope is not None:
332
  freqs, xpos_scale = rope
333
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
334
 
335
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
336
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
 
345
  # mask. e.g. inference got a batch with different target durations, mask out the padding
346
  if mask is not None:
347
  attn_mask = mask
348
+ attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
349
  attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
350
  else:
351
  attn_mask = None
 
360
  x = attn.to_out[1](x)
361
 
362
  if mask is not None:
363
+ mask = rearrange(mask, 'b n -> b n 1')
364
+ x = x.masked_fill(~mask, 0.)
365
 
366
  return x
367
+
368
 
369
  # Joint Attention processor for MM-DiT
370
  # modified from diffusers/src/diffusers/models/attention_processor.py
371
 
 
372
  class JointAttnProcessor:
373
  def __init__(self):
374
  pass
 
376
  def __call__(
377
  self,
378
  attn: Attention,
379
+ x: float['b n d'], # noised input x
380
+ c: float['b nt d'] = None, # context c, here text
381
+ mask: bool['b n'] | None = None,
382
+ rope = None, # rotary position embedding for x
383
+ c_rope = None, # rotary position embedding for c
384
  ) -> torch.FloatTensor:
385
  residual = x
386
 
 
399
  # apply rope for context and noised input independently
400
  if rope is not None:
401
  freqs, xpos_scale = rope
402
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
403
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
404
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
405
  if c_rope is not None:
406
  freqs, xpos_scale = c_rope
407
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
408
  c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
409
  c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
410
 
 
421
 
422
  # mask. e.g. inference got a batch with different target durations, mask out the padding
423
  if mask is not None:
424
+ attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text)
425
+ attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
426
  attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
427
  else:
428
  attn_mask = None
 
433
 
434
  # Split the attention outputs.
435
  x, c = (
436
+ x[:, :residual.shape[1]],
437
+ x[:, residual.shape[1]:],
438
  )
439
 
440
  # linear proj
 
445
  c = attn.to_out_c(c)
446
 
447
  if mask is not None:
448
+ mask = rearrange(mask, 'b n -> b n 1')
449
+ x = x.masked_fill(~mask, 0.)
450
  # c = c.masked_fill(~mask, 0.) # no mask for c (text)
451
 
452
  return x, c
 
454
 
455
  # DiT Block
456
 
 
457
  class DiTBlock(nn.Module):
 
 
458
 
459
+ def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1):
460
+ super().__init__()
461
+
462
  self.attn_norm = AdaLayerNormZero(dim)
463
  self.attn = Attention(
464
+ processor = AttnProcessor(),
465
+ dim = dim,
466
+ heads = heads,
467
+ dim_head = dim_head,
468
+ dropout = dropout,
469
+ )
470
+
471
  self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
472
+ self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
473
 
474
+ def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding
475
  # pre-norm & modulation for attention input
476
  norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
477
 
 
480
 
481
  # process attention output for input x
482
  x = x + gate_msa.unsqueeze(1) * attn_output
483
+
484
  norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
485
  ff_output = self.ff(norm)
486
  x = x + gate_mlp.unsqueeze(1) * ff_output
 
490
 
491
  # MMDiT Block https://arxiv.org/abs/2403.03206
492
 
 
493
  class MMDiTBlock(nn.Module):
494
+ r"""
495
  modified from diffusers/src/diffusers/models/attention.py
496
 
497
  notes.
 
500
  context_pre_only: last layer only do prenorm + modulation cuz no more ffn
501
  """
502
 
503
+ def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False):
504
  super().__init__()
505
 
506
  self.context_pre_only = context_pre_only
507
+
508
  self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
509
  self.attn_norm_x = AdaLayerNormZero(dim)
510
  self.attn = Attention(
511
+ processor = JointAttnProcessor(),
512
+ dim = dim,
513
+ heads = heads,
514
+ dim_head = dim_head,
515
+ dropout = dropout,
516
+ context_dim = dim,
517
+ context_pre_only = context_pre_only,
518
+ )
519
 
520
  if not context_pre_only:
521
  self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
522
+ self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
523
  else:
524
  self.ff_norm_c = None
525
  self.ff_c = None
526
  self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
527
+ self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
528
 
529
+ def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding
530
  # pre-norm & modulation for attention input
531
  if self.context_pre_only:
532
  norm_c = self.attn_norm_c(c, t)
 
540
  # process attention output for context c
541
  if self.context_pre_only:
542
  c = None
543
+ else: # if not last layer
544
  c = c + c_gate_msa.unsqueeze(1) * c_attn_output
545
 
546
  norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
 
549
 
550
  # process attention output for input x
551
  x = x + x_gate_msa.unsqueeze(1) * x_attn_output
552
+
553
  norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
554
  x_ff_output = self.ff_x(norm_x)
555
  x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
 
559
 
560
  # time step conditioning embedding
561
 
 
562
  class TimestepEmbedding(nn.Module):
563
  def __init__(self, dim, freq_embed_dim=256):
564
  super().__init__()
565
  self.time_embed = SinusPositionEmbedding(freq_embed_dim)
566
+ self.time_mlp = nn.Sequential(
567
+ nn.Linear(freq_embed_dim, dim),
568
+ nn.SiLU(),
569
+ nn.Linear(dim, dim)
570
+ )
571
 
572
+ def forward(self, timestep: float['b']):
573
  time_hidden = self.time_embed(timestep)
 
574
  time = self.time_mlp(time_hidden) # b d
575
  return time
model/trainer.py CHANGED
@@ -10,6 +10,8 @@ from torch.optim import AdamW
10
  from torch.utils.data import DataLoader, Dataset, SequentialSampler
11
  from torch.optim.lr_scheduler import LinearLR, SequentialLR
12
 
 
 
13
  from accelerate import Accelerator
14
  from accelerate.utils import DistributedDataParallelKwargs
15
 
@@ -22,69 +24,66 @@ from model.dataset import DynamicBatchSampler, collate_fn
22
 
23
  # trainer
24
 
25
-
26
  class Trainer:
27
  def __init__(
28
  self,
29
  model: CFM,
30
  epochs,
31
  learning_rate,
32
- num_warmup_updates=20000,
33
- save_per_updates=1000,
34
- checkpoint_path=None,
35
- batch_size=32,
36
  batch_size_type: str = "sample",
37
- max_samples=32,
38
- grad_accumulation_steps=1,
39
- max_grad_norm=1.0,
40
  noise_scheduler: str | None = None,
41
  duration_predictor: torch.nn.Module | None = None,
42
- wandb_project="test_e2-tts",
43
- wandb_run_name="test_run",
44
  wandb_resume_id: str = None,
45
- last_per_steps=None,
46
  accelerate_kwargs: dict = dict(),
47
- ema_kwargs: dict = dict(),
48
- bnb_optimizer: bool = False,
49
  ):
50
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
51
-
52
- logger = "wandb" if wandb.api.api_key else None
53
- print(f"Using logger: {logger}")
54
 
55
  self.accelerator = Accelerator(
56
- log_with=logger,
57
- kwargs_handlers=[ddp_kwargs],
58
- gradient_accumulation_steps=grad_accumulation_steps,
59
- **accelerate_kwargs,
60
  )
61
-
62
- if logger == "wandb":
63
- if exists(wandb_resume_id):
64
- init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
65
- else:
66
- init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
67
- self.accelerator.init_trackers(
68
- project_name=wandb_project,
69
- init_kwargs=init_kwargs,
70
- config={
71
- "epochs": epochs,
72
  "learning_rate": learning_rate,
73
- "num_warmup_updates": num_warmup_updates,
74
  "batch_size": batch_size,
75
  "batch_size_type": batch_size_type,
76
  "max_samples": max_samples,
77
  "grad_accumulation_steps": grad_accumulation_steps,
78
  "max_grad_norm": max_grad_norm,
79
  "gpus": self.accelerator.num_processes,
80
- "noise_scheduler": noise_scheduler,
81
- },
82
  )
83
 
84
  self.model = model
85
 
86
  if self.is_main:
87
- self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
 
 
 
 
88
 
89
  self.ema_model.to(self.accelerator.device)
90
 
@@ -92,7 +91,7 @@ class Trainer:
92
  self.num_warmup_updates = num_warmup_updates
93
  self.save_per_updates = save_per_updates
94
  self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
95
- self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
96
 
97
  self.batch_size = batch_size
98
  self.batch_size_type = batch_size_type
@@ -104,13 +103,10 @@ class Trainer:
104
 
105
  self.duration_predictor = duration_predictor
106
 
107
- if bnb_optimizer:
108
- import bitsandbytes as bnb
109
-
110
- self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
111
- else:
112
- self.optimizer = AdamW(model.parameters(), lr=learning_rate)
113
- self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
114
 
115
  @property
116
  def is_main(self):
@@ -120,112 +116,76 @@ class Trainer:
120
  self.accelerator.wait_for_everyone()
121
  if self.is_main:
122
  checkpoint = dict(
123
- model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
124
- optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
125
- ema_model_state_dict=self.ema_model.state_dict(),
126
- scheduler_state_dict=self.scheduler.state_dict(),
127
- step=step,
128
  )
129
  if not os.path.exists(self.checkpoint_path):
130
  os.makedirs(self.checkpoint_path)
131
- if last:
132
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
133
  print(f"Saved last checkpoint at step {step}")
134
  else:
135
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
136
 
137
  def load_checkpoint(self):
138
- if (
139
- not exists(self.checkpoint_path)
140
- or not os.path.exists(self.checkpoint_path)
141
- or not os.listdir(self.checkpoint_path)
142
- ):
143
  return 0
144
-
145
  self.accelerator.wait_for_everyone()
146
  if "model_last.pt" in os.listdir(self.checkpoint_path):
147
  latest_checkpoint = "model_last.pt"
148
  else:
149
- latest_checkpoint = sorted(
150
- [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
151
- key=lambda x: int("".join(filter(str.isdigit, x))),
152
- )[-1]
153
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
154
- checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
 
 
155
 
156
  if self.is_main:
157
- self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
158
-
159
- if "step" in checkpoint:
160
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
161
- self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
162
- if self.scheduler:
163
- self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
164
- step = checkpoint["step"]
165
- else:
166
- checkpoint["model_state_dict"] = {
167
- k.replace("ema_model.", ""): v
168
- for k, v in checkpoint["ema_model_state_dict"].items()
169
- if k not in ["initted", "step"]
170
- }
171
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
172
- step = 0
173
-
174
- del checkpoint
175
- gc.collect()
176
  return step
177
 
178
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
 
179
  if exists(resumable_with_seed):
180
  generator = torch.Generator()
181
  generator.manual_seed(resumable_with_seed)
182
- else:
183
  generator = None
184
 
185
  if self.batch_size_type == "sample":
186
- train_dataloader = DataLoader(
187
- train_dataset,
188
- collate_fn=collate_fn,
189
- num_workers=num_workers,
190
- pin_memory=True,
191
- persistent_workers=True,
192
- batch_size=self.batch_size,
193
- shuffle=True,
194
- generator=generator,
195
- )
196
  elif self.batch_size_type == "frame":
197
  self.accelerator.even_batches = False
198
  sampler = SequentialSampler(train_dataset)
199
- batch_sampler = DynamicBatchSampler(
200
- sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
201
- )
202
- train_dataloader = DataLoader(
203
- train_dataset,
204
- collate_fn=collate_fn,
205
- num_workers=num_workers,
206
- pin_memory=True,
207
- persistent_workers=True,
208
- batch_sampler=batch_sampler,
209
- )
210
  else:
211
- raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
212
-
213
  # accelerator.prepare() dispatches batches to devices;
214
  # which means the length of dataloader calculated before, should consider the number of devices
215
- warmup_steps = (
216
- self.num_warmup_updates * self.accelerator.num_processes
217
- ) # consider a fixed warmup steps while using accelerate multi-gpu ddp
218
- # otherwise by default with split_batches=False, warmup steps change with num_processes
219
  total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
220
  decay_steps = total_steps - warmup_steps
221
  warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
222
  decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
223
- self.scheduler = SequentialLR(
224
- self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
225
- )
226
- train_dataloader, self.scheduler = self.accelerator.prepare(
227
- train_dataloader, self.scheduler
228
- ) # actual steps = 1 gpu steps / gpus
229
  start_step = self.load_checkpoint()
230
  global_step = start_step
231
 
@@ -240,36 +200,23 @@ class Trainer:
240
  for epoch in range(skipped_epoch, self.epochs):
241
  self.model.train()
242
  if exists(resumable_with_seed) and epoch == skipped_epoch:
243
- progress_bar = tqdm(
244
- skipped_dataloader,
245
- desc=f"Epoch {epoch+1}/{self.epochs}",
246
- unit="step",
247
- disable=not self.accelerator.is_local_main_process,
248
- initial=skipped_batch,
249
- total=orig_epoch_step,
250
- )
251
  else:
252
- progress_bar = tqdm(
253
- train_dataloader,
254
- desc=f"Epoch {epoch+1}/{self.epochs}",
255
- unit="step",
256
- disable=not self.accelerator.is_local_main_process,
257
- )
258
 
259
  for batch in progress_bar:
260
  with self.accelerator.accumulate(self.model):
261
- text_inputs = batch["text"]
262
- mel_spec = batch["mel"].permute(0, 2, 1)
263
  mel_lengths = batch["mel_lengths"]
264
 
265
  # TODO. add duration predictor training
266
  if self.duration_predictor is not None and self.accelerator.is_local_main_process:
267
- dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
268
  self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
269
 
270
- loss, cond, pred = self.model(
271
- mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
272
- )
273
  self.accelerator.backward(loss)
274
 
275
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
@@ -286,15 +233,13 @@ class Trainer:
286
 
287
  if self.accelerator.is_local_main_process:
288
  self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
289
-
290
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
291
-
292
  if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
293
  self.save_checkpoint(global_step)
294
-
295
  if global_step % self.last_per_steps == 0:
296
  self.save_checkpoint(global_step, last=True)
297
-
298
- self.save_checkpoint(global_step, last=True)
299
-
300
  self.accelerator.end_training()
 
10
  from torch.utils.data import DataLoader, Dataset, SequentialSampler
11
  from torch.optim.lr_scheduler import LinearLR, SequentialLR
12
 
13
+ from einops import rearrange
14
+
15
  from accelerate import Accelerator
16
  from accelerate.utils import DistributedDataParallelKwargs
17
 
 
24
 
25
  # trainer
26
 
 
27
  class Trainer:
28
  def __init__(
29
  self,
30
  model: CFM,
31
  epochs,
32
  learning_rate,
33
+ num_warmup_updates = 20000,
34
+ save_per_updates = 1000,
35
+ checkpoint_path = None,
36
+ batch_size = 32,
37
  batch_size_type: str = "sample",
38
+ max_samples = 32,
39
+ grad_accumulation_steps = 1,
40
+ max_grad_norm = 1.0,
41
  noise_scheduler: str | None = None,
42
  duration_predictor: torch.nn.Module | None = None,
43
+ wandb_project = "test_e2-tts",
44
+ wandb_run_name = "test_run",
45
  wandb_resume_id: str = None,
46
+ last_per_steps = None,
47
  accelerate_kwargs: dict = dict(),
48
+ ema_kwargs: dict = dict()
 
49
  ):
50
+
51
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
 
 
52
 
53
  self.accelerator = Accelerator(
54
+ log_with = "wandb",
55
+ kwargs_handlers = [ddp_kwargs],
56
+ gradient_accumulation_steps = grad_accumulation_steps,
57
+ **accelerate_kwargs
58
  )
59
+
60
+ if exists(wandb_resume_id):
61
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
62
+ else:
63
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
64
+ self.accelerator.init_trackers(
65
+ project_name = wandb_project,
66
+ init_kwargs=init_kwargs,
67
+ config={"epochs": epochs,
 
 
68
  "learning_rate": learning_rate,
69
+ "num_warmup_updates": num_warmup_updates,
70
  "batch_size": batch_size,
71
  "batch_size_type": batch_size_type,
72
  "max_samples": max_samples,
73
  "grad_accumulation_steps": grad_accumulation_steps,
74
  "max_grad_norm": max_grad_norm,
75
  "gpus": self.accelerator.num_processes,
76
+ "noise_scheduler": noise_scheduler}
 
77
  )
78
 
79
  self.model = model
80
 
81
  if self.is_main:
82
+ self.ema_model = EMA(
83
+ model,
84
+ include_online_model = False,
85
+ **ema_kwargs
86
+ )
87
 
88
  self.ema_model.to(self.accelerator.device)
89
 
 
91
  self.num_warmup_updates = num_warmup_updates
92
  self.save_per_updates = save_per_updates
93
  self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
94
+ self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts')
95
 
96
  self.batch_size = batch_size
97
  self.batch_size_type = batch_size_type
 
103
 
104
  self.duration_predictor = duration_predictor
105
 
106
+ self.optimizer = AdamW(model.parameters(), lr=learning_rate)
107
+ self.model, self.optimizer = self.accelerator.prepare(
108
+ self.model, self.optimizer
109
+ )
 
 
 
110
 
111
  @property
112
  def is_main(self):
 
116
  self.accelerator.wait_for_everyone()
117
  if self.is_main:
118
  checkpoint = dict(
119
+ model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
120
+ optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
121
+ ema_model_state_dict = self.ema_model.state_dict(),
122
+ scheduler_state_dict = self.scheduler.state_dict(),
123
+ step = step
124
  )
125
  if not os.path.exists(self.checkpoint_path):
126
  os.makedirs(self.checkpoint_path)
127
+ if last == True:
128
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
129
  print(f"Saved last checkpoint at step {step}")
130
  else:
131
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
132
 
133
  def load_checkpoint(self):
134
+ if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path):
 
 
 
 
135
  return 0
136
+
137
  self.accelerator.wait_for_everyone()
138
  if "model_last.pt" in os.listdir(self.checkpoint_path):
139
  latest_checkpoint = "model_last.pt"
140
  else:
141
+ latest_checkpoint = sorted(os.listdir(self.checkpoint_path), key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
 
 
 
142
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
+ checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
144
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
145
+ self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
146
 
147
  if self.is_main:
148
+ self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
149
+
150
+ if self.scheduler:
151
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
152
+
153
+ step = checkpoint['step']
154
+ del checkpoint; gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
155
  return step
156
 
157
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
158
+
159
  if exists(resumable_with_seed):
160
  generator = torch.Generator()
161
  generator.manual_seed(resumable_with_seed)
162
+ else:
163
  generator = None
164
 
165
  if self.batch_size_type == "sample":
166
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True,
167
+ batch_size=self.batch_size, shuffle=True, generator=generator)
 
 
 
 
 
 
 
 
168
  elif self.batch_size_type == "frame":
169
  self.accelerator.even_batches = False
170
  sampler = SequentialSampler(train_dataset)
171
+ batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
172
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True,
173
+ batch_sampler=batch_sampler)
 
 
 
 
 
 
 
 
174
  else:
175
+ raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but recieved {self.batch_size_type}")
176
+
177
  # accelerator.prepare() dispatches batches to devices;
178
  # which means the length of dataloader calculated before, should consider the number of devices
179
+ warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp
180
+ # otherwise by default with split_batches=False, warmup steps change with num_processes
 
 
181
  total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
182
  decay_steps = total_steps - warmup_steps
183
  warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
184
  decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
185
+ self.scheduler = SequentialLR(self.optimizer,
186
+ schedulers=[warmup_scheduler, decay_scheduler],
187
+ milestones=[warmup_steps])
188
+ train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus
 
 
189
  start_step = self.load_checkpoint()
190
  global_step = start_step
191
 
 
200
  for epoch in range(skipped_epoch, self.epochs):
201
  self.model.train()
202
  if exists(resumable_with_seed) and epoch == skipped_epoch:
203
+ progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process,
204
+ initial=skipped_batch, total=orig_epoch_step)
 
 
 
 
 
 
205
  else:
206
+ progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
 
 
 
 
 
207
 
208
  for batch in progress_bar:
209
  with self.accelerator.accumulate(self.model):
210
+ text_inputs = batch['text']
211
+ mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
212
  mel_lengths = batch["mel_lengths"]
213
 
214
  # TODO. add duration predictor training
215
  if self.duration_predictor is not None and self.accelerator.is_local_main_process:
216
+ dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
217
  self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
218
 
219
+ loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler)
 
 
220
  self.accelerator.backward(loss)
221
 
222
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
 
233
 
234
  if self.accelerator.is_local_main_process:
235
  self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
236
+
237
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
238
+
239
  if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
240
  self.save_checkpoint(global_step)
241
+
242
  if global_step % self.last_per_steps == 0:
243
  self.save_checkpoint(global_step, last=True)
244
+
 
 
245
  self.accelerator.end_training()
model/utils.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
 
3
  import os
 
4
  import math
5
  import random
6
  import string
@@ -8,7 +9,6 @@ from tqdm import tqdm
8
  from collections import defaultdict
9
 
10
  import matplotlib
11
-
12
  matplotlib.use("Agg")
13
  import matplotlib.pylab as plt
14
 
@@ -17,8 +17,17 @@ import torch.nn.functional as F
17
  from torch.nn.utils.rnn import pad_sequence
18
  import torchaudio
19
 
 
 
 
20
  import jieba
21
  from pypinyin import lazy_pinyin, Style
 
 
 
 
 
 
22
 
23
  from model.ecapa_tdnn import ECAPA_TDNN_SMALL
24
  from model.modules import MelSpec
@@ -26,102 +35,106 @@ from model.modules import MelSpec
26
 
27
  # seed everything
28
 
29
-
30
- def seed_everything(seed=0):
31
  random.seed(seed)
32
- os.environ["PYTHONHASHSEED"] = str(seed)
33
  torch.manual_seed(seed)
34
  torch.cuda.manual_seed(seed)
35
  torch.cuda.manual_seed_all(seed)
36
  torch.backends.cudnn.deterministic = True
37
  torch.backends.cudnn.benchmark = False
38
 
39
-
40
  # helpers
41
 
42
-
43
  def exists(v):
44
  return v is not None
45
 
46
-
47
  def default(v, d):
48
  return v if exists(v) else d
49
 
50
-
51
  # tensor helpers
52
 
 
 
 
 
53
 
54
- def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
55
  if not exists(length):
56
  length = t.amax()
57
 
58
- seq = torch.arange(length, device=t.device)
59
- return seq[None, :] < t[:, None]
60
-
61
-
62
- def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
63
- max_seq_len = seq_len.max().item()
64
- seq = torch.arange(max_seq_len, device=start.device).long()
65
- start_mask = seq[None, :] >= start[:, None]
66
- end_mask = seq[None, :] < end[:, None]
67
- return start_mask & end_mask
68
 
 
 
 
 
 
 
 
 
69
 
70
- def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
 
 
 
71
  lengths = (frac_lengths * seq_len).long()
72
  max_start = seq_len - lengths
73
 
74
  rand = torch.rand_like(frac_lengths)
75
- start = (max_start * rand).long().clamp(min=0)
76
  end = start + lengths
77
 
78
  return mask_from_start_end_indices(seq_len, start, end)
79
 
 
 
 
 
80
 
81
- def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
82
  if not exists(mask):
83
- return t.mean(dim=1)
84
 
85
- t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
86
- num = t.sum(dim=1)
87
- den = mask.float().sum(dim=1)
88
 
89
- return num / den.clamp(min=1.0)
90
 
91
 
92
  # simple utf-8 tokenizer, since paper went character based
93
- def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
94
- list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
95
- text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
 
 
 
96
  return text
97
 
98
-
99
  # char tokenizer, based on custom dataset's extracted .txt file
100
  def list_str_to_idx(
101
  text: list[str] | list[list[str]],
102
  vocab_char_map: dict[str, int], # {char: idx}
103
- padding_value=-1,
104
- ) -> int["b nt"]: # noqa: F722
105
  list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
106
- text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
107
  return text
108
 
109
 
110
  # Get tokenizer
111
 
112
-
113
  def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
114
- """
115
  tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
116
  - "char" for char-wise tokenizer, need .txt vocab_file
117
  - "byte" for utf-8 tokenizer
118
- - "custom" if you're directly passing in a path to the vocab.txt you want to use
119
  vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
120
  - if use "char", derived from unfiltered character & symbol counts of custom dataset
121
- - if use "byte", set to 256 (unicode byte range)
122
- """
123
  if tokenizer in ["pinyin", "char"]:
124
- with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
125
  vocab_char_map = {}
126
  for i, char in enumerate(f):
127
  vocab_char_map[char[:-1]] = i
@@ -131,31 +144,20 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
131
  elif tokenizer == "byte":
132
  vocab_char_map = None
133
  vocab_size = 256
134
- elif tokenizer == "custom":
135
- with open(dataset_name, "r", encoding="utf-8") as f:
136
- vocab_char_map = {}
137
- for i, char in enumerate(f):
138
- vocab_char_map[char[:-1]] = i
139
- vocab_size = len(vocab_char_map)
140
 
141
  return vocab_char_map, vocab_size
142
 
143
 
144
  # convert char to pinyin
145
 
146
-
147
- def convert_char_to_pinyin(text_list, polyphone=True):
148
  final_text_list = []
149
- god_knows_why_en_testset_contains_zh_quote = str.maketrans(
150
- {"“": '"', "”": '"', "‘": "'", "’": "'"}
151
- ) # in case librispeech (orig no-pc) test-clean
152
- custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov
153
  for text in text_list:
154
  char_list = []
155
  text = text.translate(god_knows_why_en_testset_contains_zh_quote)
156
- text = text.translate(custom_trans)
157
  for seg in jieba.cut(text):
158
- seg_byte_len = len(bytes(seg, "UTF-8"))
159
  if seg_byte_len == len(seg): # if pure alphabets and symbols
160
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
161
  char_list.append(" ")
@@ -184,7 +186,7 @@ def convert_char_to_pinyin(text_list, polyphone=True):
184
  # save spectrogram
185
  def save_spectrogram(spectrogram, path):
186
  plt.figure(figsize=(12, 4))
187
- plt.imshow(spectrogram, origin="lower", aspect="auto")
188
  plt.colorbar()
189
  plt.savefig(path)
190
  plt.close()
@@ -192,15 +194,13 @@ def save_spectrogram(spectrogram, path):
192
 
193
  # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
194
  def get_seedtts_testset_metainfo(metalst):
195
- f = open(metalst)
196
- lines = f.readlines()
197
- f.close()
198
  metainfo = []
199
  for line in lines:
200
- if len(line.strip().split("|")) == 5:
201
- utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
202
- elif len(line.strip().split("|")) == 4:
203
- utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
204
  gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
205
  if not os.path.isabs(prompt_wav):
206
  prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
@@ -210,20 +210,18 @@ def get_seedtts_testset_metainfo(metalst):
210
 
211
  # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
212
  def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
213
- f = open(metalst)
214
- lines = f.readlines()
215
- f.close()
216
  metainfo = []
217
  for line in lines:
218
- ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
219
 
220
  # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
221
- ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
222
- ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
223
 
224
  # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
225
- gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
226
- gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
227
 
228
  metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
229
 
@@ -235,30 +233,21 @@ def padded_mel_batch(ref_mels):
235
  max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
236
  padded_ref_mels = []
237
  for mel in ref_mels:
238
- padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
239
  padded_ref_mels.append(padded_ref_mel)
240
  padded_ref_mels = torch.stack(padded_ref_mels)
241
- padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
242
  return padded_ref_mels
243
 
244
 
245
  # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
246
 
247
-
248
  def get_inference_prompt(
249
- metainfo,
250
- speed=1.0,
251
- tokenizer="pinyin",
252
- polyphone=True,
253
- target_sample_rate=24000,
254
- n_mel_channels=100,
255
- hop_length=256,
256
- target_rms=0.1,
257
- use_truth_duration=False,
258
- infer_batch_size=1,
259
- num_buckets=200,
260
- min_secs=3,
261
- max_secs=40,
262
  ):
263
  prompts_all = []
264
 
@@ -266,15 +255,13 @@ def get_inference_prompt(
266
  max_tokens = max_secs * target_sample_rate // hop_length
267
 
268
  batch_accum = [0] * num_buckets
269
- utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
270
- [[] for _ in range(num_buckets)] for _ in range(6)
271
- )
272
 
273
- mel_spectrogram = MelSpec(
274
- target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
275
- )
276
 
277
  for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
 
278
  # Audio
279
  ref_audio, ref_sr = torchaudio.load(prompt_wav)
280
  ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
@@ -286,11 +273,9 @@ def get_inference_prompt(
286
  ref_audio = resampler(ref_audio)
287
 
288
  # Text
289
- if len(prompt_text[-1].encode("utf-8")) == 1:
290
- prompt_text = prompt_text + " "
291
  text = [prompt_text + gt_text]
292
  if tokenizer == "pinyin":
293
- text_list = convert_char_to_pinyin(text, polyphone=polyphone)
294
  else:
295
  text_list = text
296
 
@@ -306,19 +291,19 @@ def get_inference_prompt(
306
  # # test vocoder resynthesis
307
  # ref_audio = gt_audio
308
  else:
309
- ref_text_len = len(prompt_text.encode("utf-8"))
310
- gen_text_len = len(gt_text.encode("utf-8"))
 
311
  total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
312
 
313
  # to mel spectrogram
314
  ref_mel = mel_spectrogram(ref_audio)
315
- ref_mel = ref_mel.squeeze(0)
316
 
317
  # deal with batch
318
  assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
319
- assert (
320
- min_tokens <= total_mel_len <= max_tokens
321
- ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
322
  bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
323
 
324
  utts[bucket_i].append(utt)
@@ -332,39 +317,28 @@ def get_inference_prompt(
332
 
333
  if batch_accum[bucket_i] >= infer_batch_size:
334
  # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
335
- prompts_all.append(
336
- (
337
- utts[bucket_i],
338
- ref_rms_list[bucket_i],
339
- padded_mel_batch(ref_mels[bucket_i]),
340
- ref_mel_lens[bucket_i],
341
- total_mel_lens[bucket_i],
342
- final_text_list[bucket_i],
343
- )
344
- )
345
  batch_accum[bucket_i] = 0
346
- (
347
- utts[bucket_i],
348
- ref_rms_list[bucket_i],
349
- ref_mels[bucket_i],
350
- ref_mel_lens[bucket_i],
351
- total_mel_lens[bucket_i],
352
- final_text_list[bucket_i],
353
- ) = [], [], [], [], [], []
354
 
355
  # add residual
356
  for bucket_i, bucket_frames in enumerate(batch_accum):
357
  if bucket_frames > 0:
358
- prompts_all.append(
359
- (
360
- utts[bucket_i],
361
- ref_rms_list[bucket_i],
362
- padded_mel_batch(ref_mels[bucket_i]),
363
- ref_mel_lens[bucket_i],
364
- total_mel_lens[bucket_i],
365
- final_text_list[bucket_i],
366
- )
367
- )
368
  # not only leave easy work for last workers
369
  random.seed(666)
370
  random.shuffle(prompts_all)
@@ -375,7 +349,6 @@ def get_inference_prompt(
375
  # get wav_res_ref_text of seed-tts test metalst
376
  # https://github.com/BytedanceSpeech/seed-tts-eval
377
 
378
-
379
  def get_seed_tts_test(metalst, gen_wav_dir, gpus):
380
  f = open(metalst)
381
  lines = f.readlines()
@@ -383,14 +356,14 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
383
 
384
  test_set_ = []
385
  for line in tqdm(lines):
386
- if len(line.strip().split("|")) == 5:
387
- utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
388
- elif len(line.strip().split("|")) == 4:
389
- utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
390
 
391
- if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")):
392
  continue
393
- gen_wav = os.path.join(gen_wav_dir, utt + ".wav")
394
  if not os.path.isabs(prompt_wav):
395
  prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
396
 
@@ -399,69 +372,63 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
399
  num_jobs = len(gpus)
400
  if num_jobs == 1:
401
  return [(gpus[0], test_set_)]
402
-
403
  wav_per_job = len(test_set_) // num_jobs + 1
404
  test_set = []
405
  for i in range(num_jobs):
406
- test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
407
 
408
  return test_set
409
 
410
 
411
  # get librispeech test-clean cross sentence test
412
 
413
-
414
- def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False):
415
  f = open(metalst)
416
  lines = f.readlines()
417
  f.close()
418
 
419
  test_set_ = []
420
  for line in tqdm(lines):
421
- ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
422
 
423
  if eval_ground_truth:
424
- gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
425
- gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
426
  else:
427
- if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
428
  raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
429
- gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
430
 
431
- ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
432
- ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
433
 
434
  test_set_.append((gen_wav, ref_wav, gen_txt))
435
 
436
  num_jobs = len(gpus)
437
  if num_jobs == 1:
438
  return [(gpus[0], test_set_)]
439
-
440
  wav_per_job = len(test_set_) // num_jobs + 1
441
  test_set = []
442
  for i in range(num_jobs):
443
- test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
444
 
445
  return test_set
446
 
447
 
448
  # load asr model
449
 
450
-
451
- def load_asr_model(lang, ckpt_dir=""):
452
  if lang == "zh":
453
- from funasr import AutoModel
454
-
455
  model = AutoModel(
456
- model=os.path.join(ckpt_dir, "paraformer-zh"),
457
- # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
458
  # punc_model = os.path.join(ckpt_dir, "ct-punc"),
459
- # spk_model = os.path.join(ckpt_dir, "cam++"),
460
  disable_update=True,
461
- ) # following seed-tts setting
462
  elif lang == "en":
463
- from faster_whisper import WhisperModel
464
-
465
  model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
466
  model = WhisperModel(model_size, device="cuda", compute_type="float16")
467
  return model
@@ -469,50 +436,41 @@ def load_asr_model(lang, ckpt_dir=""):
469
 
470
  # WER Evaluation, the way Seed-TTS does
471
 
472
-
473
  def run_asr_wer(args):
474
  rank, lang, test_set, ckpt_dir = args
475
 
476
  if lang == "zh":
477
- import zhconv
478
-
479
  torch.cuda.set_device(rank)
480
  elif lang == "en":
481
  os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
482
  else:
483
- raise NotImplementedError(
484
- "lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now."
485
- )
486
 
487
- asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir)
488
-
489
- from zhon.hanzi import punctuation
490
 
491
  punctuation_all = punctuation + string.punctuation
492
  wers = []
493
 
494
- from jiwer import compute_measures
495
-
496
  for gen_wav, prompt_wav, truth in tqdm(test_set):
497
  if lang == "zh":
498
  res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
499
  hypo = res[0]["text"]
500
- hypo = zhconv.convert(hypo, "zh-cn")
501
  elif lang == "en":
502
  segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
503
- hypo = ""
504
  for segment in segments:
505
- hypo = hypo + " " + segment.text
506
 
507
  # raw_truth = truth
508
  # raw_hypo = hypo
509
 
510
  for x in punctuation_all:
511
- truth = truth.replace(x, "")
512
- hypo = hypo.replace(x, "")
513
 
514
- truth = truth.replace(" ", " ")
515
- hypo = hypo.replace(" ", " ")
516
 
517
  if lang == "zh":
518
  truth = " ".join([x for x in truth])
@@ -536,22 +494,22 @@ def run_asr_wer(args):
536
 
537
  # SIM Evaluation
538
 
539
-
540
  def run_sim(args):
541
  rank, test_set, ckpt_dir = args
542
  device = f"cuda:{rank}"
543
 
544
- model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
545
- state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
546
- model.load_state_dict(state_dict["model"], strict=False)
547
 
548
- use_gpu = True if torch.cuda.is_available() else False
549
  if use_gpu:
550
  model = model.cuda(device)
551
  model.eval()
552
 
553
  sim_list = []
554
  for wav1, wav2, truth in tqdm(test_set):
 
555
  wav1, sr1 = torchaudio.load(wav1)
556
  wav2, sr2 = torchaudio.load(wav2)
557
 
@@ -566,55 +524,22 @@ def run_sim(args):
566
  with torch.no_grad():
567
  emb1 = model(wav1)
568
  emb2 = model(wav2)
569
-
570
  sim = F.cosine_similarity(emb1, emb2)[0].item()
571
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
572
  sim_list.append(sim)
573
-
574
  return sim_list
575
 
576
 
577
  # filter func for dirty data with many repetitions
578
 
579
-
580
- def repetition_found(text, length=2, tolerance=10):
581
  pattern_count = defaultdict(int)
582
  for i in range(len(text) - length + 1):
583
- pattern = text[i : i + length]
584
  pattern_count[pattern] += 1
585
  for pattern, count in pattern_count.items():
586
  if count > tolerance:
587
  return True
588
  return False
589
-
590
-
591
- # load model checkpoint for inference
592
-
593
-
594
- def load_checkpoint(model, ckpt_path, device, use_ema=True):
595
- if device == "cuda":
596
- model = model.half()
597
-
598
- ckpt_type = ckpt_path.split(".")[-1]
599
- if ckpt_type == "safetensors":
600
- from safetensors.torch import load_file
601
-
602
- checkpoint = load_file(ckpt_path)
603
- else:
604
- checkpoint = torch.load(ckpt_path, weights_only=True)
605
-
606
- if use_ema:
607
- if ckpt_type == "safetensors":
608
- checkpoint = {"ema_model_state_dict": checkpoint}
609
- checkpoint["model_state_dict"] = {
610
- k.replace("ema_model.", ""): v
611
- for k, v in checkpoint["ema_model_state_dict"].items()
612
- if k not in ["initted", "step"]
613
- }
614
- model.load_state_dict(checkpoint["model_state_dict"])
615
- else:
616
- if ckpt_type == "safetensors":
617
- checkpoint = {"model_state_dict": checkpoint}
618
- model.load_state_dict(checkpoint["model_state_dict"])
619
-
620
- return model.to(device)
 
1
  from __future__ import annotations
2
 
3
  import os
4
+ import re
5
  import math
6
  import random
7
  import string
 
9
  from collections import defaultdict
10
 
11
  import matplotlib
 
12
  matplotlib.use("Agg")
13
  import matplotlib.pylab as plt
14
 
 
17
  from torch.nn.utils.rnn import pad_sequence
18
  import torchaudio
19
 
20
+ import einx
21
+ from einops import rearrange, reduce
22
+
23
  import jieba
24
  from pypinyin import lazy_pinyin, Style
25
+ import zhconv
26
+ from zhon.hanzi import punctuation
27
+ from jiwer import compute_measures
28
+
29
+ from funasr import AutoModel
30
+ from faster_whisper import WhisperModel
31
 
32
  from model.ecapa_tdnn import ECAPA_TDNN_SMALL
33
  from model.modules import MelSpec
 
35
 
36
  # seed everything
37
 
38
+ def seed_everything(seed = 0):
 
39
  random.seed(seed)
40
+ os.environ['PYTHONHASHSEED'] = str(seed)
41
  torch.manual_seed(seed)
42
  torch.cuda.manual_seed(seed)
43
  torch.cuda.manual_seed_all(seed)
44
  torch.backends.cudnn.deterministic = True
45
  torch.backends.cudnn.benchmark = False
46
 
 
47
  # helpers
48
 
 
49
  def exists(v):
50
  return v is not None
51
 
 
52
  def default(v, d):
53
  return v if exists(v) else d
54
 
 
55
  # tensor helpers
56
 
57
+ def lens_to_mask(
58
+ t: int['b'],
59
+ length: int | None = None
60
+ ) -> bool['b n']:
61
 
 
62
  if not exists(length):
63
  length = t.amax()
64
 
65
+ seq = torch.arange(length, device = t.device)
66
+ return einx.less('n, b -> b n', seq, t)
 
 
 
 
 
 
 
 
67
 
68
+ def mask_from_start_end_indices(
69
+ seq_len: int['b'],
70
+ start: int['b'],
71
+ end: int['b']
72
+ ):
73
+ max_seq_len = seq_len.max().item()
74
+ seq = torch.arange(max_seq_len, device = start.device).long()
75
+ return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
76
 
77
+ def mask_from_frac_lengths(
78
+ seq_len: int['b'],
79
+ frac_lengths: float['b']
80
+ ):
81
  lengths = (frac_lengths * seq_len).long()
82
  max_start = seq_len - lengths
83
 
84
  rand = torch.rand_like(frac_lengths)
85
+ start = (max_start * rand).long().clamp(min = 0)
86
  end = start + lengths
87
 
88
  return mask_from_start_end_indices(seq_len, start, end)
89
 
90
+ def maybe_masked_mean(
91
+ t: float['b n d'],
92
+ mask: bool['b n'] = None
93
+ ) -> float['b d']:
94
 
 
95
  if not exists(mask):
96
+ return t.mean(dim = 1)
97
 
98
+ t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
99
+ num = reduce(t, 'b n d -> b d', 'sum')
100
+ den = reduce(mask.float(), 'b n -> b', 'sum')
101
 
102
+ return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
103
 
104
 
105
  # simple utf-8 tokenizer, since paper went character based
106
+ def list_str_to_tensor(
107
+ text: list[str],
108
+ padding_value = -1
109
+ ) -> int['b nt']:
110
+ list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
111
+ text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
112
  return text
113
 
 
114
  # char tokenizer, based on custom dataset's extracted .txt file
115
  def list_str_to_idx(
116
  text: list[str] | list[list[str]],
117
  vocab_char_map: dict[str, int], # {char: idx}
118
+ padding_value = -1
119
+ ) -> int['b nt']:
120
  list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
121
+ text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
122
  return text
123
 
124
 
125
  # Get tokenizer
126
 
 
127
  def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
128
+ '''
129
  tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
130
  - "char" for char-wise tokenizer, need .txt vocab_file
131
  - "byte" for utf-8 tokenizer
 
132
  vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
133
  - if use "char", derived from unfiltered character & symbol counts of custom dataset
134
+ - if use "byte", set to 256 (unicode byte range)
135
+ '''
136
  if tokenizer in ["pinyin", "char"]:
137
+ with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r") as f:
138
  vocab_char_map = {}
139
  for i, char in enumerate(f):
140
  vocab_char_map[char[:-1]] = i
 
144
  elif tokenizer == "byte":
145
  vocab_char_map = None
146
  vocab_size = 256
 
 
 
 
 
 
147
 
148
  return vocab_char_map, vocab_size
149
 
150
 
151
  # convert char to pinyin
152
 
153
+ def convert_char_to_pinyin(text_list, polyphone = True):
 
154
  final_text_list = []
155
+ god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
 
 
 
156
  for text in text_list:
157
  char_list = []
158
  text = text.translate(god_knows_why_en_testset_contains_zh_quote)
 
159
  for seg in jieba.cut(text):
160
+ seg_byte_len = len(bytes(seg, 'UTF-8'))
161
  if seg_byte_len == len(seg): # if pure alphabets and symbols
162
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
163
  char_list.append(" ")
 
186
  # save spectrogram
187
  def save_spectrogram(spectrogram, path):
188
  plt.figure(figsize=(12, 4))
189
+ plt.imshow(spectrogram, origin='lower', aspect='auto')
190
  plt.colorbar()
191
  plt.savefig(path)
192
  plt.close()
 
194
 
195
  # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
196
  def get_seedtts_testset_metainfo(metalst):
197
+ f = open(metalst); lines = f.readlines(); f.close()
 
 
198
  metainfo = []
199
  for line in lines:
200
+ if len(line.strip().split('|')) == 5:
201
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
202
+ elif len(line.strip().split('|')) == 4:
203
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
204
  gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
205
  if not os.path.isabs(prompt_wav):
206
  prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
 
210
 
211
  # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
212
  def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
213
+ f = open(metalst); lines = f.readlines(); f.close()
 
 
214
  metainfo = []
215
  for line in lines:
216
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
217
 
218
  # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
219
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
220
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
221
 
222
  # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
223
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
224
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
225
 
226
  metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
227
 
 
233
  max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
234
  padded_ref_mels = []
235
  for mel in ref_mels:
236
+ padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
237
  padded_ref_mels.append(padded_ref_mel)
238
  padded_ref_mels = torch.stack(padded_ref_mels)
239
+ padded_ref_mels = rearrange(padded_ref_mels, 'b d n -> b n d')
240
  return padded_ref_mels
241
 
242
 
243
  # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
244
 
 
245
  def get_inference_prompt(
246
+ metainfo,
247
+ speed = 1., tokenizer = "pinyin", polyphone = True,
248
+ target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
249
+ use_truth_duration = False,
250
+ infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
 
 
 
 
 
 
 
 
251
  ):
252
  prompts_all = []
253
 
 
255
  max_tokens = max_secs * target_sample_rate // hop_length
256
 
257
  batch_accum = [0] * num_buckets
258
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
259
+ ([[] for _ in range(num_buckets)] for _ in range(6))
 
260
 
261
+ mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
 
 
262
 
263
  for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
264
+
265
  # Audio
266
  ref_audio, ref_sr = torchaudio.load(prompt_wav)
267
  ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
 
273
  ref_audio = resampler(ref_audio)
274
 
275
  # Text
 
 
276
  text = [prompt_text + gt_text]
277
  if tokenizer == "pinyin":
278
+ text_list = convert_char_to_pinyin(text, polyphone = polyphone)
279
  else:
280
  text_list = text
281
 
 
291
  # # test vocoder resynthesis
292
  # ref_audio = gt_audio
293
  else:
294
+ zh_pause_punc = r"。,、;:?!"
295
+ ref_text_len = len(prompt_text) + len(re.findall(zh_pause_punc, prompt_text))
296
+ gen_text_len = len(gt_text) + len(re.findall(zh_pause_punc, gt_text))
297
  total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
298
 
299
  # to mel spectrogram
300
  ref_mel = mel_spectrogram(ref_audio)
301
+ ref_mel = rearrange(ref_mel, '1 d n -> d n')
302
 
303
  # deal with batch
304
  assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
305
+ assert min_tokens <= total_mel_len <= max_tokens, \
306
+ f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
 
307
  bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
308
 
309
  utts[bucket_i].append(utt)
 
317
 
318
  if batch_accum[bucket_i] >= infer_batch_size:
319
  # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
320
+ prompts_all.append((
321
+ utts[bucket_i],
322
+ ref_rms_list[bucket_i],
323
+ padded_mel_batch(ref_mels[bucket_i]),
324
+ ref_mel_lens[bucket_i],
325
+ total_mel_lens[bucket_i],
326
+ final_text_list[bucket_i]
327
+ ))
 
 
328
  batch_accum[bucket_i] = 0
329
+ utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
 
 
 
 
 
 
 
330
 
331
  # add residual
332
  for bucket_i, bucket_frames in enumerate(batch_accum):
333
  if bucket_frames > 0:
334
+ prompts_all.append((
335
+ utts[bucket_i],
336
+ ref_rms_list[bucket_i],
337
+ padded_mel_batch(ref_mels[bucket_i]),
338
+ ref_mel_lens[bucket_i],
339
+ total_mel_lens[bucket_i],
340
+ final_text_list[bucket_i]
341
+ ))
 
 
342
  # not only leave easy work for last workers
343
  random.seed(666)
344
  random.shuffle(prompts_all)
 
349
  # get wav_res_ref_text of seed-tts test metalst
350
  # https://github.com/BytedanceSpeech/seed-tts-eval
351
 
 
352
  def get_seed_tts_test(metalst, gen_wav_dir, gpus):
353
  f = open(metalst)
354
  lines = f.readlines()
 
356
 
357
  test_set_ = []
358
  for line in tqdm(lines):
359
+ if len(line.strip().split('|')) == 5:
360
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
361
+ elif len(line.strip().split('|')) == 4:
362
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
363
 
364
+ if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
365
  continue
366
+ gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
367
  if not os.path.isabs(prompt_wav):
368
  prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
369
 
 
372
  num_jobs = len(gpus)
373
  if num_jobs == 1:
374
  return [(gpus[0], test_set_)]
375
+
376
  wav_per_job = len(test_set_) // num_jobs + 1
377
  test_set = []
378
  for i in range(num_jobs):
379
+ test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
380
 
381
  return test_set
382
 
383
 
384
  # get librispeech test-clean cross sentence test
385
 
386
+ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
 
387
  f = open(metalst)
388
  lines = f.readlines()
389
  f.close()
390
 
391
  test_set_ = []
392
  for line in tqdm(lines):
393
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
394
 
395
  if eval_ground_truth:
396
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
397
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
398
  else:
399
+ if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
400
  raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
401
+ gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
402
 
403
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
404
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
405
 
406
  test_set_.append((gen_wav, ref_wav, gen_txt))
407
 
408
  num_jobs = len(gpus)
409
  if num_jobs == 1:
410
  return [(gpus[0], test_set_)]
411
+
412
  wav_per_job = len(test_set_) // num_jobs + 1
413
  test_set = []
414
  for i in range(num_jobs):
415
+ test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
416
 
417
  return test_set
418
 
419
 
420
  # load asr model
421
 
422
+ def load_asr_model(lang, ckpt_dir = ""):
 
423
  if lang == "zh":
 
 
424
  model = AutoModel(
425
+ model = os.path.join(ckpt_dir, "paraformer-zh"),
426
+ # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
427
  # punc_model = os.path.join(ckpt_dir, "ct-punc"),
428
+ # spk_model = os.path.join(ckpt_dir, "cam++"),
429
  disable_update=True,
430
+ ) # following seed-tts setting
431
  elif lang == "en":
 
 
432
  model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
433
  model = WhisperModel(model_size, device="cuda", compute_type="float16")
434
  return model
 
436
 
437
  # WER Evaluation, the way Seed-TTS does
438
 
 
439
  def run_asr_wer(args):
440
  rank, lang, test_set, ckpt_dir = args
441
 
442
  if lang == "zh":
 
 
443
  torch.cuda.set_device(rank)
444
  elif lang == "en":
445
  os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
446
  else:
447
+ raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
 
 
448
 
449
+ asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
 
 
450
 
451
  punctuation_all = punctuation + string.punctuation
452
  wers = []
453
 
 
 
454
  for gen_wav, prompt_wav, truth in tqdm(test_set):
455
  if lang == "zh":
456
  res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
457
  hypo = res[0]["text"]
458
+ hypo = zhconv.convert(hypo, 'zh-cn')
459
  elif lang == "en":
460
  segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
461
+ hypo = ''
462
  for segment in segments:
463
+ hypo = hypo + ' ' + segment.text
464
 
465
  # raw_truth = truth
466
  # raw_hypo = hypo
467
 
468
  for x in punctuation_all:
469
+ truth = truth.replace(x, '')
470
+ hypo = hypo.replace(x, '')
471
 
472
+ truth = truth.replace(' ', ' ')
473
+ hypo = hypo.replace(' ', ' ')
474
 
475
  if lang == "zh":
476
  truth = " ".join([x for x in truth])
 
494
 
495
  # SIM Evaluation
496
 
 
497
  def run_sim(args):
498
  rank, test_set, ckpt_dir = args
499
  device = f"cuda:{rank}"
500
 
501
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
502
+ state_dict = torch.load(ckpt_dir, map_location=lambda storage, loc: storage)
503
+ model.load_state_dict(state_dict['model'], strict=False)
504
 
505
+ use_gpu=True if torch.cuda.is_available() else False
506
  if use_gpu:
507
  model = model.cuda(device)
508
  model.eval()
509
 
510
  sim_list = []
511
  for wav1, wav2, truth in tqdm(test_set):
512
+
513
  wav1, sr1 = torchaudio.load(wav1)
514
  wav2, sr2 = torchaudio.load(wav2)
515
 
 
524
  with torch.no_grad():
525
  emb1 = model(wav1)
526
  emb2 = model(wav2)
527
+
528
  sim = F.cosine_similarity(emb1, emb2)[0].item()
529
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
530
  sim_list.append(sim)
531
+
532
  return sim_list
533
 
534
 
535
  # filter func for dirty data with many repetitions
536
 
537
+ def repetition_found(text, length = 2, tolerance = 10):
 
538
  pattern_count = defaultdict(int)
539
  for i in range(len(text) - length + 1):
540
+ pattern = text[i:i + length]
541
  pattern_count[pattern] += 1
542
  for pattern, count in pattern_count.items():
543
  if count > tolerance:
544
  return True
545
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/utils_infer.py DELETED
@@ -1,357 +0,0 @@
1
- # A unified script for inference process
2
- # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
3
-
4
- import re
5
- import tempfile
6
-
7
- import numpy as np
8
- import torch
9
- import torchaudio
10
- import tqdm
11
- from pydub import AudioSegment, silence
12
- from transformers import pipeline
13
- from vocos import Vocos
14
-
15
- from model import CFM
16
- from model.utils import (
17
- load_checkpoint,
18
- get_tokenizer,
19
- convert_char_to_pinyin,
20
- )
21
-
22
-
23
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
24
-
25
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
26
-
27
-
28
- # -----------------------------------------
29
-
30
- target_sample_rate = 24000
31
- n_mel_channels = 100
32
- hop_length = 256
33
- target_rms = 0.1
34
- cross_fade_duration = 0.15
35
- ode_method = "euler"
36
- nfe_step = 32 # 16, 32
37
- cfg_strength = 2.0
38
- sway_sampling_coef = -1.0
39
- speed = 1.0
40
- fix_duration = None
41
-
42
- # -----------------------------------------
43
-
44
-
45
- # chunk text into smaller pieces
46
-
47
-
48
- def chunk_text(text, max_chars=135):
49
- """
50
- Splits the input text into chunks, each with a maximum number of characters.
51
-
52
- Args:
53
- text (str): The text to be split.
54
- max_chars (int): The maximum number of characters per chunk.
55
-
56
- Returns:
57
- List[str]: A list of text chunks.
58
- """
59
- chunks = []
60
- current_chunk = ""
61
- # Split the text into sentences based on punctuation followed by whitespace
62
- sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
63
-
64
- for sentence in sentences:
65
- if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
66
- current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
67
- else:
68
- if current_chunk:
69
- chunks.append(current_chunk.strip())
70
- current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
71
-
72
- if current_chunk:
73
- chunks.append(current_chunk.strip())
74
-
75
- return chunks
76
-
77
-
78
- # load vocoder
79
- def load_vocoder(is_local=False, local_path="", device=device):
80
- if is_local:
81
- print(f"Load vocos from local path {local_path}")
82
- vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
83
- state_dict = torch.load(f"{local_path}/pytorch_model.bin", map_location=device)
84
- vocos.load_state_dict(state_dict)
85
- vocos.eval()
86
- else:
87
- print("Download Vocos from huggingface charactr/vocos-mel-24khz")
88
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
89
- return vocos
90
-
91
-
92
- # load asr pipeline
93
-
94
- asr_pipe = None
95
-
96
-
97
- def initialize_asr_pipeline(device=device):
98
- global asr_pipe
99
- asr_pipe = pipeline(
100
- "automatic-speech-recognition",
101
- model="openai/whisper-large-v3-turbo",
102
- torch_dtype=torch.float16,
103
- device=device,
104
- )
105
-
106
-
107
- # load model for inference
108
-
109
-
110
- def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=device):
111
- if vocab_file == "":
112
- vocab_file = "Emilia_ZH_EN"
113
- tokenizer = "pinyin"
114
- else:
115
- tokenizer = "custom"
116
-
117
- print("\nvocab : ", vocab_file)
118
- print("tokenizer : ", tokenizer)
119
- print("model : ", ckpt_path, "\n")
120
-
121
- vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
122
- model = CFM(
123
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
124
- mel_spec_kwargs=dict(
125
- target_sample_rate=target_sample_rate,
126
- n_mel_channels=n_mel_channels,
127
- hop_length=hop_length,
128
- ),
129
- odeint_kwargs=dict(
130
- method=ode_method,
131
- ),
132
- vocab_char_map=vocab_char_map,
133
- ).to(device)
134
-
135
- model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
136
-
137
- return model
138
-
139
-
140
- # preprocess reference audio and text
141
-
142
-
143
- def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=device):
144
- show_info("Converting audio...")
145
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
146
- aseg = AudioSegment.from_file(ref_audio_orig)
147
-
148
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
149
- non_silent_wave = AudioSegment.silent(duration=0)
150
- for non_silent_seg in non_silent_segs:
151
- non_silent_wave += non_silent_seg
152
- aseg = non_silent_wave
153
-
154
- audio_duration = len(aseg)
155
- if audio_duration > 15000:
156
- show_info("Audio is over 15s, clipping to only first 15s.")
157
- aseg = aseg[:15000]
158
- aseg.export(f.name, format="wav")
159
- ref_audio = f.name
160
-
161
- if not ref_text.strip():
162
- global asr_pipe
163
- if asr_pipe is None:
164
- initialize_asr_pipeline(device=device)
165
- show_info("No reference text provided, transcribing reference audio...")
166
- ref_text = asr_pipe(
167
- ref_audio,
168
- chunk_length_s=30,
169
- batch_size=128,
170
- generate_kwargs={"task": "transcribe"},
171
- return_timestamps=False,
172
- )["text"].strip()
173
- show_info("Finished transcription")
174
- else:
175
- show_info("Using custom reference text...")
176
-
177
- # Add the functionality to ensure it ends with ". "
178
- if not ref_text.endswith(". ") and not ref_text.endswith("。"):
179
- if ref_text.endswith("."):
180
- ref_text += " "
181
- else:
182
- ref_text += ". "
183
-
184
- return ref_audio, ref_text
185
-
186
-
187
- # infer process: chunk text -> infer batches [i.e. infer_batch_process()]
188
-
189
-
190
- def infer_process(
191
- ref_audio,
192
- ref_text,
193
- gen_text,
194
- model_obj,
195
- show_info=print,
196
- progress=tqdm,
197
- target_rms=target_rms,
198
- cross_fade_duration=cross_fade_duration,
199
- nfe_step=nfe_step,
200
- cfg_strength=cfg_strength,
201
- sway_sampling_coef=sway_sampling_coef,
202
- speed=speed,
203
- fix_duration=fix_duration,
204
- device=device,
205
- ):
206
- # Split the input text into batches
207
- audio, sr = torchaudio.load(ref_audio)
208
- max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
209
- gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
210
- for i, gen_text in enumerate(gen_text_batches):
211
- print(f"gen_text {i}", gen_text)
212
-
213
- show_info(f"Generating audio in {len(gen_text_batches)} batches...")
214
- return infer_batch_process(
215
- (audio, sr),
216
- ref_text,
217
- gen_text_batches,
218
- model_obj,
219
- progress=progress,
220
- target_rms=target_rms,
221
- cross_fade_duration=cross_fade_duration,
222
- nfe_step=nfe_step,
223
- cfg_strength=cfg_strength,
224
- sway_sampling_coef=sway_sampling_coef,
225
- speed=speed,
226
- fix_duration=fix_duration,
227
- device=device,
228
- )
229
-
230
-
231
- # infer batches
232
-
233
-
234
- def infer_batch_process(
235
- ref_audio,
236
- ref_text,
237
- gen_text_batches,
238
- model_obj,
239
- progress=tqdm,
240
- target_rms=0.1,
241
- cross_fade_duration=0.15,
242
- nfe_step=32,
243
- cfg_strength=2.0,
244
- sway_sampling_coef=-1,
245
- speed=1,
246
- fix_duration=None,
247
- device=None,
248
- ):
249
- audio, sr = ref_audio
250
- if audio.shape[0] > 1:
251
- audio = torch.mean(audio, dim=0, keepdim=True)
252
-
253
- rms = torch.sqrt(torch.mean(torch.square(audio)))
254
- if rms < target_rms:
255
- audio = audio * target_rms / rms
256
- if sr != target_sample_rate:
257
- resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
258
- audio = resampler(audio)
259
- audio = audio.to(device)
260
-
261
- generated_waves = []
262
- spectrograms = []
263
-
264
- if len(ref_text[-1].encode("utf-8")) == 1:
265
- ref_text = ref_text + " "
266
- for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
267
- # Prepare the text
268
- text_list = [ref_text + gen_text]
269
- final_text_list = convert_char_to_pinyin(text_list)
270
-
271
- ref_audio_len = audio.shape[-1] // hop_length
272
- if fix_duration is not None:
273
- duration = int(fix_duration * target_sample_rate / hop_length)
274
- else:
275
- # Calculate duration
276
- ref_text_len = len(ref_text.encode("utf-8"))
277
- gen_text_len = len(gen_text.encode("utf-8"))
278
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
279
-
280
- # inference
281
- with torch.inference_mode():
282
- generated, _ = model_obj.sample(
283
- cond=audio,
284
- text=final_text_list,
285
- duration=duration,
286
- steps=nfe_step,
287
- cfg_strength=cfg_strength,
288
- sway_sampling_coef=sway_sampling_coef,
289
- )
290
-
291
- generated = generated.to(torch.float32)
292
- generated = generated[:, ref_audio_len:, :]
293
- generated_mel_spec = generated.permute(0, 2, 1)
294
- generated_wave = vocos.decode(generated_mel_spec.cpu())
295
- if rms < target_rms:
296
- generated_wave = generated_wave * rms / target_rms
297
-
298
- # wav -> numpy
299
- generated_wave = generated_wave.squeeze().cpu().numpy()
300
-
301
- generated_waves.append(generated_wave)
302
- spectrograms.append(generated_mel_spec[0].cpu().numpy())
303
-
304
- # Combine all generated waves with cross-fading
305
- if cross_fade_duration <= 0:
306
- # Simply concatenate
307
- final_wave = np.concatenate(generated_waves)
308
- else:
309
- final_wave = generated_waves[0]
310
- for i in range(1, len(generated_waves)):
311
- prev_wave = final_wave
312
- next_wave = generated_waves[i]
313
-
314
- # Calculate cross-fade samples, ensuring it does not exceed wave lengths
315
- cross_fade_samples = int(cross_fade_duration * target_sample_rate)
316
- cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
317
-
318
- if cross_fade_samples <= 0:
319
- # No overlap possible, concatenate
320
- final_wave = np.concatenate([prev_wave, next_wave])
321
- continue
322
-
323
- # Overlapping parts
324
- prev_overlap = prev_wave[-cross_fade_samples:]
325
- next_overlap = next_wave[:cross_fade_samples]
326
-
327
- # Fade out and fade in
328
- fade_out = np.linspace(1, 0, cross_fade_samples)
329
- fade_in = np.linspace(0, 1, cross_fade_samples)
330
-
331
- # Cross-faded overlap
332
- cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
333
-
334
- # Combine
335
- new_wave = np.concatenate(
336
- [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
337
- )
338
-
339
- final_wave = new_wave
340
-
341
- # Create a combined spectrogram
342
- combined_spectrogram = np.concatenate(spectrograms, axis=1)
343
-
344
- return final_wave, target_sample_rate, combined_spectrogram
345
-
346
-
347
- # remove silence from generated wav
348
-
349
-
350
- def remove_silence_for_generated_wav(filename):
351
- aseg = AudioSegment.from_file(filename)
352
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
353
- non_silent_wave = AudioSegment.silent(duration=0)
354
- for non_silent_seg in non_silent_segs:
355
- non_silent_wave += non_silent_seg
356
- aseg = non_silent_wave
357
- aseg.export(filename, format="wav")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt CHANGED
@@ -1,22 +1,27 @@
1
  accelerate>=0.33.0
2
- bitsandbytes>0.37.0
3
- cached_path
4
- click
5
  datasets
 
 
6
  ema_pytorch>=0.5.2
7
- gradio
 
8
  jieba
 
9
  librosa
10
  matplotlib
11
- numpy<=1.26.4
12
- pydub
13
  pypinyin
14
- safetensors
15
- soundfile
16
- tomli
17
  torchdiffeq
18
  tqdm>=4.65.0
19
  transformers
20
  vocos
21
  wandb
22
  x_transformers>=1.31.14
 
 
 
 
 
 
 
 
1
  accelerate>=0.33.0
 
 
 
2
  datasets
3
+ einops>=0.8.0
4
+ einx>=0.3.0
5
  ema_pytorch>=0.5.2
6
+ faster_whisper
7
+ funasr
8
  jieba
9
+ jiwer
10
  librosa
11
  matplotlib
 
 
12
  pypinyin
13
+ torch>=2.0
14
+ torchaudio>=2.3.0
 
15
  torchdiffeq
16
  tqdm>=4.65.0
17
  transformers
18
  vocos
19
  wandb
20
  x_transformers>=1.31.14
21
+ zhconv
22
+ zhon
23
+ cached_path
24
+ pydub
25
+ txtsplit
26
+ detoxify
27
+ soundfile
requirements_eval.txt DELETED
@@ -1,5 +0,0 @@
1
- faster_whisper
2
- funasr
3
- jiwer
4
- zhconv
5
- zhon
 
 
 
 
 
 
ruff.toml DELETED
@@ -1,10 +0,0 @@
1
- line-length = 120
2
- target-version = "py310"
3
-
4
- [lint]
5
- # Only ignore variables with names starting with "_".
6
- dummy-variable-rgx = "^_.*$"
7
-
8
- [lint.isort]
9
- force-single-line = true
10
- lines-after-imports = 2
 
 
 
 
 
 
 
 
 
 
 
samples/country.flac DELETED
Binary file (180 kB)
 
samples/main.flac DELETED
Binary file (279 kB)
 
samples/story.toml DELETED
@@ -1,19 +0,0 @@
1
- # F5-TTS | E2-TTS
2
- model = "F5-TTS"
3
- ref_audio = "samples/main.flac"
4
- # If an empty "", transcribes the reference audio automatically.
5
- ref_text = ""
6
- gen_text = ""
7
- # File with text to generate. Ignores the text above.
8
- gen_file = "samples/story.txt"
9
- remove_silence = true
10
- output_dir = "samples"
11
-
12
- [voices.town]
13
- ref_audio = "samples/town.flac"
14
- ref_text = ""
15
-
16
- [voices.country]
17
- ref_audio = "samples/country.flac"
18
- ref_text = ""
19
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
samples/story.txt DELETED
@@ -1 +0,0 @@
1
- A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] “My poor dear friend, you live here no better than the ants. Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land.” [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] “Goodbye,” [main] said he, [country] “I’m off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace.”
 
 
samples/town.flac DELETED
Binary file (229 kB)
 
scripts/count_max_epoch.py CHANGED
@@ -1,7 +1,6 @@
1
- """ADAPTIVE BATCH SIZE"""
2
-
3
- print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in")
4
- print(" -> least padding, gather wavs with accumulated frames in a batch\n")
5
 
6
  # data
7
  total_hours = 95282
 
1
+ '''ADAPTIVE BATCH SIZE'''
2
+ print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in')
3
+ print(' -> least padding, gather wavs with accumulated frames in a batch\n')
 
4
 
5
  # data
6
  total_hours = 95282
scripts/count_params_gflops.py CHANGED
@@ -1,15 +1,13 @@
1
- import sys
2
- import os
3
-
4
  sys.path.append(os.getcwd())
5
 
6
- from model import M2_TTS, DiT
7
 
8
  import torch
9
  import thop
10
 
11
 
12
- """ ~155M """
13
  # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
14
  # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
15
  # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
@@ -17,11 +15,11 @@ import thop
17
  # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
18
  # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
19
 
20
- """ ~335M """
21
  # FLOPs: 622.1 G, Params: 333.2 M
22
  # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
23
  # FLOPs: 363.4 G, Params: 335.8 M
24
- transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
25
 
26
 
27
  model = M2_TTS(transformer=transformer)
@@ -32,8 +30,6 @@ duration = 20
32
  frame_length = int(duration * target_sample_rate / hop_length)
33
  text_length = 150
34
 
35
- flops, params = thop.profile(
36
- model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
37
- )
38
  print(f"FLOPs: {flops / 1e9} G")
39
  print(f"Params: {params / 1e6} M")
 
1
+ import sys, os
 
 
2
  sys.path.append(os.getcwd())
3
 
4
+ from model import M2_TTS, UNetT, DiT, MMDiT
5
 
6
  import torch
7
  import thop
8
 
9
 
10
+ ''' ~155M '''
11
  # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
12
  # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
13
  # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
 
15
  # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
16
  # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
17
 
18
+ ''' ~335M '''
19
  # FLOPs: 622.1 G, Params: 333.2 M
20
  # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
21
  # FLOPs: 363.4 G, Params: 335.8 M
22
+ transformer = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
23
 
24
 
25
  model = M2_TTS(transformer=transformer)
 
30
  frame_length = int(duration * target_sample_rate / hop_length)
31
  text_length = 150
32
 
33
+ flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)))
 
 
34
  print(f"FLOPs: {flops / 1e9} G")
35
  print(f"Params: {params / 1e6} M")
scripts/eval_infer_batch.sh DELETED
@@ -1,13 +0,0 @@
1
- #!/bin/bash
2
-
3
- # e.g. F5-TTS, 16 NFE
4
- accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
- accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
- accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
-
8
- # e.g. Vanilla E2 TTS, 32 NFE
9
- accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
- accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
- accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
-
13
- # etc.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/eval_librispeech_test_clean.py CHANGED
@@ -1,8 +1,6 @@
1
  # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2
 
3
- import sys
4
- import os
5
-
6
  sys.path.append(os.getcwd())
7
 
8
  import multiprocessing as mp
@@ -21,7 +19,7 @@ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
21
  librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
22
  gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
23
 
24
- gpus = [0, 1, 2, 3, 4, 5, 6, 7]
25
  test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
26
 
27
  ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
@@ -48,7 +46,7 @@ if eval_task == "wer":
48
  for wers_ in results:
49
  wers.extend(wers_)
50
 
51
- wer = round(np.mean(wers) * 100, 3)
52
  print(f"\nTotal {len(wers)} samples")
53
  print(f"WER : {wer}%")
54
 
@@ -64,6 +62,6 @@ if eval_task == "sim":
64
  for sim_ in results:
65
  sim_list.extend(sim_)
66
 
67
- sim = round(sum(sim_list) / len(sim_list), 3)
68
  print(f"\nTotal {len(sim_list)} samples")
69
  print(f"SIM : {sim}")
 
1
  # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2
 
3
+ import sys, os
 
 
4
  sys.path.append(os.getcwd())
5
 
6
  import multiprocessing as mp
 
19
  librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
20
  gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
21
 
22
+ gpus = [0,1,2,3,4,5,6,7]
23
  test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
24
 
25
  ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
 
46
  for wers_ in results:
47
  wers.extend(wers_)
48
 
49
+ wer = round(np.mean(wers)*100, 3)
50
  print(f"\nTotal {len(wers)} samples")
51
  print(f"WER : {wer}%")
52
 
 
62
  for sim_ in results:
63
  sim_list.extend(sim_)
64
 
65
+ sim = round(sum(sim_list)/len(sim_list), 3)
66
  print(f"\nTotal {len(sim_list)} samples")
67
  print(f"SIM : {sim}")
scripts/eval_seedtts_testset.py CHANGED
@@ -1,8 +1,6 @@
1
  # Evaluate with Seed-TTS testset
2
 
3
- import sys
4
- import os
5
-
6
  sys.path.append(os.getcwd())
7
 
8
  import multiprocessing as mp
@@ -16,21 +14,21 @@ from model.utils import (
16
 
17
 
18
  eval_task = "wer" # sim | wer
19
- lang = "zh" # zh | en
20
  metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
21
  # gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
22
- gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
23
 
24
 
25
  # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
26
- # zh 1.254 seems a result of 4 workers wer_seed_tts
27
- gpus = [0, 1, 2, 3, 4, 5, 6, 7]
28
  test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
29
 
30
  local = False
31
  if local: # use local custom checkpoint dir
32
  if lang == "zh":
33
- asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
34
  elif lang == "en":
35
  asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
36
  else:
@@ -50,7 +48,7 @@ if eval_task == "wer":
50
  for wers_ in results:
51
  wers.extend(wers_)
52
 
53
- wer = round(np.mean(wers) * 100, 3)
54
  print(f"\nTotal {len(wers)} samples")
55
  print(f"WER : {wer}%")
56
 
@@ -66,6 +64,6 @@ if eval_task == "sim":
66
  for sim_ in results:
67
  sim_list.extend(sim_)
68
 
69
- sim = round(sum(sim_list) / len(sim_list), 3)
70
  print(f"\nTotal {len(sim_list)} samples")
71
  print(f"SIM : {sim}")
 
1
  # Evaluate with Seed-TTS testset
2
 
3
+ import sys, os
 
 
4
  sys.path.append(os.getcwd())
5
 
6
  import multiprocessing as mp
 
14
 
15
 
16
  eval_task = "wer" # sim | wer
17
+ lang = "zh" # zh | en
18
  metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
19
  # gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
20
+ gen_wav_dir = f"PATH_TO_GENERATED" # generated wavs
21
 
22
 
23
  # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
24
+ # zh 1.254 seems a result of 4 workers wer_seed_tts
25
+ gpus = [0,1,2,3,4,5,6,7]
26
  test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
27
 
28
  local = False
29
  if local: # use local custom checkpoint dir
30
  if lang == "zh":
31
+ asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
32
  elif lang == "en":
33
  asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
34
  else:
 
48
  for wers_ in results:
49
  wers.extend(wers_)
50
 
51
+ wer = round(np.mean(wers)*100, 3)
52
  print(f"\nTotal {len(wers)} samples")
53
  print(f"WER : {wer}%")
54
 
 
64
  for sim_ in results:
65
  sim_list.extend(sim_)
66
 
67
+ sim = round(sum(sim_list)/len(sim_list), 3)
68
  print(f"\nTotal {len(sim_list)} samples")
69
  print(f"SIM : {sim}")
scripts/prepare_csv_wavs.py DELETED
@@ -1,138 +0,0 @@
1
- import sys
2
- import os
3
-
4
- sys.path.append(os.getcwd())
5
-
6
- from pathlib import Path
7
- import json
8
- import shutil
9
- import argparse
10
-
11
- import csv
12
- import torchaudio
13
- from tqdm import tqdm
14
- from datasets.arrow_writer import ArrowWriter
15
-
16
- from model.utils import (
17
- convert_char_to_pinyin,
18
- )
19
-
20
- PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
21
-
22
-
23
- def is_csv_wavs_format(input_dataset_dir):
24
- fpath = Path(input_dataset_dir)
25
- metadata = fpath / "metadata.csv"
26
- wavs = fpath / "wavs"
27
- return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
28
-
29
-
30
- def prepare_csv_wavs_dir(input_dir):
31
- assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
32
- input_dir = Path(input_dir)
33
- metadata_path = input_dir / "metadata.csv"
34
- audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
35
-
36
- sub_result, durations = [], []
37
- vocab_set = set()
38
- polyphone = True
39
- for audio_path, text in audio_path_text_pairs:
40
- if not Path(audio_path).exists():
41
- print(f"audio {audio_path} not found, skipping")
42
- continue
43
- audio_duration = get_audio_duration(audio_path)
44
- # assume tokenizer = "pinyin" ("pinyin" | "char")
45
- text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
46
- sub_result.append({"audio_path": audio_path, "text": text, "duration": audio_duration})
47
- durations.append(audio_duration)
48
- vocab_set.update(list(text))
49
-
50
- return sub_result, durations, vocab_set
51
-
52
-
53
- def get_audio_duration(audio_path):
54
- audio, sample_rate = torchaudio.load(audio_path)
55
- num_channels = audio.shape[0]
56
- return audio.shape[1] / (sample_rate * num_channels)
57
-
58
-
59
- def read_audio_text_pairs(csv_file_path):
60
- audio_text_pairs = []
61
-
62
- parent = Path(csv_file_path).parent
63
- with open(csv_file_path, mode="r", newline="", encoding="utf-8") as csvfile:
64
- reader = csv.reader(csvfile, delimiter="|")
65
- next(reader) # Skip the header row
66
- for row in reader:
67
- if len(row) >= 2:
68
- audio_file = row[0].strip() # First column: audio file path
69
- text = row[1].strip() # Second column: text
70
- audio_file_path = parent / audio_file
71
- audio_text_pairs.append((audio_file_path.as_posix(), text))
72
-
73
- return audio_text_pairs
74
-
75
-
76
- def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
77
- out_dir = Path(out_dir)
78
- # save preprocessed dataset to disk
79
- out_dir.mkdir(exist_ok=True, parents=True)
80
- print(f"\nSaving to {out_dir} ...")
81
-
82
- # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
83
- # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
84
- raw_arrow_path = out_dir / "raw.arrow"
85
- with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
86
- for line in tqdm(result, desc="Writing to raw.arrow ..."):
87
- writer.write(line)
88
-
89
- # dup a json separately saving duration in case for DynamicBatchSampler ease
90
- dur_json_path = out_dir / "duration.json"
91
- with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
92
- json.dump({"duration": duration_list}, f, ensure_ascii=False)
93
-
94
- # vocab map, i.e. tokenizer
95
- # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
96
- # if tokenizer == "pinyin":
97
- # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
98
- voca_out_path = out_dir / "vocab.txt"
99
- with open(voca_out_path.as_posix(), "w") as f:
100
- for vocab in sorted(text_vocab_set):
101
- f.write(vocab + "\n")
102
-
103
- if is_finetune:
104
- file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
105
- shutil.copy2(file_vocab_finetune, voca_out_path)
106
- else:
107
- with open(voca_out_path, "w") as f:
108
- for vocab in sorted(text_vocab_set):
109
- f.write(vocab + "\n")
110
-
111
- dataset_name = out_dir.stem
112
- print(f"\nFor {dataset_name}, sample count: {len(result)}")
113
- print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
114
- print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
115
-
116
-
117
- def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True):
118
- if is_finetune:
119
- assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
120
- sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
121
- save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
122
-
123
-
124
- def cli():
125
- # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
126
- # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
127
- parser = argparse.ArgumentParser(description="Prepare and save dataset.")
128
- parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
129
- parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
130
- parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
131
-
132
- args = parser.parse_args()
133
-
134
- prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
135
-
136
-
137
- if __name__ == "__main__":
138
- cli()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/prepare_emilia.py CHANGED
@@ -4,9 +4,7 @@
4
  # generate audio text map for Emilia ZH & EN
5
  # evaluate for vocab size
6
 
7
- import sys
8
- import os
9
-
10
  sys.path.append(os.getcwd())
11
 
12
  from pathlib import Path
@@ -14,6 +12,7 @@ import json
14
  from tqdm import tqdm
15
  from concurrent.futures import ProcessPoolExecutor
16
 
 
17
  from datasets.arrow_writer import ArrowWriter
18
 
19
  from model.utils import (
@@ -22,89 +21,13 @@ from model.utils import (
22
  )
23
 
24
 
25
- out_zh = {
26
- "ZH_B00041_S06226",
27
- "ZH_B00042_S09204",
28
- "ZH_B00065_S09430",
29
- "ZH_B00065_S09431",
30
- "ZH_B00066_S09327",
31
- "ZH_B00066_S09328",
32
- }
33
  zh_filters = ["い", "て"]
34
  # seems synthesized audios, or heavily code-switched
35
  out_en = {
36
- "EN_B00013_S00913",
37
- "EN_B00042_S00120",
38
- "EN_B00055_S04111",
39
- "EN_B00061_S00693",
40
- "EN_B00061_S01494",
41
- "EN_B00061_S03375",
42
- "EN_B00059_S00092",
43
- "EN_B00111_S04300",
44
- "EN_B00100_S03759",
45
- "EN_B00087_S03811",
46
- "EN_B00059_S00950",
47
- "EN_B00089_S00946",
48
- "EN_B00078_S05127",
49
- "EN_B00070_S04089",
50
- "EN_B00074_S09659",
51
- "EN_B00061_S06983",
52
- "EN_B00061_S07060",
53
- "EN_B00059_S08397",
54
- "EN_B00082_S06192",
55
- "EN_B00091_S01238",
56
- "EN_B00089_S07349",
57
- "EN_B00070_S04343",
58
- "EN_B00061_S02400",
59
- "EN_B00076_S01262",
60
- "EN_B00068_S06467",
61
- "EN_B00076_S02943",
62
- "EN_B00064_S05954",
63
- "EN_B00061_S05386",
64
- "EN_B00066_S06544",
65
- "EN_B00076_S06944",
66
- "EN_B00072_S08620",
67
- "EN_B00076_S07135",
68
- "EN_B00076_S09127",
69
- "EN_B00065_S00497",
70
- "EN_B00059_S06227",
71
- "EN_B00063_S02859",
72
- "EN_B00075_S01547",
73
- "EN_B00061_S08286",
74
- "EN_B00079_S02901",
75
- "EN_B00092_S03643",
76
- "EN_B00096_S08653",
77
- "EN_B00063_S04297",
78
- "EN_B00063_S04614",
79
- "EN_B00079_S04698",
80
- "EN_B00104_S01666",
81
- "EN_B00061_S09504",
82
- "EN_B00061_S09694",
83
- "EN_B00065_S05444",
84
- "EN_B00063_S06860",
85
- "EN_B00065_S05725",
86
- "EN_B00069_S07628",
87
- "EN_B00083_S03875",
88
- "EN_B00071_S07665",
89
- "EN_B00071_S07665",
90
- "EN_B00062_S04187",
91
- "EN_B00065_S09873",
92
- "EN_B00065_S09922",
93
- "EN_B00084_S02463",
94
- "EN_B00067_S05066",
95
- "EN_B00106_S08060",
96
- "EN_B00073_S06399",
97
- "EN_B00073_S09236",
98
- "EN_B00087_S00432",
99
- "EN_B00085_S05618",
100
- "EN_B00064_S01262",
101
- "EN_B00072_S01739",
102
- "EN_B00059_S03913",
103
- "EN_B00069_S04036",
104
- "EN_B00067_S05623",
105
- "EN_B00060_S05389",
106
- "EN_B00060_S07290",
107
- "EN_B00062_S08995",
108
  }
109
  en_filters = ["ا", "い", "て"]
110
 
@@ -120,24 +43,18 @@ def deal_with_audio_dir(audio_dir):
120
  for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
121
  obj = json.loads(line)
122
  text = obj["text"]
123
- if obj["language"] == "zh":
124
  if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
125
  bad_case_zh += 1
126
  continue
127
  else:
128
- text = text.translate(
129
- str.maketrans({",": ",", "!": "", "?": "?"})
130
- ) # not "" cuz much code-switched
131
- if obj["language"] == "en":
132
- if (
133
- obj["wav"].split("/")[1] in out_en
134
- or any(f in text for f in en_filters)
135
- or repetition_found(text, length=4)
136
- ):
137
  bad_case_en += 1
138
  continue
139
  if tokenizer == "pinyin":
140
- text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
141
  duration = obj["duration"]
142
  sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
143
  durations.append(duration)
@@ -179,11 +96,11 @@ def main():
179
  # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
180
  # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
181
  with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
182
- for line in tqdm(result, desc="Writing to raw.arrow ..."):
183
  writer.write(line)
184
 
185
  # dup a json separately saving duration in case for DynamicBatchSampler ease
186
- with open(f"data/{dataset_name}/duration.json", "w", encoding="utf-8") as f:
187
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
188
 
189
  # vocab map, i.e. tokenizer
@@ -197,13 +114,12 @@ def main():
197
  print(f"\nFor {dataset_name}, sample count: {len(result)}")
198
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
199
  print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
200
- if "ZH" in langs:
201
- print(f"Bad zh transcription case: {total_bad_case_zh}")
202
- if "EN" in langs:
203
- print(f"Bad en transcription case: {total_bad_case_en}\n")
204
 
205
 
206
  if __name__ == "__main__":
 
207
  max_workers = 32
208
 
209
  tokenizer = "pinyin" # "pinyin" | "char"
 
4
  # generate audio text map for Emilia ZH & EN
5
  # evaluate for vocab size
6
 
7
+ import sys, os
 
 
8
  sys.path.append(os.getcwd())
9
 
10
  from pathlib import Path
 
12
  from tqdm import tqdm
13
  from concurrent.futures import ProcessPoolExecutor
14
 
15
+ from datasets import Dataset
16
  from datasets.arrow_writer import ArrowWriter
17
 
18
  from model.utils import (
 
21
  )
22
 
23
 
24
+ out_zh = {"ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328"}
 
 
 
 
 
 
 
25
  zh_filters = ["い", "て"]
26
  # seems synthesized audios, or heavily code-switched
27
  out_en = {
28
+ "EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375",
29
+
30
+ "EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  }
32
  en_filters = ["ا", "い", "て"]
33
 
 
43
  for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
44
  obj = json.loads(line)
45
  text = obj["text"]
46
+ if obj['language'] == "zh":
47
  if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
48
  bad_case_zh += 1
49
  continue
50
  else:
51
+ text = text.translate(str.maketrans({',': ',', '!': '!', '?': '?'})) # not "。" cuz much code-switched
52
+ if obj['language'] == "en":
53
+ if obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4):
 
 
 
 
 
 
54
  bad_case_en += 1
55
  continue
56
  if tokenizer == "pinyin":
57
+ text = convert_char_to_pinyin([text], polyphone = polyphone)[0]
58
  duration = obj["duration"]
59
  sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
60
  durations.append(duration)
 
96
  # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
97
  # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
98
  with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
99
+ for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
100
  writer.write(line)
101
 
102
  # dup a json separately saving duration in case for DynamicBatchSampler ease
103
+ with open(f"data/{dataset_name}/duration.json", 'w', encoding='utf-8') as f:
104
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
105
 
106
  # vocab map, i.e. tokenizer
 
114
  print(f"\nFor {dataset_name}, sample count: {len(result)}")
115
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
116
  print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
117
+ if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}")
118
+ if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n")
 
 
119
 
120
 
121
  if __name__ == "__main__":
122
+
123
  max_workers = 32
124
 
125
  tokenizer = "pinyin" # "pinyin" | "char"
scripts/prepare_wenetspeech4tts.py CHANGED
@@ -1,9 +1,7 @@
1
  # generate audio text map for WenetSpeech4TTS
2
  # evaluate for vocab size
3
 
4
- import sys
5
- import os
6
-
7
  sys.path.append(os.getcwd())
8
 
9
  import json
@@ -25,7 +23,7 @@ def deal_with_sub_path_files(dataset_path, sub_path):
25
 
26
  audio_paths, texts, durations = [], [], []
27
  for text_file in tqdm(text_files):
28
- with open(os.path.join(text_dir, text_file), "r", encoding="utf-8") as file:
29
  first_line = file.readline().split("\t")
30
  audio_nm = first_line[0]
31
  audio_path = os.path.join(audio_dir, audio_nm + ".wav")
@@ -34,7 +32,7 @@ def deal_with_sub_path_files(dataset_path, sub_path):
34
  audio_paths.append(audio_path)
35
 
36
  if tokenizer == "pinyin":
37
- texts.extend(convert_char_to_pinyin([text], polyphone=polyphone))
38
  elif tokenizer == "char":
39
  texts.append(text)
40
 
@@ -48,7 +46,7 @@ def main():
48
  assert tokenizer in ["pinyin", "char"]
49
 
50
  audio_path_list, text_list, duration_list = [], [], []
51
-
52
  executor = ProcessPoolExecutor(max_workers=max_workers)
53
  futures = []
54
  for dataset_path in dataset_paths:
@@ -70,10 +68,8 @@ def main():
70
  dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
71
  dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
72
 
73
- with open(f"data/{dataset_name}_{tokenizer}/duration.json", "w", encoding="utf-8") as f:
74
- json.dump(
75
- {"duration": duration_list}, f, ensure_ascii=False
76
- ) # dup a json separately saving duration in case for DynamicBatchSampler ease
77
 
78
  print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
79
  text_vocab_set = set()
@@ -89,21 +85,22 @@ def main():
89
  f.write(vocab + "\n")
90
  print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
91
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
92
-
93
 
94
  if __name__ == "__main__":
 
95
  max_workers = 32
96
 
97
  tokenizer = "pinyin" # "pinyin" | "char"
98
  polyphone = True
99
  dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
100
 
101
- dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1]
102
  dataset_paths = [
103
  "<SOME_PATH>/WenetSpeech4TTS/Basic",
104
  "<SOME_PATH>/WenetSpeech4TTS/Standard",
105
  "<SOME_PATH>/WenetSpeech4TTS/Premium",
106
- ][-dataset_choice:]
107
  print(f"\nChoose Dataset: {dataset_name}\n")
108
 
109
  main()
@@ -112,8 +109,8 @@ if __name__ == "__main__":
112
  # WenetSpeech4TTS Basic Standard Premium
113
  # samples count 3932473 1941220 407494
114
  # pinyin vocab size 1349 1348 1344 (no polyphone)
115
- # - - 1459 (polyphone)
116
  # char vocab size 5264 5219 5042
117
-
118
  # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
119
  # please be careful if using pretrained model, make sure the vocab.txt is same
 
1
  # generate audio text map for WenetSpeech4TTS
2
  # evaluate for vocab size
3
 
4
+ import sys, os
 
 
5
  sys.path.append(os.getcwd())
6
 
7
  import json
 
23
 
24
  audio_paths, texts, durations = [], [], []
25
  for text_file in tqdm(text_files):
26
+ with open(os.path.join(text_dir, text_file), 'r', encoding='utf-8') as file:
27
  first_line = file.readline().split("\t")
28
  audio_nm = first_line[0]
29
  audio_path = os.path.join(audio_dir, audio_nm + ".wav")
 
32
  audio_paths.append(audio_path)
33
 
34
  if tokenizer == "pinyin":
35
+ texts.extend(convert_char_to_pinyin([text], polyphone = polyphone))
36
  elif tokenizer == "char":
37
  texts.append(text)
38
 
 
46
  assert tokenizer in ["pinyin", "char"]
47
 
48
  audio_path_list, text_list, duration_list = [], [], []
49
+
50
  executor = ProcessPoolExecutor(max_workers=max_workers)
51
  futures = []
52
  for dataset_path in dataset_paths:
 
68
  dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
69
  dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
70
 
71
+ with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'w', encoding='utf-8') as f:
72
+ json.dump({"duration": duration_list}, f, ensure_ascii=False) # dup a json separately saving duration in case for DynamicBatchSampler ease
 
 
73
 
74
  print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
75
  text_vocab_set = set()
 
85
  f.write(vocab + "\n")
86
  print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
87
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
88
+
89
 
90
  if __name__ == "__main__":
91
+
92
  max_workers = 32
93
 
94
  tokenizer = "pinyin" # "pinyin" | "char"
95
  polyphone = True
96
  dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
97
 
98
+ dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1]
99
  dataset_paths = [
100
  "<SOME_PATH>/WenetSpeech4TTS/Basic",
101
  "<SOME_PATH>/WenetSpeech4TTS/Standard",
102
  "<SOME_PATH>/WenetSpeech4TTS/Premium",
103
+ ][-dataset_choice:]
104
  print(f"\nChoose Dataset: {dataset_name}\n")
105
 
106
  main()
 
109
  # WenetSpeech4TTS Basic Standard Premium
110
  # samples count 3932473 1941220 407494
111
  # pinyin vocab size 1349 1348 1344 (no polyphone)
112
+ # - - 1459 (polyphone)
113
  # char vocab size 5264 5219 5042
114
+
115
  # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
116
  # please be careful if using pretrained model, make sure the vocab.txt is same
scripts/eval_infer_batch.py → test_infer_batch.py RENAMED
@@ -1,8 +1,4 @@
1
- import sys
2
  import os
3
-
4
- sys.path.append(os.getcwd())
5
-
6
  import time
7
  import random
8
  from tqdm import tqdm
@@ -11,14 +7,15 @@ import argparse
11
  import torch
12
  import torchaudio
13
  from accelerate import Accelerator
 
 
14
  from vocos import Vocos
15
 
16
  from model import CFM, UNetT, DiT
17
  from model.utils import (
18
- load_checkpoint,
19
- get_tokenizer,
20
- get_seedtts_testset_metainfo,
21
- get_librispeech_test_clean_metainfo,
22
  get_inference_prompt,
23
  )
24
 
@@ -40,16 +37,16 @@ tokenizer = "pinyin"
40
 
41
  parser = argparse.ArgumentParser(description="batch inference")
42
 
43
- parser.add_argument("-s", "--seed", default=None, type=int)
44
- parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
45
- parser.add_argument("-n", "--expname", required=True)
46
- parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
47
 
48
- parser.add_argument("-nfe", "--nfestep", default=32, type=int)
49
- parser.add_argument("-o", "--odemethod", default="euler")
50
- parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
51
 
52
- parser.add_argument("-t", "--testset", required=True)
53
 
54
  args = parser.parse_args()
55
 
@@ -58,7 +55,7 @@ seed = args.seed
58
  dataset_name = args.dataset
59
  exp_name = args.expname
60
  ckpt_step = args.ckptstep
61
- ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
62
 
63
  nfe_step = args.nfestep
64
  ode_method = args.odemethod
@@ -68,26 +65,26 @@ testset = args.testset
68
 
69
 
70
  infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
71
- cfg_strength = 2.0
72
- speed = 1.0
73
  use_truth_duration = False
74
  no_ref_audio = False
75
 
76
 
77
  if exp_name == "F5TTS_Base":
78
  model_cls = DiT
79
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
80
 
81
  elif exp_name == "E2TTS_Base":
82
  model_cls = UNetT
83
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
84
 
85
 
86
  if testset == "ls_pc_test_clean":
87
  metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
88
  librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
89
  metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
90
-
91
  elif testset == "seedtts_test_zh":
92
  metalst = "data/seedtts_testset/zh/meta.lst"
93
  metainfo = get_seedtts_testset_metainfo(metalst)
@@ -98,16 +95,13 @@ elif testset == "seedtts_test_en":
98
 
99
 
100
  # path to save genereted wavs
101
- if seed is None:
102
- seed = random.randint(-10000, 10000)
103
- output_dir = (
104
- f"results/{exp_name}_{ckpt_step}/{testset}/"
105
- f"seed{seed}_{ode_method}_nfe{nfe_step}"
106
- f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
107
- f"_cfg{cfg_strength}_speed{speed}"
108
- f"{'_gt-dur' if use_truth_duration else ''}"
109
  f"{'_no-ref-audio' if no_ref_audio else ''}"
110
- )
111
 
112
 
113
  # -------------------------------------------------#
@@ -115,15 +109,15 @@ output_dir = (
115
  use_ema = True
116
 
117
  prompts_all = get_inference_prompt(
118
- metainfo,
119
- speed=speed,
120
- tokenizer=tokenizer,
121
- target_sample_rate=target_sample_rate,
122
- n_mel_channels=n_mel_channels,
123
- hop_length=hop_length,
124
- target_rms=target_rms,
125
- use_truth_duration=use_truth_duration,
126
- infer_batch_size=infer_batch_size,
127
  )
128
 
129
  # Vocoder model
@@ -131,7 +125,7 @@ local = False
131
  if local:
132
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
133
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
134
- state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
135
  vocos.load_state_dict(state_dict)
136
  vocos.eval()
137
  else:
@@ -142,19 +136,28 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
142
 
143
  # Model
144
  model = CFM(
145
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
146
- mel_spec_kwargs=dict(
147
- target_sample_rate=target_sample_rate,
148
- n_mel_channels=n_mel_channels,
149
- hop_length=hop_length,
150
  ),
151
- odeint_kwargs=dict(
152
- method=ode_method,
 
 
153
  ),
154
- vocab_char_map=vocab_char_map,
 
 
 
155
  ).to(device)
156
 
157
- model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
 
 
 
 
 
158
 
159
  if not os.path.exists(output_dir) and accelerator.is_main_process:
160
  os.makedirs(output_dir)
@@ -164,29 +167,30 @@ accelerator.wait_for_everyone()
164
  start = time.time()
165
 
166
  with accelerator.split_between_processes(prompts_all) as prompts:
 
167
  for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
168
  utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
169
  ref_mels = ref_mels.to(device)
170
- ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
171
- total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
172
-
173
  # Inference
174
  with torch.inference_mode():
175
  generated, _ = model.sample(
176
- cond=ref_mels,
177
- text=final_text_list,
178
- duration=total_mel_lens,
179
- lens=ref_mel_lens,
180
- steps=nfe_step,
181
- cfg_strength=cfg_strength,
182
- sway_sampling_coef=sway_sampling_coef,
183
- no_ref_audio=no_ref_audio,
184
- seed=seed,
185
  )
186
  # Final result
187
  for i, gen in enumerate(generated):
188
- gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
189
- gen_mel_spec = gen.permute(0, 2, 1)
190
  generated_wave = vocos.decode(gen_mel_spec.cpu())
191
  if ref_rms_list[i] < target_rms:
192
  generated_wave = generated_wave * ref_rms_list[i] / target_rms
 
 
1
  import os
 
 
 
2
  import time
3
  import random
4
  from tqdm import tqdm
 
7
  import torch
8
  import torchaudio
9
  from accelerate import Accelerator
10
+ from einops import rearrange
11
+ from ema_pytorch import EMA
12
  from vocos import Vocos
13
 
14
  from model import CFM, UNetT, DiT
15
  from model.utils import (
16
+ get_tokenizer,
17
+ get_seedtts_testset_metainfo,
18
+ get_librispeech_test_clean_metainfo,
 
19
  get_inference_prompt,
20
  )
21
 
 
37
 
38
  parser = argparse.ArgumentParser(description="batch inference")
39
 
40
+ parser.add_argument('-s', '--seed', default=None, type=int)
41
+ parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
42
+ parser.add_argument('-n', '--expname', required=True)
43
+ parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
44
 
45
+ parser.add_argument('-nfe', '--nfestep', default=32, type=int)
46
+ parser.add_argument('-o', '--odemethod', default="euler")
47
+ parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
48
 
49
+ parser.add_argument('-t', '--testset', required=True)
50
 
51
  args = parser.parse_args()
52
 
 
55
  dataset_name = args.dataset
56
  exp_name = args.expname
57
  ckpt_step = args.ckptstep
58
+ checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
59
 
60
  nfe_step = args.nfestep
61
  ode_method = args.odemethod
 
65
 
66
 
67
  infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
68
+ cfg_strength = 2.
69
+ speed = 1.
70
  use_truth_duration = False
71
  no_ref_audio = False
72
 
73
 
74
  if exp_name == "F5TTS_Base":
75
  model_cls = DiT
76
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
77
 
78
  elif exp_name == "E2TTS_Base":
79
  model_cls = UNetT
80
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
81
 
82
 
83
  if testset == "ls_pc_test_clean":
84
  metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
85
  librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
86
  metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
87
+
88
  elif testset == "seedtts_test_zh":
89
  metalst = "data/seedtts_testset/zh/meta.lst"
90
  metainfo = get_seedtts_testset_metainfo(metalst)
 
95
 
96
 
97
  # path to save genereted wavs
98
+ if seed is None: seed = random.randint(-10000, 10000)
99
+ output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
100
+ f"seed{seed}_{ode_method}_nfe{nfe_step}" \
101
+ f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
102
+ f"_cfg{cfg_strength}_speed{speed}" \
103
+ f"{'_gt-dur' if use_truth_duration else ''}" \
 
 
104
  f"{'_no-ref-audio' if no_ref_audio else ''}"
 
105
 
106
 
107
  # -------------------------------------------------#
 
109
  use_ema = True
110
 
111
  prompts_all = get_inference_prompt(
112
+ metainfo,
113
+ speed = speed,
114
+ tokenizer = tokenizer,
115
+ target_sample_rate = target_sample_rate,
116
+ n_mel_channels = n_mel_channels,
117
+ hop_length = hop_length,
118
+ target_rms = target_rms,
119
+ use_truth_duration = use_truth_duration,
120
+ infer_batch_size = infer_batch_size,
121
  )
122
 
123
  # Vocoder model
 
125
  if local:
126
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
127
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
128
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
129
  vocos.load_state_dict(state_dict)
130
  vocos.eval()
131
  else:
 
136
 
137
  # Model
138
  model = CFM(
139
+ transformer = model_cls(
140
+ **model_cfg,
141
+ text_num_embeds = vocab_size,
142
+ mel_dim = n_mel_channels
 
143
  ),
144
+ mel_spec_kwargs = dict(
145
+ target_sample_rate = target_sample_rate,
146
+ n_mel_channels = n_mel_channels,
147
+ hop_length = hop_length,
148
  ),
149
+ odeint_kwargs = dict(
150
+ method = ode_method,
151
+ ),
152
+ vocab_char_map = vocab_char_map,
153
  ).to(device)
154
 
155
+ if use_ema == True:
156
+ ema_model = EMA(model, include_online_model = False).to(device)
157
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
158
+ ema_model.copy_params_from_ema_to_model()
159
+ else:
160
+ model.load_state_dict(checkpoint['model_state_dict'])
161
 
162
  if not os.path.exists(output_dir) and accelerator.is_main_process:
163
  os.makedirs(output_dir)
 
167
  start = time.time()
168
 
169
  with accelerator.split_between_processes(prompts_all) as prompts:
170
+
171
  for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
172
  utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
173
  ref_mels = ref_mels.to(device)
174
+ ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
175
+ total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
176
+
177
  # Inference
178
  with torch.inference_mode():
179
  generated, _ = model.sample(
180
+ cond = ref_mels,
181
+ text = final_text_list,
182
+ duration = total_mel_lens,
183
+ lens = ref_mel_lens,
184
+ steps = nfe_step,
185
+ cfg_strength = cfg_strength,
186
+ sway_sampling_coef = sway_sampling_coef,
187
+ no_ref_audio = no_ref_audio,
188
+ seed = seed,
189
  )
190
  # Final result
191
  for i, gen in enumerate(generated):
192
+ gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
193
+ gen_mel_spec = rearrange(gen, '1 n d -> 1 d n')
194
  generated_wave = vocos.decode(gen_mel_spec.cpu())
195
  if ref_rms_list[i] < target_rms:
196
  generated_wave = generated_wave * ref_rms_list[i] / target_rms
test_infer_batch.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # e.g. F5-TTS, 16 NFE
4
+ accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
+ accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
+ accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
+
8
+ # e.g. Vanilla E2 TTS, 32 NFE
9
+ accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
+ accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
+ accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
+
13
+ # etc.
speech_edit.py → test_infer_single.py RENAMED
@@ -1,19 +1,20 @@
1
  import os
 
2
 
3
  import torch
4
- import torch.nn.functional as F
5
  import torchaudio
 
 
6
  from vocos import Vocos
7
 
8
- from model import CFM, UNetT, DiT
9
  from model.utils import (
10
- load_checkpoint,
11
- get_tokenizer,
12
- convert_char_to_pinyin,
13
  save_spectrogram,
14
  )
15
 
16
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
17
 
18
 
19
  # --------------------- Dataset Settings -------------------- #
@@ -35,47 +36,30 @@ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
35
  ckpt_step = 1200000
36
 
37
  nfe_step = 32 # 16, 32
38
- cfg_strength = 2.0
39
- ode_method = "euler" # euler | midpoint
40
- sway_sampling_coef = -1.0
41
- speed = 1.0
 
42
 
43
  if exp_name == "F5TTS_Base":
44
  model_cls = DiT
45
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
46
 
47
  elif exp_name == "E2TTS_Base":
48
  model_cls = UNetT
49
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
50
 
51
- ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
52
  output_dir = "tests"
53
 
54
- # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
55
- # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
56
- # [write the origin_text into a file, e.g. tests/test_edit.txt]
57
- # ctc-forced-aligner --audio_path "tests/ref_audio/test_en_1_ref_short.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char"
58
- # [result will be saved at same path of audio file]
59
- # [--language "zho" for Chinese, "eng" for English]
60
- # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
61
-
62
- audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav"
63
- origin_text = "Some call me nature, others call me mother nature."
64
- target_text = "Some call me optimist, others call me realist."
65
- parts_to_edit = [
66
- [1.42, 2.44],
67
- [4.04, 4.9],
68
- ] # stard_ends of "nature" & "mother nature", in seconds
69
- fix_duration = [
70
- 1.2,
71
- 1,
72
- ] # fix duration for "optimist" & "realist", in seconds
73
-
74
- # audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav"
75
- # origin_text = "对,这就是我,万人敬仰的太乙真人。"
76
- # target_text = "对,那就是你,万人敬仰的太白金星。"
77
- # parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
78
- # fix_duration = None # use origin text duration
79
 
80
 
81
  # -------------------------------------------------#
@@ -90,9 +74,8 @@ local = False
90
  if local:
91
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
92
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
93
- state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
94
  vocos.load_state_dict(state_dict)
95
-
96
  vocos.eval()
97
  else:
98
  vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
@@ -102,55 +85,41 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
102
 
103
  # Model
104
  model = CFM(
105
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
106
- mel_spec_kwargs=dict(
107
- target_sample_rate=target_sample_rate,
108
- n_mel_channels=n_mel_channels,
109
- hop_length=hop_length,
 
 
 
 
110
  ),
111
- odeint_kwargs=dict(
112
- method=ode_method,
113
  ),
114
- vocab_char_map=vocab_char_map,
115
  ).to(device)
116
 
117
- model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
 
 
 
 
 
118
 
119
  # Audio
120
- audio, sr = torchaudio.load(audio_to_edit)
121
- if audio.shape[0] > 1:
122
- audio = torch.mean(audio, dim=0, keepdim=True)
123
  rms = torch.sqrt(torch.mean(torch.square(audio)))
124
  if rms < target_rms:
125
  audio = audio * target_rms / rms
126
  if sr != target_sample_rate:
127
  resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
128
  audio = resampler(audio)
129
- offset = 0
130
- audio_ = torch.zeros(1, 0)
131
- edit_mask = torch.zeros(1, 0, dtype=torch.bool)
132
- for part in parts_to_edit:
133
- start, end = part
134
- part_dur = end - start if fix_duration is None else fix_duration.pop(0)
135
- part_dur = part_dur * target_sample_rate
136
- start = start * target_sample_rate
137
- audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1)
138
- edit_mask = torch.cat(
139
- (
140
- edit_mask,
141
- torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool),
142
- torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool),
143
- ),
144
- dim=-1,
145
- )
146
- offset = end * target_sample_rate
147
- # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
148
- edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
149
  audio = audio.to(device)
150
- edit_mask = edit_mask.to(device)
151
 
152
  # Text
153
- text_list = [target_text]
154
  if tokenizer == "pinyin":
155
  final_text_list = convert_char_to_pinyin(text_list)
156
  else:
@@ -159,31 +128,35 @@ print(f"text : {text_list}")
159
  print(f"pinyin: {final_text_list}")
160
 
161
  # Duration
162
- ref_audio_len = 0
163
- duration = audio.shape[-1] // hop_length
 
 
 
 
 
 
164
 
165
  # Inference
166
  with torch.inference_mode():
167
  generated, trajectory = model.sample(
168
- cond=audio,
169
- text=final_text_list,
170
- duration=duration,
171
- steps=nfe_step,
172
- cfg_strength=cfg_strength,
173
- sway_sampling_coef=sway_sampling_coef,
174
- seed=seed,
175
- edit_mask=edit_mask,
176
  )
177
  print(f"Generated mel: {generated.shape}")
178
 
179
  # Final result
180
- generated = generated.to(torch.float32)
181
  generated = generated[:, ref_audio_len:, :]
182
- generated_mel_spec = generated.permute(0, 2, 1)
183
  generated_wave = vocos.decode(generated_mel_spec.cpu())
184
  if rms < target_rms:
185
  generated_wave = generated_wave * rms / target_rms
186
 
187
- save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
188
- torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
189
  print(f"Generated wav: {generated_wave.shape}")
 
1
  import os
2
+ import re
3
 
4
  import torch
 
5
  import torchaudio
6
+ from einops import rearrange
7
+ from ema_pytorch import EMA
8
  from vocos import Vocos
9
 
10
+ from model import CFM, UNetT, DiT, MMDiT
11
  from model.utils import (
12
+ get_tokenizer,
13
+ convert_char_to_pinyin,
 
14
  save_spectrogram,
15
  )
16
 
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
 
20
  # --------------------- Dataset Settings -------------------- #
 
36
  ckpt_step = 1200000
37
 
38
  nfe_step = 32 # 16, 32
39
+ cfg_strength = 2.
40
+ ode_method = 'euler' # euler | midpoint
41
+ sway_sampling_coef = -1.
42
+ speed = 1.
43
+ fix_duration = 27 # None (will linear estimate. if code-switched, consider fix) | float (total in seconds, include ref audio)
44
 
45
  if exp_name == "F5TTS_Base":
46
  model_cls = DiT
47
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
48
 
49
  elif exp_name == "E2TTS_Base":
50
  model_cls = UNetT
51
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
52
 
53
+ checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
54
  output_dir = "tests"
55
 
56
+ ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
57
+ ref_text = "Some call me nature, others call me mother nature."
58
+ gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
59
+
60
+ # ref_audio = "tests/ref_audio/test_zh_1_ref_short.wav"
61
+ # ref_text = "对,这就是我,万人敬仰的太乙真人。"
62
+ # gen_text = "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:\"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?\""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
 
65
  # -------------------------------------------------#
 
74
  if local:
75
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
76
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
77
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
78
  vocos.load_state_dict(state_dict)
 
79
  vocos.eval()
80
  else:
81
  vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
 
85
 
86
  # Model
87
  model = CFM(
88
+ transformer = model_cls(
89
+ **model_cfg,
90
+ text_num_embeds = vocab_size,
91
+ mel_dim = n_mel_channels
92
+ ),
93
+ mel_spec_kwargs = dict(
94
+ target_sample_rate = target_sample_rate,
95
+ n_mel_channels = n_mel_channels,
96
+ hop_length = hop_length,
97
  ),
98
+ odeint_kwargs = dict(
99
+ method = ode_method,
100
  ),
101
+ vocab_char_map = vocab_char_map,
102
  ).to(device)
103
 
104
+ if use_ema == True:
105
+ ema_model = EMA(model, include_online_model = False).to(device)
106
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
107
+ ema_model.copy_params_from_ema_to_model()
108
+ else:
109
+ model.load_state_dict(checkpoint['model_state_dict'])
110
 
111
  # Audio
112
+ audio, sr = torchaudio.load(ref_audio)
 
 
113
  rms = torch.sqrt(torch.mean(torch.square(audio)))
114
  if rms < target_rms:
115
  audio = audio * target_rms / rms
116
  if sr != target_sample_rate:
117
  resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
118
  audio = resampler(audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  audio = audio.to(device)
 
120
 
121
  # Text
122
+ text_list = [ref_text + gen_text]
123
  if tokenizer == "pinyin":
124
  final_text_list = convert_char_to_pinyin(text_list)
125
  else:
 
128
  print(f"pinyin: {final_text_list}")
129
 
130
  # Duration
131
+ ref_audio_len = audio.shape[-1] // hop_length
132
+ if fix_duration is not None:
133
+ duration = int(fix_duration * target_sample_rate / hop_length)
134
+ else: # simple linear scale calcul
135
+ zh_pause_punc = r"。,、;:?!"
136
+ ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
137
+ gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
138
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
139
 
140
  # Inference
141
  with torch.inference_mode():
142
  generated, trajectory = model.sample(
143
+ cond = audio,
144
+ text = final_text_list,
145
+ duration = duration,
146
+ steps = nfe_step,
147
+ cfg_strength = cfg_strength,
148
+ sway_sampling_coef = sway_sampling_coef,
149
+ seed = seed,
 
150
  )
151
  print(f"Generated mel: {generated.shape}")
152
 
153
  # Final result
 
154
  generated = generated[:, ref_audio_len:, :]
155
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
156
  generated_wave = vocos.decode(generated_mel_spec.cpu())
157
  if rms < target_rms:
158
  generated_wave = generated_wave * rms / target_rms
159
 
160
+ save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single.png")
161
+ torchaudio.save(f"{output_dir}/test_single.wav", generated_wave, target_sample_rate)
162
  print(f"Generated wav: {generated_wave.shape}")
train.py → test_train.py RENAMED
@@ -1,4 +1,4 @@
1
- from model import CFM, UNetT, DiT, Trainer
2
  from model.utils import get_tokenizer
3
  from model.dataset import load_dataset
4
 
@@ -9,10 +9,10 @@ target_sample_rate = 24000
9
  n_mel_channels = 100
10
  hop_length = 256
11
 
12
- tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
13
- tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
14
  dataset_name = "Emilia_ZH_EN"
15
 
 
16
  # -------------------------- Training Settings -------------------------- #
17
 
18
  exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
@@ -23,7 +23,7 @@ batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200
23
  batch_size_type = "frame" # "frame" or "sample"
24
  max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
25
  grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
26
- max_grad_norm = 1.0
27
 
28
  epochs = 11 # use linear decay, thus epochs control the slope
29
  num_warmup_updates = 20000 # warmup steps
@@ -34,59 +34,58 @@ last_per_steps = 5000 # save last checkpoint per steps
34
  if exp_name == "F5TTS_Base":
35
  wandb_resume_id = None
36
  model_cls = DiT
37
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
38
  elif exp_name == "E2TTS_Base":
39
  wandb_resume_id = None
40
  model_cls = UNetT
41
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
42
 
43
 
44
  # ----------------------------------------------------------------------- #
45
 
46
-
47
  def main():
48
- if tokenizer == "custom":
49
- tokenizer_path = tokenizer_path
50
- else:
51
- tokenizer_path = dataset_name
52
- vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
53
 
54
- mel_spec_kwargs = dict(
55
- target_sample_rate=target_sample_rate,
56
- n_mel_channels=n_mel_channels,
57
- hop_length=hop_length,
58
- )
59
 
60
- model = CFM(
61
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
62
- mel_spec_kwargs=mel_spec_kwargs,
63
- vocab_char_map=vocab_char_map,
 
 
 
 
 
 
 
 
 
 
64
  )
65
 
66
  trainer = Trainer(
67
- model,
68
- epochs,
69
  learning_rate,
70
- num_warmup_updates=num_warmup_updates,
71
- save_per_updates=save_per_updates,
72
- checkpoint_path=f"ckpts/{exp_name}",
73
- batch_size=batch_size_per_gpu,
74
- batch_size_type=batch_size_type,
75
- max_samples=max_samples,
76
- grad_accumulation_steps=grad_accumulation_steps,
77
- max_grad_norm=max_grad_norm,
78
- wandb_project="CFM-TTS",
79
- wandb_run_name=exp_name,
80
- wandb_resume_id=wandb_resume_id,
81
- last_per_steps=last_per_steps,
82
  )
83
 
84
  train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
85
- trainer.train(
86
- train_dataset,
87
- resumable_with_seed=666, # seed for shuffling dataset
88
- )
89
 
90
 
91
- if __name__ == "__main__":
92
  main()
 
1
+ from model import CFM, UNetT, DiT, MMDiT, Trainer
2
  from model.utils import get_tokenizer
3
  from model.dataset import load_dataset
4
 
 
9
  n_mel_channels = 100
10
  hop_length = 256
11
 
12
+ tokenizer = "pinyin"
 
13
  dataset_name = "Emilia_ZH_EN"
14
 
15
+
16
  # -------------------------- Training Settings -------------------------- #
17
 
18
  exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
 
23
  batch_size_type = "frame" # "frame" or "sample"
24
  max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
25
  grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
26
+ max_grad_norm = 1.
27
 
28
  epochs = 11 # use linear decay, thus epochs control the slope
29
  num_warmup_updates = 20000 # warmup steps
 
34
  if exp_name == "F5TTS_Base":
35
  wandb_resume_id = None
36
  model_cls = DiT
37
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
38
  elif exp_name == "E2TTS_Base":
39
  wandb_resume_id = None
40
  model_cls = UNetT
41
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
42
 
43
 
44
  # ----------------------------------------------------------------------- #
45
 
 
46
  def main():
 
 
 
 
 
47
 
48
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
 
 
 
 
49
 
50
+ mel_spec_kwargs = dict(
51
+ target_sample_rate = target_sample_rate,
52
+ n_mel_channels = n_mel_channels,
53
+ hop_length = hop_length,
54
+ )
55
+
56
+ e2tts = CFM(
57
+ transformer = model_cls(
58
+ **model_cfg,
59
+ text_num_embeds = vocab_size,
60
+ mel_dim = n_mel_channels
61
+ ),
62
+ mel_spec_kwargs = mel_spec_kwargs,
63
+ vocab_char_map = vocab_char_map,
64
  )
65
 
66
  trainer = Trainer(
67
+ e2tts,
68
+ epochs,
69
  learning_rate,
70
+ num_warmup_updates = num_warmup_updates,
71
+ save_per_updates = save_per_updates,
72
+ checkpoint_path = f'ckpts/{exp_name}',
73
+ batch_size = batch_size_per_gpu,
74
+ batch_size_type = batch_size_type,
75
+ max_samples = max_samples,
76
+ grad_accumulation_steps = grad_accumulation_steps,
77
+ max_grad_norm = max_grad_norm,
78
+ wandb_project = "CFM-TTS",
79
+ wandb_run_name = exp_name,
80
+ wandb_resume_id = wandb_resume_id,
81
+ last_per_steps = last_per_steps,
82
  )
83
 
84
  train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
85
+ trainer.train(train_dataset,
86
+ resumable_with_seed = 666 # seed for shuffling dataset
87
+ )
 
88
 
89
 
90
+ if __name__ == '__main__':
91
  main()