File size: 4,697 Bytes
b97e015
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import evaluate
import datasets
from typing import Union, Dict
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from tqdm import tqdm

_DESCRIPTION = """
Perplexity metric implemented by d-Matrix.
Perplexity (PPL) is one of the most common metrics for evaluating language models.
It is defined as the exponentiated average negative log-likelihood of a sequence, calculated with exponent base `e`.
For more information, see https://huggingface.co/docs/transformers/perplexity
"""

_KWARGS_DESCRIPTION = """
Args:
    model (Union[str,AutoModelForCausalLM]): model used for calculating Perplexity
            NOTE: Perplexity can only be calculated for causal language models.
                    This includes models such as gpt2, causal variations of bert,
                    causal versions of t5, and more (the full list can be found
                    in the AutoModelForCausalLM documentation here:
                    https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM )
    text (list of str): input text, each separate text snippet is one list entry.
    device (str): device to run on, defaults to 'cuda' when available.
    max_length (int): maximum sequence length, defaults to 2048.
Returns:
    perplexity: dictionary containing the perplexity score and loss.
Examples:
    Example:
        >>> from datasets import load_dataset
        >>> perplexity = evaluate.load("dmx_perplexity", module_type="metric")
        >>> input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:10] # doctest: +SKIP
        >>> results = perplexity.compute(model='distilgpt2',
        ...                              text=input_texts)
        >>> print(list(results.keys()))
        ['loss', 'perplexity']
        >>> print(results['loss']) # doctest: +SKIP
        3.8299286365509033
        >>> print(results['perplexity']) # doctest: +SKIP
        46.05925369262695
"""


class DmxPerplexity(evaluate.Metric):
    def _info(self):
        return evaluate.MetricInfo(
            module_type="metric",
            description=_DESCRIPTION,
            citation="",
            inputs_description=_KWARGS_DESCRIPTION,
            features=datasets.Features(
                {
                    "text": datasets.Value("string"),
                }
            ),
            reference_urls=["https://huggingface.co/docs/transformers/perplexity"],
        )

    def _compute(
        self,
        text,
        model: Union[str, AutoModelForCausalLM],
        device=None,
        max_length=None,
    ):
        if device is not None:
            assert device in [
                "gpu",
                "cpu",
                "cuda",
            ], "device should be either gpu or cpu."
            if device == "gpu":
                device = "cuda"
        else:
            device = "cuda" if torch.cuda.is_available() else "cpu"

        if isinstance(model, str):
            tokenizer = AutoTokenizer.from_pretrained(model)
            model = AutoModelForCausalLM.from_pretrained(model)

        else:
            tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)

        if max_length:
            max_seq_len = max_length
        elif hasattr(model.config, "max_position_embeddings"):
            max_seq_len = model.config.max_position_embeddings
        elif hasattr(model.config, "n_positions"):
            max_seq_len = model.config.n_positions
        else:
            max_seq_len = 2048

        model = model.to(device)
        encodings = tokenizer("\n\n".join(text), return_tensors="pt")

        stride = max_seq_len
        seq_len = encodings.input_ids.size(1)

        nlls = []
        prev_end_loc = 0
        for begin_loc in tqdm(range(0, seq_len, stride)):
            end_loc = min(begin_loc + max_seq_len, seq_len)
            trg_len = end_loc - prev_end_loc
            input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
            target_ids = input_ids.clone()
            target_ids[:, :-trg_len] = -100

            with torch.no_grad():
                outputs = model(input_ids, labels=target_ids)
                if isinstance(outputs, Dict):
                    neg_log_likelihood = outputs["loss"] * trg_len
                else:
                    neg_log_likelihood = outputs.loss * trg_len

            nlls.append(neg_log_likelihood)

            prev_end_loc = end_loc
            if end_loc == seq_len:
                break

        loss = torch.stack(nlls).float().sum() / end_loc
        ppl = torch.exp(loss)

        return dict(
            loss=loss.item(),
            perplexity=ppl.item(),
        )