ner_eval / tests /test_ner_eval.py
fschlatt's picture
initial commit
d93bc17
raw
history blame
10.1 kB
import evaluate
import pytest
ner_eval = evaluate.load("ner_eval.py")
test_cases = [
{
"predictions": ["B-PER", "I-PER", "O", "B-LOC", "I-LOC", "O", "O", "B-ORG"],
"references": ["B-PER", "I-PER", "O", "B-LOC", "I-LOC", "O", "O", "B-ORG"],
"results": {
"overall": {
"strict_precision": 1.0,
"strict_recall": 1.0,
"strict_f1": 1.0,
"ent_type_precision": 1.0,
"ent_type_recall": 1.0,
"ent_type_f1": 1.0,
"partial_precision": 1.0,
"partial_recall": 1.0,
"partial_f1": 1.0,
"exact_precision": 1.0,
"exact_recall": 1.0,
"exact_f1": 1.0,
},
"LOC": {
"strict_precision": 1.0,
"strict_recall": 1.0,
"strict_f1": 1.0,
"ent_type_precision": 1.0,
"ent_type_recall": 1.0,
"ent_type_f1": 1.0,
"partial_precision": 1.0,
"partial_recall": 1.0,
"partial_f1": 1.0,
"exact_precision": 1.0,
"exact_recall": 1.0,
"exact_f1": 1.0,
},
"PER": {
"strict_precision": 1.0,
"strict_recall": 1.0,
"strict_f1": 1.0,
"ent_type_precision": 1.0,
"ent_type_recall": 1.0,
"ent_type_f1": 1.0,
"partial_precision": 1.0,
"partial_recall": 1.0,
"partial_f1": 1.0,
"exact_precision": 1.0,
"exact_recall": 1.0,
"exact_f1": 1.0,
},
"ORG": {
"strict_precision": 1.0,
"strict_recall": 1.0,
"strict_f1": 1.0,
"ent_type_precision": 1.0,
"ent_type_recall": 1.0,
"ent_type_f1": 1.0,
"partial_precision": 1.0,
"partial_recall": 1.0,
"partial_f1": 1.0,
"exact_precision": 1.0,
"exact_recall": 1.0,
"exact_f1": 1.0,
},
},
},
{
"predictions": [
"B-LOC",
"I-LOC",
"O",
"B-PER",
"I-PER",
"I-PER",
"I-PER",
"O",
"B-LOC",
"O",
],
"references": [
"B-LOC",
"I-LOC",
"O",
"B-PER",
"I-PER",
"I-PER",
"I-PER",
"O",
"B-LOC",
"O",
],
"results": {
"overall": {
"strict_precision": 1.0,
"strict_recall": 1.0,
"strict_f1": 1.0,
"ent_type_precision": 1.0,
"ent_type_recall": 1.0,
"ent_type_f1": 1.0,
"partial_precision": 1.0,
"partial_recall": 1.0,
"partial_f1": 1.0,
"exact_precision": 1.0,
"exact_recall": 1.0,
"exact_f1": 1.0,
},
"LOC": {
"strict_precision": 1.0,
"strict_recall": 1.0,
"strict_f1": 1.0,
"ent_type_precision": 1.0,
"ent_type_recall": 1.0,
"ent_type_f1": 1.0,
"partial_precision": 1.0,
"partial_recall": 1.0,
"partial_f1": 1.0,
"exact_precision": 1.0,
"exact_recall": 1.0,
"exact_f1": 1.0,
},
"PER": {
"strict_precision": 1.0,
"strict_recall": 1.0,
"strict_f1": 1.0,
"ent_type_precision": 1.0,
"ent_type_recall": 1.0,
"ent_type_f1": 1.0,
"partial_precision": 1.0,
"partial_recall": 1.0,
"partial_f1": 1.0,
"exact_precision": 1.0,
"exact_recall": 1.0,
"exact_f1": 1.0,
},
},
},
{
"predictions": ["O", "B-LOC", "I-LOC", "B-PER", "I-PER", "O", "B-ORG"],
"references": ["O", "B-LOC", "I-LOC", "O", "B-PER", "I-PER", "O", "B-ORG"],
},
{
"predictions": ["B-PER", "O", "B-LOC", "I-LOC", "O", "B-ORG", "I-ORG"],
"references": ["B-PER", "I-PER", "O", "B-LOC", "I-LOC", "O", "B-ORG"],
"results": {
"overall": {
"strict_precision": 0.0,
"strict_recall": 0.0,
"strict_f1": 0,
"ent_type_precision": 0.0,
"ent_type_recall": 0.0,
"ent_type_f1": 0,
"partial_precision": 0.0,
"partial_recall": 0.0,
"partial_f1": 0,
"exact_precision": 0.0,
"exact_recall": 0.0,
"exact_f1": 0,
},
"ORG": {
"strict_precision": 0.0,
"strict_recall": 0.0,
"strict_f1": 0,
"ent_type_precision": 0.0,
"ent_type_recall": 0.0,
"ent_type_f1": 0,
"partial_precision": 0.0,
"partial_recall": 0.0,
"partial_f1": 0,
"exact_precision": 0.0,
"exact_recall": 0.0,
"exact_f1": 0,
},
"PER": {
"strict_precision": 0.0,
"strict_recall": 0.0,
"strict_f1": 0,
"ent_type_precision": 0.0,
"ent_type_recall": 0.0,
"ent_type_f1": 0,
"partial_precision": 0.0,
"partial_recall": 0.0,
"partial_f1": 0,
"exact_precision": 0.0,
"exact_recall": 0.0,
"exact_f1": 0,
},
"LOC": {
"strict_precision": 0.0,
"strict_recall": 0.0,
"strict_f1": 0,
"ent_type_precision": 0.0,
"ent_type_recall": 0.0,
"ent_type_f1": 0,
"partial_precision": 0.0,
"partial_recall": 0.0,
"partial_f1": 0,
"exact_precision": 0.0,
"exact_recall": 0.0,
"exact_f1": 0,
},
},
},
{
"predictions": [
"B-LOC",
"I-LOC",
"I-LOC",
"B-ORG",
"I-ORG",
"O",
"B-PER",
"I-PER",
"I-PER",
"O",
],
"references": [
"B-LOC",
"I-LOC",
"O",
"O",
"B-ORG",
"I-ORG",
"O",
"B-PER",
"I-PER",
"O",
],
"results": {
"overall": {
"strict_precision": 0.0,
"strict_recall": 0.0,
"strict_f1": 0,
"ent_type_precision": 2 / 3,
"ent_type_recall": 2 / 3,
"ent_type_f1": 2 / 3,
"partial_precision": 1 / 3,
"partial_recall": 1 / 3,
"partial_f1": 1 / 3,
"exact_precision": 0.0,
"exact_recall": 0.0,
"exact_f1": 0,
},
"ORG": {
"strict_precision": 0.0,
"strict_recall": 0.0,
"strict_f1": 0,
"ent_type_precision": 0.0,
"ent_type_recall": 0.0,
"ent_type_f1": 0,
"partial_precision": 0.0,
"partial_recall": 0.0,
"partial_f1": 0,
"exact_precision": 0.0,
"exact_recall": 0.0,
"exact_f1": 0,
},
"PER": {
"strict_precision": 0.0,
"strict_recall": 0.0,
"strict_f1": 0,
"ent_type_precision": 0.5,
"ent_type_recall": 1.0,
"ent_type_f1": 2 / 3,
"partial_precision": 0.25,
"partial_recall": 0.5,
"partial_f1": 1 / 3,
"exact_precision": 0.0,
"exact_recall": 0.0,
"exact_f1": 0,
},
"LOC": {
"strict_precision": 0.0,
"strict_recall": 0.0,
"strict_f1": 0,
"ent_type_precision": 0.5,
"ent_type_recall": 1.0,
"ent_type_f1": 2 / 3,
"partial_precision": 0.25,
"partial_recall": 0.5,
"partial_f1": 1 / 3,
"exact_precision": 0.0,
"exact_recall": 0.0,
"exact_f1": 0,
},
},
},
]
def compare_results(result1, result2):
# recursively check if dictionaries are equal
if isinstance(result1, dict):
for key in result1.keys():
if not compare_results(result1[key], result2[key]):
return False
return True
elif isinstance(result1, list):
for item1, item2 in zip(result1, result2):
if not compare_results(item1, item2):
return False
return True
else:
return result1 == result2
@pytest.mark.parametrize("case", test_cases)
def test_metric(case):
if "results" not in case:
with pytest.raises(ValueError):
results = ner_eval.compute(
predictions=[case["predictions"]], references=[case["references"]]
)
else:
results = ner_eval.compute(
predictions=[case["predictions"]], references=[case["references"]]
)
assert compare_results(results, case["results"])