|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""TODO: Add a description here.""" |
|
|
|
import re |
|
import string |
|
import collections |
|
from typing import Callable |
|
import evaluate |
|
import datasets |
|
|
|
|
|
|
|
_CITATION = """\ |
|
@InProceedings{huggingface:module, |
|
title = {A great new module}, |
|
authors={huggingface, Inc.}, |
|
year={2020} |
|
} |
|
""" |
|
|
|
_DESCRIPTION = """\ |
|
Question-answering metrics (`Exact Match` and `F1`) for Musique-Answerable dataset. |
|
|
|
The implementation is taken from Musique repository. |
|
https://github.com/StonyBrookNLP/musique |
|
""" |
|
|
|
|
|
_KWARGS_DESCRIPTION = """ |
|
Calculates how good are predictions given some references, using certain scores |
|
Args: |
|
predictions: list of predicted answers. |
|
references: list of ground truth answers. Each reference should be a list of |
|
ground truth answers for the corresponding prediction. |
|
Returns: |
|
exact_match: Exact match score, |
|
f1: F1 score over tokens |
|
Examples: |
|
>>> my_new_module = evaluate.load("musique") |
|
>>> results = my_new_module.compute( |
|
references=[["New York City", "NYC"], ["Einstein", "Albert Einstein"]], |
|
predictions=["New York City", "Albert Einstein"], |
|
) |
|
>>> print(results) |
|
{'exact_match': 1.0, 'f1': 1.0} |
|
""" |
|
|
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
class musique(evaluate.Metric): |
|
"""TODO: Question answering metrics (EM and F1) for Musique-Answerable dataset.""" |
|
|
|
def _info(self): |
|
|
|
return evaluate.MetricInfo( |
|
|
|
module_type="metric", |
|
description=_DESCRIPTION, |
|
citation=_CITATION, |
|
inputs_description=_KWARGS_DESCRIPTION, |
|
|
|
features=datasets.Features( |
|
{ |
|
"predictions": datasets.Value("string"), |
|
"references": datasets.features.Sequence(datasets.Value("string")), |
|
} |
|
), |
|
|
|
homepage="https://huggingface.co/spaces/bdsaglam/musique", |
|
|
|
codebase_urls=["https://huggingface.co/spaces/bdsaglam/musique"], |
|
reference_urls=["https://github.com/StonyBrookNLP/musique"], |
|
) |
|
|
|
def _download_and_prepare(self, dl_manager): |
|
"""Optional: download external resources useful to compute the scores""" |
|
pass |
|
|
|
def _compute(self, predictions, references): |
|
"""Returns the scores""" |
|
|
|
if len(predictions) != len(references): |
|
raise ValueError( |
|
"The number of predictions and references should be the same." |
|
) |
|
|
|
if len(predictions) == 0: |
|
return {"exact_match": 0.0, "f1": 0.0} |
|
|
|
exact_scores = [ |
|
metric_max_over_ground_truths(compute_exact, prediction, reference) |
|
for prediction, reference in zip(predictions, references) |
|
] |
|
f1_scores = [ |
|
metric_max_over_ground_truths(compute_f1, prediction, reference) |
|
for prediction, reference in zip(predictions, references) |
|
] |
|
return { |
|
"exact_match": sum(exact_scores) / len(exact_scores), |
|
"f1": sum(f1_scores) / len(f1_scores), |
|
} |
|
|
|
|
|
|
|
|
|
|
|
def normalize_answer(s): |
|
"""Lower text and remove punctuation, articles and extra whitespace.""" |
|
|
|
def remove_articles(text): |
|
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) |
|
return re.sub(regex, " ", text) |
|
|
|
def white_space_fix(text): |
|
return " ".join(text.split()) |
|
|
|
def remove_punc(text): |
|
exclude = set(string.punctuation) |
|
return "".join(ch for ch in text if ch not in exclude) |
|
|
|
def lower(text): |
|
return text.lower() |
|
|
|
return white_space_fix(remove_articles(remove_punc(lower(s)))) |
|
|
|
|
|
def get_tokens(s): |
|
if not s: |
|
return [] |
|
return normalize_answer(s).split() |
|
|
|
|
|
def compute_exact(a_gold, a_pred): |
|
return int(normalize_answer(a_gold) == normalize_answer(a_pred)) |
|
|
|
|
|
def compute_f1(a_gold, a_pred): |
|
gold_toks = get_tokens(a_gold) |
|
pred_toks = get_tokens(a_pred) |
|
common = collections.Counter(gold_toks) & collections.Counter(pred_toks) |
|
num_same = sum(common.values()) |
|
if len(gold_toks) == 0 or len(pred_toks) == 0: |
|
|
|
return int(gold_toks == pred_toks) |
|
if num_same == 0: |
|
return 0 |
|
precision = 1.0 * num_same / len(pred_toks) |
|
recall = 1.0 * num_same / len(gold_toks) |
|
f1 = (2 * precision * recall) / (precision + recall) |
|
return f1 |
|
|
|
|
|
def metric_max_over_ground_truths( |
|
metric_fn: Callable[[str, str], float], |
|
prediction: str, |
|
ground_truths: list[str], |
|
) -> float: |
|
scores_for_ground_truths = [ |
|
metric_fn(prediction, ground_truth) for ground_truth in ground_truths |
|
] |
|
return max(scores_for_ground_truths) |
|
|