File size: 3,626 Bytes
9c6b911
8a4a728
fcd15ea
 
 
8a4a728
020c3bd
8a4a728
020c3bd
 
8a4a728
 
020c3bd
 
8a4a728
020c3bd
fcd15ea
 
 
 
 
 
 
 
 
 
 
8a4a728
 
 
 
 
 
 
 
 
 
 
fcd15ea
 
 
 
 
 
 
 
8a4a728
 
 
fcd15ea
8a4a728
 
 
 
 
 
 
 
 
 
 
fcd15ea
 
8a4a728
fcd15ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b7a053
d1fbaa3
 
 
 
 
 
 
 
 
 
 
8b7a053
 
 
9418c93
 
 
9c6b911
9418c93
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""This module provides functions for calculating hierarchical variants of precicion, recall and F1."""

from typing import List, Set, Dict, Tuple


def find_ancestors(node: str, hierarchy: dict) -> set:
    """
    Find the ancestors of a given node in a hierarchy.

    Args:
        node (str): The node for which to find ancestors.
        hierarchy (dict): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents.

    Returns:
        set: A set of ancestors of the given node.
    """
    ancestors = set()
    nodes_to_visit = [node]
    while nodes_to_visit:
        current_node = nodes_to_visit.pop()
        if current_node in hierarchy:
            parents = hierarchy[current_node]
            ancestors.update(parents)
            nodes_to_visit.extend(parents)
    return ancestors


def extend_with_ancestors(classes: set, hierarchy: dict) -> set:
    """
    Extend the given set of classes with their ancestors from the hierarchy.

    Args:
        classes (set): The set of classes to extend.
        hierarchy (dict): The hierarchy of classes.

    Returns:
        set: The extended set of classes including their ancestors.
    """
    extended_classes = set(classes)
    for cls in classes:
        ancestors = find_ancestors(cls, hierarchy)
        extended_classes.update(ancestors)
    return extended_classes


def calculate_hierarchical_precision_recall(
    reference_codes: List[str],
    predicted_codes: List[str],
    hierarchy: Dict[str, Set[str]],
) -> Tuple[float, float]:
    """
    Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition.

    Args:
        real_codes (List[str]): The list of reference codes.
        predicted_codes (List[str]): The list of predicted codes.
        hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.

    Returns:
        Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
    """
    # Extend the sets of real and predicted codes with their ancestors
    extended_real = set()
    for code in reference_codes:
        extended_real.add(code)
        extended_real.update(hierarchy.get(code, set()))

    extended_predicted = set()
    for code in predicted_codes:
        extended_predicted.add(code)
        extended_predicted.update(hierarchy.get(code, set()))

    # Calculate the intersection
    correct_predictions = extended_real.intersection(extended_predicted)

    # Calculate hierarchical precision and recall
    hP = len(correct_predictions) / len(extended_predicted) if extended_predicted else 0
    hR = len(correct_predictions) / len(extended_real) if extended_real else 0

    return hP, hR


def hierarchical_f_measure(hP, hR, beta=1.0):
    """
    Calculate the hierarchical F-measure.

    Parameters:
    hP (float): The hierarchical precision.
    hR (float): The hierarchical recall.
    beta (float, optional): The beta value for F-measure calculation. Default is 1.0.

    Returns:
    float: The hierarchical F-measure.
    """
    if hP + hR == 0:
        return 0
    return (beta**2 + 1) * hP * hR / (beta**2 * hP + hR)


# Example list usage:
# reference_codes = ["1111", "1112", "1113", "1114"]
# predicted_codes = ["1111", "1113", "1120", "1211"]
# hierarchy_dict = {'1111': {'111', '1', '11'}, '1112': {'111', '1', '11'}, '1113': {'111', '1', '11'}, '1114': {'111', '1', '11'} ...}
# result = calculate_hierarchical_precision_recall(real_codes, predicted_codes, hierarchy_dict)
# print(result)