File size: 2,980 Bytes
527e550
 
 
 
 
94f5fd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

def encode_sentence(sent, pair, tokenizer, model, layer: int):
    if pair == None:
        inputs = tokenizer(sent, padding=False, truncation=False, is_split_into_words=True, return_offsets_mapping=True,
                           return_tensors="pt")
        with torch.no_grad():
            outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device),
                            inputs['token_type_ids'].to(device))
    else:
        inputs = tokenizer(text=sent, text_pair=pair, padding=False, truncation=True,
                           is_split_into_words=True,
                           return_offsets_mapping=True, return_tensors="pt")
        with torch.no_grad():
            outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device),
                            inputs['token_type_ids'].to(device))

    return outputs.hidden_states[layer][0], inputs['input_ids'][0], inputs['offset_mapping'][0]


def centering(hidden_outputs):
    """
    hidden_outputs : [tokens, hidden_size]
    """
    # 全てのトークンの埋め込みについて足し上げ、その平均ベクトルを求める
    mean_vec = torch.sum(hidden_outputs, dim=0) / hidden_outputs.shape[0]
    hidden_outputs = hidden_outputs - mean_vec
    print(hidden_outputs.shape)
    return hidden_outputs


def convert_to_word_embeddings(offset_mapping, token_ids, hidden_tensors, tokenizer, pair):
    word_idx = -1
    subword_to_word_conv = np.full((hidden_tensors.shape[0]), -1)
    # Bug in hugging face tokenizer? Sometimes Metaspace is inserted
    metaspace = getattr(tokenizer.decoder, "replacement", None)
    metaspace = tokenizer.decoder.prefix if metaspace is None else metaspace
    tokenizer_bug_idxes = [i for i, x in enumerate(tokenizer.convert_ids_to_tokens(token_ids)) if
                           x == metaspace]

    for subw_idx, offset in enumerate(offset_mapping):
        if subw_idx in tokenizer_bug_idxes:
            continue
        elif offset[0] == offset[1]:  # Special token
            continue
        elif offset[0] == 0:
            word_idx += 1
            subword_to_word_conv[subw_idx] = word_idx
        else:
            subword_to_word_conv[subw_idx] = word_idx

    word_embeddings = torch.vstack(
        ([torch.mean(hidden_tensors[subword_to_word_conv == word_idx], dim=0) for word_idx in range(word_idx + 1)]))
    print(word_embeddings.shape)

    if pair:
        sep_tok_indices = [i for i, x in enumerate(token_ids) if x == tokenizer.sep_token_id]
        s2_start_idx = subword_to_word_conv[
            sep_tok_indices[0] + np.argmax(subword_to_word_conv[sep_tok_indices[0]:] > -1)]

        s1_word_embeddigs = word_embeddings[0:s2_start_idx, :]
        s2_word_embeddigs = word_embeddings[s2_start_idx:, :]

        return s1_word_embeddigs, s2_word_embeddigs
    else:
        return word_embeddings