진승환 commited on
Commit
44a8c76
·
1 Parent(s): ee29f36
README.md CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  <div align="center">
2
  <h1>
3
  Spark-TTS
 
1
+ ---
2
+ title: Spark-TTS
3
+ app_file: webui.py
4
+ sdk: gradio
5
+ sdk_version: 5.18.0
6
+ ---
7
  <div align="center">
8
  <h1>
9
  Spark-TTS
sparkTTS/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
sparkTTS/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SparkTTS
3
+ emoji: 🐨
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.21.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
sparktts/models/audio_tokenizer.py DELETED
@@ -1,163 +0,0 @@
1
- # Copyright (c) 2025 SparkAudio
2
- # 2025 Xinsheng Wang ([email protected])
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
-
17
- import torch
18
- import numpy as np
19
-
20
- from pathlib import Path
21
- from typing import Any, Dict, Tuple
22
- from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
23
-
24
- from sparktts.utils.file import load_config
25
- from sparktts.utils.audio import load_audio
26
- from sparktts.models.bicodec import BiCodec
27
-
28
-
29
- class BiCodecTokenizer:
30
- """BiCodec tokenizer for handling audio input and tokenization."""
31
-
32
- def __init__(self, model_dir: Path, device: torch.device = None, **kwargs):
33
- super().__init__()
34
- """
35
- Args:
36
- model_dir: Path to the model directory.
37
- device: Device to run the model on (default is GPU if available).
38
- """
39
- self.device = device
40
- self.model_dir = model_dir
41
- self.config = load_config(f"{model_dir}/config.yaml")
42
- self._initialize_model()
43
-
44
- def _initialize_model(self):
45
- """Load and initialize the BiCodec model and Wav2Vec2 feature extractor."""
46
- self.model = BiCodec.load_from_checkpoint(f"{self.model_dir}/BiCodec").to(
47
- self.device
48
- )
49
- self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
50
- f"{self.model_dir}/wav2vec2-large-xlsr-53"
51
- )
52
- self.feature_extractor = Wav2Vec2Model.from_pretrained(
53
- f"{self.model_dir}/wav2vec2-large-xlsr-53"
54
- ).to(self.device)
55
- self.feature_extractor.config.output_hidden_states = True
56
-
57
- def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
58
- """Get reference audio clip for speaker embedding."""
59
- ref_segment_length = (
60
- int(self.config["sample_rate"] * self.config["ref_segment_duration"])
61
- // self.config["latent_hop_length"]
62
- * self.config["latent_hop_length"]
63
- )
64
- wav_length = len(wav)
65
-
66
- if ref_segment_length > wav_length:
67
- # Repeat and truncate to handle insufficient length
68
- wav = np.tile(wav, ref_segment_length // wav_length + 1)
69
-
70
- return wav[:ref_segment_length]
71
-
72
- def process_audio(self, wav_path: Path) -> Tuple[np.ndarray, torch.Tensor]:
73
- """load auido and get reference audio from wav path"""
74
- wav = load_audio(
75
- wav_path,
76
- sampling_rate=self.config["sample_rate"],
77
- volume_normalize=self.config["volume_normalize"],
78
- )
79
-
80
- wav_ref = self.get_ref_clip(wav)
81
-
82
- wav_ref = torch.from_numpy(wav_ref).unsqueeze(0).float()
83
- return wav, wav_ref
84
-
85
- def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor:
86
- """extract wav2vec2 features"""
87
- inputs = self.processor(
88
- wavs,
89
- sampling_rate=16000,
90
- return_tensors="pt",
91
- padding=True,
92
- output_hidden_states=True,
93
- ).input_values
94
- feat = self.feature_extractor(inputs.to(self.feature_extractor.device))
95
- feats_mix = (
96
- feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16]
97
- ) / 3
98
-
99
- return feats_mix
100
-
101
- def tokenize_batch(self, batch: Dict[str, Any]) -> torch.Tensor:
102
- """tokenize the batch of audio
103
-
104
- Args:
105
- batch:
106
- wavs (List[np.ndarray]): batch of audio
107
- ref_wavs (torch.Tensor): reference audio. shape: (batch_size, seq_len)
108
-
109
- Returns:
110
- semantic_tokens: semantic tokens. shape: (batch_size, seq_len, latent_dim)
111
- global_tokens: global tokens. shape: (batch_size, seq_len, global_dim)
112
- """
113
- feats = self.extract_wav2vec2_features(batch["wav"])
114
- batch["feat"] = feats
115
- semantic_tokens, global_tokens = self.model.tokenize(batch)
116
-
117
- return global_tokens, semantic_tokens
118
-
119
- def tokenize(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
120
- """tokenize the audio"""
121
- wav, ref_wav = self.process_audio(audio_path)
122
- feat = self.extract_wav2vec2_features(wav)
123
- batch = {
124
- "wav": torch.from_numpy(wav).unsqueeze(0).float().to(self.device),
125
- "ref_wav": ref_wav.to(self.device),
126
- "feat": feat.to(self.device),
127
- }
128
- semantic_tokens, global_tokens = self.model.tokenize(batch)
129
-
130
- return global_tokens, semantic_tokens
131
-
132
- def detokenize(
133
- self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor
134
- ) -> np.array:
135
- """detokenize the tokens to waveform
136
-
137
- Args:
138
- global_tokens: global tokens. shape: (batch_size, global_dim)
139
- semantic_tokens: semantic tokens. shape: (batch_size, latent_dim)
140
-
141
- Returns:
142
- wav_rec: waveform. shape: (batch_size, seq_len) for batch or (seq_len,) for single
143
- """
144
- global_tokens = global_tokens.unsqueeze(1)
145
- wav_rec = self.model.detokenize(semantic_tokens, global_tokens)
146
- return wav_rec.detach().squeeze().cpu().numpy()
147
-
148
-
149
- # test
150
- if __name__ == "__main__":
151
- import soundfile as sf
152
-
153
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154
- tokenizer = BiCodecTokenizer(
155
- model_dir="pretrained_models/Spark-TTS-0.5B",
156
- device=device,
157
- )
158
- wav_path = "example/prompt_audio.wav"
159
-
160
- global_tokens, semantic_tokens = tokenizer.tokenize(wav_path)
161
-
162
- wav_rec = tokenizer.detokenize(global_tokens.squeeze(0), semantic_tokens)
163
- sf.write("example/prompt_recon.wav", wav_rec, 16000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sparktts/models/bicodec.py DELETED
@@ -1,247 +0,0 @@
1
- # Copyright (c) 2025 SparkAudio
2
- # 2025 Xinsheng Wang ([email protected])
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import torch
17
- import torch.nn as nn
18
- from pathlib import Path
19
- from typing import Dict, Any
20
- from omegaconf import DictConfig
21
- from safetensors.torch import load_file
22
-
23
- from sparktts.utils.file import load_config
24
- from sparktts.modules.speaker.speaker_encoder import SpeakerEncoder
25
- from sparktts.modules.encoder_decoder.feat_encoder import Encoder
26
- from sparktts.modules.encoder_decoder.feat_decoder import Decoder
27
- from sparktts.modules.encoder_decoder.wave_generator import WaveGenerator
28
- from sparktts.modules.vq.factorized_vector_quantize import FactorizedVectorQuantize
29
-
30
-
31
- class BiCodec(nn.Module):
32
- """
33
- BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder,
34
- quantizer, and wave generator.
35
- """
36
-
37
- def __init__(
38
- self,
39
- mel_params: Dict[str, Any],
40
- encoder: nn.Module,
41
- decoder: nn.Module,
42
- quantizer: nn.Module,
43
- speaker_encoder: nn.Module,
44
- prenet: nn.Module,
45
- postnet: nn.Module,
46
- **kwargs
47
- ) -> None:
48
- """
49
- Initializes the BiCodec model with the required components.
50
-
51
- Args:
52
- mel_params (dict): Parameters for the mel-spectrogram transformer.
53
- encoder (nn.Module): Encoder module.
54
- decoder (nn.Module): Decoder module.
55
- quantizer (nn.Module): Quantizer module.
56
- speaker_encoder (nn.Module): Speaker encoder module.
57
- prenet (nn.Module): Prenet network.
58
- postnet (nn.Module): Postnet network.
59
- """
60
- super().__init__()
61
- self.encoder = encoder
62
- self.decoder = decoder
63
- self.quantizer = quantizer
64
- self.speaker_encoder = speaker_encoder
65
- self.prenet = prenet
66
- self.postnet = postnet
67
- self.init_mel_transformer(mel_params)
68
-
69
- @classmethod
70
- def load_from_checkpoint(cls, model_dir: Path, **kwargs) -> "BiCodec":
71
- """
72
- Loads the model from a checkpoint.
73
-
74
- Args:
75
- model_dir (Path): Path to the model directory containing checkpoint and config.
76
-
77
- Returns:
78
- BiCodec: The initialized BiCodec model.
79
- """
80
- ckpt_path = f'{model_dir}/model.safetensors'
81
- config = load_config(f'{model_dir}/config.yaml')['audio_tokenizer']
82
- mel_params = config["mel_params"]
83
- encoder = Encoder(**config["encoder"])
84
- quantizer = FactorizedVectorQuantize(**config["quantizer"])
85
- prenet = Decoder(**config["prenet"])
86
- postnet = Decoder(**config["postnet"])
87
- decoder = WaveGenerator(**config["decoder"])
88
- speaker_encoder = SpeakerEncoder(**config["speaker_encoder"])
89
-
90
- model = cls(
91
- mel_params=mel_params,
92
- encoder=encoder,
93
- decoder=decoder,
94
- quantizer=quantizer,
95
- speaker_encoder=speaker_encoder,
96
- prenet=prenet,
97
- postnet=postnet,
98
- )
99
-
100
- state_dict = load_file(ckpt_path)
101
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
102
-
103
- for key in missing_keys:
104
- print(f"Missing tensor: {key}")
105
- for key in unexpected_keys:
106
- print(f"Unexpected tensor: {key}")
107
-
108
- model.eval()
109
- model.remove_weight_norm()
110
-
111
- return model
112
-
113
- def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
114
- """
115
- Performs a forward pass through the model.
116
-
117
- Args:
118
- batch (dict): A dictionary containing features, reference waveform, and target waveform.
119
-
120
- Returns:
121
- dict: A dictionary containing the reconstruction, features, and other metrics.
122
- """
123
- feat = batch["feat"]
124
- mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
125
-
126
- z = self.encoder(feat.transpose(1, 2))
127
- vq_outputs = self.quantizer(z)
128
-
129
- x_vector, d_vector = self.speaker_encoder(mel.transpose(1, 2))
130
-
131
- conditions = d_vector
132
- with_speaker_loss = False
133
-
134
- x = self.prenet(vq_outputs["z_q"], conditions)
135
- pred_feat = self.postnet(x)
136
- x = x + conditions.unsqueeze(-1)
137
- wav_recon = self.decoder(x)
138
-
139
- return {
140
- "vq_loss": vq_outputs["vq_loss"],
141
- "perplexity": vq_outputs["perplexity"],
142
- "cluster_size": vq_outputs["active_num"],
143
- "recons": wav_recon,
144
- "pred_feat": pred_feat,
145
- "x_vector": x_vector,
146
- "d_vector": d_vector,
147
- "audios": batch["wav"].unsqueeze(1),
148
- "with_speaker_loss": with_speaker_loss,
149
- }
150
-
151
- @torch.no_grad()
152
- def tokenize(self, batch: Dict[str, Any]):
153
- """
154
- Tokenizes the input audio into semantic and global tokens.
155
-
156
- Args:
157
- batch (dict): The input audio features and reference waveform.
158
-
159
- Returns:
160
- tuple: Semantic tokens and global tokens.
161
- """
162
- feat = batch["feat"]
163
- mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
164
-
165
- z = self.encoder(feat.transpose(1, 2))
166
- semantic_tokens = self.quantizer.tokenize(z)
167
- global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2))
168
-
169
- return semantic_tokens, global_tokens
170
-
171
- @torch.no_grad()
172
- def detokenize(self, semantic_tokens, global_tokens):
173
- """
174
- Detokenizes the semantic and global tokens into a waveform.
175
-
176
- Args:
177
- semantic_tokens (tensor): Semantic tokens.
178
- global_tokens (tensor): Global tokens.
179
-
180
- Returns:
181
- tensor: Reconstructed waveform.
182
- """
183
- z_q = self.quantizer.detokenize(semantic_tokens)
184
- d_vector = self.speaker_encoder.detokenize(global_tokens)
185
- x = self.prenet(z_q, d_vector)
186
- x = x + d_vector.unsqueeze(-1)
187
- wav_recon = self.decoder(x)
188
-
189
- return wav_recon
190
-
191
- def init_mel_transformer(self, config: Dict[str, Any]):
192
- """
193
- Initializes the MelSpectrogram transformer based on the provided configuration.
194
-
195
- Args:
196
- config (dict): Configuration parameters for MelSpectrogram.
197
- """
198
- import torchaudio.transforms as TT
199
-
200
- self.mel_transformer = TT.MelSpectrogram(
201
- config["sample_rate"],
202
- config["n_fft"],
203
- config["win_length"],
204
- config["hop_length"],
205
- config["mel_fmin"],
206
- config["mel_fmax"],
207
- n_mels=config["num_mels"],
208
- power=1,
209
- norm="slaney",
210
- mel_scale="slaney",
211
- )
212
-
213
- def remove_weight_norm(self):
214
- """Removes weight normalization from all layers."""
215
- def _remove_weight_norm(m):
216
- try:
217
- torch.nn.utils.remove_weight_norm(m)
218
- except ValueError:
219
- pass # The module didn't have weight norm
220
-
221
- self.apply(_remove_weight_norm)
222
-
223
-
224
- # Test the model
225
- if __name__ == "__main__":
226
-
227
- config = load_config("pretrained_models/SparkTTS-0.5B/BiCodec/config.yaml")
228
- model = BiCodec.load_from_checkpoint(
229
- model_dir="pretrained_models/SparkTTS-0.5B/BiCodec",
230
- )
231
-
232
- # Generate random inputs for testing
233
- duration = 0.96
234
- x = torch.randn(20, 1, int(duration * 16000))
235
- feat = torch.randn(20, int(duration * 50), 1024)
236
- inputs = {"feat": feat, "wav": x, "ref_wav": x}
237
-
238
- # Forward pass
239
- outputs = model(inputs)
240
- semantic_tokens, global_tokens = model.tokenize(inputs)
241
- wav_recon = model.detokenize(semantic_tokens, global_tokens)
242
-
243
- # Verify if the reconstruction matches
244
- if torch.allclose(outputs["recons"].detach(), wav_recon):
245
- print("Test successful")
246
- else:
247
- print("Test failed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sparktts/modules/blocks/layers.py DELETED
@@ -1,73 +0,0 @@
1
- # Copyright (c) 2025 SparkAudio
2
- # 2025 Xinsheng Wang ([email protected])
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- # Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
17
-
18
-
19
- import torch
20
- import torch.nn as nn
21
- from torch.nn.utils import weight_norm
22
-
23
-
24
- def WNConv1d(*args, **kwargs):
25
- return weight_norm(nn.Conv1d(*args, **kwargs))
26
-
27
-
28
- def WNConvTranspose1d(*args, **kwargs):
29
- return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
30
-
31
-
32
- # Scripting this brings model speed up 1.4x
33
- @torch.jit.script
34
- def snake(x, alpha):
35
- shape = x.shape
36
- x = x.reshape(shape[0], shape[1], -1)
37
- x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
38
- x = x.reshape(shape)
39
- return x
40
-
41
-
42
- class Snake1d(nn.Module):
43
- def __init__(self, channels):
44
- super().__init__()
45
- self.alpha = nn.Parameter(torch.ones(1, channels, 1))
46
-
47
- def forward(self, x):
48
- return snake(x, self.alpha)
49
-
50
-
51
- class ResidualUnit(nn.Module):
52
- def __init__(self, dim: int = 16, dilation: int = 1):
53
- super().__init__()
54
- pad = ((7 - 1) * dilation) // 2
55
- self.block = nn.Sequential(
56
- Snake1d(dim),
57
- WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
58
- Snake1d(dim),
59
- WNConv1d(dim, dim, kernel_size=1),
60
- )
61
-
62
- def forward(self, x):
63
- y = self.block(x)
64
- pad = (x.shape[-1] - y.shape[-1]) // 2
65
- if pad > 0:
66
- x = x[..., pad:-pad]
67
- return x + y
68
-
69
-
70
- def init_weights(m):
71
- if isinstance(m, nn.Conv1d):
72
- nn.init.trunc_normal_(m.weight, std=0.02)
73
- nn.init.constant_(m.bias, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sparktts/modules/blocks/samper.py DELETED
@@ -1,115 +0,0 @@
1
- # Copyright (c) 2025 SparkAudio
2
- # 2025 Xinsheng Wang ([email protected])
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
-
17
- import torch
18
- import torch.nn as nn
19
- import torch.nn.functional as F
20
-
21
-
22
- class SamplingBlock(nn.Module):
23
- """Sampling block for upsampling or downsampling"""
24
-
25
- def __init__(
26
- self,
27
- dim: int,
28
- groups: int = 1,
29
- upsample_scale: int = 1,
30
- downsample_scale: int = 1,
31
- ) -> None:
32
- """
33
- Args:
34
- dim: input dimension
35
- groups: number of groups
36
- upsample_scale: upsampling scale
37
- downsample_scale: downsampling scale
38
- """
39
- super(SamplingBlock, self).__init__()
40
-
41
- self.upsample_scale = upsample_scale
42
- self.downsample_scale = downsample_scale
43
-
44
- if self.upsample_scale > 1:
45
- self.de_conv_upsampler = nn.Sequential(
46
- nn.LeakyReLU(0.2),
47
- nn.ConvTranspose1d(
48
- dim,
49
- dim,
50
- kernel_size=upsample_scale * 2,
51
- stride=upsample_scale,
52
- padding=upsample_scale // 2 + upsample_scale % 2,
53
- output_padding=upsample_scale % 2,
54
- groups=groups,
55
- ),
56
- )
57
-
58
- if self.downsample_scale > 1:
59
- self.conv_downsampler = nn.Sequential(
60
- nn.LeakyReLU(0.2),
61
- nn.Conv1d(
62
- dim,
63
- dim,
64
- kernel_size=2 * downsample_scale,
65
- stride=downsample_scale,
66
- padding=downsample_scale // 2 + downsample_scale % 2,
67
- groups=groups,
68
- ),
69
- )
70
-
71
- @staticmethod
72
- def repeat_upsampler(x, upsample_scale):
73
- return x.repeat_interleave(upsample_scale, dim=2)
74
-
75
- @staticmethod
76
- def skip_downsampler(x, downsample_scale):
77
- return F.avg_pool1d(x, kernel_size=downsample_scale, stride=downsample_scale)
78
-
79
- def forward(self, x):
80
- x = x.transpose(1, 2)
81
- if self.upsample_scale > 1:
82
- repeat_res = self.repeat_upsampler(x, self.upsample_scale)
83
- deconv_res = self.de_conv_upsampler(x)
84
- upmerge_res = repeat_res + deconv_res
85
- else:
86
- upmerge_res = x
87
- repeat_res = x
88
-
89
- if self.downsample_scale > 1:
90
- conv_res = self.conv_downsampler(upmerge_res)
91
- skip2_res = self.skip_downsampler(upmerge_res, self.downsample_scale)
92
- skip1_res = self.skip_downsampler(repeat_res, self.downsample_scale)
93
- else:
94
- conv_res = upmerge_res
95
- skip2_res = upmerge_res
96
- skip1_res = repeat_res
97
-
98
- final_res = conv_res + skip1_res + skip2_res
99
-
100
- return final_res
101
-
102
-
103
- # test
104
- if __name__ == "__main__":
105
- test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
106
- model = SamplingBlock(1024, 1024, upsample_scale=2)
107
- model_down = SamplingBlock(1024, 1024, downsample_scale=2)
108
- output = model(test_input)
109
- output_down = model_down(test_input)
110
- print("shape after upsample * 2", output.shape) # torch.Size([8, 1024, 100])
111
- print("shape after downsample * 2", output_down.shape) # torch.Size([8, 1024, 25])
112
- if output.shape == torch.Size([8, 1024, 100]) and output_down.shape == torch.Size(
113
- [8, 1024, 25]
114
- ):
115
- print("test successful")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sparktts/modules/blocks/vocos.py DELETED
@@ -1,373 +0,0 @@
1
- # Copyright (c) 2025 SparkAudio
2
- # 2025 Xinsheng Wang ([email protected])
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
-
17
- import torch
18
- import torch.nn as nn
19
-
20
- from typing import Tuple
21
- from torch.nn.utils import weight_norm, remove_weight_norm
22
-
23
- from typing import Optional
24
-
25
-
26
- class ConvNeXtBlock(nn.Module):
27
- """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
28
-
29
- Args:
30
- dim (int): Number of input channels.
31
- intermediate_dim (int): Dimensionality of the intermediate layer.
32
- layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
33
- Defaults to None.
34
- adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
35
- None means non-conditional LayerNorm. Defaults to None.
36
- """
37
-
38
- def __init__(
39
- self,
40
- dim: int,
41
- intermediate_dim: int,
42
- layer_scale_init_value: float,
43
- condition_dim: Optional[int] = None,
44
- ):
45
- super().__init__()
46
- self.dwconv = nn.Conv1d(
47
- dim, dim, kernel_size=7, padding=3, groups=dim
48
- ) # depthwise conv
49
- self.adanorm = condition_dim is not None
50
- if condition_dim:
51
- self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
52
- else:
53
- self.norm = nn.LayerNorm(dim, eps=1e-6)
54
- self.pwconv1 = nn.Linear(
55
- dim, intermediate_dim
56
- ) # pointwise/1x1 convs, implemented with linear layers
57
- self.act = nn.GELU()
58
- self.pwconv2 = nn.Linear(intermediate_dim, dim)
59
- self.gamma = (
60
- nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
61
- if layer_scale_init_value > 0
62
- else None
63
- )
64
-
65
- def forward(
66
- self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
67
- ) -> torch.Tensor:
68
- residual = x
69
- x = self.dwconv(x)
70
- x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
71
- if self.adanorm:
72
- assert cond_embedding_id is not None
73
- x = self.norm(x, cond_embedding_id)
74
- else:
75
- x = self.norm(x)
76
- x = self.pwconv1(x)
77
- x = self.act(x)
78
- x = self.pwconv2(x)
79
- if self.gamma is not None:
80
- x = self.gamma * x
81
- x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
82
-
83
- x = residual + x
84
- return x
85
-
86
-
87
- class AdaLayerNorm(nn.Module):
88
- """
89
- Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
90
-
91
- Args:
92
- condition_dim (int): Dimension of the condition.
93
- embedding_dim (int): Dimension of the embeddings.
94
- """
95
-
96
- def __init__(self, condition_dim: int, embedding_dim: int, eps: float = 1e-6):
97
- super().__init__()
98
- self.eps = eps
99
- self.dim = embedding_dim
100
- self.scale = nn.Linear(condition_dim, embedding_dim)
101
- self.shift = nn.Linear(condition_dim, embedding_dim)
102
- torch.nn.init.ones_(self.scale.weight)
103
- torch.nn.init.zeros_(self.shift.weight)
104
-
105
- def forward(self, x: torch.Tensor, cond_embedding: torch.Tensor) -> torch.Tensor:
106
- scale = self.scale(cond_embedding)
107
- shift = self.shift(cond_embedding)
108
- x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
109
- x = x * scale.unsqueeze(1) + shift.unsqueeze(1)
110
- return x
111
-
112
-
113
- class ResBlock1(nn.Module):
114
- """
115
- ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
116
- but without upsampling layers.
117
-
118
- Args:
119
- dim (int): Number of input channels.
120
- kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
121
- dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
122
- Defaults to (1, 3, 5).
123
- lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
124
- Defaults to 0.1.
125
- layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
126
- Defaults to None.
127
- """
128
-
129
- def __init__(
130
- self,
131
- dim: int,
132
- kernel_size: int = 3,
133
- dilation: Tuple[int, int, int] = (1, 3, 5),
134
- lrelu_slope: float = 0.1,
135
- layer_scale_init_value: Optional[float] = None,
136
- ):
137
- super().__init__()
138
- self.lrelu_slope = lrelu_slope
139
- self.convs1 = nn.ModuleList(
140
- [
141
- weight_norm(
142
- nn.Conv1d(
143
- dim,
144
- dim,
145
- kernel_size,
146
- 1,
147
- dilation=dilation[0],
148
- padding=self.get_padding(kernel_size, dilation[0]),
149
- )
150
- ),
151
- weight_norm(
152
- nn.Conv1d(
153
- dim,
154
- dim,
155
- kernel_size,
156
- 1,
157
- dilation=dilation[1],
158
- padding=self.get_padding(kernel_size, dilation[1]),
159
- )
160
- ),
161
- weight_norm(
162
- nn.Conv1d(
163
- dim,
164
- dim,
165
- kernel_size,
166
- 1,
167
- dilation=dilation[2],
168
- padding=self.get_padding(kernel_size, dilation[2]),
169
- )
170
- ),
171
- ]
172
- )
173
-
174
- self.convs2 = nn.ModuleList(
175
- [
176
- weight_norm(
177
- nn.Conv1d(
178
- dim,
179
- dim,
180
- kernel_size,
181
- 1,
182
- dilation=1,
183
- padding=self.get_padding(kernel_size, 1),
184
- )
185
- ),
186
- weight_norm(
187
- nn.Conv1d(
188
- dim,
189
- dim,
190
- kernel_size,
191
- 1,
192
- dilation=1,
193
- padding=self.get_padding(kernel_size, 1),
194
- )
195
- ),
196
- weight_norm(
197
- nn.Conv1d(
198
- dim,
199
- dim,
200
- kernel_size,
201
- 1,
202
- dilation=1,
203
- padding=self.get_padding(kernel_size, 1),
204
- )
205
- ),
206
- ]
207
- )
208
-
209
- self.gamma = nn.ParameterList(
210
- [
211
- (
212
- nn.Parameter(
213
- layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
214
- )
215
- if layer_scale_init_value is not None
216
- else None
217
- ),
218
- (
219
- nn.Parameter(
220
- layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
221
- )
222
- if layer_scale_init_value is not None
223
- else None
224
- ),
225
- (
226
- nn.Parameter(
227
- layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
228
- )
229
- if layer_scale_init_value is not None
230
- else None
231
- ),
232
- ]
233
- )
234
-
235
- def forward(self, x: torch.Tensor) -> torch.Tensor:
236
- for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
237
- xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
238
- xt = c1(xt)
239
- xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
240
- xt = c2(xt)
241
- if gamma is not None:
242
- xt = gamma * xt
243
- x = xt + x
244
- return x
245
-
246
- def remove_weight_norm(self):
247
- for l in self.convs1:
248
- remove_weight_norm(l)
249
- for l in self.convs2:
250
- remove_weight_norm(l)
251
-
252
- @staticmethod
253
- def get_padding(kernel_size: int, dilation: int = 1) -> int:
254
- return int((kernel_size * dilation - dilation) / 2)
255
-
256
-
257
- class Backbone(nn.Module):
258
- """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
259
-
260
- def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
261
- """
262
- Args:
263
- x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
264
- C denotes output features, and L is the sequence length.
265
-
266
- Returns:
267
- Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
268
- and H denotes the model dimension.
269
- """
270
- raise NotImplementedError("Subclasses must implement the forward method.")
271
-
272
-
273
- class VocosBackbone(Backbone):
274
- """
275
- Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
276
-
277
- Args:
278
- input_channels (int): Number of input features channels.
279
- dim (int): Hidden dimension of the model.
280
- intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
281
- num_layers (int): Number of ConvNeXtBlock layers.
282
- layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
283
- adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
284
- None means non-conditional model. Defaults to None.
285
- """
286
-
287
- def __init__(
288
- self,
289
- input_channels: int,
290
- dim: int,
291
- intermediate_dim: int,
292
- num_layers: int,
293
- layer_scale_init_value: Optional[float] = None,
294
- condition_dim: Optional[int] = None,
295
- ):
296
- super().__init__()
297
- self.input_channels = input_channels
298
- self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
299
- self.adanorm = condition_dim is not None
300
- if condition_dim:
301
- self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
302
- else:
303
- self.norm = nn.LayerNorm(dim, eps=1e-6)
304
- layer_scale_init_value = layer_scale_init_value or 1 / num_layers
305
- self.convnext = nn.ModuleList(
306
- [
307
- ConvNeXtBlock(
308
- dim=dim,
309
- intermediate_dim=intermediate_dim,
310
- layer_scale_init_value=layer_scale_init_value,
311
- condition_dim=condition_dim,
312
- )
313
- for _ in range(num_layers)
314
- ]
315
- )
316
- self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
317
- self.apply(self._init_weights)
318
-
319
- def _init_weights(self, m):
320
- if isinstance(m, (nn.Conv1d, nn.Linear)):
321
- nn.init.trunc_normal_(m.weight, std=0.02)
322
- nn.init.constant_(m.bias, 0)
323
-
324
- def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> torch.Tensor:
325
- x = self.embed(x)
326
- if self.adanorm:
327
- assert condition is not None
328
- x = self.norm(x.transpose(1, 2), condition)
329
- else:
330
- x = self.norm(x.transpose(1, 2))
331
- x = x.transpose(1, 2)
332
- for conv_block in self.convnext:
333
- x = conv_block(x, condition)
334
- x = self.final_layer_norm(x.transpose(1, 2))
335
- return x
336
-
337
-
338
- class VocosResNetBackbone(Backbone):
339
- """
340
- Vocos backbone module built with ResBlocks.
341
-
342
- Args:
343
- input_channels (int): Number of input features channels.
344
- dim (int): Hidden dimension of the model.
345
- num_blocks (int): Number of ResBlock1 blocks.
346
- layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
347
- """
348
-
349
- def __init__(
350
- self,
351
- input_channels,
352
- dim,
353
- num_blocks,
354
- layer_scale_init_value=None,
355
- ):
356
- super().__init__()
357
- self.input_channels = input_channels
358
- self.embed = weight_norm(
359
- nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
360
- )
361
- layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
362
- self.resnet = nn.Sequential(
363
- *[
364
- ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
365
- for _ in range(num_blocks)
366
- ]
367
- )
368
-
369
- def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
370
- x = self.embed(x)
371
- x = self.resnet(x)
372
- x = x.transpose(1, 2)
373
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sparktts/modules/encoder_decoder/feat_decoder.py DELETED
@@ -1,115 +0,0 @@
1
- # Copyright (c) 2025 SparkAudio
2
- # 2025 Xinsheng Wang ([email protected])
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
-
17
- import torch
18
- import torch.nn as nn
19
-
20
- from typing import List
21
-
22
- from sparktts.modules.blocks.vocos import VocosBackbone
23
- from sparktts.modules.blocks.samper import SamplingBlock
24
-
25
-
26
- class Decoder(nn.Module):
27
- """Decoder module with convnext and upsampling blocks
28
-
29
- Args:
30
- sample_ratios (List[int]): sample ratios
31
- example: [2, 2] means downsample by 2x and then upsample by 2x
32
- """
33
-
34
- def __init__(
35
- self,
36
- input_channels: int,
37
- vocos_dim: int,
38
- vocos_intermediate_dim: int,
39
- vocos_num_layers: int,
40
- out_channels: int,
41
- condition_dim: int = None,
42
- sample_ratios: List[int] = [1, 1],
43
- use_tanh_at_final: bool = False,
44
- ):
45
- super().__init__()
46
-
47
- self.linear_pre = nn.Linear(input_channels, vocos_dim)
48
- modules = [
49
- nn.Sequential(
50
- SamplingBlock(
51
- dim=vocos_dim,
52
- groups=vocos_dim,
53
- upsample_scale=ratio,
54
- ),
55
- VocosBackbone(
56
- input_channels=vocos_dim,
57
- dim=vocos_dim,
58
- intermediate_dim=vocos_intermediate_dim,
59
- num_layers=2,
60
- condition_dim=None,
61
- ),
62
- )
63
- for ratio in sample_ratios
64
- ]
65
-
66
- self.downsample = nn.Sequential(*modules)
67
-
68
- self.vocos_backbone = VocosBackbone(
69
- input_channels=vocos_dim,
70
- dim=vocos_dim,
71
- intermediate_dim=vocos_intermediate_dim,
72
- num_layers=vocos_num_layers,
73
- condition_dim=condition_dim,
74
- )
75
- self.linear = nn.Linear(vocos_dim, out_channels)
76
- self.use_tanh_at_final = use_tanh_at_final
77
-
78
- def forward(self, x: torch.Tensor, c: torch.Tensor = None):
79
- """encoder forward.
80
-
81
- Args:
82
- x (torch.Tensor): (batch_size, input_channels, length)
83
-
84
- Returns:
85
- x (torch.Tensor): (batch_size, encode_channels, length)
86
- """
87
- x = self.linear_pre(x.transpose(1, 2))
88
- x = self.downsample(x).transpose(1, 2)
89
- x = self.vocos_backbone(x, condition=c)
90
- x = self.linear(x).transpose(1, 2)
91
- if self.use_tanh_at_final:
92
- x = torch.tanh(x)
93
-
94
- return x
95
-
96
-
97
- # test
98
- if __name__ == "__main__":
99
- test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
100
- condition = torch.randn(8, 256)
101
- decoder = Decoder(
102
- input_channels=1024,
103
- vocos_dim=384,
104
- vocos_intermediate_dim=2048,
105
- vocos_num_layers=12,
106
- out_channels=256,
107
- condition_dim=256,
108
- sample_ratios=[2, 2],
109
- )
110
- output = decoder(test_input, condition)
111
- print(output.shape) # torch.Size([8, 256, 200])
112
- if output.shape == torch.Size([8, 256, 200]):
113
- print("Decoder test passed")
114
- else:
115
- print("Decoder test failed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sparktts/modules/encoder_decoder/feat_encoder.py DELETED
@@ -1,105 +0,0 @@
1
- # Copyright (c) 2025 SparkAudio
2
- # 2025 Xinsheng Wang ([email protected])
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
-
17
- import torch
18
- import torch.nn as nn
19
-
20
- from typing import List
21
-
22
- from sparktts.modules.blocks.vocos import VocosBackbone
23
- from sparktts.modules.blocks.samper import SamplingBlock
24
-
25
-
26
- class Encoder(nn.Module):
27
- """Encoder module with convnext and downsampling blocks"""
28
-
29
- def __init__(
30
- self,
31
- input_channels: int,
32
- vocos_dim: int,
33
- vocos_intermediate_dim: int,
34
- vocos_num_layers: int,
35
- out_channels: int,
36
- sample_ratios: List[int] = [1, 1],
37
- ):
38
- super().__init__()
39
- """
40
- Encoder module with VocosBackbone and sampling blocks.
41
-
42
- Args:
43
- sample_ratios (List[int]): sample ratios
44
- example: [2, 2] means downsample by 2x and then upsample by 2x
45
- """
46
- self.encoder = VocosBackbone(
47
- input_channels=input_channels,
48
- dim=vocos_dim,
49
- intermediate_dim=vocos_intermediate_dim,
50
- num_layers=vocos_num_layers,
51
- condition_dim=None,
52
- )
53
-
54
- modules = [
55
- nn.Sequential(
56
- SamplingBlock(
57
- dim=vocos_dim,
58
- groups=vocos_dim,
59
- downsample_scale=ratio,
60
- ),
61
- VocosBackbone(
62
- input_channels=vocos_dim,
63
- dim=vocos_dim,
64
- intermediate_dim=vocos_intermediate_dim,
65
- num_layers=2,
66
- condition_dim=None,
67
- ),
68
- )
69
- for ratio in sample_ratios
70
- ]
71
-
72
- self.downsample = nn.Sequential(*modules)
73
-
74
- self.project = nn.Linear(vocos_dim, out_channels)
75
-
76
- def forward(self, x: torch.Tensor, *args):
77
- """
78
- Args:
79
- x (torch.Tensor): (batch_size, input_channels, length)
80
-
81
- Returns:
82
- x (torch.Tensor): (batch_size, encode_channels, length)
83
- """
84
- x = self.encoder(x)
85
- x = self.downsample(x)
86
- x = self.project(x)
87
- return x.transpose(1, 2)
88
-
89
-
90
- # test
91
- if __name__ == "__main__":
92
- test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
93
- encoder = Encoder(
94
- input_channels=1024,
95
- vocos_dim=384,
96
- vocos_intermediate_dim=2048,
97
- vocos_num_layers=12,
98
- out_channels=256,
99
- sample_ratios=[2, 2],
100
- )
101
-
102
- output = encoder(test_input)
103
- print(output.shape) # torch.Size([8, 256, 12])
104
- if output.shape == torch.Size([8, 256, 12]):
105
- print("test successful")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sparktts/modules/encoder_decoder/wave_generator.py DELETED
@@ -1,88 +0,0 @@
1
- # Copyright (c) 2024 Xinsheng Wang ([email protected])
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
16
-
17
-
18
- import torch.nn as nn
19
-
20
- from sparktts.modules.blocks.layers import (
21
- Snake1d,
22
- WNConv1d,
23
- ResidualUnit,
24
- WNConvTranspose1d,
25
- init_weights,
26
- )
27
-
28
-
29
- class DecoderBlock(nn.Module):
30
- def __init__(
31
- self,
32
- input_dim: int = 16,
33
- output_dim: int = 8,
34
- kernel_size: int = 2,
35
- stride: int = 1,
36
- ):
37
- super().__init__()
38
- self.block = nn.Sequential(
39
- Snake1d(input_dim),
40
- WNConvTranspose1d(
41
- input_dim,
42
- output_dim,
43
- kernel_size=kernel_size,
44
- stride=stride,
45
- padding=(kernel_size - stride) // 2,
46
- ),
47
- ResidualUnit(output_dim, dilation=1),
48
- ResidualUnit(output_dim, dilation=3),
49
- ResidualUnit(output_dim, dilation=9),
50
- )
51
-
52
- def forward(self, x):
53
- return self.block(x)
54
-
55
-
56
- class WaveGenerator(nn.Module):
57
- def __init__(
58
- self,
59
- input_channel,
60
- channels,
61
- rates,
62
- kernel_sizes,
63
- d_out: int = 1,
64
- ):
65
- super().__init__()
66
-
67
- # Add first conv layer
68
- layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
69
-
70
- # Add upsampling + MRF blocks
71
- for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)):
72
- input_dim = channels // 2**i
73
- output_dim = channels // 2 ** (i + 1)
74
- layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)]
75
-
76
- # Add final conv layer
77
- layers += [
78
- Snake1d(output_dim),
79
- WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
80
- nn.Tanh(),
81
- ]
82
-
83
- self.model = nn.Sequential(*layers)
84
-
85
- self.apply(init_weights)
86
-
87
- def forward(self, x):
88
- return self.model(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sparktts/modules/fsq/finite_scalar_quantization.py DELETED
@@ -1,251 +0,0 @@
1
- """
2
- Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
3
- Code adapted from Jax version in Appendix A.1
4
- """
5
-
6
- from __future__ import annotations
7
- from functools import wraps, partial
8
- from contextlib import nullcontext
9
- from typing import List, Tuple
10
-
11
- import torch
12
- import torch.nn as nn
13
- from torch.nn import Module
14
- from torch import Tensor, int32
15
- from torch.amp import autocast
16
-
17
- from einops import rearrange, pack, unpack
18
-
19
- # helper functions
20
-
21
-
22
- def exists(v):
23
- return v is not None
24
-
25
-
26
- def default(*args):
27
- for arg in args:
28
- if exists(arg):
29
- return arg
30
- return None
31
-
32
-
33
- def maybe(fn):
34
- @wraps(fn)
35
- def inner(x, *args, **kwargs):
36
- if not exists(x):
37
- return x
38
- return fn(x, *args, **kwargs)
39
-
40
- return inner
41
-
42
-
43
- def pack_one(t, pattern):
44
- return pack([t], pattern)
45
-
46
-
47
- def unpack_one(t, ps, pattern):
48
- return unpack(t, ps, pattern)[0]
49
-
50
-
51
- # tensor helpers
52
-
53
-
54
- def round_ste(z: Tensor) -> Tensor:
55
- """Round with straight through gradients."""
56
- zhat = z.round()
57
- return z + (zhat - z).detach()
58
-
59
-
60
- # main class
61
-
62
-
63
- class FSQ(Module):
64
- def __init__(
65
- self,
66
- levels: List[int],
67
- dim: int | None = None,
68
- num_codebooks=1,
69
- keep_num_codebooks_dim: bool | None = None,
70
- scale: float | None = None,
71
- allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
72
- channel_first: bool = False,
73
- projection_has_bias: bool = True,
74
- return_indices=True,
75
- force_quantization_f32=True,
76
- ):
77
- super().__init__()
78
- _levels = torch.tensor(levels, dtype=int32)
79
- self.register_buffer("_levels", _levels, persistent=False)
80
-
81
- _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
82
- self.register_buffer("_basis", _basis, persistent=False)
83
-
84
- self.scale = scale
85
-
86
- codebook_dim = len(levels)
87
- self.codebook_dim = codebook_dim
88
-
89
- effective_codebook_dim = codebook_dim * num_codebooks
90
- self.num_codebooks = num_codebooks
91
- self.effective_codebook_dim = effective_codebook_dim
92
-
93
- keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
94
- assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
95
- self.keep_num_codebooks_dim = keep_num_codebooks_dim
96
-
97
- self.dim = default(dim, len(_levels) * num_codebooks)
98
-
99
- self.channel_first = channel_first
100
-
101
- has_projections = self.dim != effective_codebook_dim
102
- self.project_in = (
103
- nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias)
104
- if has_projections
105
- else nn.Identity()
106
- )
107
- self.project_out = (
108
- nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias)
109
- if has_projections
110
- else nn.Identity()
111
- )
112
-
113
- self.has_projections = has_projections
114
-
115
- self.return_indices = return_indices
116
- if return_indices:
117
- self.codebook_size = self._levels.prod().item()
118
- implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
119
- self.register_buffer(
120
- "implicit_codebook", implicit_codebook, persistent=False
121
- )
122
-
123
- self.allowed_dtypes = allowed_dtypes
124
- self.force_quantization_f32 = force_quantization_f32
125
-
126
- def bound(self, z, eps: float = 1e-3):
127
- """Bound `z`, an array of shape (..., d)."""
128
- half_l = (self._levels - 1) * (1 + eps) / 2
129
- offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
130
- shift = (offset / half_l).atanh()
131
- return (z + shift).tanh() * half_l - offset
132
-
133
- def quantize(self, z):
134
- """Quantizes z, returns quantized zhat, same shape as z."""
135
- quantized = round_ste(self.bound(z))
136
- half_width = self._levels // 2 # Renormalize to [-1, 1].
137
- return quantized / half_width
138
-
139
- def _scale_and_shift(self, zhat_normalized):
140
- half_width = self._levels // 2
141
- return (zhat_normalized * half_width) + half_width
142
-
143
- def _scale_and_shift_inverse(self, zhat):
144
- half_width = self._levels // 2
145
- return (zhat - half_width) / half_width
146
-
147
- def _indices_to_codes(self, indices):
148
- level_indices = self.indices_to_level_indices(indices)
149
- codes = self._scale_and_shift_inverse(level_indices)
150
- return codes
151
-
152
- def codes_to_indices(self, zhat):
153
- """Converts a `code` to an index in the codebook."""
154
- assert zhat.shape[-1] == self.codebook_dim
155
- zhat = self._scale_and_shift(zhat)
156
- return (zhat * self._basis).sum(dim=-1).to(int32)
157
-
158
- def indices_to_level_indices(self, indices):
159
- """Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings"""
160
- indices = rearrange(indices, "... -> ... 1")
161
- codes_non_centered = (indices // self._basis) % self._levels
162
- return codes_non_centered
163
-
164
- def indices_to_codes(self, indices):
165
- """Inverse of `codes_to_indices`."""
166
- assert exists(indices)
167
-
168
- is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
169
-
170
- codes = self._indices_to_codes(indices)
171
-
172
- if self.keep_num_codebooks_dim:
173
- codes = rearrange(codes, "... c d -> ... (c d)")
174
-
175
- codes = self.project_out(codes)
176
-
177
- if is_img_or_video or self.channel_first:
178
- codes = rearrange(codes, "b ... d -> b d ...")
179
-
180
- return codes
181
-
182
- def forward(self, z):
183
- """
184
- einstein notation
185
- b - batch
186
- n - sequence (or flattened spatial dimensions)
187
- d - feature dimension
188
- c - number of codebook dim
189
- """
190
-
191
- is_img_or_video = z.ndim >= 4
192
- need_move_channel_last = is_img_or_video or self.channel_first
193
-
194
- # standardize image or video into (batch, seq, dimension)
195
-
196
- if need_move_channel_last:
197
- z = rearrange(z, "b d ... -> b ... d")
198
- z, ps = pack_one(z, "b * d")
199
-
200
- assert (
201
- z.shape[-1] == self.dim
202
- ), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
203
-
204
- z = self.project_in(z)
205
-
206
- z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
207
-
208
- # whether to force quantization step to be full precision or not
209
-
210
- force_f32 = self.force_quantization_f32
211
- quantization_context = (
212
- partial(autocast, "cuda", enabled=False) if force_f32 else nullcontext
213
- )
214
-
215
- with quantization_context():
216
- orig_dtype = z.dtype
217
-
218
- if force_f32 and orig_dtype not in self.allowed_dtypes:
219
- z = z.float()
220
-
221
- codes = self.quantize(z)
222
-
223
- # returning indices could be optional
224
-
225
- indices = None
226
-
227
- if self.return_indices:
228
- indices = self.codes_to_indices(codes)
229
-
230
- codes = rearrange(codes, "b n c d -> b n (c d)")
231
-
232
- codes = codes.type(orig_dtype)
233
-
234
- # project out
235
-
236
- out = self.project_out(codes)
237
-
238
- # reconstitute image or video dimensions
239
-
240
- if need_move_channel_last:
241
- out = unpack_one(out, ps, "b * d")
242
- out = rearrange(out, "b ... d -> b d ...")
243
-
244
- indices = maybe(unpack_one)(indices, ps, "b * c")
245
-
246
- if not self.keep_num_codebooks_dim and self.return_indices:
247
- indices = maybe(rearrange)(indices, "... 1 -> ...")
248
-
249
- # return quantized output and indices
250
-
251
- return out, indices
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sparktts/modules/fsq/residual_fsq.py DELETED
@@ -1,355 +0,0 @@
1
- import random
2
- import torch
3
- import torch.nn.functional as F
4
- import torch.distributed as dist
5
-
6
- from typing import List
7
- from torch import nn
8
- from torch.nn import Module
9
- from torch.amp import autocast
10
- from einx import get_at
11
- from einops import rearrange, reduce, pack, unpack
12
-
13
- from sparktts.modules.fsq.finite_scalar_quantization import FSQ
14
-
15
-
16
- def exists(val):
17
- return val is not None
18
-
19
-
20
- def first(l):
21
- return l[0]
22
-
23
-
24
- def default(val, d):
25
- return val if exists(val) else d
26
-
27
-
28
- def round_up_multiple(num, mult):
29
- return ceil(num / mult) * mult
30
-
31
-
32
- # distributed helpers
33
-
34
-
35
- def is_distributed():
36
- return dist.is_initialized() and dist.get_world_size() > 1
37
-
38
-
39
- def get_maybe_sync_seed(device, max_size=10_000):
40
- rand_int = torch.randint(0, max_size, (), device=device)
41
-
42
- if is_distributed():
43
- dist.all_reduce(rand_int)
44
-
45
- return rand_int.item()
46
-
47
-
48
- class ResidualFSQ(Module):
49
- """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
50
-
51
- def __init__(
52
- self,
53
- *,
54
- levels: List[int],
55
- num_quantizers,
56
- dim=None,
57
- is_channel_first=False,
58
- quantize_dropout=False,
59
- quantize_dropout_cutoff_index=0,
60
- quantize_dropout_multiple_of=1,
61
- **kwargs,
62
- ):
63
- super().__init__()
64
- codebook_dim = len(levels)
65
- dim = default(dim, codebook_dim)
66
-
67
- requires_projection = codebook_dim != dim
68
- self.project_in = (
69
- nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
70
- )
71
- self.project_out = (
72
- nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
73
- )
74
- self.has_projections = requires_projection
75
-
76
- self.is_channel_first = is_channel_first
77
- self.num_quantizers = num_quantizers
78
-
79
- self.levels = levels
80
- self.layers = nn.ModuleList([])
81
-
82
- levels_tensor = torch.Tensor(levels)
83
-
84
- scales = []
85
-
86
- for ind in range(num_quantizers):
87
- scales.append((levels_tensor - 1) ** -ind)
88
-
89
- fsq = FSQ(levels=levels, dim=codebook_dim, **kwargs)
90
-
91
- self.layers.append(fsq)
92
-
93
- assert all([not fsq.has_projections for fsq in self.layers])
94
-
95
- self.codebook_size = self.layers[0].codebook_size
96
-
97
- self.register_buffer("scales", torch.stack(scales), persistent=False)
98
-
99
- self.quantize_dropout = quantize_dropout and num_quantizers > 1
100
-
101
- assert quantize_dropout_cutoff_index >= 0
102
-
103
- self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
104
- self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
105
-
106
- @property
107
- def codebooks(self):
108
- codebooks = [layer.implicit_codebook for layer in self.layers]
109
- codebooks = torch.stack(codebooks, dim=0)
110
- return codebooks
111
-
112
- def get_codes_from_indices(self, indices):
113
-
114
- batch, quantize_dim = indices.shape[0], indices.shape[-1]
115
-
116
- # may also receive indices in the shape of 'b h w q' (accept_image_fmap)
117
-
118
- indices, ps = pack([indices], "b * q")
119
-
120
- # because of quantize dropout, one can pass in indices that are coarse
121
- # and the network should be able to reconstruct
122
-
123
- if quantize_dim < self.num_quantizers:
124
- assert (
125
- self.quantize_dropout > 0.0
126
- ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
127
- indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
128
-
129
- # take care of quantizer dropout
130
-
131
- mask = indices == -1
132
- indices = indices.masked_fill(
133
- mask, 0
134
- ) # have it fetch a dummy code to be masked out later
135
-
136
- all_codes = get_at("q [c] d, b n q -> q b n d", self.codebooks, indices)
137
-
138
- # mask out any codes that were dropout-ed
139
-
140
- all_codes = all_codes.masked_fill(rearrange(mask, "b n q -> q b n 1"), 0.0)
141
-
142
- # scale the codes
143
-
144
- scales = rearrange(self.scales, "q d -> q 1 1 d")
145
- all_codes = all_codes * scales
146
-
147
- # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
148
-
149
- (all_codes,) = unpack(all_codes, ps, "q b * d")
150
-
151
- return all_codes
152
-
153
- def get_output_from_indices(self, indices):
154
- codes = self.get_codes_from_indices(indices)
155
- codes_summed = reduce(codes, "q ... -> ...", "sum")
156
- return self.project_out(codes_summed)
157
-
158
- def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None):
159
- num_quant, quant_dropout_multiple_of, device = (
160
- self.num_quantizers,
161
- self.quantize_dropout_multiple_of,
162
- x.device,
163
- )
164
-
165
- # handle channel first
166
-
167
- if self.is_channel_first:
168
- x = rearrange(x, "b d ... -> b ... d")
169
- x, ps = pack([x], "b * d")
170
-
171
- # maybe project in
172
-
173
- x = self.project_in(x)
174
-
175
- quantized_out = 0.0
176
- residual = x
177
-
178
- all_indices = []
179
-
180
- should_quantize_dropout = self.training and self.quantize_dropout
181
-
182
- # sample a layer index at which to dropout further residual quantization
183
- # also prepare null indices
184
-
185
- if should_quantize_dropout:
186
-
187
- # check if seed is manually passed in
188
-
189
- if not exists(rand_quantize_dropout_fixed_seed):
190
- rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
191
-
192
- rand = random.Random(rand_quantize_dropout_fixed_seed)
193
-
194
- rand_quantize_dropout_index = rand.randrange(
195
- self.quantize_dropout_cutoff_index, num_quant
196
- )
197
-
198
- if quant_dropout_multiple_of != 1:
199
- rand_quantize_dropout_index = (
200
- round_up_multiple(
201
- rand_quantize_dropout_index + 1, quant_dropout_multiple_of
202
- )
203
- - 1
204
- )
205
-
206
- null_indices = torch.full(
207
- x.shape[:2], -1.0, device=device, dtype=torch.long
208
- )
209
-
210
- # go through the layers
211
-
212
- with autocast("cuda", enabled=False):
213
- for quantizer_index, (layer, scale) in enumerate(
214
- zip(self.layers, self.scales)
215
- ):
216
-
217
- if (
218
- should_quantize_dropout
219
- and quantizer_index > rand_quantize_dropout_index
220
- ):
221
- all_indices.append(null_indices)
222
- continue
223
-
224
- quantized, indices = layer(residual / scale)
225
-
226
- quantized = quantized * scale
227
-
228
- residual = residual - quantized.detach()
229
- quantized_out = quantized_out + quantized
230
-
231
- all_indices.append(indices)
232
-
233
- # project out, if needed
234
-
235
- quantized_out = self.project_out(quantized_out)
236
-
237
- # stack all indices
238
-
239
- all_indices = torch.stack(all_indices, dim=-1)
240
-
241
- # channel first out
242
-
243
- if self.is_channel_first:
244
- (quantized_out,) = unpack(quantized_out, ps, "b * d")
245
- (all_indices,) = unpack(all_indices, ps, "b * d")
246
-
247
- quantized_out = rearrange(quantized_out, "b ... d -> b d ...")
248
- all_indices = rearrange(all_indices, "b ... d -> b d ...")
249
-
250
- # return
251
-
252
- ret = (quantized_out, all_indices)
253
-
254
- if not return_all_codes:
255
- return ret
256
-
257
- # whether to return all codes from all codebooks across layers
258
-
259
- all_codes = self.get_codes_from_indices(all_indices)
260
-
261
- # will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
262
-
263
- return (*ret, all_codes)
264
-
265
-
266
- # grouped residual fsq
267
-
268
-
269
- class GroupedResidualFSQ(Module):
270
- def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs):
271
- super().__init__()
272
- self.dim = dim
273
- self.groups = groups
274
- assert (dim % groups) == 0
275
- dim_per_group = dim // groups
276
-
277
- self.accept_image_fmap = accept_image_fmap
278
-
279
- self.rvqs = nn.ModuleList([])
280
-
281
- for _ in range(groups):
282
- self.rvqs.append(ResidualFSQ(dim=dim_per_group, **kwargs))
283
-
284
- self.codebook_size = self.rvqs[0].codebook_size
285
-
286
- @property
287
- def codebooks(self):
288
- return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
289
-
290
- @property
291
- def split_dim(self):
292
- return 1 if self.accept_image_fmap else -1
293
-
294
- def get_codes_from_indices(self, indices):
295
- codes = tuple(
296
- rvq.get_codes_from_indices(chunk_indices)
297
- for rvq, chunk_indices in zip(self.rvqs, indices)
298
- )
299
- return torch.stack(codes)
300
-
301
- def get_output_from_indices(self, indices):
302
- outputs = tuple(
303
- rvq.get_output_from_indices(chunk_indices)
304
- for rvq, chunk_indices in zip(self.rvqs, indices)
305
- )
306
- return torch.cat(outputs, dim=self.split_dim)
307
-
308
- def forward(self, x, return_all_codes=False):
309
- shape, split_dim, device = x.shape, self.split_dim, x.device
310
- assert shape[split_dim] == self.dim
311
-
312
- # split the feature dimension into groups
313
-
314
- x = x.chunk(self.groups, dim=split_dim)
315
-
316
- forward_kwargs = dict(
317
- return_all_codes=return_all_codes,
318
- rand_quantize_dropout_fixed_seed=(
319
- get_maybe_sync_seed(device) if self.training else None
320
- ),
321
- )
322
-
323
- # invoke residual vq on each group
324
-
325
- out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
326
- out = tuple(zip(*out))
327
-
328
- # otherwise, get all the zipped outputs and combine them
329
-
330
- quantized, all_indices, *maybe_all_codes = out
331
-
332
- quantized = torch.cat(quantized, dim=split_dim)
333
- all_indices = torch.stack(all_indices)
334
-
335
- ret = (quantized, all_indices, *maybe_all_codes)
336
- return ret
337
-
338
-
339
- if __name__ == "__main__":
340
- model = ResidualFSQ(
341
- levels=[4, 4, 4, 4, 4, 4],
342
- num_quantizers=1,
343
- dim=30,
344
- is_channel_first=True,
345
- quantize_dropout=False,
346
- )
347
- x = torch.randn(2, 30, 10)
348
- quantize, embed_ind = model(x)
349
-
350
- emb_from_ind = model.get_output_from_indices(embed_ind.transpose(1, 2))
351
-
352
- print(quantize == emb_from_ind.transpose(1, 2))
353
-
354
- print("quantize shape", quantize.shape)
355
- print("embed_ind", embed_ind)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sparktts/modules/speaker/ecapa_tdnn.py DELETED
@@ -1,267 +0,0 @@
1
- # Copyright (c) 2021 Zhengyang Chen ([email protected])
2
- # 2022 Hongji Wang ([email protected])
3
- # 2023 Bing Han ([email protected])
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
- """ This implementation is adapted from github repo:
18
- https://github.com/lawlict/ECAPA-TDNN.
19
- """
20
-
21
- import torch
22
- import torch.nn as nn
23
- import torch.nn.functional as F
24
-
25
- import sparktts.modules.speaker.pooling_layers as pooling_layers
26
-
27
-
28
- class Res2Conv1dReluBn(nn.Module):
29
- """
30
- in_channels == out_channels == channels
31
- """
32
-
33
- def __init__(
34
- self,
35
- channels,
36
- kernel_size=1,
37
- stride=1,
38
- padding=0,
39
- dilation=1,
40
- bias=True,
41
- scale=4,
42
- ):
43
- super().__init__()
44
- assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
45
- self.scale = scale
46
- self.width = channels // scale
47
- self.nums = scale if scale == 1 else scale - 1
48
-
49
- self.convs = []
50
- self.bns = []
51
- for i in range(self.nums):
52
- self.convs.append(
53
- nn.Conv1d(
54
- self.width,
55
- self.width,
56
- kernel_size,
57
- stride,
58
- padding,
59
- dilation,
60
- bias=bias,
61
- )
62
- )
63
- self.bns.append(nn.BatchNorm1d(self.width))
64
- self.convs = nn.ModuleList(self.convs)
65
- self.bns = nn.ModuleList(self.bns)
66
-
67
- def forward(self, x):
68
- out = []
69
- spx = torch.split(x, self.width, 1)
70
- sp = spx[0]
71
- for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
72
- # Order: conv -> relu -> bn
73
- if i >= 1:
74
- sp = sp + spx[i]
75
- sp = conv(sp)
76
- sp = bn(F.relu(sp))
77
- out.append(sp)
78
- if self.scale != 1:
79
- out.append(spx[self.nums])
80
- out = torch.cat(out, dim=1)
81
-
82
- return out
83
-
84
-
85
- """ Conv1d + BatchNorm1d + ReLU
86
- """
87
-
88
-
89
- class Conv1dReluBn(nn.Module):
90
-
91
- def __init__(
92
- self,
93
- in_channels,
94
- out_channels,
95
- kernel_size=1,
96
- stride=1,
97
- padding=0,
98
- dilation=1,
99
- bias=True,
100
- ):
101
- super().__init__()
102
- self.conv = nn.Conv1d(
103
- in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
104
- )
105
- self.bn = nn.BatchNorm1d(out_channels)
106
-
107
- def forward(self, x):
108
- return self.bn(F.relu(self.conv(x)))
109
-
110
-
111
- """ The SE connection of 1D case.
112
- """
113
-
114
-
115
- class SE_Connect(nn.Module):
116
-
117
- def __init__(self, channels, se_bottleneck_dim=128):
118
- super().__init__()
119
- self.linear1 = nn.Linear(channels, se_bottleneck_dim)
120
- self.linear2 = nn.Linear(se_bottleneck_dim, channels)
121
-
122
- def forward(self, x):
123
- out = x.mean(dim=2)
124
- out = F.relu(self.linear1(out))
125
- out = torch.sigmoid(self.linear2(out))
126
- out = x * out.unsqueeze(2)
127
-
128
- return out
129
-
130
-
131
- """ SE-Res2Block of the ECAPA-TDNN architecture.
132
- """
133
-
134
-
135
- class SE_Res2Block(nn.Module):
136
-
137
- def __init__(self, channels, kernel_size, stride, padding, dilation, scale):
138
- super().__init__()
139
- self.se_res2block = nn.Sequential(
140
- Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
141
- Res2Conv1dReluBn(
142
- channels, kernel_size, stride, padding, dilation, scale=scale
143
- ),
144
- Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
145
- SE_Connect(channels),
146
- )
147
-
148
- def forward(self, x):
149
- return x + self.se_res2block(x)
150
-
151
-
152
- class ECAPA_TDNN(nn.Module):
153
-
154
- def __init__(
155
- self,
156
- channels=512,
157
- feat_dim=80,
158
- embed_dim=192,
159
- pooling_func="ASTP",
160
- global_context_att=False,
161
- emb_bn=False,
162
- ):
163
- super().__init__()
164
-
165
- self.layer1 = Conv1dReluBn(feat_dim, channels, kernel_size=5, padding=2)
166
- self.layer2 = SE_Res2Block(
167
- channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8
168
- )
169
- self.layer3 = SE_Res2Block(
170
- channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8
171
- )
172
- self.layer4 = SE_Res2Block(
173
- channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8
174
- )
175
-
176
- cat_channels = channels * 3
177
- out_channels = 512 * 3
178
- self.conv = nn.Conv1d(cat_channels, out_channels, kernel_size=1)
179
- self.pool = getattr(pooling_layers, pooling_func)(
180
- in_dim=out_channels, global_context_att=global_context_att
181
- )
182
- self.pool_out_dim = self.pool.get_out_dim()
183
- self.bn = nn.BatchNorm1d(self.pool_out_dim)
184
- self.linear = nn.Linear(self.pool_out_dim, embed_dim)
185
- self.emb_bn = emb_bn
186
- if emb_bn: # better in SSL for SV
187
- self.bn2 = nn.BatchNorm1d(embed_dim)
188
- else:
189
- self.bn2 = nn.Identity()
190
-
191
- def forward(self, x, return_latent=False):
192
- x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
193
-
194
- out1 = self.layer1(x)
195
- out2 = self.layer2(out1)
196
- out3 = self.layer3(out2)
197
- out4 = self.layer4(out3)
198
-
199
- out = torch.cat([out2, out3, out4], dim=1)
200
- latent = F.relu(self.conv(out))
201
- out = self.bn(self.pool(latent))
202
- out = self.linear(out)
203
- if self.emb_bn:
204
- out = self.bn2(out)
205
-
206
- if return_latent:
207
- return out, latent
208
- return out
209
-
210
-
211
- def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
212
- return ECAPA_TDNN(
213
- channels=1024,
214
- feat_dim=feat_dim,
215
- embed_dim=embed_dim,
216
- pooling_func=pooling_func,
217
- emb_bn=emb_bn,
218
- )
219
-
220
-
221
- def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
222
- return ECAPA_TDNN(
223
- channels=1024,
224
- feat_dim=feat_dim,
225
- embed_dim=embed_dim,
226
- pooling_func=pooling_func,
227
- global_context_att=True,
228
- emb_bn=emb_bn,
229
- )
230
-
231
-
232
- def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
233
- return ECAPA_TDNN(
234
- channels=512,
235
- feat_dim=feat_dim,
236
- embed_dim=embed_dim,
237
- pooling_func=pooling_func,
238
- emb_bn=emb_bn,
239
- )
240
-
241
-
242
- def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
243
- return ECAPA_TDNN(
244
- channels=512,
245
- feat_dim=feat_dim,
246
- embed_dim=embed_dim,
247
- pooling_func=pooling_func,
248
- global_context_att=True,
249
- emb_bn=emb_bn,
250
- )
251
-
252
-
253
- if __name__ == "__main__":
254
- x = torch.zeros(1, 200, 100)
255
- model = ECAPA_TDNN_GLOB_c512(feat_dim=100, embed_dim=256, pooling_func="ASTP")
256
- model.eval()
257
- out, latent = model(x, True)
258
- print(out.shape)
259
- print(latent.shape)
260
-
261
- num_params = sum(param.numel() for param in model.parameters())
262
- print("{} M".format(num_params / 1e6))
263
-
264
- # from thop import profile
265
- # x_np = torch.randn(1, 200, 80)
266
- # flops, params = profile(model, inputs=(x_np, ))
267
- # print("FLOPs: {} G, Params: {} M".format(flops / 1e9, params / 1e6))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sparktts/modules/speaker/perceiver_encoder.py DELETED
@@ -1,360 +0,0 @@
1
- # Copyright (c) 2025 SparkAudio
2
- # 2025 Xinsheng Wang ([email protected])
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- # Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
17
-
18
- from collections import namedtuple
19
- from functools import wraps
20
-
21
- import torch
22
- import torch.nn.functional as F
23
- from einops import rearrange, repeat
24
- from einops.layers.torch import Rearrange
25
- from packaging import version
26
- from torch import einsum, nn
27
-
28
-
29
- def exists(val):
30
- return val is not None
31
-
32
-
33
- def once(fn):
34
- called = False
35
-
36
- @wraps(fn)
37
- def inner(x):
38
- nonlocal called
39
- if called:
40
- return
41
- called = True
42
- return fn(x)
43
-
44
- return inner
45
-
46
-
47
- print_once = once(print)
48
-
49
- # main class
50
-
51
-
52
- class Attend(nn.Module):
53
- def __init__(self, dropout=0.0, causal=False, use_flash=False):
54
- super().__init__()
55
- self.dropout = dropout
56
- self.attn_dropout = nn.Dropout(dropout)
57
-
58
- self.causal = causal
59
- self.register_buffer("mask", None, persistent=False)
60
-
61
- self.use_flash = use_flash
62
- assert not (
63
- use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
64
- ), "in order to use flash attention, you must be using pytorch 2.0 or above"
65
-
66
- # determine efficient attention configs for cuda and cpu
67
- self.config = namedtuple(
68
- "EfficientAttentionConfig",
69
- ["enable_flash", "enable_math", "enable_mem_efficient"],
70
- )
71
- self.cpu_config = self.config(True, True, True)
72
- self.cuda_config = None
73
-
74
- if not torch.cuda.is_available() or not use_flash:
75
- return
76
-
77
- device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
78
-
79
- if device_properties.major == 8 and device_properties.minor == 0:
80
- print_once(
81
- "A100 GPU detected, using flash attention if input tensor is on cuda"
82
- )
83
- self.cuda_config = self.config(True, False, False)
84
- else:
85
- print_once(
86
- "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
87
- )
88
- self.cuda_config = self.config(False, True, True)
89
-
90
- def get_mask(self, n, device):
91
- if exists(self.mask) and self.mask.shape[-1] >= n:
92
- return self.mask[:n, :n]
93
-
94
- mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
95
- self.register_buffer("mask", mask, persistent=False)
96
- return mask
97
-
98
- def flash_attn(self, q, k, v, mask=None):
99
- _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
100
-
101
- # Recommended for multi-query single-key-value attention by Tri Dao
102
- # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
103
-
104
- if k.ndim == 3:
105
- k = rearrange(k, "b ... -> b 1 ...").expand_as(q)
106
-
107
- if v.ndim == 3:
108
- v = rearrange(v, "b ... -> b 1 ...").expand_as(q)
109
-
110
- # Check if mask exists and expand to compatible shape
111
- # The mask is B L, so it would have to be expanded to B H N L
112
-
113
- if exists(mask):
114
- mask = rearrange(mask, "b j -> b 1 1 j")
115
- mask = mask.expand(-1, heads, q_len, -1)
116
-
117
- # Check if there is a compatible device for flash attention
118
-
119
- config = self.cuda_config if is_cuda else self.cpu_config
120
-
121
- # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
122
-
123
- with torch.backends.cuda.sdp_kernel(**config._asdict()):
124
- out = F.scaled_dot_product_attention(
125
- q,
126
- k,
127
- v,
128
- attn_mask=mask,
129
- dropout_p=self.dropout if self.training else 0.0,
130
- is_causal=self.causal,
131
- )
132
-
133
- return out
134
-
135
- def forward(self, q, k, v, mask=None):
136
- """
137
- einstein notation
138
- b - batch
139
- h - heads
140
- n, i, j - sequence length (base sequence length, source, target)
141
- d - feature dimension
142
- """
143
-
144
- n, device = q.shape[-2], q.device
145
-
146
- scale = q.shape[-1] ** -0.5
147
-
148
- if self.use_flash:
149
- return self.flash_attn(q, k, v, mask=mask)
150
-
151
- kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d"
152
-
153
- # similarity
154
-
155
- sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
156
-
157
- # key padding mask
158
-
159
- if exists(mask):
160
- mask = rearrange(mask, "b j -> b 1 1 j")
161
- sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
162
-
163
- # causal mask
164
-
165
- if self.causal:
166
- causal_mask = self.get_mask(n, device)
167
- sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
168
-
169
- # attention
170
-
171
- attn = sim.softmax(dim=-1)
172
- attn = self.attn_dropout(attn)
173
-
174
- # aggregate values
175
-
176
- out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
177
-
178
- return out
179
-
180
-
181
- def Sequential(*mods):
182
- return nn.Sequential(*filter(exists, mods))
183
-
184
-
185
- def exists(x):
186
- return x is not None
187
-
188
-
189
- def default(val, d):
190
- if exists(val):
191
- return val
192
- return d() if callable(d) else d
193
-
194
-
195
- class RMSNorm(nn.Module):
196
- def __init__(self, dim, scale=True, dim_cond=None):
197
- super().__init__()
198
- self.cond = exists(dim_cond)
199
- self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
200
-
201
- self.scale = dim**0.5
202
- self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
203
-
204
- def forward(self, x, cond=None):
205
- gamma = default(self.gamma, 1)
206
- out = F.normalize(x, dim=-1) * self.scale * gamma
207
-
208
- if not self.cond:
209
- return out
210
-
211
- assert exists(cond)
212
- gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1)
213
- gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta))
214
- return out * gamma + beta
215
-
216
-
217
- class CausalConv1d(nn.Conv1d):
218
- def __init__(self, *args, **kwargs):
219
- super().__init__(*args, **kwargs)
220
- (kernel_size,) = self.kernel_size
221
- (dilation,) = self.dilation
222
- (stride,) = self.stride
223
-
224
- assert stride == 1
225
- self.causal_padding = dilation * (kernel_size - 1)
226
-
227
- def forward(self, x):
228
- causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
229
- return super().forward(causal_padded_x)
230
-
231
-
232
- class GEGLU(nn.Module):
233
- def forward(self, x):
234
- x, gate = x.chunk(2, dim=-1)
235
- return F.gelu(gate) * x
236
-
237
-
238
- def FeedForward(dim, mult=4, causal_conv=False):
239
- dim_inner = int(dim * mult * 2 / 3)
240
-
241
- conv = None
242
- if causal_conv:
243
- conv = nn.Sequential(
244
- Rearrange("b n d -> b d n"),
245
- CausalConv1d(dim_inner, dim_inner, 3),
246
- Rearrange("b d n -> b n d"),
247
- )
248
-
249
- return Sequential(
250
- nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim)
251
- )
252
-
253
-
254
- class Attention(nn.Module):
255
- def __init__(
256
- self,
257
- dim,
258
- *,
259
- dim_context=None,
260
- causal=False,
261
- dim_head=64,
262
- heads=8,
263
- dropout=0.0,
264
- use_flash=False,
265
- cross_attn_include_queries=False,
266
- ):
267
- super().__init__()
268
- self.scale = dim_head**-0.5
269
- self.heads = heads
270
- self.cross_attn_include_queries = cross_attn_include_queries
271
-
272
- dim_inner = dim_head * heads
273
- dim_context = default(dim_context, dim)
274
-
275
- self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
276
- self.to_q = nn.Linear(dim, dim_inner, bias=False)
277
- self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
278
- self.to_out = nn.Linear(dim_inner, dim, bias=False)
279
-
280
- def forward(self, x, context=None, mask=None):
281
- h, has_context = self.heads, exists(context)
282
-
283
- context = default(context, x)
284
-
285
- if has_context and self.cross_attn_include_queries:
286
- context = torch.cat((x, context), dim=-2)
287
-
288
- q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
289
- q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
290
-
291
- out = self.attend(q, k, v, mask=mask)
292
-
293
- out = rearrange(out, "b h n d -> b n (h d)")
294
- return self.to_out(out)
295
-
296
-
297
- class PerceiverResampler(nn.Module):
298
- def __init__(
299
- self,
300
- *,
301
- dim,
302
- depth=2,
303
- dim_context=None,
304
- num_latents=32,
305
- dim_head=64,
306
- heads=8,
307
- ff_mult=4,
308
- use_flash_attn=False,
309
- ):
310
- super().__init__()
311
- dim_context = default(dim_context, dim)
312
-
313
- self.proj_context = (
314
- nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
315
- )
316
-
317
- self.latents = nn.Parameter(torch.randn(num_latents, dim))
318
- nn.init.normal_(self.latents, std=0.02)
319
-
320
- self.layers = nn.ModuleList([])
321
- for _ in range(depth):
322
- self.layers.append(
323
- nn.ModuleList(
324
- [
325
- Attention(
326
- dim=dim,
327
- dim_head=dim_head,
328
- heads=heads,
329
- use_flash=use_flash_attn,
330
- cross_attn_include_queries=True,
331
- ),
332
- FeedForward(dim=dim, mult=ff_mult),
333
- ]
334
- )
335
- )
336
-
337
- self.norm = RMSNorm(dim)
338
-
339
- def forward(self, x, mask=None):
340
- batch = x.shape[0]
341
-
342
- x = self.proj_context(x)
343
-
344
- latents = repeat(self.latents, "n d -> b n d", b=batch)
345
-
346
- for attn, ff in self.layers:
347
- latents = attn(latents, x, mask=mask) + latents
348
- latents = ff(latents) + latents
349
-
350
- return self.norm(latents)
351
-
352
-
353
- if __name__ == "__main__":
354
- model = PerceiverResampler(dim=256, dim_context=80)
355
- x = torch.randn(8, 200, 80)
356
- out = model(x)
357
- print(out.shape) # [8, 32, 80]
358
-
359
- num_params = sum(param.numel() for param in model.parameters())
360
- print("{} M".format(num_params / 1e6))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sparktts/modules/speaker/pooling_layers.py DELETED
@@ -1,298 +0,0 @@
1
- # Copyright (c) 2021 Shuai Wang ([email protected])
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- """
15
- Pooling functions to aggregate frame-level deep features
16
- into segment-level speaker embeddings
17
-
18
- High-order statistics are surprisingly effective, TSDP acts similarly as TSTP,
19
- even though we remove the mean statistic, on Voxceleb.
20
- """
21
-
22
- import torch
23
- import torch.nn as nn
24
- import torch.nn.functional as F
25
-
26
-
27
- class TAP(nn.Module):
28
- """
29
- Temporal average pooling, only first-order mean is considered
30
- """
31
-
32
- def __init__(self, in_dim=0, **kwargs):
33
- super(TAP, self).__init__()
34
- self.in_dim = in_dim
35
-
36
- def forward(self, x):
37
- pooling_mean = x.mean(dim=-1)
38
- # To be compatable with 2D input
39
- pooling_mean = pooling_mean.flatten(start_dim=1)
40
- return pooling_mean
41
-
42
- def get_out_dim(self):
43
- self.out_dim = self.in_dim
44
- return self.out_dim
45
-
46
-
47
- class TSDP(nn.Module):
48
- """
49
- Temporal standard deviation pooling, only second-order std is considered
50
- """
51
-
52
- def __init__(self, in_dim=0, **kwargs):
53
- super(TSDP, self).__init__()
54
- self.in_dim = in_dim
55
-
56
- def forward(self, x):
57
- # The last dimension is the temporal axis
58
- pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
59
- pooling_std = pooling_std.flatten(start_dim=1)
60
- return pooling_std
61
-
62
- def get_out_dim(self):
63
- self.out_dim = self.in_dim
64
- return self.out_dim
65
-
66
-
67
- class TSTP(nn.Module):
68
- """
69
- Temporal statistics pooling, concatenate mean and std, which is used in
70
- x-vector
71
- Comment: simple concatenation can not make full use of both statistics
72
- """
73
-
74
- def __init__(self, in_dim=0, **kwargs):
75
- super(TSTP, self).__init__()
76
- self.in_dim = in_dim
77
-
78
- def forward(self, x):
79
- # The last dimension is the temporal axis
80
- pooling_mean = x.mean(dim=-1)
81
- pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
82
- pooling_mean = pooling_mean.flatten(start_dim=1)
83
- pooling_std = pooling_std.flatten(start_dim=1)
84
- stats = torch.cat((pooling_mean, pooling_std), 1)
85
- return stats
86
-
87
- def get_out_dim(self):
88
- self.out_dim = self.in_dim * 2
89
- return self.out_dim
90
-
91
-
92
- class ASTP(nn.Module):
93
- """ Attentive statistics pooling: Channel- and context-dependent
94
- statistics pooling, first used in ECAPA_TDNN.
95
- """
96
-
97
- def __init__(self,
98
- in_dim,
99
- bottleneck_dim=128,
100
- global_context_att=False,
101
- **kwargs):
102
- super(ASTP, self).__init__()
103
- self.in_dim = in_dim
104
- self.global_context_att = global_context_att
105
-
106
- # Use Conv1d with stride == 1 rather than Linear, then we don't
107
- # need to transpose inputs.
108
- if global_context_att:
109
- self.linear1 = nn.Conv1d(
110
- in_dim * 3, bottleneck_dim,
111
- kernel_size=1) # equals W and b in the paper
112
- else:
113
- self.linear1 = nn.Conv1d(
114
- in_dim, bottleneck_dim,
115
- kernel_size=1) # equals W and b in the paper
116
- self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
117
- kernel_size=1) # equals V and k in the paper
118
-
119
- def forward(self, x):
120
- """
121
- x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
122
- or a 4-dimensional tensor in resnet architecture (B,C,F,T)
123
- 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
124
- """
125
- if len(x.shape) == 4:
126
- x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
127
- assert len(x.shape) == 3
128
-
129
- if self.global_context_att:
130
- context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
131
- context_std = torch.sqrt(
132
- torch.var(x, dim=-1, keepdim=True) + 1e-7).expand_as(x)
133
- x_in = torch.cat((x, context_mean, context_std), dim=1)
134
- else:
135
- x_in = x
136
-
137
- # DON'T use ReLU here! ReLU may be hard to converge.
138
- alpha = torch.tanh(
139
- self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
140
- alpha = torch.softmax(self.linear2(alpha), dim=2)
141
- mean = torch.sum(alpha * x, dim=2)
142
- var = torch.sum(alpha * (x**2), dim=2) - mean**2
143
- std = torch.sqrt(var.clamp(min=1e-7))
144
- return torch.cat([mean, std], dim=1)
145
-
146
- def get_out_dim(self):
147
- self.out_dim = 2 * self.in_dim
148
- return self.out_dim
149
-
150
-
151
- class MHASTP(torch.nn.Module):
152
- """ Multi head attentive statistics pooling
153
- Reference:
154
- Self Multi-Head Attention for Speaker Recognition
155
- https://arxiv.org/pdf/1906.09890.pdf
156
- """
157
-
158
- def __init__(self,
159
- in_dim,
160
- layer_num=2,
161
- head_num=2,
162
- d_s=1,
163
- bottleneck_dim=64,
164
- **kwargs):
165
- super(MHASTP, self).__init__()
166
- assert (in_dim % head_num
167
- ) == 0 # make sure that head num can be divided by input_dim
168
- self.in_dim = in_dim
169
- self.head_num = head_num
170
- d_model = int(in_dim / head_num)
171
- channel_dims = [bottleneck_dim for i in range(layer_num + 1)]
172
- if d_s > 1:
173
- d_s = d_model
174
- else:
175
- d_s = 1
176
- self.d_s = d_s
177
- channel_dims[0], channel_dims[-1] = d_model, d_s
178
- heads_att_trans = []
179
- for i in range(self.head_num):
180
- att_trans = nn.Sequential()
181
- for i in range(layer_num - 1):
182
- att_trans.add_module(
183
- 'att_' + str(i),
184
- nn.Conv1d(channel_dims[i], channel_dims[i + 1], 1, 1))
185
- att_trans.add_module('tanh' + str(i), nn.Tanh())
186
- att_trans.add_module(
187
- 'att_' + str(layer_num - 1),
188
- nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num],
189
- 1, 1))
190
- heads_att_trans.append(att_trans)
191
- self.heads_att_trans = nn.ModuleList(heads_att_trans)
192
-
193
- def forward(self, input):
194
- """
195
- input: a 3-dimensional tensor in xvector architecture
196
- or a 4-dimensional tensor in resnet architecture
197
- 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
198
- """
199
- if len(input.shape) == 4: # B x F x T
200
- input = input.reshape(input.shape[0],
201
- input.shape[1] * input.shape[2],
202
- input.shape[3])
203
- assert len(input.shape) == 3
204
- bs, f_dim, t_dim = input.shape
205
- chunks = torch.chunk(input, self.head_num, 1)
206
- # split
207
- chunks_out = []
208
- # for i in range(self.head_num):
209
- # att_score = self.heads_att_trans[i](chunks[i])
210
- for i, layer in enumerate(self.heads_att_trans):
211
- att_score = layer(chunks[i])
212
- alpha = F.softmax(att_score, dim=-1)
213
- mean = torch.sum(alpha * chunks[i], dim=2)
214
- var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2
215
- std = torch.sqrt(var.clamp(min=1e-7))
216
- chunks_out.append(torch.cat((mean, std), dim=1))
217
- out = torch.cat(chunks_out, dim=1)
218
- return out
219
-
220
- def get_out_dim(self):
221
- self.out_dim = 2 * self.in_dim
222
- return self.out_dim
223
-
224
-
225
- class MQMHASTP(torch.nn.Module):
226
- """ An attentive pooling
227
- Reference:
228
- multi query multi head attentive statistics pooling
229
- https://arxiv.org/pdf/2110.05042.pdf
230
- Args:
231
- in_dim: the feature dimension of input
232
- layer_num: the number of layer in the pooling layer
233
- query_num: the number of querys
234
- head_num: the number of heads
235
- bottleneck_dim: the bottleneck dimension
236
-
237
- SA (H = 1, Q = 1, n = 2, d_s = 1) ref:
238
- https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf
239
- MHA (H > 1, Q = 1, n = 1, d_s = 1) ref:
240
- https://arxiv.org/pdf/1906.09890.pdf
241
- AS (H = 1, Q > 1, n = 2, d_s = 1) ref:
242
- https://arxiv.org/pdf/1803.10963.pdf
243
- VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref:
244
- http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf
245
- """
246
-
247
- def __init__(self,
248
- in_dim,
249
- layer_num=2,
250
- query_num=2,
251
- head_num=8,
252
- d_s=2,
253
- bottleneck_dim=64,
254
- **kwargs):
255
- super(MQMHASTP, self).__init__()
256
- self.n_query = nn.ModuleList([
257
- MHASTP(in_dim,
258
- layer_num=layer_num,
259
- head_num=head_num,
260
- d_s=d_s,
261
- bottleneck_dim=bottleneck_dim) for i in range(query_num)
262
- ])
263
- self.query_num = query_num
264
- self.in_dim = in_dim
265
-
266
- def forward(self, input):
267
- """
268
- input: a 3-dimensional tensor in xvector architecture
269
- or a 4-dimensional tensor in resnet architecture
270
- 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
271
- """
272
- if len(input.shape) == 4: # B x F x T
273
- input = input.reshape(input.shape[0],
274
- input.shape[1] * input.shape[2],
275
- input.shape[3])
276
- assert len(input.shape) == 3
277
- res = []
278
- for i, layer in enumerate(self.n_query):
279
- res.append(layer(input))
280
- out = torch.cat(res, dim=-1)
281
- return out
282
-
283
- def get_out_dim(self):
284
- self.out_dim = self.in_dim * 2 * self.query_num
285
- return self.out_dim
286
-
287
-
288
- if __name__ == '__main__':
289
- data = torch.randn(16, 512, 10, 35)
290
- # model = StatisticsPooling()
291
- model = MQMHASTP(512 * 10)
292
- model = MHASTP(512 * 10)
293
- model = MQMHASTP(512 * 10, context=False)
294
- print(model)
295
-
296
- out = model(data)
297
- print(out.shape)
298
- print(model.get_out_dim())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sparktts/modules/speaker/speaker_encoder.py DELETED
@@ -1,136 +0,0 @@
1
- # Copyright (c) 2025 SparkAudio
2
- # 2025 Xinsheng Wang ([email protected])
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import torch
17
- import torch.nn as nn
18
-
19
- from typing import List, Tuple
20
- from sparktts.modules.fsq.residual_fsq import ResidualFSQ
21
- from sparktts.modules.speaker.ecapa_tdnn import ECAPA_TDNN_GLOB_c512
22
- from sparktts.modules.speaker.perceiver_encoder import PerceiverResampler
23
-
24
- """
25
- x-vector + d-vector
26
- """
27
-
28
-
29
- class SpeakerEncoder(nn.Module):
30
- """
31
-
32
- Args:
33
- input_dim (int): acoustic feature dimension
34
- out_dim (int): output dimension of x-vector and d-vector
35
- latent_dim (int): latent dimension before quantization
36
- token_num (int): sequence length of speaker tokens
37
- fsq_levels (List[int]): number of levels for each quantizer
38
- fsq_num_quantizers (int): number of quantizers
39
-
40
- Return:
41
- speaker_embs: (B, T2, out_dim)
42
- """
43
-
44
- def __init__(
45
- self,
46
- input_dim: int = 100,
47
- out_dim: int = 512,
48
- latent_dim: int = 128,
49
- token_num: int = 32,
50
- fsq_levels: List[int] = [4, 4, 4, 4, 4, 4],
51
- fsq_num_quantizers: int = 1,
52
- ):
53
- super(SpeakerEncoder, self).__init__()
54
-
55
- self.speaker_encoder = ECAPA_TDNN_GLOB_c512(
56
- feat_dim=input_dim, embed_dim=out_dim
57
- )
58
- self.perceiver_sampler = PerceiverResampler(
59
- dim=latent_dim, dim_context=512 * 3, num_latents=token_num
60
- )
61
- self.quantizer = ResidualFSQ(
62
- levels=fsq_levels,
63
- num_quantizers=fsq_num_quantizers,
64
- dim=latent_dim,
65
- is_channel_first=True,
66
- quantize_dropout=False,
67
- )
68
-
69
- self.project = nn.Linear(latent_dim * token_num, out_dim)
70
-
71
- def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor:
72
- zq = self.quantizer.get_codes_from_indices(indices.transpose(1, 2))
73
- return zq.transpose(1, 2)
74
-
75
- def get_indices(self, mels: torch.Tensor) -> torch.Tensor:
76
- mels = mels.transpose(1, 2)
77
- x = self.perceiver_sampler(mels).transpose(1, 2)
78
- zq, indices = self.quantizer(x)
79
- return indices
80
-
81
- def forward(self, mels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
82
- """
83
- Args:
84
- mels: (B, D_mel, T1)
85
-
86
- Return:
87
- x_vector: (B, out_dim)
88
- d_vector: (B, out_dim)
89
- """
90
- # mels = mels.transpose(1,2)
91
-
92
- x_vector, features = self.speaker_encoder(mels, True)
93
- x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
94
- zq, indices = self.quantizer(x) # zq: (B, latent_dim, T2, latent_dim)
95
- x = zq.reshape(zq.shape[0], -1)
96
- d_vector = self.project(x)
97
-
98
- return x_vector, d_vector
99
-
100
- def tokenize(self, mels: torch.Tensor) -> torch.Tensor:
101
- """tokenize the input mel spectrogram"""
102
- _, features = self.speaker_encoder(mels, True)
103
- x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
104
- zq, indices = self.quantizer(x)
105
- return indices
106
-
107
- def detokenize(self, indices: torch.Tensor) -> torch.Tensor:
108
- """detokenize the input indices to d-vector"""
109
- zq = self.quantizer.get_output_from_indices(indices.transpose(1, 2)).transpose(1, 2)
110
- x = zq.reshape(zq.shape[0], -1)
111
- d_vector = self.project(x)
112
- return d_vector
113
-
114
- if __name__ == "__main__":
115
- model = SpeakerEncoder(
116
- input_dim=100,
117
- latent_dim=128,
118
- token_num=32,
119
- fsq_levels=[4, 4, 4, 4, 4, 4],
120
- fsq_num_quantizers=1,
121
- )
122
- mel = torch.randn(8, 200, 100)
123
- x_vector, d_vector = model(mel)
124
- print("x-vector shape", x_vector.shape)
125
- print("d-vector shape", d_vector.shape)
126
-
127
- indices = model.tokenize(mel)
128
- print("indices shape", indices.shape)
129
- d_vector_post = model.detokenize(indices)
130
- print("d-vector shape", d_vector_post.shape)
131
- if d_vector_post.all() == d_vector.all():
132
- print("d-vector post and d-vector are the same")
133
- else:
134
- print("d-vector post and d-vector are different")
135
- num_params = sum(param.numel() for param in model.parameters())
136
- print("{} M".format(num_params / 1e6))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sparktts/modules/vq/factorized_vector_quantize.py DELETED
@@ -1,187 +0,0 @@
1
- # Copyright (c) 2025 SparkAudio
2
- # 2025 Xinsheng Wang ([email protected])
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- # Heavily based on https://github.com/lucidrains/vector-quantize-pytorch
17
-
18
-
19
- from typing import Any, Dict
20
-
21
- import torch
22
- import torch.nn as nn
23
- import torch.nn.functional as F
24
- from einops import rearrange
25
- from torch.nn.utils import weight_norm
26
-
27
-
28
- def WNConv1d(*args, **kwargs):
29
- return weight_norm(nn.Conv1d(*args, **kwargs))
30
-
31
-
32
- def ema_inplace(moving_avg, new, decay):
33
- moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
34
-
35
-
36
- class FactorizedVectorQuantize(nn.Module):
37
- def __init__(
38
- self,
39
- input_dim: int,
40
- codebook_size: int,
41
- codebook_dim: int,
42
- commitment: float,
43
- codebook_loss_weight: float = 1.0,
44
- decay: float = 0.99,
45
- threshold_ema_dead_code: float = 2,
46
- momentum: float = 0.99,
47
- **kwargs,
48
- ):
49
- super().__init__()
50
- self.input_dim = input_dim
51
- self.codebook_size = codebook_size
52
- self.codebook_dim = codebook_dim
53
- self.commitment = commitment
54
- self.codebook_loss_weight = codebook_loss_weight
55
- self.decay = decay
56
- self.threshold_ema_dead_code = threshold_ema_dead_code
57
- self.momentum = momentum
58
-
59
- if input_dim != self.codebook_dim:
60
- self.in_project = WNConv1d(input_dim, self.codebook_dim, kernel_size=1)
61
- self.out_project = WNConv1d(self.codebook_dim, input_dim, kernel_size=1)
62
-
63
- else:
64
- self.in_project = nn.Identity()
65
- self.out_project = nn.Identity()
66
-
67
- self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
68
- self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
69
-
70
- def forward(self, z: torch.Tensor) -> Dict[str, Any]:
71
- """Quantized the input tensor using a fixed codebook and returns
72
- the corresponding codebook vectors
73
-
74
- Parameters
75
- ----------
76
- z : Tensor[B x D x T]
77
-
78
- Returns
79
- -------
80
- Tensor[B x D x T]
81
- Quantized continuous representation of input
82
- Tensor[1]
83
- Commitment loss to train encoder to predict vectors closer to codebook
84
- entries
85
- Tensor[1]
86
- Codebook loss to update the codebook
87
- Tensor[B x T]
88
- Codebook indices (quantized discrete representation of input)
89
- Tensor[B x D x T]
90
- Projected latents (continuous representation of input before quantization)
91
- """
92
- # transpose since we use linear
93
-
94
- # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
95
- z_e = self.in_project(z)
96
- z_q, indices, dists = self.decode_latents(z_e)
97
-
98
- # statistic the usage of codes
99
- embed_onehot = F.one_hot(indices, self.codebook_size).type(z_e.dtype)
100
- avg_probs = torch.mean(embed_onehot.reshape(-1, self.codebook_size), dim=0)
101
- perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
102
-
103
- active_num = (embed_onehot.sum(0).sum(0) > 0).sum()
104
- if self.training:
105
- # We do the expiry of code at that point as buffers are in sync
106
- # and all the workers will take the same decision.
107
- ema_inplace(self.cluster_size, embed_onehot.sum(0).sum(0), self.decay)
108
- active_num = sum(self.cluster_size > self.threshold_ema_dead_code)
109
-
110
- if self.training:
111
- commit_loss = (
112
- F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
113
- * self.commitment
114
- )
115
-
116
- codebook_loss = (
117
- F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
118
- * self.codebook_loss_weight
119
- )
120
-
121
- else:
122
- commit_loss = torch.zeros(0, device=z.device)
123
- codebook_loss = torch.zeros(0, device=z.device)
124
-
125
- z_q = (
126
- z_e + (z_q - z_e).detach()
127
- ) # noop in forward pass, straight-through gradient estimator in backward pass
128
-
129
- z_q = self.out_project(z_q)
130
-
131
- vq_loss = (commit_loss + codebook_loss).mean()
132
-
133
- return {
134
- "z_q": z_q,
135
- "indices": indices,
136
- "dists": dists,
137
- "vq_loss": vq_loss,
138
- "perplexity": perplexity,
139
- "active_num": active_num.float(),
140
- }
141
-
142
- def vq2emb(self, vq, out_proj=True):
143
- emb = self.embed_code(vq)
144
- if out_proj:
145
- emb = self.out_project(emb)
146
- return emb
147
-
148
- def tokenize(self, z: torch.Tensor) -> torch.Tensor:
149
- """tokenize the input tensor"""
150
- z_e = self.in_project(z)
151
- _, indices, _ = self.decode_latents(z_e)
152
- return indices
153
-
154
- def detokenize(self, indices):
155
- """detokenize the input indices"""
156
- z_q = self.decode_code(indices)
157
- z_q = self.out_project(z_q)
158
- return z_q
159
-
160
- def get_emb(self):
161
- return self.codebook.weight
162
-
163
- def embed_code(self, embed_id):
164
- return F.embedding(embed_id, self.codebook.weight)
165
-
166
- def decode_code(self, embed_id):
167
- return self.embed_code(embed_id).transpose(1, 2)
168
-
169
- def decode_latents(self, latents):
170
- encodings = rearrange(latents, "b d t -> (b t) d")
171
- codebook = self.codebook.weight
172
-
173
- # L2 normalize encodings and codebook
174
- encodings = F.normalize(encodings)
175
- codebook = F.normalize(codebook)
176
-
177
- # Compute euclidean distance between encodings and codebook,
178
- # with L2 normalization, the distance is equal to cosine distance
179
- dist = (
180
- encodings.pow(2).sum(1, keepdim=True)
181
- - 2 * encodings @ codebook.t()
182
- + codebook.pow(2).sum(1, keepdim=True).t()
183
- )
184
- indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
185
- z_q = self.decode_code(indices)
186
-
187
- return z_q, indices, dist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sparktts/utils/__init__.py DELETED
File without changes
sparktts/utils/audio.py DELETED
@@ -1,271 +0,0 @@
1
- # Copyright (c) 2025 SparkAudio
2
- # 2025 Xinsheng Wang ([email protected])
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """
16
- Description:
17
- This script contains a collection of functions designed to handle various
18
- audio processing.
19
- """
20
-
21
- import random
22
- import soxr
23
- import soundfile
24
- import torch
25
- import torchaudio
26
- import numpy as np
27
-
28
- from pathlib import Path
29
- from typing import Tuple
30
- from numpy.lib.stride_tricks import sliding_window_view
31
-
32
-
33
- def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
34
- """
35
- Normalize the volume of an audio signal.
36
-
37
- Parameters:
38
- audio (numpy array): Input audio signal array.
39
- coeff (float): Target coefficient for normalization, default is 0.2.
40
-
41
- Returns:
42
- numpy array: The volume-normalized audio signal.
43
- """
44
- # Sort the absolute values of the audio signal
45
- temp = np.sort(np.abs(audio))
46
-
47
- # If the maximum value is less than 0.1, scale the array to have a maximum of 0.1
48
- if temp[-1] < 0.1:
49
- scaling_factor = max(
50
- temp[-1], 1e-3
51
- ) # Prevent division by zero with a small constant
52
- audio = audio / scaling_factor * 0.1
53
-
54
- # Filter out values less than 0.01 from temp
55
- temp = temp[temp > 0.01]
56
- L = temp.shape[0] # Length of the filtered array
57
-
58
- # If there are fewer than or equal to 10 significant values, return the audio without further processing
59
- if L <= 10:
60
- return audio
61
-
62
- # Compute the average of the top 10% to 1% of values in temp
63
- volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
64
-
65
- # Normalize the audio to the target coefficient level, clamping the scale factor between 0.1 and 10
66
- audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
67
-
68
- # Ensure the maximum absolute value in the audio does not exceed 1
69
- max_value = np.max(np.abs(audio))
70
- if max_value > 1:
71
- audio = audio / max_value
72
-
73
- return audio
74
-
75
-
76
- def load_audio(
77
- adfile: Path,
78
- sampling_rate: int = None,
79
- length: int = None,
80
- volume_normalize: bool = False,
81
- segment_duration: int = None,
82
- ) -> np.ndarray:
83
- r"""Load audio file with target sampling rate and lsength
84
-
85
- Args:
86
- adfile (Path): path to audio file.
87
- sampling_rate (int, optional): target sampling rate. Defaults to None.
88
- length (int, optional): target audio length. Defaults to None.
89
- volume_normalize (bool, optional): whether perform volume normalization. Defaults to False.
90
- segment_duration (int): random select a segment with duration of {segment_duration}s.
91
- Defualt to None which means the whole audio will be used.
92
-
93
- Returns:
94
- audio (np.ndarray): audio
95
- """
96
-
97
- audio, sr = soundfile.read(adfile)
98
- if len(audio.shape) > 1:
99
- audio = audio[:, 0]
100
-
101
- if sampling_rate is not None and sr != sampling_rate:
102
- audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ")
103
- sr = sampling_rate
104
-
105
- if segment_duration is not None:
106
- seg_length = int(sr * segment_duration)
107
- audio = random_select_audio_segment(audio, seg_length)
108
-
109
- # Audio volume normalize
110
- if volume_normalize:
111
- audio = audio_volume_normalize(audio)
112
- # check the audio length
113
- if length is not None:
114
- assert abs(audio.shape[0] - length) < 1000
115
- if audio.shape[0] > length:
116
- audio = audio[:length]
117
- else:
118
- audio = np.pad(audio, (0, int(length - audio.shape[0])))
119
- return audio
120
-
121
-
122
- def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray:
123
- """get an audio segment given the length
124
-
125
- Args:
126
- audio (np.ndarray):
127
- length (int): audio length = sampling_rate * duration
128
- """
129
- if audio.shape[0] < length:
130
- audio = np.pad(audio, (0, int(length - audio.shape[0])))
131
- start_index = random.randint(0, audio.shape[0] - length)
132
- end_index = int(start_index + length)
133
-
134
- return audio[start_index:end_index]
135
-
136
-
137
- def audio_highpass_filter(audio, sample_rate, highpass_cutoff_freq):
138
- """apply highpass fileter to audio
139
-
140
- Args:
141
- audio (np.ndarray):
142
- sample_rate (ind):
143
- highpass_cutoff_freq (int):
144
- """
145
-
146
- audio = torchaudio.functional.highpass_biquad(
147
- torch.from_numpy(audio), sample_rate, cutoff_freq=highpass_cutoff_freq
148
- )
149
- return audio.numpy()
150
-
151
-
152
- def stft(
153
- x: torch.Tensor,
154
- fft_size: int,
155
- hop_size: int,
156
- win_length: int,
157
- window: str,
158
- use_complex: bool = False,
159
- ) -> torch.Tensor:
160
- """Perform STFT and convert to magnitude spectrogram.
161
- Args:
162
- x (Tensor): Input signal tensor (B, T).
163
- fft_size (int): FFT size.
164
- hop_size (int): Hop size.
165
- win_length (int): Window length.
166
- window (str): Window function type.
167
- Returns:
168
- Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
169
- """
170
-
171
- x_stft = torch.stft(
172
- x, fft_size, hop_size, win_length, window.to(x.device), return_complex=True
173
- )
174
-
175
- # clamp is needed to avoid nan or inf
176
- if not use_complex:
177
- return torch.sqrt(
178
- torch.clamp(x_stft.real**2 + x_stft.imag**2, min=1e-7, max=1e3)
179
- ).transpose(2, 1)
180
- else:
181
- res = torch.cat([x_stft.real.unsqueeze(1), x_stft.imag.unsqueeze(1)], dim=1)
182
- res = res.transpose(2, 3) # [B, 2, T, F]
183
- return res
184
-
185
-
186
- def detect_speech_boundaries(
187
- wav: np.ndarray,
188
- sample_rate: int,
189
- window_duration: float = 0.1,
190
- energy_threshold: float = 0.01,
191
- margin_factor: int = 2
192
- ) -> Tuple[int, int]:
193
- """Detect the start and end points of speech in an audio signal using RMS energy.
194
-
195
- Args:
196
- wav: Input audio signal array with values in [-1, 1]
197
- sample_rate: Audio sample rate in Hz
198
- window_duration: Duration of detection window in seconds
199
- energy_threshold: RMS energy threshold for speech detection
200
- margin_factor: Factor to determine extra margin around detected boundaries
201
-
202
- Returns:
203
- tuple: (start_index, end_index) of speech segment
204
-
205
- Raises:
206
- ValueError: If the audio contains only silence
207
- """
208
- window_size = int(window_duration * sample_rate)
209
- margin = margin_factor * window_size
210
- step_size = window_size // 10
211
-
212
- # Create sliding windows using stride tricks to avoid loops
213
- windows = sliding_window_view(wav, window_size)[::step_size]
214
-
215
- # Calculate RMS energy for each window
216
- energy = np.sqrt(np.mean(windows ** 2, axis=1))
217
- speech_mask = energy >= energy_threshold
218
-
219
- if not np.any(speech_mask):
220
- raise ValueError("No speech detected in audio (only silence)")
221
-
222
- start = max(0, np.argmax(speech_mask) * step_size - margin)
223
- end = min(len(wav), (len(speech_mask) - 1 - np.argmax(speech_mask[::-1])) * step_size + margin)
224
-
225
- return start, end
226
-
227
-
228
- def remove_silence_on_both_ends(
229
- wav: np.ndarray,
230
- sample_rate: int,
231
- window_duration: float = 0.1,
232
- volume_threshold: float = 0.01
233
- ) -> np.ndarray:
234
- """Remove silence from both ends of an audio signal.
235
-
236
- Args:
237
- wav: Input audio signal array
238
- sample_rate: Audio sample rate in Hz
239
- window_duration: Duration of detection window in seconds
240
- volume_threshold: Amplitude threshold for silence detection
241
-
242
- Returns:
243
- np.ndarray: Audio signal with silence removed from both ends
244
-
245
- Raises:
246
- ValueError: If the audio contains only silence
247
- """
248
- start, end = detect_speech_boundaries(
249
- wav,
250
- sample_rate,
251
- window_duration,
252
- volume_threshold
253
- )
254
- return wav[start:end]
255
-
256
-
257
-
258
- def hertz_to_mel(pitch: float) -> float:
259
- """
260
- Converts a frequency from the Hertz scale to the Mel scale.
261
-
262
- Parameters:
263
- - pitch: float or ndarray
264
- Frequency in Hertz.
265
-
266
- Returns:
267
- - mel: float or ndarray
268
- Frequency in Mel scale.
269
- """
270
- mel = 2595 * np.log10(1 + pitch / 700)
271
- return mel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sparktts/utils/file.py DELETED
@@ -1,221 +0,0 @@
1
- # Copyright (c) 2025 SparkAudio
2
- # 2025 Xinsheng Wang ([email protected])
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """
16
- Description:
17
- This script contains a collection of functions designed to handle various
18
- file reading and writing operations. It provides utilities to read from files,
19
- write data to files, and perform file manipulation tasks.
20
- """
21
-
22
-
23
- import os
24
- import json
25
- import json
26
- import csv
27
-
28
- from tqdm import tqdm
29
- from typing import List, Dict, Any, Set, Union
30
- from pathlib import Path
31
- from omegaconf import OmegaConf, DictConfig
32
-
33
-
34
- def resolve_symbolic_link(symbolic_link_path: Path) -> Path:
35
- """
36
- Resolves the absolute path of a symbolic link.
37
-
38
- Args:
39
- symbolic_link_path (Path): The path to the symbolic link.
40
-
41
- Returns:
42
- Path: The absolute path that the symbolic link points to.
43
- """
44
-
45
- link_directory = os.path.dirname(symbolic_link_path)
46
- target_path_relative = os.readlink(symbolic_link_path)
47
- return os.path.join(link_directory, target_path_relative)
48
-
49
-
50
- def write_jsonl(metadata: List[dict], file_path: Path) -> None:
51
- """Writes a list of dictionaries to a JSONL file.
52
-
53
- Args:
54
- metadata : List[dict]
55
- A list of dictionaries, each representing a piece of meta.
56
- file_path : Path
57
- The file path to save the JSONL file
58
-
59
- This function writes each dictionary in the list to a new line in the specified file.
60
- """
61
- with open(file_path, "w", encoding="utf-8") as f:
62
- for meta in tqdm(metadata, desc="writing jsonl"):
63
- # Convert dictionary to JSON string and write it to the file with a newline
64
- json_str = json.dumps(meta, ensure_ascii=False) + "\n"
65
- f.write(json_str)
66
- print(f"jsonl saved to {file_path}")
67
-
68
-
69
- def read_jsonl(file_path: Path) -> List[dict]:
70
- """
71
- Reads a JSONL file and returns a list of dictionaries.
72
-
73
- Args:
74
- file_path : Path
75
- The path to the JSONL file to be read.
76
-
77
- Returns:
78
- List[dict]
79
- A list of dictionaries parsed from each line of the JSONL file.
80
- """
81
- metadata = []
82
- # Open the file for reading
83
- with open(file_path, "r", encoding="utf-8") as f:
84
- # Split the file into lines
85
- lines = f.read().splitlines()
86
- # Process each line
87
- for line in lines:
88
- # Convert JSON string back to dictionary and append to list
89
- meta = json.loads(line)
90
- metadata.append(meta)
91
- # Return the list of metadata
92
- return metadata
93
-
94
- def read_json_as_jsonl(file_path: Path) -> List[dict]:
95
- metadata = []
96
- with open(file_path, 'r', encoding='utf-8') as infile:
97
- data = json.load(infile)
98
- for k in sorted(data.keys()):
99
- meta = {'index': k}
100
- meta.update(data[k])
101
- metadata.append(meta)
102
- return metadata
103
-
104
-
105
-
106
- def decode_unicode_strings(meta: Dict[str, Any]) -> Dict[str, Any]:
107
- processed_meta = {}
108
- for k, v in meta.items():
109
- if isinstance(v, str):
110
- processed_meta[k] = v.encode("utf-8").decode("unicode_escape")
111
- else:
112
- processed_meta[k] = v
113
- return processed_meta
114
-
115
-
116
- def load_config(config_path: Path) -> DictConfig:
117
- """Loads a configuration file and optionally merges it with a base configuration.
118
-
119
- Args:
120
- config_path (Path): Path to the configuration file.
121
- """
122
- # Load the initial configuration from the given path
123
- config = OmegaConf.load(config_path)
124
-
125
- # Check if there is a base configuration specified and merge if necessary
126
- if config.get("base_config", None) is not None:
127
- base_config = OmegaConf.load(config["base_config"])
128
- config = OmegaConf.merge(base_config, config)
129
-
130
- return config
131
-
132
-
133
-
134
- def jsonl_to_csv(jsonl_file_path: str, csv_file_path: str) -> None:
135
- """
136
- Converts a JSONL file to a CSV file.
137
-
138
- This function reads a JSONL file, determines all unique keys present in the file,
139
- and writes the data to a CSV file with columns for all these keys.
140
- """
141
-
142
- all_keys = set()
143
- data_rows = []
144
-
145
- # Read the JSONL file once to extract keys and collect data
146
- with open(jsonl_file_path, 'r') as file:
147
- for line in file:
148
- data = json.loads(line.strip())
149
- data_rows.append(data)
150
- all_keys.update(data.keys())
151
-
152
- # Convert the set of keys to a sorted list for consistent column order
153
- sorted_keys = sorted(all_keys)
154
-
155
- # Write the data to a CSV file
156
- with open(csv_file_path, 'w', newline='') as csvfile:
157
- writer = csv.DictWriter(csvfile, fieldnames=sorted_keys)
158
-
159
- # Write the header row
160
- writer.writeheader()
161
-
162
- # Write each row of data
163
- for data in data_rows:
164
- writer.writerow(data)
165
-
166
- print(f"CSV file has been created at {csv_file_path}")
167
-
168
-
169
- def save_metadata(data, filename, headers=None):
170
- """
171
- Save metadata to a file.
172
-
173
- Args:
174
- data (list of dict): Metadata to be saved.
175
- filename (str): Name of the file to save the metadata.
176
- headers (list of str): The order of column names to be saved; defaults to the keys from the first dictionary in data if not provided.
177
- """
178
- # Set headers to keys from the first dictionary in data if not explicitly provided
179
- if headers is None:
180
- headers = list(data[0].keys())
181
-
182
- with open(filename, "w", encoding="utf-8") as file:
183
- # Write the headers to the file
184
- file.write("|".join(headers) + "\n")
185
- for entry in data:
186
- # Retrieve values in the order of headers, replacing any '|' characters with a space to prevent formatting errors
187
- formatted_values = [str(entry.get(key, "")).replace("|", " ") for key in headers]
188
- # Write the formatted values to the file
189
- file.write("|".join(formatted_values) + "\n")
190
-
191
-
192
- def read_metadata(filename, headers=None):
193
- """
194
- Read metadata from a file.
195
-
196
- Args:
197
- filename (str): The file from which to read the metadata.
198
-
199
- Returns:
200
- list of dict: The metadata read from the file.
201
- list of str: The headers used in the file.
202
- """
203
- with open(filename, "r", encoding="utf-8") as file:
204
- lines = file.readlines()
205
-
206
- data = []
207
- # Set headers from the first line of the file if not provided
208
- if headers is None:
209
- headers = lines[0].strip().split("|")
210
- lines = lines[1:]
211
-
212
- for line in lines:
213
- line = line.strip()
214
- # Skip empty lines
215
- if not line:
216
- continue
217
- # Split the line by '|' and pair with headers to form a dictionary
218
- entry_data = dict(zip(headers, line.split("|")))
219
- data.append(entry_data)
220
-
221
- return data, headers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sparktts/utils/parse_options.sh DELETED
@@ -1,97 +0,0 @@
1
- #!/bin/bash
2
-
3
- # Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
4
- # Arnab Ghoshal, Karel Vesely
5
-
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13
- # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
14
- # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
15
- # MERCHANTABLITY OR NON-INFRINGEMENT.
16
- # See the Apache 2 License for the specific language governing permissions and
17
- # limitations under the License.
18
-
19
-
20
- # Parse command-line options.
21
- # To be sourced by another script (as in ". parse_options.sh").
22
- # Option format is: --option-name arg
23
- # and shell variable "option_name" gets set to value "arg."
24
- # The exception is --help, which takes no arguments, but prints the
25
- # $help_message variable (if defined).
26
-
27
-
28
- ###
29
- ### The --config file options have lower priority to command line
30
- ### options, so we need to import them first...
31
- ###
32
-
33
- # Now import all the configs specified by command-line, in left-to-right order
34
- # for ((argpos=1; argpos<$#; argpos++)); do
35
- # if [ "${!argpos}" == "--config" ]; then
36
- # argpos_plus1=$((argpos+1))
37
- # config=${!argpos_plus1}
38
- # [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
39
- # . $config # source the config file.
40
- # fi
41
- # done
42
-
43
-
44
- ###
45
- ### No we process the command line options
46
- ###
47
- while true; do
48
- [ -z "${1:-}" ] && break; # break if there are no arguments
49
- case "$1" in
50
- # If the enclosing script is called with --help option, print the help
51
- # message and exit. Scripts should put help messages in $help_message
52
- --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
53
- else printf "$help_message\n" 1>&2 ; fi;
54
- exit 0 ;;
55
- --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
56
- exit 1 ;;
57
- # If the first command-line argument begins with "--" (e.g. --foo-bar),
58
- # then work out the variable name as $name, which will equal "foo_bar".
59
- --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
60
- # Next we test whether the variable in question is undefned-- if so it's
61
- # an invalid option and we die. Note: $0 evaluates to the name of the
62
- # enclosing script.
63
- # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
64
- # is undefined. We then have to wrap this test inside "eval" because
65
- # foo_bar is itself inside a variable ($name).
66
- eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
67
-
68
- oldval="`eval echo \\$$name`";
69
- # Work out whether we seem to be expecting a Boolean argument.
70
- if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
71
- was_bool=true;
72
- else
73
- was_bool=false;
74
- fi
75
-
76
- # Set the variable to the right value-- the escaped quotes make it work if
77
- # the option had spaces, like --cmd "queue.pl -sync y"
78
- eval $name=\"$2\";
79
-
80
- # Check that Boolean-valued arguments are really Boolean.
81
- if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
82
- echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
83
- exit 1;
84
- fi
85
- shift 2;
86
- ;;
87
- *) break;
88
- esac
89
- done
90
-
91
-
92
- # Check for an empty argument to the --cmd option, which can easily occur as a
93
- # result of scripting errors.
94
- [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
95
-
96
-
97
- true; # so this script returns exit code 0.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sparktts/utils/token_parser.py DELETED
@@ -1,187 +0,0 @@
1
- TASK_TOKEN_MAP = {
2
- "vc": "<|task_vc|>",
3
- "tts": "<|task_tts|>",
4
- "asr": "<|task_asr|>",
5
- "s2s": "<|task_s2s|>",
6
- "t2s": "<|task_t2s|>",
7
- "understand": "<|task_understand|>",
8
- "caption": "<|task_cap|>",
9
- "controllable_tts": "<|task_controllable_tts|>",
10
- "prompt_tts": "<|task_prompt_tts|>",
11
- "speech_edit": "<|task_edit|>",
12
- }
13
-
14
- LEVELS_MAP = {
15
- "very_low": 0,
16
- "low": 1,
17
- "moderate": 2,
18
- "high": 3,
19
- "very_high": 4,
20
- }
21
-
22
- LEVELS_MAP_UI = {
23
- 1: 'very_low',
24
- 2: 'low',
25
- 3: 'moderate',
26
- 4: 'high',
27
- 5: 'very_high'
28
- }
29
-
30
- GENDER_MAP = {
31
- "female": 0,
32
- "male": 1,
33
- }
34
-
35
- AGE_MAP = {"Child": 0, "Teenager": 1, "Youth-Adult": 2, "Middle-aged": 3, "Elderly": 4}
36
-
37
- EMO_MAP = {
38
- "UNKNOWN": 0,
39
- "NEUTRAL": 1,
40
- "ANGRY": 2,
41
- "HAPPY": 3,
42
- "SAD": 4,
43
- "FEARFUL": 5,
44
- "DISGUSTED": 6,
45
- "SURPRISED": 7,
46
- "SARCASTIC": 8,
47
- "EXCITED": 9,
48
- "SLEEPY": 10,
49
- "CONFUSED": 11,
50
- "EMPHASIS": 12,
51
- "LAUGHING": 13,
52
- "SINGING": 14,
53
- "WORRIED": 15,
54
- "WHISPER": 16,
55
- "ANXIOUS": 17,
56
- "NO-AGREEMENT": 18,
57
- "APOLOGETIC": 19,
58
- "CONCERNED": 20,
59
- "ENUNCIATED": 21,
60
- "ASSERTIVE": 22,
61
- "ENCOURAGING": 23,
62
- "CONTEMPT": 24,
63
- }
64
-
65
-
66
- class TokenParser:
67
- """Turn label to special token"""
68
-
69
- def __init__(self):
70
- pass
71
-
72
- """Parse the attributes of a person."""
73
-
74
- def __init__(self):
75
- pass
76
-
77
- @staticmethod
78
- def age(age: str) -> str:
79
- """Turn age token."""
80
- age_id = AGE_MAP[age]
81
- return f"<|age_{age_id}|>"
82
-
83
- @staticmethod
84
- def gender(gender: str) -> str:
85
- """Turn gender token."""
86
- gender_id = GENDER_MAP[gender]
87
- return f"<|gender_{gender_id}|>"
88
-
89
- @staticmethod
90
- def mel_value(mel: int):
91
- """Turn special token of mel scale pitch."""
92
- mel = max(0, int(mel))
93
- mel = min(1000, int(mel))
94
- return f"<|pitch_value_{mel}|>"
95
-
96
- @staticmethod
97
- def mel_level(level: str):
98
- """Turn special token of mel level."""
99
- level_tag = LEVELS_MAP[level]
100
- return f"<|pitch_label_{level_tag}|>"
101
-
102
- @staticmethod
103
- def pitch_var_value(pitch_std: int):
104
- """Turn special token of pitch_std value."""
105
- assert isinstance(pitch_std, int)
106
- pitch_std = max(0, int(pitch_std))
107
- pitch_std = min(10, int(pitch_std))
108
- return f"<|pitch_var_value_{pitch_std}|>"
109
-
110
- @staticmethod
111
- def pitch_var_level(level: str):
112
- """Turn special token of pitch std level."""
113
- level_tag = LEVELS_MAP[level]
114
- return f"<|pitch_var_label_{level_tag}|>"
115
-
116
- @staticmethod
117
- def loudness_value(loudness: int):
118
- """Turn special toak of loudness value [0, 30]"""
119
- assert loudness >= 0
120
- loudness = max(0, int(loudness))
121
- loudness = min(30, int(loudness))
122
- return f"<|loudness_value_{loudness}|>"
123
-
124
- @staticmethod
125
- def loudness_level(level: str):
126
- """Turn special token of loudness level."""
127
- level_tag = LEVELS_MAP[level]
128
- return f"<|loudness_label_{level_tag}|>"
129
-
130
- @staticmethod
131
- def speed_value(speed: int):
132
- """Turn special token of speed value."""
133
- speed = max(0, int(speed))
134
- speed = min(10, int(speed))
135
- return f"<|speed_value_{speed}|>"
136
-
137
- @staticmethod
138
- def speed_level(level: str):
139
- """Turn special token of speed level."""
140
- level_tag = LEVELS_MAP[level]
141
- return f"<|speed_label_{level_tag}|>"
142
-
143
- @staticmethod
144
- def task(task: str) -> str:
145
- """Turn special token of task."""
146
- assert task in TASK_TOKEN_MAP.keys()
147
-
148
- return TASK_TOKEN_MAP[task]
149
-
150
- @staticmethod
151
- def emotion(emotion: str):
152
- emo_id = EMO_MAP[emotion]
153
-
154
- return f"<|emotion_{emo_id}|>"
155
-
156
-
157
- # test
158
- if __name__ == "__main__":
159
- from transformers import AutoTokenizer
160
-
161
- tokenizer = AutoTokenizer.from_pretrained(
162
- "/aifs4su/xinshengwang/code/StyleCraft/tokenizer/stylecraft-bicodec-pitch-loudness-speed-emotion-tokenizer"
163
- )
164
-
165
- tasks = ["tts", "tts", "understand", "controllable_tts", "prompt_tts"]
166
- ages = ["Child", "Teenager", "Youth-Adult", "Middle-aged", "Elderly"]
167
- genders = ["female", "female", "female", "male", "male"]
168
- mels = [100, 200, 300, 400, 500]
169
- mel_levels = ["very_low", "low", "moderate", "high", "very_high"]
170
- loudnesses = [1, 10, 23, 19, 30]
171
- loudness_levels = ["very_low", "low", "moderate", "high", "very_high"]
172
- emotions = ["UNKNOWN", "NEUTRAL", "ANGRY", "HAPPY", "SAD"]
173
-
174
- for i in range(5):
175
- task = TokenParser.task(tasks[i])
176
- age = TokenParser.age(ages[i])
177
- gender = TokenParser.gender(genders[i])
178
- mel = TokenParser.mel_value(mels[i])
179
- mel_level = TokenParser.mel_level(mel_levels[i])
180
- loudness = TokenParser.loudness_value(loudnesses[i])
181
- loudness_level = TokenParser.loudness_level(loudness_levels[i])
182
- emotion = TokenParser.emotion(emotions[i])
183
- inputs = [task, age, gender, mel, mel_level, loudness, loudness_level, emotion]
184
- inputs = "".join(inputs)
185
- ids = tokenizer.encode(inputs, add_special_tokens=False)
186
- print(ids)
187
- print("decode", tokenizer.decode(ids))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
webui.py CHANGED
@@ -265,5 +265,6 @@ if __name__ == "__main__":
265
  # Launch Gradio with the specified server name and port
266
  demo.launch(
267
  server_name=args.server_name,
268
- server_port=args.server_port
269
- )
 
 
265
  # Launch Gradio with the specified server name and port
266
  demo.launch(
267
  server_name=args.server_name,
268
+ server_port=args.server_port,
269
+ share=True
270
+ )