Spaces:
Build error
Build error
# Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import evaluate | |
import datasets | |
import numpy as np | |
from vendi_score import vendi, image_utils, text_utils | |
# TODO: Add BibTeX citation | |
_CITATION = "" | |
_DESCRIPTION = """\ | |
The Vendi Score is a metric for evaluating diversity in machine learning. | |
The input to metric is a collection of samples and a pairwise similarity function, and the output is a number, which can be interpreted as the effective number of unique elements in the sample. | |
See the project's README at https://github.com/vertaix/Vendi-Score for more information. | |
The interactive example calculates the Vendi Score for a set of strings using the n-gram overlap similarity, averaged between n=1 and n=2. | |
""" | |
_KWARGS_DESCRIPTION = """ | |
Calculates the Vendi Score given samples and a similarity function. | |
Args: | |
samples: an iterable containing n samples to score, an n x n similarity | |
matrix K, or an n x d feature matrix X. | |
k: a pairwise similarity function, or a string identifying a predefined | |
similarity function. | |
Options: ngram_overlap, text_embeddings, pixels, image_embeddings. | |
score_K: if true, samples is an n x n similarity matrix K. | |
score_X: if true, samples is an n x d feature matrix X. | |
score_dual: if true, compute diversity score of X @ X.T. | |
normalize: if true, normalize the similarity scores. | |
model (optional): if k is "text_embeddings", a model mapping sentences to | |
embeddings (output should be an object with an attribute called | |
`pooler_output` or `last_hidden_state`). If k is "image_embeddings", a | |
model mapping images to embeddings. | |
tokenizer (optional): if k is "text_embeddings" or "ngram_overlap", a | |
tokenizer mapping strings to lists. | |
transform (optional): if k is "image_embeddings", a torchvision transform | |
to apply to the samples. | |
model_path (optional): if k is "text_embeddings", the name of a model on the | |
HuggingFace hub. | |
ns (optional): if k is "ngram_overlap", the values of n to calculate. | |
batch_size (optional): batch size to use if k is "text_embedding" or | |
"image_embedding". | |
device (optional): a string (e.g. "cuda", "cpu") or torch.device identifying | |
the device to use if k is "text_embedding or "image_embedding". | |
Returns: | |
VS: The Vendi Score. | |
Examples: | |
>>> vendiscore = evaluate.load("danf0/vendiscore") | |
>>> samples = ["Look, Jane.", | |
"See Spot.", | |
"See Spot run.", | |
"Run, Spot, run.", | |
"Jane sees Spot run."] | |
>>> results = vendiscore.compute(samples, k="ngram_overlap", ns=[1, 2]) | |
>>> print(results) | |
{'VS': 3.90657...} | |
""" | |
def get_dtype(config_name): | |
if config_name == "text": | |
return datasets.Features({"samples": datasets.Value("string")}) | |
if config_name == "image": | |
return datasets.Features({"samples": datasets.Image}) | |
elif config_name in ("X", "K"): | |
return datasets.Array2D | |
elif config_name == "default": | |
return datasets.Value("string") | |
else: | |
return datasets.Value(config_name) | |
def get_features(config_name): | |
if config_name in ("text", "default"): | |
return datasets.Features({"samples": datasets.Value("string")}) | |
return [ | |
datasets.Features({"samples": datasets.Value("int32")}), | |
datasets.Features( | |
{"samples": datasets.Sequence(datasets.Value("int32"))} | |
), | |
datasets.Features({"samples": datasets.Value("float")}), | |
datasets.Features( | |
{"samples": datasets.Sequence(datasets.Value("float"))} | |
), | |
datasets.Features({"samples": datasets.Image}), | |
datasets.Features({"samples": datasets.Array2D}), | |
] | |
class VendiScore(evaluate.Metric): | |
"""TODO: Short description of my evaluation module.""" | |
def _info(self): | |
# TODO: Specifies the evaluate.EvaluationModuleInfo object | |
return evaluate.MetricInfo( | |
# This is the description that will appear on the modules page. | |
module_type="metric", | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=get_features(self.config_name), | |
homepage="http://github.com/Vertaix/Vendi-Score", | |
codebase_urls=["http://github.com/Vertaix/Vendi-Score"], | |
reference_urls=[], | |
) | |
def _download_and_prepare(self, dl_manager): | |
import nltk | |
nltk.download("punkt") | |
def _compute( | |
self, | |
samples, | |
k="ngram_overlap", | |
score_K=False, | |
score_X=False, | |
score_dual=False, | |
normalize=False, | |
model=None, | |
tokenizer=None, | |
transform=None, | |
model_path=None, | |
ns=[1, 2], | |
batch_size=16, | |
device="cpu", | |
): | |
if score_K: | |
vs = vendi.score_K(samples, normalize=normalize) | |
elif score_dual: | |
vs = vendi.score_dual(samples, normalize=normalize) | |
elif score_X: | |
vs = vendi.score_X(samples, normalize=normalize) | |
elif type(k) == str and k == "ngram_overlap": | |
vs = text_utils.ngram_vendi_score( | |
samples, ns=ns, tokenizer=tokenizer | |
) | |
elif type(k) == str and k == "text_embeddings": | |
vs = text_utils.embedding_vendi_score( | |
samples, | |
model=model, | |
tokenizer=tokenizer, | |
batch_size=batch_size, | |
device=device, | |
model_path=model_path, | |
) | |
elif type(k) == str and k == "pixels": | |
vs = image_utils.pixel_vendi_score(samples) | |
elif type(k) == str and k == "image_embeddings": | |
vs = image_utils.embedding_vendi_score( | |
samples, | |
batch_size=batch_size, | |
device=device, | |
model=model, | |
transform=transform, | |
) | |
else: | |
vs = vendi.score(samples, k) | |
return {"VS": vs} | |