Update modeling_opt.py
Browse files- modeling_opt.py +5 -4
modeling_opt.py
CHANGED
@@ -133,7 +133,8 @@ def softmax_1(input: torch.Tensor, dim=-1, dtype=torch.float32) -> torch.Tensor:
|
|
133 |
"""
|
134 |
$\text(softmax)_n(x_i) = exp(x_i) / (1 + \sum_j exp(x_j))$
|
135 |
"""
|
136 |
-
|
|
|
137 |
|
138 |
|
139 |
class OPTAttention(nn.Module):
|
@@ -151,7 +152,7 @@ class OPTAttention(nn.Module):
|
|
151 |
self.num_heads = config.num_attention_heads
|
152 |
self.dropout = config.attention_dropout
|
153 |
self.enable_bias = config.enable_bias
|
154 |
-
|
155 |
self.head_dim = self.embed_dim // self.num_heads
|
156 |
self.is_causal = True
|
157 |
|
@@ -327,7 +328,7 @@ class OPTOutEffHop(OPTAttention):
|
|
327 |
self.num_heads = config.num_attention_heads
|
328 |
self.dropout = config.attention_dropout
|
329 |
self.enable_bias = config.enable_bias
|
330 |
-
|
331 |
self.head_dim = self.embed_dim // self.num_heads
|
332 |
self.is_causal = True
|
333 |
|
@@ -488,7 +489,7 @@ class OPTOutEffHop(OPTAttention):
|
|
488 |
return attn_output, attn_weights_reshaped, past_key_value
|
489 |
|
490 |
|
491 |
-
class OptFlashAttention2(
|
492 |
"""
|
493 |
OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
|
494 |
The only required change would be on the forward pass where it needs to correctly call the public API of flash
|
|
|
133 |
"""
|
134 |
$\text(softmax)_n(x_i) = exp(x_i) / (1 + \sum_j exp(x_j))$
|
135 |
"""
|
136 |
+
output = softmax_n_shifted_zeros(input, 1, dim=dim)
|
137 |
+
return output if dtype is None else output.type(dtype=dtype)
|
138 |
|
139 |
|
140 |
class OPTAttention(nn.Module):
|
|
|
152 |
self.num_heads = config.num_attention_heads
|
153 |
self.dropout = config.attention_dropout
|
154 |
self.enable_bias = config.enable_bias
|
155 |
+
self.attention= nn.functional.softmax
|
156 |
self.head_dim = self.embed_dim // self.num_heads
|
157 |
self.is_causal = True
|
158 |
|
|
|
328 |
self.num_heads = config.num_attention_heads
|
329 |
self.dropout = config.attention_dropout
|
330 |
self.enable_bias = config.enable_bias
|
331 |
+
self.attention= softmax_1
|
332 |
self.head_dim = self.embed_dim // self.num_heads
|
333 |
self.is_causal = True
|
334 |
|
|
|
489 |
return attn_output, attn_weights_reshaped, past_key_value
|
490 |
|
491 |
|
492 |
+
class OptFlashAttention2(OPTOutEffHop):
|
493 |
"""
|
494 |
OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
|
495 |
The only required change would be on the forward pass where it needs to correctly call the public API of flash
|