danieldux commited on
Commit
fcd15ea
1 Parent(s): de0138a

Add functions for finding ancestors and calculating hierarchical precision and recall

Browse files
Files changed (1) hide show
  1. ham.py +48 -1
ham.py CHANGED
@@ -1,4 +1,7 @@
1
- def find_ancestors(tree, code):
 
 
 
2
  """
3
  Recursively finds ancestors of a given class (e.g., an ISCO-08 code) in a hierarchical JSON structure.
4
 
@@ -19,6 +22,50 @@ def find_ancestors(tree, code):
19
  return ancestors
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def calculate_hierarchical_measures(true_labels, predicted_labels, tree):
23
  """
24
  Calculates hierarchical precision, recall, and F-measure in a hierarchical structure.
 
1
+ from typing import List, Set, Dict, Tuple
2
+
3
+
4
+ def find_ancestors1(tree, code):
5
  """
6
  Recursively finds ancestors of a given class (e.g., an ISCO-08 code) in a hierarchical JSON structure.
7
 
 
22
  return ancestors
23
 
24
 
25
+ def find_ancestors(node, hierarchy):
26
+ ancestors = set()
27
+ nodes_to_visit = [node]
28
+ while nodes_to_visit:
29
+ current_node = nodes_to_visit.pop()
30
+ if current_node in hierarchy:
31
+ parents = hierarchy[current_node]
32
+ ancestors.update(parents)
33
+ nodes_to_visit.extend(parents)
34
+ return ancestors
35
+
36
+
37
+ def extend_with_ancestors(classes, hierarchy):
38
+ extended_classes = set(classes)
39
+ for cls in classes:
40
+ ancestors = find_ancestors(cls, hierarchy)
41
+ extended_classes.update(ancestors)
42
+ return extended_classes
43
+
44
+
45
+ def calculate_hierarchical_precision_recall(
46
+ real_codes: List[str], predicted_codes: List[str], hierarchy: Dict[str, Set[str]]
47
+ ) -> Tuple[float, float]:
48
+ # Extend the sets of real and predicted codes with their ancestors
49
+ extended_real = set()
50
+ for code in real_codes:
51
+ extended_real.add(code)
52
+ extended_real.update(hierarchy.get(code, set()))
53
+
54
+ extended_predicted = set()
55
+ for code in predicted_codes:
56
+ extended_predicted.add(code)
57
+ extended_predicted.update(hierarchy.get(code, set()))
58
+
59
+ # Calculate the intersection
60
+ correct_predictions = extended_real.intersection(extended_predicted)
61
+
62
+ # Calculate hierarchical precision and recall
63
+ hP = len(correct_predictions) / len(extended_predicted) if extended_predicted else 0
64
+ hR = len(correct_predictions) / len(extended_real) if extended_real else 0
65
+
66
+ return hP, hR
67
+
68
+
69
  def calculate_hierarchical_measures(true_labels, predicted_labels, tree):
70
  """
71
  Calculates hierarchical precision, recall, and F-measure in a hierarchical structure.