qfournier commited on
Commit
af2c1f1
1 Parent(s): 908a709

Update amplify.py

Browse files
Files changed (1) hide show
  1. amplify.py +15 -16
amplify.py CHANGED
@@ -124,13 +124,13 @@ class EncoderBlock(nn.Module):
124
 
125
  self.ffn_dropout = nn.Dropout(config.dropout_prob)
126
 
127
- def forward(self, x: torch.Tensor, pad_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
128
- attn, contact = self._att_block(self.attention_norm(x), pad_mask, freqs_cis, output_attentions)
129
  x = x + attn
130
  x = x + self._ff_block(self.ffn_norm(x))
131
  return x, contact
132
 
133
- def _att_block(self, x: torch.Tensor, pad_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
134
  batch_size, seq_len, _ = x.shape
135
  xq, xk, xv = self.q(x), self.k(x), self.v(x)
136
 
@@ -144,15 +144,15 @@ class EncoderBlock(nn.Module):
144
  query=xq,
145
  key=xk,
146
  value=xv,
147
- attn_bias=pad_mask,
148
  p=self.config.dropout_prob if self.training else 0,
149
  )
150
 
151
  _attn = None
152
  if output_attentions:
153
  _attn = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
154
- if pad_mask is not None:
155
- _attn = _attn + pad_mask
156
  _attn = _attn.softmax(-1)
157
 
158
  return self.resid_dropout(self.wo(attn.view(batch_size, seq_len, self.config.num_attention_heads * self.d_head))), _attn
@@ -203,28 +203,28 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
203
  # Initialize weights and apply final processing
204
  self.post_init()
205
 
206
- def forward(self, src, pad_mask=None, output_hidden_states=False, output_attentions=False):
207
  # Initialize
208
  hidden_states, attentions = [], []
209
 
210
  # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
211
- if pad_mask is not None and not torch.all(pad_mask == 0):
212
- pad_mask = pad_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, pad_mask.size(-1), 1)
213
  else:
214
- pad_mask = None
215
 
216
  # RoPE
217
- self.freqs_cis = self.freqs_cis.to(src.device, non_blocking=True)
218
- freqs_cis = self.freqs_cis[: src.shape[1]]
219
 
220
  # Embedding
221
- x = self.encoder(src)
222
  if self.config.layer_norm_after_embedding:
223
  x = self.layer_norm_1(x)
224
 
225
  # Transformer encoder
226
  for layer in self.transformer_encoder:
227
- x, attn = layer(x, pad_mask, freqs_cis, output_attentions)
228
  if output_hidden_states:
229
  hidden_states.append(x)
230
  if output_attentions:
@@ -234,5 +234,4 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
234
  logits = self.decoder(self.layer_norm_2(x) if self.config.layer_norm_before_last_layer else x)
235
 
236
  # Return logits or the output of the last hidden layer
237
- return MaskedLMOutput(logits=logits, hidden_states=hidden_states, attentions=attentions)
238
-
 
124
 
125
  self.ffn_dropout = nn.Dropout(config.dropout_prob)
126
 
127
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
128
+ attn, contact = self._att_block(self.attention_norm(x), attention_mask, freqs_cis, output_attentions)
129
  x = x + attn
130
  x = x + self._ff_block(self.ffn_norm(x))
131
  return x, contact
132
 
133
+ def _att_block(self, x: torch.Tensor, attention_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
134
  batch_size, seq_len, _ = x.shape
135
  xq, xk, xv = self.q(x), self.k(x), self.v(x)
136
 
 
144
  query=xq,
145
  key=xk,
146
  value=xv,
147
+ attn_bias=attention_mask,
148
  p=self.config.dropout_prob if self.training else 0,
149
  )
150
 
151
  _attn = None
152
  if output_attentions:
153
  _attn = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
154
+ if attention_mask is not None:
155
+ _attn = _attn + attention_mask
156
  _attn = _attn.softmax(-1)
157
 
158
  return self.resid_dropout(self.wo(attn.view(batch_size, seq_len, self.config.num_attention_heads * self.d_head))), _attn
 
203
  # Initialize weights and apply final processing
204
  self.post_init()
205
 
206
+ def forward(self, input_ids, attention_mask=None, output_hidden_states=False, output_attentions=False, **kwargs):
207
  # Initialize
208
  hidden_states, attentions = [], []
209
 
210
  # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
211
+ if attention_mask is not None and not torch.all(attention_mask == 0):
212
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
213
  else:
214
+ attention_mask = None
215
 
216
  # RoPE
217
+ self.freqs_cis = self.freqs_cis.to(input_ids.device, non_blocking=True)
218
+ freqs_cis = self.freqs_cis[: input_ids.shape[1]]
219
 
220
  # Embedding
221
+ x = self.encoder(input_ids)
222
  if self.config.layer_norm_after_embedding:
223
  x = self.layer_norm_1(x)
224
 
225
  # Transformer encoder
226
  for layer in self.transformer_encoder:
227
+ x, attn = layer(x, attention_mask, freqs_cis, output_attentions)
228
  if output_hidden_states:
229
  hidden_states.append(x)
230
  if output_attentions:
 
234
  logits = self.decoder(self.layer_norm_2(x) if self.config.layer_norm_before_last_layer else x)
235
 
236
  # Return logits or the output of the last hidden layer
237
+ return MaskedLMOutput(logits=logits, hidden_states=hidden_states, attentions=attentions)