# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple from copy import deepcopy from typing import Sequence, Optional import datasets import evaluate # TODO: Add BibTeX citation _CITATION = """\ @misc{nereval, title={{NER-Evaluation}: Named Entity Evaluation as in SemEval 2013 task 9.1}, url={https://github.com/davidsbatista/NER-Evaluation}, note={Software available from https://github.com/davidsbatista/NER-Evaluation}, author={Batista David}, year={2018}, } """ # TODO: Add description of the module here _DESCRIPTION = """\ ner-eval is a Python frame for sequence labeling evaluation. I twas used in SemEval 2013 task 9.1. It supports exact match, partial match, spurious and other errors. """ # TODO: Add description of the arguments of the module here _KWARGS_DESCRIPTION = """ Calculates how good are predictions given some references, using certain scores Args: predictions: List of List of predicted labels (Estimated targets as returned by a tagger) references: List of List of reference labels (Ground truth (correct) target values) tags: List of tags to evaluate. default: None Returns: 'scores' dict. Summary of the scores for overall and each tag. { "overall": { "strict_precision": 0.0, "strict_recall": 0.0, "strict_f1": 0, "ent_type_precision": 0.0, "ent_type_recall": 0.0, "ent_type_f1": 0, "partial_precision": 0.0, "partial_recall": 0.0, "partial_f1": 0, "exact_precision": 0.0, "exact_recall": 0.0, "exact_f1": 0, }, "ORG": { "strict_precision": 0.0, "strict_recall": 0.0, "strict_f1": 0, "ent_type_precision": 0.0, "ent_type_recall": 0.0, "ent_type_f1": 0, "partial_precision": 0.0, "partial_recall": 0.0, "partial_f1": 0, "exact_precision": 0.0, "exact_recall": 0.0, "exact_f1": 0, }, "PER": { "strict_precision": 0.0, "strict_recall": 0.0, "strict_f1": 0, "ent_type_precision": 0.0, "ent_type_recall": 0.0, "ent_type_f1": 0, "partial_precision": 0.0, "partial_recall": 0.0, "partial_f1": 0, "exact_precision": 0.0, "exact_recall": 0.0, "exact_f1": 0, }, "LOC": { "strict_precision": 0.0, "strict_recall": 0.0, "strict_f1": 0, "ent_type_precision": 0.0, "ent_type_recall": 0.0, "ent_type_f1": 0, "partial_precision": 0.0, "partial_recall": 0.0, "partial_f1": 0, "exact_precision": 0.0, "exact_recall": 0.0, "exact_f1": 0, }, } Examples: >>> my_new_module = evaluate.load("fschlatt/ner_eval") >>> results = my_new_module.compute( ... references=[["B-LOC", "I-LOC", "I-LOC", "B-ORG", "I-ORG", "O", "B-PER", "I-PER", "I-PER", "O"]], ... predictions=[["B-LOC", "I-LOC", "O", "O", "B-ORG", "I-ORG", "O", "B-PER", "I-PER", "O"]] ... ) >>> print(results) { "overall": { "strict_precision": 0.0, "strict_recall": 0.0, "strict_f1": 0, "ent_type_precision": 2 / 3, "ent_type_recall": 2 / 3, "ent_type_f1": 2 / 3, "partial_precision": 1 / 3, "partial_recall": 1 / 3, "partial_f1": 1 / 3, "exact_precision": 0.0, "exact_recall": 0.0, "exact_f1": 0, }, "ORG": { "strict_precision": 0.0, "strict_recall": 0.0, "strict_f1": 0, "ent_type_precision": 0.0, "ent_type_recall": 0.0, "ent_type_f1": 0, "partial_precision": 0.0, "partial_recall": 0.0, "partial_f1": 0, "exact_precision": 0.0, "exact_recall": 0.0, "exact_f1": 0, }, "PER": { "strict_precision": 0.0, "strict_recall": 0.0, "strict_f1": 0, "ent_type_precision": 0.5, "ent_type_recall": 1.0, "ent_type_f1": 2 / 3, "partial_precision": 0.25, "partial_recall": 0.5, "partial_f1": 1 / 3, "exact_precision": 0.0, "exact_recall": 0.0, "exact_f1": 0, }, "LOC": { "strict_precision": 0.0, "strict_recall": 0.0, "strict_f1": 0, "ent_type_precision": 0.5, "ent_type_recall": 1.0, "ent_type_f1": 2 / 3, "partial_precision": 0.25, "partial_recall": 0.5, "partial_f1": 1 / 3, "exact_precision": 0.0, "exact_recall": 0.0, "exact_f1": 0, } } """ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) class NEREval(evaluate.Metric): """TODO: Short description of my evaluation module.""" def _info(self): return evaluate.MetricInfo( # This is the description that will appear on the modules page. module_type="metric", description=_DESCRIPTION, citation=_CITATION, homepage="https://github.com/davidsbatista/NER-Evaluation", inputs_description=_KWARGS_DESCRIPTION, # This defines the format of each prediction and reference features=datasets.Features( { "predictions": datasets.Sequence( datasets.Value("string", id="label"), id="sequence" ), "references": datasets.Sequence( datasets.Value("string", id="label"), id="sequence" ), } ), # Additional links to the codebase or references codebase_urls=["https://github.com/davidsbatista/NER-Evaluation"], reference_urls=[ "https://github.com/davidsbatista/NER-Evaluation", "https://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/", ], ) def _download_and_prepare(self, dl_manager): """Optional: download external resources useful to compute the scores""" # TODO: Download external resources if needed pass def _compute( self, predictions: Sequence[Sequence[str]], references: Sequence[Sequence[str]], tags: Optional[Sequence[str]] = None, modes: Optional[Sequence[str]] = None, ): if tags is None: tags = list(parse_tags(predictions).union(parse_tags(references))) evaluator = Evaluator(predictions, references, tags) results, agg_results = evaluator.evaluate() out = {"overall": parse_results(results, modes)} for tag, tag_result in agg_results.items(): out = {**out, tag: parse_results(tag_result, modes)} return out def parse_results(results, modes: Optional[Sequence[str]] = None): if modes is None: modes = ["strict", "ent_type", "partial", "exact"] out = {} for mode in modes: out[f"{mode}_precision"] = results[mode]["precision"] out[f"{mode}_recall"] = results[mode]["recall"] out[f"{mode}_f1"] = results[mode]["f1"] return out def parse_tags(tokens: Sequence[Sequence[str]]): tags = set() for seq in tokens: for t in seq: tags.add(t.split("-")[-1]) tags.discard("O") return tags Entity = namedtuple("Entity", "e_type start_offset end_offset") class Evaluator: def __init__(self, true, pred, tags): """ """ if len(true) != len(pred): raise ValueError("Number of predicted documents does not equal true") self.true = true self.pred = pred self.tags = tags # Setup dict into which metrics will be stored. self.metrics_results = { "correct": 0, "incorrect": 0, "partial": 0, "missed": 0, "spurious": 0, "possible": 0, "actual": 0, "precision": 0, "recall": 0, } # Copy results dict to cover the four schemes. self.results = { "strict": deepcopy(self.metrics_results), "ent_type": deepcopy(self.metrics_results), "partial": deepcopy(self.metrics_results), "exact": deepcopy(self.metrics_results), } # Create an accumulator to store results self.evaluation_agg_entities_type = {e: deepcopy(self.results) for e in tags} def evaluate(self): for true_ents, pred_ents in zip(self.true, self.pred): # Check that the length of the true and predicted examples are the # same. This must be checked here, because another error may not # be thrown if the lengths do not match. if len(true_ents) != len(pred_ents): raise ValueError("Prediction length does not match true example length") # Compute results for one message tmp_results, tmp_agg_results = compute_metrics( collect_named_entities(true_ents), collect_named_entities(pred_ents), self.tags, ) # Cycle through each result and accumulate # TODO: Combine these loops below: for eval_schema in self.results: for metric in self.results[eval_schema]: self.results[eval_schema][metric] += tmp_results[eval_schema][ metric ] # Calculate global precision and recall self.results = compute_precision_recall_f1_wrapper(self.results) # Aggregate results by entity type for e_type in self.tags: for eval_schema in tmp_agg_results[e_type]: for metric in tmp_agg_results[e_type][eval_schema]: self.evaluation_agg_entities_type[e_type][eval_schema][ metric ] += tmp_agg_results[e_type][eval_schema][metric] # Calculate precision recall at the individual entity level self.evaluation_agg_entities_type[ e_type ] = compute_precision_recall_f1_wrapper( self.evaluation_agg_entities_type[e_type] ) return self.results, self.evaluation_agg_entities_type def collect_named_entities(tokens): """ Creates a list of Entity named-tuples, storing the entity type and the start and end offsets of the entity. :param tokens: a list of tags :return: a list of Entity named-tuples """ named_entities = [] start_offset = None end_offset = None ent_type = None for offset, token_tag in enumerate(tokens): if token_tag == "O": if ent_type is not None and start_offset is not None: end_offset = offset - 1 named_entities.append(Entity(ent_type, start_offset, end_offset)) start_offset = None end_offset = None ent_type = None elif ent_type is None: ent_type = token_tag[2:] start_offset = offset elif ent_type != token_tag[2:] or ( ent_type == token_tag[2:] and token_tag[:1] == "B" ): end_offset = offset - 1 named_entities.append(Entity(ent_type, start_offset, end_offset)) # start of a new entity ent_type = token_tag[2:] start_offset = offset end_offset = None # catches an entity that goes up until the last token if ent_type is not None and start_offset is not None and end_offset is None: named_entities.append(Entity(ent_type, start_offset, len(tokens) - 1)) return named_entities def compute_metrics(true_named_entities, pred_named_entities, tags): eval_metrics = { "correct": 0, "incorrect": 0, "partial": 0, "missed": 0, "spurious": 0, "precision": 0, "recall": 0, } # overall results evaluation = { "strict": deepcopy(eval_metrics), "ent_type": deepcopy(eval_metrics), "partial": deepcopy(eval_metrics), "exact": deepcopy(eval_metrics), } # results by entity type evaluation_agg_entities_type = {e: deepcopy(evaluation) for e in tags} # keep track of entities that overlapped true_which_overlapped_with_pred = [] # Subset into only the tags that we are interested in. # NOTE: we remove the tags we don't want from both the predicted and the # true entities. This covers the two cases where mismatches can occur: # # 1) Where the model predicts a tag that is not present in the true data # 2) Where there is a tag in the true data that the model is not capable of # predicting. true_named_entities = [ent for ent in true_named_entities if ent.e_type in tags] pred_named_entities = [ent for ent in pred_named_entities if ent.e_type in tags] # go through each predicted named-entity for pred in pred_named_entities: found_overlap = False # Check each of the potential scenarios in turn. See # http://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/ # for scenario explanation. # Scenario I: Exact match between true and pred if pred in true_named_entities: true_which_overlapped_with_pred.append(pred) evaluation["strict"]["correct"] += 1 evaluation["ent_type"]["correct"] += 1 evaluation["exact"]["correct"] += 1 evaluation["partial"]["correct"] += 1 # for the agg. by e_type results evaluation_agg_entities_type[pred.e_type]["strict"]["correct"] += 1 evaluation_agg_entities_type[pred.e_type]["ent_type"]["correct"] += 1 evaluation_agg_entities_type[pred.e_type]["exact"]["correct"] += 1 evaluation_agg_entities_type[pred.e_type]["partial"]["correct"] += 1 else: # check for overlaps with any of the true entities for true in true_named_entities: pred_range = range(pred.start_offset, pred.end_offset) true_range = range(true.start_offset, true.end_offset) # Scenario IV: Offsets match, but entity type is wrong if ( true.start_offset == pred.start_offset and pred.end_offset == true.end_offset and true.e_type != pred.e_type ): # overall results evaluation["strict"]["incorrect"] += 1 evaluation["ent_type"]["incorrect"] += 1 evaluation["partial"]["correct"] += 1 evaluation["exact"]["correct"] += 1 # aggregated by entity type results evaluation_agg_entities_type[true.e_type]["strict"][ "incorrect" ] += 1 evaluation_agg_entities_type[true.e_type]["ent_type"][ "incorrect" ] += 1 evaluation_agg_entities_type[true.e_type]["partial"]["correct"] += 1 evaluation_agg_entities_type[true.e_type]["exact"]["correct"] += 1 true_which_overlapped_with_pred.append(true) found_overlap = True break # check for an overlap i.e. not exact boundary match, with true entities elif find_overlap(true_range, pred_range): true_which_overlapped_with_pred.append(true) # Scenario V: There is an overlap (but offsets do not match # exactly), and the entity type is the same. # 2.1 overlaps with the same entity type if pred.e_type == true.e_type: # overall results evaluation["strict"]["incorrect"] += 1 evaluation["ent_type"]["correct"] += 1 evaluation["partial"]["partial"] += 1 evaluation["exact"]["incorrect"] += 1 # aggregated by entity type results evaluation_agg_entities_type[true.e_type]["strict"][ "incorrect" ] += 1 evaluation_agg_entities_type[true.e_type]["ent_type"][ "correct" ] += 1 evaluation_agg_entities_type[true.e_type]["partial"][ "partial" ] += 1 evaluation_agg_entities_type[true.e_type]["exact"][ "incorrect" ] += 1 found_overlap = True break # Scenario VI: Entities overlap, but the entity type is # different. else: # overall results evaluation["strict"]["incorrect"] += 1 evaluation["ent_type"]["incorrect"] += 1 evaluation["partial"]["partial"] += 1 evaluation["exact"]["incorrect"] += 1 # aggregated by entity type results # Results against the true entity evaluation_agg_entities_type[true.e_type]["strict"][ "incorrect" ] += 1 evaluation_agg_entities_type[true.e_type]["partial"][ "partial" ] += 1 evaluation_agg_entities_type[true.e_type]["ent_type"][ "incorrect" ] += 1 evaluation_agg_entities_type[true.e_type]["exact"][ "incorrect" ] += 1 # Results against the predicted entity # evaluation_agg_entities_type[pred.e_type]['strict']['spurious'] += 1 found_overlap = True break # Scenario II: Entities are spurious (i.e., over-generated). if not found_overlap: # Overall results evaluation["strict"]["spurious"] += 1 evaluation["ent_type"]["spurious"] += 1 evaluation["partial"]["spurious"] += 1 evaluation["exact"]["spurious"] += 1 # Aggregated by entity type results # NOTE: when pred.e_type is not found in tags # or when it simply does not appear in the test set, then it is # spurious, but it is not clear where to assign it at the tag # level. In this case, it is applied to all target_tags # found in this example. This will mean that the sum of the # evaluation_agg_entities will not equal evaluation. for true in tags: evaluation_agg_entities_type[true]["strict"]["spurious"] += 1 evaluation_agg_entities_type[true]["ent_type"]["spurious"] += 1 evaluation_agg_entities_type[true]["partial"]["spurious"] += 1 evaluation_agg_entities_type[true]["exact"]["spurious"] += 1 # Scenario III: Entity was missed entirely. for true in true_named_entities: if true in true_which_overlapped_with_pred: continue else: # overall results evaluation["strict"]["missed"] += 1 evaluation["ent_type"]["missed"] += 1 evaluation["partial"]["missed"] += 1 evaluation["exact"]["missed"] += 1 # for the agg. by e_type evaluation_agg_entities_type[true.e_type]["strict"]["missed"] += 1 evaluation_agg_entities_type[true.e_type]["ent_type"]["missed"] += 1 evaluation_agg_entities_type[true.e_type]["partial"]["missed"] += 1 evaluation_agg_entities_type[true.e_type]["exact"]["missed"] += 1 # Compute 'possible', 'actual' according to SemEval-2013 Task 9.1 on the # overall results, and use these to calculate precision and recall. for eval_type in evaluation: evaluation[eval_type] = compute_actual_possible(evaluation[eval_type]) # Compute 'possible', 'actual', and precision and recall on entity level # results. Start by cycling through the accumulated results. for entity_type, entity_level in evaluation_agg_entities_type.items(): # Cycle through the evaluation types for each dict containing entity # level results. for eval_type in entity_level: evaluation_agg_entities_type[entity_type][ eval_type ] = compute_actual_possible(entity_level[eval_type]) return evaluation, evaluation_agg_entities_type def find_overlap(true_range, pred_range): """Find the overlap between two ranges Find the overlap between two ranges. Return the overlapping values if present, else return an empty set(). Examples: >>> find_overlap((1, 2), (2, 3)) 2 >>> find_overlap((1, 2), (3, 4)) set() """ true_set = set(true_range) pred_set = set(pred_range) overlaps = true_set.intersection(pred_set) return overlaps def compute_actual_possible(results): """ Takes a result dict that has been output by compute metrics. Returns the results dict with actual, possible populated. When the results dicts is from partial or ent_type metrics, then partial_or_type=True to ensure the right calculation is used for calculating precision and recall. """ correct = results["correct"] incorrect = results["incorrect"] partial = results["partial"] missed = results["missed"] spurious = results["spurious"] # Possible: number annotations in the gold-standard which contribute to the # final score possible = correct + incorrect + partial + missed # Actual: number of annotations produced by the NER system actual = correct + incorrect + partial + spurious results["actual"] = actual results["possible"] = possible return results def compute_precision_recall_f1(results, partial_or_type=False): """ Takes a result dict that has been output by compute metrics. Returns the results dict with precison and recall populated. When the results dicts is from partial or ent_type metrics, then partial_or_type=True to ensure the right calculation is used for calculating precision and recall. """ actual = results["actual"] possible = results["possible"] partial = results["partial"] correct = results["correct"] if partial_or_type: precision = (correct + 0.5 * partial) / actual if actual > 0 else 0 recall = (correct + 0.5 * partial) / possible if possible > 0 else 0 else: precision = correct / actual if actual > 0 else 0 recall = correct / possible if possible > 0 else 0 results["precision"] = precision results["recall"] = recall results["f1"] = ( precision * recall * 2 / (precision + recall) if precision + recall > 0 else 0 ) return results def compute_precision_recall_f1_wrapper(results): """ Wraps the compute_precision_recall_f1 function and runs on a dict of results """ results_a = { key: compute_precision_recall_f1(value, True) for key, value in results.items() if key in ["partial", "ent_type"] } results_b = { key: compute_precision_recall_f1(value) for key, value in results.items() if key in ["strict", "exact"] } results = {**results_a, **results_b} return results