File size: 4,452 Bytes
751936e
f4973d4
751936e
 
 
 
f4973d4
751936e
 
 
 
f4973d4
 
 
751936e
 
 
 
 
 
 
f4973d4
 
 
751936e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4973d4
 
 
 
751936e
 
f4973d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d10ecd7
 
f4973d4
d10ecd7
 
 
751936e
c75633b
751936e
f4973d4
 
 
 
 
751936e
 
f4973d4
 
 
 
 
 
 
 
 
 
c75633b
f4973d4
 
 
 
 
 
 
c75633b
f4973d4
 
 
 
 
 
 
 
 
751936e
f4973d4
751936e
 
 
f4973d4
751936e
 
 
 
 
 
 
d10ecd7
751936e
 
 
d10ecd7
751936e
 
 
 
d10ecd7
 
751936e
 
 
 
 
 
9495a4f
f4973d4
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
TODO: 繁体、简体、语种、
"""
import os
import json
from collections import Counter
from utils.text_util import is_chinese, get_zh_count, get_digit_count
from zhon.hanzi import punctuation as zh_punc

CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))

zh_tokens = [line.strip() for line in open(os.path.join(CURRENT_DIR, "vocab.jd.txt.v2"), "r", encoding="utf-8") if
             is_chinese(line.strip())]


def zh_iterator():
    for idx in range(ord(u'\u4e00'), ord(u'\u9fa5')):
        yield (chr(idx))


def get_coding_length(tokenizer, vocab, filter=None):
    """
    计算编码长度。(有些中文汉字被解码成多个token)
    """
    all_length = []
    for word in vocab:
        if len(word) > 1:
            continue
        if filter is not None and filter(word):
            continue
        tokens = tokenizer.encode(word)
        all_length.append(len(tokens))
        # if len(tokens.ids) > 1:
        # if len(tokens) > 3:
        #     print(word, tokens)

    dist_length = Counter(all_length)
    mean_length = round(sum(all_length) / len(all_length), 2)
    return dist_length, mean_length


def has_zh_punc(text):
    """
    是否包含中文标点
    """
    return any(ch in zh_punc for ch in text)



def get_space_count(text):
    space_count = 0
    for char in text:
        if len(char.strip()) == 0:
            space_count += 1
    return space_count


def remove_special_char():
    """
    :return:
    """
    # bert词典有 ##开头的
    # byteBPE词典有带空格的
    # decode_str = decode_str.strip().replace("#", "")  # TODO, 按类型
    pass


cache = {}


def iter_vocab(tokenizer, name="", from_cache=True):
    if from_cache and name in cache:
        return cache[name]

    f_out = open(name + "_vocab.jsonl", "w", encoding="utf-8")
    zh_token_count = {"total": 0, "中文单字": 0, "中文多字": 0}

    # zh_token_count = {"total": 0, "包含1个中文单字": 0, "中文多字": 0}

    # symbol_count = 0

    all_single_zh_tokens = set()
    zh_symbol_count = 0
    for token_id in range(tokenizer.vocab_size):
        decode_str = tokenizer.decode([token_id], skip_special_tokens=False)
        token = tokenizer.convert_ids_to_tokens([token_id], skip_special_tokens=False)[0]
        # tokenizer.convert_tokens_to_string(tokens)

        if token is None:  # 有些词典有空的id(不连续)
            continue
        if isinstance(token, bytes):
            token = token.decode("utf-8", errors="ignore")

        digit_count = get_digit_count(decode_str)
        zh_count = get_zh_count(decode_str)
        space_count = get_space_count(decode_str)

        f_out.write(json.dumps(
            {"id": token_id,
             "token": token,
             "token_decode": decode_str,
             "token_len": len(decode_str),
             "zh_count": zh_count,
             "space_count": space_count,
             "digit_count": digit_count,
             "zh_symbol_count": zh_symbol_count,
             },
            ensure_ascii=False) + "\n"
                    )

        if zh_count >= 1:
            zh_token_count["total"] += 1
            if zh_count > 1:
                zh_token_count["中文多字"] += 1
            else:
                zh_token_count["中文单字"] += 1
                all_single_zh_tokens.add(decode_str.strip().replace("#", ""))
    #

    dist_length, mean_length = get_coding_length(tokenizer, zh_tokens, filter=lambda k: not is_chinese(k))

    # TODO: 繁体字,简体字
    zh_token_count["中文单字-去重后"] = len(all_single_zh_tokens)

    result = {
        "name": name,
        "impl": str(tokenizer.__class__),
        "vocab_size": tokenizer.vocab_size,
        "中文汉字数": zh_token_count,
        "中文标点数": zh_symbol_count,
        "中文汉字编码长度均值": mean_length,
        "中文汉字编码长度分布": json.dumps(dist_length),
    }
    cache[name] = result
    return result


if __name__ == "__main__":
    # test_coding_length(jd_vocab_tokens, filter=lambda k: not is_chinese(k))
    # test_coding_length(zh_punc)
    # test_coding_length(zh_iterator())

    from vocab.chatglm2_6b import tokenizer; name = "chatglm2_6b"
    # from vocab.chatglm_6b import tokenizer; name="chatglm_6b"
    # from vocab.baichuan2 import tokenizer;  name="baichuan2"
    # from vocab.gpt_4 import tokenizer; name="gpt4"

    print(iter_vocab(tokenizer, name=name))