Commit
·
66a3123
1
Parent(s):
23bd3af
first commit
Browse files- NeuralTextGenerator.py +484 -0
- app.py +7 -0
- requirements +4 -0
- textprocessing.py +90 -0
- utils.py +38 -0
NeuralTextGenerator.py
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
6 |
+
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
|
7 |
+
from transformers import AdamW, get_linear_schedule_with_warmup
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from textprocessing import *
|
10 |
+
from utils import *
|
11 |
+
|
12 |
+
try:
|
13 |
+
from apex import amp
|
14 |
+
|
15 |
+
APEX_AVAILABLE = True
|
16 |
+
except ModuleNotFoundError:
|
17 |
+
APEX_AVAILABLE = False
|
18 |
+
|
19 |
+
DEFAULT_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
20 |
+
|
21 |
+
|
22 |
+
class BertTextGenerator:
|
23 |
+
def __init__(self, model_version, device=DEFAULT_DEVICE, use_apex=APEX_AVAILABLE, use_fast=True,
|
24 |
+
do_basic_tokenize=True):
|
25 |
+
"""
|
26 |
+
Wrapper of a BERT model from AutoModelForMaskedLM from huggingfaces.
|
27 |
+
This class implements methods to generate text with the BERT module
|
28 |
+
Parameters
|
29 |
+
----------
|
30 |
+
model_version : str
|
31 |
+
The name of the BERT model to initialize form AutoModelForMaskedLM
|
32 |
+
device : str
|
33 |
+
Type of pytorch device to adopt. By default is set to DEFAULT_DEVICE
|
34 |
+
that is 'cuda' if cuda is available otherwise is 'cpu'
|
35 |
+
use_apex : boolean
|
36 |
+
Flag to adopt nvidia apex
|
37 |
+
"""
|
38 |
+
self.device = device
|
39 |
+
self.model_version = model_version
|
40 |
+
self.model = AutoModelForMaskedLM.from_pretrained(model_version, output_attentions=True)
|
41 |
+
self.model.to(self.device)
|
42 |
+
self.use_apex = use_apex
|
43 |
+
|
44 |
+
# Move to finetune
|
45 |
+
if use_apex:
|
46 |
+
optimizer = torch.optim.SGD(self.model.parameters(), lr=1e-3)
|
47 |
+
self.model, optimizer = amp.initialize(self.model, optimizer, opt_level="O2", keep_batchnorm_fp32=True,
|
48 |
+
loss_scale="dynamic")
|
49 |
+
|
50 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_version, do_lower_case="uncased" in model_version,
|
51 |
+
use_fast=use_fast,
|
52 |
+
do_basic_tokenize=do_basic_tokenize) # added to avoid splitting of unused tokens
|
53 |
+
self.num_attention_masks = len(self.model.base_model.base_model.encoder.layer)
|
54 |
+
self.has_format_tokenizer = False
|
55 |
+
|
56 |
+
def generate(self, save_to_path=None, n_sentences=100, seed_text="", batch_size=10, max_iter=500, verbose=False,
|
57 |
+
print_every=50, max_len=40, min_len=4, avg_len=20, std_len=4, init_mask_prob=1,
|
58 |
+
generation_method="parallel", masked_portion=1, temperature=1.0, sample=True, top_k=100, burnin=None):
|
59 |
+
'''
|
60 |
+
Principal method of the class, used to generate sentences. The methodology used to generate a batch of sentences
|
61 |
+
can be decomposed into 3 main points:
|
62 |
+
1) Initialization: each batch is initialized as a matrix of tokens where each row represent a sentence
|
63 |
+
2) Selection: for each iteration and for each sentence one or more tokens are selected and masked
|
64 |
+
3) Sampling: for each iteration BERT is used to compute logits of the masked tokens that are then used to sample
|
65 |
+
new tokens that will replace the masked ones
|
66 |
+
Parameters
|
67 |
+
==============================
|
68 |
+
(General)
|
69 |
+
------------------------------
|
70 |
+
save_to_path: str, default = None
|
71 |
+
path of txt file where to store the sentences generated
|
72 |
+
n_sentences: int, default = 100
|
73 |
+
total number of sentences to generate
|
74 |
+
seed_text: str, default = ""
|
75 |
+
Initial text used to generate the sentences
|
76 |
+
batch_size: int, default = 10
|
77 |
+
number of sentences for each batch
|
78 |
+
max_iter: int, default = 300
|
79 |
+
number of iterations
|
80 |
+
verbose: boolean, default = False
|
81 |
+
print_every:int, default = 50
|
82 |
+
print a sample from the batch every print_every iteration.Used only if verbose is True
|
83 |
+
(Length of the sentences)
|
84 |
+
------------------------------
|
85 |
+
The method can generated sentences with different length. For each batch the len of the sentences in it
|
86 |
+
is sampled from a normal distribution N(avg_len, std_len) and then rounded to the closest int.
|
87 |
+
max_len and min_len are used to clip the length
|
88 |
+
max_len: int, default = 40
|
89 |
+
maximum length of each sentence
|
90 |
+
min_len: int, default = 4
|
91 |
+
minimum length of each sentence
|
92 |
+
avg_len: float or int, default = 20
|
93 |
+
average length of the sentences
|
94 |
+
std_len: float or int, default = 4
|
95 |
+
standard deviation of the sentences
|
96 |
+
(Initialization)
|
97 |
+
------------------------------
|
98 |
+
Each batch is initialized as a matrix of tokens of dimension (batch_size x batch_len + 2), where batch_len is
|
99 |
+
selected as described above. At the beginning of each sentences is added a cls_token and at the end a sep_token.
|
100 |
+
Each other token is selected based on the value of init_mask_prob:
|
101 |
+
- if init_mask_prob == 1 -> each token is [MASK] with probability 1 (the batch is whole [MASK]s)
|
102 |
+
- if init_mask_prob == 0 -> each token is selected as a random token in the tokenizer vocabulary (the batch is init as random sentences)
|
103 |
+
- if init_mask_prob in (0, 1) -> each token is sampled as [MASK] with prob init_mask_prob or with probability
|
104 |
+
(1 - init_mask_prob) as any other token in the tokenizer vocabulary
|
105 |
+
init_mask_prob: float in [0,1], default = 1
|
106 |
+
probability of the mask token
|
107 |
+
(Selection)
|
108 |
+
------------------------------
|
109 |
+
generation_method: str, default = "parallel"
|
110 |
+
method used to select the tokens to replace at each iteration
|
111 |
+
- 'parallel': for each sentence is selected randomly one token or a percentage of tokens based on the value of masked_portion
|
112 |
+
- 'sequential': the tokens are selected sequentially. At iteration i the token in position i % batch_len is selected
|
113 |
+
- 'attention': At the first iteration one token is selected randomly for each sentence. In later iterations
|
114 |
+
for each sentence the token is selected with probabilty distribution based on the attention mask
|
115 |
+
of the token sampled in the previous iteration
|
116 |
+
masked_portion: int or float in [0, 1], default = 1
|
117 |
+
percentage of tokens to mask for each sentence. Used only if generation_method is 'parallel'
|
118 |
+
(Sampling)
|
119 |
+
------------------------------
|
120 |
+
temperature: float, default = 1
|
121 |
+
temperature for logits ( logits <- logits/temperature)
|
122 |
+
sample: boolean, default = True
|
123 |
+
when sample is True each masked token is replaced sampling randomly according to the corresponding logits
|
124 |
+
top_k: int or None, default = 100
|
125 |
+
when top_k > 0 each masked token is replaced sampling randomly according to the logits considering
|
126 |
+
only the top_k tokens. If setted to None all the tokens will be considered
|
127 |
+
burnin: int, default = None
|
128 |
+
after burnin iterations the tokens will be chosen determinsitically selecting the one with maximum
|
129 |
+
logit score
|
130 |
+
Returns
|
131 |
+
-------
|
132 |
+
list
|
133 |
+
a list of sentences (str) already detokenized and cleaned
|
134 |
+
'''
|
135 |
+
|
136 |
+
n_batches = math.ceil(n_sentences / batch_size)
|
137 |
+
|
138 |
+
if burnin is None:
|
139 |
+
burnin = max_iter
|
140 |
+
|
141 |
+
sentences = []
|
142 |
+
|
143 |
+
for batch_n in range(n_batches):
|
144 |
+
batch_sentence_len = np.round(np.random.normal(avg_len, std_len))
|
145 |
+
batch_sentence_len = int(np.clip(batch_sentence_len, min_len, max_len))
|
146 |
+
|
147 |
+
# Generate and append batch of sentences
|
148 |
+
sentences += self.generate_batch(seed_text, batch_size, max_iter, verbose=verbose, print_every=print_every,
|
149 |
+
sent_len=batch_sentence_len, init_mask_prob=init_mask_prob,
|
150 |
+
generation_method=generation_method,
|
151 |
+
masked_portion=masked_portion, temperature=temperature, sample=sample,
|
152 |
+
top_k=top_k, burnin=burnin)
|
153 |
+
|
154 |
+
# Print if verbose
|
155 |
+
if verbose and (batch_n + 1) % print_every == 0:
|
156 |
+
print("Finished batch %d in %.3fs" % (batch_n + 1, time.time() - start_time))
|
157 |
+
start_time = time.time()
|
158 |
+
|
159 |
+
# Store results
|
160 |
+
if save_to_path is not None:
|
161 |
+
with open(save_to_path, 'w') as f:
|
162 |
+
for sent in sentences:
|
163 |
+
f.write(sent + '\n')
|
164 |
+
|
165 |
+
return sentences
|
166 |
+
|
167 |
+
def generate_batch(self, seed_text, batch_size, max_iter, verbose, print_every, sent_len, init_mask_prob,
|
168 |
+
generation_method, masked_portion, temperature, sample, top_k, burnin):
|
169 |
+
|
170 |
+
# Init batch
|
171 |
+
seed_text = self.tokenizer.tokenize(
|
172 |
+
self.tokenizer.cls_token + seed_text) # add [CLS] token at the beggining of the seed_text
|
173 |
+
seed_len = len(seed_text)
|
174 |
+
batch = self.get_init_text(seed_text, sent_len, batch_size, init_mask_prob)
|
175 |
+
|
176 |
+
# Init sampling parameters
|
177 |
+
if generation_method == "parallel":
|
178 |
+
if type(masked_portion) is int:
|
179 |
+
num_mask = masked_portion
|
180 |
+
else:
|
181 |
+
num_mask = int(np.round(sent_len * masked_portion))
|
182 |
+
list_probs = None
|
183 |
+
elif generation_method == "sequential":
|
184 |
+
list_probs = None
|
185 |
+
num_mask = 1
|
186 |
+
else:
|
187 |
+
# One probability distribution for each sentence in the batch (initially uniform among all tokens)
|
188 |
+
num_mask = 1
|
189 |
+
list_probs = [np.full(sent_len, 1.0 / sent_len)] * batch_size
|
190 |
+
counter = np.zeros((batch_size, sent_len))
|
191 |
+
|
192 |
+
with torch.no_grad():
|
193 |
+
for ii in range(max_iter):
|
194 |
+
|
195 |
+
# 1. Select indices to replace
|
196 |
+
idx_to_replace = self.__select_tokens_to_replace(generation_method, sent_len, batch_size, num_mask, ii,
|
197 |
+
seed_len, list_probs)
|
198 |
+
|
199 |
+
# 2. Replace with mask
|
200 |
+
self.__replace_tokens(batch, idx_to_replace, tokens=self.tokenizer.mask_token_id)
|
201 |
+
|
202 |
+
# 3. Sample new tokens
|
203 |
+
out = self.model(batch)
|
204 |
+
logits = out['logits']
|
205 |
+
|
206 |
+
if generation_method == 'attention':
|
207 |
+
counter[np.arange(batch_size), idx_to_replace.flatten() - seed_len] += 1
|
208 |
+
attentions = torch.stack(out['attentions'])
|
209 |
+
list_probs = self.__compute_probs(attentions, batch_size, idx_to_replace, seed_len, counter)
|
210 |
+
|
211 |
+
sample = False if ii >= burnin else sample
|
212 |
+
idxs = self.generate_step(logits, gen_idx=idx_to_replace, temperature=temperature, sample=sample,
|
213 |
+
top_k=top_k)
|
214 |
+
|
215 |
+
# 4. Replace tokens
|
216 |
+
self.__replace_tokens(batch, idx_to_replace, tokens=idxs)
|
217 |
+
|
218 |
+
if verbose and ii % print_every == 0:
|
219 |
+
print_batch(self.tokenizer, batch, 3)
|
220 |
+
|
221 |
+
return self.tokenizer.batch_decode(batch, skip_special_tokens=True)
|
222 |
+
|
223 |
+
def get_init_text(self, seed_text, sent_len, batch_size, init_mask_prob):
|
224 |
+
""" Get initial sentence by padding seed_text with either masks or random words to sent_len """
|
225 |
+
|
226 |
+
seed_text = self.tokenizer.convert_tokens_to_ids(seed_text)
|
227 |
+
|
228 |
+
if init_mask_prob == 1:
|
229 |
+
batch = [seed_text + [self.tokenizer.mask_token_id] * sent_len + [self.tokenizer.sep_token_id] for _ in
|
230 |
+
range(batch_size)]
|
231 |
+
elif init_mask_prob == 0:
|
232 |
+
batch = [seed_text + np.random.randint(0, self.tokenizer.vocab_size, sent_len).tolist() + [
|
233 |
+
self.tokenizer.sep_token_id] for _ in range(batch_size)]
|
234 |
+
else:
|
235 |
+
p = [(1 - init_mask_prob) / (self.tokenizer.vocab_size - 1)] * self.tokenizer.vocab_size
|
236 |
+
p[self.tokenizer.mask_token_id] = init_mask_prob
|
237 |
+
|
238 |
+
batch = [seed_text + np.random.choice(np.arange(self.tokenizer.vocab_size), sent_len, p=p).tolist() + [
|
239 |
+
self.tokenizer.sep_token_id] for _ in range(batch_size)]
|
240 |
+
|
241 |
+
return torch.tensor(batch).to(self.device)
|
242 |
+
|
243 |
+
def __select_tokens_to_replace(self, generation_method, sent_len, batch_size, num_mask, ii, seed_len, list_probs):
|
244 |
+
if generation_method == "sequential":
|
245 |
+
kk = [[ii % sent_len] for _ in range(batch_size)]
|
246 |
+
elif generation_method == "attention":
|
247 |
+
kk = [np.random.choice(range(sent_len), num_mask, p=p).tolist() for p in list_probs]
|
248 |
+
elif generation_method == 'parallel':
|
249 |
+
# kk = np.random.randint(0, sent_len, (batch_size, num_mask))
|
250 |
+
x = np.random.randint(0, sent_len)
|
251 |
+
kk = [[x] for _ in range(batch_size)]
|
252 |
+
# elif generation_method == 'parallel original':
|
253 |
+
# x = np.random.randint(0, sent_len)
|
254 |
+
# kk = [[x] for _ in range(batch_size)]
|
255 |
+
|
256 |
+
return np.array(kk) + seed_len
|
257 |
+
|
258 |
+
def __replace_tokens(self, batch, idx_to_replace, tokens):
|
259 |
+
rows_idx = np.repeat(range(len(batch)), idx_to_replace.shape[-1]).reshape(idx_to_replace.shape)
|
260 |
+
|
261 |
+
if type(tokens) is not int:
|
262 |
+
tokens = tokens.reshape(idx_to_replace.shape)
|
263 |
+
|
264 |
+
batch[rows_idx, idx_to_replace] = tokens
|
265 |
+
|
266 |
+
def __compute_probs(self, attentions, batch_size, idx, seed_len, counter):
|
267 |
+
''' compute probabilities from attention masks'''
|
268 |
+
# list_probs = []
|
269 |
+
#
|
270 |
+
# # attentions has dimension (batch_size, num_attention_masks, sentence_len, sentence_len)
|
271 |
+
# for i in range(batch_size):
|
272 |
+
# average_prob = attentions[i, :, idx[i], :].mean(axis=0).flatten().cpu().numpy()
|
273 |
+
# average_prob = average_prob[seed_len:-1] # avoid seed_text and last token ([SEP])
|
274 |
+
# average_prob = average_prob / average_prob.sum() # normalize
|
275 |
+
# list_probs.append(average_prob)
|
276 |
+
#
|
277 |
+
# return list_probs
|
278 |
+
|
279 |
+
avg_attentions = attentions.mean(axis=(0, 2)).cpu().detach().numpy() # mean through encoders and attention masks
|
280 |
+
avg_attentions = avg_attentions[np.arange(batch_size),seed_len:-1,idx.flatten()] # for each sentence extract the
|
281 |
+
# attention corresponding to the
|
282 |
+
# masked token (avoiding special tokens and seed)
|
283 |
+
|
284 |
+
|
285 |
+
c = counter + 1
|
286 |
+
prob = avg_attentions / c
|
287 |
+
|
288 |
+
return prob / prob.sum(axis=1)[:, np.newaxis]
|
289 |
+
# def counter_penalization(attention, idx_mask, counter, **kwargs):
|
290 |
+
# a = attention.mean(
|
291 |
+
# axis=(0, 1)).cpu().detach().numpy() # mean over ax0 that is encoders and ax1 that is attention_mask
|
292 |
+
# a = a[1:-1, idx_mask].reshape(-1, 1)
|
293 |
+
# c = np.array(counter) + 1
|
294 |
+
# prob = a.flatten() / c
|
295 |
+
# prob = prob / sum(prob)
|
296 |
+
# return prob
|
297 |
+
|
298 |
+
|
299 |
+
def generate_step(self, out, gen_idx, temperature=1, sample=True, top_k=None):
|
300 |
+
""" Generate a word from from out[gen_idx]
|
301 |
+
args:
|
302 |
+
- out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
|
303 |
+
- gen_idx (int): location for which to generate for
|
304 |
+
- top_k (int): if >0, only sample from the top k most probable words
|
305 |
+
- sample (Bool): if True, sample from full distribution. Overridden by top_k
|
306 |
+
"""
|
307 |
+
if type(gen_idx) is int:
|
308 |
+
gen_idx = np.array(gen_idx)
|
309 |
+
|
310 |
+
rows_idx = np.repeat(range(len(out)), gen_idx.shape[-1]).reshape(gen_idx.shape)
|
311 |
+
|
312 |
+
logits = out[rows_idx, gen_idx]
|
313 |
+
|
314 |
+
if temperature is not None:
|
315 |
+
logits = logits / temperature
|
316 |
+
|
317 |
+
if sample:
|
318 |
+
# general sampling
|
319 |
+
if top_k is None:
|
320 |
+
dist = torch.distributions.categorical.Categorical(logits=logits)
|
321 |
+
idx = dist.sample().squeeze(-1)
|
322 |
+
# top_k sampling
|
323 |
+
else:
|
324 |
+
kth_vals, kth_idx = logits.topk(top_k, dim=-1)
|
325 |
+
dist = torch.distributions.categorical.Categorical(logits=kth_vals)
|
326 |
+
idx = kth_idx.gather(dim=-1, index=dist.sample().unsqueeze(-1)).squeeze(-1)
|
327 |
+
|
328 |
+
# burnin - deterministic
|
329 |
+
else:
|
330 |
+
idx = torch.argmax(logits, dim=-1)
|
331 |
+
|
332 |
+
return idx
|
333 |
+
|
334 |
+
def finetune(self, sentences, labels=None, encoded_dict=None, mask_percentage=0.15, epochs=4, batch_size=32,
|
335 |
+
optimizer=AdamW, optimizer_parameters=dict(lr=2e-5, eps=1e-8),
|
336 |
+
scheduler=get_linear_schedule_with_warmup, scheduler_parameters=dict(num_warmup_steps=0),
|
337 |
+
num_tokens_per_class=3
|
338 |
+
):
|
339 |
+
|
340 |
+
if encoded_dict is None:
|
341 |
+
# set encoder
|
342 |
+
if labels is None:
|
343 |
+
self.encoder = Encoder(self.tokenizer)
|
344 |
+
encoded_dict = self.encoder.encode(sentences)
|
345 |
+
else:
|
346 |
+
classes = np.unique(labels)
|
347 |
+
self.encoder = LabelEncoder(self.model, self.tokenizer, classes=classes,
|
348 |
+
num_tokens_per_class=num_tokens_per_class)
|
349 |
+
encoded_dict = self.encoder.encode(sentences, labels)
|
350 |
+
|
351 |
+
|
352 |
+
# Retrieve tokenized sentences and attention masks
|
353 |
+
input_ids = encoded_dict['input_ids']
|
354 |
+
attention_mask = encoded_dict['attention_mask']
|
355 |
+
|
356 |
+
dataset = TensorDataset(input_ids, attention_mask)
|
357 |
+
dataloader = DataLoader(dataset, sampler=RandomSampler(dataset), batch_size=batch_size)
|
358 |
+
|
359 |
+
# Setting optimizer and scheduler
|
360 |
+
optimizer = optimizer(self.model.parameters(), **optimizer_parameters)
|
361 |
+
if self.use_apex:
|
362 |
+
self.model, optimizer = amp.initialize(self.model, optimizer, opt_level="O2", keep_batchnorm_fp32=True,
|
363 |
+
loss_scale="dynamic")
|
364 |
+
|
365 |
+
total_steps = len(dataloader) * epochs
|
366 |
+
scheduler = scheduler(optimizer, num_training_steps=total_steps, **scheduler_parameters)
|
367 |
+
|
368 |
+
# TODO add stats
|
369 |
+
training_stats = []
|
370 |
+
test_stats = []
|
371 |
+
total_t0 = time.time()
|
372 |
+
|
373 |
+
self.model.train()
|
374 |
+
|
375 |
+
for epoch_i in range(0, epochs):
|
376 |
+
|
377 |
+
print(f'\n======== Epoch {epoch_i + 1} / {epochs} ========')
|
378 |
+
print('Training...')
|
379 |
+
|
380 |
+
t0 = time.time()
|
381 |
+
total_train_loss = 0
|
382 |
+
|
383 |
+
for step, batch in enumerate(dataloader):
|
384 |
+
|
385 |
+
if step % 25 == 0 and not step == 0:
|
386 |
+
elapsed = format_time(time.time() - t0)
|
387 |
+
print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format(step, len(dataloader), elapsed))
|
388 |
+
|
389 |
+
batch_input = batch[0].to(self.device)
|
390 |
+
batch_attention = batch[1].to(self.device)
|
391 |
+
|
392 |
+
# 512 to truncate max bert input
|
393 |
+
if len(batch[0]) > 512:
|
394 |
+
batch_input = batch_input[:, :512]
|
395 |
+
batch_attention = batch_attention[:, :512]
|
396 |
+
|
397 |
+
# Computing number to tokens to mask based on mask_percentage
|
398 |
+
num_sent, num_tokens = batch_input.shape
|
399 |
+
num_tokens_to_mask = int(mask_percentage * num_tokens)
|
400 |
+
|
401 |
+
# Generating randomly num_tokens_to_mask to mask for each sentence, considering only real tokens
|
402 |
+
# (not [CLS] nor label-tokens that are at the beginning of the sentence)
|
403 |
+
start_id = 1 + num_tokens_per_class # mask only
|
404 |
+
batch_mask_ids = torch.randint(start_id, num_tokens - 1, size=(num_sent, num_tokens_to_mask))
|
405 |
+
|
406 |
+
# Each sentence needs to be indexed num_tokens_to_mask times.
|
407 |
+
# This array is of the type [0,0,0 ..., 1,1,1, ..., 2,2,2, ... num_sentences -1]
|
408 |
+
sentence_ids = np.repeat(np.arange(len(batch_input)), num_tokens_to_mask)
|
409 |
+
|
410 |
+
# Retrieve the original tokens to mask:
|
411 |
+
batch_masked_tokens = batch_input[sentence_ids, batch_mask_ids.flatten()]
|
412 |
+
|
413 |
+
# Mask the tokens
|
414 |
+
batch_input[sentence_ids, batch_mask_ids.flatten()] = self.tokenizer.mask_token_id
|
415 |
+
|
416 |
+
|
417 |
+
# Forward pass
|
418 |
+
self.model.zero_grad()
|
419 |
+
result = self.model(batch_input, attention_mask=batch_attention, return_dict=True)
|
420 |
+
logits = result['logits']
|
421 |
+
|
422 |
+
# Retrieve logits only for masked tokens. logits is a tensor of dim [batch_size, num_tokens, len_vocab]
|
423 |
+
# logits = logits[np.concatenate([[i] * batch_mask_ids.shape[1] for i in range(len(batch_mask_ids))], 0),
|
424 |
+
# batch_mask_ids.flatten(), :]
|
425 |
+
logits = logits[sentence_ids, batch_mask_ids.flatten(), :]
|
426 |
+
|
427 |
+
|
428 |
+
loss = F.cross_entropy(logits, batch_masked_tokens.flatten())
|
429 |
+
total_train_loss += loss.item()
|
430 |
+
|
431 |
+
# Backward pass
|
432 |
+
if self.use_apex:
|
433 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
434 |
+
scaled_loss.backward()
|
435 |
+
else:
|
436 |
+
loss.backward()
|
437 |
+
|
438 |
+
# Clip the norm of the gradients to 1.0.
|
439 |
+
# This is to help prevent the "exploding gradients" problem.
|
440 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
441 |
+
|
442 |
+
optimizer.step()
|
443 |
+
scheduler.step()
|
444 |
+
|
445 |
+
# Restoring masked tokens
|
446 |
+
batch_input[sentence_ids, batch_mask_ids.flatten()] = batch_masked_tokens.flatten()
|
447 |
+
|
448 |
+
avg_train_loss = total_train_loss / len(dataloader)
|
449 |
+
training_time = format_time(time.time() - t0)
|
450 |
+
|
451 |
+
print("")
|
452 |
+
print(" Average training loss: {0:.2f}".format(avg_train_loss))
|
453 |
+
print(" Training epcoh took: {:}".format(training_time))
|
454 |
+
|
455 |
+
print("")
|
456 |
+
print("Training complete!")
|
457 |
+
|
458 |
+
print("Total training took {:} (h:mm:ss)".format(format_time(time.time() - total_t0)))
|
459 |
+
|
460 |
+
|
461 |
+
if __name__ == '__main__':
|
462 |
+
|
463 |
+
# model initialization
|
464 |
+
en_bert_model = BertTextGenerator('bert-base-uncased')
|
465 |
+
|
466 |
+
# text generation
|
467 |
+
parameters = {'n_sentences': 10, # 1000
|
468 |
+
'seed_text': "",
|
469 |
+
'batch_size': 10, # 50
|
470 |
+
'max_iter': 150,
|
471 |
+
'init_mask_prob': 1,
|
472 |
+
'generation_method': "attention",
|
473 |
+
'masked_portion': 1,
|
474 |
+
'temperature': 1,
|
475 |
+
'sample': True,
|
476 |
+
'top_k': 100,
|
477 |
+
}
|
478 |
+
|
479 |
+
file_path = None
|
480 |
+
print('\n\n ENGLISH TEXT GENERATION')
|
481 |
+
en_bert_sents = en_bert_model.generate(save_to_path=file_path, **parameters)
|
482 |
+
print("\nEnglish text generated: ")
|
483 |
+
for sent in en_bert_sents:
|
484 |
+
print(f"\t{sent}")
|
app.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
def greet(name):
|
4 |
+
return "Hello " + name + "!!"
|
5 |
+
|
6 |
+
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
+
iface.launch()
|
requirements
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
datasets
|
3 |
+
torch
|
4 |
+
evaluate
|
textprocessing.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
class Formatter():
|
12 |
+
|
13 |
+
def __init__(self, replace_tokens, unused_type='[unusedi]'):
|
14 |
+
self.dict_token_replace = {k: ' ' + unused_type.replace('i', str(i + 1)) + ' ' for i, k in
|
15 |
+
enumerate(replace_tokens)}
|
16 |
+
|
17 |
+
def format(self, path, pattern):
|
18 |
+
lines = []
|
19 |
+
|
20 |
+
re_line = re.compile(pattern)
|
21 |
+
with open(path, 'r') as f:
|
22 |
+
for match in re_line.finditer(''.join(f.readlines())):
|
23 |
+
line = match[0]
|
24 |
+
|
25 |
+
# Replace
|
26 |
+
for k, v in self.dict_token_replace.items():
|
27 |
+
line = line.replace(k, v)
|
28 |
+
|
29 |
+
lines.append(line)
|
30 |
+
|
31 |
+
return lines
|
32 |
+
|
33 |
+
def unformat(self, sentences):
|
34 |
+
unformatted_sentences = []
|
35 |
+
for sent in sentences:
|
36 |
+
# Replace
|
37 |
+
for k, v in self.dict_token_replace.items():
|
38 |
+
sent = sent.replace(v.strip(), k)
|
39 |
+
|
40 |
+
unformatted_sentences.append(sent)
|
41 |
+
|
42 |
+
return unformatted_sentences
|
43 |
+
|
44 |
+
class Encoder():
|
45 |
+
def __init__(self, tokenizer):
|
46 |
+
self.set_tokenizer(tokenizer)
|
47 |
+
|
48 |
+
def set_tokenizer(self, tokenizer):
|
49 |
+
self.tokenizer = tokenizer
|
50 |
+
|
51 |
+
def set_model(self, model):
|
52 |
+
self.model = model
|
53 |
+
|
54 |
+
def encode(self, lines):
|
55 |
+
|
56 |
+
encoded_dict = self.tokenizer.batch_encode_plus(
|
57 |
+
lines, # Sentence to encode.
|
58 |
+
padding=True,
|
59 |
+
add_special_tokens=True, # Add '[CLS]' and '[SEP]'
|
60 |
+
return_attention_mask=True, # Construct attn. masks.
|
61 |
+
return_tensors='pt', # Return pytorch tensors.
|
62 |
+
)
|
63 |
+
|
64 |
+
return encoded_dict
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
class LabelEncoder(Encoder):
|
70 |
+
def __init__(self, model, tokenizer, classes=[], num_tokens_per_class=3):
|
71 |
+
|
72 |
+
super().__init__(tokenizer)
|
73 |
+
self.set_model(model)
|
74 |
+
|
75 |
+
# Preparing special tokens related to labels.
|
76 |
+
# Each token is of the type 'cls-k' where cls is a class in classes and
|
77 |
+
# k is an integer value in range(0, num_tokens_per_class)
|
78 |
+
self.num_tokens_per_class = num_tokens_per_class
|
79 |
+
self.label_special_tokens_dict = {cls: [f'[{cls}-{i}]' for i in range(num_tokens_per_class)] for cls in classes}
|
80 |
+
self.label_special_tokens_list = np.concatenate([list(x) for x in self.label_special_tokens_dict.values()]).tolist()
|
81 |
+
|
82 |
+
|
83 |
+
# Addd special tokens and replace vocabulary
|
84 |
+
self.tokenizer.add_special_tokens({'additional_special_tokens': self.label_special_tokens_list})
|
85 |
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
86 |
+
|
87 |
+
|
88 |
+
def encode(self, lines, labels):
|
89 |
+
labeled_lines = [' '.join(self.label_special_tokens_dict[label]) + ' ' + line for line, label in zip(lines, labels)]
|
90 |
+
return super().encode(labeled_lines)
|
utils.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
import datetime
|
4 |
+
|
5 |
+
|
6 |
+
def print_batch(tokenizer, batch, n, header=None):
|
7 |
+
'''
|
8 |
+
print a batch of tokens. Used mainly for debugging
|
9 |
+
Parameters
|
10 |
+
------------
|
11 |
+
tokenizer : Tokenizer (https://huggingface.co/docs/tokenizers/python/latest/api/reference.html#tokenizers.Tokenizer)
|
12 |
+
|
13 |
+
batch : List of List[int]
|
14 |
+
|
15 |
+
n : int
|
16 |
+
number of sentences to print from the batch
|
17 |
+
header : str
|
18 |
+
header of the batch printed before the sentences
|
19 |
+
'''
|
20 |
+
print(f'=== {header or "Batch"} ===')
|
21 |
+
print(tokenizer.batch_decode(batch[:n], skip_special_tokens=True))
|
22 |
+
print('...\n' if n < len(batch) else '')
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
def flat_accuracy(preds, labels):
|
27 |
+
pred_flat = np.argmax(preds, axis=1).flatten()
|
28 |
+
labels_flat = labels.flatten()
|
29 |
+
return np.sum(pred_flat == labels_flat) / len(labels_flat)
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
def format_time(elapsed):
|
34 |
+
'''
|
35 |
+
Takes a time in seconds and returns a string hh:mm:ss
|
36 |
+
'''
|
37 |
+
elapsed_rounded = int(round((elapsed)))
|
38 |
+
return str(datetime.timedelta(seconds=elapsed_rounded))
|