|
from typing import List |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torch.nn.utils import parametrize |
|
|
|
from TTS.vocoder.layers.lvc_block import LVCBlock |
|
|
|
LRELU_SLOPE = 0.1 |
|
|
|
|
|
class UnivnetGenerator(torch.nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
hidden_channels: int, |
|
cond_channels: int, |
|
upsample_factors: List[int], |
|
lvc_layers_each_block: int, |
|
lvc_kernel_size: int, |
|
kpnet_hidden_channels: int, |
|
kpnet_conv_size: int, |
|
dropout: float, |
|
use_weight_norm=True, |
|
): |
|
"""Univnet Generator network. |
|
|
|
Paper: https://arxiv.org/pdf/2106.07889.pdf |
|
|
|
Args: |
|
in_channels (int): Number of input tensor channels. |
|
out_channels (int): Number of channels of the output tensor. |
|
hidden_channels (int): Number of hidden network channels. |
|
cond_channels (int): Number of channels of the conditioning tensors. |
|
upsample_factors (List[int]): List of uplsample factors for the upsampling layers. |
|
lvc_layers_each_block (int): Number of LVC layers in each block. |
|
lvc_kernel_size (int): Kernel size of the LVC layers. |
|
kpnet_hidden_channels (int): Number of hidden channels in the key-point network. |
|
kpnet_conv_size (int): Number of convolution channels in the key-point network. |
|
dropout (float): Dropout rate. |
|
use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True. |
|
""" |
|
|
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.cond_channels = cond_channels |
|
self.upsample_scale = np.prod(upsample_factors) |
|
self.lvc_block_nums = len(upsample_factors) |
|
|
|
|
|
self.first_conv = torch.nn.Conv1d( |
|
in_channels, hidden_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True |
|
) |
|
|
|
|
|
self.lvc_blocks = torch.nn.ModuleList() |
|
cond_hop_length = 1 |
|
for n in range(self.lvc_block_nums): |
|
cond_hop_length = cond_hop_length * upsample_factors[n] |
|
lvcb = LVCBlock( |
|
in_channels=hidden_channels, |
|
cond_channels=cond_channels, |
|
upsample_ratio=upsample_factors[n], |
|
conv_layers=lvc_layers_each_block, |
|
conv_kernel_size=lvc_kernel_size, |
|
cond_hop_length=cond_hop_length, |
|
kpnet_hidden_channels=kpnet_hidden_channels, |
|
kpnet_conv_size=kpnet_conv_size, |
|
kpnet_dropout=dropout, |
|
) |
|
self.lvc_blocks += [lvcb] |
|
|
|
|
|
self.last_conv_layers = torch.nn.ModuleList( |
|
[ |
|
torch.nn.Conv1d( |
|
hidden_channels, out_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True |
|
), |
|
] |
|
) |
|
|
|
|
|
if use_weight_norm: |
|
self.apply_weight_norm() |
|
|
|
def forward(self, c): |
|
"""Calculate forward propagation. |
|
Args: |
|
c (Tensor): Local conditioning auxiliary features (B, C ,T'). |
|
Returns: |
|
Tensor: Output tensor (B, out_channels, T) |
|
""" |
|
|
|
x = torch.randn([c.shape[0], self.in_channels, c.shape[2]]) |
|
x = x.to(self.first_conv.bias.device) |
|
x = self.first_conv(x) |
|
|
|
for n in range(self.lvc_block_nums): |
|
x = self.lvc_blocks[n](x, c) |
|
|
|
|
|
for f in self.last_conv_layers: |
|
x = F.leaky_relu(x, LRELU_SLOPE) |
|
x = f(x) |
|
x = torch.tanh(x) |
|
return x |
|
|
|
def remove_weight_norm(self): |
|
"""Remove weight normalization module from all of the layers.""" |
|
|
|
def _remove_weight_norm(m): |
|
try: |
|
|
|
parametrize.remove_parametrizations(m, "weight") |
|
except ValueError: |
|
return |
|
|
|
self.apply(_remove_weight_norm) |
|
|
|
def apply_weight_norm(self): |
|
"""Apply weight normalization module from all of the layers.""" |
|
|
|
def _apply_weight_norm(m): |
|
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): |
|
torch.nn.utils.parametrizations.weight_norm(m) |
|
|
|
|
|
self.apply(_apply_weight_norm) |
|
|
|
@staticmethod |
|
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x): |
|
assert layers % stacks == 0 |
|
layers_per_cycle = layers // stacks |
|
dilations = [dilation(i % layers_per_cycle) for i in range(layers)] |
|
return (kernel_size - 1) * sum(dilations) + 1 |
|
|
|
@property |
|
def receptive_field_size(self): |
|
"""Return receptive field size.""" |
|
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) |
|
|
|
@torch.no_grad() |
|
def inference(self, c): |
|
"""Perform inference. |
|
Args: |
|
c (Tensor): Local conditioning auxiliary features :math:`(B, C, T)`. |
|
Returns: |
|
Tensor: Output tensor (T, out_channels) |
|
""" |
|
x = torch.randn([c.shape[0], self.in_channels, c.shape[2]]) |
|
x = x.to(self.first_conv.bias.device) |
|
|
|
c = c.to(next(self.parameters())) |
|
return self.forward(c) |
|
|