Refactor calculate_hierarchical_precision_recall to use weighted sums
Browse files
ham.py
CHANGED
@@ -46,36 +46,53 @@ def extend_with_ancestors(classes: set, hierarchy: dict) -> set:
|
|
46 |
def calculate_hierarchical_precision_recall(
|
47 |
reference_codes: List[str],
|
48 |
predicted_codes: List[str],
|
49 |
-
hierarchy: Dict[str,
|
50 |
) -> Tuple[float, float]:
|
51 |
"""
|
52 |
Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition.
|
53 |
|
54 |
Args:
|
55 |
-
|
56 |
predicted_codes (List[str]): The list of predicted codes.
|
57 |
hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.
|
58 |
|
59 |
Returns:
|
60 |
Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
|
61 |
"""
|
62 |
-
|
63 |
-
extended_real = set()
|
64 |
-
for code in reference_codes:
|
65 |
-
extended_real.add(code)
|
66 |
-
extended_real.update(hierarchy.get(code, set()))
|
67 |
|
68 |
-
|
69 |
-
for code in
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
|
74 |
-
correct_predictions = extended_real.intersection(extended_predicted)
|
75 |
|
76 |
-
#
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
return hP, hR
|
81 |
|
|
|
46 |
def calculate_hierarchical_precision_recall(
|
47 |
reference_codes: List[str],
|
48 |
predicted_codes: List[str],
|
49 |
+
hierarchy: Dict[str, Dict[str, float]],
|
50 |
) -> Tuple[float, float]:
|
51 |
"""
|
52 |
Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition.
|
53 |
|
54 |
Args:
|
55 |
+
reference_codes (List[str]): The list of reference codes.
|
56 |
predicted_codes (List[str]): The list of predicted codes.
|
57 |
hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.
|
58 |
|
59 |
Returns:
|
60 |
Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
|
61 |
"""
|
62 |
+
extended_real = {}
|
|
|
|
|
|
|
|
|
63 |
|
64 |
+
# Extend the sets of reference codes with their ancestors
|
65 |
+
for code in reference_codes:
|
66 |
+
weight = 1.0 # Full weight for exact match
|
67 |
+
extended_real[code] = weight
|
68 |
+
for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
|
69 |
+
extended_real[ancestor] = max(
|
70 |
+
extended_real.get(ancestor, 0), ancestor_weight
|
71 |
+
)
|
72 |
|
73 |
+
extended_predicted = {}
|
|
|
74 |
|
75 |
+
# Extend the sets of predicted codes with their ancestors
|
76 |
+
for code in predicted_codes:
|
77 |
+
weight = 1.0
|
78 |
+
extended_predicted[code] = weight
|
79 |
+
for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
|
80 |
+
extended_predicted[ancestor] = max(
|
81 |
+
extended_predicted.get(ancestor, 0), ancestor_weight
|
82 |
+
)
|
83 |
+
|
84 |
+
# Calculate weighted correct predictions
|
85 |
+
correct_weights = 0
|
86 |
+
for code, weight in extended_predicted.items():
|
87 |
+
if code in extended_real:
|
88 |
+
correct_weights += min(weight, extended_real[code])
|
89 |
+
|
90 |
+
total_predicted_weights = sum(extended_predicted.values())
|
91 |
+
total_real_weights = sum(extended_real.values())
|
92 |
+
|
93 |
+
# Calculate hierarchical precision and recall using weighted sums
|
94 |
+
hP = correct_weights / total_predicted_weights if total_predicted_weights else 0
|
95 |
+
hR = correct_weights / total_real_weights if total_real_weights else 0
|
96 |
|
97 |
return hP, hR
|
98 |
|