File size: 2,618 Bytes
9b2107c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import torch
from torch import nn
from TTS.tts.layers.glow_tts.decoder import Decoder as GlowDecoder
from TTS.tts.utils.helpers import sequence_mask
class Decoder(nn.Module):
"""Uses glow decoder with some modifications.
::
Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze
Args:
in_channels (int): channels of input tensor.
hidden_channels (int): hidden decoder channels.
kernel_size (int): Coupling block kernel size. (Wavenet filter kernel size.)
dilation_rate (int): rate to increase dilation by each layer in a decoder block.
num_flow_blocks (int): number of decoder blocks.
num_coupling_layers (int): number coupling layers. (number of wavenet layers.)
dropout_p (float): wavenet dropout rate.
sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer.
"""
def __init__(
self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_flow_blocks,
num_coupling_layers,
dropout_p=0.0,
num_splits=4,
num_squeeze=2,
sigmoid_scale=False,
c_in_channels=0,
):
super().__init__()
self.glow_decoder = GlowDecoder(
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_flow_blocks,
num_coupling_layers,
dropout_p,
num_splits,
num_squeeze,
sigmoid_scale,
c_in_channels,
)
self.n_sqz = num_squeeze
def forward(self, x, x_len, g=None, reverse=False):
"""
Input shapes:
- x: :math:`[B, C, T]`
- x_len :math:`[B]`
- g: :math:`[B, C]`
Output shapes:
- x: :math:`[B, C, T]`
- x_len :math:`[B]`
- logget_tot :math:`[B]`
"""
x, x_len, x_max_len = self.preprocess(x, x_len, x_len.max())
x_mask = torch.unsqueeze(sequence_mask(x_len, x_max_len), 1).to(x.dtype)
x, logdet_tot = self.glow_decoder(x, x_mask, g, reverse)
return x, x_len, logdet_tot
def preprocess(self, y, y_lengths, y_max_length):
if y_max_length is not None:
y_max_length = torch.div(y_max_length, self.n_sqz, rounding_mode="floor") * self.n_sqz
y = y[:, :, :y_max_length]
y_lengths = torch.div(y_lengths, self.n_sqz, rounding_mode="floor") * self.n_sqz
return y, y_lengths, y_max_length
def store_inverse(self):
self.glow_decoder.store_inverse()
|