from torch import nn class ZeroTemporalPad(nn.Module): """Pad sequences to equal lentgh in the temporal dimension""" def __init__(self, kernel_size, dilation): super().__init__() total_pad = (dilation * (kernel_size - 1)) begin = total_pad // 2 end = total_pad - begin self.pad_layer = nn.ZeroPad2d((0, 0, begin, end)) def forward(self, x): return self.pad_layer(x) class Conv1dBN(nn.Module): """1d convolutional with batch norm. conv1d -> relu -> BN blocks. Note: Batch normalization is applied after ReLU regarding the original implementation. Args: in_channels (int): number of input channels. out_channels (int): number of output channels. kernel_size (int): kernel size for convolutional filters. dilation (int): dilation for convolution layers. """ def __init__(self, in_channels, out_channels, kernel_size, dilation): super().__init__() padding = (dilation * (kernel_size - 1)) pad_s = padding // 2 pad_e = padding - pad_s self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation) self.pad = nn.ZeroPad2d((pad_s, pad_e, 0, 0)) # uneven left and right padding self.norm = nn.BatchNorm1d(out_channels) def forward(self, x): o = self.conv1d(x) o = self.pad(o) o = nn.functional.relu(o) o = self.norm(o) return o class Conv1dBNBlock(nn.Module): """1d convolutional block with batch norm. It is a set of conv1d -> relu -> BN blocks. Args: in_channels (int): number of input channels. out_channels (int): number of output channels. hidden_channels (int): number of inner convolution channels. kernel_size (int): kernel size for convolutional filters. dilation (int): dilation for convolution layers. num_conv_blocks (int, optional): number of convolutional blocks. Defaults to 2. """ def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation, num_conv_blocks=2): super().__init__() self.conv_bn_blocks = [] for idx in range(num_conv_blocks): layer = Conv1dBN(in_channels if idx == 0 else hidden_channels, out_channels if idx == (num_conv_blocks - 1) else hidden_channels, kernel_size, dilation) self.conv_bn_blocks.append(layer) self.conv_bn_blocks = nn.Sequential(*self.conv_bn_blocks) def forward(self, x): """ Shapes: x: (B, D, T) """ return self.conv_bn_blocks(x) class ResidualConv1dBNBlock(nn.Module): """Residual Convolutional Blocks with BN Each block has 'num_conv_block' conv layers and 'num_res_blocks' such blocks are connected with residual connections. conv_block = (conv1d -> relu -> bn) x 'num_conv_blocks' residuak_conv_block = (x -> conv_block -> + ->) x 'num_res_blocks' ' - - - - - - - - - ^ Args: in_channels (int): number of input channels. out_channels (int): number of output channels. hidden_channels (int): number of inner convolution channels. kernel_size (int): kernel size for convolutional filters. dilations (list): dilations for each convolution layer. num_res_blocks (int, optional): number of residual blocks. Defaults to 13. num_conv_blocks (int, optional): number of convolutional blocks in each residual block. Defaults to 2. """ def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilations, num_res_blocks=13, num_conv_blocks=2): super().__init__() assert len(dilations) == num_res_blocks self.res_blocks = nn.ModuleList() for idx, dilation in enumerate(dilations): block = Conv1dBNBlock(in_channels if idx==0 else hidden_channels, out_channels if (idx + 1) == len(dilations) else hidden_channels, hidden_channels, kernel_size, dilation, num_conv_blocks) self.res_blocks.append(block) def forward(self, x, x_mask=None): if x_mask is None: x_mask = 1.0 o = x * x_mask for block in self.res_blocks: res = o o = block(o) o = o + res if x_mask is not None: o = o * x_mask return o