# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. import torch from torch import nn from torch import pow from torch import sin from torch.nn import Parameter class SnakeBeta(nn.Module): """ A modified Snake function which uses separate parameters for the magnitude of the periodic components Shape: - Input: (B, C, T) - Output: (B, C, T), same shape as the input Parameters: - alpha - trainable parameter that controls frequency - beta - trainable parameter that controls magnitude References: - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: https://arxiv.org/abs/2006.08195 """ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): """ Initialization. INPUT: - in_features: shape of the input - alpha - trainable parameter that controls frequency - beta - trainable parameter that controls magnitude alpha is initialized to 1 by default, higher values = higher-frequency. beta is initialized to 1 by default, higher values = higher-magnitude. alpha will be trained along with the rest of your model. """ super(SnakeBeta, self).__init__() self.in_features = in_features # initialize alpha self.alpha_logscale = alpha_logscale if self.alpha_logscale: # log scale alphas initialized to zeros self.alpha = Parameter(torch.zeros(in_features) * alpha) self.beta = Parameter(torch.zeros(in_features) * alpha) else: # linear scale alphas initialized to ones self.alpha = Parameter(torch.ones(in_features) * alpha) self.beta = Parameter(torch.ones(in_features) * alpha) self.alpha.requires_grad = alpha_trainable self.beta.requires_grad = alpha_trainable self.no_div_by_zero = 0.000000001 def forward(self, x): """ Forward pass of the function. Applies the function to the input elementwise. SnakeBeta ∶= x + 1/b * sin^2 (xa) """ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] beta = self.beta.unsqueeze(0).unsqueeze(-1) if self.alpha_logscale: alpha = torch.exp(alpha) beta = torch.exp(beta) x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) return x