|
from typing import Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.utils import parametrize |
|
|
|
from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor |
|
|
|
|
|
def calc_same_padding(kernel_size: int) -> Tuple[int, int]: |
|
pad = kernel_size // 2 |
|
return (pad, pad - (kernel_size + 1) % 2) |
|
|
|
|
|
class ConvNorm(nn.Module): |
|
"""A 1-dimensional convolutional layer with optional weight normalization. |
|
|
|
This layer wraps a 1D convolutional layer from PyTorch and applies |
|
optional weight normalization. The layer can be used in a similar way to |
|
the convolutional layers in PyTorch's `torch.nn` module. |
|
|
|
Args: |
|
in_channels (int): The number of channels in the input signal. |
|
out_channels (int): The number of channels in the output signal. |
|
kernel_size (int, optional): The size of the convolving kernel. |
|
Defaults to 1. |
|
stride (int, optional): The stride of the convolution. Defaults to 1. |
|
padding (int, optional): Zero-padding added to both sides of the input. |
|
If `None`, the padding will be calculated so that the output has |
|
the same length as the input. Defaults to `None`. |
|
dilation (int, optional): Spacing between kernel elements. Defaults to 1. |
|
bias (bool, optional): If `True`, add bias after convolution. Defaults to `True`. |
|
w_init_gain (str, optional): The weight initialization function to use. |
|
Can be either 'linear' or 'relu'. Defaults to 'linear'. |
|
use_weight_norm (bool, optional): If `True`, apply weight normalization |
|
to the convolutional weights. Defaults to `False`. |
|
|
|
Shapes: |
|
- Input: :math:`[N, D, T]` |
|
|
|
- Output: :math:`[N, out_dim, T]` where `out_dim` is the number of output dimensions. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=None, |
|
dilation=1, |
|
bias=True, |
|
w_init_gain="linear", |
|
use_weight_norm=False, |
|
): |
|
super(ConvNorm, self).__init__() |
|
if padding is None: |
|
assert kernel_size % 2 == 1 |
|
padding = int(dilation * (kernel_size - 1) / 2) |
|
self.kernel_size = kernel_size |
|
self.dilation = dilation |
|
self.use_weight_norm = use_weight_norm |
|
conv_fn = nn.Conv1d |
|
self.conv = conv_fn( |
|
in_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=bias, |
|
) |
|
nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain)) |
|
if self.use_weight_norm: |
|
self.conv = nn.utils.parametrizations.weight_norm(self.conv) |
|
|
|
def forward(self, signal, mask=None): |
|
conv_signal = self.conv(signal) |
|
if mask is not None: |
|
|
|
|
|
conv_signal = conv_signal * mask |
|
return conv_signal |
|
|
|
|
|
class ConvLSTMLinear(nn.Module): |
|
def __init__( |
|
self, |
|
in_dim, |
|
out_dim, |
|
n_layers=2, |
|
n_channels=256, |
|
kernel_size=3, |
|
p_dropout=0.1, |
|
lstm_type="bilstm", |
|
use_linear=True, |
|
): |
|
super(ConvLSTMLinear, self).__init__() |
|
self.out_dim = out_dim |
|
self.lstm_type = lstm_type |
|
self.use_linear = use_linear |
|
self.dropout = nn.Dropout(p=p_dropout) |
|
|
|
convolutions = [] |
|
for i in range(n_layers): |
|
conv_layer = ConvNorm( |
|
in_dim if i == 0 else n_channels, |
|
n_channels, |
|
kernel_size=kernel_size, |
|
stride=1, |
|
padding=int((kernel_size - 1) / 2), |
|
dilation=1, |
|
w_init_gain="relu", |
|
) |
|
conv_layer = nn.utils.parametrizations.weight_norm(conv_layer.conv, name="weight") |
|
convolutions.append(conv_layer) |
|
|
|
self.convolutions = nn.ModuleList(convolutions) |
|
|
|
if not self.use_linear: |
|
n_channels = out_dim |
|
|
|
if self.lstm_type != "": |
|
use_bilstm = False |
|
lstm_channels = n_channels |
|
if self.lstm_type == "bilstm": |
|
use_bilstm = True |
|
lstm_channels = int(n_channels // 2) |
|
|
|
self.bilstm = nn.LSTM(n_channels, lstm_channels, 1, batch_first=True, bidirectional=use_bilstm) |
|
lstm_norm_fn_pntr = nn.utils.spectral_norm |
|
self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0") |
|
if self.lstm_type == "bilstm": |
|
self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0_reverse") |
|
|
|
if self.use_linear: |
|
self.dense = nn.Linear(n_channels, out_dim) |
|
|
|
def run_padded_sequence(self, context, lens): |
|
context_embedded = [] |
|
for b_ind in range(context.size()[0]): |
|
curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone() |
|
for conv in self.convolutions: |
|
curr_context = self.dropout(F.relu(conv(curr_context))) |
|
context_embedded.append(curr_context[0].transpose(0, 1)) |
|
context = nn.utils.rnn.pad_sequence(context_embedded, batch_first=True) |
|
return context |
|
|
|
def run_unsorted_inputs(self, fn, context, lens): |
|
lens_sorted, ids_sorted = torch.sort(lens, descending=True) |
|
unsort_ids = [0] * lens.size(0) |
|
for i in range(len(ids_sorted)): |
|
unsort_ids[ids_sorted[i]] = i |
|
lens_sorted = lens_sorted.long().cpu() |
|
|
|
context = context[ids_sorted] |
|
context = nn.utils.rnn.pack_padded_sequence(context, lens_sorted, batch_first=True) |
|
context = fn(context)[0] |
|
context = nn.utils.rnn.pad_packed_sequence(context, batch_first=True)[0] |
|
|
|
|
|
context = context[unsort_ids] |
|
return context |
|
|
|
def forward(self, context, lens): |
|
if context.size()[0] > 1: |
|
context = self.run_padded_sequence(context, lens) |
|
|
|
context = context.transpose(1, 2) |
|
else: |
|
for conv in self.convolutions: |
|
context = self.dropout(F.relu(conv(context))) |
|
|
|
if self.lstm_type != "": |
|
context = context.transpose(1, 2) |
|
self.bilstm.flatten_parameters() |
|
if lens is not None: |
|
context = self.run_unsorted_inputs(self.bilstm, context, lens) |
|
else: |
|
context = self.bilstm(context)[0] |
|
context = context.transpose(1, 2) |
|
|
|
x_hat = context |
|
if self.use_linear: |
|
x_hat = self.dense(context.transpose(1, 2)).transpose(1, 2) |
|
|
|
return x_hat |
|
|
|
|
|
class DepthWiseConv1d(nn.Module): |
|
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: int): |
|
super().__init__() |
|
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, groups=in_channels) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.conv(x) |
|
|
|
|
|
class PointwiseConv1d(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
stride: int = 1, |
|
padding: int = 0, |
|
bias: bool = True, |
|
): |
|
super().__init__() |
|
self.conv = nn.Conv1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=1, |
|
stride=stride, |
|
padding=padding, |
|
bias=bias, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.conv(x) |
|
|
|
|
|
class BSConv1d(nn.Module): |
|
"""https://arxiv.org/pdf/2003.13549.pdf""" |
|
|
|
def __init__(self, channels_in: int, channels_out: int, kernel_size: int, padding: int): |
|
super().__init__() |
|
self.pointwise = nn.Conv1d(channels_in, channels_out, kernel_size=1) |
|
self.depthwise = nn.Conv1d( |
|
channels_out, |
|
channels_out, |
|
kernel_size=kernel_size, |
|
padding=padding, |
|
groups=channels_out, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x1 = self.pointwise(x) |
|
x2 = self.depthwise(x1) |
|
return x2 |
|
|
|
|
|
class BSConv2d(nn.Module): |
|
"""https://arxiv.org/pdf/2003.13549.pdf""" |
|
|
|
def __init__(self, channels_in: int, channels_out: int, kernel_size: int, padding: int): |
|
super().__init__() |
|
self.pointwise = nn.Conv2d(channels_in, channels_out, kernel_size=1) |
|
self.depthwise = nn.Conv2d( |
|
channels_out, |
|
channels_out, |
|
kernel_size=kernel_size, |
|
padding=padding, |
|
groups=channels_out, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x1 = self.pointwise(x) |
|
x2 = self.depthwise(x1) |
|
return x2 |
|
|
|
|
|
class Conv1dGLU(nn.Module): |
|
"""From DeepVoice 3""" |
|
|
|
def __init__(self, d_model: int, kernel_size: int, padding: int, embedding_dim: int): |
|
super().__init__() |
|
self.conv = BSConv1d(d_model, 2 * d_model, kernel_size=kernel_size, padding=padding) |
|
self.embedding_proj = nn.Linear(embedding_dim, d_model) |
|
self.register_buffer("sqrt", torch.sqrt(torch.FloatTensor([0.5])).squeeze(0)) |
|
self.softsign = torch.nn.Softsign() |
|
|
|
def forward(self, x: torch.Tensor, embeddings: torch.Tensor) -> torch.Tensor: |
|
x = x.permute((0, 2, 1)) |
|
residual = x |
|
x = self.conv(x) |
|
splitdim = 1 |
|
a, b = x.split(x.size(splitdim) // 2, dim=splitdim) |
|
embeddings = self.embedding_proj(embeddings).unsqueeze(2) |
|
softsign = self.softsign(embeddings) |
|
softsign = softsign.expand_as(a) |
|
a = a + softsign |
|
x = a * torch.sigmoid(b) |
|
x = x + residual |
|
x = x * self.sqrt |
|
x = x.permute((0, 2, 1)) |
|
return x |
|
|
|
|
|
class ConvTransposed(nn.Module): |
|
""" |
|
A 1D convolutional transposed layer for PyTorch. |
|
This layer applies a 1D convolutional transpose operation to its input tensor, |
|
where the number of channels of the input tensor is the same as the number of channels of the output tensor. |
|
|
|
Attributes: |
|
in_channels (int): The number of channels in the input tensor. |
|
out_channels (int): The number of channels in the output tensor. |
|
kernel_size (int): The size of the convolutional kernel. Default: 1. |
|
padding (int): The number of padding elements to add to the input tensor. Default: 0. |
|
conv (BSConv1d): The 1D convolutional transpose layer. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
kernel_size: int = 1, |
|
padding: int = 0, |
|
): |
|
super().__init__() |
|
self.conv = BSConv1d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
padding=padding, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = x.contiguous().transpose(1, 2) |
|
x = self.conv(x) |
|
x = x.contiguous().transpose(1, 2) |
|
return x |
|
|
|
|
|
class DepthwiseConvModule(nn.Module): |
|
def __init__(self, dim: int, kernel_size: int = 7, expansion: int = 4, lrelu_slope: float = 0.3): |
|
super().__init__() |
|
padding = calc_same_padding(kernel_size) |
|
self.depthwise = nn.Conv1d( |
|
dim, |
|
dim * expansion, |
|
kernel_size=kernel_size, |
|
padding=padding[0], |
|
groups=dim, |
|
) |
|
self.act = nn.LeakyReLU(lrelu_slope) |
|
self.out = nn.Conv1d(dim * expansion, dim, 1, 1, 0) |
|
self.ln = nn.LayerNorm(dim) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.ln(x) |
|
x = x.permute((0, 2, 1)) |
|
x = self.depthwise(x) |
|
x = self.act(x) |
|
x = self.out(x) |
|
x = x.permute((0, 2, 1)) |
|
return x |
|
|
|
|
|
class AddCoords(nn.Module): |
|
def __init__(self, rank: int, with_r: bool = False): |
|
super().__init__() |
|
self.rank = rank |
|
self.with_r = with_r |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
if self.rank == 1: |
|
batch_size_shape, channel_in_shape, dim_x = x.shape |
|
xx_range = torch.arange(dim_x, dtype=torch.int32) |
|
xx_channel = xx_range[None, None, :] |
|
|
|
xx_channel = xx_channel.float() / (dim_x - 1) |
|
xx_channel = xx_channel * 2 - 1 |
|
xx_channel = xx_channel.repeat(batch_size_shape, 1, 1) |
|
|
|
xx_channel = xx_channel.to(x.device) |
|
out = torch.cat([x, xx_channel], dim=1) |
|
|
|
if self.with_r: |
|
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2)) |
|
out = torch.cat([out, rr], dim=1) |
|
|
|
elif self.rank == 2: |
|
batch_size_shape, channel_in_shape, dim_y, dim_x = x.shape |
|
xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32) |
|
yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32) |
|
|
|
xx_range = torch.arange(dim_y, dtype=torch.int32) |
|
yy_range = torch.arange(dim_x, dtype=torch.int32) |
|
xx_range = xx_range[None, None, :, None] |
|
yy_range = yy_range[None, None, :, None] |
|
|
|
xx_channel = torch.matmul(xx_range, xx_ones) |
|
yy_channel = torch.matmul(yy_range, yy_ones) |
|
|
|
|
|
yy_channel = yy_channel.permute(0, 1, 3, 2) |
|
|
|
xx_channel = xx_channel.float() / (dim_y - 1) |
|
yy_channel = yy_channel.float() / (dim_x - 1) |
|
|
|
xx_channel = xx_channel * 2 - 1 |
|
yy_channel = yy_channel * 2 - 1 |
|
|
|
xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1) |
|
yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1) |
|
|
|
xx_channel = xx_channel.to(x.device) |
|
yy_channel = yy_channel.to(x.device) |
|
|
|
out = torch.cat([x, xx_channel, yy_channel], dim=1) |
|
|
|
if self.with_r: |
|
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) |
|
out = torch.cat([out, rr], dim=1) |
|
|
|
elif self.rank == 3: |
|
batch_size_shape, channel_in_shape, dim_z, dim_y, dim_x = x.shape |
|
xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32) |
|
yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32) |
|
zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32) |
|
|
|
xy_range = torch.arange(dim_y, dtype=torch.int32) |
|
xy_range = xy_range[None, None, None, :, None] |
|
|
|
yz_range = torch.arange(dim_z, dtype=torch.int32) |
|
yz_range = yz_range[None, None, None, :, None] |
|
|
|
zx_range = torch.arange(dim_x, dtype=torch.int32) |
|
zx_range = zx_range[None, None, None, :, None] |
|
|
|
xy_channel = torch.matmul(xy_range, xx_ones) |
|
xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], dim=2) |
|
|
|
yz_channel = torch.matmul(yz_range, yy_ones) |
|
yz_channel = yz_channel.permute(0, 1, 3, 4, 2) |
|
yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], dim=4) |
|
|
|
zx_channel = torch.matmul(zx_range, zz_ones) |
|
zx_channel = zx_channel.permute(0, 1, 4, 2, 3) |
|
zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], dim=3) |
|
|
|
xx_channel = xx_channel.to(x.device) |
|
yy_channel = yy_channel.to(x.device) |
|
zz_channel = zz_channel.to(x.device) |
|
out = torch.cat([x, xx_channel, yy_channel, zz_channel], dim=1) |
|
|
|
if self.with_r: |
|
rr = torch.sqrt( |
|
torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2) + torch.pow(zz_channel - 0.5, 2) |
|
) |
|
out = torch.cat([out, rr], dim=1) |
|
else: |
|
raise NotImplementedError |
|
|
|
return out |
|
|
|
|
|
class CoordConv1d(nn.modules.conv.Conv1d): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
kernel_size: int, |
|
stride: int = 1, |
|
padding: int = 0, |
|
dilation: int = 1, |
|
groups: int = 1, |
|
bias: bool = True, |
|
with_r: bool = False, |
|
): |
|
super().__init__( |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride, |
|
padding, |
|
dilation, |
|
groups, |
|
bias, |
|
) |
|
self.rank = 1 |
|
self.addcoords = AddCoords(self.rank, with_r) |
|
self.conv = nn.Conv1d( |
|
in_channels + self.rank + int(with_r), |
|
out_channels, |
|
kernel_size, |
|
stride, |
|
padding, |
|
dilation, |
|
groups, |
|
bias, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.addcoords(x) |
|
x = self.conv(x) |
|
return x |
|
|
|
|
|
class CoordConv2d(nn.modules.conv.Conv2d): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
kernel_size: int, |
|
stride: int = 1, |
|
padding: int = 0, |
|
dilation: int = 1, |
|
groups: int = 1, |
|
bias: bool = True, |
|
with_r: bool = False, |
|
): |
|
super().__init__( |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride, |
|
padding, |
|
dilation, |
|
groups, |
|
bias, |
|
) |
|
self.rank = 2 |
|
self.addcoords = AddCoords(self.rank, with_r) |
|
self.conv = nn.Conv2d( |
|
in_channels + self.rank + int(with_r), |
|
out_channels, |
|
kernel_size, |
|
stride, |
|
padding, |
|
dilation, |
|
groups, |
|
bias, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.addcoords(x) |
|
x = self.conv(x) |
|
return x |
|
|
|
|
|
class LVCBlock(torch.nn.Module): |
|
"""the location-variable convolutions""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
cond_channels, |
|
stride, |
|
dilations=[1, 3, 9, 27], |
|
lReLU_slope=0.2, |
|
conv_kernel_size=3, |
|
cond_hop_length=256, |
|
kpnet_hidden_channels=64, |
|
kpnet_conv_size=3, |
|
kpnet_dropout=0.0, |
|
): |
|
super().__init__() |
|
|
|
self.cond_hop_length = cond_hop_length |
|
self.conv_layers = len(dilations) |
|
self.conv_kernel_size = conv_kernel_size |
|
|
|
self.kernel_predictor = KernelPredictor( |
|
cond_channels=cond_channels, |
|
conv_in_channels=in_channels, |
|
conv_out_channels=2 * in_channels, |
|
conv_layers=len(dilations), |
|
conv_kernel_size=conv_kernel_size, |
|
kpnet_hidden_channels=kpnet_hidden_channels, |
|
kpnet_conv_size=kpnet_conv_size, |
|
kpnet_dropout=kpnet_dropout, |
|
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}, |
|
) |
|
|
|
self.convt_pre = nn.Sequential( |
|
nn.LeakyReLU(lReLU_slope), |
|
nn.utils.parametrizations.weight_norm( |
|
nn.ConvTranspose1d( |
|
in_channels, |
|
in_channels, |
|
2 * stride, |
|
stride=stride, |
|
padding=stride // 2 + stride % 2, |
|
output_padding=stride % 2, |
|
) |
|
), |
|
) |
|
|
|
self.conv_blocks = nn.ModuleList() |
|
for dilation in dilations: |
|
self.conv_blocks.append( |
|
nn.Sequential( |
|
nn.LeakyReLU(lReLU_slope), |
|
nn.utils.parametrizations.weight_norm( |
|
nn.Conv1d( |
|
in_channels, |
|
in_channels, |
|
conv_kernel_size, |
|
padding=dilation * (conv_kernel_size - 1) // 2, |
|
dilation=dilation, |
|
) |
|
), |
|
nn.LeakyReLU(lReLU_slope), |
|
) |
|
) |
|
|
|
def forward(self, x, c): |
|
"""forward propagation of the location-variable convolutions. |
|
Args: |
|
x (Tensor): the input sequence (batch, in_channels, in_length) |
|
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) |
|
|
|
Returns: |
|
Tensor: the output sequence (batch, in_channels, in_length) |
|
""" |
|
_, in_channels, _ = x.shape |
|
|
|
x = self.convt_pre(x) |
|
kernels, bias = self.kernel_predictor(c) |
|
|
|
for i, conv in enumerate(self.conv_blocks): |
|
output = conv(x) |
|
|
|
k = kernels[:, i, :, :, :, :] |
|
b = bias[:, i, :, :] |
|
|
|
output = self.location_variable_convolution( |
|
output, k, b, hop_size=self.cond_hop_length |
|
) |
|
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh( |
|
output[:, in_channels:, :] |
|
) |
|
|
|
return x |
|
|
|
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): |
|
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. |
|
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. |
|
Args: |
|
x (Tensor): the input sequence (batch, in_channels, in_length). |
|
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) |
|
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) |
|
dilation (int): the dilation of convolution. |
|
hop_size (int): the hop_size of the conditioning sequence. |
|
Returns: |
|
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). |
|
""" |
|
batch, _, in_length = x.shape |
|
batch, _, out_channels, kernel_size, kernel_length = kernel.shape |
|
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" |
|
|
|
padding = dilation * int((kernel_size - 1) / 2) |
|
x = F.pad(x, (padding, padding), "constant", 0) |
|
x = x.unfold(2, hop_size + 2 * padding, hop_size) |
|
|
|
if hop_size < dilation: |
|
x = F.pad(x, (0, dilation), "constant", 0) |
|
x = x.unfold( |
|
3, dilation, dilation |
|
) |
|
x = x[:, :, :, :, :hop_size] |
|
x = x.transpose(3, 4) |
|
x = x.unfold(4, kernel_size, 1) |
|
|
|
o = torch.einsum("bildsk,biokl->bolsd", x, kernel) |
|
o = o.to(memory_format=torch.channels_last_3d) |
|
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) |
|
o = o + bias |
|
o = o.contiguous().view(batch, out_channels, -1) |
|
|
|
return o |
|
|
|
def remove_weight_norm(self): |
|
self.kernel_predictor.remove_weight_norm() |
|
parametrize.remove_parametrizations(self.convt_pre[1], "weight") |
|
for block in self.conv_blocks: |
|
parametrize.remove_parametrizations(block[1], "weight") |
|
|