File size: 6,078 Bytes
d10ecd7
 
6f9d07b
d10ecd7
d27a756
d10ecd7
 
79b95c3
9495a4f
 
d10ecd7
 
9495a4f
 
d10ecd7
 
9495a4f
d10ecd7
 
d27a756
 
 
 
d10ecd7
 
 
 
 
 
 
 
f4973d4
d10ecd7
 
 
 
 
a37f943
79b95c3
 
d10ecd7
 
9495a4f
d10ecd7
 
 
9495a4f
d10ecd7
e6543ac
 
 
 
d10ecd7
 
 
 
 
 
 
9495a4f
d10ecd7
 
 
 
 
e6543ac
d10ecd7
 
9495a4f
d10ecd7
 
9495a4f
 
 
 
 
b15345c
d10ecd7
 
 
 
7cb27ea
d10ecd7
 
 
 
 
 
9495a4f
d10ecd7
 
 
9495a4f
 
 
 
 
 
 
 
 
 
 
 
 
d10ecd7
9495a4f
 
d10ecd7
 
 
9495a4f
 
 
 
 
 
 
 
7a8d6d6
9495a4f
 
 
 
 
 
7a8d6d6
 
 
 
9495a4f
7cb27ea
9495a4f
 
 
 
 
 
7a8d6d6
 
 
 
 
 
 
9495a4f
 
 
d10ecd7
 
 
 
 
 
9495a4f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import gradio as gr
import json
import socket
import pandas as pd
import config
from vocab import load_tokener
from utils.zh_util import iter_vocab
from utils.log_util import logger
from functools import lru_cache
from urllib.parse import urlparse, parse_qs


@lru_cache
def tokenize(text, tokenizer_type, color_num=5):
    """
    """
    logger.info("param=" + json.dumps({"text": text, "tokenizer_type": tokenizer_type}, ensure_ascii=False))
    pos_tokens = []
    tokenizer = load_tokener(tokenizer_type)
    if config.ADD_SPECIAL_TOKEN:
        encoding = tokenizer.encode(text, add_special_tokens=True)
    else:
        encoding = tokenizer.encode(text, add_special_tokens=False)

    table = []

    for idx, token_id in enumerate(encoding):
        decode_text = tokenizer.decode([token_id])  # 特殊字符解码后会统一变成 �,对应 "\ufffd"
        pos_tokens.extend([(decode_text, str(idx % color_num))])

        # token  "Byte":  # 这是 utf-8编码吧?
        token = tokenizer.convert_ids_to_tokens([token_id], skip_special_tokens=False)[0]
        if isinstance(token, bytes):
            try:
                token_str = token.decode("utf-8")
            except:
                token_str = token.decode("utf-8", errors="ignore")
                logger.error("decode_error: " + json.dumps(    # gpt_35_turbo 经常有token会decode error,这里用来记录一下
                    {"tokenizer_type": tokenizer_type, "token": str(token), "token_str": token_str},
                    ensure_ascii=False))

            token_bytes = token
            # json_dumps = json.dumps(token_str)
        elif isinstance(token, str):
            token_str = token
            token_bytes = bytes(token_str, "utf-8")
            # json_dumps = json.dumps(token_str)
        else:
            logger.error(f"{idx}: wrong type for token {token_id} {type(token)} " + json.dumps({"text": text, "tokenizer_type": tokenizer_type}, ensure_ascii=False))
            token_str = token
            token_bytes = token
            # continue

        # ⭐
        table.append(
            {"TokenID": token_id,
             "Token": token_str,  # utf-8解码后的字符串,为什么有些是 <0xE7>,表示什么?比如llama
             "Text": decode_text,  #
             # "Bytes": token_bytes,  # bytes类型在gradio前端页面被解码成字符串,比如   b'\xe4\xb8\xad' 仍然显示成 "中"。因此 str(token_bytes)
             "UTF8 Bytes": str(token_bytes),
             # "Unicode": json_dumps  # unicode, 如果是ascii码,就直接显示。如果不是ascii码,就显示unicode
             }
        )

    table_df = pd.DataFrame(table)
    logger.info(f"tokenizer_type={tokenizer_type}, Tokens={table[:4]}")
    # print(table_df)

    return gr.update(value=pos_tokens, label=f"Tokens: {len(encoding)}"), table_df


@lru_cache
def tokenize_pair(text, tokenizer_type_1, tokenizer_type_2):
    """
    input_text.change
    """
    pos_tokens_1, table_df_1 = tokenize(text, tokenizer_type_1)
    pos_tokens_2, table_df_2 = tokenize(text, tokenizer_type_2)
    return pos_tokens_1, table_df_1, pos_tokens_2, table_df_2


@lru_cache
def basic_count(tokenizer_type):
    tokenizer = load_tokener(tokenizer_type)
    stats = iter_vocab(tokenizer, tokenizer_type)
    return tokenizer.vocab_size, f'{stats["中文汉字数"]["中文单字"]}/{stats["中文汉字数"]["中文多字"]}'


@lru_cache
def get_overlap_token_size(tokenizer_type_1, tokenizer_type_2):
    tokenizer1 = load_tokener(tokenizer_type_1)
    tokenizer2 = load_tokener(tokenizer_type_2)

    vocab_set_1 = tokenizer1.get_vocab().keys()
    vocab_set_2 = tokenizer2.get_vocab().keys()

    token1 = next(iter(vocab_set_1))
    token2 = next(iter(vocab_set_2))
    if type(token1) != type(token2):  # bytes  str
        if isinstance(token1, str):
            vocab_set_1 = set([token.encode("utf-8") for token in vocab_set_1])
        if isinstance(token2, str):
            vocab_set_2 = set([token.encode("utf-8") for token in vocab_set_2])

    overlap_tokens = vocab_set_1 & vocab_set_2
    overlap_token_size = len(overlap_tokens)
    logger.info(
        f"{overlap_token_size} OverlapTokens of {tokenizer_type_1} {tokenizer_type_2}: {list(overlap_tokens)[:10]}")
    return overlap_token_size, overlap_token_size


default_user_input = """Replace this text in the input field to see how tokenization works
华为发布Mate60手机
ラグビーワールドカップ2023フランス"""
default_tokenizer_type_1 = "llama"
# default_tokenizer_type_2 = "internlm_chat_7b"
default_tokenizer_type_2 = "gpt_35_turbo"


def on_load(url_params, request: gr.Request):
    """
    onLoad
    """
    text = None
    tokenizer_type_1 = None
    tokenizer_type_2 = None
    try:
        url_params = json.loads(url_params)
    except:
        url_params = {}
    if request:
        logger.info(str(request.headers))
        client_ip = request.client.host
        # local_ip = socket.gethostbyname(socket.gethostbyname(""))
        # headers = request.kwargs['headers']
        # if headers and 'x-forwarded-for' in headers:
        #     x_forwarded_for = headers['x-forwarded-for']
        #     client_ip = x_forwarded_for.split(' ')[0] if x_forwarded_for else ""
        # if "referer" in request.headers:   # not work for huggingface-space
        #     url_params = parse_qs(urlparse(request.headers["referer"]).query)
        #     url_params = {k: v[0] for k, v in url_params.items() if len(v) > 0}
        tokenizer_type_1 = url_params.get("tokenizer1", default_tokenizer_type_1)
        tokenizer_type_2 = url_params.get("tokenizer2", default_tokenizer_type_2)
        text = url_params.get("text", default_user_input)
        logger.info(f"client_ip: {client_ip}; params: {url_params}")
    return text, tokenizer_type_1, tokenizer_type_2


def test_coding():
    bytes1 = b'\xe4\xb8\xad'
    print(bytes1)  # b'\xe4\xb8\xad'


if __name__ == "__main__":
    print(get_overlap_token_size("gpt_35_turbo", "gpt_4"))
    # print(basic_count("internlm_chat_7b"))