File size: 4,274 Bytes
d2ae486 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import torch
import numpy as np
import torch.nn.functional as F
from lm_eval.base import BaseLM
from datasets import load_dataset
def set_seed(seed):
np.random.seed(seed)
torch.random.manual_seed(seed)
def get_test_dataset(dataset_name, tokenizer, seqlen=2048):
if dataset_name == "wikitext2":
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
testdata = "".join(testdata['text']).split('\n')
elif dataset_name == "c4":
testdata = load_dataset('allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')['text']
else:
raise NotImplementedError
testdata = [item for item in testdata if item != ""]
tokenized_text = [tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id] for item in testdata]
data, doc = [], [tokenizer.bos_token_id]
for sen in tokenized_text:
if len(sen) > seqlen:
continue
if len(doc) + len(sen) > seqlen:
data.append(doc)
doc = [tokenizer.bos_token_id]
doc.extend(sen)
if len(doc) > 1 and len(doc) <= seqlen:
data.append(doc)
return data
class LMEvalAdaptor(BaseLM):
def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1):
super().__init__()
assert isinstance(batch_size, int)
self.model_name = model_name
self.model = model
self.model.eval()
self.tokenizer = tokenizer
self.vocab_size = self.tokenizer.vocab_size
self._batch_size = batch_size
self._max_length = max_length
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
if self._max_length != -1:
return self._max_length
if hasattr(self.model.config, "n_ctx"):
return self.model.config.n_ctx
elif hasattr(self.model.config, "max_position_embeddings"):
return self.model.config.max_position_embeddings
elif hasattr(self.model.config, "n_positions"):
return self.model.config.n_positions
elif "bloom" in self.model_name:
return 2048
elif "llama" in self.model_name:
return 2048 # TODO: did not check this
elif "mpt" in self.model_name:
return 2048
elif "falcon" in self.model_name:
return 2048
else:
print(self.model.config)
raise NotImplementedError
@property
def max_gen_toks(self):
return 256
@property
def batch_size(self):
return self._batch_size
@property
def device(self):
return "cuda"
def tok_encode(self, string: str, add_special_tokens=True):
return self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def loglikelihood(self, requests):
new_reqs = []
for context, continuation in requests:
context, continuation = context.strip(), continuation.strip()
if context == "":
# end of text as context
context_enc = [self.eot_token_id]
else:
context_enc = self.tok_encode(context, add_special_tokens=True)
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
out = self.model(inps)[0]
return out
def _model_generate(self, context, max_length, eos_token_id):
return self.model.generate(
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
) |