|
import torch |
|
from torch.nn import functional as F |
|
|
|
|
|
class ResidualBlock(torch.nn.Module): |
|
"""Residual block module in WaveNet.""" |
|
|
|
def __init__( |
|
self, |
|
kernel_size=3, |
|
res_channels=64, |
|
gate_channels=128, |
|
skip_channels=64, |
|
aux_channels=80, |
|
dropout=0.0, |
|
dilation=1, |
|
bias=True, |
|
use_causal_conv=False, |
|
): |
|
super().__init__() |
|
self.dropout = dropout |
|
|
|
if use_causal_conv: |
|
padding = (kernel_size - 1) * dilation |
|
else: |
|
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." |
|
padding = (kernel_size - 1) // 2 * dilation |
|
self.use_causal_conv = use_causal_conv |
|
|
|
|
|
self.conv = torch.nn.Conv1d( |
|
res_channels, gate_channels, kernel_size, padding=padding, dilation=dilation, bias=bias |
|
) |
|
|
|
|
|
if aux_channels > 0: |
|
self.conv1x1_aux = torch.nn.Conv1d(aux_channels, gate_channels, 1, bias=False) |
|
else: |
|
self.conv1x1_aux = None |
|
|
|
|
|
gate_out_channels = gate_channels // 2 |
|
self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, res_channels, 1, bias=bias) |
|
self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, skip_channels, 1, bias=bias) |
|
|
|
def forward(self, x, c): |
|
""" |
|
x: B x D_res x T |
|
c: B x D_aux x T |
|
""" |
|
residual = x |
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
x = self.conv(x) |
|
|
|
|
|
x = x[:, :, : residual.size(-1)] if self.use_causal_conv else x |
|
|
|
|
|
splitdim = 1 |
|
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) |
|
|
|
|
|
if c is not None: |
|
assert self.conv1x1_aux is not None |
|
c = self.conv1x1_aux(c) |
|
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) |
|
xa, xb = xa + ca, xb + cb |
|
|
|
x = torch.tanh(xa) * torch.sigmoid(xb) |
|
|
|
|
|
s = self.conv1x1_skip(x) |
|
|
|
|
|
x = (self.conv1x1_out(x) + residual) * (0.5**2) |
|
|
|
return x, s |
|
|