sonsus commited on
Commit
6c4d03f
1 Parent(s): 82e28d8

Update harim_plus.py

Browse files
Files changed (1) hide show
  1. harim_plus.py +3 -3
harim_plus.py CHANGED
@@ -150,7 +150,7 @@ class Harimplus_Scorer:
150
  idx=0
151
  minibatches = []
152
  while True:
153
- start = idx
154
  end = idx+bsz
155
  if start >= len(exs):
156
  break
@@ -232,10 +232,10 @@ class Harimplus_Scorer:
232
  labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
233
  return_dict=True).logits
234
  sent_lengths = tgt_mask.sum(-1)
235
- ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, tgt_mask)
236
  ll = ll_tok.sum(-1) / sent_lengths
237
 
238
- harim_tok = self.harim(s2s_logits, lm_logits, tgt_in.input_ids, tgt_mask)
239
  harim = harim_tok.sum(-1) / sent_lengths
240
 
241
  harim_plus_normalized = (ll + self._lambda * harim) # loglikelihood + lambda * negative_harim (negative harim=-1* risk)
 
150
  idx=0
151
  minibatches = []
152
  while True:
153
+ start = id
154
  end = idx+bsz
155
  if start >= len(exs):
156
  break
 
232
  labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
233
  return_dict=True).logits
234
  sent_lengths = tgt_mask.sum(-1)
235
+ ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, 1)#tgt_mask)
236
  ll = ll_tok.sum(-1) / sent_lengths
237
 
238
+ harim_tok = self.harim(s2s_logits, lm_logits, tgt_in.input_ids, 1)#tgt_mask)
239
  harim = harim_tok.sum(-1) / sent_lengths
240
 
241
  harim_plus_normalized = (ll + self._lambda * harim) # loglikelihood + lambda * negative_harim (negative harim=-1* risk)