Elron commited on
Commit
7e64b87
·
verified ·
1 Parent(s): c92ffc9

Upload metrics.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. metrics.py +164 -10
metrics.py CHANGED
@@ -16,7 +16,7 @@ from scipy.stats import bootstrap
16
  from scipy.stats._warnings_errors import DegenerateDataWarning
17
 
18
  from .artifact import Artifact
19
- from .dataclass import AbstractField, InternalField, OptionalField
20
  from .logging_utils import get_logger
21
  from .metric_utils import InstanceInput, MetricRequest, MetricResponse
22
  from .operator import (
@@ -648,6 +648,9 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
648
 
649
  reduction_map: Dict[str, List[str]] = AbstractField()
650
 
 
 
 
651
  def _validate_group_mean_reduction(self, instances: List[dict]):
652
  """Ensure that group_mean reduction_map is properly formatted.
653
 
@@ -827,10 +830,21 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
827
  instances = []
828
 
829
  for instance in stream:
830
- refs, pred = instance["references"], instance["prediction"]
 
 
 
 
 
 
 
 
 
 
 
 
831
  self._validate_prediction(pred)
832
  self._validate_reference(refs)
833
- task_data = instance["task_data"] if "task_data" in instance else {}
834
 
835
  instance_score = self.compute(
836
  references=refs, prediction=pred, task_data=task_data
@@ -1033,7 +1047,6 @@ class MetricPipeline(MultiStreamOperator, Metric):
1033
  [f"score/instance/{self.main_score}", "score/instance/score"],
1034
  [f"score/global/{self.main_score}", "score/global/score"],
1035
  ],
1036
- use_query=True,
1037
  )
1038
 
1039
  def process(self, multi_stream: MultiStream) -> MultiStream:
@@ -1447,13 +1460,15 @@ class Rouge(HuggingfaceMetric):
1447
 
1448
 
1449
  # Computes char edit distance, ignoring whitespace
1450
- class CharEditDistanceAccuracy(InstanceMetric):
1451
- reduction_map = {"mean": ["char_edit_dist_accuracy"]}
1452
- main_score = "char_edit_dist_accuracy"
1453
- ci_scores = ["char_edit_dist_accuracy"]
1454
  prediction_type = "str"
1455
  single_reference_per_prediction = True
1456
 
 
 
1457
  _requirements_list: List[str] = ["editdistance"]
1458
 
1459
  def prepare(self):
@@ -1467,9 +1482,21 @@ class CharEditDistanceAccuracy(InstanceMetric):
1467
  formatted_reference = "".join(references[0].split())
1468
  max_length = max(len(formatted_reference), len(formatted_prediction))
1469
  if max_length == 0:
1470
- return {"char_edit_dist_accuracy": 0.0}
1471
  edit_dist = self.eval(formatted_reference, formatted_prediction)
1472
- return {"char_edit_dist_accuracy": (1 - edit_dist / max_length)}
 
 
 
 
 
 
 
 
 
 
 
 
1473
 
1474
 
1475
  class Wer(HuggingfaceMetric):
@@ -1853,6 +1880,8 @@ class BertScore(HuggingfaceBulkMetric):
1853
  ci_scores = ["f1", "precision", "recall"]
1854
  model_name: str
1855
 
 
 
1856
  _requirements_list: List[str] = ["bert_score"]
1857
 
1858
  def prepare(self):
@@ -1949,6 +1978,38 @@ class Reward(BulkInstanceMetric):
1949
  return self.pipe(inputs, batch_size=self.batch_size)
1950
 
1951
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1952
  class LlamaIndexCorrectness(InstanceMetric):
1953
  """LlamaIndex based metric class for evaluating correctness."""
1954
 
@@ -3320,6 +3381,99 @@ class BinaryMaxAccuracy(GlobalMetric):
3320
  return {self.main_score: best_acc, "best_thr_max_acc": best_thr}
3321
 
3322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3323
  KO_ERROR_MESSAGE = """
3324
 
3325
  Additional dependencies required. To install them, run:
 
16
  from scipy.stats._warnings_errors import DegenerateDataWarning
17
 
18
  from .artifact import Artifact
19
+ from .dataclass import AbstractField, InternalField, NonPositionalField, OptionalField
20
  from .logging_utils import get_logger
21
  from .metric_utils import InstanceInput, MetricRequest, MetricResponse
22
  from .operator import (
 
648
 
649
  reduction_map: Dict[str, List[str]] = AbstractField()
650
 
651
+ reference_field: str = NonPositionalField(default="references")
652
+ prediction_field: str = NonPositionalField(default="prediction")
653
+
654
  def _validate_group_mean_reduction(self, instances: List[dict]):
655
  """Ensure that group_mean reduction_map is properly formatted.
656
 
 
830
  instances = []
831
 
832
  for instance in stream:
833
+ task_data = instance["task_data"] if "task_data" in instance else {}
834
+
835
+ if self.reference_field == "references":
836
+ refs = instance["references"]
837
+ else:
838
+ refs = task_data[self.reference_field]
839
+ if not isinstance(refs, list):
840
+ refs = [refs]
841
+ if self.prediction_field == "prediction":
842
+ pred = instance["prediction"]
843
+ else:
844
+ pred = task_data[self.prediction_field]
845
+
846
  self._validate_prediction(pred)
847
  self._validate_reference(refs)
 
848
 
849
  instance_score = self.compute(
850
  references=refs, prediction=pred, task_data=task_data
 
1047
  [f"score/instance/{self.main_score}", "score/instance/score"],
1048
  [f"score/global/{self.main_score}", "score/global/score"],
1049
  ],
 
1050
  )
1051
 
1052
  def process(self, multi_stream: MultiStream) -> MultiStream:
 
1460
 
1461
 
1462
  # Computes char edit distance, ignoring whitespace
1463
+ class CharEditDistance(InstanceMetric):
1464
+ main_score = "char_edit_distance"
1465
+ reduction_map = {"mean": [main_score]}
1466
+ ci_scores = [main_score]
1467
  prediction_type = "str"
1468
  single_reference_per_prediction = True
1469
 
1470
+ accuracy_metric = False
1471
+
1472
  _requirements_list: List[str] = ["editdistance"]
1473
 
1474
  def prepare(self):
 
1482
  formatted_reference = "".join(references[0].split())
1483
  max_length = max(len(formatted_reference), len(formatted_prediction))
1484
  if max_length == 0:
1485
+ return {self.main_score: 0.0}
1486
  edit_dist = self.eval(formatted_reference, formatted_prediction)
1487
+ if self.accuracy_metric:
1488
+ score = 1 - edit_dist / max_length
1489
+ else:
1490
+ score = edit_dist
1491
+ return {self.main_score: score}
1492
+
1493
+
1494
+ class CharEditDistanceAccuracy(CharEditDistance):
1495
+ main_score = "char_edit_dist_accuracy"
1496
+ reduction_map = {"mean": [main_score]}
1497
+ ci_scores = [main_score]
1498
+
1499
+ accuracy_metric = True
1500
 
1501
 
1502
  class Wer(HuggingfaceMetric):
 
1880
  ci_scores = ["f1", "precision", "recall"]
1881
  model_name: str
1882
 
1883
+ prediction_type = "str"
1884
+
1885
  _requirements_list: List[str] = ["bert_score"]
1886
 
1887
  def prepare(self):
 
1978
  return self.pipe(inputs, batch_size=self.batch_size)
1979
 
1980
 
1981
+ class Detector(BulkInstanceMetric):
1982
+ reduction_map = {"mean": ["score"]}
1983
+ main_score = "score"
1984
+ batch_size: int = 32
1985
+
1986
+ prediction_type = "str"
1987
+
1988
+ model_name: str
1989
+
1990
+ _requirements_list: List[str] = ["transformers", "torch"]
1991
+
1992
+ def prepare(self):
1993
+ super().prepare()
1994
+ import torch
1995
+ from transformers import pipeline
1996
+
1997
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
1998
+ self.pipe = pipeline(
1999
+ "text-classification", model=self.model_name, device=device
2000
+ )
2001
+
2002
+ def compute(
2003
+ self,
2004
+ references: List[List[Any]],
2005
+ predictions: List[Any],
2006
+ task_data: List[Dict],
2007
+ ) -> List[Dict[str, Any]]:
2008
+ # compute the metric
2009
+ # add function_to_apply="none" to disable sigmoid
2010
+ return self.pipe(predictions, batch_size=self.batch_size)
2011
+
2012
+
2013
  class LlamaIndexCorrectness(InstanceMetric):
2014
  """LlamaIndex based metric class for evaluating correctness."""
2015
 
 
3381
  return {self.main_score: best_acc, "best_thr_max_acc": best_thr}
3382
 
3383
 
3384
+ ######################
3385
+ # RerankRecallMetric #
3386
+
3387
+
3388
+ def pytrec_eval_at_k(results, qrels, at_k, metric_name):
3389
+ import pandas as pd
3390
+ import pytrec_eval
3391
+
3392
+ metric = {}
3393
+
3394
+ for k in at_k:
3395
+ metric[f"{metric_name}@{k}"] = 0.0
3396
+
3397
+ metric_string = f"{metric_name}." + ",".join([str(k) for k in at_k])
3398
+ # print('metric_string = ', metric_string)
3399
+ evaluator = pytrec_eval.RelevanceEvaluator(
3400
+ qrels, {"ndcg", metric_string}
3401
+ ) # {map_string, ndcg_string, recall_string, precision_string})
3402
+ scores = evaluator.evaluate(results)
3403
+ scores = pd.DataFrame(scores).transpose()
3404
+
3405
+ keys = []
3406
+ column_map = {}
3407
+ for k in at_k:
3408
+ keys.append(f"{metric_name}_{k}")
3409
+ column_map[f"{metric_name}_{k}"] = k
3410
+ scores[keys].rename(columns=column_map)
3411
+
3412
+ return scores
3413
+
3414
+
3415
+ class RerankRecall(GlobalMetric):
3416
+ """RerankRecall: measures the quality of reranking with respect to ground truth ranking scores.
3417
+
3418
+ This metric measures ranking performance across a dataset. The
3419
+ references for a query will have a score of 1 for the gold passage
3420
+ and 0 for all other passages. The model returns scores in [0,1]
3421
+ for each passage,query pair. This metric measures recall at k by
3422
+ testing that the predicted score for the gold passage,query pair
3423
+ is at least the k'th highest for all passages for that query. A
3424
+ query receives 1 if so, and 0 if not. The 1's and 0's are
3425
+ averaged across the dataset.
3426
+
3427
+ query_id_field selects the field containing the query id for an instance.
3428
+ passage_id_field selects the field containing the passage id for an instance.
3429
+ at_k selects the value of k used to compute recall.
3430
+
3431
+ """
3432
+
3433
+ main_score = "recall_at_5"
3434
+ query_id_field: str = "query_id"
3435
+ passage_id_field: str = "passage_id"
3436
+ at_k: List[int] = [1, 2, 5]
3437
+
3438
+ # This doesn't seem to make sense
3439
+ n_resamples = None
3440
+
3441
+ _requirements_list: List[str] = ["pandas", "pytrec_eval"]
3442
+
3443
+ def compute(
3444
+ self,
3445
+ references: List[List[str]],
3446
+ predictions: List[str],
3447
+ task_data: List[Dict],
3448
+ ):
3449
+ # Collect relevance score and ref per query/passage pair
3450
+ results = {}
3451
+ qrels = {}
3452
+ for ref, pred, data in zip(references, predictions, task_data):
3453
+ qid = data[self.query_id_field]
3454
+ pid = data[self.passage_id_field]
3455
+ if qid not in results:
3456
+ results[qid] = {}
3457
+ qrels[qid] = {}
3458
+ # Convert string-wrapped float to regular float
3459
+ try:
3460
+ results[qid][pid] = float(pred)
3461
+ except ValueError:
3462
+ # Card testing feeds nonnumeric values in, so catch that.
3463
+ results[qid][pid] = np.nan
3464
+
3465
+ # There's always a single reference per pid/qid pair
3466
+ qrels[qid][pid] = int(ref[0])
3467
+
3468
+ # Compute recall @ 5
3469
+ scores = pytrec_eval_at_k(results, qrels, self.at_k, "recall")
3470
+ # print(scores.describe())
3471
+ # pytrec returns numpy float32
3472
+ return {
3473
+ f"recall_at_{i}": float(scores[f"recall_{i}"].mean()) for i in self.at_k
3474
+ }
3475
+
3476
+
3477
  KO_ERROR_MESSAGE = """
3478
 
3479
  Additional dependencies required. To install them, run: