Jackmin108
commited on
Commit
·
5ee2c37
1
Parent(s):
4fa2261
Remove triton flash implementation
Browse files- modeling_bert.py +0 -18
modeling_bert.py
CHANGED
@@ -63,12 +63,6 @@ try:
|
|
63 |
except ImportError:
|
64 |
scaled_dot_product_attention = None
|
65 |
|
66 |
-
# Triton implementation
|
67 |
-
try:
|
68 |
-
from .flash_attn_triton import flash_attn_func
|
69 |
-
except Exception:
|
70 |
-
flash_attn_func = None
|
71 |
-
|
72 |
# This is used by encode but user may not have it installed
|
73 |
try:
|
74 |
from tqdm.autonotebook import trange
|
@@ -324,18 +318,6 @@ class JinaBertSelfAttention(nn.Module):
|
|
324 |
output_attentions: Optional[bool] = False,
|
325 |
bias: Optional[torch.FloatTensor] = None,
|
326 |
) -> Tuple[torch.Tensor]:
|
327 |
-
if self.attn_implementation == 'triton':
|
328 |
-
b, s, h = hidden_states.shape
|
329 |
-
q = self.query(hidden_states)
|
330 |
-
k = self.key(hidden_states)
|
331 |
-
v = self.value(hidden_states)
|
332 |
-
# B x S x hidden_dim -> B x S x num_heads x head_dim
|
333 |
-
q = q.view(b, s, self.num_attention_heads, self.attention_head_size)
|
334 |
-
k = k.view(b, s, self.num_attention_heads, self.attention_head_size)
|
335 |
-
v = v.view(b, s, self.num_attention_heads, self.attention_head_size)
|
336 |
-
attn = flash_attn_func(q, k, v, bias)
|
337 |
-
return (attn.view(b, s, h),)
|
338 |
-
|
339 |
mixed_query_layer = self.query(hidden_states)
|
340 |
|
341 |
# If this is instantiated as a cross-attention module, the keys
|
|
|
63 |
except ImportError:
|
64 |
scaled_dot_product_attention = None
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
# This is used by encode but user may not have it installed
|
67 |
try:
|
68 |
from tqdm.autonotebook import trange
|
|
|
318 |
output_attentions: Optional[bool] = False,
|
319 |
bias: Optional[torch.FloatTensor] = None,
|
320 |
) -> Tuple[torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
mixed_query_layer = self.query(hidden_states)
|
322 |
|
323 |
# If this is instantiated as a cross-attention module, the keys
|