Spaces:
Running
Running
Update harim_plus.py
Browse files- harim_plus.py +9 -8
harim_plus.py
CHANGED
@@ -171,7 +171,7 @@ class Harimplus_Scorer:
|
|
171 |
bsz:int=32,
|
172 |
use_aggregator:bool=False,
|
173 |
return_details:bool=False,
|
174 |
-
tokenwise_score:bool=False,
|
175 |
):
|
176 |
'''
|
177 |
returns harim+ score (List[float]) for predictions (summaries) and references (articles)
|
@@ -238,15 +238,15 @@ class Harimplus_Scorer:
|
|
238 |
harim_tok = self.harim(s2s_logits, lm_logits, tgt_in.input_ids, tgt_mask)
|
239 |
harim = harim_tok.sum(-1) / sent_lengths
|
240 |
|
241 |
-
harim_plus_normalized = ll + self._lambda * harim # loglikelihood + lambda * negative_harim (negative harim=-1* risk)
|
242 |
|
243 |
scores['harim+'].extend(harim_plus_normalized.tolist())
|
244 |
scores['harim'].extend(harim.tolist())
|
245 |
scores['log_ppl'].extend(ll.tolist())
|
246 |
|
247 |
-
if tokenwise_score:
|
248 |
-
|
249 |
-
|
250 |
|
251 |
if use_aggregator: # after
|
252 |
for k, v in scores.items():
|
@@ -314,13 +314,14 @@ class Harimplus(evaluate.Metric):
|
|
314 |
references=None,
|
315 |
use_aggregator=False,
|
316 |
bsz=32,
|
317 |
-
|
318 |
-
|
|
|
319 |
summaries = predictions
|
320 |
articles = references
|
321 |
scores = self.scorer.compute(predictions=summaries,
|
322 |
references=articles,
|
323 |
use_aggregator=use_aggregator,
|
324 |
-
bsz=bsz, tokenwise_score=tokenwise_score,
|
325 |
return_details=return_details)
|
326 |
return scores
|
|
|
171 |
bsz:int=32,
|
172 |
use_aggregator:bool=False,
|
173 |
return_details:bool=False,
|
174 |
+
# tokenwise_score:bool=False,
|
175 |
):
|
176 |
'''
|
177 |
returns harim+ score (List[float]) for predictions (summaries) and references (articles)
|
|
|
238 |
harim_tok = self.harim(s2s_logits, lm_logits, tgt_in.input_ids, tgt_mask)
|
239 |
harim = harim_tok.sum(-1) / sent_lengths
|
240 |
|
241 |
+
harim_plus_normalized = (ll + self._lambda * harim)/sent_lengths # loglikelihood + lambda * negative_harim (negative harim=-1* risk)
|
242 |
|
243 |
scores['harim+'].extend(harim_plus_normalized.tolist())
|
244 |
scores['harim'].extend(harim.tolist())
|
245 |
scores['log_ppl'].extend(ll.tolist())
|
246 |
|
247 |
+
# if tokenwise_score:
|
248 |
+
# scores['tok_harim+'].append(harim_tok*self._lambda + ll_tok)
|
249 |
+
# scores['tok_predictions'].append( [self._tokenizer.convert_ids_to_token(idxs) for idxs in src_in.labels] )
|
250 |
|
251 |
if use_aggregator: # after
|
252 |
for k, v in scores.items():
|
|
|
314 |
references=None,
|
315 |
use_aggregator=False,
|
316 |
bsz=32,
|
317 |
+
return_details=False):
|
318 |
+
# tokenwise_score=False,
|
319 |
+
|
320 |
summaries = predictions
|
321 |
articles = references
|
322 |
scores = self.scorer.compute(predictions=summaries,
|
323 |
references=articles,
|
324 |
use_aggregator=use_aggregator,
|
325 |
+
bsz=bsz, #tokenwise_score=tokenwise_score,
|
326 |
return_details=return_details)
|
327 |
return scores
|