Text Generation
English
Eval Results
d-matrix-user commited on
Commit
6626064
1 Parent(s): 3434f81

adding distilgpt2 dmatrix model

Browse files
Files changed (1) hide show
  1. perplexity.py +129 -0
perplexity.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ import datasets
3
+ from evaluate import logging
4
+ from typing import Union, Dict
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+ _DESCRIPTION = """
10
+ Perplexity metric implemented by d-Matrix.
11
+ Perplexity (PPL) is one of the most common metrics for evaluating language models.
12
+ It is defined as the exponentiated average negative log-likelihood of a sequence, calculated with exponent base `e`.
13
+ For more information, see https://huggingface.co/docs/transformers/perplexity
14
+ """
15
+
16
+ _KWARGS_DESCRIPTION = """
17
+ Args:
18
+ model (Union[str,AutoModelForCausalLM]): model used for calculating Perplexity
19
+ NOTE: Perplexity can only be calculated for causal language models.
20
+ This includes models such as gpt2, causal variations of bert,
21
+ causal versions of t5, and more (the full list can be found
22
+ in the AutoModelForCausalLM documentation here:
23
+ https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM )
24
+ predictions (list of str): input text, each separate text snippet
25
+ is one list entry.
26
+ device (str): device to run on, defaults to 'cuda' when available.
27
+ max_length (int): maximum sequence length, defaults to 2048.
28
+ Returns:
29
+ perplexity: dictionary containing the perplexity score and loss.
30
+ Examples:
31
+ Example:
32
+ >>> from datasets import load_dataset
33
+ >>> perplexity = evaluate.load("dmx_perplexity", module_type="metric")
34
+ >>> input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:10] # doctest: +SKIP
35
+ >>> results = perplexity.compute(model='distilgpt2',
36
+ ... predictions=input_texts)
37
+ >>> print(list(results.keys()))
38
+ ['loss', 'perplexity']
39
+ >>> print(results['loss']) # doctest: +SKIP
40
+ 3.8299286365509033
41
+ >>> print(results['perplexity']) # doctest: +SKIP
42
+ 46.05925369262695
43
+ """
44
+
45
+
46
+ class DmxPerplexity(evaluate.Metric):
47
+ def _info(self):
48
+ return evaluate.MetricInfo(
49
+ module_type="metric",
50
+ description=_DESCRIPTION,
51
+ citation="",
52
+ inputs_description=_KWARGS_DESCRIPTION,
53
+ features=datasets.Features(
54
+ {
55
+ "predictions": datasets.Value("string"),
56
+ }
57
+ ),
58
+ reference_urls=["https://huggingface.co/docs/transformers/perplexity"],
59
+ )
60
+
61
+ def _compute(
62
+ self,
63
+ predictions,
64
+ model: Union[str, AutoModelForCausalLM],
65
+ device=None,
66
+ max_length=None,
67
+ ):
68
+ if device is not None:
69
+ assert device in [
70
+ "gpu",
71
+ "cpu",
72
+ "cuda",
73
+ ], "device should be either gpu or cpu."
74
+ if device == "gpu":
75
+ device = "cuda"
76
+ else:
77
+ device = "cuda" if torch.cuda.is_available() else "cpu"
78
+
79
+ if isinstance(model, str):
80
+ tokenizer = AutoTokenizer.from_pretrained(model)
81
+ model = AutoModelForCausalLM.from_pretrained(model)
82
+
83
+ else:
84
+ tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
85
+
86
+ if max_length:
87
+ max_seq_len = max_length
88
+ elif hasattr(model.config, "max_position_embeddings"):
89
+ max_seq_len = model.config.max_position_embeddings
90
+ elif hasattr(model.config, "n_positions"):
91
+ max_seq_len = model.config.n_positions
92
+ else:
93
+ max_seq_len = 2048
94
+
95
+ model = model.to(device)
96
+ encodings = tokenizer("\n\n".join(predictions), return_tensors="pt")
97
+
98
+ stride = max_seq_len
99
+ seq_len = encodings.input_ids.size(1)
100
+
101
+ nlls = []
102
+ prev_end_loc = 0
103
+ for begin_loc in tqdm(range(0, seq_len, stride)):
104
+ end_loc = min(begin_loc + max_seq_len, seq_len)
105
+ trg_len = end_loc - prev_end_loc
106
+ input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
107
+ target_ids = input_ids.clone()
108
+ target_ids[:, :-trg_len] = -100
109
+
110
+ with torch.no_grad():
111
+ outputs = model(input_ids, labels=target_ids)
112
+ if isinstance(outputs, Dict):
113
+ neg_log_likelihood = outputs["loss"] * trg_len
114
+ else:
115
+ neg_log_likelihood = outputs.loss * trg_len
116
+
117
+ nlls.append(neg_log_likelihood)
118
+
119
+ prev_end_loc = end_loc
120
+ if end_loc == seq_len:
121
+ break
122
+
123
+ loss = torch.stack(nlls).float().sum() / end_loc
124
+ ppl = torch.exp(loss)
125
+
126
+ return dict(
127
+ loss=loss.item(),
128
+ perplexity=ppl.item(),
129
+ )