|
from tiktoken import Encoding |
|
from utils.log_util import logger |
|
|
|
|
|
def decode(self, tokens, errors="replace", skip_special_tokens=False): |
|
""" |
|
默认的decode,可能会报错,详见 decode_test.py |
|
skip_special_tokens 是为了兼容 hf_tokenizer |
|
|
|
errors: |
|
decoded bytes are not guaranteed to be valid UTF-8. |
|
"strict" Raise UnicodeError |
|
"ignore" Ignore and continue |
|
"replace" Replace with replacement character |
|
"backslashreplace" Replace with backslashed escape sequence |
|
"xmlcharrefreplace" Replace with XML character reference |
|
"namereplace" |
|
""" |
|
try: |
|
decode_str = self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors) |
|
except Exception as e: |
|
logger.error(f"{e} for {tokens} -> return 'null'") |
|
decode_str = "null" |
|
except: |
|
logger.error(f"unknown exception for {tokens} -> return 'null'") |
|
decode_str = "null" |
|
return decode_str |
|
|
|
|
|
def convert_ids_to_tokens(self, tokens, skip_special_tokens=False): |
|
""" |
|
为什么没有这个方法? |
|
""" |
|
try: |
|
return self.decode_tokens_bytes(tokens) |
|
except Exception as e: |
|
|
|
|
|
logger.error(f"{e} for {tokens} -> return None") |
|
return [None for _ in tokens] |
|
except: |
|
logger.error(f"unknown exception for {tokens} -> return None") |
|
return [None for _ in tokens] |
|
|
|
|
|
def get_vocab(self, token_type="str"): |
|
"""Returns vocab as a dict |
|
:param token_type: ["str", "byte"] |
|
:return: |
|
""" |
|
vocab = {} |
|
key_error_list = [] |
|
unicode_decode_error_list = [] |
|
for i in range(self.vocab_size): |
|
try: |
|
token_byte = self.convert_ids_to_tokens([i])[0] |
|
if token_byte is None: |
|
continue |
|
|
|
vocab[token_byte] = i |
|
except UnicodeDecodeError: |
|
unicode_decode_error_list.append((i, str(token_byte))) |
|
vocab[token_byte] = i |
|
|
|
|
|
logger.info(f"{self.name} {len(key_error_list)} KeyError: {key_error_list}") |
|
logger.info(f"{self.name} {len(unicode_decode_error_list)} UnicodeDecodeError: {unicode_decode_error_list[:5]}") |
|
return vocab |
|
|
|
|
|
@property |
|
def vocab_size(self): |
|
"""Returns vocab size""" |
|
return self.n_vocab |
|
|
|
|
|
def encode(self, *args, **kwargs): |
|
""" |
|
add_special_token 是为了兼容 hf_tokenizer |
|
""" |
|
kwargs.pop("add_special_tokens", None) |
|
kwargs["allowed_special"] = "all" |
|
return self._encode(*args, **kwargs) |
|
|
|
|
|
def __len__(self): |
|
return self.n_vocab |
|
|
|
|
|
|
|
Encoding._encode = Encoding.encode |
|
Encoding.encode = encode |
|
Encoding.decode = decode |
|
Encoding.convert_ids_to_tokens = convert_ids_to_tokens |
|
Encoding.get_vocab = get_vocab |
|
Encoding.vocab_size = vocab_size |
|
Encoding.__len__ = __len__ |
|
|