danieldux commited on
Commit
03c8589
1 Parent(s): ad04d80

Refactor calculate_hierarchical_precision_recall to use weighted sums

Browse files
Files changed (1) hide show
  1. ham.py +33 -16
ham.py CHANGED
@@ -46,36 +46,53 @@ def extend_with_ancestors(classes: set, hierarchy: dict) -> set:
46
  def calculate_hierarchical_precision_recall(
47
  reference_codes: List[str],
48
  predicted_codes: List[str],
49
- hierarchy: Dict[str, Set[str]],
50
  ) -> Tuple[float, float]:
51
  """
52
  Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition.
53
 
54
  Args:
55
- real_codes (List[str]): The list of reference codes.
56
  predicted_codes (List[str]): The list of predicted codes.
57
  hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.
58
 
59
  Returns:
60
  Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
61
  """
62
- # Extend the sets of real and predicted codes with their ancestors
63
- extended_real = set()
64
- for code in reference_codes:
65
- extended_real.add(code)
66
- extended_real.update(hierarchy.get(code, set()))
67
 
68
- extended_predicted = set()
69
- for code in predicted_codes:
70
- extended_predicted.add(code)
71
- extended_predicted.update(hierarchy.get(code, set()))
 
 
 
 
72
 
73
- # Calculate the intersection
74
- correct_predictions = extended_real.intersection(extended_predicted)
75
 
76
- # Calculate hierarchical precision and recall
77
- hP = len(correct_predictions) / len(extended_predicted) if extended_predicted else 0
78
- hR = len(correct_predictions) / len(extended_real) if extended_real else 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  return hP, hR
81
 
 
46
  def calculate_hierarchical_precision_recall(
47
  reference_codes: List[str],
48
  predicted_codes: List[str],
49
+ hierarchy: Dict[str, Dict[str, float]],
50
  ) -> Tuple[float, float]:
51
  """
52
  Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition.
53
 
54
  Args:
55
+ reference_codes (List[str]): The list of reference codes.
56
  predicted_codes (List[str]): The list of predicted codes.
57
  hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.
58
 
59
  Returns:
60
  Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
61
  """
62
+ extended_real = {}
 
 
 
 
63
 
64
+ # Extend the sets of reference codes with their ancestors
65
+ for code in reference_codes:
66
+ weight = 1.0 # Full weight for exact match
67
+ extended_real[code] = weight
68
+ for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
69
+ extended_real[ancestor] = max(
70
+ extended_real.get(ancestor, 0), ancestor_weight
71
+ )
72
 
73
+ extended_predicted = {}
 
74
 
75
+ # Extend the sets of predicted codes with their ancestors
76
+ for code in predicted_codes:
77
+ weight = 1.0
78
+ extended_predicted[code] = weight
79
+ for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
80
+ extended_predicted[ancestor] = max(
81
+ extended_predicted.get(ancestor, 0), ancestor_weight
82
+ )
83
+
84
+ # Calculate weighted correct predictions
85
+ correct_weights = 0
86
+ for code, weight in extended_predicted.items():
87
+ if code in extended_real:
88
+ correct_weights += min(weight, extended_real[code])
89
+
90
+ total_predicted_weights = sum(extended_predicted.values())
91
+ total_real_weights = sum(extended_real.values())
92
+
93
+ # Calculate hierarchical precision and recall using weighted sums
94
+ hP = correct_weights / total_predicted_weights if total_predicted_weights else 0
95
+ hR = correct_weights / total_real_weights if total_real_weights else 0
96
 
97
  return hP, hR
98