|
from typing import List, Set, Dict, Tuple |
|
|
|
|
|
def find_ancestors1(tree, code): |
|
""" |
|
Recursively finds ancestors of a given class (e.g., an ISCO-08 code) in a hierarchical JSON structure. |
|
|
|
Args: |
|
- tree: A dictionary representing the hierarchical structure. |
|
- code: A string representing the label of the class. |
|
|
|
Returns: |
|
- A list of strings, each representing an ancestor of the input class. |
|
""" |
|
ancestors = [] |
|
current = code |
|
while current: |
|
parent = tree[current]["parent"] |
|
if parent: |
|
ancestors.append(parent) |
|
current = parent |
|
return ancestors |
|
|
|
|
|
def find_ancestors(node, hierarchy): |
|
ancestors = set() |
|
nodes_to_visit = [node] |
|
while nodes_to_visit: |
|
current_node = nodes_to_visit.pop() |
|
if current_node in hierarchy: |
|
parents = hierarchy[current_node] |
|
ancestors.update(parents) |
|
nodes_to_visit.extend(parents) |
|
return ancestors |
|
|
|
|
|
def extend_with_ancestors(classes, hierarchy): |
|
extended_classes = set(classes) |
|
for cls in classes: |
|
ancestors = find_ancestors(cls, hierarchy) |
|
extended_classes.update(ancestors) |
|
return extended_classes |
|
|
|
|
|
def calculate_hierarchical_precision_recall( |
|
real_codes: List[str], predicted_codes: List[str], hierarchy: Dict[str, Set[str]] |
|
) -> Tuple[float, float]: |
|
|
|
extended_real = set() |
|
for code in real_codes: |
|
extended_real.add(code) |
|
extended_real.update(hierarchy.get(code, set())) |
|
|
|
extended_predicted = set() |
|
for code in predicted_codes: |
|
extended_predicted.add(code) |
|
extended_predicted.update(hierarchy.get(code, set())) |
|
|
|
|
|
correct_predictions = extended_real.intersection(extended_predicted) |
|
|
|
|
|
hP = len(correct_predictions) / len(extended_predicted) if extended_predicted else 0 |
|
hR = len(correct_predictions) / len(extended_real) if extended_real else 0 |
|
|
|
return hP, hR |
|
|
|
|
|
def calculate_hierarchical_measures(true_labels, predicted_labels, tree): |
|
""" |
|
Calculates hierarchical precision, recall, and F-measure in a hierarchical structure. |
|
|
|
Args: |
|
- true_labels: A list of strings representing true class labels. |
|
- predicted_labels: A list of strings representing predicted class labels. |
|
- tree: A dictionary representing the hierarchical structure. |
|
|
|
Returns: |
|
- hP: A floating point number representing hierarchical precision. |
|
- hR: A floating point number representing hierarchical recall. |
|
- hF: A floating point number representing hierarchical F-measure. |
|
""" |
|
|
|
extended_true = [set(find_ancestors(tree, code) | {code}) for code in true_labels] |
|
extended_pred = [ |
|
set(find_ancestors(tree, code) | {code}) for code in predicted_labels |
|
] |
|
|
|
true_positive = sum(len(t & p) for t, p in zip(extended_true, extended_pred)) |
|
predicted = sum(len(p) for p in extended_pred) |
|
actual = sum(len(t) for t in extended_true) |
|
|
|
hP = true_positive / predicted if predicted else 0 |
|
hR = true_positive / actual if actual else 0 |
|
hF = (2 * hP * hR) / (hP + hR) if (hP + hR) else 0 |
|
|
|
return hP, hR, hF |
|
|
|
|
|
def hierarchical_f_measure(hP, hR, beta=1.0): |
|
"""Calculate the hierarchical F-measure.""" |
|
if hP + hR == 0: |
|
return 0 |
|
return (beta**2 + 1) * hP * hR / (beta**2 * hP + hR) |
|
|