danieldux commited on
Commit
087d986
1 Parent(s): 03c8589

Refactor ISCO_Hierarchical_Accuracy class to use weighted hierarchy dictionary

Browse files
Files changed (1) hide show
  1. isco_hierarchical_accuracy.py +42 -25
isco_hierarchical_accuracy.py CHANGED
@@ -114,15 +114,14 @@ class ISCO_Hierarchical_Accuracy(evaluate.Metric):
114
 
115
  def create_hierarchy_dict(self, file: str) -> dict:
116
  """
117
- Creates a dictionary where keys are nodes and values are sets of parent nodes representing the group level hierarchy of the ISCO-08 structure.
118
- The function assumes that the input CSV file has a column named 'unit' with the 4-digit ISCO-08 codes.
119
- A csv file with the ISCO-08 structure can be downloaded from the International Labour Organization (ILO) at [https://www.ilo.org/ilostat-files/ISCO/newdocs-08-2021/ISCO-08/ISCO-08 EN.csv](https://www.ilo.org/ilostat-files/ISCO/newdocs-08-2021/ISCO-08/ISCO-08%20EN.csv)
120
 
121
  Args:
122
  - file: A string representing the path to the CSV file containing the 4-digit ISCO-08 codes. It can be a local path or a web URL.
123
 
124
  Returns:
125
- - A dictionary where keys are ISCO-08 unit codes and values are sets of their parent codes.
126
  """
127
 
128
  try:
@@ -146,7 +145,12 @@ class ISCO_Hierarchical_Accuracy(evaluate.Metric):
146
  minor_code = unit_code[0:3]
147
  sub_major_code = unit_code[0:2]
148
  major_code = unit_code[0]
149
- isco_hierarchy[unit_code] = {minor_code, major_code, sub_major_code}
 
 
 
 
 
150
 
151
  return isco_hierarchy
152
 
@@ -192,40 +196,53 @@ class ISCO_Hierarchical_Accuracy(evaluate.Metric):
192
  self,
193
  reference_codes: List[str],
194
  predicted_codes: List[str],
195
- hierarchy: Dict[str, Set[str]],
196
  ) -> Tuple[float, float]:
197
  """
198
  Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition.
199
 
200
  Args:
201
- real_codes (List[str]): The list of reference codes.
202
  predicted_codes (List[str]): The list of predicted codes.
203
  hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.
204
 
205
  Returns:
206
  Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
207
  """
208
- # Extend the sets of real and predicted codes with their ancestors
209
- extended_real = set()
210
- for code in reference_codes:
211
- extended_real.add(code)
212
- extended_real.update(hierarchy.get(code, set()))
213
 
214
- extended_predicted = set()
215
- for code in predicted_codes:
216
- extended_predicted.add(code)
217
- extended_predicted.update(hierarchy.get(code, set()))
 
 
 
 
218
 
219
- # Calculate the intersection
220
- correct_predictions = extended_real.intersection(extended_predicted)
221
 
222
- # Calculate hierarchical precision and recall
223
- hP = (
224
- len(correct_predictions) / len(extended_predicted)
225
- if extended_predicted
226
- else 0
227
- )
228
- hR = len(correct_predictions) / len(extended_real) if extended_real else 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  return hP, hR
231
 
 
114
 
115
  def create_hierarchy_dict(self, file: str) -> dict:
116
  """
117
+ Creates a dictionary where keys are nodes and values are dictionaries of their parent nodes with distance as weights,
118
+ representing the group level hierarchy of the ISCO-08 structure.
 
119
 
120
  Args:
121
  - file: A string representing the path to the CSV file containing the 4-digit ISCO-08 codes. It can be a local path or a web URL.
122
 
123
  Returns:
124
+ - A dictionary where keys are ISCO-08 unit codes and values are dictionaries of their parent codes with distances.
125
  """
126
 
127
  try:
 
145
  minor_code = unit_code[0:3]
146
  sub_major_code = unit_code[0:2]
147
  major_code = unit_code[0]
148
+
149
+ # Assign weights, higher for closer ancestors
150
+ weights = {minor_code: 0.75, sub_major_code: 0.5, major_code: 0.25}
151
+
152
+ # Store ancestors with their weights
153
+ isco_hierarchy[unit_code] = weights
154
 
155
  return isco_hierarchy
156
 
 
196
  self,
197
  reference_codes: List[str],
198
  predicted_codes: List[str],
199
+ hierarchy: Dict[str, Dict[str, float]],
200
  ) -> Tuple[float, float]:
201
  """
202
  Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition.
203
 
204
  Args:
205
+ reference_codes (List[str]): The list of reference codes.
206
  predicted_codes (List[str]): The list of predicted codes.
207
  hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.
208
 
209
  Returns:
210
  Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
211
  """
212
+ extended_real = {}
 
 
 
 
213
 
214
+ # Extend the sets of reference codes with their ancestors
215
+ for code in reference_codes:
216
+ weight = 1.0 # Full weight for exact match
217
+ extended_real[code] = weight
218
+ for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
219
+ extended_real[ancestor] = max(
220
+ extended_real.get(ancestor, 0), ancestor_weight
221
+ )
222
 
223
+ extended_predicted = {}
 
224
 
225
+ # Extend the sets of predicted codes with their ancestors
226
+ for code in predicted_codes:
227
+ weight = 1.0
228
+ extended_predicted[code] = weight
229
+ for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
230
+ extended_predicted[ancestor] = max(
231
+ extended_predicted.get(ancestor, 0), ancestor_weight
232
+ )
233
+
234
+ # Calculate weighted correct predictions
235
+ correct_weights = 0
236
+ for code, weight in extended_predicted.items():
237
+ if code in extended_real:
238
+ correct_weights += min(weight, extended_real[code])
239
+
240
+ total_predicted_weights = sum(extended_predicted.values())
241
+ total_real_weights = sum(extended_real.values())
242
+
243
+ # Calculate hierarchical precision and recall using weighted sums
244
+ hP = correct_weights / total_predicted_weights if total_predicted_weights else 0
245
+ hR = correct_weights / total_real_weights if total_real_weights else 0
246
 
247
  return hP, hR
248