import evaluate import pytest ner_eval = evaluate.load("ner_eval.py") test_cases = [ { "predictions": ["B-PER", "I-PER", "O", "B-LOC", "I-LOC", "O", "O", "B-ORG"], "references": ["B-PER", "I-PER", "O", "B-LOC", "I-LOC", "O", "O", "B-ORG"], "results": { "overall": { "strict_precision": 1.0, "strict_recall": 1.0, "strict_f1": 1.0, "ent_type_precision": 1.0, "ent_type_recall": 1.0, "ent_type_f1": 1.0, "partial_precision": 1.0, "partial_recall": 1.0, "partial_f1": 1.0, "exact_precision": 1.0, "exact_recall": 1.0, "exact_f1": 1.0, }, "LOC": { "strict_precision": 1.0, "strict_recall": 1.0, "strict_f1": 1.0, "ent_type_precision": 1.0, "ent_type_recall": 1.0, "ent_type_f1": 1.0, "partial_precision": 1.0, "partial_recall": 1.0, "partial_f1": 1.0, "exact_precision": 1.0, "exact_recall": 1.0, "exact_f1": 1.0, }, "PER": { "strict_precision": 1.0, "strict_recall": 1.0, "strict_f1": 1.0, "ent_type_precision": 1.0, "ent_type_recall": 1.0, "ent_type_f1": 1.0, "partial_precision": 1.0, "partial_recall": 1.0, "partial_f1": 1.0, "exact_precision": 1.0, "exact_recall": 1.0, "exact_f1": 1.0, }, "ORG": { "strict_precision": 1.0, "strict_recall": 1.0, "strict_f1": 1.0, "ent_type_precision": 1.0, "ent_type_recall": 1.0, "ent_type_f1": 1.0, "partial_precision": 1.0, "partial_recall": 1.0, "partial_f1": 1.0, "exact_precision": 1.0, "exact_recall": 1.0, "exact_f1": 1.0, }, }, }, { "predictions": [ "B-LOC", "I-LOC", "O", "B-PER", "I-PER", "I-PER", "I-PER", "O", "B-LOC", "O", ], "references": [ "B-LOC", "I-LOC", "O", "B-PER", "I-PER", "I-PER", "I-PER", "O", "B-LOC", "O", ], "results": { "overall": { "strict_precision": 1.0, "strict_recall": 1.0, "strict_f1": 1.0, "ent_type_precision": 1.0, "ent_type_recall": 1.0, "ent_type_f1": 1.0, "partial_precision": 1.0, "partial_recall": 1.0, "partial_f1": 1.0, "exact_precision": 1.0, "exact_recall": 1.0, "exact_f1": 1.0, }, "LOC": { "strict_precision": 1.0, "strict_recall": 1.0, "strict_f1": 1.0, "ent_type_precision": 1.0, "ent_type_recall": 1.0, "ent_type_f1": 1.0, "partial_precision": 1.0, "partial_recall": 1.0, "partial_f1": 1.0, "exact_precision": 1.0, "exact_recall": 1.0, "exact_f1": 1.0, }, "PER": { "strict_precision": 1.0, "strict_recall": 1.0, "strict_f1": 1.0, "ent_type_precision": 1.0, "ent_type_recall": 1.0, "ent_type_f1": 1.0, "partial_precision": 1.0, "partial_recall": 1.0, "partial_f1": 1.0, "exact_precision": 1.0, "exact_recall": 1.0, "exact_f1": 1.0, }, }, }, { "predictions": ["O", "B-LOC", "I-LOC", "B-PER", "I-PER", "O", "B-ORG"], "references": ["O", "B-LOC", "I-LOC", "O", "B-PER", "I-PER", "O", "B-ORG"], }, { "predictions": ["B-PER", "O", "B-LOC", "I-LOC", "O", "B-ORG", "I-ORG"], "references": ["B-PER", "I-PER", "O", "B-LOC", "I-LOC", "O", "B-ORG"], "results": { "overall": { "strict_precision": 0.0, "strict_recall": 0.0, "strict_f1": 0, "ent_type_precision": 0.0, "ent_type_recall": 0.0, "ent_type_f1": 0, "partial_precision": 0.0, "partial_recall": 0.0, "partial_f1": 0, "exact_precision": 0.0, "exact_recall": 0.0, "exact_f1": 0, }, "ORG": { "strict_precision": 0.0, "strict_recall": 0.0, "strict_f1": 0, "ent_type_precision": 0.0, "ent_type_recall": 0.0, "ent_type_f1": 0, "partial_precision": 0.0, "partial_recall": 0.0, "partial_f1": 0, "exact_precision": 0.0, "exact_recall": 0.0, "exact_f1": 0, }, "PER": { "strict_precision": 0.0, "strict_recall": 0.0, "strict_f1": 0, "ent_type_precision": 0.0, "ent_type_recall": 0.0, "ent_type_f1": 0, "partial_precision": 0.0, "partial_recall": 0.0, "partial_f1": 0, "exact_precision": 0.0, "exact_recall": 0.0, "exact_f1": 0, }, "LOC": { "strict_precision": 0.0, "strict_recall": 0.0, "strict_f1": 0, "ent_type_precision": 0.0, "ent_type_recall": 0.0, "ent_type_f1": 0, "partial_precision": 0.0, "partial_recall": 0.0, "partial_f1": 0, "exact_precision": 0.0, "exact_recall": 0.0, "exact_f1": 0, }, }, }, { "predictions": [ "B-LOC", "I-LOC", "I-LOC", "B-ORG", "I-ORG", "O", "B-PER", "I-PER", "I-PER", "O", ], "references": [ "B-LOC", "I-LOC", "O", "O", "B-ORG", "I-ORG", "O", "B-PER", "I-PER", "O", ], "results": { "overall": { "strict_precision": 0.0, "strict_recall": 0.0, "strict_f1": 0, "ent_type_precision": 2 / 3, "ent_type_recall": 2 / 3, "ent_type_f1": 2 / 3, "partial_precision": 1 / 3, "partial_recall": 1 / 3, "partial_f1": 1 / 3, "exact_precision": 0.0, "exact_recall": 0.0, "exact_f1": 0, }, "ORG": { "strict_precision": 0.0, "strict_recall": 0.0, "strict_f1": 0, "ent_type_precision": 0.0, "ent_type_recall": 0.0, "ent_type_f1": 0, "partial_precision": 0.0, "partial_recall": 0.0, "partial_f1": 0, "exact_precision": 0.0, "exact_recall": 0.0, "exact_f1": 0, }, "PER": { "strict_precision": 0.0, "strict_recall": 0.0, "strict_f1": 0, "ent_type_precision": 0.5, "ent_type_recall": 1.0, "ent_type_f1": 2 / 3, "partial_precision": 0.25, "partial_recall": 0.5, "partial_f1": 1 / 3, "exact_precision": 0.0, "exact_recall": 0.0, "exact_f1": 0, }, "LOC": { "strict_precision": 0.0, "strict_recall": 0.0, "strict_f1": 0, "ent_type_precision": 0.5, "ent_type_recall": 1.0, "ent_type_f1": 2 / 3, "partial_precision": 0.25, "partial_recall": 0.5, "partial_f1": 1 / 3, "exact_precision": 0.0, "exact_recall": 0.0, "exact_f1": 0, }, }, }, ] def compare_results(result1, result2): # recursively check if dictionaries are equal if isinstance(result1, dict): for key in result1.keys(): if not compare_results(result1[key], result2[key]): return False return True elif isinstance(result1, list): for item1, item2 in zip(result1, result2): if not compare_results(item1, item2): return False return True else: return result1 == result2 @pytest.mark.parametrize("case", test_cases) def test_metric(case): if "results" not in case: with pytest.raises(ValueError): results = ner_eval.compute( predictions=[case["predictions"]], references=[case["references"]] ) else: results = ner_eval.compute( predictions=[case["predictions"]], references=[case["references"]] ) assert compare_results(results, case["results"])