|
from typing import List, Tuple |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from tqdm.auto import tqdm |
|
|
|
from TTS.tts.layers.tacotron.common_layers import Linear |
|
from TTS.tts.layers.tacotron.tacotron2 import ConvBNBlock |
|
|
|
|
|
class Encoder(nn.Module): |
|
r"""Neural HMM Encoder |
|
|
|
Same as Tacotron 2 encoder but increases the input length by states per phone |
|
|
|
Args: |
|
num_chars (int): Number of characters in the input. |
|
state_per_phone (int): Number of states per phone. |
|
in_out_channels (int): number of input and output channels. |
|
n_convolutions (int): number of convolutional layers. |
|
""" |
|
|
|
def __init__(self, num_chars, state_per_phone, in_out_channels=512, n_convolutions=3): |
|
super().__init__() |
|
|
|
self.state_per_phone = state_per_phone |
|
self.in_out_channels = in_out_channels |
|
|
|
self.emb = nn.Embedding(num_chars, in_out_channels) |
|
self.convolutions = nn.ModuleList() |
|
for _ in range(n_convolutions): |
|
self.convolutions.append(ConvBNBlock(in_out_channels, in_out_channels, 5, "relu")) |
|
self.lstm = nn.LSTM( |
|
in_out_channels, |
|
int(in_out_channels / 2) * state_per_phone, |
|
num_layers=1, |
|
batch_first=True, |
|
bias=True, |
|
bidirectional=True, |
|
) |
|
self.rnn_state = None |
|
|
|
def forward(self, x: torch.FloatTensor, x_len: torch.LongTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]: |
|
"""Forward pass to the encoder. |
|
|
|
Args: |
|
x (torch.FloatTensor): input text indices. |
|
- shape: :math:`(b, T_{in})` |
|
x_len (torch.LongTensor): input text lengths. |
|
- shape: :math:`(b,)` |
|
|
|
Returns: |
|
Tuple[torch.FloatTensor, torch.LongTensor]: encoder outputs and output lengths. |
|
-shape: :math:`((b, T_{in} * states_per_phone, in_out_channels), (b,))` |
|
""" |
|
b, T = x.shape |
|
o = self.emb(x).transpose(1, 2) |
|
for layer in self.convolutions: |
|
o = layer(o) |
|
o = o.transpose(1, 2) |
|
o = nn.utils.rnn.pack_padded_sequence(o, x_len.cpu(), batch_first=True) |
|
self.lstm.flatten_parameters() |
|
o, _ = self.lstm(o) |
|
o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True) |
|
o = o.reshape(b, T * self.state_per_phone, self.in_out_channels) |
|
x_len = x_len * self.state_per_phone |
|
return o, x_len |
|
|
|
def inference(self, x, x_len): |
|
"""Inference to the encoder. |
|
|
|
Args: |
|
x (torch.FloatTensor): input text indices. |
|
- shape: :math:`(b, T_{in})` |
|
x_len (torch.LongTensor): input text lengths. |
|
- shape: :math:`(b,)` |
|
|
|
Returns: |
|
Tuple[torch.FloatTensor, torch.LongTensor]: encoder outputs and output lengths. |
|
-shape: :math:`((b, T_{in} * states_per_phone, in_out_channels), (b,))` |
|
""" |
|
b, T = x.shape |
|
o = self.emb(x).transpose(1, 2) |
|
for layer in self.convolutions: |
|
o = layer(o) |
|
o = o.transpose(1, 2) |
|
|
|
o, _ = self.lstm(o) |
|
o = o.reshape(b, T * self.state_per_phone, self.in_out_channels) |
|
x_len = x_len * self.state_per_phone |
|
return o, x_len |
|
|
|
|
|
class ParameterModel(nn.Module): |
|
r"""Main neural network of the outputnet |
|
|
|
Note: Do not put dropout layers here, the model will not converge. |
|
|
|
Args: |
|
outputnet_size (List[int]): the architecture of the parameter model |
|
input_size (int): size of input for the first layer |
|
output_size (int): size of output i.e size of the feature dim |
|
frame_channels (int): feature dim to set the flat start bias |
|
flat_start_params (dict): flat start parameters to set the bias |
|
""" |
|
|
|
def __init__( |
|
self, |
|
outputnet_size: List[int], |
|
input_size: int, |
|
output_size: int, |
|
frame_channels: int, |
|
flat_start_params: dict, |
|
): |
|
super().__init__() |
|
self.frame_channels = frame_channels |
|
|
|
self.layers = nn.ModuleList( |
|
[Linear(inp, out) for inp, out in zip([input_size] + outputnet_size[:-1], outputnet_size)] |
|
) |
|
self.last_layer = nn.Linear(outputnet_size[-1], output_size) |
|
self.flat_start_output_layer( |
|
flat_start_params["mean"], flat_start_params["std"], flat_start_params["transition_p"] |
|
) |
|
|
|
def flat_start_output_layer(self, mean, std, transition_p): |
|
self.last_layer.weight.data.zero_() |
|
self.last_layer.bias.data[0 : self.frame_channels] = mean |
|
self.last_layer.bias.data[self.frame_channels : 2 * self.frame_channels] = OverflowUtils.inverse_softplus(std) |
|
self.last_layer.bias.data[2 * self.frame_channels :] = OverflowUtils.inverse_sigmod(transition_p) |
|
|
|
def forward(self, x): |
|
for layer in self.layers: |
|
x = F.relu(layer(x)) |
|
x = self.last_layer(x) |
|
return x |
|
|
|
|
|
class Outputnet(nn.Module): |
|
r""" |
|
This network takes current state and previous observed values as input |
|
and returns its parameters, mean, standard deviation and probability |
|
of transition to the next state |
|
""" |
|
|
|
def __init__( |
|
self, |
|
encoder_dim: int, |
|
memory_rnn_dim: int, |
|
frame_channels: int, |
|
outputnet_size: List[int], |
|
flat_start_params: dict, |
|
std_floor: float = 1e-2, |
|
): |
|
super().__init__() |
|
|
|
self.frame_channels = frame_channels |
|
self.flat_start_params = flat_start_params |
|
self.std_floor = std_floor |
|
|
|
input_size = memory_rnn_dim + encoder_dim |
|
output_size = 2 * frame_channels + 1 |
|
|
|
self.parametermodel = ParameterModel( |
|
outputnet_size=outputnet_size, |
|
input_size=input_size, |
|
output_size=output_size, |
|
flat_start_params=flat_start_params, |
|
frame_channels=frame_channels, |
|
) |
|
|
|
def forward(self, ar_mels, inputs): |
|
r"""Inputs observation and returns the means, stds and transition probability for the current state |
|
|
|
Args: |
|
ar_mel_inputs (torch.FloatTensor): shape (batch, prenet_dim) |
|
states (torch.FloatTensor): (batch, hidden_states, hidden_state_dim) |
|
|
|
Returns: |
|
means: means for the emission observation for each feature |
|
- shape: (B, hidden_states, feature_size) |
|
stds: standard deviations for the emission observation for each feature |
|
- shape: (batch, hidden_states, feature_size) |
|
transition_vectors: transition vector for the current hidden state |
|
- shape: (batch, hidden_states) |
|
""" |
|
batch_size, prenet_dim = ar_mels.shape[0], ar_mels.shape[1] |
|
N = inputs.shape[1] |
|
|
|
ar_mels = ar_mels.unsqueeze(1).expand(batch_size, N, prenet_dim) |
|
ar_mels = torch.cat((ar_mels, inputs), dim=2) |
|
ar_mels = self.parametermodel(ar_mels) |
|
|
|
mean, std, transition_vector = ( |
|
ar_mels[:, :, 0 : self.frame_channels], |
|
ar_mels[:, :, self.frame_channels : 2 * self.frame_channels], |
|
ar_mels[:, :, 2 * self.frame_channels :].squeeze(2), |
|
) |
|
std = F.softplus(std) |
|
std = self._floor_std(std) |
|
return mean, std, transition_vector |
|
|
|
def _floor_std(self, std): |
|
r""" |
|
It clamps the standard deviation to not to go below some level |
|
This removes the problem when the model tries to cheat for higher likelihoods by converting |
|
one of the gaussians to a point mass. |
|
|
|
Args: |
|
std (float Tensor): tensor containing the standard deviation to be |
|
""" |
|
original_tensor = std.clone().detach() |
|
std = torch.clamp(std, min=self.std_floor) |
|
if torch.any(original_tensor != std): |
|
print( |
|
"[*] Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about" |
|
) |
|
return std |
|
|
|
|
|
class OverflowUtils: |
|
@staticmethod |
|
def get_data_parameters_for_flat_start( |
|
data_loader: torch.utils.data.DataLoader, out_channels: int, states_per_phone: int |
|
): |
|
"""Generates data parameters for flat starting the HMM. |
|
|
|
Args: |
|
data_loader (torch.utils.data.Dataloader): _description_ |
|
out_channels (int): mel spectrogram channels |
|
states_per_phone (_type_): HMM states per phone |
|
""" |
|
|
|
|
|
total_state_len = 0 |
|
total_mel_len = 0 |
|
|
|
|
|
total_mel_sum = 0 |
|
total_mel_sq_sum = 0 |
|
|
|
for batch in tqdm(data_loader, leave=False): |
|
text_lengths = batch["token_id_lengths"] |
|
mels = batch["mel"] |
|
mel_lengths = batch["mel_lengths"] |
|
|
|
total_state_len += torch.sum(text_lengths) |
|
total_mel_len += torch.sum(mel_lengths) |
|
total_mel_sum += torch.sum(mels) |
|
total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) |
|
|
|
data_mean = total_mel_sum / (total_mel_len * out_channels) |
|
data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) |
|
average_num_states = total_state_len / len(data_loader.dataset) |
|
average_mel_len = total_mel_len / len(data_loader.dataset) |
|
average_duration_each_state = average_mel_len / average_num_states |
|
init_transition_prob = 1 / average_duration_each_state |
|
|
|
return data_mean, data_std, (init_transition_prob * states_per_phone) |
|
|
|
@staticmethod |
|
@torch.no_grad() |
|
def update_flat_start_transition(model, transition_p): |
|
model.neural_hmm.output_net.parametermodel.flat_start_output_layer(0.0, 1.0, transition_p) |
|
|
|
@staticmethod |
|
def log_clamped(x, eps=1e-04): |
|
""" |
|
Avoids the log(0) problem |
|
|
|
Args: |
|
x (torch.tensor): input tensor |
|
eps (float, optional): lower bound. Defaults to 1e-04. |
|
|
|
Returns: |
|
torch.tensor: :math:`log(x)` |
|
""" |
|
clamped_x = torch.clamp(x, min=eps) |
|
return torch.log(clamped_x) |
|
|
|
@staticmethod |
|
def inverse_sigmod(x): |
|
r""" |
|
Inverse of the sigmoid function |
|
""" |
|
if not torch.is_tensor(x): |
|
x = torch.tensor(x) |
|
return OverflowUtils.log_clamped(x / (1.0 - x)) |
|
|
|
@staticmethod |
|
def inverse_softplus(x): |
|
r""" |
|
Inverse of the softplus function |
|
""" |
|
if not torch.is_tensor(x): |
|
x = torch.tensor(x) |
|
return OverflowUtils.log_clamped(torch.exp(x) - 1.0) |
|
|
|
@staticmethod |
|
def logsumexp(x, dim): |
|
r""" |
|
Differentiable LogSumExp: Does not creates nan gradients |
|
when all the inputs are -inf yeilds 0 gradients. |
|
Args: |
|
x : torch.Tensor - The input tensor |
|
dim: int - The dimension on which the log sum exp has to be applied |
|
""" |
|
|
|
m, _ = x.max(dim=dim) |
|
mask = m == -float("inf") |
|
s = (x - m.masked_fill_(mask, 0).unsqueeze(dim=dim)).exp().sum(dim=dim) |
|
return s.masked_fill_(mask, 1).log() + m.masked_fill_(mask, -float("inf")) |
|
|
|
@staticmethod |
|
def double_pad(list_of_different_shape_tensors): |
|
r""" |
|
Pads the list of tensors in 2 dimensions |
|
""" |
|
second_dim_lens = [len(a) for a in [i[0] for i in list_of_different_shape_tensors]] |
|
second_dim_max = max(second_dim_lens) |
|
padded_x = [F.pad(x, (0, second_dim_max - len(x[0]))) for x in list_of_different_shape_tensors] |
|
return nn.utils.rnn.pad_sequence(padded_x, batch_first=True) |
|
|