Upload metrics.py with huggingface_hub
Browse files- metrics.py +18 -3
metrics.py
CHANGED
@@ -220,13 +220,14 @@ class HuggingfaceMetric(GlobalMetric):
|
|
220 |
metric_name: str = None
|
221 |
main_score: str = None
|
222 |
scale: float = 1.0
|
|
|
223 |
|
224 |
def prepare(self):
|
225 |
super().prepare()
|
226 |
self.metric = evaluate.load(self.metric_name)
|
227 |
|
228 |
def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
|
229 |
-
result = self.metric.compute(predictions=predictions, references=references)
|
230 |
if self.scale != 1.0:
|
231 |
for key in result:
|
232 |
if isinstance(result[key], float):
|
@@ -373,7 +374,14 @@ class Rouge(HuggingfaceMetric):
|
|
373 |
main_score = "rougeL"
|
374 |
scale = 1.0
|
375 |
|
|
|
|
|
|
|
|
|
|
|
376 |
def prepare(self):
|
|
|
|
|
377 |
super().prepare()
|
378 |
import nltk
|
379 |
|
@@ -381,8 +389,9 @@ class Rouge(HuggingfaceMetric):
|
|
381 |
self.sent_tokenize = nltk.sent_tokenize
|
382 |
|
383 |
def compute(self, references, predictions):
|
384 |
-
|
385 |
-
|
|
|
386 |
return super().compute(references, predictions)
|
387 |
|
388 |
|
@@ -429,6 +438,12 @@ class Bleu(HuggingfaceMetric):
|
|
429 |
scale = 1.0
|
430 |
|
431 |
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
class MatthewsCorrelation(HuggingfaceMetric):
|
433 |
metric_name = "matthews_correlation"
|
434 |
main_score = "matthews_correlation"
|
|
|
220 |
metric_name: str = None
|
221 |
main_score: str = None
|
222 |
scale: float = 1.0
|
223 |
+
hf_compute_args: dict = {}
|
224 |
|
225 |
def prepare(self):
|
226 |
super().prepare()
|
227 |
self.metric = evaluate.load(self.metric_name)
|
228 |
|
229 |
def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
|
230 |
+
result = self.metric.compute(predictions=predictions, references=references, **self.hf_compute_args)
|
231 |
if self.scale != 1.0:
|
232 |
for key in result:
|
233 |
if isinstance(result[key], float):
|
|
|
374 |
main_score = "rougeL"
|
375 |
scale = 1.0
|
376 |
|
377 |
+
use_aggregator: bool = True
|
378 |
+
rouge_types: List[str] = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
|
379 |
+
|
380 |
+
sent_split_newline: bool = True
|
381 |
+
|
382 |
def prepare(self):
|
383 |
+
self.hf_compute_args = {"use_aggregator": self.use_aggregator, "rouge_types": self.rouge_types}
|
384 |
+
|
385 |
super().prepare()
|
386 |
import nltk
|
387 |
|
|
|
389 |
self.sent_tokenize = nltk.sent_tokenize
|
390 |
|
391 |
def compute(self, references, predictions):
|
392 |
+
if self.sent_split_newline:
|
393 |
+
predictions = ["\n".join(self.sent_tokenize(prediction.strip())) for prediction in predictions]
|
394 |
+
references = [["\n".join(self.sent_tokenize(r.strip())) for r in reference] for reference in references]
|
395 |
return super().compute(references, predictions)
|
396 |
|
397 |
|
|
|
438 |
scale = 1.0
|
439 |
|
440 |
|
441 |
+
class SacreBleu(HuggingfaceMetric):
|
442 |
+
metric_name = "sacrebleu"
|
443 |
+
main_score = "score"
|
444 |
+
scale = 1.0
|
445 |
+
|
446 |
+
|
447 |
class MatthewsCorrelation(HuggingfaceMetric):
|
448 |
metric_name = "matthews_correlation"
|
449 |
main_score = "matthews_correlation"
|