Spaces:
Running
Running
Update harim_plus.py
Browse files- harim_plus.py +3 -3
harim_plus.py
CHANGED
@@ -150,7 +150,7 @@ class Harimplus_Scorer:
|
|
150 |
idx=0
|
151 |
minibatches = []
|
152 |
while True:
|
153 |
-
start =
|
154 |
end = idx+bsz
|
155 |
if start >= len(exs):
|
156 |
break
|
@@ -232,10 +232,10 @@ class Harimplus_Scorer:
|
|
232 |
labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
|
233 |
return_dict=True).logits
|
234 |
sent_lengths = tgt_mask.sum(-1)
|
235 |
-
ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, tgt_mask)
|
236 |
ll = ll_tok.sum(-1) / sent_lengths
|
237 |
|
238 |
-
harim_tok = self.harim(s2s_logits, lm_logits, tgt_in.input_ids, tgt_mask)
|
239 |
harim = harim_tok.sum(-1) / sent_lengths
|
240 |
|
241 |
harim_plus_normalized = (ll + self._lambda * harim) # loglikelihood + lambda * negative_harim (negative harim=-1* risk)
|
|
|
150 |
idx=0
|
151 |
minibatches = []
|
152 |
while True:
|
153 |
+
start = id
|
154 |
end = idx+bsz
|
155 |
if start >= len(exs):
|
156 |
break
|
|
|
232 |
labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
|
233 |
return_dict=True).logits
|
234 |
sent_lengths = tgt_mask.sum(-1)
|
235 |
+
ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, 1)#tgt_mask)
|
236 |
ll = ll_tok.sum(-1) / sent_lengths
|
237 |
|
238 |
+
harim_tok = self.harim(s2s_logits, lm_logits, tgt_in.input_ids, 1)#tgt_mask)
|
239 |
harim = harim_tok.sum(-1) / sent_lengths
|
240 |
|
241 |
harim_plus_normalized = (ll + self._lambda * harim) # loglikelihood + lambda * negative_harim (negative harim=-1* risk)
|