feat: try to monkey-patch index_first_axis
Browse files- modeling_bert.py +8 -4
modeling_bert.py
CHANGED
@@ -28,12 +28,16 @@ from transformers.models.bert.modeling_bert import (
|
|
28 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
29 |
BertForPreTrainingOutput,
|
30 |
)
|
31 |
-
from .patched_padding_bert import index_first_axis
|
|
|
|
|
|
|
32 |
from flash_attn.bert_padding import (
|
33 |
index_first_axis_residual,
|
34 |
pad_input,
|
35 |
unpad_input,
|
36 |
)
|
|
|
37 |
from flash_attn.modules.block import Block
|
38 |
from flash_attn.modules.embedding import BertEmbeddings
|
39 |
from flash_attn.modules.mha import MHA
|
@@ -172,14 +176,14 @@ class BertEncoder(nn.Module):
|
|
172 |
hidden_states = hidden_states[subset_mask]
|
173 |
else:
|
174 |
batch, seqlen = hidden_states.shape[:2]
|
175 |
-
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
176 |
hidden_states, key_padding_mask
|
177 |
)
|
178 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
179 |
if subset_mask is None:
|
180 |
for layer in self.layers:
|
181 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
182 |
-
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
183 |
else:
|
184 |
for layer in self.layers[:-1]:
|
185 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
@@ -197,7 +201,7 @@ class BertEncoder(nn.Module):
|
|
197 |
subset_cu_seqlens = F.pad(
|
198 |
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
|
199 |
)
|
200 |
-
hidden_states_subset, hidden_states = index_first_axis_residual(
|
201 |
hidden_states, subset_idx
|
202 |
)
|
203 |
# It's ok to set max_seqlen_q to be much larger
|
|
|
28 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
29 |
BertForPreTrainingOutput,
|
30 |
)
|
31 |
+
from .patched_padding_bert import index_first_axis as index_first_axis_monkey_patch
|
32 |
+
import flash_attn.bert_padding
|
33 |
+
flash_attn.bert_padding.index_first_axis = index_first_axis_monkey_patch
|
34 |
+
"""
|
35 |
from flash_attn.bert_padding import (
|
36 |
index_first_axis_residual,
|
37 |
pad_input,
|
38 |
unpad_input,
|
39 |
)
|
40 |
+
"""
|
41 |
from flash_attn.modules.block import Block
|
42 |
from flash_attn.modules.embedding import BertEmbeddings
|
43 |
from flash_attn.modules.mha import MHA
|
|
|
176 |
hidden_states = hidden_states[subset_mask]
|
177 |
else:
|
178 |
batch, seqlen = hidden_states.shape[:2]
|
179 |
+
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = flash_attn.bert_padding.unpad_input(
|
180 |
hidden_states, key_padding_mask
|
181 |
)
|
182 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
183 |
if subset_mask is None:
|
184 |
for layer in self.layers:
|
185 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
186 |
+
hidden_states = flash_attn.bert_padding.pad_input(hidden_states, indices, batch, seqlen)
|
187 |
else:
|
188 |
for layer in self.layers[:-1]:
|
189 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
|
201 |
subset_cu_seqlens = F.pad(
|
202 |
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
|
203 |
)
|
204 |
+
hidden_states_subset, hidden_states = flash_attn.bert_padding.index_first_axis_residual(
|
205 |
hidden_states, subset_idx
|
206 |
)
|
207 |
# It's ok to set max_seqlen_q to be much larger
|