Update amplify.py
Browse files- amplify.py +15 -16
amplify.py
CHANGED
@@ -124,13 +124,13 @@ class EncoderBlock(nn.Module):
|
|
124 |
|
125 |
self.ffn_dropout = nn.Dropout(config.dropout_prob)
|
126 |
|
127 |
-
def forward(self, x: torch.Tensor,
|
128 |
-
attn, contact = self._att_block(self.attention_norm(x),
|
129 |
x = x + attn
|
130 |
x = x + self._ff_block(self.ffn_norm(x))
|
131 |
return x, contact
|
132 |
|
133 |
-
def _att_block(self, x: torch.Tensor,
|
134 |
batch_size, seq_len, _ = x.shape
|
135 |
xq, xk, xv = self.q(x), self.k(x), self.v(x)
|
136 |
|
@@ -144,15 +144,15 @@ class EncoderBlock(nn.Module):
|
|
144 |
query=xq,
|
145 |
key=xk,
|
146 |
value=xv,
|
147 |
-
attn_bias=
|
148 |
p=self.config.dropout_prob if self.training else 0,
|
149 |
)
|
150 |
|
151 |
_attn = None
|
152 |
if output_attentions:
|
153 |
_attn = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
|
154 |
-
if
|
155 |
-
_attn = _attn +
|
156 |
_attn = _attn.softmax(-1)
|
157 |
|
158 |
return self.resid_dropout(self.wo(attn.view(batch_size, seq_len, self.config.num_attention_heads * self.d_head))), _attn
|
@@ -203,28 +203,28 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
|
|
203 |
# Initialize weights and apply final processing
|
204 |
self.post_init()
|
205 |
|
206 |
-
def forward(self,
|
207 |
# Initialize
|
208 |
hidden_states, attentions = [], []
|
209 |
|
210 |
# Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
|
211 |
-
if
|
212 |
-
|
213 |
else:
|
214 |
-
|
215 |
|
216 |
# RoPE
|
217 |
-
self.freqs_cis = self.freqs_cis.to(
|
218 |
-
freqs_cis = self.freqs_cis[:
|
219 |
|
220 |
# Embedding
|
221 |
-
x = self.encoder(
|
222 |
if self.config.layer_norm_after_embedding:
|
223 |
x = self.layer_norm_1(x)
|
224 |
|
225 |
# Transformer encoder
|
226 |
for layer in self.transformer_encoder:
|
227 |
-
x, attn = layer(x,
|
228 |
if output_hidden_states:
|
229 |
hidden_states.append(x)
|
230 |
if output_attentions:
|
@@ -234,5 +234,4 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
|
|
234 |
logits = self.decoder(self.layer_norm_2(x) if self.config.layer_norm_before_last_layer else x)
|
235 |
|
236 |
# Return logits or the output of the last hidden layer
|
237 |
-
return MaskedLMOutput(logits=logits, hidden_states=hidden_states, attentions=attentions)
|
238 |
-
|
|
|
124 |
|
125 |
self.ffn_dropout = nn.Dropout(config.dropout_prob)
|
126 |
|
127 |
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
|
128 |
+
attn, contact = self._att_block(self.attention_norm(x), attention_mask, freqs_cis, output_attentions)
|
129 |
x = x + attn
|
130 |
x = x + self._ff_block(self.ffn_norm(x))
|
131 |
return x, contact
|
132 |
|
133 |
+
def _att_block(self, x: torch.Tensor, attention_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
|
134 |
batch_size, seq_len, _ = x.shape
|
135 |
xq, xk, xv = self.q(x), self.k(x), self.v(x)
|
136 |
|
|
|
144 |
query=xq,
|
145 |
key=xk,
|
146 |
value=xv,
|
147 |
+
attn_bias=attention_mask,
|
148 |
p=self.config.dropout_prob if self.training else 0,
|
149 |
)
|
150 |
|
151 |
_attn = None
|
152 |
if output_attentions:
|
153 |
_attn = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
|
154 |
+
if attention_mask is not None:
|
155 |
+
_attn = _attn + attention_mask
|
156 |
_attn = _attn.softmax(-1)
|
157 |
|
158 |
return self.resid_dropout(self.wo(attn.view(batch_size, seq_len, self.config.num_attention_heads * self.d_head))), _attn
|
|
|
203 |
# Initialize weights and apply final processing
|
204 |
self.post_init()
|
205 |
|
206 |
+
def forward(self, input_ids, attention_mask=None, output_hidden_states=False, output_attentions=False, **kwargs):
|
207 |
# Initialize
|
208 |
hidden_states, attentions = [], []
|
209 |
|
210 |
# Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
|
211 |
+
if attention_mask is not None and not torch.all(attention_mask == 0):
|
212 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
|
213 |
else:
|
214 |
+
attention_mask = None
|
215 |
|
216 |
# RoPE
|
217 |
+
self.freqs_cis = self.freqs_cis.to(input_ids.device, non_blocking=True)
|
218 |
+
freqs_cis = self.freqs_cis[: input_ids.shape[1]]
|
219 |
|
220 |
# Embedding
|
221 |
+
x = self.encoder(input_ids)
|
222 |
if self.config.layer_norm_after_embedding:
|
223 |
x = self.layer_norm_1(x)
|
224 |
|
225 |
# Transformer encoder
|
226 |
for layer in self.transformer_encoder:
|
227 |
+
x, attn = layer(x, attention_mask, freqs_cis, output_attentions)
|
228 |
if output_hidden_states:
|
229 |
hidden_states.append(x)
|
230 |
if output_attentions:
|
|
|
234 |
logits = self.decoder(self.layer_norm_2(x) if self.config.layer_norm_before_last_layer else x)
|
235 |
|
236 |
# Return logits or the output of the last hidden layer
|
237 |
+
return MaskedLMOutput(logits=logits, hidden_states=hidden_states, attentions=attentions)
|
|