"""This module provides functions for calculating hierarchical precicion, recall and f1.""" from typing import List, Set, Dict, Tuple def find_ancestors(node: str, hierarchy: dict) -> set: """ Find the ancestors of a given node in a hierarchy. Args: node (str): The node for which to find ancestors. hierarchy (dict): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents. Returns: set: A set of ancestors of the given node. """ ancestors = set() nodes_to_visit = [node] while nodes_to_visit: current_node = nodes_to_visit.pop() if current_node in hierarchy: parents = hierarchy[current_node] ancestors.update(parents) nodes_to_visit.extend(parents) return ancestors def extend_with_ancestors(classes: set, hierarchy: dict) -> set: """ Extend the given set of classes with their ancestors from the hierarchy. Args: classes (set): The set of classes to extend. hierarchy (dict): The hierarchy of classes. Returns: set: The extended set of classes including their ancestors. """ extended_classes = set(classes) for cls in classes: ancestors = find_ancestors(cls, hierarchy) extended_classes.update(ancestors) return extended_classes def calculate_hierarchical_precision_recall( reference_codes: List[str], predicted_codes: List[str], hierarchy: Dict[str, Set[str]], ) -> Tuple[float, float]: """ Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition. Args: real_codes (List[str]): The list of reference codes. predicted_codes (List[str]): The list of predicted codes. hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes. Returns: Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values. """ # Extend the sets of real and predicted codes with their ancestors extended_real = set() for code in reference_codes: extended_real.add(code) extended_real.update(hierarchy.get(code, set())) extended_predicted = set() for code in predicted_codes: extended_predicted.add(code) extended_predicted.update(hierarchy.get(code, set())) # Calculate the intersection correct_predictions = extended_real.intersection(extended_predicted) # Calculate hierarchical precision and recall hP = len(correct_predictions) / len(extended_predicted) if extended_predicted else 0 hR = len(correct_predictions) / len(extended_real) if extended_real else 0 return hP, hR def hierarchical_f_measure(hP, hR, beta=1.0): """Calculate the hierarchical F-measure.""" if hP + hR == 0: return 0 return (beta**2 + 1) * hP * hR / (beta**2 * hP + hR)