Spaces:
Running
Running
""" | |
TODO: | |
1. add more language | |
2. check space count of bert | |
3. add token_impl | |
4. | |
""" | |
import os | |
import json | |
import numpy as np | |
import pandas as pd | |
from collections import Counter, defaultdict | |
from vocab import tokenizer_factory | |
from typing import Optional, Union, Literal | |
from utils.log_util import logger | |
from utils.text_util import contains_digit, get_space_count | |
from utils.lang_util import detect_language_by_unicode, language_ranges | |
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
default_columns = ["digit", "zh"] | |
def _to_unicode(text): | |
return ''.join(r'\u{:04X}'.format(ord(chr)) for chr in text) | |
def _get_coding_length(tokenizer, vocab, filter=None): | |
""" | |
oov character may be tokenized into more than one token. | |
""" | |
all_length = [] | |
for word in vocab: | |
if len(word) > 1: | |
continue | |
if filter is not None and filter(word): | |
continue | |
try: | |
tokens = tokenizer.encode(word) | |
except Exception as e: | |
print(e) | |
all_length.append(len(tokens)) | |
# if len(tokens.ids) > 1: | |
# if len(tokens) > 3: | |
# print(word, tokens) | |
dist_length = Counter(all_length) | |
mean_length = round(sum(all_length) / len(all_length), 2) | |
return dist_length, mean_length | |
cache = {} | |
def _dist(token_lens): | |
""" | |
:param token_lens: | |
:return: min,median,max of token_lens | |
""" | |
if not token_lens: | |
return "-" | |
return f"{min(token_lens)},{round(np.median(token_lens))},{max(token_lens)}" | |
def iter_vocab( | |
tokenizer_name: str, | |
from_cache: bool = True, | |
cache_dir: str = "stats", | |
) -> Union[pd.DataFrame, dict]: | |
""" | |
:param tokenizer_name: | |
:param from_cache: | |
:param cache_dir: | |
:return: | |
""" | |
tokenizer_config = tokenizer_factory.get_tokenizer_config(tokenizer_name) | |
cache_dir = os.path.join(CURRENT_DIR, cache_dir) | |
os.makedirs(cache_dir, exist_ok=True) | |
# load from cache | |
cache_path = os.path.join(cache_dir, "character_stats.json") | |
if not cache and os.path.exists(cache_path): | |
with open(cache_path, "r", encoding="utf-8") as f_tmp: | |
cache.update(json.load(f_tmp)) | |
if from_cache and tokenizer_name in cache: | |
# logger.info(f"load {tokenizer_config.name_or_path} from cache") | |
return cache[tokenizer_name] | |
tokenizer = tokenizer_factory.get_tokenizer(tokenizer_name) | |
tokens_by_lang = {lang[1]: [] for lang in language_ranges.keys()} | |
digit_tokens = [] | |
space_tokens = [] | |
byte_tokens = [] | |
buffer = [] | |
for token_id in range(tokenizer.vocab_size): | |
# for token_id in tokenizer.get_vocab(): | |
# for token_id in range(len(tokenizer)): | |
decode_str = tokenizer.decode([token_id], skip_special_tokens=False) | |
token = tokenizer.convert_ids_to_tokens([token_id], skip_special_tokens=False)[0] | |
tags = [] | |
if token is None: # 有些词典有空的id(不连续) | |
continue | |
if isinstance(token, bytes): | |
token = token.decode("utf-8", errors="ignore") | |
if hasattr(tokenizer, "sp_model"): # 基于 sentencepiece 包 | |
if tokenizer.sp_model.is_byte(token_id): | |
tags.append("is_byte") | |
byte_tokens.append(token) | |
language_tags = detect_language_by_unicode(decode_str) | |
for language in language_tags: | |
tokens_by_lang[language[1]].append(decode_str) | |
if contains_digit(decode_str): | |
tags.append("digit") | |
digit_tokens.append(decode_str) | |
space_count = get_space_count(decode_str) | |
if space_count > 0: | |
space_tokens.append(decode_str) | |
buffer.append(json.dumps( | |
{ | |
"id": token_id, | |
"token": token, | |
"token_decode": decode_str, | |
"token_dumps": json.dumps(token), | |
"token_unicode": _to_unicode(token), | |
"token_len": len(decode_str), | |
}, | |
ensure_ascii=False) + "\n") | |
result = { | |
"tokenizer": tokenizer_factory.get_name_with_hyperlink(tokenizer_name), | |
"organization": tokenizer_config.org, | |
# "impl": str(tokenizer.__class__), | |
# "vocab_size-": tokenizer.vocab_size, # vocab_size_without_added_token | |
"vocab_size": len(tokenizer), | |
# "中文汉字编码长度均值": mean_length, # 不用统计,因为字典包含中文字符多,一般就意味着 中文汉字编码长度短。 | |
# "中文汉字编码长度分布": json.dumps(dist_length), | |
"num(digit)": len(digit_tokens), | |
"len(digit)": _dist([len(token) for token in digit_tokens]), | |
"num(space)": len(space_tokens), | |
"len(space)": _dist([len(token) for token in space_tokens]), | |
# "num(byte)": len(byte_tokens) | |
} | |
for lang, tokens in tokens_by_lang.items(): | |
result[f"num({lang})"] = len(tokens) | |
result["len(" + lang + ")"] = _dist([len(token) for token in tokens]) | |
out_path = os.path.join(cache_dir, f"iter_vocab/{tokenizer_name.replace('/', '_')}.vocab.jsonl") | |
with open(out_path, "w", encoding="utf-8") as f_out: | |
for line in buffer: | |
f_out.write(line) | |
len_before = len(cache) | |
cache[tokenizer_name] = result | |
len_after = len(cache) | |
logger.info(f"saving {tokenizer_name} to memory and file cache: {len_before}->{len_after}") | |
with open(cache_path, "w", encoding="utf-8") as f_out: | |
f_out.write(json.dumps(cache, ensure_ascii=False, indent=2)) | |
return result | |
def to_dataframe(stats, columns): | |
table = [] | |
for stat in stats.values(): | |
filtered_stat = {} | |
for k, v in stat.items(): | |
if not k.startswith("num") and not k.startswith("len"): | |
filtered_stat[k] = v | |
if any(column in k for column in columns): | |
k = k.replace("ja-kana", "kana") | |
filtered_stat[k] = v | |
table.append(filtered_stat) | |
df = pd.DataFrame(table) | |
return df | |
def get_character_table( | |
tokenizer_filter: Optional[str] = None, | |
columns: Optional[list] = None, | |
return_type: Optional[Literal["dict", "dataframe"]] = "dataframe" | |
) -> Union[pd.DataFrame, dict]: | |
""" | |
""" | |
logger.info(f"columns: {columns}, tokenizer_filter: {tokenizer_filter}") | |
stats = {} | |
if columns is None: | |
columns = default_columns | |
if tokenizer_filter is not None: | |
tokenizer_names = [tokenizer_config.name_or_path for tokenizer_config in tokenizer_factory.all_tokenizer_configs | |
if tokenizer_filter.lower() in tokenizer_config.name_or_path.lower()] | |
else: | |
tokenizer_names = tokenizer_factory.all_tokenizer_names | |
for tokenizer_name in tokenizer_names: | |
stat = iter_vocab(tokenizer_name) | |
stats[tokenizer_name] = stat | |
if return_type == "dataframe": | |
stats = to_dataframe(stats, columns) | |
return stats | |
if __name__ == "__main__": | |
# aa = get_character_table(tokenizer_filter="baichuan") | |
df = get_character_table() | |
logger.info(f"\n{df.to_markdown(index=False)}") | |