danieldux's picture
Refactor code to improve readability and add type annotations
8a4a728
raw
history blame
2.99 kB
"""This module provides functions for calculating hierarchical precicion, recall and f1."""
from typing import List, Set, Dict, Tuple
def find_ancestors(node: str, hierarchy: dict) -> set:
"""
Find the ancestors of a given node in a hierarchy.
Args:
node (str): The node for which to find ancestors.
hierarchy (dict): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents.
Returns:
set: A set of ancestors of the given node.
"""
ancestors = set()
nodes_to_visit = [node]
while nodes_to_visit:
current_node = nodes_to_visit.pop()
if current_node in hierarchy:
parents = hierarchy[current_node]
ancestors.update(parents)
nodes_to_visit.extend(parents)
return ancestors
def extend_with_ancestors(classes: set, hierarchy: dict) -> set:
"""
Extend the given set of classes with their ancestors from the hierarchy.
Args:
classes (set): The set of classes to extend.
hierarchy (dict): The hierarchy of classes.
Returns:
set: The extended set of classes including their ancestors.
"""
extended_classes = set(classes)
for cls in classes:
ancestors = find_ancestors(cls, hierarchy)
extended_classes.update(ancestors)
return extended_classes
def calculate_hierarchical_precision_recall(
reference_codes: List[str],
predicted_codes: List[str],
hierarchy: Dict[str, Set[str]],
) -> Tuple[float, float]:
"""
Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition.
Args:
real_codes (List[str]): The list of reference codes.
predicted_codes (List[str]): The list of predicted codes.
hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.
Returns:
Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
"""
# Extend the sets of real and predicted codes with their ancestors
extended_real = set()
for code in reference_codes:
extended_real.add(code)
extended_real.update(hierarchy.get(code, set()))
extended_predicted = set()
for code in predicted_codes:
extended_predicted.add(code)
extended_predicted.update(hierarchy.get(code, set()))
# Calculate the intersection
correct_predictions = extended_real.intersection(extended_predicted)
# Calculate hierarchical precision and recall
hP = len(correct_predictions) / len(extended_predicted) if extended_predicted else 0
hR = len(correct_predictions) / len(extended_real) if extended_real else 0
return hP, hR
def hierarchical_f_measure(hP, hR, beta=1.0):
"""Calculate the hierarchical F-measure."""
if hP + hR == 0:
return 0
return (beta**2 + 1) * hP * hR / (beta**2 * hP + hR)