Spaces:
Sleeping
Sleeping
import os | |
from time import time | |
os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
def get_bpe_groups(token_offsets, bpe_offsets, input_ids, max_bpe_pieces=5): | |
bpe_groups = [] | |
last_used_bpe = 0 | |
# find the size of offsets | |
if (0, 0) in bpe_offsets: | |
bpe_size = bpe_offsets.index((0, 0)) | |
else: | |
bpe_size = len(bpe_offsets) | |
saved_ids = [i for i in range(len(input_ids))] | |
redundant_ids = [] | |
for token_offset in token_offsets: | |
start_token, end_token = token_offset | |
bpe_group = [] | |
mapping_is_found = False | |
for i in range(last_used_bpe, bpe_size): | |
start_bpe, end_bpe = bpe_offsets[i] | |
if start_bpe >= start_token and end_bpe <= end_token: | |
# check if bpe_group is satisfy max_bpe_pieces constraint | |
if len(bpe_group) < max_bpe_pieces: | |
bpe_group.append(i) | |
else: | |
redundant_ids.append(i) | |
last_used_bpe = i + 1 | |
mapping_is_found = True | |
elif mapping_is_found: | |
# stop doing useless iterations | |
break | |
else: | |
continue | |
bpe_groups.append(bpe_group) | |
saved_ids = [i for i in saved_ids if i not in redundant_ids] | |
return bpe_groups, saved_ids | |
def reduce_input_ids(input_ids, bpe_groups, saved_ids, | |
max_bpe_length=80, max_bpe_pieces=5): | |
# check if sequence is satisfy max_bpe_length constraint | |
while len(saved_ids) > max_bpe_length: | |
max_bpe_pieces -= 1 | |
for token_id in range(len(bpe_groups)): | |
if len(bpe_groups[token_id]) > max_bpe_pieces: | |
redundant_ids = bpe_groups[token_id][max_bpe_pieces:] | |
bpe_groups[token_id] = bpe_groups[token_id][:max_bpe_pieces] | |
saved_ids = [i for i in saved_ids if i not in redundant_ids] | |
# get offsets | |
reduced_ids = [input_ids[i] for i in saved_ids] | |
correct_offsets = [] | |
idx = 0 | |
for i, bpe_group in enumerate(bpe_groups): | |
norm_idx = min(idx, len(reduced_ids) - 1) | |
correct_offsets.append(norm_idx) | |
idx += len(bpe_group) | |
return reduced_ids, correct_offsets | |
def get_offsets_and_reduce_input_ids(tokenizer_output, token_offset_list, | |
index_name="bert", max_bpe_length=80, | |
max_bpe_pieces=5): | |
timings = {"bpe": 0, "reduce": 0, "mask": 0} | |
output_ids, output_offsets, output_masks = [], [], [] | |
for i, token_offsets in enumerate(token_offset_list): | |
input_ids = tokenizer_output['input_ids'][i] | |
t0 = time() | |
# get bpe level offsets | |
bpe_offsets = tokenizer_output['offset_mapping'][i] | |
bpe_groups, saved_ids = get_bpe_groups(token_offsets, bpe_offsets, | |
input_ids, | |
max_bpe_pieces=max_bpe_pieces) | |
t1 = time() | |
timings["bpe"] += t1 - t0 | |
# reduce sequence length | |
reduced_ids, correct_offsets = reduce_input_ids(input_ids, bpe_groups, | |
saved_ids, | |
max_bpe_length=max_bpe_length, | |
max_bpe_pieces=max_bpe_pieces) | |
t2 = time() | |
timings["reduce"] += t2 - t1 | |
# get mask | |
bpe_mask = [1 for _ in correct_offsets] | |
output_ids.append(reduced_ids) | |
output_offsets.append(correct_offsets) | |
output_masks.append(bpe_mask) | |
t3 = time() | |
timings["mask"] += t3 - t2 | |
# tt = sum(timings.values()) | |
# timings = {k: f"{round(v * 100 / tt, 2)}%" for k, v in timings.items()} | |
# print(timings) | |
output = {index_name: output_ids, | |
f"{index_name}-offsets": output_offsets, | |
"mask": output_masks} | |
return output | |
def get_offset_for_tokens(tokens): | |
sentence = " ".join(tokens) | |
token_offsets = [] | |
end_idx = 0 | |
for token in tokens: | |
idx = sentence[end_idx:].index(token) + end_idx | |
end_idx = idx + len(token) | |
offset = (idx, end_idx) | |
token_offsets.append(offset) | |
return token_offsets | |
def get_token_offsets(batch): | |
token_offset_list = [] | |
for tokens in batch: | |
token_offsets = get_offset_for_tokens(tokens) | |
token_offset_list.append(token_offsets) | |
return token_offset_list | |
def pad_output(output, pad_idx=0): | |
padded_output = {} | |
for input_key in output.keys(): | |
indexes = output[input_key] | |
max_len = max([len(x) for x in indexes]) | |
padded_indexes = [] | |
for index_list in indexes: | |
cur_len = len(index_list) | |
pad_len = max_len - cur_len | |
padded_indexes.append(index_list + [pad_idx] * pad_len) | |
padded_output[input_key] = padded_indexes | |
return padded_output | |
def tokenize_batch(tokenizer, batch_tokens, index_name="bert", | |
max_bpe_length=80, max_bpe_pieces=5): | |
timings = {} | |
t0 = time() | |
# get batch with sentences | |
batch_sentences = [" ".join(x) for x in batch_tokens] | |
# get token level offsets | |
token_offset_list = get_token_offsets(batch_tokens) | |
# token_offset_list = get_token_offsets_multi(batch_tokens) | |
t1 = time() | |
timings["offset_time"] = t1 - t0 | |
# tokenize batch | |
tokenizer_output = tokenizer.batch_encode_plus(batch_sentences, | |
pad_to_max_length=False, | |
return_offsets_mapping=True, | |
add_special_tokens=False) | |
t2 = time() | |
timings["tokenize_time"] = t2 - t1 | |
# postprocess batch | |
output = get_offsets_and_reduce_input_ids(tokenizer_output, | |
token_offset_list, | |
index_name=index_name, | |
max_bpe_length=max_bpe_length, | |
max_bpe_pieces=max_bpe_pieces) | |
t3 = time() | |
timings["reduce_time"] = t3 - t2 | |
# pad output | |
output = pad_output(output) | |
t4 = time() | |
timings["pading_time"] = t4 - t3 | |
# tt = sum(timings.values()) | |
# timings = {k:f"{round(v*100/tt, 2)}%" for k,v in timings.items()} | |
# print(timings) | |
return output | |