tokenizer-arena / tokenizer /tiktoken_patch.py
xu-song's picture
fix tiktoken special tokens
adcfb97
raw
history blame
2.15 kB
from tiktoken import Encoding
from utils.log_util import logger
def decode(self, tokens, errors="replace", skip_special_tokens=False):
"""
默认的decode,可能会报错,详见 decode_test.py
skip_special_tokens 是为了兼容 hf_tokenizer
errors=replace, ignore, strict 有什么区别?
"""
try:
decode_str = self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors)
except:
decode_str = "null"
return decode_str
def convert_ids_to_tokens(self, tokens, skip_special_tokens=False):
"""
为什么没有这个方法?
"""
try:
return self.decode_tokens_bytes(tokens)
except Exception as e:
# 什么要返回None?见zh_util.py
# 16个空闲id, 100256 100261-100275
logger.error(e)
return [None for _ in tokens]
def get_vocab(self, token_type="str"):
"""Returns vocab as a dict
:param token_type: ["str", "byte"]
:return:
"""
vocab = {}
key_error_list = []
unicode_decode_error_list = []
for i in range(self.vocab_size):
try:
token_byte = self.convert_ids_to_tokens([i])[0]
if token_byte is None:
continue
# token_str = token_byte.decode("utf-8")
vocab[token_byte] = i
except UnicodeDecodeError: # 773 UnicodeDecodeError
unicode_decode_error_list.append((i, str(token_byte)))
vocab[token_byte] = i
# vocab.update(self.added_tokens_encoder)
logger.info(f"{self.name} {len(key_error_list)} KeyError: {key_error_list}")
logger.info(f"{self.name} {len(unicode_decode_error_list)} UnicodeDecodeError: {unicode_decode_error_list[:5]}")
return vocab
def encode(self, *args, **kwargs):
"""
add_special_token 是为了兼容 hf_tokenizer
"""
kwargs.pop("add_special_tokens", None)
kwargs["allowed_special"] = "all"
return self._encode(*args, **kwargs)
# tiktoken patch
Encoding._encode = Encoding.encode
Encoding.encode = encode
Encoding.decode = decode
Encoding.convert_ids_to_tokens = convert_ids_to_tokens
Encoding.get_vocab = get_vocab