Update modeling_baichuan.py
Browse files这样实现才合适吧,利用了xops,同时跟7B模型的实现保持一致,也兼容低版本pytorch
- modeling_baichuan.py +10 -7
modeling_baichuan.py
CHANGED
|
@@ -30,7 +30,8 @@ except ImportError:
|
|
| 30 |
logger.warning(
|
| 31 |
"Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
|
| 32 |
)
|
| 33 |
-
|
|
|
|
| 34 |
|
| 35 |
def _get_interleave(n):
|
| 36 |
def _get_interleave_power_of_2(n):
|
|
@@ -173,12 +174,14 @@ class BaichuanAttention(torch.nn.Module):
|
|
| 173 |
past_key_value = (key_states, value_states) if use_cache else None
|
| 174 |
if xops is not None and self.training:
|
| 175 |
attn_weights = None
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
|
|
|
|
|
|
| 182 |
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
| 183 |
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
|
| 184 |
attn_output = attn_output.transpose(1, 2)
|
|
|
|
| 30 |
logger.warning(
|
| 31 |
"Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
|
| 32 |
)
|
| 33 |
+
|
| 34 |
+
pytorch_major_version = int(torch.__version__.split('.')[0])
|
| 35 |
|
| 36 |
def _get_interleave(n):
|
| 37 |
def _get_interleave_power_of_2(n):
|
|
|
|
| 174 |
past_key_value = (key_states, value_states) if use_cache else None
|
| 175 |
if xops is not None and self.training:
|
| 176 |
attn_weights = None
|
| 177 |
+
query_states = query_states.transpose(1, 2)
|
| 178 |
+
key_states = key_states.transpose(1, 2)
|
| 179 |
+
value_states = value_states.transpose(1, 2)
|
| 180 |
+
attn_output = xops.memory_efficient_attention(
|
| 181 |
+
query_states, key_states, value_states, attn_bias=attention_mask
|
| 182 |
+
)
|
| 183 |
+
elif pytorch_major_version >= 2:
|
| 184 |
+
attn_weights = None
|
| 185 |
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
| 186 |
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
|
| 187 |
attn_output = attn_output.transpose(1, 2)
|