|
import math |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from TTS.tts.layers.generic.gated_conv import GatedConvBlock |
|
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock |
|
from TTS.tts.layers.generic.time_depth_sep_conv import TimeDepthSeparableConvBlock |
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor |
|
from TTS.tts.layers.glow_tts.glow import ResidualConv1dLayerNormBlock |
|
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer |
|
from TTS.tts.utils.helpers import sequence_mask |
|
|
|
|
|
class Encoder(nn.Module): |
|
"""Glow-TTS encoder module. |
|
|
|
:: |
|
|
|
embedding -> <prenet> -> encoder_module -> <postnet> --> proj_mean |
|
| |
|
|-> proj_var |
|
| |
|
|-> concat -> duration_predictor |
|
↑ |
|
speaker_embed |
|
|
|
Args: |
|
num_chars (int): number of characters. |
|
out_channels (int): number of output channels. |
|
hidden_channels (int): encoder's embedding size. |
|
hidden_channels_ffn (int): transformer's feed-forward channels. |
|
kernel_size (int): kernel size for conv layers and duration predictor. |
|
dropout_p (float): dropout rate for any dropout layer. |
|
mean_only (bool): if True, output only mean values and use constant std. |
|
use_prenet (bool): if True, use pre-convolutional layers before transformer layers. |
|
c_in_channels (int): number of channels in conditional input. |
|
|
|
Shapes: |
|
- input: (B, T, C) |
|
|
|
:: |
|
|
|
suggested encoder params... |
|
|
|
for encoder_type == 'rel_pos_transformer' |
|
encoder_params={ |
|
'kernel_size':3, |
|
'dropout_p': 0.1, |
|
'num_layers': 6, |
|
'num_heads': 2, |
|
'hidden_channels_ffn': 768, # 4 times the hidden_channels |
|
'input_length': None |
|
} |
|
|
|
for encoder_type == 'gated_conv' |
|
encoder_params={ |
|
'kernel_size':5, |
|
'dropout_p': 0.1, |
|
'num_layers': 9, |
|
} |
|
|
|
for encoder_type == 'residual_conv_bn' |
|
encoder_params={ |
|
"kernel_size": 4, |
|
"dilations": [1, 2, 4, 1, 2, 4, 1, 2, 4, 1, 2, 4, 1], |
|
"num_conv_blocks": 2, |
|
"num_res_blocks": 13 |
|
} |
|
|
|
for encoder_type == 'time_depth_separable' |
|
encoder_params={ |
|
"kernel_size": 5, |
|
'num_layers': 9, |
|
} |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_chars, |
|
out_channels, |
|
hidden_channels, |
|
hidden_channels_dp, |
|
encoder_type, |
|
encoder_params, |
|
dropout_p_dp=0.1, |
|
mean_only=False, |
|
use_prenet=True, |
|
c_in_channels=0, |
|
): |
|
super().__init__() |
|
|
|
self.num_chars = num_chars |
|
self.out_channels = out_channels |
|
self.hidden_channels = hidden_channels |
|
self.hidden_channels_dp = hidden_channels_dp |
|
self.dropout_p_dp = dropout_p_dp |
|
self.mean_only = mean_only |
|
self.use_prenet = use_prenet |
|
self.c_in_channels = c_in_channels |
|
self.encoder_type = encoder_type |
|
|
|
self.emb = nn.Embedding(num_chars, hidden_channels) |
|
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) |
|
|
|
if encoder_type.lower() == "rel_pos_transformer": |
|
if use_prenet: |
|
self.prenet = ResidualConv1dLayerNormBlock( |
|
hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5 |
|
) |
|
self.encoder = RelativePositionTransformer( |
|
hidden_channels, hidden_channels, hidden_channels, **encoder_params |
|
) |
|
elif encoder_type.lower() == "gated_conv": |
|
self.encoder = GatedConvBlock(hidden_channels, **encoder_params) |
|
elif encoder_type.lower() == "residual_conv_bn": |
|
if use_prenet: |
|
self.prenet = nn.Sequential(nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU()) |
|
self.encoder = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **encoder_params) |
|
self.postnet = nn.Sequential( |
|
nn.Conv1d(self.hidden_channels, self.hidden_channels, 1), nn.BatchNorm1d(self.hidden_channels) |
|
) |
|
elif encoder_type.lower() == "time_depth_separable": |
|
if use_prenet: |
|
self.prenet = ResidualConv1dLayerNormBlock( |
|
hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5 |
|
) |
|
self.encoder = TimeDepthSeparableConvBlock( |
|
hidden_channels, hidden_channels, hidden_channels, **encoder_params |
|
) |
|
else: |
|
raise ValueError(" [!] Unkown encoder type.") |
|
|
|
|
|
self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1) |
|
if not mean_only: |
|
self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1) |
|
|
|
self.duration_predictor = DurationPredictor( |
|
hidden_channels + c_in_channels, hidden_channels_dp, 3, dropout_p_dp |
|
) |
|
|
|
def forward(self, x, x_lengths, g=None): |
|
""" |
|
Shapes: |
|
- x: :math:`[B, C, T]` |
|
- x_lengths: :math:`[B]` |
|
- g (optional): :math:`[B, 1, T]` |
|
""" |
|
|
|
|
|
x = self.emb(x) * math.sqrt(self.hidden_channels) |
|
|
|
x = torch.transpose(x, 1, -1) |
|
|
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) |
|
|
|
if hasattr(self, "prenet") and self.use_prenet: |
|
x = self.prenet(x, x_mask) |
|
|
|
x = self.encoder(x, x_mask) |
|
|
|
if hasattr(self, "postnet"): |
|
x = self.postnet(x) * x_mask |
|
|
|
if g is not None: |
|
g_exp = g.expand(-1, -1, x.size(-1)) |
|
x_dp = torch.cat([x.detach(), g_exp], 1) |
|
else: |
|
x_dp = x.detach() |
|
|
|
x_m = self.proj_m(x) * x_mask |
|
if not self.mean_only: |
|
x_logs = self.proj_s(x) * x_mask |
|
else: |
|
x_logs = torch.zeros_like(x_m) |
|
|
|
logw = self.duration_predictor(x_dp, x_mask) |
|
return x_m, x_logs, logw, x_mask |
|
|