Update attention.py
Browse files- attention.py +1 -1
attention.py
CHANGED
@@ -87,7 +87,7 @@ def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None
|
|
87 |
|
88 |
def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
89 |
try:
|
90 |
-
from
|
91 |
except:
|
92 |
raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
|
93 |
check_valid_inputs(query, key, value)
|
|
|
87 |
|
88 |
def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
89 |
try:
|
90 |
+
from flash_attn import flash_attn_triton
|
91 |
except:
|
92 |
raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
|
93 |
check_valid_inputs(query, key, value)
|