|
"""This module provides functions for calculating hierarchical variants of precicion, recall and F1.""" |
|
|
|
from typing import List, Set, Dict, Tuple |
|
|
|
|
|
def find_ancestors(node: str, hierarchy: dict) -> set: |
|
""" |
|
Find the ancestors of a given node in a hierarchy. |
|
|
|
Args: |
|
node (str): The node for which to find ancestors. |
|
hierarchy (dict): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents. |
|
|
|
Returns: |
|
set: A set of ancestors of the given node. |
|
""" |
|
ancestors = set() |
|
nodes_to_visit = [node] |
|
while nodes_to_visit: |
|
current_node = nodes_to_visit.pop() |
|
if current_node in hierarchy: |
|
parents = hierarchy[current_node] |
|
ancestors.update(parents) |
|
nodes_to_visit.extend(parents) |
|
return ancestors |
|
|
|
|
|
def extend_with_ancestors(classes: set, hierarchy: dict) -> set: |
|
""" |
|
Extend the given set of classes with their ancestors from the hierarchy. |
|
|
|
Args: |
|
classes (set): The set of classes to extend. |
|
hierarchy (dict): The hierarchy of classes. |
|
|
|
Returns: |
|
set: The extended set of classes including their ancestors. |
|
""" |
|
extended_classes = set(classes) |
|
for cls in classes: |
|
ancestors = find_ancestors(cls, hierarchy) |
|
extended_classes.update(ancestors) |
|
return extended_classes |
|
|
|
|
|
def calculate_hierarchical_precision_recall( |
|
reference_codes: List[str], |
|
predicted_codes: List[str], |
|
hierarchy: Dict[str, Dict[str, float]], |
|
) -> Tuple[float, float]: |
|
""" |
|
Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition. |
|
|
|
Args: |
|
reference_codes (List[str]): The list of reference codes. |
|
predicted_codes (List[str]): The list of predicted codes. |
|
hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes. |
|
|
|
Returns: |
|
Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values. |
|
""" |
|
extended_real = {} |
|
|
|
|
|
for code in reference_codes: |
|
weight = 1.0 |
|
extended_real[code] = weight |
|
for ancestor, ancestor_weight in hierarchy.get(code, {}).items(): |
|
extended_real[ancestor] = max( |
|
extended_real.get(ancestor, 0), ancestor_weight |
|
) |
|
|
|
extended_predicted = {} |
|
|
|
|
|
for code in predicted_codes: |
|
weight = 1.0 |
|
extended_predicted[code] = weight |
|
for ancestor, ancestor_weight in hierarchy.get(code, {}).items(): |
|
extended_predicted[ancestor] = max( |
|
extended_predicted.get(ancestor, 0), ancestor_weight |
|
) |
|
|
|
|
|
correct_weights = 0 |
|
for code, weight in extended_predicted.items(): |
|
if code in extended_real: |
|
correct_weights += min(weight, extended_real[code]) |
|
|
|
total_predicted_weights = sum(extended_predicted.values()) |
|
total_real_weights = sum(extended_real.values()) |
|
|
|
|
|
hP = correct_weights / total_predicted_weights if total_predicted_weights else 0 |
|
hR = correct_weights / total_real_weights if total_real_weights else 0 |
|
|
|
return hP, hR |
|
|
|
|
|
def hierarchical_f_measure(hP, hR, beta=1.0): |
|
""" |
|
Calculate the hierarchical F-measure. |
|
|
|
Parameters: |
|
hP (float): The hierarchical precision. |
|
hR (float): The hierarchical recall. |
|
beta (float, optional): The beta value for F-measure calculation. Default is 1.0. |
|
|
|
Returns: |
|
float: The hierarchical F-measure. |
|
""" |
|
if hP + hR == 0: |
|
return 0 |
|
return (beta**2 + 1) * hP * hR / (beta**2 * hP + hR) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|