UOT / utils.py
4kasha
update demo
94f5fd3
import numpy as np
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
def encode_sentence(sent, pair, tokenizer, model, layer: int):
if pair == None:
inputs = tokenizer(sent, padding=False, truncation=False, is_split_into_words=True, return_offsets_mapping=True,
return_tensors="pt")
with torch.no_grad():
outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device),
inputs['token_type_ids'].to(device))
else:
inputs = tokenizer(text=sent, text_pair=pair, padding=False, truncation=True,
is_split_into_words=True,
return_offsets_mapping=True, return_tensors="pt")
with torch.no_grad():
outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device),
inputs['token_type_ids'].to(device))
return outputs.hidden_states[layer][0], inputs['input_ids'][0], inputs['offset_mapping'][0]
def centering(hidden_outputs):
"""
hidden_outputs : [tokens, hidden_size]
"""
# 全てのトークンの埋め込みについて足し上げ、その平均ベクトルを求める
mean_vec = torch.sum(hidden_outputs, dim=0) / hidden_outputs.shape[0]
hidden_outputs = hidden_outputs - mean_vec
print(hidden_outputs.shape)
return hidden_outputs
def convert_to_word_embeddings(offset_mapping, token_ids, hidden_tensors, tokenizer, pair):
word_idx = -1
subword_to_word_conv = np.full((hidden_tensors.shape[0]), -1)
# Bug in hugging face tokenizer? Sometimes Metaspace is inserted
metaspace = getattr(tokenizer.decoder, "replacement", None)
metaspace = tokenizer.decoder.prefix if metaspace is None else metaspace
tokenizer_bug_idxes = [i for i, x in enumerate(tokenizer.convert_ids_to_tokens(token_ids)) if
x == metaspace]
for subw_idx, offset in enumerate(offset_mapping):
if subw_idx in tokenizer_bug_idxes:
continue
elif offset[0] == offset[1]: # Special token
continue
elif offset[0] == 0:
word_idx += 1
subword_to_word_conv[subw_idx] = word_idx
else:
subword_to_word_conv[subw_idx] = word_idx
word_embeddings = torch.vstack(
([torch.mean(hidden_tensors[subword_to_word_conv == word_idx], dim=0) for word_idx in range(word_idx + 1)]))
print(word_embeddings.shape)
if pair:
sep_tok_indices = [i for i, x in enumerate(token_ids) if x == tokenizer.sep_token_id]
s2_start_idx = subword_to_word_conv[
sep_tok_indices[0] + np.argmax(subword_to_word_conv[sep_tok_indices[0]:] > -1)]
s1_word_embeddigs = word_embeddings[0:s2_start_idx, :]
s2_word_embeddigs = word_embeddings[s2_start_idx:, :]
return s1_word_embeddigs, s2_word_embeddigs
else:
return word_embeddings