flash-mla / torch-ext /torch_binding.h
drbh
fix: adjust missed double type
90cbc4b
#pragma once
#include <torch/torch.h>
std::vector<torch::Tensor>
get_mla_metadata(
torch::Tensor &seqlens_k,
const int64_t num_heads_per_head_k,
const int64_t num_heads_k
);
std::vector<torch::Tensor>
mha_fwd_kvcache_mla(
torch::Tensor &q,
const torch::Tensor &kcache,
const c10::optional<torch::Tensor> &vcache_,
const int64_t head_size_v,
const torch::Tensor &seqlens_k,
const torch::Tensor &block_table,
const double softmax_scale,
bool is_causal,
const torch::Tensor &tile_scheduler_metadata,
const torch::Tensor &num_splits
);