davda54 commited on
Commit
d25295b
·
verified ·
1 Parent(s): 2697fb5

simplified softmax (to allow torch.compile)

Browse files
Files changed (1) hide show
  1. modeling_norbert.py +4 -25
modeling_norbert.py CHANGED
@@ -101,23 +101,6 @@ class FeedForward(nn.Module):
101
  return self.mlp(x)
102
 
103
 
104
- class MaskedSoftmax(torch.autograd.Function):
105
- @staticmethod
106
- def forward(self, x, mask, dim):
107
- self.dim = dim
108
- x.masked_fill_(mask, float('-inf'))
109
- x = torch.softmax(x, self.dim)
110
- x.masked_fill_(mask, 0.0)
111
- self.save_for_backward(x)
112
- return x
113
-
114
- @staticmethod
115
- def backward(self, grad_output):
116
- output, = self.saved_tensors
117
- input_grad = softmax_backward_data(self, grad_output, output, self.dim, output)
118
- return input_grad, None, None
119
-
120
-
121
  class Attention(nn.Module):
122
  def __init__(self, config):
123
  super().__init__()
@@ -155,7 +138,7 @@ class Attention(nn.Module):
155
  bucket_pos = torch.where(abs_pos <= mid, relative_pos, log_pos * sign).long()
156
  return bucket_pos
157
 
158
- def compute_attention_scores(self, hidden_states, relative_embedding):
159
  key_len, batch_size, _ = hidden_states.size()
160
  query_len = key_len
161
 
@@ -193,21 +176,17 @@ class Attention(nn.Module):
193
  attention_scores.add_(attention_c_p)
194
  attention_scores.add_(attention_p_c)
195
 
196
- return attention_scores, value
 
197
 
198
- def compute_output(self, attention_probs, value):
199
  attention_probs = self.dropout(attention_probs)
200
  context = torch.bmm(attention_probs.flatten(0, 1), value) # shape: [B*H, Q, D]
201
  context = context.transpose(0, 1).reshape(context.size(1), -1, self.hidden_size) # shape: [Q, B, H*D]
202
  context = self.out_proj(context)
203
  context = self.post_layer_norm(context)
204
  context = self.dropout(context)
205
- return context
206
 
207
- def forward(self, hidden_states, attention_mask, relative_embedding):
208
- attention_scores, value = self.compute_attention_scores(hidden_states, relative_embedding)
209
- attention_probs = MaskedSoftmax.apply(attention_scores, attention_mask, -1)
210
- return self.compute_output(attention_probs, value), attention_probs.detach()
211
 
212
 
213
  class Embedding(nn.Module):
 
101
  return self.mlp(x)
102
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  class Attention(nn.Module):
105
  def __init__(self, config):
106
  super().__init__()
 
138
  bucket_pos = torch.where(abs_pos <= mid, relative_pos, log_pos * sign).long()
139
  return bucket_pos
140
 
141
+ def forward(self, hidden_states, attention_mask, relative_embedding):
142
  key_len, batch_size, _ = hidden_states.size()
143
  query_len = key_len
144
 
 
176
  attention_scores.add_(attention_c_p)
177
  attention_scores.add_(attention_p_c)
178
 
179
+ attention_scores = attention_scores.masked_fill(attention_mask, float('-inf'))
180
+ attention_probs = F.softmax(attention_scores, dim=-1)
181
 
 
182
  attention_probs = self.dropout(attention_probs)
183
  context = torch.bmm(attention_probs.flatten(0, 1), value) # shape: [B*H, Q, D]
184
  context = context.transpose(0, 1).reshape(context.size(1), -1, self.hidden_size) # shape: [Q, B, H*D]
185
  context = self.out_proj(context)
186
  context = self.post_layer_norm(context)
187
  context = self.dropout(context)
 
188
 
189
+ return context, attention_probs.detach()
 
 
 
190
 
191
 
192
  class Embedding(nn.Module):