|
import torch |
|
from torch import nn |
|
|
|
from TTS.tts.layers.glow_tts.decoder import Decoder as GlowDecoder |
|
from TTS.tts.utils.helpers import sequence_mask |
|
|
|
|
|
class Decoder(nn.Module): |
|
"""Uses glow decoder with some modifications. |
|
:: |
|
|
|
Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze |
|
|
|
Args: |
|
in_channels (int): channels of input tensor. |
|
hidden_channels (int): hidden decoder channels. |
|
kernel_size (int): Coupling block kernel size. (Wavenet filter kernel size.) |
|
dilation_rate (int): rate to increase dilation by each layer in a decoder block. |
|
num_flow_blocks (int): number of decoder blocks. |
|
num_coupling_layers (int): number coupling layers. (number of wavenet layers.) |
|
dropout_p (float): wavenet dropout rate. |
|
sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
hidden_channels, |
|
kernel_size, |
|
dilation_rate, |
|
num_flow_blocks, |
|
num_coupling_layers, |
|
dropout_p=0.0, |
|
num_splits=4, |
|
num_squeeze=2, |
|
sigmoid_scale=False, |
|
c_in_channels=0, |
|
): |
|
super().__init__() |
|
|
|
self.glow_decoder = GlowDecoder( |
|
in_channels, |
|
hidden_channels, |
|
kernel_size, |
|
dilation_rate, |
|
num_flow_blocks, |
|
num_coupling_layers, |
|
dropout_p, |
|
num_splits, |
|
num_squeeze, |
|
sigmoid_scale, |
|
c_in_channels, |
|
) |
|
self.n_sqz = num_squeeze |
|
|
|
def forward(self, x, x_len, g=None, reverse=False): |
|
""" |
|
Input shapes: |
|
- x: :math:`[B, C, T]` |
|
- x_len :math:`[B]` |
|
- g: :math:`[B, C]` |
|
|
|
Output shapes: |
|
- x: :math:`[B, C, T]` |
|
- x_len :math:`[B]` |
|
- logget_tot :math:`[B]` |
|
""" |
|
x, x_len, x_max_len = self.preprocess(x, x_len, x_len.max()) |
|
x_mask = torch.unsqueeze(sequence_mask(x_len, x_max_len), 1).to(x.dtype) |
|
x, logdet_tot = self.glow_decoder(x, x_mask, g, reverse) |
|
return x, x_len, logdet_tot |
|
|
|
def preprocess(self, y, y_lengths, y_max_length): |
|
if y_max_length is not None: |
|
y_max_length = torch.div(y_max_length, self.n_sqz, rounding_mode="floor") * self.n_sqz |
|
y = y[:, :, :y_max_length] |
|
y_lengths = torch.div(y_lengths, self.n_sqz, rounding_mode="floor") * self.n_sqz |
|
return y, y_lengths, y_max_length |
|
|
|
def store_inverse(self): |
|
self.glow_decoder.store_inverse() |
|
|