Update llama_xformers_attention.py
Browse files- llama_xformers_attention.py +16 -28
llama_xformers_attention.py
CHANGED
@@ -3,7 +3,7 @@ import torch.nn as nn
|
|
3 |
|
4 |
from typing import Optional, Tuple
|
5 |
|
6 |
-
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
7 |
|
8 |
from xformers.ops.fmha import (
|
9 |
memory_efficient_attention,
|
@@ -51,33 +51,21 @@ class LlamaXFormersAttention(LlamaAttention):
|
|
51 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
52 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
if attention_mask is
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
# upcast attention to fp32
|
70 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
71 |
-
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
72 |
-
attn_output = torch.matmul(attn_weights, value_states)
|
73 |
-
|
74 |
-
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
75 |
-
raise ValueError(
|
76 |
-
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
77 |
-
f" {attn_output.size()}"
|
78 |
-
)
|
79 |
-
|
80 |
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
81 |
|
82 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
83 |
|
|
|
3 |
|
4 |
from typing import Optional, Tuple
|
5 |
|
6 |
+
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
7 |
|
8 |
from xformers.ops.fmha import (
|
9 |
memory_efficient_attention,
|
|
|
51 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
52 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
53 |
|
54 |
+
dtype = query_states.dtype
|
55 |
+
|
56 |
+
query_states = query_states.transpose(1, 2)
|
57 |
+
key_states = key_states.transpose(1, 2)
|
58 |
+
value_states = value_states.transpose(1, 2)
|
59 |
+
|
60 |
+
#This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
61 |
+
#We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
62 |
+
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
63 |
+
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
64 |
+
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
|
65 |
+
else:
|
66 |
+
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
67 |
+
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask())
|
68 |
+
attn_weights = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
71 |
|