|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
torch._C._jit_set_profiling_mode(False) |
|
torch._C._jit_set_profiling_executor(False) |
|
torch._C._jit_override_can_fuse_on_cpu(True) |
|
torch._C._jit_override_can_fuse_on_gpu(True) |
|
|
|
|
|
def get_activation(neox_args): |
|
"""retrieves the activation function specified in neox_args and whether or not the activation is gated""" |
|
is_gated = False |
|
if neox_args.activation == "geglu": |
|
is_gated = True |
|
activation_func = F.gelu |
|
elif neox_args.activation == "reglu": |
|
is_gated = True |
|
activation_func = F.relu |
|
elif neox_args.activation == "bilinear": |
|
is_gated = True |
|
activation_func = lambda x: x |
|
elif neox_args.activation == "swiglu": |
|
is_gated = True |
|
activation_func = swish |
|
elif neox_args.activation == "glu": |
|
is_gated = True |
|
activation_func = F.sigmoid |
|
elif neox_args.activation == "gelu": |
|
if neox_args.onnx_safe and neox_args.bias_gelu_fusion: |
|
raise ValueError("onnx_safe + bias_gelu_fusion not compatible") |
|
if neox_args.onnx_safe: |
|
activation_func = erf_gelu |
|
elif neox_args.bias_gelu_fusion: |
|
activation_func = bias_gelu_impl |
|
else: |
|
activation_func = F.gelu |
|
elif neox_args.activation == "relu": |
|
activation_func = F.relu |
|
elif neox_args.activation == "softsign": |
|
activation_func = F.softsign |
|
elif neox_args.activation == "swish": |
|
activation_func = swish |
|
elif neox_args.activation == "mish": |
|
activation_func = mish |
|
elif neox_args.activation == "silu": |
|
activation_func = F.silu |
|
else: |
|
raise ValueError(f"Activation function {neox_args.activation} not recognized") |
|
return activation_func, is_gated |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.script |
|
def bias_gelu(bias, y): |
|
x = bias + y |
|
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) |
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.script |
|
def bias_gelu_back(g, bias, y): |
|
x = bias + y |
|
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) |
|
|
|
ff = 0.5 * x * ( |
|
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x) |
|
) + 0.5 * (1 + tanh_out) |
|
return ff * g |
|
|
|
|
|
class GeLUFunction(torch.autograd.Function): |
|
@staticmethod |
|
|
|
def forward(ctx, input, bias): |
|
ctx.save_for_backward(input, bias) |
|
return bias_gelu(bias, input) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
input, bias = ctx.saved_tensors |
|
tmp = bias_gelu_back(grad_output, bias, input) |
|
return tmp, tmp |
|
|
|
|
|
bias_gelu_impl = GeLUFunction.apply |
|
|
|
|
|
|
|
@torch.jit.script |
|
def erf_gelu(x): |
|
return ( |
|
x |
|
* 0.5 |
|
* ( |
|
torch.erf(x / 1.41421).to(dtype=x.dtype) |
|
+ torch.ones_like(x).to(dtype=x.dtype) |
|
) |
|
) |
|
|
|
|
|
@torch.jit.script |
|
def swish(x, beta: float = 1.0): |
|
return x * torch.sigmoid(beta * x) |
|
|
|
|
|
@torch.jit.script |
|
def mish(x): |
|
return x * torch.tanh(F.softplus(x)) |
|
|