Spaces:
Running
Running
from tiktoken import Encoding | |
from utils.log_util import logger | |
def decode(self, tokens, errors="replace", skip_special_tokens=False): | |
""" | |
默认的decode,可能会报错,详见 decode_test.py | |
skip_special_tokens 是为了兼容 hf_tokenizer | |
errors: | |
decoded bytes are not guaranteed to be valid UTF-8. | |
"strict" Raise UnicodeError | |
"ignore" Ignore and continue | |
"replace" Replace with replacement character | |
"backslashreplace" Replace with backslashed escape sequence | |
"xmlcharrefreplace" Replace with XML character reference | |
"namereplace" | |
""" | |
try: | |
decode_str = self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors) | |
except Exception as e: # 捕捉不到 PyO3PanicException | |
logger.error(f"{e} for {tokens} -> return 'null'") | |
decode_str = "null" | |
except: | |
logger.error(f"unknown exception for {tokens} -> return 'null'") | |
decode_str = "null" | |
return decode_str | |
def convert_ids_to_tokens(self, tokens, skip_special_tokens=False): | |
""" | |
为什么没有这个方法? | |
""" | |
try: | |
return self.decode_tokens_bytes(tokens) | |
except Exception as e: # 捕捉不到 PyO3PanicException | |
# 什么要返回None?见zh_util.py | |
# 16个空闲id, 100256 100261-100275 | |
logger.error(f"{e} for {tokens} -> return None") | |
return [None for _ in tokens] | |
except: | |
logger.error(f"unknown exception for {tokens} -> return None") | |
return [None for _ in tokens] | |
def get_vocab(self, token_type="str"): | |
"""Returns vocab as a dict | |
:param token_type: ["str", "byte"] | |
:return: | |
""" | |
vocab = {} | |
key_error_list = [] | |
unicode_decode_error_list = [] | |
for i in range(self.vocab_size): | |
try: | |
token_byte = self.convert_ids_to_tokens([i])[0] | |
if token_byte is None: | |
continue | |
# token_str = token_byte.decode("utf-8") | |
vocab[token_byte] = i | |
except UnicodeDecodeError: # 773 UnicodeDecodeError | |
unicode_decode_error_list.append((i, str(token_byte))) | |
vocab[token_byte] = i | |
# vocab.update(self.added_tokens_encoder) | |
logger.info(f"{self.name} {len(key_error_list)} KeyError: {key_error_list}") | |
logger.info(f"{self.name} {len(unicode_decode_error_list)} UnicodeDecodeError: {unicode_decode_error_list[:5]}") | |
return vocab | |
def vocab_size(self): | |
"""Returns vocab size without special tokens""" | |
return len(self._mergeable_ranks) | |
def encode(self, *args, **kwargs): | |
""" | |
add_special_token 是为了兼容 hf_tokenizer | |
""" | |
kwargs.pop("add_special_tokens", None) | |
kwargs["allowed_special"] = "all" | |
return self._encode(*args, **kwargs) | |
def __len__(self): | |
return self.n_vocab | |
# tiktoken patch | |
Encoding._encode = Encoding.encode | |
Encoding.encode = encode | |
Encoding.decode = decode | |
Encoding.convert_ids_to_tokens = convert_ids_to_tokens | |
Encoding.get_vocab = get_vocab | |
Encoding.vocab_size = vocab_size | |
Encoding.__len__ = __len__ | |