File size: 6,105 Bytes
6cf191b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#############################
#   Imports
#############################

# Python modules

# Remote modules
import torch

# Local modules

#############################
#   Constants
#############################

#############################
#   Stuff
#############################

def find_head_to_mask(heads_mask) -> int:
    head_idx = torch.argmax(heads_mask)
    head_idx_simple = head_idx.item()
    return head_idx_simple

def commonsense_attention_mask_update(bsz, n_tokens, commonsense_matrix, attn_weights,
                                      num_heads=16, specific_head=0):
    commonsense_mask = torch.zeros(
        ((bsz, num_heads, n_tokens, n_tokens))
    )
    attn_weights_helper = attn_weights.reshape((num_heads, bsz, n_tokens, n_tokens))
    zeros = torch.zeros(
        ((bsz, n_tokens, n_tokens))
    )
    head_previous_attention_weights = attn_weights_helper[specific_head]
    attn_weights_helper[specific_head] = zeros
    attn_weights_helper = attn_weights_helper.reshape((bsz, num_heads, n_tokens, n_tokens))
    if commonsense_matrix is None:
        # ignore is not passed (ones -> neutral since multiplication is used)
        commonsense_matrix = torch.ones(
            ((bsz, n_tokens, n_tokens))
        )
    commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens))
    commonsense_mask[specific_head] = head_previous_attention_weights * commonsense_matrix
    # TODO Stupid conversion
    commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens)).to('cuda')
    return attn_weights_helper + commonsense_mask

def convert_relations_to_binary_mask(input_relations, should_clone=True):
    relations_binary_mask=input_relations
    if should_clone:
        relations_binary_mask = input_relations.clone()
    relations_binary_mask[relations_binary_mask > 1] = 1
    return relations_binary_mask

def relation_binary_2d_to_1d(relations_binary_mask):
    relations_binary_mask = relations_binary_mask.sum(dim=1)
    relations_binary_mask[relations_binary_mask > 1] = 1
    return relations_binary_mask

def create_layer_with_commonsense_on_specific_head(relation_binary_mask, bsz, num_heads, specific_head=0):
    n_tokens = relation_binary_mask.size()[-1]
    relations_mask = torch.zeros(
        (bsz, num_heads, n_tokens, n_tokens)
    )
    layer = relations_mask.reshape((num_heads, bsz, n_tokens, n_tokens))
    layer[specific_head] = relation_binary_mask
    layer = layer.reshape((bsz, num_heads, n_tokens, n_tokens))
    return layer

def update_weights_regarding_relations_on_specific_head(layer_head_mask, attn_weights, relation_inputs, bsz, num_heads, tgt_len, src_len, verbose=True):
    #layer_head_mask = layer_head_mask.to(attn_weights.device)
    inverse_layer_head_mask = (layer_head_mask.view(num_heads, 1, 1) - 1) * -1
    #inverse_layer_head_mask = inverse_layer_head_mask.to(attn_weights.device)
    #print('layer_head_mask:', layer_head_mask)
    if verbose:
        print("==============================")
        print('layer_head_mask.shape:',  layer_head_mask.shape)
        print('inverse_layer_head_mask.shape:',  inverse_layer_head_mask.shape)
        print('attn_weights.shape:',  attn_weights.shape)
        print('relation_inputs.shape', relation_inputs.shape)
        print("==============================")
    #print('layer_head_mask.device:', layer_head_mask.device)
    #print('inverse_layer_head_mask.device:', inverse_layer_head_mask.device)
    #print('relation_inputs.device:', relation_inputs.device)
    intermediate_weights = inverse_layer_head_mask * attn_weights.view(bsz, num_heads, tgt_len, src_len)
    relation_inputs = convert_relations_to_binary_mask(relation_inputs, should_clone=False)
    relation_weights = layer_head_mask.view(num_heads, 1, 1) * relation_inputs.view(bsz,1,tgt_len, src_len) * attn_weights.view(bsz, num_heads,
                                                                                               tgt_len, src_len)
    attn_weights = intermediate_weights + relation_weights
    # [batch, n_heads, seq_length, seq_length]
    if verbose:
        print('attn_weights_int.shape', attn_weights.shape)
    return attn_weights

"""
    def create_commonsense_mask(self, bsz, n_tokens, commonsense_matrix, num_heads=16, specific_head=0):
        commonsense_mask = torch.zeros(
            ((bsz, num_heads, n_tokens, n_tokens))
        )
        if commonsense_matrix is None:
            commonsense_matrix = torch.zeros(
                ((bsz, n_tokens, n_tokens))
            )
        commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens))
        commonsense_mask[specific_head] = commonsense_matrix
        commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens))
        return commonsense_mask

    def commonsense_attention_mask_update(self, bsz, n_tokens, commonsense_matrix, attn_weights,
                                          specific_head=0):
        num_heads = self.num_heads
        commonsense_mask = torch.zeros(
            ((bsz, num_heads, n_tokens, n_tokens))
        )
        attn_weights_helper = attn_weights.reshape((num_heads, bsz, n_tokens, n_tokens))
        zeros = torch.zeros(
            ((bsz, n_tokens, n_tokens))
        )
        head_previous_attention_weights = attn_weights_helper[specific_head]
        attn_weights_helper[specific_head] = zeros
        attn_weights_helper = attn_weights_helper.reshape((bsz, num_heads, n_tokens, n_tokens))
        if commonsense_matrix is None:
            # ignore is not passed (ones -> neutral since multiplication is used)
            commonsense_matrix = torch.ones(
                ((bsz, n_tokens, n_tokens))
            )
        commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens))
        commonsense_mask[specific_head] = head_previous_attention_weights * commonsense_matrix
        # TODO Stupid conversion
        commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens)).to('cuda')
        return attn_weights_helper + commonsense_mask
"""