HuggingFace version does NOT use efficient MLA caching

#95
by Avelina - opened

As the title suggests, the version of this model provided in modelling_deepseek.py does NOT make use of the efficient MLA caching mechanism pioneered by DeepSeek V2 and V3.

The relevant code is here: https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py#L810

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(
                key_states, value_states, self.layer_idx, cache_kwargs
            )

Notice that the full fat keys and values are cached after up-projection, rather than caching compressed_kv (or self.kv_a_layernorm(compressed_kv)) and k_pe as is done in the native implementation in the DeepSeek repo.

I imagine this was done because of the way HF handles the Cache and the devs porting this for HF didn't want to deal with incompatibilities between typical MHA and MLA.

However, by storing k_pe as the keys and compressed_kv as the values we would be able to use efficient MLA caching AND support cache-managed rotary embeddings. Additionally, by fiddling with the config class we can 'trick' the special Cache variants -- like Static or Sink caches -- into pre-allocating the correct tensor shapes which would otherwise be incompatible with MLA.

@Avelina I instantiated a mock model with the mod you suggested but got no inference speed-up via model.generate on the transformers library. Do you think it is necessary to use a specific inference setup?

@hertric I did some more research into the matter and there are a few things at play here. There are like... 3 or 4 levels of optimization that can be done.

  1. How HF does it with caching of decompressed keys and values. This has increased KV cache memory overhead.
  2. How I suggested things here. This use the least memory for storage of the KV cache, but requires decompression which uses additional compute. I find this has the same inference latency as 1 because it trades off better memory efficiency for slightly worse compute efficiency.
  3. Using actual latent attention, building on 2 but instead of decompressing the keys and values before attention, performing attention in the latent space and then decompression on the output of the attention operation. This is actually slower than 1 and 2 because the attention kernel isn't optimized for this, but results in memory reductions for both the KV cache and from transient buffers as we only decompress 1 token at a time. (Look up the deep seek repo for how they do this naively.)
  4. Doing 3 but with either the FlashMLA kernel on a hopper or blackwell GPU, or by using torch 2.6's FlexAttention op which allows you to fuse the RoPE computation into the score calculation and share memory access across all key and value heads. This is both fastest and lowest memory solution.

If you want an out of the box solution and have an H100 lying around the FlashMLA kernel is your best bet, otherwise try and wrangle with Flex Attention using this implementation as a baseline: https://github.com/pytorch-labs/attention-gym/blob/main/examples/mla.py

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment