Qwen
/

wunein commited on
Commit
2129890
·
1 Parent(s): d0db884

Add support for flash-attention2

Browse files

flash-attention2 rename the `flash_attn_unpadded_func`

Files changed (1) hide show
  1. modeling_qwen.py +7 -3
modeling_qwen.py CHANGED
@@ -66,11 +66,15 @@ _CONFIG_FOR_DOC = "QWenConfig"
66
  QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
67
 
68
  try:
69
- from flash_attn.flash_attn_interface import flash_attn_unpadded_func
 
 
 
 
 
70
  except ImportError:
71
  flash_attn_unpadded_func = None
72
- print("Warning: import flash_attn fail, please install FlashAttention "
73
- "https://github.com/Dao-AILab/flash-attention")
74
 
75
 
76
  class FlashSelfAttention(torch.nn.Module):
 
66
  QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
67
 
68
  try:
69
+ # from flash_attn.flash_attn_interface import flash_attn_unpadded_func
70
+ import flash_attn
71
+ if int(flash_attn.__version__.split(".")[0]) == 1:
72
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func
73
+ if int(flash_attn.__version__.split(".")[0]) == 2:
74
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
75
  except ImportError:
76
  flash_attn_unpadded_func = None
77
+ print("import flash_attn qkv fail")
 
78
 
79
 
80
  class FlashSelfAttention(torch.nn.Module):