Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
f79db70
0
Parent(s):
Clean multilingual TTS repo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +48 -0
- README.md +10 -0
- app.py +319 -0
- requirements.txt +18 -0
- src/chatterbox/__init__.py +11 -0
- src/chatterbox/models/__init__.py +0 -0
- src/chatterbox/models/s3gen/__init__.py +2 -0
- src/chatterbox/models/s3gen/configs.py +10 -0
- src/chatterbox/models/s3gen/const.py +1 -0
- src/chatterbox/models/s3gen/decoder.py +317 -0
- src/chatterbox/models/s3gen/f0_predictor.py +55 -0
- src/chatterbox/models/s3gen/flow.py +290 -0
- src/chatterbox/models/s3gen/flow_matching.py +218 -0
- src/chatterbox/models/s3gen/hifigan.py +474 -0
- src/chatterbox/models/s3gen/matcha/decoder.py +443 -0
- src/chatterbox/models/s3gen/matcha/flow_matching.py +129 -0
- src/chatterbox/models/s3gen/matcha/text_encoder.py +413 -0
- src/chatterbox/models/s3gen/matcha/transformer.py +316 -0
- src/chatterbox/models/s3gen/s3gen.py +298 -0
- src/chatterbox/models/s3gen/transformer/__init__.py +0 -0
- src/chatterbox/models/s3gen/transformer/activation.py +84 -0
- src/chatterbox/models/s3gen/transformer/attention.py +330 -0
- src/chatterbox/models/s3gen/transformer/convolution.py +145 -0
- src/chatterbox/models/s3gen/transformer/embedding.py +294 -0
- src/chatterbox/models/s3gen/transformer/encoder_layer.py +236 -0
- src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py +115 -0
- src/chatterbox/models/s3gen/transformer/subsampling.py +383 -0
- src/chatterbox/models/s3gen/transformer/upsample_encoder.py +318 -0
- src/chatterbox/models/s3gen/utils/class_utils.py +71 -0
- src/chatterbox/models/s3gen/utils/mask.py +193 -0
- src/chatterbox/models/s3gen/utils/mel.py +85 -0
- src/chatterbox/models/s3gen/xvector.py +428 -0
- src/chatterbox/models/s3tokenizer/__init__.py +30 -0
- src/chatterbox/models/s3tokenizer/s3tokenizer.py +168 -0
- src/chatterbox/models/t3/__init__.py +1 -0
- src/chatterbox/models/t3/inference/alignment_stream_analyzer.py +178 -0
- src/chatterbox/models/t3/inference/t3_hf_backend.py +116 -0
- src/chatterbox/models/t3/llama_configs.py +37 -0
- src/chatterbox/models/t3/modules/cond_enc.py +97 -0
- src/chatterbox/models/t3/modules/learned_pos_emb.py +32 -0
- src/chatterbox/models/t3/modules/perceiver.py +212 -0
- src/chatterbox/models/t3/modules/t3_config.py +37 -0
- src/chatterbox/models/t3/t3.py +391 -0
- src/chatterbox/models/tokenizers/__init__.py +1 -0
- src/chatterbox/models/tokenizers/tokenizer.py +323 -0
- src/chatterbox/models/utils.py +4 -0
- src/chatterbox/models/voice_encoder/__init__.py +1 -0
- src/chatterbox/models/voice_encoder/config.py +18 -0
- src/chatterbox/models/voice_encoder/melspec.py +78 -0
- src/chatterbox/models/voice_encoder/voice_encoder.py +274 -0
.gitignore
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.vscode
|
2 |
+
|
3 |
+
# Pylance
|
4 |
+
pyrightconfig.json
|
5 |
+
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
|
11 |
+
# C extensions
|
12 |
+
*.so
|
13 |
+
|
14 |
+
# Distribution / packaging
|
15 |
+
.Python
|
16 |
+
build/
|
17 |
+
develop-eggs/
|
18 |
+
dist/
|
19 |
+
downloads/
|
20 |
+
eggs/
|
21 |
+
.eggs/
|
22 |
+
lib/
|
23 |
+
lib64/
|
24 |
+
parts/
|
25 |
+
sdist/
|
26 |
+
var/
|
27 |
+
wheels/
|
28 |
+
*.egg-info/
|
29 |
+
.installed.cfg
|
30 |
+
*.egg
|
31 |
+
MANIFEST
|
32 |
+
|
33 |
+
# PyInstaller
|
34 |
+
# Usually these files are written by a python script from a template
|
35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
36 |
+
*.manifest
|
37 |
+
*.spec
|
38 |
+
|
39 |
+
# Installer logs
|
40 |
+
pip-log.txt
|
41 |
+
pip-delete-this-directory.txt
|
42 |
+
|
43 |
+
syn_out/
|
44 |
+
checkpoints/
|
45 |
+
.gradio
|
46 |
+
|
47 |
+
# Ignore generated sample .wav files
|
48 |
+
**/*.wav
|
README.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
metadata
|
2 |
+
title: Chatterbox-Multilingual-TTS
|
3 |
+
emoji: 🌎
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: blue
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.29.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
short_description: Chatterbox TTS supporting 23 languages
|
app.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
|
5 |
+
import gradio as gr
|
6 |
+
import spaces
|
7 |
+
|
8 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
+
print(f"🚀 Running on device: {DEVICE}")
|
10 |
+
|
11 |
+
# --- Global Model Initialization ---
|
12 |
+
MODEL = None
|
13 |
+
|
14 |
+
LANGUAGE_CONFIG = {
|
15 |
+
"ar": {
|
16 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_m1.flac",
|
17 |
+
"text": "في الشهر الماضي، وصلنا إلى معلم جديد بمليارين من المشاهدات على قناتنا على يوتيوب."
|
18 |
+
},
|
19 |
+
"da": {
|
20 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/da_m1.flac",
|
21 |
+
"text": "Sidste måned nåede vi en ny milepæl med to milliarder visninger på vores YouTube-kanal."
|
22 |
+
},
|
23 |
+
"de": {
|
24 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/de_f1.flac",
|
25 |
+
"text": "Letzten Monat haben wir einen neuen Meilenstein erreicht: zwei Milliarden Aufrufe auf unserem YouTube-Kanal."
|
26 |
+
},
|
27 |
+
"el": {
|
28 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/el_m.flac",
|
29 |
+
"text": "Τον περασμένο μήνα, φτάσαμε σε ένα νέο ορόσημο με δύο δισεκατομμύρια προβολές στο κανάλι μας στο YouTube."
|
30 |
+
},
|
31 |
+
"en": {
|
32 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/en_f1.flac",
|
33 |
+
"text": "Last month, we reached a new milestone with two billion views on our YouTube channel."
|
34 |
+
},
|
35 |
+
"es": {
|
36 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/es_f1.flac",
|
37 |
+
"text": "El mes pasado alcanzamos un nuevo hito: dos mil millones de visualizaciones en nuestro canal de YouTube."
|
38 |
+
},
|
39 |
+
"fi": {
|
40 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fi_m.flac",
|
41 |
+
"text": "Viime kuussa saavutimme uuden virstanpylvään kahden miljardin katselukerran kanssa YouTube-kanavallamme."
|
42 |
+
},
|
43 |
+
"fr": {
|
44 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fr_f1.flac",
|
45 |
+
"text": "Le mois dernier, nous avons atteint un nouveau jalon avec deux milliards de vues sur notre chaîne YouTube."
|
46 |
+
},
|
47 |
+
"he": {
|
48 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/he_m1.flac",
|
49 |
+
"text": "בחודש שעבר הגענו לאבן דרך חדשה עם שני מיליארד צפיות בערוץ היוטיוב שלנו."
|
50 |
+
},
|
51 |
+
"hi": {
|
52 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/hi_f1.flac",
|
53 |
+
"text": "पिछले महीने हमने एक नया मील का पत्थर छुआ: हमारे YouTube चैनल पर दो अरब व्यूज़।"
|
54 |
+
},
|
55 |
+
"it": {
|
56 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/it_m1.flac",
|
57 |
+
"text": "Il mese scorso abbiamo raggiunto un nuovo traguardo: due miliardi di visualizzazioni sul nostro canale YouTube."
|
58 |
+
},
|
59 |
+
"ja": {
|
60 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ja_f.flac",
|
61 |
+
"text": "先月、私たちのYouTubeチャンネルで二十億回の再生回数という新たなマイルストーンに到達しました。"
|
62 |
+
},
|
63 |
+
"ko": {
|
64 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ko_f.flac",
|
65 |
+
"text": "지난달 우리는 유튜브 채널에서 이십억 조회수라는 새로운 이정표에 도달했습니다."
|
66 |
+
},
|
67 |
+
"ms": {
|
68 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ms_f.flac",
|
69 |
+
"text": "Bulan lepas, kami mencapai pencapaian baru dengan dua bilion tontonan di saluran YouTube kami."
|
70 |
+
},
|
71 |
+
"nl": {
|
72 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/nl_m.flac",
|
73 |
+
"text": "Vorige maand bereikten we een nieuwe mijlpaal met twee miljard weergaven op ons YouTube-kanaal."
|
74 |
+
},
|
75 |
+
"no": {
|
76 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/no_f1.flac",
|
77 |
+
"text": "Forrige måned nådde vi en ny milepæl med to milliarder visninger på YouTube-kanalen vår."
|
78 |
+
},
|
79 |
+
"pl": {
|
80 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/pl_m.flac",
|
81 |
+
"text": "W zeszłym miesiącu osiągnęliśmy nowy kamień milowy z dwoma miliardami wyświetleń na naszym kanale YouTube."
|
82 |
+
},
|
83 |
+
"pt": {
|
84 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/pt_m1.flac",
|
85 |
+
"text": "No mês passado, alcançámos um novo marco: dois mil milhões de visualizações no nosso canal do YouTube."
|
86 |
+
},
|
87 |
+
"ru": {
|
88 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ru_m.flac",
|
89 |
+
"text": "В прошлом месяце мы достигли нового рубежа: два миллиарда просмотров на нашем YouTube-канале."
|
90 |
+
},
|
91 |
+
"sv": {
|
92 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/sv_f.flac",
|
93 |
+
"text": "Förra månaden nådde vi en ny milstolpe med två miljarder visningar på vår YouTube-kanal."
|
94 |
+
},
|
95 |
+
"sw": {
|
96 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/sw_m.flac",
|
97 |
+
"text": "Mwezi uliopita, tulifika hatua mpya ya maoni ya bilioni mbili kweny kituo chetu cha YouTube."
|
98 |
+
},
|
99 |
+
"tr": {
|
100 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/tr_m.flac",
|
101 |
+
"text": "Geçen ay YouTube kanalımızda iki milyar görüntüleme ile yeni bir dönüm noktasına ulaştık."
|
102 |
+
},
|
103 |
+
"zh": {
|
104 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/zh_f.flac",
|
105 |
+
"text": "上个月,我们达到了一个新的里程碑,我们的YouTube频道观看次数达到了二十亿次,这绝对令人难以置信。"
|
106 |
+
},
|
107 |
+
}
|
108 |
+
|
109 |
+
# --- UI Helpers ---
|
110 |
+
def default_audio_for_ui(lang: str) -> str | None:
|
111 |
+
return LANGUAGE_CONFIG.get(lang, {}).get("audio")
|
112 |
+
|
113 |
+
|
114 |
+
def default_text_for_ui(lang: str) -> str:
|
115 |
+
return LANGUAGE_CONFIG.get(lang, {}).get("text", "")
|
116 |
+
|
117 |
+
|
118 |
+
def get_supported_languages_display() -> str:
|
119 |
+
"""Generate a formatted display of all supported languages."""
|
120 |
+
language_items = []
|
121 |
+
for code, name in sorted(SUPPORTED_LANGUAGES.items()):
|
122 |
+
language_items.append(f"**{name}** (`{code}`)")
|
123 |
+
|
124 |
+
# Split into 2 lines
|
125 |
+
mid = len(language_items) // 2
|
126 |
+
line1 = " • ".join(language_items[:mid])
|
127 |
+
line2 = " • ".join(language_items[mid:])
|
128 |
+
|
129 |
+
return f"""
|
130 |
+
### 🌍 Supported Languages ({len(SUPPORTED_LANGUAGES)} total)
|
131 |
+
{line1}
|
132 |
+
|
133 |
+
{line2}
|
134 |
+
"""
|
135 |
+
|
136 |
+
|
137 |
+
def get_or_load_model():
|
138 |
+
"""Loads the ChatterboxMultilingualTTS model if it hasn't been loaded already,
|
139 |
+
and ensures it's on the correct device."""
|
140 |
+
global MODEL
|
141 |
+
if MODEL is None:
|
142 |
+
print("Model not loaded, initializing...")
|
143 |
+
try:
|
144 |
+
MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE)
|
145 |
+
if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
|
146 |
+
MODEL.to(DEVICE)
|
147 |
+
print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
|
148 |
+
except Exception as e:
|
149 |
+
print(f"Error loading model: {e}")
|
150 |
+
raise
|
151 |
+
return MODEL
|
152 |
+
|
153 |
+
# Attempt to load the model at startup.
|
154 |
+
try:
|
155 |
+
get_or_load_model()
|
156 |
+
except Exception as e:
|
157 |
+
print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
|
158 |
+
|
159 |
+
def set_seed(seed: int):
|
160 |
+
"""Sets the random seed for reproducibility across torch, numpy, and random."""
|
161 |
+
torch.manual_seed(seed)
|
162 |
+
if DEVICE == "cuda":
|
163 |
+
torch.cuda.manual_seed(seed)
|
164 |
+
torch.cuda.manual_seed_all(seed)
|
165 |
+
random.seed(seed)
|
166 |
+
np.random.seed(seed)
|
167 |
+
|
168 |
+
def resolve_audio_prompt(language_id: str, provided_path: str | None) -> str | None:
|
169 |
+
"""
|
170 |
+
Decide which audio prompt to use:
|
171 |
+
- If user provided a path (upload/mic/url), use it.
|
172 |
+
- Else, fall back to language-specific default (if any).
|
173 |
+
"""
|
174 |
+
if provided_path and str(provided_path).strip():
|
175 |
+
return provided_path
|
176 |
+
return LANGUAGE_CONFIG.get(language_id, {}).get("audio")
|
177 |
+
|
178 |
+
|
179 |
+
@spaces.GPU
|
180 |
+
def generate_tts_audio(
|
181 |
+
text_input: str,
|
182 |
+
language_id: str,
|
183 |
+
audio_prompt_path_input: str = None,
|
184 |
+
exaggeration_input: float = 0.5,
|
185 |
+
temperature_input: float = 0.8,
|
186 |
+
seed_num_input: int = 0,
|
187 |
+
cfgw_input: float = 0.5
|
188 |
+
) -> tuple[int, np.ndarray]:
|
189 |
+
"""
|
190 |
+
Generate high-quality speech audio from text using Chatterbox Multilingual model with optional reference audio styling.
|
191 |
+
Supported languages: English, French, German, Spanish, Italian, Portuguese, and Hindi.
|
192 |
+
|
193 |
+
This tool synthesizes natural-sounding speech from input text. When a reference audio file
|
194 |
+
is provided, it captures the speaker's voice characteristics and speaking style. The generated audio
|
195 |
+
maintains the prosody, tone, and vocal qualities of the reference speaker, or uses default voice if no reference is provided.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
text_input (str): The text to synthesize into speech (maximum 300 characters)
|
199 |
+
language_id (str): The language code for synthesis (eg. en, fr, de, es, it, pt, hi)
|
200 |
+
audio_prompt_path_input (str, optional): File path or URL to the reference audio file that defines the target voice style. Defaults to None.
|
201 |
+
exaggeration_input (float, optional): Controls speech expressiveness (0.25-2.0, neutral=0.5, extreme values may be unstable). Defaults to 0.5.
|
202 |
+
temperature_input (float, optional): Controls randomness in generation (0.05-5.0, higher=more varied). Defaults to 0.8.
|
203 |
+
seed_num_input (int, optional): Random seed for reproducible results (0 for random generation). Defaults to 0.
|
204 |
+
cfgw_input (float, optional): CFG/Pace weight controlling generation guidance (0.2-1.0). Defaults to 0.5, 0 for language transfer.
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
tuple[int, np.ndarray]: A tuple containing the sample rate (int) and the generated audio waveform (numpy.ndarray)
|
208 |
+
"""
|
209 |
+
current_model = get_or_load_model()
|
210 |
+
|
211 |
+
if current_model is None:
|
212 |
+
raise RuntimeError("TTS model is not loaded.")
|
213 |
+
|
214 |
+
if seed_num_input != 0:
|
215 |
+
set_seed(int(seed_num_input))
|
216 |
+
|
217 |
+
print(f"Generating audio for text: '{text_input[:50]}...'")
|
218 |
+
|
219 |
+
# Handle optional audio prompt
|
220 |
+
chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id)
|
221 |
+
|
222 |
+
generate_kwargs = {
|
223 |
+
"exaggeration": exaggeration_input,
|
224 |
+
"temperature": temperature_input,
|
225 |
+
"cfg_weight": cfgw_input,
|
226 |
+
}
|
227 |
+
if chosen_prompt:
|
228 |
+
generate_kwargs["audio_prompt_path"] = chosen_prompt
|
229 |
+
print(f"Using audio prompt: {chosen_prompt}")
|
230 |
+
else:
|
231 |
+
print("No audio prompt provided; using default voice.")
|
232 |
+
|
233 |
+
wav = current_model.generate(
|
234 |
+
text_input[:300], # Truncate text to max chars
|
235 |
+
language_id=language_id,
|
236 |
+
**generate_kwargs
|
237 |
+
)
|
238 |
+
print("Audio generation complete.")
|
239 |
+
return (current_model.sr, wav.squeeze(0).numpy())
|
240 |
+
|
241 |
+
with gr.Blocks() as demo:
|
242 |
+
gr.Markdown(
|
243 |
+
"""
|
244 |
+
# Chatterbox Multilingual Demo
|
245 |
+
Generate high-quality multilingual speech from text with reference audio styling, supporting 23 languages.
|
246 |
+
"""
|
247 |
+
)
|
248 |
+
|
249 |
+
# Display supported languages
|
250 |
+
gr.Markdown(get_supported_languages_display())
|
251 |
+
with gr.Row():
|
252 |
+
with gr.Column():
|
253 |
+
initial_lang = "fr"
|
254 |
+
text = gr.Textbox(
|
255 |
+
value=default_text_for_ui(initial_lang),
|
256 |
+
label="Text to synthesize (max chars 300)",
|
257 |
+
max_lines=5
|
258 |
+
)
|
259 |
+
|
260 |
+
language_id = gr.Dropdown(
|
261 |
+
choices=list(ChatterboxMultilingualTTS.get_supported_languages().keys()),
|
262 |
+
value=initial_lang,
|
263 |
+
label="Language",
|
264 |
+
info="Select the language for text-to-speech synthesis"
|
265 |
+
)
|
266 |
+
|
267 |
+
ref_wav = gr.Audio(
|
268 |
+
sources=["upload", "microphone"],
|
269 |
+
type="filepath",
|
270 |
+
label="Reference Audio File (Optional)",
|
271 |
+
value=default_audio_for_ui(initial_lang)
|
272 |
+
)
|
273 |
+
|
274 |
+
gr.Markdown(
|
275 |
+
"💡 **Note**: Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clip's language. To mitigate this, set the CFG weight to 0.",
|
276 |
+
elem_classes=["audio-note"]
|
277 |
+
)
|
278 |
+
|
279 |
+
exaggeration = gr.Slider(
|
280 |
+
0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
|
281 |
+
)
|
282 |
+
cfg_weight = gr.Slider(
|
283 |
+
0.2, 1, step=.05, label="CFG/Pace", value=0.5
|
284 |
+
)
|
285 |
+
|
286 |
+
with gr.Accordion("More options", open=False):
|
287 |
+
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
|
288 |
+
temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
|
289 |
+
|
290 |
+
run_btn = gr.Button("Generate", variant="primary")
|
291 |
+
|
292 |
+
with gr.Column():
|
293 |
+
audio_output = gr.Audio(label="Output Audio")
|
294 |
+
|
295 |
+
def on_language_change(lang, current_ref, current_text):
|
296 |
+
return default_audio_for_ui(lang), default_text_for_ui(lang)
|
297 |
+
|
298 |
+
language_id.change(
|
299 |
+
fn=on_language_change,
|
300 |
+
inputs=[language_id, ref_wav, text],
|
301 |
+
outputs=[ref_wav, text],
|
302 |
+
show_progress=False
|
303 |
+
)
|
304 |
+
|
305 |
+
run_btn.click(
|
306 |
+
fn=generate_tts_audio,
|
307 |
+
inputs=[
|
308 |
+
text,
|
309 |
+
language_id,
|
310 |
+
ref_wav,
|
311 |
+
exaggeration,
|
312 |
+
temp,
|
313 |
+
seed_num,
|
314 |
+
cfg_weight,
|
315 |
+
],
|
316 |
+
outputs=[audio_output],
|
317 |
+
)
|
318 |
+
|
319 |
+
demo.launch(mcp_server=True)
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
numpy==1.26.0
|
3 |
+
resampy==0.4.3
|
4 |
+
librosa==0.10.0
|
5 |
+
s3tokenizer
|
6 |
+
transformers==4.46.3
|
7 |
+
diffusers==0.29.0
|
8 |
+
omegaconf==2.3.0
|
9 |
+
resemble-perth==1.0.1
|
10 |
+
silero-vad==5.1.2
|
11 |
+
conformer==0.3.2
|
12 |
+
safetensors
|
13 |
+
|
14 |
+
# Optional language-specific dependencies
|
15 |
+
# Uncomment the ones you need for specific languages:
|
16 |
+
# pkuseg # For Chinese text segmentation (improves mixed text handling)
|
17 |
+
# pykakasi>=2.2.0 # For Japanese text processing (Kanji to Hiragana)
|
18 |
+
# dicta-onnx>=0.1.0 # For Hebrew diacritization
|
src/chatterbox/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
try:
|
2 |
+
from importlib.metadata import version
|
3 |
+
except ImportError:
|
4 |
+
from importlib_metadata import version # For Python <3.8
|
5 |
+
|
6 |
+
__version__ = version("chatterbox-tts")
|
7 |
+
|
8 |
+
|
9 |
+
from .tts import ChatterboxTTS
|
10 |
+
from .vc import ChatterboxVC
|
11 |
+
from .mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
|
src/chatterbox/models/__init__.py
ADDED
File without changes
|
src/chatterbox/models/s3gen/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .s3gen import S3Token2Wav as S3Gen
|
2 |
+
from .const import S3GEN_SR
|
src/chatterbox/models/s3gen/configs.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..utils import AttrDict
|
2 |
+
|
3 |
+
CFM_PARAMS = AttrDict({
|
4 |
+
"sigma_min": 1e-06,
|
5 |
+
"solver": "euler",
|
6 |
+
"t_scheduler": "cosine",
|
7 |
+
"training_cfg_rate": 0.2,
|
8 |
+
"inference_cfg_rate": 0.7,
|
9 |
+
"reg_loss_type": "l1"
|
10 |
+
})
|
src/chatterbox/models/s3gen/const.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
S3GEN_SR = 24000
|
src/chatterbox/models/s3gen/decoder.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from einops import pack, rearrange, repeat
|
18 |
+
|
19 |
+
from .utils.mask import add_optional_chunk_mask
|
20 |
+
from .matcha.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, \
|
21 |
+
TimestepEmbedding, Upsample1D
|
22 |
+
from .matcha.transformer import BasicTransformerBlock
|
23 |
+
|
24 |
+
|
25 |
+
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
26 |
+
assert mask.dtype == torch.bool
|
27 |
+
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
|
28 |
+
mask = mask.to(dtype)
|
29 |
+
# attention mask bias
|
30 |
+
# NOTE(Mddct): torch.finfo jit issues
|
31 |
+
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
32 |
+
mask = (1.0 - mask) * -1.0e+10
|
33 |
+
return mask
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
class Transpose(torch.nn.Module):
|
38 |
+
def __init__(self, dim0: int, dim1: int):
|
39 |
+
super().__init__()
|
40 |
+
self.dim0 = dim0
|
41 |
+
self.dim1 = dim1
|
42 |
+
|
43 |
+
def forward(self, x: torch.Tensor):
|
44 |
+
x = torch.transpose(x, self.dim0, self.dim1)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class CausalBlock1D(Block1D):
|
49 |
+
def __init__(self, dim: int, dim_out: int):
|
50 |
+
super(CausalBlock1D, self).__init__(dim, dim_out)
|
51 |
+
self.block = torch.nn.Sequential(
|
52 |
+
CausalConv1d(dim, dim_out, 3),
|
53 |
+
Transpose(1, 2),
|
54 |
+
nn.LayerNorm(dim_out),
|
55 |
+
Transpose(1, 2),
|
56 |
+
nn.Mish(),
|
57 |
+
)
|
58 |
+
|
59 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
60 |
+
output = self.block(x * mask)
|
61 |
+
return output * mask
|
62 |
+
|
63 |
+
|
64 |
+
class CausalResnetBlock1D(ResnetBlock1D):
|
65 |
+
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
66 |
+
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
67 |
+
self.block1 = CausalBlock1D(dim, dim_out)
|
68 |
+
self.block2 = CausalBlock1D(dim_out, dim_out)
|
69 |
+
|
70 |
+
|
71 |
+
class CausalConv1d(torch.nn.Conv1d):
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
in_channels: int,
|
75 |
+
out_channels: int,
|
76 |
+
kernel_size: int,
|
77 |
+
stride: int = 1,
|
78 |
+
dilation: int = 1,
|
79 |
+
groups: int = 1,
|
80 |
+
bias: bool = True,
|
81 |
+
padding_mode: str = 'zeros',
|
82 |
+
device=None,
|
83 |
+
dtype=None
|
84 |
+
) -> None:
|
85 |
+
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
86 |
+
kernel_size, stride,
|
87 |
+
padding=0, dilation=dilation,
|
88 |
+
groups=groups, bias=bias,
|
89 |
+
padding_mode=padding_mode,
|
90 |
+
device=device, dtype=dtype)
|
91 |
+
assert stride == 1
|
92 |
+
self.causal_padding = (kernel_size - 1, 0)
|
93 |
+
|
94 |
+
def forward(self, x: torch.Tensor):
|
95 |
+
x = F.pad(x, self.causal_padding)
|
96 |
+
x = super(CausalConv1d, self).forward(x)
|
97 |
+
return x
|
98 |
+
|
99 |
+
|
100 |
+
class ConditionalDecoder(nn.Module):
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
in_channels=320,
|
104 |
+
out_channels=80,
|
105 |
+
causal=True,
|
106 |
+
channels=[256],
|
107 |
+
dropout=0.0,
|
108 |
+
attention_head_dim=64,
|
109 |
+
n_blocks=4,
|
110 |
+
num_mid_blocks=12,
|
111 |
+
num_heads=8,
|
112 |
+
act_fn="gelu",
|
113 |
+
):
|
114 |
+
"""
|
115 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
116 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
117 |
+
"""
|
118 |
+
super().__init__()
|
119 |
+
channels = tuple(channels)
|
120 |
+
self.in_channels = in_channels
|
121 |
+
self.out_channels = out_channels
|
122 |
+
self.causal = causal
|
123 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
124 |
+
time_embed_dim = channels[0] * 4
|
125 |
+
self.time_mlp = TimestepEmbedding(
|
126 |
+
in_channels=in_channels,
|
127 |
+
time_embed_dim=time_embed_dim,
|
128 |
+
act_fn="silu",
|
129 |
+
)
|
130 |
+
self.down_blocks = nn.ModuleList([])
|
131 |
+
self.mid_blocks = nn.ModuleList([])
|
132 |
+
self.up_blocks = nn.ModuleList([])
|
133 |
+
|
134 |
+
# NOTE jrm: `static_chunk_size` is missing?
|
135 |
+
self.static_chunk_size = 0
|
136 |
+
|
137 |
+
output_channel = in_channels
|
138 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
139 |
+
input_channel = output_channel
|
140 |
+
output_channel = channels[i]
|
141 |
+
is_last = i == len(channels) - 1
|
142 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
143 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
144 |
+
transformer_blocks = nn.ModuleList(
|
145 |
+
[
|
146 |
+
BasicTransformerBlock(
|
147 |
+
dim=output_channel,
|
148 |
+
num_attention_heads=num_heads,
|
149 |
+
attention_head_dim=attention_head_dim,
|
150 |
+
dropout=dropout,
|
151 |
+
activation_fn=act_fn,
|
152 |
+
)
|
153 |
+
for _ in range(n_blocks)
|
154 |
+
]
|
155 |
+
)
|
156 |
+
downsample = (
|
157 |
+
Downsample1D(output_channel) if not is_last else
|
158 |
+
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
159 |
+
)
|
160 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
161 |
+
|
162 |
+
for _ in range(num_mid_blocks):
|
163 |
+
input_channel = channels[-1]
|
164 |
+
out_channels = channels[-1]
|
165 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
166 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
167 |
+
|
168 |
+
transformer_blocks = nn.ModuleList(
|
169 |
+
[
|
170 |
+
BasicTransformerBlock(
|
171 |
+
dim=output_channel,
|
172 |
+
num_attention_heads=num_heads,
|
173 |
+
attention_head_dim=attention_head_dim,
|
174 |
+
dropout=dropout,
|
175 |
+
activation_fn=act_fn,
|
176 |
+
)
|
177 |
+
for _ in range(n_blocks)
|
178 |
+
]
|
179 |
+
)
|
180 |
+
|
181 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
182 |
+
|
183 |
+
channels = channels[::-1] + (channels[0],)
|
184 |
+
for i in range(len(channels) - 1):
|
185 |
+
input_channel = channels[i] * 2
|
186 |
+
output_channel = channels[i + 1]
|
187 |
+
is_last = i == len(channels) - 2
|
188 |
+
resnet = CausalResnetBlock1D(
|
189 |
+
dim=input_channel,
|
190 |
+
dim_out=output_channel,
|
191 |
+
time_emb_dim=time_embed_dim,
|
192 |
+
) if self.causal else ResnetBlock1D(
|
193 |
+
dim=input_channel,
|
194 |
+
dim_out=output_channel,
|
195 |
+
time_emb_dim=time_embed_dim,
|
196 |
+
)
|
197 |
+
transformer_blocks = nn.ModuleList(
|
198 |
+
[
|
199 |
+
BasicTransformerBlock(
|
200 |
+
dim=output_channel,
|
201 |
+
num_attention_heads=num_heads,
|
202 |
+
attention_head_dim=attention_head_dim,
|
203 |
+
dropout=dropout,
|
204 |
+
activation_fn=act_fn,
|
205 |
+
)
|
206 |
+
for _ in range(n_blocks)
|
207 |
+
]
|
208 |
+
)
|
209 |
+
upsample = (
|
210 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
211 |
+
if not is_last
|
212 |
+
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
213 |
+
)
|
214 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
215 |
+
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
|
216 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
217 |
+
self.initialize_weights()
|
218 |
+
|
219 |
+
def initialize_weights(self):
|
220 |
+
for m in self.modules():
|
221 |
+
if isinstance(m, nn.Conv1d):
|
222 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
223 |
+
if m.bias is not None:
|
224 |
+
nn.init.constant_(m.bias, 0)
|
225 |
+
elif isinstance(m, nn.GroupNorm):
|
226 |
+
nn.init.constant_(m.weight, 1)
|
227 |
+
nn.init.constant_(m.bias, 0)
|
228 |
+
elif isinstance(m, nn.Linear):
|
229 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
230 |
+
if m.bias is not None:
|
231 |
+
nn.init.constant_(m.bias, 0)
|
232 |
+
|
233 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
234 |
+
"""Forward pass of the UNet1DConditional model.
|
235 |
+
|
236 |
+
Args:
|
237 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
238 |
+
mask (_type_): shape (batch_size, 1, time)
|
239 |
+
t (_type_): shape (batch_size)
|
240 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
241 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
242 |
+
|
243 |
+
Raises:
|
244 |
+
ValueError: _description_
|
245 |
+
ValueError: _description_
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
_type_: _description_
|
249 |
+
"""
|
250 |
+
|
251 |
+
t = self.time_embeddings(t).to(t.dtype)
|
252 |
+
t = self.time_mlp(t)
|
253 |
+
|
254 |
+
x = pack([x, mu], "b * t")[0]
|
255 |
+
|
256 |
+
if spks is not None:
|
257 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
258 |
+
x = pack([x, spks], "b * t")[0]
|
259 |
+
if cond is not None:
|
260 |
+
x = pack([x, cond], "b * t")[0]
|
261 |
+
|
262 |
+
hiddens = []
|
263 |
+
masks = [mask]
|
264 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
265 |
+
mask_down = masks[-1]
|
266 |
+
x = resnet(x, mask_down, t)
|
267 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
268 |
+
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
269 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
270 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
271 |
+
for transformer_block in transformer_blocks:
|
272 |
+
x = transformer_block(
|
273 |
+
hidden_states=x,
|
274 |
+
attention_mask=attn_mask,
|
275 |
+
timestep=t,
|
276 |
+
)
|
277 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
278 |
+
hiddens.append(x) # Save hidden states for skip connections
|
279 |
+
x = downsample(x * mask_down)
|
280 |
+
masks.append(mask_down[:, :, ::2])
|
281 |
+
masks = masks[:-1]
|
282 |
+
mask_mid = masks[-1]
|
283 |
+
|
284 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
285 |
+
x = resnet(x, mask_mid, t)
|
286 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
287 |
+
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
288 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
289 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
290 |
+
for transformer_block in transformer_blocks:
|
291 |
+
x = transformer_block(
|
292 |
+
hidden_states=x,
|
293 |
+
attention_mask=attn_mask,
|
294 |
+
timestep=t,
|
295 |
+
)
|
296 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
297 |
+
|
298 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
299 |
+
mask_up = masks.pop()
|
300 |
+
skip = hiddens.pop()
|
301 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
302 |
+
x = resnet(x, mask_up, t)
|
303 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
304 |
+
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
305 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
306 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
307 |
+
for transformer_block in transformer_blocks:
|
308 |
+
x = transformer_block(
|
309 |
+
hidden_states=x,
|
310 |
+
attention_mask=attn_mask,
|
311 |
+
timestep=t,
|
312 |
+
)
|
313 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
314 |
+
x = upsample(x * mask_up)
|
315 |
+
x = self.final_block(x, mask_up)
|
316 |
+
output = self.final_proj(x * mask_up)
|
317 |
+
return output * mask
|
src/chatterbox/models/s3gen/f0_predictor.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.nn.utils.parametrizations import weight_norm
|
17 |
+
|
18 |
+
|
19 |
+
class ConvRNNF0Predictor(nn.Module):
|
20 |
+
def __init__(self,
|
21 |
+
num_class: int = 1,
|
22 |
+
in_channels: int = 80,
|
23 |
+
cond_channels: int = 512
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
self.num_class = num_class
|
28 |
+
self.condnet = nn.Sequential(
|
29 |
+
weight_norm(
|
30 |
+
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
|
31 |
+
),
|
32 |
+
nn.ELU(),
|
33 |
+
weight_norm(
|
34 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
35 |
+
),
|
36 |
+
nn.ELU(),
|
37 |
+
weight_norm(
|
38 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
39 |
+
),
|
40 |
+
nn.ELU(),
|
41 |
+
weight_norm(
|
42 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
43 |
+
),
|
44 |
+
nn.ELU(),
|
45 |
+
weight_norm(
|
46 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
47 |
+
),
|
48 |
+
nn.ELU(),
|
49 |
+
)
|
50 |
+
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
51 |
+
|
52 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
53 |
+
x = self.condnet(x)
|
54 |
+
x = x.transpose(1, 2)
|
55 |
+
return torch.abs(self.classifier(x).squeeze(-1))
|
src/chatterbox/models/s3gen/flow.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import logging
|
15 |
+
import random
|
16 |
+
from typing import Dict, Optional
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
from torch.nn import functional as F
|
22 |
+
from .utils.mask import make_pad_mask
|
23 |
+
from .configs import CFM_PARAMS
|
24 |
+
|
25 |
+
|
26 |
+
class MaskedDiffWithXvec(torch.nn.Module):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
input_size: int = 512,
|
30 |
+
output_size: int = 80,
|
31 |
+
spk_embed_dim: int = 192,
|
32 |
+
output_type: str = "mel",
|
33 |
+
vocab_size: int = 4096,
|
34 |
+
input_frame_rate: int = 50,
|
35 |
+
only_mask_loss: bool = True,
|
36 |
+
encoder: torch.nn.Module = None,
|
37 |
+
length_regulator: torch.nn.Module = None,
|
38 |
+
decoder: torch.nn.Module = None,
|
39 |
+
decoder_conf: Dict = {
|
40 |
+
'in_channels': 240,
|
41 |
+
'out_channel': 80,
|
42 |
+
'spk_emb_dim': 80,
|
43 |
+
'n_spks': 1,
|
44 |
+
'cfm_params': CFM_PARAMS,
|
45 |
+
'decoder_params': {
|
46 |
+
'channels': [256, 256],
|
47 |
+
'dropout': 0.0,
|
48 |
+
'attention_head_dim': 64,
|
49 |
+
'n_blocks': 4,
|
50 |
+
'num_mid_blocks': 12,
|
51 |
+
'num_heads': 8,
|
52 |
+
'act_fn': 'gelu',
|
53 |
+
}
|
54 |
+
},
|
55 |
+
mel_feat_conf: Dict = {
|
56 |
+
'n_fft': 1024,
|
57 |
+
'num_mels': 80,
|
58 |
+
'sampling_rate': 22050,
|
59 |
+
'hop_size': 256,
|
60 |
+
'win_size': 1024,
|
61 |
+
'fmin': 0,
|
62 |
+
'fmax': 8000
|
63 |
+
}
|
64 |
+
):
|
65 |
+
super().__init__()
|
66 |
+
self.input_size = input_size
|
67 |
+
self.output_size = output_size
|
68 |
+
self.decoder_conf = decoder_conf
|
69 |
+
self.mel_feat_conf = mel_feat_conf
|
70 |
+
self.vocab_size = vocab_size
|
71 |
+
self.output_type = output_type
|
72 |
+
self.input_frame_rate = input_frame_rate
|
73 |
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
74 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
75 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
76 |
+
self.encoder = encoder
|
77 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
78 |
+
self.decoder = decoder
|
79 |
+
self.length_regulator = length_regulator
|
80 |
+
self.only_mask_loss = only_mask_loss
|
81 |
+
|
82 |
+
def forward(
|
83 |
+
self,
|
84 |
+
batch: dict,
|
85 |
+
device: torch.device,
|
86 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
87 |
+
token = batch['speech_token'].to(device)
|
88 |
+
token_len = batch['speech_token_len'].to(device)
|
89 |
+
feat = batch['speech_feat'].to(device)
|
90 |
+
feat_len = batch['speech_feat_len'].to(device)
|
91 |
+
embedding = batch['embedding'].to(device)
|
92 |
+
|
93 |
+
# xvec projection
|
94 |
+
embedding = F.normalize(embedding, dim=1)
|
95 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
96 |
+
|
97 |
+
# concat text and prompt_text
|
98 |
+
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
99 |
+
token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask
|
100 |
+
|
101 |
+
# text encode
|
102 |
+
h, h_lengths = self.encoder(token, token_len)
|
103 |
+
h = self.encoder_proj(h)
|
104 |
+
h, h_lengths = self.length_regulator(h, feat_len)
|
105 |
+
|
106 |
+
# get conditions
|
107 |
+
conds = torch.zeros(feat.shape, device=token.device)
|
108 |
+
for i, j in enumerate(feat_len):
|
109 |
+
if random.random() < 0.5:
|
110 |
+
continue
|
111 |
+
index = random.randint(0, int(0.3 * j))
|
112 |
+
conds[i, :index] = feat[i, :index]
|
113 |
+
conds = conds.transpose(1, 2)
|
114 |
+
|
115 |
+
mask = (~make_pad_mask(feat_len)).to(h)
|
116 |
+
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
117 |
+
loss, _ = self.decoder.compute_loss(
|
118 |
+
feat.transpose(1, 2).contiguous(),
|
119 |
+
mask.unsqueeze(1),
|
120 |
+
h.transpose(1, 2).contiguous(),
|
121 |
+
embedding,
|
122 |
+
cond=conds
|
123 |
+
)
|
124 |
+
return {'loss': loss}
|
125 |
+
|
126 |
+
@torch.inference_mode()
|
127 |
+
def inference(self,
|
128 |
+
token,
|
129 |
+
token_len,
|
130 |
+
prompt_token,
|
131 |
+
prompt_token_len,
|
132 |
+
prompt_feat,
|
133 |
+
prompt_feat_len,
|
134 |
+
embedding,
|
135 |
+
flow_cache):
|
136 |
+
if self.fp16 is True:
|
137 |
+
prompt_feat = prompt_feat.half()
|
138 |
+
embedding = embedding.half()
|
139 |
+
|
140 |
+
assert token.shape[0] == 1
|
141 |
+
# xvec projection
|
142 |
+
embedding = F.normalize(embedding, dim=1)
|
143 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
144 |
+
|
145 |
+
# concat text and prompt_text
|
146 |
+
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
147 |
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
148 |
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
149 |
+
|
150 |
+
# Check for out-of-bounds token IDs
|
151 |
+
vocab_size = self.input_embedding.num_embeddings
|
152 |
+
if token.max() >= vocab_size or token.min() < 0:
|
153 |
+
logging.warning(f"S3Gen: Token IDs out of bounds: min={token.min().item()}, max={token.max().item()}, vocab_size={vocab_size}")
|
154 |
+
|
155 |
+
token = self.input_embedding(torch.clamp(token, min=0, max=vocab_size-1)) * mask
|
156 |
+
|
157 |
+
# text encode
|
158 |
+
h, h_lengths = self.encoder(token, token_len)
|
159 |
+
h = self.encoder_proj(h)
|
160 |
+
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
|
161 |
+
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
162 |
+
|
163 |
+
# get conditions
|
164 |
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
165 |
+
conds[:, :mel_len1] = prompt_feat
|
166 |
+
conds = conds.transpose(1, 2)
|
167 |
+
|
168 |
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
169 |
+
feat, flow_cache = self.decoder(
|
170 |
+
mu=h.transpose(1, 2).contiguous(),
|
171 |
+
mask=mask.unsqueeze(1),
|
172 |
+
spks=embedding,
|
173 |
+
cond=conds,
|
174 |
+
n_timesteps=10,
|
175 |
+
prompt_len=mel_len1,
|
176 |
+
flow_cache=flow_cache
|
177 |
+
)
|
178 |
+
feat = feat[:, :, mel_len1:]
|
179 |
+
assert feat.shape[2] == mel_len2
|
180 |
+
return feat.float(), flow_cache
|
181 |
+
|
182 |
+
|
183 |
+
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
input_size: int = 512,
|
187 |
+
output_size: int = 80,
|
188 |
+
spk_embed_dim: int = 192,
|
189 |
+
output_type: str = "mel",
|
190 |
+
vocab_size: int = 6561,
|
191 |
+
input_frame_rate: int = 25,
|
192 |
+
only_mask_loss: bool = True,
|
193 |
+
token_mel_ratio: int = 2,
|
194 |
+
pre_lookahead_len: int = 3,
|
195 |
+
encoder: torch.nn.Module = None,
|
196 |
+
decoder: torch.nn.Module = None,
|
197 |
+
decoder_conf: Dict = {
|
198 |
+
'in_channels': 240,
|
199 |
+
'out_channel': 80,
|
200 |
+
'spk_emb_dim': 80,
|
201 |
+
'n_spks': 1,
|
202 |
+
'cfm_params': CFM_PARAMS,
|
203 |
+
'decoder_params': {
|
204 |
+
'channels': [256, 256],
|
205 |
+
'dropout': 0.0,
|
206 |
+
'attention_head_dim': 64,
|
207 |
+
'n_blocks': 4,
|
208 |
+
'num_mid_blocks': 12,
|
209 |
+
'num_heads': 8,
|
210 |
+
'act_fn': 'gelu',
|
211 |
+
}
|
212 |
+
},
|
213 |
+
mel_feat_conf: Dict = {
|
214 |
+
'n_fft': 1024,
|
215 |
+
'num_mels': 80,
|
216 |
+
'sampling_rate': 22050,
|
217 |
+
'hop_size': 256,
|
218 |
+
'win_size': 1024,
|
219 |
+
'fmin': 0,
|
220 |
+
'fmax': 8000
|
221 |
+
}
|
222 |
+
):
|
223 |
+
super().__init__()
|
224 |
+
self.input_size = input_size
|
225 |
+
self.output_size = output_size
|
226 |
+
self.decoder_conf = decoder_conf
|
227 |
+
self.mel_feat_conf = mel_feat_conf
|
228 |
+
self.vocab_size = vocab_size
|
229 |
+
self.output_type = output_type
|
230 |
+
self.input_frame_rate = input_frame_rate
|
231 |
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
232 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
233 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
234 |
+
self.encoder = encoder
|
235 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
236 |
+
self.decoder = decoder
|
237 |
+
self.only_mask_loss = only_mask_loss
|
238 |
+
self.token_mel_ratio = token_mel_ratio
|
239 |
+
self.pre_lookahead_len = pre_lookahead_len
|
240 |
+
|
241 |
+
# FIXME: this was missing - just putting it in as false
|
242 |
+
self.fp16 = False
|
243 |
+
|
244 |
+
@torch.inference_mode()
|
245 |
+
def inference(self,
|
246 |
+
token,
|
247 |
+
token_len,
|
248 |
+
prompt_token,
|
249 |
+
prompt_token_len,
|
250 |
+
prompt_feat,
|
251 |
+
prompt_feat_len,
|
252 |
+
embedding,
|
253 |
+
finalize):
|
254 |
+
if self.fp16 is True:
|
255 |
+
prompt_feat = prompt_feat.half()
|
256 |
+
embedding = embedding.half()
|
257 |
+
|
258 |
+
assert token.shape[0] == 1
|
259 |
+
# xvec projection
|
260 |
+
embedding = F.normalize(embedding, dim=1)
|
261 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
262 |
+
|
263 |
+
# concat text and prompt_text
|
264 |
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
265 |
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
266 |
+
token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask
|
267 |
+
|
268 |
+
# text encode
|
269 |
+
h, h_lengths = self.encoder(token, token_len)
|
270 |
+
if finalize is False:
|
271 |
+
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
|
272 |
+
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
273 |
+
h = self.encoder_proj(h)
|
274 |
+
|
275 |
+
# get conditions
|
276 |
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
277 |
+
conds[:, :mel_len1] = prompt_feat
|
278 |
+
conds = conds.transpose(1, 2)
|
279 |
+
|
280 |
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
281 |
+
feat, _ = self.decoder(
|
282 |
+
mu=h.transpose(1, 2).contiguous(),
|
283 |
+
mask=mask.unsqueeze(1),
|
284 |
+
spks=embedding,
|
285 |
+
cond=conds,
|
286 |
+
n_timesteps=10
|
287 |
+
)
|
288 |
+
feat = feat[:, :, mel_len1:]
|
289 |
+
assert feat.shape[2] == mel_len2
|
290 |
+
return feat.float(), None # NOTE jrm: why are they returning None here?
|
src/chatterbox/models/s3gen/flow_matching.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import threading
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from .matcha.flow_matching import BASECFM
|
18 |
+
from .configs import CFM_PARAMS
|
19 |
+
|
20 |
+
|
21 |
+
class ConditionalCFM(BASECFM):
|
22 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
23 |
+
super().__init__(
|
24 |
+
n_feats=in_channels,
|
25 |
+
cfm_params=cfm_params,
|
26 |
+
n_spks=n_spks,
|
27 |
+
spk_emb_dim=spk_emb_dim,
|
28 |
+
)
|
29 |
+
self.t_scheduler = cfm_params.t_scheduler
|
30 |
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
31 |
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
32 |
+
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
33 |
+
# Just change the architecture of the estimator here
|
34 |
+
self.estimator = estimator
|
35 |
+
self.lock = threading.Lock()
|
36 |
+
|
37 |
+
@torch.inference_mode()
|
38 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
|
39 |
+
"""Forward diffusion
|
40 |
+
|
41 |
+
Args:
|
42 |
+
mu (torch.Tensor): output of encoder
|
43 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
44 |
+
mask (torch.Tensor): output_mask
|
45 |
+
shape: (batch_size, 1, mel_timesteps)
|
46 |
+
n_timesteps (int): number of diffusion steps
|
47 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
48 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
49 |
+
shape: (batch_size, spk_emb_dim)
|
50 |
+
cond: Not used but kept for future purposes
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
sample: generated mel-spectrogram
|
54 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
55 |
+
"""
|
56 |
+
|
57 |
+
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
58 |
+
cache_size = flow_cache.shape[2]
|
59 |
+
# fix prompt and overlap part mu and z
|
60 |
+
if cache_size != 0:
|
61 |
+
z[:, :, :cache_size] = flow_cache[:, :, :, 0]
|
62 |
+
mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
|
63 |
+
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
|
64 |
+
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
|
65 |
+
flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
|
66 |
+
|
67 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
68 |
+
if self.t_scheduler == 'cosine':
|
69 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
70 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
|
71 |
+
|
72 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
73 |
+
"""
|
74 |
+
Fixed euler solver for ODEs.
|
75 |
+
Args:
|
76 |
+
x (torch.Tensor): random noise
|
77 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
78 |
+
shape: (n_timesteps + 1,)
|
79 |
+
mu (torch.Tensor): output of encoder
|
80 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
81 |
+
mask (torch.Tensor): output_mask
|
82 |
+
shape: (batch_size, 1, mel_timesteps)
|
83 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
84 |
+
shape: (batch_size, spk_emb_dim)
|
85 |
+
cond: Not used but kept for future purposes
|
86 |
+
"""
|
87 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
88 |
+
t = t.unsqueeze(dim=0)
|
89 |
+
|
90 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
91 |
+
# Or in future might add like a return_all_steps flag
|
92 |
+
sol = []
|
93 |
+
|
94 |
+
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
95 |
+
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
96 |
+
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
97 |
+
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
98 |
+
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
99 |
+
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
100 |
+
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
101 |
+
for step in range(1, len(t_span)):
|
102 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
103 |
+
x_in[:] = x
|
104 |
+
mask_in[:] = mask
|
105 |
+
mu_in[0] = mu
|
106 |
+
t_in[:] = t.unsqueeze(0)
|
107 |
+
spks_in[0] = spks
|
108 |
+
cond_in[0] = cond
|
109 |
+
dphi_dt = self.forward_estimator(
|
110 |
+
x_in, mask_in,
|
111 |
+
mu_in, t_in,
|
112 |
+
spks_in,
|
113 |
+
cond_in
|
114 |
+
)
|
115 |
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
116 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
117 |
+
x = x + dt * dphi_dt
|
118 |
+
t = t + dt
|
119 |
+
sol.append(x)
|
120 |
+
if step < len(t_span) - 1:
|
121 |
+
dt = t_span[step + 1] - t
|
122 |
+
|
123 |
+
return sol[-1].float()
|
124 |
+
|
125 |
+
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
126 |
+
if isinstance(self.estimator, torch.nn.Module):
|
127 |
+
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
128 |
+
else:
|
129 |
+
with self.lock:
|
130 |
+
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
131 |
+
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
132 |
+
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
133 |
+
self.estimator.set_input_shape('t', (2,))
|
134 |
+
self.estimator.set_input_shape('spks', (2, 80))
|
135 |
+
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
136 |
+
# run trt engine
|
137 |
+
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
138 |
+
mask.contiguous().data_ptr(),
|
139 |
+
mu.contiguous().data_ptr(),
|
140 |
+
t.contiguous().data_ptr(),
|
141 |
+
spks.contiguous().data_ptr(),
|
142 |
+
cond.contiguous().data_ptr(),
|
143 |
+
x.data_ptr()])
|
144 |
+
return x
|
145 |
+
|
146 |
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
147 |
+
"""Computes diffusion loss
|
148 |
+
|
149 |
+
Args:
|
150 |
+
x1 (torch.Tensor): Target
|
151 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
152 |
+
mask (torch.Tensor): target mask
|
153 |
+
shape: (batch_size, 1, mel_timesteps)
|
154 |
+
mu (torch.Tensor): output of encoder
|
155 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
156 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
157 |
+
shape: (batch_size, spk_emb_dim)
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
loss: conditional flow matching loss
|
161 |
+
y: conditional flow
|
162 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
163 |
+
"""
|
164 |
+
b, _, t = mu.shape
|
165 |
+
|
166 |
+
# random timestep
|
167 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
168 |
+
if self.t_scheduler == 'cosine':
|
169 |
+
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
170 |
+
# sample noise p(x_0)
|
171 |
+
z = torch.randn_like(x1)
|
172 |
+
|
173 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
174 |
+
u = x1 - (1 - self.sigma_min) * z
|
175 |
+
|
176 |
+
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
177 |
+
if self.training_cfg_rate > 0:
|
178 |
+
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
179 |
+
mu = mu * cfg_mask.view(-1, 1, 1)
|
180 |
+
spks = spks * cfg_mask.view(-1, 1)
|
181 |
+
cond = cond * cfg_mask.view(-1, 1, 1)
|
182 |
+
|
183 |
+
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
184 |
+
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
185 |
+
return loss, y
|
186 |
+
|
187 |
+
|
188 |
+
class CausalConditionalCFM(ConditionalCFM):
|
189 |
+
def __init__(self, in_channels=240, cfm_params=CFM_PARAMS, n_spks=1, spk_emb_dim=80, estimator=None):
|
190 |
+
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
191 |
+
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
192 |
+
|
193 |
+
@torch.inference_mode()
|
194 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
195 |
+
"""Forward diffusion
|
196 |
+
|
197 |
+
Args:
|
198 |
+
mu (torch.Tensor): output of encoder
|
199 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
200 |
+
mask (torch.Tensor): output_mask
|
201 |
+
shape: (batch_size, 1, mel_timesteps)
|
202 |
+
n_timesteps (int): number of diffusion steps
|
203 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
204 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
205 |
+
shape: (batch_size, spk_emb_dim)
|
206 |
+
cond: Not used but kept for future purposes
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
sample: generated mel-spectrogram
|
210 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
211 |
+
"""
|
212 |
+
|
213 |
+
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
|
214 |
+
# fix prompt and overlap part mu and z
|
215 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
216 |
+
if self.t_scheduler == 'cosine':
|
217 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
218 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
|
src/chatterbox/models/s3gen/hifigan.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# jrm: adapted from CosyVoice/cosyvoice/hifigan/generator.py
|
2 |
+
# most modules should be reusable, but I found their SineGen changed a git.
|
3 |
+
|
4 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
"""HIFI-GAN"""
|
19 |
+
|
20 |
+
from typing import Dict, Optional, List
|
21 |
+
import numpy as np
|
22 |
+
from scipy.signal import get_window
|
23 |
+
import torch
|
24 |
+
import torch.nn.functional as F
|
25 |
+
from torch.nn import Conv1d
|
26 |
+
from torch.nn import ConvTranspose1d
|
27 |
+
from torch.nn.utils import remove_weight_norm
|
28 |
+
from torch.nn.utils.parametrizations import weight_norm
|
29 |
+
from torch.distributions.uniform import Uniform
|
30 |
+
from torch import nn, sin, pow
|
31 |
+
from torch.nn import Parameter
|
32 |
+
|
33 |
+
|
34 |
+
class Snake(nn.Module):
|
35 |
+
'''
|
36 |
+
Implementation of a sine-based periodic activation function
|
37 |
+
Shape:
|
38 |
+
- Input: (B, C, T)
|
39 |
+
- Output: (B, C, T), same shape as the input
|
40 |
+
Parameters:
|
41 |
+
- alpha - trainable parameter
|
42 |
+
References:
|
43 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
44 |
+
https://arxiv.org/abs/2006.08195
|
45 |
+
Examples:
|
46 |
+
>>> a1 = snake(256)
|
47 |
+
>>> x = torch.randn(256)
|
48 |
+
>>> x = a1(x)
|
49 |
+
'''
|
50 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
51 |
+
'''
|
52 |
+
Initialization.
|
53 |
+
INPUT:
|
54 |
+
- in_features: shape of the input
|
55 |
+
- alpha: trainable parameter
|
56 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
57 |
+
alpha will be trained along with the rest of your model.
|
58 |
+
'''
|
59 |
+
super(Snake, self).__init__()
|
60 |
+
self.in_features = in_features
|
61 |
+
|
62 |
+
# initialize alpha
|
63 |
+
self.alpha_logscale = alpha_logscale
|
64 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
65 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
66 |
+
else: # linear scale alphas initialized to ones
|
67 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
68 |
+
|
69 |
+
self.alpha.requires_grad = alpha_trainable
|
70 |
+
|
71 |
+
self.no_div_by_zero = 0.000000001
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
'''
|
75 |
+
Forward pass of the function.
|
76 |
+
Applies the function to the input elementwise.
|
77 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
78 |
+
'''
|
79 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
80 |
+
if self.alpha_logscale:
|
81 |
+
alpha = torch.exp(alpha)
|
82 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
83 |
+
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
def get_padding(kernel_size, dilation=1):
|
89 |
+
return int((kernel_size * dilation - dilation) / 2)
|
90 |
+
|
91 |
+
def init_weights(m, mean=0.0, std=0.01):
|
92 |
+
classname = m.__class__.__name__
|
93 |
+
if classname.find("Conv") != -1:
|
94 |
+
m.weight.data.normal_(mean, std)
|
95 |
+
|
96 |
+
|
97 |
+
"""hifigan based generator implementation.
|
98 |
+
|
99 |
+
This code is modified from https://github.com/jik876/hifi-gan
|
100 |
+
,https://github.com/kan-bayashi/ParallelWaveGAN and
|
101 |
+
https://github.com/NVIDIA/BigVGAN
|
102 |
+
|
103 |
+
"""
|
104 |
+
|
105 |
+
|
106 |
+
class ResBlock(torch.nn.Module):
|
107 |
+
"""Residual block module in HiFiGAN/BigVGAN."""
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
channels: int = 512,
|
111 |
+
kernel_size: int = 3,
|
112 |
+
dilations: List[int] = [1, 3, 5],
|
113 |
+
):
|
114 |
+
super(ResBlock, self).__init__()
|
115 |
+
self.convs1 = nn.ModuleList()
|
116 |
+
self.convs2 = nn.ModuleList()
|
117 |
+
|
118 |
+
for dilation in dilations:
|
119 |
+
self.convs1.append(
|
120 |
+
weight_norm(
|
121 |
+
Conv1d(
|
122 |
+
channels,
|
123 |
+
channels,
|
124 |
+
kernel_size,
|
125 |
+
1,
|
126 |
+
dilation=dilation,
|
127 |
+
padding=get_padding(kernel_size, dilation)
|
128 |
+
)
|
129 |
+
)
|
130 |
+
)
|
131 |
+
self.convs2.append(
|
132 |
+
weight_norm(
|
133 |
+
Conv1d(
|
134 |
+
channels,
|
135 |
+
channels,
|
136 |
+
kernel_size,
|
137 |
+
1,
|
138 |
+
dilation=1,
|
139 |
+
padding=get_padding(kernel_size, 1)
|
140 |
+
)
|
141 |
+
)
|
142 |
+
)
|
143 |
+
self.convs1.apply(init_weights)
|
144 |
+
self.convs2.apply(init_weights)
|
145 |
+
self.activations1 = nn.ModuleList([
|
146 |
+
Snake(channels, alpha_logscale=False)
|
147 |
+
for _ in range(len(self.convs1))
|
148 |
+
])
|
149 |
+
self.activations2 = nn.ModuleList([
|
150 |
+
Snake(channels, alpha_logscale=False)
|
151 |
+
for _ in range(len(self.convs2))
|
152 |
+
])
|
153 |
+
|
154 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
155 |
+
for idx in range(len(self.convs1)):
|
156 |
+
xt = self.activations1[idx](x)
|
157 |
+
xt = self.convs1[idx](xt)
|
158 |
+
xt = self.activations2[idx](xt)
|
159 |
+
xt = self.convs2[idx](xt)
|
160 |
+
x = xt + x
|
161 |
+
return x
|
162 |
+
|
163 |
+
def remove_weight_norm(self):
|
164 |
+
for idx in range(len(self.convs1)):
|
165 |
+
remove_weight_norm(self.convs1[idx])
|
166 |
+
remove_weight_norm(self.convs2[idx])
|
167 |
+
|
168 |
+
|
169 |
+
class SineGen(torch.nn.Module):
|
170 |
+
""" Definition of sine generator
|
171 |
+
SineGen(samp_rate, harmonic_num = 0,
|
172 |
+
sine_amp = 0.1, noise_std = 0.003,
|
173 |
+
voiced_threshold = 0,
|
174 |
+
flag_for_pulse=False)
|
175 |
+
samp_rate: sampling rate in Hz
|
176 |
+
harmonic_num: number of harmonic overtones (default 0)
|
177 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
178 |
+
noise_std: std of Gaussian noise (default 0.003)
|
179 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
180 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
181 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
182 |
+
segment is always sin(np.pi) or cos(0)
|
183 |
+
"""
|
184 |
+
|
185 |
+
def __init__(self, samp_rate, harmonic_num=0,
|
186 |
+
sine_amp=0.1, noise_std=0.003,
|
187 |
+
voiced_threshold=0):
|
188 |
+
super(SineGen, self).__init__()
|
189 |
+
self.sine_amp = sine_amp
|
190 |
+
self.noise_std = noise_std
|
191 |
+
self.harmonic_num = harmonic_num
|
192 |
+
self.sampling_rate = samp_rate
|
193 |
+
self.voiced_threshold = voiced_threshold
|
194 |
+
|
195 |
+
def _f02uv(self, f0):
|
196 |
+
# generate uv signal
|
197 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
198 |
+
return uv
|
199 |
+
|
200 |
+
@torch.no_grad()
|
201 |
+
def forward(self, f0):
|
202 |
+
"""
|
203 |
+
:param f0: [B, 1, sample_len], Hz
|
204 |
+
:return: [B, 1, sample_len]
|
205 |
+
"""
|
206 |
+
|
207 |
+
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
208 |
+
for i in range(self.harmonic_num + 1):
|
209 |
+
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
210 |
+
|
211 |
+
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
212 |
+
u_dist = Uniform(low=-np.pi, high=np.pi)
|
213 |
+
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
|
214 |
+
phase_vec[:, 0, :] = 0
|
215 |
+
|
216 |
+
# generate sine waveforms
|
217 |
+
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
|
218 |
+
|
219 |
+
# generate uv signal
|
220 |
+
uv = self._f02uv(f0)
|
221 |
+
|
222 |
+
# noise: for unvoiced should be similar to sine_amp
|
223 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
224 |
+
# . for voiced regions is self.noise_std
|
225 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
226 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
227 |
+
|
228 |
+
# first: set the unvoiced part to 0 by uv
|
229 |
+
# then: additive noise
|
230 |
+
sine_waves = sine_waves * uv + noise
|
231 |
+
return sine_waves, uv, noise
|
232 |
+
|
233 |
+
|
234 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
235 |
+
""" SourceModule for hn-nsf
|
236 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
237 |
+
add_noise_std=0.003, voiced_threshod=0)
|
238 |
+
sampling_rate: sampling_rate in Hz
|
239 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
240 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
241 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
242 |
+
note that amplitude of noise in unvoiced is decided
|
243 |
+
by sine_amp
|
244 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
245 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
246 |
+
F0_sampled (batchsize, length, 1)
|
247 |
+
Sine_source (batchsize, length, 1)
|
248 |
+
noise_source (batchsize, length 1)
|
249 |
+
uv (batchsize, length, 1)
|
250 |
+
"""
|
251 |
+
|
252 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
253 |
+
add_noise_std=0.003, voiced_threshod=0):
|
254 |
+
super(SourceModuleHnNSF, self).__init__()
|
255 |
+
|
256 |
+
self.sine_amp = sine_amp
|
257 |
+
self.noise_std = add_noise_std
|
258 |
+
|
259 |
+
# to produce sine waveforms
|
260 |
+
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
261 |
+
sine_amp, add_noise_std, voiced_threshod)
|
262 |
+
|
263 |
+
# to merge source harmonics into a single excitation
|
264 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
265 |
+
self.l_tanh = torch.nn.Tanh()
|
266 |
+
|
267 |
+
def forward(self, x):
|
268 |
+
"""
|
269 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
270 |
+
F0_sampled (batchsize, length, 1)
|
271 |
+
Sine_source (batchsize, length, 1)
|
272 |
+
noise_source (batchsize, length 1)
|
273 |
+
"""
|
274 |
+
# source for harmonic branch
|
275 |
+
with torch.no_grad():
|
276 |
+
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
|
277 |
+
sine_wavs = sine_wavs.transpose(1, 2)
|
278 |
+
uv = uv.transpose(1, 2)
|
279 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
280 |
+
|
281 |
+
# source for noise branch, in the same shape as uv
|
282 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
283 |
+
return sine_merge, noise, uv
|
284 |
+
|
285 |
+
|
286 |
+
class HiFTGenerator(nn.Module):
|
287 |
+
"""
|
288 |
+
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
289 |
+
https://arxiv.org/abs/2309.09493
|
290 |
+
"""
|
291 |
+
def __init__(
|
292 |
+
self,
|
293 |
+
in_channels: int = 80,
|
294 |
+
base_channels: int = 512,
|
295 |
+
nb_harmonics: int = 8,
|
296 |
+
sampling_rate: int = 22050,
|
297 |
+
nsf_alpha: float = 0.1,
|
298 |
+
nsf_sigma: float = 0.003,
|
299 |
+
nsf_voiced_threshold: float = 10,
|
300 |
+
upsample_rates: List[int] = [8, 8],
|
301 |
+
upsample_kernel_sizes: List[int] = [16, 16],
|
302 |
+
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
303 |
+
resblock_kernel_sizes: List[int] = [3, 7, 11],
|
304 |
+
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
305 |
+
source_resblock_kernel_sizes: List[int] = [7, 11],
|
306 |
+
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
|
307 |
+
lrelu_slope: float = 0.1,
|
308 |
+
audio_limit: float = 0.99,
|
309 |
+
f0_predictor: torch.nn.Module = None,
|
310 |
+
):
|
311 |
+
super(HiFTGenerator, self).__init__()
|
312 |
+
|
313 |
+
self.out_channels = 1
|
314 |
+
self.nb_harmonics = nb_harmonics
|
315 |
+
self.sampling_rate = sampling_rate
|
316 |
+
self.istft_params = istft_params
|
317 |
+
self.lrelu_slope = lrelu_slope
|
318 |
+
self.audio_limit = audio_limit
|
319 |
+
|
320 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
321 |
+
self.num_upsamples = len(upsample_rates)
|
322 |
+
self.m_source = SourceModuleHnNSF(
|
323 |
+
sampling_rate=sampling_rate,
|
324 |
+
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
325 |
+
harmonic_num=nb_harmonics,
|
326 |
+
sine_amp=nsf_alpha,
|
327 |
+
add_noise_std=nsf_sigma,
|
328 |
+
voiced_threshod=nsf_voiced_threshold)
|
329 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
330 |
+
|
331 |
+
self.conv_pre = weight_norm(
|
332 |
+
Conv1d(in_channels, base_channels, 7, 1, padding=3)
|
333 |
+
)
|
334 |
+
|
335 |
+
# Up
|
336 |
+
self.ups = nn.ModuleList()
|
337 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
338 |
+
self.ups.append(
|
339 |
+
weight_norm(
|
340 |
+
ConvTranspose1d(
|
341 |
+
base_channels // (2**i),
|
342 |
+
base_channels // (2**(i + 1)),
|
343 |
+
k,
|
344 |
+
u,
|
345 |
+
padding=(k - u) // 2,
|
346 |
+
)
|
347 |
+
)
|
348 |
+
)
|
349 |
+
|
350 |
+
# Down
|
351 |
+
self.source_downs = nn.ModuleList()
|
352 |
+
self.source_resblocks = nn.ModuleList()
|
353 |
+
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
354 |
+
downsample_cum_rates = np.cumprod(downsample_rates)
|
355 |
+
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
|
356 |
+
if u == 1:
|
357 |
+
self.source_downs.append(
|
358 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
359 |
+
)
|
360 |
+
else:
|
361 |
+
self.source_downs.append(
|
362 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
|
363 |
+
)
|
364 |
+
|
365 |
+
self.source_resblocks.append(
|
366 |
+
ResBlock(base_channels // (2 ** (i + 1)), k, d)
|
367 |
+
)
|
368 |
+
|
369 |
+
self.resblocks = nn.ModuleList()
|
370 |
+
for i in range(len(self.ups)):
|
371 |
+
ch = base_channels // (2**(i + 1))
|
372 |
+
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
373 |
+
self.resblocks.append(ResBlock(ch, k, d))
|
374 |
+
|
375 |
+
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
|
376 |
+
self.ups.apply(init_weights)
|
377 |
+
self.conv_post.apply(init_weights)
|
378 |
+
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
379 |
+
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
380 |
+
self.f0_predictor = f0_predictor
|
381 |
+
|
382 |
+
def remove_weight_norm(self):
|
383 |
+
print('Removing weight norm...')
|
384 |
+
for l in self.ups:
|
385 |
+
remove_weight_norm(l)
|
386 |
+
for l in self.resblocks:
|
387 |
+
l.remove_weight_norm()
|
388 |
+
remove_weight_norm(self.conv_pre)
|
389 |
+
remove_weight_norm(self.conv_post)
|
390 |
+
self.m_source.remove_weight_norm()
|
391 |
+
for l in self.source_downs:
|
392 |
+
remove_weight_norm(l)
|
393 |
+
for l in self.source_resblocks:
|
394 |
+
l.remove_weight_norm()
|
395 |
+
|
396 |
+
def _stft(self, x):
|
397 |
+
spec = torch.stft(
|
398 |
+
x,
|
399 |
+
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
|
400 |
+
return_complex=True)
|
401 |
+
spec = torch.view_as_real(spec) # [B, F, TT, 2]
|
402 |
+
return spec[..., 0], spec[..., 1]
|
403 |
+
|
404 |
+
def _istft(self, magnitude, phase):
|
405 |
+
magnitude = torch.clip(magnitude, max=1e2)
|
406 |
+
real = magnitude * torch.cos(phase)
|
407 |
+
img = magnitude * torch.sin(phase)
|
408 |
+
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
|
409 |
+
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
410 |
+
return inverse_transform
|
411 |
+
|
412 |
+
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
413 |
+
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
414 |
+
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
415 |
+
|
416 |
+
x = self.conv_pre(x)
|
417 |
+
for i in range(self.num_upsamples):
|
418 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
419 |
+
x = self.ups[i](x)
|
420 |
+
|
421 |
+
if i == self.num_upsamples - 1:
|
422 |
+
x = self.reflection_pad(x)
|
423 |
+
|
424 |
+
# fusion
|
425 |
+
si = self.source_downs[i](s_stft)
|
426 |
+
si = self.source_resblocks[i](si)
|
427 |
+
x = x + si
|
428 |
+
|
429 |
+
xs = None
|
430 |
+
for j in range(self.num_kernels):
|
431 |
+
if xs is None:
|
432 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
433 |
+
else:
|
434 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
435 |
+
x = xs / self.num_kernels
|
436 |
+
|
437 |
+
x = F.leaky_relu(x)
|
438 |
+
x = self.conv_post(x)
|
439 |
+
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
|
440 |
+
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
|
441 |
+
|
442 |
+
x = self._istft(magnitude, phase)
|
443 |
+
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
444 |
+
return x
|
445 |
+
|
446 |
+
def forward(
|
447 |
+
self,
|
448 |
+
batch: dict,
|
449 |
+
device: torch.device,
|
450 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
451 |
+
speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
|
452 |
+
# mel->f0
|
453 |
+
f0 = self.f0_predictor(speech_feat)
|
454 |
+
# f0->source
|
455 |
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
456 |
+
s, _, _ = self.m_source(s)
|
457 |
+
s = s.transpose(1, 2)
|
458 |
+
# mel+source->speech
|
459 |
+
generated_speech = self.decode(x=speech_feat, s=s)
|
460 |
+
return generated_speech, f0
|
461 |
+
|
462 |
+
@torch.inference_mode()
|
463 |
+
def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
464 |
+
# mel->f0
|
465 |
+
f0 = self.f0_predictor(speech_feat)
|
466 |
+
# f0->source
|
467 |
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
468 |
+
s, _, _ = self.m_source(s)
|
469 |
+
s = s.transpose(1, 2)
|
470 |
+
# use cache_source to avoid glitch
|
471 |
+
if cache_source.shape[2] != 0:
|
472 |
+
s[:, :, :cache_source.shape[2]] = cache_source
|
473 |
+
generated_speech = self.decode(x=speech_feat, s=s)
|
474 |
+
return generated_speech, s
|
src/chatterbox/models/s3gen/matcha/decoder.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from conformer import ConformerBlock
|
8 |
+
from diffusers.models.activations import get_activation
|
9 |
+
from einops import pack, rearrange, repeat
|
10 |
+
|
11 |
+
from .transformer import BasicTransformerBlock
|
12 |
+
|
13 |
+
|
14 |
+
class SinusoidalPosEmb(torch.nn.Module):
|
15 |
+
def __init__(self, dim):
|
16 |
+
super().__init__()
|
17 |
+
self.dim = dim
|
18 |
+
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
19 |
+
|
20 |
+
def forward(self, x, scale=1000):
|
21 |
+
if x.ndim < 1:
|
22 |
+
x = x.unsqueeze(0)
|
23 |
+
device = x.device
|
24 |
+
half_dim = self.dim // 2
|
25 |
+
emb = math.log(10000) / (half_dim - 1)
|
26 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
27 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
28 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
29 |
+
return emb
|
30 |
+
|
31 |
+
|
32 |
+
class Block1D(torch.nn.Module):
|
33 |
+
def __init__(self, dim, dim_out, groups=8):
|
34 |
+
super().__init__()
|
35 |
+
self.block = torch.nn.Sequential(
|
36 |
+
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
|
37 |
+
torch.nn.GroupNorm(groups, dim_out),
|
38 |
+
nn.Mish(),
|
39 |
+
)
|
40 |
+
|
41 |
+
def forward(self, x, mask):
|
42 |
+
output = self.block(x * mask)
|
43 |
+
return output * mask
|
44 |
+
|
45 |
+
|
46 |
+
class ResnetBlock1D(torch.nn.Module):
|
47 |
+
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
48 |
+
super().__init__()
|
49 |
+
self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
|
50 |
+
|
51 |
+
self.block1 = Block1D(dim, dim_out, groups=groups)
|
52 |
+
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
53 |
+
|
54 |
+
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
|
55 |
+
|
56 |
+
def forward(self, x, mask, time_emb):
|
57 |
+
h = self.block1(x, mask)
|
58 |
+
h += self.mlp(time_emb).unsqueeze(-1)
|
59 |
+
h = self.block2(h, mask)
|
60 |
+
output = h + self.res_conv(x * mask)
|
61 |
+
return output
|
62 |
+
|
63 |
+
|
64 |
+
class Downsample1D(nn.Module):
|
65 |
+
def __init__(self, dim):
|
66 |
+
super().__init__()
|
67 |
+
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
return self.conv(x)
|
71 |
+
|
72 |
+
|
73 |
+
class TimestepEmbedding(nn.Module):
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
in_channels: int,
|
77 |
+
time_embed_dim: int,
|
78 |
+
act_fn: str = "silu",
|
79 |
+
out_dim: int = None,
|
80 |
+
post_act_fn: Optional[str] = None,
|
81 |
+
cond_proj_dim=None,
|
82 |
+
):
|
83 |
+
super().__init__()
|
84 |
+
|
85 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
86 |
+
|
87 |
+
if cond_proj_dim is not None:
|
88 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
89 |
+
else:
|
90 |
+
self.cond_proj = None
|
91 |
+
|
92 |
+
self.act = get_activation(act_fn)
|
93 |
+
|
94 |
+
if out_dim is not None:
|
95 |
+
time_embed_dim_out = out_dim
|
96 |
+
else:
|
97 |
+
time_embed_dim_out = time_embed_dim
|
98 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
99 |
+
|
100 |
+
if post_act_fn is None:
|
101 |
+
self.post_act = None
|
102 |
+
else:
|
103 |
+
self.post_act = get_activation(post_act_fn)
|
104 |
+
|
105 |
+
def forward(self, sample, condition=None):
|
106 |
+
if condition is not None:
|
107 |
+
sample = sample + self.cond_proj(condition)
|
108 |
+
sample = self.linear_1(sample)
|
109 |
+
|
110 |
+
if self.act is not None:
|
111 |
+
sample = self.act(sample)
|
112 |
+
|
113 |
+
sample = self.linear_2(sample)
|
114 |
+
|
115 |
+
if self.post_act is not None:
|
116 |
+
sample = self.post_act(sample)
|
117 |
+
return sample
|
118 |
+
|
119 |
+
|
120 |
+
class Upsample1D(nn.Module):
|
121 |
+
"""A 1D upsampling layer with an optional convolution.
|
122 |
+
|
123 |
+
Parameters:
|
124 |
+
channels (`int`):
|
125 |
+
number of channels in the inputs and outputs.
|
126 |
+
use_conv (`bool`, default `False`):
|
127 |
+
option to use a convolution.
|
128 |
+
use_conv_transpose (`bool`, default `False`):
|
129 |
+
option to use a convolution transpose.
|
130 |
+
out_channels (`int`, optional):
|
131 |
+
number of output channels. Defaults to `channels`.
|
132 |
+
"""
|
133 |
+
|
134 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
|
135 |
+
super().__init__()
|
136 |
+
self.channels = channels
|
137 |
+
self.out_channels = out_channels or channels
|
138 |
+
self.use_conv = use_conv
|
139 |
+
self.use_conv_transpose = use_conv_transpose
|
140 |
+
self.name = name
|
141 |
+
|
142 |
+
self.conv = None
|
143 |
+
if use_conv_transpose:
|
144 |
+
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
145 |
+
elif use_conv:
|
146 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
147 |
+
|
148 |
+
def forward(self, inputs):
|
149 |
+
assert inputs.shape[1] == self.channels
|
150 |
+
if self.use_conv_transpose:
|
151 |
+
return self.conv(inputs)
|
152 |
+
|
153 |
+
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
154 |
+
|
155 |
+
if self.use_conv:
|
156 |
+
outputs = self.conv(outputs)
|
157 |
+
|
158 |
+
return outputs
|
159 |
+
|
160 |
+
|
161 |
+
class ConformerWrapper(ConformerBlock):
|
162 |
+
def __init__( # pylint: disable=useless-super-delegation
|
163 |
+
self,
|
164 |
+
*,
|
165 |
+
dim,
|
166 |
+
dim_head=64,
|
167 |
+
heads=8,
|
168 |
+
ff_mult=4,
|
169 |
+
conv_expansion_factor=2,
|
170 |
+
conv_kernel_size=31,
|
171 |
+
attn_dropout=0,
|
172 |
+
ff_dropout=0,
|
173 |
+
conv_dropout=0,
|
174 |
+
conv_causal=False,
|
175 |
+
):
|
176 |
+
super().__init__(
|
177 |
+
dim=dim,
|
178 |
+
dim_head=dim_head,
|
179 |
+
heads=heads,
|
180 |
+
ff_mult=ff_mult,
|
181 |
+
conv_expansion_factor=conv_expansion_factor,
|
182 |
+
conv_kernel_size=conv_kernel_size,
|
183 |
+
attn_dropout=attn_dropout,
|
184 |
+
ff_dropout=ff_dropout,
|
185 |
+
conv_dropout=conv_dropout,
|
186 |
+
conv_causal=conv_causal,
|
187 |
+
)
|
188 |
+
|
189 |
+
def forward(
|
190 |
+
self,
|
191 |
+
hidden_states,
|
192 |
+
attention_mask,
|
193 |
+
encoder_hidden_states=None,
|
194 |
+
encoder_attention_mask=None,
|
195 |
+
timestep=None,
|
196 |
+
):
|
197 |
+
return super().forward(x=hidden_states, mask=attention_mask.bool())
|
198 |
+
|
199 |
+
|
200 |
+
class Decoder(nn.Module):
|
201 |
+
def __init__(
|
202 |
+
self,
|
203 |
+
in_channels,
|
204 |
+
out_channels,
|
205 |
+
channels=(256, 256),
|
206 |
+
dropout=0.05,
|
207 |
+
attention_head_dim=64,
|
208 |
+
n_blocks=1,
|
209 |
+
num_mid_blocks=2,
|
210 |
+
num_heads=4,
|
211 |
+
act_fn="snake",
|
212 |
+
down_block_type="transformer",
|
213 |
+
mid_block_type="transformer",
|
214 |
+
up_block_type="transformer",
|
215 |
+
):
|
216 |
+
super().__init__()
|
217 |
+
channels = tuple(channels)
|
218 |
+
self.in_channels = in_channels
|
219 |
+
self.out_channels = out_channels
|
220 |
+
|
221 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
222 |
+
time_embed_dim = channels[0] * 4
|
223 |
+
self.time_mlp = TimestepEmbedding(
|
224 |
+
in_channels=in_channels,
|
225 |
+
time_embed_dim=time_embed_dim,
|
226 |
+
act_fn="silu",
|
227 |
+
)
|
228 |
+
|
229 |
+
self.down_blocks = nn.ModuleList([])
|
230 |
+
self.mid_blocks = nn.ModuleList([])
|
231 |
+
self.up_blocks = nn.ModuleList([])
|
232 |
+
|
233 |
+
output_channel = in_channels
|
234 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
235 |
+
input_channel = output_channel
|
236 |
+
output_channel = channels[i]
|
237 |
+
is_last = i == len(channels) - 1
|
238 |
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
239 |
+
transformer_blocks = nn.ModuleList(
|
240 |
+
[
|
241 |
+
self.get_block(
|
242 |
+
down_block_type,
|
243 |
+
output_channel,
|
244 |
+
attention_head_dim,
|
245 |
+
num_heads,
|
246 |
+
dropout,
|
247 |
+
act_fn,
|
248 |
+
)
|
249 |
+
for _ in range(n_blocks)
|
250 |
+
]
|
251 |
+
)
|
252 |
+
downsample = (
|
253 |
+
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
254 |
+
)
|
255 |
+
|
256 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
257 |
+
|
258 |
+
for i in range(num_mid_blocks):
|
259 |
+
input_channel = channels[-1]
|
260 |
+
out_channels = channels[-1]
|
261 |
+
|
262 |
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
263 |
+
|
264 |
+
transformer_blocks = nn.ModuleList(
|
265 |
+
[
|
266 |
+
self.get_block(
|
267 |
+
mid_block_type,
|
268 |
+
output_channel,
|
269 |
+
attention_head_dim,
|
270 |
+
num_heads,
|
271 |
+
dropout,
|
272 |
+
act_fn,
|
273 |
+
)
|
274 |
+
for _ in range(n_blocks)
|
275 |
+
]
|
276 |
+
)
|
277 |
+
|
278 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
279 |
+
|
280 |
+
channels = channels[::-1] + (channels[0],)
|
281 |
+
for i in range(len(channels) - 1):
|
282 |
+
input_channel = channels[i]
|
283 |
+
output_channel = channels[i + 1]
|
284 |
+
is_last = i == len(channels) - 2
|
285 |
+
|
286 |
+
resnet = ResnetBlock1D(
|
287 |
+
dim=2 * input_channel,
|
288 |
+
dim_out=output_channel,
|
289 |
+
time_emb_dim=time_embed_dim,
|
290 |
+
)
|
291 |
+
transformer_blocks = nn.ModuleList(
|
292 |
+
[
|
293 |
+
self.get_block(
|
294 |
+
up_block_type,
|
295 |
+
output_channel,
|
296 |
+
attention_head_dim,
|
297 |
+
num_heads,
|
298 |
+
dropout,
|
299 |
+
act_fn,
|
300 |
+
)
|
301 |
+
for _ in range(n_blocks)
|
302 |
+
]
|
303 |
+
)
|
304 |
+
upsample = (
|
305 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
306 |
+
if not is_last
|
307 |
+
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
308 |
+
)
|
309 |
+
|
310 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
311 |
+
|
312 |
+
self.final_block = Block1D(channels[-1], channels[-1])
|
313 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
314 |
+
|
315 |
+
self.initialize_weights()
|
316 |
+
# nn.init.normal_(self.final_proj.weight)
|
317 |
+
|
318 |
+
@staticmethod
|
319 |
+
def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
|
320 |
+
if block_type == "conformer":
|
321 |
+
block = ConformerWrapper(
|
322 |
+
dim=dim,
|
323 |
+
dim_head=attention_head_dim,
|
324 |
+
heads=num_heads,
|
325 |
+
ff_mult=1,
|
326 |
+
conv_expansion_factor=2,
|
327 |
+
ff_dropout=dropout,
|
328 |
+
attn_dropout=dropout,
|
329 |
+
conv_dropout=dropout,
|
330 |
+
conv_kernel_size=31,
|
331 |
+
)
|
332 |
+
elif block_type == "transformer":
|
333 |
+
block = BasicTransformerBlock(
|
334 |
+
dim=dim,
|
335 |
+
num_attention_heads=num_heads,
|
336 |
+
attention_head_dim=attention_head_dim,
|
337 |
+
dropout=dropout,
|
338 |
+
activation_fn=act_fn,
|
339 |
+
)
|
340 |
+
else:
|
341 |
+
raise ValueError(f"Unknown block type {block_type}")
|
342 |
+
|
343 |
+
return block
|
344 |
+
|
345 |
+
def initialize_weights(self):
|
346 |
+
for m in self.modules():
|
347 |
+
if isinstance(m, nn.Conv1d):
|
348 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
349 |
+
|
350 |
+
if m.bias is not None:
|
351 |
+
nn.init.constant_(m.bias, 0)
|
352 |
+
|
353 |
+
elif isinstance(m, nn.GroupNorm):
|
354 |
+
nn.init.constant_(m.weight, 1)
|
355 |
+
nn.init.constant_(m.bias, 0)
|
356 |
+
|
357 |
+
elif isinstance(m, nn.Linear):
|
358 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
359 |
+
|
360 |
+
if m.bias is not None:
|
361 |
+
nn.init.constant_(m.bias, 0)
|
362 |
+
|
363 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
364 |
+
"""Forward pass of the UNet1DConditional model.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
368 |
+
mask (_type_): shape (batch_size, 1, time)
|
369 |
+
t (_type_): shape (batch_size)
|
370 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
371 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
372 |
+
|
373 |
+
Raises:
|
374 |
+
ValueError: _description_
|
375 |
+
ValueError: _description_
|
376 |
+
|
377 |
+
Returns:
|
378 |
+
_type_: _description_
|
379 |
+
"""
|
380 |
+
|
381 |
+
t = self.time_embeddings(t)
|
382 |
+
t = self.time_mlp(t)
|
383 |
+
|
384 |
+
x = pack([x, mu], "b * t")[0]
|
385 |
+
|
386 |
+
if spks is not None:
|
387 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
388 |
+
x = pack([x, spks], "b * t")[0]
|
389 |
+
|
390 |
+
hiddens = []
|
391 |
+
masks = [mask]
|
392 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
393 |
+
mask_down = masks[-1]
|
394 |
+
x = resnet(x, mask_down, t)
|
395 |
+
x = rearrange(x, "b c t -> b t c")
|
396 |
+
mask_down = rearrange(mask_down, "b 1 t -> b t")
|
397 |
+
for transformer_block in transformer_blocks:
|
398 |
+
x = transformer_block(
|
399 |
+
hidden_states=x,
|
400 |
+
attention_mask=mask_down,
|
401 |
+
timestep=t,
|
402 |
+
)
|
403 |
+
x = rearrange(x, "b t c -> b c t")
|
404 |
+
mask_down = rearrange(mask_down, "b t -> b 1 t")
|
405 |
+
hiddens.append(x) # Save hidden states for skip connections
|
406 |
+
x = downsample(x * mask_down)
|
407 |
+
masks.append(mask_down[:, :, ::2])
|
408 |
+
|
409 |
+
masks = masks[:-1]
|
410 |
+
mask_mid = masks[-1]
|
411 |
+
|
412 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
413 |
+
x = resnet(x, mask_mid, t)
|
414 |
+
x = rearrange(x, "b c t -> b t c")
|
415 |
+
mask_mid = rearrange(mask_mid, "b 1 t -> b t")
|
416 |
+
for transformer_block in transformer_blocks:
|
417 |
+
x = transformer_block(
|
418 |
+
hidden_states=x,
|
419 |
+
attention_mask=mask_mid,
|
420 |
+
timestep=t,
|
421 |
+
)
|
422 |
+
x = rearrange(x, "b t c -> b c t")
|
423 |
+
mask_mid = rearrange(mask_mid, "b t -> b 1 t")
|
424 |
+
|
425 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
426 |
+
mask_up = masks.pop()
|
427 |
+
x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
|
428 |
+
x = rearrange(x, "b c t -> b t c")
|
429 |
+
mask_up = rearrange(mask_up, "b 1 t -> b t")
|
430 |
+
for transformer_block in transformer_blocks:
|
431 |
+
x = transformer_block(
|
432 |
+
hidden_states=x,
|
433 |
+
attention_mask=mask_up,
|
434 |
+
timestep=t,
|
435 |
+
)
|
436 |
+
x = rearrange(x, "b t c -> b c t")
|
437 |
+
mask_up = rearrange(mask_up, "b t -> b 1 t")
|
438 |
+
x = upsample(x * mask_up)
|
439 |
+
|
440 |
+
x = self.final_block(x, mask_up)
|
441 |
+
output = self.final_proj(x * mask_up)
|
442 |
+
|
443 |
+
return output * mask
|
src/chatterbox/models/s3gen/matcha/flow_matching.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from .decoder import Decoder
|
7 |
+
|
8 |
+
|
9 |
+
class BASECFM(torch.nn.Module, ABC):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
n_feats,
|
13 |
+
cfm_params,
|
14 |
+
n_spks=1,
|
15 |
+
spk_emb_dim=128,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
self.n_feats = n_feats
|
19 |
+
self.n_spks = n_spks
|
20 |
+
self.spk_emb_dim = spk_emb_dim
|
21 |
+
self.solver = cfm_params.solver
|
22 |
+
if hasattr(cfm_params, "sigma_min"):
|
23 |
+
self.sigma_min = cfm_params.sigma_min
|
24 |
+
else:
|
25 |
+
self.sigma_min = 1e-4
|
26 |
+
|
27 |
+
self.estimator = None
|
28 |
+
|
29 |
+
@torch.inference_mode()
|
30 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
31 |
+
"""Forward diffusion
|
32 |
+
|
33 |
+
Args:
|
34 |
+
mu (torch.Tensor): output of encoder
|
35 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
36 |
+
mask (torch.Tensor): output_mask
|
37 |
+
shape: (batch_size, 1, mel_timesteps)
|
38 |
+
n_timesteps (int): number of diffusion steps
|
39 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
40 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
41 |
+
shape: (batch_size, spk_emb_dim)
|
42 |
+
cond: Not used but kept for future purposes
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
sample: generated mel-spectrogram
|
46 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
47 |
+
"""
|
48 |
+
z = torch.randn_like(mu) * temperature
|
49 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
50 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
51 |
+
|
52 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
53 |
+
"""
|
54 |
+
Fixed euler solver for ODEs.
|
55 |
+
Args:
|
56 |
+
x (torch.Tensor): random noise
|
57 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
58 |
+
shape: (n_timesteps + 1,)
|
59 |
+
mu (torch.Tensor): output of encoder
|
60 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
61 |
+
mask (torch.Tensor): output_mask
|
62 |
+
shape: (batch_size, 1, mel_timesteps)
|
63 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
64 |
+
shape: (batch_size, spk_emb_dim)
|
65 |
+
cond: Not used but kept for future purposes
|
66 |
+
"""
|
67 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
68 |
+
|
69 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
70 |
+
# Or in future might add like a return_all_steps flag
|
71 |
+
sol = []
|
72 |
+
|
73 |
+
for step in range(1, len(t_span)):
|
74 |
+
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
75 |
+
|
76 |
+
x = x + dt * dphi_dt
|
77 |
+
t = t + dt
|
78 |
+
sol.append(x)
|
79 |
+
if step < len(t_span) - 1:
|
80 |
+
dt = t_span[step + 1] - t
|
81 |
+
|
82 |
+
return sol[-1]
|
83 |
+
|
84 |
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
85 |
+
"""Computes diffusion loss
|
86 |
+
|
87 |
+
Args:
|
88 |
+
x1 (torch.Tensor): Target
|
89 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
90 |
+
mask (torch.Tensor): target mask
|
91 |
+
shape: (batch_size, 1, mel_timesteps)
|
92 |
+
mu (torch.Tensor): output of encoder
|
93 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
94 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
95 |
+
shape: (batch_size, spk_emb_dim)
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
loss: conditional flow matching loss
|
99 |
+
y: conditional flow
|
100 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
101 |
+
"""
|
102 |
+
b, _, t = mu.shape
|
103 |
+
|
104 |
+
# random timestep
|
105 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
106 |
+
# sample noise p(x_0)
|
107 |
+
z = torch.randn_like(x1)
|
108 |
+
|
109 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
110 |
+
u = x1 - (1 - self.sigma_min) * z
|
111 |
+
|
112 |
+
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
|
113 |
+
torch.sum(mask) * u.shape[1]
|
114 |
+
)
|
115 |
+
return loss, y
|
116 |
+
|
117 |
+
|
118 |
+
class CFM(BASECFM):
|
119 |
+
def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
|
120 |
+
super().__init__(
|
121 |
+
n_feats=in_channels,
|
122 |
+
cfm_params=cfm_params,
|
123 |
+
n_spks=n_spks,
|
124 |
+
spk_emb_dim=spk_emb_dim,
|
125 |
+
)
|
126 |
+
|
127 |
+
in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
|
128 |
+
# Just change the architecture of the estimator here
|
129 |
+
self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
|
src/chatterbox/models/s3gen/matcha/text_encoder.py
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/jaywalnut310/glow-tts """
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
|
10 |
+
def sequence_mask(length, max_length=None):
|
11 |
+
if max_length is None:
|
12 |
+
max_length = length.max()
|
13 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
14 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
class LayerNorm(nn.Module):
|
19 |
+
def __init__(self, channels, eps=1e-4):
|
20 |
+
super().__init__()
|
21 |
+
self.channels = channels
|
22 |
+
self.eps = eps
|
23 |
+
|
24 |
+
self.gamma = torch.nn.Parameter(torch.ones(channels))
|
25 |
+
self.beta = torch.nn.Parameter(torch.zeros(channels))
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
n_dims = len(x.shape)
|
29 |
+
mean = torch.mean(x, 1, keepdim=True)
|
30 |
+
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
31 |
+
|
32 |
+
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
33 |
+
|
34 |
+
shape = [1, -1] + [1] * (n_dims - 2)
|
35 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
class ConvReluNorm(nn.Module):
|
40 |
+
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
41 |
+
super().__init__()
|
42 |
+
self.in_channels = in_channels
|
43 |
+
self.hidden_channels = hidden_channels
|
44 |
+
self.out_channels = out_channels
|
45 |
+
self.kernel_size = kernel_size
|
46 |
+
self.n_layers = n_layers
|
47 |
+
self.p_dropout = p_dropout
|
48 |
+
|
49 |
+
self.conv_layers = torch.nn.ModuleList()
|
50 |
+
self.norm_layers = torch.nn.ModuleList()
|
51 |
+
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
52 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
53 |
+
self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
|
54 |
+
for _ in range(n_layers - 1):
|
55 |
+
self.conv_layers.append(
|
56 |
+
torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
|
57 |
+
)
|
58 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
59 |
+
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
|
60 |
+
self.proj.weight.data.zero_()
|
61 |
+
self.proj.bias.data.zero_()
|
62 |
+
|
63 |
+
def forward(self, x, x_mask):
|
64 |
+
x_org = x
|
65 |
+
for i in range(self.n_layers):
|
66 |
+
x = self.conv_layers[i](x * x_mask)
|
67 |
+
x = self.norm_layers[i](x)
|
68 |
+
x = self.relu_drop(x)
|
69 |
+
x = x_org + self.proj(x)
|
70 |
+
return x * x_mask
|
71 |
+
|
72 |
+
|
73 |
+
class DurationPredictor(nn.Module):
|
74 |
+
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
|
75 |
+
super().__init__()
|
76 |
+
self.in_channels = in_channels
|
77 |
+
self.filter_channels = filter_channels
|
78 |
+
self.p_dropout = p_dropout
|
79 |
+
|
80 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
81 |
+
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
82 |
+
self.norm_1 = LayerNorm(filter_channels)
|
83 |
+
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
84 |
+
self.norm_2 = LayerNorm(filter_channels)
|
85 |
+
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
|
86 |
+
|
87 |
+
def forward(self, x, x_mask):
|
88 |
+
x = self.conv_1(x * x_mask)
|
89 |
+
x = torch.relu(x)
|
90 |
+
x = self.norm_1(x)
|
91 |
+
x = self.drop(x)
|
92 |
+
x = self.conv_2(x * x_mask)
|
93 |
+
x = torch.relu(x)
|
94 |
+
x = self.norm_2(x)
|
95 |
+
x = self.drop(x)
|
96 |
+
x = self.proj(x * x_mask)
|
97 |
+
return x * x_mask
|
98 |
+
|
99 |
+
|
100 |
+
class RotaryPositionalEmbeddings(nn.Module):
|
101 |
+
"""
|
102 |
+
## RoPE module
|
103 |
+
|
104 |
+
Rotary encoding transforms pairs of features by rotating in the 2D plane.
|
105 |
+
That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
|
106 |
+
Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
|
107 |
+
by an angle depending on the position of the token.
|
108 |
+
"""
|
109 |
+
|
110 |
+
def __init__(self, d: int, base: int = 10_000):
|
111 |
+
r"""
|
112 |
+
* `d` is the number of features $d$
|
113 |
+
* `base` is the constant used for calculating $\Theta$
|
114 |
+
"""
|
115 |
+
super().__init__()
|
116 |
+
|
117 |
+
self.base = base
|
118 |
+
self.d = int(d)
|
119 |
+
self.cos_cached = None
|
120 |
+
self.sin_cached = None
|
121 |
+
|
122 |
+
def _build_cache(self, x: torch.Tensor):
|
123 |
+
r"""
|
124 |
+
Cache $\cos$ and $\sin$ values
|
125 |
+
"""
|
126 |
+
# Return if cache is already built
|
127 |
+
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
|
128 |
+
return
|
129 |
+
|
130 |
+
# Get sequence length
|
131 |
+
seq_len = x.shape[0]
|
132 |
+
|
133 |
+
# $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
134 |
+
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
|
135 |
+
|
136 |
+
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
137 |
+
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
|
138 |
+
|
139 |
+
# Calculate the product of position index and $\theta_i$
|
140 |
+
idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
|
141 |
+
|
142 |
+
# Concatenate so that for row $m$ we have
|
143 |
+
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
|
144 |
+
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
|
145 |
+
|
146 |
+
# Cache them
|
147 |
+
self.cos_cached = idx_theta2.cos()[:, None, None, :]
|
148 |
+
self.sin_cached = idx_theta2.sin()[:, None, None, :]
|
149 |
+
|
150 |
+
def _neg_half(self, x: torch.Tensor):
|
151 |
+
# $\frac{d}{2}$
|
152 |
+
d_2 = self.d // 2
|
153 |
+
|
154 |
+
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
155 |
+
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
|
156 |
+
|
157 |
+
def forward(self, x: torch.Tensor):
|
158 |
+
"""
|
159 |
+
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
|
160 |
+
"""
|
161 |
+
# Cache $\cos$ and $\sin$ values
|
162 |
+
x = rearrange(x, "b h t d -> t b h d")
|
163 |
+
|
164 |
+
self._build_cache(x)
|
165 |
+
|
166 |
+
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
|
167 |
+
x_rope, x_pass = x[..., : self.d], x[..., self.d :]
|
168 |
+
|
169 |
+
# Calculate
|
170 |
+
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
171 |
+
neg_half_x = self._neg_half(x_rope)
|
172 |
+
|
173 |
+
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
|
174 |
+
|
175 |
+
return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
|
176 |
+
|
177 |
+
|
178 |
+
class MultiHeadAttention(nn.Module):
|
179 |
+
def __init__(
|
180 |
+
self,
|
181 |
+
channels,
|
182 |
+
out_channels,
|
183 |
+
n_heads,
|
184 |
+
heads_share=True,
|
185 |
+
p_dropout=0.0,
|
186 |
+
proximal_bias=False,
|
187 |
+
proximal_init=False,
|
188 |
+
):
|
189 |
+
super().__init__()
|
190 |
+
assert channels % n_heads == 0
|
191 |
+
|
192 |
+
self.channels = channels
|
193 |
+
self.out_channels = out_channels
|
194 |
+
self.n_heads = n_heads
|
195 |
+
self.heads_share = heads_share
|
196 |
+
self.proximal_bias = proximal_bias
|
197 |
+
self.p_dropout = p_dropout
|
198 |
+
self.attn = None
|
199 |
+
|
200 |
+
self.k_channels = channels // n_heads
|
201 |
+
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
|
202 |
+
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
|
203 |
+
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
|
204 |
+
|
205 |
+
# from https://nn.labml.ai/transformers/rope/index.html
|
206 |
+
self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
|
207 |
+
self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
|
208 |
+
|
209 |
+
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
|
210 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
211 |
+
|
212 |
+
torch.nn.init.xavier_uniform_(self.conv_q.weight)
|
213 |
+
torch.nn.init.xavier_uniform_(self.conv_k.weight)
|
214 |
+
if proximal_init:
|
215 |
+
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
|
216 |
+
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
|
217 |
+
torch.nn.init.xavier_uniform_(self.conv_v.weight)
|
218 |
+
|
219 |
+
def forward(self, x, c, attn_mask=None):
|
220 |
+
q = self.conv_q(x)
|
221 |
+
k = self.conv_k(c)
|
222 |
+
v = self.conv_v(c)
|
223 |
+
|
224 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
225 |
+
|
226 |
+
x = self.conv_o(x)
|
227 |
+
return x
|
228 |
+
|
229 |
+
def attention(self, query, key, value, mask=None):
|
230 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
231 |
+
query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads)
|
232 |
+
key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads)
|
233 |
+
value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads)
|
234 |
+
|
235 |
+
query = self.query_rotary_pe(query)
|
236 |
+
key = self.key_rotary_pe(key)
|
237 |
+
|
238 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
|
239 |
+
|
240 |
+
if self.proximal_bias:
|
241 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
242 |
+
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
243 |
+
if mask is not None:
|
244 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
245 |
+
p_attn = torch.nn.functional.softmax(scores, dim=-1)
|
246 |
+
p_attn = self.drop(p_attn)
|
247 |
+
output = torch.matmul(p_attn, value)
|
248 |
+
output = output.transpose(2, 3).contiguous().view(b, d, t_t)
|
249 |
+
return output, p_attn
|
250 |
+
|
251 |
+
@staticmethod
|
252 |
+
def _attention_bias_proximal(length):
|
253 |
+
r = torch.arange(length, dtype=torch.float32)
|
254 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
255 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
256 |
+
|
257 |
+
|
258 |
+
class FFN(nn.Module):
|
259 |
+
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
|
260 |
+
super().__init__()
|
261 |
+
self.in_channels = in_channels
|
262 |
+
self.out_channels = out_channels
|
263 |
+
self.filter_channels = filter_channels
|
264 |
+
self.kernel_size = kernel_size
|
265 |
+
self.p_dropout = p_dropout
|
266 |
+
|
267 |
+
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
268 |
+
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
269 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
270 |
+
|
271 |
+
def forward(self, x, x_mask):
|
272 |
+
x = self.conv_1(x * x_mask)
|
273 |
+
x = torch.relu(x)
|
274 |
+
x = self.drop(x)
|
275 |
+
x = self.conv_2(x * x_mask)
|
276 |
+
return x * x_mask
|
277 |
+
|
278 |
+
|
279 |
+
class Encoder(nn.Module):
|
280 |
+
def __init__(
|
281 |
+
self,
|
282 |
+
hidden_channels,
|
283 |
+
filter_channels,
|
284 |
+
n_heads,
|
285 |
+
n_layers,
|
286 |
+
kernel_size=1,
|
287 |
+
p_dropout=0.0,
|
288 |
+
**kwargs,
|
289 |
+
):
|
290 |
+
super().__init__()
|
291 |
+
self.hidden_channels = hidden_channels
|
292 |
+
self.filter_channels = filter_channels
|
293 |
+
self.n_heads = n_heads
|
294 |
+
self.n_layers = n_layers
|
295 |
+
self.kernel_size = kernel_size
|
296 |
+
self.p_dropout = p_dropout
|
297 |
+
|
298 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
299 |
+
self.attn_layers = torch.nn.ModuleList()
|
300 |
+
self.norm_layers_1 = torch.nn.ModuleList()
|
301 |
+
self.ffn_layers = torch.nn.ModuleList()
|
302 |
+
self.norm_layers_2 = torch.nn.ModuleList()
|
303 |
+
for _ in range(self.n_layers):
|
304 |
+
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
|
305 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
306 |
+
self.ffn_layers.append(
|
307 |
+
FFN(
|
308 |
+
hidden_channels,
|
309 |
+
hidden_channels,
|
310 |
+
filter_channels,
|
311 |
+
kernel_size,
|
312 |
+
p_dropout=p_dropout,
|
313 |
+
)
|
314 |
+
)
|
315 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
316 |
+
|
317 |
+
def forward(self, x, x_mask):
|
318 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
319 |
+
for i in range(self.n_layers):
|
320 |
+
x = x * x_mask
|
321 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
322 |
+
y = self.drop(y)
|
323 |
+
x = self.norm_layers_1[i](x + y)
|
324 |
+
y = self.ffn_layers[i](x, x_mask)
|
325 |
+
y = self.drop(y)
|
326 |
+
x = self.norm_layers_2[i](x + y)
|
327 |
+
x = x * x_mask
|
328 |
+
return x
|
329 |
+
|
330 |
+
|
331 |
+
class TextEncoder(nn.Module):
|
332 |
+
def __init__(
|
333 |
+
self,
|
334 |
+
encoder_type,
|
335 |
+
encoder_params,
|
336 |
+
duration_predictor_params,
|
337 |
+
n_vocab,
|
338 |
+
n_spks=1,
|
339 |
+
spk_emb_dim=128,
|
340 |
+
):
|
341 |
+
super().__init__()
|
342 |
+
self.encoder_type = encoder_type
|
343 |
+
self.n_vocab = n_vocab
|
344 |
+
self.n_feats = encoder_params.n_feats
|
345 |
+
self.n_channels = encoder_params.n_channels
|
346 |
+
self.spk_emb_dim = spk_emb_dim
|
347 |
+
self.n_spks = n_spks
|
348 |
+
|
349 |
+
self.emb = torch.nn.Embedding(n_vocab, self.n_channels)
|
350 |
+
torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5)
|
351 |
+
|
352 |
+
if encoder_params.prenet:
|
353 |
+
self.prenet = ConvReluNorm(
|
354 |
+
self.n_channels,
|
355 |
+
self.n_channels,
|
356 |
+
self.n_channels,
|
357 |
+
kernel_size=5,
|
358 |
+
n_layers=3,
|
359 |
+
p_dropout=0.5,
|
360 |
+
)
|
361 |
+
else:
|
362 |
+
self.prenet = lambda x, x_mask: x
|
363 |
+
|
364 |
+
self.encoder = Encoder(
|
365 |
+
encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0),
|
366 |
+
encoder_params.filter_channels,
|
367 |
+
encoder_params.n_heads,
|
368 |
+
encoder_params.n_layers,
|
369 |
+
encoder_params.kernel_size,
|
370 |
+
encoder_params.p_dropout,
|
371 |
+
)
|
372 |
+
|
373 |
+
self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
|
374 |
+
self.proj_w = DurationPredictor(
|
375 |
+
self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
|
376 |
+
duration_predictor_params.filter_channels_dp,
|
377 |
+
duration_predictor_params.kernel_size,
|
378 |
+
duration_predictor_params.p_dropout,
|
379 |
+
)
|
380 |
+
|
381 |
+
def forward(self, x, x_lengths, spks=None):
|
382 |
+
"""Run forward pass to the transformer based encoder and duration predictor
|
383 |
+
|
384 |
+
Args:
|
385 |
+
x (torch.Tensor): text input
|
386 |
+
shape: (batch_size, max_text_length)
|
387 |
+
x_lengths (torch.Tensor): text input lengths
|
388 |
+
shape: (batch_size,)
|
389 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
390 |
+
shape: (batch_size,)
|
391 |
+
|
392 |
+
Returns:
|
393 |
+
mu (torch.Tensor): average output of the encoder
|
394 |
+
shape: (batch_size, n_feats, max_text_length)
|
395 |
+
logw (torch.Tensor): log duration predicted by the duration predictor
|
396 |
+
shape: (batch_size, 1, max_text_length)
|
397 |
+
x_mask (torch.Tensor): mask for the text input
|
398 |
+
shape: (batch_size, 1, max_text_length)
|
399 |
+
"""
|
400 |
+
x = self.emb(x) * math.sqrt(self.n_channels)
|
401 |
+
x = torch.transpose(x, 1, -1)
|
402 |
+
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
403 |
+
|
404 |
+
x = self.prenet(x, x_mask)
|
405 |
+
if self.n_spks > 1:
|
406 |
+
x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
|
407 |
+
x = self.encoder(x, x_mask)
|
408 |
+
mu = self.proj_m(x) * x_mask
|
409 |
+
|
410 |
+
x_dp = torch.detach(x)
|
411 |
+
logw = self.proj_w(x_dp, x_mask)
|
412 |
+
|
413 |
+
return mu, logw, x_mask
|
src/chatterbox/models/s3gen/matcha/transformer.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from diffusers.models.attention import (
|
6 |
+
GEGLU,
|
7 |
+
GELU,
|
8 |
+
AdaLayerNorm,
|
9 |
+
AdaLayerNormZero,
|
10 |
+
ApproximateGELU,
|
11 |
+
)
|
12 |
+
from diffusers.models.attention_processor import Attention
|
13 |
+
from diffusers.models.lora import LoRACompatibleLinear
|
14 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
15 |
+
|
16 |
+
|
17 |
+
class SnakeBeta(nn.Module):
|
18 |
+
"""
|
19 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
20 |
+
Shape:
|
21 |
+
- Input: (B, C, T)
|
22 |
+
- Output: (B, C, T), same shape as the input
|
23 |
+
Parameters:
|
24 |
+
- alpha - trainable parameter that controls frequency
|
25 |
+
- beta - trainable parameter that controls magnitude
|
26 |
+
References:
|
27 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
28 |
+
https://arxiv.org/abs/2006.08195
|
29 |
+
Examples:
|
30 |
+
>>> a1 = snakebeta(256)
|
31 |
+
>>> x = torch.randn(256)
|
32 |
+
>>> x = a1(x)
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
36 |
+
"""
|
37 |
+
Initialization.
|
38 |
+
INPUT:
|
39 |
+
- in_features: shape of the input
|
40 |
+
- alpha - trainable parameter that controls frequency
|
41 |
+
- beta - trainable parameter that controls magnitude
|
42 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
43 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
44 |
+
alpha will be trained along with the rest of your model.
|
45 |
+
"""
|
46 |
+
super().__init__()
|
47 |
+
self.in_features = out_features if isinstance(out_features, list) else [out_features]
|
48 |
+
self.proj = LoRACompatibleLinear(in_features, out_features)
|
49 |
+
|
50 |
+
# initialize alpha
|
51 |
+
self.alpha_logscale = alpha_logscale
|
52 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
53 |
+
self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
54 |
+
self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
55 |
+
else: # linear scale alphas initialized to ones
|
56 |
+
self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
|
57 |
+
self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
|
58 |
+
|
59 |
+
self.alpha.requires_grad = alpha_trainable
|
60 |
+
self.beta.requires_grad = alpha_trainable
|
61 |
+
|
62 |
+
self.no_div_by_zero = 0.000000001
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
"""
|
66 |
+
Forward pass of the function.
|
67 |
+
Applies the function to the input elementwise.
|
68 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
69 |
+
"""
|
70 |
+
x = self.proj(x)
|
71 |
+
if self.alpha_logscale:
|
72 |
+
alpha = torch.exp(self.alpha)
|
73 |
+
beta = torch.exp(self.beta)
|
74 |
+
else:
|
75 |
+
alpha = self.alpha
|
76 |
+
beta = self.beta
|
77 |
+
|
78 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
|
79 |
+
|
80 |
+
return x
|
81 |
+
|
82 |
+
|
83 |
+
class FeedForward(nn.Module):
|
84 |
+
r"""
|
85 |
+
A feed-forward layer.
|
86 |
+
|
87 |
+
Parameters:
|
88 |
+
dim (`int`): The number of channels in the input.
|
89 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
90 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
91 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
92 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
93 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
dim: int,
|
99 |
+
dim_out: Optional[int] = None,
|
100 |
+
mult: int = 4,
|
101 |
+
dropout: float = 0.0,
|
102 |
+
activation_fn: str = "geglu",
|
103 |
+
final_dropout: bool = False,
|
104 |
+
):
|
105 |
+
super().__init__()
|
106 |
+
inner_dim = int(dim * mult)
|
107 |
+
dim_out = dim_out if dim_out is not None else dim
|
108 |
+
|
109 |
+
if activation_fn == "gelu":
|
110 |
+
act_fn = GELU(dim, inner_dim)
|
111 |
+
if activation_fn == "gelu-approximate":
|
112 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
113 |
+
elif activation_fn == "geglu":
|
114 |
+
act_fn = GEGLU(dim, inner_dim)
|
115 |
+
elif activation_fn == "geglu-approximate":
|
116 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
117 |
+
elif activation_fn == "snakebeta":
|
118 |
+
act_fn = SnakeBeta(dim, inner_dim)
|
119 |
+
|
120 |
+
self.net = nn.ModuleList([])
|
121 |
+
# project in
|
122 |
+
self.net.append(act_fn)
|
123 |
+
# project dropout
|
124 |
+
self.net.append(nn.Dropout(dropout))
|
125 |
+
# project out
|
126 |
+
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
|
127 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
128 |
+
if final_dropout:
|
129 |
+
self.net.append(nn.Dropout(dropout))
|
130 |
+
|
131 |
+
def forward(self, hidden_states):
|
132 |
+
for module in self.net:
|
133 |
+
hidden_states = module(hidden_states)
|
134 |
+
return hidden_states
|
135 |
+
|
136 |
+
|
137 |
+
@maybe_allow_in_graph
|
138 |
+
class BasicTransformerBlock(nn.Module):
|
139 |
+
r"""
|
140 |
+
A basic Transformer block.
|
141 |
+
|
142 |
+
Parameters:
|
143 |
+
dim (`int`): The number of channels in the input and output.
|
144 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
145 |
+
attention_head_dim (`int`): The number of channels in each head.
|
146 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
147 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
148 |
+
only_cross_attention (`bool`, *optional*):
|
149 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
150 |
+
double_self_attention (`bool`, *optional*):
|
151 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
152 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
153 |
+
num_embeds_ada_norm (:
|
154 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
155 |
+
attention_bias (:
|
156 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(
|
160 |
+
self,
|
161 |
+
dim: int,
|
162 |
+
num_attention_heads: int,
|
163 |
+
attention_head_dim: int,
|
164 |
+
dropout=0.0,
|
165 |
+
cross_attention_dim: Optional[int] = None,
|
166 |
+
activation_fn: str = "geglu",
|
167 |
+
num_embeds_ada_norm: Optional[int] = None,
|
168 |
+
attention_bias: bool = False,
|
169 |
+
only_cross_attention: bool = False,
|
170 |
+
double_self_attention: bool = False,
|
171 |
+
upcast_attention: bool = False,
|
172 |
+
norm_elementwise_affine: bool = True,
|
173 |
+
norm_type: str = "layer_norm",
|
174 |
+
final_dropout: bool = False,
|
175 |
+
):
|
176 |
+
super().__init__()
|
177 |
+
self.only_cross_attention = only_cross_attention
|
178 |
+
|
179 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
180 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
181 |
+
|
182 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
183 |
+
raise ValueError(
|
184 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
185 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
186 |
+
)
|
187 |
+
|
188 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
189 |
+
# 1. Self-Attn
|
190 |
+
if self.use_ada_layer_norm:
|
191 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
192 |
+
elif self.use_ada_layer_norm_zero:
|
193 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
194 |
+
else:
|
195 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
196 |
+
self.attn1 = Attention(
|
197 |
+
query_dim=dim,
|
198 |
+
heads=num_attention_heads,
|
199 |
+
dim_head=attention_head_dim,
|
200 |
+
dropout=dropout,
|
201 |
+
bias=attention_bias,
|
202 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
203 |
+
upcast_attention=upcast_attention,
|
204 |
+
)
|
205 |
+
|
206 |
+
# 2. Cross-Attn
|
207 |
+
if cross_attention_dim is not None or double_self_attention:
|
208 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
209 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
210 |
+
# the second cross attention block.
|
211 |
+
self.norm2 = (
|
212 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
213 |
+
if self.use_ada_layer_norm
|
214 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
215 |
+
)
|
216 |
+
self.attn2 = Attention(
|
217 |
+
query_dim=dim,
|
218 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
219 |
+
heads=num_attention_heads,
|
220 |
+
dim_head=attention_head_dim,
|
221 |
+
dropout=dropout,
|
222 |
+
bias=attention_bias,
|
223 |
+
upcast_attention=upcast_attention,
|
224 |
+
# scale_qk=False, # uncomment this to not to use flash attention
|
225 |
+
) # is self-attn if encoder_hidden_states is none
|
226 |
+
else:
|
227 |
+
self.norm2 = None
|
228 |
+
self.attn2 = None
|
229 |
+
|
230 |
+
# 3. Feed-forward
|
231 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
232 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
233 |
+
|
234 |
+
# let chunk size default to None
|
235 |
+
self._chunk_size = None
|
236 |
+
self._chunk_dim = 0
|
237 |
+
|
238 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
239 |
+
# Sets chunk feed-forward
|
240 |
+
self._chunk_size = chunk_size
|
241 |
+
self._chunk_dim = dim
|
242 |
+
|
243 |
+
def forward(
|
244 |
+
self,
|
245 |
+
hidden_states: torch.FloatTensor,
|
246 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
247 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
248 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
249 |
+
timestep: Optional[torch.LongTensor] = None,
|
250 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
251 |
+
class_labels: Optional[torch.LongTensor] = None,
|
252 |
+
):
|
253 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
254 |
+
# 1. Self-Attention
|
255 |
+
if self.use_ada_layer_norm:
|
256 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
257 |
+
elif self.use_ada_layer_norm_zero:
|
258 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
259 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
260 |
+
)
|
261 |
+
else:
|
262 |
+
norm_hidden_states = self.norm1(hidden_states)
|
263 |
+
|
264 |
+
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
265 |
+
|
266 |
+
attn_output = self.attn1(
|
267 |
+
norm_hidden_states,
|
268 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
269 |
+
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
|
270 |
+
**cross_attention_kwargs,
|
271 |
+
)
|
272 |
+
if self.use_ada_layer_norm_zero:
|
273 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
274 |
+
hidden_states = attn_output + hidden_states
|
275 |
+
|
276 |
+
# 2. Cross-Attention
|
277 |
+
if self.attn2 is not None:
|
278 |
+
norm_hidden_states = (
|
279 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
280 |
+
)
|
281 |
+
|
282 |
+
attn_output = self.attn2(
|
283 |
+
norm_hidden_states,
|
284 |
+
encoder_hidden_states=encoder_hidden_states,
|
285 |
+
attention_mask=encoder_attention_mask,
|
286 |
+
**cross_attention_kwargs,
|
287 |
+
)
|
288 |
+
hidden_states = attn_output + hidden_states
|
289 |
+
|
290 |
+
# 3. Feed-forward
|
291 |
+
norm_hidden_states = self.norm3(hidden_states)
|
292 |
+
|
293 |
+
if self.use_ada_layer_norm_zero:
|
294 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
295 |
+
|
296 |
+
if self._chunk_size is not None:
|
297 |
+
# "feed_forward_chunk_size" can be used to save memory
|
298 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
299 |
+
raise ValueError(
|
300 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
301 |
+
)
|
302 |
+
|
303 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
304 |
+
ff_output = torch.cat(
|
305 |
+
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
306 |
+
dim=self._chunk_dim,
|
307 |
+
)
|
308 |
+
else:
|
309 |
+
ff_output = self.ff(norm_hidden_states)
|
310 |
+
|
311 |
+
if self.use_ada_layer_norm_zero:
|
312 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
313 |
+
|
314 |
+
hidden_states = ff_output + hidden_states
|
315 |
+
|
316 |
+
return hidden_states
|
src/chatterbox/models/s3gen/s3gen.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import logging
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import torchaudio as ta
|
20 |
+
from functools import lru_cache
|
21 |
+
from typing import Optional
|
22 |
+
|
23 |
+
from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer
|
24 |
+
from .const import S3GEN_SR
|
25 |
+
from .flow import CausalMaskedDiffWithXvec
|
26 |
+
from .xvector import CAMPPlus
|
27 |
+
from .utils.mel import mel_spectrogram
|
28 |
+
from .f0_predictor import ConvRNNF0Predictor
|
29 |
+
from .hifigan import HiFTGenerator
|
30 |
+
from .transformer.upsample_encoder import UpsampleConformerEncoder
|
31 |
+
from .flow_matching import CausalConditionalCFM
|
32 |
+
from .decoder import ConditionalDecoder
|
33 |
+
from .configs import CFM_PARAMS
|
34 |
+
|
35 |
+
|
36 |
+
def drop_invalid_tokens(x):
|
37 |
+
assert len(x.shape) <= 2 and x.shape[0] == 1, "only batch size of one allowed for now"
|
38 |
+
return x[x < SPEECH_VOCAB_SIZE]
|
39 |
+
|
40 |
+
|
41 |
+
# TODO: global resampler cache
|
42 |
+
@lru_cache(100)
|
43 |
+
def get_resampler(src_sr, dst_sr, device):
|
44 |
+
return ta.transforms.Resample(src_sr, dst_sr).to(device)
|
45 |
+
|
46 |
+
|
47 |
+
class S3Token2Mel(torch.nn.Module):
|
48 |
+
"""
|
49 |
+
CosyVoice2's CFM decoder maps S3 speech tokens to mel-spectrograms.
|
50 |
+
|
51 |
+
TODO: make these modules configurable?
|
52 |
+
"""
|
53 |
+
def __init__(self):
|
54 |
+
super().__init__()
|
55 |
+
self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz")
|
56 |
+
self.mel_extractor = mel_spectrogram # TODO: make it a torch module?
|
57 |
+
self.speaker_encoder = CAMPPlus() # use default args
|
58 |
+
|
59 |
+
encoder = UpsampleConformerEncoder(
|
60 |
+
output_size=512,
|
61 |
+
attention_heads=8,
|
62 |
+
linear_units=2048,
|
63 |
+
num_blocks=6,
|
64 |
+
dropout_rate=0.1,
|
65 |
+
positional_dropout_rate=0.1,
|
66 |
+
attention_dropout_rate=0.1,
|
67 |
+
normalize_before=True,
|
68 |
+
input_layer='linear',
|
69 |
+
pos_enc_layer_type='rel_pos_espnet',
|
70 |
+
selfattention_layer_type='rel_selfattn',
|
71 |
+
input_size=512,
|
72 |
+
use_cnn_module=False,
|
73 |
+
macaron_style=False,
|
74 |
+
)
|
75 |
+
|
76 |
+
estimator = ConditionalDecoder(
|
77 |
+
in_channels=320,
|
78 |
+
out_channels=80,
|
79 |
+
causal=True,
|
80 |
+
channels=[256],
|
81 |
+
dropout=0.0,
|
82 |
+
attention_head_dim=64,
|
83 |
+
n_blocks=4,
|
84 |
+
num_mid_blocks=12,
|
85 |
+
num_heads=8,
|
86 |
+
act_fn='gelu',
|
87 |
+
)
|
88 |
+
cfm_params = CFM_PARAMS
|
89 |
+
decoder = CausalConditionalCFM(
|
90 |
+
spk_emb_dim=80,
|
91 |
+
cfm_params=cfm_params,
|
92 |
+
estimator=estimator,
|
93 |
+
)
|
94 |
+
|
95 |
+
self.flow = CausalMaskedDiffWithXvec(
|
96 |
+
encoder=encoder,
|
97 |
+
decoder=decoder
|
98 |
+
)
|
99 |
+
|
100 |
+
self.resamplers = {}
|
101 |
+
|
102 |
+
@property
|
103 |
+
def device(self):
|
104 |
+
params = self.tokenizer.parameters()
|
105 |
+
return next(params).device
|
106 |
+
|
107 |
+
def embed_ref(
|
108 |
+
self,
|
109 |
+
ref_wav: torch.Tensor,
|
110 |
+
ref_sr: int,
|
111 |
+
device="auto",
|
112 |
+
ref_fade_out=True,
|
113 |
+
):
|
114 |
+
device = self.device if device == "auto" else device
|
115 |
+
if isinstance(ref_wav, np.ndarray):
|
116 |
+
ref_wav = torch.from_numpy(ref_wav).float()
|
117 |
+
|
118 |
+
if ref_wav.device != device:
|
119 |
+
ref_wav = ref_wav.to(device)
|
120 |
+
|
121 |
+
if len(ref_wav.shape) == 1:
|
122 |
+
ref_wav = ref_wav.unsqueeze(0) # (B, L)
|
123 |
+
|
124 |
+
if ref_wav.size(1) > 10 * ref_sr:
|
125 |
+
print("WARNING: cosydec received ref longer than 10s")
|
126 |
+
|
127 |
+
ref_wav_24 = ref_wav
|
128 |
+
if ref_sr != S3GEN_SR:
|
129 |
+
ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav)
|
130 |
+
|
131 |
+
ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(device)
|
132 |
+
ref_mels_24_len = None
|
133 |
+
|
134 |
+
# Resample to 16kHz
|
135 |
+
ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav).to(device)
|
136 |
+
|
137 |
+
# Speaker embedding
|
138 |
+
ref_x_vector = self.speaker_encoder.inference(ref_wav_16)
|
139 |
+
|
140 |
+
# Tokenize 16khz reference
|
141 |
+
ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16)
|
142 |
+
|
143 |
+
# Make sure mel_len = 2 * stoken_len (happens when the input is not padded to multiple of 40ms)
|
144 |
+
if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]:
|
145 |
+
logging.warning(
|
146 |
+
"Reference mel length is not equal to 2 * reference token length.\n"
|
147 |
+
)
|
148 |
+
ref_speech_tokens = ref_speech_tokens[:, :ref_mels_24.shape[1] // 2]
|
149 |
+
ref_speech_token_lens[0] = ref_speech_tokens.shape[1]
|
150 |
+
|
151 |
+
return dict(
|
152 |
+
prompt_token=ref_speech_tokens.to(device),
|
153 |
+
prompt_token_len=ref_speech_token_lens,
|
154 |
+
prompt_feat=ref_mels_24,
|
155 |
+
prompt_feat_len=ref_mels_24_len,
|
156 |
+
embedding=ref_x_vector,
|
157 |
+
)
|
158 |
+
|
159 |
+
def forward(
|
160 |
+
self,
|
161 |
+
speech_tokens: torch.LongTensor,
|
162 |
+
# locally-computed ref embedding (mutex with ref_dict)
|
163 |
+
ref_wav: Optional[torch.Tensor],
|
164 |
+
ref_sr: Optional[int],
|
165 |
+
# pre-computed ref embedding (prod API)
|
166 |
+
ref_dict: Optional[dict] = None,
|
167 |
+
finalize: bool = False,
|
168 |
+
):
|
169 |
+
"""
|
170 |
+
Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from.
|
171 |
+
|
172 |
+
NOTE:
|
173 |
+
- The speaker encoder accepts 16 kHz waveform.
|
174 |
+
- S3TokenizerV2 accepts 16 kHz waveform.
|
175 |
+
- The mel-spectrogram for the reference assumes 24 kHz input signal.
|
176 |
+
- This function is designed for batch_size=1 only.
|
177 |
+
|
178 |
+
Args
|
179 |
+
----
|
180 |
+
- `speech_tokens`: S3 speech tokens [B=1, T]
|
181 |
+
- `ref_wav`: reference waveform (`torch.Tensor` with shape=[B=1, T])
|
182 |
+
- `ref_sr`: reference sample rate
|
183 |
+
- `finalize`: whether streaming is finished or not. Note that if False, the last 3 tokens will be ignored.
|
184 |
+
"""
|
185 |
+
assert (ref_wav is None) ^ (ref_dict is None), f"Must provide exactly one of ref_wav or ref_dict (got {ref_wav} and {ref_dict})"
|
186 |
+
|
187 |
+
if ref_dict is None:
|
188 |
+
ref_dict = self.embed_ref(ref_wav, ref_sr)
|
189 |
+
else:
|
190 |
+
# type/device casting (all values will be numpy if it's from a prod API call)
|
191 |
+
for rk in list(ref_dict):
|
192 |
+
if isinstance(ref_dict[rk], np.ndarray):
|
193 |
+
ref_dict[rk] = torch.from_numpy(ref_dict[rk])
|
194 |
+
if torch.is_tensor(ref_dict[rk]):
|
195 |
+
ref_dict[rk] = ref_dict[rk].to(self.device)
|
196 |
+
|
197 |
+
if len(speech_tokens.shape) == 1:
|
198 |
+
speech_tokens = speech_tokens.unsqueeze(0)
|
199 |
+
|
200 |
+
# assert speech_tokens.shape[0] == 1, "only batch size of one allowed for now"
|
201 |
+
speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device)
|
202 |
+
|
203 |
+
output_mels, _ = self.flow.inference(
|
204 |
+
token=speech_tokens,
|
205 |
+
token_len=speech_token_lens,
|
206 |
+
finalize=finalize,
|
207 |
+
**ref_dict,
|
208 |
+
)
|
209 |
+
return output_mels
|
210 |
+
|
211 |
+
|
212 |
+
class S3Token2Wav(S3Token2Mel):
|
213 |
+
"""
|
214 |
+
The decoder of CosyVoice2 is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules.
|
215 |
+
|
216 |
+
TODO: make these modules configurable?
|
217 |
+
"""
|
218 |
+
|
219 |
+
def __init__(self):
|
220 |
+
super().__init__()
|
221 |
+
|
222 |
+
f0_predictor = ConvRNNF0Predictor()
|
223 |
+
self.mel2wav = HiFTGenerator(
|
224 |
+
sampling_rate=S3GEN_SR,
|
225 |
+
upsample_rates=[8, 5, 3],
|
226 |
+
upsample_kernel_sizes=[16, 11, 7],
|
227 |
+
source_resblock_kernel_sizes=[7, 7, 11],
|
228 |
+
source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
229 |
+
f0_predictor=f0_predictor,
|
230 |
+
)
|
231 |
+
|
232 |
+
# silence out a few ms and fade audio in to reduce artifacts
|
233 |
+
n_trim = S3GEN_SR // 50 # 20ms = half of a frame
|
234 |
+
trim_fade = torch.zeros(2 * n_trim)
|
235 |
+
trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2
|
236 |
+
self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting)
|
237 |
+
|
238 |
+
def forward(
|
239 |
+
self,
|
240 |
+
speech_tokens,
|
241 |
+
# locally-computed ref embedding (mutex with ref_dict)
|
242 |
+
ref_wav: Optional[torch.Tensor],
|
243 |
+
ref_sr: Optional[int],
|
244 |
+
# pre-computed ref embedding (prod API)
|
245 |
+
ref_dict: Optional[dict] = None,
|
246 |
+
finalize: bool = False
|
247 |
+
):
|
248 |
+
output_mels = super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
|
249 |
+
|
250 |
+
# TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now.
|
251 |
+
hift_cache_source = torch.zeros(1, 1, 0).to(self.device)
|
252 |
+
|
253 |
+
output_wavs, *_ = self.mel2wav.inference(speech_feat=output_mels, cache_source=hift_cache_source)
|
254 |
+
|
255 |
+
if not self.training:
|
256 |
+
# NOTE: ad-hoc method to reduce "spillover" from the reference clip.
|
257 |
+
output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
|
258 |
+
|
259 |
+
return output_wavs
|
260 |
+
|
261 |
+
@torch.inference_mode()
|
262 |
+
def flow_inference(
|
263 |
+
self,
|
264 |
+
speech_tokens,
|
265 |
+
# locally-computed ref embedding (mutex with ref_dict)
|
266 |
+
ref_wav: Optional[torch.Tensor] = None,
|
267 |
+
ref_sr: Optional[int] = None,
|
268 |
+
# pre-computed ref embedding (prod API)
|
269 |
+
ref_dict: Optional[dict] = None,
|
270 |
+
finalize: bool = False,
|
271 |
+
):
|
272 |
+
return super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
|
273 |
+
|
274 |
+
@torch.inference_mode()
|
275 |
+
def hift_inference(self, speech_feat, cache_source: torch.Tensor = None):
|
276 |
+
if cache_source is None:
|
277 |
+
cache_source = torch.zeros(1, 1, 0).to(self.device)
|
278 |
+
return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source)
|
279 |
+
|
280 |
+
@torch.inference_mode()
|
281 |
+
def inference(
|
282 |
+
self,
|
283 |
+
speech_tokens,
|
284 |
+
# locally-computed ref embedding (mutex with ref_dict)
|
285 |
+
ref_wav: Optional[torch.Tensor] = None,
|
286 |
+
ref_sr: Optional[int] = None,
|
287 |
+
# pre-computed ref embedding (prod API)
|
288 |
+
ref_dict: Optional[dict] = None,
|
289 |
+
cache_source: torch.Tensor = None, # NOTE: this arg is for streaming, it can probably be removed here
|
290 |
+
finalize: bool = True,
|
291 |
+
):
|
292 |
+
output_mels = self.flow_inference(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
|
293 |
+
output_wavs, output_sources = self.hift_inference(output_mels, cache_source)
|
294 |
+
|
295 |
+
# NOTE: ad-hoc method to reduce "spillover" from the reference clip.
|
296 |
+
output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
|
297 |
+
|
298 |
+
return output_wavs, output_sources
|
src/chatterbox/models/s3gen/transformer/__init__.py
ADDED
File without changes
|
src/chatterbox/models/s3gen/transformer/activation.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
|
2 |
+
# 2020 Northwestern Polytechnical University (Pengcheng Guo)
|
3 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
4 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
"""Swish() activation function for Conformer."""
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn, sin, pow
|
21 |
+
from torch.nn import Parameter
|
22 |
+
|
23 |
+
|
24 |
+
class Swish(torch.nn.Module):
|
25 |
+
"""Construct an Swish object."""
|
26 |
+
|
27 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
28 |
+
"""Return Swish activation function."""
|
29 |
+
return x * torch.sigmoid(x)
|
30 |
+
|
31 |
+
|
32 |
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
33 |
+
# LICENSE is in incl_licenses directory.
|
34 |
+
class Snake(nn.Module):
|
35 |
+
'''
|
36 |
+
Implementation of a sine-based periodic activation function
|
37 |
+
Shape:
|
38 |
+
- Input: (B, C, T)
|
39 |
+
- Output: (B, C, T), same shape as the input
|
40 |
+
Parameters:
|
41 |
+
- alpha - trainable parameter
|
42 |
+
References:
|
43 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
44 |
+
https://arxiv.org/abs/2006.08195
|
45 |
+
Examples:
|
46 |
+
>>> a1 = snake(256)
|
47 |
+
>>> x = torch.randn(256)
|
48 |
+
>>> x = a1(x)
|
49 |
+
'''
|
50 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
51 |
+
'''
|
52 |
+
Initialization.
|
53 |
+
INPUT:
|
54 |
+
- in_features: shape of the input
|
55 |
+
- alpha: trainable parameter
|
56 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
57 |
+
alpha will be trained along with the rest of your model.
|
58 |
+
'''
|
59 |
+
super(Snake, self).__init__()
|
60 |
+
self.in_features = in_features
|
61 |
+
|
62 |
+
# initialize alpha
|
63 |
+
self.alpha_logscale = alpha_logscale
|
64 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
65 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
66 |
+
else: # linear scale alphas initialized to ones
|
67 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
68 |
+
|
69 |
+
self.alpha.requires_grad = alpha_trainable
|
70 |
+
|
71 |
+
self.no_div_by_zero = 0.000000001
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
'''
|
75 |
+
Forward pass of the function.
|
76 |
+
Applies the function to the input elementwise.
|
77 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
78 |
+
'''
|
79 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
80 |
+
if self.alpha_logscale:
|
81 |
+
alpha = torch.exp(alpha)
|
82 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
83 |
+
|
84 |
+
return x
|
src/chatterbox/models/s3gen/transformer/attention.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019 Shigeki Karita
|
2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
3 |
+
# 2022 Xingchen Song ([email protected])
|
4 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
"""Multi-Head Attention layer definition."""
|
18 |
+
|
19 |
+
import math
|
20 |
+
from typing import Tuple
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from torch import nn
|
24 |
+
|
25 |
+
|
26 |
+
class MultiHeadedAttention(nn.Module):
|
27 |
+
"""Multi-Head Attention layer.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
n_head (int): The number of heads.
|
31 |
+
n_feat (int): The number of features.
|
32 |
+
dropout_rate (float): Dropout rate.
|
33 |
+
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
n_head: int,
|
38 |
+
n_feat: int,
|
39 |
+
dropout_rate: float,
|
40 |
+
key_bias: bool = True):
|
41 |
+
"""Construct an MultiHeadedAttention object."""
|
42 |
+
super().__init__()
|
43 |
+
assert n_feat % n_head == 0
|
44 |
+
# We assume d_v always equals d_k
|
45 |
+
self.d_k = n_feat // n_head
|
46 |
+
self.h = n_head
|
47 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
48 |
+
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
|
49 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
50 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
51 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
52 |
+
|
53 |
+
def forward_qkv(
|
54 |
+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
55 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
56 |
+
"""Transform query, key and value.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
60 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
61 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
torch.Tensor: Transformed query tensor, size
|
65 |
+
(#batch, n_head, time1, d_k).
|
66 |
+
torch.Tensor: Transformed key tensor, size
|
67 |
+
(#batch, n_head, time2, d_k).
|
68 |
+
torch.Tensor: Transformed value tensor, size
|
69 |
+
(#batch, n_head, time2, d_k).
|
70 |
+
|
71 |
+
"""
|
72 |
+
n_batch = query.size(0)
|
73 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
74 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
75 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
76 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
77 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
78 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
79 |
+
|
80 |
+
return q, k, v
|
81 |
+
|
82 |
+
def forward_attention(
|
83 |
+
self,
|
84 |
+
value: torch.Tensor,
|
85 |
+
scores: torch.Tensor,
|
86 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
87 |
+
) -> torch.Tensor:
|
88 |
+
"""Compute attention context vector.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
value (torch.Tensor): Transformed value, size
|
92 |
+
(#batch, n_head, time2, d_k).
|
93 |
+
scores (torch.Tensor): Attention score, size
|
94 |
+
(#batch, n_head, time1, time2).
|
95 |
+
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
|
96 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
100 |
+
weighted by the attention score (#batch, time1, time2).
|
101 |
+
|
102 |
+
"""
|
103 |
+
n_batch = value.size(0)
|
104 |
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
|
105 |
+
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
106 |
+
# 1st chunk to ease the onnx export.]
|
107 |
+
# 2. pytorch training
|
108 |
+
if mask.size(2) > 0: # time2 > 0
|
109 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
110 |
+
# For last chunk, time2 might be larger than scores.size(-1)
|
111 |
+
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
112 |
+
scores = scores.masked_fill(mask, -float('inf'))
|
113 |
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
114 |
+
mask, 0.0) # (batch, head, time1, time2)
|
115 |
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
|
116 |
+
# 1. onnx(16/-1, -1/-1, 16/0)
|
117 |
+
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
118 |
+
else:
|
119 |
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
120 |
+
|
121 |
+
p_attn = self.dropout(attn)
|
122 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
123 |
+
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
|
124 |
+
self.h * self.d_k)
|
125 |
+
) # (batch, time1, d_model)
|
126 |
+
|
127 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
128 |
+
|
129 |
+
def forward(
|
130 |
+
self,
|
131 |
+
query: torch.Tensor,
|
132 |
+
key: torch.Tensor,
|
133 |
+
value: torch.Tensor,
|
134 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
135 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
136 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
137 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
138 |
+
"""Compute scaled dot product attention.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
142 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
143 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
144 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
145 |
+
(#batch, time1, time2).
|
146 |
+
1.When applying cross attention between decoder and encoder,
|
147 |
+
the batch padding mask for input is in (#batch, 1, T) shape.
|
148 |
+
2.When applying self attention of encoder,
|
149 |
+
the mask is in (#batch, T, T) shape.
|
150 |
+
3.When applying self attention of decoder,
|
151 |
+
the mask is in (#batch, L, L) shape.
|
152 |
+
4.If the different position in decoder see different block
|
153 |
+
of the encoder, such as Mocha, the passed in mask could be
|
154 |
+
in (#batch, L, T) shape. But there is no such case in current
|
155 |
+
CosyVoice.
|
156 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
157 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
158 |
+
and `head * d_k == size`
|
159 |
+
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
163 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
164 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
165 |
+
and `head * d_k == size`
|
166 |
+
|
167 |
+
"""
|
168 |
+
q, k, v = self.forward_qkv(query, key, value)
|
169 |
+
|
170 |
+
# NOTE(xcsong):
|
171 |
+
# when export onnx model, for 1st chunk, we feed
|
172 |
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
173 |
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
174 |
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
175 |
+
# and we will always do splitting and
|
176 |
+
# concatnation(this will simplify onnx export). Note that
|
177 |
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
178 |
+
# when export jit model, for 1st chunk, we always feed
|
179 |
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
180 |
+
# >>> a = torch.ones((1, 2, 0, 4))
|
181 |
+
# >>> b = torch.ones((1, 2, 3, 4))
|
182 |
+
# >>> c = torch.cat((a, b), dim=2)
|
183 |
+
# >>> torch.equal(b, c) # True
|
184 |
+
# >>> d = torch.split(a, 2, dim=-1)
|
185 |
+
# >>> torch.equal(d[0], d[1]) # True
|
186 |
+
if cache.size(0) > 0:
|
187 |
+
key_cache, value_cache = torch.split(cache,
|
188 |
+
cache.size(-1) // 2,
|
189 |
+
dim=-1)
|
190 |
+
k = torch.cat([key_cache, k], dim=2)
|
191 |
+
v = torch.cat([value_cache, v], dim=2)
|
192 |
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
193 |
+
# non-trivial to calculate `next_cache_start` here.
|
194 |
+
new_cache = torch.cat((k, v), dim=-1)
|
195 |
+
|
196 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
197 |
+
return self.forward_attention(v, scores, mask), new_cache
|
198 |
+
|
199 |
+
|
200 |
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
201 |
+
"""Multi-Head Attention layer with relative position encoding.
|
202 |
+
Paper: https://arxiv.org/abs/1901.02860
|
203 |
+
Args:
|
204 |
+
n_head (int): The number of heads.
|
205 |
+
n_feat (int): The number of features.
|
206 |
+
dropout_rate (float): Dropout rate.
|
207 |
+
"""
|
208 |
+
|
209 |
+
def __init__(self,
|
210 |
+
n_head: int,
|
211 |
+
n_feat: int,
|
212 |
+
dropout_rate: float,
|
213 |
+
key_bias: bool = True):
|
214 |
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
215 |
+
super().__init__(n_head, n_feat, dropout_rate, key_bias)
|
216 |
+
# linear transformation for positional encoding
|
217 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
218 |
+
# these two learnable bias are used in matrix c and matrix d
|
219 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
220 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
221 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
222 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
223 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
224 |
+
|
225 |
+
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
226 |
+
"""Compute relative positional encoding.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
230 |
+
time1 means the length of query vector.
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
torch.Tensor: Output tensor.
|
234 |
+
|
235 |
+
"""
|
236 |
+
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
237 |
+
device=x.device,
|
238 |
+
dtype=x.dtype)
|
239 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
240 |
+
|
241 |
+
x_padded = x_padded.view(x.size()[0],
|
242 |
+
x.size()[1],
|
243 |
+
x.size(3) + 1, x.size(2))
|
244 |
+
x = x_padded[:, :, 1:].view_as(x)[
|
245 |
+
:, :, :, : x.size(-1) // 2 + 1
|
246 |
+
] # only keep the positions from 0 to time2
|
247 |
+
return x
|
248 |
+
|
249 |
+
def forward(
|
250 |
+
self,
|
251 |
+
query: torch.Tensor,
|
252 |
+
key: torch.Tensor,
|
253 |
+
value: torch.Tensor,
|
254 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
255 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
256 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
257 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
258 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
259 |
+
Args:
|
260 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
261 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
262 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
263 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
264 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
265 |
+
pos_emb (torch.Tensor): Positional embedding tensor
|
266 |
+
(#batch, time2, size).
|
267 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
268 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
269 |
+
and `head * d_k == size`
|
270 |
+
Returns:
|
271 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
272 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
273 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
274 |
+
and `head * d_k == size`
|
275 |
+
"""
|
276 |
+
q, k, v = self.forward_qkv(query, key, value)
|
277 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
278 |
+
|
279 |
+
# NOTE(xcsong):
|
280 |
+
# when export onnx model, for 1st chunk, we feed
|
281 |
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
282 |
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
283 |
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
284 |
+
# and we will always do splitting and
|
285 |
+
# concatnation(this will simplify onnx export). Note that
|
286 |
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
287 |
+
# when export jit model, for 1st chunk, we always feed
|
288 |
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
289 |
+
# >>> a = torch.ones((1, 2, 0, 4))
|
290 |
+
# >>> b = torch.ones((1, 2, 3, 4))
|
291 |
+
# >>> c = torch.cat((a, b), dim=2)
|
292 |
+
# >>> torch.equal(b, c) # True
|
293 |
+
# >>> d = torch.split(a, 2, dim=-1)
|
294 |
+
# >>> torch.equal(d[0], d[1]) # True
|
295 |
+
if cache.size(0) > 0:
|
296 |
+
key_cache, value_cache = torch.split(cache,
|
297 |
+
cache.size(-1) // 2,
|
298 |
+
dim=-1)
|
299 |
+
k = torch.cat([key_cache, k], dim=2)
|
300 |
+
v = torch.cat([value_cache, v], dim=2)
|
301 |
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
302 |
+
# non-trivial to calculate `next_cache_start` here.
|
303 |
+
new_cache = torch.cat((k, v), dim=-1)
|
304 |
+
|
305 |
+
n_batch_pos = pos_emb.size(0)
|
306 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
307 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
308 |
+
|
309 |
+
# (batch, head, time1, d_k)
|
310 |
+
q_with_bias_u = (q + self.pos_bias_u.to(q.device)).transpose(1, 2)
|
311 |
+
# (batch, head, time1, d_k)
|
312 |
+
q_with_bias_v = (q + self.pos_bias_v.to(q.device)).transpose(1, 2)
|
313 |
+
|
314 |
+
# compute attention score
|
315 |
+
# first compute matrix a and matrix c
|
316 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
317 |
+
# (batch, head, time1, time2)
|
318 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
319 |
+
|
320 |
+
# compute matrix b and matrix d
|
321 |
+
# (batch, head, time1, time2)
|
322 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
323 |
+
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
|
324 |
+
if matrix_ac.shape != matrix_bd.shape:
|
325 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
326 |
+
|
327 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
328 |
+
self.d_k) # (batch, head, time1, time2)
|
329 |
+
|
330 |
+
return self.forward_attention(v, scores, mask), new_cache
|
src/chatterbox/models/s3gen/transformer/convolution.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
2 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""ConvolutionModule definition."""
|
17 |
+
|
18 |
+
from typing import Tuple
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
|
24 |
+
class ConvolutionModule(nn.Module):
|
25 |
+
"""ConvolutionModule in Conformer model."""
|
26 |
+
|
27 |
+
def __init__(self,
|
28 |
+
channels: int,
|
29 |
+
kernel_size: int = 15,
|
30 |
+
activation: nn.Module = nn.ReLU(),
|
31 |
+
norm: str = "batch_norm",
|
32 |
+
causal: bool = False,
|
33 |
+
bias: bool = True):
|
34 |
+
"""Construct an ConvolutionModule object.
|
35 |
+
Args:
|
36 |
+
channels (int): The number of channels of conv layers.
|
37 |
+
kernel_size (int): Kernel size of conv layers.
|
38 |
+
causal (int): Whether use causal convolution or not
|
39 |
+
"""
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
self.pointwise_conv1 = nn.Conv1d(
|
43 |
+
channels,
|
44 |
+
2 * channels,
|
45 |
+
kernel_size=1,
|
46 |
+
stride=1,
|
47 |
+
padding=0,
|
48 |
+
bias=bias,
|
49 |
+
)
|
50 |
+
# self.lorder is used to distinguish if it's a causal convolution,
|
51 |
+
# if self.lorder > 0: it's a causal convolution, the input will be
|
52 |
+
# padded with self.lorder frames on the left in forward.
|
53 |
+
# else: it's a symmetrical convolution
|
54 |
+
if causal:
|
55 |
+
padding = 0
|
56 |
+
self.lorder = kernel_size - 1
|
57 |
+
else:
|
58 |
+
# kernel_size should be an odd number for none causal convolution
|
59 |
+
assert (kernel_size - 1) % 2 == 0
|
60 |
+
padding = (kernel_size - 1) // 2
|
61 |
+
self.lorder = 0
|
62 |
+
self.depthwise_conv = nn.Conv1d(
|
63 |
+
channels,
|
64 |
+
channels,
|
65 |
+
kernel_size,
|
66 |
+
stride=1,
|
67 |
+
padding=padding,
|
68 |
+
groups=channels,
|
69 |
+
bias=bias,
|
70 |
+
)
|
71 |
+
|
72 |
+
assert norm in ['batch_norm', 'layer_norm']
|
73 |
+
if norm == "batch_norm":
|
74 |
+
self.use_layer_norm = False
|
75 |
+
self.norm = nn.BatchNorm1d(channels)
|
76 |
+
else:
|
77 |
+
self.use_layer_norm = True
|
78 |
+
self.norm = nn.LayerNorm(channels)
|
79 |
+
|
80 |
+
self.pointwise_conv2 = nn.Conv1d(
|
81 |
+
channels,
|
82 |
+
channels,
|
83 |
+
kernel_size=1,
|
84 |
+
stride=1,
|
85 |
+
padding=0,
|
86 |
+
bias=bias,
|
87 |
+
)
|
88 |
+
self.activation = activation
|
89 |
+
|
90 |
+
def forward(
|
91 |
+
self,
|
92 |
+
x: torch.Tensor,
|
93 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
94 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0)),
|
95 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
96 |
+
"""Compute convolution module.
|
97 |
+
Args:
|
98 |
+
x (torch.Tensor): Input tensor (#batch, time, channels).
|
99 |
+
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
|
100 |
+
(0, 0, 0) means fake mask.
|
101 |
+
cache (torch.Tensor): left context cache, it is only
|
102 |
+
used in causal convolution (#batch, channels, cache_t),
|
103 |
+
(0, 0, 0) meas fake cache.
|
104 |
+
Returns:
|
105 |
+
torch.Tensor: Output tensor (#batch, time, channels).
|
106 |
+
"""
|
107 |
+
# exchange the temporal dimension and the feature dimension
|
108 |
+
x = x.transpose(1, 2) # (#batch, channels, time)
|
109 |
+
|
110 |
+
# mask batch padding
|
111 |
+
if mask_pad.size(2) > 0: # time > 0
|
112 |
+
x.masked_fill_(~mask_pad, 0.0)
|
113 |
+
|
114 |
+
if self.lorder > 0:
|
115 |
+
if cache.size(2) == 0: # cache_t == 0
|
116 |
+
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
|
117 |
+
else:
|
118 |
+
assert cache.size(0) == x.size(0) # equal batch
|
119 |
+
assert cache.size(1) == x.size(1) # equal channel
|
120 |
+
x = torch.cat((cache, x), dim=2)
|
121 |
+
assert (x.size(2) > self.lorder)
|
122 |
+
new_cache = x[:, :, -self.lorder:]
|
123 |
+
else:
|
124 |
+
# It's better we just return None if no cache is required,
|
125 |
+
# However, for JIT export, here we just fake one tensor instead of
|
126 |
+
# None.
|
127 |
+
new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
128 |
+
|
129 |
+
# GLU mechanism
|
130 |
+
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
131 |
+
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
132 |
+
|
133 |
+
# 1D Depthwise Conv
|
134 |
+
x = self.depthwise_conv(x)
|
135 |
+
if self.use_layer_norm:
|
136 |
+
x = x.transpose(1, 2)
|
137 |
+
x = self.activation(self.norm(x))
|
138 |
+
if self.use_layer_norm:
|
139 |
+
x = x.transpose(1, 2)
|
140 |
+
x = self.pointwise_conv2(x)
|
141 |
+
# mask batch padding
|
142 |
+
if mask_pad.size(2) > 0: # time > 0
|
143 |
+
x.masked_fill_(~mask_pad, 0.0)
|
144 |
+
|
145 |
+
return x.transpose(1, 2), new_cache
|
src/chatterbox/models/s3gen/transformer/embedding.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
2 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""Positonal Encoding Module."""
|
17 |
+
|
18 |
+
import math
|
19 |
+
from typing import Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
|
26 |
+
class PositionalEncoding(torch.nn.Module):
|
27 |
+
"""Positional encoding.
|
28 |
+
|
29 |
+
:param int d_model: embedding dim
|
30 |
+
:param float dropout_rate: dropout rate
|
31 |
+
:param int max_len: maximum input length
|
32 |
+
|
33 |
+
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
|
34 |
+
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self,
|
38 |
+
d_model: int,
|
39 |
+
dropout_rate: float,
|
40 |
+
max_len: int = 5000,
|
41 |
+
reverse: bool = False):
|
42 |
+
"""Construct an PositionalEncoding object."""
|
43 |
+
super().__init__()
|
44 |
+
self.d_model = d_model
|
45 |
+
self.xscale = math.sqrt(self.d_model)
|
46 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
47 |
+
self.max_len = max_len
|
48 |
+
|
49 |
+
self.pe = torch.zeros(self.max_len, self.d_model)
|
50 |
+
position = torch.arange(0, self.max_len,
|
51 |
+
dtype=torch.float32).unsqueeze(1)
|
52 |
+
div_term = torch.exp(
|
53 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32) *
|
54 |
+
-(math.log(10000.0) / self.d_model))
|
55 |
+
self.pe[:, 0::2] = torch.sin(position * div_term)
|
56 |
+
self.pe[:, 1::2] = torch.cos(position * div_term)
|
57 |
+
self.pe = self.pe.unsqueeze(0)
|
58 |
+
|
59 |
+
def forward(self,
|
60 |
+
x: torch.Tensor,
|
61 |
+
offset: Union[int, torch.Tensor] = 0) \
|
62 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
63 |
+
"""Add positional encoding.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
67 |
+
offset (int, torch.tensor): position offset
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
71 |
+
torch.Tensor: for compatibility to RelPositionalEncoding
|
72 |
+
"""
|
73 |
+
|
74 |
+
self.pe = self.pe.to(x.device)
|
75 |
+
pos_emb = self.position_encoding(offset, x.size(1), False)
|
76 |
+
x = x * self.xscale + pos_emb
|
77 |
+
return self.dropout(x), self.dropout(pos_emb)
|
78 |
+
|
79 |
+
def position_encoding(self,
|
80 |
+
offset: Union[int, torch.Tensor],
|
81 |
+
size: int,
|
82 |
+
apply_dropout: bool = True) -> torch.Tensor:
|
83 |
+
""" For getting encoding in a streaming fashion
|
84 |
+
|
85 |
+
Attention!!!!!
|
86 |
+
we apply dropout only once at the whole utterance level in a none
|
87 |
+
streaming way, but will call this function several times with
|
88 |
+
increasing input size in a streaming scenario, so the dropout will
|
89 |
+
be applied several times.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
offset (int or torch.tensor): start offset
|
93 |
+
size (int): required size of position encoding
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
torch.Tensor: Corresponding encoding
|
97 |
+
"""
|
98 |
+
# How to subscript a Union type:
|
99 |
+
# https://github.com/pytorch/pytorch/issues/69434
|
100 |
+
if isinstance(offset, int):
|
101 |
+
assert offset + size <= self.max_len
|
102 |
+
pos_emb = self.pe[:, offset:offset + size]
|
103 |
+
elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
|
104 |
+
assert offset + size <= self.max_len
|
105 |
+
pos_emb = self.pe[:, offset:offset + size]
|
106 |
+
else: # for batched streaming decoding on GPU
|
107 |
+
assert torch.max(offset) + size <= self.max_len
|
108 |
+
index = offset.unsqueeze(1) + \
|
109 |
+
torch.arange(0, size).to(offset.device) # B X T
|
110 |
+
flag = index > 0
|
111 |
+
# remove negative offset
|
112 |
+
index = index * flag
|
113 |
+
pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
|
114 |
+
|
115 |
+
if apply_dropout:
|
116 |
+
pos_emb = self.dropout(pos_emb)
|
117 |
+
return pos_emb
|
118 |
+
|
119 |
+
|
120 |
+
class RelPositionalEncoding(PositionalEncoding):
|
121 |
+
"""Relative positional encoding module.
|
122 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
123 |
+
Args:
|
124 |
+
d_model (int): Embedding dimension.
|
125 |
+
dropout_rate (float): Dropout rate.
|
126 |
+
max_len (int): Maximum input length.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
130 |
+
"""Initialize class."""
|
131 |
+
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
132 |
+
|
133 |
+
def forward(self,
|
134 |
+
x: torch.Tensor,
|
135 |
+
offset: Union[int, torch.Tensor] = 0) \
|
136 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
137 |
+
"""Compute positional encoding.
|
138 |
+
Args:
|
139 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
140 |
+
Returns:
|
141 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
142 |
+
torch.Tensor: Positional embedding tensor (1, time, `*`).
|
143 |
+
"""
|
144 |
+
self.pe = self.pe.to(x.device)
|
145 |
+
x = x * self.xscale
|
146 |
+
pos_emb = self.position_encoding(offset, x.size(1), False)
|
147 |
+
return self.dropout(x), self.dropout(pos_emb)
|
148 |
+
|
149 |
+
|
150 |
+
class WhisperPositionalEncoding(PositionalEncoding):
|
151 |
+
""" Sinusoids position encoding used in openai-whisper.encoder
|
152 |
+
"""
|
153 |
+
|
154 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
|
155 |
+
super().__init__(d_model, dropout_rate, max_len)
|
156 |
+
self.xscale = 1.0
|
157 |
+
log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
|
158 |
+
inv_timescales = torch.exp(-log_timescale_increment *
|
159 |
+
torch.arange(d_model // 2))
|
160 |
+
scaled_time = torch.arange(max_len)[:, np.newaxis] * \
|
161 |
+
inv_timescales[np.newaxis, :]
|
162 |
+
pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
163 |
+
delattr(self, "pe")
|
164 |
+
self.register_buffer("pe", pe.unsqueeze(0))
|
165 |
+
|
166 |
+
|
167 |
+
class LearnablePositionalEncoding(PositionalEncoding):
|
168 |
+
""" Learnable position encoding used in openai-whisper.decoder
|
169 |
+
"""
|
170 |
+
|
171 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
|
172 |
+
super().__init__(d_model, dropout_rate, max_len)
|
173 |
+
# NOTE(xcsong): overwrite self.pe & self.xscale
|
174 |
+
self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
|
175 |
+
self.xscale = 1.0
|
176 |
+
|
177 |
+
|
178 |
+
class NoPositionalEncoding(torch.nn.Module):
|
179 |
+
""" No position encoding
|
180 |
+
"""
|
181 |
+
|
182 |
+
def __init__(self, d_model: int, dropout_rate: float):
|
183 |
+
super().__init__()
|
184 |
+
self.d_model = d_model
|
185 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
186 |
+
|
187 |
+
def forward(self,
|
188 |
+
x: torch.Tensor,
|
189 |
+
offset: Union[int, torch.Tensor] = 0) \
|
190 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
191 |
+
""" Just return zero vector for interface compatibility
|
192 |
+
"""
|
193 |
+
pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
|
194 |
+
return self.dropout(x), pos_emb
|
195 |
+
|
196 |
+
def position_encoding(self, offset: Union[int, torch.Tensor],
|
197 |
+
size: int) -> torch.Tensor:
|
198 |
+
return torch.zeros(1, size, self.d_model)
|
199 |
+
|
200 |
+
|
201 |
+
class EspnetRelPositionalEncoding(torch.nn.Module):
|
202 |
+
"""Relative positional encoding module (new implementation).
|
203 |
+
|
204 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
205 |
+
|
206 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
207 |
+
|
208 |
+
Args:
|
209 |
+
d_model (int): Embedding dimension.
|
210 |
+
dropout_rate (float): Dropout rate.
|
211 |
+
max_len (int): Maximum input length.
|
212 |
+
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
216 |
+
"""Construct an PositionalEncoding object."""
|
217 |
+
super(EspnetRelPositionalEncoding, self).__init__()
|
218 |
+
self.d_model = d_model
|
219 |
+
self.xscale = math.sqrt(self.d_model)
|
220 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
221 |
+
self.pe = None
|
222 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
223 |
+
|
224 |
+
def extend_pe(self, x: torch.Tensor):
|
225 |
+
"""Reset the positional encodings."""
|
226 |
+
if self.pe is not None:
|
227 |
+
# self.pe contains both positive and negative parts
|
228 |
+
# the length of self.pe is 2 * input_len - 1
|
229 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
230 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
231 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
232 |
+
return
|
233 |
+
# Suppose `i` means to the position of query vecotr and `j` means the
|
234 |
+
# position of key vector. We use position relative positions when keys
|
235 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
236 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
237 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
238 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
239 |
+
div_term = torch.exp(
|
240 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
241 |
+
* -(math.log(10000.0) / self.d_model)
|
242 |
+
)
|
243 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
244 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
245 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
246 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
247 |
+
|
248 |
+
# Reserve the order of positive indices and concat both positive and
|
249 |
+
# negative indices. This is used to support the shifting trick
|
250 |
+
# as in https://arxiv.org/abs/1901.02860
|
251 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
252 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
253 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
254 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
255 |
+
|
256 |
+
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
|
257 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
258 |
+
"""Add positional encoding.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
265 |
+
|
266 |
+
"""
|
267 |
+
self.extend_pe(x)
|
268 |
+
x = x * self.xscale
|
269 |
+
pos_emb = self.position_encoding(size=x.size(1), offset=offset)
|
270 |
+
return self.dropout(x), self.dropout(pos_emb)
|
271 |
+
|
272 |
+
def position_encoding(self,
|
273 |
+
offset: Union[int, torch.Tensor],
|
274 |
+
size: int) -> torch.Tensor:
|
275 |
+
""" For getting encoding in a streaming fashion
|
276 |
+
|
277 |
+
Attention!!!!!
|
278 |
+
we apply dropout only once at the whole utterance level in a none
|
279 |
+
streaming way, but will call this function several times with
|
280 |
+
increasing input size in a streaming scenario, so the dropout will
|
281 |
+
be applied several times.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
offset (int or torch.tensor): start offset
|
285 |
+
size (int): required size of position encoding
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
torch.Tensor: Corresponding encoding
|
289 |
+
"""
|
290 |
+
pos_emb = self.pe[
|
291 |
+
:,
|
292 |
+
self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
|
293 |
+
]
|
294 |
+
return pos_emb
|
src/chatterbox/models/s3gen/transformer/encoder_layer.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
2 |
+
# 2022 Xingchen Song ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""Encoder self-attention layer definition."""
|
17 |
+
|
18 |
+
from typing import Optional, Tuple
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
|
24 |
+
class TransformerEncoderLayer(nn.Module):
|
25 |
+
"""Encoder layer module.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
size (int): Input dimension.
|
29 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
30 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
31 |
+
instance can be used as the argument.
|
32 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
33 |
+
`PositionwiseFeedForward`, instance can be used as the argument.
|
34 |
+
dropout_rate (float): Dropout rate.
|
35 |
+
normalize_before (bool):
|
36 |
+
True: use layer_norm before each sub-block.
|
37 |
+
False: to use layer_norm after each sub-block.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
size: int,
|
43 |
+
self_attn: torch.nn.Module,
|
44 |
+
feed_forward: torch.nn.Module,
|
45 |
+
dropout_rate: float,
|
46 |
+
normalize_before: bool = True,
|
47 |
+
):
|
48 |
+
"""Construct an EncoderLayer object."""
|
49 |
+
super().__init__()
|
50 |
+
self.self_attn = self_attn
|
51 |
+
self.feed_forward = feed_forward
|
52 |
+
self.norm1 = nn.LayerNorm(size, eps=1e-12)
|
53 |
+
self.norm2 = nn.LayerNorm(size, eps=1e-12)
|
54 |
+
self.dropout = nn.Dropout(dropout_rate)
|
55 |
+
self.size = size
|
56 |
+
self.normalize_before = normalize_before
|
57 |
+
|
58 |
+
def forward(
|
59 |
+
self,
|
60 |
+
x: torch.Tensor,
|
61 |
+
mask: torch.Tensor,
|
62 |
+
pos_emb: torch.Tensor,
|
63 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
64 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
65 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
66 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
67 |
+
"""Compute encoded features.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
x (torch.Tensor): (#batch, time, size)
|
71 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
72 |
+
(0, 0, 0) means fake mask.
|
73 |
+
pos_emb (torch.Tensor): just for interface compatibility
|
74 |
+
to ConformerEncoderLayer
|
75 |
+
mask_pad (torch.Tensor): does not used in transformer layer,
|
76 |
+
just for unified api with conformer.
|
77 |
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
78 |
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
79 |
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
80 |
+
(#batch=1, size, cache_t2), not used here, it's for interface
|
81 |
+
compatibility to ConformerEncoderLayer.
|
82 |
+
Returns:
|
83 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
84 |
+
torch.Tensor: Mask tensor (#batch, time, time).
|
85 |
+
torch.Tensor: att_cache tensor,
|
86 |
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
87 |
+
torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
|
88 |
+
|
89 |
+
"""
|
90 |
+
residual = x
|
91 |
+
if self.normalize_before:
|
92 |
+
x = self.norm1(x)
|
93 |
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
|
94 |
+
x = residual + self.dropout(x_att)
|
95 |
+
if not self.normalize_before:
|
96 |
+
x = self.norm1(x)
|
97 |
+
|
98 |
+
residual = x
|
99 |
+
if self.normalize_before:
|
100 |
+
x = self.norm2(x)
|
101 |
+
x = residual + self.dropout(self.feed_forward(x))
|
102 |
+
if not self.normalize_before:
|
103 |
+
x = self.norm2(x)
|
104 |
+
|
105 |
+
fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
106 |
+
return x, mask, new_att_cache, fake_cnn_cache
|
107 |
+
|
108 |
+
|
109 |
+
class ConformerEncoderLayer(nn.Module):
|
110 |
+
"""Encoder layer module.
|
111 |
+
Args:
|
112 |
+
size (int): Input dimension.
|
113 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
114 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
115 |
+
instance can be used as the argument.
|
116 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
117 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
118 |
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
|
119 |
+
instance.
|
120 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
121 |
+
conv_module (torch.nn.Module): Convolution module instance.
|
122 |
+
`ConvlutionModule` instance can be used as the argument.
|
123 |
+
dropout_rate (float): Dropout rate.
|
124 |
+
normalize_before (bool):
|
125 |
+
True: use layer_norm before each sub-block.
|
126 |
+
False: use layer_norm after each sub-block.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
size: int,
|
132 |
+
self_attn: torch.nn.Module,
|
133 |
+
feed_forward: Optional[nn.Module] = None,
|
134 |
+
feed_forward_macaron: Optional[nn.Module] = None,
|
135 |
+
conv_module: Optional[nn.Module] = None,
|
136 |
+
dropout_rate: float = 0.1,
|
137 |
+
normalize_before: bool = True,
|
138 |
+
):
|
139 |
+
"""Construct an EncoderLayer object."""
|
140 |
+
super().__init__()
|
141 |
+
self.self_attn = self_attn
|
142 |
+
self.feed_forward = feed_forward
|
143 |
+
self.feed_forward_macaron = feed_forward_macaron
|
144 |
+
self.conv_module = conv_module
|
145 |
+
self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
|
146 |
+
self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
|
147 |
+
if feed_forward_macaron is not None:
|
148 |
+
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
|
149 |
+
self.ff_scale = 0.5
|
150 |
+
else:
|
151 |
+
self.ff_scale = 1.0
|
152 |
+
if self.conv_module is not None:
|
153 |
+
self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
|
154 |
+
self.norm_final = nn.LayerNorm(
|
155 |
+
size, eps=1e-12) # for the final output of the block
|
156 |
+
self.dropout = nn.Dropout(dropout_rate)
|
157 |
+
self.size = size
|
158 |
+
self.normalize_before = normalize_before
|
159 |
+
|
160 |
+
def forward(
|
161 |
+
self,
|
162 |
+
x: torch.Tensor,
|
163 |
+
mask: torch.Tensor,
|
164 |
+
pos_emb: torch.Tensor,
|
165 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
166 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
167 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
168 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
169 |
+
"""Compute encoded features.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
x (torch.Tensor): (#batch, time, size)
|
173 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
174 |
+
(0, 0, 0) means fake mask.
|
175 |
+
pos_emb (torch.Tensor): positional encoding, must not be None
|
176 |
+
for ConformerEncoderLayer.
|
177 |
+
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
178 |
+
(#batch, 1,time), (0, 0, 0) means fake mask.
|
179 |
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
180 |
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
181 |
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
182 |
+
(#batch=1, size, cache_t2)
|
183 |
+
Returns:
|
184 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
185 |
+
torch.Tensor: Mask tensor (#batch, time, time).
|
186 |
+
torch.Tensor: att_cache tensor,
|
187 |
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
188 |
+
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
189 |
+
"""
|
190 |
+
|
191 |
+
# whether to use macaron style
|
192 |
+
if self.feed_forward_macaron is not None:
|
193 |
+
residual = x
|
194 |
+
if self.normalize_before:
|
195 |
+
x = self.norm_ff_macaron(x)
|
196 |
+
x = residual + self.ff_scale * self.dropout(
|
197 |
+
self.feed_forward_macaron(x))
|
198 |
+
if not self.normalize_before:
|
199 |
+
x = self.norm_ff_macaron(x)
|
200 |
+
|
201 |
+
# multi-headed self-attention module
|
202 |
+
residual = x
|
203 |
+
if self.normalize_before:
|
204 |
+
x = self.norm_mha(x)
|
205 |
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
|
206 |
+
att_cache)
|
207 |
+
x = residual + self.dropout(x_att)
|
208 |
+
if not self.normalize_before:
|
209 |
+
x = self.norm_mha(x)
|
210 |
+
|
211 |
+
# convolution module
|
212 |
+
# Fake new cnn cache here, and then change it in conv_module
|
213 |
+
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
214 |
+
if self.conv_module is not None:
|
215 |
+
residual = x
|
216 |
+
if self.normalize_before:
|
217 |
+
x = self.norm_conv(x)
|
218 |
+
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
|
219 |
+
x = residual + self.dropout(x)
|
220 |
+
|
221 |
+
if not self.normalize_before:
|
222 |
+
x = self.norm_conv(x)
|
223 |
+
|
224 |
+
# feed forward module
|
225 |
+
residual = x
|
226 |
+
if self.normalize_before:
|
227 |
+
x = self.norm_ff(x)
|
228 |
+
|
229 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
230 |
+
if not self.normalize_before:
|
231 |
+
x = self.norm_ff(x)
|
232 |
+
|
233 |
+
if self.conv_module is not None:
|
234 |
+
x = self.norm_final(x)
|
235 |
+
|
236 |
+
return x, mask, new_att_cache, new_cnn_cache
|
src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019 Shigeki Karita
|
2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Positionwise feed forward layer definition."""
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
|
20 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
21 |
+
"""Positionwise feed forward layer.
|
22 |
+
|
23 |
+
FeedForward are appied on each position of the sequence.
|
24 |
+
The output dim is same with the input dim.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
idim (int): Input dimenstion.
|
28 |
+
hidden_units (int): The number of hidden units.
|
29 |
+
dropout_rate (float): Dropout rate.
|
30 |
+
activation (torch.nn.Module): Activation function
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
idim: int,
|
36 |
+
hidden_units: int,
|
37 |
+
dropout_rate: float,
|
38 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
39 |
+
):
|
40 |
+
"""Construct a PositionwiseFeedForward object."""
|
41 |
+
super(PositionwiseFeedForward, self).__init__()
|
42 |
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
43 |
+
self.activation = activation
|
44 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
45 |
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
46 |
+
|
47 |
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
48 |
+
"""Forward function.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
xs: input tensor (B, L, D)
|
52 |
+
Returns:
|
53 |
+
output tensor, (B, L, D)
|
54 |
+
"""
|
55 |
+
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
56 |
+
|
57 |
+
|
58 |
+
class MoEFFNLayer(torch.nn.Module):
|
59 |
+
"""
|
60 |
+
Mixture of expert with Positionwise feed forward layer
|
61 |
+
See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
|
62 |
+
The output dim is same with the input dim.
|
63 |
+
|
64 |
+
Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
|
65 |
+
https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
|
66 |
+
Args:
|
67 |
+
n_expert: number of expert.
|
68 |
+
n_expert_per_token: The actual number of experts used for each frame
|
69 |
+
idim (int): Input dimenstion.
|
70 |
+
hidden_units (int): The number of hidden units.
|
71 |
+
dropout_rate (float): Dropout rate.
|
72 |
+
activation (torch.nn.Module): Activation function
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
n_expert: int,
|
78 |
+
n_expert_per_token: int,
|
79 |
+
idim: int,
|
80 |
+
hidden_units: int,
|
81 |
+
dropout_rate: float,
|
82 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
83 |
+
):
|
84 |
+
super(MoEFFNLayer, self).__init__()
|
85 |
+
self.gate = torch.nn.Linear(idim, n_expert, bias=False)
|
86 |
+
self.experts = torch.nn.ModuleList(
|
87 |
+
PositionwiseFeedForward(idim, hidden_units, dropout_rate,
|
88 |
+
activation) for _ in range(n_expert))
|
89 |
+
self.n_expert_per_token = n_expert_per_token
|
90 |
+
|
91 |
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
92 |
+
"""Foward function.
|
93 |
+
Args:
|
94 |
+
xs: input tensor (B, L, D)
|
95 |
+
Returns:
|
96 |
+
output tensor, (B, L, D)
|
97 |
+
|
98 |
+
"""
|
99 |
+
B, L, D = xs.size(
|
100 |
+
) # batch size, sequence length, embedding dimension (idim)
|
101 |
+
xs = xs.view(-1, D) # (B*L, D)
|
102 |
+
router = self.gate(xs) # (B*L, n_expert)
|
103 |
+
logits, indices = torch.topk(
|
104 |
+
router, self.n_expert_per_token
|
105 |
+
) # probs:(B*L, n_expert), indices: (B*L, n_expert)
|
106 |
+
weights = torch.nn.functional.softmax(
|
107 |
+
logits, dim=1,
|
108 |
+
dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
|
109 |
+
output = torch.zeros_like(xs) # (B*L, D)
|
110 |
+
for i, expert in enumerate(self.experts):
|
111 |
+
mask = indices == i
|
112 |
+
batch_idx, ith_expert = torch.where(mask)
|
113 |
+
output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
|
114 |
+
xs[batch_idx])
|
115 |
+
return output.view(B, L, D)
|
src/chatterbox/models/s3gen/transformer/subsampling.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
2 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""Subsampling layer definition."""
|
17 |
+
|
18 |
+
from typing import Tuple, Union
|
19 |
+
|
20 |
+
import torch
|
21 |
+
|
22 |
+
|
23 |
+
class BaseSubsampling(torch.nn.Module):
|
24 |
+
|
25 |
+
def __init__(self):
|
26 |
+
super().__init__()
|
27 |
+
self.right_context = 0
|
28 |
+
self.subsampling_rate = 1
|
29 |
+
|
30 |
+
def position_encoding(self, offset: Union[int, torch.Tensor],
|
31 |
+
size: int) -> torch.Tensor:
|
32 |
+
return self.pos_enc.position_encoding(offset, size)
|
33 |
+
|
34 |
+
|
35 |
+
class EmbedinigNoSubsampling(BaseSubsampling):
|
36 |
+
"""Embedding input without subsampling
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
40 |
+
pos_enc_class: torch.nn.Module):
|
41 |
+
super().__init__()
|
42 |
+
self.embed = torch.nn.Embedding(idim, odim)
|
43 |
+
self.pos_enc = pos_enc_class
|
44 |
+
|
45 |
+
def forward(
|
46 |
+
self,
|
47 |
+
x: torch.Tensor,
|
48 |
+
x_mask: torch.Tensor,
|
49 |
+
offset: Union[int, torch.Tensor] = 0
|
50 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
51 |
+
"""Input x.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
55 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
59 |
+
where time' = time .
|
60 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
61 |
+
where time' = time .
|
62 |
+
|
63 |
+
"""
|
64 |
+
x = self.embed(x)
|
65 |
+
x, pos_emb = self.pos_enc(x, offset)
|
66 |
+
return x, pos_emb, x_mask
|
67 |
+
|
68 |
+
|
69 |
+
class LinearNoSubsampling(BaseSubsampling):
|
70 |
+
"""Linear transform the input without subsampling
|
71 |
+
|
72 |
+
Args:
|
73 |
+
idim (int): Input dimension.
|
74 |
+
odim (int): Output dimension.
|
75 |
+
dropout_rate (float): Dropout rate.
|
76 |
+
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
80 |
+
pos_enc_class: torch.nn.Module):
|
81 |
+
"""Construct an linear object."""
|
82 |
+
super().__init__()
|
83 |
+
self.out = torch.nn.Sequential(
|
84 |
+
torch.nn.Linear(idim, odim),
|
85 |
+
torch.nn.LayerNorm(odim, eps=1e-5),
|
86 |
+
torch.nn.Dropout(dropout_rate),
|
87 |
+
)
|
88 |
+
self.pos_enc = pos_enc_class
|
89 |
+
self.right_context = 0
|
90 |
+
self.subsampling_rate = 1
|
91 |
+
|
92 |
+
def forward(
|
93 |
+
self,
|
94 |
+
x: torch.Tensor,
|
95 |
+
x_mask: torch.Tensor,
|
96 |
+
offset: Union[int, torch.Tensor] = 0
|
97 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
98 |
+
"""Input x.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
102 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
106 |
+
where time' = time .
|
107 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
108 |
+
where time' = time .
|
109 |
+
|
110 |
+
"""
|
111 |
+
x = self.out(x)
|
112 |
+
x, pos_emb = self.pos_enc(x, offset)
|
113 |
+
return x, pos_emb, x_mask
|
114 |
+
|
115 |
+
|
116 |
+
class Conv1dSubsampling2(BaseSubsampling):
|
117 |
+
"""Convolutional 1D subsampling (to 1/2 length).
|
118 |
+
It is designed for Whisper, ref:
|
119 |
+
https://github.com/openai/whisper/blob/main/whisper/model.py
|
120 |
+
|
121 |
+
Args:
|
122 |
+
idim (int): Input dimension.
|
123 |
+
odim (int): Output dimension.
|
124 |
+
dropout_rate (float): Dropout rate.
|
125 |
+
|
126 |
+
"""
|
127 |
+
|
128 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
129 |
+
pos_enc_class: torch.nn.Module):
|
130 |
+
"""Construct an Conv1dSubsampling2 object."""
|
131 |
+
super().__init__()
|
132 |
+
self.conv = torch.nn.Sequential(
|
133 |
+
torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
|
134 |
+
torch.nn.GELU(),
|
135 |
+
torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
|
136 |
+
torch.nn.GELU(),
|
137 |
+
)
|
138 |
+
self.pos_enc = pos_enc_class
|
139 |
+
# The right context for every conv layer is computed by:
|
140 |
+
# (kernel_size - 1) * frame_rate_of_this_layer
|
141 |
+
self.subsampling_rate = 2
|
142 |
+
# 4 = (3 - 1) * 1 + (3 - 1) * 1
|
143 |
+
self.right_context = 4
|
144 |
+
|
145 |
+
def forward(
|
146 |
+
self,
|
147 |
+
x: torch.Tensor,
|
148 |
+
x_mask: torch.Tensor,
|
149 |
+
offset: Union[int, torch.Tensor] = 0
|
150 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
151 |
+
"""Subsample x.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
155 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
159 |
+
where time' = time // 2.
|
160 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
161 |
+
where time' = time // 2.
|
162 |
+
torch.Tensor: positional encoding
|
163 |
+
|
164 |
+
"""
|
165 |
+
time = x.size(1)
|
166 |
+
x = x.transpose(1, 2) # (b, f, t)
|
167 |
+
x = self.conv(x)
|
168 |
+
x = x.transpose(1, 2) # (b, t, f)
|
169 |
+
x, pos_emb = self.pos_enc(x, offset)
|
170 |
+
return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
|
171 |
+
|
172 |
+
|
173 |
+
class Conv2dSubsampling4(BaseSubsampling):
|
174 |
+
"""Convolutional 2D subsampling (to 1/4 length).
|
175 |
+
|
176 |
+
Args:
|
177 |
+
idim (int): Input dimension.
|
178 |
+
odim (int): Output dimension.
|
179 |
+
dropout_rate (float): Dropout rate.
|
180 |
+
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
184 |
+
pos_enc_class: torch.nn.Module):
|
185 |
+
"""Construct an Conv2dSubsampling4 object."""
|
186 |
+
super().__init__()
|
187 |
+
self.conv = torch.nn.Sequential(
|
188 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
189 |
+
torch.nn.ReLU(),
|
190 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
191 |
+
torch.nn.ReLU(),
|
192 |
+
)
|
193 |
+
self.out = torch.nn.Sequential(
|
194 |
+
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
|
195 |
+
self.pos_enc = pos_enc_class
|
196 |
+
# The right context for every conv layer is computed by:
|
197 |
+
# (kernel_size - 1) * frame_rate_of_this_layer
|
198 |
+
self.subsampling_rate = 4
|
199 |
+
# 6 = (3 - 1) * 1 + (3 - 1) * 2
|
200 |
+
self.right_context = 6
|
201 |
+
|
202 |
+
def forward(
|
203 |
+
self,
|
204 |
+
x: torch.Tensor,
|
205 |
+
x_mask: torch.Tensor,
|
206 |
+
offset: Union[int, torch.Tensor] = 0
|
207 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
208 |
+
"""Subsample x.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
212 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
216 |
+
where time' = time // 4.
|
217 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
218 |
+
where time' = time // 4.
|
219 |
+
torch.Tensor: positional encoding
|
220 |
+
|
221 |
+
"""
|
222 |
+
x = x.unsqueeze(1) # (b, c=1, t, f)
|
223 |
+
x = self.conv(x)
|
224 |
+
b, c, t, f = x.size()
|
225 |
+
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
226 |
+
x, pos_emb = self.pos_enc(x, offset)
|
227 |
+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
|
228 |
+
|
229 |
+
|
230 |
+
class Conv2dSubsampling6(BaseSubsampling):
|
231 |
+
"""Convolutional 2D subsampling (to 1/6 length).
|
232 |
+
Args:
|
233 |
+
idim (int): Input dimension.
|
234 |
+
odim (int): Output dimension.
|
235 |
+
dropout_rate (float): Dropout rate.
|
236 |
+
pos_enc (torch.nn.Module): Custom position encoding layer.
|
237 |
+
"""
|
238 |
+
|
239 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
240 |
+
pos_enc_class: torch.nn.Module):
|
241 |
+
"""Construct an Conv2dSubsampling6 object."""
|
242 |
+
super().__init__()
|
243 |
+
self.conv = torch.nn.Sequential(
|
244 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
245 |
+
torch.nn.ReLU(),
|
246 |
+
torch.nn.Conv2d(odim, odim, 5, 3),
|
247 |
+
torch.nn.ReLU(),
|
248 |
+
)
|
249 |
+
self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
|
250 |
+
odim)
|
251 |
+
self.pos_enc = pos_enc_class
|
252 |
+
# 10 = (3 - 1) * 1 + (5 - 1) * 2
|
253 |
+
self.subsampling_rate = 6
|
254 |
+
self.right_context = 10
|
255 |
+
|
256 |
+
def forward(
|
257 |
+
self,
|
258 |
+
x: torch.Tensor,
|
259 |
+
x_mask: torch.Tensor,
|
260 |
+
offset: Union[int, torch.Tensor] = 0
|
261 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
262 |
+
"""Subsample x.
|
263 |
+
Args:
|
264 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
265 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
269 |
+
where time' = time // 6.
|
270 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
271 |
+
where time' = time // 6.
|
272 |
+
torch.Tensor: positional encoding
|
273 |
+
"""
|
274 |
+
x = x.unsqueeze(1) # (b, c, t, f)
|
275 |
+
x = self.conv(x)
|
276 |
+
b, c, t, f = x.size()
|
277 |
+
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
278 |
+
x, pos_emb = self.pos_enc(x, offset)
|
279 |
+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
|
280 |
+
|
281 |
+
|
282 |
+
class Conv2dSubsampling8(BaseSubsampling):
|
283 |
+
"""Convolutional 2D subsampling (to 1/8 length).
|
284 |
+
|
285 |
+
Args:
|
286 |
+
idim (int): Input dimension.
|
287 |
+
odim (int): Output dimension.
|
288 |
+
dropout_rate (float): Dropout rate.
|
289 |
+
|
290 |
+
"""
|
291 |
+
|
292 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
293 |
+
pos_enc_class: torch.nn.Module):
|
294 |
+
"""Construct an Conv2dSubsampling8 object."""
|
295 |
+
super().__init__()
|
296 |
+
self.conv = torch.nn.Sequential(
|
297 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
298 |
+
torch.nn.ReLU(),
|
299 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
300 |
+
torch.nn.ReLU(),
|
301 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
302 |
+
torch.nn.ReLU(),
|
303 |
+
)
|
304 |
+
self.linear = torch.nn.Linear(
|
305 |
+
odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
|
306 |
+
self.pos_enc = pos_enc_class
|
307 |
+
self.subsampling_rate = 8
|
308 |
+
# 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
|
309 |
+
self.right_context = 14
|
310 |
+
|
311 |
+
def forward(
|
312 |
+
self,
|
313 |
+
x: torch.Tensor,
|
314 |
+
x_mask: torch.Tensor,
|
315 |
+
offset: Union[int, torch.Tensor] = 0
|
316 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
317 |
+
"""Subsample x.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
321 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
325 |
+
where time' = time // 8.
|
326 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
327 |
+
where time' = time // 8.
|
328 |
+
torch.Tensor: positional encoding
|
329 |
+
"""
|
330 |
+
x = x.unsqueeze(1) # (b, c, t, f)
|
331 |
+
x = self.conv(x)
|
332 |
+
b, c, t, f = x.size()
|
333 |
+
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
334 |
+
x, pos_emb = self.pos_enc(x, offset)
|
335 |
+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
|
336 |
+
|
337 |
+
|
338 |
+
class LegacyLinearNoSubsampling(BaseSubsampling):
|
339 |
+
"""Linear transform the input without subsampling
|
340 |
+
|
341 |
+
Args:
|
342 |
+
idim (int): Input dimension.
|
343 |
+
odim (int): Output dimension.
|
344 |
+
dropout_rate (float): Dropout rate.
|
345 |
+
|
346 |
+
"""
|
347 |
+
|
348 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
349 |
+
pos_enc_class: torch.nn.Module):
|
350 |
+
"""Construct an linear object."""
|
351 |
+
super().__init__()
|
352 |
+
self.out = torch.nn.Sequential(
|
353 |
+
torch.nn.Linear(idim, odim),
|
354 |
+
torch.nn.LayerNorm(odim, eps=1e-5),
|
355 |
+
torch.nn.Dropout(dropout_rate),
|
356 |
+
torch.nn.ReLU(),
|
357 |
+
)
|
358 |
+
self.pos_enc = pos_enc_class
|
359 |
+
self.right_context = 0
|
360 |
+
self.subsampling_rate = 1
|
361 |
+
|
362 |
+
def forward(
|
363 |
+
self,
|
364 |
+
x: torch.Tensor,
|
365 |
+
x_mask: torch.Tensor,
|
366 |
+
offset: Union[int, torch.Tensor] = 0
|
367 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
368 |
+
"""Input x.
|
369 |
+
|
370 |
+
Args:
|
371 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
372 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
373 |
+
|
374 |
+
Returns:
|
375 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
376 |
+
where time' = time .
|
377 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
378 |
+
where time' = time .
|
379 |
+
|
380 |
+
"""
|
381 |
+
x = self.out(x)
|
382 |
+
x, pos_emb = self.pos_enc(x, offset)
|
383 |
+
return x, pos_emb, x_mask
|
src/chatterbox/models/s3gen/transformer/upsample_encoder.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
2 |
+
# 2022 Xingchen Song ([email protected])
|
3 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
17 |
+
"""Encoder definition."""
|
18 |
+
from typing import Tuple
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch import nn
|
22 |
+
from torch.nn import functional as F
|
23 |
+
|
24 |
+
from .convolution import ConvolutionModule
|
25 |
+
from .encoder_layer import ConformerEncoderLayer
|
26 |
+
from .positionwise_feed_forward import PositionwiseFeedForward
|
27 |
+
from ..utils.class_utils import (
|
28 |
+
COSYVOICE_EMB_CLASSES,
|
29 |
+
COSYVOICE_SUBSAMPLE_CLASSES,
|
30 |
+
COSYVOICE_ATTENTION_CLASSES,
|
31 |
+
COSYVOICE_ACTIVATION_CLASSES,
|
32 |
+
)
|
33 |
+
from ..utils.mask import make_pad_mask
|
34 |
+
from ..utils.mask import add_optional_chunk_mask
|
35 |
+
|
36 |
+
|
37 |
+
class Upsample1D(nn.Module):
|
38 |
+
"""A 1D upsampling layer with an optional convolution.
|
39 |
+
|
40 |
+
Parameters:
|
41 |
+
channels (`int`):
|
42 |
+
number of channels in the inputs and outputs.
|
43 |
+
use_conv (`bool`, default `False`):
|
44 |
+
option to use a convolution.
|
45 |
+
use_conv_transpose (`bool`, default `False`):
|
46 |
+
option to use a convolution transpose.
|
47 |
+
out_channels (`int`, optional):
|
48 |
+
number of output channels. Defaults to `channels`.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self, channels: int, out_channels: int, stride: int = 2):
|
52 |
+
super().__init__()
|
53 |
+
self.channels = channels
|
54 |
+
self.out_channels = out_channels
|
55 |
+
self.stride = stride
|
56 |
+
# In this mode, first repeat interpolate, than conv with stride=1
|
57 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
|
58 |
+
|
59 |
+
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
|
60 |
+
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
|
61 |
+
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
|
62 |
+
outputs = self.conv(outputs)
|
63 |
+
return outputs, input_lengths * self.stride
|
64 |
+
|
65 |
+
|
66 |
+
class PreLookaheadLayer(nn.Module):
|
67 |
+
def __init__(self, channels: int, pre_lookahead_len: int = 1):
|
68 |
+
super().__init__()
|
69 |
+
self.channels = channels
|
70 |
+
self.pre_lookahead_len = pre_lookahead_len
|
71 |
+
self.conv1 = nn.Conv1d(
|
72 |
+
channels, channels,
|
73 |
+
kernel_size=pre_lookahead_len + 1,
|
74 |
+
stride=1, padding=0,
|
75 |
+
)
|
76 |
+
self.conv2 = nn.Conv1d(
|
77 |
+
channels, channels,
|
78 |
+
kernel_size=3, stride=1, padding=0,
|
79 |
+
)
|
80 |
+
|
81 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
82 |
+
"""
|
83 |
+
inputs: (batch_size, seq_len, channels)
|
84 |
+
"""
|
85 |
+
outputs = inputs.transpose(1, 2).contiguous()
|
86 |
+
# look ahead
|
87 |
+
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
|
88 |
+
outputs = F.leaky_relu(self.conv1(outputs))
|
89 |
+
# outputs
|
90 |
+
outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
|
91 |
+
outputs = self.conv2(outputs)
|
92 |
+
outputs = outputs.transpose(1, 2).contiguous()
|
93 |
+
|
94 |
+
# residual connection
|
95 |
+
outputs = outputs + inputs
|
96 |
+
return outputs
|
97 |
+
|
98 |
+
|
99 |
+
class UpsampleConformerEncoder(torch.nn.Module):
|
100 |
+
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
input_size: int = 512,
|
104 |
+
output_size: int = 512,
|
105 |
+
attention_heads: int = 8,
|
106 |
+
linear_units: int = 2048,
|
107 |
+
num_blocks: int = 6,
|
108 |
+
dropout_rate: float = 0.1,
|
109 |
+
positional_dropout_rate: float = 0.1,
|
110 |
+
attention_dropout_rate: float = 0.1,
|
111 |
+
input_layer: str = "linear",
|
112 |
+
pos_enc_layer_type: str = "rel_pos_espnet",
|
113 |
+
normalize_before: bool = True,
|
114 |
+
static_chunk_size: int = 0,
|
115 |
+
use_dynamic_chunk: bool = False,
|
116 |
+
global_cmvn: torch.nn.Module = None,
|
117 |
+
use_dynamic_left_chunk: bool = False,
|
118 |
+
positionwise_conv_kernel_size: int = 1,
|
119 |
+
macaron_style: bool = False,
|
120 |
+
selfattention_layer_type: str = "rel_selfattn",
|
121 |
+
activation_type: str = "swish",
|
122 |
+
use_cnn_module: bool = False,
|
123 |
+
cnn_module_kernel: int = 15,
|
124 |
+
causal: bool = False,
|
125 |
+
cnn_module_norm: str = "batch_norm",
|
126 |
+
key_bias: bool = True,
|
127 |
+
gradient_checkpointing: bool = False,
|
128 |
+
):
|
129 |
+
"""
|
130 |
+
Args:
|
131 |
+
input_size (int): input dim
|
132 |
+
output_size (int): dimension of attention
|
133 |
+
attention_heads (int): the number of heads of multi head attention
|
134 |
+
linear_units (int): the hidden units number of position-wise feed
|
135 |
+
forward
|
136 |
+
num_blocks (int): the number of decoder blocks
|
137 |
+
dropout_rate (float): dropout rate
|
138 |
+
attention_dropout_rate (float): dropout rate in attention
|
139 |
+
positional_dropout_rate (float): dropout rate after adding
|
140 |
+
positional encoding
|
141 |
+
input_layer (str): input layer type.
|
142 |
+
optional [linear, conv2d, conv2d6, conv2d8]
|
143 |
+
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
144 |
+
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
|
145 |
+
normalize_before (bool):
|
146 |
+
True: use layer_norm before each sub-block of a layer.
|
147 |
+
False: use layer_norm after each sub-block of a layer.
|
148 |
+
static_chunk_size (int): chunk size for static chunk training and
|
149 |
+
decoding
|
150 |
+
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
151 |
+
training or not, You can only use fixed chunk(chunk_size > 0)
|
152 |
+
or dyanmic chunk size(use_dynamic_chunk = True)
|
153 |
+
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
|
154 |
+
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
155 |
+
dynamic chunk training
|
156 |
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
157 |
+
gradient_checkpointing: rerunning a forward-pass segment for each
|
158 |
+
checkpointed segment during backward.
|
159 |
+
"""
|
160 |
+
super().__init__()
|
161 |
+
self._output_size = output_size
|
162 |
+
|
163 |
+
self.global_cmvn = global_cmvn
|
164 |
+
self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
165 |
+
input_size,
|
166 |
+
output_size,
|
167 |
+
dropout_rate,
|
168 |
+
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
169 |
+
positional_dropout_rate),
|
170 |
+
)
|
171 |
+
|
172 |
+
self.normalize_before = normalize_before
|
173 |
+
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
174 |
+
self.static_chunk_size = static_chunk_size
|
175 |
+
self.use_dynamic_chunk = use_dynamic_chunk
|
176 |
+
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
177 |
+
self.gradient_checkpointing = gradient_checkpointing
|
178 |
+
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
179 |
+
# self-attention module definition
|
180 |
+
encoder_selfattn_layer_args = (
|
181 |
+
attention_heads,
|
182 |
+
output_size,
|
183 |
+
attention_dropout_rate,
|
184 |
+
key_bias,
|
185 |
+
)
|
186 |
+
# feed-forward module definition
|
187 |
+
positionwise_layer_args = (
|
188 |
+
output_size,
|
189 |
+
linear_units,
|
190 |
+
dropout_rate,
|
191 |
+
activation,
|
192 |
+
)
|
193 |
+
# convolution module definition
|
194 |
+
convolution_layer_args = (output_size, cnn_module_kernel, activation,
|
195 |
+
cnn_module_norm, causal)
|
196 |
+
self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
|
197 |
+
self.encoders = torch.nn.ModuleList([
|
198 |
+
ConformerEncoderLayer(
|
199 |
+
output_size,
|
200 |
+
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
201 |
+
*encoder_selfattn_layer_args),
|
202 |
+
PositionwiseFeedForward(*positionwise_layer_args),
|
203 |
+
PositionwiseFeedForward(
|
204 |
+
*positionwise_layer_args) if macaron_style else None,
|
205 |
+
ConvolutionModule(
|
206 |
+
*convolution_layer_args) if use_cnn_module else None,
|
207 |
+
dropout_rate,
|
208 |
+
normalize_before,
|
209 |
+
) for _ in range(num_blocks)
|
210 |
+
])
|
211 |
+
self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
|
212 |
+
self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
213 |
+
input_size,
|
214 |
+
output_size,
|
215 |
+
dropout_rate,
|
216 |
+
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
217 |
+
positional_dropout_rate),
|
218 |
+
)
|
219 |
+
self.up_encoders = torch.nn.ModuleList([
|
220 |
+
ConformerEncoderLayer(
|
221 |
+
output_size,
|
222 |
+
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
223 |
+
*encoder_selfattn_layer_args),
|
224 |
+
PositionwiseFeedForward(*positionwise_layer_args),
|
225 |
+
PositionwiseFeedForward(
|
226 |
+
*positionwise_layer_args) if macaron_style else None,
|
227 |
+
ConvolutionModule(
|
228 |
+
*convolution_layer_args) if use_cnn_module else None,
|
229 |
+
dropout_rate,
|
230 |
+
normalize_before,
|
231 |
+
) for _ in range(4)
|
232 |
+
])
|
233 |
+
|
234 |
+
def output_size(self) -> int:
|
235 |
+
return self._output_size
|
236 |
+
|
237 |
+
def forward(
|
238 |
+
self,
|
239 |
+
xs: torch.Tensor,
|
240 |
+
xs_lens: torch.Tensor,
|
241 |
+
decoding_chunk_size: int = 0,
|
242 |
+
num_decoding_left_chunks: int = -1,
|
243 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
244 |
+
"""Embed positions in tensor.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
xs: padded input tensor (B, T, D)
|
248 |
+
xs_lens: input length (B)
|
249 |
+
decoding_chunk_size: decoding chunk size for dynamic chunk
|
250 |
+
0: default for training, use random dynamic chunk.
|
251 |
+
<0: for decoding, use full chunk.
|
252 |
+
>0: for decoding, use fixed chunk size as set.
|
253 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
254 |
+
the chunk size is decoding_chunk_size.
|
255 |
+
>=0: use num_decoding_left_chunks
|
256 |
+
<0: use all left chunks
|
257 |
+
Returns:
|
258 |
+
encoder output tensor xs, and subsampled masks
|
259 |
+
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
260 |
+
masks: torch.Tensor batch padding mask after subsample
|
261 |
+
(B, 1, T' ~= T/subsample_rate)
|
262 |
+
NOTE(xcsong):
|
263 |
+
We pass the `__call__` method of the modules instead of `forward` to the
|
264 |
+
checkpointing API because `__call__` attaches all the hooks of the module.
|
265 |
+
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
266 |
+
"""
|
267 |
+
T = xs.size(1)
|
268 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
269 |
+
if self.global_cmvn is not None:
|
270 |
+
xs = self.global_cmvn(xs)
|
271 |
+
xs, pos_emb, masks = self.embed(xs, masks)
|
272 |
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
273 |
+
chunk_masks = add_optional_chunk_mask(xs, masks,
|
274 |
+
self.use_dynamic_chunk,
|
275 |
+
self.use_dynamic_left_chunk,
|
276 |
+
decoding_chunk_size,
|
277 |
+
self.static_chunk_size,
|
278 |
+
num_decoding_left_chunks)
|
279 |
+
# lookahead + conformer encoder
|
280 |
+
xs = self.pre_lookahead_layer(xs)
|
281 |
+
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
282 |
+
|
283 |
+
# upsample + conformer encoder
|
284 |
+
xs = xs.transpose(1, 2).contiguous()
|
285 |
+
xs, xs_lens = self.up_layer(xs, xs_lens)
|
286 |
+
xs = xs.transpose(1, 2).contiguous()
|
287 |
+
T = xs.size(1)
|
288 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
289 |
+
xs, pos_emb, masks = self.up_embed(xs, masks)
|
290 |
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
291 |
+
chunk_masks = add_optional_chunk_mask(xs, masks,
|
292 |
+
self.use_dynamic_chunk,
|
293 |
+
self.use_dynamic_left_chunk,
|
294 |
+
decoding_chunk_size,
|
295 |
+
self.static_chunk_size * self.up_layer.stride,
|
296 |
+
num_decoding_left_chunks)
|
297 |
+
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
|
298 |
+
|
299 |
+
if self.normalize_before:
|
300 |
+
xs = self.after_norm(xs)
|
301 |
+
# Here we assume the mask is not changed in encoder layers, so just
|
302 |
+
# return the masks before encoder layers, and the masks will be used
|
303 |
+
# for cross attention with decoder later
|
304 |
+
return xs, masks
|
305 |
+
|
306 |
+
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
307 |
+
pos_emb: torch.Tensor,
|
308 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
309 |
+
for layer in self.encoders:
|
310 |
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
311 |
+
return xs
|
312 |
+
|
313 |
+
def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
314 |
+
pos_emb: torch.Tensor,
|
315 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
316 |
+
for layer in self.up_encoders:
|
317 |
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
318 |
+
return xs
|
src/chatterbox/models/s3gen/utils/class_utils.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright [2023-11-28] <[email protected], Xingchen Song>
|
2 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from ..transformer.activation import Swish
|
18 |
+
from ..transformer.subsampling import (
|
19 |
+
LinearNoSubsampling,
|
20 |
+
EmbedinigNoSubsampling,
|
21 |
+
Conv1dSubsampling2,
|
22 |
+
Conv2dSubsampling4,
|
23 |
+
Conv2dSubsampling6,
|
24 |
+
Conv2dSubsampling8,
|
25 |
+
)
|
26 |
+
from ..transformer.embedding import (
|
27 |
+
PositionalEncoding,
|
28 |
+
RelPositionalEncoding,
|
29 |
+
WhisperPositionalEncoding,
|
30 |
+
LearnablePositionalEncoding,
|
31 |
+
NoPositionalEncoding)
|
32 |
+
from ..transformer.attention import (MultiHeadedAttention,
|
33 |
+
RelPositionMultiHeadedAttention)
|
34 |
+
from ..transformer.embedding import EspnetRelPositionalEncoding
|
35 |
+
from ..transformer.subsampling import LegacyLinearNoSubsampling
|
36 |
+
|
37 |
+
|
38 |
+
COSYVOICE_ACTIVATION_CLASSES = {
|
39 |
+
"hardtanh": torch.nn.Hardtanh,
|
40 |
+
"tanh": torch.nn.Tanh,
|
41 |
+
"relu": torch.nn.ReLU,
|
42 |
+
"selu": torch.nn.SELU,
|
43 |
+
"swish": getattr(torch.nn, "SiLU", Swish),
|
44 |
+
"gelu": torch.nn.GELU,
|
45 |
+
}
|
46 |
+
|
47 |
+
COSYVOICE_SUBSAMPLE_CLASSES = {
|
48 |
+
"linear": LinearNoSubsampling,
|
49 |
+
"linear_legacy": LegacyLinearNoSubsampling,
|
50 |
+
"embed": EmbedinigNoSubsampling,
|
51 |
+
"conv1d2": Conv1dSubsampling2,
|
52 |
+
"conv2d": Conv2dSubsampling4,
|
53 |
+
"conv2d6": Conv2dSubsampling6,
|
54 |
+
"conv2d8": Conv2dSubsampling8,
|
55 |
+
'paraformer_dummy': torch.nn.Identity
|
56 |
+
}
|
57 |
+
|
58 |
+
COSYVOICE_EMB_CLASSES = {
|
59 |
+
"embed": PositionalEncoding,
|
60 |
+
"abs_pos": PositionalEncoding,
|
61 |
+
"rel_pos": RelPositionalEncoding,
|
62 |
+
"rel_pos_espnet": EspnetRelPositionalEncoding,
|
63 |
+
"no_pos": NoPositionalEncoding,
|
64 |
+
"abs_pos_whisper": WhisperPositionalEncoding,
|
65 |
+
"embed_learnable_pe": LearnablePositionalEncoding,
|
66 |
+
}
|
67 |
+
|
68 |
+
COSYVOICE_ATTENTION_CLASSES = {
|
69 |
+
"selfattn": MultiHeadedAttention,
|
70 |
+
"rel_selfattn": RelPositionMultiHeadedAttention,
|
71 |
+
}
|
src/chatterbox/models/s3gen/utils/mask.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019 Shigeki Karita
|
2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
3 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
'''
|
20 |
+
def subsequent_mask(
|
21 |
+
size: int,
|
22 |
+
device: torch.device = torch.device("cpu"),
|
23 |
+
) -> torch.Tensor:
|
24 |
+
"""Create mask for subsequent steps (size, size).
|
25 |
+
|
26 |
+
This mask is used only in decoder which works in an auto-regressive mode.
|
27 |
+
This means the current step could only do attention with its left steps.
|
28 |
+
|
29 |
+
In encoder, fully attention is used when streaming is not necessary and
|
30 |
+
the sequence is not long. In this case, no attention mask is needed.
|
31 |
+
|
32 |
+
When streaming is need, chunk-based attention is used in encoder. See
|
33 |
+
subsequent_chunk_mask for the chunk-based attention mask.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
size (int): size of mask
|
37 |
+
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
38 |
+
dtype (torch.device): result dtype
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
torch.Tensor: mask
|
42 |
+
|
43 |
+
Examples:
|
44 |
+
>>> subsequent_mask(3)
|
45 |
+
[[1, 0, 0],
|
46 |
+
[1, 1, 0],
|
47 |
+
[1, 1, 1]]
|
48 |
+
"""
|
49 |
+
ret = torch.ones(size, size, device=device, dtype=torch.bool)
|
50 |
+
return torch.tril(ret)
|
51 |
+
'''
|
52 |
+
|
53 |
+
|
54 |
+
def subsequent_chunk_mask(
|
55 |
+
size: int,
|
56 |
+
chunk_size: int,
|
57 |
+
num_left_chunks: int = -1,
|
58 |
+
device: torch.device = torch.device("cpu"),
|
59 |
+
) -> torch.Tensor:
|
60 |
+
"""Create mask for subsequent steps (size, size) with chunk size,
|
61 |
+
this is for streaming encoder
|
62 |
+
|
63 |
+
Args:
|
64 |
+
size (int): size of mask
|
65 |
+
chunk_size (int): size of chunk
|
66 |
+
num_left_chunks (int): number of left chunks
|
67 |
+
<0: use full chunk
|
68 |
+
>=0: use num_left_chunks
|
69 |
+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
torch.Tensor: mask
|
73 |
+
|
74 |
+
Examples:
|
75 |
+
>>> subsequent_chunk_mask(4, 2)
|
76 |
+
[[1, 1, 0, 0],
|
77 |
+
[1, 1, 0, 0],
|
78 |
+
[1, 1, 1, 1],
|
79 |
+
[1, 1, 1, 1]]
|
80 |
+
"""
|
81 |
+
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
|
82 |
+
# actually this is not needed after we have inference cache implemented, will remove it later
|
83 |
+
pos_idx = torch.arange(size, device=device)
|
84 |
+
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
|
85 |
+
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
86 |
+
return ret
|
87 |
+
|
88 |
+
|
89 |
+
def add_optional_chunk_mask(xs: torch.Tensor,
|
90 |
+
masks: torch.Tensor,
|
91 |
+
use_dynamic_chunk: bool,
|
92 |
+
use_dynamic_left_chunk: bool,
|
93 |
+
decoding_chunk_size: int,
|
94 |
+
static_chunk_size: int,
|
95 |
+
num_decoding_left_chunks: int,
|
96 |
+
enable_full_context: bool = True):
|
97 |
+
""" Apply optional mask for encoder.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
101 |
+
mask (torch.Tensor): mask for xs, (B, 1, L)
|
102 |
+
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
103 |
+
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
104 |
+
training.
|
105 |
+
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
106 |
+
0: default for training, use random dynamic chunk.
|
107 |
+
<0: for decoding, use full chunk.
|
108 |
+
>0: for decoding, use fixed chunk size as set.
|
109 |
+
static_chunk_size (int): chunk size for static chunk training/decoding
|
110 |
+
if it's greater than 0, if use_dynamic_chunk is true,
|
111 |
+
this parameter will be ignored
|
112 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
113 |
+
the chunk size is decoding_chunk_size.
|
114 |
+
>=0: use num_decoding_left_chunks
|
115 |
+
<0: use all left chunks
|
116 |
+
enable_full_context (bool):
|
117 |
+
True: chunk size is either [1, 25] or full context(max_len)
|
118 |
+
False: chunk size ~ U[1, 25]
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
torch.Tensor: chunk mask of the input xs.
|
122 |
+
"""
|
123 |
+
# Whether to use chunk mask or not
|
124 |
+
if use_dynamic_chunk:
|
125 |
+
max_len = xs.size(1)
|
126 |
+
if decoding_chunk_size < 0:
|
127 |
+
chunk_size = max_len
|
128 |
+
num_left_chunks = -1
|
129 |
+
elif decoding_chunk_size > 0:
|
130 |
+
chunk_size = decoding_chunk_size
|
131 |
+
num_left_chunks = num_decoding_left_chunks
|
132 |
+
else:
|
133 |
+
# chunk size is either [1, 25] or full context(max_len).
|
134 |
+
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
135 |
+
# delay, the maximum frame is 100 / 4 = 25.
|
136 |
+
chunk_size = torch.randint(1, max_len, (1, )).item()
|
137 |
+
num_left_chunks = -1
|
138 |
+
if chunk_size > max_len // 2 and enable_full_context:
|
139 |
+
chunk_size = max_len
|
140 |
+
else:
|
141 |
+
chunk_size = chunk_size % 25 + 1
|
142 |
+
if use_dynamic_left_chunk:
|
143 |
+
max_left_chunks = (max_len - 1) // chunk_size
|
144 |
+
num_left_chunks = torch.randint(0, max_left_chunks,
|
145 |
+
(1, )).item()
|
146 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
147 |
+
num_left_chunks,
|
148 |
+
xs.device) # (L, L)
|
149 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
150 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
151 |
+
elif static_chunk_size > 0:
|
152 |
+
num_left_chunks = num_decoding_left_chunks
|
153 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
154 |
+
num_left_chunks,
|
155 |
+
xs.device) # (L, L)
|
156 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
157 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
158 |
+
else:
|
159 |
+
chunk_masks = masks
|
160 |
+
assert chunk_masks.dtype == torch.bool
|
161 |
+
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
|
162 |
+
logging.warning('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
|
163 |
+
chunk_masks[chunk_masks.sum(dim=-1)==0] = True
|
164 |
+
return chunk_masks
|
165 |
+
|
166 |
+
|
167 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
168 |
+
"""Make mask tensor containing indices of padded part.
|
169 |
+
|
170 |
+
See description of make_non_pad_mask.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
lengths (torch.Tensor): Batch of lengths (B,).
|
174 |
+
Returns:
|
175 |
+
torch.Tensor: Mask tensor containing indices of padded part.
|
176 |
+
|
177 |
+
Examples:
|
178 |
+
>>> lengths = [5, 3, 2]
|
179 |
+
>>> make_pad_mask(lengths)
|
180 |
+
masks = [[0, 0, 0, 0 ,0],
|
181 |
+
[0, 0, 0, 1, 1],
|
182 |
+
[0, 0, 1, 1, 1]]
|
183 |
+
"""
|
184 |
+
batch_size = lengths.size(0)
|
185 |
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
186 |
+
seq_range = torch.arange(0,
|
187 |
+
max_len,
|
188 |
+
dtype=torch.int64,
|
189 |
+
device=lengths.device)
|
190 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
191 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
192 |
+
mask = seq_range_expand >= seq_length_expand
|
193 |
+
return mask
|
src/chatterbox/models/s3gen/utils/mel.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""mel-spectrogram extraction in Matcha-TTS"""
|
2 |
+
import logging
|
3 |
+
from librosa.filters import mel as librosa_mel_fn
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
# NOTE: they decalred these global vars
|
11 |
+
mel_basis = {}
|
12 |
+
hann_window = {}
|
13 |
+
|
14 |
+
|
15 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
16 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
17 |
+
|
18 |
+
|
19 |
+
def spectral_normalize_torch(magnitudes):
|
20 |
+
output = dynamic_range_compression_torch(magnitudes)
|
21 |
+
return output
|
22 |
+
|
23 |
+
"""
|
24 |
+
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
25 |
+
n_fft: 1920
|
26 |
+
num_mels: 80
|
27 |
+
sampling_rate: 24000
|
28 |
+
hop_size: 480
|
29 |
+
win_size: 1920
|
30 |
+
fmin: 0
|
31 |
+
fmax: 8000
|
32 |
+
center: False
|
33 |
+
|
34 |
+
"""
|
35 |
+
|
36 |
+
def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920,
|
37 |
+
fmin=0, fmax=8000, center=False):
|
38 |
+
"""Copied from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/audio.py
|
39 |
+
Set default values according to Cosyvoice's config.
|
40 |
+
"""
|
41 |
+
|
42 |
+
if isinstance(y, np.ndarray):
|
43 |
+
y = torch.tensor(y).float()
|
44 |
+
|
45 |
+
if len(y.shape) == 1:
|
46 |
+
y = y[None, ]
|
47 |
+
|
48 |
+
# Debug: Check for audio clipping (values outside [-1.0, 1.0] range)
|
49 |
+
min_val = torch.min(y)
|
50 |
+
max_val = torch.max(y)
|
51 |
+
if min_val < -1.0 or max_val > 1.0:
|
52 |
+
logger.warning(f"Audio values outside normalized range: min={min_val.item():.4f}, max={max_val.item():.4f}")
|
53 |
+
|
54 |
+
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
|
55 |
+
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
56 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
57 |
+
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
58 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
59 |
+
|
60 |
+
y = torch.nn.functional.pad(
|
61 |
+
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
62 |
+
)
|
63 |
+
y = y.squeeze(1)
|
64 |
+
|
65 |
+
spec = torch.view_as_real(
|
66 |
+
torch.stft(
|
67 |
+
y,
|
68 |
+
n_fft,
|
69 |
+
hop_length=hop_size,
|
70 |
+
win_length=win_size,
|
71 |
+
window=hann_window[str(y.device)],
|
72 |
+
center=center,
|
73 |
+
pad_mode="reflect",
|
74 |
+
normalized=False,
|
75 |
+
onesided=True,
|
76 |
+
return_complex=True,
|
77 |
+
)
|
78 |
+
)
|
79 |
+
|
80 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
81 |
+
|
82 |
+
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
|
83 |
+
spec = spectral_normalize_torch(spec)
|
84 |
+
|
85 |
+
return spec
|
src/chatterbox/models/s3gen/xvector.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
# Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker)
|
6 |
+
|
7 |
+
|
8 |
+
from collections import OrderedDict
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.utils.checkpoint as cp
|
12 |
+
import torchaudio.compliance.kaldi as Kaldi
|
13 |
+
|
14 |
+
|
15 |
+
def pad_list(xs, pad_value):
|
16 |
+
"""Perform padding for the list of tensors.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
20 |
+
pad_value (float): Value for padding.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
Tensor: Padded tensor (B, Tmax, `*`).
|
24 |
+
|
25 |
+
Examples:
|
26 |
+
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
27 |
+
>>> x
|
28 |
+
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
29 |
+
>>> pad_list(x, 0)
|
30 |
+
tensor([[1., 1., 1., 1.],
|
31 |
+
[1., 1., 0., 0.],
|
32 |
+
[1., 0., 0., 0.]])
|
33 |
+
|
34 |
+
"""
|
35 |
+
n_batch = len(xs)
|
36 |
+
max_len = max(x.size(0) for x in xs)
|
37 |
+
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
38 |
+
|
39 |
+
for i in range(n_batch):
|
40 |
+
pad[i, : xs[i].size(0)] = xs[i]
|
41 |
+
|
42 |
+
return pad
|
43 |
+
|
44 |
+
|
45 |
+
def extract_feature(audio):
|
46 |
+
features = []
|
47 |
+
feature_times = []
|
48 |
+
feature_lengths = []
|
49 |
+
for au in audio:
|
50 |
+
feature = Kaldi.fbank(au.unsqueeze(0), num_mel_bins=80)
|
51 |
+
feature = feature - feature.mean(dim=0, keepdim=True)
|
52 |
+
features.append(feature)
|
53 |
+
feature_times.append(au.shape[0])
|
54 |
+
feature_lengths.append(feature.shape[0])
|
55 |
+
# padding for batch inference
|
56 |
+
features_padded = pad_list(features, pad_value=0)
|
57 |
+
# features = torch.cat(features)
|
58 |
+
return features_padded, feature_lengths, feature_times
|
59 |
+
|
60 |
+
|
61 |
+
class BasicResBlock(torch.nn.Module):
|
62 |
+
expansion = 1
|
63 |
+
|
64 |
+
def __init__(self, in_planes, planes, stride=1):
|
65 |
+
super(BasicResBlock, self).__init__()
|
66 |
+
self.conv1 = torch.nn.Conv2d(
|
67 |
+
in_planes, planes, kernel_size=3, stride=(stride, 1), padding=1, bias=False
|
68 |
+
)
|
69 |
+
self.bn1 = torch.nn.BatchNorm2d(planes)
|
70 |
+
self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
71 |
+
self.bn2 = torch.nn.BatchNorm2d(planes)
|
72 |
+
|
73 |
+
self.shortcut = torch.nn.Sequential()
|
74 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
75 |
+
self.shortcut = torch.nn.Sequential(
|
76 |
+
torch.nn.Conv2d(
|
77 |
+
in_planes,
|
78 |
+
self.expansion * planes,
|
79 |
+
kernel_size=1,
|
80 |
+
stride=(stride, 1),
|
81 |
+
bias=False,
|
82 |
+
),
|
83 |
+
torch.nn.BatchNorm2d(self.expansion * planes),
|
84 |
+
)
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
88 |
+
out = self.bn2(self.conv2(out))
|
89 |
+
out += self.shortcut(x)
|
90 |
+
out = F.relu(out)
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
class FCM(torch.nn.Module):
|
95 |
+
def __init__(self, block=BasicResBlock, num_blocks=[2, 2], m_channels=32, feat_dim=80):
|
96 |
+
super(FCM, self).__init__()
|
97 |
+
self.in_planes = m_channels
|
98 |
+
self.conv1 = torch.nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
99 |
+
self.bn1 = torch.nn.BatchNorm2d(m_channels)
|
100 |
+
|
101 |
+
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
|
102 |
+
self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
|
103 |
+
|
104 |
+
self.conv2 = torch.nn.Conv2d(
|
105 |
+
m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False
|
106 |
+
)
|
107 |
+
self.bn2 = torch.nn.BatchNorm2d(m_channels)
|
108 |
+
self.out_channels = m_channels * (feat_dim // 8)
|
109 |
+
|
110 |
+
def _make_layer(self, block, planes, num_blocks, stride):
|
111 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
112 |
+
layers = []
|
113 |
+
for stride in strides:
|
114 |
+
layers.append(block(self.in_planes, planes, stride))
|
115 |
+
self.in_planes = planes * block.expansion
|
116 |
+
return torch.nn.Sequential(*layers)
|
117 |
+
|
118 |
+
def forward(self, x):
|
119 |
+
x = x.unsqueeze(1)
|
120 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
121 |
+
out = self.layer1(out)
|
122 |
+
out = self.layer2(out)
|
123 |
+
out = F.relu(self.bn2(self.conv2(out)))
|
124 |
+
|
125 |
+
shape = out.shape
|
126 |
+
out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
|
127 |
+
return out
|
128 |
+
|
129 |
+
|
130 |
+
def get_nonlinear(config_str, channels):
|
131 |
+
nonlinear = torch.nn.Sequential()
|
132 |
+
for name in config_str.split("-"):
|
133 |
+
if name == "relu":
|
134 |
+
nonlinear.add_module("relu", torch.nn.ReLU(inplace=True))
|
135 |
+
elif name == "prelu":
|
136 |
+
nonlinear.add_module("prelu", torch.nn.PReLU(channels))
|
137 |
+
elif name == "batchnorm":
|
138 |
+
nonlinear.add_module("batchnorm", torch.nn.BatchNorm1d(channels))
|
139 |
+
elif name == "batchnorm_":
|
140 |
+
nonlinear.add_module("batchnorm", torch.nn.BatchNorm1d(channels, affine=False))
|
141 |
+
else:
|
142 |
+
raise ValueError("Unexpected module ({}).".format(name))
|
143 |
+
return nonlinear
|
144 |
+
|
145 |
+
|
146 |
+
def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
|
147 |
+
mean = x.mean(dim=dim)
|
148 |
+
std = x.std(dim=dim, unbiased=unbiased)
|
149 |
+
stats = torch.cat([mean, std], dim=-1)
|
150 |
+
if keepdim:
|
151 |
+
stats = stats.unsqueeze(dim=dim)
|
152 |
+
return stats
|
153 |
+
|
154 |
+
|
155 |
+
class StatsPool(torch.nn.Module):
|
156 |
+
def forward(self, x):
|
157 |
+
return statistics_pooling(x)
|
158 |
+
|
159 |
+
|
160 |
+
class TDNNLayer(torch.nn.Module):
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
in_channels,
|
164 |
+
out_channels,
|
165 |
+
kernel_size,
|
166 |
+
stride=1,
|
167 |
+
padding=0,
|
168 |
+
dilation=1,
|
169 |
+
bias=False,
|
170 |
+
config_str="batchnorm-relu",
|
171 |
+
):
|
172 |
+
super(TDNNLayer, self).__init__()
|
173 |
+
if padding < 0:
|
174 |
+
assert (
|
175 |
+
kernel_size % 2 == 1
|
176 |
+
), "Expect equal paddings, but got even kernel size ({})".format(kernel_size)
|
177 |
+
padding = (kernel_size - 1) // 2 * dilation
|
178 |
+
self.linear = torch.nn.Conv1d(
|
179 |
+
in_channels,
|
180 |
+
out_channels,
|
181 |
+
kernel_size,
|
182 |
+
stride=stride,
|
183 |
+
padding=padding,
|
184 |
+
dilation=dilation,
|
185 |
+
bias=bias,
|
186 |
+
)
|
187 |
+
self.nonlinear = get_nonlinear(config_str, out_channels)
|
188 |
+
|
189 |
+
def forward(self, x):
|
190 |
+
x = self.linear(x)
|
191 |
+
x = self.nonlinear(x)
|
192 |
+
return x
|
193 |
+
|
194 |
+
|
195 |
+
class CAMLayer(torch.nn.Module):
|
196 |
+
def __init__(
|
197 |
+
self, bn_channels, out_channels, kernel_size, stride, padding, dilation, bias, reduction=2
|
198 |
+
):
|
199 |
+
super(CAMLayer, self).__init__()
|
200 |
+
self.linear_local = torch.nn.Conv1d(
|
201 |
+
bn_channels,
|
202 |
+
out_channels,
|
203 |
+
kernel_size,
|
204 |
+
stride=stride,
|
205 |
+
padding=padding,
|
206 |
+
dilation=dilation,
|
207 |
+
bias=bias,
|
208 |
+
)
|
209 |
+
self.linear1 = torch.nn.Conv1d(bn_channels, bn_channels // reduction, 1)
|
210 |
+
self.relu = torch.nn.ReLU(inplace=True)
|
211 |
+
self.linear2 = torch.nn.Conv1d(bn_channels // reduction, out_channels, 1)
|
212 |
+
self.sigmoid = torch.nn.Sigmoid()
|
213 |
+
|
214 |
+
def forward(self, x):
|
215 |
+
y = self.linear_local(x)
|
216 |
+
context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
|
217 |
+
context = self.relu(self.linear1(context))
|
218 |
+
m = self.sigmoid(self.linear2(context))
|
219 |
+
return y * m
|
220 |
+
|
221 |
+
def seg_pooling(self, x, seg_len=100, stype="avg"):
|
222 |
+
if stype == "avg":
|
223 |
+
seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
|
224 |
+
elif stype == "max":
|
225 |
+
seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
|
226 |
+
else:
|
227 |
+
raise ValueError("Wrong segment pooling type.")
|
228 |
+
shape = seg.shape
|
229 |
+
seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
|
230 |
+
seg = seg[..., : x.shape[-1]]
|
231 |
+
return seg
|
232 |
+
|
233 |
+
|
234 |
+
class CAMDenseTDNNLayer(torch.nn.Module):
|
235 |
+
def __init__(
|
236 |
+
self,
|
237 |
+
in_channels,
|
238 |
+
out_channels,
|
239 |
+
bn_channels,
|
240 |
+
kernel_size,
|
241 |
+
stride=1,
|
242 |
+
dilation=1,
|
243 |
+
bias=False,
|
244 |
+
config_str="batchnorm-relu",
|
245 |
+
memory_efficient=False,
|
246 |
+
):
|
247 |
+
super(CAMDenseTDNNLayer, self).__init__()
|
248 |
+
assert kernel_size % 2 == 1, "Expect equal paddings, but got even kernel size ({})".format(
|
249 |
+
kernel_size
|
250 |
+
)
|
251 |
+
padding = (kernel_size - 1) // 2 * dilation
|
252 |
+
self.memory_efficient = memory_efficient
|
253 |
+
self.nonlinear1 = get_nonlinear(config_str, in_channels)
|
254 |
+
self.linear1 = torch.nn.Conv1d(in_channels, bn_channels, 1, bias=False)
|
255 |
+
self.nonlinear2 = get_nonlinear(config_str, bn_channels)
|
256 |
+
self.cam_layer = CAMLayer(
|
257 |
+
bn_channels,
|
258 |
+
out_channels,
|
259 |
+
kernel_size,
|
260 |
+
stride=stride,
|
261 |
+
padding=padding,
|
262 |
+
dilation=dilation,
|
263 |
+
bias=bias,
|
264 |
+
)
|
265 |
+
|
266 |
+
def bn_function(self, x):
|
267 |
+
return self.linear1(self.nonlinear1(x))
|
268 |
+
|
269 |
+
def forward(self, x):
|
270 |
+
if self.training and self.memory_efficient:
|
271 |
+
x = cp.checkpoint(self.bn_function, x)
|
272 |
+
else:
|
273 |
+
x = self.bn_function(x)
|
274 |
+
x = self.cam_layer(self.nonlinear2(x))
|
275 |
+
return x
|
276 |
+
|
277 |
+
|
278 |
+
class CAMDenseTDNNBlock(torch.nn.ModuleList):
|
279 |
+
def __init__(
|
280 |
+
self,
|
281 |
+
num_layers,
|
282 |
+
in_channels,
|
283 |
+
out_channels,
|
284 |
+
bn_channels,
|
285 |
+
kernel_size,
|
286 |
+
stride=1,
|
287 |
+
dilation=1,
|
288 |
+
bias=False,
|
289 |
+
config_str="batchnorm-relu",
|
290 |
+
memory_efficient=False,
|
291 |
+
):
|
292 |
+
super(CAMDenseTDNNBlock, self).__init__()
|
293 |
+
for i in range(num_layers):
|
294 |
+
layer = CAMDenseTDNNLayer(
|
295 |
+
in_channels=in_channels + i * out_channels,
|
296 |
+
out_channels=out_channels,
|
297 |
+
bn_channels=bn_channels,
|
298 |
+
kernel_size=kernel_size,
|
299 |
+
stride=stride,
|
300 |
+
dilation=dilation,
|
301 |
+
bias=bias,
|
302 |
+
config_str=config_str,
|
303 |
+
memory_efficient=memory_efficient,
|
304 |
+
)
|
305 |
+
self.add_module("tdnnd%d" % (i + 1), layer)
|
306 |
+
|
307 |
+
def forward(self, x):
|
308 |
+
for layer in self:
|
309 |
+
x = torch.cat([x, layer(x)], dim=1)
|
310 |
+
return x
|
311 |
+
|
312 |
+
|
313 |
+
class TransitLayer(torch.nn.Module):
|
314 |
+
def __init__(self, in_channels, out_channels, bias=True, config_str="batchnorm-relu"):
|
315 |
+
super(TransitLayer, self).__init__()
|
316 |
+
self.nonlinear = get_nonlinear(config_str, in_channels)
|
317 |
+
self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias)
|
318 |
+
|
319 |
+
def forward(self, x):
|
320 |
+
x = self.nonlinear(x)
|
321 |
+
x = self.linear(x)
|
322 |
+
return x
|
323 |
+
|
324 |
+
|
325 |
+
class DenseLayer(torch.nn.Module):
|
326 |
+
def __init__(self, in_channels, out_channels, bias=False, config_str="batchnorm-relu"):
|
327 |
+
super(DenseLayer, self).__init__()
|
328 |
+
self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias)
|
329 |
+
self.nonlinear = get_nonlinear(config_str, out_channels)
|
330 |
+
|
331 |
+
def forward(self, x):
|
332 |
+
if len(x.shape) == 2:
|
333 |
+
x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
|
334 |
+
else:
|
335 |
+
x = self.linear(x)
|
336 |
+
x = self.nonlinear(x)
|
337 |
+
return x
|
338 |
+
|
339 |
+
# @tables.register("model_classes", "CAMPPlus")
|
340 |
+
class CAMPPlus(torch.nn.Module):
|
341 |
+
def __init__(
|
342 |
+
self,
|
343 |
+
feat_dim=80,
|
344 |
+
embedding_size=192,
|
345 |
+
growth_rate=32,
|
346 |
+
bn_size=4,
|
347 |
+
init_channels=128,
|
348 |
+
config_str="batchnorm-relu",
|
349 |
+
memory_efficient=True,
|
350 |
+
output_level="segment",
|
351 |
+
**kwargs,
|
352 |
+
):
|
353 |
+
super().__init__()
|
354 |
+
|
355 |
+
self.head = FCM(feat_dim=feat_dim)
|
356 |
+
channels = self.head.out_channels
|
357 |
+
self.output_level = output_level
|
358 |
+
|
359 |
+
self.xvector = torch.nn.Sequential(
|
360 |
+
OrderedDict(
|
361 |
+
[
|
362 |
+
(
|
363 |
+
"tdnn",
|
364 |
+
TDNNLayer(
|
365 |
+
channels,
|
366 |
+
init_channels,
|
367 |
+
5,
|
368 |
+
stride=2,
|
369 |
+
dilation=1,
|
370 |
+
padding=-1,
|
371 |
+
config_str=config_str,
|
372 |
+
),
|
373 |
+
),
|
374 |
+
]
|
375 |
+
)
|
376 |
+
)
|
377 |
+
channels = init_channels
|
378 |
+
for i, (num_layers, kernel_size, dilation) in enumerate(
|
379 |
+
zip((12, 24, 16), (3, 3, 3), (1, 2, 2))
|
380 |
+
):
|
381 |
+
block = CAMDenseTDNNBlock(
|
382 |
+
num_layers=num_layers,
|
383 |
+
in_channels=channels,
|
384 |
+
out_channels=growth_rate,
|
385 |
+
bn_channels=bn_size * growth_rate,
|
386 |
+
kernel_size=kernel_size,
|
387 |
+
dilation=dilation,
|
388 |
+
config_str=config_str,
|
389 |
+
memory_efficient=memory_efficient,
|
390 |
+
)
|
391 |
+
self.xvector.add_module("block%d" % (i + 1), block)
|
392 |
+
channels = channels + num_layers * growth_rate
|
393 |
+
self.xvector.add_module(
|
394 |
+
"transit%d" % (i + 1),
|
395 |
+
TransitLayer(channels, channels // 2, bias=False, config_str=config_str),
|
396 |
+
)
|
397 |
+
channels //= 2
|
398 |
+
|
399 |
+
self.xvector.add_module("out_nonlinear", get_nonlinear(config_str, channels))
|
400 |
+
|
401 |
+
if self.output_level == "segment":
|
402 |
+
self.xvector.add_module("stats", StatsPool())
|
403 |
+
self.xvector.add_module(
|
404 |
+
"dense", DenseLayer(channels * 2, embedding_size, config_str="batchnorm_")
|
405 |
+
)
|
406 |
+
else:
|
407 |
+
assert (
|
408 |
+
self.output_level == "frame"
|
409 |
+
), "`output_level` should be set to 'segment' or 'frame'. "
|
410 |
+
|
411 |
+
for m in self.modules():
|
412 |
+
if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)):
|
413 |
+
torch.nn.init.kaiming_normal_(m.weight.data)
|
414 |
+
if m.bias is not None:
|
415 |
+
torch.nn.init.zeros_(m.bias)
|
416 |
+
|
417 |
+
def forward(self, x):
|
418 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
419 |
+
x = self.head(x)
|
420 |
+
x = self.xvector(x)
|
421 |
+
if self.output_level == "frame":
|
422 |
+
x = x.transpose(1, 2)
|
423 |
+
return x
|
424 |
+
|
425 |
+
def inference(self, audio_list):
|
426 |
+
speech, speech_lengths, speech_times = extract_feature(audio_list)
|
427 |
+
results = self.forward(speech.to(torch.float32))
|
428 |
+
return results
|
src/chatterbox/models/s3tokenizer/__init__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .s3tokenizer import (
|
2 |
+
S3_SR,
|
3 |
+
S3_HOP,
|
4 |
+
S3_TOKEN_HOP,
|
5 |
+
S3_TOKEN_RATE,
|
6 |
+
SPEECH_VOCAB_SIZE,
|
7 |
+
S3Tokenizer,
|
8 |
+
)
|
9 |
+
|
10 |
+
|
11 |
+
SOS = SPEECH_VOCAB_SIZE
|
12 |
+
EOS = SPEECH_VOCAB_SIZE + 1
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
def drop_invalid_tokens(x):
|
17 |
+
"""Drop SoS and EoS"""
|
18 |
+
assert len(x.shape) == 1 or (len(x.shape) == 2 and x.shape[0] == 1), "only batch size of one allowed for now"
|
19 |
+
if SOS in x:
|
20 |
+
s = (x == SOS).nonzero(as_tuple=True)[0].squeeze(0) + 1
|
21 |
+
else:
|
22 |
+
s = 0
|
23 |
+
|
24 |
+
if EOS in x:
|
25 |
+
e = (x == EOS).nonzero(as_tuple=True)[0].squeeze(0)
|
26 |
+
else:
|
27 |
+
e = None
|
28 |
+
|
29 |
+
x = x[s: e]
|
30 |
+
return x
|
src/chatterbox/models/s3tokenizer/s3tokenizer.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import librosa
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from s3tokenizer.utils import padding
|
8 |
+
from s3tokenizer.model_v2 import (
|
9 |
+
S3TokenizerV2,
|
10 |
+
ModelConfig,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
# Sampling rate of the inputs to S3TokenizerV2
|
15 |
+
S3_SR = 16_000
|
16 |
+
S3_HOP = 160 # 100 frames/sec
|
17 |
+
S3_TOKEN_HOP = 640 # 25 tokens/sec
|
18 |
+
S3_TOKEN_RATE = 25
|
19 |
+
SPEECH_VOCAB_SIZE = 6561
|
20 |
+
|
21 |
+
|
22 |
+
class S3Tokenizer(S3TokenizerV2):
|
23 |
+
"""
|
24 |
+
s3tokenizer.S3TokenizerV2 with the following changes:
|
25 |
+
- a more integrated `forward`
|
26 |
+
- compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers`
|
27 |
+
"""
|
28 |
+
|
29 |
+
ignore_state_dict_missing = ("_mel_filters", "window")
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
name: str="speech_tokenizer_v2_25hz",
|
34 |
+
config: ModelConfig = ModelConfig()
|
35 |
+
):
|
36 |
+
super().__init__(name)
|
37 |
+
|
38 |
+
self.n_fft = 400
|
39 |
+
_mel_filters = librosa.filters.mel(
|
40 |
+
sr=S3_SR,
|
41 |
+
n_fft=self.n_fft,
|
42 |
+
n_mels=config.n_mels
|
43 |
+
)
|
44 |
+
self.register_buffer(
|
45 |
+
"_mel_filters",
|
46 |
+
torch.FloatTensor(_mel_filters),
|
47 |
+
)
|
48 |
+
|
49 |
+
self.register_buffer(
|
50 |
+
"window",
|
51 |
+
torch.hann_window(self.n_fft),
|
52 |
+
)
|
53 |
+
|
54 |
+
def pad(self, wavs, sr) -> List[torch.Tensor]:
|
55 |
+
"""
|
56 |
+
Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec).
|
57 |
+
"""
|
58 |
+
processed_wavs = []
|
59 |
+
for wav in wavs:
|
60 |
+
if isinstance(wav, np.ndarray):
|
61 |
+
wav = torch.from_numpy(wav)
|
62 |
+
if wav.dim() == 1:
|
63 |
+
wav = wav.unsqueeze(0)
|
64 |
+
|
65 |
+
n_tokens = (wav.shape[1] / sr) * S3_TOKEN_RATE
|
66 |
+
n_tokens = np.ceil(n_tokens)
|
67 |
+
intended_wav_len = n_tokens * (sr / S3_TOKEN_RATE)
|
68 |
+
intended_wav_len = int(intended_wav_len)
|
69 |
+
wav = torch.nn.functional.pad(
|
70 |
+
wav,
|
71 |
+
(0, intended_wav_len - wav.shape[-1]),
|
72 |
+
mode="constant",
|
73 |
+
value=0
|
74 |
+
)
|
75 |
+
processed_wavs.append(wav)
|
76 |
+
return processed_wavs
|
77 |
+
|
78 |
+
def _prepare_audio(self, wavs):
|
79 |
+
"""Prepare a list of audios for s3tokenizer processing."""
|
80 |
+
processed_wavs = []
|
81 |
+
for wav in wavs:
|
82 |
+
if isinstance(wav, np.ndarray):
|
83 |
+
wav = torch.from_numpy(wav)
|
84 |
+
if wav.dim() == 1:
|
85 |
+
wav = wav.unsqueeze(0)
|
86 |
+
|
87 |
+
processed_wavs.append(wav)
|
88 |
+
return processed_wavs
|
89 |
+
|
90 |
+
@torch.no_grad()
|
91 |
+
def forward(
|
92 |
+
self,
|
93 |
+
wavs: torch.Tensor,
|
94 |
+
accelerator: 'Accelerator'=None,
|
95 |
+
max_len: int=None,
|
96 |
+
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
97 |
+
"""
|
98 |
+
NOTE: mel-spec has a hop size of 160 points (100 frame/sec).
|
99 |
+
FIXME: this class inherits `nn.Module` but doesn't accept `torch.Tensor` and handles a list of wavs one by one, which is unexpected.
|
100 |
+
|
101 |
+
Args
|
102 |
+
----
|
103 |
+
- `wavs`: 16 kHz speech audio
|
104 |
+
- `max_len` max length to truncate the output sequence to (25 token/sec).
|
105 |
+
NOTE: please pad the waveform if longer sequence is needed.
|
106 |
+
"""
|
107 |
+
processed_wavs = self._prepare_audio(wavs)
|
108 |
+
mels, mel_lens = [], []
|
109 |
+
for wav in processed_wavs:
|
110 |
+
wav = wav.to(self.device)
|
111 |
+
mel = self.log_mel_spectrogram(wav) # [B=1, F, T]
|
112 |
+
if max_len is not None:
|
113 |
+
mel = mel[..., :max_len * 4] # num_mel_frames = 4 * num_tokens
|
114 |
+
mels.append(mel.squeeze(0))
|
115 |
+
|
116 |
+
mels, mel_lens = padding(mels)
|
117 |
+
if accelerator is None:
|
118 |
+
tokenizer = self
|
119 |
+
else:
|
120 |
+
tokenizer = accelerator.unwrap_model(self)
|
121 |
+
|
122 |
+
speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device))
|
123 |
+
return (
|
124 |
+
speech_tokens.long().detach(),
|
125 |
+
speech_token_lens.long().detach(),
|
126 |
+
)
|
127 |
+
|
128 |
+
def log_mel_spectrogram(
|
129 |
+
self,
|
130 |
+
audio: torch.Tensor,
|
131 |
+
padding: int = 0,
|
132 |
+
):
|
133 |
+
"""
|
134 |
+
Compute the log-Mel spectrogram of
|
135 |
+
|
136 |
+
Parameters
|
137 |
+
----------
|
138 |
+
audio: torch.Tensor, shape = (*)
|
139 |
+
The path to audio or either a NumPy array or Tensor containing the
|
140 |
+
audio waveform in 16 kHz
|
141 |
+
|
142 |
+
padding: int
|
143 |
+
Number of zero samples to pad to the right
|
144 |
+
|
145 |
+
Returns
|
146 |
+
-------
|
147 |
+
torch.Tensor, shape = (128, n_frames)
|
148 |
+
A Tensor that contains the Mel spectrogram
|
149 |
+
"""
|
150 |
+
if not torch.is_tensor(audio):
|
151 |
+
audio = torch.from_numpy(audio)
|
152 |
+
|
153 |
+
audio = audio.to(self.device)
|
154 |
+
if padding > 0:
|
155 |
+
audio = F.pad(audio, (0, padding))
|
156 |
+
stft = torch.stft(
|
157 |
+
audio, self.n_fft, S3_HOP,
|
158 |
+
window=self.window.to(self.device),
|
159 |
+
return_complex=True
|
160 |
+
)
|
161 |
+
magnitudes = stft[..., :-1].abs()**2
|
162 |
+
|
163 |
+
mel_spec = self._mel_filters.to(self.device) @ magnitudes
|
164 |
+
|
165 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
166 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
167 |
+
log_spec = (log_spec + 4.0) / 4.0
|
168 |
+
return log_spec
|
src/chatterbox/models/t3/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .t3 import T3
|
src/chatterbox/models/t3/inference/alignment_stream_analyzer.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 Resemble AI
|
2 |
+
# Author: John Meade, Jeremy Hsu
|
3 |
+
# MIT License
|
4 |
+
import logging
|
5 |
+
import torch
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from types import MethodType
|
8 |
+
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
LLAMA_ALIGNED_HEADS = [(12, 15), (13, 11), (9, 2)]
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class AlignmentAnalysisResult:
|
18 |
+
# was this frame detected as being part of a noisy beginning chunk with potential hallucinations?
|
19 |
+
false_start: bool
|
20 |
+
# was this frame detected as being part of a long tail with potential hallucinations?
|
21 |
+
long_tail: bool
|
22 |
+
# was this frame detected as repeating existing text content?
|
23 |
+
repetition: bool
|
24 |
+
# was the alignment position of this frame too far from the previous frame?
|
25 |
+
discontinuity: bool
|
26 |
+
# has inference reached the end of the text tokens? eg, this remains false if inference stops early
|
27 |
+
complete: bool
|
28 |
+
# approximate position in the text token sequence. Can be used for generating online timestamps.
|
29 |
+
position: int
|
30 |
+
|
31 |
+
|
32 |
+
class AlignmentStreamAnalyzer:
|
33 |
+
def __init__(self, tfmr, queue, text_tokens_slice, alignment_layer_idx=9, eos_idx=0):
|
34 |
+
"""
|
35 |
+
Some transformer TTS models implicitly solve text-speech alignment in one or more of their self-attention
|
36 |
+
activation maps. This module exploits this to perform online integrity checks which streaming.
|
37 |
+
A hook is injected into the specified attention layer, and heuristics are used to determine alignment
|
38 |
+
position, repetition, etc.
|
39 |
+
|
40 |
+
NOTE: currently requires no queues.
|
41 |
+
"""
|
42 |
+
# self.queue = queue
|
43 |
+
self.text_tokens_slice = (i, j) = text_tokens_slice
|
44 |
+
self.eos_idx = eos_idx
|
45 |
+
self.alignment = torch.zeros(0, j-i)
|
46 |
+
# self.alignment_bin = torch.zeros(0, j-i)
|
47 |
+
self.curr_frame_pos = 0
|
48 |
+
self.text_position = 0
|
49 |
+
|
50 |
+
self.started = False
|
51 |
+
self.started_at = None
|
52 |
+
|
53 |
+
self.complete = False
|
54 |
+
self.completed_at = None
|
55 |
+
|
56 |
+
# Track generated tokens for repetition detection
|
57 |
+
self.generated_tokens = []
|
58 |
+
|
59 |
+
# Using `output_attentions=True` is incompatible with optimized attention kernels, so
|
60 |
+
# using it for all layers slows things down too much. We can apply it to just one layer
|
61 |
+
# by intercepting the kwargs and adding a forward hook (credit: jrm)
|
62 |
+
self.last_aligned_attns = []
|
63 |
+
for i, (layer_idx, head_idx) in enumerate(LLAMA_ALIGNED_HEADS):
|
64 |
+
self.last_aligned_attns += [None]
|
65 |
+
self._add_attention_spy(tfmr, i, layer_idx, head_idx)
|
66 |
+
|
67 |
+
def _add_attention_spy(self, tfmr, buffer_idx, layer_idx, head_idx):
|
68 |
+
"""
|
69 |
+
Adds a forward hook to a specific attention layer to collect outputs.
|
70 |
+
"""
|
71 |
+
def attention_forward_hook(module, input, output):
|
72 |
+
"""
|
73 |
+
See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`.
|
74 |
+
NOTE:
|
75 |
+
- When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`.
|
76 |
+
- `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th.
|
77 |
+
"""
|
78 |
+
if isinstance(output, tuple) and len(output) > 1 and output[1] is not None:
|
79 |
+
step_attention = output[1].cpu() # (B, n_heads, T0, Ti)
|
80 |
+
self.last_aligned_attns[buffer_idx] = step_attention[0, head_idx] # (T0, Ti)
|
81 |
+
|
82 |
+
target_layer = tfmr.layers[layer_idx].self_attn
|
83 |
+
# Register hook and store the handle
|
84 |
+
target_layer.register_forward_hook(attention_forward_hook)
|
85 |
+
if hasattr(tfmr, 'config') and hasattr(tfmr.config, 'output_attentions'):
|
86 |
+
self.original_output_attentions = tfmr.config.output_attentions
|
87 |
+
tfmr.config.output_attentions = True
|
88 |
+
|
89 |
+
def step(self, logits, next_token=None):
|
90 |
+
"""
|
91 |
+
Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS.
|
92 |
+
"""
|
93 |
+
# extract approximate alignment matrix chunk (1 frame at a time after the first chunk)
|
94 |
+
aligned_attn = torch.stack(self.last_aligned_attns).mean(dim=0) # (N, N)
|
95 |
+
i, j = self.text_tokens_slice
|
96 |
+
if self.curr_frame_pos == 0:
|
97 |
+
# first chunk has conditioning info, text tokens, and BOS token
|
98 |
+
A_chunk = aligned_attn[j:, i:j].clone().cpu() # (T, S)
|
99 |
+
else:
|
100 |
+
# subsequent chunks have 1 frame due to KV-caching
|
101 |
+
A_chunk = aligned_attn[:, i:j].clone().cpu() # (1, S)
|
102 |
+
|
103 |
+
# TODO: monotonic masking; could have issue b/c spaces are often skipped.
|
104 |
+
A_chunk[:, self.curr_frame_pos + 1:] = 0
|
105 |
+
|
106 |
+
|
107 |
+
self.alignment = torch.cat((self.alignment, A_chunk), dim=0)
|
108 |
+
|
109 |
+
A = self.alignment
|
110 |
+
T, S = A.shape
|
111 |
+
|
112 |
+
# update position
|
113 |
+
cur_text_posn = A_chunk[-1].argmax()
|
114 |
+
discontinuity = not(-4 < cur_text_posn - self.text_position < 7) # NOTE: very lenient!
|
115 |
+
if not discontinuity:
|
116 |
+
self.text_position = cur_text_posn
|
117 |
+
|
118 |
+
# Hallucinations at the start of speech show up as activations at the bottom of the attention maps!
|
119 |
+
# To mitigate this, we just wait until there are no activations far off-diagonal in the last 2 tokens,
|
120 |
+
# and there are some strong activations in the first few tokens.
|
121 |
+
false_start = (not self.started) and (A[-2:, -2:].max() > 0.1 or A[:, :4].max() < 0.5)
|
122 |
+
self.started = not false_start
|
123 |
+
if self.started and self.started_at is None:
|
124 |
+
self.started_at = T
|
125 |
+
|
126 |
+
# Is generation likely complete?
|
127 |
+
self.complete = self.complete or self.text_position >= S - 3
|
128 |
+
if self.complete and self.completed_at is None:
|
129 |
+
self.completed_at = T
|
130 |
+
|
131 |
+
# NOTE: EOS rarely assigned activations, and second-last token is often punctuation, so use last 3 tokens.
|
132 |
+
# NOTE: due to the false-start behaviour, we need to make sure we skip activations for the first few tokens.
|
133 |
+
last_text_token_duration = A[15:, -3:].sum()
|
134 |
+
|
135 |
+
# Activations for the final token that last too long are likely hallucinations.
|
136 |
+
long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 5) # 200ms
|
137 |
+
|
138 |
+
# If there are activations in previous tokens after generation has completed, assume this is a repetition error.
|
139 |
+
alignment_repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
|
140 |
+
|
141 |
+
# Track generated tokens for repetition detection
|
142 |
+
if next_token is not None:
|
143 |
+
# Convert tensor to scalar if needed
|
144 |
+
if isinstance(next_token, torch.Tensor):
|
145 |
+
token_id = next_token.item() if next_token.numel() == 1 else next_token.view(-1)[0].item()
|
146 |
+
else:
|
147 |
+
token_id = next_token
|
148 |
+
self.generated_tokens.append(token_id)
|
149 |
+
|
150 |
+
# Keep only last 8 tokens to prevent memory issues
|
151 |
+
if len(self.generated_tokens) > 8:
|
152 |
+
self.generated_tokens = self.generated_tokens[-8:]
|
153 |
+
|
154 |
+
# Check for excessive token repetition (3x same token in a row)
|
155 |
+
token_repetition = (
|
156 |
+
# self.complete and
|
157 |
+
len(self.generated_tokens) >= 3 and
|
158 |
+
len(set(self.generated_tokens[-3:])) == 1
|
159 |
+
)
|
160 |
+
|
161 |
+
if token_repetition:
|
162 |
+
repeated_token = self.generated_tokens[-1]
|
163 |
+
logger.warning(f"🚨 Detected 3x repetition of token {repeated_token}")
|
164 |
+
|
165 |
+
# Suppress EoS to prevent early termination
|
166 |
+
if cur_text_posn < S - 3 and S > 5: # Only suppress if text is longer than 5 tokens
|
167 |
+
logits[..., self.eos_idx] = -2**15
|
168 |
+
|
169 |
+
# If a bad ending is detected, force emit EOS by modifying logits
|
170 |
+
# NOTE: this means logits may be inconsistent with latents!
|
171 |
+
if long_tail or alignment_repetition or token_repetition:
|
172 |
+
logger.warning(f"forcing EOS token, {long_tail=}, {alignment_repetition=}, {token_repetition=}")
|
173 |
+
# (±2**15 is safe for all dtypes >= 16bit)
|
174 |
+
logits = -(2**15) * torch.ones_like(logits)
|
175 |
+
logits[..., self.eos_idx] = 2**15
|
176 |
+
|
177 |
+
self.curr_frame_pos += 1
|
178 |
+
return logits
|
src/chatterbox/models/t3/inference/t3_hf_backend.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn as nn
|
5 |
+
from transformers import LlamaConfig, LlamaModel, LlamaPreTrainedModel, GenerationMixin
|
6 |
+
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
7 |
+
|
8 |
+
|
9 |
+
class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin):
|
10 |
+
"""
|
11 |
+
Override some HuggingFace interface methods so we can use the standard `generate` method with our
|
12 |
+
custom embedding / logit layers.
|
13 |
+
|
14 |
+
NOTE: need to extend "*PreTrainedModel" to avoid re-initializing weights!
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
config: LlamaConfig,
|
20 |
+
llama: LlamaModel,
|
21 |
+
*,
|
22 |
+
speech_enc,
|
23 |
+
speech_head,
|
24 |
+
latents_queue=None,
|
25 |
+
logits_queue=None,
|
26 |
+
alignment_stream_analyzer: 'AlignmentStreamAnalyzer'=None,
|
27 |
+
):
|
28 |
+
super().__init__(config)
|
29 |
+
self.model = llama
|
30 |
+
self.speech_enc = speech_enc
|
31 |
+
self.speech_head = speech_head
|
32 |
+
self._added_cond = False
|
33 |
+
self.alignment_stream_analyzer = alignment_stream_analyzer
|
34 |
+
|
35 |
+
@torch.inference_mode()
|
36 |
+
def prepare_inputs_for_generation(
|
37 |
+
self, input_ids: torch.Tensor, decoder_cond: torch.Tensor, use_cache: bool, past_key_values=None,
|
38 |
+
# This argument was introduced in some recent version of transformers (>=4.29.1)
|
39 |
+
cache_position=None
|
40 |
+
):
|
41 |
+
"""
|
42 |
+
This is a method used by huggingface's generate() method.
|
43 |
+
Overridden here to apply our custom speech token embedding layer.
|
44 |
+
|
45 |
+
:param input_ids: (B, S) int64 tensors of input tokens.
|
46 |
+
:param decoder_cond: (B, T, C) float32 tensor of conditioning (prefixed to <input_embeds>)
|
47 |
+
"""
|
48 |
+
|
49 |
+
# Make use of the kv cache: only the last input ID is new, we trim away all the ones before
|
50 |
+
if not use_cache:
|
51 |
+
past_key_values = None
|
52 |
+
if past_key_values is not None:
|
53 |
+
input_ids = input_ids[:, -1:]
|
54 |
+
|
55 |
+
# custom speech token embedding layer
|
56 |
+
inputs_embeds = self.speech_enc(input_ids)
|
57 |
+
|
58 |
+
# prefix decoder conditioning if applicable
|
59 |
+
if not self._added_cond:
|
60 |
+
assert past_key_values is not None # should be first step
|
61 |
+
if decoder_cond.size(0) != inputs_embeds.size(0):
|
62 |
+
decoder_cond = decoder_cond.expand(inputs_embeds.size(0), -1, -1)
|
63 |
+
inputs_embeds = torch.cat([decoder_cond, inputs_embeds], dim=1)
|
64 |
+
self._added_cond = True
|
65 |
+
|
66 |
+
return {
|
67 |
+
"inputs_embeds": inputs_embeds,
|
68 |
+
"past_key_values": past_key_values,
|
69 |
+
"use_cache": use_cache,
|
70 |
+
}
|
71 |
+
|
72 |
+
@torch.inference_mode()
|
73 |
+
def forward(
|
74 |
+
self,
|
75 |
+
inputs_embeds: torch.Tensor,
|
76 |
+
past_key_values: Optional[torch.Tensor]=None,
|
77 |
+
use_cache=True,
|
78 |
+
output_attentions=False,
|
79 |
+
output_hidden_states=True,
|
80 |
+
return_dict=True,
|
81 |
+
):
|
82 |
+
"""
|
83 |
+
This is a method used by huggingface's generate() method.
|
84 |
+
Overridden here to apply our custom layer norm and speech logit projection layers.
|
85 |
+
|
86 |
+
:param inputs_embeds: (B, S, C) float32 tensor of conditioning inputs. If past key values are given,
|
87 |
+
S should be 1.
|
88 |
+
"""
|
89 |
+
is_large_input = inputs_embeds.size(1) != 1
|
90 |
+
has_cache = past_key_values is not None and len(past_key_values) > 0
|
91 |
+
assert not (is_large_input and has_cache)
|
92 |
+
assert return_dict
|
93 |
+
assert output_hidden_states
|
94 |
+
|
95 |
+
tfmr_out = self.model(
|
96 |
+
inputs_embeds=inputs_embeds,
|
97 |
+
past_key_values=past_key_values,
|
98 |
+
use_cache=use_cache,
|
99 |
+
output_attentions=output_attentions,
|
100 |
+
output_hidden_states=output_hidden_states,
|
101 |
+
return_dict=True,
|
102 |
+
)
|
103 |
+
hidden_states = tfmr_out.hidden_states[-1] # (B, seq, dim)
|
104 |
+
|
105 |
+
logits = self.speech_head(hidden_states)
|
106 |
+
# assert inputs_embeds.size(0) == 1 # (disabled for CFG)
|
107 |
+
|
108 |
+
# NOTE: hallucination handler may modify logits to force emit an EOS token
|
109 |
+
# logits = self.alignment_stream_analyzer.step(logits)
|
110 |
+
|
111 |
+
return CausalLMOutputWithCrossAttentions(
|
112 |
+
logits=logits,
|
113 |
+
past_key_values=tfmr_out.past_key_values,
|
114 |
+
hidden_states=tfmr_out.hidden_states,
|
115 |
+
attentions=tfmr_out.attentions,
|
116 |
+
)
|
src/chatterbox/models/t3/llama_configs.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LLAMA_520M_CONFIG_DICT = dict(
|
2 |
+
# Arbitrary small number that won't cause problems when loading.
|
3 |
+
# These param are unused due to custom input layers.
|
4 |
+
vocab_size=8,
|
5 |
+
# default params needed for loading most pretrained 1B weights
|
6 |
+
max_position_embeddings=131072,
|
7 |
+
hidden_size=1024,
|
8 |
+
intermediate_size=4096,
|
9 |
+
num_hidden_layers=30,
|
10 |
+
num_attention_heads=16,
|
11 |
+
attn_implementation="sdpa",
|
12 |
+
head_dim=64,
|
13 |
+
tie_word_embeddings=False,
|
14 |
+
hidden_act="silu",
|
15 |
+
attention_bias=False,
|
16 |
+
attention_dropout=0.0,
|
17 |
+
initializer_range=0.02,
|
18 |
+
mlp_bias=False,
|
19 |
+
model_type="llama",
|
20 |
+
num_key_value_heads=16,
|
21 |
+
pretraining_tp=1,
|
22 |
+
rms_norm_eps=1e-05,
|
23 |
+
rope_scaling=dict(
|
24 |
+
factor=8.0,
|
25 |
+
high_freq_factor=4.0,
|
26 |
+
low_freq_factor=1.0,
|
27 |
+
original_max_position_embeddings=8192,
|
28 |
+
rope_type="llama3"
|
29 |
+
),
|
30 |
+
rope_theta=500000.0,
|
31 |
+
torch_dtype="bfloat16",
|
32 |
+
use_cache=True,
|
33 |
+
)
|
34 |
+
|
35 |
+
LLAMA_CONFIGS = {
|
36 |
+
"Llama_520M": LLAMA_520M_CONFIG_DICT,
|
37 |
+
}
|
src/chatterbox/models/t3/modules/cond_enc.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn, Tensor
|
6 |
+
|
7 |
+
from .perceiver import Perceiver
|
8 |
+
from .t3_config import T3Config
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class T3Cond:
|
13 |
+
"""
|
14 |
+
Dataclass container for most / all conditioning info.
|
15 |
+
TODO: serialization methods aren't used, keeping them around for convenience
|
16 |
+
"""
|
17 |
+
|
18 |
+
speaker_emb: Tensor
|
19 |
+
clap_emb: Optional[Tensor] = None
|
20 |
+
cond_prompt_speech_tokens: Optional[Tensor] = None
|
21 |
+
cond_prompt_speech_emb: Optional[Tensor] = None
|
22 |
+
emotion_adv: Optional[Tensor] = 0.5
|
23 |
+
|
24 |
+
def to(self, *, device=None, dtype=None):
|
25 |
+
"Cast to a device and dtype. Dtype casting is ignored for long/int tensors."
|
26 |
+
for k, v in self.__dict__.items():
|
27 |
+
if torch.is_tensor(v):
|
28 |
+
is_fp = type(v.view(-1)[0].item()) is not int
|
29 |
+
setattr(self, k, v.to(device=device, dtype=dtype if is_fp else None))
|
30 |
+
return self
|
31 |
+
|
32 |
+
def save(self, fpath):
|
33 |
+
torch.save(self.__dict__, fpath)
|
34 |
+
|
35 |
+
@staticmethod
|
36 |
+
def load(fpath, map_location="cpu"):
|
37 |
+
kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
|
38 |
+
return T3Cond(**kwargs)
|
39 |
+
|
40 |
+
|
41 |
+
class T3CondEnc(nn.Module):
|
42 |
+
"""
|
43 |
+
Handle all non-text conditioning, like speaker embeddings / prompts, CLAP, emotion, etc.
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, hp: T3Config):
|
47 |
+
super().__init__()
|
48 |
+
self.hp = hp
|
49 |
+
if hp.encoder_type == "voice_encoder":
|
50 |
+
self.spkr_enc = nn.Linear(hp.speaker_embed_size, hp.n_channels)
|
51 |
+
else:
|
52 |
+
raise NotImplementedError(str(hp.encoder_type))
|
53 |
+
|
54 |
+
# emotion adv
|
55 |
+
self.emotion_adv_fc = None
|
56 |
+
if hp.emotion_adv:
|
57 |
+
self.emotion_adv_fc = nn.Linear(1, hp.n_channels, bias=False)
|
58 |
+
|
59 |
+
# perceiver resampler
|
60 |
+
self.perceiver = None
|
61 |
+
if hp.use_perceiver_resampler:
|
62 |
+
self.perceiver = Perceiver()
|
63 |
+
|
64 |
+
def forward(self, cond: T3Cond):
|
65 |
+
# Validate
|
66 |
+
assert (cond.cond_prompt_speech_tokens is None) == (cond.cond_prompt_speech_emb is None), \
|
67 |
+
"no embeddings for cond_prompt_speech_tokens"
|
68 |
+
|
69 |
+
# Speaker embedding projection
|
70 |
+
cond_spkr = self.spkr_enc(cond.speaker_emb.view(-1, self.hp.speaker_embed_size))[:, None] # (B, 1, dim)
|
71 |
+
empty = torch.zeros_like(cond_spkr[:, :0]) # (B, 0, dim)
|
72 |
+
|
73 |
+
# TODO CLAP
|
74 |
+
assert cond.clap_emb is None, "clap_embed not implemented"
|
75 |
+
cond_clap = empty # (B, 0, dim)
|
76 |
+
|
77 |
+
# Cond prompt
|
78 |
+
cond_prompt_speech_emb = cond.cond_prompt_speech_emb
|
79 |
+
if cond_prompt_speech_emb is None:
|
80 |
+
cond_prompt_speech_emb = empty # (B, 0, dim)
|
81 |
+
elif self.hp.use_perceiver_resampler:
|
82 |
+
cond_prompt_speech_emb = self.perceiver(cond_prompt_speech_emb)
|
83 |
+
|
84 |
+
# Emotion Adv: must provide a value if this model uses emotion conditioning
|
85 |
+
cond_emotion_adv = empty # (B, 0, dim)
|
86 |
+
if self.hp.emotion_adv:
|
87 |
+
assert cond.emotion_adv is not None
|
88 |
+
cond_emotion_adv = self.emotion_adv_fc(cond.emotion_adv.view(-1, 1, 1))
|
89 |
+
|
90 |
+
# Concat and return
|
91 |
+
cond_embeds = torch.cat((
|
92 |
+
cond_spkr,
|
93 |
+
cond_clap,
|
94 |
+
cond_prompt_speech_emb,
|
95 |
+
cond_emotion_adv,
|
96 |
+
), dim=1)
|
97 |
+
return cond_embeds
|
src/chatterbox/models/t3/modules/learned_pos_emb.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn, Tensor
|
5 |
+
|
6 |
+
|
7 |
+
class LearnedPositionEmbeddings(nn.Module):
|
8 |
+
def __init__(self, seq_len, model_dim, init=.02):
|
9 |
+
super().__init__()
|
10 |
+
self.emb = nn.Embedding(seq_len, model_dim)
|
11 |
+
# Initializing this way is standard for GPT-2
|
12 |
+
self.emb.weight.data.normal_(mean=0.0, std=init)
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
"""
|
16 |
+
Returns positional embeddings for index 0 up to the length of x
|
17 |
+
"""
|
18 |
+
sl = x.shape[1]
|
19 |
+
return self.emb(torch.arange(0, sl, device=x.device))
|
20 |
+
|
21 |
+
def get_fixed_embedding(self, idx: 'Union[int, Tensor]'):
|
22 |
+
"""
|
23 |
+
Args:
|
24 |
+
idx: scalar int or an integer tensor of shape (T,) or (B, T)
|
25 |
+
Returns:
|
26 |
+
positional embeddings for given indices, shape (B, T, dim), ie (1, 1, dim) for int input
|
27 |
+
"""
|
28 |
+
device = self.emb.weight.device
|
29 |
+
idx = idx.to(device) if torch.is_tensor(idx) else torch.tensor(idx, device=device)
|
30 |
+
idx = torch.atleast_2d(idx)
|
31 |
+
assert idx.ndim == 2
|
32 |
+
return self.emb(idx) # (B, T, dim)
|
src/chatterbox/models/t3/modules/perceiver.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 Resemble AI
|
2 |
+
# Author: Manmay Nakhashi
|
3 |
+
# MIT License
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange
|
10 |
+
|
11 |
+
|
12 |
+
class RelativePositionBias(nn.Module):
|
13 |
+
def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
|
14 |
+
super().__init__()
|
15 |
+
self.scale = scale
|
16 |
+
self.causal = causal
|
17 |
+
self.num_buckets = num_buckets
|
18 |
+
self.max_distance = max_distance
|
19 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
|
23 |
+
ret = 0
|
24 |
+
n = -relative_position
|
25 |
+
if not causal:
|
26 |
+
num_buckets //= 2
|
27 |
+
ret += (n < 0).long() * num_buckets
|
28 |
+
n = torch.abs(n)
|
29 |
+
else:
|
30 |
+
n = torch.max(n, torch.zeros_like(n))
|
31 |
+
|
32 |
+
max_exact = num_buckets // 2
|
33 |
+
is_small = n < max_exact
|
34 |
+
|
35 |
+
val_if_large = max_exact + (
|
36 |
+
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
37 |
+
).long()
|
38 |
+
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
39 |
+
|
40 |
+
ret += torch.where(is_small, n, val_if_large)
|
41 |
+
return ret
|
42 |
+
|
43 |
+
def forward(self, qk_dots):
|
44 |
+
i, j, device = *qk_dots.shape[-2:], qk_dots.device
|
45 |
+
q_pos = torch.arange(i, dtype=torch.long, device=device)
|
46 |
+
k_pos = torch.arange(j, dtype=torch.long, device=device)
|
47 |
+
rel_pos = k_pos[None, :] - q_pos[:, None]
|
48 |
+
rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
|
49 |
+
max_distance=self.max_distance)
|
50 |
+
values = self.relative_attention_bias(rp_bucket)
|
51 |
+
bias = rearrange(values, 'i j h -> () h i j')
|
52 |
+
return qk_dots + (bias * self.scale)
|
53 |
+
|
54 |
+
|
55 |
+
class AttentionQKV(nn.Module):
|
56 |
+
def __init__(self, n_heads, head_dim, dropout_rate=0.1, scale=None, flash=False):
|
57 |
+
super().__init__()
|
58 |
+
self.n_heads = n_heads
|
59 |
+
self.head_dim = head_dim
|
60 |
+
self.scale = scale if scale is not None else head_dim ** -0.5
|
61 |
+
self.flash = flash
|
62 |
+
self.dropout_rate = dropout_rate
|
63 |
+
self.dropout = nn.Dropout(dropout_rate)
|
64 |
+
self.flash_config = self.setup_flash_config() if flash else None
|
65 |
+
|
66 |
+
def setup_flash_config(self):
|
67 |
+
# Setup flash attention configuration
|
68 |
+
flash_config = {
|
69 |
+
'enable_flash': True,
|
70 |
+
'enable_math': True,
|
71 |
+
'enable_mem_efficient': True
|
72 |
+
}
|
73 |
+
return flash_config
|
74 |
+
|
75 |
+
def forward(self, q, k, v, mask=None):
|
76 |
+
q, k, v = [self.split_heads(tensor) for tensor in [q, k, v]]
|
77 |
+
if self.flash:
|
78 |
+
out = self.flash_attention(q, k, v, mask=mask)
|
79 |
+
else:
|
80 |
+
out = self.scaled_dot_product_attention(q, k, v, mask=mask)
|
81 |
+
|
82 |
+
return self.combine_heads(out)
|
83 |
+
|
84 |
+
def scaled_dot_product_attention(self, q, k, v, mask=None):
|
85 |
+
sim = torch.einsum("bhlt,bhls->bhts", q, k) * self.scale
|
86 |
+
if mask is not None:
|
87 |
+
sim = sim.masked_fill(mask == 0, float('-inf'))
|
88 |
+
attn = torch.softmax(sim, dim=-1)
|
89 |
+
attn = self.dropout(attn)
|
90 |
+
return torch.einsum("bhts,bhls->bhlt", attn, v)
|
91 |
+
|
92 |
+
def flash_attention(self, q, k, v, mask=None):
|
93 |
+
config = self.flash_config if self.flash_config else {}
|
94 |
+
with torch.backends.cuda.sdp_kernel(**config):
|
95 |
+
out = F.scaled_dot_product_attention(
|
96 |
+
q, k, v,
|
97 |
+
attn_mask=mask,
|
98 |
+
dropout_p=self.dropout_rate if self.training else 0.
|
99 |
+
)
|
100 |
+
return out
|
101 |
+
|
102 |
+
def split_heads(self, x):
|
103 |
+
bs, length, _ = x.shape
|
104 |
+
x = x.view(bs, length, self.n_heads, self.head_dim)
|
105 |
+
return x.permute(0, 2, 1, 3)
|
106 |
+
|
107 |
+
def combine_heads(self, x):
|
108 |
+
bs, _, length, _ = x.shape
|
109 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
110 |
+
return x.view(bs, length, -1)
|
111 |
+
|
112 |
+
|
113 |
+
class AttentionBlock2(nn.Module):
|
114 |
+
"""
|
115 |
+
An attention block that allows spatial positions to attend to each other,
|
116 |
+
using AttentionQKV and separate linear transformations for Q, K, and V.
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
channels,
|
122 |
+
num_heads=1,
|
123 |
+
num_head_channels=-1,
|
124 |
+
relative_pos_embeddings=False,
|
125 |
+
flash_attention=True,
|
126 |
+
dropout_rate=0.2,
|
127 |
+
scale=None
|
128 |
+
):
|
129 |
+
super().__init__()
|
130 |
+
self.channels = channels
|
131 |
+
|
132 |
+
if num_head_channels == -1:
|
133 |
+
self.num_heads = num_heads
|
134 |
+
else:
|
135 |
+
assert (
|
136 |
+
channels % num_head_channels == 0
|
137 |
+
), f"channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
138 |
+
self.num_heads = channels // num_head_channels
|
139 |
+
|
140 |
+
self.norm = nn.LayerNorm(channels)
|
141 |
+
|
142 |
+
# Separate linear layers for Q, K, and V
|
143 |
+
self.to_q = nn.Linear(channels, channels)
|
144 |
+
self.to_k = nn.Linear(channels, channels)
|
145 |
+
self.to_v = nn.Linear(channels, channels)
|
146 |
+
|
147 |
+
self.attention = AttentionQKV(self.num_heads, channels // self.num_heads, dropout_rate=dropout_rate, flash=flash_attention, scale=scale)
|
148 |
+
|
149 |
+
self.proj_out = nn.Linear(channels, channels)
|
150 |
+
|
151 |
+
if relative_pos_embeddings:
|
152 |
+
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
|
153 |
+
else:
|
154 |
+
self.relative_pos_embeddings = None
|
155 |
+
|
156 |
+
def forward(self, x1, x2, mask=None):
|
157 |
+
b1, c1, *spatial1 = x1.shape
|
158 |
+
b2, c2, *spatial2 = x2.shape
|
159 |
+
|
160 |
+
x1_norm = self.norm(x1)
|
161 |
+
x2_norm = self.norm(x2)
|
162 |
+
|
163 |
+
q = self.to_q(x1_norm)
|
164 |
+
k = self.to_k(x2_norm)
|
165 |
+
v = self.to_v(x2_norm)
|
166 |
+
|
167 |
+
h = self.attention(q, k, v, mask=mask)
|
168 |
+
h = self.proj_out(h)
|
169 |
+
|
170 |
+
return (x1 + h).reshape(b1, c1, *spatial1)
|
171 |
+
|
172 |
+
|
173 |
+
class Perceiver(nn.Module):
|
174 |
+
"""Inspired by https://arxiv.org/abs/2103.03206"""
|
175 |
+
def __init__(self, pre_attention_query_token=32, pre_attention_query_size=1024, embedding_dim=1024, num_attn_heads=4):
|
176 |
+
"""
|
177 |
+
Initialize the perceiver module.
|
178 |
+
|
179 |
+
:param pre_attention_query_token: Number of query tokens for pre-attention
|
180 |
+
:param pre_attention_query_size: Size of each query token
|
181 |
+
:param embedding_dim: Dimension of the embedding space
|
182 |
+
:param num_attn_heads: Number of attention heads
|
183 |
+
"""
|
184 |
+
super().__init__()
|
185 |
+
|
186 |
+
# Initialize the pre-attention query parameter
|
187 |
+
self.pre_attention_query = torch.nn.Parameter(
|
188 |
+
torch.empty(1, pre_attention_query_token, pre_attention_query_size)
|
189 |
+
)
|
190 |
+
|
191 |
+
# Calculate the variance for uniform initialization
|
192 |
+
query_variance = math.sqrt(3.0) * math.sqrt(2.0 / (pre_attention_query_token + pre_attention_query_token))
|
193 |
+
|
194 |
+
# Initialize the pre-attention query with uniform distribution
|
195 |
+
self.pre_attention_query.data.uniform_(-query_variance, query_variance)
|
196 |
+
|
197 |
+
# Initialize the attention block
|
198 |
+
self.attn = AttentionBlock2(embedding_dim, num_attn_heads)
|
199 |
+
|
200 |
+
def forward(self, h):
|
201 |
+
"""
|
202 |
+
Forward pass of the perceiver module.
|
203 |
+
:param h: Input tensor
|
204 |
+
:return: Output after applying attention mechanisms
|
205 |
+
"""
|
206 |
+
# Expand the pre-attention query to match the batch size of the input
|
207 |
+
query_ = self.pre_attention_query.expand(h.shape[0], -1, -1)
|
208 |
+
# Apply the first attention mechanism (cross-attention)
|
209 |
+
pre_att = self.attn(query_, h)
|
210 |
+
# Apply the second attention mechanism (self-attention)
|
211 |
+
attn = self.attn(pre_att, pre_att)
|
212 |
+
return attn
|
src/chatterbox/models/t3/modules/t3_config.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..llama_configs import LLAMA_CONFIGS
|
2 |
+
|
3 |
+
|
4 |
+
class T3Config:
|
5 |
+
def __init__(self, text_tokens_dict_size=704):
|
6 |
+
self.start_text_token = 255
|
7 |
+
self.stop_text_token = 0
|
8 |
+
self.text_tokens_dict_size = text_tokens_dict_size
|
9 |
+
self.max_text_tokens = 2048
|
10 |
+
|
11 |
+
self.start_speech_token = 6561
|
12 |
+
self.stop_speech_token = 6562
|
13 |
+
self.speech_tokens_dict_size = 8194
|
14 |
+
self.max_speech_tokens = 4096
|
15 |
+
|
16 |
+
self.llama_config_name = "Llama_520M"
|
17 |
+
self.input_pos_emb = "learned"
|
18 |
+
self.speech_cond_prompt_len = 150
|
19 |
+
|
20 |
+
self.encoder_type = "voice_encoder"
|
21 |
+
self.speaker_embed_size = 256
|
22 |
+
self.use_perceiver_resampler = True
|
23 |
+
self.emotion_adv = True
|
24 |
+
|
25 |
+
@property
|
26 |
+
def n_channels(self):
|
27 |
+
return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"]
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def english_only(cls):
|
31 |
+
"""Create configuration for English-only TTS model."""
|
32 |
+
return cls(text_tokens_dict_size=704)
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def multilingual(cls):
|
36 |
+
"""Create configuration for multilingual TTS model."""
|
37 |
+
return cls(text_tokens_dict_size=2352)
|
src/chatterbox/models/t3/t3.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 Resemble AI
|
2 |
+
# MIT License
|
3 |
+
import logging
|
4 |
+
from typing import Union, Optional, List
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
from tqdm import tqdm
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import nn, Tensor
|
12 |
+
from transformers import LlamaModel, LlamaConfig
|
13 |
+
from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor, MinPLogitsWarper
|
14 |
+
|
15 |
+
from .modules.learned_pos_emb import LearnedPositionEmbeddings
|
16 |
+
|
17 |
+
from .modules.cond_enc import T3CondEnc, T3Cond
|
18 |
+
from .modules.t3_config import T3Config
|
19 |
+
from .llama_configs import LLAMA_CONFIGS
|
20 |
+
from .inference.t3_hf_backend import T3HuggingfaceBackend
|
21 |
+
from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer
|
22 |
+
from ..utils import AttrDict
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
def _ensure_BOT_EOT(text_tokens: Tensor, hp):
|
29 |
+
B = text_tokens.size(0)
|
30 |
+
assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token"
|
31 |
+
assert (text_tokens == hp.stop_text_token).int().sum() >= B, "missing stop_text_token"
|
32 |
+
|
33 |
+
|
34 |
+
class T3(nn.Module):
|
35 |
+
"""
|
36 |
+
Token-To-Token (T3) TTS model using huggingface transformer models as backbones,
|
37 |
+
* tokenization, including start / stop tokens are always added externally to this class
|
38 |
+
* conditioning data like CLAP, emotion, etc are all in a separate file for more modularity
|
39 |
+
* careful! this class assumes relative positional encoding -- with absolute PE, we would at
|
40 |
+
least want to reset the position to 0 when speech tokens begin, and optionally use a
|
41 |
+
different PE embedding space for speech.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self, hp=None):
|
45 |
+
if hp is None:
|
46 |
+
hp = T3Config.english_only() # Default to English-only config for backward compatibility
|
47 |
+
super().__init__()
|
48 |
+
self.hp = hp
|
49 |
+
self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
|
50 |
+
self.tfmr = LlamaModel(self.cfg)
|
51 |
+
self.dim = self.cfg.hidden_size
|
52 |
+
self.deepspeed_patch_applied = False
|
53 |
+
|
54 |
+
# conditioning / embedding
|
55 |
+
self.cond_enc = T3CondEnc(hp)
|
56 |
+
self.text_emb = nn.Embedding(hp.text_tokens_dict_size, self.dim)
|
57 |
+
self.speech_emb = nn.Embedding(hp.speech_tokens_dict_size, self.dim)
|
58 |
+
|
59 |
+
# custom position embedding
|
60 |
+
if hp.input_pos_emb == "learned":
|
61 |
+
max_text_seq_len = hp.max_text_tokens + 2
|
62 |
+
self.text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, self.dim)
|
63 |
+
|
64 |
+
max_mel_seq_len = hp.max_speech_tokens + 2 + 2
|
65 |
+
self.speech_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, self.dim)
|
66 |
+
|
67 |
+
# logit projection
|
68 |
+
self.text_head = nn.Linear(self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False)
|
69 |
+
self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=False)
|
70 |
+
self.compiled = False
|
71 |
+
|
72 |
+
@property
|
73 |
+
def device(self):
|
74 |
+
return self.speech_head.weight.device
|
75 |
+
|
76 |
+
def prepare_conditioning(self, t3_cond: T3Cond):
|
77 |
+
"""
|
78 |
+
Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`.
|
79 |
+
"""
|
80 |
+
if t3_cond.cond_prompt_speech_tokens is not None and t3_cond.cond_prompt_speech_emb is None:
|
81 |
+
t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens) + \
|
82 |
+
self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
|
83 |
+
return self.cond_enc(t3_cond) # (B, len_cond, dim)
|
84 |
+
|
85 |
+
def prepare_input_embeds(
|
86 |
+
self,
|
87 |
+
*,
|
88 |
+
t3_cond: T3Cond,
|
89 |
+
text_tokens: torch.LongTensor,
|
90 |
+
speech_tokens: torch.LongTensor,
|
91 |
+
cfg_weight: float = 0.0,
|
92 |
+
):
|
93 |
+
# prepare input embeddings (skip backbone tranformer embeddings)
|
94 |
+
cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim)
|
95 |
+
text_emb = self.text_emb(text_tokens) # (B, len_text, dim)
|
96 |
+
if cfg_weight > 0.0:
|
97 |
+
text_emb[1].zero_() # CFG uncond
|
98 |
+
|
99 |
+
speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim)
|
100 |
+
if self.hp.input_pos_emb == "learned":
|
101 |
+
text_emb = text_emb + self.text_pos_emb(text_tokens)
|
102 |
+
speech_emb = speech_emb + self.speech_pos_emb(speech_tokens)
|
103 |
+
len_cond = cond_emb.size(1)
|
104 |
+
|
105 |
+
if cond_emb.size(0) != text_emb.size(0):
|
106 |
+
cond_emb = cond_emb.expand(text_emb.size(0), -1, -1)
|
107 |
+
|
108 |
+
# concat
|
109 |
+
embeds = torch.stack([
|
110 |
+
torch.cat((ce, te, se))
|
111 |
+
for ce, te, se in zip(cond_emb, text_emb, speech_emb)
|
112 |
+
]) # (B, length, dim)
|
113 |
+
return embeds, len_cond
|
114 |
+
|
115 |
+
def forward(
|
116 |
+
self,
|
117 |
+
*,
|
118 |
+
t3_cond: T3Cond,
|
119 |
+
text_tokens: torch.LongTensor,
|
120 |
+
text_token_lens: torch.LongTensor,
|
121 |
+
speech_tokens: torch.LongTensor,
|
122 |
+
speech_token_lens: torch.LongTensor,
|
123 |
+
training=False,
|
124 |
+
):
|
125 |
+
_ensure_BOT_EOT(text_tokens, self.hp)
|
126 |
+
|
127 |
+
# prepare custom input embeds
|
128 |
+
embeds, len_cond = self.prepare_input_embeds(
|
129 |
+
t3_cond=t3_cond,
|
130 |
+
text_tokens=text_tokens,
|
131 |
+
speech_tokens=speech_tokens,
|
132 |
+
)
|
133 |
+
|
134 |
+
# backbone tranformer forward
|
135 |
+
tfmr_out = self.tfmr.forward(
|
136 |
+
input_ids=None,
|
137 |
+
# position_ids=position_ids, # TODO? ROPE should be fine?
|
138 |
+
inputs_embeds=embeds,
|
139 |
+
output_hidden_states=True,
|
140 |
+
return_dict=True,
|
141 |
+
use_cache=(not training),
|
142 |
+
)
|
143 |
+
hidden_states = tfmr_out.hidden_states[-1] # final tfmr layer output, (B, seq, dim)
|
144 |
+
|
145 |
+
# post-processing: splice out text and speech parts of hidden states
|
146 |
+
len_text = text_tokens.size(1)
|
147 |
+
len_speech = speech_tokens.size(1)
|
148 |
+
B, _, dim = hidden_states.shape
|
149 |
+
device, dtype = hidden_states.device, hidden_states.dtype
|
150 |
+
text_latents = torch.zeros(B, len_text, dim, dtype=dtype, device=device)
|
151 |
+
speech_latents = torch.zeros(B, len_speech, dim, dtype=dtype, device=device)
|
152 |
+
ttl, stl = text_token_lens, speech_token_lens
|
153 |
+
for i in range(B):
|
154 |
+
text_end = len_cond + ttl[i].item()
|
155 |
+
speech_start = len_cond + text_tokens.size(1)
|
156 |
+
speech_end = speech_start + stl[i].item()
|
157 |
+
text_latents[i, :ttl[i]] = hidden_states[i, len_cond:text_end]
|
158 |
+
speech_latents[i, :stl[i]] = hidden_states[i, speech_start:speech_end]
|
159 |
+
|
160 |
+
# logit projection
|
161 |
+
text_logits = self.text_head(text_latents)
|
162 |
+
speech_logits = self.speech_head(speech_latents)
|
163 |
+
|
164 |
+
return AttrDict(
|
165 |
+
text_logits=text_logits,
|
166 |
+
text_latents=text_latents,
|
167 |
+
speech_logits=speech_logits,
|
168 |
+
speech_latents=speech_latents,
|
169 |
+
hidden_states=hidden_states,
|
170 |
+
)
|
171 |
+
|
172 |
+
def loss(
|
173 |
+
self,
|
174 |
+
*,
|
175 |
+
t3_cond: T3Cond,
|
176 |
+
text_tokens: torch.LongTensor,
|
177 |
+
text_token_lens: torch.LongTensor,
|
178 |
+
speech_tokens: torch.LongTensor,
|
179 |
+
speech_token_lens: torch.LongTensor,
|
180 |
+
):
|
181 |
+
"training method"
|
182 |
+
len_text = text_tokens.size(1)
|
183 |
+
len_speech = speech_tokens.size(1)
|
184 |
+
assert len_text == text_token_lens.max()
|
185 |
+
assert len_speech == speech_token_lens.max()
|
186 |
+
|
187 |
+
out = self.forward(
|
188 |
+
t3_cond=t3_cond,
|
189 |
+
text_tokens=text_tokens,
|
190 |
+
text_token_lens=text_token_lens,
|
191 |
+
speech_tokens=speech_tokens,
|
192 |
+
speech_token_lens=speech_token_lens,
|
193 |
+
training=True,
|
194 |
+
) # (B, seq, vocab_size)
|
195 |
+
|
196 |
+
# Calc CCE losses
|
197 |
+
IGNORE_ID = -100
|
198 |
+
device = out.text_logits.device
|
199 |
+
mask_text = torch.arange(len_text, device=device)[None] >= text_token_lens[:, None] # (B, len_text)
|
200 |
+
mask_speech = torch.arange(len_speech, device=device)[None] >= speech_token_lens[:, None] # (B, len_speech)
|
201 |
+
masked_text = text_tokens.masked_fill(mask_text, IGNORE_ID)
|
202 |
+
masked_speech = speech_tokens.masked_fill(mask_speech, IGNORE_ID)
|
203 |
+
loss_text = F.cross_entropy(out.text_logits, masked_text, ignore_index=IGNORE_ID)
|
204 |
+
loss_speech = F.cross_entropy(out.speech_logits, masked_speech, ignore_index=IGNORE_ID)
|
205 |
+
|
206 |
+
return loss_text, loss_speech
|
207 |
+
|
208 |
+
@torch.inference_mode()
|
209 |
+
def inference(
|
210 |
+
self,
|
211 |
+
*,
|
212 |
+
t3_cond: T3Cond,
|
213 |
+
text_tokens: Tensor,
|
214 |
+
initial_speech_tokens: Optional[Tensor]=None,
|
215 |
+
|
216 |
+
# misc conditioning
|
217 |
+
prepend_prompt_speech_tokens: Optional[Tensor]=None,
|
218 |
+
|
219 |
+
# HF generate args
|
220 |
+
num_return_sequences=1,
|
221 |
+
max_new_tokens=None,
|
222 |
+
stop_on_eos=True,
|
223 |
+
do_sample=True,
|
224 |
+
temperature=0.8,
|
225 |
+
top_p=0.95,
|
226 |
+
min_p=0.05,
|
227 |
+
length_penalty=1.0,
|
228 |
+
repetition_penalty=1.2,
|
229 |
+
cfg_weight=0.5,
|
230 |
+
):
|
231 |
+
"""
|
232 |
+
Args:
|
233 |
+
text_tokens: a 1D (unbatched) or 2D (batched) tensor.
|
234 |
+
"""
|
235 |
+
# Validate / sanitize inputs
|
236 |
+
assert prepend_prompt_speech_tokens is None, "not implemented"
|
237 |
+
_ensure_BOT_EOT(text_tokens, self.hp)
|
238 |
+
text_tokens = torch.atleast_2d(text_tokens).to(dtype=torch.long, device=self.device)
|
239 |
+
|
240 |
+
# Default initial speech to a single start-of-speech token
|
241 |
+
if initial_speech_tokens is None:
|
242 |
+
initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1])
|
243 |
+
|
244 |
+
# Prepare custom input embeds
|
245 |
+
embeds, len_cond = self.prepare_input_embeds(
|
246 |
+
t3_cond=t3_cond,
|
247 |
+
text_tokens=text_tokens,
|
248 |
+
speech_tokens=initial_speech_tokens,
|
249 |
+
cfg_weight=cfg_weight,
|
250 |
+
)
|
251 |
+
|
252 |
+
# In order to use the standard HF generate method, we need to extend some methods to inject our custom logic
|
253 |
+
# Note the llama-specific logic. Other tfmr types can be added later.
|
254 |
+
|
255 |
+
self.compiled = False
|
256 |
+
|
257 |
+
# TODO? synchronize the expensive compile function
|
258 |
+
# with self.compile_lock:
|
259 |
+
if not self.compiled:
|
260 |
+
alignment_stream_analyzer = AlignmentStreamAnalyzer(
|
261 |
+
self.tfmr,
|
262 |
+
None,
|
263 |
+
text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
|
264 |
+
alignment_layer_idx=9, # TODO: hparam or something?
|
265 |
+
eos_idx=self.hp.stop_speech_token,
|
266 |
+
)
|
267 |
+
assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
|
268 |
+
|
269 |
+
patched_model = T3HuggingfaceBackend(
|
270 |
+
config=self.cfg,
|
271 |
+
llama=self.tfmr,
|
272 |
+
speech_enc=self.speech_emb,
|
273 |
+
speech_head=self.speech_head,
|
274 |
+
alignment_stream_analyzer=alignment_stream_analyzer,
|
275 |
+
)
|
276 |
+
self.patched_model = patched_model
|
277 |
+
self.compiled = True
|
278 |
+
|
279 |
+
# # Run normal generate method, which calls our custom extended methods
|
280 |
+
# return self.patched_model.generate(
|
281 |
+
# inputs=initial_speech_tokens,
|
282 |
+
# decoder_cond=embeds,
|
283 |
+
# bos_token_id=self.hp.start_speech_token,
|
284 |
+
# eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1),
|
285 |
+
# pad_token_id=self.hp.stop_speech_token,
|
286 |
+
# max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
|
287 |
+
# num_return_sequences=num_return_sequences,
|
288 |
+
# temperature=temperature,
|
289 |
+
# min_p=min_p,
|
290 |
+
# length_penalty=length_penalty,
|
291 |
+
# repetition_penalty=repetition_penalty,
|
292 |
+
# do_sample=do_sample,
|
293 |
+
# # cache_implementation=None if not self.compiled else "static",
|
294 |
+
# )
|
295 |
+
|
296 |
+
device = embeds.device
|
297 |
+
|
298 |
+
bos_token = torch.tensor([[self.hp.start_speech_token]], dtype=torch.long, device=device)
|
299 |
+
bos_embed = self.speech_emb(bos_token) # shape: (B, 1, embed_dim)
|
300 |
+
bos_embed = bos_embed + self.speech_pos_emb.get_fixed_embedding(0)
|
301 |
+
|
302 |
+
# batch_size=2 for CFG
|
303 |
+
bos_embed = torch.cat([bos_embed, bos_embed])
|
304 |
+
|
305 |
+
# Combine condition and BOS token for the initial input
|
306 |
+
inputs_embeds = torch.cat([embeds, bos_embed], dim=1)
|
307 |
+
|
308 |
+
# Track generated token ids; start with the BOS token.
|
309 |
+
generated_ids = bos_token.clone()
|
310 |
+
predicted = [] # To store the predicted tokens
|
311 |
+
|
312 |
+
# Instantiate the logits processors.
|
313 |
+
top_p_warper = TopPLogitsWarper(top_p=top_p)
|
314 |
+
min_p_warper = MinPLogitsWarper(min_p=min_p)
|
315 |
+
top_p_warper = TopPLogitsWarper(top_p=top_p)
|
316 |
+
repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty))
|
317 |
+
|
318 |
+
# ---- Initial Forward Pass (no kv_cache yet) ----
|
319 |
+
output = self.patched_model(
|
320 |
+
inputs_embeds=inputs_embeds,
|
321 |
+
past_key_values=None,
|
322 |
+
use_cache=True,
|
323 |
+
output_attentions=True,
|
324 |
+
output_hidden_states=True,
|
325 |
+
return_dict=True,
|
326 |
+
)
|
327 |
+
# Initialize kv_cache with the full context.
|
328 |
+
past = output.past_key_values
|
329 |
+
|
330 |
+
# ---- Generation Loop using kv_cache ----
|
331 |
+
for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
|
332 |
+
logits_step = output.logits[:, -1, :]
|
333 |
+
# CFG combine → (1, V)
|
334 |
+
cond = logits_step[0:1, :]
|
335 |
+
uncond = logits_step[1:2, :]
|
336 |
+
cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
|
337 |
+
logits = cond + cfg * (cond - uncond)
|
338 |
+
|
339 |
+
# Apply alignment stream analyzer integrity checks
|
340 |
+
if self.patched_model.alignment_stream_analyzer is not None:
|
341 |
+
if logits.dim() == 1: # guard in case something upstream squeezed
|
342 |
+
logits = logits.unsqueeze(0) # (1, V)
|
343 |
+
# Pass the last generated token for repetition tracking
|
344 |
+
last_token = generated_ids[0, -1].item() if len(generated_ids[0]) > 0 else None
|
345 |
+
logits = self.patched_model.alignment_stream_analyzer.step(logits, next_token=last_token) # (1, V)
|
346 |
+
|
347 |
+
# Apply repetition penalty
|
348 |
+
ids_for_proc = generated_ids[:1, ...] # batch = 1
|
349 |
+
logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V)
|
350 |
+
|
351 |
+
# Apply temperature scaling.
|
352 |
+
if temperature != 1.0:
|
353 |
+
logits = logits / temperature
|
354 |
+
|
355 |
+
# Apply min_p and top_p filtering
|
356 |
+
logits = min_p_warper(ids_for_proc, logits)
|
357 |
+
logits = top_p_warper(ids_for_proc, logits)
|
358 |
+
|
359 |
+
# Convert logits to probabilities and sample the next token.
|
360 |
+
probs = torch.softmax(logits, dim=-1)
|
361 |
+
next_token = torch.multinomial(probs, num_samples=1) # shape: (B, 1)
|
362 |
+
|
363 |
+
predicted.append(next_token)
|
364 |
+
generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
365 |
+
|
366 |
+
# Check for EOS token.
|
367 |
+
if next_token.view(-1) == self.hp.stop_speech_token:
|
368 |
+
logger.info(f"✅ EOS token detected! Stopping generation at step {i+1}")
|
369 |
+
break
|
370 |
+
|
371 |
+
# Get embedding for the new token.
|
372 |
+
next_token_embed = self.speech_emb(next_token)
|
373 |
+
next_token_embed = next_token_embed + self.speech_pos_emb.get_fixed_embedding(i + 1)
|
374 |
+
|
375 |
+
# For CFG
|
376 |
+
next_token_embed = torch.cat([next_token_embed, next_token_embed])
|
377 |
+
|
378 |
+
# Forward pass with only the new token and the cached past.
|
379 |
+
output = self.patched_model(
|
380 |
+
inputs_embeds=next_token_embed,
|
381 |
+
past_key_values=past,
|
382 |
+
output_attentions=True,
|
383 |
+
output_hidden_states=True,
|
384 |
+
return_dict=True,
|
385 |
+
)
|
386 |
+
# Update the kv_cache.
|
387 |
+
past = output.past_key_values
|
388 |
+
|
389 |
+
# Concatenate all predicted tokens along the sequence dimension.
|
390 |
+
predicted_tokens = torch.cat(predicted, dim=1) # shape: (B, num_tokens)
|
391 |
+
return predicted_tokens
|
src/chatterbox/models/tokenizers/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .tokenizer import EnTokenizer, MTLTokenizer
|
src/chatterbox/models/tokenizers/tokenizer.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from pathlib import Path
|
7 |
+
from unicodedata import category
|
8 |
+
from tokenizers import Tokenizer
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
|
11 |
+
|
12 |
+
# Special tokens
|
13 |
+
SOT = "[START]"
|
14 |
+
EOT = "[STOP]"
|
15 |
+
UNK = "[UNK]"
|
16 |
+
SPACE = "[SPACE]"
|
17 |
+
SPECIAL_TOKENS = [SOT, EOT, UNK, SPACE, "[PAD]", "[SEP]", "[CLS]", "[MASK]"]
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
class EnTokenizer:
|
22 |
+
def __init__(self, vocab_file_path):
|
23 |
+
self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
|
24 |
+
self.check_vocabset_sot_eot()
|
25 |
+
|
26 |
+
def check_vocabset_sot_eot(self):
|
27 |
+
voc = self.tokenizer.get_vocab()
|
28 |
+
assert SOT in voc
|
29 |
+
assert EOT in voc
|
30 |
+
|
31 |
+
def text_to_tokens(self, text: str):
|
32 |
+
text_tokens = self.encode(text)
|
33 |
+
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
|
34 |
+
return text_tokens
|
35 |
+
|
36 |
+
def encode( self, txt: str, verbose=False):
|
37 |
+
"""
|
38 |
+
clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer
|
39 |
+
"""
|
40 |
+
txt = txt.replace(' ', SPACE)
|
41 |
+
code = self.tokenizer.encode(txt)
|
42 |
+
ids = code.ids
|
43 |
+
return ids
|
44 |
+
|
45 |
+
def decode(self, seq):
|
46 |
+
if isinstance(seq, torch.Tensor):
|
47 |
+
seq = seq.cpu().numpy()
|
48 |
+
|
49 |
+
txt: str = self.tokenizer.decode(seq,
|
50 |
+
skip_special_tokens=False)
|
51 |
+
txt = txt.replace(' ', '')
|
52 |
+
txt = txt.replace(SPACE, ' ')
|
53 |
+
txt = txt.replace(EOT, '')
|
54 |
+
txt = txt.replace(UNK, '')
|
55 |
+
return txt
|
56 |
+
|
57 |
+
|
58 |
+
# Model repository
|
59 |
+
REPO_ID = "ResembleAI/chatterbox-multilingual"
|
60 |
+
|
61 |
+
# Global instances for optional dependencies
|
62 |
+
_kakasi = None
|
63 |
+
_dicta = None
|
64 |
+
|
65 |
+
|
66 |
+
def is_kanji(c: str) -> bool:
|
67 |
+
"""Check if character is kanji."""
|
68 |
+
return 19968 <= ord(c) <= 40959
|
69 |
+
|
70 |
+
|
71 |
+
def is_katakana(c: str) -> bool:
|
72 |
+
"""Check if character is katakana."""
|
73 |
+
return 12449 <= ord(c) <= 12538
|
74 |
+
|
75 |
+
|
76 |
+
def hiragana_normalize(text: str) -> str:
|
77 |
+
"""Japanese text normalization: converts kanji to hiragana; katakana remains the same."""
|
78 |
+
global _kakasi
|
79 |
+
|
80 |
+
try:
|
81 |
+
if _kakasi is None:
|
82 |
+
import pykakasi
|
83 |
+
_kakasi = pykakasi.kakasi()
|
84 |
+
|
85 |
+
result = _kakasi.convert(text)
|
86 |
+
out = []
|
87 |
+
|
88 |
+
for r in result:
|
89 |
+
inp = r['orig']
|
90 |
+
hira = r["hira"]
|
91 |
+
|
92 |
+
# Any kanji in the phrase
|
93 |
+
if any([is_kanji(c) for c in inp]):
|
94 |
+
if hira and hira[0] in ["は", "へ"]: # Safety check for empty hira
|
95 |
+
hira = " " + hira
|
96 |
+
out.append(hira)
|
97 |
+
|
98 |
+
# All katakana
|
99 |
+
elif all([is_katakana(c) for c in inp]) if inp else False: # Safety check for empty inp
|
100 |
+
out.append(r['orig'])
|
101 |
+
|
102 |
+
else:
|
103 |
+
out.append(inp)
|
104 |
+
|
105 |
+
normalized_text = "".join(out)
|
106 |
+
|
107 |
+
# Decompose Japanese characters for tokenizer compatibility
|
108 |
+
import unicodedata
|
109 |
+
normalized_text = unicodedata.normalize('NFKD', normalized_text)
|
110 |
+
|
111 |
+
return normalized_text
|
112 |
+
|
113 |
+
except ImportError:
|
114 |
+
logger.warning("pykakasi not available - Japanese text processing skipped")
|
115 |
+
return text
|
116 |
+
|
117 |
+
|
118 |
+
def add_hebrew_diacritics(text: str) -> str:
|
119 |
+
"""Hebrew text normalization: adds diacritics to Hebrew text."""
|
120 |
+
global _dicta
|
121 |
+
|
122 |
+
try:
|
123 |
+
if _dicta is None:
|
124 |
+
from dicta_onnx import Dicta
|
125 |
+
_dicta = Dicta()
|
126 |
+
|
127 |
+
return _dicta.add_diacritics(text)
|
128 |
+
|
129 |
+
except ImportError:
|
130 |
+
logger.warning("dicta_onnx not available - Hebrew text processing skipped")
|
131 |
+
return text
|
132 |
+
except Exception as e:
|
133 |
+
logger.warning(f"Hebrew diacritization failed: {e}")
|
134 |
+
return text
|
135 |
+
|
136 |
+
|
137 |
+
def korean_normalize(text: str) -> str:
|
138 |
+
"""Korean text normalization: decompose syllables into Jamo for tokenization."""
|
139 |
+
|
140 |
+
def decompose_hangul(char):
|
141 |
+
"""Decompose Korean syllable into Jamo components."""
|
142 |
+
if not ('\uac00' <= char <= '\ud7af'):
|
143 |
+
return char
|
144 |
+
|
145 |
+
# Hangul decomposition formula
|
146 |
+
base = ord(char) - 0xAC00
|
147 |
+
initial = chr(0x1100 + base // (21 * 28))
|
148 |
+
medial = chr(0x1161 + (base % (21 * 28)) // 28)
|
149 |
+
final = chr(0x11A7 + base % 28) if base % 28 > 0 else ''
|
150 |
+
|
151 |
+
return initial + medial + final
|
152 |
+
|
153 |
+
# Decompose syllables and normalize punctuation
|
154 |
+
result = ''.join(decompose_hangul(char) for char in text)
|
155 |
+
result = re.sub(r'[…~?!,:;()「」『』]', '.', result) # Korean punctuation
|
156 |
+
|
157 |
+
return result.strip()
|
158 |
+
|
159 |
+
|
160 |
+
class ChineseCangjieConverter:
|
161 |
+
"""Converts Chinese characters to Cangjie codes for tokenization."""
|
162 |
+
|
163 |
+
def __init__(self, model_dir=None):
|
164 |
+
self.word2cj = {}
|
165 |
+
self.cj2word = {}
|
166 |
+
self.segmenter = None
|
167 |
+
self._load_cangjie_mapping(model_dir)
|
168 |
+
self._init_segmenter()
|
169 |
+
|
170 |
+
def _load_cangjie_mapping(self, model_dir=None):
|
171 |
+
"""Load Cangjie mapping from HuggingFace model repository."""
|
172 |
+
try:
|
173 |
+
cangjie_file = hf_hub_download(
|
174 |
+
repo_id=REPO_ID,
|
175 |
+
filename="Cangjie5_TC.json",
|
176 |
+
cache_dir=model_dir
|
177 |
+
)
|
178 |
+
|
179 |
+
with open(cangjie_file, "r", encoding="utf-8") as fp:
|
180 |
+
data = json.load(fp)
|
181 |
+
|
182 |
+
for entry in data:
|
183 |
+
word, code = entry.split("\t")[:2]
|
184 |
+
self.word2cj[word] = code
|
185 |
+
if code not in self.cj2word:
|
186 |
+
self.cj2word[code] = [word]
|
187 |
+
else:
|
188 |
+
self.cj2word[code].append(word)
|
189 |
+
|
190 |
+
except Exception as e:
|
191 |
+
logger.warning(f"Could not load Cangjie mapping: {e}")
|
192 |
+
|
193 |
+
def _init_segmenter(self):
|
194 |
+
"""Initialize pkuseg segmenter."""
|
195 |
+
try:
|
196 |
+
from pkuseg import pkuseg
|
197 |
+
self.segmenter = pkuseg()
|
198 |
+
except ImportError:
|
199 |
+
logger.warning("pkuseg not available - Chinese segmentation will be skipped")
|
200 |
+
self.segmenter = None
|
201 |
+
|
202 |
+
def _cangjie_encode(self, glyph: str):
|
203 |
+
"""Encode a single Chinese glyph to Cangjie code."""
|
204 |
+
code = self.word2cj.get(glyph, None)
|
205 |
+
if code is None:
|
206 |
+
return None
|
207 |
+
|
208 |
+
index = self.cj2word[code].index(glyph)
|
209 |
+
index_suffix = str(index) if index > 0 else ""
|
210 |
+
return code + index_suffix
|
211 |
+
|
212 |
+
def _normalize_numbers(self, text: str) -> str:
|
213 |
+
"""Convert Arabic numerals (1-99) to Chinese characters."""
|
214 |
+
digit_map = {'0': '零', '1': '一', '2': '二', '3': '三', '4': '四',
|
215 |
+
'5': '五', '6': '六', '7': '七', '8': '八', '9': '九'}
|
216 |
+
|
217 |
+
pattern = re.compile(r'(?<!\d)(\d{1,2})(?!\d)')
|
218 |
+
|
219 |
+
def convert_number(match):
|
220 |
+
num = int(match.group(1))
|
221 |
+
|
222 |
+
if num == 0:
|
223 |
+
return '零'
|
224 |
+
elif 1 <= num <= 9:
|
225 |
+
return digit_map[str(num)]
|
226 |
+
elif num == 10:
|
227 |
+
return '十'
|
228 |
+
elif 11 <= num <= 19:
|
229 |
+
return '十' + digit_map[str(num % 10)]
|
230 |
+
elif 20 <= num <= 99:
|
231 |
+
tens, ones = divmod(num, 10)
|
232 |
+
if ones == 0:
|
233 |
+
return digit_map[str(tens)] + '十'
|
234 |
+
else:
|
235 |
+
return digit_map[str(tens)] + '十' + digit_map[str(ones)]
|
236 |
+
else:
|
237 |
+
return match.group(1)
|
238 |
+
|
239 |
+
return pattern.sub(convert_number, text)
|
240 |
+
|
241 |
+
def convert_chinese_text(self, text: str) -> str:
|
242 |
+
"""Convert Chinese characters in text to Cangjie tokens."""
|
243 |
+
text = re.sub('[、,:;〜-()⦅⦆]', ',', text)
|
244 |
+
text = re.sub('(。|…)', '.', text)
|
245 |
+
text = self._normalize_numbers(text)
|
246 |
+
|
247 |
+
# Skip segmentation for simple sequences (numbers, punctuation, short phrases)
|
248 |
+
if self.segmenter is not None:
|
249 |
+
# This avoids over-segmenting number sequences like "一, 二, 三"
|
250 |
+
is_simple_sequence = (
|
251 |
+
len([c for c in text if category(c) == "Lo"]) <= 15 and # Max 15 Chinese chars
|
252 |
+
text.count(',') >= 2 # Contains multiple commas (likely enumeration)
|
253 |
+
)
|
254 |
+
|
255 |
+
# Only segment complex Chinese text (longer sentences without enumeration patterns)
|
256 |
+
if not is_simple_sequence and len(text) > 10:
|
257 |
+
chinese_chars = sum(1 for c in text if category(c) == "Lo")
|
258 |
+
total_chars = len([c for c in text if c.strip()])
|
259 |
+
|
260 |
+
if chinese_chars > 5 and chinese_chars / total_chars > 0.7:
|
261 |
+
segmented_words = self.segmenter.cut(text)
|
262 |
+
text = " ".join(segmented_words)
|
263 |
+
|
264 |
+
output = []
|
265 |
+
for char in text:
|
266 |
+
if category(char) == "Lo": # Chinese character
|
267 |
+
cangjie = self._cangjie_encode(char)
|
268 |
+
if cangjie is None:
|
269 |
+
output.append(char)
|
270 |
+
continue
|
271 |
+
|
272 |
+
code_tokens = [f"[cj_{c}]" for c in cangjie]
|
273 |
+
code_tokens.append("[cj_.]")
|
274 |
+
|
275 |
+
output.append("".join(code_tokens))
|
276 |
+
else:
|
277 |
+
output.append(char)
|
278 |
+
|
279 |
+
return "".join(output)
|
280 |
+
|
281 |
+
|
282 |
+
class MTLTokenizer:
|
283 |
+
def __init__(self, vocab_file_path):
|
284 |
+
self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
|
285 |
+
model_dir = Path(vocab_file_path).parent
|
286 |
+
self.cangjie_converter = ChineseCangjieConverter(model_dir)
|
287 |
+
self.check_vocabset_sot_eot()
|
288 |
+
|
289 |
+
def check_vocabset_sot_eot(self):
|
290 |
+
voc = self.tokenizer.get_vocab()
|
291 |
+
assert SOT in voc
|
292 |
+
assert EOT in voc
|
293 |
+
|
294 |
+
def text_to_tokens(self, text: str, language_id: str = None):
|
295 |
+
text_tokens = self.encode(text, language_id=language_id)
|
296 |
+
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
|
297 |
+
return text_tokens
|
298 |
+
|
299 |
+
def encode(self, txt: str, language_id: str = None):
|
300 |
+
# Language-specific text processing
|
301 |
+
if language_id == 'zh':
|
302 |
+
txt = self.cangjie_converter.convert_chinese_text(txt)
|
303 |
+
elif language_id == 'ja':
|
304 |
+
txt = hiragana_normalize(txt)
|
305 |
+
elif language_id == 'he':
|
306 |
+
txt = add_hebrew_diacritics(txt)
|
307 |
+
elif language_id == 'ko':
|
308 |
+
txt = korean_normalize(txt)
|
309 |
+
|
310 |
+
# Prepend language token
|
311 |
+
if language_id:
|
312 |
+
txt = f"[{language_id.lower()}]{txt}"
|
313 |
+
|
314 |
+
txt = txt.replace(' ', SPACE)
|
315 |
+
return self.tokenizer.encode(txt).ids
|
316 |
+
|
317 |
+
def decode(self, seq):
|
318 |
+
if isinstance(seq, torch.Tensor):
|
319 |
+
seq = seq.cpu().numpy()
|
320 |
+
|
321 |
+
txt = self.tokenizer.decode(seq, skip_special_tokens=False)
|
322 |
+
txt = txt.replace(' ', '').replace(SPACE, ' ').replace(EOT, '').replace(UNK, '')
|
323 |
+
return txt
|
src/chatterbox/models/utils.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class AttrDict(dict):
|
2 |
+
def __init__(self, *args, **kwargs):
|
3 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
4 |
+
self.__dict__ = self
|
src/chatterbox/models/voice_encoder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .voice_encoder import VoiceEncoder, VoiceEncConfig
|
src/chatterbox/models/voice_encoder/config.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class VoiceEncConfig:
|
2 |
+
num_mels = 40
|
3 |
+
sample_rate = 16000
|
4 |
+
speaker_embed_size = 256
|
5 |
+
ve_hidden_size = 256
|
6 |
+
flatten_lstm_params = False
|
7 |
+
n_fft = 400
|
8 |
+
hop_size = 160
|
9 |
+
win_size = 400
|
10 |
+
fmax = 8000
|
11 |
+
fmin = 0
|
12 |
+
preemphasis = 0.
|
13 |
+
mel_power = 2.0
|
14 |
+
mel_type = "amp"
|
15 |
+
normalized_mels = False
|
16 |
+
ve_partial_frames = 160
|
17 |
+
ve_final_relu = True
|
18 |
+
stft_magnitude_min = 1e-4
|
src/chatterbox/models/voice_encoder/melspec.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
|
3 |
+
from scipy import signal
|
4 |
+
import numpy as np
|
5 |
+
import librosa
|
6 |
+
|
7 |
+
|
8 |
+
@lru_cache()
|
9 |
+
def mel_basis(hp):
|
10 |
+
assert hp.fmax <= hp.sample_rate // 2
|
11 |
+
return librosa.filters.mel(
|
12 |
+
sr=hp.sample_rate,
|
13 |
+
n_fft=hp.n_fft,
|
14 |
+
n_mels=hp.num_mels,
|
15 |
+
fmin=hp.fmin,
|
16 |
+
fmax=hp.fmax) # -> (nmel, nfreq)
|
17 |
+
|
18 |
+
|
19 |
+
def preemphasis(wav, hp):
|
20 |
+
assert hp.preemphasis != 0
|
21 |
+
wav = signal.lfilter([1, -hp.preemphasis], [1], wav)
|
22 |
+
wav = np.clip(wav, -1, 1)
|
23 |
+
return wav
|
24 |
+
|
25 |
+
|
26 |
+
def melspectrogram(wav, hp, pad=True):
|
27 |
+
# Run through pre-emphasis
|
28 |
+
if hp.preemphasis > 0:
|
29 |
+
wav = preemphasis(wav, hp)
|
30 |
+
assert np.abs(wav).max() - 1 < 1e-07
|
31 |
+
|
32 |
+
# Do the stft
|
33 |
+
spec_complex = _stft(wav, hp, pad=pad)
|
34 |
+
|
35 |
+
# Get the magnitudes
|
36 |
+
spec_magnitudes = np.abs(spec_complex)
|
37 |
+
|
38 |
+
if hp.mel_power != 1.0:
|
39 |
+
spec_magnitudes **= hp.mel_power
|
40 |
+
|
41 |
+
# Get the mel and convert magnitudes->db
|
42 |
+
mel = np.dot(mel_basis(hp), spec_magnitudes)
|
43 |
+
if hp.mel_type == "db":
|
44 |
+
mel = _amp_to_db(mel, hp)
|
45 |
+
|
46 |
+
# Normalise the mel from db to 0,1
|
47 |
+
if hp.normalized_mels:
|
48 |
+
mel = _normalize(mel, hp).astype(np.float32)
|
49 |
+
|
50 |
+
assert not pad or mel.shape[1] == 1 + len(wav) // hp.hop_size # Sanity check
|
51 |
+
return mel # (M, T)
|
52 |
+
|
53 |
+
|
54 |
+
def _stft(y, hp, pad=True):
|
55 |
+
# NOTE: after 0.8, pad mode defaults to constant, setting this to reflect for
|
56 |
+
# historical consistency and streaming-version consistency
|
57 |
+
return librosa.stft(
|
58 |
+
y,
|
59 |
+
n_fft=hp.n_fft,
|
60 |
+
hop_length=hp.hop_size,
|
61 |
+
win_length=hp.win_size,
|
62 |
+
center=pad,
|
63 |
+
pad_mode="reflect",
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
def _amp_to_db(x, hp):
|
68 |
+
return 20 * np.log10(np.maximum(hp.stft_magnitude_min, x))
|
69 |
+
|
70 |
+
|
71 |
+
def _db_to_amp(x):
|
72 |
+
return np.power(10.0, x * 0.05)
|
73 |
+
|
74 |
+
|
75 |
+
def _normalize(s, hp, headroom_db=15):
|
76 |
+
min_level_db = 20 * np.log10(hp.stft_magnitude_min)
|
77 |
+
s = (s - min_level_db) / (-min_level_db + headroom_db)
|
78 |
+
return s
|
src/chatterbox/models/voice_encoder/voice_encoder.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/CorentinJ/Real-Time-Voice-Cloning
|
2 |
+
# MIT License
|
3 |
+
from typing import List, Union, Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from numpy.lib.stride_tricks import as_strided
|
7 |
+
import librosa
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch import nn, Tensor
|
11 |
+
|
12 |
+
from .config import VoiceEncConfig
|
13 |
+
from .melspec import melspectrogram
|
14 |
+
|
15 |
+
|
16 |
+
def pack(arrays, seq_len: int=None, pad_value=0):
|
17 |
+
"""
|
18 |
+
Given a list of length B of array-like objects of shapes (Ti, ...), packs them in a single tensor of
|
19 |
+
shape (B, T, ...) by padding each individual array on the right.
|
20 |
+
|
21 |
+
:param arrays: a list of array-like objects of matching shapes except for the first axis.
|
22 |
+
:param seq_len: the value of T. It must be the maximum of the lengths Ti of the arrays at
|
23 |
+
minimum. Will default to that value if None.
|
24 |
+
:param pad_value: the value to pad the arrays with.
|
25 |
+
:return: a (B, T, ...) tensor
|
26 |
+
"""
|
27 |
+
if seq_len is None:
|
28 |
+
seq_len = max(len(array) for array in arrays)
|
29 |
+
else:
|
30 |
+
assert seq_len >= max(len(array) for array in arrays)
|
31 |
+
|
32 |
+
# Convert lists to np.array
|
33 |
+
if isinstance(arrays[0], list):
|
34 |
+
arrays = [np.array(array) for array in arrays]
|
35 |
+
|
36 |
+
# Convert to tensor and handle device
|
37 |
+
device = None
|
38 |
+
if isinstance(arrays[0], torch.Tensor):
|
39 |
+
tensors = arrays
|
40 |
+
device = tensors[0].device
|
41 |
+
else:
|
42 |
+
tensors = [torch.as_tensor(array) for array in arrays]
|
43 |
+
|
44 |
+
# Fill the packed tensor with the array data
|
45 |
+
packed_shape = (len(tensors), seq_len, *tensors[0].shape[1:])
|
46 |
+
packed_tensor = torch.full(packed_shape, pad_value, dtype=tensors[0].dtype, device=device)
|
47 |
+
|
48 |
+
for i, tensor in enumerate(tensors):
|
49 |
+
packed_tensor[i, :tensor.size(0)] = tensor
|
50 |
+
|
51 |
+
return packed_tensor
|
52 |
+
|
53 |
+
|
54 |
+
def get_num_wins(
|
55 |
+
n_frames: int,
|
56 |
+
step: int,
|
57 |
+
min_coverage: float,
|
58 |
+
hp: VoiceEncConfig,
|
59 |
+
):
|
60 |
+
assert n_frames > 0
|
61 |
+
win_size = hp.ve_partial_frames
|
62 |
+
n_wins, remainder = divmod(max(n_frames - win_size + step, 0), step)
|
63 |
+
if n_wins == 0 or (remainder + (win_size - step)) / win_size >= min_coverage:
|
64 |
+
n_wins += 1
|
65 |
+
target_n = win_size + step * (n_wins - 1)
|
66 |
+
return n_wins, target_n
|
67 |
+
|
68 |
+
|
69 |
+
def get_frame_step(
|
70 |
+
overlap: float,
|
71 |
+
rate: float,
|
72 |
+
hp: VoiceEncConfig,
|
73 |
+
):
|
74 |
+
# Compute how many frames separate two partial utterances
|
75 |
+
assert 0 <= overlap < 1
|
76 |
+
if rate is None:
|
77 |
+
frame_step = int(np.round(hp.ve_partial_frames * (1 - overlap)))
|
78 |
+
else:
|
79 |
+
frame_step = int(np.round((hp.sample_rate / rate) / hp.ve_partial_frames))
|
80 |
+
assert 0 < frame_step <= hp.ve_partial_frames
|
81 |
+
return frame_step
|
82 |
+
|
83 |
+
|
84 |
+
def stride_as_partials(
|
85 |
+
mel: np.ndarray,
|
86 |
+
hp: VoiceEncConfig,
|
87 |
+
overlap=0.5,
|
88 |
+
rate: float=None,
|
89 |
+
min_coverage=0.8,
|
90 |
+
):
|
91 |
+
"""
|
92 |
+
Takes unscaled mels in (T, M) format
|
93 |
+
TODO: doc
|
94 |
+
"""
|
95 |
+
assert 0 < min_coverage <= 1
|
96 |
+
frame_step = get_frame_step(overlap, rate, hp)
|
97 |
+
|
98 |
+
# Compute how many partials can fit in the mel
|
99 |
+
n_partials, target_len = get_num_wins(len(mel), frame_step, min_coverage, hp)
|
100 |
+
|
101 |
+
# Trim or pad the mel spectrogram to match the number of partials
|
102 |
+
if target_len > len(mel):
|
103 |
+
mel = np.concatenate((mel, np.full((target_len - len(mel), hp.num_mels), 0)))
|
104 |
+
elif target_len < len(mel):
|
105 |
+
mel = mel[:target_len]
|
106 |
+
|
107 |
+
# Ensure the numpy array data is float32 and contiguous in memory
|
108 |
+
mel = mel.astype(np.float32, order="C")
|
109 |
+
|
110 |
+
# Re-arrange the array in memory to be of shape (N, P, M) with partials overlapping eachother,
|
111 |
+
# where N is the number of partials, P is the number of frames of each partial and M the
|
112 |
+
# number of channels of the mel spectrograms.
|
113 |
+
shape = (n_partials, hp.ve_partial_frames, hp.num_mels)
|
114 |
+
strides = (mel.strides[0] * frame_step, mel.strides[0], mel.strides[1])
|
115 |
+
partials = as_strided(mel, shape, strides)
|
116 |
+
return partials
|
117 |
+
|
118 |
+
|
119 |
+
class VoiceEncoder(nn.Module):
|
120 |
+
def __init__(self, hp=VoiceEncConfig()):
|
121 |
+
super().__init__()
|
122 |
+
|
123 |
+
self.hp = hp
|
124 |
+
|
125 |
+
# Network definition
|
126 |
+
self.lstm = nn.LSTM(self.hp.num_mels, self.hp.ve_hidden_size, num_layers=3, batch_first=True)
|
127 |
+
if hp.flatten_lstm_params:
|
128 |
+
self.lstm.flatten_parameters()
|
129 |
+
self.proj = nn.Linear(self.hp.ve_hidden_size, self.hp.speaker_embed_size)
|
130 |
+
|
131 |
+
# Cosine similarity scaling (fixed initial parameter values)
|
132 |
+
self.similarity_weight = nn.Parameter(torch.tensor([10.]), requires_grad=True)
|
133 |
+
self.similarity_bias = nn.Parameter(torch.tensor([-5.]), requires_grad=True)
|
134 |
+
|
135 |
+
@property
|
136 |
+
def device(self):
|
137 |
+
return next(self.parameters()).device
|
138 |
+
|
139 |
+
def forward(self, mels: torch.FloatTensor):
|
140 |
+
"""
|
141 |
+
Computes the embeddings of a batch of partial utterances.
|
142 |
+
|
143 |
+
:param mels: a batch of unscaled mel spectrograms of same duration as a float32 tensor
|
144 |
+
of shape (B, T, M) where T is hp.ve_partial_frames
|
145 |
+
:return: the embeddings as a float32 tensor of shape (B, E) where E is
|
146 |
+
hp.speaker_embed_size. Embeddings are L2-normed and thus lay in the range [-1, 1].
|
147 |
+
"""
|
148 |
+
if self.hp.normalized_mels and (mels.min() < 0 or mels.max() > 1):
|
149 |
+
raise Exception(f"Mels outside [0, 1]. Min={mels.min()}, Max={mels.max()}")
|
150 |
+
|
151 |
+
# Pass the input through the LSTM layers
|
152 |
+
_, (hidden, _) = self.lstm(mels)
|
153 |
+
|
154 |
+
# Project the final hidden state
|
155 |
+
raw_embeds = self.proj(hidden[-1])
|
156 |
+
if self.hp.ve_final_relu:
|
157 |
+
raw_embeds = F.relu(raw_embeds)
|
158 |
+
|
159 |
+
# L2 normalize the embeddings.
|
160 |
+
return raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True)
|
161 |
+
|
162 |
+
def inference(self, mels: torch.Tensor, mel_lens, overlap=0.5, rate: float=None, min_coverage=0.8, batch_size=None):
|
163 |
+
"""
|
164 |
+
Computes the embeddings of a batch of full utterances with gradients.
|
165 |
+
|
166 |
+
:param mels: (B, T, M) unscaled mels
|
167 |
+
:return: (B, E) embeddings on CPU
|
168 |
+
"""
|
169 |
+
mel_lens = mel_lens.tolist() if torch.is_tensor(mel_lens) else mel_lens
|
170 |
+
|
171 |
+
# Compute where to split the utterances into partials
|
172 |
+
frame_step = get_frame_step(overlap, rate, self.hp)
|
173 |
+
n_partials, target_lens = zip(*(get_num_wins(l, frame_step, min_coverage, self.hp) for l in mel_lens))
|
174 |
+
|
175 |
+
# Possibly pad the mels to reach the target lengths
|
176 |
+
len_diff = max(target_lens) - mels.size(1)
|
177 |
+
if len_diff > 0:
|
178 |
+
pad = torch.full((mels.size(0), len_diff, self.hp.num_mels), 0, dtype=torch.float32)
|
179 |
+
mels = torch.cat((mels, pad.to(mels.device)), dim=1)
|
180 |
+
|
181 |
+
# Group all partials together so that we can batch them easily
|
182 |
+
partials = [
|
183 |
+
mel[i * frame_step: i * frame_step + self.hp.ve_partial_frames]
|
184 |
+
for mel, n_partial in zip(mels, n_partials) for i in range(n_partial)
|
185 |
+
]
|
186 |
+
assert all(partials[0].shape == partial.shape for partial in partials)
|
187 |
+
partials = torch.stack(partials)
|
188 |
+
|
189 |
+
# Forward the partials
|
190 |
+
n_chunks = int(np.ceil(len(partials) / (batch_size or len(partials))))
|
191 |
+
partial_embeds = torch.cat([self(batch) for batch in partials.chunk(n_chunks)], dim=0).cpu()
|
192 |
+
|
193 |
+
# Reduce the partial embeds into full embeds and L2-normalize them
|
194 |
+
slices = np.concatenate(([0], np.cumsum(n_partials)))
|
195 |
+
raw_embeds = [torch.mean(partial_embeds[start:end], dim=0) for start, end in zip(slices[:-1], slices[1:])]
|
196 |
+
raw_embeds = torch.stack(raw_embeds)
|
197 |
+
embeds = raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True)
|
198 |
+
|
199 |
+
return embeds
|
200 |
+
|
201 |
+
@staticmethod
|
202 |
+
def utt_to_spk_embed(utt_embeds: np.ndarray):
|
203 |
+
"""
|
204 |
+
Takes an array of L2-normalized utterance embeddings, computes the mean embedding and L2-normalize it to get a
|
205 |
+
speaker embedding.
|
206 |
+
"""
|
207 |
+
assert utt_embeds.ndim == 2
|
208 |
+
utt_embeds = np.mean(utt_embeds, axis=0)
|
209 |
+
return utt_embeds / np.linalg.norm(utt_embeds, 2)
|
210 |
+
|
211 |
+
@staticmethod
|
212 |
+
def voice_similarity(embeds_x: np.ndarray, embeds_y: np.ndarray):
|
213 |
+
"""
|
214 |
+
Cosine similarity for L2-normalized utterance embeddings or speaker embeddings
|
215 |
+
"""
|
216 |
+
embeds_x = embeds_x if embeds_x.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_x)
|
217 |
+
embeds_y = embeds_y if embeds_y.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_y)
|
218 |
+
return embeds_x @ embeds_y
|
219 |
+
|
220 |
+
def embeds_from_mels(
|
221 |
+
self, mels: Union[Tensor, List[np.ndarray]], mel_lens=None, as_spk=False, batch_size=32, **kwargs
|
222 |
+
):
|
223 |
+
"""
|
224 |
+
Convenience function for deriving utterance or speaker embeddings from mel spectrograms.
|
225 |
+
|
226 |
+
:param mels: unscaled mels strictly within [0, 1] as either a (B, T, M) tensor or a list of (Ti, M) arrays.
|
227 |
+
:param mel_lens: if passing mels as a tensor, individual mel lengths
|
228 |
+
:param as_spk: whether to return utterance embeddings or a single speaker embedding
|
229 |
+
:param kwargs: args for inference()
|
230 |
+
|
231 |
+
:returns: embeds as a (B, E) float32 numpy array if <as_spk> is False, else as a (E,) array
|
232 |
+
"""
|
233 |
+
# Load mels in memory and pack them
|
234 |
+
if isinstance(mels, List):
|
235 |
+
mels = [np.asarray(mel) for mel in mels]
|
236 |
+
assert all(m.shape[1] == mels[0].shape[1] for m in mels), "Mels aren't in (B, T, M) format"
|
237 |
+
mel_lens = [mel.shape[0] for mel in mels]
|
238 |
+
mels = pack(mels)
|
239 |
+
|
240 |
+
# Embed them
|
241 |
+
with torch.inference_mode():
|
242 |
+
utt_embeds = self.inference(mels.to(self.device), mel_lens, batch_size=batch_size, **kwargs).numpy()
|
243 |
+
|
244 |
+
return self.utt_to_spk_embed(utt_embeds) if as_spk else utt_embeds
|
245 |
+
|
246 |
+
def embeds_from_wavs(
|
247 |
+
self,
|
248 |
+
wavs: List[np.ndarray],
|
249 |
+
sample_rate,
|
250 |
+
as_spk=False,
|
251 |
+
batch_size=32,
|
252 |
+
trim_top_db: Optional[float]=20,
|
253 |
+
**kwargs
|
254 |
+
):
|
255 |
+
"""
|
256 |
+
Wrapper around embeds_from_mels
|
257 |
+
|
258 |
+
:param trim_top_db: this argument was only added for the sake of compatibility with metavoice's implementation
|
259 |
+
"""
|
260 |
+
if sample_rate != self.hp.sample_rate:
|
261 |
+
wavs = [
|
262 |
+
librosa.resample(wav, orig_sr=sample_rate, target_sr=self.hp.sample_rate, res_type="kaiser_fast")
|
263 |
+
for wav in wavs
|
264 |
+
]
|
265 |
+
|
266 |
+
if trim_top_db:
|
267 |
+
wavs = [librosa.effects.trim(wav, top_db=trim_top_db)[0] for wav in wavs]
|
268 |
+
|
269 |
+
if "rate" not in kwargs:
|
270 |
+
kwargs["rate"] = 1.3 # Resemble's default value.
|
271 |
+
|
272 |
+
mels = [melspectrogram(w, self.hp).T for w in wavs]
|
273 |
+
|
274 |
+
return self.embeds_from_mels(mels, as_spk=as_spk, batch_size=batch_size, **kwargs)
|