Serhiy Stetskovych
Initial commit
2ccf6b5
raw
history blame
2.56 kB
import math
import torch.nn as nn
import torch
import torch.nn.functional as F
import pflow.models.components.vits_modules as modules
import pflow.models.components.commons as commons
class Mish(nn.Module):
def forward(self, x):
return x * torch.tanh(F.softplus(x))
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super(SinusoidalPosEmb, self).__init__()
self.dim = dim
def forward(self, x, scale=1000):
if x.ndim < 1:
x = x.unsqueeze(0)
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class VitsWNDecoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0,
pe_scale=1000
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.pe_scale = pe_scale
self.time_pos_emb = SinusoidalPosEmb(hidden_channels * 2)
dim = hidden_channels * 2
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
Mish(),
nn.Linear(dim * 4, dim)
)
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(hidden_channels * 2,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels)
self.proj = nn.Conv1d(hidden_channels * 2, out_channels, 1)
def forward(self, x, x_mask, mu, t, *args, **kwargs):
# x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)),
# 1).to(x.dtype)
t = self.time_pos_emb(t, scale=self.pe_scale)
t = self.mlp(t)
x = self.pre(x) * x_mask
mu = self.pre(mu)
x = torch.cat((x, mu), dim=1)
x = self.enc(x, x_mask, g=t)
stats = self.proj(x) * x_mask
return stats