wunein
commited on
Commit
·
2129890
1
Parent(s):
d0db884
Add support for flash-attention2
Browse filesflash-attention2 rename the `flash_attn_unpadded_func`
- modeling_qwen.py +7 -3
modeling_qwen.py
CHANGED
@@ -66,11 +66,15 @@ _CONFIG_FOR_DOC = "QWenConfig"
|
|
66 |
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
|
67 |
|
68 |
try:
|
69 |
-
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
|
|
|
|
|
|
|
|
|
|
70 |
except ImportError:
|
71 |
flash_attn_unpadded_func = None
|
72 |
-
print("
|
73 |
-
"https://github.com/Dao-AILab/flash-attention")
|
74 |
|
75 |
|
76 |
class FlashSelfAttention(torch.nn.Module):
|
|
|
66 |
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
|
67 |
|
68 |
try:
|
69 |
+
# from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
70 |
+
import flash_attn
|
71 |
+
if int(flash_attn.__version__.split(".")[0]) == 1:
|
72 |
+
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
73 |
+
if int(flash_attn.__version__.split(".")[0]) == 2:
|
74 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
|
75 |
except ImportError:
|
76 |
flash_attn_unpadded_func = None
|
77 |
+
print("import flash_attn qkv fail")
|
|
|
78 |
|
79 |
|
80 |
class FlashSelfAttention(torch.nn.Module):
|