ner_eval / ner_eval.py
fschlatt's picture
compute f1
97af528
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from copy import deepcopy
from typing import Sequence, Optional
import datasets
import evaluate
# TODO: Add BibTeX citation
_CITATION = """\
@misc{nereval,
title={{NER-Evaluation}: Named Entity Evaluation as in SemEval 2013 task 9.1},
url={https://github.com/davidsbatista/NER-Evaluation},
note={Software available from https://github.com/davidsbatista/NER-Evaluation},
author={Batista David},
year={2018},
}
"""
# TODO: Add description of the module here
_DESCRIPTION = """\
ner-eval is a Python frame for sequence labeling evaluation. I twas used in SemEval 2013 task 9.1.
It supports exact match, partial match, spurious and other errors.
"""
# TODO: Add description of the arguments of the module here
_KWARGS_DESCRIPTION = """
Calculates how good are predictions given some references, using certain scores
Args:
predictions: List of List of predicted labels (Estimated targets as returned by a tagger)
references: List of List of reference labels (Ground truth (correct) target values)
tags: List of tags to evaluate. default: None
Returns:
'scores' dict. Summary of the scores for overall and each tag.
{
"overall": {
"strict_precision": 0.0,
"strict_recall": 0.0,
"strict_f1": 0,
"ent_type_precision": 0.0,
"ent_type_recall": 0.0,
"ent_type_f1": 0,
"partial_precision": 0.0,
"partial_recall": 0.0,
"partial_f1": 0,
"exact_precision": 0.0,
"exact_recall": 0.0,
"exact_f1": 0,
},
"ORG": {
"strict_precision": 0.0,
"strict_recall": 0.0,
"strict_f1": 0,
"ent_type_precision": 0.0,
"ent_type_recall": 0.0,
"ent_type_f1": 0,
"partial_precision": 0.0,
"partial_recall": 0.0,
"partial_f1": 0,
"exact_precision": 0.0,
"exact_recall": 0.0,
"exact_f1": 0,
},
"PER": {
"strict_precision": 0.0,
"strict_recall": 0.0,
"strict_f1": 0,
"ent_type_precision": 0.0,
"ent_type_recall": 0.0,
"ent_type_f1": 0,
"partial_precision": 0.0,
"partial_recall": 0.0,
"partial_f1": 0,
"exact_precision": 0.0,
"exact_recall": 0.0,
"exact_f1": 0,
},
"LOC": {
"strict_precision": 0.0,
"strict_recall": 0.0,
"strict_f1": 0,
"ent_type_precision": 0.0,
"ent_type_recall": 0.0,
"ent_type_f1": 0,
"partial_precision": 0.0,
"partial_recall": 0.0,
"partial_f1": 0,
"exact_precision": 0.0,
"exact_recall": 0.0,
"exact_f1": 0,
},
}
Examples:
>>> my_new_module = evaluate.load("fschlatt/ner_eval")
>>> results = my_new_module.compute(
... references=[["B-LOC", "I-LOC", "I-LOC", "B-ORG", "I-ORG", "O", "B-PER", "I-PER", "I-PER", "O"]],
... predictions=[["B-LOC", "I-LOC", "O", "O", "B-ORG", "I-ORG", "O", "B-PER", "I-PER", "O"]]
... )
>>> print(results)
{
"overall": {
"strict_precision": 0.0,
"strict_recall": 0.0,
"strict_f1": 0,
"ent_type_precision": 2 / 3,
"ent_type_recall": 2 / 3,
"ent_type_f1": 2 / 3,
"partial_precision": 1 / 3,
"partial_recall": 1 / 3,
"partial_f1": 1 / 3,
"exact_precision": 0.0,
"exact_recall": 0.0,
"exact_f1": 0,
},
"ORG": {
"strict_precision": 0.0,
"strict_recall": 0.0,
"strict_f1": 0,
"ent_type_precision": 0.0,
"ent_type_recall": 0.0,
"ent_type_f1": 0,
"partial_precision": 0.0,
"partial_recall": 0.0,
"partial_f1": 0,
"exact_precision": 0.0,
"exact_recall": 0.0,
"exact_f1": 0,
},
"PER": {
"strict_precision": 0.0,
"strict_recall": 0.0,
"strict_f1": 0,
"ent_type_precision": 0.5,
"ent_type_recall": 1.0,
"ent_type_f1": 2 / 3,
"partial_precision": 0.25,
"partial_recall": 0.5,
"partial_f1": 1 / 3,
"exact_precision": 0.0,
"exact_recall": 0.0,
"exact_f1": 0,
},
"LOC": {
"strict_precision": 0.0,
"strict_recall": 0.0,
"strict_f1": 0,
"ent_type_precision": 0.5,
"ent_type_recall": 1.0,
"ent_type_f1": 2 / 3,
"partial_precision": 0.25,
"partial_recall": 0.5,
"partial_f1": 1 / 3,
"exact_precision": 0.0,
"exact_recall": 0.0,
"exact_f1": 0,
}
}
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class NEREval(evaluate.Metric):
"""TODO: Short description of my evaluation module."""
def _info(self):
return evaluate.MetricInfo(
# This is the description that will appear on the modules page.
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
homepage="https://github.com/davidsbatista/NER-Evaluation",
inputs_description=_KWARGS_DESCRIPTION,
# This defines the format of each prediction and reference
features=datasets.Features(
{
"predictions": datasets.Sequence(
datasets.Value("string", id="label"), id="sequence"
),
"references": datasets.Sequence(
datasets.Value("string", id="label"), id="sequence"
),
}
),
# Additional links to the codebase or references
codebase_urls=["https://github.com/davidsbatista/NER-Evaluation"],
reference_urls=[
"https://github.com/davidsbatista/NER-Evaluation",
"https://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/",
],
)
def _download_and_prepare(self, dl_manager):
"""Optional: download external resources useful to compute the scores"""
# TODO: Download external resources if needed
pass
def _compute(
self,
predictions: Sequence[Sequence[str]],
references: Sequence[Sequence[str]],
tags: Optional[Sequence[str]] = None,
modes: Optional[Sequence[str]] = None,
):
if tags is None:
tags = list(parse_tags(predictions).union(parse_tags(references)))
evaluator = Evaluator(predictions, references, tags)
results, agg_results = evaluator.evaluate()
out = {"overall": parse_results(results, modes)}
for tag, tag_result in agg_results.items():
out = {**out, tag: parse_results(tag_result, modes)}
return out
def parse_results(results, modes: Optional[Sequence[str]] = None):
if modes is None:
modes = ["strict", "ent_type", "partial", "exact"]
out = {}
for mode in modes:
out[f"{mode}_precision"] = results[mode]["precision"]
out[f"{mode}_recall"] = results[mode]["recall"]
out[f"{mode}_f1"] = results[mode]["f1"]
return out
def parse_tags(tokens: Sequence[Sequence[str]]):
tags = set()
for seq in tokens:
for t in seq:
tags.add(t.split("-")[-1])
tags.discard("O")
return tags
Entity = namedtuple("Entity", "e_type start_offset end_offset")
class Evaluator:
def __init__(self, true, pred, tags):
""" """
if len(true) != len(pred):
raise ValueError("Number of predicted documents does not equal true")
self.true = true
self.pred = pred
self.tags = tags
# Setup dict into which metrics will be stored.
self.metrics_results = {
"correct": 0,
"incorrect": 0,
"partial": 0,
"missed": 0,
"spurious": 0,
"possible": 0,
"actual": 0,
"precision": 0,
"recall": 0,
"f1": 0,
}
# Copy results dict to cover the four schemes.
self.results = {
"strict": deepcopy(self.metrics_results),
"ent_type": deepcopy(self.metrics_results),
"partial": deepcopy(self.metrics_results),
"exact": deepcopy(self.metrics_results),
}
# Create an accumulator to store results
self.evaluation_agg_entities_type = {e: deepcopy(self.results) for e in tags}
def evaluate(self):
for true_ents, pred_ents in zip(self.true, self.pred):
# Check that the length of the true and predicted examples are the
# same. This must be checked here, because another error may not
# be thrown if the lengths do not match.
if len(true_ents) != len(pred_ents):
raise ValueError("Prediction length does not match true example length")
# Compute results for one message
tmp_results, tmp_agg_results = compute_metrics(
collect_named_entities(true_ents),
collect_named_entities(pred_ents),
self.tags,
)
# Cycle through each result and accumulate
# TODO: Combine these loops below:
for eval_schema in self.results:
for metric in self.results[eval_schema]:
self.results[eval_schema][metric] += tmp_results[eval_schema][
metric
]
# Calculate global precision and recall
self.results = compute_precision_recall_f1_wrapper(self.results)
# Aggregate results by entity type
for e_type in self.tags:
for eval_schema in tmp_agg_results[e_type]:
for metric in tmp_agg_results[e_type][eval_schema]:
self.evaluation_agg_entities_type[e_type][eval_schema][
metric
] += tmp_agg_results[e_type][eval_schema][metric]
# Calculate precision recall at the individual entity level
self.evaluation_agg_entities_type[
e_type
] = compute_precision_recall_f1_wrapper(
self.evaluation_agg_entities_type[e_type]
)
return self.results, self.evaluation_agg_entities_type
def collect_named_entities(tokens):
"""
Creates a list of Entity named-tuples, storing the entity type and the start and end
offsets of the entity.
:param tokens: a list of tags
:return: a list of Entity named-tuples
"""
named_entities = []
start_offset = None
end_offset = None
ent_type = None
for offset, token_tag in enumerate(tokens):
if token_tag == "O":
if ent_type is not None and start_offset is not None:
end_offset = offset - 1
named_entities.append(Entity(ent_type, start_offset, end_offset))
start_offset = None
end_offset = None
ent_type = None
elif ent_type is None:
ent_type = token_tag[2:]
start_offset = offset
elif ent_type != token_tag[2:] or (
ent_type == token_tag[2:] and token_tag[:1] == "B"
):
end_offset = offset - 1
named_entities.append(Entity(ent_type, start_offset, end_offset))
# start of a new entity
ent_type = token_tag[2:]
start_offset = offset
end_offset = None
# catches an entity that goes up until the last token
if ent_type is not None and start_offset is not None and end_offset is None:
named_entities.append(Entity(ent_type, start_offset, len(tokens) - 1))
return named_entities
def compute_metrics(true_named_entities, pred_named_entities, tags):
eval_metrics = {
"correct": 0,
"incorrect": 0,
"partial": 0,
"missed": 0,
"spurious": 0,
"precision": 0,
"recall": 0,
"f1": 0,
}
# overall results
evaluation = {
"strict": deepcopy(eval_metrics),
"ent_type": deepcopy(eval_metrics),
"partial": deepcopy(eval_metrics),
"exact": deepcopy(eval_metrics),
}
# results by entity type
evaluation_agg_entities_type = {e: deepcopy(evaluation) for e in tags}
# keep track of entities that overlapped
true_which_overlapped_with_pred = []
# Subset into only the tags that we are interested in.
# NOTE: we remove the tags we don't want from both the predicted and the
# true entities. This covers the two cases where mismatches can occur:
#
# 1) Where the model predicts a tag that is not present in the true data
# 2) Where there is a tag in the true data that the model is not capable of
# predicting.
true_named_entities = [ent for ent in true_named_entities if ent.e_type in tags]
pred_named_entities = [ent for ent in pred_named_entities if ent.e_type in tags]
# go through each predicted named-entity
for pred in pred_named_entities:
found_overlap = False
# Check each of the potential scenarios in turn. See
# http://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/
# for scenario explanation.
# Scenario I: Exact match between true and pred
if pred in true_named_entities:
true_which_overlapped_with_pred.append(pred)
evaluation["strict"]["correct"] += 1
evaluation["ent_type"]["correct"] += 1
evaluation["exact"]["correct"] += 1
evaluation["partial"]["correct"] += 1
# for the agg. by e_type results
evaluation_agg_entities_type[pred.e_type]["strict"]["correct"] += 1
evaluation_agg_entities_type[pred.e_type]["ent_type"]["correct"] += 1
evaluation_agg_entities_type[pred.e_type]["exact"]["correct"] += 1
evaluation_agg_entities_type[pred.e_type]["partial"]["correct"] += 1
else:
# check for overlaps with any of the true entities
for true in true_named_entities:
pred_range = range(pred.start_offset, pred.end_offset)
true_range = range(true.start_offset, true.end_offset)
# Scenario IV: Offsets match, but entity type is wrong
if (
true.start_offset == pred.start_offset
and pred.end_offset == true.end_offset
and true.e_type != pred.e_type
):
# overall results
evaluation["strict"]["incorrect"] += 1
evaluation["ent_type"]["incorrect"] += 1
evaluation["partial"]["correct"] += 1
evaluation["exact"]["correct"] += 1
# aggregated by entity type results
evaluation_agg_entities_type[true.e_type]["strict"][
"incorrect"
] += 1
evaluation_agg_entities_type[true.e_type]["ent_type"][
"incorrect"
] += 1
evaluation_agg_entities_type[true.e_type]["partial"]["correct"] += 1
evaluation_agg_entities_type[true.e_type]["exact"]["correct"] += 1
true_which_overlapped_with_pred.append(true)
found_overlap = True
break
# check for an overlap i.e. not exact boundary match, with true entities
elif find_overlap(true_range, pred_range):
true_which_overlapped_with_pred.append(true)
# Scenario V: There is an overlap (but offsets do not match
# exactly), and the entity type is the same.
# 2.1 overlaps with the same entity type
if pred.e_type == true.e_type:
# overall results
evaluation["strict"]["incorrect"] += 1
evaluation["ent_type"]["correct"] += 1
evaluation["partial"]["partial"] += 1
evaluation["exact"]["incorrect"] += 1
# aggregated by entity type results
evaluation_agg_entities_type[true.e_type]["strict"][
"incorrect"
] += 1
evaluation_agg_entities_type[true.e_type]["ent_type"][
"correct"
] += 1
evaluation_agg_entities_type[true.e_type]["partial"][
"partial"
] += 1
evaluation_agg_entities_type[true.e_type]["exact"][
"incorrect"
] += 1
found_overlap = True
break
# Scenario VI: Entities overlap, but the entity type is
# different.
else:
# overall results
evaluation["strict"]["incorrect"] += 1
evaluation["ent_type"]["incorrect"] += 1
evaluation["partial"]["partial"] += 1
evaluation["exact"]["incorrect"] += 1
# aggregated by entity type results
# Results against the true entity
evaluation_agg_entities_type[true.e_type]["strict"][
"incorrect"
] += 1
evaluation_agg_entities_type[true.e_type]["partial"][
"partial"
] += 1
evaluation_agg_entities_type[true.e_type]["ent_type"][
"incorrect"
] += 1
evaluation_agg_entities_type[true.e_type]["exact"][
"incorrect"
] += 1
# Results against the predicted entity
# evaluation_agg_entities_type[pred.e_type]['strict']['spurious'] += 1
found_overlap = True
break
# Scenario II: Entities are spurious (i.e., over-generated).
if not found_overlap:
# Overall results
evaluation["strict"]["spurious"] += 1
evaluation["ent_type"]["spurious"] += 1
evaluation["partial"]["spurious"] += 1
evaluation["exact"]["spurious"] += 1
# Aggregated by entity type results
# NOTE: when pred.e_type is not found in tags
# or when it simply does not appear in the test set, then it is
# spurious, but it is not clear where to assign it at the tag
# level. In this case, it is applied to all target_tags
# found in this example. This will mean that the sum of the
# evaluation_agg_entities will not equal evaluation.
for true in tags:
evaluation_agg_entities_type[true]["strict"]["spurious"] += 1
evaluation_agg_entities_type[true]["ent_type"]["spurious"] += 1
evaluation_agg_entities_type[true]["partial"]["spurious"] += 1
evaluation_agg_entities_type[true]["exact"]["spurious"] += 1
# Scenario III: Entity was missed entirely.
for true in true_named_entities:
if true in true_which_overlapped_with_pred:
continue
else:
# overall results
evaluation["strict"]["missed"] += 1
evaluation["ent_type"]["missed"] += 1
evaluation["partial"]["missed"] += 1
evaluation["exact"]["missed"] += 1
# for the agg. by e_type
evaluation_agg_entities_type[true.e_type]["strict"]["missed"] += 1
evaluation_agg_entities_type[true.e_type]["ent_type"]["missed"] += 1
evaluation_agg_entities_type[true.e_type]["partial"]["missed"] += 1
evaluation_agg_entities_type[true.e_type]["exact"]["missed"] += 1
# Compute 'possible', 'actual' according to SemEval-2013 Task 9.1 on the
# overall results, and use these to calculate precision and recall.
for eval_type in evaluation:
evaluation[eval_type] = compute_actual_possible(evaluation[eval_type])
# Compute 'possible', 'actual', and precision and recall on entity level
# results. Start by cycling through the accumulated results.
for entity_type, entity_level in evaluation_agg_entities_type.items():
# Cycle through the evaluation types for each dict containing entity
# level results.
for eval_type in entity_level:
evaluation_agg_entities_type[entity_type][
eval_type
] = compute_actual_possible(entity_level[eval_type])
return evaluation, evaluation_agg_entities_type
def find_overlap(true_range, pred_range):
"""Find the overlap between two ranges
Find the overlap between two ranges. Return the overlapping values if
present, else return an empty set().
Examples:
>>> find_overlap((1, 2), (2, 3))
2
>>> find_overlap((1, 2), (3, 4))
set()
"""
true_set = set(true_range)
pred_set = set(pred_range)
overlaps = true_set.intersection(pred_set)
return overlaps
def compute_actual_possible(results):
"""
Takes a result dict that has been output by compute metrics.
Returns the results dict with actual, possible populated.
When the results dicts is from partial or ent_type metrics, then
partial_or_type=True to ensure the right calculation is used for
calculating precision and recall.
"""
correct = results["correct"]
incorrect = results["incorrect"]
partial = results["partial"]
missed = results["missed"]
spurious = results["spurious"]
# Possible: number annotations in the gold-standard which contribute to the
# final score
possible = correct + incorrect + partial + missed
# Actual: number of annotations produced by the NER system
actual = correct + incorrect + partial + spurious
results["actual"] = actual
results["possible"] = possible
return results
def compute_precision_recall_f1(results, partial_or_type=False):
"""
Takes a result dict that has been output by compute metrics.
Returns the results dict with precison and recall populated.
When the results dicts is from partial or ent_type metrics, then
partial_or_type=True to ensure the right calculation is used for
calculating precision and recall.
"""
actual = results["actual"]
possible = results["possible"]
partial = results["partial"]
correct = results["correct"]
if partial_or_type:
precision = (correct + 0.5 * partial) / actual if actual > 0 else 0
recall = (correct + 0.5 * partial) / possible if possible > 0 else 0
else:
precision = correct / actual if actual > 0 else 0
recall = correct / possible if possible > 0 else 0
results["precision"] = precision
results["recall"] = recall
results["f1"] = (
precision * recall * 2 / (precision + recall) if precision + recall > 0 else 0
)
return results
def compute_precision_recall_f1_wrapper(results):
"""
Wraps the compute_precision_recall_f1 function and runs on a dict of results
"""
results_a = {
key: compute_precision_recall_f1(value, True)
for key, value in results.items()
if key in ["partial", "ent_type"]
}
results_b = {
key: compute_precision_recall_f1(value)
for key, value in results.items()
if key in ["strict", "exact"]
}
results = {**results_a, **results_b}
return results