Elron commited on
Commit
cc0572c
·
1 Parent(s): da1b3a8

Upload metrics.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. metrics.py +41 -11
metrics.py CHANGED
@@ -7,6 +7,7 @@ from typing import Any, Dict, Generator, List, Optional
7
  import evaluate
8
  import nltk
9
  import numpy
 
10
 
11
  from .dataclass import InternalField
12
  from .operator import (
@@ -21,12 +22,12 @@ from .stream import MultiStream, Stream
21
  nltk.download("punkt")
22
 
23
 
24
- def absrtact_factory():
25
  return {}
26
 
27
 
28
  def abstract_field():
29
- return field(default_factory=absrtact_factory)
30
 
31
 
32
  class UpdateStream(StreamInstanceOperator):
@@ -253,7 +254,7 @@ class F1(GlobalMetric):
253
  def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
254
  assert all(
255
  len(reference) == 1 for reference in references
256
- ), "One single reference per predictition are allowed in F1 metric"
257
  self.str_to_id = {}
258
  self.id_to_str = {}
259
  formatted_references = [self.get_str_id(reference[0]) for reference in references]
@@ -287,7 +288,6 @@ class F1MultiLabel(GlobalMetric):
287
  _metric = None
288
  main_score = "f1_macro"
289
  average = None # Report per class then aggregate by mean
290
- seperator = ","
291
 
292
  def prepare(self):
293
  super(F1MultiLabel, self).prepare()
@@ -310,17 +310,15 @@ class F1MultiLabel(GlobalMetric):
310
  def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
311
  self.str_to_id = {}
312
  self.id_to_str = {}
 
 
 
 
313
  labels = list(set([label for reference in references for label in reference]))
314
  for label in labels:
315
- assert (
316
- not self.seperator in label
317
- ), "Reference label (f{label}) can not contain multi label seperator (f{self.seperator}) "
318
  self.add_str_to_id(label)
319
  formatted_references = [self.get_one_hot_vector(reference) for reference in references]
320
- split_predictions = [
321
- [label.strip() for label in prediction.split(self.seperator)] for prediction in predictions
322
- ]
323
- formatted_predictions = [self.get_one_hot_vector(prediction) for prediction in split_predictions]
324
  result = self._metric.compute(
325
  predictions=formatted_predictions, references=formatted_references, average=self.average
326
  )
@@ -356,6 +354,38 @@ class Rouge(HuggingfaceMetric):
356
  return super().compute(references, predictions)
357
 
358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  class Bleu(HuggingfaceMetric):
360
  metric_name = "bleu"
361
  main_score = "bleu"
 
7
  import evaluate
8
  import nltk
9
  import numpy
10
+ from editdistance import eval
11
 
12
  from .dataclass import InternalField
13
  from .operator import (
 
22
  nltk.download("punkt")
23
 
24
 
25
+ def abstract_factory():
26
  return {}
27
 
28
 
29
  def abstract_field():
30
+ return field(default_factory=abstract_factory)
31
 
32
 
33
  class UpdateStream(StreamInstanceOperator):
 
254
  def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
255
  assert all(
256
  len(reference) == 1 for reference in references
257
+ ), "Only a single reference per prediction is allowed in F1 metric"
258
  self.str_to_id = {}
259
  self.id_to_str = {}
260
  formatted_references = [self.get_str_id(reference[0]) for reference in references]
 
288
  _metric = None
289
  main_score = "f1_macro"
290
  average = None # Report per class then aggregate by mean
 
291
 
292
  def prepare(self):
293
  super(F1MultiLabel, self).prepare()
 
310
  def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
311
  self.str_to_id = {}
312
  self.id_to_str = {}
313
+ assert all(
314
+ len(reference) == 1 for reference in references
315
+ ), "Only a single reference per prediction is allowed in F1 metric"
316
+ references = [reference[0] for reference in references]
317
  labels = list(set([label for reference in references for label in reference]))
318
  for label in labels:
 
 
 
319
  self.add_str_to_id(label)
320
  formatted_references = [self.get_one_hot_vector(reference) for reference in references]
321
+ formatted_predictions = [self.get_one_hot_vector(prediction) for prediction in predictions]
 
 
 
322
  result = self._metric.compute(
323
  predictions=formatted_predictions, references=formatted_references, average=self.average
324
  )
 
354
  return super().compute(references, predictions)
355
 
356
 
357
+ # Computes chat edit distance, ignoring repeating whitespace
358
+ class CharEditDistanceAccuracy(SingleReferenceInstanceMetric):
359
+ reduction_map = {"mean": ["char_edit_dist_accuracy"]}
360
+ main_score = "char_edit_dist_accuracy"
361
+
362
+ def compute(self, reference, prediction: str) -> dict:
363
+ formatted_prediction = " ".join(prediction.split())
364
+ formatted_reference = " ".join(reference.split())
365
+ max_length = max(len(formatted_reference), len(formatted_prediction))
366
+ if max_length == 0:
367
+ return 0
368
+ edit_dist = eval(formatted_reference, formatted_prediction)
369
+ return {"char_edit_dist_accuracy": (1 - edit_dist / max_length)}
370
+
371
+
372
+ class Wer(HuggingfaceMetric):
373
+ metric_name = "wer"
374
+ main_score = "wer"
375
+
376
+ def prepare(self):
377
+ super().prepare()
378
+ self.metric = evaluate.load(self.metric_name)
379
+
380
+ def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
381
+ assert all(
382
+ len(reference) == 1 for reference in references
383
+ ), "Only single reference per prediction is allowed in wer metric"
384
+ formatted_references = [reference[0] for reference in references]
385
+ result = self.metric.compute(predictions=predictions, references=formatted_references)
386
+ return {self.main_score: result}
387
+
388
+
389
  class Bleu(HuggingfaceMetric):
390
  metric_name = "bleu"
391
  main_score = "bleu"