dragonSwing commited on
Commit
07eef75
·
1 Parent(s): dd9b3ed

Fix error_prob bug

Browse files
Files changed (1) hide show
  1. gec_model.py +5 -1
gec_model.py CHANGED
@@ -89,6 +89,7 @@ class GecBERTModel(torch.nn.Module):
89
  self.lowercase_tokens = lowercase_tokens
90
  self.min_error_probability = min_error_probability
91
  self.vocab = Vocabulary.from_files(vocab_path)
 
92
  self.log = log
93
  self.iterations = iterations
94
  self.confidence = confidence
@@ -337,7 +338,10 @@ class GecBERTModel(torch.nn.Module):
337
  for output, weight in zip(data, self.model_weights):
338
  class_probabilities_labels = torch.softmax(output['logits'], dim=-1)
339
  all_class_probs += weight * class_probabilities_labels / sum(self.model_weights)
340
- error_probs += weight * output['max_error_probability'] / sum(self.model_weights)
 
 
 
341
 
342
  max_vals = torch.max(all_class_probs, dim=-1)
343
  probs = max_vals[0].tolist()
 
89
  self.lowercase_tokens = lowercase_tokens
90
  self.min_error_probability = min_error_probability
91
  self.vocab = Vocabulary.from_files(vocab_path)
92
+ self.incorr_index = self.vocab.get_token_index("INCORRECT", "d_tags")
93
  self.log = log
94
  self.iterations = iterations
95
  self.confidence = confidence
 
338
  for output, weight in zip(data, self.model_weights):
339
  class_probabilities_labels = torch.softmax(output['logits'], dim=-1)
340
  all_class_probs += weight * class_probabilities_labels / sum(self.model_weights)
341
+ class_probabilities_d = torch.softmax(output['detect_logits'], dim=-1)
342
+ error_probs_d = class_probabilities_d[:, :, self.incorr_index]
343
+ incorr_prob = torch.max(error_probs_d, dim=-1)[0]
344
+ error_probs += weight * incorr_prob / sum(self.model_weights)
345
 
346
  max_vals = torch.max(all_class_probs, dim=-1)
347
  probs = max_vals[0].tolist()