진승환
commited on
Commit
·
44a8c76
1
Parent(s):
ee29f36
Update UI
Browse files- README.md +6 -0
- sparkTTS/.gitattributes +35 -0
- sparkTTS/README.md +12 -0
- sparktts/models/audio_tokenizer.py +0 -163
- sparktts/models/bicodec.py +0 -247
- sparktts/modules/blocks/layers.py +0 -73
- sparktts/modules/blocks/samper.py +0 -115
- sparktts/modules/blocks/vocos.py +0 -373
- sparktts/modules/encoder_decoder/feat_decoder.py +0 -115
- sparktts/modules/encoder_decoder/feat_encoder.py +0 -105
- sparktts/modules/encoder_decoder/wave_generator.py +0 -88
- sparktts/modules/fsq/finite_scalar_quantization.py +0 -251
- sparktts/modules/fsq/residual_fsq.py +0 -355
- sparktts/modules/speaker/ecapa_tdnn.py +0 -267
- sparktts/modules/speaker/perceiver_encoder.py +0 -360
- sparktts/modules/speaker/pooling_layers.py +0 -298
- sparktts/modules/speaker/speaker_encoder.py +0 -136
- sparktts/modules/vq/factorized_vector_quantize.py +0 -187
- sparktts/utils/__init__.py +0 -0
- sparktts/utils/audio.py +0 -271
- sparktts/utils/file.py +0 -221
- sparktts/utils/parse_options.sh +0 -97
- sparktts/utils/token_parser.py +0 -187
- webui.py +3 -2
README.md
CHANGED
@@ -1,3 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
<div align="center">
|
2 |
<h1>
|
3 |
Spark-TTS
|
|
|
1 |
+
---
|
2 |
+
title: Spark-TTS
|
3 |
+
app_file: webui.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 5.18.0
|
6 |
+
---
|
7 |
<div align="center">
|
8 |
<h1>
|
9 |
Spark-TTS
|
sparkTTS/.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
sparkTTS/README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: SparkTTS
|
3 |
+
emoji: 🐨
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.21.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
sparktts/models/audio_tokenizer.py
DELETED
@@ -1,163 +0,0 @@
|
|
1 |
-
# Copyright (c) 2025 SparkAudio
|
2 |
-
# 2025 Xinsheng Wang ([email protected])
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
|
17 |
-
import torch
|
18 |
-
import numpy as np
|
19 |
-
|
20 |
-
from pathlib import Path
|
21 |
-
from typing import Any, Dict, Tuple
|
22 |
-
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
|
23 |
-
|
24 |
-
from sparktts.utils.file import load_config
|
25 |
-
from sparktts.utils.audio import load_audio
|
26 |
-
from sparktts.models.bicodec import BiCodec
|
27 |
-
|
28 |
-
|
29 |
-
class BiCodecTokenizer:
|
30 |
-
"""BiCodec tokenizer for handling audio input and tokenization."""
|
31 |
-
|
32 |
-
def __init__(self, model_dir: Path, device: torch.device = None, **kwargs):
|
33 |
-
super().__init__()
|
34 |
-
"""
|
35 |
-
Args:
|
36 |
-
model_dir: Path to the model directory.
|
37 |
-
device: Device to run the model on (default is GPU if available).
|
38 |
-
"""
|
39 |
-
self.device = device
|
40 |
-
self.model_dir = model_dir
|
41 |
-
self.config = load_config(f"{model_dir}/config.yaml")
|
42 |
-
self._initialize_model()
|
43 |
-
|
44 |
-
def _initialize_model(self):
|
45 |
-
"""Load and initialize the BiCodec model and Wav2Vec2 feature extractor."""
|
46 |
-
self.model = BiCodec.load_from_checkpoint(f"{self.model_dir}/BiCodec").to(
|
47 |
-
self.device
|
48 |
-
)
|
49 |
-
self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
|
50 |
-
f"{self.model_dir}/wav2vec2-large-xlsr-53"
|
51 |
-
)
|
52 |
-
self.feature_extractor = Wav2Vec2Model.from_pretrained(
|
53 |
-
f"{self.model_dir}/wav2vec2-large-xlsr-53"
|
54 |
-
).to(self.device)
|
55 |
-
self.feature_extractor.config.output_hidden_states = True
|
56 |
-
|
57 |
-
def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
|
58 |
-
"""Get reference audio clip for speaker embedding."""
|
59 |
-
ref_segment_length = (
|
60 |
-
int(self.config["sample_rate"] * self.config["ref_segment_duration"])
|
61 |
-
// self.config["latent_hop_length"]
|
62 |
-
* self.config["latent_hop_length"]
|
63 |
-
)
|
64 |
-
wav_length = len(wav)
|
65 |
-
|
66 |
-
if ref_segment_length > wav_length:
|
67 |
-
# Repeat and truncate to handle insufficient length
|
68 |
-
wav = np.tile(wav, ref_segment_length // wav_length + 1)
|
69 |
-
|
70 |
-
return wav[:ref_segment_length]
|
71 |
-
|
72 |
-
def process_audio(self, wav_path: Path) -> Tuple[np.ndarray, torch.Tensor]:
|
73 |
-
"""load auido and get reference audio from wav path"""
|
74 |
-
wav = load_audio(
|
75 |
-
wav_path,
|
76 |
-
sampling_rate=self.config["sample_rate"],
|
77 |
-
volume_normalize=self.config["volume_normalize"],
|
78 |
-
)
|
79 |
-
|
80 |
-
wav_ref = self.get_ref_clip(wav)
|
81 |
-
|
82 |
-
wav_ref = torch.from_numpy(wav_ref).unsqueeze(0).float()
|
83 |
-
return wav, wav_ref
|
84 |
-
|
85 |
-
def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor:
|
86 |
-
"""extract wav2vec2 features"""
|
87 |
-
inputs = self.processor(
|
88 |
-
wavs,
|
89 |
-
sampling_rate=16000,
|
90 |
-
return_tensors="pt",
|
91 |
-
padding=True,
|
92 |
-
output_hidden_states=True,
|
93 |
-
).input_values
|
94 |
-
feat = self.feature_extractor(inputs.to(self.feature_extractor.device))
|
95 |
-
feats_mix = (
|
96 |
-
feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16]
|
97 |
-
) / 3
|
98 |
-
|
99 |
-
return feats_mix
|
100 |
-
|
101 |
-
def tokenize_batch(self, batch: Dict[str, Any]) -> torch.Tensor:
|
102 |
-
"""tokenize the batch of audio
|
103 |
-
|
104 |
-
Args:
|
105 |
-
batch:
|
106 |
-
wavs (List[np.ndarray]): batch of audio
|
107 |
-
ref_wavs (torch.Tensor): reference audio. shape: (batch_size, seq_len)
|
108 |
-
|
109 |
-
Returns:
|
110 |
-
semantic_tokens: semantic tokens. shape: (batch_size, seq_len, latent_dim)
|
111 |
-
global_tokens: global tokens. shape: (batch_size, seq_len, global_dim)
|
112 |
-
"""
|
113 |
-
feats = self.extract_wav2vec2_features(batch["wav"])
|
114 |
-
batch["feat"] = feats
|
115 |
-
semantic_tokens, global_tokens = self.model.tokenize(batch)
|
116 |
-
|
117 |
-
return global_tokens, semantic_tokens
|
118 |
-
|
119 |
-
def tokenize(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
120 |
-
"""tokenize the audio"""
|
121 |
-
wav, ref_wav = self.process_audio(audio_path)
|
122 |
-
feat = self.extract_wav2vec2_features(wav)
|
123 |
-
batch = {
|
124 |
-
"wav": torch.from_numpy(wav).unsqueeze(0).float().to(self.device),
|
125 |
-
"ref_wav": ref_wav.to(self.device),
|
126 |
-
"feat": feat.to(self.device),
|
127 |
-
}
|
128 |
-
semantic_tokens, global_tokens = self.model.tokenize(batch)
|
129 |
-
|
130 |
-
return global_tokens, semantic_tokens
|
131 |
-
|
132 |
-
def detokenize(
|
133 |
-
self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor
|
134 |
-
) -> np.array:
|
135 |
-
"""detokenize the tokens to waveform
|
136 |
-
|
137 |
-
Args:
|
138 |
-
global_tokens: global tokens. shape: (batch_size, global_dim)
|
139 |
-
semantic_tokens: semantic tokens. shape: (batch_size, latent_dim)
|
140 |
-
|
141 |
-
Returns:
|
142 |
-
wav_rec: waveform. shape: (batch_size, seq_len) for batch or (seq_len,) for single
|
143 |
-
"""
|
144 |
-
global_tokens = global_tokens.unsqueeze(1)
|
145 |
-
wav_rec = self.model.detokenize(semantic_tokens, global_tokens)
|
146 |
-
return wav_rec.detach().squeeze().cpu().numpy()
|
147 |
-
|
148 |
-
|
149 |
-
# test
|
150 |
-
if __name__ == "__main__":
|
151 |
-
import soundfile as sf
|
152 |
-
|
153 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
154 |
-
tokenizer = BiCodecTokenizer(
|
155 |
-
model_dir="pretrained_models/Spark-TTS-0.5B",
|
156 |
-
device=device,
|
157 |
-
)
|
158 |
-
wav_path = "example/prompt_audio.wav"
|
159 |
-
|
160 |
-
global_tokens, semantic_tokens = tokenizer.tokenize(wav_path)
|
161 |
-
|
162 |
-
wav_rec = tokenizer.detokenize(global_tokens.squeeze(0), semantic_tokens)
|
163 |
-
sf.write("example/prompt_recon.wav", wav_rec, 16000)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparktts/models/bicodec.py
DELETED
@@ -1,247 +0,0 @@
|
|
1 |
-
# Copyright (c) 2025 SparkAudio
|
2 |
-
# 2025 Xinsheng Wang ([email protected])
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
import torch
|
17 |
-
import torch.nn as nn
|
18 |
-
from pathlib import Path
|
19 |
-
from typing import Dict, Any
|
20 |
-
from omegaconf import DictConfig
|
21 |
-
from safetensors.torch import load_file
|
22 |
-
|
23 |
-
from sparktts.utils.file import load_config
|
24 |
-
from sparktts.modules.speaker.speaker_encoder import SpeakerEncoder
|
25 |
-
from sparktts.modules.encoder_decoder.feat_encoder import Encoder
|
26 |
-
from sparktts.modules.encoder_decoder.feat_decoder import Decoder
|
27 |
-
from sparktts.modules.encoder_decoder.wave_generator import WaveGenerator
|
28 |
-
from sparktts.modules.vq.factorized_vector_quantize import FactorizedVectorQuantize
|
29 |
-
|
30 |
-
|
31 |
-
class BiCodec(nn.Module):
|
32 |
-
"""
|
33 |
-
BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder,
|
34 |
-
quantizer, and wave generator.
|
35 |
-
"""
|
36 |
-
|
37 |
-
def __init__(
|
38 |
-
self,
|
39 |
-
mel_params: Dict[str, Any],
|
40 |
-
encoder: nn.Module,
|
41 |
-
decoder: nn.Module,
|
42 |
-
quantizer: nn.Module,
|
43 |
-
speaker_encoder: nn.Module,
|
44 |
-
prenet: nn.Module,
|
45 |
-
postnet: nn.Module,
|
46 |
-
**kwargs
|
47 |
-
) -> None:
|
48 |
-
"""
|
49 |
-
Initializes the BiCodec model with the required components.
|
50 |
-
|
51 |
-
Args:
|
52 |
-
mel_params (dict): Parameters for the mel-spectrogram transformer.
|
53 |
-
encoder (nn.Module): Encoder module.
|
54 |
-
decoder (nn.Module): Decoder module.
|
55 |
-
quantizer (nn.Module): Quantizer module.
|
56 |
-
speaker_encoder (nn.Module): Speaker encoder module.
|
57 |
-
prenet (nn.Module): Prenet network.
|
58 |
-
postnet (nn.Module): Postnet network.
|
59 |
-
"""
|
60 |
-
super().__init__()
|
61 |
-
self.encoder = encoder
|
62 |
-
self.decoder = decoder
|
63 |
-
self.quantizer = quantizer
|
64 |
-
self.speaker_encoder = speaker_encoder
|
65 |
-
self.prenet = prenet
|
66 |
-
self.postnet = postnet
|
67 |
-
self.init_mel_transformer(mel_params)
|
68 |
-
|
69 |
-
@classmethod
|
70 |
-
def load_from_checkpoint(cls, model_dir: Path, **kwargs) -> "BiCodec":
|
71 |
-
"""
|
72 |
-
Loads the model from a checkpoint.
|
73 |
-
|
74 |
-
Args:
|
75 |
-
model_dir (Path): Path to the model directory containing checkpoint and config.
|
76 |
-
|
77 |
-
Returns:
|
78 |
-
BiCodec: The initialized BiCodec model.
|
79 |
-
"""
|
80 |
-
ckpt_path = f'{model_dir}/model.safetensors'
|
81 |
-
config = load_config(f'{model_dir}/config.yaml')['audio_tokenizer']
|
82 |
-
mel_params = config["mel_params"]
|
83 |
-
encoder = Encoder(**config["encoder"])
|
84 |
-
quantizer = FactorizedVectorQuantize(**config["quantizer"])
|
85 |
-
prenet = Decoder(**config["prenet"])
|
86 |
-
postnet = Decoder(**config["postnet"])
|
87 |
-
decoder = WaveGenerator(**config["decoder"])
|
88 |
-
speaker_encoder = SpeakerEncoder(**config["speaker_encoder"])
|
89 |
-
|
90 |
-
model = cls(
|
91 |
-
mel_params=mel_params,
|
92 |
-
encoder=encoder,
|
93 |
-
decoder=decoder,
|
94 |
-
quantizer=quantizer,
|
95 |
-
speaker_encoder=speaker_encoder,
|
96 |
-
prenet=prenet,
|
97 |
-
postnet=postnet,
|
98 |
-
)
|
99 |
-
|
100 |
-
state_dict = load_file(ckpt_path)
|
101 |
-
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
102 |
-
|
103 |
-
for key in missing_keys:
|
104 |
-
print(f"Missing tensor: {key}")
|
105 |
-
for key in unexpected_keys:
|
106 |
-
print(f"Unexpected tensor: {key}")
|
107 |
-
|
108 |
-
model.eval()
|
109 |
-
model.remove_weight_norm()
|
110 |
-
|
111 |
-
return model
|
112 |
-
|
113 |
-
def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
|
114 |
-
"""
|
115 |
-
Performs a forward pass through the model.
|
116 |
-
|
117 |
-
Args:
|
118 |
-
batch (dict): A dictionary containing features, reference waveform, and target waveform.
|
119 |
-
|
120 |
-
Returns:
|
121 |
-
dict: A dictionary containing the reconstruction, features, and other metrics.
|
122 |
-
"""
|
123 |
-
feat = batch["feat"]
|
124 |
-
mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
|
125 |
-
|
126 |
-
z = self.encoder(feat.transpose(1, 2))
|
127 |
-
vq_outputs = self.quantizer(z)
|
128 |
-
|
129 |
-
x_vector, d_vector = self.speaker_encoder(mel.transpose(1, 2))
|
130 |
-
|
131 |
-
conditions = d_vector
|
132 |
-
with_speaker_loss = False
|
133 |
-
|
134 |
-
x = self.prenet(vq_outputs["z_q"], conditions)
|
135 |
-
pred_feat = self.postnet(x)
|
136 |
-
x = x + conditions.unsqueeze(-1)
|
137 |
-
wav_recon = self.decoder(x)
|
138 |
-
|
139 |
-
return {
|
140 |
-
"vq_loss": vq_outputs["vq_loss"],
|
141 |
-
"perplexity": vq_outputs["perplexity"],
|
142 |
-
"cluster_size": vq_outputs["active_num"],
|
143 |
-
"recons": wav_recon,
|
144 |
-
"pred_feat": pred_feat,
|
145 |
-
"x_vector": x_vector,
|
146 |
-
"d_vector": d_vector,
|
147 |
-
"audios": batch["wav"].unsqueeze(1),
|
148 |
-
"with_speaker_loss": with_speaker_loss,
|
149 |
-
}
|
150 |
-
|
151 |
-
@torch.no_grad()
|
152 |
-
def tokenize(self, batch: Dict[str, Any]):
|
153 |
-
"""
|
154 |
-
Tokenizes the input audio into semantic and global tokens.
|
155 |
-
|
156 |
-
Args:
|
157 |
-
batch (dict): The input audio features and reference waveform.
|
158 |
-
|
159 |
-
Returns:
|
160 |
-
tuple: Semantic tokens and global tokens.
|
161 |
-
"""
|
162 |
-
feat = batch["feat"]
|
163 |
-
mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
|
164 |
-
|
165 |
-
z = self.encoder(feat.transpose(1, 2))
|
166 |
-
semantic_tokens = self.quantizer.tokenize(z)
|
167 |
-
global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2))
|
168 |
-
|
169 |
-
return semantic_tokens, global_tokens
|
170 |
-
|
171 |
-
@torch.no_grad()
|
172 |
-
def detokenize(self, semantic_tokens, global_tokens):
|
173 |
-
"""
|
174 |
-
Detokenizes the semantic and global tokens into a waveform.
|
175 |
-
|
176 |
-
Args:
|
177 |
-
semantic_tokens (tensor): Semantic tokens.
|
178 |
-
global_tokens (tensor): Global tokens.
|
179 |
-
|
180 |
-
Returns:
|
181 |
-
tensor: Reconstructed waveform.
|
182 |
-
"""
|
183 |
-
z_q = self.quantizer.detokenize(semantic_tokens)
|
184 |
-
d_vector = self.speaker_encoder.detokenize(global_tokens)
|
185 |
-
x = self.prenet(z_q, d_vector)
|
186 |
-
x = x + d_vector.unsqueeze(-1)
|
187 |
-
wav_recon = self.decoder(x)
|
188 |
-
|
189 |
-
return wav_recon
|
190 |
-
|
191 |
-
def init_mel_transformer(self, config: Dict[str, Any]):
|
192 |
-
"""
|
193 |
-
Initializes the MelSpectrogram transformer based on the provided configuration.
|
194 |
-
|
195 |
-
Args:
|
196 |
-
config (dict): Configuration parameters for MelSpectrogram.
|
197 |
-
"""
|
198 |
-
import torchaudio.transforms as TT
|
199 |
-
|
200 |
-
self.mel_transformer = TT.MelSpectrogram(
|
201 |
-
config["sample_rate"],
|
202 |
-
config["n_fft"],
|
203 |
-
config["win_length"],
|
204 |
-
config["hop_length"],
|
205 |
-
config["mel_fmin"],
|
206 |
-
config["mel_fmax"],
|
207 |
-
n_mels=config["num_mels"],
|
208 |
-
power=1,
|
209 |
-
norm="slaney",
|
210 |
-
mel_scale="slaney",
|
211 |
-
)
|
212 |
-
|
213 |
-
def remove_weight_norm(self):
|
214 |
-
"""Removes weight normalization from all layers."""
|
215 |
-
def _remove_weight_norm(m):
|
216 |
-
try:
|
217 |
-
torch.nn.utils.remove_weight_norm(m)
|
218 |
-
except ValueError:
|
219 |
-
pass # The module didn't have weight norm
|
220 |
-
|
221 |
-
self.apply(_remove_weight_norm)
|
222 |
-
|
223 |
-
|
224 |
-
# Test the model
|
225 |
-
if __name__ == "__main__":
|
226 |
-
|
227 |
-
config = load_config("pretrained_models/SparkTTS-0.5B/BiCodec/config.yaml")
|
228 |
-
model = BiCodec.load_from_checkpoint(
|
229 |
-
model_dir="pretrained_models/SparkTTS-0.5B/BiCodec",
|
230 |
-
)
|
231 |
-
|
232 |
-
# Generate random inputs for testing
|
233 |
-
duration = 0.96
|
234 |
-
x = torch.randn(20, 1, int(duration * 16000))
|
235 |
-
feat = torch.randn(20, int(duration * 50), 1024)
|
236 |
-
inputs = {"feat": feat, "wav": x, "ref_wav": x}
|
237 |
-
|
238 |
-
# Forward pass
|
239 |
-
outputs = model(inputs)
|
240 |
-
semantic_tokens, global_tokens = model.tokenize(inputs)
|
241 |
-
wav_recon = model.detokenize(semantic_tokens, global_tokens)
|
242 |
-
|
243 |
-
# Verify if the reconstruction matches
|
244 |
-
if torch.allclose(outputs["recons"].detach(), wav_recon):
|
245 |
-
print("Test successful")
|
246 |
-
else:
|
247 |
-
print("Test failed")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparktts/modules/blocks/layers.py
DELETED
@@ -1,73 +0,0 @@
|
|
1 |
-
# Copyright (c) 2025 SparkAudio
|
2 |
-
# 2025 Xinsheng Wang ([email protected])
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
# Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
|
17 |
-
|
18 |
-
|
19 |
-
import torch
|
20 |
-
import torch.nn as nn
|
21 |
-
from torch.nn.utils import weight_norm
|
22 |
-
|
23 |
-
|
24 |
-
def WNConv1d(*args, **kwargs):
|
25 |
-
return weight_norm(nn.Conv1d(*args, **kwargs))
|
26 |
-
|
27 |
-
|
28 |
-
def WNConvTranspose1d(*args, **kwargs):
|
29 |
-
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
30 |
-
|
31 |
-
|
32 |
-
# Scripting this brings model speed up 1.4x
|
33 |
-
@torch.jit.script
|
34 |
-
def snake(x, alpha):
|
35 |
-
shape = x.shape
|
36 |
-
x = x.reshape(shape[0], shape[1], -1)
|
37 |
-
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
38 |
-
x = x.reshape(shape)
|
39 |
-
return x
|
40 |
-
|
41 |
-
|
42 |
-
class Snake1d(nn.Module):
|
43 |
-
def __init__(self, channels):
|
44 |
-
super().__init__()
|
45 |
-
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
46 |
-
|
47 |
-
def forward(self, x):
|
48 |
-
return snake(x, self.alpha)
|
49 |
-
|
50 |
-
|
51 |
-
class ResidualUnit(nn.Module):
|
52 |
-
def __init__(self, dim: int = 16, dilation: int = 1):
|
53 |
-
super().__init__()
|
54 |
-
pad = ((7 - 1) * dilation) // 2
|
55 |
-
self.block = nn.Sequential(
|
56 |
-
Snake1d(dim),
|
57 |
-
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
58 |
-
Snake1d(dim),
|
59 |
-
WNConv1d(dim, dim, kernel_size=1),
|
60 |
-
)
|
61 |
-
|
62 |
-
def forward(self, x):
|
63 |
-
y = self.block(x)
|
64 |
-
pad = (x.shape[-1] - y.shape[-1]) // 2
|
65 |
-
if pad > 0:
|
66 |
-
x = x[..., pad:-pad]
|
67 |
-
return x + y
|
68 |
-
|
69 |
-
|
70 |
-
def init_weights(m):
|
71 |
-
if isinstance(m, nn.Conv1d):
|
72 |
-
nn.init.trunc_normal_(m.weight, std=0.02)
|
73 |
-
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparktts/modules/blocks/samper.py
DELETED
@@ -1,115 +0,0 @@
|
|
1 |
-
# Copyright (c) 2025 SparkAudio
|
2 |
-
# 2025 Xinsheng Wang ([email protected])
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
|
17 |
-
import torch
|
18 |
-
import torch.nn as nn
|
19 |
-
import torch.nn.functional as F
|
20 |
-
|
21 |
-
|
22 |
-
class SamplingBlock(nn.Module):
|
23 |
-
"""Sampling block for upsampling or downsampling"""
|
24 |
-
|
25 |
-
def __init__(
|
26 |
-
self,
|
27 |
-
dim: int,
|
28 |
-
groups: int = 1,
|
29 |
-
upsample_scale: int = 1,
|
30 |
-
downsample_scale: int = 1,
|
31 |
-
) -> None:
|
32 |
-
"""
|
33 |
-
Args:
|
34 |
-
dim: input dimension
|
35 |
-
groups: number of groups
|
36 |
-
upsample_scale: upsampling scale
|
37 |
-
downsample_scale: downsampling scale
|
38 |
-
"""
|
39 |
-
super(SamplingBlock, self).__init__()
|
40 |
-
|
41 |
-
self.upsample_scale = upsample_scale
|
42 |
-
self.downsample_scale = downsample_scale
|
43 |
-
|
44 |
-
if self.upsample_scale > 1:
|
45 |
-
self.de_conv_upsampler = nn.Sequential(
|
46 |
-
nn.LeakyReLU(0.2),
|
47 |
-
nn.ConvTranspose1d(
|
48 |
-
dim,
|
49 |
-
dim,
|
50 |
-
kernel_size=upsample_scale * 2,
|
51 |
-
stride=upsample_scale,
|
52 |
-
padding=upsample_scale // 2 + upsample_scale % 2,
|
53 |
-
output_padding=upsample_scale % 2,
|
54 |
-
groups=groups,
|
55 |
-
),
|
56 |
-
)
|
57 |
-
|
58 |
-
if self.downsample_scale > 1:
|
59 |
-
self.conv_downsampler = nn.Sequential(
|
60 |
-
nn.LeakyReLU(0.2),
|
61 |
-
nn.Conv1d(
|
62 |
-
dim,
|
63 |
-
dim,
|
64 |
-
kernel_size=2 * downsample_scale,
|
65 |
-
stride=downsample_scale,
|
66 |
-
padding=downsample_scale // 2 + downsample_scale % 2,
|
67 |
-
groups=groups,
|
68 |
-
),
|
69 |
-
)
|
70 |
-
|
71 |
-
@staticmethod
|
72 |
-
def repeat_upsampler(x, upsample_scale):
|
73 |
-
return x.repeat_interleave(upsample_scale, dim=2)
|
74 |
-
|
75 |
-
@staticmethod
|
76 |
-
def skip_downsampler(x, downsample_scale):
|
77 |
-
return F.avg_pool1d(x, kernel_size=downsample_scale, stride=downsample_scale)
|
78 |
-
|
79 |
-
def forward(self, x):
|
80 |
-
x = x.transpose(1, 2)
|
81 |
-
if self.upsample_scale > 1:
|
82 |
-
repeat_res = self.repeat_upsampler(x, self.upsample_scale)
|
83 |
-
deconv_res = self.de_conv_upsampler(x)
|
84 |
-
upmerge_res = repeat_res + deconv_res
|
85 |
-
else:
|
86 |
-
upmerge_res = x
|
87 |
-
repeat_res = x
|
88 |
-
|
89 |
-
if self.downsample_scale > 1:
|
90 |
-
conv_res = self.conv_downsampler(upmerge_res)
|
91 |
-
skip2_res = self.skip_downsampler(upmerge_res, self.downsample_scale)
|
92 |
-
skip1_res = self.skip_downsampler(repeat_res, self.downsample_scale)
|
93 |
-
else:
|
94 |
-
conv_res = upmerge_res
|
95 |
-
skip2_res = upmerge_res
|
96 |
-
skip1_res = repeat_res
|
97 |
-
|
98 |
-
final_res = conv_res + skip1_res + skip2_res
|
99 |
-
|
100 |
-
return final_res
|
101 |
-
|
102 |
-
|
103 |
-
# test
|
104 |
-
if __name__ == "__main__":
|
105 |
-
test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
|
106 |
-
model = SamplingBlock(1024, 1024, upsample_scale=2)
|
107 |
-
model_down = SamplingBlock(1024, 1024, downsample_scale=2)
|
108 |
-
output = model(test_input)
|
109 |
-
output_down = model_down(test_input)
|
110 |
-
print("shape after upsample * 2", output.shape) # torch.Size([8, 1024, 100])
|
111 |
-
print("shape after downsample * 2", output_down.shape) # torch.Size([8, 1024, 25])
|
112 |
-
if output.shape == torch.Size([8, 1024, 100]) and output_down.shape == torch.Size(
|
113 |
-
[8, 1024, 25]
|
114 |
-
):
|
115 |
-
print("test successful")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparktts/modules/blocks/vocos.py
DELETED
@@ -1,373 +0,0 @@
|
|
1 |
-
# Copyright (c) 2025 SparkAudio
|
2 |
-
# 2025 Xinsheng Wang ([email protected])
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
|
17 |
-
import torch
|
18 |
-
import torch.nn as nn
|
19 |
-
|
20 |
-
from typing import Tuple
|
21 |
-
from torch.nn.utils import weight_norm, remove_weight_norm
|
22 |
-
|
23 |
-
from typing import Optional
|
24 |
-
|
25 |
-
|
26 |
-
class ConvNeXtBlock(nn.Module):
|
27 |
-
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
28 |
-
|
29 |
-
Args:
|
30 |
-
dim (int): Number of input channels.
|
31 |
-
intermediate_dim (int): Dimensionality of the intermediate layer.
|
32 |
-
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
33 |
-
Defaults to None.
|
34 |
-
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
35 |
-
None means non-conditional LayerNorm. Defaults to None.
|
36 |
-
"""
|
37 |
-
|
38 |
-
def __init__(
|
39 |
-
self,
|
40 |
-
dim: int,
|
41 |
-
intermediate_dim: int,
|
42 |
-
layer_scale_init_value: float,
|
43 |
-
condition_dim: Optional[int] = None,
|
44 |
-
):
|
45 |
-
super().__init__()
|
46 |
-
self.dwconv = nn.Conv1d(
|
47 |
-
dim, dim, kernel_size=7, padding=3, groups=dim
|
48 |
-
) # depthwise conv
|
49 |
-
self.adanorm = condition_dim is not None
|
50 |
-
if condition_dim:
|
51 |
-
self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
|
52 |
-
else:
|
53 |
-
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
54 |
-
self.pwconv1 = nn.Linear(
|
55 |
-
dim, intermediate_dim
|
56 |
-
) # pointwise/1x1 convs, implemented with linear layers
|
57 |
-
self.act = nn.GELU()
|
58 |
-
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
59 |
-
self.gamma = (
|
60 |
-
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
61 |
-
if layer_scale_init_value > 0
|
62 |
-
else None
|
63 |
-
)
|
64 |
-
|
65 |
-
def forward(
|
66 |
-
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
67 |
-
) -> torch.Tensor:
|
68 |
-
residual = x
|
69 |
-
x = self.dwconv(x)
|
70 |
-
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
71 |
-
if self.adanorm:
|
72 |
-
assert cond_embedding_id is not None
|
73 |
-
x = self.norm(x, cond_embedding_id)
|
74 |
-
else:
|
75 |
-
x = self.norm(x)
|
76 |
-
x = self.pwconv1(x)
|
77 |
-
x = self.act(x)
|
78 |
-
x = self.pwconv2(x)
|
79 |
-
if self.gamma is not None:
|
80 |
-
x = self.gamma * x
|
81 |
-
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
82 |
-
|
83 |
-
x = residual + x
|
84 |
-
return x
|
85 |
-
|
86 |
-
|
87 |
-
class AdaLayerNorm(nn.Module):
|
88 |
-
"""
|
89 |
-
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
90 |
-
|
91 |
-
Args:
|
92 |
-
condition_dim (int): Dimension of the condition.
|
93 |
-
embedding_dim (int): Dimension of the embeddings.
|
94 |
-
"""
|
95 |
-
|
96 |
-
def __init__(self, condition_dim: int, embedding_dim: int, eps: float = 1e-6):
|
97 |
-
super().__init__()
|
98 |
-
self.eps = eps
|
99 |
-
self.dim = embedding_dim
|
100 |
-
self.scale = nn.Linear(condition_dim, embedding_dim)
|
101 |
-
self.shift = nn.Linear(condition_dim, embedding_dim)
|
102 |
-
torch.nn.init.ones_(self.scale.weight)
|
103 |
-
torch.nn.init.zeros_(self.shift.weight)
|
104 |
-
|
105 |
-
def forward(self, x: torch.Tensor, cond_embedding: torch.Tensor) -> torch.Tensor:
|
106 |
-
scale = self.scale(cond_embedding)
|
107 |
-
shift = self.shift(cond_embedding)
|
108 |
-
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
109 |
-
x = x * scale.unsqueeze(1) + shift.unsqueeze(1)
|
110 |
-
return x
|
111 |
-
|
112 |
-
|
113 |
-
class ResBlock1(nn.Module):
|
114 |
-
"""
|
115 |
-
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
|
116 |
-
but without upsampling layers.
|
117 |
-
|
118 |
-
Args:
|
119 |
-
dim (int): Number of input channels.
|
120 |
-
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
|
121 |
-
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
|
122 |
-
Defaults to (1, 3, 5).
|
123 |
-
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
|
124 |
-
Defaults to 0.1.
|
125 |
-
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
126 |
-
Defaults to None.
|
127 |
-
"""
|
128 |
-
|
129 |
-
def __init__(
|
130 |
-
self,
|
131 |
-
dim: int,
|
132 |
-
kernel_size: int = 3,
|
133 |
-
dilation: Tuple[int, int, int] = (1, 3, 5),
|
134 |
-
lrelu_slope: float = 0.1,
|
135 |
-
layer_scale_init_value: Optional[float] = None,
|
136 |
-
):
|
137 |
-
super().__init__()
|
138 |
-
self.lrelu_slope = lrelu_slope
|
139 |
-
self.convs1 = nn.ModuleList(
|
140 |
-
[
|
141 |
-
weight_norm(
|
142 |
-
nn.Conv1d(
|
143 |
-
dim,
|
144 |
-
dim,
|
145 |
-
kernel_size,
|
146 |
-
1,
|
147 |
-
dilation=dilation[0],
|
148 |
-
padding=self.get_padding(kernel_size, dilation[0]),
|
149 |
-
)
|
150 |
-
),
|
151 |
-
weight_norm(
|
152 |
-
nn.Conv1d(
|
153 |
-
dim,
|
154 |
-
dim,
|
155 |
-
kernel_size,
|
156 |
-
1,
|
157 |
-
dilation=dilation[1],
|
158 |
-
padding=self.get_padding(kernel_size, dilation[1]),
|
159 |
-
)
|
160 |
-
),
|
161 |
-
weight_norm(
|
162 |
-
nn.Conv1d(
|
163 |
-
dim,
|
164 |
-
dim,
|
165 |
-
kernel_size,
|
166 |
-
1,
|
167 |
-
dilation=dilation[2],
|
168 |
-
padding=self.get_padding(kernel_size, dilation[2]),
|
169 |
-
)
|
170 |
-
),
|
171 |
-
]
|
172 |
-
)
|
173 |
-
|
174 |
-
self.convs2 = nn.ModuleList(
|
175 |
-
[
|
176 |
-
weight_norm(
|
177 |
-
nn.Conv1d(
|
178 |
-
dim,
|
179 |
-
dim,
|
180 |
-
kernel_size,
|
181 |
-
1,
|
182 |
-
dilation=1,
|
183 |
-
padding=self.get_padding(kernel_size, 1),
|
184 |
-
)
|
185 |
-
),
|
186 |
-
weight_norm(
|
187 |
-
nn.Conv1d(
|
188 |
-
dim,
|
189 |
-
dim,
|
190 |
-
kernel_size,
|
191 |
-
1,
|
192 |
-
dilation=1,
|
193 |
-
padding=self.get_padding(kernel_size, 1),
|
194 |
-
)
|
195 |
-
),
|
196 |
-
weight_norm(
|
197 |
-
nn.Conv1d(
|
198 |
-
dim,
|
199 |
-
dim,
|
200 |
-
kernel_size,
|
201 |
-
1,
|
202 |
-
dilation=1,
|
203 |
-
padding=self.get_padding(kernel_size, 1),
|
204 |
-
)
|
205 |
-
),
|
206 |
-
]
|
207 |
-
)
|
208 |
-
|
209 |
-
self.gamma = nn.ParameterList(
|
210 |
-
[
|
211 |
-
(
|
212 |
-
nn.Parameter(
|
213 |
-
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
214 |
-
)
|
215 |
-
if layer_scale_init_value is not None
|
216 |
-
else None
|
217 |
-
),
|
218 |
-
(
|
219 |
-
nn.Parameter(
|
220 |
-
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
221 |
-
)
|
222 |
-
if layer_scale_init_value is not None
|
223 |
-
else None
|
224 |
-
),
|
225 |
-
(
|
226 |
-
nn.Parameter(
|
227 |
-
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
228 |
-
)
|
229 |
-
if layer_scale_init_value is not None
|
230 |
-
else None
|
231 |
-
),
|
232 |
-
]
|
233 |
-
)
|
234 |
-
|
235 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
236 |
-
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
|
237 |
-
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
|
238 |
-
xt = c1(xt)
|
239 |
-
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
|
240 |
-
xt = c2(xt)
|
241 |
-
if gamma is not None:
|
242 |
-
xt = gamma * xt
|
243 |
-
x = xt + x
|
244 |
-
return x
|
245 |
-
|
246 |
-
def remove_weight_norm(self):
|
247 |
-
for l in self.convs1:
|
248 |
-
remove_weight_norm(l)
|
249 |
-
for l in self.convs2:
|
250 |
-
remove_weight_norm(l)
|
251 |
-
|
252 |
-
@staticmethod
|
253 |
-
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
254 |
-
return int((kernel_size * dilation - dilation) / 2)
|
255 |
-
|
256 |
-
|
257 |
-
class Backbone(nn.Module):
|
258 |
-
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
|
259 |
-
|
260 |
-
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
261 |
-
"""
|
262 |
-
Args:
|
263 |
-
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
|
264 |
-
C denotes output features, and L is the sequence length.
|
265 |
-
|
266 |
-
Returns:
|
267 |
-
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
268 |
-
and H denotes the model dimension.
|
269 |
-
"""
|
270 |
-
raise NotImplementedError("Subclasses must implement the forward method.")
|
271 |
-
|
272 |
-
|
273 |
-
class VocosBackbone(Backbone):
|
274 |
-
"""
|
275 |
-
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
|
276 |
-
|
277 |
-
Args:
|
278 |
-
input_channels (int): Number of input features channels.
|
279 |
-
dim (int): Hidden dimension of the model.
|
280 |
-
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
281 |
-
num_layers (int): Number of ConvNeXtBlock layers.
|
282 |
-
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
283 |
-
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
284 |
-
None means non-conditional model. Defaults to None.
|
285 |
-
"""
|
286 |
-
|
287 |
-
def __init__(
|
288 |
-
self,
|
289 |
-
input_channels: int,
|
290 |
-
dim: int,
|
291 |
-
intermediate_dim: int,
|
292 |
-
num_layers: int,
|
293 |
-
layer_scale_init_value: Optional[float] = None,
|
294 |
-
condition_dim: Optional[int] = None,
|
295 |
-
):
|
296 |
-
super().__init__()
|
297 |
-
self.input_channels = input_channels
|
298 |
-
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
|
299 |
-
self.adanorm = condition_dim is not None
|
300 |
-
if condition_dim:
|
301 |
-
self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
|
302 |
-
else:
|
303 |
-
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
304 |
-
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
305 |
-
self.convnext = nn.ModuleList(
|
306 |
-
[
|
307 |
-
ConvNeXtBlock(
|
308 |
-
dim=dim,
|
309 |
-
intermediate_dim=intermediate_dim,
|
310 |
-
layer_scale_init_value=layer_scale_init_value,
|
311 |
-
condition_dim=condition_dim,
|
312 |
-
)
|
313 |
-
for _ in range(num_layers)
|
314 |
-
]
|
315 |
-
)
|
316 |
-
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
317 |
-
self.apply(self._init_weights)
|
318 |
-
|
319 |
-
def _init_weights(self, m):
|
320 |
-
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
321 |
-
nn.init.trunc_normal_(m.weight, std=0.02)
|
322 |
-
nn.init.constant_(m.bias, 0)
|
323 |
-
|
324 |
-
def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> torch.Tensor:
|
325 |
-
x = self.embed(x)
|
326 |
-
if self.adanorm:
|
327 |
-
assert condition is not None
|
328 |
-
x = self.norm(x.transpose(1, 2), condition)
|
329 |
-
else:
|
330 |
-
x = self.norm(x.transpose(1, 2))
|
331 |
-
x = x.transpose(1, 2)
|
332 |
-
for conv_block in self.convnext:
|
333 |
-
x = conv_block(x, condition)
|
334 |
-
x = self.final_layer_norm(x.transpose(1, 2))
|
335 |
-
return x
|
336 |
-
|
337 |
-
|
338 |
-
class VocosResNetBackbone(Backbone):
|
339 |
-
"""
|
340 |
-
Vocos backbone module built with ResBlocks.
|
341 |
-
|
342 |
-
Args:
|
343 |
-
input_channels (int): Number of input features channels.
|
344 |
-
dim (int): Hidden dimension of the model.
|
345 |
-
num_blocks (int): Number of ResBlock1 blocks.
|
346 |
-
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
|
347 |
-
"""
|
348 |
-
|
349 |
-
def __init__(
|
350 |
-
self,
|
351 |
-
input_channels,
|
352 |
-
dim,
|
353 |
-
num_blocks,
|
354 |
-
layer_scale_init_value=None,
|
355 |
-
):
|
356 |
-
super().__init__()
|
357 |
-
self.input_channels = input_channels
|
358 |
-
self.embed = weight_norm(
|
359 |
-
nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
|
360 |
-
)
|
361 |
-
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
|
362 |
-
self.resnet = nn.Sequential(
|
363 |
-
*[
|
364 |
-
ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
|
365 |
-
for _ in range(num_blocks)
|
366 |
-
]
|
367 |
-
)
|
368 |
-
|
369 |
-
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
370 |
-
x = self.embed(x)
|
371 |
-
x = self.resnet(x)
|
372 |
-
x = x.transpose(1, 2)
|
373 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparktts/modules/encoder_decoder/feat_decoder.py
DELETED
@@ -1,115 +0,0 @@
|
|
1 |
-
# Copyright (c) 2025 SparkAudio
|
2 |
-
# 2025 Xinsheng Wang ([email protected])
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
|
17 |
-
import torch
|
18 |
-
import torch.nn as nn
|
19 |
-
|
20 |
-
from typing import List
|
21 |
-
|
22 |
-
from sparktts.modules.blocks.vocos import VocosBackbone
|
23 |
-
from sparktts.modules.blocks.samper import SamplingBlock
|
24 |
-
|
25 |
-
|
26 |
-
class Decoder(nn.Module):
|
27 |
-
"""Decoder module with convnext and upsampling blocks
|
28 |
-
|
29 |
-
Args:
|
30 |
-
sample_ratios (List[int]): sample ratios
|
31 |
-
example: [2, 2] means downsample by 2x and then upsample by 2x
|
32 |
-
"""
|
33 |
-
|
34 |
-
def __init__(
|
35 |
-
self,
|
36 |
-
input_channels: int,
|
37 |
-
vocos_dim: int,
|
38 |
-
vocos_intermediate_dim: int,
|
39 |
-
vocos_num_layers: int,
|
40 |
-
out_channels: int,
|
41 |
-
condition_dim: int = None,
|
42 |
-
sample_ratios: List[int] = [1, 1],
|
43 |
-
use_tanh_at_final: bool = False,
|
44 |
-
):
|
45 |
-
super().__init__()
|
46 |
-
|
47 |
-
self.linear_pre = nn.Linear(input_channels, vocos_dim)
|
48 |
-
modules = [
|
49 |
-
nn.Sequential(
|
50 |
-
SamplingBlock(
|
51 |
-
dim=vocos_dim,
|
52 |
-
groups=vocos_dim,
|
53 |
-
upsample_scale=ratio,
|
54 |
-
),
|
55 |
-
VocosBackbone(
|
56 |
-
input_channels=vocos_dim,
|
57 |
-
dim=vocos_dim,
|
58 |
-
intermediate_dim=vocos_intermediate_dim,
|
59 |
-
num_layers=2,
|
60 |
-
condition_dim=None,
|
61 |
-
),
|
62 |
-
)
|
63 |
-
for ratio in sample_ratios
|
64 |
-
]
|
65 |
-
|
66 |
-
self.downsample = nn.Sequential(*modules)
|
67 |
-
|
68 |
-
self.vocos_backbone = VocosBackbone(
|
69 |
-
input_channels=vocos_dim,
|
70 |
-
dim=vocos_dim,
|
71 |
-
intermediate_dim=vocos_intermediate_dim,
|
72 |
-
num_layers=vocos_num_layers,
|
73 |
-
condition_dim=condition_dim,
|
74 |
-
)
|
75 |
-
self.linear = nn.Linear(vocos_dim, out_channels)
|
76 |
-
self.use_tanh_at_final = use_tanh_at_final
|
77 |
-
|
78 |
-
def forward(self, x: torch.Tensor, c: torch.Tensor = None):
|
79 |
-
"""encoder forward.
|
80 |
-
|
81 |
-
Args:
|
82 |
-
x (torch.Tensor): (batch_size, input_channels, length)
|
83 |
-
|
84 |
-
Returns:
|
85 |
-
x (torch.Tensor): (batch_size, encode_channels, length)
|
86 |
-
"""
|
87 |
-
x = self.linear_pre(x.transpose(1, 2))
|
88 |
-
x = self.downsample(x).transpose(1, 2)
|
89 |
-
x = self.vocos_backbone(x, condition=c)
|
90 |
-
x = self.linear(x).transpose(1, 2)
|
91 |
-
if self.use_tanh_at_final:
|
92 |
-
x = torch.tanh(x)
|
93 |
-
|
94 |
-
return x
|
95 |
-
|
96 |
-
|
97 |
-
# test
|
98 |
-
if __name__ == "__main__":
|
99 |
-
test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
|
100 |
-
condition = torch.randn(8, 256)
|
101 |
-
decoder = Decoder(
|
102 |
-
input_channels=1024,
|
103 |
-
vocos_dim=384,
|
104 |
-
vocos_intermediate_dim=2048,
|
105 |
-
vocos_num_layers=12,
|
106 |
-
out_channels=256,
|
107 |
-
condition_dim=256,
|
108 |
-
sample_ratios=[2, 2],
|
109 |
-
)
|
110 |
-
output = decoder(test_input, condition)
|
111 |
-
print(output.shape) # torch.Size([8, 256, 200])
|
112 |
-
if output.shape == torch.Size([8, 256, 200]):
|
113 |
-
print("Decoder test passed")
|
114 |
-
else:
|
115 |
-
print("Decoder test failed")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparktts/modules/encoder_decoder/feat_encoder.py
DELETED
@@ -1,105 +0,0 @@
|
|
1 |
-
# Copyright (c) 2025 SparkAudio
|
2 |
-
# 2025 Xinsheng Wang ([email protected])
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
|
17 |
-
import torch
|
18 |
-
import torch.nn as nn
|
19 |
-
|
20 |
-
from typing import List
|
21 |
-
|
22 |
-
from sparktts.modules.blocks.vocos import VocosBackbone
|
23 |
-
from sparktts.modules.blocks.samper import SamplingBlock
|
24 |
-
|
25 |
-
|
26 |
-
class Encoder(nn.Module):
|
27 |
-
"""Encoder module with convnext and downsampling blocks"""
|
28 |
-
|
29 |
-
def __init__(
|
30 |
-
self,
|
31 |
-
input_channels: int,
|
32 |
-
vocos_dim: int,
|
33 |
-
vocos_intermediate_dim: int,
|
34 |
-
vocos_num_layers: int,
|
35 |
-
out_channels: int,
|
36 |
-
sample_ratios: List[int] = [1, 1],
|
37 |
-
):
|
38 |
-
super().__init__()
|
39 |
-
"""
|
40 |
-
Encoder module with VocosBackbone and sampling blocks.
|
41 |
-
|
42 |
-
Args:
|
43 |
-
sample_ratios (List[int]): sample ratios
|
44 |
-
example: [2, 2] means downsample by 2x and then upsample by 2x
|
45 |
-
"""
|
46 |
-
self.encoder = VocosBackbone(
|
47 |
-
input_channels=input_channels,
|
48 |
-
dim=vocos_dim,
|
49 |
-
intermediate_dim=vocos_intermediate_dim,
|
50 |
-
num_layers=vocos_num_layers,
|
51 |
-
condition_dim=None,
|
52 |
-
)
|
53 |
-
|
54 |
-
modules = [
|
55 |
-
nn.Sequential(
|
56 |
-
SamplingBlock(
|
57 |
-
dim=vocos_dim,
|
58 |
-
groups=vocos_dim,
|
59 |
-
downsample_scale=ratio,
|
60 |
-
),
|
61 |
-
VocosBackbone(
|
62 |
-
input_channels=vocos_dim,
|
63 |
-
dim=vocos_dim,
|
64 |
-
intermediate_dim=vocos_intermediate_dim,
|
65 |
-
num_layers=2,
|
66 |
-
condition_dim=None,
|
67 |
-
),
|
68 |
-
)
|
69 |
-
for ratio in sample_ratios
|
70 |
-
]
|
71 |
-
|
72 |
-
self.downsample = nn.Sequential(*modules)
|
73 |
-
|
74 |
-
self.project = nn.Linear(vocos_dim, out_channels)
|
75 |
-
|
76 |
-
def forward(self, x: torch.Tensor, *args):
|
77 |
-
"""
|
78 |
-
Args:
|
79 |
-
x (torch.Tensor): (batch_size, input_channels, length)
|
80 |
-
|
81 |
-
Returns:
|
82 |
-
x (torch.Tensor): (batch_size, encode_channels, length)
|
83 |
-
"""
|
84 |
-
x = self.encoder(x)
|
85 |
-
x = self.downsample(x)
|
86 |
-
x = self.project(x)
|
87 |
-
return x.transpose(1, 2)
|
88 |
-
|
89 |
-
|
90 |
-
# test
|
91 |
-
if __name__ == "__main__":
|
92 |
-
test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
|
93 |
-
encoder = Encoder(
|
94 |
-
input_channels=1024,
|
95 |
-
vocos_dim=384,
|
96 |
-
vocos_intermediate_dim=2048,
|
97 |
-
vocos_num_layers=12,
|
98 |
-
out_channels=256,
|
99 |
-
sample_ratios=[2, 2],
|
100 |
-
)
|
101 |
-
|
102 |
-
output = encoder(test_input)
|
103 |
-
print(output.shape) # torch.Size([8, 256, 12])
|
104 |
-
if output.shape == torch.Size([8, 256, 12]):
|
105 |
-
print("test successful")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparktts/modules/encoder_decoder/wave_generator.py
DELETED
@@ -1,88 +0,0 @@
|
|
1 |
-
# Copyright (c) 2024 Xinsheng Wang ([email protected])
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
# Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
|
16 |
-
|
17 |
-
|
18 |
-
import torch.nn as nn
|
19 |
-
|
20 |
-
from sparktts.modules.blocks.layers import (
|
21 |
-
Snake1d,
|
22 |
-
WNConv1d,
|
23 |
-
ResidualUnit,
|
24 |
-
WNConvTranspose1d,
|
25 |
-
init_weights,
|
26 |
-
)
|
27 |
-
|
28 |
-
|
29 |
-
class DecoderBlock(nn.Module):
|
30 |
-
def __init__(
|
31 |
-
self,
|
32 |
-
input_dim: int = 16,
|
33 |
-
output_dim: int = 8,
|
34 |
-
kernel_size: int = 2,
|
35 |
-
stride: int = 1,
|
36 |
-
):
|
37 |
-
super().__init__()
|
38 |
-
self.block = nn.Sequential(
|
39 |
-
Snake1d(input_dim),
|
40 |
-
WNConvTranspose1d(
|
41 |
-
input_dim,
|
42 |
-
output_dim,
|
43 |
-
kernel_size=kernel_size,
|
44 |
-
stride=stride,
|
45 |
-
padding=(kernel_size - stride) // 2,
|
46 |
-
),
|
47 |
-
ResidualUnit(output_dim, dilation=1),
|
48 |
-
ResidualUnit(output_dim, dilation=3),
|
49 |
-
ResidualUnit(output_dim, dilation=9),
|
50 |
-
)
|
51 |
-
|
52 |
-
def forward(self, x):
|
53 |
-
return self.block(x)
|
54 |
-
|
55 |
-
|
56 |
-
class WaveGenerator(nn.Module):
|
57 |
-
def __init__(
|
58 |
-
self,
|
59 |
-
input_channel,
|
60 |
-
channels,
|
61 |
-
rates,
|
62 |
-
kernel_sizes,
|
63 |
-
d_out: int = 1,
|
64 |
-
):
|
65 |
-
super().__init__()
|
66 |
-
|
67 |
-
# Add first conv layer
|
68 |
-
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
69 |
-
|
70 |
-
# Add upsampling + MRF blocks
|
71 |
-
for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)):
|
72 |
-
input_dim = channels // 2**i
|
73 |
-
output_dim = channels // 2 ** (i + 1)
|
74 |
-
layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)]
|
75 |
-
|
76 |
-
# Add final conv layer
|
77 |
-
layers += [
|
78 |
-
Snake1d(output_dim),
|
79 |
-
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
80 |
-
nn.Tanh(),
|
81 |
-
]
|
82 |
-
|
83 |
-
self.model = nn.Sequential(*layers)
|
84 |
-
|
85 |
-
self.apply(init_weights)
|
86 |
-
|
87 |
-
def forward(self, x):
|
88 |
-
return self.model(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparktts/modules/fsq/finite_scalar_quantization.py
DELETED
@@ -1,251 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
|
3 |
-
Code adapted from Jax version in Appendix A.1
|
4 |
-
"""
|
5 |
-
|
6 |
-
from __future__ import annotations
|
7 |
-
from functools import wraps, partial
|
8 |
-
from contextlib import nullcontext
|
9 |
-
from typing import List, Tuple
|
10 |
-
|
11 |
-
import torch
|
12 |
-
import torch.nn as nn
|
13 |
-
from torch.nn import Module
|
14 |
-
from torch import Tensor, int32
|
15 |
-
from torch.amp import autocast
|
16 |
-
|
17 |
-
from einops import rearrange, pack, unpack
|
18 |
-
|
19 |
-
# helper functions
|
20 |
-
|
21 |
-
|
22 |
-
def exists(v):
|
23 |
-
return v is not None
|
24 |
-
|
25 |
-
|
26 |
-
def default(*args):
|
27 |
-
for arg in args:
|
28 |
-
if exists(arg):
|
29 |
-
return arg
|
30 |
-
return None
|
31 |
-
|
32 |
-
|
33 |
-
def maybe(fn):
|
34 |
-
@wraps(fn)
|
35 |
-
def inner(x, *args, **kwargs):
|
36 |
-
if not exists(x):
|
37 |
-
return x
|
38 |
-
return fn(x, *args, **kwargs)
|
39 |
-
|
40 |
-
return inner
|
41 |
-
|
42 |
-
|
43 |
-
def pack_one(t, pattern):
|
44 |
-
return pack([t], pattern)
|
45 |
-
|
46 |
-
|
47 |
-
def unpack_one(t, ps, pattern):
|
48 |
-
return unpack(t, ps, pattern)[0]
|
49 |
-
|
50 |
-
|
51 |
-
# tensor helpers
|
52 |
-
|
53 |
-
|
54 |
-
def round_ste(z: Tensor) -> Tensor:
|
55 |
-
"""Round with straight through gradients."""
|
56 |
-
zhat = z.round()
|
57 |
-
return z + (zhat - z).detach()
|
58 |
-
|
59 |
-
|
60 |
-
# main class
|
61 |
-
|
62 |
-
|
63 |
-
class FSQ(Module):
|
64 |
-
def __init__(
|
65 |
-
self,
|
66 |
-
levels: List[int],
|
67 |
-
dim: int | None = None,
|
68 |
-
num_codebooks=1,
|
69 |
-
keep_num_codebooks_dim: bool | None = None,
|
70 |
-
scale: float | None = None,
|
71 |
-
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
|
72 |
-
channel_first: bool = False,
|
73 |
-
projection_has_bias: bool = True,
|
74 |
-
return_indices=True,
|
75 |
-
force_quantization_f32=True,
|
76 |
-
):
|
77 |
-
super().__init__()
|
78 |
-
_levels = torch.tensor(levels, dtype=int32)
|
79 |
-
self.register_buffer("_levels", _levels, persistent=False)
|
80 |
-
|
81 |
-
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
|
82 |
-
self.register_buffer("_basis", _basis, persistent=False)
|
83 |
-
|
84 |
-
self.scale = scale
|
85 |
-
|
86 |
-
codebook_dim = len(levels)
|
87 |
-
self.codebook_dim = codebook_dim
|
88 |
-
|
89 |
-
effective_codebook_dim = codebook_dim * num_codebooks
|
90 |
-
self.num_codebooks = num_codebooks
|
91 |
-
self.effective_codebook_dim = effective_codebook_dim
|
92 |
-
|
93 |
-
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
94 |
-
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
95 |
-
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
96 |
-
|
97 |
-
self.dim = default(dim, len(_levels) * num_codebooks)
|
98 |
-
|
99 |
-
self.channel_first = channel_first
|
100 |
-
|
101 |
-
has_projections = self.dim != effective_codebook_dim
|
102 |
-
self.project_in = (
|
103 |
-
nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias)
|
104 |
-
if has_projections
|
105 |
-
else nn.Identity()
|
106 |
-
)
|
107 |
-
self.project_out = (
|
108 |
-
nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias)
|
109 |
-
if has_projections
|
110 |
-
else nn.Identity()
|
111 |
-
)
|
112 |
-
|
113 |
-
self.has_projections = has_projections
|
114 |
-
|
115 |
-
self.return_indices = return_indices
|
116 |
-
if return_indices:
|
117 |
-
self.codebook_size = self._levels.prod().item()
|
118 |
-
implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
|
119 |
-
self.register_buffer(
|
120 |
-
"implicit_codebook", implicit_codebook, persistent=False
|
121 |
-
)
|
122 |
-
|
123 |
-
self.allowed_dtypes = allowed_dtypes
|
124 |
-
self.force_quantization_f32 = force_quantization_f32
|
125 |
-
|
126 |
-
def bound(self, z, eps: float = 1e-3):
|
127 |
-
"""Bound `z`, an array of shape (..., d)."""
|
128 |
-
half_l = (self._levels - 1) * (1 + eps) / 2
|
129 |
-
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
|
130 |
-
shift = (offset / half_l).atanh()
|
131 |
-
return (z + shift).tanh() * half_l - offset
|
132 |
-
|
133 |
-
def quantize(self, z):
|
134 |
-
"""Quantizes z, returns quantized zhat, same shape as z."""
|
135 |
-
quantized = round_ste(self.bound(z))
|
136 |
-
half_width = self._levels // 2 # Renormalize to [-1, 1].
|
137 |
-
return quantized / half_width
|
138 |
-
|
139 |
-
def _scale_and_shift(self, zhat_normalized):
|
140 |
-
half_width = self._levels // 2
|
141 |
-
return (zhat_normalized * half_width) + half_width
|
142 |
-
|
143 |
-
def _scale_and_shift_inverse(self, zhat):
|
144 |
-
half_width = self._levels // 2
|
145 |
-
return (zhat - half_width) / half_width
|
146 |
-
|
147 |
-
def _indices_to_codes(self, indices):
|
148 |
-
level_indices = self.indices_to_level_indices(indices)
|
149 |
-
codes = self._scale_and_shift_inverse(level_indices)
|
150 |
-
return codes
|
151 |
-
|
152 |
-
def codes_to_indices(self, zhat):
|
153 |
-
"""Converts a `code` to an index in the codebook."""
|
154 |
-
assert zhat.shape[-1] == self.codebook_dim
|
155 |
-
zhat = self._scale_and_shift(zhat)
|
156 |
-
return (zhat * self._basis).sum(dim=-1).to(int32)
|
157 |
-
|
158 |
-
def indices_to_level_indices(self, indices):
|
159 |
-
"""Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings"""
|
160 |
-
indices = rearrange(indices, "... -> ... 1")
|
161 |
-
codes_non_centered = (indices // self._basis) % self._levels
|
162 |
-
return codes_non_centered
|
163 |
-
|
164 |
-
def indices_to_codes(self, indices):
|
165 |
-
"""Inverse of `codes_to_indices`."""
|
166 |
-
assert exists(indices)
|
167 |
-
|
168 |
-
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
169 |
-
|
170 |
-
codes = self._indices_to_codes(indices)
|
171 |
-
|
172 |
-
if self.keep_num_codebooks_dim:
|
173 |
-
codes = rearrange(codes, "... c d -> ... (c d)")
|
174 |
-
|
175 |
-
codes = self.project_out(codes)
|
176 |
-
|
177 |
-
if is_img_or_video or self.channel_first:
|
178 |
-
codes = rearrange(codes, "b ... d -> b d ...")
|
179 |
-
|
180 |
-
return codes
|
181 |
-
|
182 |
-
def forward(self, z):
|
183 |
-
"""
|
184 |
-
einstein notation
|
185 |
-
b - batch
|
186 |
-
n - sequence (or flattened spatial dimensions)
|
187 |
-
d - feature dimension
|
188 |
-
c - number of codebook dim
|
189 |
-
"""
|
190 |
-
|
191 |
-
is_img_or_video = z.ndim >= 4
|
192 |
-
need_move_channel_last = is_img_or_video or self.channel_first
|
193 |
-
|
194 |
-
# standardize image or video into (batch, seq, dimension)
|
195 |
-
|
196 |
-
if need_move_channel_last:
|
197 |
-
z = rearrange(z, "b d ... -> b ... d")
|
198 |
-
z, ps = pack_one(z, "b * d")
|
199 |
-
|
200 |
-
assert (
|
201 |
-
z.shape[-1] == self.dim
|
202 |
-
), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
|
203 |
-
|
204 |
-
z = self.project_in(z)
|
205 |
-
|
206 |
-
z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
|
207 |
-
|
208 |
-
# whether to force quantization step to be full precision or not
|
209 |
-
|
210 |
-
force_f32 = self.force_quantization_f32
|
211 |
-
quantization_context = (
|
212 |
-
partial(autocast, "cuda", enabled=False) if force_f32 else nullcontext
|
213 |
-
)
|
214 |
-
|
215 |
-
with quantization_context():
|
216 |
-
orig_dtype = z.dtype
|
217 |
-
|
218 |
-
if force_f32 and orig_dtype not in self.allowed_dtypes:
|
219 |
-
z = z.float()
|
220 |
-
|
221 |
-
codes = self.quantize(z)
|
222 |
-
|
223 |
-
# returning indices could be optional
|
224 |
-
|
225 |
-
indices = None
|
226 |
-
|
227 |
-
if self.return_indices:
|
228 |
-
indices = self.codes_to_indices(codes)
|
229 |
-
|
230 |
-
codes = rearrange(codes, "b n c d -> b n (c d)")
|
231 |
-
|
232 |
-
codes = codes.type(orig_dtype)
|
233 |
-
|
234 |
-
# project out
|
235 |
-
|
236 |
-
out = self.project_out(codes)
|
237 |
-
|
238 |
-
# reconstitute image or video dimensions
|
239 |
-
|
240 |
-
if need_move_channel_last:
|
241 |
-
out = unpack_one(out, ps, "b * d")
|
242 |
-
out = rearrange(out, "b ... d -> b d ...")
|
243 |
-
|
244 |
-
indices = maybe(unpack_one)(indices, ps, "b * c")
|
245 |
-
|
246 |
-
if not self.keep_num_codebooks_dim and self.return_indices:
|
247 |
-
indices = maybe(rearrange)(indices, "... 1 -> ...")
|
248 |
-
|
249 |
-
# return quantized output and indices
|
250 |
-
|
251 |
-
return out, indices
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparktts/modules/fsq/residual_fsq.py
DELETED
@@ -1,355 +0,0 @@
|
|
1 |
-
import random
|
2 |
-
import torch
|
3 |
-
import torch.nn.functional as F
|
4 |
-
import torch.distributed as dist
|
5 |
-
|
6 |
-
from typing import List
|
7 |
-
from torch import nn
|
8 |
-
from torch.nn import Module
|
9 |
-
from torch.amp import autocast
|
10 |
-
from einx import get_at
|
11 |
-
from einops import rearrange, reduce, pack, unpack
|
12 |
-
|
13 |
-
from sparktts.modules.fsq.finite_scalar_quantization import FSQ
|
14 |
-
|
15 |
-
|
16 |
-
def exists(val):
|
17 |
-
return val is not None
|
18 |
-
|
19 |
-
|
20 |
-
def first(l):
|
21 |
-
return l[0]
|
22 |
-
|
23 |
-
|
24 |
-
def default(val, d):
|
25 |
-
return val if exists(val) else d
|
26 |
-
|
27 |
-
|
28 |
-
def round_up_multiple(num, mult):
|
29 |
-
return ceil(num / mult) * mult
|
30 |
-
|
31 |
-
|
32 |
-
# distributed helpers
|
33 |
-
|
34 |
-
|
35 |
-
def is_distributed():
|
36 |
-
return dist.is_initialized() and dist.get_world_size() > 1
|
37 |
-
|
38 |
-
|
39 |
-
def get_maybe_sync_seed(device, max_size=10_000):
|
40 |
-
rand_int = torch.randint(0, max_size, (), device=device)
|
41 |
-
|
42 |
-
if is_distributed():
|
43 |
-
dist.all_reduce(rand_int)
|
44 |
-
|
45 |
-
return rand_int.item()
|
46 |
-
|
47 |
-
|
48 |
-
class ResidualFSQ(Module):
|
49 |
-
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
|
50 |
-
|
51 |
-
def __init__(
|
52 |
-
self,
|
53 |
-
*,
|
54 |
-
levels: List[int],
|
55 |
-
num_quantizers,
|
56 |
-
dim=None,
|
57 |
-
is_channel_first=False,
|
58 |
-
quantize_dropout=False,
|
59 |
-
quantize_dropout_cutoff_index=0,
|
60 |
-
quantize_dropout_multiple_of=1,
|
61 |
-
**kwargs,
|
62 |
-
):
|
63 |
-
super().__init__()
|
64 |
-
codebook_dim = len(levels)
|
65 |
-
dim = default(dim, codebook_dim)
|
66 |
-
|
67 |
-
requires_projection = codebook_dim != dim
|
68 |
-
self.project_in = (
|
69 |
-
nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
|
70 |
-
)
|
71 |
-
self.project_out = (
|
72 |
-
nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
|
73 |
-
)
|
74 |
-
self.has_projections = requires_projection
|
75 |
-
|
76 |
-
self.is_channel_first = is_channel_first
|
77 |
-
self.num_quantizers = num_quantizers
|
78 |
-
|
79 |
-
self.levels = levels
|
80 |
-
self.layers = nn.ModuleList([])
|
81 |
-
|
82 |
-
levels_tensor = torch.Tensor(levels)
|
83 |
-
|
84 |
-
scales = []
|
85 |
-
|
86 |
-
for ind in range(num_quantizers):
|
87 |
-
scales.append((levels_tensor - 1) ** -ind)
|
88 |
-
|
89 |
-
fsq = FSQ(levels=levels, dim=codebook_dim, **kwargs)
|
90 |
-
|
91 |
-
self.layers.append(fsq)
|
92 |
-
|
93 |
-
assert all([not fsq.has_projections for fsq in self.layers])
|
94 |
-
|
95 |
-
self.codebook_size = self.layers[0].codebook_size
|
96 |
-
|
97 |
-
self.register_buffer("scales", torch.stack(scales), persistent=False)
|
98 |
-
|
99 |
-
self.quantize_dropout = quantize_dropout and num_quantizers > 1
|
100 |
-
|
101 |
-
assert quantize_dropout_cutoff_index >= 0
|
102 |
-
|
103 |
-
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
|
104 |
-
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
|
105 |
-
|
106 |
-
@property
|
107 |
-
def codebooks(self):
|
108 |
-
codebooks = [layer.implicit_codebook for layer in self.layers]
|
109 |
-
codebooks = torch.stack(codebooks, dim=0)
|
110 |
-
return codebooks
|
111 |
-
|
112 |
-
def get_codes_from_indices(self, indices):
|
113 |
-
|
114 |
-
batch, quantize_dim = indices.shape[0], indices.shape[-1]
|
115 |
-
|
116 |
-
# may also receive indices in the shape of 'b h w q' (accept_image_fmap)
|
117 |
-
|
118 |
-
indices, ps = pack([indices], "b * q")
|
119 |
-
|
120 |
-
# because of quantize dropout, one can pass in indices that are coarse
|
121 |
-
# and the network should be able to reconstruct
|
122 |
-
|
123 |
-
if quantize_dim < self.num_quantizers:
|
124 |
-
assert (
|
125 |
-
self.quantize_dropout > 0.0
|
126 |
-
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
|
127 |
-
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
|
128 |
-
|
129 |
-
# take care of quantizer dropout
|
130 |
-
|
131 |
-
mask = indices == -1
|
132 |
-
indices = indices.masked_fill(
|
133 |
-
mask, 0
|
134 |
-
) # have it fetch a dummy code to be masked out later
|
135 |
-
|
136 |
-
all_codes = get_at("q [c] d, b n q -> q b n d", self.codebooks, indices)
|
137 |
-
|
138 |
-
# mask out any codes that were dropout-ed
|
139 |
-
|
140 |
-
all_codes = all_codes.masked_fill(rearrange(mask, "b n q -> q b n 1"), 0.0)
|
141 |
-
|
142 |
-
# scale the codes
|
143 |
-
|
144 |
-
scales = rearrange(self.scales, "q d -> q 1 1 d")
|
145 |
-
all_codes = all_codes * scales
|
146 |
-
|
147 |
-
# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
|
148 |
-
|
149 |
-
(all_codes,) = unpack(all_codes, ps, "q b * d")
|
150 |
-
|
151 |
-
return all_codes
|
152 |
-
|
153 |
-
def get_output_from_indices(self, indices):
|
154 |
-
codes = self.get_codes_from_indices(indices)
|
155 |
-
codes_summed = reduce(codes, "q ... -> ...", "sum")
|
156 |
-
return self.project_out(codes_summed)
|
157 |
-
|
158 |
-
def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None):
|
159 |
-
num_quant, quant_dropout_multiple_of, device = (
|
160 |
-
self.num_quantizers,
|
161 |
-
self.quantize_dropout_multiple_of,
|
162 |
-
x.device,
|
163 |
-
)
|
164 |
-
|
165 |
-
# handle channel first
|
166 |
-
|
167 |
-
if self.is_channel_first:
|
168 |
-
x = rearrange(x, "b d ... -> b ... d")
|
169 |
-
x, ps = pack([x], "b * d")
|
170 |
-
|
171 |
-
# maybe project in
|
172 |
-
|
173 |
-
x = self.project_in(x)
|
174 |
-
|
175 |
-
quantized_out = 0.0
|
176 |
-
residual = x
|
177 |
-
|
178 |
-
all_indices = []
|
179 |
-
|
180 |
-
should_quantize_dropout = self.training and self.quantize_dropout
|
181 |
-
|
182 |
-
# sample a layer index at which to dropout further residual quantization
|
183 |
-
# also prepare null indices
|
184 |
-
|
185 |
-
if should_quantize_dropout:
|
186 |
-
|
187 |
-
# check if seed is manually passed in
|
188 |
-
|
189 |
-
if not exists(rand_quantize_dropout_fixed_seed):
|
190 |
-
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
|
191 |
-
|
192 |
-
rand = random.Random(rand_quantize_dropout_fixed_seed)
|
193 |
-
|
194 |
-
rand_quantize_dropout_index = rand.randrange(
|
195 |
-
self.quantize_dropout_cutoff_index, num_quant
|
196 |
-
)
|
197 |
-
|
198 |
-
if quant_dropout_multiple_of != 1:
|
199 |
-
rand_quantize_dropout_index = (
|
200 |
-
round_up_multiple(
|
201 |
-
rand_quantize_dropout_index + 1, quant_dropout_multiple_of
|
202 |
-
)
|
203 |
-
- 1
|
204 |
-
)
|
205 |
-
|
206 |
-
null_indices = torch.full(
|
207 |
-
x.shape[:2], -1.0, device=device, dtype=torch.long
|
208 |
-
)
|
209 |
-
|
210 |
-
# go through the layers
|
211 |
-
|
212 |
-
with autocast("cuda", enabled=False):
|
213 |
-
for quantizer_index, (layer, scale) in enumerate(
|
214 |
-
zip(self.layers, self.scales)
|
215 |
-
):
|
216 |
-
|
217 |
-
if (
|
218 |
-
should_quantize_dropout
|
219 |
-
and quantizer_index > rand_quantize_dropout_index
|
220 |
-
):
|
221 |
-
all_indices.append(null_indices)
|
222 |
-
continue
|
223 |
-
|
224 |
-
quantized, indices = layer(residual / scale)
|
225 |
-
|
226 |
-
quantized = quantized * scale
|
227 |
-
|
228 |
-
residual = residual - quantized.detach()
|
229 |
-
quantized_out = quantized_out + quantized
|
230 |
-
|
231 |
-
all_indices.append(indices)
|
232 |
-
|
233 |
-
# project out, if needed
|
234 |
-
|
235 |
-
quantized_out = self.project_out(quantized_out)
|
236 |
-
|
237 |
-
# stack all indices
|
238 |
-
|
239 |
-
all_indices = torch.stack(all_indices, dim=-1)
|
240 |
-
|
241 |
-
# channel first out
|
242 |
-
|
243 |
-
if self.is_channel_first:
|
244 |
-
(quantized_out,) = unpack(quantized_out, ps, "b * d")
|
245 |
-
(all_indices,) = unpack(all_indices, ps, "b * d")
|
246 |
-
|
247 |
-
quantized_out = rearrange(quantized_out, "b ... d -> b d ...")
|
248 |
-
all_indices = rearrange(all_indices, "b ... d -> b d ...")
|
249 |
-
|
250 |
-
# return
|
251 |
-
|
252 |
-
ret = (quantized_out, all_indices)
|
253 |
-
|
254 |
-
if not return_all_codes:
|
255 |
-
return ret
|
256 |
-
|
257 |
-
# whether to return all codes from all codebooks across layers
|
258 |
-
|
259 |
-
all_codes = self.get_codes_from_indices(all_indices)
|
260 |
-
|
261 |
-
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
|
262 |
-
|
263 |
-
return (*ret, all_codes)
|
264 |
-
|
265 |
-
|
266 |
-
# grouped residual fsq
|
267 |
-
|
268 |
-
|
269 |
-
class GroupedResidualFSQ(Module):
|
270 |
-
def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs):
|
271 |
-
super().__init__()
|
272 |
-
self.dim = dim
|
273 |
-
self.groups = groups
|
274 |
-
assert (dim % groups) == 0
|
275 |
-
dim_per_group = dim // groups
|
276 |
-
|
277 |
-
self.accept_image_fmap = accept_image_fmap
|
278 |
-
|
279 |
-
self.rvqs = nn.ModuleList([])
|
280 |
-
|
281 |
-
for _ in range(groups):
|
282 |
-
self.rvqs.append(ResidualFSQ(dim=dim_per_group, **kwargs))
|
283 |
-
|
284 |
-
self.codebook_size = self.rvqs[0].codebook_size
|
285 |
-
|
286 |
-
@property
|
287 |
-
def codebooks(self):
|
288 |
-
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
|
289 |
-
|
290 |
-
@property
|
291 |
-
def split_dim(self):
|
292 |
-
return 1 if self.accept_image_fmap else -1
|
293 |
-
|
294 |
-
def get_codes_from_indices(self, indices):
|
295 |
-
codes = tuple(
|
296 |
-
rvq.get_codes_from_indices(chunk_indices)
|
297 |
-
for rvq, chunk_indices in zip(self.rvqs, indices)
|
298 |
-
)
|
299 |
-
return torch.stack(codes)
|
300 |
-
|
301 |
-
def get_output_from_indices(self, indices):
|
302 |
-
outputs = tuple(
|
303 |
-
rvq.get_output_from_indices(chunk_indices)
|
304 |
-
for rvq, chunk_indices in zip(self.rvqs, indices)
|
305 |
-
)
|
306 |
-
return torch.cat(outputs, dim=self.split_dim)
|
307 |
-
|
308 |
-
def forward(self, x, return_all_codes=False):
|
309 |
-
shape, split_dim, device = x.shape, self.split_dim, x.device
|
310 |
-
assert shape[split_dim] == self.dim
|
311 |
-
|
312 |
-
# split the feature dimension into groups
|
313 |
-
|
314 |
-
x = x.chunk(self.groups, dim=split_dim)
|
315 |
-
|
316 |
-
forward_kwargs = dict(
|
317 |
-
return_all_codes=return_all_codes,
|
318 |
-
rand_quantize_dropout_fixed_seed=(
|
319 |
-
get_maybe_sync_seed(device) if self.training else None
|
320 |
-
),
|
321 |
-
)
|
322 |
-
|
323 |
-
# invoke residual vq on each group
|
324 |
-
|
325 |
-
out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
|
326 |
-
out = tuple(zip(*out))
|
327 |
-
|
328 |
-
# otherwise, get all the zipped outputs and combine them
|
329 |
-
|
330 |
-
quantized, all_indices, *maybe_all_codes = out
|
331 |
-
|
332 |
-
quantized = torch.cat(quantized, dim=split_dim)
|
333 |
-
all_indices = torch.stack(all_indices)
|
334 |
-
|
335 |
-
ret = (quantized, all_indices, *maybe_all_codes)
|
336 |
-
return ret
|
337 |
-
|
338 |
-
|
339 |
-
if __name__ == "__main__":
|
340 |
-
model = ResidualFSQ(
|
341 |
-
levels=[4, 4, 4, 4, 4, 4],
|
342 |
-
num_quantizers=1,
|
343 |
-
dim=30,
|
344 |
-
is_channel_first=True,
|
345 |
-
quantize_dropout=False,
|
346 |
-
)
|
347 |
-
x = torch.randn(2, 30, 10)
|
348 |
-
quantize, embed_ind = model(x)
|
349 |
-
|
350 |
-
emb_from_ind = model.get_output_from_indices(embed_ind.transpose(1, 2))
|
351 |
-
|
352 |
-
print(quantize == emb_from_ind.transpose(1, 2))
|
353 |
-
|
354 |
-
print("quantize shape", quantize.shape)
|
355 |
-
print("embed_ind", embed_ind)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparktts/modules/speaker/ecapa_tdnn.py
DELETED
@@ -1,267 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021 Zhengyang Chen ([email protected])
|
2 |
-
# 2022 Hongji Wang ([email protected])
|
3 |
-
# 2023 Bing Han ([email protected])
|
4 |
-
#
|
5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
-
# you may not use this file except in compliance with the License.
|
7 |
-
# You may obtain a copy of the License at
|
8 |
-
#
|
9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
-
#
|
11 |
-
# Unless required by applicable law or agreed to in writing, software
|
12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
-
# See the License for the specific language governing permissions and
|
15 |
-
# limitations under the License.
|
16 |
-
|
17 |
-
""" This implementation is adapted from github repo:
|
18 |
-
https://github.com/lawlict/ECAPA-TDNN.
|
19 |
-
"""
|
20 |
-
|
21 |
-
import torch
|
22 |
-
import torch.nn as nn
|
23 |
-
import torch.nn.functional as F
|
24 |
-
|
25 |
-
import sparktts.modules.speaker.pooling_layers as pooling_layers
|
26 |
-
|
27 |
-
|
28 |
-
class Res2Conv1dReluBn(nn.Module):
|
29 |
-
"""
|
30 |
-
in_channels == out_channels == channels
|
31 |
-
"""
|
32 |
-
|
33 |
-
def __init__(
|
34 |
-
self,
|
35 |
-
channels,
|
36 |
-
kernel_size=1,
|
37 |
-
stride=1,
|
38 |
-
padding=0,
|
39 |
-
dilation=1,
|
40 |
-
bias=True,
|
41 |
-
scale=4,
|
42 |
-
):
|
43 |
-
super().__init__()
|
44 |
-
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
45 |
-
self.scale = scale
|
46 |
-
self.width = channels // scale
|
47 |
-
self.nums = scale if scale == 1 else scale - 1
|
48 |
-
|
49 |
-
self.convs = []
|
50 |
-
self.bns = []
|
51 |
-
for i in range(self.nums):
|
52 |
-
self.convs.append(
|
53 |
-
nn.Conv1d(
|
54 |
-
self.width,
|
55 |
-
self.width,
|
56 |
-
kernel_size,
|
57 |
-
stride,
|
58 |
-
padding,
|
59 |
-
dilation,
|
60 |
-
bias=bias,
|
61 |
-
)
|
62 |
-
)
|
63 |
-
self.bns.append(nn.BatchNorm1d(self.width))
|
64 |
-
self.convs = nn.ModuleList(self.convs)
|
65 |
-
self.bns = nn.ModuleList(self.bns)
|
66 |
-
|
67 |
-
def forward(self, x):
|
68 |
-
out = []
|
69 |
-
spx = torch.split(x, self.width, 1)
|
70 |
-
sp = spx[0]
|
71 |
-
for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
|
72 |
-
# Order: conv -> relu -> bn
|
73 |
-
if i >= 1:
|
74 |
-
sp = sp + spx[i]
|
75 |
-
sp = conv(sp)
|
76 |
-
sp = bn(F.relu(sp))
|
77 |
-
out.append(sp)
|
78 |
-
if self.scale != 1:
|
79 |
-
out.append(spx[self.nums])
|
80 |
-
out = torch.cat(out, dim=1)
|
81 |
-
|
82 |
-
return out
|
83 |
-
|
84 |
-
|
85 |
-
""" Conv1d + BatchNorm1d + ReLU
|
86 |
-
"""
|
87 |
-
|
88 |
-
|
89 |
-
class Conv1dReluBn(nn.Module):
|
90 |
-
|
91 |
-
def __init__(
|
92 |
-
self,
|
93 |
-
in_channels,
|
94 |
-
out_channels,
|
95 |
-
kernel_size=1,
|
96 |
-
stride=1,
|
97 |
-
padding=0,
|
98 |
-
dilation=1,
|
99 |
-
bias=True,
|
100 |
-
):
|
101 |
-
super().__init__()
|
102 |
-
self.conv = nn.Conv1d(
|
103 |
-
in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
|
104 |
-
)
|
105 |
-
self.bn = nn.BatchNorm1d(out_channels)
|
106 |
-
|
107 |
-
def forward(self, x):
|
108 |
-
return self.bn(F.relu(self.conv(x)))
|
109 |
-
|
110 |
-
|
111 |
-
""" The SE connection of 1D case.
|
112 |
-
"""
|
113 |
-
|
114 |
-
|
115 |
-
class SE_Connect(nn.Module):
|
116 |
-
|
117 |
-
def __init__(self, channels, se_bottleneck_dim=128):
|
118 |
-
super().__init__()
|
119 |
-
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
|
120 |
-
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
|
121 |
-
|
122 |
-
def forward(self, x):
|
123 |
-
out = x.mean(dim=2)
|
124 |
-
out = F.relu(self.linear1(out))
|
125 |
-
out = torch.sigmoid(self.linear2(out))
|
126 |
-
out = x * out.unsqueeze(2)
|
127 |
-
|
128 |
-
return out
|
129 |
-
|
130 |
-
|
131 |
-
""" SE-Res2Block of the ECAPA-TDNN architecture.
|
132 |
-
"""
|
133 |
-
|
134 |
-
|
135 |
-
class SE_Res2Block(nn.Module):
|
136 |
-
|
137 |
-
def __init__(self, channels, kernel_size, stride, padding, dilation, scale):
|
138 |
-
super().__init__()
|
139 |
-
self.se_res2block = nn.Sequential(
|
140 |
-
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
|
141 |
-
Res2Conv1dReluBn(
|
142 |
-
channels, kernel_size, stride, padding, dilation, scale=scale
|
143 |
-
),
|
144 |
-
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
|
145 |
-
SE_Connect(channels),
|
146 |
-
)
|
147 |
-
|
148 |
-
def forward(self, x):
|
149 |
-
return x + self.se_res2block(x)
|
150 |
-
|
151 |
-
|
152 |
-
class ECAPA_TDNN(nn.Module):
|
153 |
-
|
154 |
-
def __init__(
|
155 |
-
self,
|
156 |
-
channels=512,
|
157 |
-
feat_dim=80,
|
158 |
-
embed_dim=192,
|
159 |
-
pooling_func="ASTP",
|
160 |
-
global_context_att=False,
|
161 |
-
emb_bn=False,
|
162 |
-
):
|
163 |
-
super().__init__()
|
164 |
-
|
165 |
-
self.layer1 = Conv1dReluBn(feat_dim, channels, kernel_size=5, padding=2)
|
166 |
-
self.layer2 = SE_Res2Block(
|
167 |
-
channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8
|
168 |
-
)
|
169 |
-
self.layer3 = SE_Res2Block(
|
170 |
-
channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8
|
171 |
-
)
|
172 |
-
self.layer4 = SE_Res2Block(
|
173 |
-
channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8
|
174 |
-
)
|
175 |
-
|
176 |
-
cat_channels = channels * 3
|
177 |
-
out_channels = 512 * 3
|
178 |
-
self.conv = nn.Conv1d(cat_channels, out_channels, kernel_size=1)
|
179 |
-
self.pool = getattr(pooling_layers, pooling_func)(
|
180 |
-
in_dim=out_channels, global_context_att=global_context_att
|
181 |
-
)
|
182 |
-
self.pool_out_dim = self.pool.get_out_dim()
|
183 |
-
self.bn = nn.BatchNorm1d(self.pool_out_dim)
|
184 |
-
self.linear = nn.Linear(self.pool_out_dim, embed_dim)
|
185 |
-
self.emb_bn = emb_bn
|
186 |
-
if emb_bn: # better in SSL for SV
|
187 |
-
self.bn2 = nn.BatchNorm1d(embed_dim)
|
188 |
-
else:
|
189 |
-
self.bn2 = nn.Identity()
|
190 |
-
|
191 |
-
def forward(self, x, return_latent=False):
|
192 |
-
x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
|
193 |
-
|
194 |
-
out1 = self.layer1(x)
|
195 |
-
out2 = self.layer2(out1)
|
196 |
-
out3 = self.layer3(out2)
|
197 |
-
out4 = self.layer4(out3)
|
198 |
-
|
199 |
-
out = torch.cat([out2, out3, out4], dim=1)
|
200 |
-
latent = F.relu(self.conv(out))
|
201 |
-
out = self.bn(self.pool(latent))
|
202 |
-
out = self.linear(out)
|
203 |
-
if self.emb_bn:
|
204 |
-
out = self.bn2(out)
|
205 |
-
|
206 |
-
if return_latent:
|
207 |
-
return out, latent
|
208 |
-
return out
|
209 |
-
|
210 |
-
|
211 |
-
def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
212 |
-
return ECAPA_TDNN(
|
213 |
-
channels=1024,
|
214 |
-
feat_dim=feat_dim,
|
215 |
-
embed_dim=embed_dim,
|
216 |
-
pooling_func=pooling_func,
|
217 |
-
emb_bn=emb_bn,
|
218 |
-
)
|
219 |
-
|
220 |
-
|
221 |
-
def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
222 |
-
return ECAPA_TDNN(
|
223 |
-
channels=1024,
|
224 |
-
feat_dim=feat_dim,
|
225 |
-
embed_dim=embed_dim,
|
226 |
-
pooling_func=pooling_func,
|
227 |
-
global_context_att=True,
|
228 |
-
emb_bn=emb_bn,
|
229 |
-
)
|
230 |
-
|
231 |
-
|
232 |
-
def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
233 |
-
return ECAPA_TDNN(
|
234 |
-
channels=512,
|
235 |
-
feat_dim=feat_dim,
|
236 |
-
embed_dim=embed_dim,
|
237 |
-
pooling_func=pooling_func,
|
238 |
-
emb_bn=emb_bn,
|
239 |
-
)
|
240 |
-
|
241 |
-
|
242 |
-
def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
243 |
-
return ECAPA_TDNN(
|
244 |
-
channels=512,
|
245 |
-
feat_dim=feat_dim,
|
246 |
-
embed_dim=embed_dim,
|
247 |
-
pooling_func=pooling_func,
|
248 |
-
global_context_att=True,
|
249 |
-
emb_bn=emb_bn,
|
250 |
-
)
|
251 |
-
|
252 |
-
|
253 |
-
if __name__ == "__main__":
|
254 |
-
x = torch.zeros(1, 200, 100)
|
255 |
-
model = ECAPA_TDNN_GLOB_c512(feat_dim=100, embed_dim=256, pooling_func="ASTP")
|
256 |
-
model.eval()
|
257 |
-
out, latent = model(x, True)
|
258 |
-
print(out.shape)
|
259 |
-
print(latent.shape)
|
260 |
-
|
261 |
-
num_params = sum(param.numel() for param in model.parameters())
|
262 |
-
print("{} M".format(num_params / 1e6))
|
263 |
-
|
264 |
-
# from thop import profile
|
265 |
-
# x_np = torch.randn(1, 200, 80)
|
266 |
-
# flops, params = profile(model, inputs=(x_np, ))
|
267 |
-
# print("FLOPs: {} G, Params: {} M".format(flops / 1e9, params / 1e6))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparktts/modules/speaker/perceiver_encoder.py
DELETED
@@ -1,360 +0,0 @@
|
|
1 |
-
# Copyright (c) 2025 SparkAudio
|
2 |
-
# 2025 Xinsheng Wang ([email protected])
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
|
17 |
-
|
18 |
-
from collections import namedtuple
|
19 |
-
from functools import wraps
|
20 |
-
|
21 |
-
import torch
|
22 |
-
import torch.nn.functional as F
|
23 |
-
from einops import rearrange, repeat
|
24 |
-
from einops.layers.torch import Rearrange
|
25 |
-
from packaging import version
|
26 |
-
from torch import einsum, nn
|
27 |
-
|
28 |
-
|
29 |
-
def exists(val):
|
30 |
-
return val is not None
|
31 |
-
|
32 |
-
|
33 |
-
def once(fn):
|
34 |
-
called = False
|
35 |
-
|
36 |
-
@wraps(fn)
|
37 |
-
def inner(x):
|
38 |
-
nonlocal called
|
39 |
-
if called:
|
40 |
-
return
|
41 |
-
called = True
|
42 |
-
return fn(x)
|
43 |
-
|
44 |
-
return inner
|
45 |
-
|
46 |
-
|
47 |
-
print_once = once(print)
|
48 |
-
|
49 |
-
# main class
|
50 |
-
|
51 |
-
|
52 |
-
class Attend(nn.Module):
|
53 |
-
def __init__(self, dropout=0.0, causal=False, use_flash=False):
|
54 |
-
super().__init__()
|
55 |
-
self.dropout = dropout
|
56 |
-
self.attn_dropout = nn.Dropout(dropout)
|
57 |
-
|
58 |
-
self.causal = causal
|
59 |
-
self.register_buffer("mask", None, persistent=False)
|
60 |
-
|
61 |
-
self.use_flash = use_flash
|
62 |
-
assert not (
|
63 |
-
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
64 |
-
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
65 |
-
|
66 |
-
# determine efficient attention configs for cuda and cpu
|
67 |
-
self.config = namedtuple(
|
68 |
-
"EfficientAttentionConfig",
|
69 |
-
["enable_flash", "enable_math", "enable_mem_efficient"],
|
70 |
-
)
|
71 |
-
self.cpu_config = self.config(True, True, True)
|
72 |
-
self.cuda_config = None
|
73 |
-
|
74 |
-
if not torch.cuda.is_available() or not use_flash:
|
75 |
-
return
|
76 |
-
|
77 |
-
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
|
78 |
-
|
79 |
-
if device_properties.major == 8 and device_properties.minor == 0:
|
80 |
-
print_once(
|
81 |
-
"A100 GPU detected, using flash attention if input tensor is on cuda"
|
82 |
-
)
|
83 |
-
self.cuda_config = self.config(True, False, False)
|
84 |
-
else:
|
85 |
-
print_once(
|
86 |
-
"Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
|
87 |
-
)
|
88 |
-
self.cuda_config = self.config(False, True, True)
|
89 |
-
|
90 |
-
def get_mask(self, n, device):
|
91 |
-
if exists(self.mask) and self.mask.shape[-1] >= n:
|
92 |
-
return self.mask[:n, :n]
|
93 |
-
|
94 |
-
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
|
95 |
-
self.register_buffer("mask", mask, persistent=False)
|
96 |
-
return mask
|
97 |
-
|
98 |
-
def flash_attn(self, q, k, v, mask=None):
|
99 |
-
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
|
100 |
-
|
101 |
-
# Recommended for multi-query single-key-value attention by Tri Dao
|
102 |
-
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
|
103 |
-
|
104 |
-
if k.ndim == 3:
|
105 |
-
k = rearrange(k, "b ... -> b 1 ...").expand_as(q)
|
106 |
-
|
107 |
-
if v.ndim == 3:
|
108 |
-
v = rearrange(v, "b ... -> b 1 ...").expand_as(q)
|
109 |
-
|
110 |
-
# Check if mask exists and expand to compatible shape
|
111 |
-
# The mask is B L, so it would have to be expanded to B H N L
|
112 |
-
|
113 |
-
if exists(mask):
|
114 |
-
mask = rearrange(mask, "b j -> b 1 1 j")
|
115 |
-
mask = mask.expand(-1, heads, q_len, -1)
|
116 |
-
|
117 |
-
# Check if there is a compatible device for flash attention
|
118 |
-
|
119 |
-
config = self.cuda_config if is_cuda else self.cpu_config
|
120 |
-
|
121 |
-
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
122 |
-
|
123 |
-
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
124 |
-
out = F.scaled_dot_product_attention(
|
125 |
-
q,
|
126 |
-
k,
|
127 |
-
v,
|
128 |
-
attn_mask=mask,
|
129 |
-
dropout_p=self.dropout if self.training else 0.0,
|
130 |
-
is_causal=self.causal,
|
131 |
-
)
|
132 |
-
|
133 |
-
return out
|
134 |
-
|
135 |
-
def forward(self, q, k, v, mask=None):
|
136 |
-
"""
|
137 |
-
einstein notation
|
138 |
-
b - batch
|
139 |
-
h - heads
|
140 |
-
n, i, j - sequence length (base sequence length, source, target)
|
141 |
-
d - feature dimension
|
142 |
-
"""
|
143 |
-
|
144 |
-
n, device = q.shape[-2], q.device
|
145 |
-
|
146 |
-
scale = q.shape[-1] ** -0.5
|
147 |
-
|
148 |
-
if self.use_flash:
|
149 |
-
return self.flash_attn(q, k, v, mask=mask)
|
150 |
-
|
151 |
-
kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d"
|
152 |
-
|
153 |
-
# similarity
|
154 |
-
|
155 |
-
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
|
156 |
-
|
157 |
-
# key padding mask
|
158 |
-
|
159 |
-
if exists(mask):
|
160 |
-
mask = rearrange(mask, "b j -> b 1 1 j")
|
161 |
-
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
162 |
-
|
163 |
-
# causal mask
|
164 |
-
|
165 |
-
if self.causal:
|
166 |
-
causal_mask = self.get_mask(n, device)
|
167 |
-
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
|
168 |
-
|
169 |
-
# attention
|
170 |
-
|
171 |
-
attn = sim.softmax(dim=-1)
|
172 |
-
attn = self.attn_dropout(attn)
|
173 |
-
|
174 |
-
# aggregate values
|
175 |
-
|
176 |
-
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
|
177 |
-
|
178 |
-
return out
|
179 |
-
|
180 |
-
|
181 |
-
def Sequential(*mods):
|
182 |
-
return nn.Sequential(*filter(exists, mods))
|
183 |
-
|
184 |
-
|
185 |
-
def exists(x):
|
186 |
-
return x is not None
|
187 |
-
|
188 |
-
|
189 |
-
def default(val, d):
|
190 |
-
if exists(val):
|
191 |
-
return val
|
192 |
-
return d() if callable(d) else d
|
193 |
-
|
194 |
-
|
195 |
-
class RMSNorm(nn.Module):
|
196 |
-
def __init__(self, dim, scale=True, dim_cond=None):
|
197 |
-
super().__init__()
|
198 |
-
self.cond = exists(dim_cond)
|
199 |
-
self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
|
200 |
-
|
201 |
-
self.scale = dim**0.5
|
202 |
-
self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
|
203 |
-
|
204 |
-
def forward(self, x, cond=None):
|
205 |
-
gamma = default(self.gamma, 1)
|
206 |
-
out = F.normalize(x, dim=-1) * self.scale * gamma
|
207 |
-
|
208 |
-
if not self.cond:
|
209 |
-
return out
|
210 |
-
|
211 |
-
assert exists(cond)
|
212 |
-
gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1)
|
213 |
-
gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta))
|
214 |
-
return out * gamma + beta
|
215 |
-
|
216 |
-
|
217 |
-
class CausalConv1d(nn.Conv1d):
|
218 |
-
def __init__(self, *args, **kwargs):
|
219 |
-
super().__init__(*args, **kwargs)
|
220 |
-
(kernel_size,) = self.kernel_size
|
221 |
-
(dilation,) = self.dilation
|
222 |
-
(stride,) = self.stride
|
223 |
-
|
224 |
-
assert stride == 1
|
225 |
-
self.causal_padding = dilation * (kernel_size - 1)
|
226 |
-
|
227 |
-
def forward(self, x):
|
228 |
-
causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
229 |
-
return super().forward(causal_padded_x)
|
230 |
-
|
231 |
-
|
232 |
-
class GEGLU(nn.Module):
|
233 |
-
def forward(self, x):
|
234 |
-
x, gate = x.chunk(2, dim=-1)
|
235 |
-
return F.gelu(gate) * x
|
236 |
-
|
237 |
-
|
238 |
-
def FeedForward(dim, mult=4, causal_conv=False):
|
239 |
-
dim_inner = int(dim * mult * 2 / 3)
|
240 |
-
|
241 |
-
conv = None
|
242 |
-
if causal_conv:
|
243 |
-
conv = nn.Sequential(
|
244 |
-
Rearrange("b n d -> b d n"),
|
245 |
-
CausalConv1d(dim_inner, dim_inner, 3),
|
246 |
-
Rearrange("b d n -> b n d"),
|
247 |
-
)
|
248 |
-
|
249 |
-
return Sequential(
|
250 |
-
nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim)
|
251 |
-
)
|
252 |
-
|
253 |
-
|
254 |
-
class Attention(nn.Module):
|
255 |
-
def __init__(
|
256 |
-
self,
|
257 |
-
dim,
|
258 |
-
*,
|
259 |
-
dim_context=None,
|
260 |
-
causal=False,
|
261 |
-
dim_head=64,
|
262 |
-
heads=8,
|
263 |
-
dropout=0.0,
|
264 |
-
use_flash=False,
|
265 |
-
cross_attn_include_queries=False,
|
266 |
-
):
|
267 |
-
super().__init__()
|
268 |
-
self.scale = dim_head**-0.5
|
269 |
-
self.heads = heads
|
270 |
-
self.cross_attn_include_queries = cross_attn_include_queries
|
271 |
-
|
272 |
-
dim_inner = dim_head * heads
|
273 |
-
dim_context = default(dim_context, dim)
|
274 |
-
|
275 |
-
self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
|
276 |
-
self.to_q = nn.Linear(dim, dim_inner, bias=False)
|
277 |
-
self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
|
278 |
-
self.to_out = nn.Linear(dim_inner, dim, bias=False)
|
279 |
-
|
280 |
-
def forward(self, x, context=None, mask=None):
|
281 |
-
h, has_context = self.heads, exists(context)
|
282 |
-
|
283 |
-
context = default(context, x)
|
284 |
-
|
285 |
-
if has_context and self.cross_attn_include_queries:
|
286 |
-
context = torch.cat((x, context), dim=-2)
|
287 |
-
|
288 |
-
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
|
289 |
-
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
290 |
-
|
291 |
-
out = self.attend(q, k, v, mask=mask)
|
292 |
-
|
293 |
-
out = rearrange(out, "b h n d -> b n (h d)")
|
294 |
-
return self.to_out(out)
|
295 |
-
|
296 |
-
|
297 |
-
class PerceiverResampler(nn.Module):
|
298 |
-
def __init__(
|
299 |
-
self,
|
300 |
-
*,
|
301 |
-
dim,
|
302 |
-
depth=2,
|
303 |
-
dim_context=None,
|
304 |
-
num_latents=32,
|
305 |
-
dim_head=64,
|
306 |
-
heads=8,
|
307 |
-
ff_mult=4,
|
308 |
-
use_flash_attn=False,
|
309 |
-
):
|
310 |
-
super().__init__()
|
311 |
-
dim_context = default(dim_context, dim)
|
312 |
-
|
313 |
-
self.proj_context = (
|
314 |
-
nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
|
315 |
-
)
|
316 |
-
|
317 |
-
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
318 |
-
nn.init.normal_(self.latents, std=0.02)
|
319 |
-
|
320 |
-
self.layers = nn.ModuleList([])
|
321 |
-
for _ in range(depth):
|
322 |
-
self.layers.append(
|
323 |
-
nn.ModuleList(
|
324 |
-
[
|
325 |
-
Attention(
|
326 |
-
dim=dim,
|
327 |
-
dim_head=dim_head,
|
328 |
-
heads=heads,
|
329 |
-
use_flash=use_flash_attn,
|
330 |
-
cross_attn_include_queries=True,
|
331 |
-
),
|
332 |
-
FeedForward(dim=dim, mult=ff_mult),
|
333 |
-
]
|
334 |
-
)
|
335 |
-
)
|
336 |
-
|
337 |
-
self.norm = RMSNorm(dim)
|
338 |
-
|
339 |
-
def forward(self, x, mask=None):
|
340 |
-
batch = x.shape[0]
|
341 |
-
|
342 |
-
x = self.proj_context(x)
|
343 |
-
|
344 |
-
latents = repeat(self.latents, "n d -> b n d", b=batch)
|
345 |
-
|
346 |
-
for attn, ff in self.layers:
|
347 |
-
latents = attn(latents, x, mask=mask) + latents
|
348 |
-
latents = ff(latents) + latents
|
349 |
-
|
350 |
-
return self.norm(latents)
|
351 |
-
|
352 |
-
|
353 |
-
if __name__ == "__main__":
|
354 |
-
model = PerceiverResampler(dim=256, dim_context=80)
|
355 |
-
x = torch.randn(8, 200, 80)
|
356 |
-
out = model(x)
|
357 |
-
print(out.shape) # [8, 32, 80]
|
358 |
-
|
359 |
-
num_params = sum(param.numel() for param in model.parameters())
|
360 |
-
print("{} M".format(num_params / 1e6))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparktts/modules/speaker/pooling_layers.py
DELETED
@@ -1,298 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021 Shuai Wang ([email protected])
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
"""
|
15 |
-
Pooling functions to aggregate frame-level deep features
|
16 |
-
into segment-level speaker embeddings
|
17 |
-
|
18 |
-
High-order statistics are surprisingly effective, TSDP acts similarly as TSTP,
|
19 |
-
even though we remove the mean statistic, on Voxceleb.
|
20 |
-
"""
|
21 |
-
|
22 |
-
import torch
|
23 |
-
import torch.nn as nn
|
24 |
-
import torch.nn.functional as F
|
25 |
-
|
26 |
-
|
27 |
-
class TAP(nn.Module):
|
28 |
-
"""
|
29 |
-
Temporal average pooling, only first-order mean is considered
|
30 |
-
"""
|
31 |
-
|
32 |
-
def __init__(self, in_dim=0, **kwargs):
|
33 |
-
super(TAP, self).__init__()
|
34 |
-
self.in_dim = in_dim
|
35 |
-
|
36 |
-
def forward(self, x):
|
37 |
-
pooling_mean = x.mean(dim=-1)
|
38 |
-
# To be compatable with 2D input
|
39 |
-
pooling_mean = pooling_mean.flatten(start_dim=1)
|
40 |
-
return pooling_mean
|
41 |
-
|
42 |
-
def get_out_dim(self):
|
43 |
-
self.out_dim = self.in_dim
|
44 |
-
return self.out_dim
|
45 |
-
|
46 |
-
|
47 |
-
class TSDP(nn.Module):
|
48 |
-
"""
|
49 |
-
Temporal standard deviation pooling, only second-order std is considered
|
50 |
-
"""
|
51 |
-
|
52 |
-
def __init__(self, in_dim=0, **kwargs):
|
53 |
-
super(TSDP, self).__init__()
|
54 |
-
self.in_dim = in_dim
|
55 |
-
|
56 |
-
def forward(self, x):
|
57 |
-
# The last dimension is the temporal axis
|
58 |
-
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
|
59 |
-
pooling_std = pooling_std.flatten(start_dim=1)
|
60 |
-
return pooling_std
|
61 |
-
|
62 |
-
def get_out_dim(self):
|
63 |
-
self.out_dim = self.in_dim
|
64 |
-
return self.out_dim
|
65 |
-
|
66 |
-
|
67 |
-
class TSTP(nn.Module):
|
68 |
-
"""
|
69 |
-
Temporal statistics pooling, concatenate mean and std, which is used in
|
70 |
-
x-vector
|
71 |
-
Comment: simple concatenation can not make full use of both statistics
|
72 |
-
"""
|
73 |
-
|
74 |
-
def __init__(self, in_dim=0, **kwargs):
|
75 |
-
super(TSTP, self).__init__()
|
76 |
-
self.in_dim = in_dim
|
77 |
-
|
78 |
-
def forward(self, x):
|
79 |
-
# The last dimension is the temporal axis
|
80 |
-
pooling_mean = x.mean(dim=-1)
|
81 |
-
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
|
82 |
-
pooling_mean = pooling_mean.flatten(start_dim=1)
|
83 |
-
pooling_std = pooling_std.flatten(start_dim=1)
|
84 |
-
stats = torch.cat((pooling_mean, pooling_std), 1)
|
85 |
-
return stats
|
86 |
-
|
87 |
-
def get_out_dim(self):
|
88 |
-
self.out_dim = self.in_dim * 2
|
89 |
-
return self.out_dim
|
90 |
-
|
91 |
-
|
92 |
-
class ASTP(nn.Module):
|
93 |
-
""" Attentive statistics pooling: Channel- and context-dependent
|
94 |
-
statistics pooling, first used in ECAPA_TDNN.
|
95 |
-
"""
|
96 |
-
|
97 |
-
def __init__(self,
|
98 |
-
in_dim,
|
99 |
-
bottleneck_dim=128,
|
100 |
-
global_context_att=False,
|
101 |
-
**kwargs):
|
102 |
-
super(ASTP, self).__init__()
|
103 |
-
self.in_dim = in_dim
|
104 |
-
self.global_context_att = global_context_att
|
105 |
-
|
106 |
-
# Use Conv1d with stride == 1 rather than Linear, then we don't
|
107 |
-
# need to transpose inputs.
|
108 |
-
if global_context_att:
|
109 |
-
self.linear1 = nn.Conv1d(
|
110 |
-
in_dim * 3, bottleneck_dim,
|
111 |
-
kernel_size=1) # equals W and b in the paper
|
112 |
-
else:
|
113 |
-
self.linear1 = nn.Conv1d(
|
114 |
-
in_dim, bottleneck_dim,
|
115 |
-
kernel_size=1) # equals W and b in the paper
|
116 |
-
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
|
117 |
-
kernel_size=1) # equals V and k in the paper
|
118 |
-
|
119 |
-
def forward(self, x):
|
120 |
-
"""
|
121 |
-
x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
|
122 |
-
or a 4-dimensional tensor in resnet architecture (B,C,F,T)
|
123 |
-
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
124 |
-
"""
|
125 |
-
if len(x.shape) == 4:
|
126 |
-
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
|
127 |
-
assert len(x.shape) == 3
|
128 |
-
|
129 |
-
if self.global_context_att:
|
130 |
-
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
131 |
-
context_std = torch.sqrt(
|
132 |
-
torch.var(x, dim=-1, keepdim=True) + 1e-7).expand_as(x)
|
133 |
-
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
134 |
-
else:
|
135 |
-
x_in = x
|
136 |
-
|
137 |
-
# DON'T use ReLU here! ReLU may be hard to converge.
|
138 |
-
alpha = torch.tanh(
|
139 |
-
self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
|
140 |
-
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
141 |
-
mean = torch.sum(alpha * x, dim=2)
|
142 |
-
var = torch.sum(alpha * (x**2), dim=2) - mean**2
|
143 |
-
std = torch.sqrt(var.clamp(min=1e-7))
|
144 |
-
return torch.cat([mean, std], dim=1)
|
145 |
-
|
146 |
-
def get_out_dim(self):
|
147 |
-
self.out_dim = 2 * self.in_dim
|
148 |
-
return self.out_dim
|
149 |
-
|
150 |
-
|
151 |
-
class MHASTP(torch.nn.Module):
|
152 |
-
""" Multi head attentive statistics pooling
|
153 |
-
Reference:
|
154 |
-
Self Multi-Head Attention for Speaker Recognition
|
155 |
-
https://arxiv.org/pdf/1906.09890.pdf
|
156 |
-
"""
|
157 |
-
|
158 |
-
def __init__(self,
|
159 |
-
in_dim,
|
160 |
-
layer_num=2,
|
161 |
-
head_num=2,
|
162 |
-
d_s=1,
|
163 |
-
bottleneck_dim=64,
|
164 |
-
**kwargs):
|
165 |
-
super(MHASTP, self).__init__()
|
166 |
-
assert (in_dim % head_num
|
167 |
-
) == 0 # make sure that head num can be divided by input_dim
|
168 |
-
self.in_dim = in_dim
|
169 |
-
self.head_num = head_num
|
170 |
-
d_model = int(in_dim / head_num)
|
171 |
-
channel_dims = [bottleneck_dim for i in range(layer_num + 1)]
|
172 |
-
if d_s > 1:
|
173 |
-
d_s = d_model
|
174 |
-
else:
|
175 |
-
d_s = 1
|
176 |
-
self.d_s = d_s
|
177 |
-
channel_dims[0], channel_dims[-1] = d_model, d_s
|
178 |
-
heads_att_trans = []
|
179 |
-
for i in range(self.head_num):
|
180 |
-
att_trans = nn.Sequential()
|
181 |
-
for i in range(layer_num - 1):
|
182 |
-
att_trans.add_module(
|
183 |
-
'att_' + str(i),
|
184 |
-
nn.Conv1d(channel_dims[i], channel_dims[i + 1], 1, 1))
|
185 |
-
att_trans.add_module('tanh' + str(i), nn.Tanh())
|
186 |
-
att_trans.add_module(
|
187 |
-
'att_' + str(layer_num - 1),
|
188 |
-
nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num],
|
189 |
-
1, 1))
|
190 |
-
heads_att_trans.append(att_trans)
|
191 |
-
self.heads_att_trans = nn.ModuleList(heads_att_trans)
|
192 |
-
|
193 |
-
def forward(self, input):
|
194 |
-
"""
|
195 |
-
input: a 3-dimensional tensor in xvector architecture
|
196 |
-
or a 4-dimensional tensor in resnet architecture
|
197 |
-
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
198 |
-
"""
|
199 |
-
if len(input.shape) == 4: # B x F x T
|
200 |
-
input = input.reshape(input.shape[0],
|
201 |
-
input.shape[1] * input.shape[2],
|
202 |
-
input.shape[3])
|
203 |
-
assert len(input.shape) == 3
|
204 |
-
bs, f_dim, t_dim = input.shape
|
205 |
-
chunks = torch.chunk(input, self.head_num, 1)
|
206 |
-
# split
|
207 |
-
chunks_out = []
|
208 |
-
# for i in range(self.head_num):
|
209 |
-
# att_score = self.heads_att_trans[i](chunks[i])
|
210 |
-
for i, layer in enumerate(self.heads_att_trans):
|
211 |
-
att_score = layer(chunks[i])
|
212 |
-
alpha = F.softmax(att_score, dim=-1)
|
213 |
-
mean = torch.sum(alpha * chunks[i], dim=2)
|
214 |
-
var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2
|
215 |
-
std = torch.sqrt(var.clamp(min=1e-7))
|
216 |
-
chunks_out.append(torch.cat((mean, std), dim=1))
|
217 |
-
out = torch.cat(chunks_out, dim=1)
|
218 |
-
return out
|
219 |
-
|
220 |
-
def get_out_dim(self):
|
221 |
-
self.out_dim = 2 * self.in_dim
|
222 |
-
return self.out_dim
|
223 |
-
|
224 |
-
|
225 |
-
class MQMHASTP(torch.nn.Module):
|
226 |
-
""" An attentive pooling
|
227 |
-
Reference:
|
228 |
-
multi query multi head attentive statistics pooling
|
229 |
-
https://arxiv.org/pdf/2110.05042.pdf
|
230 |
-
Args:
|
231 |
-
in_dim: the feature dimension of input
|
232 |
-
layer_num: the number of layer in the pooling layer
|
233 |
-
query_num: the number of querys
|
234 |
-
head_num: the number of heads
|
235 |
-
bottleneck_dim: the bottleneck dimension
|
236 |
-
|
237 |
-
SA (H = 1, Q = 1, n = 2, d_s = 1) ref:
|
238 |
-
https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf
|
239 |
-
MHA (H > 1, Q = 1, n = 1, d_s = 1) ref:
|
240 |
-
https://arxiv.org/pdf/1906.09890.pdf
|
241 |
-
AS (H = 1, Q > 1, n = 2, d_s = 1) ref:
|
242 |
-
https://arxiv.org/pdf/1803.10963.pdf
|
243 |
-
VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref:
|
244 |
-
http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf
|
245 |
-
"""
|
246 |
-
|
247 |
-
def __init__(self,
|
248 |
-
in_dim,
|
249 |
-
layer_num=2,
|
250 |
-
query_num=2,
|
251 |
-
head_num=8,
|
252 |
-
d_s=2,
|
253 |
-
bottleneck_dim=64,
|
254 |
-
**kwargs):
|
255 |
-
super(MQMHASTP, self).__init__()
|
256 |
-
self.n_query = nn.ModuleList([
|
257 |
-
MHASTP(in_dim,
|
258 |
-
layer_num=layer_num,
|
259 |
-
head_num=head_num,
|
260 |
-
d_s=d_s,
|
261 |
-
bottleneck_dim=bottleneck_dim) for i in range(query_num)
|
262 |
-
])
|
263 |
-
self.query_num = query_num
|
264 |
-
self.in_dim = in_dim
|
265 |
-
|
266 |
-
def forward(self, input):
|
267 |
-
"""
|
268 |
-
input: a 3-dimensional tensor in xvector architecture
|
269 |
-
or a 4-dimensional tensor in resnet architecture
|
270 |
-
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
271 |
-
"""
|
272 |
-
if len(input.shape) == 4: # B x F x T
|
273 |
-
input = input.reshape(input.shape[0],
|
274 |
-
input.shape[1] * input.shape[2],
|
275 |
-
input.shape[3])
|
276 |
-
assert len(input.shape) == 3
|
277 |
-
res = []
|
278 |
-
for i, layer in enumerate(self.n_query):
|
279 |
-
res.append(layer(input))
|
280 |
-
out = torch.cat(res, dim=-1)
|
281 |
-
return out
|
282 |
-
|
283 |
-
def get_out_dim(self):
|
284 |
-
self.out_dim = self.in_dim * 2 * self.query_num
|
285 |
-
return self.out_dim
|
286 |
-
|
287 |
-
|
288 |
-
if __name__ == '__main__':
|
289 |
-
data = torch.randn(16, 512, 10, 35)
|
290 |
-
# model = StatisticsPooling()
|
291 |
-
model = MQMHASTP(512 * 10)
|
292 |
-
model = MHASTP(512 * 10)
|
293 |
-
model = MQMHASTP(512 * 10, context=False)
|
294 |
-
print(model)
|
295 |
-
|
296 |
-
out = model(data)
|
297 |
-
print(out.shape)
|
298 |
-
print(model.get_out_dim())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparktts/modules/speaker/speaker_encoder.py
DELETED
@@ -1,136 +0,0 @@
|
|
1 |
-
# Copyright (c) 2025 SparkAudio
|
2 |
-
# 2025 Xinsheng Wang ([email protected])
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
import torch
|
17 |
-
import torch.nn as nn
|
18 |
-
|
19 |
-
from typing import List, Tuple
|
20 |
-
from sparktts.modules.fsq.residual_fsq import ResidualFSQ
|
21 |
-
from sparktts.modules.speaker.ecapa_tdnn import ECAPA_TDNN_GLOB_c512
|
22 |
-
from sparktts.modules.speaker.perceiver_encoder import PerceiverResampler
|
23 |
-
|
24 |
-
"""
|
25 |
-
x-vector + d-vector
|
26 |
-
"""
|
27 |
-
|
28 |
-
|
29 |
-
class SpeakerEncoder(nn.Module):
|
30 |
-
"""
|
31 |
-
|
32 |
-
Args:
|
33 |
-
input_dim (int): acoustic feature dimension
|
34 |
-
out_dim (int): output dimension of x-vector and d-vector
|
35 |
-
latent_dim (int): latent dimension before quantization
|
36 |
-
token_num (int): sequence length of speaker tokens
|
37 |
-
fsq_levels (List[int]): number of levels for each quantizer
|
38 |
-
fsq_num_quantizers (int): number of quantizers
|
39 |
-
|
40 |
-
Return:
|
41 |
-
speaker_embs: (B, T2, out_dim)
|
42 |
-
"""
|
43 |
-
|
44 |
-
def __init__(
|
45 |
-
self,
|
46 |
-
input_dim: int = 100,
|
47 |
-
out_dim: int = 512,
|
48 |
-
latent_dim: int = 128,
|
49 |
-
token_num: int = 32,
|
50 |
-
fsq_levels: List[int] = [4, 4, 4, 4, 4, 4],
|
51 |
-
fsq_num_quantizers: int = 1,
|
52 |
-
):
|
53 |
-
super(SpeakerEncoder, self).__init__()
|
54 |
-
|
55 |
-
self.speaker_encoder = ECAPA_TDNN_GLOB_c512(
|
56 |
-
feat_dim=input_dim, embed_dim=out_dim
|
57 |
-
)
|
58 |
-
self.perceiver_sampler = PerceiverResampler(
|
59 |
-
dim=latent_dim, dim_context=512 * 3, num_latents=token_num
|
60 |
-
)
|
61 |
-
self.quantizer = ResidualFSQ(
|
62 |
-
levels=fsq_levels,
|
63 |
-
num_quantizers=fsq_num_quantizers,
|
64 |
-
dim=latent_dim,
|
65 |
-
is_channel_first=True,
|
66 |
-
quantize_dropout=False,
|
67 |
-
)
|
68 |
-
|
69 |
-
self.project = nn.Linear(latent_dim * token_num, out_dim)
|
70 |
-
|
71 |
-
def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor:
|
72 |
-
zq = self.quantizer.get_codes_from_indices(indices.transpose(1, 2))
|
73 |
-
return zq.transpose(1, 2)
|
74 |
-
|
75 |
-
def get_indices(self, mels: torch.Tensor) -> torch.Tensor:
|
76 |
-
mels = mels.transpose(1, 2)
|
77 |
-
x = self.perceiver_sampler(mels).transpose(1, 2)
|
78 |
-
zq, indices = self.quantizer(x)
|
79 |
-
return indices
|
80 |
-
|
81 |
-
def forward(self, mels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
82 |
-
"""
|
83 |
-
Args:
|
84 |
-
mels: (B, D_mel, T1)
|
85 |
-
|
86 |
-
Return:
|
87 |
-
x_vector: (B, out_dim)
|
88 |
-
d_vector: (B, out_dim)
|
89 |
-
"""
|
90 |
-
# mels = mels.transpose(1,2)
|
91 |
-
|
92 |
-
x_vector, features = self.speaker_encoder(mels, True)
|
93 |
-
x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
|
94 |
-
zq, indices = self.quantizer(x) # zq: (B, latent_dim, T2, latent_dim)
|
95 |
-
x = zq.reshape(zq.shape[0], -1)
|
96 |
-
d_vector = self.project(x)
|
97 |
-
|
98 |
-
return x_vector, d_vector
|
99 |
-
|
100 |
-
def tokenize(self, mels: torch.Tensor) -> torch.Tensor:
|
101 |
-
"""tokenize the input mel spectrogram"""
|
102 |
-
_, features = self.speaker_encoder(mels, True)
|
103 |
-
x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
|
104 |
-
zq, indices = self.quantizer(x)
|
105 |
-
return indices
|
106 |
-
|
107 |
-
def detokenize(self, indices: torch.Tensor) -> torch.Tensor:
|
108 |
-
"""detokenize the input indices to d-vector"""
|
109 |
-
zq = self.quantizer.get_output_from_indices(indices.transpose(1, 2)).transpose(1, 2)
|
110 |
-
x = zq.reshape(zq.shape[0], -1)
|
111 |
-
d_vector = self.project(x)
|
112 |
-
return d_vector
|
113 |
-
|
114 |
-
if __name__ == "__main__":
|
115 |
-
model = SpeakerEncoder(
|
116 |
-
input_dim=100,
|
117 |
-
latent_dim=128,
|
118 |
-
token_num=32,
|
119 |
-
fsq_levels=[4, 4, 4, 4, 4, 4],
|
120 |
-
fsq_num_quantizers=1,
|
121 |
-
)
|
122 |
-
mel = torch.randn(8, 200, 100)
|
123 |
-
x_vector, d_vector = model(mel)
|
124 |
-
print("x-vector shape", x_vector.shape)
|
125 |
-
print("d-vector shape", d_vector.shape)
|
126 |
-
|
127 |
-
indices = model.tokenize(mel)
|
128 |
-
print("indices shape", indices.shape)
|
129 |
-
d_vector_post = model.detokenize(indices)
|
130 |
-
print("d-vector shape", d_vector_post.shape)
|
131 |
-
if d_vector_post.all() == d_vector.all():
|
132 |
-
print("d-vector post and d-vector are the same")
|
133 |
-
else:
|
134 |
-
print("d-vector post and d-vector are different")
|
135 |
-
num_params = sum(param.numel() for param in model.parameters())
|
136 |
-
print("{} M".format(num_params / 1e6))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparktts/modules/vq/factorized_vector_quantize.py
DELETED
@@ -1,187 +0,0 @@
|
|
1 |
-
# Copyright (c) 2025 SparkAudio
|
2 |
-
# 2025 Xinsheng Wang ([email protected])
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
# Heavily based on https://github.com/lucidrains/vector-quantize-pytorch
|
17 |
-
|
18 |
-
|
19 |
-
from typing import Any, Dict
|
20 |
-
|
21 |
-
import torch
|
22 |
-
import torch.nn as nn
|
23 |
-
import torch.nn.functional as F
|
24 |
-
from einops import rearrange
|
25 |
-
from torch.nn.utils import weight_norm
|
26 |
-
|
27 |
-
|
28 |
-
def WNConv1d(*args, **kwargs):
|
29 |
-
return weight_norm(nn.Conv1d(*args, **kwargs))
|
30 |
-
|
31 |
-
|
32 |
-
def ema_inplace(moving_avg, new, decay):
|
33 |
-
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
34 |
-
|
35 |
-
|
36 |
-
class FactorizedVectorQuantize(nn.Module):
|
37 |
-
def __init__(
|
38 |
-
self,
|
39 |
-
input_dim: int,
|
40 |
-
codebook_size: int,
|
41 |
-
codebook_dim: int,
|
42 |
-
commitment: float,
|
43 |
-
codebook_loss_weight: float = 1.0,
|
44 |
-
decay: float = 0.99,
|
45 |
-
threshold_ema_dead_code: float = 2,
|
46 |
-
momentum: float = 0.99,
|
47 |
-
**kwargs,
|
48 |
-
):
|
49 |
-
super().__init__()
|
50 |
-
self.input_dim = input_dim
|
51 |
-
self.codebook_size = codebook_size
|
52 |
-
self.codebook_dim = codebook_dim
|
53 |
-
self.commitment = commitment
|
54 |
-
self.codebook_loss_weight = codebook_loss_weight
|
55 |
-
self.decay = decay
|
56 |
-
self.threshold_ema_dead_code = threshold_ema_dead_code
|
57 |
-
self.momentum = momentum
|
58 |
-
|
59 |
-
if input_dim != self.codebook_dim:
|
60 |
-
self.in_project = WNConv1d(input_dim, self.codebook_dim, kernel_size=1)
|
61 |
-
self.out_project = WNConv1d(self.codebook_dim, input_dim, kernel_size=1)
|
62 |
-
|
63 |
-
else:
|
64 |
-
self.in_project = nn.Identity()
|
65 |
-
self.out_project = nn.Identity()
|
66 |
-
|
67 |
-
self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
|
68 |
-
self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
|
69 |
-
|
70 |
-
def forward(self, z: torch.Tensor) -> Dict[str, Any]:
|
71 |
-
"""Quantized the input tensor using a fixed codebook and returns
|
72 |
-
the corresponding codebook vectors
|
73 |
-
|
74 |
-
Parameters
|
75 |
-
----------
|
76 |
-
z : Tensor[B x D x T]
|
77 |
-
|
78 |
-
Returns
|
79 |
-
-------
|
80 |
-
Tensor[B x D x T]
|
81 |
-
Quantized continuous representation of input
|
82 |
-
Tensor[1]
|
83 |
-
Commitment loss to train encoder to predict vectors closer to codebook
|
84 |
-
entries
|
85 |
-
Tensor[1]
|
86 |
-
Codebook loss to update the codebook
|
87 |
-
Tensor[B x T]
|
88 |
-
Codebook indices (quantized discrete representation of input)
|
89 |
-
Tensor[B x D x T]
|
90 |
-
Projected latents (continuous representation of input before quantization)
|
91 |
-
"""
|
92 |
-
# transpose since we use linear
|
93 |
-
|
94 |
-
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
|
95 |
-
z_e = self.in_project(z)
|
96 |
-
z_q, indices, dists = self.decode_latents(z_e)
|
97 |
-
|
98 |
-
# statistic the usage of codes
|
99 |
-
embed_onehot = F.one_hot(indices, self.codebook_size).type(z_e.dtype)
|
100 |
-
avg_probs = torch.mean(embed_onehot.reshape(-1, self.codebook_size), dim=0)
|
101 |
-
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
102 |
-
|
103 |
-
active_num = (embed_onehot.sum(0).sum(0) > 0).sum()
|
104 |
-
if self.training:
|
105 |
-
# We do the expiry of code at that point as buffers are in sync
|
106 |
-
# and all the workers will take the same decision.
|
107 |
-
ema_inplace(self.cluster_size, embed_onehot.sum(0).sum(0), self.decay)
|
108 |
-
active_num = sum(self.cluster_size > self.threshold_ema_dead_code)
|
109 |
-
|
110 |
-
if self.training:
|
111 |
-
commit_loss = (
|
112 |
-
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
113 |
-
* self.commitment
|
114 |
-
)
|
115 |
-
|
116 |
-
codebook_loss = (
|
117 |
-
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
118 |
-
* self.codebook_loss_weight
|
119 |
-
)
|
120 |
-
|
121 |
-
else:
|
122 |
-
commit_loss = torch.zeros(0, device=z.device)
|
123 |
-
codebook_loss = torch.zeros(0, device=z.device)
|
124 |
-
|
125 |
-
z_q = (
|
126 |
-
z_e + (z_q - z_e).detach()
|
127 |
-
) # noop in forward pass, straight-through gradient estimator in backward pass
|
128 |
-
|
129 |
-
z_q = self.out_project(z_q)
|
130 |
-
|
131 |
-
vq_loss = (commit_loss + codebook_loss).mean()
|
132 |
-
|
133 |
-
return {
|
134 |
-
"z_q": z_q,
|
135 |
-
"indices": indices,
|
136 |
-
"dists": dists,
|
137 |
-
"vq_loss": vq_loss,
|
138 |
-
"perplexity": perplexity,
|
139 |
-
"active_num": active_num.float(),
|
140 |
-
}
|
141 |
-
|
142 |
-
def vq2emb(self, vq, out_proj=True):
|
143 |
-
emb = self.embed_code(vq)
|
144 |
-
if out_proj:
|
145 |
-
emb = self.out_project(emb)
|
146 |
-
return emb
|
147 |
-
|
148 |
-
def tokenize(self, z: torch.Tensor) -> torch.Tensor:
|
149 |
-
"""tokenize the input tensor"""
|
150 |
-
z_e = self.in_project(z)
|
151 |
-
_, indices, _ = self.decode_latents(z_e)
|
152 |
-
return indices
|
153 |
-
|
154 |
-
def detokenize(self, indices):
|
155 |
-
"""detokenize the input indices"""
|
156 |
-
z_q = self.decode_code(indices)
|
157 |
-
z_q = self.out_project(z_q)
|
158 |
-
return z_q
|
159 |
-
|
160 |
-
def get_emb(self):
|
161 |
-
return self.codebook.weight
|
162 |
-
|
163 |
-
def embed_code(self, embed_id):
|
164 |
-
return F.embedding(embed_id, self.codebook.weight)
|
165 |
-
|
166 |
-
def decode_code(self, embed_id):
|
167 |
-
return self.embed_code(embed_id).transpose(1, 2)
|
168 |
-
|
169 |
-
def decode_latents(self, latents):
|
170 |
-
encodings = rearrange(latents, "b d t -> (b t) d")
|
171 |
-
codebook = self.codebook.weight
|
172 |
-
|
173 |
-
# L2 normalize encodings and codebook
|
174 |
-
encodings = F.normalize(encodings)
|
175 |
-
codebook = F.normalize(codebook)
|
176 |
-
|
177 |
-
# Compute euclidean distance between encodings and codebook,
|
178 |
-
# with L2 normalization, the distance is equal to cosine distance
|
179 |
-
dist = (
|
180 |
-
encodings.pow(2).sum(1, keepdim=True)
|
181 |
-
- 2 * encodings @ codebook.t()
|
182 |
-
+ codebook.pow(2).sum(1, keepdim=True).t()
|
183 |
-
)
|
184 |
-
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
185 |
-
z_q = self.decode_code(indices)
|
186 |
-
|
187 |
-
return z_q, indices, dist
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparktts/utils/__init__.py
DELETED
File without changes
|
sparktts/utils/audio.py
DELETED
@@ -1,271 +0,0 @@
|
|
1 |
-
# Copyright (c) 2025 SparkAudio
|
2 |
-
# 2025 Xinsheng Wang ([email protected])
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
"""
|
16 |
-
Description:
|
17 |
-
This script contains a collection of functions designed to handle various
|
18 |
-
audio processing.
|
19 |
-
"""
|
20 |
-
|
21 |
-
import random
|
22 |
-
import soxr
|
23 |
-
import soundfile
|
24 |
-
import torch
|
25 |
-
import torchaudio
|
26 |
-
import numpy as np
|
27 |
-
|
28 |
-
from pathlib import Path
|
29 |
-
from typing import Tuple
|
30 |
-
from numpy.lib.stride_tricks import sliding_window_view
|
31 |
-
|
32 |
-
|
33 |
-
def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
|
34 |
-
"""
|
35 |
-
Normalize the volume of an audio signal.
|
36 |
-
|
37 |
-
Parameters:
|
38 |
-
audio (numpy array): Input audio signal array.
|
39 |
-
coeff (float): Target coefficient for normalization, default is 0.2.
|
40 |
-
|
41 |
-
Returns:
|
42 |
-
numpy array: The volume-normalized audio signal.
|
43 |
-
"""
|
44 |
-
# Sort the absolute values of the audio signal
|
45 |
-
temp = np.sort(np.abs(audio))
|
46 |
-
|
47 |
-
# If the maximum value is less than 0.1, scale the array to have a maximum of 0.1
|
48 |
-
if temp[-1] < 0.1:
|
49 |
-
scaling_factor = max(
|
50 |
-
temp[-1], 1e-3
|
51 |
-
) # Prevent division by zero with a small constant
|
52 |
-
audio = audio / scaling_factor * 0.1
|
53 |
-
|
54 |
-
# Filter out values less than 0.01 from temp
|
55 |
-
temp = temp[temp > 0.01]
|
56 |
-
L = temp.shape[0] # Length of the filtered array
|
57 |
-
|
58 |
-
# If there are fewer than or equal to 10 significant values, return the audio without further processing
|
59 |
-
if L <= 10:
|
60 |
-
return audio
|
61 |
-
|
62 |
-
# Compute the average of the top 10% to 1% of values in temp
|
63 |
-
volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
|
64 |
-
|
65 |
-
# Normalize the audio to the target coefficient level, clamping the scale factor between 0.1 and 10
|
66 |
-
audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
|
67 |
-
|
68 |
-
# Ensure the maximum absolute value in the audio does not exceed 1
|
69 |
-
max_value = np.max(np.abs(audio))
|
70 |
-
if max_value > 1:
|
71 |
-
audio = audio / max_value
|
72 |
-
|
73 |
-
return audio
|
74 |
-
|
75 |
-
|
76 |
-
def load_audio(
|
77 |
-
adfile: Path,
|
78 |
-
sampling_rate: int = None,
|
79 |
-
length: int = None,
|
80 |
-
volume_normalize: bool = False,
|
81 |
-
segment_duration: int = None,
|
82 |
-
) -> np.ndarray:
|
83 |
-
r"""Load audio file with target sampling rate and lsength
|
84 |
-
|
85 |
-
Args:
|
86 |
-
adfile (Path): path to audio file.
|
87 |
-
sampling_rate (int, optional): target sampling rate. Defaults to None.
|
88 |
-
length (int, optional): target audio length. Defaults to None.
|
89 |
-
volume_normalize (bool, optional): whether perform volume normalization. Defaults to False.
|
90 |
-
segment_duration (int): random select a segment with duration of {segment_duration}s.
|
91 |
-
Defualt to None which means the whole audio will be used.
|
92 |
-
|
93 |
-
Returns:
|
94 |
-
audio (np.ndarray): audio
|
95 |
-
"""
|
96 |
-
|
97 |
-
audio, sr = soundfile.read(adfile)
|
98 |
-
if len(audio.shape) > 1:
|
99 |
-
audio = audio[:, 0]
|
100 |
-
|
101 |
-
if sampling_rate is not None and sr != sampling_rate:
|
102 |
-
audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ")
|
103 |
-
sr = sampling_rate
|
104 |
-
|
105 |
-
if segment_duration is not None:
|
106 |
-
seg_length = int(sr * segment_duration)
|
107 |
-
audio = random_select_audio_segment(audio, seg_length)
|
108 |
-
|
109 |
-
# Audio volume normalize
|
110 |
-
if volume_normalize:
|
111 |
-
audio = audio_volume_normalize(audio)
|
112 |
-
# check the audio length
|
113 |
-
if length is not None:
|
114 |
-
assert abs(audio.shape[0] - length) < 1000
|
115 |
-
if audio.shape[0] > length:
|
116 |
-
audio = audio[:length]
|
117 |
-
else:
|
118 |
-
audio = np.pad(audio, (0, int(length - audio.shape[0])))
|
119 |
-
return audio
|
120 |
-
|
121 |
-
|
122 |
-
def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray:
|
123 |
-
"""get an audio segment given the length
|
124 |
-
|
125 |
-
Args:
|
126 |
-
audio (np.ndarray):
|
127 |
-
length (int): audio length = sampling_rate * duration
|
128 |
-
"""
|
129 |
-
if audio.shape[0] < length:
|
130 |
-
audio = np.pad(audio, (0, int(length - audio.shape[0])))
|
131 |
-
start_index = random.randint(0, audio.shape[0] - length)
|
132 |
-
end_index = int(start_index + length)
|
133 |
-
|
134 |
-
return audio[start_index:end_index]
|
135 |
-
|
136 |
-
|
137 |
-
def audio_highpass_filter(audio, sample_rate, highpass_cutoff_freq):
|
138 |
-
"""apply highpass fileter to audio
|
139 |
-
|
140 |
-
Args:
|
141 |
-
audio (np.ndarray):
|
142 |
-
sample_rate (ind):
|
143 |
-
highpass_cutoff_freq (int):
|
144 |
-
"""
|
145 |
-
|
146 |
-
audio = torchaudio.functional.highpass_biquad(
|
147 |
-
torch.from_numpy(audio), sample_rate, cutoff_freq=highpass_cutoff_freq
|
148 |
-
)
|
149 |
-
return audio.numpy()
|
150 |
-
|
151 |
-
|
152 |
-
def stft(
|
153 |
-
x: torch.Tensor,
|
154 |
-
fft_size: int,
|
155 |
-
hop_size: int,
|
156 |
-
win_length: int,
|
157 |
-
window: str,
|
158 |
-
use_complex: bool = False,
|
159 |
-
) -> torch.Tensor:
|
160 |
-
"""Perform STFT and convert to magnitude spectrogram.
|
161 |
-
Args:
|
162 |
-
x (Tensor): Input signal tensor (B, T).
|
163 |
-
fft_size (int): FFT size.
|
164 |
-
hop_size (int): Hop size.
|
165 |
-
win_length (int): Window length.
|
166 |
-
window (str): Window function type.
|
167 |
-
Returns:
|
168 |
-
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
169 |
-
"""
|
170 |
-
|
171 |
-
x_stft = torch.stft(
|
172 |
-
x, fft_size, hop_size, win_length, window.to(x.device), return_complex=True
|
173 |
-
)
|
174 |
-
|
175 |
-
# clamp is needed to avoid nan or inf
|
176 |
-
if not use_complex:
|
177 |
-
return torch.sqrt(
|
178 |
-
torch.clamp(x_stft.real**2 + x_stft.imag**2, min=1e-7, max=1e3)
|
179 |
-
).transpose(2, 1)
|
180 |
-
else:
|
181 |
-
res = torch.cat([x_stft.real.unsqueeze(1), x_stft.imag.unsqueeze(1)], dim=1)
|
182 |
-
res = res.transpose(2, 3) # [B, 2, T, F]
|
183 |
-
return res
|
184 |
-
|
185 |
-
|
186 |
-
def detect_speech_boundaries(
|
187 |
-
wav: np.ndarray,
|
188 |
-
sample_rate: int,
|
189 |
-
window_duration: float = 0.1,
|
190 |
-
energy_threshold: float = 0.01,
|
191 |
-
margin_factor: int = 2
|
192 |
-
) -> Tuple[int, int]:
|
193 |
-
"""Detect the start and end points of speech in an audio signal using RMS energy.
|
194 |
-
|
195 |
-
Args:
|
196 |
-
wav: Input audio signal array with values in [-1, 1]
|
197 |
-
sample_rate: Audio sample rate in Hz
|
198 |
-
window_duration: Duration of detection window in seconds
|
199 |
-
energy_threshold: RMS energy threshold for speech detection
|
200 |
-
margin_factor: Factor to determine extra margin around detected boundaries
|
201 |
-
|
202 |
-
Returns:
|
203 |
-
tuple: (start_index, end_index) of speech segment
|
204 |
-
|
205 |
-
Raises:
|
206 |
-
ValueError: If the audio contains only silence
|
207 |
-
"""
|
208 |
-
window_size = int(window_duration * sample_rate)
|
209 |
-
margin = margin_factor * window_size
|
210 |
-
step_size = window_size // 10
|
211 |
-
|
212 |
-
# Create sliding windows using stride tricks to avoid loops
|
213 |
-
windows = sliding_window_view(wav, window_size)[::step_size]
|
214 |
-
|
215 |
-
# Calculate RMS energy for each window
|
216 |
-
energy = np.sqrt(np.mean(windows ** 2, axis=1))
|
217 |
-
speech_mask = energy >= energy_threshold
|
218 |
-
|
219 |
-
if not np.any(speech_mask):
|
220 |
-
raise ValueError("No speech detected in audio (only silence)")
|
221 |
-
|
222 |
-
start = max(0, np.argmax(speech_mask) * step_size - margin)
|
223 |
-
end = min(len(wav), (len(speech_mask) - 1 - np.argmax(speech_mask[::-1])) * step_size + margin)
|
224 |
-
|
225 |
-
return start, end
|
226 |
-
|
227 |
-
|
228 |
-
def remove_silence_on_both_ends(
|
229 |
-
wav: np.ndarray,
|
230 |
-
sample_rate: int,
|
231 |
-
window_duration: float = 0.1,
|
232 |
-
volume_threshold: float = 0.01
|
233 |
-
) -> np.ndarray:
|
234 |
-
"""Remove silence from both ends of an audio signal.
|
235 |
-
|
236 |
-
Args:
|
237 |
-
wav: Input audio signal array
|
238 |
-
sample_rate: Audio sample rate in Hz
|
239 |
-
window_duration: Duration of detection window in seconds
|
240 |
-
volume_threshold: Amplitude threshold for silence detection
|
241 |
-
|
242 |
-
Returns:
|
243 |
-
np.ndarray: Audio signal with silence removed from both ends
|
244 |
-
|
245 |
-
Raises:
|
246 |
-
ValueError: If the audio contains only silence
|
247 |
-
"""
|
248 |
-
start, end = detect_speech_boundaries(
|
249 |
-
wav,
|
250 |
-
sample_rate,
|
251 |
-
window_duration,
|
252 |
-
volume_threshold
|
253 |
-
)
|
254 |
-
return wav[start:end]
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
def hertz_to_mel(pitch: float) -> float:
|
259 |
-
"""
|
260 |
-
Converts a frequency from the Hertz scale to the Mel scale.
|
261 |
-
|
262 |
-
Parameters:
|
263 |
-
- pitch: float or ndarray
|
264 |
-
Frequency in Hertz.
|
265 |
-
|
266 |
-
Returns:
|
267 |
-
- mel: float or ndarray
|
268 |
-
Frequency in Mel scale.
|
269 |
-
"""
|
270 |
-
mel = 2595 * np.log10(1 + pitch / 700)
|
271 |
-
return mel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparktts/utils/file.py
DELETED
@@ -1,221 +0,0 @@
|
|
1 |
-
# Copyright (c) 2025 SparkAudio
|
2 |
-
# 2025 Xinsheng Wang ([email protected])
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
"""
|
16 |
-
Description:
|
17 |
-
This script contains a collection of functions designed to handle various
|
18 |
-
file reading and writing operations. It provides utilities to read from files,
|
19 |
-
write data to files, and perform file manipulation tasks.
|
20 |
-
"""
|
21 |
-
|
22 |
-
|
23 |
-
import os
|
24 |
-
import json
|
25 |
-
import json
|
26 |
-
import csv
|
27 |
-
|
28 |
-
from tqdm import tqdm
|
29 |
-
from typing import List, Dict, Any, Set, Union
|
30 |
-
from pathlib import Path
|
31 |
-
from omegaconf import OmegaConf, DictConfig
|
32 |
-
|
33 |
-
|
34 |
-
def resolve_symbolic_link(symbolic_link_path: Path) -> Path:
|
35 |
-
"""
|
36 |
-
Resolves the absolute path of a symbolic link.
|
37 |
-
|
38 |
-
Args:
|
39 |
-
symbolic_link_path (Path): The path to the symbolic link.
|
40 |
-
|
41 |
-
Returns:
|
42 |
-
Path: The absolute path that the symbolic link points to.
|
43 |
-
"""
|
44 |
-
|
45 |
-
link_directory = os.path.dirname(symbolic_link_path)
|
46 |
-
target_path_relative = os.readlink(symbolic_link_path)
|
47 |
-
return os.path.join(link_directory, target_path_relative)
|
48 |
-
|
49 |
-
|
50 |
-
def write_jsonl(metadata: List[dict], file_path: Path) -> None:
|
51 |
-
"""Writes a list of dictionaries to a JSONL file.
|
52 |
-
|
53 |
-
Args:
|
54 |
-
metadata : List[dict]
|
55 |
-
A list of dictionaries, each representing a piece of meta.
|
56 |
-
file_path : Path
|
57 |
-
The file path to save the JSONL file
|
58 |
-
|
59 |
-
This function writes each dictionary in the list to a new line in the specified file.
|
60 |
-
"""
|
61 |
-
with open(file_path, "w", encoding="utf-8") as f:
|
62 |
-
for meta in tqdm(metadata, desc="writing jsonl"):
|
63 |
-
# Convert dictionary to JSON string and write it to the file with a newline
|
64 |
-
json_str = json.dumps(meta, ensure_ascii=False) + "\n"
|
65 |
-
f.write(json_str)
|
66 |
-
print(f"jsonl saved to {file_path}")
|
67 |
-
|
68 |
-
|
69 |
-
def read_jsonl(file_path: Path) -> List[dict]:
|
70 |
-
"""
|
71 |
-
Reads a JSONL file and returns a list of dictionaries.
|
72 |
-
|
73 |
-
Args:
|
74 |
-
file_path : Path
|
75 |
-
The path to the JSONL file to be read.
|
76 |
-
|
77 |
-
Returns:
|
78 |
-
List[dict]
|
79 |
-
A list of dictionaries parsed from each line of the JSONL file.
|
80 |
-
"""
|
81 |
-
metadata = []
|
82 |
-
# Open the file for reading
|
83 |
-
with open(file_path, "r", encoding="utf-8") as f:
|
84 |
-
# Split the file into lines
|
85 |
-
lines = f.read().splitlines()
|
86 |
-
# Process each line
|
87 |
-
for line in lines:
|
88 |
-
# Convert JSON string back to dictionary and append to list
|
89 |
-
meta = json.loads(line)
|
90 |
-
metadata.append(meta)
|
91 |
-
# Return the list of metadata
|
92 |
-
return metadata
|
93 |
-
|
94 |
-
def read_json_as_jsonl(file_path: Path) -> List[dict]:
|
95 |
-
metadata = []
|
96 |
-
with open(file_path, 'r', encoding='utf-8') as infile:
|
97 |
-
data = json.load(infile)
|
98 |
-
for k in sorted(data.keys()):
|
99 |
-
meta = {'index': k}
|
100 |
-
meta.update(data[k])
|
101 |
-
metadata.append(meta)
|
102 |
-
return metadata
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
def decode_unicode_strings(meta: Dict[str, Any]) -> Dict[str, Any]:
|
107 |
-
processed_meta = {}
|
108 |
-
for k, v in meta.items():
|
109 |
-
if isinstance(v, str):
|
110 |
-
processed_meta[k] = v.encode("utf-8").decode("unicode_escape")
|
111 |
-
else:
|
112 |
-
processed_meta[k] = v
|
113 |
-
return processed_meta
|
114 |
-
|
115 |
-
|
116 |
-
def load_config(config_path: Path) -> DictConfig:
|
117 |
-
"""Loads a configuration file and optionally merges it with a base configuration.
|
118 |
-
|
119 |
-
Args:
|
120 |
-
config_path (Path): Path to the configuration file.
|
121 |
-
"""
|
122 |
-
# Load the initial configuration from the given path
|
123 |
-
config = OmegaConf.load(config_path)
|
124 |
-
|
125 |
-
# Check if there is a base configuration specified and merge if necessary
|
126 |
-
if config.get("base_config", None) is not None:
|
127 |
-
base_config = OmegaConf.load(config["base_config"])
|
128 |
-
config = OmegaConf.merge(base_config, config)
|
129 |
-
|
130 |
-
return config
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
def jsonl_to_csv(jsonl_file_path: str, csv_file_path: str) -> None:
|
135 |
-
"""
|
136 |
-
Converts a JSONL file to a CSV file.
|
137 |
-
|
138 |
-
This function reads a JSONL file, determines all unique keys present in the file,
|
139 |
-
and writes the data to a CSV file with columns for all these keys.
|
140 |
-
"""
|
141 |
-
|
142 |
-
all_keys = set()
|
143 |
-
data_rows = []
|
144 |
-
|
145 |
-
# Read the JSONL file once to extract keys and collect data
|
146 |
-
with open(jsonl_file_path, 'r') as file:
|
147 |
-
for line in file:
|
148 |
-
data = json.loads(line.strip())
|
149 |
-
data_rows.append(data)
|
150 |
-
all_keys.update(data.keys())
|
151 |
-
|
152 |
-
# Convert the set of keys to a sorted list for consistent column order
|
153 |
-
sorted_keys = sorted(all_keys)
|
154 |
-
|
155 |
-
# Write the data to a CSV file
|
156 |
-
with open(csv_file_path, 'w', newline='') as csvfile:
|
157 |
-
writer = csv.DictWriter(csvfile, fieldnames=sorted_keys)
|
158 |
-
|
159 |
-
# Write the header row
|
160 |
-
writer.writeheader()
|
161 |
-
|
162 |
-
# Write each row of data
|
163 |
-
for data in data_rows:
|
164 |
-
writer.writerow(data)
|
165 |
-
|
166 |
-
print(f"CSV file has been created at {csv_file_path}")
|
167 |
-
|
168 |
-
|
169 |
-
def save_metadata(data, filename, headers=None):
|
170 |
-
"""
|
171 |
-
Save metadata to a file.
|
172 |
-
|
173 |
-
Args:
|
174 |
-
data (list of dict): Metadata to be saved.
|
175 |
-
filename (str): Name of the file to save the metadata.
|
176 |
-
headers (list of str): The order of column names to be saved; defaults to the keys from the first dictionary in data if not provided.
|
177 |
-
"""
|
178 |
-
# Set headers to keys from the first dictionary in data if not explicitly provided
|
179 |
-
if headers is None:
|
180 |
-
headers = list(data[0].keys())
|
181 |
-
|
182 |
-
with open(filename, "w", encoding="utf-8") as file:
|
183 |
-
# Write the headers to the file
|
184 |
-
file.write("|".join(headers) + "\n")
|
185 |
-
for entry in data:
|
186 |
-
# Retrieve values in the order of headers, replacing any '|' characters with a space to prevent formatting errors
|
187 |
-
formatted_values = [str(entry.get(key, "")).replace("|", " ") for key in headers]
|
188 |
-
# Write the formatted values to the file
|
189 |
-
file.write("|".join(formatted_values) + "\n")
|
190 |
-
|
191 |
-
|
192 |
-
def read_metadata(filename, headers=None):
|
193 |
-
"""
|
194 |
-
Read metadata from a file.
|
195 |
-
|
196 |
-
Args:
|
197 |
-
filename (str): The file from which to read the metadata.
|
198 |
-
|
199 |
-
Returns:
|
200 |
-
list of dict: The metadata read from the file.
|
201 |
-
list of str: The headers used in the file.
|
202 |
-
"""
|
203 |
-
with open(filename, "r", encoding="utf-8") as file:
|
204 |
-
lines = file.readlines()
|
205 |
-
|
206 |
-
data = []
|
207 |
-
# Set headers from the first line of the file if not provided
|
208 |
-
if headers is None:
|
209 |
-
headers = lines[0].strip().split("|")
|
210 |
-
lines = lines[1:]
|
211 |
-
|
212 |
-
for line in lines:
|
213 |
-
line = line.strip()
|
214 |
-
# Skip empty lines
|
215 |
-
if not line:
|
216 |
-
continue
|
217 |
-
# Split the line by '|' and pair with headers to form a dictionary
|
218 |
-
entry_data = dict(zip(headers, line.split("|")))
|
219 |
-
data.append(entry_data)
|
220 |
-
|
221 |
-
return data, headers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparktts/utils/parse_options.sh
DELETED
@@ -1,97 +0,0 @@
|
|
1 |
-
#!/bin/bash
|
2 |
-
|
3 |
-
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
|
4 |
-
# Arnab Ghoshal, Karel Vesely
|
5 |
-
|
6 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
-
# you may not use this file except in compliance with the License.
|
8 |
-
# You may obtain a copy of the License at
|
9 |
-
#
|
10 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
-
#
|
12 |
-
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
13 |
-
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
14 |
-
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
15 |
-
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
16 |
-
# See the Apache 2 License for the specific language governing permissions and
|
17 |
-
# limitations under the License.
|
18 |
-
|
19 |
-
|
20 |
-
# Parse command-line options.
|
21 |
-
# To be sourced by another script (as in ". parse_options.sh").
|
22 |
-
# Option format is: --option-name arg
|
23 |
-
# and shell variable "option_name" gets set to value "arg."
|
24 |
-
# The exception is --help, which takes no arguments, but prints the
|
25 |
-
# $help_message variable (if defined).
|
26 |
-
|
27 |
-
|
28 |
-
###
|
29 |
-
### The --config file options have lower priority to command line
|
30 |
-
### options, so we need to import them first...
|
31 |
-
###
|
32 |
-
|
33 |
-
# Now import all the configs specified by command-line, in left-to-right order
|
34 |
-
# for ((argpos=1; argpos<$#; argpos++)); do
|
35 |
-
# if [ "${!argpos}" == "--config" ]; then
|
36 |
-
# argpos_plus1=$((argpos+1))
|
37 |
-
# config=${!argpos_plus1}
|
38 |
-
# [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
|
39 |
-
# . $config # source the config file.
|
40 |
-
# fi
|
41 |
-
# done
|
42 |
-
|
43 |
-
|
44 |
-
###
|
45 |
-
### No we process the command line options
|
46 |
-
###
|
47 |
-
while true; do
|
48 |
-
[ -z "${1:-}" ] && break; # break if there are no arguments
|
49 |
-
case "$1" in
|
50 |
-
# If the enclosing script is called with --help option, print the help
|
51 |
-
# message and exit. Scripts should put help messages in $help_message
|
52 |
-
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
|
53 |
-
else printf "$help_message\n" 1>&2 ; fi;
|
54 |
-
exit 0 ;;
|
55 |
-
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
|
56 |
-
exit 1 ;;
|
57 |
-
# If the first command-line argument begins with "--" (e.g. --foo-bar),
|
58 |
-
# then work out the variable name as $name, which will equal "foo_bar".
|
59 |
-
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
|
60 |
-
# Next we test whether the variable in question is undefned-- if so it's
|
61 |
-
# an invalid option and we die. Note: $0 evaluates to the name of the
|
62 |
-
# enclosing script.
|
63 |
-
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
|
64 |
-
# is undefined. We then have to wrap this test inside "eval" because
|
65 |
-
# foo_bar is itself inside a variable ($name).
|
66 |
-
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
67 |
-
|
68 |
-
oldval="`eval echo \\$$name`";
|
69 |
-
# Work out whether we seem to be expecting a Boolean argument.
|
70 |
-
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
|
71 |
-
was_bool=true;
|
72 |
-
else
|
73 |
-
was_bool=false;
|
74 |
-
fi
|
75 |
-
|
76 |
-
# Set the variable to the right value-- the escaped quotes make it work if
|
77 |
-
# the option had spaces, like --cmd "queue.pl -sync y"
|
78 |
-
eval $name=\"$2\";
|
79 |
-
|
80 |
-
# Check that Boolean-valued arguments are really Boolean.
|
81 |
-
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
82 |
-
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
83 |
-
exit 1;
|
84 |
-
fi
|
85 |
-
shift 2;
|
86 |
-
;;
|
87 |
-
*) break;
|
88 |
-
esac
|
89 |
-
done
|
90 |
-
|
91 |
-
|
92 |
-
# Check for an empty argument to the --cmd option, which can easily occur as a
|
93 |
-
# result of scripting errors.
|
94 |
-
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
|
95 |
-
|
96 |
-
|
97 |
-
true; # so this script returns exit code 0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparktts/utils/token_parser.py
DELETED
@@ -1,187 +0,0 @@
|
|
1 |
-
TASK_TOKEN_MAP = {
|
2 |
-
"vc": "<|task_vc|>",
|
3 |
-
"tts": "<|task_tts|>",
|
4 |
-
"asr": "<|task_asr|>",
|
5 |
-
"s2s": "<|task_s2s|>",
|
6 |
-
"t2s": "<|task_t2s|>",
|
7 |
-
"understand": "<|task_understand|>",
|
8 |
-
"caption": "<|task_cap|>",
|
9 |
-
"controllable_tts": "<|task_controllable_tts|>",
|
10 |
-
"prompt_tts": "<|task_prompt_tts|>",
|
11 |
-
"speech_edit": "<|task_edit|>",
|
12 |
-
}
|
13 |
-
|
14 |
-
LEVELS_MAP = {
|
15 |
-
"very_low": 0,
|
16 |
-
"low": 1,
|
17 |
-
"moderate": 2,
|
18 |
-
"high": 3,
|
19 |
-
"very_high": 4,
|
20 |
-
}
|
21 |
-
|
22 |
-
LEVELS_MAP_UI = {
|
23 |
-
1: 'very_low',
|
24 |
-
2: 'low',
|
25 |
-
3: 'moderate',
|
26 |
-
4: 'high',
|
27 |
-
5: 'very_high'
|
28 |
-
}
|
29 |
-
|
30 |
-
GENDER_MAP = {
|
31 |
-
"female": 0,
|
32 |
-
"male": 1,
|
33 |
-
}
|
34 |
-
|
35 |
-
AGE_MAP = {"Child": 0, "Teenager": 1, "Youth-Adult": 2, "Middle-aged": 3, "Elderly": 4}
|
36 |
-
|
37 |
-
EMO_MAP = {
|
38 |
-
"UNKNOWN": 0,
|
39 |
-
"NEUTRAL": 1,
|
40 |
-
"ANGRY": 2,
|
41 |
-
"HAPPY": 3,
|
42 |
-
"SAD": 4,
|
43 |
-
"FEARFUL": 5,
|
44 |
-
"DISGUSTED": 6,
|
45 |
-
"SURPRISED": 7,
|
46 |
-
"SARCASTIC": 8,
|
47 |
-
"EXCITED": 9,
|
48 |
-
"SLEEPY": 10,
|
49 |
-
"CONFUSED": 11,
|
50 |
-
"EMPHASIS": 12,
|
51 |
-
"LAUGHING": 13,
|
52 |
-
"SINGING": 14,
|
53 |
-
"WORRIED": 15,
|
54 |
-
"WHISPER": 16,
|
55 |
-
"ANXIOUS": 17,
|
56 |
-
"NO-AGREEMENT": 18,
|
57 |
-
"APOLOGETIC": 19,
|
58 |
-
"CONCERNED": 20,
|
59 |
-
"ENUNCIATED": 21,
|
60 |
-
"ASSERTIVE": 22,
|
61 |
-
"ENCOURAGING": 23,
|
62 |
-
"CONTEMPT": 24,
|
63 |
-
}
|
64 |
-
|
65 |
-
|
66 |
-
class TokenParser:
|
67 |
-
"""Turn label to special token"""
|
68 |
-
|
69 |
-
def __init__(self):
|
70 |
-
pass
|
71 |
-
|
72 |
-
"""Parse the attributes of a person."""
|
73 |
-
|
74 |
-
def __init__(self):
|
75 |
-
pass
|
76 |
-
|
77 |
-
@staticmethod
|
78 |
-
def age(age: str) -> str:
|
79 |
-
"""Turn age token."""
|
80 |
-
age_id = AGE_MAP[age]
|
81 |
-
return f"<|age_{age_id}|>"
|
82 |
-
|
83 |
-
@staticmethod
|
84 |
-
def gender(gender: str) -> str:
|
85 |
-
"""Turn gender token."""
|
86 |
-
gender_id = GENDER_MAP[gender]
|
87 |
-
return f"<|gender_{gender_id}|>"
|
88 |
-
|
89 |
-
@staticmethod
|
90 |
-
def mel_value(mel: int):
|
91 |
-
"""Turn special token of mel scale pitch."""
|
92 |
-
mel = max(0, int(mel))
|
93 |
-
mel = min(1000, int(mel))
|
94 |
-
return f"<|pitch_value_{mel}|>"
|
95 |
-
|
96 |
-
@staticmethod
|
97 |
-
def mel_level(level: str):
|
98 |
-
"""Turn special token of mel level."""
|
99 |
-
level_tag = LEVELS_MAP[level]
|
100 |
-
return f"<|pitch_label_{level_tag}|>"
|
101 |
-
|
102 |
-
@staticmethod
|
103 |
-
def pitch_var_value(pitch_std: int):
|
104 |
-
"""Turn special token of pitch_std value."""
|
105 |
-
assert isinstance(pitch_std, int)
|
106 |
-
pitch_std = max(0, int(pitch_std))
|
107 |
-
pitch_std = min(10, int(pitch_std))
|
108 |
-
return f"<|pitch_var_value_{pitch_std}|>"
|
109 |
-
|
110 |
-
@staticmethod
|
111 |
-
def pitch_var_level(level: str):
|
112 |
-
"""Turn special token of pitch std level."""
|
113 |
-
level_tag = LEVELS_MAP[level]
|
114 |
-
return f"<|pitch_var_label_{level_tag}|>"
|
115 |
-
|
116 |
-
@staticmethod
|
117 |
-
def loudness_value(loudness: int):
|
118 |
-
"""Turn special toak of loudness value [0, 30]"""
|
119 |
-
assert loudness >= 0
|
120 |
-
loudness = max(0, int(loudness))
|
121 |
-
loudness = min(30, int(loudness))
|
122 |
-
return f"<|loudness_value_{loudness}|>"
|
123 |
-
|
124 |
-
@staticmethod
|
125 |
-
def loudness_level(level: str):
|
126 |
-
"""Turn special token of loudness level."""
|
127 |
-
level_tag = LEVELS_MAP[level]
|
128 |
-
return f"<|loudness_label_{level_tag}|>"
|
129 |
-
|
130 |
-
@staticmethod
|
131 |
-
def speed_value(speed: int):
|
132 |
-
"""Turn special token of speed value."""
|
133 |
-
speed = max(0, int(speed))
|
134 |
-
speed = min(10, int(speed))
|
135 |
-
return f"<|speed_value_{speed}|>"
|
136 |
-
|
137 |
-
@staticmethod
|
138 |
-
def speed_level(level: str):
|
139 |
-
"""Turn special token of speed level."""
|
140 |
-
level_tag = LEVELS_MAP[level]
|
141 |
-
return f"<|speed_label_{level_tag}|>"
|
142 |
-
|
143 |
-
@staticmethod
|
144 |
-
def task(task: str) -> str:
|
145 |
-
"""Turn special token of task."""
|
146 |
-
assert task in TASK_TOKEN_MAP.keys()
|
147 |
-
|
148 |
-
return TASK_TOKEN_MAP[task]
|
149 |
-
|
150 |
-
@staticmethod
|
151 |
-
def emotion(emotion: str):
|
152 |
-
emo_id = EMO_MAP[emotion]
|
153 |
-
|
154 |
-
return f"<|emotion_{emo_id}|>"
|
155 |
-
|
156 |
-
|
157 |
-
# test
|
158 |
-
if __name__ == "__main__":
|
159 |
-
from transformers import AutoTokenizer
|
160 |
-
|
161 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
162 |
-
"/aifs4su/xinshengwang/code/StyleCraft/tokenizer/stylecraft-bicodec-pitch-loudness-speed-emotion-tokenizer"
|
163 |
-
)
|
164 |
-
|
165 |
-
tasks = ["tts", "tts", "understand", "controllable_tts", "prompt_tts"]
|
166 |
-
ages = ["Child", "Teenager", "Youth-Adult", "Middle-aged", "Elderly"]
|
167 |
-
genders = ["female", "female", "female", "male", "male"]
|
168 |
-
mels = [100, 200, 300, 400, 500]
|
169 |
-
mel_levels = ["very_low", "low", "moderate", "high", "very_high"]
|
170 |
-
loudnesses = [1, 10, 23, 19, 30]
|
171 |
-
loudness_levels = ["very_low", "low", "moderate", "high", "very_high"]
|
172 |
-
emotions = ["UNKNOWN", "NEUTRAL", "ANGRY", "HAPPY", "SAD"]
|
173 |
-
|
174 |
-
for i in range(5):
|
175 |
-
task = TokenParser.task(tasks[i])
|
176 |
-
age = TokenParser.age(ages[i])
|
177 |
-
gender = TokenParser.gender(genders[i])
|
178 |
-
mel = TokenParser.mel_value(mels[i])
|
179 |
-
mel_level = TokenParser.mel_level(mel_levels[i])
|
180 |
-
loudness = TokenParser.loudness_value(loudnesses[i])
|
181 |
-
loudness_level = TokenParser.loudness_level(loudness_levels[i])
|
182 |
-
emotion = TokenParser.emotion(emotions[i])
|
183 |
-
inputs = [task, age, gender, mel, mel_level, loudness, loudness_level, emotion]
|
184 |
-
inputs = "".join(inputs)
|
185 |
-
ids = tokenizer.encode(inputs, add_special_tokens=False)
|
186 |
-
print(ids)
|
187 |
-
print("decode", tokenizer.decode(ids))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
webui.py
CHANGED
@@ -265,5 +265,6 @@ if __name__ == "__main__":
|
|
265 |
# Launch Gradio with the specified server name and port
|
266 |
demo.launch(
|
267 |
server_name=args.server_name,
|
268 |
-
server_port=args.server_port
|
269 |
-
|
|
|
|
265 |
# Launch Gradio with the specified server name and port
|
266 |
demo.launch(
|
267 |
server_name=args.server_name,
|
268 |
+
server_port=args.server_port,
|
269 |
+
share=True
|
270 |
+
)
|