Spaces:
Running
Running
# Copyright 2023 The HuggingFace Evaluate Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Jaccard similarity metric.""" | |
import evaluate | |
import datasets | |
from sklearn.metrics import jaccard_score | |
import numpy as np | |
_CITATION = """\ | |
@article{jaccard1912distribution, | |
title={The distribution of the flora in the alpine zone}, | |
author={Jaccard, Paul}, | |
journal={New phytologist}, | |
volume={11}, | |
number={2}, | |
pages={37--50}, | |
year={1912}, | |
publisher={Wiley Online Library} | |
} | |
""" | |
_DESCRIPTION = """\ | |
Jaccard similarity is a statistic used for gauging the similarity and diversity of sample sets. | |
The Jaccard coefficient measures similarity between finite sample sets, and is defined as the size of | |
the intersection divided by the size of the union of the sample sets. This implementation uses | |
scikit-learn's jaccard_score function. | |
""" | |
_KWARGS_DESCRIPTION = """ | |
Calculates the Jaccard similarity between predictions and references using scikit-learn. | |
Args: | |
predictions: 1d array-like, or label indicator array / sparse matrix | |
Predicted labels, as returned by a classifier. | |
references: 1d array-like, or label indicator array / sparse matrix | |
Ground truth (correct) labels. | |
labels: array-like of shape (n_classes,), default=None | |
The set of labels to include when average != 'binary', and their order if average is None. | |
pos_label: int, float, bool or str, default=1 | |
The class to report if average='binary' and the data is binary. | |
average: {'micro', 'macro', 'samples', 'weighted', 'binary'} or None, default='binary' | |
This parameter is required for multiclass/multilabel targets. | |
sample_weight: array-like of shape (n_samples,), default=None | |
Sample weights. | |
zero_division: "warn", {0.0, 1.0}, default="warn" | |
Sets the value to return when there is a zero division. | |
Returns: | |
jaccard_similarity: float or ndarray of shape (n_unique_labels,) | |
Jaccard similarity score. | |
Examples: | |
>>> jaccard_metric = evaluate.load("jaccard_similarity") | |
>>> predictions = [0, 2, 1, 3] | |
>>> references = [0, 1, 2, 3] | |
>>> results = jaccard_metric.compute(predictions=predictions, references=references, average='macro') | |
>>> print(results) | |
{'jaccard_similarity': 0.5} | |
""" | |
class JaccardSimilarity(evaluate.Metric): | |
def _info(self): | |
return evaluate.MetricInfo( | |
module_type="metric", | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features({ | |
"predictions": datasets.Sequence(datasets.Value("int32")), | |
"references": datasets.Sequence(datasets.Value("int32")), | |
}), | |
reference_urls=[ | |
"https://scikit-learn.org/stable/modules/generated/sklearn.metrics.jaccard_score.html", | |
"https://en.wikipedia.org/wiki/Jaccard_index" | |
], | |
) | |
def _compute(self, predictions, references, labels=None, pos_label=1, average='binary', sample_weight=None, zero_division='warn'): | |
predictions = np.array(predictions) | |
references = np.array(references) | |
# Handle different input shapes | |
if predictions.ndim == 1 and references.ndim == 1: | |
# Binary or multiclass case | |
pass | |
elif predictions.ndim == 2 and references.ndim == 2: | |
# Multilabel case | |
if average == 'binary': | |
average = 'micro' # 'binary' doesn't make sense for multilabel | |
else: | |
raise ValueError("Predictions and references should have the same shape") | |
return { | |
"jaccard_similarity": jaccard_score( | |
references, | |
predictions, | |
labels=labels, | |
pos_label=pos_label, | |
average=average, | |
sample_weight=sample_weight, | |
zero_division=zero_division | |
) | |
} |