|
from typing import Optional, Tuple |
|
import torch |
|
import causal_conv1d_cuda |
|
|
|
|
|
@torch.library.custom_op( |
|
"mamba_causal_conv1d::causal_conv1d_fwd", |
|
mutates_args=(), |
|
device_types="cuda", |
|
) |
|
def causal_conv1d_fwd( |
|
x: torch.Tensor, |
|
weight: torch.Tensor, |
|
bias: Optional[torch.Tensor] = None, |
|
seq_idx: Optional[torch.Tensor] = None, |
|
activation: Optional[str] = None, |
|
) -> torch.Tensor: |
|
|
|
if activation not in [None, "silu", "swish"]: |
|
raise NotImplementedError("activation must be None, silu, or swish") |
|
|
|
|
|
if x.stride(2) != 1 and x.stride(1) != 1: |
|
x = x.contiguous() |
|
|
|
|
|
bias = bias.contiguous() if bias is not None else None |
|
seq_idx = seq_idx.contiguous() if seq_idx is not None else None |
|
|
|
|
|
use_activation = activation in ["silu", "swish"] |
|
|
|
|
|
out = causal_conv1d_cuda.causal_conv1d_fwd( |
|
x, weight, bias, seq_idx, None, None, use_activation |
|
) |
|
return out |
|
|
|
|
|
@causal_conv1d_fwd.register_fake |
|
def _causal_conv1d_fwd_fake( |
|
x: torch.Tensor, |
|
weight: torch.Tensor, |
|
bias: Optional[torch.Tensor] = None, |
|
seq_idx: Optional[torch.Tensor] = None, |
|
activation: Optional[str] = None, |
|
) -> torch.Tensor: |
|
torch._check(x.shape[-2] == weight.shape[0]) |
|
return torch.empty_like(x) |
|
|
|
|
|
@torch.library.custom_op( |
|
"mamba_causal_conv1d::causal_conv1d_bwd", |
|
mutates_args=(), |
|
device_types="cuda", |
|
) |
|
def causal_conv1d_bwd( |
|
x: torch.Tensor, |
|
weight: torch.Tensor, |
|
bias: Optional[torch.Tensor], |
|
dout: torch.Tensor, |
|
seq_idx: Optional[torch.Tensor], |
|
activation: bool, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
if dout.stride(2) != 1 and dout.stride(1) != 1: |
|
dout = dout.contiguous() |
|
|
|
|
|
dx, dweight, dbias, _ = causal_conv1d_cuda.causal_conv1d_bwd( |
|
x, weight, bias, dout, seq_idx, None, None, None, False, activation |
|
) |
|
|
|
|
|
dbias = dbias if bias is not None else torch.empty((0,), device=dout.device) |
|
|
|
return dx, dweight, dbias |
|
|
|
|
|
@causal_conv1d_bwd.register_fake |
|
def _causal_conv1d_bwd_fake( |
|
x: torch.Tensor, |
|
weight: torch.Tensor, |
|
bias: Optional[torch.Tensor], |
|
dout: torch.Tensor, |
|
seq_idx: Optional[torch.Tensor], |
|
activation: bool, |
|
): |
|
return ( |
|
torch.empty_like(x), |
|
torch.empty_like(weight), |
|
torch.empty_like(bias) if bias is not None else None, |
|
) |
|
|
|
|
|
def causal_conv1d_setup_context(ctx, inputs, output): |
|
x, weight, bias, seq_idx, activation = inputs |
|
ctx.activation = activation in ["silu", "swish"] |
|
ctx.save_for_backward(x, weight, bias, seq_idx) |
|
|
|
|
|
def causal_conv1d_bwd_bridge(ctx, dout): |
|
x, weight, bias, seq_idx = ctx.saved_tensors |
|
dx, dweight, dbias = causal_conv1d_bwd(x, weight, bias, dout, seq_idx, ctx.activation) |
|
|
|
|
|
dbias = dbias if bias is not None else None |
|
return dx, dweight, dbias, None, None |
|
|
|
|
|
torch.library.register_autograd( |
|
"mamba_causal_conv1d::causal_conv1d_fwd", |
|
causal_conv1d_bwd_bridge, |
|
setup_context=causal_conv1d_setup_context, |
|
) |
|
|
|
|
|
def causal_conv1d_fn(x, weight, bias=None, seq_idx=None, activation=None): |
|
return causal_conv1d_fwd(x, weight, bias, seq_idx, activation) |
|
|
|
|
|
@torch.library.custom_op( |
|
"mamba_causal_conv1d::causal_conv1d_update", |
|
mutates_args=(), |
|
device_types="cuda", |
|
) |
|
def causal_conv1d_update_fwd( |
|
x: torch.Tensor, |
|
conv_state: torch.Tensor, |
|
weight: torch.Tensor, |
|
bias: Optional[torch.Tensor] = None, |
|
activation: Optional[str] = None, |
|
cache_seqlens: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
""" |
|
x: (batch, dim) or (batch, dim, seqlen) |
|
conv_state: (batch, dim, state_len), where state_len >= width - 1 |
|
weight: (dim, width) |
|
bias: (dim,) |
|
cache_seqlens: (batch,), dtype int32. |
|
If not None, the conv_state is treated as a circular buffer. |
|
The conv_state will be updated by copying x to the conv_state starting at the index |
|
@cache_seqlens % state_len. |
|
|
|
out: (batch, dim) or (batch, dim, seqlen) |
|
""" |
|
if activation not in [None, "silu", "swish"]: |
|
raise NotImplementedError("activation must be None, silu, or swish") |
|
activation = activation in ["silu", "swish"] |
|
unsqueeze = x.dim() == 2 |
|
if unsqueeze: |
|
x = x.unsqueeze(-1) |
|
out = causal_conv1d_cuda.causal_conv1d_update( |
|
x, conv_state, weight, bias, activation, cache_seqlens |
|
) |
|
if unsqueeze: |
|
out = out.squeeze(-1) |
|
return out |
|
|
|
@causal_conv1d_update_fwd.register_fake |
|
def _causal_conv1d_update_fwd( |
|
x: torch.Tensor, |
|
conv_state: torch.Tensor, |
|
weight: torch.Tensor, |
|
bias: Optional[torch.Tensor] = None, |
|
activation: Optional[str] = None, |
|
cache_seqlens: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
return torch.empty_like(x) |
|
|
|
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): |
|
return causal_conv1d_update_fwd(x, conv_state, weight, bias, activation, cache_seqlens) |
|
|
|
|
|
if __name__ == "__main__": |
|
from causal_conv1d import causal_conv1d_fn as causal_conv1d_fn_ref |
|
|
|
torch.manual_seed(0) |
|
|
|
x = torch.randn(8, 32, 16, device="cuda", requires_grad=True) |
|
weight = torch.randn(32, 3, device="cuda", requires_grad=True) |
|
bias = None |
|
|
|
|
|
print("Custom Implementation") |
|
out = causal_conv1d_fn(x, weight, bias, activation="silu") |
|
out.sum().backward() |
|
|
|
print(out.min(), out.max(), out.mean(), out.std()) |
|
print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std()) |
|
print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std()) |
|
|
|
|
|
x.grad.zero_(), weight.grad.zero_() |
|
compiled_conv1d = torch.compile(causal_conv1d_fn) |
|
print(compiled_conv1d) |
|
|
|
|
|
print("Compiled Implementation") |
|
out = compiled_conv1d(x, weight, bias, activation="silu") |
|
out.sum().backward() |
|
|
|
print(out.min(), out.max(), out.mean(), out.std()) |
|
print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std()) |
|
print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std()) |
|
|
|
print("Reference Implementation") |
|
x.grad.zero_(), weight.grad.zero_() |
|
out = causal_conv1d_fn_ref(x, weight, bias, activation="silu") |
|
out.sum().backward() |
|
|
|
print(out.min(), out.max(), out.mean(), out.std()) |
|
print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std()) |
|
print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std()) |
|
|
|
|