std::vector<torch::Tensor> | |
get_mla_metadata( | |
torch::Tensor &seqlens_k, | |
const int64_t num_heads_per_head_k, | |
const int64_t num_heads_k | |
); | |
std::vector<torch::Tensor> | |
mha_fwd_kvcache_mla( | |
torch::Tensor &q, | |
const torch::Tensor &kcache, | |
const c10::optional<torch::Tensor> &vcache_, | |
const int64_t head_size_v, | |
const torch::Tensor &seqlens_k, | |
const torch::Tensor &block_table, | |
const double softmax_scale, | |
bool is_causal, | |
const torch::Tensor &tile_scheduler_metadata, | |
const torch::Tensor &num_splits | |
); |