drbh commited on
Commit
d76b04d
·
1 Parent(s): 5cb0596

fix: remove unused trailing param

Browse files
flash_mla/flash_mla_api.cu CHANGED
@@ -70,10 +70,10 @@ mha_fwd_kvcache_mla(
70
  const double softmax_scale,
71
  const bool is_causal_,
72
  const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
73
- const at::Tensor &num_splits, // batch_size + 1
74
 
75
  // TODO: remove this once determined why build is adding this parameter
76
- const int64_t unknown_param
77
  ) {
78
  auto dprops = at::cuda::getCurrentDeviceProperties();
79
  bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
 
70
  const double softmax_scale,
71
  const bool is_causal_,
72
  const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
73
+ const at::Tensor &num_splits // batch_size + 1
74
 
75
  // TODO: remove this once determined why build is adding this parameter
76
+ // const int64_t unknown_param
77
  ) {
78
  auto dprops = at::cuda::getCurrentDeviceProperties();
79
  bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
torch-ext/flash_mla/__init__.py CHANGED
@@ -19,8 +19,6 @@ def mha_fwd_kvcache_mla(
19
  tile_scheduler_metadata: torch.Tensor,
20
  num_splits: torch.Tensor,
21
  ) -> torch.Tensor:
22
- # TODO: remove when resolved
23
- unknown_param = 0
24
  return ops.mha_fwd_kvcache_mla(
25
  q,
26
  kcache,
@@ -31,6 +29,5 @@ def mha_fwd_kvcache_mla(
31
  softmax_scale,
32
  is_causal_,
33
  tile_scheduler_metadata,
34
- num_splits,
35
- unknown_param,
36
  )
 
19
  tile_scheduler_metadata: torch.Tensor,
20
  num_splits: torch.Tensor,
21
  ) -> torch.Tensor:
 
 
22
  return ops.mha_fwd_kvcache_mla(
23
  q,
24
  kcache,
 
29
  softmax_scale,
30
  is_causal_,
31
  tile_scheduler_metadata,
32
+ num_splits
 
33
  )
torch-ext/torch_binding.cpp CHANGED
@@ -8,7 +8,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
8
  ops.impl("get_mla_metadata", torch::kCUDA, &get_mla_metadata);
9
 
10
  // TOOD: remove last unknown_param when resolved
11
- ops.def("mha_fwd_kvcache_mla(Tensor! q, Tensor! kcache, Tensor! vcache_, int head_size_v, Tensor! seqlens_k, Tensor! block_table, float softmax_scale, bool is_causal_, Tensor! tile_scheduler_metadata, Tensor! num_splits, int unknown_param) -> Tensor[]");
12
  ops.impl("mha_fwd_kvcache_mla", torch::kCUDA, &mha_fwd_kvcache_mla);
13
  }
14
 
 
8
  ops.impl("get_mla_metadata", torch::kCUDA, &get_mla_metadata);
9
 
10
  // TOOD: remove last unknown_param when resolved
11
+ ops.def("mha_fwd_kvcache_mla(Tensor! q, Tensor! kcache, Tensor! vcache_, int head_size_v, Tensor! seqlens_k, Tensor! block_table, float softmax_scale, bool is_causal_, Tensor! tile_scheduler_metadata, Tensor! num_splits) -> Tensor[]");
12
  ops.impl("mha_fwd_kvcache_mla", torch::kCUDA, &mha_fwd_kvcache_mla);
13
  }
14
 
torch-ext/torch_binding.h CHANGED
@@ -29,8 +29,5 @@ mha_fwd_kvcache_mla(
29
  const bool is_causal_,
30
 
31
  const torch::Tensor &tile_scheduler_metadata,
32
- const torch::Tensor &num_splits,
33
-
34
- // TODO: remove when resolved
35
- const int64_t unknown_param = 0
36
  );
 
29
  const bool is_causal_,
30
 
31
  const torch::Tensor &tile_scheduler_metadata,
32
+ const torch::Tensor &num_splits
 
 
 
33
  );