|
from torch import nn |
|
|
|
from .normalization import LayerNorm |
|
|
|
|
|
class GatedConvBlock(nn.Module): |
|
"""Gated convolutional block as in https://arxiv.org/pdf/1612.08083.pdf |
|
Args: |
|
in_out_channels (int): number of input/output channels. |
|
kernel_size (int): convolution kernel size. |
|
dropout_p (float): dropout rate. |
|
""" |
|
def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers): |
|
super().__init__() |
|
|
|
self.dropout_p = dropout_p |
|
self.num_layers = num_layers |
|
|
|
self.conv_layers = nn.ModuleList() |
|
self.norm_layers = nn.ModuleList() |
|
self.layers = nn.ModuleList() |
|
for _ in range(num_layers): |
|
self.conv_layers += [ |
|
nn.Conv1d(in_out_channels, |
|
2 * in_out_channels, |
|
kernel_size, |
|
padding=kernel_size // 2) |
|
] |
|
self.norm_layers += [LayerNorm(2 * in_out_channels)] |
|
|
|
def forward(self, x, x_mask): |
|
o = x |
|
res = x |
|
for idx in range(self.num_layers): |
|
o = nn.functional.dropout(o, |
|
p=self.dropout_p, |
|
training=self.training) |
|
o = self.conv_layers[idx](o * x_mask) |
|
o = self.norm_layers[idx](o) |
|
o = nn.functional.glu(o, dim=1) |
|
o = res + o |
|
res = o |
|
return o |