Spaces:
Running
Running
File size: 6,466 Bytes
847e3e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import os
from time import time
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
def get_bpe_groups(token_offsets, bpe_offsets, input_ids, max_bpe_pieces=5):
bpe_groups = []
last_used_bpe = 0
# find the size of offsets
if (0, 0) in bpe_offsets:
bpe_size = bpe_offsets.index((0, 0))
else:
bpe_size = len(bpe_offsets)
saved_ids = [i for i in range(len(input_ids))]
redundant_ids = []
for token_offset in token_offsets:
start_token, end_token = token_offset
bpe_group = []
mapping_is_found = False
for i in range(last_used_bpe, bpe_size):
start_bpe, end_bpe = bpe_offsets[i]
if start_bpe >= start_token and end_bpe <= end_token:
# check if bpe_group is satisfy max_bpe_pieces constraint
if len(bpe_group) < max_bpe_pieces:
bpe_group.append(i)
else:
redundant_ids.append(i)
last_used_bpe = i + 1
mapping_is_found = True
elif mapping_is_found:
# stop doing useless iterations
break
else:
continue
bpe_groups.append(bpe_group)
saved_ids = [i for i in saved_ids if i not in redundant_ids]
return bpe_groups, saved_ids
def reduce_input_ids(input_ids, bpe_groups, saved_ids,
max_bpe_length=80, max_bpe_pieces=5):
# check if sequence is satisfy max_bpe_length constraint
while len(saved_ids) > max_bpe_length:
max_bpe_pieces -= 1
for token_id in range(len(bpe_groups)):
if len(bpe_groups[token_id]) > max_bpe_pieces:
redundant_ids = bpe_groups[token_id][max_bpe_pieces:]
bpe_groups[token_id] = bpe_groups[token_id][:max_bpe_pieces]
saved_ids = [i for i in saved_ids if i not in redundant_ids]
# get offsets
reduced_ids = [input_ids[i] for i in saved_ids]
correct_offsets = []
idx = 0
for i, bpe_group in enumerate(bpe_groups):
norm_idx = min(idx, len(reduced_ids) - 1)
correct_offsets.append(norm_idx)
idx += len(bpe_group)
return reduced_ids, correct_offsets
def get_offsets_and_reduce_input_ids(tokenizer_output, token_offset_list,
index_name="bert", max_bpe_length=80,
max_bpe_pieces=5):
timings = {"bpe": 0, "reduce": 0, "mask": 0}
output_ids, output_offsets, output_masks = [], [], []
for i, token_offsets in enumerate(token_offset_list):
input_ids = tokenizer_output['input_ids'][i]
t0 = time()
# get bpe level offsets
bpe_offsets = tokenizer_output['offset_mapping'][i]
bpe_groups, saved_ids = get_bpe_groups(token_offsets, bpe_offsets,
input_ids,
max_bpe_pieces=max_bpe_pieces)
t1 = time()
timings["bpe"] += t1 - t0
# reduce sequence length
reduced_ids, correct_offsets = reduce_input_ids(input_ids, bpe_groups,
saved_ids,
max_bpe_length=max_bpe_length,
max_bpe_pieces=max_bpe_pieces)
t2 = time()
timings["reduce"] += t2 - t1
# get mask
bpe_mask = [1 for _ in correct_offsets]
output_ids.append(reduced_ids)
output_offsets.append(correct_offsets)
output_masks.append(bpe_mask)
t3 = time()
timings["mask"] += t3 - t2
# tt = sum(timings.values())
# timings = {k: f"{round(v * 100 / tt, 2)}%" for k, v in timings.items()}
# print(timings)
output = {index_name: output_ids,
f"{index_name}-offsets": output_offsets,
"mask": output_masks}
return output
def get_offset_for_tokens(tokens):
sentence = " ".join(tokens)
token_offsets = []
end_idx = 0
for token in tokens:
idx = sentence[end_idx:].index(token) + end_idx
end_idx = idx + len(token)
offset = (idx, end_idx)
token_offsets.append(offset)
return token_offsets
def get_token_offsets(batch):
token_offset_list = []
for tokens in batch:
token_offsets = get_offset_for_tokens(tokens)
token_offset_list.append(token_offsets)
return token_offset_list
def pad_output(output, pad_idx=0):
padded_output = {}
for input_key in output.keys():
indexes = output[input_key]
max_len = max([len(x) for x in indexes])
padded_indexes = []
for index_list in indexes:
cur_len = len(index_list)
pad_len = max_len - cur_len
padded_indexes.append(index_list + [pad_idx] * pad_len)
padded_output[input_key] = padded_indexes
return padded_output
def tokenize_batch(tokenizer, batch_tokens, index_name="bert",
max_bpe_length=80, max_bpe_pieces=5):
timings = {}
t0 = time()
# get batch with sentences
batch_sentences = [" ".join(x) for x in batch_tokens]
# get token level offsets
token_offset_list = get_token_offsets(batch_tokens)
# token_offset_list = get_token_offsets_multi(batch_tokens)
t1 = time()
timings["offset_time"] = t1 - t0
# tokenize batch
tokenizer_output = tokenizer.batch_encode_plus(batch_sentences,
pad_to_max_length=False,
return_offsets_mapping=True,
add_special_tokens=False)
t2 = time()
timings["tokenize_time"] = t2 - t1
# postprocess batch
output = get_offsets_and_reduce_input_ids(tokenizer_output,
token_offset_list,
index_name=index_name,
max_bpe_length=max_bpe_length,
max_bpe_pieces=max_bpe_pieces)
t3 = time()
timings["reduce_time"] = t3 - t2
# pad output
output = pad_output(output)
t4 = time()
timings["pading_time"] = t4 - t3
# tt = sum(timings.values())
# timings = {k:f"{round(v*100/tt, 2)}%" for k,v in timings.items()}
# print(timings)
return output
|