mzltest commited on
Commit
94191ef
1 Parent(s): 194d28d

Create tokenizations/bpe_tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizations/bpe_tokenizer.py +141 -0
tokenizations/bpe_tokenizer.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ from https://github.com/openai/gpt-2/, changed for chinese
3
+ """
4
+ import json
5
+ import os
6
+ import sentencepiece as spm
7
+ """
8
+ SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation
9
+ systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements
10
+ subword units (e.g., byte-pair-encoding (BPE) [Sennrich et al.]) and unigram language model [Kudo.]) with the
11
+ extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end
12
+ system that does not depend on language-specific pre/postprocessing.
13
+ https://github.com/google/sentencepiece
14
+
15
+ pip install sentencepiece
16
+
17
+ or git clone https://github.com/google/sentencepiece.git
18
+ python setup.py install
19
+
20
+ """
21
+
22
+ def get_pairs(word):
23
+ pairs = set()
24
+ prev_char = word[0]
25
+ for char in word[1:]:
26
+ pairs.add((prev_char, char))
27
+ prev_char = char
28
+ return pairs
29
+
30
+
31
+ class Encoder:
32
+ def __init__(self, encoder, bpe_merges):
33
+ self.encoder = encoder
34
+ self.decoder = {v: k for k, v in self.encoder.items()}
35
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
36
+ self.cache = {}
37
+ self.max_len = 0
38
+
39
+ def bpe(self, token):
40
+ if token in self.cache:
41
+ return self.cache[token]
42
+ word = tuple(token)
43
+ pairs = get_pairs(word)
44
+ if not pairs:
45
+ return token
46
+
47
+ while True:
48
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
49
+ if bigram not in self.bpe_ranks:
50
+ break
51
+ first, second = bigram
52
+ new_word = []
53
+ i = 0
54
+ while i < len(word):
55
+ try:
56
+ j = word.index(first, i)
57
+ new_word.extend(word[i:j])
58
+ i = j
59
+ except:
60
+ new_word.extend(word[i:])
61
+ break
62
+
63
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
64
+ new_word.append(first + second)
65
+ i += 2
66
+ else:
67
+ new_word.append(word[i])
68
+ i += 1
69
+ new_word = tuple(new_word)
70
+ word = new_word
71
+ if len(word) == 1:
72
+ break
73
+ else:
74
+ pairs = get_pairs(word)
75
+ word = ' '.join(word)
76
+ self.cache[token] = word
77
+ return word
78
+
79
+ def encode(self, text):
80
+ return [self.encoder.get(token, 1) for token in self.tokenize(text)]
81
+
82
+ def decode(self, tokens):
83
+ text = ''.join([self.decoder[token] for token in tokens])
84
+ return text
85
+
86
+ def tokenize(self, text):
87
+ bpe_tokens = []
88
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(text).split(' '))
89
+ return bpe_tokens
90
+
91
+ def convert_tokens_to_ids(self, tokens):
92
+ return [self.encoder.get(token, 1) for token in tokens]
93
+
94
+ class Encoder_SP:
95
+ def __init__(self, model_path):
96
+ self.sp = spm.SentencePieceProcessor()
97
+ self.sp.Load(model_path)
98
+
99
+
100
+ def encode(self, text):
101
+ """
102
+ text="...."
103
+ """
104
+ return self.sp.EncodeAsIds(text)
105
+
106
+
107
+ def decode(self, tokens):
108
+ """
109
+ tokens=[x1,x2,...]
110
+ """
111
+ text = [int(token) for token in tokens]
112
+ #print(text)
113
+ return self.sp.DecodeIds(text)
114
+
115
+ def tokenize(self, text):
116
+ return self.sp.EncodeAsPieces(text)
117
+
118
+ def convert_tokens_to_ids(self, tokens):
119
+ return [self.sp.PieceToId(token) for token in tokens]
120
+
121
+ def get_encoder(encoder_file, bpe_file):
122
+
123
+ #以下是为了同一个函数入兼容sentencepiece
124
+ filepath, filename = os.path.split(encoder_file)
125
+ shotname, extension = os.path.splitext(filename)
126
+
127
+ if(".model" == extension) and (bpe_file == ""):
128
+ return Encoder_SP(encoder_file)
129
+ else:
130
+ with open(encoder_file, 'r', encoding="utf-8") as f:
131
+ encoder = json.load(f)
132
+ with open(bpe_file, 'r', encoding="utf-8") as f:
133
+ bpe_data = f.read()
134
+ bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
135
+ return Encoder(
136
+ encoder=encoder,
137
+ bpe_merges=bpe_merges,
138
+ )
139
+
140
+
141
+