|
import datasets
|
|
import evaluate
|
|
from transformers.trainer_utils import EvalPrediction
|
|
|
|
accuracy = evaluate.load("accuracy").compute
|
|
precision = evaluate.load("precision").compute
|
|
recall = evaluate.load("recall").compute
|
|
f1 = evaluate.load("f1").compute
|
|
squad_v2 = evaluate.load("squad_v2").compute
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_classification_metric(p: EvalPrediction):
|
|
"""
|
|
Compute classification metrics for a given prediction.
|
|
|
|
Args:
|
|
p (EvalPrediction): The prediction object.
|
|
|
|
Returns:
|
|
datasets.Metric: The metric object containing accuracy, precision,
|
|
recall, and f1 score.
|
|
"""
|
|
|
|
predictions = p.predictions.argmax(axis=1)
|
|
references = p.label_ids
|
|
|
|
|
|
metric = accuracy(predictions=predictions, references=references)
|
|
|
|
|
|
metric.update(precision(predictions=predictions, references=references))
|
|
metric.update(recall(predictions=predictions, references=references))
|
|
metric.update(f1(predictions=predictions, references=references))
|
|
|
|
|
|
return metric
|
|
|
|
|
|
def compute_squad_v2(p: EvalPrediction):
|
|
"""
|
|
Compute SQuAD v2 metrics for a given prediction.
|
|
|
|
Args:
|
|
p (EvalPrediction): The prediction object.
|
|
|
|
Returns:
|
|
datasets.Metric: The metric object containing SQuAD v2 metrics.
|
|
"""
|
|
|
|
predictions = p.predictions
|
|
references = p.label_ids
|
|
|
|
|
|
return squad_v2(predictions=predictions, references=references)
|
|
|
|
|