vendiscore / vendiscore.py
danf0's picture
Update space name
ba0c789
raw
history blame
6.3 kB
# Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import evaluate
import datasets
import numpy as np
from vendi_score import vendi, text_utils
# TODO: Add BibTeX citation
_CITATION = ""
_DESCRIPTION = """\
The Vendi Score is a metric for evaluating diversity in machine learning.
The input to metric is a collection of samples and a pairwise similarity function, and the output is a number, which can be interpreted as the effective number of unique elements in the sample.
See the project's README at https://github.com/vertaix/Vendi-Score for more information.
The interactive example calculates the Vendi Score for a set of strings using the n-gram overlap similarity, averaged between n=1 and n=2.
"""
_KWARGS_DESCRIPTION = """
Calculates the Vendi Score given samples and a similarity function.
Args:
samples: an iterable containing n samples to score, an n x n similarity
matrix K, or an n x d feature matrix X.
k: a pairwise similarity function, or a string identifying a predefined
similarity function.
Options: ngram_overlap, text_embeddings.
score_K: if true, samples is an n x n similarity matrix K.
score_X: if true, samples is an n x d feature matrix X.
score_dual: if true, compute diversity score of X @ X.T.
normalize: if true, normalize the similarity scores.
model (optional): if k is "text_embeddings", a model mapping sentences to
embeddings (output should be an object with an attribute called
`pooler_output` or `last_hidden_state`).
tokenizer (optional): if k is "text_embeddings" or "ngram_overlap", a
tokenizer mapping strings to lists.
model_path (optional): if k is "text_embeddings", the name of a model on the
HuggingFace hub.
ns (optional): if k is "ngram_overlap", the values of n to calculate.
batch_size (optional): batch size to use if k is "text_embedding".
device (optional): a string (e.g. "cuda", "cpu") or torch.device identifying
the device to use if k is "text_embedding".
Returns:
VS: The Vendi Score.
Examples:
>>> vendiscore = evaluate.load("Vertaix/vendiscore", "text")
>>> samples = ["Look, Jane.",
"See Spot.",
"See Spot run.",
"Run, Spot, run.",
"Jane sees Spot run."]
>>> results = vendiscore.compute(samples, k="ngram_overlap", ns=[1, 2])
>>> print(results)
{'VS': 3.90657...}
"""
def get_features(config_name):
if config_name in ("text", "default"):
return datasets.Features({"samples": datasets.Value("string")})
# if config_name == "image":
# return datasets.Features({"samples": datasets.Image})
if config_name in ("K", "X"):
return [
datasets.Features(
{"samples": datasets.Sequence(datasets.Value("float"))}
),
datasets.Features(
{"samples": datasets.Sequence(datasets.Value("int32"))}
),
]
return [
datasets.Features({"samples": datasets.Value("float")}),
datasets.Features({"samples": datasets.Value("int32")}),
datasets.Features({"samples": datasets.Array2D}),
]
@evaluate.utils.file_utils.add_start_docstrings(
_DESCRIPTION, _KWARGS_DESCRIPTION
)
class VendiScore(evaluate.Metric):
"""TODO: Short description of my evaluation module."""
def _info(self):
# TODO: Specifies the evaluate.EvaluationModuleInfo object
return evaluate.MetricInfo(
# This is the description that will appear on the modules page.
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=get_features(self.config_name),
homepage="http://github.com/Vertaix/Vendi-Score",
codebase_urls=["http://github.com/Vertaix/Vendi-Score"],
reference_urls=[],
)
def _download_and_prepare(self, dl_manager):
import nltk
nltk.download("punkt")
def _compute(
self,
samples,
k="ngram_overlap",
score_K=False,
score_X=False,
score_dual=False,
normalize=False,
model=None,
tokenizer=None,
model_path=None,
ns=[1, 2],
batch_size=16,
device="cpu",
):
if score_K:
vs = vendi.score_K(np.array(samples), normalize=normalize)
elif score_dual:
vs = vendi.score_dual(np.array(samples), normalize=normalize)
elif score_X:
vs = vendi.score_X(np.array(samples), normalize=normalize)
elif type(k) == str and k == "ngram_overlap":
vs = text_utils.ngram_vendi_score(
samples, ns=ns, tokenizer=tokenizer
)
elif type(k) == str and k == "text_embeddings":
vs = text_utils.embedding_vendi_score(
samples,
model=model,
tokenizer=tokenizer,
batch_size=batch_size,
device=device,
model_path=model_path,
)
# elif type(k) == str and k == "pixels":
# vs = image_utils.pixel_vendi_score(
# [Image.fromarray(x) for x in samples]
# )
# elif type(k) == str and k == "image_embeddings":
# vs = image_utils.embedding_vendi_score(
# [Image.fromarray(x) for x in samples],
# batch_size=batch_size,
# device=device,
# model=model,
# transform=transform,
# )
else:
vs = vendi.score(samples, k)
return {"VS": vs}