Spaces:
Running
Running
File size: 11,465 Bytes
7156337 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 |
# Copyright (c) 2021, EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron tokenizers."""
from abc import ABC
from abc import abstractmethod
from tokenizers import Tokenizer
from transformers import GPT2Tokenizer, GPT2TokenizerFast
import numpy as np
import sentencepiece as spm
from typing import List, Union
from .gpt2_tokenization import GPT2Tokenizer
def build_tokenizer(args):
"""Initialize tokenizer."""
if args.rank == 0:
print("> building {} tokenizer ...".format(args.tokenizer_type), flush=True)
# Select and instantiate the tokenizer.
if args.tokenizer_type.lower() == "GPT2BPETokenizer".lower():
assert args.vocab_file is not None
assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
elif args.tokenizer_type.lower() == "SPMTokenizer".lower():
assert args.vocab_file is not None
tokenizer = SentencePieceTokenizer(args.vocab_file)
elif args.tokenizer_type.lower() == "HFTokenizer".lower():
assert args.vocab_file is not None
tokenizer = HFTokenizer(args.vocab_file)
elif args.tokenizer_type.lower() == "HFGPT2Tokenizer".lower():
if args.vocab_file is None:
print(
"WARNING: No vocab file found, loading Huggingface's pretrained GPT2Tokenizer"
)
tokenizer = HFGPT2Tokenizer(args.vocab_file)
elif args.tokenizer_type.lower() == "CharLevelTokenizer".lower():
tokenizer = CharLevelTokenizer(vocab_size=512)
elif args.tokenizer_type.lower() == "TiktokenTokenizer".lower():
assert args.vocab_file is not None
tokenizer = TiktokenTokenizer(args.vocab_file)
else:
raise NotImplementedError(
"{} tokenizer is not " "implemented.".format(args.tokenizer_type)
)
# Add vocab size.
args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args)
return tokenizer
def _vocab_size_with_padding(orig_vocab_size, args):
"""Pad vocab size so it is divisible by model parallel size and
still having GPU friendly size."""
after = orig_vocab_size
multiple = args.make_vocab_size_divisible_by * args.model_parallel_size
while (after % multiple) != 0:
after += 1
if args.rank == 0:
print(
" > padded vocab (size: {}) with {} dummy tokens "
"(new size: {})".format(orig_vocab_size, after - orig_vocab_size, after),
flush=True,
)
return after
class AbstractTokenizer(ABC):
"""Abstract class for tokenizer."""
def __init__(self, name):
self.name = name
super().__init__()
@property
@abstractmethod
def vocab_size(self):
pass
@property
@abstractmethod
def vocab(self):
"""Dictionary from vocab text token to id token."""
pass
@property
@abstractmethod
def inv_vocab(self):
"""Dictionary from vocab id token to text token."""
pass
@abstractmethod
def tokenize(self, text):
pass
def detokenize(self, token_ids):
raise NotImplementedError(
"detokenizer is not implemented for {} " "tokenizer".format(self.name)
)
@property
def cls(self):
raise NotImplementedError(
"CLS is not provided for {} " "tokenizer".format(self.name)
)
@property
def sep(self):
raise NotImplementedError(
"SEP is not provided for {} " "tokenizer".format(self.name)
)
@property
def pad(self):
raise NotImplementedError(
"PAD is not provided for {} " "tokenizer".format(self.name)
)
@property
def eod(self):
raise NotImplementedError(
"EOD is not provided for {} " "tokenizer".format(self.name)
)
@property
def mask(self):
raise NotImplementedError(
"MASK is not provided for {} " "tokenizer".format(self.name)
)
class _GPT2BPETokenizer(AbstractTokenizer):
"""Original GPT2 BPE tokenizer."""
def __init__(self, vocab_file, merge_file):
name = "GPT2 BPE"
super().__init__(name)
self.tokenizer = GPT2Tokenizer(
vocab_file, merge_file, errors="replace", special_tokens=[], max_len=None
)
self.eod_id = self.tokenizer.encoder["<|endoftext|>"]
@property
def vocab_size(self):
return len(self.tokenizer.encoder)
@property
def vocab(self):
return self.tokenizer.encoder
@property
def inv_vocab(self):
return self.tokenizer.decoder
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property
def eod(self):
return self.eod_id
class SentencePieceTokenizer(AbstractTokenizer):
"""Designed to Integrate SP's Tokenizer."""
def __init__(self, vocab_file):
name = "SPM"
super().__init__(name)
self.tokenizer = spm.SentencePieceProcessor(model_file=vocab_file)
self.eod_id = self.tokenizer.piece_to_id("<|endoftext|>")
@property
def vocab_size(self):
return self.tokenizer.get_piece_size()
@property
def vocab(self):
return {
self.tokenizer.id_to_piece(idx): idx
for idx in range(self.tokenizer.get_piece_size())
}
@property
def inv_vocab(self):
return {
idx: self.tokenizer.id_to_piece(idx)
for idx in range(self.tokenizer.get_piece_size())
}
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property
def eod(self):
return self.eod_id
class HFTokenizer(AbstractTokenizer):
"""Designed to Integrate HF's Tokenizer library."""
def __init__(self, vocab_file):
name = "HFTokenizer"
super().__init__(name)
self.tokenizer = Tokenizer.from_file(vocab_file)
self.eod_id = self.tokenizer.token_to_id("<|endoftext|>")
self.pad_id = self.tokenizer.token_to_id("<|padding|>")
@property
def vocab_size(self):
return self.tokenizer.get_vocab_size()
@property
def vocab(self):
return self.tokenizer.get_vocab()
@property
def inv_vocab(self):
return self.tokenizer.decoder
def tokenize(self, text: str):
return self.tokenizer.encode(text).ids
def tokenize_batch(self, text_batch: Union[List[str], str]):
return self.tokenizer.encode_batch(text_batch)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property
def eod(self):
return self.eod_id
class HFGPT2Tokenizer(AbstractTokenizer):
"""Designed to Integrate the pretrained OpenAI GPT2 Tokenizers from HF"""
def __init__(self, vocab_file=None, fast=True):
name = "HFGPT2Tokenizer"
if fast:
name += "Fast"
super().__init__(name)
if vocab_file is None:
vocab_file = "gpt2"
if fast:
self.tokenizer = GPT2TokenizerFast.from_pretrained(vocab_file)
else:
self.tokenizer = GPT2Tokenizer.from_pretrained(vocab_file)
self.tokenizer.add_special_tokens({"pad_token": "<|padding|>"})
self.eod_id = self.tokenizer.eos_token_id
self.pad_id = self.tokenizer.pad_token_id
@property
def vocab_size(self):
return len(self.tokenizer)
@property
def vocab(self):
return self.tokenizer.get_vocab()
@property
def inv_vocab(self):
return self.tokenizer._tokenizer.decoder
def tokenize(self, text: str):
return self.tokenizer.encode(text)
def tokenize_batch(self, text_batch: Union[List[str], str]):
if isinstance(text_batch, str):
text_batch = [text_batch]
return [self.tokenize(t) for t in text_batch]
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property
def eod(self):
return self.eod_id
class CharLevelTokenizer(AbstractTokenizer):
"""Character Level Tokenizer"""
def __init__(self, vocab_size):
name = "CharLevelTokenizer"
super().__init__(name)
self._vocab_size = vocab_size
self.eod_id = 0
self.pad_id = 1
def clamp(self, n):
return max(32, min(n, self.vocab_size))
@property
def vocab_size(self):
return self._vocab_size
@property
def vocab(self):
raise NotImplementedError
@property
def inv_vocab(self):
raise NotImplementedError
def decode_token(self, token: int):
return str(chr(self.clamp(token)))
def tokenize(self, text: str):
return list(np.fromstring(text, dtype=np.uint8))
def tokenize_batch(self, text_batch: Union[List[str], str]):
if isinstance(text_batch, list):
return [self.tokenize(s) for s in text_batch]
else:
return self.tokenize(text_batch)
def detokenize(self, token_ids):
return "".join(list(map(self.decode_token, token_ids)))
@property
def eod(self):
return self.eod_id
class TiktokenTokenizer(AbstractTokenizer):
"""Tokenizer from OpenAI's tiktoken implementation"""
def __init__(self, vocab_file):
try:
import tiktoken
except ModuleNotFoundError:
print("Please install tiktoken: (https://github.com/openai/tiktoken)")
raise Exception
name = "TiktokenTokenizer"
super().__init__(name)
self.tokenizer = tiktoken.get_encoding(vocab_file)
self.eod_id = self.tokenizer.eot_token
self.pad_id = None
@property
def vocab_size(self):
return self.tokenizer.n_vocab
@property
def vocab(self):
raise NotImplementedError(
"TiktokenTokenizer does not implement vocabulary access."
)
@property
def inv_vocab(self):
raise NotImplementedError(
"TiktokenTokenizer does not implement vocabulary access. \
To get the idx-th token in vocabulary, use tokenizer.decode([idx]) ."
)
def tokenize(self, text: str):
return self.tokenizer.encode(text) # , allowed_special="all")
def tokenize_batch(self, text_batch: List[str]):
return self.tokenizer.encode_batch(text_batch, allowed_special="all")
def detokenize(self, token_ids):
return self.tokenizer.decode(tokens=token_ids, errors="strict")
@property
def eod(self):
return self.eod_id
@property
def pad(self):
raise NotImplementedError
|