|
from typing import List, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from TTS.tts.layers.delightful_tts.conformer import ConformerMultiHeadedSelfAttention |
|
from TTS.tts.layers.delightful_tts.conv_layers import CoordConv1d |
|
from TTS.tts.layers.delightful_tts.networks import STL |
|
|
|
|
|
def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor: |
|
batch_size = lengths.shape[0] |
|
max_len = torch.max(lengths).item() |
|
ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1) |
|
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) |
|
return mask |
|
|
|
|
|
def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor: |
|
return torch.ceil(lens / stride).int() |
|
|
|
|
|
class ReferenceEncoder(nn.Module): |
|
""" |
|
Referance encoder for utterance and phoneme prosody encoders. Reference encoder |
|
made up of convolution and RNN layers. |
|
|
|
Args: |
|
num_mels (int): Number of mel frames to produce. |
|
ref_enc_filters (list[int]): List of channel sizes for encoder layers. |
|
ref_enc_size (int): Size of the kernel for the conv layers. |
|
ref_enc_strides (List[int]): List of strides to use for conv layers. |
|
ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit. |
|
|
|
Inputs: inputs, mask |
|
- **inputs** (batch, dim, time): Tensor containing mel vector |
|
- **lengths** (batch): Tensor containing the mel lengths. |
|
Returns: |
|
- **outputs** (batch, time, dim): Tensor produced by Reference Encoder. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_mels: int, |
|
ref_enc_filters: List[Union[int, int, int, int, int, int]], |
|
ref_enc_size: int, |
|
ref_enc_strides: List[Union[int, int, int, int, int]], |
|
ref_enc_gru_size: int, |
|
): |
|
super().__init__() |
|
|
|
n_mel_channels = num_mels |
|
self.n_mel_channels = n_mel_channels |
|
K = len(ref_enc_filters) |
|
filters = [self.n_mel_channels] + ref_enc_filters |
|
strides = [1] + ref_enc_strides |
|
|
|
convs = [ |
|
CoordConv1d( |
|
in_channels=filters[0], |
|
out_channels=filters[0 + 1], |
|
kernel_size=ref_enc_size, |
|
stride=strides[0], |
|
padding=ref_enc_size // 2, |
|
with_r=True, |
|
) |
|
] |
|
convs2 = [ |
|
nn.Conv1d( |
|
in_channels=filters[i], |
|
out_channels=filters[i + 1], |
|
kernel_size=ref_enc_size, |
|
stride=strides[i], |
|
padding=ref_enc_size // 2, |
|
) |
|
for i in range(1, K) |
|
] |
|
convs.extend(convs2) |
|
self.convs = nn.ModuleList(convs) |
|
|
|
self.norms = nn.ModuleList([nn.InstanceNorm1d(num_features=ref_enc_filters[i], affine=True) for i in range(K)]) |
|
|
|
self.gru = nn.GRU( |
|
input_size=ref_enc_filters[-1], |
|
hidden_size=ref_enc_gru_size, |
|
batch_first=True, |
|
) |
|
|
|
def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
""" |
|
inputs --- [N, n_mels, timesteps] |
|
outputs --- [N, E//2] |
|
""" |
|
|
|
mel_masks = get_mask_from_lengths(mel_lens).unsqueeze(1) |
|
x = x.masked_fill(mel_masks, 0) |
|
for conv, norm in zip(self.convs, self.norms): |
|
x = conv(x) |
|
x = F.leaky_relu(x, 0.3) |
|
x = norm(x) |
|
|
|
for _ in range(2): |
|
mel_lens = stride_lens(mel_lens) |
|
|
|
mel_masks = get_mask_from_lengths(mel_lens) |
|
|
|
x = x.masked_fill(mel_masks.unsqueeze(1), 0) |
|
x = x.permute((0, 2, 1)) |
|
x = torch.nn.utils.rnn.pack_padded_sequence(x, mel_lens.cpu().int(), batch_first=True, enforce_sorted=False) |
|
|
|
self.gru.flatten_parameters() |
|
x, memory = self.gru(x) |
|
x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True) |
|
|
|
return x, memory, mel_masks |
|
|
|
def calculate_channels( |
|
self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int |
|
) -> int: |
|
for _ in range(n_convs): |
|
L = (L - kernel_size + 2 * pad) // stride + 1 |
|
return L |
|
|
|
|
|
class UtteranceLevelProsodyEncoder(nn.Module): |
|
def __init__( |
|
self, |
|
num_mels: int, |
|
ref_enc_filters: List[Union[int, int, int, int, int, int]], |
|
ref_enc_size: int, |
|
ref_enc_strides: List[Union[int, int, int, int, int]], |
|
ref_enc_gru_size: int, |
|
dropout: float, |
|
n_hidden: int, |
|
bottleneck_size_u: int, |
|
token_num: int, |
|
): |
|
""" |
|
Encoder to extract prosody from utterance. it is made up of a reference encoder |
|
with a couple of linear layers and style token layer with dropout. |
|
|
|
Args: |
|
num_mels (int): Number of mel frames to produce. |
|
ref_enc_filters (list[int]): List of channel sizes for ref encoder layers. |
|
ref_enc_size (int): Size of the kernel for the ref encoder conv layers. |
|
ref_enc_strides (List[int]): List of strides to use for teh ref encoder conv layers. |
|
ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit. |
|
dropout (float): Probability of dropout. |
|
n_hidden (int): Size of hidden layers. |
|
bottleneck_size_u (int): Size of the bottle neck layer. |
|
|
|
Inputs: inputs, mask |
|
- **inputs** (batch, dim, time): Tensor containing mel vector |
|
- **lengths** (batch): Tensor containing the mel lengths. |
|
Returns: |
|
- **outputs** (batch, 1, dim): Tensor produced by Utterance Level Prosody Encoder. |
|
""" |
|
super().__init__() |
|
|
|
self.E = n_hidden |
|
self.d_q = self.d_k = n_hidden |
|
bottleneck_size = bottleneck_size_u |
|
|
|
self.encoder = ReferenceEncoder( |
|
ref_enc_filters=ref_enc_filters, |
|
ref_enc_gru_size=ref_enc_gru_size, |
|
ref_enc_size=ref_enc_size, |
|
ref_enc_strides=ref_enc_strides, |
|
num_mels=num_mels, |
|
) |
|
self.encoder_prj = nn.Linear(ref_enc_gru_size, self.E // 2) |
|
self.stl = STL(n_hidden=n_hidden, token_num=token_num) |
|
self.encoder_bottleneck = nn.Linear(self.E, bottleneck_size) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, mels: torch.Tensor, mel_lens: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Shapes: |
|
mels: :math: `[B, C, T]` |
|
mel_lens: :math: `[B]` |
|
|
|
out --- [N, seq_len, E] |
|
""" |
|
_, embedded_prosody, _ = self.encoder(mels, mel_lens) |
|
|
|
|
|
embedded_prosody = self.encoder_prj(embedded_prosody) |
|
|
|
|
|
out = self.encoder_bottleneck(self.stl(embedded_prosody)) |
|
out = self.dropout(out) |
|
|
|
out = out.view((-1, 1, out.shape[3])) |
|
return out |
|
|
|
|
|
class PhonemeLevelProsodyEncoder(nn.Module): |
|
def __init__( |
|
self, |
|
num_mels: int, |
|
ref_enc_filters: List[Union[int, int, int, int, int, int]], |
|
ref_enc_size: int, |
|
ref_enc_strides: List[Union[int, int, int, int, int]], |
|
ref_enc_gru_size: int, |
|
dropout: float, |
|
n_hidden: int, |
|
n_heads: int, |
|
bottleneck_size_p: int, |
|
): |
|
super().__init__() |
|
|
|
self.E = n_hidden |
|
self.d_q = self.d_k = n_hidden |
|
bottleneck_size = bottleneck_size_p |
|
|
|
self.encoder = ReferenceEncoder( |
|
ref_enc_filters=ref_enc_filters, |
|
ref_enc_gru_size=ref_enc_gru_size, |
|
ref_enc_size=ref_enc_size, |
|
ref_enc_strides=ref_enc_strides, |
|
num_mels=num_mels, |
|
) |
|
self.encoder_prj = nn.Linear(ref_enc_gru_size, n_hidden) |
|
self.attention = ConformerMultiHeadedSelfAttention( |
|
d_model=n_hidden, |
|
num_heads=n_heads, |
|
dropout_p=dropout, |
|
) |
|
self.encoder_bottleneck = nn.Linear(n_hidden, bottleneck_size) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
src_mask: torch.Tensor, |
|
mels: torch.Tensor, |
|
mel_lens: torch.Tensor, |
|
encoding: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" |
|
x --- [N, seq_len, encoder_embedding_dim] |
|
mels --- [N, Ty/r, n_mels*r], r=1 |
|
out --- [N, seq_len, bottleneck_size] |
|
attn --- [N, seq_len, ref_len], Ty/r = ref_len |
|
""" |
|
embedded_prosody, _, mel_masks = self.encoder(mels, mel_lens) |
|
|
|
|
|
embedded_prosody = self.encoder_prj(embedded_prosody) |
|
|
|
attn_mask = mel_masks.view((mel_masks.shape[0], 1, 1, -1)) |
|
x, _ = self.attention( |
|
query=x, |
|
key=embedded_prosody, |
|
value=embedded_prosody, |
|
mask=attn_mask, |
|
encoding=encoding, |
|
) |
|
x = self.encoder_bottleneck(x) |
|
x = x.masked_fill(src_mask.unsqueeze(-1), 0.0) |
|
return x |
|
|