Somunia's picture
Upload 116 files
306b4ac verified
import math
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange, repeat
from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd
from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state_varlen
from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref
from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd
from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_chunk_scan, ssd_chunk_scan_combined_ref, ssd_selective_scan
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined, mamba_split_conv1d_scan_ref
def detach_clone(*args):
return tuple([arg.detach().clone().requires_grad_() if arg is not None else None for arg in args])
@pytest.mark.parametrize('dtype', [torch.float32, torch.float16, torch.bfloat16])
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('ngroups', [1, 2, 8, "max"])
# @pytest.mark.parametrize('ngroups', [1])
@pytest.mark.parametrize('chunk_size', [64, 128])
# @pytest.mark.parametrize('chunk_size', [128])
def test_chunk_state_varlen(chunk_size, ngroups, dtype):
device = 'cuda'
rtol, atol = (1e-2, 3e-3)
# set seed
torch.random.manual_seed(chunk_size + (ngroups if ngroups != "max" else 64))
batch = 300
seqlens = torch.randint(1, 200, (batch,), device=device)
# batch = 3
# seqlens = torch.tensor([201, 56, 5], device=device)
cu_seqlens = F.pad(seqlens.cumsum(0), (1, 0))
total_seqlen = seqlens.sum().item()
seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=device) for i, s in enumerate(seqlens)], dim=0).unsqueeze(0)
dim = 4096
# dim = 64
headdim = 64
# dim = 32
dstate = 32
assert dim % headdim == 0
nheads = dim // headdim
if ngroups == "max":
ngroups = nheads
assert nheads % ngroups == 0
B = torch.randn(total_seqlen, ngroups, dstate, dtype=dtype, device=device) / 5
x = torch.randn(total_seqlen, nheads, headdim, dtype=dtype, device=device)
A = -0.1 * (torch.rand(nheads, device=device))
dt = F.softplus(torch.randn(total_seqlen, nheads, device=device, dtype=torch.float32) - 4)
dA_cumsum, dt_rounded = _chunk_cumsum_fwd(dt.unsqueeze(0), A, chunk_size)
chunk_states = _chunk_state_fwd(B.unsqueeze(0), x.unsqueeze(0), dt_rounded, dA_cumsum, seq_idx=seq_idx)
chunk_states, _ = _state_passing_fwd(rearrange(chunk_states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
seq_idx=seq_idx, chunk_size=chunk_size)
chunk_states = rearrange(chunk_states, "... (p n) -> ... p n", n=dstate)
chunk_states = chunk_states.squeeze(0)
dA_cumsum = dA_cumsum.squeeze(0)
dt_rounded = dt_rounded.squeeze(0)
out = chunk_state_varlen(B, x, dt_rounded, dA_cumsum, cu_seqlens, chunk_states)
out_ref = []
for b in range(batch):
x_s = x[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0)
B_s = B[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0)
dt_s = dt[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0)
dA_cumsum_s, dt_rounded_s = _chunk_cumsum_fwd(dt_s, A, chunk_size)
states = chunk_state(B_s, x_s, dt_rounded_s, dA_cumsum_s)
_, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum_s[:, :, :, -1],
chunk_size=chunk_size)
final_states = rearrange(final_states, "... (p n) -> ... p n", n=dstate)
out_ref.append(final_states)
out_ref = torch.cat(out_ref, dim=0)
print(f"Max diff = {(out - out_ref).abs().max().item()}")
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)