#pragma once #include std::vector get_mla_metadata( torch::Tensor &seqlens_k, const int64_t num_heads_per_head_k, const int64_t num_heads_k ); std::vector mha_fwd_kvcache_mla( torch::Tensor &q, const torch::Tensor &kcache, const c10::optional &vcache_, const int64_t head_size_v, const torch::Tensor &seqlens_k, const torch::Tensor &block_table, const double softmax_scale, bool is_causal, const torch::Tensor &tile_scheduler_metadata, const torch::Tensor &num_splits );