Zihan428 commited on
Commit
f79db70
·
0 Parent(s):

Clean multilingual TTS repo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +48 -0
  2. README.md +10 -0
  3. app.py +319 -0
  4. requirements.txt +18 -0
  5. src/chatterbox/__init__.py +11 -0
  6. src/chatterbox/models/__init__.py +0 -0
  7. src/chatterbox/models/s3gen/__init__.py +2 -0
  8. src/chatterbox/models/s3gen/configs.py +10 -0
  9. src/chatterbox/models/s3gen/const.py +1 -0
  10. src/chatterbox/models/s3gen/decoder.py +317 -0
  11. src/chatterbox/models/s3gen/f0_predictor.py +55 -0
  12. src/chatterbox/models/s3gen/flow.py +290 -0
  13. src/chatterbox/models/s3gen/flow_matching.py +218 -0
  14. src/chatterbox/models/s3gen/hifigan.py +474 -0
  15. src/chatterbox/models/s3gen/matcha/decoder.py +443 -0
  16. src/chatterbox/models/s3gen/matcha/flow_matching.py +129 -0
  17. src/chatterbox/models/s3gen/matcha/text_encoder.py +413 -0
  18. src/chatterbox/models/s3gen/matcha/transformer.py +316 -0
  19. src/chatterbox/models/s3gen/s3gen.py +298 -0
  20. src/chatterbox/models/s3gen/transformer/__init__.py +0 -0
  21. src/chatterbox/models/s3gen/transformer/activation.py +84 -0
  22. src/chatterbox/models/s3gen/transformer/attention.py +330 -0
  23. src/chatterbox/models/s3gen/transformer/convolution.py +145 -0
  24. src/chatterbox/models/s3gen/transformer/embedding.py +294 -0
  25. src/chatterbox/models/s3gen/transformer/encoder_layer.py +236 -0
  26. src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py +115 -0
  27. src/chatterbox/models/s3gen/transformer/subsampling.py +383 -0
  28. src/chatterbox/models/s3gen/transformer/upsample_encoder.py +318 -0
  29. src/chatterbox/models/s3gen/utils/class_utils.py +71 -0
  30. src/chatterbox/models/s3gen/utils/mask.py +193 -0
  31. src/chatterbox/models/s3gen/utils/mel.py +85 -0
  32. src/chatterbox/models/s3gen/xvector.py +428 -0
  33. src/chatterbox/models/s3tokenizer/__init__.py +30 -0
  34. src/chatterbox/models/s3tokenizer/s3tokenizer.py +168 -0
  35. src/chatterbox/models/t3/__init__.py +1 -0
  36. src/chatterbox/models/t3/inference/alignment_stream_analyzer.py +178 -0
  37. src/chatterbox/models/t3/inference/t3_hf_backend.py +116 -0
  38. src/chatterbox/models/t3/llama_configs.py +37 -0
  39. src/chatterbox/models/t3/modules/cond_enc.py +97 -0
  40. src/chatterbox/models/t3/modules/learned_pos_emb.py +32 -0
  41. src/chatterbox/models/t3/modules/perceiver.py +212 -0
  42. src/chatterbox/models/t3/modules/t3_config.py +37 -0
  43. src/chatterbox/models/t3/t3.py +391 -0
  44. src/chatterbox/models/tokenizers/__init__.py +1 -0
  45. src/chatterbox/models/tokenizers/tokenizer.py +323 -0
  46. src/chatterbox/models/utils.py +4 -0
  47. src/chatterbox/models/voice_encoder/__init__.py +1 -0
  48. src/chatterbox/models/voice_encoder/config.py +18 -0
  49. src/chatterbox/models/voice_encoder/melspec.py +78 -0
  50. src/chatterbox/models/voice_encoder/voice_encoder.py +274 -0
.gitignore ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .vscode
2
+
3
+ # Pylance
4
+ pyrightconfig.json
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ syn_out/
44
+ checkpoints/
45
+ .gradio
46
+
47
+ # Ignore generated sample .wav files
48
+ **/*.wav
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ metadata
2
+ title: Chatterbox-Multilingual-TTS
3
+ emoji: 🌎
4
+ colorFrom: indigo
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.29.0
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: Chatterbox TTS supporting 23 languages
app.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import torch
4
+ from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
5
+ import gradio as gr
6
+ import spaces
7
+
8
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
+ print(f"🚀 Running on device: {DEVICE}")
10
+
11
+ # --- Global Model Initialization ---
12
+ MODEL = None
13
+
14
+ LANGUAGE_CONFIG = {
15
+ "ar": {
16
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_m1.flac",
17
+ "text": "في الشهر الماضي، وصلنا إلى معلم جديد بمليارين من المشاهدات على قناتنا على يوتيوب."
18
+ },
19
+ "da": {
20
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/da_m1.flac",
21
+ "text": "Sidste måned nåede vi en ny milepæl med to milliarder visninger på vores YouTube-kanal."
22
+ },
23
+ "de": {
24
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/de_f1.flac",
25
+ "text": "Letzten Monat haben wir einen neuen Meilenstein erreicht: zwei Milliarden Aufrufe auf unserem YouTube-Kanal."
26
+ },
27
+ "el": {
28
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/el_m.flac",
29
+ "text": "Τον περασμένο μήνα, φτάσαμε σε ένα νέο ορόσημο με δύο δισεκατομμύρια προβολές στο κανάλι μας στο YouTube."
30
+ },
31
+ "en": {
32
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/en_f1.flac",
33
+ "text": "Last month, we reached a new milestone with two billion views on our YouTube channel."
34
+ },
35
+ "es": {
36
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/es_f1.flac",
37
+ "text": "El mes pasado alcanzamos un nuevo hito: dos mil millones de visualizaciones en nuestro canal de YouTube."
38
+ },
39
+ "fi": {
40
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fi_m.flac",
41
+ "text": "Viime kuussa saavutimme uuden virstanpylvään kahden miljardin katselukerran kanssa YouTube-kanavallamme."
42
+ },
43
+ "fr": {
44
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fr_f1.flac",
45
+ "text": "Le mois dernier, nous avons atteint un nouveau jalon avec deux milliards de vues sur notre chaîne YouTube."
46
+ },
47
+ "he": {
48
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/he_m1.flac",
49
+ "text": "בחודש שעבר הגענו לאבן דרך חדשה עם שני מיליארד צפיות בערוץ היוטיוב שלנו."
50
+ },
51
+ "hi": {
52
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/hi_f1.flac",
53
+ "text": "पिछले महीने हमने एक नया मील का पत्थर छुआ: हमारे YouTube चैनल पर दो अरब व्यूज़।"
54
+ },
55
+ "it": {
56
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/it_m1.flac",
57
+ "text": "Il mese scorso abbiamo raggiunto un nuovo traguardo: due miliardi di visualizzazioni sul nostro canale YouTube."
58
+ },
59
+ "ja": {
60
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ja_f.flac",
61
+ "text": "先月、私たちのYouTubeチャンネルで二十億回の再生回数という新たなマイルストーンに到達しました。"
62
+ },
63
+ "ko": {
64
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ko_f.flac",
65
+ "text": "지난달 우리는 유튜브 채널에서 이십억 조회수라는 새로운 이정표에 도달했습니다."
66
+ },
67
+ "ms": {
68
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ms_f.flac",
69
+ "text": "Bulan lepas, kami mencapai pencapaian baru dengan dua bilion tontonan di saluran YouTube kami."
70
+ },
71
+ "nl": {
72
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/nl_m.flac",
73
+ "text": "Vorige maand bereikten we een nieuwe mijlpaal met twee miljard weergaven op ons YouTube-kanaal."
74
+ },
75
+ "no": {
76
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/no_f1.flac",
77
+ "text": "Forrige måned nådde vi en ny milepæl med to milliarder visninger på YouTube-kanalen vår."
78
+ },
79
+ "pl": {
80
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/pl_m.flac",
81
+ "text": "W zeszłym miesiącu osiągnęliśmy nowy kamień milowy z dwoma miliardami wyświetleń na naszym kanale YouTube."
82
+ },
83
+ "pt": {
84
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/pt_m1.flac",
85
+ "text": "No mês passado, alcançámos um novo marco: dois mil milhões de visualizações no nosso canal do YouTube."
86
+ },
87
+ "ru": {
88
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ru_m.flac",
89
+ "text": "В прошлом месяце мы достигли нового рубежа: два миллиарда просмотров на нашем YouTube-канале."
90
+ },
91
+ "sv": {
92
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/sv_f.flac",
93
+ "text": "Förra månaden nådde vi en ny milstolpe med två miljarder visningar på vår YouTube-kanal."
94
+ },
95
+ "sw": {
96
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/sw_m.flac",
97
+ "text": "Mwezi uliopita, tulifika hatua mpya ya maoni ya bilioni mbili kweny kituo chetu cha YouTube."
98
+ },
99
+ "tr": {
100
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/tr_m.flac",
101
+ "text": "Geçen ay YouTube kanalımızda iki milyar görüntüleme ile yeni bir dönüm noktasına ulaştık."
102
+ },
103
+ "zh": {
104
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/zh_f.flac",
105
+ "text": "上个月,我们达到了一个新的里程碑,我们的YouTube频道观看次数达到了二十亿次,这绝对令人难以置信。"
106
+ },
107
+ }
108
+
109
+ # --- UI Helpers ---
110
+ def default_audio_for_ui(lang: str) -> str | None:
111
+ return LANGUAGE_CONFIG.get(lang, {}).get("audio")
112
+
113
+
114
+ def default_text_for_ui(lang: str) -> str:
115
+ return LANGUAGE_CONFIG.get(lang, {}).get("text", "")
116
+
117
+
118
+ def get_supported_languages_display() -> str:
119
+ """Generate a formatted display of all supported languages."""
120
+ language_items = []
121
+ for code, name in sorted(SUPPORTED_LANGUAGES.items()):
122
+ language_items.append(f"**{name}** (`{code}`)")
123
+
124
+ # Split into 2 lines
125
+ mid = len(language_items) // 2
126
+ line1 = " • ".join(language_items[:mid])
127
+ line2 = " • ".join(language_items[mid:])
128
+
129
+ return f"""
130
+ ### 🌍 Supported Languages ({len(SUPPORTED_LANGUAGES)} total)
131
+ {line1}
132
+
133
+ {line2}
134
+ """
135
+
136
+
137
+ def get_or_load_model():
138
+ """Loads the ChatterboxMultilingualTTS model if it hasn't been loaded already,
139
+ and ensures it's on the correct device."""
140
+ global MODEL
141
+ if MODEL is None:
142
+ print("Model not loaded, initializing...")
143
+ try:
144
+ MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE)
145
+ if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
146
+ MODEL.to(DEVICE)
147
+ print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
148
+ except Exception as e:
149
+ print(f"Error loading model: {e}")
150
+ raise
151
+ return MODEL
152
+
153
+ # Attempt to load the model at startup.
154
+ try:
155
+ get_or_load_model()
156
+ except Exception as e:
157
+ print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
158
+
159
+ def set_seed(seed: int):
160
+ """Sets the random seed for reproducibility across torch, numpy, and random."""
161
+ torch.manual_seed(seed)
162
+ if DEVICE == "cuda":
163
+ torch.cuda.manual_seed(seed)
164
+ torch.cuda.manual_seed_all(seed)
165
+ random.seed(seed)
166
+ np.random.seed(seed)
167
+
168
+ def resolve_audio_prompt(language_id: str, provided_path: str | None) -> str | None:
169
+ """
170
+ Decide which audio prompt to use:
171
+ - If user provided a path (upload/mic/url), use it.
172
+ - Else, fall back to language-specific default (if any).
173
+ """
174
+ if provided_path and str(provided_path).strip():
175
+ return provided_path
176
+ return LANGUAGE_CONFIG.get(language_id, {}).get("audio")
177
+
178
+
179
+ @spaces.GPU
180
+ def generate_tts_audio(
181
+ text_input: str,
182
+ language_id: str,
183
+ audio_prompt_path_input: str = None,
184
+ exaggeration_input: float = 0.5,
185
+ temperature_input: float = 0.8,
186
+ seed_num_input: int = 0,
187
+ cfgw_input: float = 0.5
188
+ ) -> tuple[int, np.ndarray]:
189
+ """
190
+ Generate high-quality speech audio from text using Chatterbox Multilingual model with optional reference audio styling.
191
+ Supported languages: English, French, German, Spanish, Italian, Portuguese, and Hindi.
192
+
193
+ This tool synthesizes natural-sounding speech from input text. When a reference audio file
194
+ is provided, it captures the speaker's voice characteristics and speaking style. The generated audio
195
+ maintains the prosody, tone, and vocal qualities of the reference speaker, or uses default voice if no reference is provided.
196
+
197
+ Args:
198
+ text_input (str): The text to synthesize into speech (maximum 300 characters)
199
+ language_id (str): The language code for synthesis (eg. en, fr, de, es, it, pt, hi)
200
+ audio_prompt_path_input (str, optional): File path or URL to the reference audio file that defines the target voice style. Defaults to None.
201
+ exaggeration_input (float, optional): Controls speech expressiveness (0.25-2.0, neutral=0.5, extreme values may be unstable). Defaults to 0.5.
202
+ temperature_input (float, optional): Controls randomness in generation (0.05-5.0, higher=more varied). Defaults to 0.8.
203
+ seed_num_input (int, optional): Random seed for reproducible results (0 for random generation). Defaults to 0.
204
+ cfgw_input (float, optional): CFG/Pace weight controlling generation guidance (0.2-1.0). Defaults to 0.5, 0 for language transfer.
205
+
206
+ Returns:
207
+ tuple[int, np.ndarray]: A tuple containing the sample rate (int) and the generated audio waveform (numpy.ndarray)
208
+ """
209
+ current_model = get_or_load_model()
210
+
211
+ if current_model is None:
212
+ raise RuntimeError("TTS model is not loaded.")
213
+
214
+ if seed_num_input != 0:
215
+ set_seed(int(seed_num_input))
216
+
217
+ print(f"Generating audio for text: '{text_input[:50]}...'")
218
+
219
+ # Handle optional audio prompt
220
+ chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id)
221
+
222
+ generate_kwargs = {
223
+ "exaggeration": exaggeration_input,
224
+ "temperature": temperature_input,
225
+ "cfg_weight": cfgw_input,
226
+ }
227
+ if chosen_prompt:
228
+ generate_kwargs["audio_prompt_path"] = chosen_prompt
229
+ print(f"Using audio prompt: {chosen_prompt}")
230
+ else:
231
+ print("No audio prompt provided; using default voice.")
232
+
233
+ wav = current_model.generate(
234
+ text_input[:300], # Truncate text to max chars
235
+ language_id=language_id,
236
+ **generate_kwargs
237
+ )
238
+ print("Audio generation complete.")
239
+ return (current_model.sr, wav.squeeze(0).numpy())
240
+
241
+ with gr.Blocks() as demo:
242
+ gr.Markdown(
243
+ """
244
+ # Chatterbox Multilingual Demo
245
+ Generate high-quality multilingual speech from text with reference audio styling, supporting 23 languages.
246
+ """
247
+ )
248
+
249
+ # Display supported languages
250
+ gr.Markdown(get_supported_languages_display())
251
+ with gr.Row():
252
+ with gr.Column():
253
+ initial_lang = "fr"
254
+ text = gr.Textbox(
255
+ value=default_text_for_ui(initial_lang),
256
+ label="Text to synthesize (max chars 300)",
257
+ max_lines=5
258
+ )
259
+
260
+ language_id = gr.Dropdown(
261
+ choices=list(ChatterboxMultilingualTTS.get_supported_languages().keys()),
262
+ value=initial_lang,
263
+ label="Language",
264
+ info="Select the language for text-to-speech synthesis"
265
+ )
266
+
267
+ ref_wav = gr.Audio(
268
+ sources=["upload", "microphone"],
269
+ type="filepath",
270
+ label="Reference Audio File (Optional)",
271
+ value=default_audio_for_ui(initial_lang)
272
+ )
273
+
274
+ gr.Markdown(
275
+ "💡 **Note**: Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clip's language. To mitigate this, set the CFG weight to 0.",
276
+ elem_classes=["audio-note"]
277
+ )
278
+
279
+ exaggeration = gr.Slider(
280
+ 0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
281
+ )
282
+ cfg_weight = gr.Slider(
283
+ 0.2, 1, step=.05, label="CFG/Pace", value=0.5
284
+ )
285
+
286
+ with gr.Accordion("More options", open=False):
287
+ seed_num = gr.Number(value=0, label="Random seed (0 for random)")
288
+ temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
289
+
290
+ run_btn = gr.Button("Generate", variant="primary")
291
+
292
+ with gr.Column():
293
+ audio_output = gr.Audio(label="Output Audio")
294
+
295
+ def on_language_change(lang, current_ref, current_text):
296
+ return default_audio_for_ui(lang), default_text_for_ui(lang)
297
+
298
+ language_id.change(
299
+ fn=on_language_change,
300
+ inputs=[language_id, ref_wav, text],
301
+ outputs=[ref_wav, text],
302
+ show_progress=False
303
+ )
304
+
305
+ run_btn.click(
306
+ fn=generate_tts_audio,
307
+ inputs=[
308
+ text,
309
+ language_id,
310
+ ref_wav,
311
+ exaggeration,
312
+ temp,
313
+ seed_num,
314
+ cfg_weight,
315
+ ],
316
+ outputs=[audio_output],
317
+ )
318
+
319
+ demo.launch(mcp_server=True)
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ numpy==1.26.0
3
+ resampy==0.4.3
4
+ librosa==0.10.0
5
+ s3tokenizer
6
+ transformers==4.46.3
7
+ diffusers==0.29.0
8
+ omegaconf==2.3.0
9
+ resemble-perth==1.0.1
10
+ silero-vad==5.1.2
11
+ conformer==0.3.2
12
+ safetensors
13
+
14
+ # Optional language-specific dependencies
15
+ # Uncomment the ones you need for specific languages:
16
+ # pkuseg # For Chinese text segmentation (improves mixed text handling)
17
+ # pykakasi>=2.2.0 # For Japanese text processing (Kanji to Hiragana)
18
+ # dicta-onnx>=0.1.0 # For Hebrew diacritization
src/chatterbox/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from importlib.metadata import version
3
+ except ImportError:
4
+ from importlib_metadata import version # For Python <3.8
5
+
6
+ __version__ = version("chatterbox-tts")
7
+
8
+
9
+ from .tts import ChatterboxTTS
10
+ from .vc import ChatterboxVC
11
+ from .mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
src/chatterbox/models/__init__.py ADDED
File without changes
src/chatterbox/models/s3gen/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .s3gen import S3Token2Wav as S3Gen
2
+ from .const import S3GEN_SR
src/chatterbox/models/s3gen/configs.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..utils import AttrDict
2
+
3
+ CFM_PARAMS = AttrDict({
4
+ "sigma_min": 1e-06,
5
+ "solver": "euler",
6
+ "t_scheduler": "cosine",
7
+ "training_cfg_rate": 0.2,
8
+ "inference_cfg_rate": 0.7,
9
+ "reg_loss_type": "l1"
10
+ })
src/chatterbox/models/s3gen/const.py ADDED
@@ -0,0 +1 @@
 
 
1
+ S3GEN_SR = 24000
src/chatterbox/models/s3gen/decoder.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from einops import pack, rearrange, repeat
18
+
19
+ from .utils.mask import add_optional_chunk_mask
20
+ from .matcha.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, \
21
+ TimestepEmbedding, Upsample1D
22
+ from .matcha.transformer import BasicTransformerBlock
23
+
24
+
25
+ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
26
+ assert mask.dtype == torch.bool
27
+ assert dtype in [torch.float32, torch.bfloat16, torch.float16]
28
+ mask = mask.to(dtype)
29
+ # attention mask bias
30
+ # NOTE(Mddct): torch.finfo jit issues
31
+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
32
+ mask = (1.0 - mask) * -1.0e+10
33
+ return mask
34
+
35
+
36
+
37
+ class Transpose(torch.nn.Module):
38
+ def __init__(self, dim0: int, dim1: int):
39
+ super().__init__()
40
+ self.dim0 = dim0
41
+ self.dim1 = dim1
42
+
43
+ def forward(self, x: torch.Tensor):
44
+ x = torch.transpose(x, self.dim0, self.dim1)
45
+ return x
46
+
47
+
48
+ class CausalBlock1D(Block1D):
49
+ def __init__(self, dim: int, dim_out: int):
50
+ super(CausalBlock1D, self).__init__(dim, dim_out)
51
+ self.block = torch.nn.Sequential(
52
+ CausalConv1d(dim, dim_out, 3),
53
+ Transpose(1, 2),
54
+ nn.LayerNorm(dim_out),
55
+ Transpose(1, 2),
56
+ nn.Mish(),
57
+ )
58
+
59
+ def forward(self, x: torch.Tensor, mask: torch.Tensor):
60
+ output = self.block(x * mask)
61
+ return output * mask
62
+
63
+
64
+ class CausalResnetBlock1D(ResnetBlock1D):
65
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
66
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
67
+ self.block1 = CausalBlock1D(dim, dim_out)
68
+ self.block2 = CausalBlock1D(dim_out, dim_out)
69
+
70
+
71
+ class CausalConv1d(torch.nn.Conv1d):
72
+ def __init__(
73
+ self,
74
+ in_channels: int,
75
+ out_channels: int,
76
+ kernel_size: int,
77
+ stride: int = 1,
78
+ dilation: int = 1,
79
+ groups: int = 1,
80
+ bias: bool = True,
81
+ padding_mode: str = 'zeros',
82
+ device=None,
83
+ dtype=None
84
+ ) -> None:
85
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
86
+ kernel_size, stride,
87
+ padding=0, dilation=dilation,
88
+ groups=groups, bias=bias,
89
+ padding_mode=padding_mode,
90
+ device=device, dtype=dtype)
91
+ assert stride == 1
92
+ self.causal_padding = (kernel_size - 1, 0)
93
+
94
+ def forward(self, x: torch.Tensor):
95
+ x = F.pad(x, self.causal_padding)
96
+ x = super(CausalConv1d, self).forward(x)
97
+ return x
98
+
99
+
100
+ class ConditionalDecoder(nn.Module):
101
+ def __init__(
102
+ self,
103
+ in_channels=320,
104
+ out_channels=80,
105
+ causal=True,
106
+ channels=[256],
107
+ dropout=0.0,
108
+ attention_head_dim=64,
109
+ n_blocks=4,
110
+ num_mid_blocks=12,
111
+ num_heads=8,
112
+ act_fn="gelu",
113
+ ):
114
+ """
115
+ This decoder requires an input with the same shape of the target. So, if your text content
116
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
117
+ """
118
+ super().__init__()
119
+ channels = tuple(channels)
120
+ self.in_channels = in_channels
121
+ self.out_channels = out_channels
122
+ self.causal = causal
123
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
124
+ time_embed_dim = channels[0] * 4
125
+ self.time_mlp = TimestepEmbedding(
126
+ in_channels=in_channels,
127
+ time_embed_dim=time_embed_dim,
128
+ act_fn="silu",
129
+ )
130
+ self.down_blocks = nn.ModuleList([])
131
+ self.mid_blocks = nn.ModuleList([])
132
+ self.up_blocks = nn.ModuleList([])
133
+
134
+ # NOTE jrm: `static_chunk_size` is missing?
135
+ self.static_chunk_size = 0
136
+
137
+ output_channel = in_channels
138
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
139
+ input_channel = output_channel
140
+ output_channel = channels[i]
141
+ is_last = i == len(channels) - 1
142
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
143
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
144
+ transformer_blocks = nn.ModuleList(
145
+ [
146
+ BasicTransformerBlock(
147
+ dim=output_channel,
148
+ num_attention_heads=num_heads,
149
+ attention_head_dim=attention_head_dim,
150
+ dropout=dropout,
151
+ activation_fn=act_fn,
152
+ )
153
+ for _ in range(n_blocks)
154
+ ]
155
+ )
156
+ downsample = (
157
+ Downsample1D(output_channel) if not is_last else
158
+ CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
159
+ )
160
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
161
+
162
+ for _ in range(num_mid_blocks):
163
+ input_channel = channels[-1]
164
+ out_channels = channels[-1]
165
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
166
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
167
+
168
+ transformer_blocks = nn.ModuleList(
169
+ [
170
+ BasicTransformerBlock(
171
+ dim=output_channel,
172
+ num_attention_heads=num_heads,
173
+ attention_head_dim=attention_head_dim,
174
+ dropout=dropout,
175
+ activation_fn=act_fn,
176
+ )
177
+ for _ in range(n_blocks)
178
+ ]
179
+ )
180
+
181
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
182
+
183
+ channels = channels[::-1] + (channels[0],)
184
+ for i in range(len(channels) - 1):
185
+ input_channel = channels[i] * 2
186
+ output_channel = channels[i + 1]
187
+ is_last = i == len(channels) - 2
188
+ resnet = CausalResnetBlock1D(
189
+ dim=input_channel,
190
+ dim_out=output_channel,
191
+ time_emb_dim=time_embed_dim,
192
+ ) if self.causal else ResnetBlock1D(
193
+ dim=input_channel,
194
+ dim_out=output_channel,
195
+ time_emb_dim=time_embed_dim,
196
+ )
197
+ transformer_blocks = nn.ModuleList(
198
+ [
199
+ BasicTransformerBlock(
200
+ dim=output_channel,
201
+ num_attention_heads=num_heads,
202
+ attention_head_dim=attention_head_dim,
203
+ dropout=dropout,
204
+ activation_fn=act_fn,
205
+ )
206
+ for _ in range(n_blocks)
207
+ ]
208
+ )
209
+ upsample = (
210
+ Upsample1D(output_channel, use_conv_transpose=True)
211
+ if not is_last
212
+ else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
213
+ )
214
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
215
+ self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
216
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
217
+ self.initialize_weights()
218
+
219
+ def initialize_weights(self):
220
+ for m in self.modules():
221
+ if isinstance(m, nn.Conv1d):
222
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
223
+ if m.bias is not None:
224
+ nn.init.constant_(m.bias, 0)
225
+ elif isinstance(m, nn.GroupNorm):
226
+ nn.init.constant_(m.weight, 1)
227
+ nn.init.constant_(m.bias, 0)
228
+ elif isinstance(m, nn.Linear):
229
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
230
+ if m.bias is not None:
231
+ nn.init.constant_(m.bias, 0)
232
+
233
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
234
+ """Forward pass of the UNet1DConditional model.
235
+
236
+ Args:
237
+ x (torch.Tensor): shape (batch_size, in_channels, time)
238
+ mask (_type_): shape (batch_size, 1, time)
239
+ t (_type_): shape (batch_size)
240
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
241
+ cond (_type_, optional): placeholder for future use. Defaults to None.
242
+
243
+ Raises:
244
+ ValueError: _description_
245
+ ValueError: _description_
246
+
247
+ Returns:
248
+ _type_: _description_
249
+ """
250
+
251
+ t = self.time_embeddings(t).to(t.dtype)
252
+ t = self.time_mlp(t)
253
+
254
+ x = pack([x, mu], "b * t")[0]
255
+
256
+ if spks is not None:
257
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
258
+ x = pack([x, spks], "b * t")[0]
259
+ if cond is not None:
260
+ x = pack([x, cond], "b * t")[0]
261
+
262
+ hiddens = []
263
+ masks = [mask]
264
+ for resnet, transformer_blocks, downsample in self.down_blocks:
265
+ mask_down = masks[-1]
266
+ x = resnet(x, mask_down, t)
267
+ x = rearrange(x, "b c t -> b t c").contiguous()
268
+ # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
269
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
270
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
271
+ for transformer_block in transformer_blocks:
272
+ x = transformer_block(
273
+ hidden_states=x,
274
+ attention_mask=attn_mask,
275
+ timestep=t,
276
+ )
277
+ x = rearrange(x, "b t c -> b c t").contiguous()
278
+ hiddens.append(x) # Save hidden states for skip connections
279
+ x = downsample(x * mask_down)
280
+ masks.append(mask_down[:, :, ::2])
281
+ masks = masks[:-1]
282
+ mask_mid = masks[-1]
283
+
284
+ for resnet, transformer_blocks in self.mid_blocks:
285
+ x = resnet(x, mask_mid, t)
286
+ x = rearrange(x, "b c t -> b t c").contiguous()
287
+ # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
288
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
289
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
290
+ for transformer_block in transformer_blocks:
291
+ x = transformer_block(
292
+ hidden_states=x,
293
+ attention_mask=attn_mask,
294
+ timestep=t,
295
+ )
296
+ x = rearrange(x, "b t c -> b c t").contiguous()
297
+
298
+ for resnet, transformer_blocks, upsample in self.up_blocks:
299
+ mask_up = masks.pop()
300
+ skip = hiddens.pop()
301
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
302
+ x = resnet(x, mask_up, t)
303
+ x = rearrange(x, "b c t -> b t c").contiguous()
304
+ # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
305
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
306
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
307
+ for transformer_block in transformer_blocks:
308
+ x = transformer_block(
309
+ hidden_states=x,
310
+ attention_mask=attn_mask,
311
+ timestep=t,
312
+ )
313
+ x = rearrange(x, "b t c -> b c t").contiguous()
314
+ x = upsample(x * mask_up)
315
+ x = self.final_block(x, mask_up)
316
+ output = self.final_proj(x * mask_up)
317
+ return output * mask
src/chatterbox/models/s3gen/f0_predictor.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn.utils.parametrizations import weight_norm
17
+
18
+
19
+ class ConvRNNF0Predictor(nn.Module):
20
+ def __init__(self,
21
+ num_class: int = 1,
22
+ in_channels: int = 80,
23
+ cond_channels: int = 512
24
+ ):
25
+ super().__init__()
26
+
27
+ self.num_class = num_class
28
+ self.condnet = nn.Sequential(
29
+ weight_norm(
30
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
31
+ ),
32
+ nn.ELU(),
33
+ weight_norm(
34
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
35
+ ),
36
+ nn.ELU(),
37
+ weight_norm(
38
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
39
+ ),
40
+ nn.ELU(),
41
+ weight_norm(
42
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
43
+ ),
44
+ nn.ELU(),
45
+ weight_norm(
46
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
47
+ ),
48
+ nn.ELU(),
49
+ )
50
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ x = self.condnet(x)
54
+ x = x.transpose(1, 2)
55
+ return torch.abs(self.classifier(x).squeeze(-1))
src/chatterbox/models/s3gen/flow.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+ from typing import Dict, Optional
17
+
18
+ logger = logging.getLogger(__name__)
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.nn import functional as F
22
+ from .utils.mask import make_pad_mask
23
+ from .configs import CFM_PARAMS
24
+
25
+
26
+ class MaskedDiffWithXvec(torch.nn.Module):
27
+ def __init__(
28
+ self,
29
+ input_size: int = 512,
30
+ output_size: int = 80,
31
+ spk_embed_dim: int = 192,
32
+ output_type: str = "mel",
33
+ vocab_size: int = 4096,
34
+ input_frame_rate: int = 50,
35
+ only_mask_loss: bool = True,
36
+ encoder: torch.nn.Module = None,
37
+ length_regulator: torch.nn.Module = None,
38
+ decoder: torch.nn.Module = None,
39
+ decoder_conf: Dict = {
40
+ 'in_channels': 240,
41
+ 'out_channel': 80,
42
+ 'spk_emb_dim': 80,
43
+ 'n_spks': 1,
44
+ 'cfm_params': CFM_PARAMS,
45
+ 'decoder_params': {
46
+ 'channels': [256, 256],
47
+ 'dropout': 0.0,
48
+ 'attention_head_dim': 64,
49
+ 'n_blocks': 4,
50
+ 'num_mid_blocks': 12,
51
+ 'num_heads': 8,
52
+ 'act_fn': 'gelu',
53
+ }
54
+ },
55
+ mel_feat_conf: Dict = {
56
+ 'n_fft': 1024,
57
+ 'num_mels': 80,
58
+ 'sampling_rate': 22050,
59
+ 'hop_size': 256,
60
+ 'win_size': 1024,
61
+ 'fmin': 0,
62
+ 'fmax': 8000
63
+ }
64
+ ):
65
+ super().__init__()
66
+ self.input_size = input_size
67
+ self.output_size = output_size
68
+ self.decoder_conf = decoder_conf
69
+ self.mel_feat_conf = mel_feat_conf
70
+ self.vocab_size = vocab_size
71
+ self.output_type = output_type
72
+ self.input_frame_rate = input_frame_rate
73
+ logging.info(f"input frame rate={self.input_frame_rate}")
74
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
75
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
76
+ self.encoder = encoder
77
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
78
+ self.decoder = decoder
79
+ self.length_regulator = length_regulator
80
+ self.only_mask_loss = only_mask_loss
81
+
82
+ def forward(
83
+ self,
84
+ batch: dict,
85
+ device: torch.device,
86
+ ) -> Dict[str, Optional[torch.Tensor]]:
87
+ token = batch['speech_token'].to(device)
88
+ token_len = batch['speech_token_len'].to(device)
89
+ feat = batch['speech_feat'].to(device)
90
+ feat_len = batch['speech_feat_len'].to(device)
91
+ embedding = batch['embedding'].to(device)
92
+
93
+ # xvec projection
94
+ embedding = F.normalize(embedding, dim=1)
95
+ embedding = self.spk_embed_affine_layer(embedding)
96
+
97
+ # concat text and prompt_text
98
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
99
+ token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask
100
+
101
+ # text encode
102
+ h, h_lengths = self.encoder(token, token_len)
103
+ h = self.encoder_proj(h)
104
+ h, h_lengths = self.length_regulator(h, feat_len)
105
+
106
+ # get conditions
107
+ conds = torch.zeros(feat.shape, device=token.device)
108
+ for i, j in enumerate(feat_len):
109
+ if random.random() < 0.5:
110
+ continue
111
+ index = random.randint(0, int(0.3 * j))
112
+ conds[i, :index] = feat[i, :index]
113
+ conds = conds.transpose(1, 2)
114
+
115
+ mask = (~make_pad_mask(feat_len)).to(h)
116
+ feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
117
+ loss, _ = self.decoder.compute_loss(
118
+ feat.transpose(1, 2).contiguous(),
119
+ mask.unsqueeze(1),
120
+ h.transpose(1, 2).contiguous(),
121
+ embedding,
122
+ cond=conds
123
+ )
124
+ return {'loss': loss}
125
+
126
+ @torch.inference_mode()
127
+ def inference(self,
128
+ token,
129
+ token_len,
130
+ prompt_token,
131
+ prompt_token_len,
132
+ prompt_feat,
133
+ prompt_feat_len,
134
+ embedding,
135
+ flow_cache):
136
+ if self.fp16 is True:
137
+ prompt_feat = prompt_feat.half()
138
+ embedding = embedding.half()
139
+
140
+ assert token.shape[0] == 1
141
+ # xvec projection
142
+ embedding = F.normalize(embedding, dim=1)
143
+ embedding = self.spk_embed_affine_layer(embedding)
144
+
145
+ # concat text and prompt_text
146
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
147
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
148
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
149
+
150
+ # Check for out-of-bounds token IDs
151
+ vocab_size = self.input_embedding.num_embeddings
152
+ if token.max() >= vocab_size or token.min() < 0:
153
+ logging.warning(f"S3Gen: Token IDs out of bounds: min={token.min().item()}, max={token.max().item()}, vocab_size={vocab_size}")
154
+
155
+ token = self.input_embedding(torch.clamp(token, min=0, max=vocab_size-1)) * mask
156
+
157
+ # text encode
158
+ h, h_lengths = self.encoder(token, token_len)
159
+ h = self.encoder_proj(h)
160
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
161
+ h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
162
+
163
+ # get conditions
164
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
165
+ conds[:, :mel_len1] = prompt_feat
166
+ conds = conds.transpose(1, 2)
167
+
168
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
169
+ feat, flow_cache = self.decoder(
170
+ mu=h.transpose(1, 2).contiguous(),
171
+ mask=mask.unsqueeze(1),
172
+ spks=embedding,
173
+ cond=conds,
174
+ n_timesteps=10,
175
+ prompt_len=mel_len1,
176
+ flow_cache=flow_cache
177
+ )
178
+ feat = feat[:, :, mel_len1:]
179
+ assert feat.shape[2] == mel_len2
180
+ return feat.float(), flow_cache
181
+
182
+
183
+ class CausalMaskedDiffWithXvec(torch.nn.Module):
184
+ def __init__(
185
+ self,
186
+ input_size: int = 512,
187
+ output_size: int = 80,
188
+ spk_embed_dim: int = 192,
189
+ output_type: str = "mel",
190
+ vocab_size: int = 6561,
191
+ input_frame_rate: int = 25,
192
+ only_mask_loss: bool = True,
193
+ token_mel_ratio: int = 2,
194
+ pre_lookahead_len: int = 3,
195
+ encoder: torch.nn.Module = None,
196
+ decoder: torch.nn.Module = None,
197
+ decoder_conf: Dict = {
198
+ 'in_channels': 240,
199
+ 'out_channel': 80,
200
+ 'spk_emb_dim': 80,
201
+ 'n_spks': 1,
202
+ 'cfm_params': CFM_PARAMS,
203
+ 'decoder_params': {
204
+ 'channels': [256, 256],
205
+ 'dropout': 0.0,
206
+ 'attention_head_dim': 64,
207
+ 'n_blocks': 4,
208
+ 'num_mid_blocks': 12,
209
+ 'num_heads': 8,
210
+ 'act_fn': 'gelu',
211
+ }
212
+ },
213
+ mel_feat_conf: Dict = {
214
+ 'n_fft': 1024,
215
+ 'num_mels': 80,
216
+ 'sampling_rate': 22050,
217
+ 'hop_size': 256,
218
+ 'win_size': 1024,
219
+ 'fmin': 0,
220
+ 'fmax': 8000
221
+ }
222
+ ):
223
+ super().__init__()
224
+ self.input_size = input_size
225
+ self.output_size = output_size
226
+ self.decoder_conf = decoder_conf
227
+ self.mel_feat_conf = mel_feat_conf
228
+ self.vocab_size = vocab_size
229
+ self.output_type = output_type
230
+ self.input_frame_rate = input_frame_rate
231
+ logging.info(f"input frame rate={self.input_frame_rate}")
232
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
233
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
234
+ self.encoder = encoder
235
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
236
+ self.decoder = decoder
237
+ self.only_mask_loss = only_mask_loss
238
+ self.token_mel_ratio = token_mel_ratio
239
+ self.pre_lookahead_len = pre_lookahead_len
240
+
241
+ # FIXME: this was missing - just putting it in as false
242
+ self.fp16 = False
243
+
244
+ @torch.inference_mode()
245
+ def inference(self,
246
+ token,
247
+ token_len,
248
+ prompt_token,
249
+ prompt_token_len,
250
+ prompt_feat,
251
+ prompt_feat_len,
252
+ embedding,
253
+ finalize):
254
+ if self.fp16 is True:
255
+ prompt_feat = prompt_feat.half()
256
+ embedding = embedding.half()
257
+
258
+ assert token.shape[0] == 1
259
+ # xvec projection
260
+ embedding = F.normalize(embedding, dim=1)
261
+ embedding = self.spk_embed_affine_layer(embedding)
262
+
263
+ # concat text and prompt_text
264
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
265
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
266
+ token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask
267
+
268
+ # text encode
269
+ h, h_lengths = self.encoder(token, token_len)
270
+ if finalize is False:
271
+ h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
272
+ mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
273
+ h = self.encoder_proj(h)
274
+
275
+ # get conditions
276
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
277
+ conds[:, :mel_len1] = prompt_feat
278
+ conds = conds.transpose(1, 2)
279
+
280
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
281
+ feat, _ = self.decoder(
282
+ mu=h.transpose(1, 2).contiguous(),
283
+ mask=mask.unsqueeze(1),
284
+ spks=embedding,
285
+ cond=conds,
286
+ n_timesteps=10
287
+ )
288
+ feat = feat[:, :, mel_len1:]
289
+ assert feat.shape[2] == mel_len2
290
+ return feat.float(), None # NOTE jrm: why are they returning None here?
src/chatterbox/models/s3gen/flow_matching.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import threading
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from .matcha.flow_matching import BASECFM
18
+ from .configs import CFM_PARAMS
19
+
20
+
21
+ class ConditionalCFM(BASECFM):
22
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
23
+ super().__init__(
24
+ n_feats=in_channels,
25
+ cfm_params=cfm_params,
26
+ n_spks=n_spks,
27
+ spk_emb_dim=spk_emb_dim,
28
+ )
29
+ self.t_scheduler = cfm_params.t_scheduler
30
+ self.training_cfg_rate = cfm_params.training_cfg_rate
31
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
32
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
33
+ # Just change the architecture of the estimator here
34
+ self.estimator = estimator
35
+ self.lock = threading.Lock()
36
+
37
+ @torch.inference_mode()
38
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
39
+ """Forward diffusion
40
+
41
+ Args:
42
+ mu (torch.Tensor): output of encoder
43
+ shape: (batch_size, n_feats, mel_timesteps)
44
+ mask (torch.Tensor): output_mask
45
+ shape: (batch_size, 1, mel_timesteps)
46
+ n_timesteps (int): number of diffusion steps
47
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
48
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
49
+ shape: (batch_size, spk_emb_dim)
50
+ cond: Not used but kept for future purposes
51
+
52
+ Returns:
53
+ sample: generated mel-spectrogram
54
+ shape: (batch_size, n_feats, mel_timesteps)
55
+ """
56
+
57
+ z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
58
+ cache_size = flow_cache.shape[2]
59
+ # fix prompt and overlap part mu and z
60
+ if cache_size != 0:
61
+ z[:, :, :cache_size] = flow_cache[:, :, :, 0]
62
+ mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
63
+ z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
64
+ mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
65
+ flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
66
+
67
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
68
+ if self.t_scheduler == 'cosine':
69
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
70
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
71
+
72
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
73
+ """
74
+ Fixed euler solver for ODEs.
75
+ Args:
76
+ x (torch.Tensor): random noise
77
+ t_span (torch.Tensor): n_timesteps interpolated
78
+ shape: (n_timesteps + 1,)
79
+ mu (torch.Tensor): output of encoder
80
+ shape: (batch_size, n_feats, mel_timesteps)
81
+ mask (torch.Tensor): output_mask
82
+ shape: (batch_size, 1, mel_timesteps)
83
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
84
+ shape: (batch_size, spk_emb_dim)
85
+ cond: Not used but kept for future purposes
86
+ """
87
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
88
+ t = t.unsqueeze(dim=0)
89
+
90
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
91
+ # Or in future might add like a return_all_steps flag
92
+ sol = []
93
+
94
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
95
+ x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
96
+ mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
97
+ mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
98
+ t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
99
+ spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
100
+ cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
101
+ for step in range(1, len(t_span)):
102
+ # Classifier-Free Guidance inference introduced in VoiceBox
103
+ x_in[:] = x
104
+ mask_in[:] = mask
105
+ mu_in[0] = mu
106
+ t_in[:] = t.unsqueeze(0)
107
+ spks_in[0] = spks
108
+ cond_in[0] = cond
109
+ dphi_dt = self.forward_estimator(
110
+ x_in, mask_in,
111
+ mu_in, t_in,
112
+ spks_in,
113
+ cond_in
114
+ )
115
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
116
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
117
+ x = x + dt * dphi_dt
118
+ t = t + dt
119
+ sol.append(x)
120
+ if step < len(t_span) - 1:
121
+ dt = t_span[step + 1] - t
122
+
123
+ return sol[-1].float()
124
+
125
+ def forward_estimator(self, x, mask, mu, t, spks, cond):
126
+ if isinstance(self.estimator, torch.nn.Module):
127
+ return self.estimator.forward(x, mask, mu, t, spks, cond)
128
+ else:
129
+ with self.lock:
130
+ self.estimator.set_input_shape('x', (2, 80, x.size(2)))
131
+ self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
132
+ self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
133
+ self.estimator.set_input_shape('t', (2,))
134
+ self.estimator.set_input_shape('spks', (2, 80))
135
+ self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
136
+ # run trt engine
137
+ self.estimator.execute_v2([x.contiguous().data_ptr(),
138
+ mask.contiguous().data_ptr(),
139
+ mu.contiguous().data_ptr(),
140
+ t.contiguous().data_ptr(),
141
+ spks.contiguous().data_ptr(),
142
+ cond.contiguous().data_ptr(),
143
+ x.data_ptr()])
144
+ return x
145
+
146
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
147
+ """Computes diffusion loss
148
+
149
+ Args:
150
+ x1 (torch.Tensor): Target
151
+ shape: (batch_size, n_feats, mel_timesteps)
152
+ mask (torch.Tensor): target mask
153
+ shape: (batch_size, 1, mel_timesteps)
154
+ mu (torch.Tensor): output of encoder
155
+ shape: (batch_size, n_feats, mel_timesteps)
156
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
157
+ shape: (batch_size, spk_emb_dim)
158
+
159
+ Returns:
160
+ loss: conditional flow matching loss
161
+ y: conditional flow
162
+ shape: (batch_size, n_feats, mel_timesteps)
163
+ """
164
+ b, _, t = mu.shape
165
+
166
+ # random timestep
167
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
168
+ if self.t_scheduler == 'cosine':
169
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
170
+ # sample noise p(x_0)
171
+ z = torch.randn_like(x1)
172
+
173
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
174
+ u = x1 - (1 - self.sigma_min) * z
175
+
176
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
177
+ if self.training_cfg_rate > 0:
178
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
179
+ mu = mu * cfg_mask.view(-1, 1, 1)
180
+ spks = spks * cfg_mask.view(-1, 1)
181
+ cond = cond * cfg_mask.view(-1, 1, 1)
182
+
183
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
184
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
185
+ return loss, y
186
+
187
+
188
+ class CausalConditionalCFM(ConditionalCFM):
189
+ def __init__(self, in_channels=240, cfm_params=CFM_PARAMS, n_spks=1, spk_emb_dim=80, estimator=None):
190
+ super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
191
+ self.rand_noise = torch.randn([1, 80, 50 * 300])
192
+
193
+ @torch.inference_mode()
194
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
195
+ """Forward diffusion
196
+
197
+ Args:
198
+ mu (torch.Tensor): output of encoder
199
+ shape: (batch_size, n_feats, mel_timesteps)
200
+ mask (torch.Tensor): output_mask
201
+ shape: (batch_size, 1, mel_timesteps)
202
+ n_timesteps (int): number of diffusion steps
203
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
204
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
205
+ shape: (batch_size, spk_emb_dim)
206
+ cond: Not used but kept for future purposes
207
+
208
+ Returns:
209
+ sample: generated mel-spectrogram
210
+ shape: (batch_size, n_feats, mel_timesteps)
211
+ """
212
+
213
+ z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
214
+ # fix prompt and overlap part mu and z
215
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
216
+ if self.t_scheduler == 'cosine':
217
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
218
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
src/chatterbox/models/s3gen/hifigan.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # jrm: adapted from CosyVoice/cosyvoice/hifigan/generator.py
2
+ # most modules should be reusable, but I found their SineGen changed a git.
3
+
4
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """HIFI-GAN"""
19
+
20
+ from typing import Dict, Optional, List
21
+ import numpy as np
22
+ from scipy.signal import get_window
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from torch.nn import Conv1d
26
+ from torch.nn import ConvTranspose1d
27
+ from torch.nn.utils import remove_weight_norm
28
+ from torch.nn.utils.parametrizations import weight_norm
29
+ from torch.distributions.uniform import Uniform
30
+ from torch import nn, sin, pow
31
+ from torch.nn import Parameter
32
+
33
+
34
+ class Snake(nn.Module):
35
+ '''
36
+ Implementation of a sine-based periodic activation function
37
+ Shape:
38
+ - Input: (B, C, T)
39
+ - Output: (B, C, T), same shape as the input
40
+ Parameters:
41
+ - alpha - trainable parameter
42
+ References:
43
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
44
+ https://arxiv.org/abs/2006.08195
45
+ Examples:
46
+ >>> a1 = snake(256)
47
+ >>> x = torch.randn(256)
48
+ >>> x = a1(x)
49
+ '''
50
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
51
+ '''
52
+ Initialization.
53
+ INPUT:
54
+ - in_features: shape of the input
55
+ - alpha: trainable parameter
56
+ alpha is initialized to 1 by default, higher values = higher-frequency.
57
+ alpha will be trained along with the rest of your model.
58
+ '''
59
+ super(Snake, self).__init__()
60
+ self.in_features = in_features
61
+
62
+ # initialize alpha
63
+ self.alpha_logscale = alpha_logscale
64
+ if self.alpha_logscale: # log scale alphas initialized to zeros
65
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
66
+ else: # linear scale alphas initialized to ones
67
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
68
+
69
+ self.alpha.requires_grad = alpha_trainable
70
+
71
+ self.no_div_by_zero = 0.000000001
72
+
73
+ def forward(self, x):
74
+ '''
75
+ Forward pass of the function.
76
+ Applies the function to the input elementwise.
77
+ Snake ∶= x + 1/a * sin^2 (xa)
78
+ '''
79
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
80
+ if self.alpha_logscale:
81
+ alpha = torch.exp(alpha)
82
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
83
+
84
+ return x
85
+
86
+
87
+
88
+ def get_padding(kernel_size, dilation=1):
89
+ return int((kernel_size * dilation - dilation) / 2)
90
+
91
+ def init_weights(m, mean=0.0, std=0.01):
92
+ classname = m.__class__.__name__
93
+ if classname.find("Conv") != -1:
94
+ m.weight.data.normal_(mean, std)
95
+
96
+
97
+ """hifigan based generator implementation.
98
+
99
+ This code is modified from https://github.com/jik876/hifi-gan
100
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
101
+ https://github.com/NVIDIA/BigVGAN
102
+
103
+ """
104
+
105
+
106
+ class ResBlock(torch.nn.Module):
107
+ """Residual block module in HiFiGAN/BigVGAN."""
108
+ def __init__(
109
+ self,
110
+ channels: int = 512,
111
+ kernel_size: int = 3,
112
+ dilations: List[int] = [1, 3, 5],
113
+ ):
114
+ super(ResBlock, self).__init__()
115
+ self.convs1 = nn.ModuleList()
116
+ self.convs2 = nn.ModuleList()
117
+
118
+ for dilation in dilations:
119
+ self.convs1.append(
120
+ weight_norm(
121
+ Conv1d(
122
+ channels,
123
+ channels,
124
+ kernel_size,
125
+ 1,
126
+ dilation=dilation,
127
+ padding=get_padding(kernel_size, dilation)
128
+ )
129
+ )
130
+ )
131
+ self.convs2.append(
132
+ weight_norm(
133
+ Conv1d(
134
+ channels,
135
+ channels,
136
+ kernel_size,
137
+ 1,
138
+ dilation=1,
139
+ padding=get_padding(kernel_size, 1)
140
+ )
141
+ )
142
+ )
143
+ self.convs1.apply(init_weights)
144
+ self.convs2.apply(init_weights)
145
+ self.activations1 = nn.ModuleList([
146
+ Snake(channels, alpha_logscale=False)
147
+ for _ in range(len(self.convs1))
148
+ ])
149
+ self.activations2 = nn.ModuleList([
150
+ Snake(channels, alpha_logscale=False)
151
+ for _ in range(len(self.convs2))
152
+ ])
153
+
154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
155
+ for idx in range(len(self.convs1)):
156
+ xt = self.activations1[idx](x)
157
+ xt = self.convs1[idx](xt)
158
+ xt = self.activations2[idx](xt)
159
+ xt = self.convs2[idx](xt)
160
+ x = xt + x
161
+ return x
162
+
163
+ def remove_weight_norm(self):
164
+ for idx in range(len(self.convs1)):
165
+ remove_weight_norm(self.convs1[idx])
166
+ remove_weight_norm(self.convs2[idx])
167
+
168
+
169
+ class SineGen(torch.nn.Module):
170
+ """ Definition of sine generator
171
+ SineGen(samp_rate, harmonic_num = 0,
172
+ sine_amp = 0.1, noise_std = 0.003,
173
+ voiced_threshold = 0,
174
+ flag_for_pulse=False)
175
+ samp_rate: sampling rate in Hz
176
+ harmonic_num: number of harmonic overtones (default 0)
177
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
178
+ noise_std: std of Gaussian noise (default 0.003)
179
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
180
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
181
+ Note: when flag_for_pulse is True, the first time step of a voiced
182
+ segment is always sin(np.pi) or cos(0)
183
+ """
184
+
185
+ def __init__(self, samp_rate, harmonic_num=0,
186
+ sine_amp=0.1, noise_std=0.003,
187
+ voiced_threshold=0):
188
+ super(SineGen, self).__init__()
189
+ self.sine_amp = sine_amp
190
+ self.noise_std = noise_std
191
+ self.harmonic_num = harmonic_num
192
+ self.sampling_rate = samp_rate
193
+ self.voiced_threshold = voiced_threshold
194
+
195
+ def _f02uv(self, f0):
196
+ # generate uv signal
197
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
198
+ return uv
199
+
200
+ @torch.no_grad()
201
+ def forward(self, f0):
202
+ """
203
+ :param f0: [B, 1, sample_len], Hz
204
+ :return: [B, 1, sample_len]
205
+ """
206
+
207
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
208
+ for i in range(self.harmonic_num + 1):
209
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
210
+
211
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
212
+ u_dist = Uniform(low=-np.pi, high=np.pi)
213
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
214
+ phase_vec[:, 0, :] = 0
215
+
216
+ # generate sine waveforms
217
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
218
+
219
+ # generate uv signal
220
+ uv = self._f02uv(f0)
221
+
222
+ # noise: for unvoiced should be similar to sine_amp
223
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
224
+ # . for voiced regions is self.noise_std
225
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
226
+ noise = noise_amp * torch.randn_like(sine_waves)
227
+
228
+ # first: set the unvoiced part to 0 by uv
229
+ # then: additive noise
230
+ sine_waves = sine_waves * uv + noise
231
+ return sine_waves, uv, noise
232
+
233
+
234
+ class SourceModuleHnNSF(torch.nn.Module):
235
+ """ SourceModule for hn-nsf
236
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
237
+ add_noise_std=0.003, voiced_threshod=0)
238
+ sampling_rate: sampling_rate in Hz
239
+ harmonic_num: number of harmonic above F0 (default: 0)
240
+ sine_amp: amplitude of sine source signal (default: 0.1)
241
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
242
+ note that amplitude of noise in unvoiced is decided
243
+ by sine_amp
244
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
245
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
246
+ F0_sampled (batchsize, length, 1)
247
+ Sine_source (batchsize, length, 1)
248
+ noise_source (batchsize, length 1)
249
+ uv (batchsize, length, 1)
250
+ """
251
+
252
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
253
+ add_noise_std=0.003, voiced_threshod=0):
254
+ super(SourceModuleHnNSF, self).__init__()
255
+
256
+ self.sine_amp = sine_amp
257
+ self.noise_std = add_noise_std
258
+
259
+ # to produce sine waveforms
260
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
261
+ sine_amp, add_noise_std, voiced_threshod)
262
+
263
+ # to merge source harmonics into a single excitation
264
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
265
+ self.l_tanh = torch.nn.Tanh()
266
+
267
+ def forward(self, x):
268
+ """
269
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
270
+ F0_sampled (batchsize, length, 1)
271
+ Sine_source (batchsize, length, 1)
272
+ noise_source (batchsize, length 1)
273
+ """
274
+ # source for harmonic branch
275
+ with torch.no_grad():
276
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
277
+ sine_wavs = sine_wavs.transpose(1, 2)
278
+ uv = uv.transpose(1, 2)
279
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
280
+
281
+ # source for noise branch, in the same shape as uv
282
+ noise = torch.randn_like(uv) * self.sine_amp / 3
283
+ return sine_merge, noise, uv
284
+
285
+
286
+ class HiFTGenerator(nn.Module):
287
+ """
288
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
289
+ https://arxiv.org/abs/2309.09493
290
+ """
291
+ def __init__(
292
+ self,
293
+ in_channels: int = 80,
294
+ base_channels: int = 512,
295
+ nb_harmonics: int = 8,
296
+ sampling_rate: int = 22050,
297
+ nsf_alpha: float = 0.1,
298
+ nsf_sigma: float = 0.003,
299
+ nsf_voiced_threshold: float = 10,
300
+ upsample_rates: List[int] = [8, 8],
301
+ upsample_kernel_sizes: List[int] = [16, 16],
302
+ istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
303
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
304
+ resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
305
+ source_resblock_kernel_sizes: List[int] = [7, 11],
306
+ source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
307
+ lrelu_slope: float = 0.1,
308
+ audio_limit: float = 0.99,
309
+ f0_predictor: torch.nn.Module = None,
310
+ ):
311
+ super(HiFTGenerator, self).__init__()
312
+
313
+ self.out_channels = 1
314
+ self.nb_harmonics = nb_harmonics
315
+ self.sampling_rate = sampling_rate
316
+ self.istft_params = istft_params
317
+ self.lrelu_slope = lrelu_slope
318
+ self.audio_limit = audio_limit
319
+
320
+ self.num_kernels = len(resblock_kernel_sizes)
321
+ self.num_upsamples = len(upsample_rates)
322
+ self.m_source = SourceModuleHnNSF(
323
+ sampling_rate=sampling_rate,
324
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
325
+ harmonic_num=nb_harmonics,
326
+ sine_amp=nsf_alpha,
327
+ add_noise_std=nsf_sigma,
328
+ voiced_threshod=nsf_voiced_threshold)
329
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
330
+
331
+ self.conv_pre = weight_norm(
332
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
333
+ )
334
+
335
+ # Up
336
+ self.ups = nn.ModuleList()
337
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
338
+ self.ups.append(
339
+ weight_norm(
340
+ ConvTranspose1d(
341
+ base_channels // (2**i),
342
+ base_channels // (2**(i + 1)),
343
+ k,
344
+ u,
345
+ padding=(k - u) // 2,
346
+ )
347
+ )
348
+ )
349
+
350
+ # Down
351
+ self.source_downs = nn.ModuleList()
352
+ self.source_resblocks = nn.ModuleList()
353
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
354
+ downsample_cum_rates = np.cumprod(downsample_rates)
355
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
356
+ if u == 1:
357
+ self.source_downs.append(
358
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
359
+ )
360
+ else:
361
+ self.source_downs.append(
362
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
363
+ )
364
+
365
+ self.source_resblocks.append(
366
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
367
+ )
368
+
369
+ self.resblocks = nn.ModuleList()
370
+ for i in range(len(self.ups)):
371
+ ch = base_channels // (2**(i + 1))
372
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
373
+ self.resblocks.append(ResBlock(ch, k, d))
374
+
375
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
376
+ self.ups.apply(init_weights)
377
+ self.conv_post.apply(init_weights)
378
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
379
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
380
+ self.f0_predictor = f0_predictor
381
+
382
+ def remove_weight_norm(self):
383
+ print('Removing weight norm...')
384
+ for l in self.ups:
385
+ remove_weight_norm(l)
386
+ for l in self.resblocks:
387
+ l.remove_weight_norm()
388
+ remove_weight_norm(self.conv_pre)
389
+ remove_weight_norm(self.conv_post)
390
+ self.m_source.remove_weight_norm()
391
+ for l in self.source_downs:
392
+ remove_weight_norm(l)
393
+ for l in self.source_resblocks:
394
+ l.remove_weight_norm()
395
+
396
+ def _stft(self, x):
397
+ spec = torch.stft(
398
+ x,
399
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
400
+ return_complex=True)
401
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
402
+ return spec[..., 0], spec[..., 1]
403
+
404
+ def _istft(self, magnitude, phase):
405
+ magnitude = torch.clip(magnitude, max=1e2)
406
+ real = magnitude * torch.cos(phase)
407
+ img = magnitude * torch.sin(phase)
408
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
409
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
410
+ return inverse_transform
411
+
412
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
413
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
414
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
415
+
416
+ x = self.conv_pre(x)
417
+ for i in range(self.num_upsamples):
418
+ x = F.leaky_relu(x, self.lrelu_slope)
419
+ x = self.ups[i](x)
420
+
421
+ if i == self.num_upsamples - 1:
422
+ x = self.reflection_pad(x)
423
+
424
+ # fusion
425
+ si = self.source_downs[i](s_stft)
426
+ si = self.source_resblocks[i](si)
427
+ x = x + si
428
+
429
+ xs = None
430
+ for j in range(self.num_kernels):
431
+ if xs is None:
432
+ xs = self.resblocks[i * self.num_kernels + j](x)
433
+ else:
434
+ xs += self.resblocks[i * self.num_kernels + j](x)
435
+ x = xs / self.num_kernels
436
+
437
+ x = F.leaky_relu(x)
438
+ x = self.conv_post(x)
439
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
440
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
441
+
442
+ x = self._istft(magnitude, phase)
443
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
444
+ return x
445
+
446
+ def forward(
447
+ self,
448
+ batch: dict,
449
+ device: torch.device,
450
+ ) -> Dict[str, Optional[torch.Tensor]]:
451
+ speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
452
+ # mel->f0
453
+ f0 = self.f0_predictor(speech_feat)
454
+ # f0->source
455
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
456
+ s, _, _ = self.m_source(s)
457
+ s = s.transpose(1, 2)
458
+ # mel+source->speech
459
+ generated_speech = self.decode(x=speech_feat, s=s)
460
+ return generated_speech, f0
461
+
462
+ @torch.inference_mode()
463
+ def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
464
+ # mel->f0
465
+ f0 = self.f0_predictor(speech_feat)
466
+ # f0->source
467
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
468
+ s, _, _ = self.m_source(s)
469
+ s = s.transpose(1, 2)
470
+ # use cache_source to avoid glitch
471
+ if cache_source.shape[2] != 0:
472
+ s[:, :, :cache_source.shape[2]] = cache_source
473
+ generated_speech = self.decode(x=speech_feat, s=s)
474
+ return generated_speech, s
src/chatterbox/models/s3gen/matcha/decoder.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from conformer import ConformerBlock
8
+ from diffusers.models.activations import get_activation
9
+ from einops import pack, rearrange, repeat
10
+
11
+ from .transformer import BasicTransformerBlock
12
+
13
+
14
+ class SinusoidalPosEmb(torch.nn.Module):
15
+ def __init__(self, dim):
16
+ super().__init__()
17
+ self.dim = dim
18
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
19
+
20
+ def forward(self, x, scale=1000):
21
+ if x.ndim < 1:
22
+ x = x.unsqueeze(0)
23
+ device = x.device
24
+ half_dim = self.dim // 2
25
+ emb = math.log(10000) / (half_dim - 1)
26
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
27
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
28
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
29
+ return emb
30
+
31
+
32
+ class Block1D(torch.nn.Module):
33
+ def __init__(self, dim, dim_out, groups=8):
34
+ super().__init__()
35
+ self.block = torch.nn.Sequential(
36
+ torch.nn.Conv1d(dim, dim_out, 3, padding=1),
37
+ torch.nn.GroupNorm(groups, dim_out),
38
+ nn.Mish(),
39
+ )
40
+
41
+ def forward(self, x, mask):
42
+ output = self.block(x * mask)
43
+ return output * mask
44
+
45
+
46
+ class ResnetBlock1D(torch.nn.Module):
47
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
48
+ super().__init__()
49
+ self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
50
+
51
+ self.block1 = Block1D(dim, dim_out, groups=groups)
52
+ self.block2 = Block1D(dim_out, dim_out, groups=groups)
53
+
54
+ self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
55
+
56
+ def forward(self, x, mask, time_emb):
57
+ h = self.block1(x, mask)
58
+ h += self.mlp(time_emb).unsqueeze(-1)
59
+ h = self.block2(h, mask)
60
+ output = h + self.res_conv(x * mask)
61
+ return output
62
+
63
+
64
+ class Downsample1D(nn.Module):
65
+ def __init__(self, dim):
66
+ super().__init__()
67
+ self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
68
+
69
+ def forward(self, x):
70
+ return self.conv(x)
71
+
72
+
73
+ class TimestepEmbedding(nn.Module):
74
+ def __init__(
75
+ self,
76
+ in_channels: int,
77
+ time_embed_dim: int,
78
+ act_fn: str = "silu",
79
+ out_dim: int = None,
80
+ post_act_fn: Optional[str] = None,
81
+ cond_proj_dim=None,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
86
+
87
+ if cond_proj_dim is not None:
88
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
89
+ else:
90
+ self.cond_proj = None
91
+
92
+ self.act = get_activation(act_fn)
93
+
94
+ if out_dim is not None:
95
+ time_embed_dim_out = out_dim
96
+ else:
97
+ time_embed_dim_out = time_embed_dim
98
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
99
+
100
+ if post_act_fn is None:
101
+ self.post_act = None
102
+ else:
103
+ self.post_act = get_activation(post_act_fn)
104
+
105
+ def forward(self, sample, condition=None):
106
+ if condition is not None:
107
+ sample = sample + self.cond_proj(condition)
108
+ sample = self.linear_1(sample)
109
+
110
+ if self.act is not None:
111
+ sample = self.act(sample)
112
+
113
+ sample = self.linear_2(sample)
114
+
115
+ if self.post_act is not None:
116
+ sample = self.post_act(sample)
117
+ return sample
118
+
119
+
120
+ class Upsample1D(nn.Module):
121
+ """A 1D upsampling layer with an optional convolution.
122
+
123
+ Parameters:
124
+ channels (`int`):
125
+ number of channels in the inputs and outputs.
126
+ use_conv (`bool`, default `False`):
127
+ option to use a convolution.
128
+ use_conv_transpose (`bool`, default `False`):
129
+ option to use a convolution transpose.
130
+ out_channels (`int`, optional):
131
+ number of output channels. Defaults to `channels`.
132
+ """
133
+
134
+ def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
135
+ super().__init__()
136
+ self.channels = channels
137
+ self.out_channels = out_channels or channels
138
+ self.use_conv = use_conv
139
+ self.use_conv_transpose = use_conv_transpose
140
+ self.name = name
141
+
142
+ self.conv = None
143
+ if use_conv_transpose:
144
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
145
+ elif use_conv:
146
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
147
+
148
+ def forward(self, inputs):
149
+ assert inputs.shape[1] == self.channels
150
+ if self.use_conv_transpose:
151
+ return self.conv(inputs)
152
+
153
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
154
+
155
+ if self.use_conv:
156
+ outputs = self.conv(outputs)
157
+
158
+ return outputs
159
+
160
+
161
+ class ConformerWrapper(ConformerBlock):
162
+ def __init__( # pylint: disable=useless-super-delegation
163
+ self,
164
+ *,
165
+ dim,
166
+ dim_head=64,
167
+ heads=8,
168
+ ff_mult=4,
169
+ conv_expansion_factor=2,
170
+ conv_kernel_size=31,
171
+ attn_dropout=0,
172
+ ff_dropout=0,
173
+ conv_dropout=0,
174
+ conv_causal=False,
175
+ ):
176
+ super().__init__(
177
+ dim=dim,
178
+ dim_head=dim_head,
179
+ heads=heads,
180
+ ff_mult=ff_mult,
181
+ conv_expansion_factor=conv_expansion_factor,
182
+ conv_kernel_size=conv_kernel_size,
183
+ attn_dropout=attn_dropout,
184
+ ff_dropout=ff_dropout,
185
+ conv_dropout=conv_dropout,
186
+ conv_causal=conv_causal,
187
+ )
188
+
189
+ def forward(
190
+ self,
191
+ hidden_states,
192
+ attention_mask,
193
+ encoder_hidden_states=None,
194
+ encoder_attention_mask=None,
195
+ timestep=None,
196
+ ):
197
+ return super().forward(x=hidden_states, mask=attention_mask.bool())
198
+
199
+
200
+ class Decoder(nn.Module):
201
+ def __init__(
202
+ self,
203
+ in_channels,
204
+ out_channels,
205
+ channels=(256, 256),
206
+ dropout=0.05,
207
+ attention_head_dim=64,
208
+ n_blocks=1,
209
+ num_mid_blocks=2,
210
+ num_heads=4,
211
+ act_fn="snake",
212
+ down_block_type="transformer",
213
+ mid_block_type="transformer",
214
+ up_block_type="transformer",
215
+ ):
216
+ super().__init__()
217
+ channels = tuple(channels)
218
+ self.in_channels = in_channels
219
+ self.out_channels = out_channels
220
+
221
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
222
+ time_embed_dim = channels[0] * 4
223
+ self.time_mlp = TimestepEmbedding(
224
+ in_channels=in_channels,
225
+ time_embed_dim=time_embed_dim,
226
+ act_fn="silu",
227
+ )
228
+
229
+ self.down_blocks = nn.ModuleList([])
230
+ self.mid_blocks = nn.ModuleList([])
231
+ self.up_blocks = nn.ModuleList([])
232
+
233
+ output_channel = in_channels
234
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
235
+ input_channel = output_channel
236
+ output_channel = channels[i]
237
+ is_last = i == len(channels) - 1
238
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
239
+ transformer_blocks = nn.ModuleList(
240
+ [
241
+ self.get_block(
242
+ down_block_type,
243
+ output_channel,
244
+ attention_head_dim,
245
+ num_heads,
246
+ dropout,
247
+ act_fn,
248
+ )
249
+ for _ in range(n_blocks)
250
+ ]
251
+ )
252
+ downsample = (
253
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
254
+ )
255
+
256
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
257
+
258
+ for i in range(num_mid_blocks):
259
+ input_channel = channels[-1]
260
+ out_channels = channels[-1]
261
+
262
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
263
+
264
+ transformer_blocks = nn.ModuleList(
265
+ [
266
+ self.get_block(
267
+ mid_block_type,
268
+ output_channel,
269
+ attention_head_dim,
270
+ num_heads,
271
+ dropout,
272
+ act_fn,
273
+ )
274
+ for _ in range(n_blocks)
275
+ ]
276
+ )
277
+
278
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
279
+
280
+ channels = channels[::-1] + (channels[0],)
281
+ for i in range(len(channels) - 1):
282
+ input_channel = channels[i]
283
+ output_channel = channels[i + 1]
284
+ is_last = i == len(channels) - 2
285
+
286
+ resnet = ResnetBlock1D(
287
+ dim=2 * input_channel,
288
+ dim_out=output_channel,
289
+ time_emb_dim=time_embed_dim,
290
+ )
291
+ transformer_blocks = nn.ModuleList(
292
+ [
293
+ self.get_block(
294
+ up_block_type,
295
+ output_channel,
296
+ attention_head_dim,
297
+ num_heads,
298
+ dropout,
299
+ act_fn,
300
+ )
301
+ for _ in range(n_blocks)
302
+ ]
303
+ )
304
+ upsample = (
305
+ Upsample1D(output_channel, use_conv_transpose=True)
306
+ if not is_last
307
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
308
+ )
309
+
310
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
311
+
312
+ self.final_block = Block1D(channels[-1], channels[-1])
313
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
314
+
315
+ self.initialize_weights()
316
+ # nn.init.normal_(self.final_proj.weight)
317
+
318
+ @staticmethod
319
+ def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
320
+ if block_type == "conformer":
321
+ block = ConformerWrapper(
322
+ dim=dim,
323
+ dim_head=attention_head_dim,
324
+ heads=num_heads,
325
+ ff_mult=1,
326
+ conv_expansion_factor=2,
327
+ ff_dropout=dropout,
328
+ attn_dropout=dropout,
329
+ conv_dropout=dropout,
330
+ conv_kernel_size=31,
331
+ )
332
+ elif block_type == "transformer":
333
+ block = BasicTransformerBlock(
334
+ dim=dim,
335
+ num_attention_heads=num_heads,
336
+ attention_head_dim=attention_head_dim,
337
+ dropout=dropout,
338
+ activation_fn=act_fn,
339
+ )
340
+ else:
341
+ raise ValueError(f"Unknown block type {block_type}")
342
+
343
+ return block
344
+
345
+ def initialize_weights(self):
346
+ for m in self.modules():
347
+ if isinstance(m, nn.Conv1d):
348
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
349
+
350
+ if m.bias is not None:
351
+ nn.init.constant_(m.bias, 0)
352
+
353
+ elif isinstance(m, nn.GroupNorm):
354
+ nn.init.constant_(m.weight, 1)
355
+ nn.init.constant_(m.bias, 0)
356
+
357
+ elif isinstance(m, nn.Linear):
358
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
359
+
360
+ if m.bias is not None:
361
+ nn.init.constant_(m.bias, 0)
362
+
363
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
364
+ """Forward pass of the UNet1DConditional model.
365
+
366
+ Args:
367
+ x (torch.Tensor): shape (batch_size, in_channels, time)
368
+ mask (_type_): shape (batch_size, 1, time)
369
+ t (_type_): shape (batch_size)
370
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
371
+ cond (_type_, optional): placeholder for future use. Defaults to None.
372
+
373
+ Raises:
374
+ ValueError: _description_
375
+ ValueError: _description_
376
+
377
+ Returns:
378
+ _type_: _description_
379
+ """
380
+
381
+ t = self.time_embeddings(t)
382
+ t = self.time_mlp(t)
383
+
384
+ x = pack([x, mu], "b * t")[0]
385
+
386
+ if spks is not None:
387
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
388
+ x = pack([x, spks], "b * t")[0]
389
+
390
+ hiddens = []
391
+ masks = [mask]
392
+ for resnet, transformer_blocks, downsample in self.down_blocks:
393
+ mask_down = masks[-1]
394
+ x = resnet(x, mask_down, t)
395
+ x = rearrange(x, "b c t -> b t c")
396
+ mask_down = rearrange(mask_down, "b 1 t -> b t")
397
+ for transformer_block in transformer_blocks:
398
+ x = transformer_block(
399
+ hidden_states=x,
400
+ attention_mask=mask_down,
401
+ timestep=t,
402
+ )
403
+ x = rearrange(x, "b t c -> b c t")
404
+ mask_down = rearrange(mask_down, "b t -> b 1 t")
405
+ hiddens.append(x) # Save hidden states for skip connections
406
+ x = downsample(x * mask_down)
407
+ masks.append(mask_down[:, :, ::2])
408
+
409
+ masks = masks[:-1]
410
+ mask_mid = masks[-1]
411
+
412
+ for resnet, transformer_blocks in self.mid_blocks:
413
+ x = resnet(x, mask_mid, t)
414
+ x = rearrange(x, "b c t -> b t c")
415
+ mask_mid = rearrange(mask_mid, "b 1 t -> b t")
416
+ for transformer_block in transformer_blocks:
417
+ x = transformer_block(
418
+ hidden_states=x,
419
+ attention_mask=mask_mid,
420
+ timestep=t,
421
+ )
422
+ x = rearrange(x, "b t c -> b c t")
423
+ mask_mid = rearrange(mask_mid, "b t -> b 1 t")
424
+
425
+ for resnet, transformer_blocks, upsample in self.up_blocks:
426
+ mask_up = masks.pop()
427
+ x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
428
+ x = rearrange(x, "b c t -> b t c")
429
+ mask_up = rearrange(mask_up, "b 1 t -> b t")
430
+ for transformer_block in transformer_blocks:
431
+ x = transformer_block(
432
+ hidden_states=x,
433
+ attention_mask=mask_up,
434
+ timestep=t,
435
+ )
436
+ x = rearrange(x, "b t c -> b c t")
437
+ mask_up = rearrange(mask_up, "b t -> b 1 t")
438
+ x = upsample(x * mask_up)
439
+
440
+ x = self.final_block(x, mask_up)
441
+ output = self.final_proj(x * mask_up)
442
+
443
+ return output * mask
src/chatterbox/models/s3gen/matcha/flow_matching.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from .decoder import Decoder
7
+
8
+
9
+ class BASECFM(torch.nn.Module, ABC):
10
+ def __init__(
11
+ self,
12
+ n_feats,
13
+ cfm_params,
14
+ n_spks=1,
15
+ spk_emb_dim=128,
16
+ ):
17
+ super().__init__()
18
+ self.n_feats = n_feats
19
+ self.n_spks = n_spks
20
+ self.spk_emb_dim = spk_emb_dim
21
+ self.solver = cfm_params.solver
22
+ if hasattr(cfm_params, "sigma_min"):
23
+ self.sigma_min = cfm_params.sigma_min
24
+ else:
25
+ self.sigma_min = 1e-4
26
+
27
+ self.estimator = None
28
+
29
+ @torch.inference_mode()
30
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
31
+ """Forward diffusion
32
+
33
+ Args:
34
+ mu (torch.Tensor): output of encoder
35
+ shape: (batch_size, n_feats, mel_timesteps)
36
+ mask (torch.Tensor): output_mask
37
+ shape: (batch_size, 1, mel_timesteps)
38
+ n_timesteps (int): number of diffusion steps
39
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
40
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
41
+ shape: (batch_size, spk_emb_dim)
42
+ cond: Not used but kept for future purposes
43
+
44
+ Returns:
45
+ sample: generated mel-spectrogram
46
+ shape: (batch_size, n_feats, mel_timesteps)
47
+ """
48
+ z = torch.randn_like(mu) * temperature
49
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
50
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
51
+
52
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
53
+ """
54
+ Fixed euler solver for ODEs.
55
+ Args:
56
+ x (torch.Tensor): random noise
57
+ t_span (torch.Tensor): n_timesteps interpolated
58
+ shape: (n_timesteps + 1,)
59
+ mu (torch.Tensor): output of encoder
60
+ shape: (batch_size, n_feats, mel_timesteps)
61
+ mask (torch.Tensor): output_mask
62
+ shape: (batch_size, 1, mel_timesteps)
63
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
64
+ shape: (batch_size, spk_emb_dim)
65
+ cond: Not used but kept for future purposes
66
+ """
67
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
68
+
69
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
70
+ # Or in future might add like a return_all_steps flag
71
+ sol = []
72
+
73
+ for step in range(1, len(t_span)):
74
+ dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
75
+
76
+ x = x + dt * dphi_dt
77
+ t = t + dt
78
+ sol.append(x)
79
+ if step < len(t_span) - 1:
80
+ dt = t_span[step + 1] - t
81
+
82
+ return sol[-1]
83
+
84
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
85
+ """Computes diffusion loss
86
+
87
+ Args:
88
+ x1 (torch.Tensor): Target
89
+ shape: (batch_size, n_feats, mel_timesteps)
90
+ mask (torch.Tensor): target mask
91
+ shape: (batch_size, 1, mel_timesteps)
92
+ mu (torch.Tensor): output of encoder
93
+ shape: (batch_size, n_feats, mel_timesteps)
94
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
95
+ shape: (batch_size, spk_emb_dim)
96
+
97
+ Returns:
98
+ loss: conditional flow matching loss
99
+ y: conditional flow
100
+ shape: (batch_size, n_feats, mel_timesteps)
101
+ """
102
+ b, _, t = mu.shape
103
+
104
+ # random timestep
105
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
106
+ # sample noise p(x_0)
107
+ z = torch.randn_like(x1)
108
+
109
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
110
+ u = x1 - (1 - self.sigma_min) * z
111
+
112
+ loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
113
+ torch.sum(mask) * u.shape[1]
114
+ )
115
+ return loss, y
116
+
117
+
118
+ class CFM(BASECFM):
119
+ def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
120
+ super().__init__(
121
+ n_feats=in_channels,
122
+ cfm_params=cfm_params,
123
+ n_spks=n_spks,
124
+ spk_emb_dim=spk_emb_dim,
125
+ )
126
+
127
+ in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
128
+ # Just change the architecture of the estimator here
129
+ self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
src/chatterbox/models/s3gen/matcha/text_encoder.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jaywalnut310/glow-tts """
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from einops import rearrange
8
+
9
+
10
+ def sequence_mask(length, max_length=None):
11
+ if max_length is None:
12
+ max_length = length.max()
13
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
14
+ return x.unsqueeze(0) < length.unsqueeze(1)
15
+
16
+
17
+
18
+ class LayerNorm(nn.Module):
19
+ def __init__(self, channels, eps=1e-4):
20
+ super().__init__()
21
+ self.channels = channels
22
+ self.eps = eps
23
+
24
+ self.gamma = torch.nn.Parameter(torch.ones(channels))
25
+ self.beta = torch.nn.Parameter(torch.zeros(channels))
26
+
27
+ def forward(self, x):
28
+ n_dims = len(x.shape)
29
+ mean = torch.mean(x, 1, keepdim=True)
30
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
31
+
32
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
33
+
34
+ shape = [1, -1] + [1] * (n_dims - 2)
35
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
36
+ return x
37
+
38
+
39
+ class ConvReluNorm(nn.Module):
40
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
41
+ super().__init__()
42
+ self.in_channels = in_channels
43
+ self.hidden_channels = hidden_channels
44
+ self.out_channels = out_channels
45
+ self.kernel_size = kernel_size
46
+ self.n_layers = n_layers
47
+ self.p_dropout = p_dropout
48
+
49
+ self.conv_layers = torch.nn.ModuleList()
50
+ self.norm_layers = torch.nn.ModuleList()
51
+ self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
52
+ self.norm_layers.append(LayerNorm(hidden_channels))
53
+ self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
54
+ for _ in range(n_layers - 1):
55
+ self.conv_layers.append(
56
+ torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
57
+ )
58
+ self.norm_layers.append(LayerNorm(hidden_channels))
59
+ self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
60
+ self.proj.weight.data.zero_()
61
+ self.proj.bias.data.zero_()
62
+
63
+ def forward(self, x, x_mask):
64
+ x_org = x
65
+ for i in range(self.n_layers):
66
+ x = self.conv_layers[i](x * x_mask)
67
+ x = self.norm_layers[i](x)
68
+ x = self.relu_drop(x)
69
+ x = x_org + self.proj(x)
70
+ return x * x_mask
71
+
72
+
73
+ class DurationPredictor(nn.Module):
74
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
75
+ super().__init__()
76
+ self.in_channels = in_channels
77
+ self.filter_channels = filter_channels
78
+ self.p_dropout = p_dropout
79
+
80
+ self.drop = torch.nn.Dropout(p_dropout)
81
+ self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
82
+ self.norm_1 = LayerNorm(filter_channels)
83
+ self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
84
+ self.norm_2 = LayerNorm(filter_channels)
85
+ self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
86
+
87
+ def forward(self, x, x_mask):
88
+ x = self.conv_1(x * x_mask)
89
+ x = torch.relu(x)
90
+ x = self.norm_1(x)
91
+ x = self.drop(x)
92
+ x = self.conv_2(x * x_mask)
93
+ x = torch.relu(x)
94
+ x = self.norm_2(x)
95
+ x = self.drop(x)
96
+ x = self.proj(x * x_mask)
97
+ return x * x_mask
98
+
99
+
100
+ class RotaryPositionalEmbeddings(nn.Module):
101
+ """
102
+ ## RoPE module
103
+
104
+ Rotary encoding transforms pairs of features by rotating in the 2D plane.
105
+ That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
106
+ Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
107
+ by an angle depending on the position of the token.
108
+ """
109
+
110
+ def __init__(self, d: int, base: int = 10_000):
111
+ r"""
112
+ * `d` is the number of features $d$
113
+ * `base` is the constant used for calculating $\Theta$
114
+ """
115
+ super().__init__()
116
+
117
+ self.base = base
118
+ self.d = int(d)
119
+ self.cos_cached = None
120
+ self.sin_cached = None
121
+
122
+ def _build_cache(self, x: torch.Tensor):
123
+ r"""
124
+ Cache $\cos$ and $\sin$ values
125
+ """
126
+ # Return if cache is already built
127
+ if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
128
+ return
129
+
130
+ # Get sequence length
131
+ seq_len = x.shape[0]
132
+
133
+ # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
134
+ theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
135
+
136
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
137
+ seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
138
+
139
+ # Calculate the product of position index and $\theta_i$
140
+ idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
141
+
142
+ # Concatenate so that for row $m$ we have
143
+ # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
144
+ idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
145
+
146
+ # Cache them
147
+ self.cos_cached = idx_theta2.cos()[:, None, None, :]
148
+ self.sin_cached = idx_theta2.sin()[:, None, None, :]
149
+
150
+ def _neg_half(self, x: torch.Tensor):
151
+ # $\frac{d}{2}$
152
+ d_2 = self.d // 2
153
+
154
+ # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
155
+ return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
156
+
157
+ def forward(self, x: torch.Tensor):
158
+ """
159
+ * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
160
+ """
161
+ # Cache $\cos$ and $\sin$ values
162
+ x = rearrange(x, "b h t d -> t b h d")
163
+
164
+ self._build_cache(x)
165
+
166
+ # Split the features, we can choose to apply rotary embeddings only to a partial set of features.
167
+ x_rope, x_pass = x[..., : self.d], x[..., self.d :]
168
+
169
+ # Calculate
170
+ # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
171
+ neg_half_x = self._neg_half(x_rope)
172
+
173
+ x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
174
+
175
+ return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
176
+
177
+
178
+ class MultiHeadAttention(nn.Module):
179
+ def __init__(
180
+ self,
181
+ channels,
182
+ out_channels,
183
+ n_heads,
184
+ heads_share=True,
185
+ p_dropout=0.0,
186
+ proximal_bias=False,
187
+ proximal_init=False,
188
+ ):
189
+ super().__init__()
190
+ assert channels % n_heads == 0
191
+
192
+ self.channels = channels
193
+ self.out_channels = out_channels
194
+ self.n_heads = n_heads
195
+ self.heads_share = heads_share
196
+ self.proximal_bias = proximal_bias
197
+ self.p_dropout = p_dropout
198
+ self.attn = None
199
+
200
+ self.k_channels = channels // n_heads
201
+ self.conv_q = torch.nn.Conv1d(channels, channels, 1)
202
+ self.conv_k = torch.nn.Conv1d(channels, channels, 1)
203
+ self.conv_v = torch.nn.Conv1d(channels, channels, 1)
204
+
205
+ # from https://nn.labml.ai/transformers/rope/index.html
206
+ self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
207
+ self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
208
+
209
+ self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
210
+ self.drop = torch.nn.Dropout(p_dropout)
211
+
212
+ torch.nn.init.xavier_uniform_(self.conv_q.weight)
213
+ torch.nn.init.xavier_uniform_(self.conv_k.weight)
214
+ if proximal_init:
215
+ self.conv_k.weight.data.copy_(self.conv_q.weight.data)
216
+ self.conv_k.bias.data.copy_(self.conv_q.bias.data)
217
+ torch.nn.init.xavier_uniform_(self.conv_v.weight)
218
+
219
+ def forward(self, x, c, attn_mask=None):
220
+ q = self.conv_q(x)
221
+ k = self.conv_k(c)
222
+ v = self.conv_v(c)
223
+
224
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
225
+
226
+ x = self.conv_o(x)
227
+ return x
228
+
229
+ def attention(self, query, key, value, mask=None):
230
+ b, d, t_s, t_t = (*key.size(), query.size(2))
231
+ query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads)
232
+ key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads)
233
+ value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads)
234
+
235
+ query = self.query_rotary_pe(query)
236
+ key = self.key_rotary_pe(key)
237
+
238
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
239
+
240
+ if self.proximal_bias:
241
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
242
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
243
+ if mask is not None:
244
+ scores = scores.masked_fill(mask == 0, -1e4)
245
+ p_attn = torch.nn.functional.softmax(scores, dim=-1)
246
+ p_attn = self.drop(p_attn)
247
+ output = torch.matmul(p_attn, value)
248
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t)
249
+ return output, p_attn
250
+
251
+ @staticmethod
252
+ def _attention_bias_proximal(length):
253
+ r = torch.arange(length, dtype=torch.float32)
254
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
255
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
256
+
257
+
258
+ class FFN(nn.Module):
259
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
260
+ super().__init__()
261
+ self.in_channels = in_channels
262
+ self.out_channels = out_channels
263
+ self.filter_channels = filter_channels
264
+ self.kernel_size = kernel_size
265
+ self.p_dropout = p_dropout
266
+
267
+ self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
268
+ self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
269
+ self.drop = torch.nn.Dropout(p_dropout)
270
+
271
+ def forward(self, x, x_mask):
272
+ x = self.conv_1(x * x_mask)
273
+ x = torch.relu(x)
274
+ x = self.drop(x)
275
+ x = self.conv_2(x * x_mask)
276
+ return x * x_mask
277
+
278
+
279
+ class Encoder(nn.Module):
280
+ def __init__(
281
+ self,
282
+ hidden_channels,
283
+ filter_channels,
284
+ n_heads,
285
+ n_layers,
286
+ kernel_size=1,
287
+ p_dropout=0.0,
288
+ **kwargs,
289
+ ):
290
+ super().__init__()
291
+ self.hidden_channels = hidden_channels
292
+ self.filter_channels = filter_channels
293
+ self.n_heads = n_heads
294
+ self.n_layers = n_layers
295
+ self.kernel_size = kernel_size
296
+ self.p_dropout = p_dropout
297
+
298
+ self.drop = torch.nn.Dropout(p_dropout)
299
+ self.attn_layers = torch.nn.ModuleList()
300
+ self.norm_layers_1 = torch.nn.ModuleList()
301
+ self.ffn_layers = torch.nn.ModuleList()
302
+ self.norm_layers_2 = torch.nn.ModuleList()
303
+ for _ in range(self.n_layers):
304
+ self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
305
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
306
+ self.ffn_layers.append(
307
+ FFN(
308
+ hidden_channels,
309
+ hidden_channels,
310
+ filter_channels,
311
+ kernel_size,
312
+ p_dropout=p_dropout,
313
+ )
314
+ )
315
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
316
+
317
+ def forward(self, x, x_mask):
318
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
319
+ for i in range(self.n_layers):
320
+ x = x * x_mask
321
+ y = self.attn_layers[i](x, x, attn_mask)
322
+ y = self.drop(y)
323
+ x = self.norm_layers_1[i](x + y)
324
+ y = self.ffn_layers[i](x, x_mask)
325
+ y = self.drop(y)
326
+ x = self.norm_layers_2[i](x + y)
327
+ x = x * x_mask
328
+ return x
329
+
330
+
331
+ class TextEncoder(nn.Module):
332
+ def __init__(
333
+ self,
334
+ encoder_type,
335
+ encoder_params,
336
+ duration_predictor_params,
337
+ n_vocab,
338
+ n_spks=1,
339
+ spk_emb_dim=128,
340
+ ):
341
+ super().__init__()
342
+ self.encoder_type = encoder_type
343
+ self.n_vocab = n_vocab
344
+ self.n_feats = encoder_params.n_feats
345
+ self.n_channels = encoder_params.n_channels
346
+ self.spk_emb_dim = spk_emb_dim
347
+ self.n_spks = n_spks
348
+
349
+ self.emb = torch.nn.Embedding(n_vocab, self.n_channels)
350
+ torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5)
351
+
352
+ if encoder_params.prenet:
353
+ self.prenet = ConvReluNorm(
354
+ self.n_channels,
355
+ self.n_channels,
356
+ self.n_channels,
357
+ kernel_size=5,
358
+ n_layers=3,
359
+ p_dropout=0.5,
360
+ )
361
+ else:
362
+ self.prenet = lambda x, x_mask: x
363
+
364
+ self.encoder = Encoder(
365
+ encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0),
366
+ encoder_params.filter_channels,
367
+ encoder_params.n_heads,
368
+ encoder_params.n_layers,
369
+ encoder_params.kernel_size,
370
+ encoder_params.p_dropout,
371
+ )
372
+
373
+ self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
374
+ self.proj_w = DurationPredictor(
375
+ self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
376
+ duration_predictor_params.filter_channels_dp,
377
+ duration_predictor_params.kernel_size,
378
+ duration_predictor_params.p_dropout,
379
+ )
380
+
381
+ def forward(self, x, x_lengths, spks=None):
382
+ """Run forward pass to the transformer based encoder and duration predictor
383
+
384
+ Args:
385
+ x (torch.Tensor): text input
386
+ shape: (batch_size, max_text_length)
387
+ x_lengths (torch.Tensor): text input lengths
388
+ shape: (batch_size,)
389
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
390
+ shape: (batch_size,)
391
+
392
+ Returns:
393
+ mu (torch.Tensor): average output of the encoder
394
+ shape: (batch_size, n_feats, max_text_length)
395
+ logw (torch.Tensor): log duration predicted by the duration predictor
396
+ shape: (batch_size, 1, max_text_length)
397
+ x_mask (torch.Tensor): mask for the text input
398
+ shape: (batch_size, 1, max_text_length)
399
+ """
400
+ x = self.emb(x) * math.sqrt(self.n_channels)
401
+ x = torch.transpose(x, 1, -1)
402
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
403
+
404
+ x = self.prenet(x, x_mask)
405
+ if self.n_spks > 1:
406
+ x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
407
+ x = self.encoder(x, x_mask)
408
+ mu = self.proj_m(x) * x_mask
409
+
410
+ x_dp = torch.detach(x)
411
+ logw = self.proj_w(x_dp, x_mask)
412
+
413
+ return mu, logw, x_mask
src/chatterbox/models/s3gen/matcha/transformer.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from diffusers.models.attention import (
6
+ GEGLU,
7
+ GELU,
8
+ AdaLayerNorm,
9
+ AdaLayerNormZero,
10
+ ApproximateGELU,
11
+ )
12
+ from diffusers.models.attention_processor import Attention
13
+ from diffusers.models.lora import LoRACompatibleLinear
14
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
15
+
16
+
17
+ class SnakeBeta(nn.Module):
18
+ """
19
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
20
+ Shape:
21
+ - Input: (B, C, T)
22
+ - Output: (B, C, T), same shape as the input
23
+ Parameters:
24
+ - alpha - trainable parameter that controls frequency
25
+ - beta - trainable parameter that controls magnitude
26
+ References:
27
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
28
+ https://arxiv.org/abs/2006.08195
29
+ Examples:
30
+ >>> a1 = snakebeta(256)
31
+ >>> x = torch.randn(256)
32
+ >>> x = a1(x)
33
+ """
34
+
35
+ def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
36
+ """
37
+ Initialization.
38
+ INPUT:
39
+ - in_features: shape of the input
40
+ - alpha - trainable parameter that controls frequency
41
+ - beta - trainable parameter that controls magnitude
42
+ alpha is initialized to 1 by default, higher values = higher-frequency.
43
+ beta is initialized to 1 by default, higher values = higher-magnitude.
44
+ alpha will be trained along with the rest of your model.
45
+ """
46
+ super().__init__()
47
+ self.in_features = out_features if isinstance(out_features, list) else [out_features]
48
+ self.proj = LoRACompatibleLinear(in_features, out_features)
49
+
50
+ # initialize alpha
51
+ self.alpha_logscale = alpha_logscale
52
+ if self.alpha_logscale: # log scale alphas initialized to zeros
53
+ self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
54
+ self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
55
+ else: # linear scale alphas initialized to ones
56
+ self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
57
+ self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
58
+
59
+ self.alpha.requires_grad = alpha_trainable
60
+ self.beta.requires_grad = alpha_trainable
61
+
62
+ self.no_div_by_zero = 0.000000001
63
+
64
+ def forward(self, x):
65
+ """
66
+ Forward pass of the function.
67
+ Applies the function to the input elementwise.
68
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
69
+ """
70
+ x = self.proj(x)
71
+ if self.alpha_logscale:
72
+ alpha = torch.exp(self.alpha)
73
+ beta = torch.exp(self.beta)
74
+ else:
75
+ alpha = self.alpha
76
+ beta = self.beta
77
+
78
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
79
+
80
+ return x
81
+
82
+
83
+ class FeedForward(nn.Module):
84
+ r"""
85
+ A feed-forward layer.
86
+
87
+ Parameters:
88
+ dim (`int`): The number of channels in the input.
89
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
90
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
91
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
92
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
93
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ dim: int,
99
+ dim_out: Optional[int] = None,
100
+ mult: int = 4,
101
+ dropout: float = 0.0,
102
+ activation_fn: str = "geglu",
103
+ final_dropout: bool = False,
104
+ ):
105
+ super().__init__()
106
+ inner_dim = int(dim * mult)
107
+ dim_out = dim_out if dim_out is not None else dim
108
+
109
+ if activation_fn == "gelu":
110
+ act_fn = GELU(dim, inner_dim)
111
+ if activation_fn == "gelu-approximate":
112
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
113
+ elif activation_fn == "geglu":
114
+ act_fn = GEGLU(dim, inner_dim)
115
+ elif activation_fn == "geglu-approximate":
116
+ act_fn = ApproximateGELU(dim, inner_dim)
117
+ elif activation_fn == "snakebeta":
118
+ act_fn = SnakeBeta(dim, inner_dim)
119
+
120
+ self.net = nn.ModuleList([])
121
+ # project in
122
+ self.net.append(act_fn)
123
+ # project dropout
124
+ self.net.append(nn.Dropout(dropout))
125
+ # project out
126
+ self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
127
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
128
+ if final_dropout:
129
+ self.net.append(nn.Dropout(dropout))
130
+
131
+ def forward(self, hidden_states):
132
+ for module in self.net:
133
+ hidden_states = module(hidden_states)
134
+ return hidden_states
135
+
136
+
137
+ @maybe_allow_in_graph
138
+ class BasicTransformerBlock(nn.Module):
139
+ r"""
140
+ A basic Transformer block.
141
+
142
+ Parameters:
143
+ dim (`int`): The number of channels in the input and output.
144
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
145
+ attention_head_dim (`int`): The number of channels in each head.
146
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
147
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
148
+ only_cross_attention (`bool`, *optional*):
149
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
150
+ double_self_attention (`bool`, *optional*):
151
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
152
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
153
+ num_embeds_ada_norm (:
154
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
155
+ attention_bias (:
156
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ dim: int,
162
+ num_attention_heads: int,
163
+ attention_head_dim: int,
164
+ dropout=0.0,
165
+ cross_attention_dim: Optional[int] = None,
166
+ activation_fn: str = "geglu",
167
+ num_embeds_ada_norm: Optional[int] = None,
168
+ attention_bias: bool = False,
169
+ only_cross_attention: bool = False,
170
+ double_self_attention: bool = False,
171
+ upcast_attention: bool = False,
172
+ norm_elementwise_affine: bool = True,
173
+ norm_type: str = "layer_norm",
174
+ final_dropout: bool = False,
175
+ ):
176
+ super().__init__()
177
+ self.only_cross_attention = only_cross_attention
178
+
179
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
180
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
181
+
182
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
183
+ raise ValueError(
184
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
185
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
186
+ )
187
+
188
+ # Define 3 blocks. Each block has its own normalization layer.
189
+ # 1. Self-Attn
190
+ if self.use_ada_layer_norm:
191
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
192
+ elif self.use_ada_layer_norm_zero:
193
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
194
+ else:
195
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
196
+ self.attn1 = Attention(
197
+ query_dim=dim,
198
+ heads=num_attention_heads,
199
+ dim_head=attention_head_dim,
200
+ dropout=dropout,
201
+ bias=attention_bias,
202
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
203
+ upcast_attention=upcast_attention,
204
+ )
205
+
206
+ # 2. Cross-Attn
207
+ if cross_attention_dim is not None or double_self_attention:
208
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
209
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
210
+ # the second cross attention block.
211
+ self.norm2 = (
212
+ AdaLayerNorm(dim, num_embeds_ada_norm)
213
+ if self.use_ada_layer_norm
214
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
215
+ )
216
+ self.attn2 = Attention(
217
+ query_dim=dim,
218
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
219
+ heads=num_attention_heads,
220
+ dim_head=attention_head_dim,
221
+ dropout=dropout,
222
+ bias=attention_bias,
223
+ upcast_attention=upcast_attention,
224
+ # scale_qk=False, # uncomment this to not to use flash attention
225
+ ) # is self-attn if encoder_hidden_states is none
226
+ else:
227
+ self.norm2 = None
228
+ self.attn2 = None
229
+
230
+ # 3. Feed-forward
231
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
232
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
233
+
234
+ # let chunk size default to None
235
+ self._chunk_size = None
236
+ self._chunk_dim = 0
237
+
238
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
239
+ # Sets chunk feed-forward
240
+ self._chunk_size = chunk_size
241
+ self._chunk_dim = dim
242
+
243
+ def forward(
244
+ self,
245
+ hidden_states: torch.FloatTensor,
246
+ attention_mask: Optional[torch.FloatTensor] = None,
247
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
248
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
249
+ timestep: Optional[torch.LongTensor] = None,
250
+ cross_attention_kwargs: Dict[str, Any] = None,
251
+ class_labels: Optional[torch.LongTensor] = None,
252
+ ):
253
+ # Notice that normalization is always applied before the real computation in the following blocks.
254
+ # 1. Self-Attention
255
+ if self.use_ada_layer_norm:
256
+ norm_hidden_states = self.norm1(hidden_states, timestep)
257
+ elif self.use_ada_layer_norm_zero:
258
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
259
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
260
+ )
261
+ else:
262
+ norm_hidden_states = self.norm1(hidden_states)
263
+
264
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
265
+
266
+ attn_output = self.attn1(
267
+ norm_hidden_states,
268
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
269
+ attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
270
+ **cross_attention_kwargs,
271
+ )
272
+ if self.use_ada_layer_norm_zero:
273
+ attn_output = gate_msa.unsqueeze(1) * attn_output
274
+ hidden_states = attn_output + hidden_states
275
+
276
+ # 2. Cross-Attention
277
+ if self.attn2 is not None:
278
+ norm_hidden_states = (
279
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
280
+ )
281
+
282
+ attn_output = self.attn2(
283
+ norm_hidden_states,
284
+ encoder_hidden_states=encoder_hidden_states,
285
+ attention_mask=encoder_attention_mask,
286
+ **cross_attention_kwargs,
287
+ )
288
+ hidden_states = attn_output + hidden_states
289
+
290
+ # 3. Feed-forward
291
+ norm_hidden_states = self.norm3(hidden_states)
292
+
293
+ if self.use_ada_layer_norm_zero:
294
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
295
+
296
+ if self._chunk_size is not None:
297
+ # "feed_forward_chunk_size" can be used to save memory
298
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
299
+ raise ValueError(
300
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
301
+ )
302
+
303
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
304
+ ff_output = torch.cat(
305
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
306
+ dim=self._chunk_dim,
307
+ )
308
+ else:
309
+ ff_output = self.ff(norm_hidden_states)
310
+
311
+ if self.use_ada_layer_norm_zero:
312
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
313
+
314
+ hidden_states = ff_output + hidden_states
315
+
316
+ return hidden_states
src/chatterbox/models/s3gen/s3gen.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torchaudio as ta
20
+ from functools import lru_cache
21
+ from typing import Optional
22
+
23
+ from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer
24
+ from .const import S3GEN_SR
25
+ from .flow import CausalMaskedDiffWithXvec
26
+ from .xvector import CAMPPlus
27
+ from .utils.mel import mel_spectrogram
28
+ from .f0_predictor import ConvRNNF0Predictor
29
+ from .hifigan import HiFTGenerator
30
+ from .transformer.upsample_encoder import UpsampleConformerEncoder
31
+ from .flow_matching import CausalConditionalCFM
32
+ from .decoder import ConditionalDecoder
33
+ from .configs import CFM_PARAMS
34
+
35
+
36
+ def drop_invalid_tokens(x):
37
+ assert len(x.shape) <= 2 and x.shape[0] == 1, "only batch size of one allowed for now"
38
+ return x[x < SPEECH_VOCAB_SIZE]
39
+
40
+
41
+ # TODO: global resampler cache
42
+ @lru_cache(100)
43
+ def get_resampler(src_sr, dst_sr, device):
44
+ return ta.transforms.Resample(src_sr, dst_sr).to(device)
45
+
46
+
47
+ class S3Token2Mel(torch.nn.Module):
48
+ """
49
+ CosyVoice2's CFM decoder maps S3 speech tokens to mel-spectrograms.
50
+
51
+ TODO: make these modules configurable?
52
+ """
53
+ def __init__(self):
54
+ super().__init__()
55
+ self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz")
56
+ self.mel_extractor = mel_spectrogram # TODO: make it a torch module?
57
+ self.speaker_encoder = CAMPPlus() # use default args
58
+
59
+ encoder = UpsampleConformerEncoder(
60
+ output_size=512,
61
+ attention_heads=8,
62
+ linear_units=2048,
63
+ num_blocks=6,
64
+ dropout_rate=0.1,
65
+ positional_dropout_rate=0.1,
66
+ attention_dropout_rate=0.1,
67
+ normalize_before=True,
68
+ input_layer='linear',
69
+ pos_enc_layer_type='rel_pos_espnet',
70
+ selfattention_layer_type='rel_selfattn',
71
+ input_size=512,
72
+ use_cnn_module=False,
73
+ macaron_style=False,
74
+ )
75
+
76
+ estimator = ConditionalDecoder(
77
+ in_channels=320,
78
+ out_channels=80,
79
+ causal=True,
80
+ channels=[256],
81
+ dropout=0.0,
82
+ attention_head_dim=64,
83
+ n_blocks=4,
84
+ num_mid_blocks=12,
85
+ num_heads=8,
86
+ act_fn='gelu',
87
+ )
88
+ cfm_params = CFM_PARAMS
89
+ decoder = CausalConditionalCFM(
90
+ spk_emb_dim=80,
91
+ cfm_params=cfm_params,
92
+ estimator=estimator,
93
+ )
94
+
95
+ self.flow = CausalMaskedDiffWithXvec(
96
+ encoder=encoder,
97
+ decoder=decoder
98
+ )
99
+
100
+ self.resamplers = {}
101
+
102
+ @property
103
+ def device(self):
104
+ params = self.tokenizer.parameters()
105
+ return next(params).device
106
+
107
+ def embed_ref(
108
+ self,
109
+ ref_wav: torch.Tensor,
110
+ ref_sr: int,
111
+ device="auto",
112
+ ref_fade_out=True,
113
+ ):
114
+ device = self.device if device == "auto" else device
115
+ if isinstance(ref_wav, np.ndarray):
116
+ ref_wav = torch.from_numpy(ref_wav).float()
117
+
118
+ if ref_wav.device != device:
119
+ ref_wav = ref_wav.to(device)
120
+
121
+ if len(ref_wav.shape) == 1:
122
+ ref_wav = ref_wav.unsqueeze(0) # (B, L)
123
+
124
+ if ref_wav.size(1) > 10 * ref_sr:
125
+ print("WARNING: cosydec received ref longer than 10s")
126
+
127
+ ref_wav_24 = ref_wav
128
+ if ref_sr != S3GEN_SR:
129
+ ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav)
130
+
131
+ ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(device)
132
+ ref_mels_24_len = None
133
+
134
+ # Resample to 16kHz
135
+ ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav).to(device)
136
+
137
+ # Speaker embedding
138
+ ref_x_vector = self.speaker_encoder.inference(ref_wav_16)
139
+
140
+ # Tokenize 16khz reference
141
+ ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16)
142
+
143
+ # Make sure mel_len = 2 * stoken_len (happens when the input is not padded to multiple of 40ms)
144
+ if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]:
145
+ logging.warning(
146
+ "Reference mel length is not equal to 2 * reference token length.\n"
147
+ )
148
+ ref_speech_tokens = ref_speech_tokens[:, :ref_mels_24.shape[1] // 2]
149
+ ref_speech_token_lens[0] = ref_speech_tokens.shape[1]
150
+
151
+ return dict(
152
+ prompt_token=ref_speech_tokens.to(device),
153
+ prompt_token_len=ref_speech_token_lens,
154
+ prompt_feat=ref_mels_24,
155
+ prompt_feat_len=ref_mels_24_len,
156
+ embedding=ref_x_vector,
157
+ )
158
+
159
+ def forward(
160
+ self,
161
+ speech_tokens: torch.LongTensor,
162
+ # locally-computed ref embedding (mutex with ref_dict)
163
+ ref_wav: Optional[torch.Tensor],
164
+ ref_sr: Optional[int],
165
+ # pre-computed ref embedding (prod API)
166
+ ref_dict: Optional[dict] = None,
167
+ finalize: bool = False,
168
+ ):
169
+ """
170
+ Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from.
171
+
172
+ NOTE:
173
+ - The speaker encoder accepts 16 kHz waveform.
174
+ - S3TokenizerV2 accepts 16 kHz waveform.
175
+ - The mel-spectrogram for the reference assumes 24 kHz input signal.
176
+ - This function is designed for batch_size=1 only.
177
+
178
+ Args
179
+ ----
180
+ - `speech_tokens`: S3 speech tokens [B=1, T]
181
+ - `ref_wav`: reference waveform (`torch.Tensor` with shape=[B=1, T])
182
+ - `ref_sr`: reference sample rate
183
+ - `finalize`: whether streaming is finished or not. Note that if False, the last 3 tokens will be ignored.
184
+ """
185
+ assert (ref_wav is None) ^ (ref_dict is None), f"Must provide exactly one of ref_wav or ref_dict (got {ref_wav} and {ref_dict})"
186
+
187
+ if ref_dict is None:
188
+ ref_dict = self.embed_ref(ref_wav, ref_sr)
189
+ else:
190
+ # type/device casting (all values will be numpy if it's from a prod API call)
191
+ for rk in list(ref_dict):
192
+ if isinstance(ref_dict[rk], np.ndarray):
193
+ ref_dict[rk] = torch.from_numpy(ref_dict[rk])
194
+ if torch.is_tensor(ref_dict[rk]):
195
+ ref_dict[rk] = ref_dict[rk].to(self.device)
196
+
197
+ if len(speech_tokens.shape) == 1:
198
+ speech_tokens = speech_tokens.unsqueeze(0)
199
+
200
+ # assert speech_tokens.shape[0] == 1, "only batch size of one allowed for now"
201
+ speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device)
202
+
203
+ output_mels, _ = self.flow.inference(
204
+ token=speech_tokens,
205
+ token_len=speech_token_lens,
206
+ finalize=finalize,
207
+ **ref_dict,
208
+ )
209
+ return output_mels
210
+
211
+
212
+ class S3Token2Wav(S3Token2Mel):
213
+ """
214
+ The decoder of CosyVoice2 is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules.
215
+
216
+ TODO: make these modules configurable?
217
+ """
218
+
219
+ def __init__(self):
220
+ super().__init__()
221
+
222
+ f0_predictor = ConvRNNF0Predictor()
223
+ self.mel2wav = HiFTGenerator(
224
+ sampling_rate=S3GEN_SR,
225
+ upsample_rates=[8, 5, 3],
226
+ upsample_kernel_sizes=[16, 11, 7],
227
+ source_resblock_kernel_sizes=[7, 7, 11],
228
+ source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
229
+ f0_predictor=f0_predictor,
230
+ )
231
+
232
+ # silence out a few ms and fade audio in to reduce artifacts
233
+ n_trim = S3GEN_SR // 50 # 20ms = half of a frame
234
+ trim_fade = torch.zeros(2 * n_trim)
235
+ trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2
236
+ self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting)
237
+
238
+ def forward(
239
+ self,
240
+ speech_tokens,
241
+ # locally-computed ref embedding (mutex with ref_dict)
242
+ ref_wav: Optional[torch.Tensor],
243
+ ref_sr: Optional[int],
244
+ # pre-computed ref embedding (prod API)
245
+ ref_dict: Optional[dict] = None,
246
+ finalize: bool = False
247
+ ):
248
+ output_mels = super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
249
+
250
+ # TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now.
251
+ hift_cache_source = torch.zeros(1, 1, 0).to(self.device)
252
+
253
+ output_wavs, *_ = self.mel2wav.inference(speech_feat=output_mels, cache_source=hift_cache_source)
254
+
255
+ if not self.training:
256
+ # NOTE: ad-hoc method to reduce "spillover" from the reference clip.
257
+ output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
258
+
259
+ return output_wavs
260
+
261
+ @torch.inference_mode()
262
+ def flow_inference(
263
+ self,
264
+ speech_tokens,
265
+ # locally-computed ref embedding (mutex with ref_dict)
266
+ ref_wav: Optional[torch.Tensor] = None,
267
+ ref_sr: Optional[int] = None,
268
+ # pre-computed ref embedding (prod API)
269
+ ref_dict: Optional[dict] = None,
270
+ finalize: bool = False,
271
+ ):
272
+ return super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
273
+
274
+ @torch.inference_mode()
275
+ def hift_inference(self, speech_feat, cache_source: torch.Tensor = None):
276
+ if cache_source is None:
277
+ cache_source = torch.zeros(1, 1, 0).to(self.device)
278
+ return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source)
279
+
280
+ @torch.inference_mode()
281
+ def inference(
282
+ self,
283
+ speech_tokens,
284
+ # locally-computed ref embedding (mutex with ref_dict)
285
+ ref_wav: Optional[torch.Tensor] = None,
286
+ ref_sr: Optional[int] = None,
287
+ # pre-computed ref embedding (prod API)
288
+ ref_dict: Optional[dict] = None,
289
+ cache_source: torch.Tensor = None, # NOTE: this arg is for streaming, it can probably be removed here
290
+ finalize: bool = True,
291
+ ):
292
+ output_mels = self.flow_inference(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
293
+ output_wavs, output_sources = self.hift_inference(output_mels, cache_source)
294
+
295
+ # NOTE: ad-hoc method to reduce "spillover" from the reference clip.
296
+ output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
297
+
298
+ return output_wavs, output_sources
src/chatterbox/models/s3gen/transformer/__init__.py ADDED
File without changes
src/chatterbox/models/s3gen/transformer/activation.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # 2020 Northwestern Polytechnical University (Pengcheng Guo)
3
+ # 2020 Mobvoi Inc (Binbin Zhang)
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Swish() activation function for Conformer."""
18
+
19
+ import torch
20
+ from torch import nn, sin, pow
21
+ from torch.nn import Parameter
22
+
23
+
24
+ class Swish(torch.nn.Module):
25
+ """Construct an Swish object."""
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ """Return Swish activation function."""
29
+ return x * torch.sigmoid(x)
30
+
31
+
32
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
33
+ # LICENSE is in incl_licenses directory.
34
+ class Snake(nn.Module):
35
+ '''
36
+ Implementation of a sine-based periodic activation function
37
+ Shape:
38
+ - Input: (B, C, T)
39
+ - Output: (B, C, T), same shape as the input
40
+ Parameters:
41
+ - alpha - trainable parameter
42
+ References:
43
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
44
+ https://arxiv.org/abs/2006.08195
45
+ Examples:
46
+ >>> a1 = snake(256)
47
+ >>> x = torch.randn(256)
48
+ >>> x = a1(x)
49
+ '''
50
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
51
+ '''
52
+ Initialization.
53
+ INPUT:
54
+ - in_features: shape of the input
55
+ - alpha: trainable parameter
56
+ alpha is initialized to 1 by default, higher values = higher-frequency.
57
+ alpha will be trained along with the rest of your model.
58
+ '''
59
+ super(Snake, self).__init__()
60
+ self.in_features = in_features
61
+
62
+ # initialize alpha
63
+ self.alpha_logscale = alpha_logscale
64
+ if self.alpha_logscale: # log scale alphas initialized to zeros
65
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
66
+ else: # linear scale alphas initialized to ones
67
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
68
+
69
+ self.alpha.requires_grad = alpha_trainable
70
+
71
+ self.no_div_by_zero = 0.000000001
72
+
73
+ def forward(self, x):
74
+ '''
75
+ Forward pass of the function.
76
+ Applies the function to the input elementwise.
77
+ Snake ∶= x + 1/a * sin^2 (xa)
78
+ '''
79
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
80
+ if self.alpha_logscale:
81
+ alpha = torch.exp(alpha)
82
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
83
+
84
+ return x
src/chatterbox/models/s3gen/transformer/attention.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2022 Xingchen Song ([email protected])
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Multi-Head Attention layer definition."""
18
+
19
+ import math
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+
26
+ class MultiHeadedAttention(nn.Module):
27
+ """Multi-Head Attention layer.
28
+
29
+ Args:
30
+ n_head (int): The number of heads.
31
+ n_feat (int): The number of features.
32
+ dropout_rate (float): Dropout rate.
33
+
34
+ """
35
+
36
+ def __init__(self,
37
+ n_head: int,
38
+ n_feat: int,
39
+ dropout_rate: float,
40
+ key_bias: bool = True):
41
+ """Construct an MultiHeadedAttention object."""
42
+ super().__init__()
43
+ assert n_feat % n_head == 0
44
+ # We assume d_v always equals d_k
45
+ self.d_k = n_feat // n_head
46
+ self.h = n_head
47
+ self.linear_q = nn.Linear(n_feat, n_feat)
48
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
49
+ self.linear_v = nn.Linear(n_feat, n_feat)
50
+ self.linear_out = nn.Linear(n_feat, n_feat)
51
+ self.dropout = nn.Dropout(p=dropout_rate)
52
+
53
+ def forward_qkv(
54
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
55
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
56
+ """Transform query, key and value.
57
+
58
+ Args:
59
+ query (torch.Tensor): Query tensor (#batch, time1, size).
60
+ key (torch.Tensor): Key tensor (#batch, time2, size).
61
+ value (torch.Tensor): Value tensor (#batch, time2, size).
62
+
63
+ Returns:
64
+ torch.Tensor: Transformed query tensor, size
65
+ (#batch, n_head, time1, d_k).
66
+ torch.Tensor: Transformed key tensor, size
67
+ (#batch, n_head, time2, d_k).
68
+ torch.Tensor: Transformed value tensor, size
69
+ (#batch, n_head, time2, d_k).
70
+
71
+ """
72
+ n_batch = query.size(0)
73
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
74
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
75
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
76
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
77
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
78
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
79
+
80
+ return q, k, v
81
+
82
+ def forward_attention(
83
+ self,
84
+ value: torch.Tensor,
85
+ scores: torch.Tensor,
86
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
87
+ ) -> torch.Tensor:
88
+ """Compute attention context vector.
89
+
90
+ Args:
91
+ value (torch.Tensor): Transformed value, size
92
+ (#batch, n_head, time2, d_k).
93
+ scores (torch.Tensor): Attention score, size
94
+ (#batch, n_head, time1, time2).
95
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
96
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
97
+
98
+ Returns:
99
+ torch.Tensor: Transformed value (#batch, time1, d_model)
100
+ weighted by the attention score (#batch, time1, time2).
101
+
102
+ """
103
+ n_batch = value.size(0)
104
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
105
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
106
+ # 1st chunk to ease the onnx export.]
107
+ # 2. pytorch training
108
+ if mask.size(2) > 0: # time2 > 0
109
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
110
+ # For last chunk, time2 might be larger than scores.size(-1)
111
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
112
+ scores = scores.masked_fill(mask, -float('inf'))
113
+ attn = torch.softmax(scores, dim=-1).masked_fill(
114
+ mask, 0.0) # (batch, head, time1, time2)
115
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
116
+ # 1. onnx(16/-1, -1/-1, 16/0)
117
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
118
+ else:
119
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
120
+
121
+ p_attn = self.dropout(attn)
122
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
123
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
124
+ self.h * self.d_k)
125
+ ) # (batch, time1, d_model)
126
+
127
+ return self.linear_out(x) # (batch, time1, d_model)
128
+
129
+ def forward(
130
+ self,
131
+ query: torch.Tensor,
132
+ key: torch.Tensor,
133
+ value: torch.Tensor,
134
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
135
+ pos_emb: torch.Tensor = torch.empty(0),
136
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
137
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
138
+ """Compute scaled dot product attention.
139
+
140
+ Args:
141
+ query (torch.Tensor): Query tensor (#batch, time1, size).
142
+ key (torch.Tensor): Key tensor (#batch, time2, size).
143
+ value (torch.Tensor): Value tensor (#batch, time2, size).
144
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
145
+ (#batch, time1, time2).
146
+ 1.When applying cross attention between decoder and encoder,
147
+ the batch padding mask for input is in (#batch, 1, T) shape.
148
+ 2.When applying self attention of encoder,
149
+ the mask is in (#batch, T, T) shape.
150
+ 3.When applying self attention of decoder,
151
+ the mask is in (#batch, L, L) shape.
152
+ 4.If the different position in decoder see different block
153
+ of the encoder, such as Mocha, the passed in mask could be
154
+ in (#batch, L, T) shape. But there is no such case in current
155
+ CosyVoice.
156
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
157
+ where `cache_t == chunk_size * num_decoding_left_chunks`
158
+ and `head * d_k == size`
159
+
160
+
161
+ Returns:
162
+ torch.Tensor: Output tensor (#batch, time1, d_model).
163
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
164
+ where `cache_t == chunk_size * num_decoding_left_chunks`
165
+ and `head * d_k == size`
166
+
167
+ """
168
+ q, k, v = self.forward_qkv(query, key, value)
169
+
170
+ # NOTE(xcsong):
171
+ # when export onnx model, for 1st chunk, we feed
172
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
173
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
174
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
175
+ # and we will always do splitting and
176
+ # concatnation(this will simplify onnx export). Note that
177
+ # it's OK to concat & split zero-shaped tensors(see code below).
178
+ # when export jit model, for 1st chunk, we always feed
179
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
180
+ # >>> a = torch.ones((1, 2, 0, 4))
181
+ # >>> b = torch.ones((1, 2, 3, 4))
182
+ # >>> c = torch.cat((a, b), dim=2)
183
+ # >>> torch.equal(b, c) # True
184
+ # >>> d = torch.split(a, 2, dim=-1)
185
+ # >>> torch.equal(d[0], d[1]) # True
186
+ if cache.size(0) > 0:
187
+ key_cache, value_cache = torch.split(cache,
188
+ cache.size(-1) // 2,
189
+ dim=-1)
190
+ k = torch.cat([key_cache, k], dim=2)
191
+ v = torch.cat([value_cache, v], dim=2)
192
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
193
+ # non-trivial to calculate `next_cache_start` here.
194
+ new_cache = torch.cat((k, v), dim=-1)
195
+
196
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
197
+ return self.forward_attention(v, scores, mask), new_cache
198
+
199
+
200
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
201
+ """Multi-Head Attention layer with relative position encoding.
202
+ Paper: https://arxiv.org/abs/1901.02860
203
+ Args:
204
+ n_head (int): The number of heads.
205
+ n_feat (int): The number of features.
206
+ dropout_rate (float): Dropout rate.
207
+ """
208
+
209
+ def __init__(self,
210
+ n_head: int,
211
+ n_feat: int,
212
+ dropout_rate: float,
213
+ key_bias: bool = True):
214
+ """Construct an RelPositionMultiHeadedAttention object."""
215
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
216
+ # linear transformation for positional encoding
217
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
218
+ # these two learnable bias are used in matrix c and matrix d
219
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
220
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
221
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
222
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
223
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
224
+
225
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
226
+ """Compute relative positional encoding.
227
+
228
+ Args:
229
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
230
+ time1 means the length of query vector.
231
+
232
+ Returns:
233
+ torch.Tensor: Output tensor.
234
+
235
+ """
236
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
237
+ device=x.device,
238
+ dtype=x.dtype)
239
+ x_padded = torch.cat([zero_pad, x], dim=-1)
240
+
241
+ x_padded = x_padded.view(x.size()[0],
242
+ x.size()[1],
243
+ x.size(3) + 1, x.size(2))
244
+ x = x_padded[:, :, 1:].view_as(x)[
245
+ :, :, :, : x.size(-1) // 2 + 1
246
+ ] # only keep the positions from 0 to time2
247
+ return x
248
+
249
+ def forward(
250
+ self,
251
+ query: torch.Tensor,
252
+ key: torch.Tensor,
253
+ value: torch.Tensor,
254
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
255
+ pos_emb: torch.Tensor = torch.empty(0),
256
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
257
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
258
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
259
+ Args:
260
+ query (torch.Tensor): Query tensor (#batch, time1, size).
261
+ key (torch.Tensor): Key tensor (#batch, time2, size).
262
+ value (torch.Tensor): Value tensor (#batch, time2, size).
263
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
264
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
265
+ pos_emb (torch.Tensor): Positional embedding tensor
266
+ (#batch, time2, size).
267
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
268
+ where `cache_t == chunk_size * num_decoding_left_chunks`
269
+ and `head * d_k == size`
270
+ Returns:
271
+ torch.Tensor: Output tensor (#batch, time1, d_model).
272
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
273
+ where `cache_t == chunk_size * num_decoding_left_chunks`
274
+ and `head * d_k == size`
275
+ """
276
+ q, k, v = self.forward_qkv(query, key, value)
277
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
278
+
279
+ # NOTE(xcsong):
280
+ # when export onnx model, for 1st chunk, we feed
281
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
282
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
283
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
284
+ # and we will always do splitting and
285
+ # concatnation(this will simplify onnx export). Note that
286
+ # it's OK to concat & split zero-shaped tensors(see code below).
287
+ # when export jit model, for 1st chunk, we always feed
288
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
289
+ # >>> a = torch.ones((1, 2, 0, 4))
290
+ # >>> b = torch.ones((1, 2, 3, 4))
291
+ # >>> c = torch.cat((a, b), dim=2)
292
+ # >>> torch.equal(b, c) # True
293
+ # >>> d = torch.split(a, 2, dim=-1)
294
+ # >>> torch.equal(d[0], d[1]) # True
295
+ if cache.size(0) > 0:
296
+ key_cache, value_cache = torch.split(cache,
297
+ cache.size(-1) // 2,
298
+ dim=-1)
299
+ k = torch.cat([key_cache, k], dim=2)
300
+ v = torch.cat([value_cache, v], dim=2)
301
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
302
+ # non-trivial to calculate `next_cache_start` here.
303
+ new_cache = torch.cat((k, v), dim=-1)
304
+
305
+ n_batch_pos = pos_emb.size(0)
306
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
307
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
308
+
309
+ # (batch, head, time1, d_k)
310
+ q_with_bias_u = (q + self.pos_bias_u.to(q.device)).transpose(1, 2)
311
+ # (batch, head, time1, d_k)
312
+ q_with_bias_v = (q + self.pos_bias_v.to(q.device)).transpose(1, 2)
313
+
314
+ # compute attention score
315
+ # first compute matrix a and matrix c
316
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
317
+ # (batch, head, time1, time2)
318
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
319
+
320
+ # compute matrix b and matrix d
321
+ # (batch, head, time1, time2)
322
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
323
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
324
+ if matrix_ac.shape != matrix_bd.shape:
325
+ matrix_bd = self.rel_shift(matrix_bd)
326
+
327
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
328
+ self.d_k) # (batch, head, time1, time2)
329
+
330
+ return self.forward_attention(v, scores, mask), new_cache
src/chatterbox/models/s3gen/transformer/convolution.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """ConvolutionModule definition."""
17
+
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class ConvolutionModule(nn.Module):
25
+ """ConvolutionModule in Conformer model."""
26
+
27
+ def __init__(self,
28
+ channels: int,
29
+ kernel_size: int = 15,
30
+ activation: nn.Module = nn.ReLU(),
31
+ norm: str = "batch_norm",
32
+ causal: bool = False,
33
+ bias: bool = True):
34
+ """Construct an ConvolutionModule object.
35
+ Args:
36
+ channels (int): The number of channels of conv layers.
37
+ kernel_size (int): Kernel size of conv layers.
38
+ causal (int): Whether use causal convolution or not
39
+ """
40
+ super().__init__()
41
+
42
+ self.pointwise_conv1 = nn.Conv1d(
43
+ channels,
44
+ 2 * channels,
45
+ kernel_size=1,
46
+ stride=1,
47
+ padding=0,
48
+ bias=bias,
49
+ )
50
+ # self.lorder is used to distinguish if it's a causal convolution,
51
+ # if self.lorder > 0: it's a causal convolution, the input will be
52
+ # padded with self.lorder frames on the left in forward.
53
+ # else: it's a symmetrical convolution
54
+ if causal:
55
+ padding = 0
56
+ self.lorder = kernel_size - 1
57
+ else:
58
+ # kernel_size should be an odd number for none causal convolution
59
+ assert (kernel_size - 1) % 2 == 0
60
+ padding = (kernel_size - 1) // 2
61
+ self.lorder = 0
62
+ self.depthwise_conv = nn.Conv1d(
63
+ channels,
64
+ channels,
65
+ kernel_size,
66
+ stride=1,
67
+ padding=padding,
68
+ groups=channels,
69
+ bias=bias,
70
+ )
71
+
72
+ assert norm in ['batch_norm', 'layer_norm']
73
+ if norm == "batch_norm":
74
+ self.use_layer_norm = False
75
+ self.norm = nn.BatchNorm1d(channels)
76
+ else:
77
+ self.use_layer_norm = True
78
+ self.norm = nn.LayerNorm(channels)
79
+
80
+ self.pointwise_conv2 = nn.Conv1d(
81
+ channels,
82
+ channels,
83
+ kernel_size=1,
84
+ stride=1,
85
+ padding=0,
86
+ bias=bias,
87
+ )
88
+ self.activation = activation
89
+
90
+ def forward(
91
+ self,
92
+ x: torch.Tensor,
93
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
94
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
95
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
96
+ """Compute convolution module.
97
+ Args:
98
+ x (torch.Tensor): Input tensor (#batch, time, channels).
99
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
100
+ (0, 0, 0) means fake mask.
101
+ cache (torch.Tensor): left context cache, it is only
102
+ used in causal convolution (#batch, channels, cache_t),
103
+ (0, 0, 0) meas fake cache.
104
+ Returns:
105
+ torch.Tensor: Output tensor (#batch, time, channels).
106
+ """
107
+ # exchange the temporal dimension and the feature dimension
108
+ x = x.transpose(1, 2) # (#batch, channels, time)
109
+
110
+ # mask batch padding
111
+ if mask_pad.size(2) > 0: # time > 0
112
+ x.masked_fill_(~mask_pad, 0.0)
113
+
114
+ if self.lorder > 0:
115
+ if cache.size(2) == 0: # cache_t == 0
116
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
117
+ else:
118
+ assert cache.size(0) == x.size(0) # equal batch
119
+ assert cache.size(1) == x.size(1) # equal channel
120
+ x = torch.cat((cache, x), dim=2)
121
+ assert (x.size(2) > self.lorder)
122
+ new_cache = x[:, :, -self.lorder:]
123
+ else:
124
+ # It's better we just return None if no cache is required,
125
+ # However, for JIT export, here we just fake one tensor instead of
126
+ # None.
127
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
128
+
129
+ # GLU mechanism
130
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
131
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
132
+
133
+ # 1D Depthwise Conv
134
+ x = self.depthwise_conv(x)
135
+ if self.use_layer_norm:
136
+ x = x.transpose(1, 2)
137
+ x = self.activation(self.norm(x))
138
+ if self.use_layer_norm:
139
+ x = x.transpose(1, 2)
140
+ x = self.pointwise_conv2(x)
141
+ # mask batch padding
142
+ if mask_pad.size(2) > 0: # time > 0
143
+ x.masked_fill_(~mask_pad, 0.0)
144
+
145
+ return x.transpose(1, 2), new_cache
src/chatterbox/models/s3gen/transformer/embedding.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Positonal Encoding Module."""
17
+
18
+ import math
19
+ from typing import Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import numpy as np
24
+
25
+
26
+ class PositionalEncoding(torch.nn.Module):
27
+ """Positional encoding.
28
+
29
+ :param int d_model: embedding dim
30
+ :param float dropout_rate: dropout rate
31
+ :param int max_len: maximum input length
32
+
33
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
34
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
35
+ """
36
+
37
+ def __init__(self,
38
+ d_model: int,
39
+ dropout_rate: float,
40
+ max_len: int = 5000,
41
+ reverse: bool = False):
42
+ """Construct an PositionalEncoding object."""
43
+ super().__init__()
44
+ self.d_model = d_model
45
+ self.xscale = math.sqrt(self.d_model)
46
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
47
+ self.max_len = max_len
48
+
49
+ self.pe = torch.zeros(self.max_len, self.d_model)
50
+ position = torch.arange(0, self.max_len,
51
+ dtype=torch.float32).unsqueeze(1)
52
+ div_term = torch.exp(
53
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) *
54
+ -(math.log(10000.0) / self.d_model))
55
+ self.pe[:, 0::2] = torch.sin(position * div_term)
56
+ self.pe[:, 1::2] = torch.cos(position * div_term)
57
+ self.pe = self.pe.unsqueeze(0)
58
+
59
+ def forward(self,
60
+ x: torch.Tensor,
61
+ offset: Union[int, torch.Tensor] = 0) \
62
+ -> Tuple[torch.Tensor, torch.Tensor]:
63
+ """Add positional encoding.
64
+
65
+ Args:
66
+ x (torch.Tensor): Input. Its shape is (batch, time, ...)
67
+ offset (int, torch.tensor): position offset
68
+
69
+ Returns:
70
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
71
+ torch.Tensor: for compatibility to RelPositionalEncoding
72
+ """
73
+
74
+ self.pe = self.pe.to(x.device)
75
+ pos_emb = self.position_encoding(offset, x.size(1), False)
76
+ x = x * self.xscale + pos_emb
77
+ return self.dropout(x), self.dropout(pos_emb)
78
+
79
+ def position_encoding(self,
80
+ offset: Union[int, torch.Tensor],
81
+ size: int,
82
+ apply_dropout: bool = True) -> torch.Tensor:
83
+ """ For getting encoding in a streaming fashion
84
+
85
+ Attention!!!!!
86
+ we apply dropout only once at the whole utterance level in a none
87
+ streaming way, but will call this function several times with
88
+ increasing input size in a streaming scenario, so the dropout will
89
+ be applied several times.
90
+
91
+ Args:
92
+ offset (int or torch.tensor): start offset
93
+ size (int): required size of position encoding
94
+
95
+ Returns:
96
+ torch.Tensor: Corresponding encoding
97
+ """
98
+ # How to subscript a Union type:
99
+ # https://github.com/pytorch/pytorch/issues/69434
100
+ if isinstance(offset, int):
101
+ assert offset + size <= self.max_len
102
+ pos_emb = self.pe[:, offset:offset + size]
103
+ elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
104
+ assert offset + size <= self.max_len
105
+ pos_emb = self.pe[:, offset:offset + size]
106
+ else: # for batched streaming decoding on GPU
107
+ assert torch.max(offset) + size <= self.max_len
108
+ index = offset.unsqueeze(1) + \
109
+ torch.arange(0, size).to(offset.device) # B X T
110
+ flag = index > 0
111
+ # remove negative offset
112
+ index = index * flag
113
+ pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
114
+
115
+ if apply_dropout:
116
+ pos_emb = self.dropout(pos_emb)
117
+ return pos_emb
118
+
119
+
120
+ class RelPositionalEncoding(PositionalEncoding):
121
+ """Relative positional encoding module.
122
+ See : Appendix B in https://arxiv.org/abs/1901.02860
123
+ Args:
124
+ d_model (int): Embedding dimension.
125
+ dropout_rate (float): Dropout rate.
126
+ max_len (int): Maximum input length.
127
+ """
128
+
129
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
130
+ """Initialize class."""
131
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
132
+
133
+ def forward(self,
134
+ x: torch.Tensor,
135
+ offset: Union[int, torch.Tensor] = 0) \
136
+ -> Tuple[torch.Tensor, torch.Tensor]:
137
+ """Compute positional encoding.
138
+ Args:
139
+ x (torch.Tensor): Input tensor (batch, time, `*`).
140
+ Returns:
141
+ torch.Tensor: Encoded tensor (batch, time, `*`).
142
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
143
+ """
144
+ self.pe = self.pe.to(x.device)
145
+ x = x * self.xscale
146
+ pos_emb = self.position_encoding(offset, x.size(1), False)
147
+ return self.dropout(x), self.dropout(pos_emb)
148
+
149
+
150
+ class WhisperPositionalEncoding(PositionalEncoding):
151
+ """ Sinusoids position encoding used in openai-whisper.encoder
152
+ """
153
+
154
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
155
+ super().__init__(d_model, dropout_rate, max_len)
156
+ self.xscale = 1.0
157
+ log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
158
+ inv_timescales = torch.exp(-log_timescale_increment *
159
+ torch.arange(d_model // 2))
160
+ scaled_time = torch.arange(max_len)[:, np.newaxis] * \
161
+ inv_timescales[np.newaxis, :]
162
+ pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
163
+ delattr(self, "pe")
164
+ self.register_buffer("pe", pe.unsqueeze(0))
165
+
166
+
167
+ class LearnablePositionalEncoding(PositionalEncoding):
168
+ """ Learnable position encoding used in openai-whisper.decoder
169
+ """
170
+
171
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
172
+ super().__init__(d_model, dropout_rate, max_len)
173
+ # NOTE(xcsong): overwrite self.pe & self.xscale
174
+ self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
175
+ self.xscale = 1.0
176
+
177
+
178
+ class NoPositionalEncoding(torch.nn.Module):
179
+ """ No position encoding
180
+ """
181
+
182
+ def __init__(self, d_model: int, dropout_rate: float):
183
+ super().__init__()
184
+ self.d_model = d_model
185
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
186
+
187
+ def forward(self,
188
+ x: torch.Tensor,
189
+ offset: Union[int, torch.Tensor] = 0) \
190
+ -> Tuple[torch.Tensor, torch.Tensor]:
191
+ """ Just return zero vector for interface compatibility
192
+ """
193
+ pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
194
+ return self.dropout(x), pos_emb
195
+
196
+ def position_encoding(self, offset: Union[int, torch.Tensor],
197
+ size: int) -> torch.Tensor:
198
+ return torch.zeros(1, size, self.d_model)
199
+
200
+
201
+ class EspnetRelPositionalEncoding(torch.nn.Module):
202
+ """Relative positional encoding module (new implementation).
203
+
204
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
205
+
206
+ See : Appendix B in https://arxiv.org/abs/1901.02860
207
+
208
+ Args:
209
+ d_model (int): Embedding dimension.
210
+ dropout_rate (float): Dropout rate.
211
+ max_len (int): Maximum input length.
212
+
213
+ """
214
+
215
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
216
+ """Construct an PositionalEncoding object."""
217
+ super(EspnetRelPositionalEncoding, self).__init__()
218
+ self.d_model = d_model
219
+ self.xscale = math.sqrt(self.d_model)
220
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
221
+ self.pe = None
222
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
223
+
224
+ def extend_pe(self, x: torch.Tensor):
225
+ """Reset the positional encodings."""
226
+ if self.pe is not None:
227
+ # self.pe contains both positive and negative parts
228
+ # the length of self.pe is 2 * input_len - 1
229
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
230
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
231
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
232
+ return
233
+ # Suppose `i` means to the position of query vecotr and `j` means the
234
+ # position of key vector. We use position relative positions when keys
235
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
236
+ pe_positive = torch.zeros(x.size(1), self.d_model)
237
+ pe_negative = torch.zeros(x.size(1), self.d_model)
238
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
239
+ div_term = torch.exp(
240
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
241
+ * -(math.log(10000.0) / self.d_model)
242
+ )
243
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
244
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
245
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
246
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
247
+
248
+ # Reserve the order of positive indices and concat both positive and
249
+ # negative indices. This is used to support the shifting trick
250
+ # as in https://arxiv.org/abs/1901.02860
251
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
252
+ pe_negative = pe_negative[1:].unsqueeze(0)
253
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
254
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
255
+
256
+ def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
257
+ -> Tuple[torch.Tensor, torch.Tensor]:
258
+ """Add positional encoding.
259
+
260
+ Args:
261
+ x (torch.Tensor): Input tensor (batch, time, `*`).
262
+
263
+ Returns:
264
+ torch.Tensor: Encoded tensor (batch, time, `*`).
265
+
266
+ """
267
+ self.extend_pe(x)
268
+ x = x * self.xscale
269
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
270
+ return self.dropout(x), self.dropout(pos_emb)
271
+
272
+ def position_encoding(self,
273
+ offset: Union[int, torch.Tensor],
274
+ size: int) -> torch.Tensor:
275
+ """ For getting encoding in a streaming fashion
276
+
277
+ Attention!!!!!
278
+ we apply dropout only once at the whole utterance level in a none
279
+ streaming way, but will call this function several times with
280
+ increasing input size in a streaming scenario, so the dropout will
281
+ be applied several times.
282
+
283
+ Args:
284
+ offset (int or torch.tensor): start offset
285
+ size (int): required size of position encoding
286
+
287
+ Returns:
288
+ torch.Tensor: Corresponding encoding
289
+ """
290
+ pos_emb = self.pe[
291
+ :,
292
+ self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
293
+ ]
294
+ return pos_emb
src/chatterbox/models/s3gen/transformer/encoder_layer.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Encoder self-attention layer definition."""
17
+
18
+ from typing import Optional, Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class TransformerEncoderLayer(nn.Module):
25
+ """Encoder layer module.
26
+
27
+ Args:
28
+ size (int): Input dimension.
29
+ self_attn (torch.nn.Module): Self-attention module instance.
30
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
31
+ instance can be used as the argument.
32
+ feed_forward (torch.nn.Module): Feed-forward module instance.
33
+ `PositionwiseFeedForward`, instance can be used as the argument.
34
+ dropout_rate (float): Dropout rate.
35
+ normalize_before (bool):
36
+ True: use layer_norm before each sub-block.
37
+ False: to use layer_norm after each sub-block.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ size: int,
43
+ self_attn: torch.nn.Module,
44
+ feed_forward: torch.nn.Module,
45
+ dropout_rate: float,
46
+ normalize_before: bool = True,
47
+ ):
48
+ """Construct an EncoderLayer object."""
49
+ super().__init__()
50
+ self.self_attn = self_attn
51
+ self.feed_forward = feed_forward
52
+ self.norm1 = nn.LayerNorm(size, eps=1e-12)
53
+ self.norm2 = nn.LayerNorm(size, eps=1e-12)
54
+ self.dropout = nn.Dropout(dropout_rate)
55
+ self.size = size
56
+ self.normalize_before = normalize_before
57
+
58
+ def forward(
59
+ self,
60
+ x: torch.Tensor,
61
+ mask: torch.Tensor,
62
+ pos_emb: torch.Tensor,
63
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
64
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
65
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
66
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
67
+ """Compute encoded features.
68
+
69
+ Args:
70
+ x (torch.Tensor): (#batch, time, size)
71
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
72
+ (0, 0, 0) means fake mask.
73
+ pos_emb (torch.Tensor): just for interface compatibility
74
+ to ConformerEncoderLayer
75
+ mask_pad (torch.Tensor): does not used in transformer layer,
76
+ just for unified api with conformer.
77
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
78
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
79
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
80
+ (#batch=1, size, cache_t2), not used here, it's for interface
81
+ compatibility to ConformerEncoderLayer.
82
+ Returns:
83
+ torch.Tensor: Output tensor (#batch, time, size).
84
+ torch.Tensor: Mask tensor (#batch, time, time).
85
+ torch.Tensor: att_cache tensor,
86
+ (#batch=1, head, cache_t1 + time, d_k * 2).
87
+ torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
88
+
89
+ """
90
+ residual = x
91
+ if self.normalize_before:
92
+ x = self.norm1(x)
93
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
94
+ x = residual + self.dropout(x_att)
95
+ if not self.normalize_before:
96
+ x = self.norm1(x)
97
+
98
+ residual = x
99
+ if self.normalize_before:
100
+ x = self.norm2(x)
101
+ x = residual + self.dropout(self.feed_forward(x))
102
+ if not self.normalize_before:
103
+ x = self.norm2(x)
104
+
105
+ fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
106
+ return x, mask, new_att_cache, fake_cnn_cache
107
+
108
+
109
+ class ConformerEncoderLayer(nn.Module):
110
+ """Encoder layer module.
111
+ Args:
112
+ size (int): Input dimension.
113
+ self_attn (torch.nn.Module): Self-attention module instance.
114
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
115
+ instance can be used as the argument.
116
+ feed_forward (torch.nn.Module): Feed-forward module instance.
117
+ `PositionwiseFeedForward` instance can be used as the argument.
118
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
119
+ instance.
120
+ `PositionwiseFeedForward` instance can be used as the argument.
121
+ conv_module (torch.nn.Module): Convolution module instance.
122
+ `ConvlutionModule` instance can be used as the argument.
123
+ dropout_rate (float): Dropout rate.
124
+ normalize_before (bool):
125
+ True: use layer_norm before each sub-block.
126
+ False: use layer_norm after each sub-block.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ size: int,
132
+ self_attn: torch.nn.Module,
133
+ feed_forward: Optional[nn.Module] = None,
134
+ feed_forward_macaron: Optional[nn.Module] = None,
135
+ conv_module: Optional[nn.Module] = None,
136
+ dropout_rate: float = 0.1,
137
+ normalize_before: bool = True,
138
+ ):
139
+ """Construct an EncoderLayer object."""
140
+ super().__init__()
141
+ self.self_attn = self_attn
142
+ self.feed_forward = feed_forward
143
+ self.feed_forward_macaron = feed_forward_macaron
144
+ self.conv_module = conv_module
145
+ self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
146
+ self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
147
+ if feed_forward_macaron is not None:
148
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
149
+ self.ff_scale = 0.5
150
+ else:
151
+ self.ff_scale = 1.0
152
+ if self.conv_module is not None:
153
+ self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
154
+ self.norm_final = nn.LayerNorm(
155
+ size, eps=1e-12) # for the final output of the block
156
+ self.dropout = nn.Dropout(dropout_rate)
157
+ self.size = size
158
+ self.normalize_before = normalize_before
159
+
160
+ def forward(
161
+ self,
162
+ x: torch.Tensor,
163
+ mask: torch.Tensor,
164
+ pos_emb: torch.Tensor,
165
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
166
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
167
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
168
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
169
+ """Compute encoded features.
170
+
171
+ Args:
172
+ x (torch.Tensor): (#batch, time, size)
173
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
174
+ (0, 0, 0) means fake mask.
175
+ pos_emb (torch.Tensor): positional encoding, must not be None
176
+ for ConformerEncoderLayer.
177
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
178
+ (#batch, 1,time), (0, 0, 0) means fake mask.
179
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
180
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
181
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
182
+ (#batch=1, size, cache_t2)
183
+ Returns:
184
+ torch.Tensor: Output tensor (#batch, time, size).
185
+ torch.Tensor: Mask tensor (#batch, time, time).
186
+ torch.Tensor: att_cache tensor,
187
+ (#batch=1, head, cache_t1 + time, d_k * 2).
188
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
189
+ """
190
+
191
+ # whether to use macaron style
192
+ if self.feed_forward_macaron is not None:
193
+ residual = x
194
+ if self.normalize_before:
195
+ x = self.norm_ff_macaron(x)
196
+ x = residual + self.ff_scale * self.dropout(
197
+ self.feed_forward_macaron(x))
198
+ if not self.normalize_before:
199
+ x = self.norm_ff_macaron(x)
200
+
201
+ # multi-headed self-attention module
202
+ residual = x
203
+ if self.normalize_before:
204
+ x = self.norm_mha(x)
205
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
206
+ att_cache)
207
+ x = residual + self.dropout(x_att)
208
+ if not self.normalize_before:
209
+ x = self.norm_mha(x)
210
+
211
+ # convolution module
212
+ # Fake new cnn cache here, and then change it in conv_module
213
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
214
+ if self.conv_module is not None:
215
+ residual = x
216
+ if self.normalize_before:
217
+ x = self.norm_conv(x)
218
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
219
+ x = residual + self.dropout(x)
220
+
221
+ if not self.normalize_before:
222
+ x = self.norm_conv(x)
223
+
224
+ # feed forward module
225
+ residual = x
226
+ if self.normalize_before:
227
+ x = self.norm_ff(x)
228
+
229
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
230
+ if not self.normalize_before:
231
+ x = self.norm_ff(x)
232
+
233
+ if self.conv_module is not None:
234
+ x = self.norm_final(x)
235
+
236
+ return x, mask, new_att_cache, new_cnn_cache
src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Positionwise feed forward layer definition."""
16
+
17
+ import torch
18
+
19
+
20
+ class PositionwiseFeedForward(torch.nn.Module):
21
+ """Positionwise feed forward layer.
22
+
23
+ FeedForward are appied on each position of the sequence.
24
+ The output dim is same with the input dim.
25
+
26
+ Args:
27
+ idim (int): Input dimenstion.
28
+ hidden_units (int): The number of hidden units.
29
+ dropout_rate (float): Dropout rate.
30
+ activation (torch.nn.Module): Activation function
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ idim: int,
36
+ hidden_units: int,
37
+ dropout_rate: float,
38
+ activation: torch.nn.Module = torch.nn.ReLU(),
39
+ ):
40
+ """Construct a PositionwiseFeedForward object."""
41
+ super(PositionwiseFeedForward, self).__init__()
42
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
43
+ self.activation = activation
44
+ self.dropout = torch.nn.Dropout(dropout_rate)
45
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
46
+
47
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
48
+ """Forward function.
49
+
50
+ Args:
51
+ xs: input tensor (B, L, D)
52
+ Returns:
53
+ output tensor, (B, L, D)
54
+ """
55
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
56
+
57
+
58
+ class MoEFFNLayer(torch.nn.Module):
59
+ """
60
+ Mixture of expert with Positionwise feed forward layer
61
+ See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
62
+ The output dim is same with the input dim.
63
+
64
+ Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
65
+ https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
66
+ Args:
67
+ n_expert: number of expert.
68
+ n_expert_per_token: The actual number of experts used for each frame
69
+ idim (int): Input dimenstion.
70
+ hidden_units (int): The number of hidden units.
71
+ dropout_rate (float): Dropout rate.
72
+ activation (torch.nn.Module): Activation function
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ n_expert: int,
78
+ n_expert_per_token: int,
79
+ idim: int,
80
+ hidden_units: int,
81
+ dropout_rate: float,
82
+ activation: torch.nn.Module = torch.nn.ReLU(),
83
+ ):
84
+ super(MoEFFNLayer, self).__init__()
85
+ self.gate = torch.nn.Linear(idim, n_expert, bias=False)
86
+ self.experts = torch.nn.ModuleList(
87
+ PositionwiseFeedForward(idim, hidden_units, dropout_rate,
88
+ activation) for _ in range(n_expert))
89
+ self.n_expert_per_token = n_expert_per_token
90
+
91
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
92
+ """Foward function.
93
+ Args:
94
+ xs: input tensor (B, L, D)
95
+ Returns:
96
+ output tensor, (B, L, D)
97
+
98
+ """
99
+ B, L, D = xs.size(
100
+ ) # batch size, sequence length, embedding dimension (idim)
101
+ xs = xs.view(-1, D) # (B*L, D)
102
+ router = self.gate(xs) # (B*L, n_expert)
103
+ logits, indices = torch.topk(
104
+ router, self.n_expert_per_token
105
+ ) # probs:(B*L, n_expert), indices: (B*L, n_expert)
106
+ weights = torch.nn.functional.softmax(
107
+ logits, dim=1,
108
+ dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
109
+ output = torch.zeros_like(xs) # (B*L, D)
110
+ for i, expert in enumerate(self.experts):
111
+ mask = indices == i
112
+ batch_idx, ith_expert = torch.where(mask)
113
+ output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
114
+ xs[batch_idx])
115
+ return output.view(B, L, D)
src/chatterbox/models/s3gen/transformer/subsampling.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Subsampling layer definition."""
17
+
18
+ from typing import Tuple, Union
19
+
20
+ import torch
21
+
22
+
23
+ class BaseSubsampling(torch.nn.Module):
24
+
25
+ def __init__(self):
26
+ super().__init__()
27
+ self.right_context = 0
28
+ self.subsampling_rate = 1
29
+
30
+ def position_encoding(self, offset: Union[int, torch.Tensor],
31
+ size: int) -> torch.Tensor:
32
+ return self.pos_enc.position_encoding(offset, size)
33
+
34
+
35
+ class EmbedinigNoSubsampling(BaseSubsampling):
36
+ """Embedding input without subsampling
37
+ """
38
+
39
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
40
+ pos_enc_class: torch.nn.Module):
41
+ super().__init__()
42
+ self.embed = torch.nn.Embedding(idim, odim)
43
+ self.pos_enc = pos_enc_class
44
+
45
+ def forward(
46
+ self,
47
+ x: torch.Tensor,
48
+ x_mask: torch.Tensor,
49
+ offset: Union[int, torch.Tensor] = 0
50
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51
+ """Input x.
52
+
53
+ Args:
54
+ x (torch.Tensor): Input tensor (#batch, time, idim).
55
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
56
+
57
+ Returns:
58
+ torch.Tensor: linear input tensor (#batch, time', odim),
59
+ where time' = time .
60
+ torch.Tensor: linear input mask (#batch, 1, time'),
61
+ where time' = time .
62
+
63
+ """
64
+ x = self.embed(x)
65
+ x, pos_emb = self.pos_enc(x, offset)
66
+ return x, pos_emb, x_mask
67
+
68
+
69
+ class LinearNoSubsampling(BaseSubsampling):
70
+ """Linear transform the input without subsampling
71
+
72
+ Args:
73
+ idim (int): Input dimension.
74
+ odim (int): Output dimension.
75
+ dropout_rate (float): Dropout rate.
76
+
77
+ """
78
+
79
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
80
+ pos_enc_class: torch.nn.Module):
81
+ """Construct an linear object."""
82
+ super().__init__()
83
+ self.out = torch.nn.Sequential(
84
+ torch.nn.Linear(idim, odim),
85
+ torch.nn.LayerNorm(odim, eps=1e-5),
86
+ torch.nn.Dropout(dropout_rate),
87
+ )
88
+ self.pos_enc = pos_enc_class
89
+ self.right_context = 0
90
+ self.subsampling_rate = 1
91
+
92
+ def forward(
93
+ self,
94
+ x: torch.Tensor,
95
+ x_mask: torch.Tensor,
96
+ offset: Union[int, torch.Tensor] = 0
97
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
98
+ """Input x.
99
+
100
+ Args:
101
+ x (torch.Tensor): Input tensor (#batch, time, idim).
102
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
103
+
104
+ Returns:
105
+ torch.Tensor: linear input tensor (#batch, time', odim),
106
+ where time' = time .
107
+ torch.Tensor: linear input mask (#batch, 1, time'),
108
+ where time' = time .
109
+
110
+ """
111
+ x = self.out(x)
112
+ x, pos_emb = self.pos_enc(x, offset)
113
+ return x, pos_emb, x_mask
114
+
115
+
116
+ class Conv1dSubsampling2(BaseSubsampling):
117
+ """Convolutional 1D subsampling (to 1/2 length).
118
+ It is designed for Whisper, ref:
119
+ https://github.com/openai/whisper/blob/main/whisper/model.py
120
+
121
+ Args:
122
+ idim (int): Input dimension.
123
+ odim (int): Output dimension.
124
+ dropout_rate (float): Dropout rate.
125
+
126
+ """
127
+
128
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
129
+ pos_enc_class: torch.nn.Module):
130
+ """Construct an Conv1dSubsampling2 object."""
131
+ super().__init__()
132
+ self.conv = torch.nn.Sequential(
133
+ torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
134
+ torch.nn.GELU(),
135
+ torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
136
+ torch.nn.GELU(),
137
+ )
138
+ self.pos_enc = pos_enc_class
139
+ # The right context for every conv layer is computed by:
140
+ # (kernel_size - 1) * frame_rate_of_this_layer
141
+ self.subsampling_rate = 2
142
+ # 4 = (3 - 1) * 1 + (3 - 1) * 1
143
+ self.right_context = 4
144
+
145
+ def forward(
146
+ self,
147
+ x: torch.Tensor,
148
+ x_mask: torch.Tensor,
149
+ offset: Union[int, torch.Tensor] = 0
150
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
151
+ """Subsample x.
152
+
153
+ Args:
154
+ x (torch.Tensor): Input tensor (#batch, time, idim).
155
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
156
+
157
+ Returns:
158
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
159
+ where time' = time // 2.
160
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
161
+ where time' = time // 2.
162
+ torch.Tensor: positional encoding
163
+
164
+ """
165
+ time = x.size(1)
166
+ x = x.transpose(1, 2) # (b, f, t)
167
+ x = self.conv(x)
168
+ x = x.transpose(1, 2) # (b, t, f)
169
+ x, pos_emb = self.pos_enc(x, offset)
170
+ return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
171
+
172
+
173
+ class Conv2dSubsampling4(BaseSubsampling):
174
+ """Convolutional 2D subsampling (to 1/4 length).
175
+
176
+ Args:
177
+ idim (int): Input dimension.
178
+ odim (int): Output dimension.
179
+ dropout_rate (float): Dropout rate.
180
+
181
+ """
182
+
183
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
184
+ pos_enc_class: torch.nn.Module):
185
+ """Construct an Conv2dSubsampling4 object."""
186
+ super().__init__()
187
+ self.conv = torch.nn.Sequential(
188
+ torch.nn.Conv2d(1, odim, 3, 2),
189
+ torch.nn.ReLU(),
190
+ torch.nn.Conv2d(odim, odim, 3, 2),
191
+ torch.nn.ReLU(),
192
+ )
193
+ self.out = torch.nn.Sequential(
194
+ torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
195
+ self.pos_enc = pos_enc_class
196
+ # The right context for every conv layer is computed by:
197
+ # (kernel_size - 1) * frame_rate_of_this_layer
198
+ self.subsampling_rate = 4
199
+ # 6 = (3 - 1) * 1 + (3 - 1) * 2
200
+ self.right_context = 6
201
+
202
+ def forward(
203
+ self,
204
+ x: torch.Tensor,
205
+ x_mask: torch.Tensor,
206
+ offset: Union[int, torch.Tensor] = 0
207
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
208
+ """Subsample x.
209
+
210
+ Args:
211
+ x (torch.Tensor): Input tensor (#batch, time, idim).
212
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
213
+
214
+ Returns:
215
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
216
+ where time' = time // 4.
217
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
218
+ where time' = time // 4.
219
+ torch.Tensor: positional encoding
220
+
221
+ """
222
+ x = x.unsqueeze(1) # (b, c=1, t, f)
223
+ x = self.conv(x)
224
+ b, c, t, f = x.size()
225
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
226
+ x, pos_emb = self.pos_enc(x, offset)
227
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
228
+
229
+
230
+ class Conv2dSubsampling6(BaseSubsampling):
231
+ """Convolutional 2D subsampling (to 1/6 length).
232
+ Args:
233
+ idim (int): Input dimension.
234
+ odim (int): Output dimension.
235
+ dropout_rate (float): Dropout rate.
236
+ pos_enc (torch.nn.Module): Custom position encoding layer.
237
+ """
238
+
239
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
240
+ pos_enc_class: torch.nn.Module):
241
+ """Construct an Conv2dSubsampling6 object."""
242
+ super().__init__()
243
+ self.conv = torch.nn.Sequential(
244
+ torch.nn.Conv2d(1, odim, 3, 2),
245
+ torch.nn.ReLU(),
246
+ torch.nn.Conv2d(odim, odim, 5, 3),
247
+ torch.nn.ReLU(),
248
+ )
249
+ self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
250
+ odim)
251
+ self.pos_enc = pos_enc_class
252
+ # 10 = (3 - 1) * 1 + (5 - 1) * 2
253
+ self.subsampling_rate = 6
254
+ self.right_context = 10
255
+
256
+ def forward(
257
+ self,
258
+ x: torch.Tensor,
259
+ x_mask: torch.Tensor,
260
+ offset: Union[int, torch.Tensor] = 0
261
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
262
+ """Subsample x.
263
+ Args:
264
+ x (torch.Tensor): Input tensor (#batch, time, idim).
265
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
266
+
267
+ Returns:
268
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
269
+ where time' = time // 6.
270
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
271
+ where time' = time // 6.
272
+ torch.Tensor: positional encoding
273
+ """
274
+ x = x.unsqueeze(1) # (b, c, t, f)
275
+ x = self.conv(x)
276
+ b, c, t, f = x.size()
277
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
278
+ x, pos_emb = self.pos_enc(x, offset)
279
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
280
+
281
+
282
+ class Conv2dSubsampling8(BaseSubsampling):
283
+ """Convolutional 2D subsampling (to 1/8 length).
284
+
285
+ Args:
286
+ idim (int): Input dimension.
287
+ odim (int): Output dimension.
288
+ dropout_rate (float): Dropout rate.
289
+
290
+ """
291
+
292
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
293
+ pos_enc_class: torch.nn.Module):
294
+ """Construct an Conv2dSubsampling8 object."""
295
+ super().__init__()
296
+ self.conv = torch.nn.Sequential(
297
+ torch.nn.Conv2d(1, odim, 3, 2),
298
+ torch.nn.ReLU(),
299
+ torch.nn.Conv2d(odim, odim, 3, 2),
300
+ torch.nn.ReLU(),
301
+ torch.nn.Conv2d(odim, odim, 3, 2),
302
+ torch.nn.ReLU(),
303
+ )
304
+ self.linear = torch.nn.Linear(
305
+ odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
306
+ self.pos_enc = pos_enc_class
307
+ self.subsampling_rate = 8
308
+ # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
309
+ self.right_context = 14
310
+
311
+ def forward(
312
+ self,
313
+ x: torch.Tensor,
314
+ x_mask: torch.Tensor,
315
+ offset: Union[int, torch.Tensor] = 0
316
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
317
+ """Subsample x.
318
+
319
+ Args:
320
+ x (torch.Tensor): Input tensor (#batch, time, idim).
321
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
322
+
323
+ Returns:
324
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
325
+ where time' = time // 8.
326
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
327
+ where time' = time // 8.
328
+ torch.Tensor: positional encoding
329
+ """
330
+ x = x.unsqueeze(1) # (b, c, t, f)
331
+ x = self.conv(x)
332
+ b, c, t, f = x.size()
333
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
334
+ x, pos_emb = self.pos_enc(x, offset)
335
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
336
+
337
+
338
+ class LegacyLinearNoSubsampling(BaseSubsampling):
339
+ """Linear transform the input without subsampling
340
+
341
+ Args:
342
+ idim (int): Input dimension.
343
+ odim (int): Output dimension.
344
+ dropout_rate (float): Dropout rate.
345
+
346
+ """
347
+
348
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
349
+ pos_enc_class: torch.nn.Module):
350
+ """Construct an linear object."""
351
+ super().__init__()
352
+ self.out = torch.nn.Sequential(
353
+ torch.nn.Linear(idim, odim),
354
+ torch.nn.LayerNorm(odim, eps=1e-5),
355
+ torch.nn.Dropout(dropout_rate),
356
+ torch.nn.ReLU(),
357
+ )
358
+ self.pos_enc = pos_enc_class
359
+ self.right_context = 0
360
+ self.subsampling_rate = 1
361
+
362
+ def forward(
363
+ self,
364
+ x: torch.Tensor,
365
+ x_mask: torch.Tensor,
366
+ offset: Union[int, torch.Tensor] = 0
367
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
368
+ """Input x.
369
+
370
+ Args:
371
+ x (torch.Tensor): Input tensor (#batch, time, idim).
372
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
373
+
374
+ Returns:
375
+ torch.Tensor: linear input tensor (#batch, time', odim),
376
+ where time' = time .
377
+ torch.Tensor: linear input mask (#batch, 1, time'),
378
+ where time' = time .
379
+
380
+ """
381
+ x = self.out(x)
382
+ x, pos_emb = self.pos_enc(x, offset)
383
+ return x, pos_emb, x_mask
src/chatterbox/models/s3gen/transformer/upsample_encoder.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song ([email protected])
3
+ # 2024 Alibaba Inc (Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ """Encoder definition."""
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+ from torch.nn import functional as F
23
+
24
+ from .convolution import ConvolutionModule
25
+ from .encoder_layer import ConformerEncoderLayer
26
+ from .positionwise_feed_forward import PositionwiseFeedForward
27
+ from ..utils.class_utils import (
28
+ COSYVOICE_EMB_CLASSES,
29
+ COSYVOICE_SUBSAMPLE_CLASSES,
30
+ COSYVOICE_ATTENTION_CLASSES,
31
+ COSYVOICE_ACTIVATION_CLASSES,
32
+ )
33
+ from ..utils.mask import make_pad_mask
34
+ from ..utils.mask import add_optional_chunk_mask
35
+
36
+
37
+ class Upsample1D(nn.Module):
38
+ """A 1D upsampling layer with an optional convolution.
39
+
40
+ Parameters:
41
+ channels (`int`):
42
+ number of channels in the inputs and outputs.
43
+ use_conv (`bool`, default `False`):
44
+ option to use a convolution.
45
+ use_conv_transpose (`bool`, default `False`):
46
+ option to use a convolution transpose.
47
+ out_channels (`int`, optional):
48
+ number of output channels. Defaults to `channels`.
49
+ """
50
+
51
+ def __init__(self, channels: int, out_channels: int, stride: int = 2):
52
+ super().__init__()
53
+ self.channels = channels
54
+ self.out_channels = out_channels
55
+ self.stride = stride
56
+ # In this mode, first repeat interpolate, than conv with stride=1
57
+ self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
58
+
59
+ def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
60
+ outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
61
+ outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
62
+ outputs = self.conv(outputs)
63
+ return outputs, input_lengths * self.stride
64
+
65
+
66
+ class PreLookaheadLayer(nn.Module):
67
+ def __init__(self, channels: int, pre_lookahead_len: int = 1):
68
+ super().__init__()
69
+ self.channels = channels
70
+ self.pre_lookahead_len = pre_lookahead_len
71
+ self.conv1 = nn.Conv1d(
72
+ channels, channels,
73
+ kernel_size=pre_lookahead_len + 1,
74
+ stride=1, padding=0,
75
+ )
76
+ self.conv2 = nn.Conv1d(
77
+ channels, channels,
78
+ kernel_size=3, stride=1, padding=0,
79
+ )
80
+
81
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
82
+ """
83
+ inputs: (batch_size, seq_len, channels)
84
+ """
85
+ outputs = inputs.transpose(1, 2).contiguous()
86
+ # look ahead
87
+ outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
88
+ outputs = F.leaky_relu(self.conv1(outputs))
89
+ # outputs
90
+ outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
91
+ outputs = self.conv2(outputs)
92
+ outputs = outputs.transpose(1, 2).contiguous()
93
+
94
+ # residual connection
95
+ outputs = outputs + inputs
96
+ return outputs
97
+
98
+
99
+ class UpsampleConformerEncoder(torch.nn.Module):
100
+
101
+ def __init__(
102
+ self,
103
+ input_size: int = 512,
104
+ output_size: int = 512,
105
+ attention_heads: int = 8,
106
+ linear_units: int = 2048,
107
+ num_blocks: int = 6,
108
+ dropout_rate: float = 0.1,
109
+ positional_dropout_rate: float = 0.1,
110
+ attention_dropout_rate: float = 0.1,
111
+ input_layer: str = "linear",
112
+ pos_enc_layer_type: str = "rel_pos_espnet",
113
+ normalize_before: bool = True,
114
+ static_chunk_size: int = 0,
115
+ use_dynamic_chunk: bool = False,
116
+ global_cmvn: torch.nn.Module = None,
117
+ use_dynamic_left_chunk: bool = False,
118
+ positionwise_conv_kernel_size: int = 1,
119
+ macaron_style: bool = False,
120
+ selfattention_layer_type: str = "rel_selfattn",
121
+ activation_type: str = "swish",
122
+ use_cnn_module: bool = False,
123
+ cnn_module_kernel: int = 15,
124
+ causal: bool = False,
125
+ cnn_module_norm: str = "batch_norm",
126
+ key_bias: bool = True,
127
+ gradient_checkpointing: bool = False,
128
+ ):
129
+ """
130
+ Args:
131
+ input_size (int): input dim
132
+ output_size (int): dimension of attention
133
+ attention_heads (int): the number of heads of multi head attention
134
+ linear_units (int): the hidden units number of position-wise feed
135
+ forward
136
+ num_blocks (int): the number of decoder blocks
137
+ dropout_rate (float): dropout rate
138
+ attention_dropout_rate (float): dropout rate in attention
139
+ positional_dropout_rate (float): dropout rate after adding
140
+ positional encoding
141
+ input_layer (str): input layer type.
142
+ optional [linear, conv2d, conv2d6, conv2d8]
143
+ pos_enc_layer_type (str): Encoder positional encoding layer type.
144
+ opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
145
+ normalize_before (bool):
146
+ True: use layer_norm before each sub-block of a layer.
147
+ False: use layer_norm after each sub-block of a layer.
148
+ static_chunk_size (int): chunk size for static chunk training and
149
+ decoding
150
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
151
+ training or not, You can only use fixed chunk(chunk_size > 0)
152
+ or dyanmic chunk size(use_dynamic_chunk = True)
153
+ global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
154
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
155
+ dynamic chunk training
156
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
157
+ gradient_checkpointing: rerunning a forward-pass segment for each
158
+ checkpointed segment during backward.
159
+ """
160
+ super().__init__()
161
+ self._output_size = output_size
162
+
163
+ self.global_cmvn = global_cmvn
164
+ self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
165
+ input_size,
166
+ output_size,
167
+ dropout_rate,
168
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
169
+ positional_dropout_rate),
170
+ )
171
+
172
+ self.normalize_before = normalize_before
173
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
174
+ self.static_chunk_size = static_chunk_size
175
+ self.use_dynamic_chunk = use_dynamic_chunk
176
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
177
+ self.gradient_checkpointing = gradient_checkpointing
178
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
179
+ # self-attention module definition
180
+ encoder_selfattn_layer_args = (
181
+ attention_heads,
182
+ output_size,
183
+ attention_dropout_rate,
184
+ key_bias,
185
+ )
186
+ # feed-forward module definition
187
+ positionwise_layer_args = (
188
+ output_size,
189
+ linear_units,
190
+ dropout_rate,
191
+ activation,
192
+ )
193
+ # convolution module definition
194
+ convolution_layer_args = (output_size, cnn_module_kernel, activation,
195
+ cnn_module_norm, causal)
196
+ self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
197
+ self.encoders = torch.nn.ModuleList([
198
+ ConformerEncoderLayer(
199
+ output_size,
200
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
201
+ *encoder_selfattn_layer_args),
202
+ PositionwiseFeedForward(*positionwise_layer_args),
203
+ PositionwiseFeedForward(
204
+ *positionwise_layer_args) if macaron_style else None,
205
+ ConvolutionModule(
206
+ *convolution_layer_args) if use_cnn_module else None,
207
+ dropout_rate,
208
+ normalize_before,
209
+ ) for _ in range(num_blocks)
210
+ ])
211
+ self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
212
+ self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
213
+ input_size,
214
+ output_size,
215
+ dropout_rate,
216
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
217
+ positional_dropout_rate),
218
+ )
219
+ self.up_encoders = torch.nn.ModuleList([
220
+ ConformerEncoderLayer(
221
+ output_size,
222
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
223
+ *encoder_selfattn_layer_args),
224
+ PositionwiseFeedForward(*positionwise_layer_args),
225
+ PositionwiseFeedForward(
226
+ *positionwise_layer_args) if macaron_style else None,
227
+ ConvolutionModule(
228
+ *convolution_layer_args) if use_cnn_module else None,
229
+ dropout_rate,
230
+ normalize_before,
231
+ ) for _ in range(4)
232
+ ])
233
+
234
+ def output_size(self) -> int:
235
+ return self._output_size
236
+
237
+ def forward(
238
+ self,
239
+ xs: torch.Tensor,
240
+ xs_lens: torch.Tensor,
241
+ decoding_chunk_size: int = 0,
242
+ num_decoding_left_chunks: int = -1,
243
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
244
+ """Embed positions in tensor.
245
+
246
+ Args:
247
+ xs: padded input tensor (B, T, D)
248
+ xs_lens: input length (B)
249
+ decoding_chunk_size: decoding chunk size for dynamic chunk
250
+ 0: default for training, use random dynamic chunk.
251
+ <0: for decoding, use full chunk.
252
+ >0: for decoding, use fixed chunk size as set.
253
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
254
+ the chunk size is decoding_chunk_size.
255
+ >=0: use num_decoding_left_chunks
256
+ <0: use all left chunks
257
+ Returns:
258
+ encoder output tensor xs, and subsampled masks
259
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
260
+ masks: torch.Tensor batch padding mask after subsample
261
+ (B, 1, T' ~= T/subsample_rate)
262
+ NOTE(xcsong):
263
+ We pass the `__call__` method of the modules instead of `forward` to the
264
+ checkpointing API because `__call__` attaches all the hooks of the module.
265
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
266
+ """
267
+ T = xs.size(1)
268
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
269
+ if self.global_cmvn is not None:
270
+ xs = self.global_cmvn(xs)
271
+ xs, pos_emb, masks = self.embed(xs, masks)
272
+ mask_pad = masks # (B, 1, T/subsample_rate)
273
+ chunk_masks = add_optional_chunk_mask(xs, masks,
274
+ self.use_dynamic_chunk,
275
+ self.use_dynamic_left_chunk,
276
+ decoding_chunk_size,
277
+ self.static_chunk_size,
278
+ num_decoding_left_chunks)
279
+ # lookahead + conformer encoder
280
+ xs = self.pre_lookahead_layer(xs)
281
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
282
+
283
+ # upsample + conformer encoder
284
+ xs = xs.transpose(1, 2).contiguous()
285
+ xs, xs_lens = self.up_layer(xs, xs_lens)
286
+ xs = xs.transpose(1, 2).contiguous()
287
+ T = xs.size(1)
288
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
289
+ xs, pos_emb, masks = self.up_embed(xs, masks)
290
+ mask_pad = masks # (B, 1, T/subsample_rate)
291
+ chunk_masks = add_optional_chunk_mask(xs, masks,
292
+ self.use_dynamic_chunk,
293
+ self.use_dynamic_left_chunk,
294
+ decoding_chunk_size,
295
+ self.static_chunk_size * self.up_layer.stride,
296
+ num_decoding_left_chunks)
297
+ xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
298
+
299
+ if self.normalize_before:
300
+ xs = self.after_norm(xs)
301
+ # Here we assume the mask is not changed in encoder layers, so just
302
+ # return the masks before encoder layers, and the masks will be used
303
+ # for cross attention with decoder later
304
+ return xs, masks
305
+
306
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
307
+ pos_emb: torch.Tensor,
308
+ mask_pad: torch.Tensor) -> torch.Tensor:
309
+ for layer in self.encoders:
310
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
311
+ return xs
312
+
313
+ def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
314
+ pos_emb: torch.Tensor,
315
+ mask_pad: torch.Tensor) -> torch.Tensor:
316
+ for layer in self.up_encoders:
317
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
318
+ return xs
src/chatterbox/models/s3gen/utils/class_utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright [2023-11-28] <[email protected], Xingchen Song>
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import torch
16
+
17
+ from ..transformer.activation import Swish
18
+ from ..transformer.subsampling import (
19
+ LinearNoSubsampling,
20
+ EmbedinigNoSubsampling,
21
+ Conv1dSubsampling2,
22
+ Conv2dSubsampling4,
23
+ Conv2dSubsampling6,
24
+ Conv2dSubsampling8,
25
+ )
26
+ from ..transformer.embedding import (
27
+ PositionalEncoding,
28
+ RelPositionalEncoding,
29
+ WhisperPositionalEncoding,
30
+ LearnablePositionalEncoding,
31
+ NoPositionalEncoding)
32
+ from ..transformer.attention import (MultiHeadedAttention,
33
+ RelPositionMultiHeadedAttention)
34
+ from ..transformer.embedding import EspnetRelPositionalEncoding
35
+ from ..transformer.subsampling import LegacyLinearNoSubsampling
36
+
37
+
38
+ COSYVOICE_ACTIVATION_CLASSES = {
39
+ "hardtanh": torch.nn.Hardtanh,
40
+ "tanh": torch.nn.Tanh,
41
+ "relu": torch.nn.ReLU,
42
+ "selu": torch.nn.SELU,
43
+ "swish": getattr(torch.nn, "SiLU", Swish),
44
+ "gelu": torch.nn.GELU,
45
+ }
46
+
47
+ COSYVOICE_SUBSAMPLE_CLASSES = {
48
+ "linear": LinearNoSubsampling,
49
+ "linear_legacy": LegacyLinearNoSubsampling,
50
+ "embed": EmbedinigNoSubsampling,
51
+ "conv1d2": Conv1dSubsampling2,
52
+ "conv2d": Conv2dSubsampling4,
53
+ "conv2d6": Conv2dSubsampling6,
54
+ "conv2d8": Conv2dSubsampling8,
55
+ 'paraformer_dummy': torch.nn.Identity
56
+ }
57
+
58
+ COSYVOICE_EMB_CLASSES = {
59
+ "embed": PositionalEncoding,
60
+ "abs_pos": PositionalEncoding,
61
+ "rel_pos": RelPositionalEncoding,
62
+ "rel_pos_espnet": EspnetRelPositionalEncoding,
63
+ "no_pos": NoPositionalEncoding,
64
+ "abs_pos_whisper": WhisperPositionalEncoding,
65
+ "embed_learnable_pe": LearnablePositionalEncoding,
66
+ }
67
+
68
+ COSYVOICE_ATTENTION_CLASSES = {
69
+ "selfattn": MultiHeadedAttention,
70
+ "rel_selfattn": RelPositionMultiHeadedAttention,
71
+ }
src/chatterbox/models/s3gen/utils/mask.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import torch
18
+
19
+ '''
20
+ def subsequent_mask(
21
+ size: int,
22
+ device: torch.device = torch.device("cpu"),
23
+ ) -> torch.Tensor:
24
+ """Create mask for subsequent steps (size, size).
25
+
26
+ This mask is used only in decoder which works in an auto-regressive mode.
27
+ This means the current step could only do attention with its left steps.
28
+
29
+ In encoder, fully attention is used when streaming is not necessary and
30
+ the sequence is not long. In this case, no attention mask is needed.
31
+
32
+ When streaming is need, chunk-based attention is used in encoder. See
33
+ subsequent_chunk_mask for the chunk-based attention mask.
34
+
35
+ Args:
36
+ size (int): size of mask
37
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
38
+ dtype (torch.device): result dtype
39
+
40
+ Returns:
41
+ torch.Tensor: mask
42
+
43
+ Examples:
44
+ >>> subsequent_mask(3)
45
+ [[1, 0, 0],
46
+ [1, 1, 0],
47
+ [1, 1, 1]]
48
+ """
49
+ ret = torch.ones(size, size, device=device, dtype=torch.bool)
50
+ return torch.tril(ret)
51
+ '''
52
+
53
+
54
+ def subsequent_chunk_mask(
55
+ size: int,
56
+ chunk_size: int,
57
+ num_left_chunks: int = -1,
58
+ device: torch.device = torch.device("cpu"),
59
+ ) -> torch.Tensor:
60
+ """Create mask for subsequent steps (size, size) with chunk size,
61
+ this is for streaming encoder
62
+
63
+ Args:
64
+ size (int): size of mask
65
+ chunk_size (int): size of chunk
66
+ num_left_chunks (int): number of left chunks
67
+ <0: use full chunk
68
+ >=0: use num_left_chunks
69
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
70
+
71
+ Returns:
72
+ torch.Tensor: mask
73
+
74
+ Examples:
75
+ >>> subsequent_chunk_mask(4, 2)
76
+ [[1, 1, 0, 0],
77
+ [1, 1, 0, 0],
78
+ [1, 1, 1, 1],
79
+ [1, 1, 1, 1]]
80
+ """
81
+ # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
82
+ # actually this is not needed after we have inference cache implemented, will remove it later
83
+ pos_idx = torch.arange(size, device=device)
84
+ block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
85
+ ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
86
+ return ret
87
+
88
+
89
+ def add_optional_chunk_mask(xs: torch.Tensor,
90
+ masks: torch.Tensor,
91
+ use_dynamic_chunk: bool,
92
+ use_dynamic_left_chunk: bool,
93
+ decoding_chunk_size: int,
94
+ static_chunk_size: int,
95
+ num_decoding_left_chunks: int,
96
+ enable_full_context: bool = True):
97
+ """ Apply optional mask for encoder.
98
+
99
+ Args:
100
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
101
+ mask (torch.Tensor): mask for xs, (B, 1, L)
102
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
103
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
104
+ training.
105
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
106
+ 0: default for training, use random dynamic chunk.
107
+ <0: for decoding, use full chunk.
108
+ >0: for decoding, use fixed chunk size as set.
109
+ static_chunk_size (int): chunk size for static chunk training/decoding
110
+ if it's greater than 0, if use_dynamic_chunk is true,
111
+ this parameter will be ignored
112
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
113
+ the chunk size is decoding_chunk_size.
114
+ >=0: use num_decoding_left_chunks
115
+ <0: use all left chunks
116
+ enable_full_context (bool):
117
+ True: chunk size is either [1, 25] or full context(max_len)
118
+ False: chunk size ~ U[1, 25]
119
+
120
+ Returns:
121
+ torch.Tensor: chunk mask of the input xs.
122
+ """
123
+ # Whether to use chunk mask or not
124
+ if use_dynamic_chunk:
125
+ max_len = xs.size(1)
126
+ if decoding_chunk_size < 0:
127
+ chunk_size = max_len
128
+ num_left_chunks = -1
129
+ elif decoding_chunk_size > 0:
130
+ chunk_size = decoding_chunk_size
131
+ num_left_chunks = num_decoding_left_chunks
132
+ else:
133
+ # chunk size is either [1, 25] or full context(max_len).
134
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
135
+ # delay, the maximum frame is 100 / 4 = 25.
136
+ chunk_size = torch.randint(1, max_len, (1, )).item()
137
+ num_left_chunks = -1
138
+ if chunk_size > max_len // 2 and enable_full_context:
139
+ chunk_size = max_len
140
+ else:
141
+ chunk_size = chunk_size % 25 + 1
142
+ if use_dynamic_left_chunk:
143
+ max_left_chunks = (max_len - 1) // chunk_size
144
+ num_left_chunks = torch.randint(0, max_left_chunks,
145
+ (1, )).item()
146
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
147
+ num_left_chunks,
148
+ xs.device) # (L, L)
149
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
150
+ chunk_masks = masks & chunk_masks # (B, L, L)
151
+ elif static_chunk_size > 0:
152
+ num_left_chunks = num_decoding_left_chunks
153
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
154
+ num_left_chunks,
155
+ xs.device) # (L, L)
156
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
157
+ chunk_masks = masks & chunk_masks # (B, L, L)
158
+ else:
159
+ chunk_masks = masks
160
+ assert chunk_masks.dtype == torch.bool
161
+ if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
162
+ logging.warning('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
163
+ chunk_masks[chunk_masks.sum(dim=-1)==0] = True
164
+ return chunk_masks
165
+
166
+
167
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
168
+ """Make mask tensor containing indices of padded part.
169
+
170
+ See description of make_non_pad_mask.
171
+
172
+ Args:
173
+ lengths (torch.Tensor): Batch of lengths (B,).
174
+ Returns:
175
+ torch.Tensor: Mask tensor containing indices of padded part.
176
+
177
+ Examples:
178
+ >>> lengths = [5, 3, 2]
179
+ >>> make_pad_mask(lengths)
180
+ masks = [[0, 0, 0, 0 ,0],
181
+ [0, 0, 0, 1, 1],
182
+ [0, 0, 1, 1, 1]]
183
+ """
184
+ batch_size = lengths.size(0)
185
+ max_len = max_len if max_len > 0 else lengths.max().item()
186
+ seq_range = torch.arange(0,
187
+ max_len,
188
+ dtype=torch.int64,
189
+ device=lengths.device)
190
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
191
+ seq_length_expand = lengths.unsqueeze(-1)
192
+ mask = seq_range_expand >= seq_length_expand
193
+ return mask
src/chatterbox/models/s3gen/utils/mel.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """mel-spectrogram extraction in Matcha-TTS"""
2
+ import logging
3
+ from librosa.filters import mel as librosa_mel_fn
4
+ import torch
5
+ import numpy as np
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ # NOTE: they decalred these global vars
11
+ mel_basis = {}
12
+ hann_window = {}
13
+
14
+
15
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
16
+ return torch.log(torch.clamp(x, min=clip_val) * C)
17
+
18
+
19
+ def spectral_normalize_torch(magnitudes):
20
+ output = dynamic_range_compression_torch(magnitudes)
21
+ return output
22
+
23
+ """
24
+ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
25
+ n_fft: 1920
26
+ num_mels: 80
27
+ sampling_rate: 24000
28
+ hop_size: 480
29
+ win_size: 1920
30
+ fmin: 0
31
+ fmax: 8000
32
+ center: False
33
+
34
+ """
35
+
36
+ def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920,
37
+ fmin=0, fmax=8000, center=False):
38
+ """Copied from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/audio.py
39
+ Set default values according to Cosyvoice's config.
40
+ """
41
+
42
+ if isinstance(y, np.ndarray):
43
+ y = torch.tensor(y).float()
44
+
45
+ if len(y.shape) == 1:
46
+ y = y[None, ]
47
+
48
+ # Debug: Check for audio clipping (values outside [-1.0, 1.0] range)
49
+ min_val = torch.min(y)
50
+ max_val = torch.max(y)
51
+ if min_val < -1.0 or max_val > 1.0:
52
+ logger.warning(f"Audio values outside normalized range: min={min_val.item():.4f}, max={max_val.item():.4f}")
53
+
54
+ global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
55
+ if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
56
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
57
+ mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
58
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
59
+
60
+ y = torch.nn.functional.pad(
61
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
62
+ )
63
+ y = y.squeeze(1)
64
+
65
+ spec = torch.view_as_real(
66
+ torch.stft(
67
+ y,
68
+ n_fft,
69
+ hop_length=hop_size,
70
+ win_length=win_size,
71
+ window=hann_window[str(y.device)],
72
+ center=center,
73
+ pad_mode="reflect",
74
+ normalized=False,
75
+ onesided=True,
76
+ return_complex=True,
77
+ )
78
+ )
79
+
80
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
81
+
82
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
83
+ spec = spectral_normalize_torch(spec)
84
+
85
+ return spec
src/chatterbox/models/s3gen/xvector.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+ # Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker)
6
+
7
+
8
+ from collections import OrderedDict
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint as cp
12
+ import torchaudio.compliance.kaldi as Kaldi
13
+
14
+
15
+ def pad_list(xs, pad_value):
16
+ """Perform padding for the list of tensors.
17
+
18
+ Args:
19
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
20
+ pad_value (float): Value for padding.
21
+
22
+ Returns:
23
+ Tensor: Padded tensor (B, Tmax, `*`).
24
+
25
+ Examples:
26
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
27
+ >>> x
28
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
29
+ >>> pad_list(x, 0)
30
+ tensor([[1., 1., 1., 1.],
31
+ [1., 1., 0., 0.],
32
+ [1., 0., 0., 0.]])
33
+
34
+ """
35
+ n_batch = len(xs)
36
+ max_len = max(x.size(0) for x in xs)
37
+ pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
38
+
39
+ for i in range(n_batch):
40
+ pad[i, : xs[i].size(0)] = xs[i]
41
+
42
+ return pad
43
+
44
+
45
+ def extract_feature(audio):
46
+ features = []
47
+ feature_times = []
48
+ feature_lengths = []
49
+ for au in audio:
50
+ feature = Kaldi.fbank(au.unsqueeze(0), num_mel_bins=80)
51
+ feature = feature - feature.mean(dim=0, keepdim=True)
52
+ features.append(feature)
53
+ feature_times.append(au.shape[0])
54
+ feature_lengths.append(feature.shape[0])
55
+ # padding for batch inference
56
+ features_padded = pad_list(features, pad_value=0)
57
+ # features = torch.cat(features)
58
+ return features_padded, feature_lengths, feature_times
59
+
60
+
61
+ class BasicResBlock(torch.nn.Module):
62
+ expansion = 1
63
+
64
+ def __init__(self, in_planes, planes, stride=1):
65
+ super(BasicResBlock, self).__init__()
66
+ self.conv1 = torch.nn.Conv2d(
67
+ in_planes, planes, kernel_size=3, stride=(stride, 1), padding=1, bias=False
68
+ )
69
+ self.bn1 = torch.nn.BatchNorm2d(planes)
70
+ self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
71
+ self.bn2 = torch.nn.BatchNorm2d(planes)
72
+
73
+ self.shortcut = torch.nn.Sequential()
74
+ if stride != 1 or in_planes != self.expansion * planes:
75
+ self.shortcut = torch.nn.Sequential(
76
+ torch.nn.Conv2d(
77
+ in_planes,
78
+ self.expansion * planes,
79
+ kernel_size=1,
80
+ stride=(stride, 1),
81
+ bias=False,
82
+ ),
83
+ torch.nn.BatchNorm2d(self.expansion * planes),
84
+ )
85
+
86
+ def forward(self, x):
87
+ out = F.relu(self.bn1(self.conv1(x)))
88
+ out = self.bn2(self.conv2(out))
89
+ out += self.shortcut(x)
90
+ out = F.relu(out)
91
+ return out
92
+
93
+
94
+ class FCM(torch.nn.Module):
95
+ def __init__(self, block=BasicResBlock, num_blocks=[2, 2], m_channels=32, feat_dim=80):
96
+ super(FCM, self).__init__()
97
+ self.in_planes = m_channels
98
+ self.conv1 = torch.nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
99
+ self.bn1 = torch.nn.BatchNorm2d(m_channels)
100
+
101
+ self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
102
+ self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
103
+
104
+ self.conv2 = torch.nn.Conv2d(
105
+ m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False
106
+ )
107
+ self.bn2 = torch.nn.BatchNorm2d(m_channels)
108
+ self.out_channels = m_channels * (feat_dim // 8)
109
+
110
+ def _make_layer(self, block, planes, num_blocks, stride):
111
+ strides = [stride] + [1] * (num_blocks - 1)
112
+ layers = []
113
+ for stride in strides:
114
+ layers.append(block(self.in_planes, planes, stride))
115
+ self.in_planes = planes * block.expansion
116
+ return torch.nn.Sequential(*layers)
117
+
118
+ def forward(self, x):
119
+ x = x.unsqueeze(1)
120
+ out = F.relu(self.bn1(self.conv1(x)))
121
+ out = self.layer1(out)
122
+ out = self.layer2(out)
123
+ out = F.relu(self.bn2(self.conv2(out)))
124
+
125
+ shape = out.shape
126
+ out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
127
+ return out
128
+
129
+
130
+ def get_nonlinear(config_str, channels):
131
+ nonlinear = torch.nn.Sequential()
132
+ for name in config_str.split("-"):
133
+ if name == "relu":
134
+ nonlinear.add_module("relu", torch.nn.ReLU(inplace=True))
135
+ elif name == "prelu":
136
+ nonlinear.add_module("prelu", torch.nn.PReLU(channels))
137
+ elif name == "batchnorm":
138
+ nonlinear.add_module("batchnorm", torch.nn.BatchNorm1d(channels))
139
+ elif name == "batchnorm_":
140
+ nonlinear.add_module("batchnorm", torch.nn.BatchNorm1d(channels, affine=False))
141
+ else:
142
+ raise ValueError("Unexpected module ({}).".format(name))
143
+ return nonlinear
144
+
145
+
146
+ def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
147
+ mean = x.mean(dim=dim)
148
+ std = x.std(dim=dim, unbiased=unbiased)
149
+ stats = torch.cat([mean, std], dim=-1)
150
+ if keepdim:
151
+ stats = stats.unsqueeze(dim=dim)
152
+ return stats
153
+
154
+
155
+ class StatsPool(torch.nn.Module):
156
+ def forward(self, x):
157
+ return statistics_pooling(x)
158
+
159
+
160
+ class TDNNLayer(torch.nn.Module):
161
+ def __init__(
162
+ self,
163
+ in_channels,
164
+ out_channels,
165
+ kernel_size,
166
+ stride=1,
167
+ padding=0,
168
+ dilation=1,
169
+ bias=False,
170
+ config_str="batchnorm-relu",
171
+ ):
172
+ super(TDNNLayer, self).__init__()
173
+ if padding < 0:
174
+ assert (
175
+ kernel_size % 2 == 1
176
+ ), "Expect equal paddings, but got even kernel size ({})".format(kernel_size)
177
+ padding = (kernel_size - 1) // 2 * dilation
178
+ self.linear = torch.nn.Conv1d(
179
+ in_channels,
180
+ out_channels,
181
+ kernel_size,
182
+ stride=stride,
183
+ padding=padding,
184
+ dilation=dilation,
185
+ bias=bias,
186
+ )
187
+ self.nonlinear = get_nonlinear(config_str, out_channels)
188
+
189
+ def forward(self, x):
190
+ x = self.linear(x)
191
+ x = self.nonlinear(x)
192
+ return x
193
+
194
+
195
+ class CAMLayer(torch.nn.Module):
196
+ def __init__(
197
+ self, bn_channels, out_channels, kernel_size, stride, padding, dilation, bias, reduction=2
198
+ ):
199
+ super(CAMLayer, self).__init__()
200
+ self.linear_local = torch.nn.Conv1d(
201
+ bn_channels,
202
+ out_channels,
203
+ kernel_size,
204
+ stride=stride,
205
+ padding=padding,
206
+ dilation=dilation,
207
+ bias=bias,
208
+ )
209
+ self.linear1 = torch.nn.Conv1d(bn_channels, bn_channels // reduction, 1)
210
+ self.relu = torch.nn.ReLU(inplace=True)
211
+ self.linear2 = torch.nn.Conv1d(bn_channels // reduction, out_channels, 1)
212
+ self.sigmoid = torch.nn.Sigmoid()
213
+
214
+ def forward(self, x):
215
+ y = self.linear_local(x)
216
+ context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
217
+ context = self.relu(self.linear1(context))
218
+ m = self.sigmoid(self.linear2(context))
219
+ return y * m
220
+
221
+ def seg_pooling(self, x, seg_len=100, stype="avg"):
222
+ if stype == "avg":
223
+ seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
224
+ elif stype == "max":
225
+ seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
226
+ else:
227
+ raise ValueError("Wrong segment pooling type.")
228
+ shape = seg.shape
229
+ seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
230
+ seg = seg[..., : x.shape[-1]]
231
+ return seg
232
+
233
+
234
+ class CAMDenseTDNNLayer(torch.nn.Module):
235
+ def __init__(
236
+ self,
237
+ in_channels,
238
+ out_channels,
239
+ bn_channels,
240
+ kernel_size,
241
+ stride=1,
242
+ dilation=1,
243
+ bias=False,
244
+ config_str="batchnorm-relu",
245
+ memory_efficient=False,
246
+ ):
247
+ super(CAMDenseTDNNLayer, self).__init__()
248
+ assert kernel_size % 2 == 1, "Expect equal paddings, but got even kernel size ({})".format(
249
+ kernel_size
250
+ )
251
+ padding = (kernel_size - 1) // 2 * dilation
252
+ self.memory_efficient = memory_efficient
253
+ self.nonlinear1 = get_nonlinear(config_str, in_channels)
254
+ self.linear1 = torch.nn.Conv1d(in_channels, bn_channels, 1, bias=False)
255
+ self.nonlinear2 = get_nonlinear(config_str, bn_channels)
256
+ self.cam_layer = CAMLayer(
257
+ bn_channels,
258
+ out_channels,
259
+ kernel_size,
260
+ stride=stride,
261
+ padding=padding,
262
+ dilation=dilation,
263
+ bias=bias,
264
+ )
265
+
266
+ def bn_function(self, x):
267
+ return self.linear1(self.nonlinear1(x))
268
+
269
+ def forward(self, x):
270
+ if self.training and self.memory_efficient:
271
+ x = cp.checkpoint(self.bn_function, x)
272
+ else:
273
+ x = self.bn_function(x)
274
+ x = self.cam_layer(self.nonlinear2(x))
275
+ return x
276
+
277
+
278
+ class CAMDenseTDNNBlock(torch.nn.ModuleList):
279
+ def __init__(
280
+ self,
281
+ num_layers,
282
+ in_channels,
283
+ out_channels,
284
+ bn_channels,
285
+ kernel_size,
286
+ stride=1,
287
+ dilation=1,
288
+ bias=False,
289
+ config_str="batchnorm-relu",
290
+ memory_efficient=False,
291
+ ):
292
+ super(CAMDenseTDNNBlock, self).__init__()
293
+ for i in range(num_layers):
294
+ layer = CAMDenseTDNNLayer(
295
+ in_channels=in_channels + i * out_channels,
296
+ out_channels=out_channels,
297
+ bn_channels=bn_channels,
298
+ kernel_size=kernel_size,
299
+ stride=stride,
300
+ dilation=dilation,
301
+ bias=bias,
302
+ config_str=config_str,
303
+ memory_efficient=memory_efficient,
304
+ )
305
+ self.add_module("tdnnd%d" % (i + 1), layer)
306
+
307
+ def forward(self, x):
308
+ for layer in self:
309
+ x = torch.cat([x, layer(x)], dim=1)
310
+ return x
311
+
312
+
313
+ class TransitLayer(torch.nn.Module):
314
+ def __init__(self, in_channels, out_channels, bias=True, config_str="batchnorm-relu"):
315
+ super(TransitLayer, self).__init__()
316
+ self.nonlinear = get_nonlinear(config_str, in_channels)
317
+ self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias)
318
+
319
+ def forward(self, x):
320
+ x = self.nonlinear(x)
321
+ x = self.linear(x)
322
+ return x
323
+
324
+
325
+ class DenseLayer(torch.nn.Module):
326
+ def __init__(self, in_channels, out_channels, bias=False, config_str="batchnorm-relu"):
327
+ super(DenseLayer, self).__init__()
328
+ self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias)
329
+ self.nonlinear = get_nonlinear(config_str, out_channels)
330
+
331
+ def forward(self, x):
332
+ if len(x.shape) == 2:
333
+ x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
334
+ else:
335
+ x = self.linear(x)
336
+ x = self.nonlinear(x)
337
+ return x
338
+
339
+ # @tables.register("model_classes", "CAMPPlus")
340
+ class CAMPPlus(torch.nn.Module):
341
+ def __init__(
342
+ self,
343
+ feat_dim=80,
344
+ embedding_size=192,
345
+ growth_rate=32,
346
+ bn_size=4,
347
+ init_channels=128,
348
+ config_str="batchnorm-relu",
349
+ memory_efficient=True,
350
+ output_level="segment",
351
+ **kwargs,
352
+ ):
353
+ super().__init__()
354
+
355
+ self.head = FCM(feat_dim=feat_dim)
356
+ channels = self.head.out_channels
357
+ self.output_level = output_level
358
+
359
+ self.xvector = torch.nn.Sequential(
360
+ OrderedDict(
361
+ [
362
+ (
363
+ "tdnn",
364
+ TDNNLayer(
365
+ channels,
366
+ init_channels,
367
+ 5,
368
+ stride=2,
369
+ dilation=1,
370
+ padding=-1,
371
+ config_str=config_str,
372
+ ),
373
+ ),
374
+ ]
375
+ )
376
+ )
377
+ channels = init_channels
378
+ for i, (num_layers, kernel_size, dilation) in enumerate(
379
+ zip((12, 24, 16), (3, 3, 3), (1, 2, 2))
380
+ ):
381
+ block = CAMDenseTDNNBlock(
382
+ num_layers=num_layers,
383
+ in_channels=channels,
384
+ out_channels=growth_rate,
385
+ bn_channels=bn_size * growth_rate,
386
+ kernel_size=kernel_size,
387
+ dilation=dilation,
388
+ config_str=config_str,
389
+ memory_efficient=memory_efficient,
390
+ )
391
+ self.xvector.add_module("block%d" % (i + 1), block)
392
+ channels = channels + num_layers * growth_rate
393
+ self.xvector.add_module(
394
+ "transit%d" % (i + 1),
395
+ TransitLayer(channels, channels // 2, bias=False, config_str=config_str),
396
+ )
397
+ channels //= 2
398
+
399
+ self.xvector.add_module("out_nonlinear", get_nonlinear(config_str, channels))
400
+
401
+ if self.output_level == "segment":
402
+ self.xvector.add_module("stats", StatsPool())
403
+ self.xvector.add_module(
404
+ "dense", DenseLayer(channels * 2, embedding_size, config_str="batchnorm_")
405
+ )
406
+ else:
407
+ assert (
408
+ self.output_level == "frame"
409
+ ), "`output_level` should be set to 'segment' or 'frame'. "
410
+
411
+ for m in self.modules():
412
+ if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)):
413
+ torch.nn.init.kaiming_normal_(m.weight.data)
414
+ if m.bias is not None:
415
+ torch.nn.init.zeros_(m.bias)
416
+
417
+ def forward(self, x):
418
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
419
+ x = self.head(x)
420
+ x = self.xvector(x)
421
+ if self.output_level == "frame":
422
+ x = x.transpose(1, 2)
423
+ return x
424
+
425
+ def inference(self, audio_list):
426
+ speech, speech_lengths, speech_times = extract_feature(audio_list)
427
+ results = self.forward(speech.to(torch.float32))
428
+ return results
src/chatterbox/models/s3tokenizer/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .s3tokenizer import (
2
+ S3_SR,
3
+ S3_HOP,
4
+ S3_TOKEN_HOP,
5
+ S3_TOKEN_RATE,
6
+ SPEECH_VOCAB_SIZE,
7
+ S3Tokenizer,
8
+ )
9
+
10
+
11
+ SOS = SPEECH_VOCAB_SIZE
12
+ EOS = SPEECH_VOCAB_SIZE + 1
13
+
14
+
15
+
16
+ def drop_invalid_tokens(x):
17
+ """Drop SoS and EoS"""
18
+ assert len(x.shape) == 1 or (len(x.shape) == 2 and x.shape[0] == 1), "only batch size of one allowed for now"
19
+ if SOS in x:
20
+ s = (x == SOS).nonzero(as_tuple=True)[0].squeeze(0) + 1
21
+ else:
22
+ s = 0
23
+
24
+ if EOS in x:
25
+ e = (x == EOS).nonzero(as_tuple=True)[0].squeeze(0)
26
+ else:
27
+ e = None
28
+
29
+ x = x[s: e]
30
+ return x
src/chatterbox/models/s3tokenizer/s3tokenizer.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import numpy as np
4
+ import librosa
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from s3tokenizer.utils import padding
8
+ from s3tokenizer.model_v2 import (
9
+ S3TokenizerV2,
10
+ ModelConfig,
11
+ )
12
+
13
+
14
+ # Sampling rate of the inputs to S3TokenizerV2
15
+ S3_SR = 16_000
16
+ S3_HOP = 160 # 100 frames/sec
17
+ S3_TOKEN_HOP = 640 # 25 tokens/sec
18
+ S3_TOKEN_RATE = 25
19
+ SPEECH_VOCAB_SIZE = 6561
20
+
21
+
22
+ class S3Tokenizer(S3TokenizerV2):
23
+ """
24
+ s3tokenizer.S3TokenizerV2 with the following changes:
25
+ - a more integrated `forward`
26
+ - compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers`
27
+ """
28
+
29
+ ignore_state_dict_missing = ("_mel_filters", "window")
30
+
31
+ def __init__(
32
+ self,
33
+ name: str="speech_tokenizer_v2_25hz",
34
+ config: ModelConfig = ModelConfig()
35
+ ):
36
+ super().__init__(name)
37
+
38
+ self.n_fft = 400
39
+ _mel_filters = librosa.filters.mel(
40
+ sr=S3_SR,
41
+ n_fft=self.n_fft,
42
+ n_mels=config.n_mels
43
+ )
44
+ self.register_buffer(
45
+ "_mel_filters",
46
+ torch.FloatTensor(_mel_filters),
47
+ )
48
+
49
+ self.register_buffer(
50
+ "window",
51
+ torch.hann_window(self.n_fft),
52
+ )
53
+
54
+ def pad(self, wavs, sr) -> List[torch.Tensor]:
55
+ """
56
+ Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec).
57
+ """
58
+ processed_wavs = []
59
+ for wav in wavs:
60
+ if isinstance(wav, np.ndarray):
61
+ wav = torch.from_numpy(wav)
62
+ if wav.dim() == 1:
63
+ wav = wav.unsqueeze(0)
64
+
65
+ n_tokens = (wav.shape[1] / sr) * S3_TOKEN_RATE
66
+ n_tokens = np.ceil(n_tokens)
67
+ intended_wav_len = n_tokens * (sr / S3_TOKEN_RATE)
68
+ intended_wav_len = int(intended_wav_len)
69
+ wav = torch.nn.functional.pad(
70
+ wav,
71
+ (0, intended_wav_len - wav.shape[-1]),
72
+ mode="constant",
73
+ value=0
74
+ )
75
+ processed_wavs.append(wav)
76
+ return processed_wavs
77
+
78
+ def _prepare_audio(self, wavs):
79
+ """Prepare a list of audios for s3tokenizer processing."""
80
+ processed_wavs = []
81
+ for wav in wavs:
82
+ if isinstance(wav, np.ndarray):
83
+ wav = torch.from_numpy(wav)
84
+ if wav.dim() == 1:
85
+ wav = wav.unsqueeze(0)
86
+
87
+ processed_wavs.append(wav)
88
+ return processed_wavs
89
+
90
+ @torch.no_grad()
91
+ def forward(
92
+ self,
93
+ wavs: torch.Tensor,
94
+ accelerator: 'Accelerator'=None,
95
+ max_len: int=None,
96
+ ) -> Tuple[torch.Tensor, torch.LongTensor]:
97
+ """
98
+ NOTE: mel-spec has a hop size of 160 points (100 frame/sec).
99
+ FIXME: this class inherits `nn.Module` but doesn't accept `torch.Tensor` and handles a list of wavs one by one, which is unexpected.
100
+
101
+ Args
102
+ ----
103
+ - `wavs`: 16 kHz speech audio
104
+ - `max_len` max length to truncate the output sequence to (25 token/sec).
105
+ NOTE: please pad the waveform if longer sequence is needed.
106
+ """
107
+ processed_wavs = self._prepare_audio(wavs)
108
+ mels, mel_lens = [], []
109
+ for wav in processed_wavs:
110
+ wav = wav.to(self.device)
111
+ mel = self.log_mel_spectrogram(wav) # [B=1, F, T]
112
+ if max_len is not None:
113
+ mel = mel[..., :max_len * 4] # num_mel_frames = 4 * num_tokens
114
+ mels.append(mel.squeeze(0))
115
+
116
+ mels, mel_lens = padding(mels)
117
+ if accelerator is None:
118
+ tokenizer = self
119
+ else:
120
+ tokenizer = accelerator.unwrap_model(self)
121
+
122
+ speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device))
123
+ return (
124
+ speech_tokens.long().detach(),
125
+ speech_token_lens.long().detach(),
126
+ )
127
+
128
+ def log_mel_spectrogram(
129
+ self,
130
+ audio: torch.Tensor,
131
+ padding: int = 0,
132
+ ):
133
+ """
134
+ Compute the log-Mel spectrogram of
135
+
136
+ Parameters
137
+ ----------
138
+ audio: torch.Tensor, shape = (*)
139
+ The path to audio or either a NumPy array or Tensor containing the
140
+ audio waveform in 16 kHz
141
+
142
+ padding: int
143
+ Number of zero samples to pad to the right
144
+
145
+ Returns
146
+ -------
147
+ torch.Tensor, shape = (128, n_frames)
148
+ A Tensor that contains the Mel spectrogram
149
+ """
150
+ if not torch.is_tensor(audio):
151
+ audio = torch.from_numpy(audio)
152
+
153
+ audio = audio.to(self.device)
154
+ if padding > 0:
155
+ audio = F.pad(audio, (0, padding))
156
+ stft = torch.stft(
157
+ audio, self.n_fft, S3_HOP,
158
+ window=self.window.to(self.device),
159
+ return_complex=True
160
+ )
161
+ magnitudes = stft[..., :-1].abs()**2
162
+
163
+ mel_spec = self._mel_filters.to(self.device) @ magnitudes
164
+
165
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
166
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
167
+ log_spec = (log_spec + 4.0) / 4.0
168
+ return log_spec
src/chatterbox/models/t3/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .t3 import T3
src/chatterbox/models/t3/inference/alignment_stream_analyzer.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Resemble AI
2
+ # Author: John Meade, Jeremy Hsu
3
+ # MIT License
4
+ import logging
5
+ import torch
6
+ from dataclasses import dataclass
7
+ from types import MethodType
8
+
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ LLAMA_ALIGNED_HEADS = [(12, 15), (13, 11), (9, 2)]
14
+
15
+
16
+ @dataclass
17
+ class AlignmentAnalysisResult:
18
+ # was this frame detected as being part of a noisy beginning chunk with potential hallucinations?
19
+ false_start: bool
20
+ # was this frame detected as being part of a long tail with potential hallucinations?
21
+ long_tail: bool
22
+ # was this frame detected as repeating existing text content?
23
+ repetition: bool
24
+ # was the alignment position of this frame too far from the previous frame?
25
+ discontinuity: bool
26
+ # has inference reached the end of the text tokens? eg, this remains false if inference stops early
27
+ complete: bool
28
+ # approximate position in the text token sequence. Can be used for generating online timestamps.
29
+ position: int
30
+
31
+
32
+ class AlignmentStreamAnalyzer:
33
+ def __init__(self, tfmr, queue, text_tokens_slice, alignment_layer_idx=9, eos_idx=0):
34
+ """
35
+ Some transformer TTS models implicitly solve text-speech alignment in one or more of their self-attention
36
+ activation maps. This module exploits this to perform online integrity checks which streaming.
37
+ A hook is injected into the specified attention layer, and heuristics are used to determine alignment
38
+ position, repetition, etc.
39
+
40
+ NOTE: currently requires no queues.
41
+ """
42
+ # self.queue = queue
43
+ self.text_tokens_slice = (i, j) = text_tokens_slice
44
+ self.eos_idx = eos_idx
45
+ self.alignment = torch.zeros(0, j-i)
46
+ # self.alignment_bin = torch.zeros(0, j-i)
47
+ self.curr_frame_pos = 0
48
+ self.text_position = 0
49
+
50
+ self.started = False
51
+ self.started_at = None
52
+
53
+ self.complete = False
54
+ self.completed_at = None
55
+
56
+ # Track generated tokens for repetition detection
57
+ self.generated_tokens = []
58
+
59
+ # Using `output_attentions=True` is incompatible with optimized attention kernels, so
60
+ # using it for all layers slows things down too much. We can apply it to just one layer
61
+ # by intercepting the kwargs and adding a forward hook (credit: jrm)
62
+ self.last_aligned_attns = []
63
+ for i, (layer_idx, head_idx) in enumerate(LLAMA_ALIGNED_HEADS):
64
+ self.last_aligned_attns += [None]
65
+ self._add_attention_spy(tfmr, i, layer_idx, head_idx)
66
+
67
+ def _add_attention_spy(self, tfmr, buffer_idx, layer_idx, head_idx):
68
+ """
69
+ Adds a forward hook to a specific attention layer to collect outputs.
70
+ """
71
+ def attention_forward_hook(module, input, output):
72
+ """
73
+ See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`.
74
+ NOTE:
75
+ - When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`.
76
+ - `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th.
77
+ """
78
+ if isinstance(output, tuple) and len(output) > 1 and output[1] is not None:
79
+ step_attention = output[1].cpu() # (B, n_heads, T0, Ti)
80
+ self.last_aligned_attns[buffer_idx] = step_attention[0, head_idx] # (T0, Ti)
81
+
82
+ target_layer = tfmr.layers[layer_idx].self_attn
83
+ # Register hook and store the handle
84
+ target_layer.register_forward_hook(attention_forward_hook)
85
+ if hasattr(tfmr, 'config') and hasattr(tfmr.config, 'output_attentions'):
86
+ self.original_output_attentions = tfmr.config.output_attentions
87
+ tfmr.config.output_attentions = True
88
+
89
+ def step(self, logits, next_token=None):
90
+ """
91
+ Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS.
92
+ """
93
+ # extract approximate alignment matrix chunk (1 frame at a time after the first chunk)
94
+ aligned_attn = torch.stack(self.last_aligned_attns).mean(dim=0) # (N, N)
95
+ i, j = self.text_tokens_slice
96
+ if self.curr_frame_pos == 0:
97
+ # first chunk has conditioning info, text tokens, and BOS token
98
+ A_chunk = aligned_attn[j:, i:j].clone().cpu() # (T, S)
99
+ else:
100
+ # subsequent chunks have 1 frame due to KV-caching
101
+ A_chunk = aligned_attn[:, i:j].clone().cpu() # (1, S)
102
+
103
+ # TODO: monotonic masking; could have issue b/c spaces are often skipped.
104
+ A_chunk[:, self.curr_frame_pos + 1:] = 0
105
+
106
+
107
+ self.alignment = torch.cat((self.alignment, A_chunk), dim=0)
108
+
109
+ A = self.alignment
110
+ T, S = A.shape
111
+
112
+ # update position
113
+ cur_text_posn = A_chunk[-1].argmax()
114
+ discontinuity = not(-4 < cur_text_posn - self.text_position < 7) # NOTE: very lenient!
115
+ if not discontinuity:
116
+ self.text_position = cur_text_posn
117
+
118
+ # Hallucinations at the start of speech show up as activations at the bottom of the attention maps!
119
+ # To mitigate this, we just wait until there are no activations far off-diagonal in the last 2 tokens,
120
+ # and there are some strong activations in the first few tokens.
121
+ false_start = (not self.started) and (A[-2:, -2:].max() > 0.1 or A[:, :4].max() < 0.5)
122
+ self.started = not false_start
123
+ if self.started and self.started_at is None:
124
+ self.started_at = T
125
+
126
+ # Is generation likely complete?
127
+ self.complete = self.complete or self.text_position >= S - 3
128
+ if self.complete and self.completed_at is None:
129
+ self.completed_at = T
130
+
131
+ # NOTE: EOS rarely assigned activations, and second-last token is often punctuation, so use last 3 tokens.
132
+ # NOTE: due to the false-start behaviour, we need to make sure we skip activations for the first few tokens.
133
+ last_text_token_duration = A[15:, -3:].sum()
134
+
135
+ # Activations for the final token that last too long are likely hallucinations.
136
+ long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 5) # 200ms
137
+
138
+ # If there are activations in previous tokens after generation has completed, assume this is a repetition error.
139
+ alignment_repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
140
+
141
+ # Track generated tokens for repetition detection
142
+ if next_token is not None:
143
+ # Convert tensor to scalar if needed
144
+ if isinstance(next_token, torch.Tensor):
145
+ token_id = next_token.item() if next_token.numel() == 1 else next_token.view(-1)[0].item()
146
+ else:
147
+ token_id = next_token
148
+ self.generated_tokens.append(token_id)
149
+
150
+ # Keep only last 8 tokens to prevent memory issues
151
+ if len(self.generated_tokens) > 8:
152
+ self.generated_tokens = self.generated_tokens[-8:]
153
+
154
+ # Check for excessive token repetition (3x same token in a row)
155
+ token_repetition = (
156
+ # self.complete and
157
+ len(self.generated_tokens) >= 3 and
158
+ len(set(self.generated_tokens[-3:])) == 1
159
+ )
160
+
161
+ if token_repetition:
162
+ repeated_token = self.generated_tokens[-1]
163
+ logger.warning(f"🚨 Detected 3x repetition of token {repeated_token}")
164
+
165
+ # Suppress EoS to prevent early termination
166
+ if cur_text_posn < S - 3 and S > 5: # Only suppress if text is longer than 5 tokens
167
+ logits[..., self.eos_idx] = -2**15
168
+
169
+ # If a bad ending is detected, force emit EOS by modifying logits
170
+ # NOTE: this means logits may be inconsistent with latents!
171
+ if long_tail or alignment_repetition or token_repetition:
172
+ logger.warning(f"forcing EOS token, {long_tail=}, {alignment_repetition=}, {token_repetition=}")
173
+ # (±2**15 is safe for all dtypes >= 16bit)
174
+ logits = -(2**15) * torch.ones_like(logits)
175
+ logits[..., self.eos_idx] = 2**15
176
+
177
+ self.curr_frame_pos += 1
178
+ return logits
src/chatterbox/models/t3/inference/t3_hf_backend.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn as nn
5
+ from transformers import LlamaConfig, LlamaModel, LlamaPreTrainedModel, GenerationMixin
6
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
7
+
8
+
9
+ class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin):
10
+ """
11
+ Override some HuggingFace interface methods so we can use the standard `generate` method with our
12
+ custom embedding / logit layers.
13
+
14
+ NOTE: need to extend "*PreTrainedModel" to avoid re-initializing weights!
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ config: LlamaConfig,
20
+ llama: LlamaModel,
21
+ *,
22
+ speech_enc,
23
+ speech_head,
24
+ latents_queue=None,
25
+ logits_queue=None,
26
+ alignment_stream_analyzer: 'AlignmentStreamAnalyzer'=None,
27
+ ):
28
+ super().__init__(config)
29
+ self.model = llama
30
+ self.speech_enc = speech_enc
31
+ self.speech_head = speech_head
32
+ self._added_cond = False
33
+ self.alignment_stream_analyzer = alignment_stream_analyzer
34
+
35
+ @torch.inference_mode()
36
+ def prepare_inputs_for_generation(
37
+ self, input_ids: torch.Tensor, decoder_cond: torch.Tensor, use_cache: bool, past_key_values=None,
38
+ # This argument was introduced in some recent version of transformers (>=4.29.1)
39
+ cache_position=None
40
+ ):
41
+ """
42
+ This is a method used by huggingface's generate() method.
43
+ Overridden here to apply our custom speech token embedding layer.
44
+
45
+ :param input_ids: (B, S) int64 tensors of input tokens.
46
+ :param decoder_cond: (B, T, C) float32 tensor of conditioning (prefixed to <input_embeds>)
47
+ """
48
+
49
+ # Make use of the kv cache: only the last input ID is new, we trim away all the ones before
50
+ if not use_cache:
51
+ past_key_values = None
52
+ if past_key_values is not None:
53
+ input_ids = input_ids[:, -1:]
54
+
55
+ # custom speech token embedding layer
56
+ inputs_embeds = self.speech_enc(input_ids)
57
+
58
+ # prefix decoder conditioning if applicable
59
+ if not self._added_cond:
60
+ assert past_key_values is not None # should be first step
61
+ if decoder_cond.size(0) != inputs_embeds.size(0):
62
+ decoder_cond = decoder_cond.expand(inputs_embeds.size(0), -1, -1)
63
+ inputs_embeds = torch.cat([decoder_cond, inputs_embeds], dim=1)
64
+ self._added_cond = True
65
+
66
+ return {
67
+ "inputs_embeds": inputs_embeds,
68
+ "past_key_values": past_key_values,
69
+ "use_cache": use_cache,
70
+ }
71
+
72
+ @torch.inference_mode()
73
+ def forward(
74
+ self,
75
+ inputs_embeds: torch.Tensor,
76
+ past_key_values: Optional[torch.Tensor]=None,
77
+ use_cache=True,
78
+ output_attentions=False,
79
+ output_hidden_states=True,
80
+ return_dict=True,
81
+ ):
82
+ """
83
+ This is a method used by huggingface's generate() method.
84
+ Overridden here to apply our custom layer norm and speech logit projection layers.
85
+
86
+ :param inputs_embeds: (B, S, C) float32 tensor of conditioning inputs. If past key values are given,
87
+ S should be 1.
88
+ """
89
+ is_large_input = inputs_embeds.size(1) != 1
90
+ has_cache = past_key_values is not None and len(past_key_values) > 0
91
+ assert not (is_large_input and has_cache)
92
+ assert return_dict
93
+ assert output_hidden_states
94
+
95
+ tfmr_out = self.model(
96
+ inputs_embeds=inputs_embeds,
97
+ past_key_values=past_key_values,
98
+ use_cache=use_cache,
99
+ output_attentions=output_attentions,
100
+ output_hidden_states=output_hidden_states,
101
+ return_dict=True,
102
+ )
103
+ hidden_states = tfmr_out.hidden_states[-1] # (B, seq, dim)
104
+
105
+ logits = self.speech_head(hidden_states)
106
+ # assert inputs_embeds.size(0) == 1 # (disabled for CFG)
107
+
108
+ # NOTE: hallucination handler may modify logits to force emit an EOS token
109
+ # logits = self.alignment_stream_analyzer.step(logits)
110
+
111
+ return CausalLMOutputWithCrossAttentions(
112
+ logits=logits,
113
+ past_key_values=tfmr_out.past_key_values,
114
+ hidden_states=tfmr_out.hidden_states,
115
+ attentions=tfmr_out.attentions,
116
+ )
src/chatterbox/models/t3/llama_configs.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LLAMA_520M_CONFIG_DICT = dict(
2
+ # Arbitrary small number that won't cause problems when loading.
3
+ # These param are unused due to custom input layers.
4
+ vocab_size=8,
5
+ # default params needed for loading most pretrained 1B weights
6
+ max_position_embeddings=131072,
7
+ hidden_size=1024,
8
+ intermediate_size=4096,
9
+ num_hidden_layers=30,
10
+ num_attention_heads=16,
11
+ attn_implementation="sdpa",
12
+ head_dim=64,
13
+ tie_word_embeddings=False,
14
+ hidden_act="silu",
15
+ attention_bias=False,
16
+ attention_dropout=0.0,
17
+ initializer_range=0.02,
18
+ mlp_bias=False,
19
+ model_type="llama",
20
+ num_key_value_heads=16,
21
+ pretraining_tp=1,
22
+ rms_norm_eps=1e-05,
23
+ rope_scaling=dict(
24
+ factor=8.0,
25
+ high_freq_factor=4.0,
26
+ low_freq_factor=1.0,
27
+ original_max_position_embeddings=8192,
28
+ rope_type="llama3"
29
+ ),
30
+ rope_theta=500000.0,
31
+ torch_dtype="bfloat16",
32
+ use_cache=True,
33
+ )
34
+
35
+ LLAMA_CONFIGS = {
36
+ "Llama_520M": LLAMA_520M_CONFIG_DICT,
37
+ }
src/chatterbox/models/t3/modules/cond_enc.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+
7
+ from .perceiver import Perceiver
8
+ from .t3_config import T3Config
9
+
10
+
11
+ @dataclass
12
+ class T3Cond:
13
+ """
14
+ Dataclass container for most / all conditioning info.
15
+ TODO: serialization methods aren't used, keeping them around for convenience
16
+ """
17
+
18
+ speaker_emb: Tensor
19
+ clap_emb: Optional[Tensor] = None
20
+ cond_prompt_speech_tokens: Optional[Tensor] = None
21
+ cond_prompt_speech_emb: Optional[Tensor] = None
22
+ emotion_adv: Optional[Tensor] = 0.5
23
+
24
+ def to(self, *, device=None, dtype=None):
25
+ "Cast to a device and dtype. Dtype casting is ignored for long/int tensors."
26
+ for k, v in self.__dict__.items():
27
+ if torch.is_tensor(v):
28
+ is_fp = type(v.view(-1)[0].item()) is not int
29
+ setattr(self, k, v.to(device=device, dtype=dtype if is_fp else None))
30
+ return self
31
+
32
+ def save(self, fpath):
33
+ torch.save(self.__dict__, fpath)
34
+
35
+ @staticmethod
36
+ def load(fpath, map_location="cpu"):
37
+ kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
38
+ return T3Cond(**kwargs)
39
+
40
+
41
+ class T3CondEnc(nn.Module):
42
+ """
43
+ Handle all non-text conditioning, like speaker embeddings / prompts, CLAP, emotion, etc.
44
+ """
45
+
46
+ def __init__(self, hp: T3Config):
47
+ super().__init__()
48
+ self.hp = hp
49
+ if hp.encoder_type == "voice_encoder":
50
+ self.spkr_enc = nn.Linear(hp.speaker_embed_size, hp.n_channels)
51
+ else:
52
+ raise NotImplementedError(str(hp.encoder_type))
53
+
54
+ # emotion adv
55
+ self.emotion_adv_fc = None
56
+ if hp.emotion_adv:
57
+ self.emotion_adv_fc = nn.Linear(1, hp.n_channels, bias=False)
58
+
59
+ # perceiver resampler
60
+ self.perceiver = None
61
+ if hp.use_perceiver_resampler:
62
+ self.perceiver = Perceiver()
63
+
64
+ def forward(self, cond: T3Cond):
65
+ # Validate
66
+ assert (cond.cond_prompt_speech_tokens is None) == (cond.cond_prompt_speech_emb is None), \
67
+ "no embeddings for cond_prompt_speech_tokens"
68
+
69
+ # Speaker embedding projection
70
+ cond_spkr = self.spkr_enc(cond.speaker_emb.view(-1, self.hp.speaker_embed_size))[:, None] # (B, 1, dim)
71
+ empty = torch.zeros_like(cond_spkr[:, :0]) # (B, 0, dim)
72
+
73
+ # TODO CLAP
74
+ assert cond.clap_emb is None, "clap_embed not implemented"
75
+ cond_clap = empty # (B, 0, dim)
76
+
77
+ # Cond prompt
78
+ cond_prompt_speech_emb = cond.cond_prompt_speech_emb
79
+ if cond_prompt_speech_emb is None:
80
+ cond_prompt_speech_emb = empty # (B, 0, dim)
81
+ elif self.hp.use_perceiver_resampler:
82
+ cond_prompt_speech_emb = self.perceiver(cond_prompt_speech_emb)
83
+
84
+ # Emotion Adv: must provide a value if this model uses emotion conditioning
85
+ cond_emotion_adv = empty # (B, 0, dim)
86
+ if self.hp.emotion_adv:
87
+ assert cond.emotion_adv is not None
88
+ cond_emotion_adv = self.emotion_adv_fc(cond.emotion_adv.view(-1, 1, 1))
89
+
90
+ # Concat and return
91
+ cond_embeds = torch.cat((
92
+ cond_spkr,
93
+ cond_clap,
94
+ cond_prompt_speech_emb,
95
+ cond_emotion_adv,
96
+ ), dim=1)
97
+ return cond_embeds
src/chatterbox/models/t3/modules/learned_pos_emb.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import torch
4
+ from torch import nn, Tensor
5
+
6
+
7
+ class LearnedPositionEmbeddings(nn.Module):
8
+ def __init__(self, seq_len, model_dim, init=.02):
9
+ super().__init__()
10
+ self.emb = nn.Embedding(seq_len, model_dim)
11
+ # Initializing this way is standard for GPT-2
12
+ self.emb.weight.data.normal_(mean=0.0, std=init)
13
+
14
+ def forward(self, x):
15
+ """
16
+ Returns positional embeddings for index 0 up to the length of x
17
+ """
18
+ sl = x.shape[1]
19
+ return self.emb(torch.arange(0, sl, device=x.device))
20
+
21
+ def get_fixed_embedding(self, idx: 'Union[int, Tensor]'):
22
+ """
23
+ Args:
24
+ idx: scalar int or an integer tensor of shape (T,) or (B, T)
25
+ Returns:
26
+ positional embeddings for given indices, shape (B, T, dim), ie (1, 1, dim) for int input
27
+ """
28
+ device = self.emb.weight.device
29
+ idx = idx.to(device) if torch.is_tensor(idx) else torch.tensor(idx, device=device)
30
+ idx = torch.atleast_2d(idx)
31
+ assert idx.ndim == 2
32
+ return self.emb(idx) # (B, T, dim)
src/chatterbox/models/t3/modules/perceiver.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Resemble AI
2
+ # Author: Manmay Nakhashi
3
+ # MIT License
4
+ import math
5
+
6
+ import torch
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+
11
+
12
+ class RelativePositionBias(nn.Module):
13
+ def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
14
+ super().__init__()
15
+ self.scale = scale
16
+ self.causal = causal
17
+ self.num_buckets = num_buckets
18
+ self.max_distance = max_distance
19
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
20
+
21
+ @staticmethod
22
+ def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
23
+ ret = 0
24
+ n = -relative_position
25
+ if not causal:
26
+ num_buckets //= 2
27
+ ret += (n < 0).long() * num_buckets
28
+ n = torch.abs(n)
29
+ else:
30
+ n = torch.max(n, torch.zeros_like(n))
31
+
32
+ max_exact = num_buckets // 2
33
+ is_small = n < max_exact
34
+
35
+ val_if_large = max_exact + (
36
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
37
+ ).long()
38
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
39
+
40
+ ret += torch.where(is_small, n, val_if_large)
41
+ return ret
42
+
43
+ def forward(self, qk_dots):
44
+ i, j, device = *qk_dots.shape[-2:], qk_dots.device
45
+ q_pos = torch.arange(i, dtype=torch.long, device=device)
46
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
47
+ rel_pos = k_pos[None, :] - q_pos[:, None]
48
+ rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
49
+ max_distance=self.max_distance)
50
+ values = self.relative_attention_bias(rp_bucket)
51
+ bias = rearrange(values, 'i j h -> () h i j')
52
+ return qk_dots + (bias * self.scale)
53
+
54
+
55
+ class AttentionQKV(nn.Module):
56
+ def __init__(self, n_heads, head_dim, dropout_rate=0.1, scale=None, flash=False):
57
+ super().__init__()
58
+ self.n_heads = n_heads
59
+ self.head_dim = head_dim
60
+ self.scale = scale if scale is not None else head_dim ** -0.5
61
+ self.flash = flash
62
+ self.dropout_rate = dropout_rate
63
+ self.dropout = nn.Dropout(dropout_rate)
64
+ self.flash_config = self.setup_flash_config() if flash else None
65
+
66
+ def setup_flash_config(self):
67
+ # Setup flash attention configuration
68
+ flash_config = {
69
+ 'enable_flash': True,
70
+ 'enable_math': True,
71
+ 'enable_mem_efficient': True
72
+ }
73
+ return flash_config
74
+
75
+ def forward(self, q, k, v, mask=None):
76
+ q, k, v = [self.split_heads(tensor) for tensor in [q, k, v]]
77
+ if self.flash:
78
+ out = self.flash_attention(q, k, v, mask=mask)
79
+ else:
80
+ out = self.scaled_dot_product_attention(q, k, v, mask=mask)
81
+
82
+ return self.combine_heads(out)
83
+
84
+ def scaled_dot_product_attention(self, q, k, v, mask=None):
85
+ sim = torch.einsum("bhlt,bhls->bhts", q, k) * self.scale
86
+ if mask is not None:
87
+ sim = sim.masked_fill(mask == 0, float('-inf'))
88
+ attn = torch.softmax(sim, dim=-1)
89
+ attn = self.dropout(attn)
90
+ return torch.einsum("bhts,bhls->bhlt", attn, v)
91
+
92
+ def flash_attention(self, q, k, v, mask=None):
93
+ config = self.flash_config if self.flash_config else {}
94
+ with torch.backends.cuda.sdp_kernel(**config):
95
+ out = F.scaled_dot_product_attention(
96
+ q, k, v,
97
+ attn_mask=mask,
98
+ dropout_p=self.dropout_rate if self.training else 0.
99
+ )
100
+ return out
101
+
102
+ def split_heads(self, x):
103
+ bs, length, _ = x.shape
104
+ x = x.view(bs, length, self.n_heads, self.head_dim)
105
+ return x.permute(0, 2, 1, 3)
106
+
107
+ def combine_heads(self, x):
108
+ bs, _, length, _ = x.shape
109
+ x = x.permute(0, 2, 1, 3).contiguous()
110
+ return x.view(bs, length, -1)
111
+
112
+
113
+ class AttentionBlock2(nn.Module):
114
+ """
115
+ An attention block that allows spatial positions to attend to each other,
116
+ using AttentionQKV and separate linear transformations for Q, K, and V.
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ channels,
122
+ num_heads=1,
123
+ num_head_channels=-1,
124
+ relative_pos_embeddings=False,
125
+ flash_attention=True,
126
+ dropout_rate=0.2,
127
+ scale=None
128
+ ):
129
+ super().__init__()
130
+ self.channels = channels
131
+
132
+ if num_head_channels == -1:
133
+ self.num_heads = num_heads
134
+ else:
135
+ assert (
136
+ channels % num_head_channels == 0
137
+ ), f"channels {channels} is not divisible by num_head_channels {num_head_channels}"
138
+ self.num_heads = channels // num_head_channels
139
+
140
+ self.norm = nn.LayerNorm(channels)
141
+
142
+ # Separate linear layers for Q, K, and V
143
+ self.to_q = nn.Linear(channels, channels)
144
+ self.to_k = nn.Linear(channels, channels)
145
+ self.to_v = nn.Linear(channels, channels)
146
+
147
+ self.attention = AttentionQKV(self.num_heads, channels // self.num_heads, dropout_rate=dropout_rate, flash=flash_attention, scale=scale)
148
+
149
+ self.proj_out = nn.Linear(channels, channels)
150
+
151
+ if relative_pos_embeddings:
152
+ self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
153
+ else:
154
+ self.relative_pos_embeddings = None
155
+
156
+ def forward(self, x1, x2, mask=None):
157
+ b1, c1, *spatial1 = x1.shape
158
+ b2, c2, *spatial2 = x2.shape
159
+
160
+ x1_norm = self.norm(x1)
161
+ x2_norm = self.norm(x2)
162
+
163
+ q = self.to_q(x1_norm)
164
+ k = self.to_k(x2_norm)
165
+ v = self.to_v(x2_norm)
166
+
167
+ h = self.attention(q, k, v, mask=mask)
168
+ h = self.proj_out(h)
169
+
170
+ return (x1 + h).reshape(b1, c1, *spatial1)
171
+
172
+
173
+ class Perceiver(nn.Module):
174
+ """Inspired by https://arxiv.org/abs/2103.03206"""
175
+ def __init__(self, pre_attention_query_token=32, pre_attention_query_size=1024, embedding_dim=1024, num_attn_heads=4):
176
+ """
177
+ Initialize the perceiver module.
178
+
179
+ :param pre_attention_query_token: Number of query tokens for pre-attention
180
+ :param pre_attention_query_size: Size of each query token
181
+ :param embedding_dim: Dimension of the embedding space
182
+ :param num_attn_heads: Number of attention heads
183
+ """
184
+ super().__init__()
185
+
186
+ # Initialize the pre-attention query parameter
187
+ self.pre_attention_query = torch.nn.Parameter(
188
+ torch.empty(1, pre_attention_query_token, pre_attention_query_size)
189
+ )
190
+
191
+ # Calculate the variance for uniform initialization
192
+ query_variance = math.sqrt(3.0) * math.sqrt(2.0 / (pre_attention_query_token + pre_attention_query_token))
193
+
194
+ # Initialize the pre-attention query with uniform distribution
195
+ self.pre_attention_query.data.uniform_(-query_variance, query_variance)
196
+
197
+ # Initialize the attention block
198
+ self.attn = AttentionBlock2(embedding_dim, num_attn_heads)
199
+
200
+ def forward(self, h):
201
+ """
202
+ Forward pass of the perceiver module.
203
+ :param h: Input tensor
204
+ :return: Output after applying attention mechanisms
205
+ """
206
+ # Expand the pre-attention query to match the batch size of the input
207
+ query_ = self.pre_attention_query.expand(h.shape[0], -1, -1)
208
+ # Apply the first attention mechanism (cross-attention)
209
+ pre_att = self.attn(query_, h)
210
+ # Apply the second attention mechanism (self-attention)
211
+ attn = self.attn(pre_att, pre_att)
212
+ return attn
src/chatterbox/models/t3/modules/t3_config.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..llama_configs import LLAMA_CONFIGS
2
+
3
+
4
+ class T3Config:
5
+ def __init__(self, text_tokens_dict_size=704):
6
+ self.start_text_token = 255
7
+ self.stop_text_token = 0
8
+ self.text_tokens_dict_size = text_tokens_dict_size
9
+ self.max_text_tokens = 2048
10
+
11
+ self.start_speech_token = 6561
12
+ self.stop_speech_token = 6562
13
+ self.speech_tokens_dict_size = 8194
14
+ self.max_speech_tokens = 4096
15
+
16
+ self.llama_config_name = "Llama_520M"
17
+ self.input_pos_emb = "learned"
18
+ self.speech_cond_prompt_len = 150
19
+
20
+ self.encoder_type = "voice_encoder"
21
+ self.speaker_embed_size = 256
22
+ self.use_perceiver_resampler = True
23
+ self.emotion_adv = True
24
+
25
+ @property
26
+ def n_channels(self):
27
+ return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"]
28
+
29
+ @classmethod
30
+ def english_only(cls):
31
+ """Create configuration for English-only TTS model."""
32
+ return cls(text_tokens_dict_size=704)
33
+
34
+ @classmethod
35
+ def multilingual(cls):
36
+ """Create configuration for multilingual TTS model."""
37
+ return cls(text_tokens_dict_size=2352)
src/chatterbox/models/t3/t3.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Resemble AI
2
+ # MIT License
3
+ import logging
4
+ from typing import Union, Optional, List
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ from tqdm import tqdm
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn, Tensor
12
+ from transformers import LlamaModel, LlamaConfig
13
+ from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor, MinPLogitsWarper
14
+
15
+ from .modules.learned_pos_emb import LearnedPositionEmbeddings
16
+
17
+ from .modules.cond_enc import T3CondEnc, T3Cond
18
+ from .modules.t3_config import T3Config
19
+ from .llama_configs import LLAMA_CONFIGS
20
+ from .inference.t3_hf_backend import T3HuggingfaceBackend
21
+ from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer
22
+ from ..utils import AttrDict
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ def _ensure_BOT_EOT(text_tokens: Tensor, hp):
29
+ B = text_tokens.size(0)
30
+ assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token"
31
+ assert (text_tokens == hp.stop_text_token).int().sum() >= B, "missing stop_text_token"
32
+
33
+
34
+ class T3(nn.Module):
35
+ """
36
+ Token-To-Token (T3) TTS model using huggingface transformer models as backbones,
37
+ * tokenization, including start / stop tokens are always added externally to this class
38
+ * conditioning data like CLAP, emotion, etc are all in a separate file for more modularity
39
+ * careful! this class assumes relative positional encoding -- with absolute PE, we would at
40
+ least want to reset the position to 0 when speech tokens begin, and optionally use a
41
+ different PE embedding space for speech.
42
+ """
43
+
44
+ def __init__(self, hp=None):
45
+ if hp is None:
46
+ hp = T3Config.english_only() # Default to English-only config for backward compatibility
47
+ super().__init__()
48
+ self.hp = hp
49
+ self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
50
+ self.tfmr = LlamaModel(self.cfg)
51
+ self.dim = self.cfg.hidden_size
52
+ self.deepspeed_patch_applied = False
53
+
54
+ # conditioning / embedding
55
+ self.cond_enc = T3CondEnc(hp)
56
+ self.text_emb = nn.Embedding(hp.text_tokens_dict_size, self.dim)
57
+ self.speech_emb = nn.Embedding(hp.speech_tokens_dict_size, self.dim)
58
+
59
+ # custom position embedding
60
+ if hp.input_pos_emb == "learned":
61
+ max_text_seq_len = hp.max_text_tokens + 2
62
+ self.text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, self.dim)
63
+
64
+ max_mel_seq_len = hp.max_speech_tokens + 2 + 2
65
+ self.speech_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, self.dim)
66
+
67
+ # logit projection
68
+ self.text_head = nn.Linear(self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False)
69
+ self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=False)
70
+ self.compiled = False
71
+
72
+ @property
73
+ def device(self):
74
+ return self.speech_head.weight.device
75
+
76
+ def prepare_conditioning(self, t3_cond: T3Cond):
77
+ """
78
+ Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`.
79
+ """
80
+ if t3_cond.cond_prompt_speech_tokens is not None and t3_cond.cond_prompt_speech_emb is None:
81
+ t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens) + \
82
+ self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
83
+ return self.cond_enc(t3_cond) # (B, len_cond, dim)
84
+
85
+ def prepare_input_embeds(
86
+ self,
87
+ *,
88
+ t3_cond: T3Cond,
89
+ text_tokens: torch.LongTensor,
90
+ speech_tokens: torch.LongTensor,
91
+ cfg_weight: float = 0.0,
92
+ ):
93
+ # prepare input embeddings (skip backbone tranformer embeddings)
94
+ cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim)
95
+ text_emb = self.text_emb(text_tokens) # (B, len_text, dim)
96
+ if cfg_weight > 0.0:
97
+ text_emb[1].zero_() # CFG uncond
98
+
99
+ speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim)
100
+ if self.hp.input_pos_emb == "learned":
101
+ text_emb = text_emb + self.text_pos_emb(text_tokens)
102
+ speech_emb = speech_emb + self.speech_pos_emb(speech_tokens)
103
+ len_cond = cond_emb.size(1)
104
+
105
+ if cond_emb.size(0) != text_emb.size(0):
106
+ cond_emb = cond_emb.expand(text_emb.size(0), -1, -1)
107
+
108
+ # concat
109
+ embeds = torch.stack([
110
+ torch.cat((ce, te, se))
111
+ for ce, te, se in zip(cond_emb, text_emb, speech_emb)
112
+ ]) # (B, length, dim)
113
+ return embeds, len_cond
114
+
115
+ def forward(
116
+ self,
117
+ *,
118
+ t3_cond: T3Cond,
119
+ text_tokens: torch.LongTensor,
120
+ text_token_lens: torch.LongTensor,
121
+ speech_tokens: torch.LongTensor,
122
+ speech_token_lens: torch.LongTensor,
123
+ training=False,
124
+ ):
125
+ _ensure_BOT_EOT(text_tokens, self.hp)
126
+
127
+ # prepare custom input embeds
128
+ embeds, len_cond = self.prepare_input_embeds(
129
+ t3_cond=t3_cond,
130
+ text_tokens=text_tokens,
131
+ speech_tokens=speech_tokens,
132
+ )
133
+
134
+ # backbone tranformer forward
135
+ tfmr_out = self.tfmr.forward(
136
+ input_ids=None,
137
+ # position_ids=position_ids, # TODO? ROPE should be fine?
138
+ inputs_embeds=embeds,
139
+ output_hidden_states=True,
140
+ return_dict=True,
141
+ use_cache=(not training),
142
+ )
143
+ hidden_states = tfmr_out.hidden_states[-1] # final tfmr layer output, (B, seq, dim)
144
+
145
+ # post-processing: splice out text and speech parts of hidden states
146
+ len_text = text_tokens.size(1)
147
+ len_speech = speech_tokens.size(1)
148
+ B, _, dim = hidden_states.shape
149
+ device, dtype = hidden_states.device, hidden_states.dtype
150
+ text_latents = torch.zeros(B, len_text, dim, dtype=dtype, device=device)
151
+ speech_latents = torch.zeros(B, len_speech, dim, dtype=dtype, device=device)
152
+ ttl, stl = text_token_lens, speech_token_lens
153
+ for i in range(B):
154
+ text_end = len_cond + ttl[i].item()
155
+ speech_start = len_cond + text_tokens.size(1)
156
+ speech_end = speech_start + stl[i].item()
157
+ text_latents[i, :ttl[i]] = hidden_states[i, len_cond:text_end]
158
+ speech_latents[i, :stl[i]] = hidden_states[i, speech_start:speech_end]
159
+
160
+ # logit projection
161
+ text_logits = self.text_head(text_latents)
162
+ speech_logits = self.speech_head(speech_latents)
163
+
164
+ return AttrDict(
165
+ text_logits=text_logits,
166
+ text_latents=text_latents,
167
+ speech_logits=speech_logits,
168
+ speech_latents=speech_latents,
169
+ hidden_states=hidden_states,
170
+ )
171
+
172
+ def loss(
173
+ self,
174
+ *,
175
+ t3_cond: T3Cond,
176
+ text_tokens: torch.LongTensor,
177
+ text_token_lens: torch.LongTensor,
178
+ speech_tokens: torch.LongTensor,
179
+ speech_token_lens: torch.LongTensor,
180
+ ):
181
+ "training method"
182
+ len_text = text_tokens.size(1)
183
+ len_speech = speech_tokens.size(1)
184
+ assert len_text == text_token_lens.max()
185
+ assert len_speech == speech_token_lens.max()
186
+
187
+ out = self.forward(
188
+ t3_cond=t3_cond,
189
+ text_tokens=text_tokens,
190
+ text_token_lens=text_token_lens,
191
+ speech_tokens=speech_tokens,
192
+ speech_token_lens=speech_token_lens,
193
+ training=True,
194
+ ) # (B, seq, vocab_size)
195
+
196
+ # Calc CCE losses
197
+ IGNORE_ID = -100
198
+ device = out.text_logits.device
199
+ mask_text = torch.arange(len_text, device=device)[None] >= text_token_lens[:, None] # (B, len_text)
200
+ mask_speech = torch.arange(len_speech, device=device)[None] >= speech_token_lens[:, None] # (B, len_speech)
201
+ masked_text = text_tokens.masked_fill(mask_text, IGNORE_ID)
202
+ masked_speech = speech_tokens.masked_fill(mask_speech, IGNORE_ID)
203
+ loss_text = F.cross_entropy(out.text_logits, masked_text, ignore_index=IGNORE_ID)
204
+ loss_speech = F.cross_entropy(out.speech_logits, masked_speech, ignore_index=IGNORE_ID)
205
+
206
+ return loss_text, loss_speech
207
+
208
+ @torch.inference_mode()
209
+ def inference(
210
+ self,
211
+ *,
212
+ t3_cond: T3Cond,
213
+ text_tokens: Tensor,
214
+ initial_speech_tokens: Optional[Tensor]=None,
215
+
216
+ # misc conditioning
217
+ prepend_prompt_speech_tokens: Optional[Tensor]=None,
218
+
219
+ # HF generate args
220
+ num_return_sequences=1,
221
+ max_new_tokens=None,
222
+ stop_on_eos=True,
223
+ do_sample=True,
224
+ temperature=0.8,
225
+ top_p=0.95,
226
+ min_p=0.05,
227
+ length_penalty=1.0,
228
+ repetition_penalty=1.2,
229
+ cfg_weight=0.5,
230
+ ):
231
+ """
232
+ Args:
233
+ text_tokens: a 1D (unbatched) or 2D (batched) tensor.
234
+ """
235
+ # Validate / sanitize inputs
236
+ assert prepend_prompt_speech_tokens is None, "not implemented"
237
+ _ensure_BOT_EOT(text_tokens, self.hp)
238
+ text_tokens = torch.atleast_2d(text_tokens).to(dtype=torch.long, device=self.device)
239
+
240
+ # Default initial speech to a single start-of-speech token
241
+ if initial_speech_tokens is None:
242
+ initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1])
243
+
244
+ # Prepare custom input embeds
245
+ embeds, len_cond = self.prepare_input_embeds(
246
+ t3_cond=t3_cond,
247
+ text_tokens=text_tokens,
248
+ speech_tokens=initial_speech_tokens,
249
+ cfg_weight=cfg_weight,
250
+ )
251
+
252
+ # In order to use the standard HF generate method, we need to extend some methods to inject our custom logic
253
+ # Note the llama-specific logic. Other tfmr types can be added later.
254
+
255
+ self.compiled = False
256
+
257
+ # TODO? synchronize the expensive compile function
258
+ # with self.compile_lock:
259
+ if not self.compiled:
260
+ alignment_stream_analyzer = AlignmentStreamAnalyzer(
261
+ self.tfmr,
262
+ None,
263
+ text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
264
+ alignment_layer_idx=9, # TODO: hparam or something?
265
+ eos_idx=self.hp.stop_speech_token,
266
+ )
267
+ assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
268
+
269
+ patched_model = T3HuggingfaceBackend(
270
+ config=self.cfg,
271
+ llama=self.tfmr,
272
+ speech_enc=self.speech_emb,
273
+ speech_head=self.speech_head,
274
+ alignment_stream_analyzer=alignment_stream_analyzer,
275
+ )
276
+ self.patched_model = patched_model
277
+ self.compiled = True
278
+
279
+ # # Run normal generate method, which calls our custom extended methods
280
+ # return self.patched_model.generate(
281
+ # inputs=initial_speech_tokens,
282
+ # decoder_cond=embeds,
283
+ # bos_token_id=self.hp.start_speech_token,
284
+ # eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1),
285
+ # pad_token_id=self.hp.stop_speech_token,
286
+ # max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
287
+ # num_return_sequences=num_return_sequences,
288
+ # temperature=temperature,
289
+ # min_p=min_p,
290
+ # length_penalty=length_penalty,
291
+ # repetition_penalty=repetition_penalty,
292
+ # do_sample=do_sample,
293
+ # # cache_implementation=None if not self.compiled else "static",
294
+ # )
295
+
296
+ device = embeds.device
297
+
298
+ bos_token = torch.tensor([[self.hp.start_speech_token]], dtype=torch.long, device=device)
299
+ bos_embed = self.speech_emb(bos_token) # shape: (B, 1, embed_dim)
300
+ bos_embed = bos_embed + self.speech_pos_emb.get_fixed_embedding(0)
301
+
302
+ # batch_size=2 for CFG
303
+ bos_embed = torch.cat([bos_embed, bos_embed])
304
+
305
+ # Combine condition and BOS token for the initial input
306
+ inputs_embeds = torch.cat([embeds, bos_embed], dim=1)
307
+
308
+ # Track generated token ids; start with the BOS token.
309
+ generated_ids = bos_token.clone()
310
+ predicted = [] # To store the predicted tokens
311
+
312
+ # Instantiate the logits processors.
313
+ top_p_warper = TopPLogitsWarper(top_p=top_p)
314
+ min_p_warper = MinPLogitsWarper(min_p=min_p)
315
+ top_p_warper = TopPLogitsWarper(top_p=top_p)
316
+ repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty))
317
+
318
+ # ---- Initial Forward Pass (no kv_cache yet) ----
319
+ output = self.patched_model(
320
+ inputs_embeds=inputs_embeds,
321
+ past_key_values=None,
322
+ use_cache=True,
323
+ output_attentions=True,
324
+ output_hidden_states=True,
325
+ return_dict=True,
326
+ )
327
+ # Initialize kv_cache with the full context.
328
+ past = output.past_key_values
329
+
330
+ # ---- Generation Loop using kv_cache ----
331
+ for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
332
+ logits_step = output.logits[:, -1, :]
333
+ # CFG combine → (1, V)
334
+ cond = logits_step[0:1, :]
335
+ uncond = logits_step[1:2, :]
336
+ cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
337
+ logits = cond + cfg * (cond - uncond)
338
+
339
+ # Apply alignment stream analyzer integrity checks
340
+ if self.patched_model.alignment_stream_analyzer is not None:
341
+ if logits.dim() == 1: # guard in case something upstream squeezed
342
+ logits = logits.unsqueeze(0) # (1, V)
343
+ # Pass the last generated token for repetition tracking
344
+ last_token = generated_ids[0, -1].item() if len(generated_ids[0]) > 0 else None
345
+ logits = self.patched_model.alignment_stream_analyzer.step(logits, next_token=last_token) # (1, V)
346
+
347
+ # Apply repetition penalty
348
+ ids_for_proc = generated_ids[:1, ...] # batch = 1
349
+ logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V)
350
+
351
+ # Apply temperature scaling.
352
+ if temperature != 1.0:
353
+ logits = logits / temperature
354
+
355
+ # Apply min_p and top_p filtering
356
+ logits = min_p_warper(ids_for_proc, logits)
357
+ logits = top_p_warper(ids_for_proc, logits)
358
+
359
+ # Convert logits to probabilities and sample the next token.
360
+ probs = torch.softmax(logits, dim=-1)
361
+ next_token = torch.multinomial(probs, num_samples=1) # shape: (B, 1)
362
+
363
+ predicted.append(next_token)
364
+ generated_ids = torch.cat([generated_ids, next_token], dim=1)
365
+
366
+ # Check for EOS token.
367
+ if next_token.view(-1) == self.hp.stop_speech_token:
368
+ logger.info(f"✅ EOS token detected! Stopping generation at step {i+1}")
369
+ break
370
+
371
+ # Get embedding for the new token.
372
+ next_token_embed = self.speech_emb(next_token)
373
+ next_token_embed = next_token_embed + self.speech_pos_emb.get_fixed_embedding(i + 1)
374
+
375
+ # For CFG
376
+ next_token_embed = torch.cat([next_token_embed, next_token_embed])
377
+
378
+ # Forward pass with only the new token and the cached past.
379
+ output = self.patched_model(
380
+ inputs_embeds=next_token_embed,
381
+ past_key_values=past,
382
+ output_attentions=True,
383
+ output_hidden_states=True,
384
+ return_dict=True,
385
+ )
386
+ # Update the kv_cache.
387
+ past = output.past_key_values
388
+
389
+ # Concatenate all predicted tokens along the sequence dimension.
390
+ predicted_tokens = torch.cat(predicted, dim=1) # shape: (B, num_tokens)
391
+ return predicted_tokens
src/chatterbox/models/tokenizers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .tokenizer import EnTokenizer, MTLTokenizer
src/chatterbox/models/tokenizers/tokenizer.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import json
3
+ import re
4
+
5
+ import torch
6
+ from pathlib import Path
7
+ from unicodedata import category
8
+ from tokenizers import Tokenizer
9
+ from huggingface_hub import hf_hub_download
10
+
11
+
12
+ # Special tokens
13
+ SOT = "[START]"
14
+ EOT = "[STOP]"
15
+ UNK = "[UNK]"
16
+ SPACE = "[SPACE]"
17
+ SPECIAL_TOKENS = [SOT, EOT, UNK, SPACE, "[PAD]", "[SEP]", "[CLS]", "[MASK]"]
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ class EnTokenizer:
22
+ def __init__(self, vocab_file_path):
23
+ self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
24
+ self.check_vocabset_sot_eot()
25
+
26
+ def check_vocabset_sot_eot(self):
27
+ voc = self.tokenizer.get_vocab()
28
+ assert SOT in voc
29
+ assert EOT in voc
30
+
31
+ def text_to_tokens(self, text: str):
32
+ text_tokens = self.encode(text)
33
+ text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
34
+ return text_tokens
35
+
36
+ def encode( self, txt: str, verbose=False):
37
+ """
38
+ clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer
39
+ """
40
+ txt = txt.replace(' ', SPACE)
41
+ code = self.tokenizer.encode(txt)
42
+ ids = code.ids
43
+ return ids
44
+
45
+ def decode(self, seq):
46
+ if isinstance(seq, torch.Tensor):
47
+ seq = seq.cpu().numpy()
48
+
49
+ txt: str = self.tokenizer.decode(seq,
50
+ skip_special_tokens=False)
51
+ txt = txt.replace(' ', '')
52
+ txt = txt.replace(SPACE, ' ')
53
+ txt = txt.replace(EOT, '')
54
+ txt = txt.replace(UNK, '')
55
+ return txt
56
+
57
+
58
+ # Model repository
59
+ REPO_ID = "ResembleAI/chatterbox-multilingual"
60
+
61
+ # Global instances for optional dependencies
62
+ _kakasi = None
63
+ _dicta = None
64
+
65
+
66
+ def is_kanji(c: str) -> bool:
67
+ """Check if character is kanji."""
68
+ return 19968 <= ord(c) <= 40959
69
+
70
+
71
+ def is_katakana(c: str) -> bool:
72
+ """Check if character is katakana."""
73
+ return 12449 <= ord(c) <= 12538
74
+
75
+
76
+ def hiragana_normalize(text: str) -> str:
77
+ """Japanese text normalization: converts kanji to hiragana; katakana remains the same."""
78
+ global _kakasi
79
+
80
+ try:
81
+ if _kakasi is None:
82
+ import pykakasi
83
+ _kakasi = pykakasi.kakasi()
84
+
85
+ result = _kakasi.convert(text)
86
+ out = []
87
+
88
+ for r in result:
89
+ inp = r['orig']
90
+ hira = r["hira"]
91
+
92
+ # Any kanji in the phrase
93
+ if any([is_kanji(c) for c in inp]):
94
+ if hira and hira[0] in ["は", "へ"]: # Safety check for empty hira
95
+ hira = " " + hira
96
+ out.append(hira)
97
+
98
+ # All katakana
99
+ elif all([is_katakana(c) for c in inp]) if inp else False: # Safety check for empty inp
100
+ out.append(r['orig'])
101
+
102
+ else:
103
+ out.append(inp)
104
+
105
+ normalized_text = "".join(out)
106
+
107
+ # Decompose Japanese characters for tokenizer compatibility
108
+ import unicodedata
109
+ normalized_text = unicodedata.normalize('NFKD', normalized_text)
110
+
111
+ return normalized_text
112
+
113
+ except ImportError:
114
+ logger.warning("pykakasi not available - Japanese text processing skipped")
115
+ return text
116
+
117
+
118
+ def add_hebrew_diacritics(text: str) -> str:
119
+ """Hebrew text normalization: adds diacritics to Hebrew text."""
120
+ global _dicta
121
+
122
+ try:
123
+ if _dicta is None:
124
+ from dicta_onnx import Dicta
125
+ _dicta = Dicta()
126
+
127
+ return _dicta.add_diacritics(text)
128
+
129
+ except ImportError:
130
+ logger.warning("dicta_onnx not available - Hebrew text processing skipped")
131
+ return text
132
+ except Exception as e:
133
+ logger.warning(f"Hebrew diacritization failed: {e}")
134
+ return text
135
+
136
+
137
+ def korean_normalize(text: str) -> str:
138
+ """Korean text normalization: decompose syllables into Jamo for tokenization."""
139
+
140
+ def decompose_hangul(char):
141
+ """Decompose Korean syllable into Jamo components."""
142
+ if not ('\uac00' <= char <= '\ud7af'):
143
+ return char
144
+
145
+ # Hangul decomposition formula
146
+ base = ord(char) - 0xAC00
147
+ initial = chr(0x1100 + base // (21 * 28))
148
+ medial = chr(0x1161 + (base % (21 * 28)) // 28)
149
+ final = chr(0x11A7 + base % 28) if base % 28 > 0 else ''
150
+
151
+ return initial + medial + final
152
+
153
+ # Decompose syllables and normalize punctuation
154
+ result = ''.join(decompose_hangul(char) for char in text)
155
+ result = re.sub(r'[…~?!,:;()「」『』]', '.', result) # Korean punctuation
156
+
157
+ return result.strip()
158
+
159
+
160
+ class ChineseCangjieConverter:
161
+ """Converts Chinese characters to Cangjie codes for tokenization."""
162
+
163
+ def __init__(self, model_dir=None):
164
+ self.word2cj = {}
165
+ self.cj2word = {}
166
+ self.segmenter = None
167
+ self._load_cangjie_mapping(model_dir)
168
+ self._init_segmenter()
169
+
170
+ def _load_cangjie_mapping(self, model_dir=None):
171
+ """Load Cangjie mapping from HuggingFace model repository."""
172
+ try:
173
+ cangjie_file = hf_hub_download(
174
+ repo_id=REPO_ID,
175
+ filename="Cangjie5_TC.json",
176
+ cache_dir=model_dir
177
+ )
178
+
179
+ with open(cangjie_file, "r", encoding="utf-8") as fp:
180
+ data = json.load(fp)
181
+
182
+ for entry in data:
183
+ word, code = entry.split("\t")[:2]
184
+ self.word2cj[word] = code
185
+ if code not in self.cj2word:
186
+ self.cj2word[code] = [word]
187
+ else:
188
+ self.cj2word[code].append(word)
189
+
190
+ except Exception as e:
191
+ logger.warning(f"Could not load Cangjie mapping: {e}")
192
+
193
+ def _init_segmenter(self):
194
+ """Initialize pkuseg segmenter."""
195
+ try:
196
+ from pkuseg import pkuseg
197
+ self.segmenter = pkuseg()
198
+ except ImportError:
199
+ logger.warning("pkuseg not available - Chinese segmentation will be skipped")
200
+ self.segmenter = None
201
+
202
+ def _cangjie_encode(self, glyph: str):
203
+ """Encode a single Chinese glyph to Cangjie code."""
204
+ code = self.word2cj.get(glyph, None)
205
+ if code is None:
206
+ return None
207
+
208
+ index = self.cj2word[code].index(glyph)
209
+ index_suffix = str(index) if index > 0 else ""
210
+ return code + index_suffix
211
+
212
+ def _normalize_numbers(self, text: str) -> str:
213
+ """Convert Arabic numerals (1-99) to Chinese characters."""
214
+ digit_map = {'0': '零', '1': '一', '2': '二', '3': '三', '4': '四',
215
+ '5': '五', '6': '六', '7': '七', '8': '八', '9': '九'}
216
+
217
+ pattern = re.compile(r'(?<!\d)(\d{1,2})(?!\d)')
218
+
219
+ def convert_number(match):
220
+ num = int(match.group(1))
221
+
222
+ if num == 0:
223
+ return '零'
224
+ elif 1 <= num <= 9:
225
+ return digit_map[str(num)]
226
+ elif num == 10:
227
+ return '十'
228
+ elif 11 <= num <= 19:
229
+ return '十' + digit_map[str(num % 10)]
230
+ elif 20 <= num <= 99:
231
+ tens, ones = divmod(num, 10)
232
+ if ones == 0:
233
+ return digit_map[str(tens)] + '十'
234
+ else:
235
+ return digit_map[str(tens)] + '十' + digit_map[str(ones)]
236
+ else:
237
+ return match.group(1)
238
+
239
+ return pattern.sub(convert_number, text)
240
+
241
+ def convert_chinese_text(self, text: str) -> str:
242
+ """Convert Chinese characters in text to Cangjie tokens."""
243
+ text = re.sub('[、,:;〜-()⦅⦆]', ',', text)
244
+ text = re.sub('(。|…)', '.', text)
245
+ text = self._normalize_numbers(text)
246
+
247
+ # Skip segmentation for simple sequences (numbers, punctuation, short phrases)
248
+ if self.segmenter is not None:
249
+ # This avoids over-segmenting number sequences like "一, 二, 三"
250
+ is_simple_sequence = (
251
+ len([c for c in text if category(c) == "Lo"]) <= 15 and # Max 15 Chinese chars
252
+ text.count(',') >= 2 # Contains multiple commas (likely enumeration)
253
+ )
254
+
255
+ # Only segment complex Chinese text (longer sentences without enumeration patterns)
256
+ if not is_simple_sequence and len(text) > 10:
257
+ chinese_chars = sum(1 for c in text if category(c) == "Lo")
258
+ total_chars = len([c for c in text if c.strip()])
259
+
260
+ if chinese_chars > 5 and chinese_chars / total_chars > 0.7:
261
+ segmented_words = self.segmenter.cut(text)
262
+ text = " ".join(segmented_words)
263
+
264
+ output = []
265
+ for char in text:
266
+ if category(char) == "Lo": # Chinese character
267
+ cangjie = self._cangjie_encode(char)
268
+ if cangjie is None:
269
+ output.append(char)
270
+ continue
271
+
272
+ code_tokens = [f"[cj_{c}]" for c in cangjie]
273
+ code_tokens.append("[cj_.]")
274
+
275
+ output.append("".join(code_tokens))
276
+ else:
277
+ output.append(char)
278
+
279
+ return "".join(output)
280
+
281
+
282
+ class MTLTokenizer:
283
+ def __init__(self, vocab_file_path):
284
+ self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
285
+ model_dir = Path(vocab_file_path).parent
286
+ self.cangjie_converter = ChineseCangjieConverter(model_dir)
287
+ self.check_vocabset_sot_eot()
288
+
289
+ def check_vocabset_sot_eot(self):
290
+ voc = self.tokenizer.get_vocab()
291
+ assert SOT in voc
292
+ assert EOT in voc
293
+
294
+ def text_to_tokens(self, text: str, language_id: str = None):
295
+ text_tokens = self.encode(text, language_id=language_id)
296
+ text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
297
+ return text_tokens
298
+
299
+ def encode(self, txt: str, language_id: str = None):
300
+ # Language-specific text processing
301
+ if language_id == 'zh':
302
+ txt = self.cangjie_converter.convert_chinese_text(txt)
303
+ elif language_id == 'ja':
304
+ txt = hiragana_normalize(txt)
305
+ elif language_id == 'he':
306
+ txt = add_hebrew_diacritics(txt)
307
+ elif language_id == 'ko':
308
+ txt = korean_normalize(txt)
309
+
310
+ # Prepend language token
311
+ if language_id:
312
+ txt = f"[{language_id.lower()}]{txt}"
313
+
314
+ txt = txt.replace(' ', SPACE)
315
+ return self.tokenizer.encode(txt).ids
316
+
317
+ def decode(self, seq):
318
+ if isinstance(seq, torch.Tensor):
319
+ seq = seq.cpu().numpy()
320
+
321
+ txt = self.tokenizer.decode(seq, skip_special_tokens=False)
322
+ txt = txt.replace(' ', '').replace(SPACE, ' ').replace(EOT, '').replace(UNK, '')
323
+ return txt
src/chatterbox/models/utils.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ class AttrDict(dict):
2
+ def __init__(self, *args, **kwargs):
3
+ super(AttrDict, self).__init__(*args, **kwargs)
4
+ self.__dict__ = self
src/chatterbox/models/voice_encoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .voice_encoder import VoiceEncoder, VoiceEncConfig
src/chatterbox/models/voice_encoder/config.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class VoiceEncConfig:
2
+ num_mels = 40
3
+ sample_rate = 16000
4
+ speaker_embed_size = 256
5
+ ve_hidden_size = 256
6
+ flatten_lstm_params = False
7
+ n_fft = 400
8
+ hop_size = 160
9
+ win_size = 400
10
+ fmax = 8000
11
+ fmin = 0
12
+ preemphasis = 0.
13
+ mel_power = 2.0
14
+ mel_type = "amp"
15
+ normalized_mels = False
16
+ ve_partial_frames = 160
17
+ ve_final_relu = True
18
+ stft_magnitude_min = 1e-4
src/chatterbox/models/voice_encoder/melspec.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+
3
+ from scipy import signal
4
+ import numpy as np
5
+ import librosa
6
+
7
+
8
+ @lru_cache()
9
+ def mel_basis(hp):
10
+ assert hp.fmax <= hp.sample_rate // 2
11
+ return librosa.filters.mel(
12
+ sr=hp.sample_rate,
13
+ n_fft=hp.n_fft,
14
+ n_mels=hp.num_mels,
15
+ fmin=hp.fmin,
16
+ fmax=hp.fmax) # -> (nmel, nfreq)
17
+
18
+
19
+ def preemphasis(wav, hp):
20
+ assert hp.preemphasis != 0
21
+ wav = signal.lfilter([1, -hp.preemphasis], [1], wav)
22
+ wav = np.clip(wav, -1, 1)
23
+ return wav
24
+
25
+
26
+ def melspectrogram(wav, hp, pad=True):
27
+ # Run through pre-emphasis
28
+ if hp.preemphasis > 0:
29
+ wav = preemphasis(wav, hp)
30
+ assert np.abs(wav).max() - 1 < 1e-07
31
+
32
+ # Do the stft
33
+ spec_complex = _stft(wav, hp, pad=pad)
34
+
35
+ # Get the magnitudes
36
+ spec_magnitudes = np.abs(spec_complex)
37
+
38
+ if hp.mel_power != 1.0:
39
+ spec_magnitudes **= hp.mel_power
40
+
41
+ # Get the mel and convert magnitudes->db
42
+ mel = np.dot(mel_basis(hp), spec_magnitudes)
43
+ if hp.mel_type == "db":
44
+ mel = _amp_to_db(mel, hp)
45
+
46
+ # Normalise the mel from db to 0,1
47
+ if hp.normalized_mels:
48
+ mel = _normalize(mel, hp).astype(np.float32)
49
+
50
+ assert not pad or mel.shape[1] == 1 + len(wav) // hp.hop_size # Sanity check
51
+ return mel # (M, T)
52
+
53
+
54
+ def _stft(y, hp, pad=True):
55
+ # NOTE: after 0.8, pad mode defaults to constant, setting this to reflect for
56
+ # historical consistency and streaming-version consistency
57
+ return librosa.stft(
58
+ y,
59
+ n_fft=hp.n_fft,
60
+ hop_length=hp.hop_size,
61
+ win_length=hp.win_size,
62
+ center=pad,
63
+ pad_mode="reflect",
64
+ )
65
+
66
+
67
+ def _amp_to_db(x, hp):
68
+ return 20 * np.log10(np.maximum(hp.stft_magnitude_min, x))
69
+
70
+
71
+ def _db_to_amp(x):
72
+ return np.power(10.0, x * 0.05)
73
+
74
+
75
+ def _normalize(s, hp, headroom_db=15):
76
+ min_level_db = 20 * np.log10(hp.stft_magnitude_min)
77
+ s = (s - min_level_db) / (-min_level_db + headroom_db)
78
+ return s
src/chatterbox/models/voice_encoder/voice_encoder.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/CorentinJ/Real-Time-Voice-Cloning
2
+ # MIT License
3
+ from typing import List, Union, Optional
4
+
5
+ import numpy as np
6
+ from numpy.lib.stride_tricks import as_strided
7
+ import librosa
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn, Tensor
11
+
12
+ from .config import VoiceEncConfig
13
+ from .melspec import melspectrogram
14
+
15
+
16
+ def pack(arrays, seq_len: int=None, pad_value=0):
17
+ """
18
+ Given a list of length B of array-like objects of shapes (Ti, ...), packs them in a single tensor of
19
+ shape (B, T, ...) by padding each individual array on the right.
20
+
21
+ :param arrays: a list of array-like objects of matching shapes except for the first axis.
22
+ :param seq_len: the value of T. It must be the maximum of the lengths Ti of the arrays at
23
+ minimum. Will default to that value if None.
24
+ :param pad_value: the value to pad the arrays with.
25
+ :return: a (B, T, ...) tensor
26
+ """
27
+ if seq_len is None:
28
+ seq_len = max(len(array) for array in arrays)
29
+ else:
30
+ assert seq_len >= max(len(array) for array in arrays)
31
+
32
+ # Convert lists to np.array
33
+ if isinstance(arrays[0], list):
34
+ arrays = [np.array(array) for array in arrays]
35
+
36
+ # Convert to tensor and handle device
37
+ device = None
38
+ if isinstance(arrays[0], torch.Tensor):
39
+ tensors = arrays
40
+ device = tensors[0].device
41
+ else:
42
+ tensors = [torch.as_tensor(array) for array in arrays]
43
+
44
+ # Fill the packed tensor with the array data
45
+ packed_shape = (len(tensors), seq_len, *tensors[0].shape[1:])
46
+ packed_tensor = torch.full(packed_shape, pad_value, dtype=tensors[0].dtype, device=device)
47
+
48
+ for i, tensor in enumerate(tensors):
49
+ packed_tensor[i, :tensor.size(0)] = tensor
50
+
51
+ return packed_tensor
52
+
53
+
54
+ def get_num_wins(
55
+ n_frames: int,
56
+ step: int,
57
+ min_coverage: float,
58
+ hp: VoiceEncConfig,
59
+ ):
60
+ assert n_frames > 0
61
+ win_size = hp.ve_partial_frames
62
+ n_wins, remainder = divmod(max(n_frames - win_size + step, 0), step)
63
+ if n_wins == 0 or (remainder + (win_size - step)) / win_size >= min_coverage:
64
+ n_wins += 1
65
+ target_n = win_size + step * (n_wins - 1)
66
+ return n_wins, target_n
67
+
68
+
69
+ def get_frame_step(
70
+ overlap: float,
71
+ rate: float,
72
+ hp: VoiceEncConfig,
73
+ ):
74
+ # Compute how many frames separate two partial utterances
75
+ assert 0 <= overlap < 1
76
+ if rate is None:
77
+ frame_step = int(np.round(hp.ve_partial_frames * (1 - overlap)))
78
+ else:
79
+ frame_step = int(np.round((hp.sample_rate / rate) / hp.ve_partial_frames))
80
+ assert 0 < frame_step <= hp.ve_partial_frames
81
+ return frame_step
82
+
83
+
84
+ def stride_as_partials(
85
+ mel: np.ndarray,
86
+ hp: VoiceEncConfig,
87
+ overlap=0.5,
88
+ rate: float=None,
89
+ min_coverage=0.8,
90
+ ):
91
+ """
92
+ Takes unscaled mels in (T, M) format
93
+ TODO: doc
94
+ """
95
+ assert 0 < min_coverage <= 1
96
+ frame_step = get_frame_step(overlap, rate, hp)
97
+
98
+ # Compute how many partials can fit in the mel
99
+ n_partials, target_len = get_num_wins(len(mel), frame_step, min_coverage, hp)
100
+
101
+ # Trim or pad the mel spectrogram to match the number of partials
102
+ if target_len > len(mel):
103
+ mel = np.concatenate((mel, np.full((target_len - len(mel), hp.num_mels), 0)))
104
+ elif target_len < len(mel):
105
+ mel = mel[:target_len]
106
+
107
+ # Ensure the numpy array data is float32 and contiguous in memory
108
+ mel = mel.astype(np.float32, order="C")
109
+
110
+ # Re-arrange the array in memory to be of shape (N, P, M) with partials overlapping eachother,
111
+ # where N is the number of partials, P is the number of frames of each partial and M the
112
+ # number of channels of the mel spectrograms.
113
+ shape = (n_partials, hp.ve_partial_frames, hp.num_mels)
114
+ strides = (mel.strides[0] * frame_step, mel.strides[0], mel.strides[1])
115
+ partials = as_strided(mel, shape, strides)
116
+ return partials
117
+
118
+
119
+ class VoiceEncoder(nn.Module):
120
+ def __init__(self, hp=VoiceEncConfig()):
121
+ super().__init__()
122
+
123
+ self.hp = hp
124
+
125
+ # Network definition
126
+ self.lstm = nn.LSTM(self.hp.num_mels, self.hp.ve_hidden_size, num_layers=3, batch_first=True)
127
+ if hp.flatten_lstm_params:
128
+ self.lstm.flatten_parameters()
129
+ self.proj = nn.Linear(self.hp.ve_hidden_size, self.hp.speaker_embed_size)
130
+
131
+ # Cosine similarity scaling (fixed initial parameter values)
132
+ self.similarity_weight = nn.Parameter(torch.tensor([10.]), requires_grad=True)
133
+ self.similarity_bias = nn.Parameter(torch.tensor([-5.]), requires_grad=True)
134
+
135
+ @property
136
+ def device(self):
137
+ return next(self.parameters()).device
138
+
139
+ def forward(self, mels: torch.FloatTensor):
140
+ """
141
+ Computes the embeddings of a batch of partial utterances.
142
+
143
+ :param mels: a batch of unscaled mel spectrograms of same duration as a float32 tensor
144
+ of shape (B, T, M) where T is hp.ve_partial_frames
145
+ :return: the embeddings as a float32 tensor of shape (B, E) where E is
146
+ hp.speaker_embed_size. Embeddings are L2-normed and thus lay in the range [-1, 1].
147
+ """
148
+ if self.hp.normalized_mels and (mels.min() < 0 or mels.max() > 1):
149
+ raise Exception(f"Mels outside [0, 1]. Min={mels.min()}, Max={mels.max()}")
150
+
151
+ # Pass the input through the LSTM layers
152
+ _, (hidden, _) = self.lstm(mels)
153
+
154
+ # Project the final hidden state
155
+ raw_embeds = self.proj(hidden[-1])
156
+ if self.hp.ve_final_relu:
157
+ raw_embeds = F.relu(raw_embeds)
158
+
159
+ # L2 normalize the embeddings.
160
+ return raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True)
161
+
162
+ def inference(self, mels: torch.Tensor, mel_lens, overlap=0.5, rate: float=None, min_coverage=0.8, batch_size=None):
163
+ """
164
+ Computes the embeddings of a batch of full utterances with gradients.
165
+
166
+ :param mels: (B, T, M) unscaled mels
167
+ :return: (B, E) embeddings on CPU
168
+ """
169
+ mel_lens = mel_lens.tolist() if torch.is_tensor(mel_lens) else mel_lens
170
+
171
+ # Compute where to split the utterances into partials
172
+ frame_step = get_frame_step(overlap, rate, self.hp)
173
+ n_partials, target_lens = zip(*(get_num_wins(l, frame_step, min_coverage, self.hp) for l in mel_lens))
174
+
175
+ # Possibly pad the mels to reach the target lengths
176
+ len_diff = max(target_lens) - mels.size(1)
177
+ if len_diff > 0:
178
+ pad = torch.full((mels.size(0), len_diff, self.hp.num_mels), 0, dtype=torch.float32)
179
+ mels = torch.cat((mels, pad.to(mels.device)), dim=1)
180
+
181
+ # Group all partials together so that we can batch them easily
182
+ partials = [
183
+ mel[i * frame_step: i * frame_step + self.hp.ve_partial_frames]
184
+ for mel, n_partial in zip(mels, n_partials) for i in range(n_partial)
185
+ ]
186
+ assert all(partials[0].shape == partial.shape for partial in partials)
187
+ partials = torch.stack(partials)
188
+
189
+ # Forward the partials
190
+ n_chunks = int(np.ceil(len(partials) / (batch_size or len(partials))))
191
+ partial_embeds = torch.cat([self(batch) for batch in partials.chunk(n_chunks)], dim=0).cpu()
192
+
193
+ # Reduce the partial embeds into full embeds and L2-normalize them
194
+ slices = np.concatenate(([0], np.cumsum(n_partials)))
195
+ raw_embeds = [torch.mean(partial_embeds[start:end], dim=0) for start, end in zip(slices[:-1], slices[1:])]
196
+ raw_embeds = torch.stack(raw_embeds)
197
+ embeds = raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True)
198
+
199
+ return embeds
200
+
201
+ @staticmethod
202
+ def utt_to_spk_embed(utt_embeds: np.ndarray):
203
+ """
204
+ Takes an array of L2-normalized utterance embeddings, computes the mean embedding and L2-normalize it to get a
205
+ speaker embedding.
206
+ """
207
+ assert utt_embeds.ndim == 2
208
+ utt_embeds = np.mean(utt_embeds, axis=0)
209
+ return utt_embeds / np.linalg.norm(utt_embeds, 2)
210
+
211
+ @staticmethod
212
+ def voice_similarity(embeds_x: np.ndarray, embeds_y: np.ndarray):
213
+ """
214
+ Cosine similarity for L2-normalized utterance embeddings or speaker embeddings
215
+ """
216
+ embeds_x = embeds_x if embeds_x.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_x)
217
+ embeds_y = embeds_y if embeds_y.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_y)
218
+ return embeds_x @ embeds_y
219
+
220
+ def embeds_from_mels(
221
+ self, mels: Union[Tensor, List[np.ndarray]], mel_lens=None, as_spk=False, batch_size=32, **kwargs
222
+ ):
223
+ """
224
+ Convenience function for deriving utterance or speaker embeddings from mel spectrograms.
225
+
226
+ :param mels: unscaled mels strictly within [0, 1] as either a (B, T, M) tensor or a list of (Ti, M) arrays.
227
+ :param mel_lens: if passing mels as a tensor, individual mel lengths
228
+ :param as_spk: whether to return utterance embeddings or a single speaker embedding
229
+ :param kwargs: args for inference()
230
+
231
+ :returns: embeds as a (B, E) float32 numpy array if <as_spk> is False, else as a (E,) array
232
+ """
233
+ # Load mels in memory and pack them
234
+ if isinstance(mels, List):
235
+ mels = [np.asarray(mel) for mel in mels]
236
+ assert all(m.shape[1] == mels[0].shape[1] for m in mels), "Mels aren't in (B, T, M) format"
237
+ mel_lens = [mel.shape[0] for mel in mels]
238
+ mels = pack(mels)
239
+
240
+ # Embed them
241
+ with torch.inference_mode():
242
+ utt_embeds = self.inference(mels.to(self.device), mel_lens, batch_size=batch_size, **kwargs).numpy()
243
+
244
+ return self.utt_to_spk_embed(utt_embeds) if as_spk else utt_embeds
245
+
246
+ def embeds_from_wavs(
247
+ self,
248
+ wavs: List[np.ndarray],
249
+ sample_rate,
250
+ as_spk=False,
251
+ batch_size=32,
252
+ trim_top_db: Optional[float]=20,
253
+ **kwargs
254
+ ):
255
+ """
256
+ Wrapper around embeds_from_mels
257
+
258
+ :param trim_top_db: this argument was only added for the sake of compatibility with metavoice's implementation
259
+ """
260
+ if sample_rate != self.hp.sample_rate:
261
+ wavs = [
262
+ librosa.resample(wav, orig_sr=sample_rate, target_sr=self.hp.sample_rate, res_type="kaiser_fast")
263
+ for wav in wavs
264
+ ]
265
+
266
+ if trim_top_db:
267
+ wavs = [librosa.effects.trim(wav, top_db=trim_top_db)[0] for wav in wavs]
268
+
269
+ if "rate" not in kwargs:
270
+ kwargs["rate"] = 1.3 # Resemble's default value.
271
+
272
+ mels = [melspectrogram(w, self.hp).T for w in wavs]
273
+
274
+ return self.embeds_from_mels(mels, as_spk=as_spk, batch_size=batch_size, **kwargs)