nguyenvulebinh commited on
Commit
90a61c9
·
1 Parent(s): 9fee0d7

Upload envibert_tokenizer.py

Browse files
Files changed (1) hide show
  1. envibert_tokenizer.py +317 -0
envibert_tokenizer.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # !pip install sentencepiece==0.1.96 transformers==4.10.0
2
+ import sentencepiece as spm
3
+ import os
4
+ from transformers import PreTrainedTokenizer
5
+ from collections import Counter
6
+ from typing import List, Optional
7
+
8
+
9
+ class RobertaTokenizer(PreTrainedTokenizer):
10
+ def __init__(
11
+ self,
12
+ pretrained_file,
13
+ bos_token="<s>",
14
+ eos_token="</s>",
15
+ sep_token="</s>",
16
+ cls_token="<s>",
17
+ unk_token="<unk>",
18
+ pad_token="<pad>",
19
+ mask_token="<mask>",
20
+ **kwargs
21
+ ):
22
+ super().__init__(
23
+ bos_token=bos_token,
24
+ eos_token=eos_token,
25
+ unk_token=unk_token,
26
+ sep_token=sep_token,
27
+ cls_token=cls_token,
28
+ pad_token=pad_token,
29
+ mask_token=mask_token,
30
+ **kwargs,
31
+ )
32
+
33
+ # load bpe model and vocab file
34
+ sentencepiece_model = os.path.join(pretrained_file, 'sentencepiece.bpe.model')
35
+ vocab_file = os.path.join(pretrained_file, 'dict.txt')
36
+ self.sp_model = spm.SentencePieceProcessor()
37
+ self.sp_model.Load(
38
+ sentencepiece_model) # please dont use anything from sp_model bcz it makes everything goes wrong
39
+
40
+ self.bpe_dict = Dictionary().load(vocab_file)
41
+
42
+ # Mimic fairseq token-to-id alignment for the first 4 token
43
+ self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
44
+
45
+ # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
46
+ self.fairseq_offset = 0
47
+
48
+ self.fairseq_tokens_to_ids["<mask>"] = len(self.bpe_dict) + self.fairseq_offset
49
+ self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
50
+
51
+ def _tokenize(self, text):
52
+ return self.sp_model.EncodeAsPieces(text)
53
+
54
+ def _convert_token_to_id(self, token):
55
+ """ Converts a token (str) in an id using the vocab. """
56
+ if token in self.fairseq_tokens_to_ids:
57
+ return self.fairseq_tokens_to_ids[token]
58
+ spm_id = self.bpe_dict.index(token)
59
+ return spm_id
60
+
61
+ def _convert_id_to_token(self, index):
62
+ """Converts an index (integer) in a token (str) using the vocab."""
63
+ if index in self.fairseq_ids_to_tokens:
64
+ return self.fairseq_ids_to_tokens[index]
65
+ return self.bpe_dict[index]
66
+
67
+ def build_inputs_with_special_tokens(
68
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
69
+ ) -> List[int]:
70
+ """
71
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
72
+ adding special tokens.
73
+
74
+ This implementation does not add special tokens and this method should be overridden in a subclass.
75
+
76
+ Args:
77
+ token_ids_0 (:obj:`List[int]`): The first tokenized sequence.
78
+ token_ids_1 (:obj:`List[int]`, `optional`): The second tokenized sequence.
79
+
80
+ Returns:
81
+ :obj:`List[int]`: The model input with special tokens.
82
+ """
83
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
84
+
85
+ def create_token_type_ids_from_sequences(
86
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
87
+ ) -> List[int]:
88
+ """
89
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does
90
+ not make use of token type ids, therefore a list of zeros is returned.
91
+
92
+ Args:
93
+ token_ids_0 (:obj:`List[int]`):
94
+ List of IDs.
95
+ token_ids_1 (:obj:`List[int]`, `optional`):
96
+ Optional second list of IDs for sequence pairs.
97
+
98
+ Returns:
99
+ :obj:`List[int]`: List of zeros.
100
+
101
+ """
102
+
103
+ sep = [self.sep_token_id]
104
+ cls = [self.cls_token_id]
105
+
106
+ return len(cls + token_ids_0 + sep) * [0]
107
+
108
+ @property
109
+ def vocab_size(self):
110
+ return len(self.bpe_dict) + self.fairseq_offset + 1 # Add the <mask> token
111
+
112
+ def get_vocab(self):
113
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
114
+ vocab.update(self.added_tokens_encoder)
115
+ return vocab
116
+
117
+
118
+ class Dictionary(object):
119
+ """A mapping from symbols to consecutive integers"""
120
+
121
+ def __init__(
122
+ self,
123
+ pad='<pad>',
124
+ eos='</s>',
125
+ unk='<unk>',
126
+ bos='<s>',
127
+ extra_special_symbols=None,
128
+ ):
129
+ self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
130
+ self.symbols = []
131
+ self.count = []
132
+ self.indices = {}
133
+ self.bos_index = self.add_symbol(bos)
134
+ self.pad_index = self.add_symbol(pad)
135
+ self.eos_index = self.add_symbol(eos)
136
+ self.unk_index = self.add_symbol(unk)
137
+ if extra_special_symbols:
138
+ for s in extra_special_symbols:
139
+ self.add_symbol(s)
140
+ self.nspecial = len(self.symbols)
141
+
142
+ def __eq__(self, other):
143
+ return self.indices == other.indices
144
+
145
+ def __getitem__(self, idx):
146
+ if idx < len(self.symbols):
147
+ return self.symbols[idx]
148
+ return self.unk_word
149
+
150
+ def __len__(self):
151
+ """Returns the number of symbols in the dictionary"""
152
+ return len(self.symbols)
153
+
154
+ def __contains__(self, sym):
155
+ return sym in self.indices
156
+
157
+ def index(self, sym):
158
+ """Returns the index of the specified symbol"""
159
+ assert isinstance(sym, str)
160
+ if sym in self.indices:
161
+ return self.indices[sym]
162
+ return self.unk_index
163
+
164
+ def unk_string(self, escape=False):
165
+ """Return unknown string, optionally escaped as: <<unk>>"""
166
+ if escape:
167
+ return '<{}>'.format(self.unk_word)
168
+ else:
169
+ return self.unk_word
170
+
171
+ def add_symbol(self, word, n=1):
172
+ """Adds a word to the dictionary"""
173
+ if word in self.indices:
174
+ idx = self.indices[word]
175
+ self.count[idx] = self.count[idx] + n
176
+ return idx
177
+ else:
178
+ idx = len(self.symbols)
179
+ self.indices[word] = idx
180
+ self.symbols.append(word)
181
+ self.count.append(n)
182
+ return idx
183
+
184
+ def update(self, new_dict):
185
+ """Updates counts from new dictionary."""
186
+ for word in new_dict.symbols:
187
+ idx2 = new_dict.indices[word]
188
+ if word in self.indices:
189
+ idx = self.indices[word]
190
+ self.count[idx] = self.count[idx] + new_dict.count[idx2]
191
+ else:
192
+ idx = len(self.symbols)
193
+ self.indices[word] = idx
194
+ self.symbols.append(word)
195
+ self.count.append(new_dict.count[idx2])
196
+
197
+ def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
198
+ """Sort symbols by frequency in descending order, ignoring special ones.
199
+
200
+ Args:
201
+ - threshold defines the minimum word count
202
+ - nwords defines the total number of words in the final dictionary,
203
+ including special symbols
204
+ - padding_factor can be used to pad the dictionary size to be a
205
+ multiple of 8, which is important on some hardware (e.g., Nvidia
206
+ Tensor Cores).
207
+ """
208
+ if nwords <= 0:
209
+ nwords = len(self)
210
+
211
+ new_indices = dict(zip(self.symbols[:self.nspecial], range(self.nspecial)))
212
+ new_symbols = self.symbols[:self.nspecial]
213
+ new_count = self.count[:self.nspecial]
214
+
215
+ c = Counter(dict(sorted(zip(self.symbols[self.nspecial:], self.count[self.nspecial:]))))
216
+ for symbol, count in c.most_common(nwords - self.nspecial):
217
+ if count >= threshold:
218
+ new_indices[symbol] = len(new_symbols)
219
+ new_symbols.append(symbol)
220
+ new_count.append(count)
221
+ else:
222
+ break
223
+
224
+ threshold_nwords = len(new_symbols)
225
+ if padding_factor > 1:
226
+ i = 0
227
+ while threshold_nwords % padding_factor != 0:
228
+ symbol = 'madeupword{:04d}'.format(i)
229
+ new_indices[symbol] = len(new_symbols)
230
+ new_symbols.append(symbol)
231
+ new_count.append(0)
232
+ i += 1
233
+ threshold_nwords += 1
234
+
235
+ assert len(new_symbols) % padding_factor == 0
236
+ assert len(new_symbols) == len(new_indices)
237
+
238
+ self.count = list(new_count)
239
+ self.symbols = list(new_symbols)
240
+ self.indices = new_indices
241
+
242
+ def bos(self):
243
+ """Helper to get index of beginning-of-sentence symbol"""
244
+ return self.bos_index
245
+
246
+ def pad(self):
247
+ """Helper to get index of pad symbol"""
248
+ return self.pad_index
249
+
250
+ def eos(self):
251
+ """Helper to get index of end-of-sentence symbol"""
252
+ return self.eos_index
253
+
254
+ def unk(self):
255
+ """Helper to get index of unk symbol"""
256
+ return self.unk_index
257
+
258
+ @classmethod
259
+ def load(cls, f):
260
+ """Loads the dictionary from a text file with the format:
261
+
262
+ ```
263
+ <symbol0> <count0>
264
+ <symbol1> <count1>
265
+ ...
266
+ ```
267
+ """
268
+ d = cls()
269
+ d.add_from_file(f)
270
+ return d
271
+
272
+ def add_from_file(self, f):
273
+ """
274
+ Loads a pre-existing dictionary from a text file and adds its symbols
275
+ to this instance.
276
+ """
277
+ if isinstance(f, str):
278
+ try:
279
+ with open(f, 'r', encoding='utf-8') as fd:
280
+ self.add_from_file(fd)
281
+ except FileNotFoundError as fnfe:
282
+ raise fnfe
283
+ except UnicodeError:
284
+ raise Exception("Incorrect encoding detected in {}, please "
285
+ "rebuild the dataset".format(f))
286
+ return
287
+
288
+ lines = f.readlines()
289
+ indices_start_line = self._load_meta(lines)
290
+ for line in lines[indices_start_line:]:
291
+ idx = line.rfind(' ')
292
+ if idx == -1:
293
+ raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'")
294
+ word = line[:idx]
295
+ count = int(line[idx + 1:])
296
+ self.indices[word] = len(self.symbols)
297
+ self.symbols.append(word)
298
+ self.count.append(count)
299
+
300
+ def _save(self, f, kv_iterator):
301
+ if isinstance(f, str):
302
+ os.makedirs(os.path.dirname(f), exist_ok=True)
303
+ with open(f, 'w', encoding='utf-8') as fd:
304
+ return self.save(fd)
305
+ for k, v in kv_iterator:
306
+ print('{} {}'.format(k, v), file=f)
307
+
308
+ def _get_meta(self):
309
+ return [], []
310
+
311
+ def _load_meta(self, lines):
312
+ return 0
313
+
314
+ def save(self, f):
315
+ """Stores dictionary into a text file"""
316
+ ex_keys, ex_vals = self._get_meta()
317
+ self._save(f, zip(ex_keys + self.symbols[self.nspecial:], ex_vals + self.count[self.nspecial:]))