""" TODO: 1. add more language 2. check space count of bert 3. add token_impl 4. """ import os import json import numpy as np import pandas as pd from collections import Counter, defaultdict from vocab import tokenizer_factory from typing import Optional, Union, Literal from utils.log_util import logger from utils.text_util import contains_digit, get_space_count from utils.lang_util import detect_language, language_ranges CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) default_columns = ["digit", "zh"] def _to_unicode(text): return ''.join(r'\u{:04X}'.format(ord(chr)) for chr in text) def _get_coding_length(tokenizer, vocab, filter=None): """ oov character may be tokenized into more than one token. """ all_length = [] for word in vocab: if len(word) > 1: continue if filter is not None and filter(word): continue try: tokens = tokenizer.encode(word) except Exception as e: print(e) all_length.append(len(tokens)) # if len(tokens.ids) > 1: # if len(tokens) > 3: # print(word, tokens) dist_length = Counter(all_length) mean_length = round(sum(all_length) / len(all_length), 2) return dist_length, mean_length cache = {} def _dist(token_lens): """ :param token_lens: :return: min,median,max of token_lens """ if not token_lens: return "-" return f"{min(token_lens)},{round(np.median(token_lens))},{max(token_lens)}" def iter_vocab( tokenizer_name: str, from_cache: bool = True, cache_dir: str = "stats", ) -> Union[pd.DataFrame, dict]: """ :param tokenizer_name: :param from_cache: :param cache_dir: :return: """ tokenizer_config = tokenizer_factory.get_tokenizer_config(tokenizer_name) cache_dir = os.path.join(CURRENT_DIR, cache_dir) os.makedirs(cache_dir, exist_ok=True) # load from cache cache_path = os.path.join(cache_dir, "character_stats.json") if not cache and os.path.exists(cache_path): with open(cache_path, "r", encoding="utf-8") as f_tmp: cache.update(json.load(f_tmp)) if from_cache and tokenizer_name in cache: # logger.info(f"load {tokenizer_config.name_or_path} from cache") return cache[tokenizer_name] tokenizer = tokenizer_factory.get_tokenizer(tokenizer_name) tokens_by_lang = {lang[1]: [] for lang in language_ranges.keys()} digit_tokens = [] space_tokens = [] byte_tokens = [] buffer = [] for token_id in range(tokenizer.vocab_size): # for token_id in tokenizer.get_vocab(): # for token_id in range(len(tokenizer)): decode_str = tokenizer.decode([token_id], skip_special_tokens=False) token = tokenizer.convert_ids_to_tokens([token_id], skip_special_tokens=False)[0] tags = [] if token is None: # 有些词典有空的id(不连续) continue if isinstance(token, bytes): token = token.decode("utf-8", errors="ignore") if hasattr(tokenizer, "sp_model"): # 基于 sentencepiece 包 if tokenizer.sp_model.is_byte(token_id): tags.append("is_byte") byte_tokens.append(token) language_tags = detect_language(decode_str) for language in language_tags: tokens_by_lang[language[1]].append(decode_str) if contains_digit(decode_str): tags.append("digit") digit_tokens.append(decode_str) space_count = get_space_count(decode_str) if space_count > 0: space_tokens.append(decode_str) buffer.append(json.dumps( { "id": token_id, "token": token, "token_decode": decode_str, "token_dumps": json.dumps(token), "token_unicode": _to_unicode(token), "token_len": len(decode_str), }, ensure_ascii=False) + "\n") result = { "tokenizer": tokenizer_factory.get_name_with_hyperlink(tokenizer_name), "organization": tokenizer_config.org, # "impl": str(tokenizer.__class__), # "vocab_size-": tokenizer.vocab_size, # vocab_size_without_added_token "vocab_size": len(tokenizer), # "中文汉字编码长度均值": mean_length, # 不用统计,因为字典包含中文字符多,一般就意味着 中文汉字编码长度短。 # "中文汉字编码长度分布": json.dumps(dist_length), "num(digit)": len(digit_tokens), "len(digit)": _dist([len(token) for token in digit_tokens]), "num(space)": len(space_tokens), "len(space)": _dist([len(token) for token in space_tokens]), # "num(byte)": len(byte_tokens) } for lang, tokens in tokens_by_lang.items(): result[f"num({lang})"] = len(tokens) result["len(" + lang + ")"] = _dist([len(token) for token in tokens]) out_path = os.path.join(cache_dir, f"iter_vocab/{tokenizer_name.replace('/', '_')}.vocab.jsonl") with open(out_path, "w", encoding="utf-8") as f_out: for line in buffer: f_out.write(line) len_before = len(cache) cache[tokenizer_name] = result len_after = len(cache) logger.info(f"saving {tokenizer_name} to memory and file cache: {len_before}->{len_after}") with open(cache_path, "w", encoding="utf-8") as f_out: f_out.write(json.dumps(cache, ensure_ascii=False, indent=2)) return result def to_dataframe(stats, columns): table = [] for stat in stats.values(): filtered_stat = {} for k, v in stat.items(): if not k.startswith("num") and not k.startswith("len"): filtered_stat[k] = v if any(column in k for column in columns): k = k.replace("ja-kana", "kana") filtered_stat[k] = v table.append(filtered_stat) df = pd.DataFrame(table) return df def get_character_table( tokenizer_filter: Optional[str] = None, columns: Optional[list] = None, return_type: Optional[Literal["dict", "dataframe"]] = "dataframe" ) -> Union[pd.DataFrame, dict]: """ """ logger.info(f"columns: {columns}, tokenizer_filter: {tokenizer_filter}") stats = {} if columns is None: columns = default_columns if tokenizer_filter is not None: tokenizer_names = [tokenizer_config.name_or_path for tokenizer_config in tokenizer_factory.all_tokenizer_configs if tokenizer_filter.lower() in tokenizer_config.name_or_path.lower()] else: tokenizer_names = tokenizer_factory.all_tokenizer_names for tokenizer_name in tokenizer_names: stat = iter_vocab(tokenizer_name) stats[tokenizer_name] = stat if return_type == "dataframe": stats = to_dataframe(stats, columns) return stats if __name__ == "__main__": # aa = get_character_table(tokenizer_filter="baichuan") df = get_character_table() logger.info(f"\n{df.to_markdown(index=False)}")