drbh
commited on
Commit
·
1f83cde
0
Parent(s):
feat: build flash mla with kernel builder
Browse files- .gitignore +2 -0
- README.md +22 -0
- build.toml +27 -0
- flake.lock +117 -0
- flake.nix +14 -0
- flash_mla/flash_fwd_mla_bf16_sm90.cu +3 -0
- flash_mla/flash_fwd_mla_fp16_sm90.cu +3 -0
- flash_mla/flash_fwd_mla_kernel.h +603 -0
- flash_mla/flash_fwd_mla_metadata.cu +77 -0
- flash_mla/flash_mla.h +63 -0
- flash_mla/flash_mla_api.cu +259 -0
- flash_mla/named_barrier.h +15 -0
- flash_mla/softmax.h +197 -0
- flash_mla/static_switch.h +65 -0
- flash_mla/utils.h +238 -0
- tests/__init__.py +0 -0
- tests/test_flash_mla.py +69 -0
- torch-ext/flash_mla/__init__.py +36 -0
- torch-ext/torch_binding.cpp +15 -0
- torch-ext/torch_binding.h +36 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.bak
|
2 |
+
__pycache__
|
README.md
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
tags:
|
3 |
+
- kernel
|
4 |
+
- flash-mla
|
5 |
+
- deepseek
|
6 |
+
- kernel-builder
|
7 |
+
---
|
8 |
+
|
9 |
+
# flash-mla
|
10 |
+
|
11 |
+
This repo builds Deepseeks [FlashMLA](https://github.com/deepseek-ai/FlashMLA) kernel via the HF [kernel-builder](https://github.com/huggingface/kernel-builder)
|
12 |
+
|
13 |
+
### Dev
|
14 |
+
```bash
|
15 |
+
nix develop -L
|
16 |
+
pytest -vv tests/
|
17 |
+
```
|
18 |
+
|
19 |
+
### Build
|
20 |
+
```bash
|
21 |
+
nix build .#bundle -L
|
22 |
+
```
|
build.toml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
name = "flash_mla"
|
3 |
+
|
4 |
+
[torch]
|
5 |
+
src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]
|
6 |
+
|
7 |
+
|
8 |
+
[kernel.activation]
|
9 |
+
cuda-capabilities = [
|
10 |
+
# "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9",
|
11 |
+
|
12 |
+
# Only available on H100 and H200
|
13 |
+
"9.0", # (Hopper)
|
14 |
+
]
|
15 |
+
src = [
|
16 |
+
"flash_mla/flash_mla_api.cu",
|
17 |
+
"flash_mla/flash_fwd_mla_bf16_sm90.cu",
|
18 |
+
"flash_mla/flash_fwd_mla_fp16_sm90.cu",
|
19 |
+
"flash_mla/flash_fwd_mla_kernel.h",
|
20 |
+
"flash_mla/flash_fwd_mla_metadata.cu",
|
21 |
+
"flash_mla/flash_mla.h",
|
22 |
+
"flash_mla/named_barrier.h",
|
23 |
+
"flash_mla/softmax.h",
|
24 |
+
"flash_mla/static_switch.h",
|
25 |
+
"flash_mla/utils.h",
|
26 |
+
]
|
27 |
+
depends = ["torch", "cutlass_3_6"]
|
flake.lock
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nodes": {
|
3 |
+
"flake-compat": {
|
4 |
+
"locked": {
|
5 |
+
"lastModified": 1733328505,
|
6 |
+
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
7 |
+
"owner": "edolstra",
|
8 |
+
"repo": "flake-compat",
|
9 |
+
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
10 |
+
"type": "github"
|
11 |
+
},
|
12 |
+
"original": {
|
13 |
+
"owner": "edolstra",
|
14 |
+
"repo": "flake-compat",
|
15 |
+
"type": "github"
|
16 |
+
}
|
17 |
+
},
|
18 |
+
"flake-utils": {
|
19 |
+
"inputs": {
|
20 |
+
"systems": "systems"
|
21 |
+
},
|
22 |
+
"locked": {
|
23 |
+
"lastModified": 1731533236,
|
24 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
25 |
+
"owner": "numtide",
|
26 |
+
"repo": "flake-utils",
|
27 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
28 |
+
"type": "github"
|
29 |
+
},
|
30 |
+
"original": {
|
31 |
+
"owner": "numtide",
|
32 |
+
"repo": "flake-utils",
|
33 |
+
"type": "github"
|
34 |
+
}
|
35 |
+
},
|
36 |
+
"kernel-builder": {
|
37 |
+
"inputs": {
|
38 |
+
"flake-compat": "flake-compat",
|
39 |
+
"flake-utils": "flake-utils",
|
40 |
+
"nixpkgs": "nixpkgs",
|
41 |
+
"rocm-nix": "rocm-nix"
|
42 |
+
},
|
43 |
+
"locked": {
|
44 |
+
"lastModified": 1740571741,
|
45 |
+
"narHash": "sha256-MIy7OBrhz8OqFaLT3/MpvL+IGz720le2Ru2wGYPdswo=",
|
46 |
+
"ref": "refs/heads/main",
|
47 |
+
"rev": "5062dad8a4818d7239f59e1153b8c18b8b59da74",
|
48 |
+
"revCount": 89,
|
49 |
+
"type": "git",
|
50 |
+
"url": "ssh://[email protected]/huggingface/kernel-builder"
|
51 |
+
},
|
52 |
+
"original": {
|
53 |
+
"type": "git",
|
54 |
+
"url": "ssh://[email protected]/huggingface/kernel-builder"
|
55 |
+
}
|
56 |
+
},
|
57 |
+
"nixpkgs": {
|
58 |
+
"locked": {
|
59 |
+
"lastModified": 1740344854,
|
60 |
+
"narHash": "sha256-+TiHtSOo+RPUNrcfkcGXmapJ40O3gt6yOe2nA8y0KPw=",
|
61 |
+
"owner": "nixos",
|
62 |
+
"repo": "nixpkgs",
|
63 |
+
"rev": "0b3aa63c013cf9302afc0ba5dbd81f8fab7bd94f",
|
64 |
+
"type": "github"
|
65 |
+
},
|
66 |
+
"original": {
|
67 |
+
"owner": "nixos",
|
68 |
+
"ref": "nixos-unstable-small",
|
69 |
+
"repo": "nixpkgs",
|
70 |
+
"type": "github"
|
71 |
+
}
|
72 |
+
},
|
73 |
+
"rocm-nix": {
|
74 |
+
"inputs": {
|
75 |
+
"nixpkgs": [
|
76 |
+
"kernel-builder",
|
77 |
+
"nixpkgs"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
"locked": {
|
81 |
+
"lastModified": 1740473629,
|
82 |
+
"narHash": "sha256-xW5RfZScKmFymmwdBSZvNZxvFzQSJu8lgF0cQKomT2E=",
|
83 |
+
"owner": "huggingface",
|
84 |
+
"repo": "rocm-nix",
|
85 |
+
"rev": "e4b51d092caf52c693c330c177369adbf6a153ba",
|
86 |
+
"type": "github"
|
87 |
+
},
|
88 |
+
"original": {
|
89 |
+
"owner": "huggingface",
|
90 |
+
"repo": "rocm-nix",
|
91 |
+
"type": "github"
|
92 |
+
}
|
93 |
+
},
|
94 |
+
"root": {
|
95 |
+
"inputs": {
|
96 |
+
"kernel-builder": "kernel-builder"
|
97 |
+
}
|
98 |
+
},
|
99 |
+
"systems": {
|
100 |
+
"locked": {
|
101 |
+
"lastModified": 1681028828,
|
102 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
103 |
+
"owner": "nix-systems",
|
104 |
+
"repo": "default",
|
105 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
106 |
+
"type": "github"
|
107 |
+
},
|
108 |
+
"original": {
|
109 |
+
"owner": "nix-systems",
|
110 |
+
"repo": "default",
|
111 |
+
"type": "github"
|
112 |
+
}
|
113 |
+
}
|
114 |
+
},
|
115 |
+
"root": "root",
|
116 |
+
"version": 7
|
117 |
+
}
|
flake.nix
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
description = "Flake for FlashMLA kernel";
|
3 |
+
|
4 |
+
inputs = {
|
5 |
+
kernel-builder.url = "git+ssh://[email protected]/huggingface/kernel-builder";
|
6 |
+
};
|
7 |
+
|
8 |
+
outputs =
|
9 |
+
{
|
10 |
+
self,
|
11 |
+
kernel-builder,
|
12 |
+
}:
|
13 |
+
kernel-builder.lib.genFlakeOutputs ./.;
|
14 |
+
}
|
flash_mla/flash_fwd_mla_bf16_sm90.cu
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#include "flash_fwd_mla_kernel.h"
|
2 |
+
|
3 |
+
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
flash_mla/flash_fwd_mla_fp16_sm90.cu
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#include "flash_fwd_mla_kernel.h"
|
2 |
+
|
3 |
+
template void run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
flash_mla/flash_fwd_mla_kernel.h
ADDED
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <cute/tensor.hpp>
|
4 |
+
#include <cutlass/cutlass.h>
|
5 |
+
#include <cutlass/array.h>
|
6 |
+
#include <cutlass/numeric_types.h>
|
7 |
+
|
8 |
+
using namespace cute;
|
9 |
+
|
10 |
+
#include "named_barrier.h"
|
11 |
+
#include "utils.h"
|
12 |
+
#include "softmax.h"
|
13 |
+
#include "static_switch.h"
|
14 |
+
#include "flash_mla.h"
|
15 |
+
|
16 |
+
|
17 |
+
template<typename PrecType, int DIM, int DIM2 = DIM>
|
18 |
+
constexpr auto getSmemLayoutK() {
|
19 |
+
constexpr int headSizeBytes = sizeof(PrecType) * DIM;
|
20 |
+
constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2;
|
21 |
+
|
22 |
+
if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) {
|
23 |
+
return GMMA::Layout_K_SW128_Atom<PrecType>{};
|
24 |
+
} else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) {
|
25 |
+
return GMMA::Layout_K_SW64_Atom<PrecType>{};
|
26 |
+
} else {
|
27 |
+
return GMMA::Layout_K_SW32_Atom<PrecType>{};
|
28 |
+
}
|
29 |
+
}
|
30 |
+
|
31 |
+
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, int kHeadDimV_ = 0>
|
32 |
+
struct Flash_fwd_kernel_traits_mla {
|
33 |
+
using Element = elem_type;
|
34 |
+
using ElementAccum = float;
|
35 |
+
using index_t = int64_t;
|
36 |
+
|
37 |
+
static constexpr int kNWarps = kNWarps_;
|
38 |
+
static constexpr int kNThreads = kNWarps * 32;
|
39 |
+
static constexpr int kNWarpsS = 4;
|
40 |
+
static constexpr int kNThreadsS = kNWarpsS * 32;
|
41 |
+
|
42 |
+
static constexpr int kBlockM = kBlockM_;
|
43 |
+
static constexpr int kBlockN = kBlockN_;
|
44 |
+
static constexpr int kHeadDim = kHeadDim_;
|
45 |
+
static_assert(kHeadDim % 32 == 0);
|
46 |
+
static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim;
|
47 |
+
static_assert(kHeadDimV % 32 == 0);
|
48 |
+
static_assert(kHeadDimV <= kHeadDim);
|
49 |
+
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
50 |
+
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
51 |
+
|
52 |
+
using TiledMma = decltype(make_tiled_mma(
|
53 |
+
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>,
|
54 |
+
GMMA::Major::K, GMMA::Major::K>(),
|
55 |
+
Layout<Shape<Int<kNWarpsS / 4>, _1, _1>>{}));
|
56 |
+
|
57 |
+
static constexpr int AtomLayoutNO = kNThreads / kNThreadsS;
|
58 |
+
using TiledMmaO = decltype(make_tiled_mma(
|
59 |
+
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kHeadDimV / AtomLayoutNO>, Int<kBlockN>>,
|
60 |
+
GMMA::Major::K, GMMA::Major::MN>(),
|
61 |
+
Layout<Shape<Int<kNWarpsS / 4>, Int<AtomLayoutNO>, _1>>{}));
|
62 |
+
|
63 |
+
using SmemLayoutQ = decltype(tile_to_shape(
|
64 |
+
getSmemLayoutK<Element, kHeadDim>(),
|
65 |
+
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
66 |
+
|
67 |
+
using SmemLayoutK = decltype(tile_to_shape(
|
68 |
+
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
69 |
+
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
70 |
+
|
71 |
+
using SmemLayoutV = decltype(tile_to_shape(
|
72 |
+
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
73 |
+
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
|
74 |
+
using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
|
75 |
+
|
76 |
+
using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, Int<kBlockN / 8>>>;
|
77 |
+
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
|
78 |
+
|
79 |
+
using SmemLayoutAtomO = decltype(composition(
|
80 |
+
Swizzle<kSwizzle, 3, 3>{},
|
81 |
+
Layout<Shape<Int<8>, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
|
82 |
+
using SmemLayoutO = decltype(tile_to_shape(
|
83 |
+
SmemLayoutAtomO{},
|
84 |
+
Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
|
85 |
+
using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, Element>;
|
86 |
+
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
|
87 |
+
|
88 |
+
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
89 |
+
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
90 |
+
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
91 |
+
using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
|
92 |
+
static constexpr int kNThreadsLoad = kNThreads - kNThreadsS;
|
93 |
+
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
94 |
+
|
95 |
+
using GmemLayoutAtom = Layout<
|
96 |
+
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
97 |
+
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
98 |
+
using GmemTiledCopy = decltype(make_tiled_copy(
|
99 |
+
Copy_Atom<Gmem_copy_struct, Element>{},
|
100 |
+
GmemLayoutAtom{},
|
101 |
+
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
102 |
+
|
103 |
+
using GmemLayoutAtomO = Layout<
|
104 |
+
Shape<Int<kNThreadsS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
105 |
+
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
106 |
+
using GmemTiledCopyO = decltype(make_tiled_copy(
|
107 |
+
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
108 |
+
GmemLayoutAtomO{},
|
109 |
+
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
110 |
+
|
111 |
+
static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
|
112 |
+
static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum;
|
113 |
+
using GmemLayoutAtomOaccum = Layout<
|
114 |
+
Shape<Int<kNThreadsS / kGmemThreadsPerRowAccum>, Int<kGmemThreadsPerRowAccum>>,
|
115 |
+
Stride<Int<kGmemThreadsPerRowAccum>, _1>>;
|
116 |
+
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
|
117 |
+
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
118 |
+
GmemLayoutAtomOaccum{},
|
119 |
+
Layout<Shape<_1, _4>>{})); // Val layout, 4 vals per store
|
120 |
+
};
|
121 |
+
|
122 |
+
namespace flash {
|
123 |
+
|
124 |
+
using namespace cute;
|
125 |
+
|
126 |
+
template<typename Kernel_traits>
|
127 |
+
struct SharedStorageMLA {
|
128 |
+
union {
|
129 |
+
struct {
|
130 |
+
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
|
131 |
+
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k; // Double buffer
|
132 |
+
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
|
133 |
+
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
|
134 |
+
};
|
135 |
+
struct {
|
136 |
+
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_max;
|
137 |
+
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_sum;
|
138 |
+
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
|
139 |
+
};
|
140 |
+
};
|
141 |
+
};
|
142 |
+
|
143 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
144 |
+
|
145 |
+
template<typename Kernel_traits, bool Split, typename SharedStorage, typename AccO, typename Softmax>
|
146 |
+
__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx,
|
147 |
+
SharedStorage &shared_storage, AccO tOrO, Softmax softmax) {
|
148 |
+
constexpr int kBlockM = Kernel_traits::kBlockM;
|
149 |
+
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
|
150 |
+
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
|
151 |
+
using Element = typename Kernel_traits::Element;
|
152 |
+
using ElementAccum = typename Kernel_traits::ElementAccum;
|
153 |
+
using index_t = typename Kernel_traits::index_t;
|
154 |
+
|
155 |
+
const int tidx = threadIdx.x;
|
156 |
+
|
157 |
+
typename Kernel_traits::TiledMmaO tiled_mma_o;
|
158 |
+
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
|
159 |
+
|
160 |
+
// Epilogue
|
161 |
+
|
162 |
+
const int split_offset = __ldg(params.num_splits_ptr + bidb);
|
163 |
+
|
164 |
+
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(tOrO, params.scale_softmax);
|
165 |
+
|
166 |
+
using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
|
167 |
+
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
168 |
+
// Partition sO to match the accumulator partitioning
|
169 |
+
using SmemTiledCopyO = std::conditional_t<
|
170 |
+
!Split,
|
171 |
+
typename Kernel_traits::SmemCopyAtomO,
|
172 |
+
typename Kernel_traits::SmemCopyAtomOaccum
|
173 |
+
>;
|
174 |
+
auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o);
|
175 |
+
auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
176 |
+
Tensor rO = flash::convert_type<ElementO>(tOrO);
|
177 |
+
Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
|
178 |
+
Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
179 |
+
|
180 |
+
__syncthreads();
|
181 |
+
|
182 |
+
cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
|
183 |
+
|
184 |
+
const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
185 |
+
const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v;
|
186 |
+
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
187 |
+
const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
188 |
+
|
189 |
+
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
|
190 |
+
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
|
191 |
+
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
|
192 |
+
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)),
|
193 |
+
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
194 |
+
|
195 |
+
using GmemTiledCopyO = std::conditional_t<!Split, typename Kernel_traits::GmemTiledCopyO, typename Kernel_traits::GmemTiledCopyOaccum>;
|
196 |
+
GmemTiledCopyO gmem_tiled_copy_Oaccum;
|
197 |
+
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
198 |
+
Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
199 |
+
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
|
200 |
+
|
201 |
+
__syncthreads();
|
202 |
+
|
203 |
+
if (tidx >= kNThreadsS) { return; }
|
204 |
+
|
205 |
+
Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
|
206 |
+
cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
|
207 |
+
|
208 |
+
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
209 |
+
Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1)
|
210 |
+
Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0);
|
211 |
+
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
|
212 |
+
if (get<1>(taccOcO_row(0)) == 0) {
|
213 |
+
#pragma unroll
|
214 |
+
for (int mi = 0; mi < size(lse); ++mi) {
|
215 |
+
const int row = get<0>(taccOcO_row(mi));
|
216 |
+
if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
|
217 |
+
}
|
218 |
+
}
|
219 |
+
|
220 |
+
// Construct identity layout for sO
|
221 |
+
Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
222 |
+
// Repeat the partitioning with identity layouts
|
223 |
+
Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
224 |
+
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
|
225 |
+
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
226 |
+
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
227 |
+
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM
|
228 |
+
);
|
229 |
+
}
|
230 |
+
|
231 |
+
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
|
232 |
+
__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params ¶ms,
|
233 |
+
const int bidb, const int bidh, const int m_block,
|
234 |
+
const int n_split_idx, const int seqlen_k,
|
235 |
+
const int n_block_min, const int n_block_max, const bool NoSplit,
|
236 |
+
SharedStorage &shared_storage) {
|
237 |
+
constexpr int kBlockM = Kernel_traits::kBlockM;
|
238 |
+
constexpr int kBlockN = Kernel_traits::kBlockN;
|
239 |
+
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
240 |
+
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
|
241 |
+
constexpr int kNThreads = Kernel_traits::kNThreads;
|
242 |
+
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
|
243 |
+
static_assert(kNThreads == 256 and kNThreadsS == 128);
|
244 |
+
using Element = typename Kernel_traits::Element;
|
245 |
+
using index_t = typename Kernel_traits::index_t;
|
246 |
+
|
247 |
+
const int tidx = threadIdx.x;
|
248 |
+
int n_block = n_block_max - 1;
|
249 |
+
|
250 |
+
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
|
251 |
+
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
|
252 |
+
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
|
253 |
+
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{});
|
254 |
+
|
255 |
+
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
|
256 |
+
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
|
257 |
+
Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{});
|
258 |
+
Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS);
|
259 |
+
Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{});
|
260 |
+
Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS);
|
261 |
+
Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{});
|
262 |
+
Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS);
|
263 |
+
|
264 |
+
typename Kernel_traits::TiledMmaO tiled_mma_o;
|
265 |
+
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
|
266 |
+
Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N)
|
267 |
+
Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
|
268 |
+
clear(tOrO);
|
269 |
+
|
270 |
+
flash::Softmax<2 * size<1>(tOrO)> softmax;
|
271 |
+
|
272 |
+
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
273 |
+
if (warp_group_idx == 0) {
|
274 |
+
typename Kernel_traits::TiledMma tiled_mma;
|
275 |
+
auto thr_mma = tiled_mma.get_thread_slice(tidx);
|
276 |
+
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
|
277 |
+
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
|
278 |
+
|
279 |
+
if (n_block % 2 == 1) {
|
280 |
+
// Double buffer for sK
|
281 |
+
constexpr int sK_offset = size(sK);
|
282 |
+
tSrK.data() = tSrK.data() + sK_offset / 8;
|
283 |
+
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
284 |
+
}
|
285 |
+
|
286 |
+
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
|
287 |
+
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
|
288 |
+
// We will have at least 1 "masking" iteration.
|
289 |
+
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
|
290 |
+
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
|
291 |
+
constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
|
292 |
+
#pragma unroll 1
|
293 |
+
for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) {
|
294 |
+
__syncthreads();
|
295 |
+
|
296 |
+
Tensor tSrS = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
|
297 |
+
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma, tSrQ, tSrK, tSrS);
|
298 |
+
|
299 |
+
const bool is_masking_step = masking_step > 0;
|
300 |
+
const bool is_first_masking_step = masking_step == n_masking_steps;
|
301 |
+
|
302 |
+
if (is_masking_step) {
|
303 |
+
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
|
304 |
+
Tensor tScS = thr_mma.partition_C(cS);
|
305 |
+
#pragma unroll
|
306 |
+
for (int i = 0; i < size(tSrS); ++i) {
|
307 |
+
if constexpr (!Is_causal) { // Just masking based on col
|
308 |
+
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY;
|
309 |
+
} else {
|
310 |
+
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
|
311 |
+
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
|
312 |
+
int row = int(get<0>(tScS(i)));
|
313 |
+
int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
|
314 |
+
if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY;
|
315 |
+
}
|
316 |
+
}
|
317 |
+
}
|
318 |
+
|
319 |
+
// We have key_padding_mask so we'll need to Check_inf
|
320 |
+
Tensor scale_o = is_first_masking_step
|
321 |
+
? softmax.template softmax</*Is_first=*/true, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
|
322 |
+
: is_masking_step ?
|
323 |
+
softmax.template softmax</*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
|
324 |
+
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
|
325 |
+
|
326 |
+
Tensor rP = flash::convert_type<Element>(tSrS);
|
327 |
+
cute::copy(rP, tPsP);
|
328 |
+
cute::copy(scale_o, tScale_osScale_o);
|
329 |
+
|
330 |
+
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SReady));
|
331 |
+
|
332 |
+
flash::rescale_o(tOrO, scale_o);
|
333 |
+
|
334 |
+
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
335 |
+
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
336 |
+
|
337 |
+
// Double buffer for sK
|
338 |
+
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
339 |
+
tSrK.data() = tSrK.data() + sK_offset / 8;
|
340 |
+
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
341 |
+
}
|
342 |
+
|
343 |
+
cute::copy(softmax.row_max, tRow_maxsRow_max);
|
344 |
+
cute::copy(softmax.row_sum, tRow_sumsRow_sum);
|
345 |
+
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
346 |
+
} else {
|
347 |
+
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
|
348 |
+
int cur_block_table = __ldg(&block_table[n_block]);
|
349 |
+
|
350 |
+
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
|
351 |
+
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
352 |
+
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
353 |
+
make_stride(params.q_row_stride, _1{}));
|
354 |
+
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q;
|
355 |
+
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS);
|
356 |
+
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
|
357 |
+
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
|
358 |
+
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
359 |
+
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
360 |
+
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
|
361 |
+
|
362 |
+
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
363 |
+
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
|
364 |
+
params.seqlen_q - m_block * kBlockM);
|
365 |
+
|
366 |
+
const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride;
|
367 |
+
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
|
368 |
+
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
369 |
+
make_stride(params.k_row_stride, _1{}));
|
370 |
+
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K;
|
371 |
+
auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS);
|
372 |
+
Tensor tKgK = gmem_thr_copy_K.partition_S(gK);
|
373 |
+
Tensor tKsK = gmem_thr_copy_K.partition_D(sK);
|
374 |
+
Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
375 |
+
Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
|
376 |
+
Tensor tKpK = make_tensor<bool>(make_shape(size<2>(tKsK)));
|
377 |
+
|
378 |
+
if (n_block % 2 == 1) {
|
379 |
+
// Double buffer for sK
|
380 |
+
constexpr int sK_offset = size(sK);
|
381 |
+
tKsK.data() = tKsK.data() + sK_offset;
|
382 |
+
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
383 |
+
}
|
384 |
+
|
385 |
+
// We need to clear the sK smem tiles because K is V.
|
386 |
+
const index_t offset_k = cur_block_table * params.k_batch_stride;
|
387 |
+
tKgK.data() = tKgK.data() + offset_k;
|
388 |
+
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK,
|
389 |
+
seqlen_k - n_block * kBlockN);
|
390 |
+
tKgK.data() = tKgK.data() + -offset_k;
|
391 |
+
cute::cp_async_fence();
|
392 |
+
|
393 |
+
if (n_block - 1 >= n_block_min) {
|
394 |
+
cur_block_table = __ldg(&block_table[n_block - 1]);
|
395 |
+
}
|
396 |
+
|
397 |
+
#pragma unroll 1
|
398 |
+
for (; n_block >= n_block_min; --n_block) {
|
399 |
+
flash::cp_async_wait<0>();
|
400 |
+
__syncthreads();
|
401 |
+
|
402 |
+
if (n_block - 1 >= n_block_min) {
|
403 |
+
// Double buffer for sK
|
404 |
+
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
405 |
+
tKsK.data() = tKsK.data() + sK_offset;
|
406 |
+
|
407 |
+
const index_t offset_k = cur_block_table * params.k_batch_stride;
|
408 |
+
tKgK.data() = tKgK.data() + offset_k;
|
409 |
+
flash::copy</*Is_even_MN=*/true, /*Is_even_K=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK);
|
410 |
+
tKgK.data() = tKgK.data() + -offset_k;
|
411 |
+
cute::cp_async_fence();
|
412 |
+
}
|
413 |
+
|
414 |
+
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SReady));
|
415 |
+
|
416 |
+
if (n_block - 2 >= n_block_min) {
|
417 |
+
cur_block_table = __ldg(&block_table[n_block - 2]);
|
418 |
+
}
|
419 |
+
|
420 |
+
typename Kernel_traits::TiledMma tiled_mma;
|
421 |
+
auto tSrS_layout = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}).layout();
|
422 |
+
Tensor rP = make_tensor<Element>(tSrS_layout);
|
423 |
+
Tensor scale_o = make_tensor<float>(Shape<_2>{});
|
424 |
+
cute::copy(tScale_osScale_o, scale_o);
|
425 |
+
cute::copy(tPsP, rP);
|
426 |
+
|
427 |
+
flash::rescale_o(tOrO, scale_o);
|
428 |
+
|
429 |
+
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
430 |
+
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
431 |
+
|
432 |
+
// Double buffer for sK
|
433 |
+
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
434 |
+
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
435 |
+
}
|
436 |
+
|
437 |
+
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
438 |
+
cute::copy(tRow_maxsRow_max, softmax.row_max);
|
439 |
+
cute::copy(tRow_sumsRow_sum, softmax.row_sum);
|
440 |
+
}
|
441 |
+
|
442 |
+
if (NoSplit)
|
443 |
+
store<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
|
444 |
+
else
|
445 |
+
store<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
|
446 |
+
}
|
447 |
+
|
448 |
+
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
|
449 |
+
__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1)
|
450 |
+
flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
|
451 |
+
constexpr int kBlockN = Kernel_traits::kBlockN;
|
452 |
+
const int m_block = blockIdx.x;
|
453 |
+
const int bidh = blockIdx.y;
|
454 |
+
const int partition_idx = blockIdx.z;
|
455 |
+
|
456 |
+
extern __shared__ char shared_memory[];
|
457 |
+
auto &shared_storage = *reinterpret_cast<SharedStorage *>(shared_memory);
|
458 |
+
|
459 |
+
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
|
460 |
+
int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
|
461 |
+
int begin_idx = tile_scheduler_metadata.x;
|
462 |
+
int begin_seqlen = tile_scheduler_metadata.y;
|
463 |
+
int end_idx = tile_scheduler_metadata.z;
|
464 |
+
int end_seqlen = tile_scheduler_metadata.w;
|
465 |
+
if (begin_idx >= params.b) return;
|
466 |
+
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
|
467 |
+
|
468 |
+
#pragma unroll 1
|
469 |
+
for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
|
470 |
+
const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;
|
471 |
+
const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id);
|
472 |
+
const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0;
|
473 |
+
const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);
|
474 |
+
const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN);
|
475 |
+
if (batch_id > begin_idx) {
|
476 |
+
__syncthreads(); // Barrier between two tiles.
|
477 |
+
}
|
478 |
+
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage);
|
479 |
+
}
|
480 |
+
}
|
481 |
+
|
482 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
483 |
+
|
484 |
+
template<typename Element, typename ElementAccum, typename index_t, int kHeadDimV, int kMaxSplits>
|
485 |
+
__global__ void __launch_bounds__(256, 1, 1)
|
486 |
+
flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
|
487 |
+
constexpr int kNThreads = 128;
|
488 |
+
|
489 |
+
const int tidx = threadIdx.x;
|
490 |
+
const int bidx = blockIdx.x;
|
491 |
+
const int hs = params.h * params.seqlen_q;
|
492 |
+
const int batch_idx = bidx / hs;
|
493 |
+
const int hs_idx = bidx % hs;
|
494 |
+
|
495 |
+
const int split_offset = __ldg(params.num_splits_ptr + batch_idx);
|
496 |
+
const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset;
|
497 |
+
FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits);
|
498 |
+
if (actual_num_splits == 1) return;
|
499 |
+
|
500 |
+
__shared__ ElementAccum sLseScale[kMaxSplits];
|
501 |
+
|
502 |
+
const index_t row_offset_lseaccum = split_offset * hs + hs_idx;
|
503 |
+
const index_t row_offset_lse = bidx;
|
504 |
+
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lseaccum),
|
505 |
+
Shape<Int<kMaxSplits>>{}, make_stride(hs));
|
506 |
+
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
507 |
+
Shape<_1>{}, Stride<_1>{});
|
508 |
+
|
509 |
+
int warp_idx = cutlass::canonical_warp_idx_sync();
|
510 |
+
if (warp_idx == 0) {
|
511 |
+
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
|
512 |
+
|
513 |
+
float local_lse[kNLsePerThread];
|
514 |
+
for (int i = 0; i < kNLsePerThread; ++i) {
|
515 |
+
const int split = i * 32 + tidx;
|
516 |
+
local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY;
|
517 |
+
}
|
518 |
+
|
519 |
+
float max_lse = -INFINITY;
|
520 |
+
for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]);
|
521 |
+
for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
|
522 |
+
max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
|
523 |
+
|
524 |
+
float sum_lse = 0;
|
525 |
+
for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + expf(local_lse[i] - max_lse);
|
526 |
+
for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);
|
527 |
+
|
528 |
+
float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse;
|
529 |
+
if (tidx == 0) gLSE(0) = global_lse;
|
530 |
+
|
531 |
+
for (int i = 0; i < kNLsePerThread; ++i) {
|
532 |
+
const int split = i * 32 + tidx;
|
533 |
+
if (split < actual_num_splits) sLseScale[split] = expf(local_lse[i] - global_lse);
|
534 |
+
}
|
535 |
+
}
|
536 |
+
__syncthreads();
|
537 |
+
|
538 |
+
static_assert(kHeadDimV % kNThreads == 0);
|
539 |
+
constexpr int Elements = kHeadDimV / kNThreads;
|
540 |
+
const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV;
|
541 |
+
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
|
542 |
+
Shape<Int<kHeadDimV>>{}, Stride<_1>{});
|
543 |
+
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
|
544 |
+
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
545 |
+
Layout<Shape<Int<kNThreads>>>{},
|
546 |
+
Layout<Shape<Int<Elements>>>{}));
|
547 |
+
GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
|
548 |
+
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
549 |
+
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
|
550 |
+
Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
|
551 |
+
Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
|
552 |
+
clear(tOrO);
|
553 |
+
|
554 |
+
for (int split = 0; split < actual_num_splits; ++split) {
|
555 |
+
cute::copy(tOgOaccum, tOrOaccum);
|
556 |
+
ElementAccum lse_scale = sLseScale[split];
|
557 |
+
for (int i = 0; i < size(tOrO); ++i) {
|
558 |
+
tOrO(i) += lse_scale * tOrOaccum(i);
|
559 |
+
}
|
560 |
+
tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV;
|
561 |
+
}
|
562 |
+
|
563 |
+
Tensor rO = flash::convert_type<Element>(tOrO);
|
564 |
+
const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q;
|
565 |
+
const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q;
|
566 |
+
auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride;
|
567 |
+
Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});
|
568 |
+
cute::copy(rO, gO);
|
569 |
+
}
|
570 |
+
|
571 |
+
} // namespace flash
|
572 |
+
|
573 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
574 |
+
|
575 |
+
template<typename Kernel_traits, typename SharedStorage>
|
576 |
+
void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) {
|
577 |
+
FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN);
|
578 |
+
const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
|
579 |
+
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
580 |
+
auto kernel = &flash::flash_fwd_splitkv_mla_kernel<Kernel_traits, Is_causal, SharedStorage>;
|
581 |
+
constexpr size_t smem_size = sizeof(SharedStorage);
|
582 |
+
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
583 |
+
kernel<<<dim3(num_m_block, params.h, params.num_sm_parts), Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
584 |
+
});
|
585 |
+
CHECK_CUDA_KERNEL_LAUNCH();
|
586 |
+
|
587 |
+
dim3 grid_combine(params.b * params.h * params.seqlen_q);
|
588 |
+
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
|
589 |
+
auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
|
590 |
+
typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>;
|
591 |
+
combine_kernel<<<grid_combine, 128, 0, stream>>>(params);
|
592 |
+
});
|
593 |
+
CHECK_CUDA_KERNEL_LAUNCH();
|
594 |
+
}
|
595 |
+
|
596 |
+
template<typename T, int Headdim>
|
597 |
+
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) {
|
598 |
+
static_assert(Headdim == 576);
|
599 |
+
FLASH_ASSERT(params.d_v == 512);
|
600 |
+
FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV
|
601 |
+
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>;
|
602 |
+
run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
|
603 |
+
}
|
flash_mla/flash_fwd_mla_metadata.cu
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "flash_fwd_mla_kernel.h"
|
2 |
+
|
3 |
+
static constexpr int MaxBatchSize = 4096;
|
4 |
+
|
5 |
+
__global__ void __launch_bounds__(256, 1, 1)
|
6 |
+
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
|
7 |
+
int *seqlens_k_ptr = params.seqlens_k_ptr;
|
8 |
+
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
|
9 |
+
int *num_splits_ptr = params.num_splits_ptr;
|
10 |
+
int batch_size = params.batch_size;
|
11 |
+
int block_size_n = params.block_size_n;
|
12 |
+
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
|
13 |
+
int num_sm_parts = params.num_sm_parts;
|
14 |
+
|
15 |
+
__shared__ int num_blocks_shared[MaxBatchSize];
|
16 |
+
__shared__ int num_splits_shared[MaxBatchSize];
|
17 |
+
|
18 |
+
int total_num_blocks = 0;
|
19 |
+
for (int i = threadIdx.x; i < batch_size; i += 32) {
|
20 |
+
int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);
|
21 |
+
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
|
22 |
+
num_blocks_shared[i] = num_blocks;
|
23 |
+
}
|
24 |
+
for (int offset = 16; offset >= 1; offset /= 2) {
|
25 |
+
total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
|
26 |
+
}
|
27 |
+
__syncwarp();
|
28 |
+
|
29 |
+
if (threadIdx.x == 0) {
|
30 |
+
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
|
31 |
+
|
32 |
+
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
|
33 |
+
num_splits_shared[0] = 0;
|
34 |
+
for (int i = 0; i < num_sm_parts; ++i) {
|
35 |
+
int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
|
36 |
+
tile_scheduler_metadata0[0] = now_idx;
|
37 |
+
tile_scheduler_metadata0[1] = now_block * block_size_n;
|
38 |
+
tile_scheduler_metadata1 = now_n_split_idx;
|
39 |
+
int remain_payload = payload;
|
40 |
+
while (now_idx < batch_size) {
|
41 |
+
int num_blocks = num_blocks_shared[now_idx];
|
42 |
+
int now_remain_blocks = num_blocks - now_block;
|
43 |
+
if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
|
44 |
+
cum_num_splits += now_n_split_idx + 1;
|
45 |
+
num_splits_shared[now_idx + 1] = cum_num_splits;
|
46 |
+
remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
|
47 |
+
++now_idx;
|
48 |
+
now_block = 0;
|
49 |
+
now_n_split_idx = 0;
|
50 |
+
} else {
|
51 |
+
if (remain_payload - fixed_overhead_num_blocks > 0) {
|
52 |
+
now_block += remain_payload - fixed_overhead_num_blocks;
|
53 |
+
++now_n_split_idx;
|
54 |
+
remain_payload = 0;
|
55 |
+
}
|
56 |
+
break;
|
57 |
+
}
|
58 |
+
}
|
59 |
+
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
|
60 |
+
tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
|
61 |
+
*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
|
62 |
+
tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
|
63 |
+
}
|
64 |
+
FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
|
65 |
+
}
|
66 |
+
__syncwarp();
|
67 |
+
|
68 |
+
for (int i = threadIdx.x; i <= batch_size; i += 32) {
|
69 |
+
num_splits_ptr[i] = num_splits_shared[i];
|
70 |
+
}
|
71 |
+
}
|
72 |
+
|
73 |
+
void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) {
|
74 |
+
FLASH_ASSERT(params.batch_size < MaxBatchSize);
|
75 |
+
get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params);
|
76 |
+
CHECK_CUDA_KERNEL_LAUNCH();
|
77 |
+
}
|
flash_mla/flash_mla.h
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
4 |
+
|
5 |
+
struct Flash_fwd_mla_params {
|
6 |
+
using index_t = int64_t;
|
7 |
+
|
8 |
+
int b, seqlen_q, d, d_v;
|
9 |
+
int h, h_h_k_ratio, ngroups;
|
10 |
+
bool is_causal;
|
11 |
+
float scale_softmax, scale_softmax_log2;
|
12 |
+
int *__restrict__ cu_seqlens_k;
|
13 |
+
|
14 |
+
void *__restrict__ q_ptr;
|
15 |
+
void *__restrict__ k_ptr;
|
16 |
+
void *__restrict__ v_ptr;
|
17 |
+
void *__restrict__ o_ptr;
|
18 |
+
void *__restrict__ softmax_lse_ptr;
|
19 |
+
|
20 |
+
index_t q_batch_stride;
|
21 |
+
index_t k_batch_stride;
|
22 |
+
index_t v_batch_stride;
|
23 |
+
index_t o_batch_stride;
|
24 |
+
index_t q_row_stride;
|
25 |
+
index_t k_row_stride;
|
26 |
+
index_t v_row_stride;
|
27 |
+
index_t o_row_stride;
|
28 |
+
index_t q_head_stride;
|
29 |
+
index_t k_head_stride;
|
30 |
+
index_t v_head_stride;
|
31 |
+
index_t o_head_stride;
|
32 |
+
|
33 |
+
int *__restrict__ block_table;
|
34 |
+
index_t block_table_batch_stride;
|
35 |
+
int page_block_size;
|
36 |
+
|
37 |
+
int *__restrict__ tile_scheduler_metadata_ptr;
|
38 |
+
int num_sm_parts;
|
39 |
+
int *__restrict__ num_splits_ptr;
|
40 |
+
|
41 |
+
void *__restrict__ softmax_lseaccum_ptr;
|
42 |
+
void *__restrict__ oaccum_ptr;
|
43 |
+
};
|
44 |
+
|
45 |
+
static constexpr int TileSchedulerMetaDataSize = 8;
|
46 |
+
// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _]
|
47 |
+
|
48 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
49 |
+
|
50 |
+
template<typename T, int Headdim>
|
51 |
+
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
52 |
+
|
53 |
+
struct Mla_metadata_params {
|
54 |
+
int *__restrict__ seqlens_k_ptr;
|
55 |
+
int *__restrict__ tile_scheduler_metadata_ptr;
|
56 |
+
int *__restrict__ num_splits_ptr;
|
57 |
+
int batch_size;
|
58 |
+
int block_size_n;
|
59 |
+
int fixed_overhead_num_blocks;
|
60 |
+
int num_sm_parts;
|
61 |
+
};
|
62 |
+
|
63 |
+
void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream);
|
flash_mla/flash_mla_api.cu
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/cuda/CUDAContext.h>
|
2 |
+
#include <c10/cuda/CUDAGuard.h>
|
3 |
+
#include <torch/all.h>
|
4 |
+
|
5 |
+
|
6 |
+
#include <cutlass/fast_math.h>
|
7 |
+
|
8 |
+
#include "flash_mla.h"
|
9 |
+
#include "static_switch.h"
|
10 |
+
|
11 |
+
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
12 |
+
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
13 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
14 |
+
|
15 |
+
|
16 |
+
//
|
17 |
+
|
18 |
+
|
19 |
+
// #include <cmath>
|
20 |
+
|
21 |
+
// #include "cute/tensor.hpp"
|
22 |
+
#include <cute/tensor.hpp>
|
23 |
+
|
24 |
+
// __global__ void relu_kernel(float *__restrict__ out,
|
25 |
+
// float const *__restrict__ input,
|
26 |
+
// const int d) {
|
27 |
+
// const int64_t token_idx = blockIdx.x;
|
28 |
+
// for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
29 |
+
// auto x = input[token_idx * d + idx];
|
30 |
+
// out[token_idx * d + idx] = x > 0.0f ? x : 0.0f;
|
31 |
+
// }
|
32 |
+
// }
|
33 |
+
|
34 |
+
// void relu(torch::Tensor &out,
|
35 |
+
// torch::Tensor const &input)
|
36 |
+
// {
|
37 |
+
// TORCH_CHECK(input.scalar_type() == at::ScalarType::Float &&
|
38 |
+
// input.scalar_type() == at::ScalarType::Float,
|
39 |
+
// "relu_kernel only supports float32");
|
40 |
+
|
41 |
+
// int d = input.size(-1);
|
42 |
+
// int64_t num_tokens = input.numel() / d;
|
43 |
+
// dim3 grid(num_tokens);
|
44 |
+
// dim3 block(std::min(d, 1024));
|
45 |
+
// const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
46 |
+
// const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
47 |
+
// relu_kernel<<<grid, block, 0, stream>>>(out.data_ptr<float>(),
|
48 |
+
// input.data_ptr<float>(), d);
|
49 |
+
// }
|
50 |
+
|
51 |
+
std::vector<at::Tensor>
|
52 |
+
get_mla_metadata(
|
53 |
+
at::Tensor &seqlens_k,
|
54 |
+
const int64_t num_heads_per_head_k,
|
55 |
+
const int64_t num_heads_k
|
56 |
+
) {
|
57 |
+
// This should match the logic in the MLA kernel.
|
58 |
+
static constexpr int block_size_m = 64;
|
59 |
+
static constexpr int block_size_n = 64;
|
60 |
+
static constexpr int fixed_overhead_num_blocks = 5;
|
61 |
+
|
62 |
+
CHECK_DEVICE(seqlens_k);
|
63 |
+
TORCH_CHECK(seqlens_k.is_contiguous());
|
64 |
+
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
|
65 |
+
|
66 |
+
int batch_size = seqlens_k.size(0);
|
67 |
+
int *seqlens_k_ptr = seqlens_k.data_ptr<int>();
|
68 |
+
auto options = seqlens_k.options();
|
69 |
+
|
70 |
+
auto dprops = at::cuda::getCurrentDeviceProperties();
|
71 |
+
int sm_count = dprops->multiProcessorCount;
|
72 |
+
int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m);
|
73 |
+
|
74 |
+
auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options);
|
75 |
+
auto num_splits = torch::empty({batch_size + 1}, options);
|
76 |
+
int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
|
77 |
+
int *num_splits_ptr = num_splits.data_ptr<int>();
|
78 |
+
|
79 |
+
at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()};
|
80 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
81 |
+
Mla_metadata_params params = {};
|
82 |
+
params.seqlens_k_ptr = seqlens_k_ptr;
|
83 |
+
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
|
84 |
+
params.num_splits_ptr = num_splits_ptr;
|
85 |
+
params.batch_size = batch_size;
|
86 |
+
params.block_size_n = block_size_n;
|
87 |
+
params.fixed_overhead_num_blocks = fixed_overhead_num_blocks;
|
88 |
+
params.num_sm_parts = num_sm_parts;
|
89 |
+
get_mla_metadata_func(params, stream);
|
90 |
+
|
91 |
+
return {tile_scheduler_metadata, num_splits};
|
92 |
+
}
|
93 |
+
|
94 |
+
std::vector<at::Tensor>
|
95 |
+
mha_fwd_kvcache_mla(
|
96 |
+
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
97 |
+
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
|
98 |
+
|
99 |
+
// TODO: fix for optional
|
100 |
+
// std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
101 |
+
|
102 |
+
const at::Tensor &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
103 |
+
const int64_t head_size_v,
|
104 |
+
const at::Tensor &seqlens_k, // batch_size
|
105 |
+
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
|
106 |
+
// TODO: should be float
|
107 |
+
const double softmax_scale,
|
108 |
+
const bool is_causal_,
|
109 |
+
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
110 |
+
const at::Tensor &num_splits, // batch_size + 1
|
111 |
+
|
112 |
+
// TODO: remove this once determined why build is adding this parameter
|
113 |
+
const int64_t unknown_param
|
114 |
+
) {
|
115 |
+
auto dprops = at::cuda::getCurrentDeviceProperties();
|
116 |
+
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
117 |
+
TORCH_CHECK(is_sm90);
|
118 |
+
|
119 |
+
// TODO: fix for mutable bool
|
120 |
+
bool is_causal = is_causal_;
|
121 |
+
|
122 |
+
|
123 |
+
// TODO: fix for optional
|
124 |
+
// at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
|
125 |
+
at::Tensor vcache = vcache_;
|
126 |
+
|
127 |
+
auto q_dtype = q.dtype();
|
128 |
+
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
129 |
+
|
130 |
+
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
|
131 |
+
|
132 |
+
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
133 |
+
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
134 |
+
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
135 |
+
|
136 |
+
CHECK_DEVICE(block_table);
|
137 |
+
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
|
138 |
+
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
|
139 |
+
|
140 |
+
const auto sizes = q.sizes();
|
141 |
+
const int batch_size = sizes[0];
|
142 |
+
const int seqlen_q_ori = sizes[1];
|
143 |
+
const int num_heads_ori = sizes[2];
|
144 |
+
const int head_size = sizes[3];
|
145 |
+
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
146 |
+
TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
|
147 |
+
|
148 |
+
const int max_num_blocks_per_seq = block_table.size(1);
|
149 |
+
const int num_blocks = kcache.size(0);
|
150 |
+
const int page_block_size = kcache.size(1);
|
151 |
+
const int num_heads_k = kcache.size(2);
|
152 |
+
TORCH_CHECK(batch_size > 0, "batch size must be postive");
|
153 |
+
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
154 |
+
|
155 |
+
if (seqlen_q_ori == 1) { is_causal = false; }
|
156 |
+
|
157 |
+
const int ngroups = num_heads_ori / num_heads_k;
|
158 |
+
const int seqlen_q = seqlen_q_ori * ngroups;
|
159 |
+
const int num_heads = num_heads_k;
|
160 |
+
q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3)
|
161 |
+
.reshape({batch_size, seqlen_q, num_heads, head_size});
|
162 |
+
|
163 |
+
int head_size_k = head_size;
|
164 |
+
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
|
165 |
+
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
|
166 |
+
|
167 |
+
// TODO: fix for optional
|
168 |
+
// if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
|
169 |
+
CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v);
|
170 |
+
|
171 |
+
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
|
172 |
+
|
173 |
+
|
174 |
+
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
|
175 |
+
CHECK_DEVICE(seqlens_k);
|
176 |
+
CHECK_CONTIGUOUS(seqlens_k);
|
177 |
+
CHECK_SHAPE(seqlens_k, batch_size);
|
178 |
+
|
179 |
+
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
180 |
+
|
181 |
+
auto opts = q.options();
|
182 |
+
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
|
183 |
+
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
184 |
+
|
185 |
+
Flash_fwd_mla_params params = {};
|
186 |
+
// Set the sizes.
|
187 |
+
params.b = batch_size;
|
188 |
+
params.seqlen_q = seqlen_q;
|
189 |
+
params.cu_seqlens_k = seqlens_k.data_ptr<int>();
|
190 |
+
params.h = num_heads;
|
191 |
+
params.h_h_k_ratio = num_heads / num_heads_k;
|
192 |
+
params.ngroups = ngroups;
|
193 |
+
params.is_causal = is_causal;
|
194 |
+
params.d = head_size;
|
195 |
+
params.d_v = head_size_v;
|
196 |
+
params.scale_softmax = softmax_scale;
|
197 |
+
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
|
198 |
+
// Set the pointers and strides.
|
199 |
+
params.q_ptr = q.data_ptr();
|
200 |
+
params.k_ptr = kcache.data_ptr();
|
201 |
+
params.v_ptr = vcache.data_ptr();
|
202 |
+
params.o_ptr = out.data_ptr();
|
203 |
+
params.softmax_lse_ptr = softmax_lse.data_ptr();
|
204 |
+
// All stride are in elements, not bytes.
|
205 |
+
params.q_batch_stride = q.stride(0);
|
206 |
+
params.k_batch_stride = kcache.stride(0);
|
207 |
+
params.v_batch_stride = vcache.stride(0);
|
208 |
+
params.o_batch_stride = out.stride(0);
|
209 |
+
params.q_row_stride = q.stride(-3);
|
210 |
+
params.k_row_stride = kcache.stride(-3);
|
211 |
+
params.v_row_stride = vcache.stride(-3);
|
212 |
+
params.o_row_stride = out.stride(-3);
|
213 |
+
params.q_head_stride = q.stride(-2);
|
214 |
+
params.k_head_stride = kcache.stride(-2);
|
215 |
+
params.v_head_stride = vcache.stride(-2);
|
216 |
+
params.o_head_stride = out.stride(-2);
|
217 |
+
|
218 |
+
params.block_table = block_table.data_ptr<int>();
|
219 |
+
params.block_table_batch_stride = block_table.stride(0);
|
220 |
+
params.page_block_size = page_block_size;
|
221 |
+
|
222 |
+
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
|
223 |
+
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
|
224 |
+
CHECK_DEVICE(tile_scheduler_metadata);
|
225 |
+
CHECK_CONTIGUOUS(tile_scheduler_metadata);
|
226 |
+
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
|
227 |
+
params.num_sm_parts = tile_scheduler_metadata.size(0);
|
228 |
+
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
|
229 |
+
CHECK_DEVICE(num_splits);
|
230 |
+
CHECK_CONTIGUOUS(num_splits);
|
231 |
+
params.num_splits_ptr = num_splits.data_ptr<int>();
|
232 |
+
|
233 |
+
at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
234 |
+
at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
|
235 |
+
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
|
236 |
+
params.oaccum_ptr = out_accum.data_ptr();
|
237 |
+
|
238 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
239 |
+
TORCH_CHECK(head_size == 576);
|
240 |
+
|
241 |
+
if (q_dtype == torch::kBFloat16) {
|
242 |
+
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
|
243 |
+
}
|
244 |
+
#ifndef FLASH_MLA_DISABLE_FP16
|
245 |
+
else if (q_dtype == torch::kHalf) {
|
246 |
+
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, stream);
|
247 |
+
}
|
248 |
+
#endif
|
249 |
+
else {
|
250 |
+
TORCH_CHECK(false, "Unsupported tensor dtype for query");
|
251 |
+
}
|
252 |
+
|
253 |
+
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
|
254 |
+
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
|
255 |
+
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
|
256 |
+
.reshape({batch_size, num_heads_ori, seqlen_q_ori});
|
257 |
+
|
258 |
+
return {out, softmax_lse};
|
259 |
+
}
|
flash_mla/named_barrier.h
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "cutlass/barrier.h"
|
4 |
+
|
5 |
+
namespace flash {
|
6 |
+
|
7 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
8 |
+
// Enumerates the reserved named barriers to avoid potential conflicts
|
9 |
+
|
10 |
+
enum class NamedBarriers {
|
11 |
+
SReady = 1,
|
12 |
+
SoftmaxReady = 2,
|
13 |
+
};
|
14 |
+
|
15 |
+
} // flash
|
flash_mla/softmax.h
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <cmath>
|
6 |
+
|
7 |
+
#include <cute/tensor.hpp>
|
8 |
+
#include <cutlass/numeric_types.h>
|
9 |
+
|
10 |
+
#include "utils.h"
|
11 |
+
|
12 |
+
namespace flash {
|
13 |
+
|
14 |
+
using namespace cute;
|
15 |
+
|
16 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
17 |
+
|
18 |
+
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
19 |
+
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
20 |
+
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
21 |
+
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
22 |
+
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
23 |
+
#pragma unroll
|
24 |
+
for (int mi = 0; mi < size<0>(tensor); mi++) {
|
25 |
+
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
26 |
+
#pragma unroll
|
27 |
+
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
28 |
+
summary(mi) = op(summary(mi), tensor(mi, ni));
|
29 |
+
}
|
30 |
+
}
|
31 |
+
}
|
32 |
+
|
33 |
+
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
34 |
+
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
|
35 |
+
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
36 |
+
#pragma unroll
|
37 |
+
for (int i = 0; i < size(dst); i++){
|
38 |
+
dst(i) = Allreduce<4>::run(src(i), op);
|
39 |
+
}
|
40 |
+
}
|
41 |
+
|
42 |
+
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
43 |
+
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
44 |
+
thread_reduce_<zero_init>(tensor, summary, op);
|
45 |
+
quad_allreduce_(summary, summary, op);
|
46 |
+
}
|
47 |
+
|
48 |
+
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
49 |
+
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
|
50 |
+
MaxOp<float> max_op;
|
51 |
+
reduce_<zero_init>(tensor, max, max_op);
|
52 |
+
}
|
53 |
+
|
54 |
+
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
55 |
+
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
|
56 |
+
SumOp<float> sum_op;
|
57 |
+
thread_reduce_<zero_init>(tensor, sum, sum_op);
|
58 |
+
}
|
59 |
+
|
60 |
+
// Apply the exp to all the elements.
|
61 |
+
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
62 |
+
__forceinline__ __device__ auto scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
|
63 |
+
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
64 |
+
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
65 |
+
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
66 |
+
#pragma unroll
|
67 |
+
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
68 |
+
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
69 |
+
// We don't want (-inf - (-inf)) since that would give NaN.
|
70 |
+
// If we don't have float around M_LOG2E the multiplication is done in fp64.
|
71 |
+
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
|
72 |
+
#pragma unroll
|
73 |
+
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
74 |
+
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
75 |
+
// max * log_2(e)) This allows the compiler to use the ffma
|
76 |
+
// instruction instead of fadd and fmul separately.
|
77 |
+
// The following macro will disable the use of fma.
|
78 |
+
// See: https://github.com/pytorch/pytorch/issues/121558 for more details
|
79 |
+
// This macro is set in PyTorch and not FlashAttention
|
80 |
+
#ifdef UNFUSE_FMA
|
81 |
+
tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
|
82 |
+
#else
|
83 |
+
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
84 |
+
#endif
|
85 |
+
}
|
86 |
+
}
|
87 |
+
return tensor;
|
88 |
+
}
|
89 |
+
|
90 |
+
// Apply the exp to all the elements.
|
91 |
+
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
92 |
+
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
|
93 |
+
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
94 |
+
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
95 |
+
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
96 |
+
#pragma unroll
|
97 |
+
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
98 |
+
MaxOp<float> max_op;
|
99 |
+
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
|
100 |
+
#pragma unroll
|
101 |
+
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
102 |
+
max(mi) = max_op(max(mi), tensor(mi, ni));
|
103 |
+
}
|
104 |
+
max(mi) = Allreduce<4>::run(max(mi), max_op);
|
105 |
+
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
106 |
+
// We don't want (-inf - (-inf)) since that would give NaN.
|
107 |
+
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
|
108 |
+
sum(mi) = 0;
|
109 |
+
#pragma unroll
|
110 |
+
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
111 |
+
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
112 |
+
// max * log_2(e)) This allows the compiler to use the ffma
|
113 |
+
// instruction instead of fadd and fmul separately.
|
114 |
+
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
115 |
+
sum(mi) += tensor(mi, ni);
|
116 |
+
}
|
117 |
+
SumOp<float> sum_op;
|
118 |
+
sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
|
119 |
+
}
|
120 |
+
}
|
121 |
+
|
122 |
+
template<typename Tensor0, typename Tensor1>
|
123 |
+
__forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) {
|
124 |
+
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
125 |
+
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
126 |
+
#pragma unroll
|
127 |
+
for (int mi = 0; mi < size(scale_o); ++mi) {
|
128 |
+
#pragma unroll
|
129 |
+
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); }
|
130 |
+
}
|
131 |
+
}
|
132 |
+
|
133 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
134 |
+
|
135 |
+
template <int kNRows>
|
136 |
+
struct Softmax {
|
137 |
+
|
138 |
+
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
|
139 |
+
TensorT row_max, row_sum;
|
140 |
+
|
141 |
+
__forceinline__ __device__ Softmax() {};
|
142 |
+
|
143 |
+
template<bool Is_first, bool Check_inf=false, typename Tensor0>
|
144 |
+
__forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) {
|
145 |
+
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
146 |
+
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
147 |
+
static_assert(decltype(size<0>(scores))::value == kNRows);
|
148 |
+
TensorT scale_o;
|
149 |
+
clear(scale_o);
|
150 |
+
if (Is_first) {
|
151 |
+
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
|
152 |
+
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
153 |
+
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
|
154 |
+
} else {
|
155 |
+
Tensor scores_max_prev = make_fragment_like(row_max);
|
156 |
+
cute::copy(row_max, scores_max_prev);
|
157 |
+
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
|
158 |
+
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
159 |
+
#pragma unroll
|
160 |
+
for (int mi = 0; mi < size(row_max); ++mi) {
|
161 |
+
float scores_max_cur = !Check_inf
|
162 |
+
? row_max(mi)
|
163 |
+
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
|
164 |
+
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
165 |
+
scale_o(mi) = scores_scale;
|
166 |
+
row_sum(mi) *= scores_scale;
|
167 |
+
}
|
168 |
+
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
169 |
+
// We don't do the reduce across threads here since we don't need to use the row_sum.
|
170 |
+
// We do that reduce at the end when we need to normalize the softmax.
|
171 |
+
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
|
172 |
+
}
|
173 |
+
return scale_o;
|
174 |
+
};
|
175 |
+
|
176 |
+
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
|
177 |
+
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
|
178 |
+
SumOp<float> sum_op;
|
179 |
+
quad_allreduce_(row_sum, row_sum, sum_op);
|
180 |
+
TensorT lse = make_fragment_like(row_sum);
|
181 |
+
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
182 |
+
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
183 |
+
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
184 |
+
#pragma unroll
|
185 |
+
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
186 |
+
float sum = row_sum(mi);
|
187 |
+
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
188 |
+
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
|
189 |
+
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
|
190 |
+
#pragma unroll
|
191 |
+
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
|
192 |
+
}
|
193 |
+
return lse;
|
194 |
+
};
|
195 |
+
};
|
196 |
+
|
197 |
+
} // namespace flash
|
flash_mla/static_switch.h
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#define CHECK_CUDA(call) \
|
4 |
+
do { \
|
5 |
+
cudaError_t status_ = call; \
|
6 |
+
if (status_ != cudaSuccess) { \
|
7 |
+
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
|
8 |
+
exit(1); \
|
9 |
+
} \
|
10 |
+
} while(0)
|
11 |
+
|
12 |
+
#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
|
13 |
+
|
14 |
+
|
15 |
+
#define FLASH_ASSERT(cond) \
|
16 |
+
do { \
|
17 |
+
if (not (cond)) { \
|
18 |
+
fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
|
19 |
+
exit(1); \
|
20 |
+
} \
|
21 |
+
} while(0)
|
22 |
+
|
23 |
+
|
24 |
+
#define FLASH_DEVICE_ASSERT(cond) \
|
25 |
+
do { \
|
26 |
+
if (not (cond)) { \
|
27 |
+
printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
|
28 |
+
asm("trap;"); \
|
29 |
+
} \
|
30 |
+
} while(0)
|
31 |
+
|
32 |
+
|
33 |
+
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
34 |
+
[&] { \
|
35 |
+
if (COND) { \
|
36 |
+
constexpr static bool CONST_NAME = true; \
|
37 |
+
return __VA_ARGS__(); \
|
38 |
+
} else { \
|
39 |
+
constexpr static bool CONST_NAME = false; \
|
40 |
+
return __VA_ARGS__(); \
|
41 |
+
} \
|
42 |
+
}()
|
43 |
+
|
44 |
+
|
45 |
+
#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \
|
46 |
+
[&] { \
|
47 |
+
if (NUM_SPLITS <= 32) { \
|
48 |
+
constexpr static int NAME = 32; \
|
49 |
+
return __VA_ARGS__(); \
|
50 |
+
} else if (NUM_SPLITS <= 64) { \
|
51 |
+
constexpr static int NAME = 64; \
|
52 |
+
return __VA_ARGS__(); \
|
53 |
+
} else if (NUM_SPLITS <= 96) { \
|
54 |
+
constexpr static int NAME = 96; \
|
55 |
+
return __VA_ARGS__(); \
|
56 |
+
} else if (NUM_SPLITS <= 128) { \
|
57 |
+
constexpr static int NAME = 128; \
|
58 |
+
return __VA_ARGS__(); \
|
59 |
+
} else if (NUM_SPLITS <= 160) { \
|
60 |
+
constexpr static int NAME = 160; \
|
61 |
+
return __VA_ARGS__(); \
|
62 |
+
} else { \
|
63 |
+
FLASH_ASSERT(false); \
|
64 |
+
} \
|
65 |
+
}()
|
flash_mla/utils.h
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <assert.h>
|
6 |
+
#include <stdint.h>
|
7 |
+
#include <stdlib.h>
|
8 |
+
|
9 |
+
#include <cuda_bf16.h>
|
10 |
+
|
11 |
+
#include <cute/tensor.hpp>
|
12 |
+
|
13 |
+
#include <cutlass/array.h>
|
14 |
+
#include <cutlass/cutlass.h>
|
15 |
+
#include <cutlass/numeric_conversion.h>
|
16 |
+
#include <cutlass/numeric_types.h>
|
17 |
+
|
18 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
19 |
+
|
20 |
+
namespace flash {
|
21 |
+
|
22 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
23 |
+
|
24 |
+
template<typename T>
|
25 |
+
struct MaxOp {
|
26 |
+
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
27 |
+
};
|
28 |
+
|
29 |
+
template <>
|
30 |
+
struct MaxOp<float> {
|
31 |
+
// This is slightly faster
|
32 |
+
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
|
33 |
+
};
|
34 |
+
|
35 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
36 |
+
|
37 |
+
template<typename T>
|
38 |
+
struct SumOp {
|
39 |
+
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
|
40 |
+
};
|
41 |
+
|
42 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
43 |
+
|
44 |
+
template<int THREADS>
|
45 |
+
struct Allreduce {
|
46 |
+
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
47 |
+
template<typename T, typename Operator>
|
48 |
+
static __device__ __forceinline__ T run(T x, Operator &op) {
|
49 |
+
constexpr int OFFSET = THREADS / 2;
|
50 |
+
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
51 |
+
return Allreduce<OFFSET>::run(x, op);
|
52 |
+
}
|
53 |
+
};
|
54 |
+
|
55 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
56 |
+
|
57 |
+
template<>
|
58 |
+
struct Allreduce<2> {
|
59 |
+
template<typename T, typename Operator>
|
60 |
+
static __device__ __forceinline__ T run(T x, Operator &op) {
|
61 |
+
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
62 |
+
return x;
|
63 |
+
}
|
64 |
+
};
|
65 |
+
|
66 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
67 |
+
|
68 |
+
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
|
69 |
+
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
|
70 |
+
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
|
71 |
+
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
|
72 |
+
if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
73 |
+
warpgroup_fence_operand(tCrC);
|
74 |
+
if constexpr (arrive) {
|
75 |
+
warpgroup_arrive();
|
76 |
+
}
|
77 |
+
if constexpr (zero_init) {
|
78 |
+
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
79 |
+
// Unroll the K mode manually to set scale D to 1
|
80 |
+
CUTLASS_PRAGMA_UNROLL
|
81 |
+
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
82 |
+
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
83 |
+
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
84 |
+
}
|
85 |
+
} else {
|
86 |
+
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
|
87 |
+
// Unroll the K mode manually to set scale D to 1
|
88 |
+
CUTLASS_PRAGMA_UNROLL
|
89 |
+
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
90 |
+
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
91 |
+
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
92 |
+
}
|
93 |
+
}
|
94 |
+
if constexpr (commit) {
|
95 |
+
warpgroup_commit_batch();
|
96 |
+
}
|
97 |
+
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
|
98 |
+
warpgroup_fence_operand(tCrC);
|
99 |
+
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
100 |
+
}
|
101 |
+
|
102 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
103 |
+
|
104 |
+
// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
105 |
+
// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
106 |
+
template<bool Transposed=false, typename Layout0>
|
107 |
+
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) {
|
108 |
+
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
|
109 |
+
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
110 |
+
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
111 |
+
static_assert(decltype(rank(acc_layout))::value == 3);
|
112 |
+
auto l = acc_layout;
|
113 |
+
if constexpr (!Transposed) {
|
114 |
+
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
|
115 |
+
} else {
|
116 |
+
return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
|
117 |
+
}
|
118 |
+
|
119 |
+
} else { // SM80
|
120 |
+
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
121 |
+
static_assert(decltype(rank(acc_layout))::value == 3);
|
122 |
+
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
|
123 |
+
if constexpr (!Transposed) {
|
124 |
+
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
125 |
+
} else {
|
126 |
+
return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
|
127 |
+
}
|
128 |
+
}
|
129 |
+
};
|
130 |
+
|
131 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
132 |
+
|
133 |
+
// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
134 |
+
// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
|
135 |
+
// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
|
136 |
+
// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N))
|
137 |
+
template<typename MMA_Traits, typename Layout0>
|
138 |
+
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) {
|
139 |
+
using X = Underscore;
|
140 |
+
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
|
141 |
+
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
142 |
+
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
143 |
+
static_assert(decltype(rank(acc_layout))::value == 3);
|
144 |
+
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
|
145 |
+
if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) {
|
146 |
+
auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16))
|
147 |
+
return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
|
148 |
+
} else {
|
149 |
+
static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1);
|
150 |
+
static_assert(decltype(stride<0, 0>(acc_layout))::value == 1);
|
151 |
+
static_assert(decltype(stride<0, 1>(acc_layout))::value == 2);
|
152 |
+
auto l = logical_divide(get<0, 2>(acc_layout), Tile<Layout<Shape<_2, _2>>>{}); // (((2, 2), N / 32))
|
153 |
+
// This combines the first two modes (<0, 0> and <0, 1>) into one mode.
|
154 |
+
// Will require register shuffling later to be correct.
|
155 |
+
return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)),
|
156 |
+
get<1>(acc_layout),
|
157 |
+
coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N)
|
158 |
+
// This combination is right but doesn't work with register shuffling.
|
159 |
+
// return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)),
|
160 |
+
// get<1>(acc_layout),
|
161 |
+
// coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
|
162 |
+
}
|
163 |
+
} else { // SM80
|
164 |
+
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
165 |
+
static_assert(decltype(rank(acc_layout))::value == 3);
|
166 |
+
constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{});
|
167 |
+
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
168 |
+
if constexpr (mma_shape_K == 8) {
|
169 |
+
return acc_layout;
|
170 |
+
} else {
|
171 |
+
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
|
172 |
+
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
|
173 |
+
}
|
174 |
+
}
|
175 |
+
};
|
176 |
+
|
177 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
178 |
+
|
179 |
+
template <typename To_type, typename Engine, typename Layout>
|
180 |
+
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
181 |
+
using From_type = typename Engine::value_type;
|
182 |
+
constexpr int numel = decltype(size(tensor))::value;
|
183 |
+
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
184 |
+
// HACK: this requires tensor to be "contiguous"
|
185 |
+
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
|
186 |
+
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
187 |
+
}
|
188 |
+
|
189 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
190 |
+
|
191 |
+
// Blocks until all but N previous cp.async.commit_group operations have committed.
|
192 |
+
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
|
193 |
+
// (which is equivalent to commit_group then wait_group 0).
|
194 |
+
// Instead we just call cp.async.wait_group 0, which is slightly faster.
|
195 |
+
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
|
196 |
+
template <int N>
|
197 |
+
CUTE_HOST_DEVICE
|
198 |
+
void cp_async_wait() {
|
199 |
+
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
200 |
+
asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
|
201 |
+
#endif
|
202 |
+
}
|
203 |
+
|
204 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
205 |
+
|
206 |
+
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
|
207 |
+
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
208 |
+
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
209 |
+
__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
|
210 |
+
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
|
211 |
+
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
|
212 |
+
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
213 |
+
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
214 |
+
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
215 |
+
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
216 |
+
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
217 |
+
// There's no case where !Clear_OOB_K && Clear_OOB_MN
|
218 |
+
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
|
219 |
+
#pragma unroll
|
220 |
+
for (int m = 0; m < size<1>(S); ++m) {
|
221 |
+
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
222 |
+
#pragma unroll
|
223 |
+
for (int k = 0; k < size<2>(S); ++k) {
|
224 |
+
if (Is_even_K || predicate_K(k)) {
|
225 |
+
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
|
226 |
+
} else if (Clear_OOB_K) {
|
227 |
+
cute::clear(D(_, m, k));
|
228 |
+
}
|
229 |
+
}
|
230 |
+
} else if (Clear_OOB_MN) {
|
231 |
+
cute::clear(D(_, m, _));
|
232 |
+
}
|
233 |
+
}
|
234 |
+
}
|
235 |
+
|
236 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
237 |
+
|
238 |
+
} // namespace flash
|
tests/__init__.py
ADDED
File without changes
|
tests/test_flash_mla.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import flash_mla
|
6 |
+
|
7 |
+
# TODO: revise to use the same test as the original code
|
8 |
+
|
9 |
+
|
10 |
+
def test_flash_mla():
|
11 |
+
# b = 128
|
12 |
+
# s_q = 4096
|
13 |
+
# mean_sk = 8192
|
14 |
+
# h_q = 16
|
15 |
+
# h_kv = 1
|
16 |
+
# d = 576
|
17 |
+
# dv = 512
|
18 |
+
|
19 |
+
b = 16
|
20 |
+
s_q = 16
|
21 |
+
mean_sk = 16
|
22 |
+
h_q = 16
|
23 |
+
h_kv = 1
|
24 |
+
d = 576
|
25 |
+
dv = 512
|
26 |
+
|
27 |
+
|
28 |
+
causal = True
|
29 |
+
varlen = False
|
30 |
+
|
31 |
+
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}")
|
32 |
+
|
33 |
+
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
|
34 |
+
if varlen:
|
35 |
+
for i in range(b):
|
36 |
+
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
|
37 |
+
total_seqlens = cache_seqlens.sum().item()
|
38 |
+
mean_seqlens = cache_seqlens.float().mean().int().item()
|
39 |
+
max_seqlen = cache_seqlens.max().item()
|
40 |
+
# TODO: avoid triton from original code
|
41 |
+
# max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
|
42 |
+
print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
|
43 |
+
max_seqlen_pad = max_seqlen + 255 & ~255 # round up to multiple of 256
|
44 |
+
q = torch.randn(b, s_q, h_q, d)
|
45 |
+
block_size = 64
|
46 |
+
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(
|
47 |
+
b, max_seqlen_pad // block_size
|
48 |
+
)
|
49 |
+
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
|
50 |
+
print(blocked_k.shape)
|
51 |
+
for i in range(b):
|
52 |
+
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = float(
|
53 |
+
"nan"
|
54 |
+
)
|
55 |
+
blocked_v = blocked_k[..., :dv]
|
56 |
+
print(blocked_k.shape, blocked_v.shape)
|
57 |
+
|
58 |
+
cache_seqlens = cache_seqlens.to("cuda")
|
59 |
+
|
60 |
+
tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(
|
61 |
+
seqlens_k=cache_seqlens,
|
62 |
+
#
|
63 |
+
s_q=s_q * h_q // h_kv,
|
64 |
+
h_kv=h_kv,
|
65 |
+
)
|
66 |
+
print(tile_scheduler_metadata, num_splits)
|
67 |
+
|
68 |
+
# TODO: update to expect the correct output
|
69 |
+
assert False
|
torch-ext/flash_mla/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ._ops import ops
|
4 |
+
|
5 |
+
|
6 |
+
def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int):
|
7 |
+
return ops.get_mla_metadata(seqlens_k, s_q, h_kv)
|
8 |
+
|
9 |
+
|
10 |
+
def mha_fwd_kvcache_mla(
|
11 |
+
q: torch.Tensor,
|
12 |
+
kcache: torch.Tensor,
|
13 |
+
vcache_: torch.Tensor,
|
14 |
+
head_size_v: int,
|
15 |
+
seqlens_k: torch.Tensor,
|
16 |
+
block_table: torch.Tensor,
|
17 |
+
softmax_scale: float,
|
18 |
+
is_causal_: bool,
|
19 |
+
tile_scheduler_metadata: torch.Tensor,
|
20 |
+
num_splits: torch.Tensor,
|
21 |
+
) -> torch.Tensor:
|
22 |
+
# TODO: remove when resolved
|
23 |
+
unknown_param = 0
|
24 |
+
return ops.mha_fwd_kvcache_mla(
|
25 |
+
q,
|
26 |
+
kcache,
|
27 |
+
vcache_,
|
28 |
+
head_size_v,
|
29 |
+
seqlens_k,
|
30 |
+
block_table,
|
31 |
+
softmax_scale,
|
32 |
+
is_causal_,
|
33 |
+
tile_scheduler_metadata,
|
34 |
+
num_splits,
|
35 |
+
unknown_param,
|
36 |
+
)
|
torch-ext/torch_binding.cpp
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/library.h>
|
2 |
+
|
3 |
+
#include "registration.h"
|
4 |
+
#include "torch_binding.h"
|
5 |
+
|
6 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
7 |
+
ops.def("get_mla_metadata(Tensor! seqlens_k, int num_heads_per_head_k, int num_heads_k) -> Tensor[]");
|
8 |
+
ops.impl("get_mla_metadata", torch::kCUDA, &get_mla_metadata);
|
9 |
+
|
10 |
+
// TOOD: remove last unknown_param when resolved
|
11 |
+
ops.def("mha_fwd_kvcache_mla(Tensor! q, Tensor! kcache, Tensor! vcache_, int head_size_v, Tensor! seqlens_k, Tensor! block_table, float softmax_scale, bool is_causal_, Tensor! tile_scheduler_metadata, Tensor! num_splits, int unknown_param) -> Tensor[]");
|
12 |
+
ops.impl("mha_fwd_kvcache_mla", torch::kCUDA, &mha_fwd_kvcache_mla);
|
13 |
+
}
|
14 |
+
|
15 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/torch.h>
|
4 |
+
|
5 |
+
std::vector<torch::Tensor>
|
6 |
+
get_mla_metadata(
|
7 |
+
torch::Tensor &seqlens_k,
|
8 |
+
const int64_t num_heads_per_head_k,
|
9 |
+
const int64_t num_heads_k
|
10 |
+
);
|
11 |
+
|
12 |
+
std::vector<torch::Tensor>
|
13 |
+
mha_fwd_kvcache_mla(
|
14 |
+
torch::Tensor &q,
|
15 |
+
const torch::Tensor &kcache,
|
16 |
+
|
17 |
+
// TODO: fix for optional
|
18 |
+
// std::optional<torch::Tensor> &vcache_,
|
19 |
+
|
20 |
+
const torch::Tensor &vcache_,
|
21 |
+
const int64_t head_size_v,
|
22 |
+
const torch::Tensor &seqlens_k,
|
23 |
+
const torch::Tensor &block_table,
|
24 |
+
|
25 |
+
// TODO:should be float
|
26 |
+
const double softmax_scale,
|
27 |
+
|
28 |
+
// TODO: fix for mutable bool
|
29 |
+
const bool is_causal_,
|
30 |
+
|
31 |
+
const torch::Tensor &tile_scheduler_metadata,
|
32 |
+
const torch::Tensor &num_splits,
|
33 |
+
|
34 |
+
// TODO: remove when resolved
|
35 |
+
const int64_t unknown_param = 0
|
36 |
+
);
|