FAcodecV2 / modules /style_encoder.py
Plachta's picture
Upload 69 files
a4d0945 verified
raw
history blame
2.89 kB
from . import attentions
from torch import nn
import torch
from torch.nn import functional as F
class Mish(nn.Module):
def __init__(self):
super(Mish, self).__init__()
def forward(self, x):
return x * torch.tanh(F.softplus(x))
class Conv1dGLU(nn.Module):
'''
Conv1d + GLU(Gated Linear Unit) with residual connection.
For GLU refer to https://arxiv.org/abs/1612.08083 paper.
'''
def __init__(self, in_channels, out_channels, kernel_size, dropout):
super(Conv1dGLU, self).__init__()
self.out_channels = out_channels
self.conv1 = nn.Conv1d(in_channels, 2 * out_channels, kernel_size=kernel_size, padding=2)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
x = self.conv1(x)
x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
x = x1 * torch.sigmoid(x2)
x = residual + self.dropout(x)
return x
class StyleEncoder(torch.nn.Module):
def __init__(self, in_dim=513, hidden_dim=128, out_dim=256):
super().__init__()
self.in_dim = in_dim # Linear 513 wav2vec 2.0 1024
self.hidden_dim = hidden_dim
self.out_dim = out_dim
self.kernel_size = 5
self.n_head = 2
self.dropout = 0.1
self.spectral = nn.Sequential(
nn.Conv1d(self.in_dim, self.hidden_dim, 1),
Mish(),
nn.Dropout(self.dropout),
nn.Conv1d(self.hidden_dim, self.hidden_dim, 1),
Mish(),
nn.Dropout(self.dropout)
)
self.temporal = nn.Sequential(
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
)
self.slf_attn = attentions.MultiHeadAttention(self.hidden_dim, self.hidden_dim, self.n_head, p_dropout = self.dropout, proximal_bias= False, proximal_init=True)
self.atten_drop = nn.Dropout(self.dropout)
self.fc = nn.Conv1d(self.hidden_dim, self.out_dim, 1)
def forward(self, x, mask=None):
# spectral
x = self.spectral(x)*mask
# temporal
x = self.temporal(x)*mask
# self-attention
attn_mask = mask.unsqueeze(2) * mask.unsqueeze(-1)
y = self.slf_attn(x,x, attn_mask=attn_mask)
x = x+ self.atten_drop(y)
# fc
x = self.fc(x)
# temoral average pooling
w = self.temporal_avg_pool(x, mask=mask)
return w
def temporal_avg_pool(self, x, mask=None):
if mask is None:
out = torch.mean(x, dim=2)
else:
len_ = mask.sum(dim=2)
x = x.sum(dim=2)
out = torch.div(x, len_)
return out