danieldux commited on
Commit
8a4a728
1 Parent(s): 17ea6af

Refactor code to improve readability and add type annotations

Browse files
Files changed (1) hide show
  1. ham.py +33 -50
ham.py CHANGED
@@ -1,28 +1,19 @@
 
 
1
  from typing import List, Set, Dict, Tuple
2
 
3
 
4
- def find_ancestors1(tree, code):
5
  """
6
- Recursively finds ancestors of a given class (e.g., an ISCO-08 code) in a hierarchical JSON structure.
7
 
8
  Args:
9
- - tree: A dictionary representing the hierarchical structure.
10
- - code: A string representing the label of the class.
11
 
12
  Returns:
13
- - A list of strings, each representing an ancestor of the input class.
14
  """
15
- ancestors = []
16
- current = code
17
- while current:
18
- parent = tree[current]["parent"]
19
- if parent:
20
- ancestors.append(parent)
21
- current = parent
22
- return ancestors
23
-
24
-
25
- def find_ancestors(node, hierarchy):
26
  ancestors = set()
27
  nodes_to_visit = [node]
28
  while nodes_to_visit:
@@ -34,7 +25,17 @@ def find_ancestors(node, hierarchy):
34
  return ancestors
35
 
36
 
37
- def extend_with_ancestors(classes, hierarchy):
 
 
 
 
 
 
 
 
 
 
38
  extended_classes = set(classes)
39
  for cls in classes:
40
  ancestors = find_ancestors(cls, hierarchy)
@@ -43,11 +44,24 @@ def extend_with_ancestors(classes, hierarchy):
43
 
44
 
45
  def calculate_hierarchical_precision_recall(
46
- real_codes: List[str], predicted_codes: List[str], hierarchy: Dict[str, Set[str]]
 
 
47
  ) -> Tuple[float, float]:
 
 
 
 
 
 
 
 
 
 
 
48
  # Extend the sets of real and predicted codes with their ancestors
49
  extended_real = set()
50
- for code in real_codes:
51
  extended_real.add(code)
52
  extended_real.update(hierarchy.get(code, set()))
53
 
@@ -66,37 +80,6 @@ def calculate_hierarchical_precision_recall(
66
  return hP, hR
67
 
68
 
69
- def calculate_hierarchical_measures(true_labels, predicted_labels, tree):
70
- """
71
- Calculates hierarchical precision, recall, and F-measure in a hierarchical structure.
72
-
73
- Args:
74
- - true_labels: A list of strings representing true class labels.
75
- - predicted_labels: A list of strings representing predicted class labels.
76
- - tree: A dictionary representing the hierarchical structure.
77
-
78
- Returns:
79
- - hP: A floating point number representing hierarchical precision.
80
- - hR: A floating point number representing hierarchical recall.
81
- - hF: A floating point number representing hierarchical F-measure.
82
- """
83
-
84
- extended_true = [set(find_ancestors(tree, code) | {code}) for code in true_labels]
85
- extended_pred = [
86
- set(find_ancestors(tree, code) | {code}) for code in predicted_labels
87
- ]
88
-
89
- true_positive = sum(len(t & p) for t, p in zip(extended_true, extended_pred))
90
- predicted = sum(len(p) for p in extended_pred)
91
- actual = sum(len(t) for t in extended_true)
92
-
93
- hP = true_positive / predicted if predicted else 0
94
- hR = true_positive / actual if actual else 0
95
- hF = (2 * hP * hR) / (hP + hR) if (hP + hR) else 0
96
-
97
- return hP, hR, hF
98
-
99
-
100
  def hierarchical_f_measure(hP, hR, beta=1.0):
101
  """Calculate the hierarchical F-measure."""
102
  if hP + hR == 0:
 
1
+ """This module provides functions for calculating hierarchical precicion, recall and f1."""
2
+
3
  from typing import List, Set, Dict, Tuple
4
 
5
 
6
+ def find_ancestors(node: str, hierarchy: dict) -> set:
7
  """
8
+ Find the ancestors of a given node in a hierarchy.
9
 
10
  Args:
11
+ node (str): The node for which to find ancestors.
12
+ hierarchy (dict): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents.
13
 
14
  Returns:
15
+ set: A set of ancestors of the given node.
16
  """
 
 
 
 
 
 
 
 
 
 
 
17
  ancestors = set()
18
  nodes_to_visit = [node]
19
  while nodes_to_visit:
 
25
  return ancestors
26
 
27
 
28
+ def extend_with_ancestors(classes: set, hierarchy: dict) -> set:
29
+ """
30
+ Extend the given set of classes with their ancestors from the hierarchy.
31
+
32
+ Args:
33
+ classes (set): The set of classes to extend.
34
+ hierarchy (dict): The hierarchy of classes.
35
+
36
+ Returns:
37
+ set: The extended set of classes including their ancestors.
38
+ """
39
  extended_classes = set(classes)
40
  for cls in classes:
41
  ancestors = find_ancestors(cls, hierarchy)
 
44
 
45
 
46
  def calculate_hierarchical_precision_recall(
47
+ reference_codes: List[str],
48
+ predicted_codes: List[str],
49
+ hierarchy: Dict[str, Set[str]],
50
  ) -> Tuple[float, float]:
51
+ """
52
+ Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition.
53
+
54
+ Args:
55
+ real_codes (List[str]): The list of reference codes.
56
+ predicted_codes (List[str]): The list of predicted codes.
57
+ hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.
58
+
59
+ Returns:
60
+ Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
61
+ """
62
  # Extend the sets of real and predicted codes with their ancestors
63
  extended_real = set()
64
+ for code in reference_codes:
65
  extended_real.add(code)
66
  extended_real.update(hierarchy.get(code, set()))
67
 
 
80
  return hP, hR
81
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def hierarchical_f_measure(hP, hR, beta=1.0):
84
  """Calculate the hierarchical F-measure."""
85
  if hP + hR == 0: