File size: 3,463 Bytes
fcd15ea
 
 
 
020c3bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcd15ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
020c3bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b7a053
 
020c3bd
 
 
8b7a053
020c3bd
 
 
8b7a053
020c3bd
8b7a053
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from typing import List, Set, Dict, Tuple


def find_ancestors1(tree, code):
    """
    Recursively finds ancestors of a given class (e.g., an ISCO-08 code) in a hierarchical JSON structure.

    Args:
    - tree: A dictionary representing the hierarchical structure.
    - code: A string representing the label of the class.

    Returns:
    - A list of strings, each representing an ancestor of the input class.
    """
    ancestors = []
    current = code
    while current:
        parent = tree[current]["parent"]
        if parent:
            ancestors.append(parent)
        current = parent
    return ancestors


def find_ancestors(node, hierarchy):
    ancestors = set()
    nodes_to_visit = [node]
    while nodes_to_visit:
        current_node = nodes_to_visit.pop()
        if current_node in hierarchy:
            parents = hierarchy[current_node]
            ancestors.update(parents)
            nodes_to_visit.extend(parents)
    return ancestors


def extend_with_ancestors(classes, hierarchy):
    extended_classes = set(classes)
    for cls in classes:
        ancestors = find_ancestors(cls, hierarchy)
        extended_classes.update(ancestors)
    return extended_classes


def calculate_hierarchical_precision_recall(
    real_codes: List[str], predicted_codes: List[str], hierarchy: Dict[str, Set[str]]
) -> Tuple[float, float]:
    # Extend the sets of real and predicted codes with their ancestors
    extended_real = set()
    for code in real_codes:
        extended_real.add(code)
        extended_real.update(hierarchy.get(code, set()))

    extended_predicted = set()
    for code in predicted_codes:
        extended_predicted.add(code)
        extended_predicted.update(hierarchy.get(code, set()))

    # Calculate the intersection
    correct_predictions = extended_real.intersection(extended_predicted)

    # Calculate hierarchical precision and recall
    hP = len(correct_predictions) / len(extended_predicted) if extended_predicted else 0
    hR = len(correct_predictions) / len(extended_real) if extended_real else 0

    return hP, hR


def calculate_hierarchical_measures(true_labels, predicted_labels, tree):
    """
    Calculates hierarchical precision, recall, and F-measure in a hierarchical structure.

    Args:
    - true_labels: A list of strings representing true class labels.
    - predicted_labels: A list of strings representing predicted class labels.
    - tree: A dictionary representing the hierarchical structure.

    Returns:
    - hP: A floating point number representing hierarchical precision.
    - hR: A floating point number representing hierarchical recall.
    - hF: A floating point number representing hierarchical F-measure.
    """

    extended_true = [set(find_ancestors(tree, code) | {code}) for code in true_labels]
    extended_pred = [
        set(find_ancestors(tree, code) | {code}) for code in predicted_labels
    ]

    true_positive = sum(len(t & p) for t, p in zip(extended_true, extended_pred))
    predicted = sum(len(p) for p in extended_pred)
    actual = sum(len(t) for t in extended_true)

    hP = true_positive / predicted if predicted else 0
    hR = true_positive / actual if actual else 0
    hF = (2 * hP * hR) / (hP + hR) if (hP + hR) else 0

    return hP, hR, hF


def hierarchical_f_measure(hP, hR, beta=1.0):
    """Calculate the hierarchical F-measure."""
    if hP + hR == 0:
        return 0
    return (beta**2 + 1) * hP * hR / (beta**2 * hP + hR)