Elron commited on
Commit
e8292e5
·
verified ·
1 Parent(s): 895d817

Upload eval_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. eval_utils.py +82 -12
eval_utils.py CHANGED
@@ -1,21 +1,91 @@
1
- from typing import List
 
2
 
3
  import pandas as pd
4
 
 
 
5
  from .operator import SequentialOperator
6
  from .stream import MultiStream
7
 
8
 
9
- def evaluate(dataset: pd.DataFrame, metric_names: List[str]):
10
- result = dataset.copy()
11
- # prepare the input stream
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  for metric_name in metric_names:
13
- multi_stream = MultiStream.from_iterables(
14
- {"test": dataset.to_dict("records")}, copying=True
15
- )
16
- metrics_operator = SequentialOperator(steps=[metric_name])
 
 
 
 
 
 
 
 
 
17
  instances = list(metrics_operator(multi_stream)["test"])
18
- result[metric_name] = [
19
- instance["score"]["instance"]["score"] for instance in instances
20
- ]
21
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import singledispatch
2
+ from typing import List, Optional
3
 
4
  import pandas as pd
5
 
6
+ from .artifact import verbosed_fetch_artifact
7
+ from .metric_utils import get_remote_metrics_endpoint, get_remote_metrics_names
8
  from .operator import SequentialOperator
9
  from .stream import MultiStream
10
 
11
 
12
+ @singledispatch
13
+ def evaluate(
14
+ dataset, metric_names: List[str], compute_conf_intervals: Optional[bool] = False
15
+ ):
16
+ """Placeholder for overloading the function, supporting both dataframe input and list input."""
17
+ pass
18
+
19
+
20
+ @evaluate.register
21
+ def _(
22
+ dataset: list,
23
+ metric_names: List[str],
24
+ compute_conf_intervals: Optional[bool] = False,
25
+ ):
26
+ global_scores = {}
27
+ remote_metrics = get_remote_metrics_names()
28
  for metric_name in metric_names:
29
+ multi_stream = MultiStream.from_iterables({"test": dataset}, copying=True)
30
+ if metric_name in remote_metrics:
31
+ metric = verbosed_fetch_artifact(metric_name)
32
+ metric_step = as_remote_metric(metric)
33
+ else:
34
+ # The SequentialOperator below will handle the load of the metric fromm its name
35
+ metric_step = metric_name
36
+ metrics_operator = SequentialOperator(steps=[metric_step])
37
+
38
+ if not compute_conf_intervals:
39
+ first_step = metrics_operator.steps[0]
40
+ n_resamples = first_step.disable_confidence_interval_calculation()
41
+
42
  instances = list(metrics_operator(multi_stream)["test"])
43
+ for entry, instance in zip(dataset, instances):
44
+ entry[metric_name] = instance["score"]["instance"]["score"]
45
+
46
+ if len(instances) > 0:
47
+ global_scores[metric_name] = instances[0]["score"].get("global", {})
48
+
49
+ # To overcome issue #325: the modified metric artifact is cached and
50
+ # a sequential retrieval of an artifact with the same name will
51
+ # retrieve the metric with the previous modification.
52
+ # This reverts the confidence interval change and restores the initial metric.
53
+ if not compute_conf_intervals:
54
+ first_step.set_n_resamples(n_resamples)
55
+
56
+ return dataset, global_scores
57
+
58
+
59
+ @evaluate.register
60
+ def _(
61
+ dataset: pd.DataFrame,
62
+ metric_names: List[str],
63
+ compute_conf_intervals: Optional[bool] = False,
64
+ ):
65
+ results, global_scores = evaluate(
66
+ dataset.to_dict("records"),
67
+ metric_names=metric_names,
68
+ compute_conf_intervals=compute_conf_intervals,
69
+ )
70
+ return pd.DataFrame(results), pd.DataFrame(global_scores)
71
+
72
+
73
+ def as_remote_metric(metric):
74
+ """Wrap a metric with a RemoteMetric.
75
+
76
+ Currently supported is wrapping the inner metric within a MetricPipeline.
77
+ """
78
+ from .metrics import MetricPipeline, RemoteMetric
79
+
80
+ remote_metrics_endpoint = get_remote_metrics_endpoint()
81
+ if isinstance(metric, MetricPipeline):
82
+ metric = RemoteMetric.wrap_inner_metric_pipeline_metric(
83
+ metric_pipeline=metric,
84
+ remote_metrics_endpoint=remote_metrics_endpoint,
85
+ )
86
+ else:
87
+ raise ValueError(
88
+ f"Unexpected remote metric type {type(metric)} for the metric named '{metric.artifact_identifier}'. "
89
+ f"Remotely executed metrics should be MetricPipeline objects."
90
+ )
91
+ return metric