|
import evaluate |
|
import pytest |
|
|
|
ner_eval = evaluate.load("ner_eval.py") |
|
|
|
test_cases = [ |
|
{ |
|
"predictions": ["B-PER", "I-PER", "O", "B-LOC", "I-LOC", "O", "O", "B-ORG"], |
|
"references": ["B-PER", "I-PER", "O", "B-LOC", "I-LOC", "O", "O", "B-ORG"], |
|
"results": { |
|
"overall": { |
|
"strict_precision": 1.0, |
|
"strict_recall": 1.0, |
|
"strict_f1": 1.0, |
|
"ent_type_precision": 1.0, |
|
"ent_type_recall": 1.0, |
|
"ent_type_f1": 1.0, |
|
"partial_precision": 1.0, |
|
"partial_recall": 1.0, |
|
"partial_f1": 1.0, |
|
"exact_precision": 1.0, |
|
"exact_recall": 1.0, |
|
"exact_f1": 1.0, |
|
}, |
|
"LOC": { |
|
"strict_precision": 1.0, |
|
"strict_recall": 1.0, |
|
"strict_f1": 1.0, |
|
"ent_type_precision": 1.0, |
|
"ent_type_recall": 1.0, |
|
"ent_type_f1": 1.0, |
|
"partial_precision": 1.0, |
|
"partial_recall": 1.0, |
|
"partial_f1": 1.0, |
|
"exact_precision": 1.0, |
|
"exact_recall": 1.0, |
|
"exact_f1": 1.0, |
|
}, |
|
"PER": { |
|
"strict_precision": 1.0, |
|
"strict_recall": 1.0, |
|
"strict_f1": 1.0, |
|
"ent_type_precision": 1.0, |
|
"ent_type_recall": 1.0, |
|
"ent_type_f1": 1.0, |
|
"partial_precision": 1.0, |
|
"partial_recall": 1.0, |
|
"partial_f1": 1.0, |
|
"exact_precision": 1.0, |
|
"exact_recall": 1.0, |
|
"exact_f1": 1.0, |
|
}, |
|
"ORG": { |
|
"strict_precision": 1.0, |
|
"strict_recall": 1.0, |
|
"strict_f1": 1.0, |
|
"ent_type_precision": 1.0, |
|
"ent_type_recall": 1.0, |
|
"ent_type_f1": 1.0, |
|
"partial_precision": 1.0, |
|
"partial_recall": 1.0, |
|
"partial_f1": 1.0, |
|
"exact_precision": 1.0, |
|
"exact_recall": 1.0, |
|
"exact_f1": 1.0, |
|
}, |
|
}, |
|
}, |
|
{ |
|
"predictions": [ |
|
"B-LOC", |
|
"I-LOC", |
|
"O", |
|
"B-PER", |
|
"I-PER", |
|
"I-PER", |
|
"I-PER", |
|
"O", |
|
"B-LOC", |
|
"O", |
|
], |
|
"references": [ |
|
"B-LOC", |
|
"I-LOC", |
|
"O", |
|
"B-PER", |
|
"I-PER", |
|
"I-PER", |
|
"I-PER", |
|
"O", |
|
"B-LOC", |
|
"O", |
|
], |
|
"results": { |
|
"overall": { |
|
"strict_precision": 1.0, |
|
"strict_recall": 1.0, |
|
"strict_f1": 1.0, |
|
"ent_type_precision": 1.0, |
|
"ent_type_recall": 1.0, |
|
"ent_type_f1": 1.0, |
|
"partial_precision": 1.0, |
|
"partial_recall": 1.0, |
|
"partial_f1": 1.0, |
|
"exact_precision": 1.0, |
|
"exact_recall": 1.0, |
|
"exact_f1": 1.0, |
|
}, |
|
"LOC": { |
|
"strict_precision": 1.0, |
|
"strict_recall": 1.0, |
|
"strict_f1": 1.0, |
|
"ent_type_precision": 1.0, |
|
"ent_type_recall": 1.0, |
|
"ent_type_f1": 1.0, |
|
"partial_precision": 1.0, |
|
"partial_recall": 1.0, |
|
"partial_f1": 1.0, |
|
"exact_precision": 1.0, |
|
"exact_recall": 1.0, |
|
"exact_f1": 1.0, |
|
}, |
|
"PER": { |
|
"strict_precision": 1.0, |
|
"strict_recall": 1.0, |
|
"strict_f1": 1.0, |
|
"ent_type_precision": 1.0, |
|
"ent_type_recall": 1.0, |
|
"ent_type_f1": 1.0, |
|
"partial_precision": 1.0, |
|
"partial_recall": 1.0, |
|
"partial_f1": 1.0, |
|
"exact_precision": 1.0, |
|
"exact_recall": 1.0, |
|
"exact_f1": 1.0, |
|
}, |
|
}, |
|
}, |
|
{ |
|
"predictions": ["O", "B-LOC", "I-LOC", "B-PER", "I-PER", "O", "B-ORG"], |
|
"references": ["O", "B-LOC", "I-LOC", "O", "B-PER", "I-PER", "O", "B-ORG"], |
|
}, |
|
{ |
|
"predictions": ["B-PER", "O", "B-LOC", "I-LOC", "O", "B-ORG", "I-ORG"], |
|
"references": ["B-PER", "I-PER", "O", "B-LOC", "I-LOC", "O", "B-ORG"], |
|
"results": { |
|
"overall": { |
|
"strict_precision": 0.0, |
|
"strict_recall": 0.0, |
|
"strict_f1": 0, |
|
"ent_type_precision": 0.0, |
|
"ent_type_recall": 0.0, |
|
"ent_type_f1": 0, |
|
"partial_precision": 0.0, |
|
"partial_recall": 0.0, |
|
"partial_f1": 0, |
|
"exact_precision": 0.0, |
|
"exact_recall": 0.0, |
|
"exact_f1": 0, |
|
}, |
|
"ORG": { |
|
"strict_precision": 0.0, |
|
"strict_recall": 0.0, |
|
"strict_f1": 0, |
|
"ent_type_precision": 0.0, |
|
"ent_type_recall": 0.0, |
|
"ent_type_f1": 0, |
|
"partial_precision": 0.0, |
|
"partial_recall": 0.0, |
|
"partial_f1": 0, |
|
"exact_precision": 0.0, |
|
"exact_recall": 0.0, |
|
"exact_f1": 0, |
|
}, |
|
"PER": { |
|
"strict_precision": 0.0, |
|
"strict_recall": 0.0, |
|
"strict_f1": 0, |
|
"ent_type_precision": 0.0, |
|
"ent_type_recall": 0.0, |
|
"ent_type_f1": 0, |
|
"partial_precision": 0.0, |
|
"partial_recall": 0.0, |
|
"partial_f1": 0, |
|
"exact_precision": 0.0, |
|
"exact_recall": 0.0, |
|
"exact_f1": 0, |
|
}, |
|
"LOC": { |
|
"strict_precision": 0.0, |
|
"strict_recall": 0.0, |
|
"strict_f1": 0, |
|
"ent_type_precision": 0.0, |
|
"ent_type_recall": 0.0, |
|
"ent_type_f1": 0, |
|
"partial_precision": 0.0, |
|
"partial_recall": 0.0, |
|
"partial_f1": 0, |
|
"exact_precision": 0.0, |
|
"exact_recall": 0.0, |
|
"exact_f1": 0, |
|
}, |
|
}, |
|
}, |
|
{ |
|
"predictions": [ |
|
"B-LOC", |
|
"I-LOC", |
|
"I-LOC", |
|
"B-ORG", |
|
"I-ORG", |
|
"O", |
|
"B-PER", |
|
"I-PER", |
|
"I-PER", |
|
"O", |
|
], |
|
"references": [ |
|
"B-LOC", |
|
"I-LOC", |
|
"O", |
|
"O", |
|
"B-ORG", |
|
"I-ORG", |
|
"O", |
|
"B-PER", |
|
"I-PER", |
|
"O", |
|
], |
|
"results": { |
|
"overall": { |
|
"strict_precision": 0.0, |
|
"strict_recall": 0.0, |
|
"strict_f1": 0, |
|
"ent_type_precision": 2 / 3, |
|
"ent_type_recall": 2 / 3, |
|
"ent_type_f1": 2 / 3, |
|
"partial_precision": 1 / 3, |
|
"partial_recall": 1 / 3, |
|
"partial_f1": 1 / 3, |
|
"exact_precision": 0.0, |
|
"exact_recall": 0.0, |
|
"exact_f1": 0, |
|
}, |
|
"ORG": { |
|
"strict_precision": 0.0, |
|
"strict_recall": 0.0, |
|
"strict_f1": 0, |
|
"ent_type_precision": 0.0, |
|
"ent_type_recall": 0.0, |
|
"ent_type_f1": 0, |
|
"partial_precision": 0.0, |
|
"partial_recall": 0.0, |
|
"partial_f1": 0, |
|
"exact_precision": 0.0, |
|
"exact_recall": 0.0, |
|
"exact_f1": 0, |
|
}, |
|
"PER": { |
|
"strict_precision": 0.0, |
|
"strict_recall": 0.0, |
|
"strict_f1": 0, |
|
"ent_type_precision": 0.5, |
|
"ent_type_recall": 1.0, |
|
"ent_type_f1": 2 / 3, |
|
"partial_precision": 0.25, |
|
"partial_recall": 0.5, |
|
"partial_f1": 1 / 3, |
|
"exact_precision": 0.0, |
|
"exact_recall": 0.0, |
|
"exact_f1": 0, |
|
}, |
|
"LOC": { |
|
"strict_precision": 0.0, |
|
"strict_recall": 0.0, |
|
"strict_f1": 0, |
|
"ent_type_precision": 0.5, |
|
"ent_type_recall": 1.0, |
|
"ent_type_f1": 2 / 3, |
|
"partial_precision": 0.25, |
|
"partial_recall": 0.5, |
|
"partial_f1": 1 / 3, |
|
"exact_precision": 0.0, |
|
"exact_recall": 0.0, |
|
"exact_f1": 0, |
|
}, |
|
}, |
|
}, |
|
] |
|
|
|
|
|
def compare_results(result1, result2): |
|
|
|
if isinstance(result1, dict): |
|
for key in result1.keys(): |
|
if not compare_results(result1[key], result2[key]): |
|
return False |
|
return True |
|
elif isinstance(result1, list): |
|
for item1, item2 in zip(result1, result2): |
|
if not compare_results(item1, item2): |
|
return False |
|
return True |
|
else: |
|
return result1 == result2 |
|
|
|
|
|
@pytest.mark.parametrize("case", test_cases) |
|
def test_metric(case): |
|
if "results" not in case: |
|
with pytest.raises(ValueError): |
|
results = ner_eval.compute( |
|
predictions=[case["predictions"]], references=[case["references"]] |
|
) |
|
else: |
|
results = ner_eval.compute( |
|
predictions=[case["predictions"]], references=[case["references"]] |
|
) |
|
assert compare_results(results, case["results"]) |
|
|