Spaces:
Sleeping
Sleeping
File size: 4,697 Bytes
b97e015 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import evaluate
import datasets
from typing import Union, Dict
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from tqdm import tqdm
_DESCRIPTION = """
Perplexity metric implemented by d-Matrix.
Perplexity (PPL) is one of the most common metrics for evaluating language models.
It is defined as the exponentiated average negative log-likelihood of a sequence, calculated with exponent base `e`.
For more information, see https://huggingface.co/docs/transformers/perplexity
"""
_KWARGS_DESCRIPTION = """
Args:
model (Union[str,AutoModelForCausalLM]): model used for calculating Perplexity
NOTE: Perplexity can only be calculated for causal language models.
This includes models such as gpt2, causal variations of bert,
causal versions of t5, and more (the full list can be found
in the AutoModelForCausalLM documentation here:
https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM )
text (list of str): input text, each separate text snippet is one list entry.
device (str): device to run on, defaults to 'cuda' when available.
max_length (int): maximum sequence length, defaults to 2048.
Returns:
perplexity: dictionary containing the perplexity score and loss.
Examples:
Example:
>>> from datasets import load_dataset
>>> perplexity = evaluate.load("dmx_perplexity", module_type="metric")
>>> input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:10] # doctest: +SKIP
>>> results = perplexity.compute(model='distilgpt2',
... text=input_texts)
>>> print(list(results.keys()))
['loss', 'perplexity']
>>> print(results['loss']) # doctest: +SKIP
3.8299286365509033
>>> print(results['perplexity']) # doctest: +SKIP
46.05925369262695
"""
class DmxPerplexity(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
module_type="metric",
description=_DESCRIPTION,
citation="",
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"text": datasets.Value("string"),
}
),
reference_urls=["https://huggingface.co/docs/transformers/perplexity"],
)
def _compute(
self,
text,
model: Union[str, AutoModelForCausalLM],
device=None,
max_length=None,
):
if device is not None:
assert device in [
"gpu",
"cpu",
"cuda",
], "device should be either gpu or cpu."
if device == "gpu":
device = "cuda"
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(model, str):
tokenizer = AutoTokenizer.from_pretrained(model)
model = AutoModelForCausalLM.from_pretrained(model)
else:
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
if max_length:
max_seq_len = max_length
elif hasattr(model.config, "max_position_embeddings"):
max_seq_len = model.config.max_position_embeddings
elif hasattr(model.config, "n_positions"):
max_seq_len = model.config.n_positions
else:
max_seq_len = 2048
model = model.to(device)
encodings = tokenizer("\n\n".join(text), return_tensors="pt")
stride = max_seq_len
seq_len = encodings.input_ids.size(1)
nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
end_loc = min(begin_loc + max_seq_len, seq_len)
trg_len = end_loc - prev_end_loc
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
if isinstance(outputs, Dict):
neg_log_likelihood = outputs["loss"] * trg_len
else:
neg_log_likelihood = outputs.loss * trg_len
nlls.append(neg_log_likelihood)
prev_end_loc = end_loc
if end_loc == seq_len:
break
loss = torch.stack(nlls).float().sum() / end_loc
ppl = torch.exp(loss)
return dict(
loss=loss.item(),
perplexity=ppl.item(),
)
|