# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """TODO: Add a description here.""" import re import string import collections from typing import Callable import evaluate import datasets # TODO: Add BibTeX citation _CITATION = """\ @InProceedings{huggingface:module, title = {A great new module}, authors={huggingface, Inc.}, year={2020} } """ _DESCRIPTION = """\ Question-answering metrics (`Exact Match` and `F1`) for Musique-Answerable dataset. The implementation is taken from Musique repository. https://github.com/StonyBrookNLP/musique """ _KWARGS_DESCRIPTION = """ Calculates how good are predictions given some references, using certain scores Args: predictions: list of predicted answers. references: list of ground truth answers. Each reference should be a list of ground truth answers for the corresponding prediction. Returns: exact_match: Exact match score, f1: F1 score over tokens Examples: >>> my_new_module = evaluate.load("musique") >>> results = my_new_module.compute( references=[["New York City", "NYC"], ["Einstein", "Albert Einstein"]], predictions=["New York City", "Albert Einstein"], ) >>> print(results) {'exact_match': 1.0, 'f1': 1.0} """ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) class musique(evaluate.Metric): """TODO: Question answering metrics (EM and F1) for Musique-Answerable dataset.""" def _info(self): # TODO: Specifies the evaluate.EvaluationModuleInfo object return evaluate.MetricInfo( # This is the description that will appear on the modules page. module_type="metric", description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, # This defines the format of each prediction and reference features=datasets.Features( { "predictions": datasets.Value("string"), "references": datasets.features.Sequence(datasets.Value("string")), } ), # Homepage of the module for documentation homepage="https://huggingface.co/spaces/bdsaglam/musique", # Additional links to the codebase or references codebase_urls=["https://huggingface.co/spaces/bdsaglam/musique"], reference_urls=["https://github.com/StonyBrookNLP/musique"], ) def _download_and_prepare(self, dl_manager): """Optional: download external resources useful to compute the scores""" pass def _compute(self, predictions, references): """Returns the scores""" if len(predictions) != len(references): raise ValueError( "The number of predictions and references should be the same." ) if len(predictions) == 0: return {"exact_match": 0.0, "f1": 0.0} exact_scores = [ metric_max_over_ground_truths(compute_exact, prediction, reference) for prediction, reference in zip(predictions, references) ] f1_scores = [ metric_max_over_ground_truths(compute_f1, prediction, reference) for prediction, reference in zip(predictions, references) ] return { "exact_match": sum(exact_scores) / len(exact_scores), "f1": sum(f1_scores) / len(f1_scores), } # Source: https://github.com/StonyBrookNLP/musique/blob/main/metrics/answer.py def normalize_answer(s): """Lower text and remove punctuation, articles and extra whitespace.""" def remove_articles(text): regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) return re.sub(regex, " ", text) def white_space_fix(text): return " ".join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return "".join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s)))) def get_tokens(s): if not s: return [] return normalize_answer(s).split() def compute_exact(a_gold, a_pred): return int(normalize_answer(a_gold) == normalize_answer(a_pred)) def compute_f1(a_gold, a_pred): gold_toks = get_tokens(a_gold) pred_toks = get_tokens(a_pred) common = collections.Counter(gold_toks) & collections.Counter(pred_toks) num_same = sum(common.values()) if len(gold_toks) == 0 or len(pred_toks) == 0: # If either is no-answer, then F1 is 1 if they agree, 0 otherwise return int(gold_toks == pred_toks) if num_same == 0: return 0 precision = 1.0 * num_same / len(pred_toks) recall = 1.0 * num_same / len(gold_toks) f1 = (2 * precision * recall) / (precision + recall) return f1 def metric_max_over_ground_truths( metric_fn: Callable[[str, str], float], prediction: str, ground_truths: list[str], ) -> float: scores_for_ground_truths = [ metric_fn(prediction, ground_truth) for ground_truth in ground_truths ] return max(scores_for_ground_truths)