Upload metrics.py with huggingface_hub
Browse files- metrics.py +41 -11
metrics.py
CHANGED
@@ -7,6 +7,7 @@ from typing import Any, Dict, Generator, List, Optional
|
|
7 |
import evaluate
|
8 |
import nltk
|
9 |
import numpy
|
|
|
10 |
|
11 |
from .dataclass import InternalField
|
12 |
from .operator import (
|
@@ -21,12 +22,12 @@ from .stream import MultiStream, Stream
|
|
21 |
nltk.download("punkt")
|
22 |
|
23 |
|
24 |
-
def
|
25 |
return {}
|
26 |
|
27 |
|
28 |
def abstract_field():
|
29 |
-
return field(default_factory=
|
30 |
|
31 |
|
32 |
class UpdateStream(StreamInstanceOperator):
|
@@ -253,7 +254,7 @@ class F1(GlobalMetric):
|
|
253 |
def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
|
254 |
assert all(
|
255 |
len(reference) == 1 for reference in references
|
256 |
-
), "
|
257 |
self.str_to_id = {}
|
258 |
self.id_to_str = {}
|
259 |
formatted_references = [self.get_str_id(reference[0]) for reference in references]
|
@@ -287,7 +288,6 @@ class F1MultiLabel(GlobalMetric):
|
|
287 |
_metric = None
|
288 |
main_score = "f1_macro"
|
289 |
average = None # Report per class then aggregate by mean
|
290 |
-
seperator = ","
|
291 |
|
292 |
def prepare(self):
|
293 |
super(F1MultiLabel, self).prepare()
|
@@ -310,17 +310,15 @@ class F1MultiLabel(GlobalMetric):
|
|
310 |
def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
|
311 |
self.str_to_id = {}
|
312 |
self.id_to_str = {}
|
|
|
|
|
|
|
|
|
313 |
labels = list(set([label for reference in references for label in reference]))
|
314 |
for label in labels:
|
315 |
-
assert (
|
316 |
-
not self.seperator in label
|
317 |
-
), "Reference label (f{label}) can not contain multi label seperator (f{self.seperator}) "
|
318 |
self.add_str_to_id(label)
|
319 |
formatted_references = [self.get_one_hot_vector(reference) for reference in references]
|
320 |
-
|
321 |
-
[label.strip() for label in prediction.split(self.seperator)] for prediction in predictions
|
322 |
-
]
|
323 |
-
formatted_predictions = [self.get_one_hot_vector(prediction) for prediction in split_predictions]
|
324 |
result = self._metric.compute(
|
325 |
predictions=formatted_predictions, references=formatted_references, average=self.average
|
326 |
)
|
@@ -356,6 +354,38 @@ class Rouge(HuggingfaceMetric):
|
|
356 |
return super().compute(references, predictions)
|
357 |
|
358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
class Bleu(HuggingfaceMetric):
|
360 |
metric_name = "bleu"
|
361 |
main_score = "bleu"
|
|
|
7 |
import evaluate
|
8 |
import nltk
|
9 |
import numpy
|
10 |
+
from editdistance import eval
|
11 |
|
12 |
from .dataclass import InternalField
|
13 |
from .operator import (
|
|
|
22 |
nltk.download("punkt")
|
23 |
|
24 |
|
25 |
+
def abstract_factory():
|
26 |
return {}
|
27 |
|
28 |
|
29 |
def abstract_field():
|
30 |
+
return field(default_factory=abstract_factory)
|
31 |
|
32 |
|
33 |
class UpdateStream(StreamInstanceOperator):
|
|
|
254 |
def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
|
255 |
assert all(
|
256 |
len(reference) == 1 for reference in references
|
257 |
+
), "Only a single reference per prediction is allowed in F1 metric"
|
258 |
self.str_to_id = {}
|
259 |
self.id_to_str = {}
|
260 |
formatted_references = [self.get_str_id(reference[0]) for reference in references]
|
|
|
288 |
_metric = None
|
289 |
main_score = "f1_macro"
|
290 |
average = None # Report per class then aggregate by mean
|
|
|
291 |
|
292 |
def prepare(self):
|
293 |
super(F1MultiLabel, self).prepare()
|
|
|
310 |
def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
|
311 |
self.str_to_id = {}
|
312 |
self.id_to_str = {}
|
313 |
+
assert all(
|
314 |
+
len(reference) == 1 for reference in references
|
315 |
+
), "Only a single reference per prediction is allowed in F1 metric"
|
316 |
+
references = [reference[0] for reference in references]
|
317 |
labels = list(set([label for reference in references for label in reference]))
|
318 |
for label in labels:
|
|
|
|
|
|
|
319 |
self.add_str_to_id(label)
|
320 |
formatted_references = [self.get_one_hot_vector(reference) for reference in references]
|
321 |
+
formatted_predictions = [self.get_one_hot_vector(prediction) for prediction in predictions]
|
|
|
|
|
|
|
322 |
result = self._metric.compute(
|
323 |
predictions=formatted_predictions, references=formatted_references, average=self.average
|
324 |
)
|
|
|
354 |
return super().compute(references, predictions)
|
355 |
|
356 |
|
357 |
+
# Computes chat edit distance, ignoring repeating whitespace
|
358 |
+
class CharEditDistanceAccuracy(SingleReferenceInstanceMetric):
|
359 |
+
reduction_map = {"mean": ["char_edit_dist_accuracy"]}
|
360 |
+
main_score = "char_edit_dist_accuracy"
|
361 |
+
|
362 |
+
def compute(self, reference, prediction: str) -> dict:
|
363 |
+
formatted_prediction = " ".join(prediction.split())
|
364 |
+
formatted_reference = " ".join(reference.split())
|
365 |
+
max_length = max(len(formatted_reference), len(formatted_prediction))
|
366 |
+
if max_length == 0:
|
367 |
+
return 0
|
368 |
+
edit_dist = eval(formatted_reference, formatted_prediction)
|
369 |
+
return {"char_edit_dist_accuracy": (1 - edit_dist / max_length)}
|
370 |
+
|
371 |
+
|
372 |
+
class Wer(HuggingfaceMetric):
|
373 |
+
metric_name = "wer"
|
374 |
+
main_score = "wer"
|
375 |
+
|
376 |
+
def prepare(self):
|
377 |
+
super().prepare()
|
378 |
+
self.metric = evaluate.load(self.metric_name)
|
379 |
+
|
380 |
+
def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
|
381 |
+
assert all(
|
382 |
+
len(reference) == 1 for reference in references
|
383 |
+
), "Only single reference per prediction is allowed in wer metric"
|
384 |
+
formatted_references = [reference[0] for reference in references]
|
385 |
+
result = self.metric.compute(predictions=predictions, references=formatted_references)
|
386 |
+
return {self.main_score: result}
|
387 |
+
|
388 |
+
|
389 |
class Bleu(HuggingfaceMetric):
|
390 |
metric_name = "bleu"
|
391 |
main_score = "bleu"
|