Update llama_xformers_attention.py
Browse files
llama_xformers_attention.py
CHANGED
@@ -57,8 +57,9 @@ class LlamaXFormersAttention(LlamaAttention):
|
|
57 |
key_states = key_states.transpose(1, 2)
|
58 |
value_states = value_states.transpose(1, 2)
|
59 |
|
60 |
-
#
|
61 |
-
#
|
|
|
62 |
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
63 |
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
64 |
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
|
|
|
57 |
key_states = key_states.transpose(1, 2)
|
58 |
value_states = value_states.transpose(1, 2)
|
59 |
|
60 |
+
# copied from https://github.com/oobabooga/text-generation-webui/pull/950/files
|
61 |
+
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
62 |
+
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
63 |
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
64 |
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
65 |
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
|