|
from typing import List, Optional, Tuple |
|
import torch |
|
|
|
from mamba_ssm.ops.triton.ssd_combined import _mamba_chunk_scan_combined_fwd, _mamba_chunk_scan_combined_bwd |
|
|
|
|
|
@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) |
|
def _compiled_mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=None): |
|
return _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) |
|
|
|
@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) |
|
def _compiled_mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, dfinal_states=None, seq_idx=None, dt_softplus=False, dt_limit=None): |
|
return _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=dt_softplus, dt_limit=dt_limit) |
|
|
|
|
|
@torch.library.custom_op( |
|
"mamba_ssm::ssm_chunk_scan_combined_fwd", |
|
mutates_args=(), |
|
device_types="cuda", |
|
) |
|
def ssm_chunk_scan_combined_fwd( |
|
x: torch.Tensor, |
|
dt: torch.Tensor, |
|
A: torch.Tensor, |
|
B: torch.Tensor, |
|
C: torch.Tensor, |
|
chunk_size: int, |
|
D: Optional[torch.Tensor] = None, |
|
z: Optional[torch.Tensor] = None, |
|
dt_bias: Optional[torch.Tensor] = None, |
|
initial_states: Optional[torch.Tensor] = None, |
|
seq_idx: Optional[torch.Tensor] = None, |
|
cu_seqlens: Optional[torch.Tensor] = None, |
|
dt_softplus: bool = False, |
|
dt_limit: Optional[List[float]] = None |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) |
|
|
|
return out, out_x if out_x is not None else out.new_empty(0), rest[0] if cu_seqlens is not None else out.new_empty(0) |
|
|
|
@ssm_chunk_scan_combined_fwd.register_fake |
|
def _ssm_chunk_scan_combined_fwd_fake( |
|
x: torch.Tensor, |
|
dt: torch.Tensor, |
|
A: torch.Tensor, |
|
B: torch.Tensor, |
|
C: torch.Tensor, |
|
chunk_size: int, |
|
D: Optional[torch.Tensor] = None, |
|
z: Optional[torch.Tensor] = None, |
|
dt_bias: Optional[torch.Tensor] = None, |
|
initial_states: Optional[torch.Tensor] = None, |
|
seq_idx: Optional[torch.Tensor] = None, |
|
cu_seqlens: Optional[torch.Tensor] = None, |
|
dt_softplus: bool = False, |
|
dt_limit: Optional[List[float]] = None |
|
): |
|
_, _, n_heads, head_dim = x.shape |
|
return ( |
|
torch.empty_like(x), |
|
torch.empty_like(x) if z is not None else None, |
|
x.new_empty((cu_seqlens.size(0)-1, n_heads, head_dim, B.size(0))) if cu_seqlens is not None else None, |
|
) |
|
|
|
@torch.library.custom_op( |
|
"mamba_ssm::ssm_chunk_scan_combined_bwd", |
|
mutates_args=(), |
|
device_types="cuda", |
|
) |
|
def ssm_chunk_scan_combined_bwd( |
|
dout: torch.Tensor, |
|
x: torch.Tensor, |
|
dt: torch.Tensor, |
|
A: torch.Tensor, |
|
B: torch.Tensor, |
|
C: torch.Tensor, |
|
out: torch.Tensor, |
|
chunk_size: int, |
|
D: Optional[torch.Tensor] = None, |
|
z: Optional[torch.Tensor] = None, |
|
dt_bias: Optional[torch.Tensor] = None, |
|
initial_states: Optional[torch.Tensor] = None, |
|
seq_idx: Optional[torch.Tensor] = None, |
|
dt_softplus: bool = False, |
|
dt_limit: Optional[List[float]] = None |
|
)-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=None, seq_idx=seq_idx, dt_softplus=dt_softplus, dt_limit=dt_limit) |
|
return ( |
|
dx, |
|
ddt, |
|
dA, |
|
dB, |
|
dC, |
|
dD if dD is not None else dx.new_empty(0), |
|
dz if dz is not None else dx.new_empty(0), |
|
ddt_bias if ddt_bias is not None else dx.new_empty(0), |
|
dinitial_states if dinitial_states is not None else dx.new_empty(0) |
|
) |
|
|
|
@ssm_chunk_scan_combined_bwd.register_fake |
|
def _ssm_chunk_scan_combined_bwd_fake( |
|
dout: torch.Tensor, |
|
x: torch.Tensor, |
|
dt: torch.Tensor, |
|
A: torch.Tensor, |
|
B: torch.Tensor, |
|
C: torch.Tensor, |
|
out: torch.Tensor, |
|
chunk_size: int, |
|
D: Optional[torch.Tensor] = None, |
|
z: Optional[torch.Tensor] = None, |
|
dt_bias: Optional[torch.Tensor] = None, |
|
initial_states: Optional[torch.Tensor] = None, |
|
seq_idx: Optional[torch.Tensor] = None, |
|
dt_softplus: bool = False, |
|
dt_limit: Optional[List[float]] = None |
|
): |
|
return ( |
|
torch.empty_like(x), |
|
torch.empty_like(dt), |
|
torch.empty_like(A), |
|
torch.empty_like(B), |
|
torch.empty_like(C), |
|
torch.empty_like(D) if D is not None else None, |
|
torch.empty_like(z) if z is not None else None, |
|
torch.empty_like(dt_bias) if dt_bias is not None else None, |
|
torch.empty_like(initial_states) if initial_states is not None else None, |
|
) |
|
|
|
|
|
def ssm_chunk_scan_combined_setup_context(ctx, inputs, output): |
|
x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit = inputs |
|
out, out_x, state_varlen = output |
|
|
|
ctx.save_for_backward(out if z is None else out_x, x, dt, A, B, C, D, z, dt_bias, initial_states, seq_idx) |
|
ctx.dt_softplus = dt_softplus |
|
ctx.chunk_size = chunk_size |
|
ctx.dt_limit = dt_limit |
|
|
|
def ssm_chunk_scan_combined_bridge(ctx, dout, dout_x, dout_state_varlen): |
|
out, x, dt, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors |
|
|
|
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = ssm_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit) |
|
|
|
return ( |
|
dx, |
|
ddt, |
|
dA, |
|
dB, |
|
dC, |
|
None, |
|
dD if D is not None else None, |
|
dz if z is not None else None, |
|
ddt_bias if dt_bias is not None else None, |
|
dinitial_states if initial_states is not None else None, |
|
None, |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
torch.library.register_autograd( |
|
"mamba_ssm::ssm_chunk_scan_combined_fwd", |
|
ssm_chunk_scan_combined_bridge, |
|
setup_context=ssm_chunk_scan_combined_setup_context, |
|
) |
|
|
|
def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): |
|
""" |
|
Argument: |
|
x: (batch, seqlen, nheads, headdim) |
|
dt: (batch, seqlen, nheads) |
|
A: (nheads) |
|
B: (batch, seqlen, ngroups, dstate) |
|
C: (batch, seqlen, ngroups, dstate) |
|
chunk_size: int |
|
D: (nheads, headdim) or (nheads,) |
|
z: (batch, seqlen, nheads, headdim) |
|
dt_bias: (nheads,) |
|
initial_states: (batch, nheads, headdim, dstate) |
|
seq_idx: (batch, seqlen) |
|
cu_seqlens: (num_sequences + 1) or None |
|
dt_softplus: Whether to apply softplus to dt |
|
Return: |
|
out: (batch, seqlen, nheads, headdim) |
|
""" |
|
|
|
out, _, varlen_states = ssm_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) |
|
if cu_seqlens is not None: |
|
return out, varlen_states |
|
return out |
|
|
|
if __name__ == "__main__": |
|
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined as mamba_chunk_scan_combined_ref |
|
|
|
torch.manual_seed(0) |
|
torch.cuda.manual_seed(0) |
|
|
|
x = torch.randn(2, 3, 4, 5).cuda() |
|
dt = torch.randn(2, 3, 4).cuda() |
|
A = torch.randn(4).cuda() |
|
B = torch.randn(2, 3, 4, 5).cuda() |
|
C = torch.randn(2, 3, 4, 5).cuda() |
|
chunk_size = 2 |
|
D = torch.randn(4, 5).cuda() |
|
z = torch.randn(2, 3, 4, 5).cuda() |
|
dt_bias = torch.randn(4).cuda() |
|
|
|
out = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias) |
|
|
|
print(out.min(), out.max(), out.mean(), out.std()) |
|
|
|
compiled_mamba_chunk_scan_combined = torch.compile(mamba_chunk_scan_combined) |
|
out = compiled_mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias) |
|
|
|
print(out.min(), out.max(), out.mean(), out.std()) |
|
|
|
out_ref = mamba_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias) |
|
|
|
print(out_ref.min(), out_ref.max(), out_ref.mean(), out_ref.std()) |