File size: 11,736 Bytes
9b2107c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 |
from typing import List, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from tqdm.auto import tqdm
from TTS.tts.layers.tacotron.common_layers import Linear
from TTS.tts.layers.tacotron.tacotron2 import ConvBNBlock
class Encoder(nn.Module):
r"""Neural HMM Encoder
Same as Tacotron 2 encoder but increases the input length by states per phone
Args:
num_chars (int): Number of characters in the input.
state_per_phone (int): Number of states per phone.
in_out_channels (int): number of input and output channels.
n_convolutions (int): number of convolutional layers.
"""
def __init__(self, num_chars, state_per_phone, in_out_channels=512, n_convolutions=3):
super().__init__()
self.state_per_phone = state_per_phone
self.in_out_channels = in_out_channels
self.emb = nn.Embedding(num_chars, in_out_channels)
self.convolutions = nn.ModuleList()
for _ in range(n_convolutions):
self.convolutions.append(ConvBNBlock(in_out_channels, in_out_channels, 5, "relu"))
self.lstm = nn.LSTM(
in_out_channels,
int(in_out_channels / 2) * state_per_phone,
num_layers=1,
batch_first=True,
bias=True,
bidirectional=True,
)
self.rnn_state = None
def forward(self, x: torch.FloatTensor, x_len: torch.LongTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]:
"""Forward pass to the encoder.
Args:
x (torch.FloatTensor): input text indices.
- shape: :math:`(b, T_{in})`
x_len (torch.LongTensor): input text lengths.
- shape: :math:`(b,)`
Returns:
Tuple[torch.FloatTensor, torch.LongTensor]: encoder outputs and output lengths.
-shape: :math:`((b, T_{in} * states_per_phone, in_out_channels), (b,))`
"""
b, T = x.shape
o = self.emb(x).transpose(1, 2)
for layer in self.convolutions:
o = layer(o)
o = o.transpose(1, 2)
o = nn.utils.rnn.pack_padded_sequence(o, x_len.cpu(), batch_first=True)
self.lstm.flatten_parameters()
o, _ = self.lstm(o)
o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True)
o = o.reshape(b, T * self.state_per_phone, self.in_out_channels)
x_len = x_len * self.state_per_phone
return o, x_len
def inference(self, x, x_len):
"""Inference to the encoder.
Args:
x (torch.FloatTensor): input text indices.
- shape: :math:`(b, T_{in})`
x_len (torch.LongTensor): input text lengths.
- shape: :math:`(b,)`
Returns:
Tuple[torch.FloatTensor, torch.LongTensor]: encoder outputs and output lengths.
-shape: :math:`((b, T_{in} * states_per_phone, in_out_channels), (b,))`
"""
b, T = x.shape
o = self.emb(x).transpose(1, 2)
for layer in self.convolutions:
o = layer(o)
o = o.transpose(1, 2)
# self.lstm.flatten_parameters()
o, _ = self.lstm(o)
o = o.reshape(b, T * self.state_per_phone, self.in_out_channels)
x_len = x_len * self.state_per_phone
return o, x_len
class ParameterModel(nn.Module):
r"""Main neural network of the outputnet
Note: Do not put dropout layers here, the model will not converge.
Args:
outputnet_size (List[int]): the architecture of the parameter model
input_size (int): size of input for the first layer
output_size (int): size of output i.e size of the feature dim
frame_channels (int): feature dim to set the flat start bias
flat_start_params (dict): flat start parameters to set the bias
"""
def __init__(
self,
outputnet_size: List[int],
input_size: int,
output_size: int,
frame_channels: int,
flat_start_params: dict,
):
super().__init__()
self.frame_channels = frame_channels
self.layers = nn.ModuleList(
[Linear(inp, out) for inp, out in zip([input_size] + outputnet_size[:-1], outputnet_size)]
)
self.last_layer = nn.Linear(outputnet_size[-1], output_size)
self.flat_start_output_layer(
flat_start_params["mean"], flat_start_params["std"], flat_start_params["transition_p"]
)
def flat_start_output_layer(self, mean, std, transition_p):
self.last_layer.weight.data.zero_()
self.last_layer.bias.data[0 : self.frame_channels] = mean
self.last_layer.bias.data[self.frame_channels : 2 * self.frame_channels] = OverflowUtils.inverse_softplus(std)
self.last_layer.bias.data[2 * self.frame_channels :] = OverflowUtils.inverse_sigmod(transition_p)
def forward(self, x):
for layer in self.layers:
x = F.relu(layer(x))
x = self.last_layer(x)
return x
class Outputnet(nn.Module):
r"""
This network takes current state and previous observed values as input
and returns its parameters, mean, standard deviation and probability
of transition to the next state
"""
def __init__(
self,
encoder_dim: int,
memory_rnn_dim: int,
frame_channels: int,
outputnet_size: List[int],
flat_start_params: dict,
std_floor: float = 1e-2,
):
super().__init__()
self.frame_channels = frame_channels
self.flat_start_params = flat_start_params
self.std_floor = std_floor
input_size = memory_rnn_dim + encoder_dim
output_size = 2 * frame_channels + 1
self.parametermodel = ParameterModel(
outputnet_size=outputnet_size,
input_size=input_size,
output_size=output_size,
flat_start_params=flat_start_params,
frame_channels=frame_channels,
)
def forward(self, ar_mels, inputs):
r"""Inputs observation and returns the means, stds and transition probability for the current state
Args:
ar_mel_inputs (torch.FloatTensor): shape (batch, prenet_dim)
states (torch.FloatTensor): (batch, hidden_states, hidden_state_dim)
Returns:
means: means for the emission observation for each feature
- shape: (B, hidden_states, feature_size)
stds: standard deviations for the emission observation for each feature
- shape: (batch, hidden_states, feature_size)
transition_vectors: transition vector for the current hidden state
- shape: (batch, hidden_states)
"""
batch_size, prenet_dim = ar_mels.shape[0], ar_mels.shape[1]
N = inputs.shape[1]
ar_mels = ar_mels.unsqueeze(1).expand(batch_size, N, prenet_dim)
ar_mels = torch.cat((ar_mels, inputs), dim=2)
ar_mels = self.parametermodel(ar_mels)
mean, std, transition_vector = (
ar_mels[:, :, 0 : self.frame_channels],
ar_mels[:, :, self.frame_channels : 2 * self.frame_channels],
ar_mels[:, :, 2 * self.frame_channels :].squeeze(2),
)
std = F.softplus(std)
std = self._floor_std(std)
return mean, std, transition_vector
def _floor_std(self, std):
r"""
It clamps the standard deviation to not to go below some level
This removes the problem when the model tries to cheat for higher likelihoods by converting
one of the gaussians to a point mass.
Args:
std (float Tensor): tensor containing the standard deviation to be
"""
original_tensor = std.clone().detach()
std = torch.clamp(std, min=self.std_floor)
if torch.any(original_tensor != std):
print(
"[*] Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about"
)
return std
class OverflowUtils:
@staticmethod
def get_data_parameters_for_flat_start(
data_loader: torch.utils.data.DataLoader, out_channels: int, states_per_phone: int
):
"""Generates data parameters for flat starting the HMM.
Args:
data_loader (torch.utils.data.Dataloader): _description_
out_channels (int): mel spectrogram channels
states_per_phone (_type_): HMM states per phone
"""
# State related information for transition_p
total_state_len = 0
total_mel_len = 0
# Useful for data mean an std
total_mel_sum = 0
total_mel_sq_sum = 0
for batch in tqdm(data_loader, leave=False):
text_lengths = batch["token_id_lengths"]
mels = batch["mel"]
mel_lengths = batch["mel_lengths"]
total_state_len += torch.sum(text_lengths)
total_mel_len += torch.sum(mel_lengths)
total_mel_sum += torch.sum(mels)
total_mel_sq_sum += torch.sum(torch.pow(mels, 2))
data_mean = total_mel_sum / (total_mel_len * out_channels)
data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2))
average_num_states = total_state_len / len(data_loader.dataset)
average_mel_len = total_mel_len / len(data_loader.dataset)
average_duration_each_state = average_mel_len / average_num_states
init_transition_prob = 1 / average_duration_each_state
return data_mean, data_std, (init_transition_prob * states_per_phone)
@staticmethod
@torch.no_grad()
def update_flat_start_transition(model, transition_p):
model.neural_hmm.output_net.parametermodel.flat_start_output_layer(0.0, 1.0, transition_p)
@staticmethod
def log_clamped(x, eps=1e-04):
"""
Avoids the log(0) problem
Args:
x (torch.tensor): input tensor
eps (float, optional): lower bound. Defaults to 1e-04.
Returns:
torch.tensor: :math:`log(x)`
"""
clamped_x = torch.clamp(x, min=eps)
return torch.log(clamped_x)
@staticmethod
def inverse_sigmod(x):
r"""
Inverse of the sigmoid function
"""
if not torch.is_tensor(x):
x = torch.tensor(x)
return OverflowUtils.log_clamped(x / (1.0 - x))
@staticmethod
def inverse_softplus(x):
r"""
Inverse of the softplus function
"""
if not torch.is_tensor(x):
x = torch.tensor(x)
return OverflowUtils.log_clamped(torch.exp(x) - 1.0)
@staticmethod
def logsumexp(x, dim):
r"""
Differentiable LogSumExp: Does not creates nan gradients
when all the inputs are -inf yeilds 0 gradients.
Args:
x : torch.Tensor - The input tensor
dim: int - The dimension on which the log sum exp has to be applied
"""
m, _ = x.max(dim=dim)
mask = m == -float("inf")
s = (x - m.masked_fill_(mask, 0).unsqueeze(dim=dim)).exp().sum(dim=dim)
return s.masked_fill_(mask, 1).log() + m.masked_fill_(mask, -float("inf"))
@staticmethod
def double_pad(list_of_different_shape_tensors):
r"""
Pads the list of tensors in 2 dimensions
"""
second_dim_lens = [len(a) for a in [i[0] for i in list_of_different_shape_tensors]]
second_dim_max = max(second_dim_lens)
padded_x = [F.pad(x, (0, second_dim_max - len(x[0]))) for x in list_of_different_shape_tensors]
return nn.utils.rnn.pad_sequence(padded_x, batch_first=True)
|