|
import numpy as np |
|
import torch |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def encode_sentence(sent, pair, tokenizer, model, layer: int): |
|
if pair == None: |
|
inputs = tokenizer(sent, padding=False, truncation=False, is_split_into_words=True, return_offsets_mapping=True, |
|
return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device), |
|
inputs['token_type_ids'].to(device)) |
|
else: |
|
inputs = tokenizer(text=sent, text_pair=pair, padding=False, truncation=True, |
|
is_split_into_words=True, |
|
return_offsets_mapping=True, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device), |
|
inputs['token_type_ids'].to(device)) |
|
|
|
return outputs.hidden_states[layer][0], inputs['input_ids'][0], inputs['offset_mapping'][0] |
|
|
|
|
|
def centering(hidden_outputs): |
|
""" |
|
hidden_outputs : [tokens, hidden_size] |
|
""" |
|
|
|
mean_vec = torch.sum(hidden_outputs, dim=0) / hidden_outputs.shape[0] |
|
hidden_outputs = hidden_outputs - mean_vec |
|
print(hidden_outputs.shape) |
|
return hidden_outputs |
|
|
|
|
|
def convert_to_word_embeddings(offset_mapping, token_ids, hidden_tensors, tokenizer, pair): |
|
word_idx = -1 |
|
subword_to_word_conv = np.full((hidden_tensors.shape[0]), -1) |
|
|
|
metaspace = getattr(tokenizer.decoder, "replacement", None) |
|
metaspace = tokenizer.decoder.prefix if metaspace is None else metaspace |
|
tokenizer_bug_idxes = [i for i, x in enumerate(tokenizer.convert_ids_to_tokens(token_ids)) if |
|
x == metaspace] |
|
|
|
for subw_idx, offset in enumerate(offset_mapping): |
|
if subw_idx in tokenizer_bug_idxes: |
|
continue |
|
elif offset[0] == offset[1]: |
|
continue |
|
elif offset[0] == 0: |
|
word_idx += 1 |
|
subword_to_word_conv[subw_idx] = word_idx |
|
else: |
|
subword_to_word_conv[subw_idx] = word_idx |
|
|
|
word_embeddings = torch.vstack( |
|
([torch.mean(hidden_tensors[subword_to_word_conv == word_idx], dim=0) for word_idx in range(word_idx + 1)])) |
|
print(word_embeddings.shape) |
|
|
|
if pair: |
|
sep_tok_indices = [i for i, x in enumerate(token_ids) if x == tokenizer.sep_token_id] |
|
s2_start_idx = subword_to_word_conv[ |
|
sep_tok_indices[0] + np.argmax(subword_to_word_conv[sep_tok_indices[0]:] > -1)] |
|
|
|
s1_word_embeddigs = word_embeddings[0:s2_start_idx, :] |
|
s2_word_embeddigs = word_embeddings[s2_start_idx:, :] |
|
|
|
return s1_word_embeddigs, s2_word_embeddigs |
|
else: |
|
return word_embeddings |