Spaces:
Running
on
T4
Running
on
T4
update to the current version
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- Architectures/ToucanTTS/StochasticToucanTTS/README.md +0 -1
- Architectures/ToucanTTS/StochasticToucanTTS/StochasticToucanTTS.py +0 -493
- Architectures/ToucanTTS/StochasticToucanTTS/StochasticVariancePredictor.py +0 -440
- Architectures/__init__.py +0 -0
- InferenceInterfaces/ControllableInterface.py +25 -18
- InferenceInterfaces/ToucanTTSInterface.py +73 -72
- InferenceInterfaces/UtteranceCloner.py +8 -6
- InferenceInterfaces/audioseal_wm_16bits.yaml +0 -39
- {Architectures β Modules}/Aligner/Aligner.py +27 -31
- {Architectures β Modules}/Aligner/CodecAlignerDataset.py +57 -14
- {Architectures β Modules}/Aligner/README.md +0 -0
- {Architectures β Modules}/Aligner/Reconstructor.py +8 -15
- {Architectures β Modules}/Aligner/__init__.py +0 -0
- {Architectures β Modules}/Aligner/autoaligner_train_loop.py +4 -2
- {Architectures β Modules}/ControllabilityGAN/GAN.py +23 -10
- {Architectures β Modules}/ControllabilityGAN/__init__.py +0 -0
- {Architectures β Modules}/ControllabilityGAN/dataset/__init__.py +0 -0
- {Architectures β Modules}/ControllabilityGAN/dataset/speaker_embeddings_dataset.py +0 -0
- {Architectures β Modules}/ControllabilityGAN/wgan/__init__.py +0 -0
- {Architectures β Modules}/ControllabilityGAN/wgan/init_weights.py +0 -0
- {Architectures β Modules}/ControllabilityGAN/wgan/init_wgan.py +2 -2
- {Architectures β Modules}/ControllabilityGAN/wgan/resnet_1.py +2 -2
- {Architectures β Modules}/ControllabilityGAN/wgan/resnet_init.py +4 -4
- {Architectures β Modules}/ControllabilityGAN/wgan/wgan_qc.py +6 -11
- {Architectures β Modules}/EmbeddingModel/GST.py +1 -1
- {Architectures β Modules}/EmbeddingModel/README.md +0 -0
- {Architectures β Modules}/EmbeddingModel/StyleEmbedding.py +2 -2
- {Architectures β Modules}/EmbeddingModel/StyleTTSEncoder.py +0 -0
- {Architectures β Modules}/EmbeddingModel/__init__.py +0 -0
- {Architectures β Modules}/GeneralLayers/Attention.py +0 -0
- {Architectures β Modules}/GeneralLayers/ConditionalLayerNorm.py +0 -0
- {Architectures β Modules}/GeneralLayers/Conformer.py +29 -18
- {Architectures β Modules}/GeneralLayers/Convolution.py +1 -1
- {Architectures β Modules}/GeneralLayers/DurationPredictor.py +3 -3
- {Architectures β Modules}/GeneralLayers/EncoderLayer.py +1 -1
- {Architectures β Modules}/GeneralLayers/LayerNorm.py +0 -0
- {Architectures β Modules}/GeneralLayers/LengthRegulator.py +0 -0
- {Architectures β Modules}/GeneralLayers/MultiLayeredConv1d.py +0 -0
- {Architectures β Modules}/GeneralLayers/MultiSequential.py +0 -0
- {Architectures β Modules}/GeneralLayers/PositionalEncoding.py +0 -0
- {Architectures β Modules}/GeneralLayers/PositionwiseFeedForward.py +0 -0
- {Architectures β Modules}/GeneralLayers/README.md +0 -0
- {Architectures β Modules}/GeneralLayers/ResidualBlock.py +0 -0
- {Architectures β Modules}/GeneralLayers/ResidualStack.py +0 -0
- {Architectures β Modules}/GeneralLayers/STFT.py +0 -0
- {Architectures β Modules}/GeneralLayers/Swish.py +0 -0
- {Architectures β Modules}/GeneralLayers/VariancePredictor.py +3 -3
- {Architectures β Modules}/GeneralLayers/__init__.py +0 -0
- {Architectures β Modules}/README.md +0 -0
- {Architectures β Modules}/ToucanTTS/CodecDiscriminator.py +0 -0
Architectures/ToucanTTS/StochasticToucanTTS/README.md
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
This is an experimental version of the TTS that uses normalizing flows to predict the prosody explicitly, so that we can still have the controllability of the explicit prosody predictors, however a much better naturalness and livelyness than what we get from a deterministic predictor.
|
|
|
|
Architectures/ToucanTTS/StochasticToucanTTS/StochasticToucanTTS.py
DELETED
@@ -1,493 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch.nn import Linear
|
3 |
-
from torch.nn import Sequential
|
4 |
-
from torch.nn import Tanh
|
5 |
-
|
6 |
-
from Architectures.GeneralLayers.Conformer import Conformer
|
7 |
-
from Architectures.GeneralLayers.LengthRegulator import LengthRegulator
|
8 |
-
from Architectures.ToucanTTS.Glow import Glow
|
9 |
-
from Architectures.ToucanTTS.StochasticToucanTTS.StochasticToucanTTSLoss import StochasticToucanTTSLoss
|
10 |
-
from Architectures.ToucanTTS.StochasticToucanTTS.StochasticVariancePredictor import StochasticVariancePredictor
|
11 |
-
from Preprocessing.articulatory_features import get_feature_to_index_lookup
|
12 |
-
from Utility.utils import initialize
|
13 |
-
from Utility.utils import make_non_pad_mask
|
14 |
-
from Utility.utils import make_pad_mask
|
15 |
-
|
16 |
-
|
17 |
-
class StochasticToucanTTS(torch.nn.Module):
|
18 |
-
"""
|
19 |
-
StochasticToucanTTS module, which is mostly just a FastSpeech 2 module,
|
20 |
-
but with lots of designs from different architectures accumulated
|
21 |
-
and some major components added to put a large focus on multilinguality.
|
22 |
-
|
23 |
-
Original contributions:
|
24 |
-
- Inputs are configurations of the articulatory tract
|
25 |
-
- Word boundaries are modeled explicitly in the encoder end removed before the decoder
|
26 |
-
- Speaker embedding conditioning is derived from GST and Adaspeech 4
|
27 |
-
- Responsiveness of variance predictors to utterance embedding is increased through conditional layer norm
|
28 |
-
- The final output receives a GAN discriminator feedback signal
|
29 |
-
- Stochastic Duration Prediction through a normalizing flow
|
30 |
-
- Stochastic Pitch Prediction through a normalizing flow
|
31 |
-
- Stochastic Energy prediction through a normalizing flow
|
32 |
-
|
33 |
-
Contributions inspired from elsewhere:
|
34 |
-
- The PostNet is also a normalizing flow, like in PortaSpeech
|
35 |
-
- Pitch and energy values are averaged per-phone, as in FastPitch to enable great controllability
|
36 |
-
- The encoder and decoder are Conformers
|
37 |
-
|
38 |
-
"""
|
39 |
-
|
40 |
-
def __init__(self,
|
41 |
-
# network structure related
|
42 |
-
input_feature_dimensions=62,
|
43 |
-
output_spectrogram_channels=80,
|
44 |
-
attention_dimension=192,
|
45 |
-
attention_heads=4,
|
46 |
-
positionwise_conv_kernel_size=1,
|
47 |
-
use_scaled_positional_encoding=True,
|
48 |
-
init_type="xavier_uniform",
|
49 |
-
use_macaron_style_in_conformer=True,
|
50 |
-
use_cnn_in_conformer=True,
|
51 |
-
|
52 |
-
# encoder
|
53 |
-
encoder_layers=6,
|
54 |
-
encoder_units=1536,
|
55 |
-
encoder_normalize_before=True,
|
56 |
-
encoder_concat_after=False,
|
57 |
-
conformer_encoder_kernel_size=7,
|
58 |
-
transformer_enc_dropout_rate=0.2,
|
59 |
-
transformer_enc_positional_dropout_rate=0.2,
|
60 |
-
transformer_enc_attn_dropout_rate=0.2,
|
61 |
-
|
62 |
-
# decoder
|
63 |
-
decoder_layers=6,
|
64 |
-
decoder_units=1536,
|
65 |
-
decoder_concat_after=False,
|
66 |
-
conformer_decoder_kernel_size=31,
|
67 |
-
decoder_normalize_before=True,
|
68 |
-
transformer_dec_dropout_rate=0.2,
|
69 |
-
transformer_dec_positional_dropout_rate=0.2,
|
70 |
-
transformer_dec_attn_dropout_rate=0.2,
|
71 |
-
|
72 |
-
# duration predictor
|
73 |
-
duration_predictor_layers=3,
|
74 |
-
duration_predictor_chans=256,
|
75 |
-
duration_predictor_kernel_size=3,
|
76 |
-
duration_predictor_dropout_rate=0.2,
|
77 |
-
|
78 |
-
# pitch predictor
|
79 |
-
pitch_embed_kernel_size=1,
|
80 |
-
pitch_embed_dropout=0.0,
|
81 |
-
|
82 |
-
# energy predictor
|
83 |
-
energy_embed_kernel_size=1,
|
84 |
-
energy_embed_dropout=0.0,
|
85 |
-
|
86 |
-
# additional features
|
87 |
-
utt_embed_dim=192,
|
88 |
-
lang_embs=8000):
|
89 |
-
super().__init__()
|
90 |
-
|
91 |
-
self.input_feature_dimensions = input_feature_dimensions
|
92 |
-
self.output_spectrogram_channels = output_spectrogram_channels
|
93 |
-
self.attention_dimension = attention_dimension
|
94 |
-
self.use_scaled_pos_enc = use_scaled_positional_encoding
|
95 |
-
self.multilingual_model = lang_embs is not None
|
96 |
-
self.multispeaker_model = utt_embed_dim is not None
|
97 |
-
|
98 |
-
articulatory_feature_embedding = Sequential(Linear(input_feature_dimensions, 100), Tanh(), Linear(100, attention_dimension))
|
99 |
-
self.encoder = Conformer(conformer_type="encoder",
|
100 |
-
attention_dim=attention_dimension,
|
101 |
-
attention_heads=attention_heads,
|
102 |
-
linear_units=encoder_units,
|
103 |
-
num_blocks=encoder_layers,
|
104 |
-
input_layer=articulatory_feature_embedding,
|
105 |
-
dropout_rate=transformer_enc_dropout_rate,
|
106 |
-
positional_dropout_rate=transformer_enc_positional_dropout_rate,
|
107 |
-
attention_dropout_rate=transformer_enc_attn_dropout_rate,
|
108 |
-
normalize_before=encoder_normalize_before,
|
109 |
-
concat_after=encoder_concat_after,
|
110 |
-
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
|
111 |
-
macaron_style=use_macaron_style_in_conformer,
|
112 |
-
use_cnn_module=use_cnn_in_conformer,
|
113 |
-
cnn_module_kernel=conformer_encoder_kernel_size,
|
114 |
-
zero_triu=False,
|
115 |
-
utt_embed=utt_embed_dim,
|
116 |
-
lang_embs=lang_embs,
|
117 |
-
use_output_norm=True)
|
118 |
-
|
119 |
-
self.duration_flow = StochasticVariancePredictor(in_channels=attention_dimension,
|
120 |
-
kernel_size=3,
|
121 |
-
p_dropout=0.5,
|
122 |
-
n_flows=5,
|
123 |
-
conditioning_signal_channels=utt_embed_dim)
|
124 |
-
|
125 |
-
self.pitch_flow = StochasticVariancePredictor(in_channels=attention_dimension,
|
126 |
-
kernel_size=5,
|
127 |
-
p_dropout=0.5,
|
128 |
-
n_flows=6,
|
129 |
-
conditioning_signal_channels=utt_embed_dim)
|
130 |
-
|
131 |
-
self.energy_flow = StochasticVariancePredictor(in_channels=attention_dimension,
|
132 |
-
kernel_size=3,
|
133 |
-
p_dropout=0.5,
|
134 |
-
n_flows=3,
|
135 |
-
conditioning_signal_channels=utt_embed_dim)
|
136 |
-
|
137 |
-
self.pitch_embed = Sequential(torch.nn.Conv1d(in_channels=1,
|
138 |
-
out_channels=attention_dimension,
|
139 |
-
kernel_size=pitch_embed_kernel_size,
|
140 |
-
padding=(pitch_embed_kernel_size - 1) // 2),
|
141 |
-
torch.nn.Dropout(pitch_embed_dropout))
|
142 |
-
|
143 |
-
self.energy_embed = Sequential(torch.nn.Conv1d(in_channels=1, out_channels=attention_dimension, kernel_size=energy_embed_kernel_size,
|
144 |
-
padding=(energy_embed_kernel_size - 1) // 2),
|
145 |
-
torch.nn.Dropout(energy_embed_dropout))
|
146 |
-
|
147 |
-
self.length_regulator = LengthRegulator()
|
148 |
-
|
149 |
-
self.decoder = Conformer(conformer_type="decoder",
|
150 |
-
attention_dim=attention_dimension,
|
151 |
-
attention_heads=attention_heads,
|
152 |
-
linear_units=decoder_units,
|
153 |
-
num_blocks=decoder_layers,
|
154 |
-
input_layer=None,
|
155 |
-
dropout_rate=transformer_dec_dropout_rate,
|
156 |
-
positional_dropout_rate=transformer_dec_positional_dropout_rate,
|
157 |
-
attention_dropout_rate=transformer_dec_attn_dropout_rate,
|
158 |
-
normalize_before=decoder_normalize_before,
|
159 |
-
concat_after=decoder_concat_after,
|
160 |
-
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
|
161 |
-
macaron_style=use_macaron_style_in_conformer,
|
162 |
-
use_cnn_module=use_cnn_in_conformer,
|
163 |
-
cnn_module_kernel=conformer_decoder_kernel_size,
|
164 |
-
use_output_norm=False,
|
165 |
-
utt_embed=utt_embed_dim)
|
166 |
-
|
167 |
-
self.feat_out = Linear(attention_dimension, output_spectrogram_channels)
|
168 |
-
|
169 |
-
self.post_flow = Glow(
|
170 |
-
in_channels=output_spectrogram_channels,
|
171 |
-
hidden_channels=192, # post_glow_hidden
|
172 |
-
kernel_size=3, # post_glow_kernel_size
|
173 |
-
dilation_rate=1,
|
174 |
-
n_blocks=12, # post_glow_n_blocks (original 12 in paper)
|
175 |
-
n_layers=3, # post_glow_n_block_layers (original 3 in paper)
|
176 |
-
n_split=4,
|
177 |
-
n_sqz=2,
|
178 |
-
text_condition_channels=attention_dimension,
|
179 |
-
share_cond_layers=False, # post_share_cond_layers
|
180 |
-
share_wn_layers=4,
|
181 |
-
sigmoid_scale=False,
|
182 |
-
condition_integration_projection=torch.nn.Conv1d(output_spectrogram_channels + attention_dimension, attention_dimension, 5, padding=2)
|
183 |
-
)
|
184 |
-
|
185 |
-
# initialize parameters
|
186 |
-
self._reset_parameters(init_type=init_type)
|
187 |
-
if lang_embs is not None:
|
188 |
-
torch.nn.init.normal_(self.encoder.language_embedding.weight, mean=0, std=attention_dimension ** -0.5)
|
189 |
-
|
190 |
-
self.criterion = StochasticToucanTTSLoss()
|
191 |
-
|
192 |
-
def forward(self,
|
193 |
-
text_tensors,
|
194 |
-
text_lengths,
|
195 |
-
gold_speech,
|
196 |
-
speech_lengths,
|
197 |
-
gold_durations,
|
198 |
-
gold_pitch,
|
199 |
-
gold_energy,
|
200 |
-
utterance_embedding,
|
201 |
-
return_feats=False,
|
202 |
-
lang_ids=None,
|
203 |
-
run_glow=True
|
204 |
-
):
|
205 |
-
"""
|
206 |
-
Args:
|
207 |
-
return_feats (Boolean): whether to return the predicted spectrogram
|
208 |
-
text_tensors (LongTensor): Batch of padded text vectors (B, Tmax).
|
209 |
-
text_lengths (LongTensor): Batch of lengths of each input (B,).
|
210 |
-
gold_speech (Tensor): Batch of padded target features (B, Lmax, odim).
|
211 |
-
speech_lengths (LongTensor): Batch of the lengths of each target (B,).
|
212 |
-
gold_durations (LongTensor): Batch of padded durations (B, Tmax + 1).
|
213 |
-
gold_pitch (Tensor): Batch of padded token-averaged pitch (B, Tmax + 1, 1).
|
214 |
-
gold_energy (Tensor): Batch of padded token-averaged energy (B, Tmax + 1, 1).
|
215 |
-
run_glow (Boolean): Whether to run the PostNet. There should be a warmup phase in the beginning.
|
216 |
-
lang_ids (LongTensor): The language IDs used to access the language embedding table, if the model is multilingual
|
217 |
-
utterance_embedding (Tensor): Batch of embeddings to condition the TTS on, if the model is multispeaker
|
218 |
-
"""
|
219 |
-
before_outs, \
|
220 |
-
after_outs, \
|
221 |
-
duration_loss, \
|
222 |
-
pitch_loss, \
|
223 |
-
energy_loss, \
|
224 |
-
glow_loss = self._forward(text_tensors=text_tensors,
|
225 |
-
text_lengths=text_lengths,
|
226 |
-
gold_speech=gold_speech,
|
227 |
-
speech_lengths=speech_lengths,
|
228 |
-
gold_durations=gold_durations,
|
229 |
-
gold_pitch=gold_pitch,
|
230 |
-
gold_energy=gold_energy,
|
231 |
-
utterance_embedding=utterance_embedding,
|
232 |
-
is_inference=False,
|
233 |
-
lang_ids=lang_ids,
|
234 |
-
run_glow=run_glow)
|
235 |
-
|
236 |
-
# calculate loss
|
237 |
-
l1_loss = self.criterion(after_outs=after_outs,
|
238 |
-
before_outs=before_outs,
|
239 |
-
gold_spectrograms=gold_speech,
|
240 |
-
spectrogram_lengths=speech_lengths,
|
241 |
-
text_lengths=text_lengths)
|
242 |
-
|
243 |
-
if return_feats:
|
244 |
-
if after_outs is None:
|
245 |
-
after_outs = before_outs
|
246 |
-
return l1_loss, duration_loss, pitch_loss, energy_loss, glow_loss, after_outs
|
247 |
-
return l1_loss, duration_loss, pitch_loss, energy_loss, glow_loss
|
248 |
-
|
249 |
-
def _forward(self,
|
250 |
-
text_tensors,
|
251 |
-
text_lengths,
|
252 |
-
gold_speech=None,
|
253 |
-
speech_lengths=None,
|
254 |
-
gold_durations=None,
|
255 |
-
gold_pitch=None,
|
256 |
-
gold_energy=None,
|
257 |
-
is_inference=False,
|
258 |
-
utterance_embedding=None,
|
259 |
-
lang_ids=None,
|
260 |
-
run_glow=True):
|
261 |
-
|
262 |
-
if not self.multilingual_model:
|
263 |
-
lang_ids = None
|
264 |
-
|
265 |
-
if not self.multispeaker_model:
|
266 |
-
utterance_embedding = None
|
267 |
-
|
268 |
-
# encoding the texts
|
269 |
-
text_masks = make_non_pad_mask(text_lengths, device=text_lengths.device).unsqueeze(-2)
|
270 |
-
padding_masks = make_pad_mask(text_lengths, device=text_lengths.device)
|
271 |
-
encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids)
|
272 |
-
|
273 |
-
if is_inference:
|
274 |
-
variance_mask = torch.ones(size=[text_tensors.size(1)], device=text_tensors.device)
|
275 |
-
|
276 |
-
# predicting pitch
|
277 |
-
pitch_predictions = self.pitch_flow(encoded_texts.transpose(1, 2), variance_mask, w=None, g=utterance_embedding.unsqueeze(-1), reverse=True).squeeze(-1).transpose(1, 2)
|
278 |
-
for phoneme_index, phoneme_vector in enumerate(text_tensors.squeeze(0)):
|
279 |
-
if phoneme_vector[get_feature_to_index_lookup()["voiced"]] == 0:
|
280 |
-
pitch_predictions[0][phoneme_index] = 0.0
|
281 |
-
embedded_pitch_curve = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2)
|
282 |
-
encoded_texts = encoded_texts + embedded_pitch_curve
|
283 |
-
|
284 |
-
# predicting energy
|
285 |
-
energy_predictions = self.energy_flow(encoded_texts.transpose(1, 2), variance_mask, w=None, g=utterance_embedding.unsqueeze(-1), reverse=True).squeeze(-1).transpose(1, 2)
|
286 |
-
embedded_energy_curve = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2)
|
287 |
-
encoded_texts = encoded_texts + embedded_energy_curve
|
288 |
-
|
289 |
-
# predicting durations
|
290 |
-
predicted_durations = self.duration_flow(encoded_texts.transpose(1, 2), variance_mask, w=None, g=utterance_embedding.unsqueeze(-1), reverse=True).squeeze(-1).transpose(1, 2).squeeze(-1)
|
291 |
-
predicted_durations = torch.ceil(torch.exp(predicted_durations)).long()
|
292 |
-
for phoneme_index, phoneme_vector in enumerate(text_tensors.squeeze(0)):
|
293 |
-
if phoneme_vector[get_feature_to_index_lookup()["word-boundary"]] == 1:
|
294 |
-
predicted_durations[0][phoneme_index] = 0
|
295 |
-
|
296 |
-
# predicting durations for text and upsampling accordingly
|
297 |
-
upsampled_enriched_encoded_texts = self.length_regulator(encoded_texts, predicted_durations)
|
298 |
-
|
299 |
-
else:
|
300 |
-
# learning to predict pitch
|
301 |
-
idx = gold_pitch != 0
|
302 |
-
pitch_mask = torch.logical_and(text_masks, idx.transpose(1, 2))
|
303 |
-
scaled_pitch_targets = gold_pitch.detach().clone()
|
304 |
-
scaled_pitch_targets[idx] = torch.exp(gold_pitch[idx]) # we scale up, so that the log in the flow can handle the value ranges better.
|
305 |
-
pitch_flow_loss = torch.sum(self.pitch_flow(encoded_texts.transpose(1, 2).detach(), pitch_mask, w=scaled_pitch_targets.transpose(1, 2), g=utterance_embedding.unsqueeze(-1), reverse=False))
|
306 |
-
pitch_flow_loss = torch.sum(pitch_flow_loss / torch.sum(pitch_mask)) # weighted masking
|
307 |
-
embedded_pitch_curve = self.pitch_embed(gold_pitch.transpose(1, 2)).transpose(1, 2)
|
308 |
-
encoded_texts = encoded_texts + embedded_pitch_curve
|
309 |
-
|
310 |
-
# learning to predict energy
|
311 |
-
idx = gold_energy != 0
|
312 |
-
energy_mask = torch.logical_and(text_masks, idx.transpose(1, 2))
|
313 |
-
scaled_energy_targets = gold_energy.detach().clone()
|
314 |
-
scaled_energy_targets[idx] = torch.exp(gold_energy[idx]) # we scale up, so that the log in the flow can handle the value ranges better.
|
315 |
-
energy_flow_loss = torch.sum(self.energy_flow(encoded_texts.transpose(1, 2).detach(), energy_mask, w=scaled_energy_targets.transpose(1, 2), g=utterance_embedding.unsqueeze(-1), reverse=False))
|
316 |
-
energy_flow_loss = torch.sum(energy_flow_loss / torch.sum(energy_mask)) # weighted masking
|
317 |
-
embedded_energy_curve = self.energy_embed(gold_energy.transpose(1, 2)).transpose(1, 2)
|
318 |
-
encoded_texts = encoded_texts + embedded_energy_curve
|
319 |
-
|
320 |
-
# learning to predict durations
|
321 |
-
idx = gold_durations.unsqueeze(-1) != 0
|
322 |
-
duration_mask = torch.logical_and(text_masks, idx.transpose(1, 2))
|
323 |
-
duration_targets = gold_durations.unsqueeze(-1).detach().clone().float()
|
324 |
-
duration_flow_loss = torch.sum(self.duration_flow(encoded_texts.transpose(1, 2).detach(), duration_mask, w=duration_targets.transpose(1, 2), g=utterance_embedding.unsqueeze(-1), reverse=False))
|
325 |
-
duration_flow_loss = torch.sum(duration_flow_loss / torch.sum(duration_mask)) # weighted masking
|
326 |
-
|
327 |
-
upsampled_enriched_encoded_texts = self.length_regulator(encoded_texts, gold_durations)
|
328 |
-
|
329 |
-
# decoding spectrogram
|
330 |
-
decoder_masks = make_non_pad_mask(speech_lengths, device=speech_lengths.device).unsqueeze(-2) if speech_lengths is not None and not is_inference else None
|
331 |
-
decoded_speech, _ = self.decoder(upsampled_enriched_encoded_texts, decoder_masks, utterance_embedding=utterance_embedding)
|
332 |
-
decoded_spectrogram = self.feat_out(decoded_speech).view(decoded_speech.size(0), -1, self.output_spectrogram_channels)
|
333 |
-
|
334 |
-
# refine spectrogram further with a normalizing flow (requires warmup, so it's not always on)
|
335 |
-
glow_loss = None
|
336 |
-
if run_glow:
|
337 |
-
if is_inference:
|
338 |
-
refined_spectrogram = self.post_flow(tgt_mels=None,
|
339 |
-
infer=is_inference,
|
340 |
-
mel_out=decoded_spectrogram,
|
341 |
-
encoded_texts=upsampled_enriched_encoded_texts,
|
342 |
-
tgt_nonpadding=None).squeeze()
|
343 |
-
else:
|
344 |
-
glow_loss = self.post_flow(tgt_mels=gold_speech,
|
345 |
-
infer=is_inference,
|
346 |
-
mel_out=decoded_spectrogram.detach().clone(),
|
347 |
-
encoded_texts=upsampled_enriched_encoded_texts.detach().clone(),
|
348 |
-
tgt_nonpadding=decoder_masks)
|
349 |
-
if is_inference:
|
350 |
-
return decoded_spectrogram.squeeze(), \
|
351 |
-
refined_spectrogram.squeeze(), \
|
352 |
-
predicted_durations.squeeze(), \
|
353 |
-
pitch_predictions.squeeze(), \
|
354 |
-
energy_predictions.squeeze()
|
355 |
-
else:
|
356 |
-
return decoded_spectrogram, \
|
357 |
-
None, \
|
358 |
-
duration_flow_loss, \
|
359 |
-
pitch_flow_loss, \
|
360 |
-
energy_flow_loss, \
|
361 |
-
glow_loss
|
362 |
-
|
363 |
-
@torch.inference_mode()
|
364 |
-
def inference(self,
|
365 |
-
text,
|
366 |
-
speech=None,
|
367 |
-
utterance_embedding=None,
|
368 |
-
return_duration_pitch_energy=False,
|
369 |
-
lang_id=None,
|
370 |
-
run_postflow=True):
|
371 |
-
"""
|
372 |
-
Args:
|
373 |
-
text (LongTensor): Input sequence of characters (T,).
|
374 |
-
speech (Tensor, optional): Feature sequence to extract style (N, idim).
|
375 |
-
return_duration_pitch_energy (Boolean): whether to return the list of predicted durations for nicer plotting
|
376 |
-
run_postflow (Boolean): Whether to run the PostNet. There should be a warmup phase in the beginning.
|
377 |
-
lang_id (LongTensor): The language ID used to access the language embedding table, if the model is multilingual
|
378 |
-
utterance_embedding (Tensor): Embedding to condition the TTS on, if the model is multispeaker
|
379 |
-
"""
|
380 |
-
self.eval()
|
381 |
-
x, y = text, speech
|
382 |
-
|
383 |
-
# setup batch axis
|
384 |
-
ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device)
|
385 |
-
xs, ys = x.unsqueeze(0), None
|
386 |
-
if y is not None:
|
387 |
-
ys = y.unsqueeze(0)
|
388 |
-
if lang_id is not None:
|
389 |
-
lang_id = lang_id.unsqueeze(0)
|
390 |
-
utterance_embeddings = utterance_embedding.unsqueeze(0) if utterance_embedding is not None else None
|
391 |
-
|
392 |
-
before_outs, \
|
393 |
-
after_outs, \
|
394 |
-
duration_predictions, \
|
395 |
-
pitch_predictions, \
|
396 |
-
energy_predictions = self._forward(xs,
|
397 |
-
ilens,
|
398 |
-
ys,
|
399 |
-
is_inference=True,
|
400 |
-
utterance_embedding=utterance_embeddings,
|
401 |
-
lang_ids=lang_id,
|
402 |
-
run_glow=run_postflow) # (1, L, odim)
|
403 |
-
self.train()
|
404 |
-
if after_outs is None:
|
405 |
-
after_outs = before_outs
|
406 |
-
if return_duration_pitch_energy:
|
407 |
-
return before_outs, after_outs, duration_predictions, pitch_predictions, energy_predictions
|
408 |
-
return after_outs
|
409 |
-
|
410 |
-
def _reset_parameters(self, init_type):
|
411 |
-
# initialize parameters
|
412 |
-
if init_type != "pytorch":
|
413 |
-
initialize(self, init_type)
|
414 |
-
|
415 |
-
|
416 |
-
if __name__ == '__main__':
|
417 |
-
print(sum(p.numel() for p in StochasticToucanTTS().parameters() if p.requires_grad))
|
418 |
-
|
419 |
-
print(" TESTING TRAINING ")
|
420 |
-
|
421 |
-
print(" batchsize 3 ")
|
422 |
-
dummy_text_batch = torch.randint(low=0, high=2, size=[3, 3, 62]).float() # [Batch, Sequence Length, Features per Phone]
|
423 |
-
dummy_text_lens = torch.LongTensor([2, 3, 3])
|
424 |
-
|
425 |
-
dummy_speech_batch = torch.randn([3, 30, 80]) # [Batch, Sequence Length, Spectrogram Buckets]
|
426 |
-
dummy_speech_lens = torch.LongTensor([10, 30, 20])
|
427 |
-
|
428 |
-
dummy_durations = torch.LongTensor([[10, 0, 0], [10, 15, 5], [5, 5, 10]])
|
429 |
-
dummy_pitch = torch.Tensor([[[1.0], [0.], [0.]], [[1.1], [1.2], [0.8]], [[1.1], [1.2], [0.8]]])
|
430 |
-
dummy_energy = torch.Tensor([[[1.0], [1.3], [0.]], [[1.1], [1.4], [0.8]], [[1.1], [1.2], [0.8]]])
|
431 |
-
|
432 |
-
dummy_utterance_embed = torch.randn([3, 192]) # [Batch, Dimensions of Speaker Embedding]
|
433 |
-
dummy_language_id = torch.LongTensor([5, 3, 2]).unsqueeze(1)
|
434 |
-
|
435 |
-
model = StochasticToucanTTS()
|
436 |
-
l1, dl, pl, el, gl = model(dummy_text_batch,
|
437 |
-
dummy_text_lens,
|
438 |
-
dummy_speech_batch,
|
439 |
-
dummy_speech_lens,
|
440 |
-
dummy_durations,
|
441 |
-
dummy_pitch,
|
442 |
-
dummy_energy,
|
443 |
-
utterance_embedding=dummy_utterance_embed,
|
444 |
-
lang_ids=dummy_language_id)
|
445 |
-
|
446 |
-
loss = l1 + gl + dl + pl + el
|
447 |
-
print(loss)
|
448 |
-
loss.backward()
|
449 |
-
|
450 |
-
# from Utility.utils import plot_grad_flow
|
451 |
-
|
452 |
-
# plot_grad_flow(model.encoder.named_parameters())
|
453 |
-
# plot_grad_flow(model.decoder.named_parameters())
|
454 |
-
# plot_grad_flow(model.pitch_predictor.named_parameters())
|
455 |
-
# plot_grad_flow(model.duration_predictor.named_parameters())
|
456 |
-
# plot_grad_flow(model.post_flow.named_parameters())
|
457 |
-
|
458 |
-
print(" batchsize 2 ")
|
459 |
-
dummy_text_batch = torch.randint(low=0, high=2, size=[2, 3, 62]).float() # [Batch, Sequence Length, Features per Phone]
|
460 |
-
dummy_text_lens = torch.LongTensor([2, 3])
|
461 |
-
|
462 |
-
dummy_speech_batch = torch.randn([2, 30, 80]) # [Batch, Sequence Length, Spectrogram Buckets]
|
463 |
-
dummy_speech_lens = torch.LongTensor([10, 30])
|
464 |
-
|
465 |
-
dummy_durations = torch.LongTensor([[10, 0, 0], [10, 15, 5]])
|
466 |
-
dummy_pitch = torch.Tensor([[[1.0], [0.], [0.]], [[1.1], [1.2], [0.8]]])
|
467 |
-
dummy_energy = torch.Tensor([[[1.0], [1.3], [0.]], [[1.1], [1.4], [0.8]]])
|
468 |
-
|
469 |
-
dummy_utterance_embed = torch.randn([2, 192]) # [Batch, Dimensions of Speaker Embedding]
|
470 |
-
dummy_language_id = torch.LongTensor([5, 3]).unsqueeze(1)
|
471 |
-
|
472 |
-
model = StochasticToucanTTS()
|
473 |
-
l1, dl, pl, el, gl = model(dummy_text_batch,
|
474 |
-
dummy_text_lens,
|
475 |
-
dummy_speech_batch,
|
476 |
-
dummy_speech_lens,
|
477 |
-
dummy_durations,
|
478 |
-
dummy_pitch,
|
479 |
-
dummy_energy,
|
480 |
-
utterance_embedding=dummy_utterance_embed,
|
481 |
-
lang_ids=dummy_language_id)
|
482 |
-
|
483 |
-
loss = l1 + gl + dl + el + pl
|
484 |
-
print(loss)
|
485 |
-
loss.backward()
|
486 |
-
|
487 |
-
print(" TESTING INFERENCE ")
|
488 |
-
dummy_text_batch = torch.randint(low=0, high=2, size=[12, 62]).float() # [Sequence Length, Features per Phone]
|
489 |
-
dummy_utterance_embed = torch.randn([192]) # [Dimensions of Speaker Embedding]
|
490 |
-
dummy_language_id = torch.LongTensor([2])
|
491 |
-
print(StochasticToucanTTS().inference(dummy_text_batch,
|
492 |
-
utterance_embedding=dummy_utterance_embed,
|
493 |
-
lang_id=dummy_language_id).shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Architectures/ToucanTTS/StochasticToucanTTS/StochasticVariancePredictor.py
DELETED
@@ -1,440 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Code taken and adapted from https://github.com/jaywalnut310/vits
|
3 |
-
|
4 |
-
MIT License
|
5 |
-
|
6 |
-
Copyright (c) 2021 Jaehyeon Kim
|
7 |
-
|
8 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
9 |
-
of this software and associated documentation files (the "Software"), to deal
|
10 |
-
in the Software without restriction, including without limitation the rights
|
11 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
12 |
-
copies of the Software, and to permit persons to whom the Software is
|
13 |
-
furnished to do so, subject to the following conditions:
|
14 |
-
|
15 |
-
The above copyright notice and this permission notice shall be included in all
|
16 |
-
copies or substantial portions of the Software.
|
17 |
-
|
18 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
19 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
20 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
21 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
22 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
23 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
24 |
-
SOFTWARE.
|
25 |
-
"""
|
26 |
-
|
27 |
-
import math
|
28 |
-
|
29 |
-
import numpy as np
|
30 |
-
import torch
|
31 |
-
from torch import nn
|
32 |
-
from torch.nn import functional as F
|
33 |
-
|
34 |
-
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
35 |
-
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
36 |
-
DEFAULT_MIN_DERIVATIVE = 1e-3
|
37 |
-
|
38 |
-
|
39 |
-
class StochasticVariancePredictor(nn.Module):
|
40 |
-
def __init__(self, in_channels, kernel_size, p_dropout, n_flows=4, conditioning_signal_channels=0):
|
41 |
-
super().__init__()
|
42 |
-
self.in_channels = in_channels
|
43 |
-
self.filter_channels = in_channels
|
44 |
-
self.kernel_size = kernel_size
|
45 |
-
self.p_dropout = p_dropout
|
46 |
-
self.n_flows = n_flows
|
47 |
-
self.gin_channels = conditioning_signal_channels if conditioning_signal_channels is not None else 0
|
48 |
-
|
49 |
-
self.log_flow = Log()
|
50 |
-
self.flows = nn.ModuleList()
|
51 |
-
self.flows.append(ElementwiseAffine(2))
|
52 |
-
for i in range(n_flows):
|
53 |
-
self.flows.append(ConvFlow(2, in_channels, kernel_size, n_layers=3))
|
54 |
-
self.flows.append(Flip())
|
55 |
-
|
56 |
-
self.post_pre = nn.Conv1d(1, in_channels, 1)
|
57 |
-
self.post_proj = nn.Conv1d(in_channels, in_channels, 1)
|
58 |
-
self.post_convs = DDSConv(in_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
59 |
-
self.post_flows = nn.ModuleList()
|
60 |
-
self.post_flows.append(ElementwiseAffine(2))
|
61 |
-
for i in range(4):
|
62 |
-
self.post_flows.append(ConvFlow(2, in_channels, kernel_size, n_layers=3))
|
63 |
-
self.post_flows.append(Flip())
|
64 |
-
|
65 |
-
self.pre = nn.Conv1d(in_channels, in_channels, 1)
|
66 |
-
self.proj = nn.Conv1d(in_channels, in_channels, 1)
|
67 |
-
self.convs = DDSConv(in_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
68 |
-
if self.gin_channels != 0:
|
69 |
-
self.cond = nn.Conv1d(self.gin_channels, in_channels, 1)
|
70 |
-
|
71 |
-
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=0.3):
|
72 |
-
x = self.pre(x)
|
73 |
-
if g is not None:
|
74 |
-
g = torch.detach(g)
|
75 |
-
x = x + self.cond(g)
|
76 |
-
x = self.convs(x, x_mask)
|
77 |
-
x = self.proj(x) * x_mask
|
78 |
-
|
79 |
-
if not reverse:
|
80 |
-
flows = self.flows
|
81 |
-
assert w is not None
|
82 |
-
|
83 |
-
logdet_tot_q = 0
|
84 |
-
h_w = self.post_pre(w)
|
85 |
-
h_w = self.post_convs(h_w, x_mask)
|
86 |
-
h_w = self.post_proj(h_w) * x_mask
|
87 |
-
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
88 |
-
z_q = e_q
|
89 |
-
for flow in self.post_flows:
|
90 |
-
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
91 |
-
logdet_tot_q += logdet_q
|
92 |
-
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
93 |
-
u = torch.sigmoid(z_u) * x_mask
|
94 |
-
z0 = (w - u) * x_mask
|
95 |
-
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
|
96 |
-
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) - logdet_tot_q
|
97 |
-
|
98 |
-
logdet_tot = 0
|
99 |
-
z0, logdet = self.log_flow(z0, x_mask)
|
100 |
-
logdet_tot += logdet
|
101 |
-
z = torch.cat([z0, z1], 1)
|
102 |
-
for flow in flows:
|
103 |
-
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
104 |
-
logdet_tot = logdet_tot + logdet
|
105 |
-
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
|
106 |
-
return nll + logq # [b]
|
107 |
-
else:
|
108 |
-
flows = list(reversed(self.flows))
|
109 |
-
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
110 |
-
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
111 |
-
# noise scale 0.8 derived from coqui implementation, but dropped to 0.3 during testing. Might not be ideal yet.
|
112 |
-
for flow in flows:
|
113 |
-
z = flow(z, x_mask, g=x, reverse=reverse)
|
114 |
-
z0, z1 = torch.split(z, [1, 1], 1)
|
115 |
-
logw = z0
|
116 |
-
return logw
|
117 |
-
|
118 |
-
|
119 |
-
class Log(nn.Module):
|
120 |
-
def forward(self, x, x_mask, reverse=False, **kwargs):
|
121 |
-
if not reverse:
|
122 |
-
y = torch.log(torch.clamp_min(x, 1e-6)) * x_mask
|
123 |
-
logdet = torch.sum(-y, [1, 2])
|
124 |
-
return y, logdet
|
125 |
-
else:
|
126 |
-
x = torch.exp(x) * x_mask
|
127 |
-
return x
|
128 |
-
|
129 |
-
|
130 |
-
class Flip(nn.Module):
|
131 |
-
def forward(self, x, *args, reverse=False, **kwargs):
|
132 |
-
x = torch.flip(x, [1])
|
133 |
-
if not reverse:
|
134 |
-
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
135 |
-
return x, logdet
|
136 |
-
else:
|
137 |
-
return x
|
138 |
-
|
139 |
-
|
140 |
-
class DDSConv(nn.Module):
|
141 |
-
"""
|
142 |
-
Dialted and Depth-Separable Convolution
|
143 |
-
"""
|
144 |
-
|
145 |
-
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
|
146 |
-
super().__init__()
|
147 |
-
self.channels = channels
|
148 |
-
self.kernel_size = kernel_size
|
149 |
-
self.n_layers = n_layers
|
150 |
-
self.p_dropout = p_dropout
|
151 |
-
|
152 |
-
self.drop = nn.Dropout(p_dropout)
|
153 |
-
self.convs_sep = nn.ModuleList()
|
154 |
-
self.convs_1x1 = nn.ModuleList()
|
155 |
-
self.norms_1 = nn.ModuleList()
|
156 |
-
self.norms_2 = nn.ModuleList()
|
157 |
-
for i in range(n_layers):
|
158 |
-
dilation = kernel_size ** i
|
159 |
-
padding = (kernel_size * dilation - dilation) // 2
|
160 |
-
self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
|
161 |
-
groups=channels, dilation=dilation, padding=padding
|
162 |
-
))
|
163 |
-
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
164 |
-
self.norms_1.append(LayerNorm(channels))
|
165 |
-
self.norms_2.append(LayerNorm(channels))
|
166 |
-
|
167 |
-
def forward(self, x, x_mask, g=None):
|
168 |
-
if g is not None:
|
169 |
-
x = x + g
|
170 |
-
for i in range(self.n_layers):
|
171 |
-
y = self.convs_sep[i](x * x_mask)
|
172 |
-
y = self.norms_1[i](y)
|
173 |
-
y = F.gelu(y)
|
174 |
-
y = self.convs_1x1[i](y)
|
175 |
-
y = self.norms_2[i](y)
|
176 |
-
y = F.gelu(y)
|
177 |
-
y = self.drop(y)
|
178 |
-
x = x + y
|
179 |
-
return x * x_mask
|
180 |
-
|
181 |
-
|
182 |
-
class ConvFlow(nn.Module):
|
183 |
-
def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
|
184 |
-
super().__init__()
|
185 |
-
self.in_channels = in_channels
|
186 |
-
self.filter_channels = filter_channels
|
187 |
-
self.kernel_size = kernel_size
|
188 |
-
self.n_layers = n_layers
|
189 |
-
self.num_bins = num_bins
|
190 |
-
self.tail_bound = tail_bound
|
191 |
-
self.half_channels = in_channels // 2
|
192 |
-
|
193 |
-
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
|
194 |
-
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
|
195 |
-
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
|
196 |
-
self.proj.weight.data.zero_()
|
197 |
-
self.proj.bias.data.zero_()
|
198 |
-
|
199 |
-
def forward(self, x, x_mask, g=None, reverse=False):
|
200 |
-
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
201 |
-
h = self.pre(x0)
|
202 |
-
h = self.convs(h, x_mask, g=g)
|
203 |
-
h = self.proj(h) * x_mask
|
204 |
-
|
205 |
-
b, c, t = x0.shape
|
206 |
-
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
207 |
-
|
208 |
-
unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
|
209 |
-
unnormalized_heights = h[..., self.num_bins:2 * self.num_bins] / math.sqrt(self.filter_channels)
|
210 |
-
unnormalized_derivatives = h[..., 2 * self.num_bins:]
|
211 |
-
|
212 |
-
x1, logabsdet = piecewise_rational_quadratic_transform(x1,
|
213 |
-
unnormalized_widths,
|
214 |
-
unnormalized_heights,
|
215 |
-
unnormalized_derivatives,
|
216 |
-
inverse=reverse,
|
217 |
-
tails='linear',
|
218 |
-
tail_bound=self.tail_bound
|
219 |
-
)
|
220 |
-
|
221 |
-
x = torch.cat([x0, x1], 1) * x_mask
|
222 |
-
logdet = torch.sum(logabsdet * x_mask, [1, 2])
|
223 |
-
if not reverse:
|
224 |
-
return x, logdet
|
225 |
-
else:
|
226 |
-
return x
|
227 |
-
|
228 |
-
|
229 |
-
class ElementwiseAffine(nn.Module):
|
230 |
-
def __init__(self, channels):
|
231 |
-
super().__init__()
|
232 |
-
self.channels = channels
|
233 |
-
self.m = nn.Parameter(torch.zeros(channels, 1))
|
234 |
-
self.logs = nn.Parameter(torch.zeros(channels, 1))
|
235 |
-
|
236 |
-
def forward(self, x, x_mask, reverse=False, **kwargs):
|
237 |
-
if not reverse:
|
238 |
-
y = self.m + torch.exp(self.logs) * x
|
239 |
-
y = y * x_mask
|
240 |
-
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
241 |
-
return y, logdet
|
242 |
-
else:
|
243 |
-
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
244 |
-
return x
|
245 |
-
|
246 |
-
|
247 |
-
class LayerNorm(nn.Module):
|
248 |
-
def __init__(self, channels, eps=1e-5):
|
249 |
-
super().__init__()
|
250 |
-
self.channels = channels
|
251 |
-
self.eps = eps
|
252 |
-
|
253 |
-
self.gamma = nn.Parameter(torch.ones(channels))
|
254 |
-
self.beta = nn.Parameter(torch.zeros(channels))
|
255 |
-
|
256 |
-
def forward(self, x):
|
257 |
-
x = x.transpose(1, -1)
|
258 |
-
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
259 |
-
return x.transpose(1, -1)
|
260 |
-
|
261 |
-
|
262 |
-
def piecewise_rational_quadratic_transform(inputs,
|
263 |
-
unnormalized_widths,
|
264 |
-
unnormalized_heights,
|
265 |
-
unnormalized_derivatives,
|
266 |
-
inverse=False,
|
267 |
-
tails=None,
|
268 |
-
tail_bound=1.,
|
269 |
-
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
270 |
-
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
271 |
-
min_derivative=DEFAULT_MIN_DERIVATIVE):
|
272 |
-
if tails is None:
|
273 |
-
spline_fn = rational_quadratic_spline
|
274 |
-
spline_kwargs = {}
|
275 |
-
else:
|
276 |
-
spline_fn = unconstrained_rational_quadratic_spline
|
277 |
-
spline_kwargs = {
|
278 |
-
'tails' : tails,
|
279 |
-
'tail_bound': tail_bound
|
280 |
-
}
|
281 |
-
|
282 |
-
outputs, logabsdet = spline_fn(
|
283 |
-
inputs=inputs,
|
284 |
-
unnormalized_widths=unnormalized_widths,
|
285 |
-
unnormalized_heights=unnormalized_heights,
|
286 |
-
unnormalized_derivatives=unnormalized_derivatives,
|
287 |
-
inverse=inverse,
|
288 |
-
min_bin_width=min_bin_width,
|
289 |
-
min_bin_height=min_bin_height,
|
290 |
-
min_derivative=min_derivative,
|
291 |
-
**spline_kwargs
|
292 |
-
)
|
293 |
-
return outputs, logabsdet
|
294 |
-
|
295 |
-
|
296 |
-
def rational_quadratic_spline(inputs,
|
297 |
-
unnormalized_widths,
|
298 |
-
unnormalized_heights,
|
299 |
-
unnormalized_derivatives,
|
300 |
-
inverse=False,
|
301 |
-
left=0., right=1., bottom=0., top=1.,
|
302 |
-
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
303 |
-
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
304 |
-
min_derivative=DEFAULT_MIN_DERIVATIVE):
|
305 |
-
if torch.min(inputs) < left or torch.max(inputs) > right:
|
306 |
-
raise ValueError('Input to a transform is not within its domain')
|
307 |
-
|
308 |
-
num_bins = unnormalized_widths.shape[-1]
|
309 |
-
|
310 |
-
if min_bin_width * num_bins > 1.0:
|
311 |
-
raise ValueError('Minimal bin width too large for the number of bins')
|
312 |
-
if min_bin_height * num_bins > 1.0:
|
313 |
-
raise ValueError('Minimal bin height too large for the number of bins')
|
314 |
-
|
315 |
-
widths = F.softmax(unnormalized_widths, dim=-1)
|
316 |
-
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
317 |
-
cumwidths = torch.cumsum(widths, dim=-1)
|
318 |
-
cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
|
319 |
-
cumwidths = (right - left) * cumwidths + left
|
320 |
-
cumwidths[..., 0] = left
|
321 |
-
cumwidths[..., -1] = right
|
322 |
-
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
323 |
-
|
324 |
-
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
325 |
-
|
326 |
-
heights = F.softmax(unnormalized_heights, dim=-1)
|
327 |
-
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
328 |
-
cumheights = torch.cumsum(heights, dim=-1)
|
329 |
-
cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
|
330 |
-
cumheights = (top - bottom) * cumheights + bottom
|
331 |
-
cumheights[..., 0] = bottom
|
332 |
-
cumheights[..., -1] = top
|
333 |
-
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
334 |
-
|
335 |
-
if inverse:
|
336 |
-
bin_idx = searchsorted(cumheights, inputs)[..., None]
|
337 |
-
else:
|
338 |
-
bin_idx = searchsorted(cumwidths, inputs)[..., None]
|
339 |
-
|
340 |
-
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
341 |
-
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
342 |
-
|
343 |
-
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
344 |
-
delta = heights / widths
|
345 |
-
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
346 |
-
|
347 |
-
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
348 |
-
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
349 |
-
|
350 |
-
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
351 |
-
|
352 |
-
if inverse:
|
353 |
-
a = (((inputs - input_cumheights) * (input_derivatives
|
354 |
-
+ input_derivatives_plus_one
|
355 |
-
- 2 * input_delta)
|
356 |
-
+ input_heights * (input_delta - input_derivatives)))
|
357 |
-
b = (input_heights * input_derivatives
|
358 |
-
- (inputs - input_cumheights) * (input_derivatives
|
359 |
-
+ input_derivatives_plus_one
|
360 |
-
- 2 * input_delta))
|
361 |
-
c = - input_delta * (inputs - input_cumheights)
|
362 |
-
|
363 |
-
discriminant = b.pow(2) - 4 * a * c
|
364 |
-
assert (discriminant >= 0).all()
|
365 |
-
|
366 |
-
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
367 |
-
outputs = root * input_bin_widths + input_cumwidths
|
368 |
-
|
369 |
-
theta_one_minus_theta = root * (1 - root)
|
370 |
-
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
371 |
-
* theta_one_minus_theta)
|
372 |
-
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
|
373 |
-
+ 2 * input_delta * theta_one_minus_theta
|
374 |
-
+ input_derivatives * (1 - root).pow(2))
|
375 |
-
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
376 |
-
|
377 |
-
return outputs, -logabsdet
|
378 |
-
else:
|
379 |
-
theta = (inputs - input_cumwidths) / input_bin_widths
|
380 |
-
theta_one_minus_theta = theta * (1 - theta)
|
381 |
-
|
382 |
-
numerator = input_heights * (input_delta * theta.pow(2)
|
383 |
-
+ input_derivatives * theta_one_minus_theta)
|
384 |
-
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
385 |
-
* theta_one_minus_theta)
|
386 |
-
outputs = input_cumheights + numerator / denominator
|
387 |
-
|
388 |
-
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
|
389 |
-
+ 2 * input_delta * theta_one_minus_theta
|
390 |
-
+ input_derivatives * (1 - theta).pow(2))
|
391 |
-
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
392 |
-
|
393 |
-
return outputs, logabsdet
|
394 |
-
|
395 |
-
|
396 |
-
def searchsorted(bin_locations, inputs, eps=1e-6):
|
397 |
-
bin_locations[..., -1] += eps
|
398 |
-
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
399 |
-
|
400 |
-
|
401 |
-
def unconstrained_rational_quadratic_spline(inputs,
|
402 |
-
unnormalized_widths,
|
403 |
-
unnormalized_heights,
|
404 |
-
unnormalized_derivatives,
|
405 |
-
inverse=False,
|
406 |
-
tails='linear',
|
407 |
-
tail_bound=1.,
|
408 |
-
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
409 |
-
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
410 |
-
min_derivative=DEFAULT_MIN_DERIVATIVE):
|
411 |
-
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
412 |
-
outside_interval_mask = ~inside_interval_mask
|
413 |
-
|
414 |
-
outputs = torch.zeros_like(inputs)
|
415 |
-
logabsdet = torch.zeros_like(inputs)
|
416 |
-
|
417 |
-
if tails == 'linear':
|
418 |
-
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
419 |
-
constant = np.log(np.exp(1 - min_derivative) - 1)
|
420 |
-
unnormalized_derivatives[..., 0] = constant
|
421 |
-
unnormalized_derivatives[..., -1] = constant
|
422 |
-
|
423 |
-
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
424 |
-
logabsdet[outside_interval_mask] = 0
|
425 |
-
else:
|
426 |
-
raise RuntimeError('{} tails are not implemented.'.format(tails))
|
427 |
-
|
428 |
-
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
|
429 |
-
inputs=inputs[inside_interval_mask],
|
430 |
-
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
431 |
-
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
432 |
-
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
433 |
-
inverse=inverse,
|
434 |
-
left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
|
435 |
-
min_bin_width=min_bin_width,
|
436 |
-
min_bin_height=min_bin_height,
|
437 |
-
min_derivative=min_derivative
|
438 |
-
)
|
439 |
-
|
440 |
-
return outputs, logabsdet
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Architectures/__init__.py
DELETED
File without changes
|
InferenceInterfaces/ControllableInterface.py
CHANGED
@@ -2,8 +2,8 @@ import os
|
|
2 |
|
3 |
import torch
|
4 |
|
5 |
-
from Architectures.ControllabilityGAN.GAN import GanWrapper
|
6 |
from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface
|
|
|
7 |
from Utility.storage_config import MODELS_DIR
|
8 |
|
9 |
|
@@ -15,7 +15,7 @@ class ControllableInterface:
|
|
15 |
else:
|
16 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
17 |
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
|
18 |
-
self.device = "cuda" if
|
19 |
self.model = ToucanTTSInterface(device=self.device, tts_model_path="Meta")
|
20 |
self.wgan = GanWrapper(os.path.join(MODELS_DIR, "Embedding", "embedding_gan.pt"), device=self.device)
|
21 |
self.generated_speaker_embeds = list()
|
@@ -25,9 +25,11 @@ class ControllableInterface:
|
|
25 |
|
26 |
def read(self,
|
27 |
prompt,
|
|
|
28 |
language,
|
29 |
accent,
|
30 |
voice_seed,
|
|
|
31 |
duration_scaling_factor,
|
32 |
pause_duration_scaling_factor,
|
33 |
pitch_variance_scale,
|
@@ -37,24 +39,29 @@ class ControllableInterface:
|
|
37 |
emb_slider_3,
|
38 |
emb_slider_4,
|
39 |
emb_slider_5,
|
40 |
-
emb_slider_6
|
|
|
41 |
):
|
42 |
if self.current_language != language:
|
43 |
self.model.set_phonemizer_language(language)
|
|
|
44 |
self.current_language = language
|
45 |
if self.current_accent != accent:
|
46 |
self.model.set_accent_language(accent)
|
|
|
47 |
self.current_accent = accent
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
58 |
|
59 |
phones = self.model.text2phone.get_phone_string(prompt)
|
60 |
if len(phones) > 1800:
|
@@ -92,15 +99,15 @@ class ControllableInterface:
|
|
92 |
if self.current_accent != "eng":
|
93 |
self.model.set_accent_language("eng")
|
94 |
self.current_accent = "eng"
|
95 |
-
|
96 |
-
print(prompt)
|
97 |
-
print(language)
|
98 |
-
print("\n\n")
|
99 |
wav, sr, fig = self.model(prompt,
|
100 |
input_is_phones=False,
|
101 |
duration_scaling_factor=duration_scaling_factor,
|
102 |
pitch_variance_scale=pitch_variance_scale,
|
103 |
energy_variance_scale=energy_variance_scale,
|
104 |
pause_duration_scaling_factor=pause_duration_scaling_factor,
|
105 |
-
return_plot_as_filepath=True
|
|
|
|
|
106 |
return sr, wav, fig
|
|
|
2 |
|
3 |
import torch
|
4 |
|
|
|
5 |
from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface
|
6 |
+
from Modules.ControllabilityGAN.GAN import GanWrapper
|
7 |
from Utility.storage_config import MODELS_DIR
|
8 |
|
9 |
|
|
|
15 |
else:
|
16 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
17 |
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
|
18 |
+
self.device = "cuda" if gpu_id != "cpu" else "cpu"
|
19 |
self.model = ToucanTTSInterface(device=self.device, tts_model_path="Meta")
|
20 |
self.wgan = GanWrapper(os.path.join(MODELS_DIR, "Embedding", "embedding_gan.pt"), device=self.device)
|
21 |
self.generated_speaker_embeds = list()
|
|
|
25 |
|
26 |
def read(self,
|
27 |
prompt,
|
28 |
+
reference_audio,
|
29 |
language,
|
30 |
accent,
|
31 |
voice_seed,
|
32 |
+
prosody_creativity,
|
33 |
duration_scaling_factor,
|
34 |
pause_duration_scaling_factor,
|
35 |
pitch_variance_scale,
|
|
|
39 |
emb_slider_3,
|
40 |
emb_slider_4,
|
41 |
emb_slider_5,
|
42 |
+
emb_slider_6,
|
43 |
+
loudness_in_db
|
44 |
):
|
45 |
if self.current_language != language:
|
46 |
self.model.set_phonemizer_language(language)
|
47 |
+
print(f"switched phonemizer language to {language}")
|
48 |
self.current_language = language
|
49 |
if self.current_accent != accent:
|
50 |
self.model.set_accent_language(accent)
|
51 |
+
print(f"switched accent language to {accent}")
|
52 |
self.current_accent = accent
|
53 |
+
if reference_audio is None:
|
54 |
+
self.wgan.set_latent(voice_seed)
|
55 |
+
controllability_vector = torch.tensor([emb_slider_1,
|
56 |
+
emb_slider_2,
|
57 |
+
emb_slider_3,
|
58 |
+
emb_slider_4,
|
59 |
+
emb_slider_5,
|
60 |
+
emb_slider_6], dtype=torch.float32)
|
61 |
+
embedding = self.wgan.modify_embed(controllability_vector)
|
62 |
+
self.model.set_utterance_embedding(embedding=embedding)
|
63 |
+
else:
|
64 |
+
self.model.set_utterance_embedding(reference_audio)
|
65 |
|
66 |
phones = self.model.text2phone.get_phone_string(prompt)
|
67 |
if len(phones) > 1800:
|
|
|
99 |
if self.current_accent != "eng":
|
100 |
self.model.set_accent_language("eng")
|
101 |
self.current_accent = "eng"
|
102 |
+
|
103 |
+
print(prompt + "\n\n")
|
|
|
|
|
104 |
wav, sr, fig = self.model(prompt,
|
105 |
input_is_phones=False,
|
106 |
duration_scaling_factor=duration_scaling_factor,
|
107 |
pitch_variance_scale=pitch_variance_scale,
|
108 |
energy_variance_scale=energy_variance_scale,
|
109 |
pause_duration_scaling_factor=pause_duration_scaling_factor,
|
110 |
+
return_plot_as_filepath=True,
|
111 |
+
prosody_creativity=prosody_creativity,
|
112 |
+
loudness_in_db=loudness_in_db)
|
113 |
return sr, wav, fig
|
InferenceInterfaces/ToucanTTSInterface.py
CHANGED
@@ -1,19 +1,17 @@
|
|
1 |
import itertools
|
2 |
import os
|
3 |
-
import warnings
|
4 |
|
|
|
5 |
import matplotlib.pyplot as plt
|
6 |
import pyloudnorm
|
7 |
import sounddevice
|
8 |
import soundfile
|
9 |
import torch
|
10 |
-
|
11 |
-
|
12 |
-
from speechbrain.pretrained import EncoderClassifier
|
13 |
-
from torchaudio.transforms import Resample
|
14 |
|
15 |
-
from
|
16 |
-
from
|
17 |
from Preprocessing.AudioPreprocessor import AudioPreprocessor
|
18 |
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
|
19 |
from Preprocessing.TextFrontend import get_language_id
|
@@ -29,7 +27,6 @@ class ToucanTTSInterface(torch.nn.Module):
|
|
29 |
tts_model_path=os.path.join(MODELS_DIR, f"ToucanTTS_Meta", "best.pt"), # path to the ToucanTTS checkpoint or just a shorthand if run standalone
|
30 |
vocoder_model_path=os.path.join(MODELS_DIR, f"Vocoder", "best.pt"), # path to the Vocoder checkpoint
|
31 |
language="eng", # initial language of the model, can be changed later with the setter methods
|
32 |
-
enhance=None # legacy argument
|
33 |
):
|
34 |
super().__init__()
|
35 |
self.device = device
|
@@ -40,7 +37,7 @@ class ToucanTTSInterface(torch.nn.Module):
|
|
40 |
################################
|
41 |
# build text to phone #
|
42 |
################################
|
43 |
-
self.text2phone = ArticulatoryCombinedTextFrontend(language=language, add_silence_to_end=True)
|
44 |
|
45 |
#####################################
|
46 |
# load phone to features model #
|
@@ -92,8 +89,12 @@ class ToucanTTSInterface(torch.nn.Module):
|
|
92 |
speaker_embs = list()
|
93 |
for path in path_to_reference_audio:
|
94 |
wave, sr = soundfile.read(path)
|
|
|
|
|
|
|
|
|
95 |
wave = Resample(orig_freq=sr, new_freq=16000).to(self.device)(torch.tensor(wave, device=self.device, dtype=torch.float32))
|
96 |
-
speaker_embedding = self.speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(self.device).unsqueeze(0)).squeeze()
|
97 |
speaker_embs.append(speaker_embedding)
|
98 |
self.default_utterance_embedding = sum(speaker_embs) / len(speaker_embs)
|
99 |
|
@@ -105,10 +106,10 @@ class ToucanTTSInterface(torch.nn.Module):
|
|
105 |
self.set_accent_language(lang_id=lang_id)
|
106 |
|
107 |
def set_phonemizer_language(self, lang_id):
|
108 |
-
self.text2phone
|
109 |
|
110 |
def set_accent_language(self, lang_id):
|
111 |
-
if lang_id in
|
112 |
if lang_id == 'vi-so' or lang_id == 'vi-ctr':
|
113 |
lang_id = 'vie'
|
114 |
elif lang_id == 'spa-lat':
|
@@ -120,7 +121,7 @@ class ToucanTTSInterface(torch.nn.Module):
|
|
120 |
elif lang_id == 'en-sc' or lang_id == 'en-us':
|
121 |
lang_id = 'eng'
|
122 |
else:
|
123 |
-
# no clue where these others are even coming from, they are not in ISO 639-
|
124 |
lang_id = 'eng'
|
125 |
|
126 |
self.lang_id = get_language_id(lang_id).to(self.device)
|
@@ -138,7 +139,7 @@ class ToucanTTSInterface(torch.nn.Module):
|
|
138 |
input_is_phones=False,
|
139 |
return_plot_as_filepath=False,
|
140 |
loudness_in_db=-24.0,
|
141 |
-
|
142 |
"""
|
143 |
duration_scaling_factor: reasonable values are 0.8 < scale < 1.2.
|
144 |
1.0 means no scaling happens, higher values increase durations for the whole
|
@@ -154,16 +155,16 @@ class ToucanTTSInterface(torch.nn.Module):
|
|
154 |
phones = self.text2phone.string_to_tensor(text, input_phonemes=input_is_phones).to(torch.device(self.device))
|
155 |
mel, durations, pitch, energy = self.phone2mel(phones,
|
156 |
return_duration_pitch_energy=True,
|
157 |
-
utterance_embedding=self.default_utterance_embedding
|
158 |
durations=durations,
|
159 |
pitch=pitch,
|
160 |
energy=energy,
|
161 |
-
lang_id=self.lang_id
|
162 |
duration_scaling_factor=duration_scaling_factor,
|
163 |
pitch_variance_scale=pitch_variance_scale,
|
164 |
energy_variance_scale=energy_variance_scale,
|
165 |
pause_duration_scaling_factor=pause_duration_scaling_factor,
|
166 |
-
|
167 |
|
168 |
wave, _, _ = self.vocoder(mel.unsqueeze(0))
|
169 |
wave = wave.squeeze().cpu()
|
@@ -177,63 +178,56 @@ class ToucanTTSInterface(torch.nn.Module):
|
|
177 |
pass
|
178 |
|
179 |
if view or return_plot_as_filepath:
|
180 |
-
|
181 |
-
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 5))
|
182 |
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
|
|
|
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
except RuntimeError:
|
220 |
-
ax.set_title(text)
|
221 |
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
except:
|
227 |
-
pass
|
228 |
|
229 |
if return_plot_as_filepath:
|
230 |
-
|
231 |
-
|
232 |
-
plt.close()
|
233 |
-
except:
|
234 |
-
pass
|
235 |
return wave, sr, "tmp.png"
|
236 |
-
|
237 |
return wave, sr
|
238 |
|
239 |
def read_to_file(self,
|
@@ -247,7 +241,7 @@ class ToucanTTSInterface(torch.nn.Module):
|
|
247 |
dur_list=None,
|
248 |
pitch_list=None,
|
249 |
energy_list=None,
|
250 |
-
|
251 |
"""
|
252 |
Args:
|
253 |
silent: Whether to be verbose about the process
|
@@ -259,12 +253,19 @@ class ToucanTTSInterface(torch.nn.Module):
|
|
259 |
duration_scaling_factor: reasonable values are 0.8 < scale < 1.2.
|
260 |
1.0 means no scaling happens, higher values increase durations for the whole
|
261 |
utterance, lower values decrease durations for the whole utterance.
|
|
|
|
|
|
|
262 |
pitch_variance_scale: reasonable values are 0.6 < scale < 1.4.
|
263 |
1.0 means no scaling happens, higher values increase variance of the pitch curve,
|
264 |
lower values decrease variance of the pitch curve.
|
265 |
energy_variance_scale: reasonable values are 0.6 < scale < 1.4.
|
266 |
1.0 means no scaling happens, higher values increase variance of the energy curve,
|
267 |
lower values decrease variance of the energy curve.
|
|
|
|
|
|
|
|
|
268 |
"""
|
269 |
if not dur_list:
|
270 |
dur_list = []
|
@@ -272,7 +273,7 @@ class ToucanTTSInterface(torch.nn.Module):
|
|
272 |
pitch_list = []
|
273 |
if not energy_list:
|
274 |
energy_list = []
|
275 |
-
silence = torch.zeros([
|
276 |
wav = silence.clone()
|
277 |
for (text, durations, pitch, energy) in itertools.zip_longest(text_list, dur_list, pitch_list, energy_list):
|
278 |
if text.strip() != "":
|
@@ -286,7 +287,7 @@ class ToucanTTSInterface(torch.nn.Module):
|
|
286 |
pitch_variance_scale=pitch_variance_scale,
|
287 |
energy_variance_scale=energy_variance_scale,
|
288 |
pause_duration_scaling_factor=pause_duration_scaling_factor,
|
289 |
-
|
290 |
spoken_sentence = torch.tensor(spoken_sentence).cpu()
|
291 |
wav = torch.cat((wav, spoken_sentence, silence), 0)
|
292 |
soundfile.write(file=file_location, data=float2pcm(wav), samplerate=sr, subtype="PCM_16")
|
@@ -298,7 +299,7 @@ class ToucanTTSInterface(torch.nn.Module):
|
|
298 |
pitch_variance_scale=1.0,
|
299 |
energy_variance_scale=1.0,
|
300 |
blocking=False,
|
301 |
-
|
302 |
if text.strip() == "":
|
303 |
return
|
304 |
wav, sr = self(text,
|
@@ -306,7 +307,7 @@ class ToucanTTSInterface(torch.nn.Module):
|
|
306 |
duration_scaling_factor=duration_scaling_factor,
|
307 |
pitch_variance_scale=pitch_variance_scale,
|
308 |
energy_variance_scale=energy_variance_scale,
|
309 |
-
|
310 |
silence = torch.zeros([sr // 2])
|
311 |
wav = torch.cat((silence, torch.tensor(wav), silence), 0).numpy()
|
312 |
sounddevice.play(float2pcm(wav), samplerate=sr)
|
|
|
1 |
import itertools
|
2 |
import os
|
|
|
3 |
|
4 |
+
import librosa
|
5 |
import matplotlib.pyplot as plt
|
6 |
import pyloudnorm
|
7 |
import sounddevice
|
8 |
import soundfile
|
9 |
import torch
|
10 |
+
from speechbrain.pretrained import EncoderClassifier
|
11 |
+
from torchaudio.transforms import Resample
|
|
|
|
|
12 |
|
13 |
+
from Modules.ToucanTTS.InferenceToucanTTS import ToucanTTS
|
14 |
+
from Modules.Vocoder.HiFiGAN_Generator import HiFiGAN
|
15 |
from Preprocessing.AudioPreprocessor import AudioPreprocessor
|
16 |
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
|
17 |
from Preprocessing.TextFrontend import get_language_id
|
|
|
27 |
tts_model_path=os.path.join(MODELS_DIR, f"ToucanTTS_Meta", "best.pt"), # path to the ToucanTTS checkpoint or just a shorthand if run standalone
|
28 |
vocoder_model_path=os.path.join(MODELS_DIR, f"Vocoder", "best.pt"), # path to the Vocoder checkpoint
|
29 |
language="eng", # initial language of the model, can be changed later with the setter methods
|
|
|
30 |
):
|
31 |
super().__init__()
|
32 |
self.device = device
|
|
|
37 |
################################
|
38 |
# build text to phone #
|
39 |
################################
|
40 |
+
self.text2phone = ArticulatoryCombinedTextFrontend(language=language, add_silence_to_end=True, device=device)
|
41 |
|
42 |
#####################################
|
43 |
# load phone to features model #
|
|
|
89 |
speaker_embs = list()
|
90 |
for path in path_to_reference_audio:
|
91 |
wave, sr = soundfile.read(path)
|
92 |
+
if len(wave.shape) > 1: # oh no, we found a stereo audio!
|
93 |
+
if len(wave[0]) == 2: # let's figure out whether we need to switch the axes
|
94 |
+
wave = wave.transpose() # if yes, we switch the axes.
|
95 |
+
wave = librosa.to_mono(wave)
|
96 |
wave = Resample(orig_freq=sr, new_freq=16000).to(self.device)(torch.tensor(wave, device=self.device, dtype=torch.float32))
|
97 |
+
speaker_embedding = self.speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(self.device).squeeze().unsqueeze(0)).squeeze()
|
98 |
speaker_embs.append(speaker_embedding)
|
99 |
self.default_utterance_embedding = sum(speaker_embs) / len(speaker_embs)
|
100 |
|
|
|
106 |
self.set_accent_language(lang_id=lang_id)
|
107 |
|
108 |
def set_phonemizer_language(self, lang_id):
|
109 |
+
self.text2phone = ArticulatoryCombinedTextFrontend(language=lang_id, add_silence_to_end=True, device=self.device)
|
110 |
|
111 |
def set_accent_language(self, lang_id):
|
112 |
+
if lang_id in {'ajp', 'ajt', 'lak', 'lno', 'nul', 'pii', 'plj', 'slq', 'smd', 'snb', 'tpw', 'wya', 'zua', 'en-us', 'en-sc', 'fr-be', 'fr-sw', 'pt-br', 'spa-lat', 'vi-ctr', 'vi-so'}:
|
113 |
if lang_id == 'vi-so' or lang_id == 'vi-ctr':
|
114 |
lang_id = 'vie'
|
115 |
elif lang_id == 'spa-lat':
|
|
|
121 |
elif lang_id == 'en-sc' or lang_id == 'en-us':
|
122 |
lang_id = 'eng'
|
123 |
else:
|
124 |
+
# no clue where these others are even coming from, they are not in ISO 639-3
|
125 |
lang_id = 'eng'
|
126 |
|
127 |
self.lang_id = get_language_id(lang_id).to(self.device)
|
|
|
139 |
input_is_phones=False,
|
140 |
return_plot_as_filepath=False,
|
141 |
loudness_in_db=-24.0,
|
142 |
+
prosody_creativity=0.1):
|
143 |
"""
|
144 |
duration_scaling_factor: reasonable values are 0.8 < scale < 1.2.
|
145 |
1.0 means no scaling happens, higher values increase durations for the whole
|
|
|
155 |
phones = self.text2phone.string_to_tensor(text, input_phonemes=input_is_phones).to(torch.device(self.device))
|
156 |
mel, durations, pitch, energy = self.phone2mel(phones,
|
157 |
return_duration_pitch_energy=True,
|
158 |
+
utterance_embedding=self.default_utterance_embedding,
|
159 |
durations=durations,
|
160 |
pitch=pitch,
|
161 |
energy=energy,
|
162 |
+
lang_id=self.lang_id,
|
163 |
duration_scaling_factor=duration_scaling_factor,
|
164 |
pitch_variance_scale=pitch_variance_scale,
|
165 |
energy_variance_scale=energy_variance_scale,
|
166 |
pause_duration_scaling_factor=pause_duration_scaling_factor,
|
167 |
+
prosody_creativity=prosody_creativity)
|
168 |
|
169 |
wave, _, _ = self.vocoder(mel.unsqueeze(0))
|
170 |
wave = wave.squeeze().cpu()
|
|
|
178 |
pass
|
179 |
|
180 |
if view or return_plot_as_filepath:
|
181 |
+
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 5))
|
|
|
182 |
|
183 |
+
ax.imshow(mel.cpu().numpy(), origin="lower", cmap='GnBu')
|
184 |
+
ax.yaxis.set_visible(False)
|
185 |
+
duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
|
186 |
+
ax.xaxis.grid(True, which='minor')
|
187 |
+
ax.set_xticks(label_positions, minor=False)
|
188 |
+
if input_is_phones:
|
189 |
+
phones = text.replace(" ", "|")
|
190 |
+
else:
|
191 |
+
phones = self.text2phone.get_phone_string(text, for_plot_labels=True)
|
192 |
+
try:
|
193 |
+
ax.set_xticklabels(phones)
|
194 |
+
except IndexError:
|
195 |
+
pass
|
196 |
+
except ValueError:
|
197 |
+
pass
|
198 |
+
word_boundaries = list()
|
199 |
+
for label_index, phone in enumerate(phones):
|
200 |
+
if phone == "|":
|
201 |
+
word_boundaries.append(label_positions[label_index])
|
202 |
|
203 |
+
try:
|
204 |
+
prev_word_boundary = 0
|
205 |
+
word_label_positions = list()
|
206 |
+
for word_boundary in word_boundaries:
|
207 |
+
word_label_positions.append((word_boundary + prev_word_boundary) / 2)
|
208 |
+
prev_word_boundary = word_boundary
|
209 |
+
word_label_positions.append((duration_splits[-1] + prev_word_boundary) / 2)
|
210 |
|
211 |
+
secondary_ax = ax.secondary_xaxis('bottom')
|
212 |
+
secondary_ax.tick_params(axis="x", direction="out", pad=24)
|
213 |
+
secondary_ax.set_xticks(word_label_positions, minor=False)
|
214 |
+
secondary_ax.set_xticklabels(text.split())
|
215 |
+
secondary_ax.tick_params(axis='x', colors='orange')
|
216 |
+
secondary_ax.xaxis.label.set_color('orange')
|
217 |
+
except ValueError:
|
218 |
+
ax.set_title(text)
|
219 |
+
except IndexError:
|
220 |
+
ax.set_title(text)
|
|
|
|
|
221 |
|
222 |
+
ax.vlines(x=duration_splits, colors="green", linestyles="solid", ymin=0, ymax=120, linewidth=0.5)
|
223 |
+
ax.vlines(x=word_boundaries, colors="orange", linestyles="solid", ymin=0, ymax=120, linewidth=1.0)
|
224 |
+
plt.subplots_adjust(left=0.02, bottom=0.2, right=0.98, top=.9, wspace=0.0, hspace=0.0)
|
225 |
+
ax.set_aspect("auto")
|
|
|
|
|
226 |
|
227 |
if return_plot_as_filepath:
|
228 |
+
plt.savefig("tmp.png")
|
229 |
+
plt.close()
|
|
|
|
|
|
|
230 |
return wave, sr, "tmp.png"
|
|
|
231 |
return wave, sr
|
232 |
|
233 |
def read_to_file(self,
|
|
|
241 |
dur_list=None,
|
242 |
pitch_list=None,
|
243 |
energy_list=None,
|
244 |
+
prosody_creativity=0.1):
|
245 |
"""
|
246 |
Args:
|
247 |
silent: Whether to be verbose about the process
|
|
|
253 |
duration_scaling_factor: reasonable values are 0.8 < scale < 1.2.
|
254 |
1.0 means no scaling happens, higher values increase durations for the whole
|
255 |
utterance, lower values decrease durations for the whole utterance.
|
256 |
+
pause_duration_scaling_factor: reasonable values are 0.8 < scale < 1.2.
|
257 |
+
1.0 means no scaling happens, higher values increase durations for the pauses,
|
258 |
+
lower values decrease durations for the whole utterance.
|
259 |
pitch_variance_scale: reasonable values are 0.6 < scale < 1.4.
|
260 |
1.0 means no scaling happens, higher values increase variance of the pitch curve,
|
261 |
lower values decrease variance of the pitch curve.
|
262 |
energy_variance_scale: reasonable values are 0.6 < scale < 1.4.
|
263 |
1.0 means no scaling happens, higher values increase variance of the energy curve,
|
264 |
lower values decrease variance of the energy curve.
|
265 |
+
prosody_creativity: sampling temperature of the generative model that comes up with the pitch, energy and
|
266 |
+
durations. Higher values mena more variance, lower temperature means less variance across
|
267 |
+
generations. reasonable values are between 0.0 and 1.2, anything higher makes the voice
|
268 |
+
sound very weird.
|
269 |
"""
|
270 |
if not dur_list:
|
271 |
dur_list = []
|
|
|
273 |
pitch_list = []
|
274 |
if not energy_list:
|
275 |
energy_list = []
|
276 |
+
silence = torch.zeros([400])
|
277 |
wav = silence.clone()
|
278 |
for (text, durations, pitch, energy) in itertools.zip_longest(text_list, dur_list, pitch_list, energy_list):
|
279 |
if text.strip() != "":
|
|
|
287 |
pitch_variance_scale=pitch_variance_scale,
|
288 |
energy_variance_scale=energy_variance_scale,
|
289 |
pause_duration_scaling_factor=pause_duration_scaling_factor,
|
290 |
+
prosody_creativity=prosody_creativity)
|
291 |
spoken_sentence = torch.tensor(spoken_sentence).cpu()
|
292 |
wav = torch.cat((wav, spoken_sentence, silence), 0)
|
293 |
soundfile.write(file=file_location, data=float2pcm(wav), samplerate=sr, subtype="PCM_16")
|
|
|
299 |
pitch_variance_scale=1.0,
|
300 |
energy_variance_scale=1.0,
|
301 |
blocking=False,
|
302 |
+
prosody_creativity=0.1):
|
303 |
if text.strip() == "":
|
304 |
return
|
305 |
wav, sr = self(text,
|
|
|
307 |
duration_scaling_factor=duration_scaling_factor,
|
308 |
pitch_variance_scale=pitch_variance_scale,
|
309 |
energy_variance_scale=energy_variance_scale,
|
310 |
+
prosody_creativity=prosody_creativity)
|
311 |
silence = torch.zeros([sr // 2])
|
312 |
wav = torch.cat((silence, torch.tensor(wav), silence), 0).numpy()
|
313 |
sounddevice.play(float2pcm(wav), samplerate=sr)
|
InferenceInterfaces/UtteranceCloner.py
CHANGED
@@ -4,11 +4,11 @@ import numpy
|
|
4 |
import soundfile as sf
|
5 |
import torch
|
6 |
|
7 |
-
from Architectures.Aligner.Aligner import Aligner
|
8 |
-
from Architectures.ToucanTTS.DurationCalculator import DurationCalculator
|
9 |
-
from Architectures.ToucanTTS.EnergyCalculator import EnergyCalculator
|
10 |
-
from Architectures.ToucanTTS.PitchCalculator import Parselmouth
|
11 |
from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface
|
|
|
|
|
|
|
|
|
12 |
from Preprocessing.AudioPreprocessor import AudioPreprocessor
|
13 |
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
|
14 |
from Preprocessing.articulatory_features import get_feature_to_index_lookup
|
@@ -26,7 +26,7 @@ class UtteranceCloner:
|
|
26 |
def __init__(self, model_id, device, language="eng"):
|
27 |
self.tts = ToucanTTSInterface(device=device, tts_model_path=model_id)
|
28 |
self.ap = AudioPreprocessor(input_sr=100, output_sr=16000, cut_silence=False)
|
29 |
-
self.tf = ArticulatoryCombinedTextFrontend(language=language)
|
30 |
self.device = device
|
31 |
acoustic_checkpoint_path = os.path.join(MODELS_DIR, "Aligner", "aligner.pt")
|
32 |
self.aligner_weights = torch.load(acoustic_checkpoint_path, map_location=device)["asr_model"]
|
@@ -43,6 +43,7 @@ class UtteranceCloner:
|
|
43 |
self.acoustic_model = Aligner()
|
44 |
self.acoustic_model = self.acoustic_model.to(self.device)
|
45 |
self.acoustic_model.load_state_dict(self.aligner_weights)
|
|
|
46 |
self.parsel = Parselmouth(reduction_factor=1, fs=16000)
|
47 |
self.energy_calc = EnergyCalculator(reduction_factor=1, fs=16000)
|
48 |
self.dc = DurationCalculator(reduction_factor=1)
|
@@ -50,10 +51,11 @@ class UtteranceCloner:
|
|
50 |
def extract_prosody(self, transcript, ref_audio_path, lang="eng", on_line_fine_tune=True):
|
51 |
if on_line_fine_tune:
|
52 |
self.acoustic_model.load_state_dict(self.aligner_weights)
|
|
|
53 |
|
54 |
wave, sr = sf.read(ref_audio_path)
|
55 |
if self.tf.language != lang:
|
56 |
-
self.tf = ArticulatoryCombinedTextFrontend(language=lang)
|
57 |
if self.ap.input_sr != sr:
|
58 |
self.ap = AudioPreprocessor(input_sr=sr, output_sr=16000, cut_silence=False)
|
59 |
try:
|
|
|
4 |
import soundfile as sf
|
5 |
import torch
|
6 |
|
|
|
|
|
|
|
|
|
7 |
from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface
|
8 |
+
from Modules.Aligner.Aligner import Aligner
|
9 |
+
from Modules.ToucanTTS.DurationCalculator import DurationCalculator
|
10 |
+
from Modules.ToucanTTS.EnergyCalculator import EnergyCalculator
|
11 |
+
from Modules.ToucanTTS.PitchCalculator import Parselmouth
|
12 |
from Preprocessing.AudioPreprocessor import AudioPreprocessor
|
13 |
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
|
14 |
from Preprocessing.articulatory_features import get_feature_to_index_lookup
|
|
|
26 |
def __init__(self, model_id, device, language="eng"):
|
27 |
self.tts = ToucanTTSInterface(device=device, tts_model_path=model_id)
|
28 |
self.ap = AudioPreprocessor(input_sr=100, output_sr=16000, cut_silence=False)
|
29 |
+
self.tf = ArticulatoryCombinedTextFrontend(language=language, device=device)
|
30 |
self.device = device
|
31 |
acoustic_checkpoint_path = os.path.join(MODELS_DIR, "Aligner", "aligner.pt")
|
32 |
self.aligner_weights = torch.load(acoustic_checkpoint_path, map_location=device)["asr_model"]
|
|
|
43 |
self.acoustic_model = Aligner()
|
44 |
self.acoustic_model = self.acoustic_model.to(self.device)
|
45 |
self.acoustic_model.load_state_dict(self.aligner_weights)
|
46 |
+
self.acoustic_model.eval()
|
47 |
self.parsel = Parselmouth(reduction_factor=1, fs=16000)
|
48 |
self.energy_calc = EnergyCalculator(reduction_factor=1, fs=16000)
|
49 |
self.dc = DurationCalculator(reduction_factor=1)
|
|
|
51 |
def extract_prosody(self, transcript, ref_audio_path, lang="eng", on_line_fine_tune=True):
|
52 |
if on_line_fine_tune:
|
53 |
self.acoustic_model.load_state_dict(self.aligner_weights)
|
54 |
+
self.acoustic_model.eval()
|
55 |
|
56 |
wave, sr = sf.read(ref_audio_path)
|
57 |
if self.tf.language != lang:
|
58 |
+
self.tf = ArticulatoryCombinedTextFrontend(language=lang, device=self.device)
|
59 |
if self.ap.input_sr != sr:
|
60 |
self.ap = AudioPreprocessor(input_sr=sr, output_sr=16000, cut_silence=False)
|
61 |
try:
|
InferenceInterfaces/audioseal_wm_16bits.yaml
DELETED
@@ -1,39 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the BSD-style license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
name: audioseal_wm_16bits
|
8 |
-
model_type: seanet
|
9 |
-
checkpoint: "https://dl.fbaipublicfiles.com/audioseal/6edcf62f/generator.pth"
|
10 |
-
nbits: 16
|
11 |
-
seanet:
|
12 |
-
activation: ELU
|
13 |
-
activation_params:
|
14 |
-
alpha: 1.0
|
15 |
-
causal: false
|
16 |
-
channels: 1
|
17 |
-
compress: 2
|
18 |
-
dilation_base: 2
|
19 |
-
dimension: 128
|
20 |
-
disable_norm_outer_blocks: 0
|
21 |
-
kernel_size: 7
|
22 |
-
last_kernel_size: 7
|
23 |
-
lstm: 2
|
24 |
-
n_filters: 32
|
25 |
-
n_residual_layers: 1
|
26 |
-
norm: weight_norm
|
27 |
-
norm_params: { }
|
28 |
-
pad_mode: constant
|
29 |
-
ratios:
|
30 |
-
- 8
|
31 |
-
- 5
|
32 |
-
- 4
|
33 |
-
- 2
|
34 |
-
residual_kernel_size: 3
|
35 |
-
true_skip: true
|
36 |
-
decoder:
|
37 |
-
final_activation: null
|
38 |
-
final_activation_params: null
|
39 |
-
trim_right_ratio: 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{Architectures β Modules}/Aligner/Aligner.py
RENAMED
@@ -1,27 +1,31 @@
|
|
1 |
"""
|
2 |
taken and adapted from https://github.com/as-ideas/DeepForcedAligner
|
|
|
|
|
|
|
|
|
3 |
"""
|
4 |
import matplotlib.pyplot as plt
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
import torch.multiprocessing
|
8 |
-
import torch.nn as nn
|
9 |
from torch.nn import CTCLoss
|
10 |
from torch.nn.utils.rnn import pack_padded_sequence
|
11 |
from torch.nn.utils.rnn import pad_packed_sequence
|
12 |
|
13 |
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
|
|
|
14 |
|
15 |
|
16 |
-
class BatchNormConv(nn.Module):
|
17 |
|
18 |
def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
|
19 |
super().__init__()
|
20 |
-
self.conv = nn.Conv1d(
|
21 |
in_channels, out_channels, kernel_size,
|
22 |
stride=1, padding=kernel_size // 2, bias=False)
|
23 |
-
self.bnorm = nn.BatchNorm1d(out_channels)
|
24 |
-
self.relu = nn.ReLU()
|
25 |
|
26 |
def forward(self, x):
|
27 |
x = x.transpose(1, 2)
|
@@ -37,22 +41,23 @@ class Aligner(torch.nn.Module):
|
|
37 |
def __init__(self,
|
38 |
n_features=128,
|
39 |
num_symbols=145,
|
40 |
-
|
41 |
-
|
42 |
super().__init__()
|
43 |
-
self.convs = nn.ModuleList([
|
44 |
BatchNormConv(n_features, conv_dim, 3),
|
45 |
-
nn.Dropout(p=0.5),
|
46 |
BatchNormConv(conv_dim, conv_dim, 3),
|
47 |
-
nn.Dropout(p=0.5),
|
48 |
BatchNormConv(conv_dim, conv_dim, 3),
|
49 |
-
nn.Dropout(p=0.5),
|
50 |
BatchNormConv(conv_dim, conv_dim, 3),
|
51 |
-
nn.Dropout(p=0.5),
|
52 |
BatchNormConv(conv_dim, conv_dim, 3),
|
53 |
-
nn.Dropout(p=0.5),
|
54 |
])
|
55 |
-
self.
|
|
|
56 |
self.proj = torch.nn.Linear(2 * lstm_dim, num_symbols)
|
57 |
self.tf = ArticulatoryCombinedTextFrontend(language="eng")
|
58 |
self.ctc_loss = CTCLoss(blank=144, zero_infinity=True)
|
@@ -61,14 +66,17 @@ class Aligner(torch.nn.Module):
|
|
61 |
def forward(self, x, lens=None):
|
62 |
for conv in self.convs:
|
63 |
x = conv(x)
|
64 |
-
|
65 |
if lens is not None:
|
66 |
x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False)
|
67 |
-
x, _ = self.
|
|
|
68 |
if lens is not None:
|
69 |
x, _ = pad_packed_sequence(x, batch_first=True)
|
70 |
|
71 |
x = self.proj(x)
|
|
|
|
|
|
|
72 |
|
73 |
return x
|
74 |
|
@@ -88,15 +96,12 @@ class Aligner(torch.nn.Module):
|
|
88 |
pred_max = pred[:, tokens]
|
89 |
|
90 |
# run monotonic alignment search
|
91 |
-
|
92 |
alignment_matrix = binarize_alignment(pred_max)
|
93 |
|
94 |
if save_img_for_debug is not None:
|
95 |
phones = list()
|
96 |
for index in tokens:
|
97 |
-
|
98 |
-
if self.tf.phone_to_id[phone] == index:
|
99 |
-
phones.append(phone)
|
100 |
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5))
|
101 |
|
102 |
ax.imshow(alignment_matrix, interpolation='nearest', aspect='auto', origin="lower", cmap='cividis')
|
@@ -115,7 +120,6 @@ class Aligner(torch.nn.Module):
|
|
115 |
return alignment_matrix
|
116 |
|
117 |
|
118 |
-
|
119 |
def binarize_alignment(alignment_prob):
|
120 |
"""
|
121 |
# Implementation by:
|
@@ -152,13 +156,5 @@ def binarize_alignment(alignment_prob):
|
|
152 |
|
153 |
|
154 |
if __name__ == '__main__':
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
cap = CodecAudioPreprocessor(input_sr=-1)
|
159 |
-
dummy_codebook_indexes = torch.randint(low=0, high=1023, size=[9, 20])
|
160 |
-
codebook_frames = cap.indexes_to_codec_frames(dummy_codebook_indexes)
|
161 |
-
alignment = Aligner().inference(codebook_frames.transpose(0, 1), tokens=tf.string_to_tensor("Hello world"))
|
162 |
-
print(alignment.shape)
|
163 |
-
plt.imshow(alignment, origin="lower", cmap="GnBu")
|
164 |
-
plt.show()
|
|
|
1 |
"""
|
2 |
taken and adapted from https://github.com/as-ideas/DeepForcedAligner
|
3 |
+
|
4 |
+
refined with insights from https://www.audiolabs-erlangen.de/resources/NLUI/2023-ICASSP-eval-alignment-tts
|
5 |
+
EVALUATING SPEECHβPHONEME ALIGNMENT AND ITS IMPACT ON NEURAL TEXT-TO-SPEECH SYNTHESIS
|
6 |
+
by Frank Zalkow, Prachi Govalkar, Meinard Muller, Emanuel A. P. Habets, Christian Dittmar
|
7 |
"""
|
8 |
import matplotlib.pyplot as plt
|
9 |
import numpy as np
|
10 |
import torch
|
11 |
import torch.multiprocessing
|
|
|
12 |
from torch.nn import CTCLoss
|
13 |
from torch.nn.utils.rnn import pack_padded_sequence
|
14 |
from torch.nn.utils.rnn import pad_packed_sequence
|
15 |
|
16 |
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
|
17 |
+
from Utility.utils import make_non_pad_mask
|
18 |
|
19 |
|
20 |
+
class BatchNormConv(torch.nn.Module):
|
21 |
|
22 |
def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
|
23 |
super().__init__()
|
24 |
+
self.conv = torch.nn.Conv1d(
|
25 |
in_channels, out_channels, kernel_size,
|
26 |
stride=1, padding=kernel_size // 2, bias=False)
|
27 |
+
self.bnorm = torch.nn.SyncBatchNorm.convert_sync_batchnorm(torch.nn.BatchNorm1d(out_channels))
|
28 |
+
self.relu = torch.nn.ReLU()
|
29 |
|
30 |
def forward(self, x):
|
31 |
x = x.transpose(1, 2)
|
|
|
41 |
def __init__(self,
|
42 |
n_features=128,
|
43 |
num_symbols=145,
|
44 |
+
conv_dim=512,
|
45 |
+
lstm_dim=512):
|
46 |
super().__init__()
|
47 |
+
self.convs = torch.nn.ModuleList([
|
48 |
BatchNormConv(n_features, conv_dim, 3),
|
49 |
+
torch.nn.Dropout(p=0.5),
|
50 |
BatchNormConv(conv_dim, conv_dim, 3),
|
51 |
+
torch.nn.Dropout(p=0.5),
|
52 |
BatchNormConv(conv_dim, conv_dim, 3),
|
53 |
+
torch.nn.Dropout(p=0.5),
|
54 |
BatchNormConv(conv_dim, conv_dim, 3),
|
55 |
+
torch.nn.Dropout(p=0.5),
|
56 |
BatchNormConv(conv_dim, conv_dim, 3),
|
57 |
+
torch.nn.Dropout(p=0.5),
|
58 |
])
|
59 |
+
self.rnn1 = torch.nn.LSTM(conv_dim, lstm_dim, batch_first=True, bidirectional=True)
|
60 |
+
self.rnn2 = torch.nn.LSTM(2 * lstm_dim, lstm_dim, batch_first=True, bidirectional=True)
|
61 |
self.proj = torch.nn.Linear(2 * lstm_dim, num_symbols)
|
62 |
self.tf = ArticulatoryCombinedTextFrontend(language="eng")
|
63 |
self.ctc_loss = CTCLoss(blank=144, zero_infinity=True)
|
|
|
66 |
def forward(self, x, lens=None):
|
67 |
for conv in self.convs:
|
68 |
x = conv(x)
|
|
|
69 |
if lens is not None:
|
70 |
x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False)
|
71 |
+
x, _ = self.rnn1(x)
|
72 |
+
x, _ = self.rnn2(x)
|
73 |
if lens is not None:
|
74 |
x, _ = pad_packed_sequence(x, batch_first=True)
|
75 |
|
76 |
x = self.proj(x)
|
77 |
+
if lens is not None:
|
78 |
+
out_masks = make_non_pad_mask(lens).unsqueeze(-1).to(x.device)
|
79 |
+
x = x * out_masks.float()
|
80 |
|
81 |
return x
|
82 |
|
|
|
96 |
pred_max = pred[:, tokens]
|
97 |
|
98 |
# run monotonic alignment search
|
|
|
99 |
alignment_matrix = binarize_alignment(pred_max)
|
100 |
|
101 |
if save_img_for_debug is not None:
|
102 |
phones = list()
|
103 |
for index in tokens:
|
104 |
+
phones.append(self.tf.id_to_phone[index])
|
|
|
|
|
105 |
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5))
|
106 |
|
107 |
ax.imshow(alignment_matrix, interpolation='nearest', aspect='auto', origin="lower", cmap='cividis')
|
|
|
120 |
return alignment_matrix
|
121 |
|
122 |
|
|
|
123 |
def binarize_alignment(alignment_prob):
|
124 |
"""
|
125 |
# Implementation by:
|
|
|
156 |
|
157 |
|
158 |
if __name__ == '__main__':
|
159 |
+
print(sum(p.numel() for p in Aligner().parameters() if p.requires_grad))
|
160 |
+
print(Aligner()(x=torch.randn(size=[3, 30, 128]), lens=torch.LongTensor([20, 30, 10])).shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{Architectures β Modules}/Aligner/CodecAlignerDataset.py
RENAMED
@@ -32,6 +32,7 @@ class CodecAlignerDataset(Dataset):
|
|
32 |
allow_unknown_symbols=False,
|
33 |
gpu_count=1,
|
34 |
rank=0):
|
|
|
35 |
self.gpu_count = gpu_count
|
36 |
self.rank = rank
|
37 |
if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache:
|
@@ -50,9 +51,10 @@ class CodecAlignerDataset(Dataset):
|
|
50 |
self.lang = lang
|
51 |
self.device = device
|
52 |
self.cache_dir = cache_dir
|
53 |
-
self.tf = ArticulatoryCombinedTextFrontend(language=self.lang)
|
54 |
cache = torch.load(os.path.join(self.cache_dir, "aligner_train_cache.pt"), map_location='cpu')
|
55 |
self.speaker_embeddings = cache[2]
|
|
|
56 |
self.datapoints = cache[0]
|
57 |
if self.gpu_count > 1:
|
58 |
# we only keep a chunk of the dataset in memory to avoid redundancy. Which chunk, we figure out using the rank.
|
@@ -85,6 +87,7 @@ class CodecAlignerDataset(Dataset):
|
|
85 |
if type(path_to_transcript_dict) != dict:
|
86 |
path_to_transcript_dict = path_to_transcript_dict() # in this case we passed a function instead of the dict, so that the function isn't executed if not necessary.
|
87 |
torch.multiprocessing.set_start_method('spawn', force=True)
|
|
|
88 |
resource_manager = Manager()
|
89 |
self.path_to_transcript_dict = resource_manager.dict(path_to_transcript_dict)
|
90 |
key_list = list(self.path_to_transcript_dict.keys())
|
@@ -93,6 +96,13 @@ class CodecAlignerDataset(Dataset):
|
|
93 |
fisher_yates_shuffle(key_list)
|
94 |
# build cache
|
95 |
print("... building dataset cache ...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
self.result_pool = resource_manager.list()
|
97 |
# make processes
|
98 |
key_splits = list()
|
@@ -176,8 +186,8 @@ class CodecAlignerDataset(Dataset):
|
|
176 |
torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets
|
177 |
# this to false globally during model loading rather than using inference mode or no_grad
|
178 |
silero_model = silero_model.to(device)
|
179 |
-
silence = torch.zeros([16000 //
|
180 |
-
tf = ArticulatoryCombinedTextFrontend(language=lang)
|
181 |
_, sr = sf.read(path_list[0])
|
182 |
assumed_sr = sr
|
183 |
ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device)
|
@@ -186,13 +196,15 @@ class CodecAlignerDataset(Dataset):
|
|
186 |
for path in tqdm(path_list):
|
187 |
if self.path_to_transcript_dict[path].strip() == "":
|
188 |
continue
|
189 |
-
|
190 |
try:
|
191 |
wave, sr = sf.read(path)
|
192 |
except:
|
193 |
print(f"Problem with an audio file: {path}")
|
194 |
continue
|
195 |
|
|
|
|
|
|
|
196 |
wave = librosa.to_mono(wave)
|
197 |
|
198 |
if sr != assumed_sr:
|
@@ -210,16 +222,19 @@ class CodecAlignerDataset(Dataset):
|
|
210 |
if verbose:
|
211 |
print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.")
|
212 |
continue
|
213 |
-
|
214 |
-
# remove silences from front and back, then add constant 1/4th second silences back to front and back
|
215 |
-
with torch.no_grad():
|
216 |
speech_timestamps = get_speech_timestamps(norm_wave, silero_model, sampling_rate=16000)
|
217 |
try:
|
|
|
|
|
|
|
|
|
|
|
218 |
result = norm_wave[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]
|
219 |
except IndexError:
|
220 |
print("Audio might be too short to cut silences from front and back.")
|
221 |
continue
|
222 |
-
|
223 |
|
224 |
# raw audio preprocessing is done
|
225 |
transcript = self.path_to_transcript_dict[path]
|
@@ -238,10 +253,10 @@ class CodecAlignerDataset(Dataset):
|
|
238 |
# this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample.
|
239 |
continue
|
240 |
|
241 |
-
cached_speech = ap.audio_to_codebook_indexes(audio=
|
242 |
process_internal_dataset_chunk.append([cached_text,
|
243 |
cached_speech,
|
244 |
-
|
245 |
path])
|
246 |
self.result_pool.append(process_internal_dataset_chunk)
|
247 |
|
@@ -256,16 +271,44 @@ class CodecAlignerDataset(Dataset):
|
|
256 |
codes = codes.transpose(0, 1)
|
257 |
|
258 |
return tokens, \
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
|
264 |
def __len__(self):
|
265 |
return len(self.datapoints)
|
266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
|
268 |
def fisher_yates_shuffle(lst):
|
269 |
for i in range(len(lst) - 1, 0, -1):
|
270 |
j = random.randint(0, i)
|
271 |
lst[i], lst[j] = lst[j], lst[i]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
allow_unknown_symbols=False,
|
33 |
gpu_count=1,
|
34 |
rank=0):
|
35 |
+
|
36 |
self.gpu_count = gpu_count
|
37 |
self.rank = rank
|
38 |
if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache:
|
|
|
51 |
self.lang = lang
|
52 |
self.device = device
|
53 |
self.cache_dir = cache_dir
|
54 |
+
self.tf = ArticulatoryCombinedTextFrontend(language=self.lang, device=device)
|
55 |
cache = torch.load(os.path.join(self.cache_dir, "aligner_train_cache.pt"), map_location='cpu')
|
56 |
self.speaker_embeddings = cache[2]
|
57 |
+
self.filepaths = cache[3]
|
58 |
self.datapoints = cache[0]
|
59 |
if self.gpu_count > 1:
|
60 |
# we only keep a chunk of the dataset in memory to avoid redundancy. Which chunk, we figure out using the rank.
|
|
|
87 |
if type(path_to_transcript_dict) != dict:
|
88 |
path_to_transcript_dict = path_to_transcript_dict() # in this case we passed a function instead of the dict, so that the function isn't executed if not necessary.
|
89 |
torch.multiprocessing.set_start_method('spawn', force=True)
|
90 |
+
torch.multiprocessing.set_sharing_strategy('file_system')
|
91 |
resource_manager = Manager()
|
92 |
self.path_to_transcript_dict = resource_manager.dict(path_to_transcript_dict)
|
93 |
key_list = list(self.path_to_transcript_dict.keys())
|
|
|
96 |
fisher_yates_shuffle(key_list)
|
97 |
# build cache
|
98 |
print("... building dataset cache ...")
|
99 |
+
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
|
100 |
+
# careful: assumes 16kHz or 8kHz audio
|
101 |
+
_, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad', # make sure it gets downloaded during single-processing first, if it's not already downloaded
|
102 |
+
model='silero_vad',
|
103 |
+
force_reload=False,
|
104 |
+
onnx=False,
|
105 |
+
verbose=False)
|
106 |
self.result_pool = resource_manager.list()
|
107 |
# make processes
|
108 |
key_splits = list()
|
|
|
186 |
torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets
|
187 |
# this to false globally during model loading rather than using inference mode or no_grad
|
188 |
silero_model = silero_model.to(device)
|
189 |
+
silence = torch.zeros([16000 // 8]).to(device)
|
190 |
+
tf = ArticulatoryCombinedTextFrontend(language=lang, device=device)
|
191 |
_, sr = sf.read(path_list[0])
|
192 |
assumed_sr = sr
|
193 |
ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device)
|
|
|
196 |
for path in tqdm(path_list):
|
197 |
if self.path_to_transcript_dict[path].strip() == "":
|
198 |
continue
|
|
|
199 |
try:
|
200 |
wave, sr = sf.read(path)
|
201 |
except:
|
202 |
print(f"Problem with an audio file: {path}")
|
203 |
continue
|
204 |
|
205 |
+
if len(wave.shape) > 1: # oh no, we found a stereo audio!
|
206 |
+
if len(wave[0]) == 2: # let's figure out whether we need to switch the axes
|
207 |
+
wave = wave.transpose() # if yes, we switch the axes.
|
208 |
wave = librosa.to_mono(wave)
|
209 |
|
210 |
if sr != assumed_sr:
|
|
|
222 |
if verbose:
|
223 |
print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.")
|
224 |
continue
|
225 |
+
with torch.inference_mode():
|
|
|
|
|
226 |
speech_timestamps = get_speech_timestamps(norm_wave, silero_model, sampling_rate=16000)
|
227 |
try:
|
228 |
+
silence_timestamps = invert_segments(speech_timestamps, len(norm_wave))
|
229 |
+
for silence_timestamp in silence_timestamps:
|
230 |
+
begin = silence_timestamp['start']
|
231 |
+
end = silence_timestamp['end']
|
232 |
+
norm_wave = torch.cat([norm_wave[:begin], torch.zeros([end - begin], device=device), norm_wave[end:]])
|
233 |
result = norm_wave[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]
|
234 |
except IndexError:
|
235 |
print("Audio might be too short to cut silences from front and back.")
|
236 |
continue
|
237 |
+
norm_wave = torch.cat([silence, result, silence])
|
238 |
|
239 |
# raw audio preprocessing is done
|
240 |
transcript = self.path_to_transcript_dict[path]
|
|
|
253 |
# this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample.
|
254 |
continue
|
255 |
|
256 |
+
cached_speech = ap.audio_to_codebook_indexes(audio=norm_wave, current_sampling_rate=16000).transpose(0, 1).cpu().numpy()
|
257 |
process_internal_dataset_chunk.append([cached_text,
|
258 |
cached_speech,
|
259 |
+
norm_wave.cpu().detach().numpy(),
|
260 |
path])
|
261 |
self.result_pool.append(process_internal_dataset_chunk)
|
262 |
|
|
|
271 |
codes = codes.transpose(0, 1)
|
272 |
|
273 |
return tokens, \
|
274 |
+
token_len, \
|
275 |
+
codes, \
|
276 |
+
None, \
|
277 |
+
self.speaker_embeddings[index]
|
278 |
|
279 |
def __len__(self):
|
280 |
return len(self.datapoints)
|
281 |
|
282 |
+
def remove_samples(self, list_of_samples_to_remove):
|
283 |
+
for remove_id in sorted(list_of_samples_to_remove, reverse=True):
|
284 |
+
self.datapoints.pop(remove_id)
|
285 |
+
self.speaker_embeddings.pop(remove_id)
|
286 |
+
self.filepaths.pop(remove_id)
|
287 |
+
torch.save((self.datapoints, None, self.speaker_embeddings, self.filepaths),
|
288 |
+
os.path.join(self.cache_dir, "aligner_train_cache.pt"))
|
289 |
+
print("Dataset updated!")
|
290 |
+
|
291 |
|
292 |
def fisher_yates_shuffle(lst):
|
293 |
for i in range(len(lst) - 1, 0, -1):
|
294 |
j = random.randint(0, i)
|
295 |
lst[i], lst[j] = lst[j], lst[i]
|
296 |
+
|
297 |
+
|
298 |
+
def invert_segments(segments, total_duration):
|
299 |
+
if not segments:
|
300 |
+
return [{'start': 0, 'end': total_duration}]
|
301 |
+
|
302 |
+
inverted_segments = []
|
303 |
+
previous_end = 0
|
304 |
+
|
305 |
+
for segment in segments:
|
306 |
+
start = segment['start']
|
307 |
+
if previous_end < start:
|
308 |
+
inverted_segments.append({'start': previous_end, 'end': start})
|
309 |
+
previous_end = segment['end']
|
310 |
+
|
311 |
+
if previous_end < total_duration:
|
312 |
+
inverted_segments.append({'start': previous_end, 'end': total_duration})
|
313 |
+
|
314 |
+
return inverted_segments
|
{Architectures β Modules}/Aligner/README.md
RENAMED
File without changes
|
{Architectures β Modules}/Aligner/Reconstructor.py
RENAMED
@@ -1,7 +1,5 @@
|
|
1 |
import torch
|
2 |
import torch.multiprocessing
|
3 |
-
from torch.nn.utils.rnn import pack_padded_sequence
|
4 |
-
from torch.nn.utils.rnn import pad_packed_sequence
|
5 |
|
6 |
from Utility.utils import make_non_pad_mask
|
7 |
|
@@ -12,28 +10,23 @@ class Reconstructor(torch.nn.Module):
|
|
12 |
n_features=128,
|
13 |
num_symbols=145,
|
14 |
speaker_embedding_dim=192,
|
15 |
-
|
16 |
super().__init__()
|
17 |
-
self.in_proj = torch.nn.Linear(num_symbols + speaker_embedding_dim,
|
18 |
-
self.
|
19 |
-
self.
|
20 |
-
self.out_proj = torch.nn.Linear(2 * lstm_dim, n_features)
|
21 |
self.l1_criterion = torch.nn.L1Loss(reduction="none")
|
22 |
-
self.l2_criterion = torch.nn.MSELoss(reduction="none")
|
23 |
|
24 |
def forward(self, x, lens, ys):
|
25 |
x = self.in_proj(x)
|
26 |
-
x =
|
27 |
-
x
|
28 |
-
x
|
29 |
-
x, _ = pad_packed_sequence(x, batch_first=True)
|
30 |
x = self.out_proj(x)
|
31 |
out_masks = make_non_pad_mask(lens).unsqueeze(-1).to(ys.device)
|
32 |
out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float()
|
33 |
out_weights /= ys.size(0) * ys.size(2)
|
34 |
-
|
35 |
-
l2_loss = self.l2_criterion(x, ys).mul(out_weights).masked_select(out_masks).sum()
|
36 |
-
return l1_loss + l2_loss
|
37 |
|
38 |
|
39 |
if __name__ == '__main__':
|
|
|
1 |
import torch
|
2 |
import torch.multiprocessing
|
|
|
|
|
3 |
|
4 |
from Utility.utils import make_non_pad_mask
|
5 |
|
|
|
10 |
n_features=128,
|
11 |
num_symbols=145,
|
12 |
speaker_embedding_dim=192,
|
13 |
+
hidden_dim=256):
|
14 |
super().__init__()
|
15 |
+
self.in_proj = torch.nn.Linear(num_symbols + speaker_embedding_dim, hidden_dim)
|
16 |
+
self.hidden_proj = torch.nn.Linear(hidden_dim, hidden_dim)
|
17 |
+
self.out_proj = torch.nn.Linear(hidden_dim, n_features)
|
|
|
18 |
self.l1_criterion = torch.nn.L1Loss(reduction="none")
|
|
|
19 |
|
20 |
def forward(self, x, lens, ys):
|
21 |
x = self.in_proj(x)
|
22 |
+
x = torch.nn.functional.leaky_relu(x)
|
23 |
+
x = self.hidden_proj(x)
|
24 |
+
x = torch.nn.functional.leaky_relu(x)
|
|
|
25 |
x = self.out_proj(x)
|
26 |
out_masks = make_non_pad_mask(lens).unsqueeze(-1).to(ys.device)
|
27 |
out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float()
|
28 |
out_weights /= ys.size(0) * ys.size(2)
|
29 |
+
return self.l1_criterion(x, ys).mul(out_weights).masked_select(out_masks).sum()
|
|
|
|
|
30 |
|
31 |
|
32 |
if __name__ == '__main__':
|
{Architectures β Modules}/Aligner/__init__.py
RENAMED
File without changes
|
{Architectures β Modules}/Aligner/autoaligner_train_loop.py
RENAMED
@@ -8,8 +8,8 @@ from torch.optim import RAdam
|
|
8 |
from torch.utils.data.dataloader import DataLoader
|
9 |
from tqdm import tqdm
|
10 |
|
11 |
-
from
|
12 |
-
from
|
13 |
from Preprocessing.AudioPreprocessor import AudioPreprocessor
|
14 |
from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor
|
15 |
|
@@ -152,6 +152,8 @@ def train_loop(train_dataset,
|
|
152 |
optim_asr.zero_grad()
|
153 |
if use_reconstruction:
|
154 |
optim_tts.zero_grad()
|
|
|
|
|
155 |
loss.backward()
|
156 |
torch.nn.utils.clip_grad_norm_(asr_model.parameters(), 1.0)
|
157 |
if use_reconstruction:
|
|
|
8 |
from torch.utils.data.dataloader import DataLoader
|
9 |
from tqdm import tqdm
|
10 |
|
11 |
+
from Modules.Aligner.Aligner import Aligner
|
12 |
+
from Modules.Aligner.Reconstructor import Reconstructor
|
13 |
from Preprocessing.AudioPreprocessor import AudioPreprocessor
|
14 |
from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor
|
15 |
|
|
|
152 |
optim_asr.zero_grad()
|
153 |
if use_reconstruction:
|
154 |
optim_tts.zero_grad()
|
155 |
+
if gpu_count > 1:
|
156 |
+
torch.distributed.barrier()
|
157 |
loss.backward()
|
158 |
torch.nn.utils.clip_grad_norm_(asr_model.parameters(), 1.0)
|
159 |
if use_reconstruction:
|
{Architectures β Modules}/ControllabilityGAN/GAN.py
RENAMED
@@ -1,12 +1,11 @@
|
|
1 |
import torch
|
2 |
|
3 |
-
from
|
4 |
|
5 |
|
6 |
-
class GanWrapper
|
7 |
|
8 |
-
def __init__(self, path_wgan, device
|
9 |
-
super().__init__(*args, **kwargs)
|
10 |
self.device = device
|
11 |
self.path_wgan = path_wgan
|
12 |
|
@@ -20,27 +19,41 @@ class GanWrapper(torch.nn.Module):
|
|
20 |
self.U = self.compute_controllability()
|
21 |
|
22 |
self.z_list = list()
|
|
|
23 |
for _ in range(1100):
|
24 |
-
self.z_list.append(self.wgan.G.
|
25 |
self.z = self.z_list[0]
|
26 |
|
27 |
def set_latent(self, seed):
|
28 |
self.z = self.z = self.z_list[seed]
|
29 |
|
30 |
def reset_default_latent(self):
|
31 |
-
self.z = self.wgan.G.
|
32 |
|
33 |
def load_model(self, path):
|
34 |
gan_checkpoint = torch.load(path, map_location="cpu")
|
35 |
|
36 |
self.wgan = create_wgan(parameters=gan_checkpoint['model_parameters'], device=self.device)
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
self.mean = gan_checkpoint["dataset_mean"]
|
41 |
self.std = gan_checkpoint["dataset_std"]
|
42 |
|
43 |
-
def compute_controllability(self, n_samples=
|
44 |
_, intermediate, z = self.wgan.sample_generator(num_samples=n_samples, nograd=True, return_intermediate=True)
|
45 |
intermediate = intermediate.cpu()
|
46 |
z = z.cpu()
|
@@ -69,7 +82,7 @@ class GanWrapper(torch.nn.Module):
|
|
69 |
def modify_embed(self, x):
|
70 |
self.wgan.G.eval()
|
71 |
z_new = self.z.squeeze() + torch.matmul(self.U.solution.t(), x)
|
72 |
-
embed_modified = self.wgan.G.
|
73 |
if self.normalize:
|
74 |
embed_modified = inverse_normalize(
|
75 |
embed_modified.cpu(),
|
|
|
1 |
import torch
|
2 |
|
3 |
+
from Modules.ControllabilityGAN.wgan.init_wgan import create_wgan
|
4 |
|
5 |
|
6 |
+
class GanWrapper:
|
7 |
|
8 |
+
def __init__(self, path_wgan, device):
|
|
|
9 |
self.device = device
|
10 |
self.path_wgan = path_wgan
|
11 |
|
|
|
19 |
self.U = self.compute_controllability()
|
20 |
|
21 |
self.z_list = list()
|
22 |
+
|
23 |
for _ in range(1100):
|
24 |
+
self.z_list.append(self.wgan.G.sample_latent(1, self.wgan.G.z_dim, temperature=0.8))
|
25 |
self.z = self.z_list[0]
|
26 |
|
27 |
def set_latent(self, seed):
|
28 |
self.z = self.z = self.z_list[seed]
|
29 |
|
30 |
def reset_default_latent(self):
|
31 |
+
self.z = self.wgan.G.sample_latent(1, self.wgan.G.z_dim, temperature=0.8)
|
32 |
|
33 |
def load_model(self, path):
|
34 |
gan_checkpoint = torch.load(path, map_location="cpu")
|
35 |
|
36 |
self.wgan = create_wgan(parameters=gan_checkpoint['model_parameters'], device=self.device)
|
37 |
+
# Create a new state dict without 'module.' prefix
|
38 |
+
new_state_dict_G = {}
|
39 |
+
for key, value in gan_checkpoint['generator_state_dict'].items():
|
40 |
+
# Remove 'module.' prefix
|
41 |
+
new_key = key.replace('module.', '')
|
42 |
+
new_state_dict_G[new_key] = value
|
43 |
+
|
44 |
+
new_state_dict_D = {}
|
45 |
+
for key, value in gan_checkpoint['critic_state_dict'].items():
|
46 |
+
# Remove 'module.' prefix
|
47 |
+
new_key = key.replace('module.', '')
|
48 |
+
new_state_dict_D[new_key] = value
|
49 |
+
|
50 |
+
self.wgan.G.load_state_dict(new_state_dict_G)
|
51 |
+
self.wgan.D.load_state_dict(new_state_dict_D)
|
52 |
|
53 |
self.mean = gan_checkpoint["dataset_mean"]
|
54 |
self.std = gan_checkpoint["dataset_std"]
|
55 |
|
56 |
+
def compute_controllability(self, n_samples=100000):
|
57 |
_, intermediate, z = self.wgan.sample_generator(num_samples=n_samples, nograd=True, return_intermediate=True)
|
58 |
intermediate = intermediate.cpu()
|
59 |
z = z.cpu()
|
|
|
82 |
def modify_embed(self, x):
|
83 |
self.wgan.G.eval()
|
84 |
z_new = self.z.squeeze() + torch.matmul(self.U.solution.t(), x)
|
85 |
+
embed_modified = self.wgan.G.forward(z_new.unsqueeze(0).to(self.device))
|
86 |
if self.normalize:
|
87 |
embed_modified = inverse_normalize(
|
88 |
embed_modified.cpu(),
|
{Architectures β Modules}/ControllabilityGAN/__init__.py
RENAMED
File without changes
|
{Architectures β Modules}/ControllabilityGAN/dataset/__init__.py
RENAMED
File without changes
|
{Architectures β Modules}/ControllabilityGAN/dataset/speaker_embeddings_dataset.py
RENAMED
File without changes
|
{Architectures β Modules}/ControllabilityGAN/wgan/__init__.py
RENAMED
File without changes
|
{Architectures β Modules}/ControllabilityGAN/wgan/init_weights.py
RENAMED
File without changes
|
{Architectures β Modules}/ControllabilityGAN/wgan/init_wgan.py
RENAMED
@@ -1,7 +1,7 @@
|
|
1 |
import torch
|
2 |
|
3 |
-
from
|
4 |
-
from
|
5 |
|
6 |
|
7 |
def create_wgan(parameters, device, optimizer='adam'):
|
|
|
1 |
import torch
|
2 |
|
3 |
+
from Modules.ControllabilityGAN.wgan.resnet_init import init_resnet
|
4 |
+
from Modules.ControllabilityGAN.wgan.wgan_qc import WassersteinGanQuadraticCost
|
5 |
|
6 |
|
7 |
def create_wgan(parameters, device, optimizer='adam'):
|
{Architectures β Modules}/ControllabilityGAN/wgan/resnet_1.py
RENAMED
@@ -76,8 +76,8 @@ class ResNet_G(nn.Module):
|
|
76 |
return out, l_1
|
77 |
return out
|
78 |
|
79 |
-
def sample_latent(self, n_samples, z_size):
|
80 |
-
return torch.randn((n_samples, z_size))
|
81 |
|
82 |
|
83 |
class ResNet_D(nn.Module):
|
|
|
76 |
return out, l_1
|
77 |
return out
|
78 |
|
79 |
+
def sample_latent(self, n_samples, z_size, temperature=0.7):
|
80 |
+
return torch.randn((n_samples, z_size)) * temperature
|
81 |
|
82 |
|
83 |
class ResNet_D(nn.Module):
|
{Architectures β Modules}/ControllabilityGAN/wgan/resnet_init.py
RENAMED
@@ -1,7 +1,7 @@
|
|
1 |
-
from
|
2 |
-
from
|
3 |
-
from
|
4 |
-
from
|
5 |
|
6 |
|
7 |
def init_resnet(parameters):
|
|
|
1 |
+
from Modules.ControllabilityGAN.wgan.init_weights import weights_init_D
|
2 |
+
from Modules.ControllabilityGAN.wgan.init_weights import weights_init_G
|
3 |
+
from Modules.ControllabilityGAN.wgan.resnet_1 import ResNet_D
|
4 |
+
from Modules.ControllabilityGAN.wgan.resnet_1 import ResNet_G
|
5 |
|
6 |
|
7 |
def init_resnet(parameters):
|
{Architectures β Modules}/ControllabilityGAN/wgan/wgan_qc.py
RENAMED
@@ -3,7 +3,6 @@ import time
|
|
3 |
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
-
import torch.nn as nn
|
7 |
import torch.optim as optim
|
8 |
from cvxopt import matrix
|
9 |
from cvxopt import solvers
|
@@ -11,13 +10,12 @@ from cvxopt import sparse
|
|
11 |
from cvxopt import spmatrix
|
12 |
from torch.autograd import grad as torch_grad
|
13 |
from tqdm import tqdm
|
14 |
-
import spaces
|
15 |
|
16 |
|
17 |
-
class WassersteinGanQuadraticCost
|
18 |
|
19 |
-
def __init__(self, generator, discriminator, gen_optimizer, dis_optimizer, criterion, epochs, n_max_iterations,
|
20 |
-
|
21 |
self.G = generator
|
22 |
self.G_opt = gen_optimizer
|
23 |
self.D = discriminator
|
@@ -46,8 +44,8 @@ class WassersteinGanQuadraticCost(torch.nn.Module):
|
|
46 |
self.Kr = np.sqrt(self.K)
|
47 |
self.LAMBDA = 2 * self.Kr * gamma * 2
|
48 |
|
49 |
-
self.G =
|
50 |
-
self.D =
|
51 |
|
52 |
self.schedulerD = self._build_lr_scheduler_(self.D_opt, milestones, lr_anneal)
|
53 |
self.schedulerG = self._build_lr_scheduler_(self.G_opt, milestones, lr_anneal)
|
@@ -245,10 +243,7 @@ class WassersteinGanQuadraticCost(torch.nn.Module):
|
|
245 |
latent_samples = latent_samples.to(self.device)
|
246 |
if nograd:
|
247 |
with torch.no_grad():
|
248 |
-
|
249 |
-
generated_data = self.G.module(latent_samples, return_intermediate=return_intermediate)
|
250 |
-
else:
|
251 |
-
generated_data = self.G(latent_samples, return_intermediate=return_intermediate)
|
252 |
else:
|
253 |
generated_data = self.G(latent_samples)
|
254 |
self.G.train()
|
|
|
3 |
|
4 |
import numpy as np
|
5 |
import torch
|
|
|
6 |
import torch.optim as optim
|
7 |
from cvxopt import matrix
|
8 |
from cvxopt import solvers
|
|
|
10 |
from cvxopt import spmatrix
|
11 |
from torch.autograd import grad as torch_grad
|
12 |
from tqdm import tqdm
|
|
|
13 |
|
14 |
|
15 |
+
class WassersteinGanQuadraticCost:
|
16 |
|
17 |
+
def __init__(self, generator, discriminator, gen_optimizer, dis_optimizer, criterion, epochs, n_max_iterations,
|
18 |
+
data_dimensions, batch_size, device, gamma=0.1, K=-1, milestones=[150000, 250000], lr_anneal=1.0):
|
19 |
self.G = generator
|
20 |
self.G_opt = gen_optimizer
|
21 |
self.D = discriminator
|
|
|
44 |
self.Kr = np.sqrt(self.K)
|
45 |
self.LAMBDA = 2 * self.Kr * gamma * 2
|
46 |
|
47 |
+
self.G = self.G.to(self.device)
|
48 |
+
self.D = self.D.to(self.device)
|
49 |
|
50 |
self.schedulerD = self._build_lr_scheduler_(self.D_opt, milestones, lr_anneal)
|
51 |
self.schedulerG = self._build_lr_scheduler_(self.G_opt, milestones, lr_anneal)
|
|
|
243 |
latent_samples = latent_samples.to(self.device)
|
244 |
if nograd:
|
245 |
with torch.no_grad():
|
246 |
+
generated_data = self.G(latent_samples, return_intermediate=return_intermediate)
|
|
|
|
|
|
|
247 |
else:
|
248 |
generated_data = self.G(latent_samples)
|
249 |
self.G.train()
|
{Architectures β Modules}/EmbeddingModel/GST.py
RENAMED
@@ -3,7 +3,7 @@
|
|
3 |
|
4 |
import torch
|
5 |
|
6 |
-
from
|
7 |
|
8 |
|
9 |
class GSTStyleEncoder(torch.nn.Module):
|
|
|
3 |
|
4 |
import torch
|
5 |
|
6 |
+
from Modules.GeneralLayers.Attention import MultiHeadedAttention as BaseMultiHeadedAttention
|
7 |
|
8 |
|
9 |
class GSTStyleEncoder(torch.nn.Module):
|
{Architectures β Modules}/EmbeddingModel/README.md
RENAMED
File without changes
|
{Architectures β Modules}/EmbeddingModel/StyleEmbedding.py
RENAMED
@@ -1,7 +1,7 @@
|
|
1 |
import torch
|
2 |
|
3 |
-
from
|
4 |
-
from
|
5 |
|
6 |
|
7 |
class StyleEmbedding(torch.nn.Module):
|
|
|
1 |
import torch
|
2 |
|
3 |
+
from Modules.EmbeddingModel.GST import GSTStyleEncoder
|
4 |
+
from Modules.EmbeddingModel.StyleTTSEncoder import StyleEncoder as StyleTTSEncoder
|
5 |
|
6 |
|
7 |
class StyleEmbedding(torch.nn.Module):
|
{Architectures β Modules}/EmbeddingModel/StyleTTSEncoder.py
RENAMED
File without changes
|
{Architectures β Modules}/EmbeddingModel/__init__.py
RENAMED
File without changes
|
{Architectures β Modules}/GeneralLayers/Attention.py
RENAMED
File without changes
|
{Architectures β Modules}/GeneralLayers/ConditionalLayerNorm.py
RENAMED
File without changes
|
{Architectures β Modules}/GeneralLayers/Conformer.py
RENAMED
@@ -4,16 +4,16 @@ Taken from ESPNet, but heavily modified
|
|
4 |
|
5 |
import torch
|
6 |
|
7 |
-
from
|
8 |
-
from
|
9 |
-
from
|
10 |
-
from
|
11 |
-
from
|
12 |
-
from
|
13 |
-
from
|
14 |
-
from
|
15 |
-
from
|
16 |
-
from
|
17 |
from Utility.utils import integrate_with_utt_embed
|
18 |
|
19 |
|
@@ -84,8 +84,12 @@ class Conformer(torch.nn.Module):
|
|
84 |
self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: torch.nn.Linear(attention_dim + utt_embed, attention_dim))
|
85 |
if lang_embs is not None:
|
86 |
self.language_embedding = torch.nn.Embedding(num_embeddings=lang_embs, embedding_dim=lang_emb_size)
|
87 |
-
|
|
|
|
|
|
|
88 |
self.language_emb_norm = LayerNorm(attention_dim)
|
|
|
89 |
# self-attention module definition
|
90 |
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
91 |
encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu)
|
@@ -138,21 +142,28 @@ class Conformer(torch.nn.Module):
|
|
138 |
if isinstance(xs, tuple):
|
139 |
x, pos_emb = xs[0], xs[1]
|
140 |
if self.conformer_type != "encoder":
|
141 |
-
x = integrate_with_utt_embed(hs=x,
|
|
|
|
|
|
|
142 |
xs = (x, pos_emb)
|
143 |
else:
|
144 |
if self.conformer_type != "encoder":
|
145 |
-
xs = integrate_with_utt_embed(hs=xs,
|
|
|
|
|
|
|
146 |
xs, masks = encoder(xs, masks)
|
147 |
|
148 |
if isinstance(xs, tuple):
|
149 |
xs = xs[0]
|
150 |
|
151 |
-
if self.use_output_norm and not (self.utt_embed and self.conformer_type == "encoder"):
|
152 |
-
xs = self.output_norm(xs)
|
153 |
-
|
154 |
if self.utt_embed and self.conformer_type == "encoder":
|
155 |
-
xs = integrate_with_utt_embed(hs=xs,
|
156 |
-
|
|
|
|
|
|
|
|
|
157 |
|
158 |
return xs, masks
|
|
|
4 |
|
5 |
import torch
|
6 |
|
7 |
+
from Modules.GeneralLayers.Attention import RelPositionMultiHeadedAttention
|
8 |
+
from Modules.GeneralLayers.ConditionalLayerNorm import AdaIN1d
|
9 |
+
from Modules.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm
|
10 |
+
from Modules.GeneralLayers.Convolution import ConvolutionModule
|
11 |
+
from Modules.GeneralLayers.EncoderLayer import EncoderLayer
|
12 |
+
from Modules.GeneralLayers.LayerNorm import LayerNorm
|
13 |
+
from Modules.GeneralLayers.MultiLayeredConv1d import MultiLayeredConv1d
|
14 |
+
from Modules.GeneralLayers.MultiSequential import repeat
|
15 |
+
from Modules.GeneralLayers.PositionalEncoding import RelPositionalEncoding
|
16 |
+
from Modules.GeneralLayers.Swish import Swish
|
17 |
from Utility.utils import integrate_with_utt_embed
|
18 |
|
19 |
|
|
|
84 |
self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: torch.nn.Linear(attention_dim + utt_embed, attention_dim))
|
85 |
if lang_embs is not None:
|
86 |
self.language_embedding = torch.nn.Embedding(num_embeddings=lang_embs, embedding_dim=lang_emb_size)
|
87 |
+
if lang_emb_size == attention_dim:
|
88 |
+
self.language_embedding_projection = lambda x: x
|
89 |
+
else:
|
90 |
+
self.language_embedding_projection = torch.nn.Linear(lang_emb_size, attention_dim)
|
91 |
self.language_emb_norm = LayerNorm(attention_dim)
|
92 |
+
|
93 |
# self-attention module definition
|
94 |
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
95 |
encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu)
|
|
|
142 |
if isinstance(xs, tuple):
|
143 |
x, pos_emb = xs[0], xs[1]
|
144 |
if self.conformer_type != "encoder":
|
145 |
+
x = integrate_with_utt_embed(hs=x,
|
146 |
+
utt_embeddings=utterance_embedding,
|
147 |
+
projection=self.decoder_embedding_projections[encoder_index],
|
148 |
+
embedding_training=self.use_conditional_layernorm_embedding_integration)
|
149 |
xs = (x, pos_emb)
|
150 |
else:
|
151 |
if self.conformer_type != "encoder":
|
152 |
+
xs = integrate_with_utt_embed(hs=xs,
|
153 |
+
utt_embeddings=utterance_embedding,
|
154 |
+
projection=self.decoder_embedding_projections[encoder_index],
|
155 |
+
embedding_training=self.use_conditional_layernorm_embedding_integration)
|
156 |
xs, masks = encoder(xs, masks)
|
157 |
|
158 |
if isinstance(xs, tuple):
|
159 |
xs = xs[0]
|
160 |
|
|
|
|
|
|
|
161 |
if self.utt_embed and self.conformer_type == "encoder":
|
162 |
+
xs = integrate_with_utt_embed(hs=xs,
|
163 |
+
utt_embeddings=utterance_embedding,
|
164 |
+
projection=self.encoder_embedding_projection,
|
165 |
+
embedding_training=self.use_conditional_layernorm_embedding_integration)
|
166 |
+
elif self.use_output_norm:
|
167 |
+
xs = self.output_norm(xs)
|
168 |
|
169 |
return xs, masks
|
{Architectures β Modules}/GeneralLayers/Convolution.py
RENAMED
@@ -24,7 +24,7 @@ class ConvolutionModule(nn.Module):
|
|
24 |
|
25 |
self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, )
|
26 |
self.depthwise_conv = nn.Conv1d(channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=bias, )
|
27 |
-
self.norm = nn.BatchNorm1d(channels)
|
28 |
self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, )
|
29 |
self.activation = activation
|
30 |
|
|
|
24 |
|
25 |
self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, )
|
26 |
self.depthwise_conv = nn.Conv1d(channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=bias, )
|
27 |
+
self.norm = nn.SyncBatchNorm.convert_sync_batchnorm(nn.BatchNorm1d(channels))
|
28 |
self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, )
|
29 |
self.activation = activation
|
30 |
|
{Architectures β Modules}/GeneralLayers/DurationPredictor.py
RENAMED
@@ -5,9 +5,9 @@
|
|
5 |
|
6 |
import torch
|
7 |
|
8 |
-
from
|
9 |
-
from
|
10 |
-
from
|
11 |
from Utility.utils import integrate_with_utt_embed
|
12 |
|
13 |
|
|
|
5 |
|
6 |
import torch
|
7 |
|
8 |
+
from Modules.GeneralLayers.ConditionalLayerNorm import AdaIN1d
|
9 |
+
from Modules.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm
|
10 |
+
from Modules.GeneralLayers.LayerNorm import LayerNorm
|
11 |
from Utility.utils import integrate_with_utt_embed
|
12 |
|
13 |
|
{Architectures β Modules}/GeneralLayers/EncoderLayer.py
RENAMED
@@ -7,7 +7,7 @@
|
|
7 |
import torch
|
8 |
from torch import nn
|
9 |
|
10 |
-
from
|
11 |
|
12 |
|
13 |
class EncoderLayer(nn.Module):
|
|
|
7 |
import torch
|
8 |
from torch import nn
|
9 |
|
10 |
+
from Modules.GeneralLayers.LayerNorm import LayerNorm
|
11 |
|
12 |
|
13 |
class EncoderLayer(nn.Module):
|
{Architectures β Modules}/GeneralLayers/LayerNorm.py
RENAMED
File without changes
|
{Architectures β Modules}/GeneralLayers/LengthRegulator.py
RENAMED
File without changes
|
{Architectures β Modules}/GeneralLayers/MultiLayeredConv1d.py
RENAMED
File without changes
|
{Architectures β Modules}/GeneralLayers/MultiSequential.py
RENAMED
File without changes
|
{Architectures β Modules}/GeneralLayers/PositionalEncoding.py
RENAMED
File without changes
|
{Architectures β Modules}/GeneralLayers/PositionwiseFeedForward.py
RENAMED
File without changes
|
{Architectures β Modules}/GeneralLayers/README.md
RENAMED
File without changes
|
{Architectures β Modules}/GeneralLayers/ResidualBlock.py
RENAMED
File without changes
|
{Architectures β Modules}/GeneralLayers/ResidualStack.py
RENAMED
File without changes
|
{Architectures β Modules}/GeneralLayers/STFT.py
RENAMED
File without changes
|
{Architectures β Modules}/GeneralLayers/Swish.py
RENAMED
File without changes
|
{Architectures β Modules}/GeneralLayers/VariancePredictor.py
RENAMED
@@ -6,9 +6,9 @@ from abc import ABC
|
|
6 |
|
7 |
import torch
|
8 |
|
9 |
-
from
|
10 |
-
from
|
11 |
-
from
|
12 |
from Utility.utils import integrate_with_utt_embed
|
13 |
|
14 |
|
|
|
6 |
|
7 |
import torch
|
8 |
|
9 |
+
from Modules.GeneralLayers.ConditionalLayerNorm import AdaIN1d
|
10 |
+
from Modules.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm
|
11 |
+
from Modules.GeneralLayers.LayerNorm import LayerNorm
|
12 |
from Utility.utils import integrate_with_utt_embed
|
13 |
|
14 |
|
{Architectures β Modules}/GeneralLayers/__init__.py
RENAMED
File without changes
|
{Architectures β Modules}/README.md
RENAMED
File without changes
|
{Architectures β Modules}/ToucanTTS/CodecDiscriminator.py
RENAMED
File without changes
|