tokenizer-arena / character_util.py
xu-song's picture
update
7c73423
raw
history blame
7.38 kB
"""
TODO:
1. add more language
2. check space count of bert
3. add token_impl
4.
"""
import os
import json
import numpy as np
import pandas as pd
from collections import Counter, defaultdict
from vocab import tokenizer_factory
from typing import Optional, Union, Literal
from utils.log_util import logger
from utils.text_util import contains_digit, get_space_count
from utils.lang_util import detect_language_by_unicode, language_ranges
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
default_columns = ["digit", "zh"]
def _to_unicode(text):
return ''.join(r'\u{:04X}'.format(ord(chr)) for chr in text)
def _get_coding_length(tokenizer, vocab, filter=None):
"""
oov character may be tokenized into more than one token.
"""
all_length = []
for word in vocab:
if len(word) > 1:
continue
if filter is not None and filter(word):
continue
try:
tokens = tokenizer.encode(word)
except Exception as e:
print(e)
all_length.append(len(tokens))
# if len(tokens.ids) > 1:
# if len(tokens) > 3:
# print(word, tokens)
dist_length = Counter(all_length)
mean_length = round(sum(all_length) / len(all_length), 2)
return dist_length, mean_length
cache = {}
def _dist(token_lens):
"""
:param token_lens:
:return: min,median,max of token_lens
"""
if not token_lens:
return "-"
return f"{min(token_lens)},{round(np.median(token_lens))},{max(token_lens)}"
def iter_vocab(
tokenizer_name: str,
from_cache: bool = True,
cache_dir: str = "stats",
) -> Union[pd.DataFrame, dict]:
"""
:param tokenizer_name:
:param from_cache:
:param cache_dir:
:return:
"""
tokenizer_config = tokenizer_factory.get_tokenizer_config(tokenizer_name)
cache_dir = os.path.join(CURRENT_DIR, cache_dir)
os.makedirs(cache_dir, exist_ok=True)
# load from cache
cache_path = os.path.join(cache_dir, "character_stats.json")
if not cache and os.path.exists(cache_path):
with open(cache_path, "r", encoding="utf-8") as f_tmp:
cache.update(json.load(f_tmp))
if from_cache and tokenizer_name in cache:
# logger.info(f"load {tokenizer_config.name_or_path} from cache")
return cache[tokenizer_name]
tokenizer = tokenizer_factory.get_tokenizer(tokenizer_name)
tokens_by_lang = {lang[1]: [] for lang in language_ranges.keys()}
digit_tokens = []
space_tokens = []
byte_tokens = []
buffer = []
for token_id in range(tokenizer.vocab_size):
# for token_id in tokenizer.get_vocab():
# for token_id in range(len(tokenizer)):
decode_str = tokenizer.decode([token_id], skip_special_tokens=False)
token = tokenizer.convert_ids_to_tokens([token_id], skip_special_tokens=False)[0]
tags = []
if token is None: # 有些词典有空的id(不连续)
continue
if isinstance(token, bytes):
token = token.decode("utf-8", errors="ignore")
if hasattr(tokenizer, "sp_model"): # 基于 sentencepiece 包
if tokenizer.sp_model.is_byte(token_id):
tags.append("is_byte")
byte_tokens.append(token)
language_tags = detect_language_by_unicode(decode_str)
for language in language_tags:
tokens_by_lang[language[1]].append(decode_str)
if contains_digit(decode_str):
tags.append("digit")
digit_tokens.append(decode_str)
space_count = get_space_count(decode_str)
if space_count > 0:
space_tokens.append(decode_str)
buffer.append(json.dumps(
{
"id": token_id,
"token": token,
"token_decode": decode_str,
"token_dumps": json.dumps(token),
"token_unicode": _to_unicode(token),
"token_len": len(decode_str),
},
ensure_ascii=False) + "\n")
result = {
"tokenizer": tokenizer_factory.get_name_with_hyperlink(tokenizer_name),
"organization": tokenizer_config.org,
# "impl": str(tokenizer.__class__),
# "vocab_size-": tokenizer.vocab_size, # vocab_size_without_added_token
"vocab_size": len(tokenizer),
# "中文汉字编码长度均值": mean_length, # 不用统计,因为字典包含中文字符多,一般就意味着 中文汉字编码长度短。
# "中文汉字编码长度分布": json.dumps(dist_length),
"num(digit)": len(digit_tokens),
"len(digit)": _dist([len(token) for token in digit_tokens]),
"num(space)": len(space_tokens),
"len(space)": _dist([len(token) for token in space_tokens]),
# "num(byte)": len(byte_tokens)
}
for lang, tokens in tokens_by_lang.items():
result[f"num({lang})"] = len(tokens)
result["len(" + lang + ")"] = _dist([len(token) for token in tokens])
out_path = os.path.join(cache_dir, f"iter_vocab/{tokenizer_name.replace('/', '_')}.vocab.jsonl")
with open(out_path, "w", encoding="utf-8") as f_out:
for line in buffer:
f_out.write(line)
len_before = len(cache)
cache[tokenizer_name] = result
len_after = len(cache)
logger.info(f"saving {tokenizer_name} to memory and file cache: {len_before}->{len_after}")
with open(cache_path, "w", encoding="utf-8") as f_out:
f_out.write(json.dumps(cache, ensure_ascii=False, indent=2))
return result
def to_dataframe(stats, columns):
table = []
for stat in stats.values():
filtered_stat = {}
for k, v in stat.items():
if not k.startswith("num") and not k.startswith("len"):
filtered_stat[k] = v
if any(column in k for column in columns):
k = k.replace("ja-kana", "kana")
filtered_stat[k] = v
table.append(filtered_stat)
df = pd.DataFrame(table)
return df
def get_character_table(
tokenizer_filter: Optional[str] = None,
columns: Optional[list] = None,
return_type: Optional[Literal["dict", "dataframe"]] = "dataframe"
) -> Union[pd.DataFrame, dict]:
"""
"""
logger.info(f"columns: {columns}, tokenizer_filter: {tokenizer_filter}")
stats = {}
if columns is None:
columns = default_columns
if tokenizer_filter is not None:
tokenizer_names = [tokenizer_config.name_or_path for tokenizer_config in tokenizer_factory.all_tokenizer_configs
if tokenizer_filter.lower() in tokenizer_config.name_or_path.lower()]
else:
tokenizer_names = tokenizer_factory.all_tokenizer_names
for tokenizer_name in tokenizer_names:
stat = iter_vocab(tokenizer_name)
stats[tokenizer_name] = stat
if return_type == "dataframe":
stats = to_dataframe(stats, columns)
return stats
if __name__ == "__main__":
# aa = get_character_table(tokenizer_filter="baichuan")
df = get_character_table()
logger.info(f"\n{df.to_markdown(index=False)}")