ybelkada commited on
Commit
491c189
1 Parent(s): d281c7a

Update llama_xformers_attention.py

Browse files
Files changed (1) hide show
  1. llama_xformers_attention.py +3 -2
llama_xformers_attention.py CHANGED
@@ -57,8 +57,9 @@ class LlamaXFormersAttention(LlamaAttention):
57
  key_states = key_states.transpose(1, 2)
58
  value_states = value_states.transpose(1, 2)
59
 
60
- #This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
61
- #We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
 
62
  if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
63
  # input and output should be of form (bsz, q_len, num_heads, head_dim)
64
  attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
 
57
  key_states = key_states.transpose(1, 2)
58
  value_states = value_states.transpose(1, 2)
59
 
60
+ # copied from https://github.com/oobabooga/text-generation-webui/pull/950/files
61
+ # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
62
+ # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
63
  if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
64
  # input and output should be of form (bsz, q_len, num_heads, head_dim)
65
  attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)