robinzixuan commited on
Commit
3561472
·
verified ·
1 Parent(s): 423395a

Update modeling_opt.py

Browse files
Files changed (1) hide show
  1. modeling_opt.py +5 -4
modeling_opt.py CHANGED
@@ -133,7 +133,8 @@ def softmax_1(input: torch.Tensor, dim=-1, dtype=torch.float32) -> torch.Tensor:
133
  """
134
  $\text(softmax)_n(x_i) = exp(x_i) / (1 + \sum_j exp(x_j))$
135
  """
136
- return softmax_n_shifted_zeros(input, 1, dim=dim)
 
137
 
138
 
139
  class OPTAttention(nn.Module):
@@ -151,7 +152,7 @@ class OPTAttention(nn.Module):
151
  self.num_heads = config.num_attention_heads
152
  self.dropout = config.attention_dropout
153
  self.enable_bias = config.enable_bias
154
-
155
  self.head_dim = self.embed_dim // self.num_heads
156
  self.is_causal = True
157
 
@@ -327,7 +328,7 @@ class OPTOutEffHop(OPTAttention):
327
  self.num_heads = config.num_attention_heads
328
  self.dropout = config.attention_dropout
329
  self.enable_bias = config.enable_bias
330
-
331
  self.head_dim = self.embed_dim // self.num_heads
332
  self.is_causal = True
333
 
@@ -488,7 +489,7 @@ class OPTOutEffHop(OPTAttention):
488
  return attn_output, attn_weights_reshaped, past_key_value
489
 
490
 
491
- class OptFlashAttention2(OPTAttention):
492
  """
493
  OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
494
  The only required change would be on the forward pass where it needs to correctly call the public API of flash
 
133
  """
134
  $\text(softmax)_n(x_i) = exp(x_i) / (1 + \sum_j exp(x_j))$
135
  """
136
+ output = softmax_n_shifted_zeros(input, 1, dim=dim)
137
+ return output if dtype is None else output.type(dtype=dtype)
138
 
139
 
140
  class OPTAttention(nn.Module):
 
152
  self.num_heads = config.num_attention_heads
153
  self.dropout = config.attention_dropout
154
  self.enable_bias = config.enable_bias
155
+ self.attention= nn.functional.softmax
156
  self.head_dim = self.embed_dim // self.num_heads
157
  self.is_causal = True
158
 
 
328
  self.num_heads = config.num_attention_heads
329
  self.dropout = config.attention_dropout
330
  self.enable_bias = config.enable_bias
331
+ self.attention= softmax_1
332
  self.head_dim = self.embed_dim // self.num_heads
333
  self.is_causal = True
334
 
 
489
  return attn_output, attn_weights_reshaped, past_key_value
490
 
491
 
492
+ class OptFlashAttention2(OPTOutEffHop):
493
  """
494
  OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
495
  The only required change would be on the forward pass where it needs to correctly call the public API of flash