|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import namedtuple |
|
from copy import deepcopy |
|
from typing import Sequence, Optional |
|
|
|
import datasets |
|
import evaluate |
|
|
|
|
|
_CITATION = """\ |
|
@misc{nereval, |
|
title={{NER-Evaluation}: Named Entity Evaluation as in SemEval 2013 task 9.1}, |
|
url={https://github.com/davidsbatista/NER-Evaluation}, |
|
note={Software available from https://github.com/davidsbatista/NER-Evaluation}, |
|
author={Batista David}, |
|
year={2018}, |
|
} |
|
""" |
|
|
|
|
|
_DESCRIPTION = """\ |
|
ner-eval is a Python frame for sequence labeling evaluation. I twas used in SemEval 2013 task 9.1. |
|
It supports exact match, partial match, spurious and other errors. |
|
""" |
|
|
|
|
|
|
|
_KWARGS_DESCRIPTION = """ |
|
Calculates how good are predictions given some references, using certain scores |
|
Args: |
|
predictions: List of List of predicted labels (Estimated targets as returned by a tagger) |
|
references: List of List of reference labels (Ground truth (correct) target values) |
|
tags: List of tags to evaluate. default: None |
|
Returns: |
|
'scores' dict. Summary of the scores for overall and each tag. |
|
{ |
|
"overall": { |
|
"strict_precision": 0.0, |
|
"strict_recall": 0.0, |
|
"strict_f1": 0, |
|
"ent_type_precision": 0.0, |
|
"ent_type_recall": 0.0, |
|
"ent_type_f1": 0, |
|
"partial_precision": 0.0, |
|
"partial_recall": 0.0, |
|
"partial_f1": 0, |
|
"exact_precision": 0.0, |
|
"exact_recall": 0.0, |
|
"exact_f1": 0, |
|
}, |
|
"ORG": { |
|
"strict_precision": 0.0, |
|
"strict_recall": 0.0, |
|
"strict_f1": 0, |
|
"ent_type_precision": 0.0, |
|
"ent_type_recall": 0.0, |
|
"ent_type_f1": 0, |
|
"partial_precision": 0.0, |
|
"partial_recall": 0.0, |
|
"partial_f1": 0, |
|
"exact_precision": 0.0, |
|
"exact_recall": 0.0, |
|
"exact_f1": 0, |
|
}, |
|
"PER": { |
|
"strict_precision": 0.0, |
|
"strict_recall": 0.0, |
|
"strict_f1": 0, |
|
"ent_type_precision": 0.0, |
|
"ent_type_recall": 0.0, |
|
"ent_type_f1": 0, |
|
"partial_precision": 0.0, |
|
"partial_recall": 0.0, |
|
"partial_f1": 0, |
|
"exact_precision": 0.0, |
|
"exact_recall": 0.0, |
|
"exact_f1": 0, |
|
}, |
|
"LOC": { |
|
"strict_precision": 0.0, |
|
"strict_recall": 0.0, |
|
"strict_f1": 0, |
|
"ent_type_precision": 0.0, |
|
"ent_type_recall": 0.0, |
|
"ent_type_f1": 0, |
|
"partial_precision": 0.0, |
|
"partial_recall": 0.0, |
|
"partial_f1": 0, |
|
"exact_precision": 0.0, |
|
"exact_recall": 0.0, |
|
"exact_f1": 0, |
|
}, |
|
} |
|
Examples: |
|
>>> my_new_module = evaluate.load("fschlatt/ner_eval") |
|
>>> results = my_new_module.compute( |
|
... references=[["B-LOC", "I-LOC", "I-LOC", "B-ORG", "I-ORG", "O", "B-PER", "I-PER", "I-PER", "O"]], |
|
... predictions=[["B-LOC", "I-LOC", "O", "O", "B-ORG", "I-ORG", "O", "B-PER", "I-PER", "O"]] |
|
... ) |
|
>>> print(results) |
|
{ |
|
"overall": { |
|
"strict_precision": 0.0, |
|
"strict_recall": 0.0, |
|
"strict_f1": 0, |
|
"ent_type_precision": 2 / 3, |
|
"ent_type_recall": 2 / 3, |
|
"ent_type_f1": 2 / 3, |
|
"partial_precision": 1 / 3, |
|
"partial_recall": 1 / 3, |
|
"partial_f1": 1 / 3, |
|
"exact_precision": 0.0, |
|
"exact_recall": 0.0, |
|
"exact_f1": 0, |
|
}, |
|
"ORG": { |
|
"strict_precision": 0.0, |
|
"strict_recall": 0.0, |
|
"strict_f1": 0, |
|
"ent_type_precision": 0.0, |
|
"ent_type_recall": 0.0, |
|
"ent_type_f1": 0, |
|
"partial_precision": 0.0, |
|
"partial_recall": 0.0, |
|
"partial_f1": 0, |
|
"exact_precision": 0.0, |
|
"exact_recall": 0.0, |
|
"exact_f1": 0, |
|
}, |
|
"PER": { |
|
"strict_precision": 0.0, |
|
"strict_recall": 0.0, |
|
"strict_f1": 0, |
|
"ent_type_precision": 0.5, |
|
"ent_type_recall": 1.0, |
|
"ent_type_f1": 2 / 3, |
|
"partial_precision": 0.25, |
|
"partial_recall": 0.5, |
|
"partial_f1": 1 / 3, |
|
"exact_precision": 0.0, |
|
"exact_recall": 0.0, |
|
"exact_f1": 0, |
|
}, |
|
"LOC": { |
|
"strict_precision": 0.0, |
|
"strict_recall": 0.0, |
|
"strict_f1": 0, |
|
"ent_type_precision": 0.5, |
|
"ent_type_recall": 1.0, |
|
"ent_type_f1": 2 / 3, |
|
"partial_precision": 0.25, |
|
"partial_recall": 0.5, |
|
"partial_f1": 1 / 3, |
|
"exact_precision": 0.0, |
|
"exact_recall": 0.0, |
|
"exact_f1": 0, |
|
} |
|
} |
|
""" |
|
|
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
class NEREval(evaluate.Metric): |
|
"""TODO: Short description of my evaluation module.""" |
|
|
|
def _info(self): |
|
return evaluate.MetricInfo( |
|
|
|
module_type="metric", |
|
description=_DESCRIPTION, |
|
citation=_CITATION, |
|
homepage="https://github.com/davidsbatista/NER-Evaluation", |
|
inputs_description=_KWARGS_DESCRIPTION, |
|
|
|
features=datasets.Features( |
|
{ |
|
"predictions": datasets.Sequence( |
|
datasets.Value("string", id="label"), id="sequence" |
|
), |
|
"references": datasets.Sequence( |
|
datasets.Value("string", id="label"), id="sequence" |
|
), |
|
} |
|
), |
|
|
|
codebase_urls=["https://github.com/davidsbatista/NER-Evaluation"], |
|
reference_urls=[ |
|
"https://github.com/davidsbatista/NER-Evaluation", |
|
"https://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/", |
|
], |
|
) |
|
|
|
def _download_and_prepare(self, dl_manager): |
|
"""Optional: download external resources useful to compute the scores""" |
|
|
|
pass |
|
|
|
def _compute( |
|
self, |
|
predictions: Sequence[Sequence[str]], |
|
references: Sequence[Sequence[str]], |
|
tags: Optional[Sequence[str]] = None, |
|
modes: Optional[Sequence[str]] = None, |
|
): |
|
if tags is None: |
|
tags = list(parse_tags(predictions).union(parse_tags(references))) |
|
|
|
evaluator = Evaluator(predictions, references, tags) |
|
results, agg_results = evaluator.evaluate() |
|
|
|
out = {"overall": parse_results(results, modes)} |
|
for tag, tag_result in agg_results.items(): |
|
out = {**out, tag: parse_results(tag_result, modes)} |
|
|
|
return out |
|
|
|
|
|
def parse_results(results, modes: Optional[Sequence[str]] = None): |
|
if modes is None: |
|
modes = ["strict", "ent_type", "partial", "exact"] |
|
|
|
out = {} |
|
for mode in modes: |
|
out[f"{mode}_precision"] = results[mode]["precision"] |
|
out[f"{mode}_recall"] = results[mode]["recall"] |
|
out[f"{mode}_f1"] = results[mode]["f1"] |
|
return out |
|
|
|
|
|
def parse_tags(tokens: Sequence[Sequence[str]]): |
|
tags = set() |
|
for seq in tokens: |
|
for t in seq: |
|
tags.add(t.split("-")[-1]) |
|
tags.discard("O") |
|
return tags |
|
|
|
|
|
Entity = namedtuple("Entity", "e_type start_offset end_offset") |
|
|
|
|
|
class Evaluator: |
|
def __init__(self, true, pred, tags): |
|
""" """ |
|
|
|
if len(true) != len(pred): |
|
raise ValueError("Number of predicted documents does not equal true") |
|
|
|
self.true = true |
|
self.pred = pred |
|
self.tags = tags |
|
|
|
|
|
|
|
self.metrics_results = { |
|
"correct": 0, |
|
"incorrect": 0, |
|
"partial": 0, |
|
"missed": 0, |
|
"spurious": 0, |
|
"possible": 0, |
|
"actual": 0, |
|
"precision": 0, |
|
"recall": 0, |
|
"f1": 0, |
|
} |
|
|
|
|
|
|
|
self.results = { |
|
"strict": deepcopy(self.metrics_results), |
|
"ent_type": deepcopy(self.metrics_results), |
|
"partial": deepcopy(self.metrics_results), |
|
"exact": deepcopy(self.metrics_results), |
|
} |
|
|
|
|
|
|
|
self.evaluation_agg_entities_type = {e: deepcopy(self.results) for e in tags} |
|
|
|
def evaluate(self): |
|
for true_ents, pred_ents in zip(self.true, self.pred): |
|
|
|
|
|
|
|
|
|
if len(true_ents) != len(pred_ents): |
|
raise ValueError("Prediction length does not match true example length") |
|
|
|
|
|
|
|
tmp_results, tmp_agg_results = compute_metrics( |
|
collect_named_entities(true_ents), |
|
collect_named_entities(pred_ents), |
|
self.tags, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
for eval_schema in self.results: |
|
for metric in self.results[eval_schema]: |
|
self.results[eval_schema][metric] += tmp_results[eval_schema][ |
|
metric |
|
] |
|
|
|
|
|
|
|
self.results = compute_precision_recall_f1_wrapper(self.results) |
|
|
|
|
|
|
|
for e_type in self.tags: |
|
for eval_schema in tmp_agg_results[e_type]: |
|
for metric in tmp_agg_results[e_type][eval_schema]: |
|
self.evaluation_agg_entities_type[e_type][eval_schema][ |
|
metric |
|
] += tmp_agg_results[e_type][eval_schema][metric] |
|
|
|
|
|
|
|
self.evaluation_agg_entities_type[ |
|
e_type |
|
] = compute_precision_recall_f1_wrapper( |
|
self.evaluation_agg_entities_type[e_type] |
|
) |
|
|
|
return self.results, self.evaluation_agg_entities_type |
|
|
|
|
|
def collect_named_entities(tokens): |
|
""" |
|
Creates a list of Entity named-tuples, storing the entity type and the start and end |
|
offsets of the entity. |
|
|
|
:param tokens: a list of tags |
|
:return: a list of Entity named-tuples |
|
""" |
|
|
|
named_entities = [] |
|
start_offset = None |
|
end_offset = None |
|
ent_type = None |
|
|
|
for offset, token_tag in enumerate(tokens): |
|
if token_tag == "O": |
|
if ent_type is not None and start_offset is not None: |
|
end_offset = offset - 1 |
|
named_entities.append(Entity(ent_type, start_offset, end_offset)) |
|
start_offset = None |
|
end_offset = None |
|
ent_type = None |
|
|
|
elif ent_type is None: |
|
ent_type = token_tag[2:] |
|
start_offset = offset |
|
|
|
elif ent_type != token_tag[2:] or ( |
|
ent_type == token_tag[2:] and token_tag[:1] == "B" |
|
): |
|
end_offset = offset - 1 |
|
named_entities.append(Entity(ent_type, start_offset, end_offset)) |
|
|
|
|
|
ent_type = token_tag[2:] |
|
start_offset = offset |
|
end_offset = None |
|
|
|
|
|
|
|
if ent_type is not None and start_offset is not None and end_offset is None: |
|
named_entities.append(Entity(ent_type, start_offset, len(tokens) - 1)) |
|
|
|
return named_entities |
|
|
|
|
|
def compute_metrics(true_named_entities, pred_named_entities, tags): |
|
eval_metrics = { |
|
"correct": 0, |
|
"incorrect": 0, |
|
"partial": 0, |
|
"missed": 0, |
|
"spurious": 0, |
|
"precision": 0, |
|
"recall": 0, |
|
} |
|
|
|
|
|
|
|
evaluation = { |
|
"strict": deepcopy(eval_metrics), |
|
"ent_type": deepcopy(eval_metrics), |
|
"partial": deepcopy(eval_metrics), |
|
"exact": deepcopy(eval_metrics), |
|
} |
|
|
|
|
|
|
|
evaluation_agg_entities_type = {e: deepcopy(evaluation) for e in tags} |
|
|
|
|
|
|
|
true_which_overlapped_with_pred = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
true_named_entities = [ent for ent in true_named_entities if ent.e_type in tags] |
|
pred_named_entities = [ent for ent in pred_named_entities if ent.e_type in tags] |
|
|
|
|
|
|
|
for pred in pred_named_entities: |
|
found_overlap = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if pred in true_named_entities: |
|
true_which_overlapped_with_pred.append(pred) |
|
evaluation["strict"]["correct"] += 1 |
|
evaluation["ent_type"]["correct"] += 1 |
|
evaluation["exact"]["correct"] += 1 |
|
evaluation["partial"]["correct"] += 1 |
|
|
|
|
|
evaluation_agg_entities_type[pred.e_type]["strict"]["correct"] += 1 |
|
evaluation_agg_entities_type[pred.e_type]["ent_type"]["correct"] += 1 |
|
evaluation_agg_entities_type[pred.e_type]["exact"]["correct"] += 1 |
|
evaluation_agg_entities_type[pred.e_type]["partial"]["correct"] += 1 |
|
|
|
else: |
|
|
|
|
|
for true in true_named_entities: |
|
pred_range = range(pred.start_offset, pred.end_offset) |
|
true_range = range(true.start_offset, true.end_offset) |
|
|
|
|
|
|
|
if ( |
|
true.start_offset == pred.start_offset |
|
and pred.end_offset == true.end_offset |
|
and true.e_type != pred.e_type |
|
): |
|
|
|
evaluation["strict"]["incorrect"] += 1 |
|
evaluation["ent_type"]["incorrect"] += 1 |
|
evaluation["partial"]["correct"] += 1 |
|
evaluation["exact"]["correct"] += 1 |
|
|
|
|
|
evaluation_agg_entities_type[true.e_type]["strict"][ |
|
"incorrect" |
|
] += 1 |
|
evaluation_agg_entities_type[true.e_type]["ent_type"][ |
|
"incorrect" |
|
] += 1 |
|
evaluation_agg_entities_type[true.e_type]["partial"]["correct"] += 1 |
|
evaluation_agg_entities_type[true.e_type]["exact"]["correct"] += 1 |
|
|
|
true_which_overlapped_with_pred.append(true) |
|
found_overlap = True |
|
|
|
break |
|
|
|
|
|
|
|
elif find_overlap(true_range, pred_range): |
|
true_which_overlapped_with_pred.append(true) |
|
|
|
|
|
|
|
|
|
|
|
if pred.e_type == true.e_type: |
|
|
|
evaluation["strict"]["incorrect"] += 1 |
|
evaluation["ent_type"]["correct"] += 1 |
|
evaluation["partial"]["partial"] += 1 |
|
evaluation["exact"]["incorrect"] += 1 |
|
|
|
|
|
evaluation_agg_entities_type[true.e_type]["strict"][ |
|
"incorrect" |
|
] += 1 |
|
evaluation_agg_entities_type[true.e_type]["ent_type"][ |
|
"correct" |
|
] += 1 |
|
evaluation_agg_entities_type[true.e_type]["partial"][ |
|
"partial" |
|
] += 1 |
|
evaluation_agg_entities_type[true.e_type]["exact"][ |
|
"incorrect" |
|
] += 1 |
|
|
|
found_overlap = True |
|
|
|
break |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
evaluation["strict"]["incorrect"] += 1 |
|
evaluation["ent_type"]["incorrect"] += 1 |
|
evaluation["partial"]["partial"] += 1 |
|
evaluation["exact"]["incorrect"] += 1 |
|
|
|
|
|
|
|
|
|
evaluation_agg_entities_type[true.e_type]["strict"][ |
|
"incorrect" |
|
] += 1 |
|
evaluation_agg_entities_type[true.e_type]["partial"][ |
|
"partial" |
|
] += 1 |
|
evaluation_agg_entities_type[true.e_type]["ent_type"][ |
|
"incorrect" |
|
] += 1 |
|
evaluation_agg_entities_type[true.e_type]["exact"][ |
|
"incorrect" |
|
] += 1 |
|
|
|
|
|
|
|
|
|
|
|
found_overlap = True |
|
|
|
break |
|
|
|
|
|
|
|
if not found_overlap: |
|
|
|
|
|
evaluation["strict"]["spurious"] += 1 |
|
evaluation["ent_type"]["spurious"] += 1 |
|
evaluation["partial"]["spurious"] += 1 |
|
evaluation["exact"]["spurious"] += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for true in tags: |
|
evaluation_agg_entities_type[true]["strict"]["spurious"] += 1 |
|
evaluation_agg_entities_type[true]["ent_type"]["spurious"] += 1 |
|
evaluation_agg_entities_type[true]["partial"]["spurious"] += 1 |
|
evaluation_agg_entities_type[true]["exact"]["spurious"] += 1 |
|
|
|
|
|
|
|
for true in true_named_entities: |
|
if true in true_which_overlapped_with_pred: |
|
continue |
|
else: |
|
|
|
evaluation["strict"]["missed"] += 1 |
|
evaluation["ent_type"]["missed"] += 1 |
|
evaluation["partial"]["missed"] += 1 |
|
evaluation["exact"]["missed"] += 1 |
|
|
|
|
|
evaluation_agg_entities_type[true.e_type]["strict"]["missed"] += 1 |
|
evaluation_agg_entities_type[true.e_type]["ent_type"]["missed"] += 1 |
|
evaluation_agg_entities_type[true.e_type]["partial"]["missed"] += 1 |
|
evaluation_agg_entities_type[true.e_type]["exact"]["missed"] += 1 |
|
|
|
|
|
|
|
|
|
for eval_type in evaluation: |
|
evaluation[eval_type] = compute_actual_possible(evaluation[eval_type]) |
|
|
|
|
|
|
|
|
|
for entity_type, entity_level in evaluation_agg_entities_type.items(): |
|
|
|
|
|
|
|
for eval_type in entity_level: |
|
evaluation_agg_entities_type[entity_type][ |
|
eval_type |
|
] = compute_actual_possible(entity_level[eval_type]) |
|
|
|
return evaluation, evaluation_agg_entities_type |
|
|
|
|
|
def find_overlap(true_range, pred_range): |
|
"""Find the overlap between two ranges |
|
|
|
Find the overlap between two ranges. Return the overlapping values if |
|
present, else return an empty set(). |
|
|
|
Examples: |
|
|
|
>>> find_overlap((1, 2), (2, 3)) |
|
2 |
|
>>> find_overlap((1, 2), (3, 4)) |
|
set() |
|
""" |
|
|
|
true_set = set(true_range) |
|
pred_set = set(pred_range) |
|
|
|
overlaps = true_set.intersection(pred_set) |
|
|
|
return overlaps |
|
|
|
|
|
def compute_actual_possible(results): |
|
""" |
|
Takes a result dict that has been output by compute metrics. |
|
Returns the results dict with actual, possible populated. |
|
|
|
When the results dicts is from partial or ent_type metrics, then |
|
partial_or_type=True to ensure the right calculation is used for |
|
calculating precision and recall. |
|
""" |
|
|
|
correct = results["correct"] |
|
incorrect = results["incorrect"] |
|
partial = results["partial"] |
|
missed = results["missed"] |
|
spurious = results["spurious"] |
|
|
|
|
|
|
|
|
|
possible = correct + incorrect + partial + missed |
|
|
|
|
|
|
|
actual = correct + incorrect + partial + spurious |
|
|
|
results["actual"] = actual |
|
results["possible"] = possible |
|
|
|
return results |
|
|
|
|
|
def compute_precision_recall_f1(results, partial_or_type=False): |
|
""" |
|
Takes a result dict that has been output by compute metrics. |
|
Returns the results dict with precison and recall populated. |
|
|
|
When the results dicts is from partial or ent_type metrics, then |
|
partial_or_type=True to ensure the right calculation is used for |
|
calculating precision and recall. |
|
""" |
|
|
|
actual = results["actual"] |
|
possible = results["possible"] |
|
partial = results["partial"] |
|
correct = results["correct"] |
|
|
|
if partial_or_type: |
|
precision = (correct + 0.5 * partial) / actual if actual > 0 else 0 |
|
recall = (correct + 0.5 * partial) / possible if possible > 0 else 0 |
|
|
|
else: |
|
precision = correct / actual if actual > 0 else 0 |
|
recall = correct / possible if possible > 0 else 0 |
|
|
|
results["precision"] = precision |
|
results["recall"] = recall |
|
results["f1"] = ( |
|
precision * recall * 2 / (precision + recall) if precision + recall > 0 else 0 |
|
) |
|
|
|
return results |
|
|
|
|
|
def compute_precision_recall_f1_wrapper(results): |
|
""" |
|
Wraps the compute_precision_recall_f1 function and runs on a dict of results |
|
""" |
|
|
|
results_a = { |
|
key: compute_precision_recall_f1(value, True) |
|
for key, value in results.items() |
|
if key in ["partial", "ent_type"] |
|
} |
|
results_b = { |
|
key: compute_precision_recall_f1(value) |
|
for key, value in results.items() |
|
if key in ["strict", "exact"] |
|
} |
|
|
|
results = {**results_a, **results_b} |
|
|
|
return results |
|
|