|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import triton |
|
import triton.language as tl |
|
|
|
import fla.modules.fused_bitlinear as fused_bitlinear |
|
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous |
|
|
|
sigmoid_fwd_codestring = """ |
|
template <typename T> T sigmoid_fwd(T x) { |
|
return 1.0f / (1.0f + ::exp(-float(x))); |
|
} |
|
""" |
|
sigmoid_bwd_codestring = """ |
|
template <typename T> T sigmoid_bwd(T x, T g) { |
|
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); |
|
return float(g) * x_sigmoid * (1.0f - x_sigmoid); |
|
} |
|
""" |
|
|
|
sigmoid_fwd = torch.cuda.jiterator._create_jit_fn(sigmoid_fwd_codestring) |
|
sigmoid_bwd = torch.cuda.jiterator._create_jit_fn(sigmoid_bwd_codestring) |
|
|
|
|
|
class SigmoidFunction(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, x): |
|
ctx.save_for_backward(x) |
|
return sigmoid_fwd(x) |
|
|
|
@staticmethod |
|
def backward(ctx, dout): |
|
x, = ctx.saved_tensors |
|
return sigmoid_bwd(x, dout) |
|
|
|
|
|
sigmoid = SigmoidFunction.apply |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({}, num_warps=1), |
|
triton.Config({}, num_warps=2), |
|
triton.Config({}, num_warps=4), |
|
triton.Config({}, num_warps=8), |
|
triton.Config({}, num_warps=16), |
|
triton.Config({}, num_warps=32) |
|
], |
|
key=['D'] |
|
) |
|
@triton.jit |
|
def logsigmoid_fwd_kernel( |
|
x, |
|
y, |
|
temperature, |
|
T: tl.constexpr, |
|
D: tl.constexpr, |
|
B: tl.constexpr |
|
): |
|
i = tl.program_id(0) |
|
o_i = i * B + tl.arange(0, B) |
|
m_i = o_i < T |
|
|
|
b_x = tl.load(x + o_i, mask=m_i, other=0.).to(tl.float32) |
|
b_m = tl.minimum(0., b_x) |
|
b_z = 1. + tl.exp(-tl.abs(b_x)) |
|
b_y = (b_m - tl.log(b_z)) / temperature |
|
tl.store(y + o_i, b_y.to(y.dtype.element_ty), mask=m_i) |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({}, num_warps=1), |
|
triton.Config({}, num_warps=2), |
|
triton.Config({}, num_warps=4), |
|
triton.Config({}, num_warps=8), |
|
triton.Config({}, num_warps=16), |
|
triton.Config({}, num_warps=32) |
|
], |
|
key=['D'] |
|
) |
|
@triton.jit |
|
def logsigmoid_bwd_kernel( |
|
x, |
|
dx, |
|
dy, |
|
temperature, |
|
T: tl.constexpr, |
|
D: tl.constexpr, |
|
B: tl.constexpr |
|
): |
|
i = tl.program_id(0) |
|
o_i = i * B + tl.arange(0, B) |
|
m_i = o_i < T |
|
|
|
b_x = tl.load(x + o_i, mask=m_i, other=0.).to(tl.float32) |
|
b_dy = tl.load(dy + o_i, mask=m_i, other=0.).to(tl.float32) |
|
b_dx = b_dy * (1. - tl.sigmoid(b_x)) / temperature |
|
tl.store(dx + o_i, b_dx.to(dx.dtype.element_ty), mask=m_i) |
|
|
|
|
|
def logsigmoid_fwd(x: torch.Tensor, temperature: float = 1.) -> torch.Tensor: |
|
T, D = x.numel(), x.shape[-1] |
|
B = triton.next_power_of_2(triton.cdiv(T, torch.cuda.get_device_properties(x.device).multi_processor_count)) |
|
y = torch.empty_like(x) |
|
logsigmoid_fwd_kernel[(triton.cdiv(T, B),)]( |
|
x=x, |
|
y=y, |
|
temperature=temperature, |
|
T=T, |
|
D=D, |
|
B=B |
|
) |
|
return y |
|
|
|
|
|
def logsigmoid_bwd(x: torch.Tensor, dy: torch.Tensor, temperature: float = 1.) -> torch.Tensor: |
|
T, D = x.numel(), x.shape[-1] |
|
B = triton.next_power_of_2(triton.cdiv(T, torch.cuda.get_device_properties(x.device).multi_processor_count)) |
|
dx = torch.empty_like(x) |
|
logsigmoid_bwd_kernel[(triton.cdiv(T, B),)]( |
|
x=x, |
|
dx=dx, |
|
dy=dy, |
|
temperature=temperature, |
|
T=T, |
|
D=D, |
|
B=B |
|
) |
|
return dx |
|
|
|
|
|
class LogSigmoidFunction(torch.autograd.Function): |
|
|
|
@staticmethod |
|
@contiguous |
|
def forward(ctx, x, temperature): |
|
ctx.save_for_backward(x,) |
|
ctx.temperature = temperature |
|
return logsigmoid_fwd(x, temperature) |
|
|
|
@staticmethod |
|
@contiguous |
|
def backward(ctx, dy): |
|
x, = ctx.saved_tensors |
|
return logsigmoid_bwd(x, dy, ctx.temperature), None |
|
|
|
|
|
def logsigmoid(x: torch.Tensor, temperature: float = 1.) -> torch.Tensor: |
|
return LogSigmoidFunction.apply(x, temperature) |
|
|
|
|
|
swish_fwd_codestring = """ |
|
template <typename T> T swish_fwd(T x) { |
|
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); |
|
return float(x) * x_sigmoid; |
|
} |
|
""" |
|
swish_bwd_codestring = """ |
|
template <typename T> T swish_bwd(T x, T g) { |
|
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); |
|
return float(g) * x_sigmoid * (1.0f - float(x) * x_sigmoid + float(x)); |
|
} |
|
""" |
|
|
|
swish_fwd = torch.cuda.jiterator._create_jit_fn(swish_fwd_codestring) |
|
swish_bwd = torch.cuda.jiterator._create_jit_fn(swish_bwd_codestring) |
|
|
|
|
|
class SwishFunction(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, x): |
|
ctx.save_for_backward(x) |
|
return swish_fwd(x) |
|
|
|
@staticmethod |
|
def backward(ctx, dout): |
|
x, = ctx.saved_tensors |
|
return swish_bwd(x, dout) |
|
|
|
|
|
swish = SwishFunction.apply |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.script |
|
def bias_gelu(y, bias): |
|
x = bias + y |
|
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype) |
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.script |
|
def bias_gelu_bwd(g, y, bias): |
|
"""Assume that y has shape (B, D) and bias has shape (D)""" |
|
x = bias + y |
|
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) |
|
|
|
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( |
|
1 + tanh_out |
|
) |
|
grad_y = ff * g |
|
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype) |
|
|
|
|
|
class GeLUFunction(torch.autograd.Function): |
|
|
|
@staticmethod |
|
|
|
def forward(ctx, input, bias): |
|
ctx.save_for_backward(input, bias) |
|
return bias_gelu(input, bias) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
input, bias = ctx.saved_tensors |
|
tmp = bias_gelu_bwd(grad_output, input, bias) |
|
return tmp, tmp |
|
|
|
|
|
bias_gelu_impl = GeLUFunction.apply |
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.script |
|
def gelu_fwd(x): |
|
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype) |
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.script |
|
def gelu_bwd(g, x): |
|
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) |
|
|
|
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( |
|
1 + tanh_out |
|
) |
|
return (ff * g).to(dtype=x.dtype) |
|
|
|
|
|
class FastGeLUFunction(torch.autograd.Function): |
|
@staticmethod |
|
|
|
def forward(ctx, input): |
|
ctx.save_for_backward(input) |
|
return gelu_fwd(input) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
(input,) = ctx.saved_tensors |
|
tmp = gelu_bwd(grad_output, input) |
|
return tmp |
|
|
|
|
|
fast_gelu_impl = FastGeLUFunction.apply |
|
|
|
|
|
@torch.jit.script |
|
def relu_bwd(g, x): |
|
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype) |
|
|
|
|
|
@torch.jit.script |
|
def sqrelu_fwd(x): |
|
r = F.relu(x) |
|
return (r * r).to(dtype=x.dtype) |
|
|
|
|
|
@torch.jit.script |
|
def sqrelu_bwd(g, x): |
|
return (2.0 * g * F.relu(x)).to(dtype=x.dtype) |
|
|
|
|
|
class SquaredReLUFunction(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, input): |
|
ctx.save_for_backward(input) |
|
return sqrelu_fwd(input) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
input, = ctx.saved_tensors |
|
return sqrelu_bwd(grad_output, input) |
|
|
|
|
|
sqrelu = SquaredReLUFunction.apply |
|
|
|
|
|
swiglu_fwd_codestring = """ |
|
template <typename T> T swiglu_fwd(T x, T y) { |
|
return float(x) * float(y) / (1.0f + ::exp(-float(x))); |
|
} |
|
""" |
|
swiglu_bwd_codestring = """ |
|
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) { |
|
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); |
|
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); |
|
dy = float(x) * x_sigmoid * float(g); |
|
} |
|
""" |
|
|
|
swiglu_bwd_with_output_codestring = """ |
|
template <typename T> T swiglu_bwd_with_output(T x, T y, T g, T& dx, T& dy, T& z) { |
|
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); |
|
float x_swish = float(x) * x_sigmoid; |
|
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); |
|
dy = x_swish * float(g); |
|
z = x_swish * float(y); |
|
} |
|
""" |
|
|
|
swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring) |
|
swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2) |
|
swiglu_bwd_with_output = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_with_output_codestring, num_outputs=3) |
|
|
|
|
|
class SwiGLUFunction(torch.autograd.Function): |
|
r""" |
|
Swish-Gated Linear Unit (SwiGLU) function. |
|
|
|
.. math:: |
|
\text{SwiGLU}(x, y) = swish(x) * y = \frac{x}{1 + \exp(-x)} * y |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx, x, y): |
|
ctx.save_for_backward(x, y) |
|
return swiglu_fwd(x, y) |
|
|
|
@staticmethod |
|
def backward(ctx, dout): |
|
x, y = ctx.saved_tensors |
|
return swiglu_bwd(x, y, dout) |
|
|
|
|
|
class SwiGLULinearFunction(torch.autograd.Function): |
|
r""" |
|
Swish-Gated Linear Unit (SwiGLU) function followed by a linear transformation. |
|
|
|
.. math:: |
|
\text{SwiGLULinear}(x, y, W, b) = (swish(x) * y) W + b |
|
|
|
This simple wrap discards the intermediate results of SwiGLU(x, y) to save memory. |
|
""" |
|
|
|
@staticmethod |
|
@autocast_custom_fwd |
|
def forward(ctx, x, y, weight, bias): |
|
z = swiglu_fwd(x, y) |
|
out = F.linear(z, weight, bias) |
|
|
|
ctx.save_for_backward(x, y, weight) |
|
ctx.linear_bias_is_none = bias is None |
|
return out |
|
|
|
@staticmethod |
|
@autocast_custom_bwd |
|
def backward(ctx, dout, *args): |
|
x, y, weight = ctx.saved_tensors |
|
dout = dout.reshape(-1, dout.shape[-1]) |
|
dz = F.linear(dout, weight.t()).view_as(x) |
|
dx, dy, z = swiglu_bwd_with_output(x, y, dz) |
|
dlinear_weight = torch.einsum("bo,bi->oi", dout, z.reshape(-1, z.shape[-1])) |
|
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) |
|
return dx, dy, dlinear_weight, dlinear_bias |
|
|
|
|
|
class SwiGLUBitLinearFunction(torch.autograd.Function): |
|
r""" |
|
Swish-Gated Linear Unit (SwiGLU) function followed by a linear transformation. |
|
|
|
.. math:: |
|
\text{SwiGLULinear}(x, y, W, b) = (swish(x) * y) W + b |
|
|
|
This simple wrap discards the intermediate results of SwiGLU(x, y) to save memory. |
|
""" |
|
|
|
@staticmethod |
|
@autocast_custom_fwd |
|
def forward(ctx, x, y, weight, bias): |
|
z = swiglu_fwd(x, y) |
|
out = fused_bitlinear.bit_linear(z, weight, bias) |
|
|
|
ctx.save_for_backward(x, y, weight) |
|
ctx.linear_bias_is_none = bias is None |
|
return out |
|
|
|
@staticmethod |
|
@autocast_custom_bwd |
|
def backward(ctx, dout, *args): |
|
x, y, weight = ctx.saved_tensors |
|
dout = dout.reshape(-1, dout.shape[-1]) |
|
dz = fused_bitlinear.bit_linear(dout, weight.t()).view_as(x) |
|
dx, dy, z = swiglu_bwd_with_output(x, y, dz) |
|
dlinear_weight = torch.einsum("bo,bi->oi", dout, z.reshape(-1, z.shape[-1])) |
|
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) |
|
return dx, dy, dlinear_weight, dlinear_bias |
|
|
|
|
|
swiglu = SwiGLUFunction.apply |
|
|
|
swiglu_linear = SwiGLULinearFunction.apply |
|
|
|
swiglu_bitlinear = SwiGLUBitLinearFunction.apply |
|
|
|
ACT2FN = { |
|
'relu': F.relu, |
|
'sigmoid': sigmoid, |
|
'logsigmoid': logsigmoid, |
|
'silu': swish, |
|
'swish': swish, |
|
'sqrelu': sqrelu, |
|
'gelu': fast_gelu_impl, |
|
'bias_gelu': bias_gelu_impl, |
|
} |
|
|