File size: 6,035 Bytes
5fded96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c30dc0f
5fded96
c30dc0f
5fded96
 
 
 
c30dc0f
 
 
 
 
 
 
 
5fded96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import regex as re
import base64
import tiktoken
import os
import json
from transformers import PreTrainedTokenizer

class BaseTokenizer(PreTrainedTokenizer):
    """Abstract class for tokenizer."""

    def __init__(self, **kwargs):
        super().__init__()

    @property
    def add_prefix_space(self):
        return False

    @property
    def vocab_size(self):
        raise NotImplemented

    def tokenize(self, text):
        raise NotImplemented

    def detokenize(self, token_ids, ignore_special_tokens=True):
        raise NotImplemented

    def build_single_message(self, role, metadata, message):
        assert role in ["system", "user", "assistant", "observation"], role
        role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
        message_tokens = self.tokenizer.encode(message, disallowed_special=())
        tokens = role_tokens + message_tokens
        return tokens

    def build_chat_input(self, query, history=None, role="user", metadata=""):
        if history is None:
            history = []
        input_ids = []
        for item in history:
            content = item["content"]
            if item["role"] == "system" and "tools" in item:
                content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
            input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
        input_ids.extend(self.build_single_message(role, metadata, query))
        input_ids.extend([self.get_command("<|assistant|>")])
        return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)

    @property
    def eos_id(self):
        raise NotImplemented

    def get_command(self, token):
        return NotImplemented

class TikTokenizer(BaseTokenizer):
    vocab_files_names = {"vocab_file": "tokenizer.tiktoken"}

    def __init__(self, vocab_file, **kwargs):
        pat_str = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
        self.pat_str = re.compile(pat_str)

        self.b64_vocab = {}
        mergeable_ranks = {}
        with open(vocab_file) as f:
            for line in f:
                token, rank = line.strip().split()
                rank = int(rank)
                token = base64.b64decode(token)
                mergeable_ranks[token] = rank
                self.b64_vocab['%s' % token] = rank

        self.special_tokens = ["<|endoftext|>", "[MASK]", "[gMASK]", "[sMASK]", "<sop>", "<eop>", "<|system|>",
                               "<|user|>", "<|assistant|>", "<|observation|>"]
        self.special_tokens = {
            token: idx for idx, token in enumerate(self.special_tokens, start=len(mergeable_ranks))
        }
        self.special_token_ids = {idx: token for token, idx in self.special_tokens.items()}

        self.tokenizer = tiktoken.Encoding(
            name="my_tokenizer",
            pat_str=pat_str,
            mergeable_ranks=mergeable_ranks,
            special_tokens=self.special_tokens
        )
        self.decoder = {rank: token for token, rank in mergeable_ranks.items()}
        self.n_words = len(self.decoder) + len(self.special_tokens)
        super().__init__()

    @property
    def add_prefix_space(self):
        return False

    def tokenize(self, text, add_special_tokens=True):
        ids = self.encode(text, add_special_tokens=add_special_tokens)
        return [self.convert_id_to_token(_id) for _id in ids]

    def detokenize(self, ids, ignore_special_tokens=True):
        if ignore_special_tokens:
            ids = [idx for idx in ids if idx not in self.special_token_ids]
        return self.tokenizer.decode(ids)

    def encode(self, text, add_special_tokens=True):
        ids = self.tokenizer.encode(text, disallowed_special=(), allowed_special="all")
        if add_special_tokens:
            ids = [self.special_tokens["[gMASK]"], self.special_tokens["<sop>"]] + ids
        return ids

    def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=False):
        if type(ids) is int:
            ids = [ids]
        return self.detokenize(ids, ignore_special_tokens=skip_special_tokens)

    def encode_pieces(self, text):
        ids = self.tokenizer.encode(text, disallowed_special=())
        return list(map(lambda x: self.decoder[x].detokenize('utf-8', errors='replace'), ids))

    @property
    def vocab_size(self):
        return self.n_words

    @property
    def eos_token_id(self):
        return self.special_tokens["<|endoftext|>"]

    def convert_token_to_id(self, token):
        """ Converts a token (str) in an id using the vocab. """
        if token in self.special_tokens:
            return self.special_tokens[token]
        # assert type(token) == str, "type of token (%s) is %s" % (token, type(token))
        # ids = self.tokenizer.encode(token, disallowed_special=())
        if token in self.b64_vocab:
            return self.b64_vocab[token]
        # if len(ids) == 1:
            # return ids[0]
        else:
            raise RuntimeError(f"{token} is not a single token")

    def _convert_token_to_id(self, token):
        return self.convert_token_to_id(token)

    def convert_id_to_token(self, index):
        if index in self.special_token_ids:
            return self.special_token_ids[index]
        return '%s' % self.decoder[index]
        # try:
        #     return self.decoder[index].decode('utf-8')
        # except Exception as e:
        #     print("Exception: %s for (%d)%s" % (e, index, self.decoder[index]))
        #     return ""
        #return self.decoder[index].detokenize('utf-8', errors='replace')

    def _convert_id_to_token(self, index):
        return self.convert_id_to_token(index)

    def get_command(self, token):
        return self.special_tokens[token]

    def get_vocab(self):
        vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
        return vocab