|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""ISCO-08 Hierarchical Accuracy Measure.""" |
|
|
|
from typing import List, Set, Dict, Tuple |
|
import evaluate |
|
import datasets |
|
|
|
|
|
|
|
|
|
|
|
|
|
_CITATION = """ |
|
@article{scikit-learn, |
|
title={Scikit-learn: Machine Learning in {P}ython}, |
|
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V. |
|
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P. |
|
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and |
|
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.}, |
|
journal={Journal of Machine Learning Research}, |
|
volume={12}, |
|
pages={2825--2830}, |
|
year={2011} |
|
} |
|
""" |
|
|
|
_DESCRIPTION = """ |
|
The ISCO-08 Hierarchical Accuracy Measure is an implementation of the measure described in [Functional Annotation of Genes Using Hierarchical Text Categorization](https://www.researchgate.net/publication/44046343_Functional_Annotation_of_Genes_Using_Hierarchical_Text_Categorization) (Kiritchenko, Svetlana and Famili, Fazel. 2005) and adapted for the ISCO-08 classification scheme by the International Labour Organization. |
|
""" |
|
|
|
_KWARGS_DESCRIPTION = """ |
|
Calculates hierarchical precision, hierarchical recall and hierarchical F1 given a list of reference codes and predicted codes from the ISCO-08 taxonomy by the International Labour Organization. |
|
|
|
Args: |
|
- references (List[str]): List of ISCO-08 reference codes. Each reference code should be a single token, 4-digit ISCO-08 code string. |
|
- predictions (List[str]): List of machine predicted or human assigned ISCO-08 codes to score. Each prediction should be a single token, 4-digit ISCO-08 code string. |
|
|
|
Returns: |
|
- hierarchical_precision (`float` or `int`): Hierarchical precision score. Minimum possible value is 0. Maximum possible value is 1.0. A higher score means higher accuracy. |
|
- hierarchical_recall: Hierarchical recall score. Minimum possible value is 0. Maximum possible value is 1.0. A higher score means higher accuracy. |
|
- hierarchical_fmeasure: Hierarchical F1 score. Minimum possible value is 0. Maximum possible value is 1.0. A higher score means higher accuracy. |
|
|
|
Examples: |
|
Example 1 |
|
|
|
>>> ham = evaluate.load("danieldux/isco_hierarchical_accuracy") |
|
>>> results = ham.compute(reference=["1111", "1112", "1113", "1114"], predictions=["1111", "1113", "1120", "1211"]) |
|
>>> print(results) |
|
{ |
|
'accuracy': 0.25, |
|
'hierarchical_precision': 0.7142857142857143, |
|
'hierarchical_recall': 0.5, |
|
'hierarchical_fmeasure': 0.588235294117647 |
|
} |
|
""" |
|
|
|
|
|
ISCO_CSV_MIRROR_URL = ( |
|
"https://storage.googleapis.com/isco-public/tables/ISCO_structure.csv" |
|
) |
|
ILO_ISCO_CSV_URL = ( |
|
"https://www.ilo.org/ilostat-files/ISCO/newdocs-08-2021/ISCO-08/ISCO-08%20EN.csv" |
|
) |
|
|
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
class ISCO_Hierarchical_Accuracy(evaluate.Metric): |
|
"""The ISCO-08 Hierarchical Accuracy Measure""" |
|
|
|
def _info(self): |
|
|
|
return evaluate.MetricInfo( |
|
|
|
module_type="metric", |
|
description=_DESCRIPTION, |
|
citation=_CITATION, |
|
inputs_description=_KWARGS_DESCRIPTION, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
features=datasets.Features( |
|
{ |
|
"references": datasets.Sequence(datasets.Value("string")), |
|
"predictions": datasets.Sequence(datasets.Value("string")), |
|
} |
|
if self.config_name == "multilabel" |
|
else { |
|
"references": datasets.Value("string"), |
|
"predictions": datasets.Value("string"), |
|
} |
|
), |
|
|
|
homepage="http://module.homepage", |
|
|
|
codebase_urls=["http://github.com/path/to/codebase/of/new_module"], |
|
reference_urls=["http://path.to.reference.url/new_module"], |
|
) |
|
|
|
def create_hierarchy_dict(self, file: str) -> dict: |
|
""" |
|
Creates a dictionary where keys are nodes and values are dictionaries of their parent nodes with distance as weights, |
|
representing the group level hierarchy of the ISCO-08 structure. |
|
|
|
Args: |
|
- file: A string representing the path to the CSV file containing the 4-digit ISCO-08 codes. It can be a local path or a web URL. |
|
|
|
Returns: |
|
- A dictionary where keys are ISCO-08 unit codes and values are dictionaries of their parent codes with distances. |
|
""" |
|
|
|
try: |
|
import requests |
|
import csv |
|
except ImportError as error: |
|
raise error |
|
|
|
isco_hierarchy = {} |
|
|
|
if file.startswith("http://") or file.startswith("https://"): |
|
response = requests.get(file) |
|
lines = response.text.splitlines() |
|
else: |
|
with open(file, newline="") as csvfile: |
|
lines = csvfile.readlines() |
|
|
|
reader = csv.DictReader(lines) |
|
for row in reader: |
|
unit_code = row["unit"].zfill(4) |
|
minor_code = unit_code[0:3] |
|
sub_major_code = unit_code[0:2] |
|
major_code = unit_code[0] |
|
|
|
|
|
weights = {minor_code: 0.75, sub_major_code: 0.5, major_code: 0.25} |
|
|
|
|
|
isco_hierarchy[unit_code] = weights |
|
|
|
return isco_hierarchy |
|
|
|
def find_ancestors(self, node: str, hierarchy: Dict[str, Set[str]]) -> Set[str]: |
|
""" |
|
Find the ancestors of a given node in a hierarchy. |
|
|
|
Args: |
|
node (str): The node for which to find ancestors. |
|
hierarchy (Dict[str, Set[str]]): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents. |
|
|
|
Returns: |
|
Set[str]: A set of ancestors of the given node. |
|
""" |
|
ancestors = set() |
|
nodes_to_visit = [node] |
|
while nodes_to_visit: |
|
current_node = nodes_to_visit.pop() |
|
if current_node in hierarchy: |
|
parents = hierarchy[current_node] |
|
ancestors.update(parents) |
|
nodes_to_visit.extend(parents) |
|
return ancestors |
|
|
|
def extend_with_ancestors(self, classes: set, hierarchy: dict) -> set: |
|
""" |
|
Extend the given set of classes with their ancestors from the hierarchy. |
|
|
|
Args: |
|
classes (set): The set of classes to extend. |
|
hierarchy (dict): The hierarchy of classes. |
|
|
|
Returns: |
|
set: The extended set of classes including their ancestors. |
|
""" |
|
extended_classes = set(classes) |
|
for cls in classes: |
|
ancestors = self.find_ancestors(cls, hierarchy) |
|
extended_classes.update(ancestors) |
|
return extended_classes |
|
|
|
def calculate_hierarchical_precision_recall( |
|
self, |
|
reference_codes: List[str], |
|
predicted_codes: List[str], |
|
hierarchy: Dict[str, Dict[str, float]], |
|
) -> Tuple[float, float]: |
|
""" |
|
Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition. |
|
|
|
Args: |
|
reference_codes (List[str]): The list of reference codes. |
|
predicted_codes (List[str]): The list of predicted codes. |
|
hierarchy (Dict[str, Dict[str, float]]): The hierarchy definition where keys are nodes and values are dictionaries of parent nodes with distances. |
|
|
|
Returns: |
|
Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values. |
|
""" |
|
extended_real = set() |
|
extended_predicted = set() |
|
|
|
|
|
for code in reference_codes: |
|
extended_real.add(code) |
|
extended_real.update(self.find_ancestors(code, hierarchy)) |
|
|
|
|
|
for code in predicted_codes: |
|
extended_predicted.add(code) |
|
extended_predicted.update(self.find_ancestors(code, hierarchy)) |
|
|
|
|
|
correct_recall = extended_real.intersection(extended_predicted) |
|
|
|
|
|
correct_precision = set() |
|
for code in predicted_codes: |
|
if code in extended_real: |
|
correct_precision.add(code) |
|
correct_precision.update( |
|
self.find_ancestors(code, hierarchy).intersection(extended_real) |
|
) |
|
|
|
|
|
hP = ( |
|
len(correct_precision) / len(extended_predicted) |
|
if extended_predicted |
|
else 0 |
|
) |
|
hR = len(correct_recall) / len(extended_real) if extended_real else 0 |
|
|
|
return hP, hR |
|
|
|
def hierarchical_f_measure(self, hP, hR, beta=1.0): |
|
""" |
|
Calculate the hierarchical F-measure. |
|
|
|
Parameters: |
|
hP (float): The hierarchical precision. |
|
hR (float): The hierarchical recall. |
|
beta (float, optional): The beta value for F-measure calculation. Default is 1.0. |
|
|
|
Returns: |
|
float: The hierarchical F-measure. |
|
""" |
|
if hP + hR == 0: |
|
return 0 |
|
return (beta**2 + 1) * hP * hR / (beta**2 * hP + hR) |
|
|
|
def _download_and_prepare(self, dl_manager): |
|
"""Download external ISCO-08 csv file from the ILO website for creating the hierarchy dictionary.""" |
|
isco_csv = dl_manager.download_and_extract(ISCO_CSV_MIRROR_URL) |
|
print(f"ISCO CSV file downloaded") |
|
self.isco_hierarchy = self.create_hierarchy_dict(isco_csv) |
|
print("Weighted ISCO hierarchy dictionary created") |
|
print(self.isco_hierarchy) |
|
|
|
def _compute(self, predictions, references): |
|
"""Returns the accuracy scores.""" |
|
|
|
predictions = [str(p) for p in predictions] |
|
references = [str(r) for r in references] |
|
|
|
|
|
accuracy = sum(i == j for i, j in zip(predictions, references)) / len( |
|
predictions |
|
) |
|
print(f"Accuracy: {accuracy}") |
|
|
|
|
|
hierarchy = self.isco_hierarchy |
|
hP, hR = self.calculate_hierarchical_precision_recall( |
|
references, predictions, hierarchy |
|
) |
|
hF = self.hierarchical_f_measure(hP, hR) |
|
print( |
|
f"Hierarchical Precision: {hP}, Hierarchical Recall: {hR}, Hierarchical F-measure: {hF}" |
|
) |
|
|
|
return { |
|
"accuracy": accuracy, |
|
"hierarchical_precision": hP, |
|
"hierarchical_recall": hR, |
|
"hierarchical_fmeasure": hF, |
|
} |
|
|