File size: 7,162 Bytes
d10ecd7
 
9d1b27e
d10ecd7
2bd606a
 
79b95c3
9495a4f
d10ecd7
2bd606a
 
 
 
 
 
 
 
d10ecd7
9d1b27e
 
 
2bd606a
 
 
 
 
1b7fc74
d10ecd7
2bd606a
 
d27a756
 
 
d10ecd7
 
 
 
 
 
 
 
f4973d4
d10ecd7
 
 
 
 
2bd606a
1b7fc74
79b95c3
d10ecd7
 
9495a4f
d10ecd7
 
 
9495a4f
d10ecd7
2bd606a
 
e6543ac
 
 
d10ecd7
 
a6c67ec
d10ecd7
 
 
 
 
9495a4f
d10ecd7
 
 
 
 
1b7fc74
9d1b27e
d10ecd7
9d1b27e
 
 
 
 
 
 
 
 
 
 
 
d10ecd7
 
9495a4f
 
 
 
b15345c
d10ecd7
 
 
 
7cb27ea
1b7fc74
 
2bd606a
814ee6b
 
d10ecd7
2bd606a
 
 
 
 
 
1b7fc74
9495a4f
2bd606a
 
 
9495a4f
 
 
 
 
 
 
 
 
 
 
 
 
d10ecd7
9495a4f
2bd606a
d10ecd7
 
 
7a8d6d6
9495a4f
 
 
 
 
 
7a8d6d6
 
 
 
9495a4f
7cb27ea
9495a4f
 
 
 
 
 
7a8d6d6
 
 
2bd606a
 
 
7a8d6d6
9495a4f
 
 
2bd606a
 
 
814ee6b
d10ecd7
 
 
 
 
 
2bd606a
9495a4f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import gradio as gr
import json
import copy
import pandas as pd
from vocab import tokenizer_factory
from character_util import iter_vocab
from utils.log_util import logger
from functools import lru_cache

default_user_input = """\
Replace this text in the input field to see how tokenization works.
Buenos días!
华为发布Mate60手机。
ラグビーワールドカップ2023フランス"""
# default_tokenizer_name_1 = "Meta/llama3"
default_tokenizer_name_1 = "gradientai/Llama-3-8B-Instruct-Gradient-1048k"
default_tokenizer_name_2 = "openai/gpt-4"


@lru_cache
def _tokenize(
        text: str,
        tokenizer_name: str,
        color_num: int = 5,
        add_special_token: bool = False
):
    logger.info("param=" + json.dumps({"text": text, "tokenizer_type": tokenizer_name}, ensure_ascii=False))
    pos_tokens = []
    tokenizer = tokenizer_factory.get_tokenizer(tokenizer_name)
    if add_special_token:
        encoding = tokenizer.encode(text, add_special_tokens=True)
    else:
        encoding = tokenizer.encode(text, add_special_tokens=False)

    table = []

    for idx, token_id in enumerate(encoding):
        decode_text = tokenizer.decode([token_id])  # 特殊字符解码后会统一变成 �,对应 "\ufffd"
        pos_tokens.extend([(decode_text, str(idx % color_num))])

        # token  "Byte":  # 这是 utf-8编码吧?
        token = tokenizer.convert_ids_to_tokens([token_id], skip_special_tokens=False)[0]
        if isinstance(token, bytes):
            try:
                token_str = token.decode("utf-8")
            except:
                token_str = token.decode("utf-8", errors="ignore")
                logger.error(f"{idx}: decode_error: " + json.dumps(  # gpt_35_turbo 经常有token会decode error,这里用来记录一下
                    {"tokenizer_type": tokenizer_name, "token": str(token), "token_str": token_str},
                    ensure_ascii=False))

            token_bytes = token
            # json_dumps = json.dumps(token_str)
        elif isinstance(token, str):
            token_str = token
            token_bytes = bytes(token_str, "utf-8")
            # json_dumps = json.dumps(token_str)
        else:
            logger.error(f"{idx}: wrong type for token {token_id} {type(token)} " + json.dumps(
                {"text": text, "tokenizer_type": tokenizer_name}, ensure_ascii=False))
            token_str = token
            token_bytes = token
            # continue

        # ⭐
        # TODO: gpt3.5_turbo错误: 只有id和text是对的,token和 utf8都是错的。说明 convert_ids_to_tokens 出错了。
        table.append(
            {"TokenID": token_id,
             "Token": token_str,  # utf-8解码后的字符串,为什么有些是 <0xE7>,表示什么?比如llama
             "Text": decode_text,  #
             # "Bytes": token_bytes,  # bytes类型在gradio前端页面被解码成字符串,比如   b'\xe4\xb8\xad' 仍然显示成 "中"。因此 str(token_bytes)
             "UTF8 Bytes": str(token_bytes),
             # "Unicode": json_dumps  # unicode, 如果是ascii码,就直接显示。如果不是ascii码,就显示unicode
             }
        )

    table_df = pd.DataFrame(table)
    logger.info(f"tokenizer_type={tokenizer_name}, Tokens={table[:4]}")
    return pos_tokens, len(encoding), table_df


def tokenize(
        text: str,
        tokenizer_name: str,
        color_num: int = 5,
        add_special_token: bool = False
):
    """ tokenize wrapper
    As gr.Update would be overwritten after passing to frontend, we apply lru_cache in _tokenize.
    """
    pos_tokens, num_tokens, table_df = _tokenize(text, tokenizer_name, color_num, add_special_token)
    return gr.update(value=pos_tokens, label=f"Tokens: {num_tokens}"), table_df


def tokenize_pair(text, tokenizer_type_1, tokenizer_type_2):
    """
    input_text.change
    """
    pos_tokens_1, table_df_1 = tokenize(text, tokenizer_type_1)
    pos_tokens_2, table_df_2 = tokenize(text, tokenizer_type_2)
    return pos_tokens_1, table_df_1, pos_tokens_2, table_df_2


@lru_cache
def basic_count(tokenizer_name):
    stats = iter_vocab(tokenizer_name)
    return stats['vocab_size'], f'{stats["organization"]}'
    # return tokenizer.vocab_size, f'{stats["中文汉字数"]["中文单字"]}/{stats["中文汉字数"]["中文多字"]}'


# def get_compress_rate(tokenizer_name, all_corpus, unit):
#     tokenizer = tokenizer_factory.get_tokenizer(tokenizer_name)
#     compress_rate_stats = tokenize_corpus(tokenizer, all_corpus)
#     compress_rate = unit_convertor(compress_rate_stats, unit)
#     return compress_rate


@lru_cache
def get_overlap_token_size(tokenizer_name_1, tokenizer_name_2):
    tokenizer1 = tokenizer_factory.get_tokenizer(tokenizer_name_1)
    tokenizer2 = tokenizer_factory.get_tokenizer(tokenizer_name_2)

    vocab_set_1 = tokenizer1.get_vocab().keys()
    vocab_set_2 = tokenizer2.get_vocab().keys()

    token1 = next(iter(vocab_set_1))
    token2 = next(iter(vocab_set_2))
    if type(token1) != type(token2):  # bytes  str
        if isinstance(token1, str):
            vocab_set_1 = set([token.encode("utf-8") for token in vocab_set_1])
        if isinstance(token2, str):
            vocab_set_2 = set([token.encode("utf-8") for token in vocab_set_2])

    overlap_tokens = vocab_set_1 & vocab_set_2
    overlap_token_size = len(overlap_tokens)
    logger.info(
        f"{overlap_token_size} OverlapTokens of {tokenizer_name_1} {tokenizer_name_2}: {list(overlap_tokens)[:10]}")
    return overlap_token_size, overlap_token_size


def on_load(url_params, request: gr.Request):
    """
    onLoad
    """
    text = None
    tokenizer_type_1 = None
    tokenizer_type_2 = None
    try:
        url_params = json.loads(url_params)
    except:
        url_params = {}
    if request:
        logger.info(str(request.headers))
        client_ip = request.client.host
        # local_ip = socket.gethostbyname(socket.gethostbyname(""))
        # headers = request.kwargs['headers']
        # if headers and 'x-forwarded-for' in headers:
        #     x_forwarded_for = headers['x-forwarded-for']
        #     client_ip = x_forwarded_for.split(' ')[0] if x_forwarded_for else ""
        # if "referer" in request.headers:   # not work for huggingface-space
        #     url_params = parse_qs(urlparse(request.headers["referer"]).query)
        #     url_params = {k: v[0] for k, v in url_params.items() if len(v) > 0}
        tokenizer_type_1 = url_params.get("tokenizer1", default_tokenizer_name_1)
        tokenizer_type_2 = url_params.get("tokenizer2", default_tokenizer_name_2)
        text = url_params.get("text", default_user_input)
        logger.info(f"client_ip: {client_ip}; params: {url_params}")
    return text, tokenizer_type_1, tokenizer_type_2


# def compress_rate_unit_change(unit):
#     return gr.update(label=f"Compress Rate: {unit}"), gr.update(label=f"Compress Rate: {unit}"),


def test_coding():
    bytes1 = b'\xe4\xb8\xad'
    print(bytes1)  # b'\xe4\xb8\xad'


if __name__ == "__main__":
    print(get_overlap_token_size("gpt-35-turbo", "gpt-4"))
    # print(basic_count("internlm_chat_7b"))