emanuelaboros commited on
Commit
bdff5a0
·
1 Parent(s): 5fa0758
Files changed (1) hide show
  1. modeling_lang.py +9 -9
modeling_lang.py CHANGED
@@ -58,15 +58,15 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
58
  def get_floret_model(self):
59
  return self.model_floret
60
 
61
- def get_extended_attention_mask(
62
- self, attention_mask, input_shape, device=None, dtype=torch.float
63
- ):
64
- if attention_mask is None:
65
- attention_mask = torch.ones(input_shape, device=device)
66
- extended_attention_mask = attention_mask[:, None, None, :]
67
- extended_attention_mask = extended_attention_mask.to(dtype=dtype)
68
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
69
- return extended_attention_mask
70
 
71
  @property
72
  def device(self):
 
58
  def get_floret_model(self):
59
  return self.model_floret
60
 
61
+ # def get_extended_attention_mask(
62
+ # self, attention_mask, input_shape, device=None, dtype=torch.float
63
+ # ):
64
+ # if attention_mask is None:
65
+ # attention_mask = torch.ones(input_shape, device=device)
66
+ # extended_attention_mask = attention_mask[:, None, None, :]
67
+ # extended_attention_mask = extended_attention_mask.to(dtype=dtype)
68
+ # extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
69
+ # return extended_attention_mask
70
 
71
  @property
72
  def device(self):