drbh
commited on
Commit
·
d76b04d
1
Parent(s):
5cb0596
fix: remove unused trailing param
Browse files
flash_mla/flash_mla_api.cu
CHANGED
@@ -70,10 +70,10 @@ mha_fwd_kvcache_mla(
|
|
70 |
const double softmax_scale,
|
71 |
const bool is_causal_,
|
72 |
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
73 |
-
const at::Tensor &num_splits
|
74 |
|
75 |
// TODO: remove this once determined why build is adding this parameter
|
76 |
-
const int64_t unknown_param
|
77 |
) {
|
78 |
auto dprops = at::cuda::getCurrentDeviceProperties();
|
79 |
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
|
|
70 |
const double softmax_scale,
|
71 |
const bool is_causal_,
|
72 |
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
73 |
+
const at::Tensor &num_splits // batch_size + 1
|
74 |
|
75 |
// TODO: remove this once determined why build is adding this parameter
|
76 |
+
// const int64_t unknown_param
|
77 |
) {
|
78 |
auto dprops = at::cuda::getCurrentDeviceProperties();
|
79 |
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
torch-ext/flash_mla/__init__.py
CHANGED
@@ -19,8 +19,6 @@ def mha_fwd_kvcache_mla(
|
|
19 |
tile_scheduler_metadata: torch.Tensor,
|
20 |
num_splits: torch.Tensor,
|
21 |
) -> torch.Tensor:
|
22 |
-
# TODO: remove when resolved
|
23 |
-
unknown_param = 0
|
24 |
return ops.mha_fwd_kvcache_mla(
|
25 |
q,
|
26 |
kcache,
|
@@ -31,6 +29,5 @@ def mha_fwd_kvcache_mla(
|
|
31 |
softmax_scale,
|
32 |
is_causal_,
|
33 |
tile_scheduler_metadata,
|
34 |
-
num_splits
|
35 |
-
unknown_param,
|
36 |
)
|
|
|
19 |
tile_scheduler_metadata: torch.Tensor,
|
20 |
num_splits: torch.Tensor,
|
21 |
) -> torch.Tensor:
|
|
|
|
|
22 |
return ops.mha_fwd_kvcache_mla(
|
23 |
q,
|
24 |
kcache,
|
|
|
29 |
softmax_scale,
|
30 |
is_causal_,
|
31 |
tile_scheduler_metadata,
|
32 |
+
num_splits
|
|
|
33 |
)
|
torch-ext/torch_binding.cpp
CHANGED
@@ -8,7 +8,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
8 |
ops.impl("get_mla_metadata", torch::kCUDA, &get_mla_metadata);
|
9 |
|
10 |
// TOOD: remove last unknown_param when resolved
|
11 |
-
ops.def("mha_fwd_kvcache_mla(Tensor! q, Tensor! kcache, Tensor! vcache_, int head_size_v, Tensor! seqlens_k, Tensor! block_table, float softmax_scale, bool is_causal_, Tensor! tile_scheduler_metadata, Tensor! num_splits
|
12 |
ops.impl("mha_fwd_kvcache_mla", torch::kCUDA, &mha_fwd_kvcache_mla);
|
13 |
}
|
14 |
|
|
|
8 |
ops.impl("get_mla_metadata", torch::kCUDA, &get_mla_metadata);
|
9 |
|
10 |
// TOOD: remove last unknown_param when resolved
|
11 |
+
ops.def("mha_fwd_kvcache_mla(Tensor! q, Tensor! kcache, Tensor! vcache_, int head_size_v, Tensor! seqlens_k, Tensor! block_table, float softmax_scale, bool is_causal_, Tensor! tile_scheduler_metadata, Tensor! num_splits) -> Tensor[]");
|
12 |
ops.impl("mha_fwd_kvcache_mla", torch::kCUDA, &mha_fwd_kvcache_mla);
|
13 |
}
|
14 |
|
torch-ext/torch_binding.h
CHANGED
@@ -29,8 +29,5 @@ mha_fwd_kvcache_mla(
|
|
29 |
const bool is_causal_,
|
30 |
|
31 |
const torch::Tensor &tile_scheduler_metadata,
|
32 |
-
const torch::Tensor &num_splits
|
33 |
-
|
34 |
-
// TODO: remove when resolved
|
35 |
-
const int64_t unknown_param = 0
|
36 |
);
|
|
|
29 |
const bool is_causal_,
|
30 |
|
31 |
const torch::Tensor &tile_scheduler_metadata,
|
32 |
+
const torch::Tensor &num_splits
|
|
|
|
|
|
|
33 |
);
|