flash-mla / build.toml
drbh
feat: build flash mla with kernel builder
1f83cde
[general]
name = "flash_mla"
[torch]
src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]
[kernel.activation]
cuda-capabilities = [
# "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9",
# Only available on H100 and H200
"9.0", # (Hopper)
]
src = [
"flash_mla/flash_mla_api.cu",
"flash_mla/flash_fwd_mla_bf16_sm90.cu",
"flash_mla/flash_fwd_mla_fp16_sm90.cu",
"flash_mla/flash_fwd_mla_kernel.h",
"flash_mla/flash_fwd_mla_metadata.cu",
"flash_mla/flash_mla.h",
"flash_mla/named_barrier.h",
"flash_mla/softmax.h",
"flash_mla/static_switch.h",
"flash_mla/utils.h",
]
depends = ["torch", "cutlass_3_6"]