tokenizer-arena / util.py
xu-song's picture
update
9495a4f
raw
history blame
5.4 kB
import gradio as gr
import json
import socket
import pandas as pd
from vocab import load_tokener
from utils.zh_util import iter_vocab
from utils.log_util import logger
from functools import lru_cache
from urllib.parse import urlparse, parse_qs
@lru_cache
def tokenize(text, tokenizer_type, color_num=5):
"""
"""
logger.info("param=" + json.dumps({"text": text, "tokenizer_type": tokenizer_type}, ensure_ascii=False))
pos_tokens = []
tokenizer = load_tokener(tokenizer_type)
encoding = tokenizer.encode(text)
table = []
for idx, token_id in enumerate(encoding):
decode_text = tokenizer.decode([token_id]) # 特殊字符解码后会统一变成 �,对应 "\ufffd"
pos_tokens.extend([(decode_text, str(idx % color_num))])
# token "Byte": # 这是 utf-8编码吧?
token = tokenizer.convert_ids_to_tokens([token_id])[0]
if isinstance(token, bytes):
try:
token_str = token.decode("utf-8")
except:
token_str = token.decode("utf-8", errors="ignore")
logger.error("decode_error: " + json.dumps(
{"tokenizer_type": tokenizer_type, "token": str(token), "token_str": token_str},
ensure_ascii=False))
token_bytes = token
# json_dumps = json.dumps(token_str)
elif isinstance(token, str):
token_str = token
token_bytes = bytes(token_str, "utf-8")
# json_dumps = json.dumps(token_str)
else:
return
# ⭐
table.append(
{"TokenID": token_id,
"Token": token_str, # utf-8解码后的字符串,为什么有些是 <0xE7>,表示什么?比如llama
"Text": decode_text, #
# "Bytes": token_bytes, # bytes类型在gradio前端页面被解码成字符串,比如 b'\xe4\xb8\xad' 仍然显示成 "中"。因此 str(token_bytes)
"UTF8 Bytes": str(token_bytes),
# "Unicode": json_dumps # unicode, 如果是ascii码,就直接显示。如果不是ascii码,就显示unicode
}
)
table_df = pd.DataFrame(table)
logger.info(f"Tokens={table[:2]}")
# print(table_df)
return gr.update(value=pos_tokens, label=f"Tokens: {len(encoding)}"), table_df
@lru_cache
def tokenize_pair(text, tokenizer_type_1, tokenizer_type_2):
"""
input_text.change
"""
pos_tokens_1, table_df_1 = tokenize(text, tokenizer_type_1)
pos_tokens_2, table_df_2 = tokenize(text, tokenizer_type_2)
return pos_tokens_1, table_df_1, pos_tokens_2, table_df_2
def basic_count(tokenizer_type):
tokenizer = load_tokener(tokenizer_type)
stats = iter_vocab(tokenizer, tokenizer_type)
return tokenizer.vocab_size, f'{stats["中文汉字数"]["中文单字"]}/{stats["中文汉字数"]["中文多字"]}'
@lru_cache
def get_overlap_token_size(tokenizer_type_1, tokenizer_type_2):
tokenizer1 = load_tokener(tokenizer_type_1)
tokenizer2 = load_tokener(tokenizer_type_2)
vocab_set_1 = tokenizer1.get_vocab().keys()
vocab_set_2 = tokenizer2.get_vocab().keys()
token1 = next(iter(vocab_set_1))
token2 = next(iter(vocab_set_2))
if type(token1) != type(token2): # bytes str
if isinstance(token1, str):
vocab_set_1 = set([token.encode("utf-8") for token in vocab_set_1])
if isinstance(token2, str):
vocab_set_2 = set([token.encode("utf-8") for token in vocab_set_2])
overlap_tokens = vocab_set_1 & vocab_set_2
overlap_token_size = len(overlap_tokens)
logger.info(
f"{overlap_token_size} OverlapTokens of {tokenizer_type_1} {tokenizer_type_2}: {list(overlap_tokens)[:10]}")
return overlap_token_size, overlap_token_size
default_user_input = """Replace this text in the input field to see how tokenization works
华为发布Mate60手机
ラグビーワールドカップ2023フランス"""
default_tokenizer_type_1 = "llama"
# default_tokenizer_type_2 = "internlm_chat_7b"
default_tokenizer_type_2 = "gpt_35_turbo"
def on_load(request: gr.Request):
"""
onLoad
"""
text = None
tokenizer_type_1 = None
tokenizer_type_2 = None
query_params = {}
if request:
client_ip = request.client.host
# local_ip = socket.gethostbyname(socket.gethostbyname(""))
# headers = request.kwargs['headers']
# if headers and 'x-forwarded-for' in headers:
# x_forwarded_for = headers['x-forwarded-for']
# client_ip = x_forwarded_for.split(' ')[0] if x_forwarded_for else ""
if "referer" in request.headers:
query_params = parse_qs(urlparse(request.headers["referer"]).query)
query_params = {k: v[0] for k, v in query_params.items() if len(v) > 0}
tokenizer_type_1 = query_params.get("tokenizer1", default_tokenizer_type_1)
tokenizer_type_2 = query_params.get("tokenizer2", default_tokenizer_type_2)
text = query_params.get("text", default_user_input)
logger.info(f"client_ip: {client_ip}; params: {query_params}")
return text, tokenizer_type_1, tokenizer_type_2
def test_coding():
bytes1 = b'\xe4\xb8\xad'
print(bytes1) # b'\xe4\xb8\xad'
if __name__ == "__main__":
print(get_overlap_token_size("gpt_35_turbo", "gpt_4"))
# print(basic_count("internlm_chat_7b"))