结合Deepseek代码探讨MLA的改进及收益
官方训练代码: https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py
官方推理代码:https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
Q1:MLA的效果为什么能超过MHA/GQA?
A1:MLA使用了大于常规MHA/GQA的attention头数和头维度,这允许它使用了比隐藏层维度还要大很多的投影矩阵来表达Q/K/V的信息。比如DS-V3的heads=128,q_head_dim=192,128 * 192远大于它的隐藏层维度7168。
Q2:为什么MLA的推理效果好?
A2:因为在推理的decode阶段,推理是逐token进行的,计算量少但需要线性积累KV Cache,总的KV Cache的大小就成为了显存瓶颈。对比主流的GQA,MLA维持了一个极小的KV Cache,所以此时起到显著的作用。
Part1:参数定义和理论能力分析
hidden_size:隐藏层维度,设定为7168(V2 = 5120)
num_heads:attention头数,设定为128,比一般GQA的64要来得大
q_lora_rank:q投影的低秩压缩维度,设定为1536
qk_nope_head_dim:qk投影的非rope部分维度,设定为128
qk_rope_head_dim:qk投影的rope部分维度,设定为64
qk_head_dim:qk投影的合并维度,等于非rope+rope,即128+64=192
kv_lora_rank:kv投影的低秩压缩维度,设定为512
v_head_dim:v投影的维度,设定为128
(1)如果是常规MHA/MQA/GQA,head_dim头维度=7168/128=56。但注意到MLA里qk_head_dim=192,v_head_dim=128,所以MLA的attention头数和头维度都要比同等量级(7168)的MHA/MQA/GQA来的大,所以理论的信息表达能力更强。一个匹配量级的MHA/MQA/GQA应该是128 * 128 = 16384隐藏维度,等于llama3-405B的配置。
(2)rope的存在是因为MLA在推理时用到了矩阵吸收来进一步减少KV Cache,即只保存投影前的信息,而不是投影后的信息。但rope的位置信息无法被直接吸收,因此设计了专门的分量来保持rope信息。
(3)MLA的KV Cache大小只和qk_rope_head_dim、kv_lora_rank这两个变量有关。常规GQA的KV Cache等价于head_dim * n_group * 2 (K和V),而MLA的KV Cache等价于qk_rope_head_dim + kv_lora_rank(论文表1),因此:
【考虑DS-V2的情况,根据论文】
- head_dim=5120/128=40,注意力头数依然是128,大于一般GQA模型
- qk_rope_head_dim为0.5 * head_dim = 20,kv_lora_rank为4 * head_dim = 160
- 128头、5120维度的DS dense 67B模型的KV Cache等价于40 * 8 * 2=640
- 而DS-V2的KV Cache等价于160 + 20 = 180 = 40 * 2.25 * 2
- 所以得到DS-V2的MLA从Cache上只等价于组数为2.25的GQA,即论文表1结论:So, its KV cache is equal to GQA with only 2.25 groups。
- 同时DS-V2引入了非MLA的手段进一步压缩KV Cache,即论文图1(b)的结论:相比DS dense 67B节约了93.3%的KV Cache。这个结论的计算是:1 - 2.25 / 8 (MLA的节约) * 60 / 95 (层的减少) * 6 / 16 (6bit量化) = 93.3%。如果单考虑MLA,DS-V2的MLA相比dense GQA节约了71.88%的存储。
【考虑DS-V3的情况,激活参数37B】
- 这里head_dim=7168/128=56,注意力头数保持128
- qk_rope_head_dim为64 = 1.14 * 56,kv_lora_rank为512 = 9.14 * 56。注意到这里没有遵循V2论文的比例,而是把两个维度都进一步变大了
- 近似激活参数的Qwen2.5-32B:KV Cache等价于5120 / 40 * 8 * 2 = 2048
- 更大激活参数的Qwen2.5-72B:KV Cache等价于8192 / 64 * 8 * 2=2048
- 而DS-V3的KV Cache等价于512 + 64 = 576 = 28.13% * 2048,相对于Qwen2.5系列,每一层节约了1 - 28.13% = 71.88%的存储。
- 假如我们设计一个头数和头维度等于DS-V3的GQA,即128头 * 128头维度 = 16384隐藏维度,如果要保持一样的KV Cache,组数还只能是 = 576 / 2 / 128 = 2.25。
所以,相比于Qwen2.5,DS-V3实现了注意力头数翻倍、头维度更大以及KV Cache更小的方案。理论上来看,至少需要一个16384维度的GQA/MHA才能对齐DS-V3中MLA的头数和头维度信息表达。进一步,如果要维持同等KV Cache,这个GQA的组数只能开到2.25。常规的GQA模型都把组数开到8,可以预想2.25组的GQA效果是不够的。
Part2:Q的低秩压缩和回放
**常规MHA/MQA/GQA里Q的投影计算**
# (B, L, 7168) -> (B, L, 128 * 56 = 7168)
q_proj = nn.Linear(hidden_size, num_heads * head_dim)
----------------------------------------------------------------
q = q_proj(hidden_states)
**MLA中Q的投影计算**
# 7168 -> 1536
q_a_proj = nn.Linear(hidden_size, q_lora_rank)
# 1536 -> 1536
q_a_layernorm = DeepseekV3RMSNorm(q_lora_rank)
# 1536 -> 128 * 192 = 24576,注意这个维度远大于7168
q_b_proj = nn.Linear(q_lora_rank, num_heads * qk_head_dim)
----------------------------------------------------------------
# (B, L, 7168) -> (B, L, 1536) -> (B, L, 24576)
q = q_b_proj(q_a_layernorm(q_a_proj(hidden_states)))
# (B, L, 24576) -> (B, L, 128, 192) -> (B, 128, L, 192)
q = q.view(bsz, q_len, num_heads, qk_head_dim).transpose(1, 2)
可以对比看到,MLA中的Q被投影到了比常规MHA/MQA/GQA更大的矩阵上,这使得模型对Q的表达能力更强。从投影矩阵的计算量看:
- 常规MHA/MQA/GQA:7168 * 7168
- MLA:7168 * 1536 + 1536 + 1536 * 24576 = 6802.5 * 7168
综上,MLA中对Q的低秩压缩和回放实现了近似常规MHA的计算量,但投影到了更大矩阵上。
Part3:KV的低秩压缩和回放
**常规GQA里KV的投影计算,以K为例**
# (B, L, 7168) -> (B, L, 8 * 56 = 448)
k_proj = nn.Linear(hidden_size, n_groups * head_dim)
----------------------------------------------------------------
k = k_proj(hidden_states)
如果n_groups=1,那就是MQA -> (B, L, 56)
如果n_groups=128,那就是MHA -> (B, L, 128 * 56 = 7168)
**MLA中KV的投影计算**
# 7168 -> 512 + 64 = 576
kv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + qk_rope_head_dim)
# 512 -> 512
kv_a_layernorm = DeepseekV3RMSNorm(kv_lora_rank)
# 512 -> 128 * (128 + 128) = 32768,注意这个维度远大于7168
kv_b_proj = nn.Linear(kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim))
----------------------------------------------------------------
# (B, L, 7168) -> (B, L, 576)
compressed_kv = kv_a_proj_with_mqa(hidden_states)
# (B, L, 576) -> (B, L, 512)和(B, L, 64),rope解耦
compressed_kv, k_pe = torch.split(compressed_kv, [kv_lora_rank, qk_rope_head_dim], dim=-1)
# (B, L, 512) -> (B, L, 32768) -> (B, 128, L, 256)
kv = (kv_b_proj(kv_a_layernorm(compressed_kv)).view(bsz, q_len, num_heads, qk_nope_head_dim + v_head_dim).transpose(1, 2))
(1)注意到Q的投影是低秩压缩->回放,而KV的投影是低秩压缩->rope解耦->非rope部分回放。这里KV先做rope解耦是为了让K的rope共享。论文2.1.3的第二段中提到:a shared key to carry RoPE。
(2)可以对比看到,MLA中的KV也被投影到了比常规GQA更大的矩阵上,这使得模型对KV的表达能力更强。从投影矩阵的计算量看:
- 常规GQA:7168 * (56 * 8) * 2 (K和V)
- MLA:7168 * 576 + 512 + 512 * 32768 = 3.26倍常规GQA
综上,MLA中对KV的低秩压缩和回放使用了等价于常规GQA 3.26倍的计算量,投影到更大的矩阵。此外,尽管KV的投影处理耗费了更多的计算量,但是KV Cache并不大 (Part4)。
Part4: KV Cache的管理
**常规GQA里KV Cache的管理,做法是保存投影后的k**
k = k_proj(hidden_states)
保存的KV Cache等价于head_dim * groups * 2 = 56 * 8 * 2 (K和V)
如果是MHA,KV Cache就等于56 * 128 * 2
如果是MQA,KV Cache就等于56 * 2
**MLA里KV Cache的管理 -- 训练代码**
# (B, 128, L, 256) -> (B, 128, L, 128)和(B, 128, L, 128),KV分离
k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1)
# (B, L, 64) -> (B, L, 1, 64) -> (B, 1, L, 64),这里k_pe是被扩展了维度,在后面用于广播和复制,对应论文中说的K的rope共享
k_pe = k_pe.view(bsz, q_len, 1, qk_rope_head_dim).transpose(1, 2)
# 添加rope信息
k_pe = apply_rotary_pos_emb(k_pe, cos, sin, position_ids)
# 初始化一个新矩阵,与k_pe同类型和设备,大小为(B, 128, L, 192)
key_states = k_pe.new_empty(bsz, num_heads, q_len, qk_head_dim)
key_states[:, :, :, :qk_nope_head_dim] = k_nope
key_states[:, :, :, qk_nope_head_dim:] = k_pe
**MLA里KV Cache的管理 -- 推理代码**
# 保存KV Cache,大小是(B, L, 512) + (B, L, 64)(降维,只需要存1份)
kv_cache = torch.zeros(bsz, seq_len, kv_lora_rank)
pe_cache = torch.zeros(bsz, seq_len, qk_rope_head_dim)
kv_cache[:bsz, start_pos:end_pos] = compressed_kv
pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(1)
(1)注意到训练代码中,MLA保存的KV Cache是key_states和vale_states,那么实际保存的KV Cache等价于heads * (qk_head_dim + v_head_dim) = 128 * (192 + 128) > 128 * (56 * 2),这实际上比MHA的KV Cache还要大。这和上面提到的KV投影参数量大是吻合的,因为正常来说KV Cache就是跟头维度、组数有关,MLA的头维度比正常的维度 (7168 // 128 = 56) 要大很多。
(2)注意到推理代码中,MLA的KV Cache才是论文表1给到的(kv_lora_rank + qk_rope_head_dim),这种方式需要与矩阵吸收联合 (Part5)。论文2.1.2中提到:In addition, during inference, since can be absorbed into and can be absorbed into , we even do not need to compute keys and values out for attention。
Part5: 应用矩阵吸收计算Attention的权重和分数
**常规MHA/MQA/GQA的Attention权重和分数计算**
# 128 * 56 = 7168 -> 7168
o_proj = nn.Linear(num_heads * head_dim, hidden_size)
----------------------------------------------------------------
# (B, L, 7168) * (B, 7168, L) -> (B, L, L)
attn_weights = torch.matmul(q, k.transpose(1, 2))
# (B, L, L) * (B, L, 7168) -> (B, L, 7168)
output = o_proj(torch.matmul(attn_weights, v))
**MLA的Attention权重和分数计算 -- 训练代码,无矩阵吸收**
# 128 * 128 = 16384 -> 7168,注意到最后的整合矩阵也非常大
o_proj = nn.Linear(num_heads * v_head_dim, hidden_size)
----------------------------------------------------------------
# (B, 128, L, 192) -> (B, 128, L, 128)和(B, 128, L, 64),rope解耦
q_nope, q_pe = torch.split(q, [qk_nope_head_dim, qk_rope_head_dim], dim=-1)
# 添加rope信息
q_pe = apply_rotary_pos_emb(q_pe, cos, sin, position_ids)
# 同样对Q初始化一个新矩阵,大小为(B, 128, L, 192),与k_pe同类型和设备
query_states = k_pe.new_empty(bsz, num_heads, q_len, qk_head_dim)
# 将q的非rope部分和rope部分(添加信息后)合并
query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
# (B, 128, L, 192) * (B, 128, 192, L) -> (B, 128, L, L)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
# (B, 128, L, L) * (B, 128, L, 128) -> (B, 128, L, 128)
attn_output = torch.matmul(attn_weights, value_states)
# (B, 128, L, 128) -> (B, L, 128, 128)
attn_output = attn_output.transpose(1, 2).contiguous()
# (B, L, 128, 128) -> (B, L, 32768)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
# (B, L, 32768) -> (B, L, 7168)
attn_output = self.o_proj(attn_output)
**MLA的Attention权重和分数计算 -- 推理代码,含矩阵吸收**
# 128 * 128 = 16384 -> 7168
o_proj = nn.Linear(num_heads * v_head_dim, hidden_size)
----------------------------------------------------------------
# 提取KV回放矩阵的权重: (32768, 512) -> (128, 256, 512)
wkv_b = kv_b_proj.weight
wkv_b = wkv_b.view(num_heads, -1, kv_lora_rank)
# (B, 128, L, 128) X (128, 128, 512) (分割出前一半k上非rope部分的权重,后一半是v的权重) -> (B, 128, L, 512),得到吸收了Wk的q
q_nope = torch.einsum("bhsd,hdc->bhsc", q_nope, wkv_b[:, :qk_nope_head_dim])
# (B, 128, L, 512) X (B, L, 512) -> (B, 128, L, L)
attn_weights_nope = torch.einsum("bhsc,btc->bhst", q_nope, kv_cache[:bsz, :end_pos])
# (B, 128, L, 64) X (B, L, 64) -> (B, 128, L, L)
attn_weights_pe = torch.einsum("bhsr,btr->bhst", q_pe, pe_cache[:bsz, :end_pos])
attn_weights = attn_weights_nope + attn_weights_pe
# (B, 128, L, L) * (B, L, 512) -> (B, 128, L, 512)
attn_output = torch.einsum("bhst,btc->bhsc", attn_weights, kv_cache[:bsz, :end_pos])
# (B, 128, L, 512) * (128, 128, 512) (从wkv_b中分割出v的权重,前一半是k上非rope部分的权重) -> (B, 128, L, 128),得到吸收了Wv的输出
attn_output = torch.einsum("bhsc,hdc->bhsd", attn_output, wkv_b[:, -self.v_head_dim:])
# (B, 128, L, 128) -> 展开为(B, L, 16384) -> (B, L, 7168)
attn_output = o_proj(attn_output.transpose(1, 2).flatten(2))
(1)注意到训练代码中,MLA把q的rope部分处理后,重新拼接成heads * 192(nope_dim + rope_dim)的矩阵,然后使用标准MHA的方式计算Attention权重和分数。这里的计算量比标准MHA还要大,因为7168维度的标准MHA的head_dim只有7168 / 128 = 56。
(2)注意到推理代码中使用了矩阵吸收。矩阵吸收原理是:
- 标准Attention权重的计算是 ,对应保存的K Cache是 ;Attention值的计算是 ,对应保存的V Cache是 。简单来说,保存的KV Cache都是投影后的矩阵。
- 但可以通过矩阵的恒等变化,得到吸收后的权重计算方式为 ,对应保存的K Cache就只有末尾的x了, 整个作为q投影的计算。同理,Attention值的计算也可以将 吸收到 中去。此时V Cache的保存也只有x,和K Cache完全一致,所以只需要保存此前标准计算一半的Cache。
- 不难发现,矩阵吸收不是MLA独有的,MHA也可以做这样的操作。但矩阵吸收的做法本质上是用计算量换空间的策略。对于MHA来说,标准计算方式的计算量为 / / / 四个矩阵的映射,等于4 * dim * head_dim;如果改成矩阵吸收的做法,计算量为 / 两个吸收矩阵的映射,每个吸收矩阵的维度从(dim, head_dim)变成了(dim, dim),所以计算量等于2 * dim * dim,等于heads/2倍的标准版计算量。这个计算量会明显增加推理时延,因此标准MHA的KV Cache还是采用了缓存投影分量 + 低计算量的方案。
- 但对于MLA来说,这个方案是可行的,因为MLA提前做了维度压缩。对于MLA来说,要保存的KV Cache不再是整个x,而是压缩后的KV和共享的rope;这实际上等于一个低秩压缩后的MQA,即kv_lora_rank + qk_rope_head_dim(根据上面的结论,压缩了71.88%的存储)。在计算量上,MLA等于dim * kv_lora_rank + dim * qk_rope_head_dim。如果考虑DS-V2的设定,kv_lora_rank = 4 * head_dim,qk_rope_head_dim = 0.5 * head_dim,MLA的计算量就等于4.5 * dim * head_dim,即等于1.125倍标准版计算量,远小于之前的heads/2倍(64)。所以,MLA + 矩阵吸收等于显著优化了KV Cache,同时用维度压缩控制了计算量。需要注意到,这里的标准版计算量是基于5120维度、128头的DS dense 67B模型给出的。
总而言之,MLA在推理时维持一个比标准Dense模型稍大一些的计算量,但维持了一个极小的KV Cache。LLM推理时,Prefill阶段的瓶颈是计算量,MLA的矩阵吸收并没有优势,甚至更慢(参考GitHub上的讨论与实测);但在Decode阶段,由于推理是逐token进行的,计算量少但需要线性积累KV Cache,总的KV Cache的大小就成为了显存瓶颈,MLA此时起到显著的作用。
参考资料
【1】DS-V2:https://arxiv.org/abs/2405.04434
【2】DS-V3: https://arxiv.org/abs/2412.19437v1
【3】MLA原理解析:https://kexue.fm/archives/10091
【4】MLA矩阵吸收解析:https://zhuanlan.zhihu.com/p/697781431
【5】DS推理解析:https://zhuanlan.zhihu.com/p/23129261011
【6】DS推理代码解析:https://zhuanlan.zhihu.com/p/21380265337
【7】DS-V3参数量精算:https://yangwenbo.com/articles/deepseek-v3-parameter-size.html