Mamba_500M / causal_conv1d_compilable.py
yagizdevre's picture
model is added
be761d6
from typing import Optional, Tuple
import torch
import causal_conv1d_cuda
# Causal Conv1D Forward Function
@torch.library.custom_op(
"mamba_causal_conv1d::causal_conv1d_fwd",
mutates_args=(),
device_types="cuda",
)
def causal_conv1d_fwd(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None,
activation: Optional[str] = None,
) -> torch.Tensor:
# Ensure activation is valid
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
# Ensure x is contiguous
if x.stride(2) != 1 and x.stride(1) != 1:
x = x.contiguous()
# Make bias and seq_idx contiguous if they exist
bias = bias.contiguous() if bias is not None else None
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
# Translate activation to bool for custom CUDA kernel
use_activation = activation in ["silu", "swish"]
# Call custom CUDA kernel for forward pass
out = causal_conv1d_cuda.causal_conv1d_fwd(
x, weight, bias, seq_idx, None, None, use_activation
)
return out
# Register a fake forward pass for tracing
@causal_conv1d_fwd.register_fake
def _causal_conv1d_fwd_fake(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None,
activation: Optional[str] = None,
) -> torch.Tensor:
torch._check(x.shape[-2] == weight.shape[0])
return torch.empty_like(x)
# Causal Conv1D Backward Function
@torch.library.custom_op(
"mamba_causal_conv1d::causal_conv1d_bwd",
mutates_args=(),
device_types="cuda",
)
def causal_conv1d_bwd(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
dout: torch.Tensor,
seq_idx: Optional[torch.Tensor],
activation: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Ensure dout is contiguous
if dout.stride(2) != 1 and dout.stride(1) != 1:
dout = dout.contiguous()
# Call custom CUDA kernel for backward pass
dx, dweight, dbias, _ = causal_conv1d_cuda.causal_conv1d_bwd(
x, weight, bias, dout, seq_idx, None, None, None, False, activation
)
# Handle optional bias gradient
dbias = dbias if bias is not None else torch.empty((0,), device=dout.device)
return dx, dweight, dbias
# Register a fake backward pass for tracing
@causal_conv1d_bwd.register_fake
def _causal_conv1d_bwd_fake(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
dout: torch.Tensor,
seq_idx: Optional[torch.Tensor],
activation: bool,
):
return (
torch.empty_like(x),
torch.empty_like(weight),
torch.empty_like(bias) if bias is not None else None,
)
# Setup context for autograd
def causal_conv1d_setup_context(ctx, inputs, output):
x, weight, bias, seq_idx, activation = inputs
ctx.activation = activation in ["silu", "swish"]
ctx.save_for_backward(x, weight, bias, seq_idx)
# Bridge for backward pass in autograd
def causal_conv1d_bwd_bridge(ctx, dout):
x, weight, bias, seq_idx = ctx.saved_tensors
dx, dweight, dbias = causal_conv1d_bwd(x, weight, bias, dout, seq_idx, ctx.activation)
# Handle None return values
dbias = dbias if bias is not None else None
return dx, dweight, dbias, None, None
# Register custom autograd function
torch.library.register_autograd(
"mamba_causal_conv1d::causal_conv1d_fwd",
causal_conv1d_bwd_bridge,
setup_context=causal_conv1d_setup_context,
)
# Define a higher-level function to invoke the custom op
def causal_conv1d_fn(x, weight, bias=None, seq_idx=None, activation=None):
return causal_conv1d_fwd(x, weight, bias, seq_idx, activation)
@torch.library.custom_op(
"mamba_causal_conv1d::causal_conv1d_update",
mutates_args=(),
device_types="cuda",
)
def causal_conv1d_update_fwd(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None,
cache_seqlens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state starting at the index
@cache_seqlens % state_len.
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
activation = activation in ["silu", "swish"]
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
out = causal_conv1d_cuda.causal_conv1d_update(
x, conv_state, weight, bias, activation, cache_seqlens
)
if unsqueeze:
out = out.squeeze(-1)
return out
@causal_conv1d_update_fwd.register_fake
def _causal_conv1d_update_fwd(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None,
cache_seqlens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(x)
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
return causal_conv1d_update_fwd(x, conv_state, weight, bias, activation, cache_seqlens)
# Test the implementation
if __name__ == "__main__":
from causal_conv1d import causal_conv1d_fn as causal_conv1d_fn_ref
torch.manual_seed(0)
x = torch.randn(8, 32, 16, device="cuda", requires_grad=True)
weight = torch.randn(32, 3, device="cuda", requires_grad=True)
bias = None#torch.randn(32, device="cuda", requires_grad=True)
# Test the forward and backward pass
print("Custom Implementation")
out = causal_conv1d_fn(x, weight, bias, activation="silu")
out.sum().backward()
print(out.min(), out.max(), out.mean(), out.std())
print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())
# Try compiling the function using torch.compile
x.grad.zero_(), weight.grad.zero_()
compiled_conv1d = torch.compile(causal_conv1d_fn)
print(compiled_conv1d)
# Run the compiled function
print("Compiled Implementation")
out = compiled_conv1d(x, weight, bias, activation="silu")
out.sum().backward()
print(out.min(), out.max(), out.mean(), out.std())
print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())
print("Reference Implementation")
x.grad.zero_(), weight.grad.zero_()
out = causal_conv1d_fn_ref(x, weight, bias, activation="silu")
out.sum().backward()
print(out.min(), out.max(), out.mean(), out.std())
print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())