File size: 3,510 Bytes
d10ecd7
 
 
 
 
79b95c3
d10ecd7
 
b15345c
d10ecd7
 
 
79b95c3
d10ecd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79b95c3
 
 
d10ecd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79b95c3
d10ecd7
 
b15345c
 
 
 
d10ecd7
 
 
b15345c
d10ecd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79b95c3
d10ecd7
 
 
 
 
 
 
 
 
79b95c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import json
import pandas as pd
from vocab import load_tokener
from utils.zh_util import iter_vocab
from utils.log_util import logger


def tokenize(text, tokenizer_type, color_num=5, update=True):
    """
    TODO: cache tokenizer
    """
    logger.info("[param]:" + json.dumps({"text": text, "tokenizer_type": tokenizer_type}, ensure_ascii=False))
    pos_tokens = []
    tokenizer = load_tokener(tokenizer_type)
    encoding = tokenizer.encode(text)

    table = []

    for idx, token_id in enumerate(encoding):
        decode_text = tokenizer.decode([token_id])  # 特殊字符解码后会统一变成 �,对应 "\ufffd"
        pos_tokens.extend([(decode_text, str(idx % color_num))])

        # token  "Byte":  # 这是 utf-8编码吧?
        token = tokenizer.convert_ids_to_tokens([token_id])[0]
        if isinstance(token, bytes):
            try:
                token_str = token.decode("utf-8")
            except:
                token_str = token.decode("utf-8", errors="ignore")
                logger.info("[decode_error]: " + json.dumps(
                    {"tokenizer_type": tokenizer_type, "token": str(token), "token_str": token_str},
                    ensure_ascii=False))

            token_bytes = token
            json_dumps = json.dumps(token_str)
        elif isinstance(token, str):
            token_str = token
            token_bytes = bytes(token_str, "utf-8")
            json_dumps = json.dumps(token_str)
        else:
            return

        # ⭐
        table.append(
            {"TokenID": token_id,
             "Token": token_str,  # utf-8解码后的字符串,为什么有些是 <0xE7>,表示什么?比如llama
             "Text": decode_text,  #
             # "Bytes": token_bytes,  # bytes类型在gradio前端页面被解码成字符串,比如   b'\xe4\xb8\xad' 仍然显示成 "中"。因此 str(token_bytes)
             "Bytes": str(token_bytes),
             # "Unicode": json_dumps  # unicode, 如果是ascii码,就直接显示。如果不是ascii码,就显示unicode
             }
        )

    table_df = pd.DataFrame(table)
    logger.info(f"[Tokens {tokenizer_type}]: {table[:2]}")
    # print(table_df)

    if update:
        return gr.update(value=pos_tokens, label=f"Tokens: {len(encoding)}"), table_df
    else:
        return pos_tokens, table_df


def tokenize_pair(text, tokenizer_type_1, tokenizer_type_2):
    pos_tokens_1, table_df_1 = tokenize(text, tokenizer_type_1)
    pos_tokens_2, table_df_2 = tokenize(text, tokenizer_type_2)
    return pos_tokens_1, table_df_1, pos_tokens_2, table_df_2


def basic_count(tokenizer_type):
    tokenizer = load_tokener(tokenizer_type)
    stats = iter_vocab(tokenizer, tokenizer_type)
    return tokenizer.vocab_size, f'{stats["中文汉字数"]["中文单字"]}/{stats["中文汉字数"]["中文多字"]}'


def get_overlap_token_size(tokenizer_type_1, tokenizer_type_2):
    tokenizer1 = load_tokener(tokenizer_type_1)
    tokenizer2 = load_tokener(tokenizer_type_2)
    vocab1 = tokenizer1.get_vocab()
    vocab2 = tokenizer2.get_vocab()
    overlap_tokens = vocab1.keys() & vocab2.keys()
    overlap_token_size = len(overlap_tokens)
    logger.info(f"[OverlapTokens of {tokenizer_type_1} {tokenizer_type_2}]: {list(overlap_tokens)[:10]}")
    return overlap_token_size, overlap_token_size


def test_coding():
    bytes1 = b'\xe4\xb8\xad'
    print(bytes1)  # b'\xe4\xb8\xad'


if __name__ == "__main__":
    print(basic_count("internlm_chat_7b"))