import gradio as gr import json import socket import pandas as pd from vocab import load_tokener from utils.zh_util import iter_vocab from utils.log_util import logger def tokenize(text, tokenizer_type, color_num=5, update=True): """ TODO: cache tokenizer """ logger.info("[param]:" + json.dumps({"text": text, "tokenizer_type": tokenizer_type}, ensure_ascii=False)) pos_tokens = [] tokenizer = load_tokener(tokenizer_type) encoding = tokenizer.encode(text) table = [] for idx, token_id in enumerate(encoding): decode_text = tokenizer.decode([token_id]) # 特殊字符解码后会统一变成 �,对应 "\ufffd" pos_tokens.extend([(decode_text, str(idx % color_num))]) # token "Byte": # 这是 utf-8编码吧? token = tokenizer.convert_ids_to_tokens([token_id])[0] if isinstance(token, bytes): try: token_str = token.decode("utf-8") except: token_str = token.decode("utf-8", errors="ignore") logger.info("[decode_error]: " + json.dumps( {"tokenizer_type": tokenizer_type, "token": str(token), "token_str": token_str}, ensure_ascii=False)) token_bytes = token json_dumps = json.dumps(token_str) elif isinstance(token, str): token_str = token token_bytes = bytes(token_str, "utf-8") json_dumps = json.dumps(token_str) else: return # ⭐ table.append( {"TokenID": token_id, "Token": token_str, # utf-8解码后的字符串,为什么有些是 <0xE7>,表示什么?比如llama "Text": decode_text, # # "Bytes": token_bytes, # bytes类型在gradio前端页面被解码成字符串,比如 b'\xe4\xb8\xad' 仍然显示成 "中"。因此 str(token_bytes) "Bytes": str(token_bytes), # "Unicode": json_dumps # unicode, 如果是ascii码,就直接显示。如果不是ascii码,就显示unicode } ) table_df = pd.DataFrame(table) logger.info(f"[Tokens {tokenizer_type}]: {table[:2]}") # print(table_df) if update: return gr.update(value=pos_tokens, label=f"Tokens: {len(encoding)}"), table_df else: return pos_tokens, table_df, len(encoding) def tokenize_pair(text, tokenizer_type_1, tokenizer_type_2, request: gr.Request): if request: client_ip = request.client.host # local_ip = socket.gethostbyname(socket.gethostbyname("")) headers = request.kwargs['headers'] if headers and 'x-forwarded-for' in headers: x_forwarded_for = headers['x-forwarded-for'] client_ip = x_forwarded_for.split(' ')[0] if x_forwarded_for else "" logger.info("[client ip]", client_ip, tokenizer_type_1, tokenizer_type_2) pos_tokens_1, table_df_1 = tokenize(text, tokenizer_type_1) pos_tokens_2, table_df_2 = tokenize(text, tokenizer_type_2) return pos_tokens_1, table_df_1, pos_tokens_2, table_df_2 def basic_count(tokenizer_type): tokenizer = load_tokener(tokenizer_type) stats = iter_vocab(tokenizer, tokenizer_type) return tokenizer.vocab_size, f'{stats["中文汉字数"]["中文单字"]}/{stats["中文汉字数"]["中文多字"]}' def get_overlap_token_size(tokenizer_type_1, tokenizer_type_2): tokenizer1 = load_tokener(tokenizer_type_1) tokenizer2 = load_tokener(tokenizer_type_2) vocab1 = tokenizer1.get_vocab() vocab2 = tokenizer2.get_vocab() overlap_tokens = vocab1.keys() & vocab2.keys() overlap_token_size = len(overlap_tokens) logger.info(f"[OverlapTokens of {tokenizer_type_1} {tokenizer_type_2}]: {list(overlap_tokens)[:10]}") return overlap_token_size, overlap_token_size def test_coding(): bytes1 = b'\xe4\xb8\xad' print(bytes1) # b'\xe4\xb8\xad' if __name__ == "__main__": print(basic_count("internlm_chat_7b"))