File size: 4,294 Bytes
9c6b911 8a4a728 fcd15ea 8a4a728 020c3bd 8a4a728 020c3bd 8a4a728 020c3bd 8a4a728 020c3bd fcd15ea 8a4a728 fcd15ea 8a4a728 03c8589 fcd15ea 8a4a728 03c8589 8a4a728 03c8589 fcd15ea 03c8589 fcd15ea 03c8589 fcd15ea 03c8589 fcd15ea 8b7a053 d1fbaa3 8b7a053 9418c93 9c6b911 9418c93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
"""This module provides functions for calculating hierarchical variants of precicion, recall and F1."""
from typing import List, Set, Dict, Tuple
def find_ancestors(node: str, hierarchy: dict) -> set:
"""
Find the ancestors of a given node in a hierarchy.
Args:
node (str): The node for which to find ancestors.
hierarchy (dict): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents.
Returns:
set: A set of ancestors of the given node.
"""
ancestors = set()
nodes_to_visit = [node]
while nodes_to_visit:
current_node = nodes_to_visit.pop()
if current_node in hierarchy:
parents = hierarchy[current_node]
ancestors.update(parents)
nodes_to_visit.extend(parents)
return ancestors
def extend_with_ancestors(classes: set, hierarchy: dict) -> set:
"""
Extend the given set of classes with their ancestors from the hierarchy.
Args:
classes (set): The set of classes to extend.
hierarchy (dict): The hierarchy of classes.
Returns:
set: The extended set of classes including their ancestors.
"""
extended_classes = set(classes)
for cls in classes:
ancestors = find_ancestors(cls, hierarchy)
extended_classes.update(ancestors)
return extended_classes
def calculate_hierarchical_precision_recall(
reference_codes: List[str],
predicted_codes: List[str],
hierarchy: Dict[str, Dict[str, float]],
) -> Tuple[float, float]:
"""
Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition.
Args:
reference_codes (List[str]): The list of reference codes.
predicted_codes (List[str]): The list of predicted codes.
hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.
Returns:
Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
"""
extended_real = {}
# Extend the sets of reference codes with their ancestors
for code in reference_codes:
weight = 1.0 # Full weight for exact match
extended_real[code] = weight
for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
extended_real[ancestor] = max(
extended_real.get(ancestor, 0), ancestor_weight
)
extended_predicted = {}
# Extend the sets of predicted codes with their ancestors
for code in predicted_codes:
weight = 1.0
extended_predicted[code] = weight
for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
extended_predicted[ancestor] = max(
extended_predicted.get(ancestor, 0), ancestor_weight
)
# Calculate weighted correct predictions
correct_weights = 0
for code, weight in extended_predicted.items():
if code in extended_real:
correct_weights += min(weight, extended_real[code])
total_predicted_weights = sum(extended_predicted.values())
total_real_weights = sum(extended_real.values())
# Calculate hierarchical precision and recall using weighted sums
hP = correct_weights / total_predicted_weights if total_predicted_weights else 0
hR = correct_weights / total_real_weights if total_real_weights else 0
return hP, hR
def hierarchical_f_measure(hP, hR, beta=1.0):
"""
Calculate the hierarchical F-measure.
Parameters:
hP (float): The hierarchical precision.
hR (float): The hierarchical recall.
beta (float, optional): The beta value for F-measure calculation. Default is 1.0.
Returns:
float: The hierarchical F-measure.
"""
if hP + hR == 0:
return 0
return (beta**2 + 1) * hP * hR / (beta**2 * hP + hR)
# Example list usage:
# reference_codes = ["1111", "1112", "1113", "1114"]
# predicted_codes = ["1111", "1113", "1120", "1211"]
# hierarchy_dict = {'1111': {'111', '1', '11'}, '1112': {'111', '1', '11'}, '1113': {'111', '1', '11'}, '1114': {'111', '1', '11'} ...}
# result = calculate_hierarchical_precision_recall(real_codes, predicted_codes, hierarchy_dict)
# print(result)
|