|
|
|
from typing import Callable, Dict, Tuple |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from coqpit import Coqpit |
|
from torch import nn |
|
|
|
from TTS.tts.layers.delightful_tts.conformer import Conformer |
|
from TTS.tts.layers.delightful_tts.encoders import ( |
|
PhonemeLevelProsodyEncoder, |
|
UtteranceLevelProsodyEncoder, |
|
get_mask_from_lengths, |
|
) |
|
from TTS.tts.layers.delightful_tts.energy_adaptor import EnergyAdaptor |
|
from TTS.tts.layers.delightful_tts.networks import EmbeddingPadded, positional_encoding |
|
from TTS.tts.layers.delightful_tts.phoneme_prosody_predictor import PhonemeProsodyPredictor |
|
from TTS.tts.layers.delightful_tts.pitch_adaptor import PitchAdaptor |
|
from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor |
|
from TTS.tts.layers.generic.aligner import AlignmentNetwork |
|
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask |
|
|
|
|
|
class AcousticModel(torch.nn.Module): |
|
def __init__( |
|
self, |
|
args: "ModelArgs", |
|
tokenizer: "TTSTokenizer" = None, |
|
speaker_manager: "SpeakerManager" = None, |
|
): |
|
super().__init__() |
|
self.args = args |
|
self.tokenizer = tokenizer |
|
self.speaker_manager = speaker_manager |
|
|
|
self.init_multispeaker(args) |
|
|
|
|
|
self.length_scale = ( |
|
float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale |
|
) |
|
|
|
self.emb_dim = args.n_hidden_conformer_encoder |
|
self.encoder = Conformer( |
|
dim=self.args.n_hidden_conformer_encoder, |
|
n_layers=self.args.n_layers_conformer_encoder, |
|
n_heads=self.args.n_heads_conformer_encoder, |
|
speaker_embedding_dim=self.embedded_speaker_dim, |
|
p_dropout=self.args.dropout_conformer_encoder, |
|
kernel_size_conv_mod=self.args.kernel_size_conv_mod_conformer_encoder, |
|
lrelu_slope=self.args.lrelu_slope, |
|
) |
|
self.pitch_adaptor = PitchAdaptor( |
|
n_input=self.args.n_hidden_conformer_encoder, |
|
n_hidden=self.args.n_hidden_variance_adaptor, |
|
n_out=1, |
|
kernel_size=self.args.kernel_size_variance_adaptor, |
|
emb_kernel_size=self.args.emb_kernel_size_variance_adaptor, |
|
p_dropout=self.args.dropout_variance_adaptor, |
|
lrelu_slope=self.args.lrelu_slope, |
|
) |
|
self.energy_adaptor = EnergyAdaptor( |
|
channels_in=self.args.n_hidden_conformer_encoder, |
|
channels_hidden=self.args.n_hidden_variance_adaptor, |
|
channels_out=1, |
|
kernel_size=self.args.kernel_size_variance_adaptor, |
|
emb_kernel_size=self.args.emb_kernel_size_variance_adaptor, |
|
dropout=self.args.dropout_variance_adaptor, |
|
lrelu_slope=self.args.lrelu_slope, |
|
) |
|
|
|
self.aligner = AlignmentNetwork( |
|
in_query_channels=self.args.out_channels, |
|
in_key_channels=self.args.n_hidden_conformer_encoder, |
|
) |
|
|
|
self.duration_predictor = VariancePredictor( |
|
channels_in=self.args.n_hidden_conformer_encoder, |
|
channels=self.args.n_hidden_variance_adaptor, |
|
channels_out=1, |
|
kernel_size=self.args.kernel_size_variance_adaptor, |
|
p_dropout=self.args.dropout_variance_adaptor, |
|
lrelu_slope=self.args.lrelu_slope, |
|
) |
|
|
|
self.utterance_prosody_encoder = UtteranceLevelProsodyEncoder( |
|
num_mels=self.args.num_mels, |
|
ref_enc_filters=self.args.ref_enc_filters_reference_encoder, |
|
ref_enc_size=self.args.ref_enc_size_reference_encoder, |
|
ref_enc_gru_size=self.args.ref_enc_gru_size_reference_encoder, |
|
ref_enc_strides=self.args.ref_enc_strides_reference_encoder, |
|
n_hidden=self.args.n_hidden_conformer_encoder, |
|
dropout=self.args.dropout_conformer_encoder, |
|
bottleneck_size_u=self.args.bottleneck_size_u_reference_encoder, |
|
token_num=self.args.token_num_reference_encoder, |
|
) |
|
|
|
self.utterance_prosody_predictor = PhonemeProsodyPredictor( |
|
hidden_size=self.args.n_hidden_conformer_encoder, |
|
kernel_size=self.args.predictor_kernel_size_reference_encoder, |
|
dropout=self.args.dropout_conformer_encoder, |
|
bottleneck_size=self.args.bottleneck_size_u_reference_encoder, |
|
lrelu_slope=self.args.lrelu_slope, |
|
) |
|
|
|
self.phoneme_prosody_encoder = PhonemeLevelProsodyEncoder( |
|
num_mels=self.args.num_mels, |
|
ref_enc_filters=self.args.ref_enc_filters_reference_encoder, |
|
ref_enc_size=self.args.ref_enc_size_reference_encoder, |
|
ref_enc_gru_size=self.args.ref_enc_gru_size_reference_encoder, |
|
ref_enc_strides=self.args.ref_enc_strides_reference_encoder, |
|
n_hidden=self.args.n_hidden_conformer_encoder, |
|
dropout=self.args.dropout_conformer_encoder, |
|
bottleneck_size_p=self.args.bottleneck_size_p_reference_encoder, |
|
n_heads=self.args.n_heads_conformer_encoder, |
|
) |
|
|
|
self.phoneme_prosody_predictor = PhonemeProsodyPredictor( |
|
hidden_size=self.args.n_hidden_conformer_encoder, |
|
kernel_size=self.args.predictor_kernel_size_reference_encoder, |
|
dropout=self.args.dropout_conformer_encoder, |
|
bottleneck_size=self.args.bottleneck_size_p_reference_encoder, |
|
lrelu_slope=self.args.lrelu_slope, |
|
) |
|
|
|
self.u_bottle_out = nn.Linear( |
|
self.args.bottleneck_size_u_reference_encoder, |
|
self.args.n_hidden_conformer_encoder, |
|
) |
|
|
|
self.u_norm = nn.InstanceNorm1d(self.args.bottleneck_size_u_reference_encoder) |
|
self.p_bottle_out = nn.Linear( |
|
self.args.bottleneck_size_p_reference_encoder, |
|
self.args.n_hidden_conformer_encoder, |
|
) |
|
self.p_norm = nn.InstanceNorm1d( |
|
self.args.bottleneck_size_p_reference_encoder, |
|
) |
|
self.decoder = Conformer( |
|
dim=self.args.n_hidden_conformer_decoder, |
|
n_layers=self.args.n_layers_conformer_decoder, |
|
n_heads=self.args.n_heads_conformer_decoder, |
|
speaker_embedding_dim=self.embedded_speaker_dim, |
|
p_dropout=self.args.dropout_conformer_decoder, |
|
kernel_size_conv_mod=self.args.kernel_size_conv_mod_conformer_decoder, |
|
lrelu_slope=self.args.lrelu_slope, |
|
) |
|
|
|
padding_idx = self.tokenizer.characters.pad_id |
|
self.src_word_emb = EmbeddingPadded( |
|
self.args.num_chars, self.args.n_hidden_conformer_encoder, padding_idx=padding_idx |
|
) |
|
self.to_mel = nn.Linear( |
|
self.args.n_hidden_conformer_decoder, |
|
self.args.num_mels, |
|
) |
|
|
|
self.energy_scaler = torch.nn.BatchNorm1d(1, affine=False, track_running_stats=True, momentum=None) |
|
self.energy_scaler.requires_grad_(False) |
|
|
|
def init_multispeaker(self, args: Coqpit): |
|
"""Init for multi-speaker training.""" |
|
self.embedded_speaker_dim = 0 |
|
self.num_speakers = self.args.num_speakers |
|
self.audio_transform = None |
|
|
|
if self.speaker_manager: |
|
self.num_speakers = self.speaker_manager.num_speakers |
|
|
|
if self.args.use_speaker_embedding: |
|
self._init_speaker_embedding() |
|
|
|
if self.args.use_d_vector_file: |
|
self._init_d_vector() |
|
|
|
@staticmethod |
|
def _set_cond_input(aux_input: Dict): |
|
"""Set the speaker conditioning input based on the multi-speaker mode.""" |
|
sid, g, lid, durations = None, None, None, None |
|
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: |
|
sid = aux_input["speaker_ids"] |
|
if sid.ndim == 0: |
|
sid = sid.unsqueeze_(0) |
|
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: |
|
g = F.normalize(aux_input["d_vectors"]) |
|
if g.ndim == 2: |
|
g = g |
|
|
|
if "durations" in aux_input and aux_input["durations"] is not None: |
|
durations = aux_input["durations"] |
|
|
|
return sid, g, lid, durations |
|
|
|
def get_aux_input(self, aux_input: Dict): |
|
sid, g, lid, _ = self._set_cond_input(aux_input) |
|
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} |
|
|
|
def _set_speaker_input(self, aux_input: Dict): |
|
d_vectors = aux_input.get("d_vectors", None) |
|
speaker_ids = aux_input.get("speaker_ids", None) |
|
|
|
if d_vectors is not None and speaker_ids is not None: |
|
raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") |
|
|
|
if speaker_ids is not None and not hasattr(self, "emb_g"): |
|
raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") |
|
|
|
g = speaker_ids if speaker_ids is not None else d_vectors |
|
return g |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_speaker_embedding(self): |
|
|
|
if self.num_speakers > 0: |
|
print(" > initialization of speaker-embedding layers.") |
|
self.embedded_speaker_dim = self.args.speaker_embedding_channels |
|
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) |
|
|
|
def _init_d_vector(self): |
|
|
|
if hasattr(self, "emb_g"): |
|
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") |
|
self.embedded_speaker_dim = self.args.d_vector_dim |
|
|
|
@staticmethod |
|
def generate_attn(dr, x_mask, y_mask=None): |
|
"""Generate an attention mask from the linear scale durations. |
|
|
|
Args: |
|
dr (Tensor): Linear scale durations. |
|
x_mask (Tensor): Mask for the input (character) sequence. |
|
y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations |
|
if None. Defaults to None. |
|
|
|
Shapes |
|
- dr: :math:`(B, T_{en})` |
|
- x_mask: :math:`(B, T_{en})` |
|
- y_mask: :math:`(B, T_{de})` |
|
""" |
|
|
|
if y_mask is None: |
|
y_lengths = dr.sum(1).long() |
|
y_lengths[y_lengths < 1] = 1 |
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) |
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) |
|
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) |
|
return attn |
|
|
|
def _expand_encoder_with_durations( |
|
self, |
|
o_en: torch.FloatTensor, |
|
dr: torch.IntTensor, |
|
x_mask: torch.IntTensor, |
|
y_lengths: torch.IntTensor, |
|
): |
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) |
|
attn = self.generate_attn(dr, x_mask, y_mask) |
|
o_en_ex = torch.einsum("kmn, kjm -> kjn", [attn.float(), o_en]) |
|
return y_mask, o_en_ex, attn.transpose(1, 2) |
|
|
|
def _forward_aligner( |
|
self, |
|
x: torch.FloatTensor, |
|
y: torch.FloatTensor, |
|
x_mask: torch.IntTensor, |
|
y_mask: torch.IntTensor, |
|
attn_priors: torch.FloatTensor, |
|
) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: |
|
"""Aligner forward pass. |
|
|
|
1. Compute a mask to apply to the attention map. |
|
2. Run the alignment network. |
|
3. Apply MAS to compute the hard alignment map. |
|
4. Compute the durations from the hard alignment map. |
|
|
|
Args: |
|
x (torch.FloatTensor): Input sequence. |
|
y (torch.FloatTensor): Output sequence. |
|
x_mask (torch.IntTensor): Input sequence mask. |
|
y_mask (torch.IntTensor): Output sequence mask. |
|
attn_priors (torch.FloatTensor): Prior for the aligner network map. |
|
|
|
Returns: |
|
Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: |
|
Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials, |
|
hard alignment map. |
|
|
|
Shapes: |
|
- x: :math:`[B, T_en, C_en]` |
|
- y: :math:`[B, T_de, C_de]` |
|
- x_mask: :math:`[B, 1, T_en]` |
|
- y_mask: :math:`[B, 1, T_de]` |
|
- attn_priors: :math:`[B, T_de, T_en]` |
|
|
|
- aligner_durations: :math:`[B, T_en]` |
|
- aligner_soft: :math:`[B, T_de, T_en]` |
|
- aligner_logprob: :math:`[B, 1, T_de, T_en]` |
|
- aligner_mas: :math:`[B, T_de, T_en]` |
|
""" |
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) |
|
aligner_soft, aligner_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, attn_priors) |
|
aligner_mas = maximum_path( |
|
aligner_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() |
|
) |
|
aligner_durations = torch.sum(aligner_mas, -1).int() |
|
aligner_soft = aligner_soft.squeeze(1) |
|
aligner_mas = aligner_mas.transpose(1, 2) |
|
return aligner_durations, aligner_soft, aligner_logprob, aligner_mas |
|
|
|
def average_utterance_prosody( |
|
self, u_prosody_pred: torch.Tensor, src_mask: torch.Tensor |
|
) -> torch.Tensor: |
|
lengths = ((~src_mask) * 1.0).sum(1) |
|
u_prosody_pred = u_prosody_pred.sum(1, keepdim=True) / lengths.view(-1, 1, 1) |
|
return u_prosody_pred |
|
|
|
def forward( |
|
self, |
|
tokens: torch.Tensor, |
|
src_lens: torch.Tensor, |
|
mels: torch.Tensor, |
|
mel_lens: torch.Tensor, |
|
pitches: torch.Tensor, |
|
energies: torch.Tensor, |
|
attn_priors: torch.Tensor, |
|
use_ground_truth: bool = True, |
|
d_vectors: torch.Tensor = None, |
|
speaker_idx: torch.Tensor = None, |
|
) -> Dict[str, torch.Tensor]: |
|
sid, g, lid, _ = self._set_cond_input( |
|
{"d_vectors": d_vectors, "speaker_ids": speaker_idx} |
|
) |
|
|
|
src_mask = get_mask_from_lengths(src_lens) |
|
mel_mask = get_mask_from_lengths(mel_lens) |
|
|
|
|
|
token_embeddings = self.src_word_emb(tokens) |
|
token_embeddings = token_embeddings.masked_fill(src_mask.unsqueeze(-1), 0.0) |
|
|
|
|
|
aligner_durations, aligner_soft, aligner_logprob, aligner_mas = self._forward_aligner( |
|
x=token_embeddings, |
|
y=mels.transpose(1, 2), |
|
x_mask=~src_mask[:, None], |
|
y_mask=~mel_mask[:, None], |
|
attn_priors=attn_priors, |
|
) |
|
dr = aligner_durations |
|
|
|
|
|
speaker_embedding = None |
|
if d_vectors is not None: |
|
speaker_embedding = g |
|
elif speaker_idx is not None: |
|
speaker_embedding = F.normalize(self.emb_g(sid)) |
|
|
|
pos_encoding = positional_encoding( |
|
self.emb_dim, |
|
max(token_embeddings.shape[1], max(mel_lens)), |
|
device=token_embeddings.device, |
|
) |
|
encoder_outputs = self.encoder( |
|
token_embeddings, |
|
src_mask, |
|
speaker_embedding=speaker_embedding, |
|
encoding=pos_encoding, |
|
) |
|
|
|
u_prosody_ref = self.u_norm(self.utterance_prosody_encoder(mels=mels, mel_lens=mel_lens)) |
|
u_prosody_pred = self.u_norm( |
|
self.average_utterance_prosody( |
|
u_prosody_pred=self.utterance_prosody_predictor(x=encoder_outputs, mask=src_mask), |
|
src_mask=src_mask, |
|
) |
|
) |
|
|
|
if use_ground_truth: |
|
encoder_outputs = encoder_outputs + self.u_bottle_out(u_prosody_ref) |
|
else: |
|
encoder_outputs = encoder_outputs + self.u_bottle_out(u_prosody_pred) |
|
|
|
p_prosody_ref = self.p_norm( |
|
self.phoneme_prosody_encoder( |
|
x=encoder_outputs, src_mask=src_mask, mels=mels, mel_lens=mel_lens, encoding=pos_encoding |
|
) |
|
) |
|
p_prosody_pred = self.p_norm(self.phoneme_prosody_predictor(x=encoder_outputs, mask=src_mask)) |
|
|
|
if use_ground_truth: |
|
encoder_outputs = encoder_outputs + self.p_bottle_out(p_prosody_ref) |
|
else: |
|
encoder_outputs = encoder_outputs + self.p_bottle_out(p_prosody_pred) |
|
|
|
encoder_outputs_res = encoder_outputs |
|
|
|
pitch_pred, avg_pitch_target, pitch_emb = self.pitch_adaptor.get_pitch_embedding_train( |
|
x=encoder_outputs, |
|
target=pitches, |
|
dr=dr, |
|
mask=src_mask, |
|
) |
|
|
|
energy_pred, avg_energy_target, energy_emb = self.energy_adaptor.get_energy_embedding_train( |
|
x=encoder_outputs, |
|
target=energies, |
|
dr=dr, |
|
mask=src_mask, |
|
) |
|
|
|
encoder_outputs = encoder_outputs.transpose(1, 2) + pitch_emb + energy_emb |
|
log_duration_prediction = self.duration_predictor(x=encoder_outputs_res.detach(), mask=src_mask) |
|
|
|
mel_pred_mask, encoder_outputs_ex, alignments = self._expand_encoder_with_durations( |
|
o_en=encoder_outputs, y_lengths=mel_lens, dr=dr, x_mask=~src_mask[:, None] |
|
) |
|
|
|
x = self.decoder( |
|
encoder_outputs_ex.transpose(1, 2), |
|
mel_mask, |
|
speaker_embedding=speaker_embedding, |
|
encoding=pos_encoding, |
|
) |
|
x = self.to_mel(x) |
|
|
|
dr = torch.log(dr + 1) |
|
|
|
dr_pred = torch.exp(log_duration_prediction) - 1 |
|
alignments_dp = self.generate_attn(dr_pred, src_mask.unsqueeze(1), mel_pred_mask) |
|
|
|
return { |
|
"model_outputs": x, |
|
"pitch_pred": pitch_pred, |
|
"pitch_target": avg_pitch_target, |
|
"energy_pred": energy_pred, |
|
"energy_target": avg_energy_target, |
|
"u_prosody_pred": u_prosody_pred, |
|
"u_prosody_ref": u_prosody_ref, |
|
"p_prosody_pred": p_prosody_pred, |
|
"p_prosody_ref": p_prosody_ref, |
|
"alignments_dp": alignments_dp, |
|
"alignments": alignments, |
|
"aligner_soft": aligner_soft, |
|
"aligner_mas": aligner_mas, |
|
"aligner_durations": aligner_durations, |
|
"aligner_logprob": aligner_logprob, |
|
"dr_log_pred": log_duration_prediction.squeeze(1), |
|
"dr_log_target": dr.squeeze(1), |
|
"spk_emb": speaker_embedding, |
|
} |
|
|
|
@torch.no_grad() |
|
def inference( |
|
self, |
|
tokens: torch.Tensor, |
|
speaker_idx: torch.Tensor, |
|
p_control: float = None, |
|
d_control: float = None, |
|
d_vectors: torch.Tensor = None, |
|
pitch_transform: Callable = None, |
|
energy_transform: Callable = None, |
|
) -> torch.Tensor: |
|
src_mask = get_mask_from_lengths(torch.tensor([tokens.shape[1]], dtype=torch.int64, device=tokens.device)) |
|
src_lens = torch.tensor(tokens.shape[1:2]).to(tokens.device) |
|
sid, g, lid, _ = self._set_cond_input( |
|
{"d_vectors": d_vectors, "speaker_ids": speaker_idx} |
|
) |
|
|
|
token_embeddings = self.src_word_emb(tokens) |
|
token_embeddings = token_embeddings.masked_fill(src_mask.unsqueeze(-1), 0.0) |
|
|
|
|
|
speaker_embedding = None |
|
if d_vectors is not None: |
|
speaker_embedding = g |
|
elif speaker_idx is not None: |
|
speaker_embedding = F.normalize(self.emb_g(sid)) |
|
|
|
pos_encoding = positional_encoding( |
|
self.emb_dim, |
|
token_embeddings.shape[1], |
|
device=token_embeddings.device, |
|
) |
|
encoder_outputs = self.encoder( |
|
token_embeddings, |
|
src_mask, |
|
speaker_embedding=speaker_embedding, |
|
encoding=pos_encoding, |
|
) |
|
|
|
u_prosody_pred = self.u_norm( |
|
self.average_utterance_prosody( |
|
u_prosody_pred=self.utterance_prosody_predictor(x=encoder_outputs, mask=src_mask), |
|
src_mask=src_mask, |
|
) |
|
) |
|
encoder_outputs = encoder_outputs + self.u_bottle_out(u_prosody_pred).expand_as(encoder_outputs) |
|
|
|
p_prosody_pred = self.p_norm( |
|
self.phoneme_prosody_predictor( |
|
x=encoder_outputs, |
|
mask=src_mask, |
|
) |
|
) |
|
encoder_outputs = encoder_outputs + self.p_bottle_out(p_prosody_pred).expand_as(encoder_outputs) |
|
|
|
encoder_outputs_res = encoder_outputs |
|
|
|
pitch_emb_pred, pitch_pred = self.pitch_adaptor.get_pitch_embedding( |
|
x=encoder_outputs, |
|
mask=src_mask, |
|
pitch_transform=pitch_transform, |
|
pitch_mean=self.pitch_mean if hasattr(self, "pitch_mean") else None, |
|
pitch_std=self.pitch_std if hasattr(self, "pitch_std") else None, |
|
) |
|
|
|
energy_emb_pred, energy_pred = self.energy_adaptor.get_energy_embedding( |
|
x=encoder_outputs, mask=src_mask, energy_transform=energy_transform |
|
) |
|
encoder_outputs = encoder_outputs.transpose(1, 2) + pitch_emb_pred + energy_emb_pred |
|
|
|
log_duration_pred = self.duration_predictor( |
|
x=encoder_outputs_res.detach(), mask=src_mask |
|
) |
|
duration_pred = (torch.exp(log_duration_pred) - 1) * (~src_mask) * self.length_scale |
|
duration_pred[duration_pred < 1] = 1.0 |
|
duration_pred = torch.round(duration_pred) |
|
mel_lens = duration_pred.sum(1) |
|
|
|
_, encoder_outputs_ex, alignments = self._expand_encoder_with_durations( |
|
o_en=encoder_outputs, y_lengths=mel_lens, dr=duration_pred.squeeze(1), x_mask=~src_mask[:, None] |
|
) |
|
|
|
mel_mask = get_mask_from_lengths( |
|
torch.tensor([encoder_outputs_ex.shape[2]], dtype=torch.int64, device=encoder_outputs_ex.device) |
|
) |
|
|
|
if encoder_outputs_ex.shape[1] > pos_encoding.shape[1]: |
|
encoding = positional_encoding(self.emb_dim, encoder_outputs_ex.shape[2], device=tokens.device) |
|
|
|
|
|
x = self.decoder( |
|
encoder_outputs_ex.transpose(1, 2), |
|
mel_mask, |
|
speaker_embedding=speaker_embedding, |
|
encoding=encoding, |
|
) |
|
x = self.to_mel(x) |
|
outputs = { |
|
"model_outputs": x, |
|
"alignments": alignments, |
|
|
|
"durations": duration_pred, |
|
"pitch": pitch_pred, |
|
"energy": energy_pred, |
|
"spk_emb": speaker_embedding, |
|
} |
|
return outputs |
|
|