Transformers
PyTorch
code
custom_code
Inference Endpoints
codesage commited on
Commit
15b098e
·
verified ·
1 Parent(s): 9d0025f

Update modeling_codesage.py

Browse files
Files changed (1) hide show
  1. modeling_codesage.py +67 -1
modeling_codesage.py CHANGED
@@ -11,7 +11,11 @@ from transformers.activations import ACT2FN
11
  from transformers.modeling_utils import Conv1D, PreTrainedModel
12
  from transformers.utils import logging
13
  from .config_codesage import CodeSageConfig
14
- from transformers.modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
 
 
 
 
15
 
16
  logger = logging.get_logger(__name__)
17
 
@@ -151,6 +155,7 @@ class CodeSageBlock(nn.Module):
151
 
152
  class CodeSagePreTrainedModel(PreTrainedModel):
153
  config_class = CodeSageConfig
 
154
 
155
  def _init_weights(self, module):
156
  """Initialize the weights."""
@@ -277,7 +282,68 @@ class CodeSageModel(CodeSagePreTrainedModel):
277
  )
278
 
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  class CodeSageForSequenceClassification(CodeSagePreTrainedModel):
 
281
  def __init__(self, config):
282
  super().__init__(config)
283
  self.num_labels = config.num_labels
 
11
  from transformers.modeling_utils import Conv1D, PreTrainedModel
12
  from transformers.utils import logging
13
  from .config_codesage import CodeSageConfig
14
+ from transformers.modeling_outputs import (
15
+ BaseModelOutputWithPooling,
16
+ MaskedLMOutput,
17
+ SequenceClassifierOutput
18
+ )
19
 
20
  logger = logging.get_logger(__name__)
21
 
 
155
 
156
  class CodeSagePreTrainedModel(PreTrainedModel):
157
  config_class = CodeSageConfig
158
+ base_model_prefix = "transformer"
159
 
160
  def _init_weights(self, module):
161
  """Initialize the weights."""
 
282
  )
283
 
284
 
285
+ class CodeSageForMaskedLM(CodeSagePreTrainedModel):
286
+ _tied_weights_keys = ["lm_head.weight"]
287
+
288
+ def __init__(self, config):
289
+ super().__init__(config)
290
+ self.transformer = CodeSageModel(config)
291
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
292
+
293
+ self.init_weights()
294
+
295
+ def get_output_embeddings(self):
296
+ return self.lm_head
297
+
298
+ def set_output_embeddings(self, new_embeddings):
299
+ self.lm_head = new_embeddings
300
+
301
+ def forward(
302
+ self,
303
+ input_ids=None,
304
+ attention_mask=None,
305
+ position_ids=None,
306
+ head_mask=None,
307
+ inputs_embeds=None,
308
+ labels=None,
309
+ output_attentions=None,
310
+ output_hidden_states=None,
311
+ return_dict=None
312
+ ):
313
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
314
+
315
+ transformer_outputs = self.transformer(
316
+ input_ids,
317
+ attention_mask=attention_mask,
318
+ position_ids=position_ids,
319
+ head_mask=head_mask,
320
+ inputs_embeds=inputs_embeds,
321
+ output_attentions=output_attentions,
322
+ output_hidden_states=output_hidden_states,
323
+ return_dict=return_dict
324
+ )
325
+ hidden_states = transformer_outputs[0]
326
+ lm_logits = self.lm_head(hidden_states)
327
+
328
+ masked_lm_loss = None
329
+ if labels is not None:
330
+ loss_fct = CrossEntropyLoss()
331
+ masked_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
332
+
333
+ if not return_dict:
334
+ output = (lm_logits,) + transformer_outputs[1:]
335
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
336
+
337
+ return MaskedLMOutput(
338
+ loss=masked_lm_loss,
339
+ logits=lm_logits,
340
+ hidden_states=transformer_outputs.hidden_states,
341
+ attentions=transformer_outputs.attentions,
342
+ )
343
+
344
+
345
  class CodeSageForSequenceClassification(CodeSagePreTrainedModel):
346
+
347
  def __init__(self, config):
348
  super().__init__(config)
349
  self.num_labels = config.num_labels