|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_head_to_mask(heads_mask) -> int: |
|
head_idx = torch.argmax(heads_mask) |
|
head_idx_simple = head_idx.item() |
|
return head_idx_simple |
|
|
|
def commonsense_attention_mask_update(bsz, n_tokens, commonsense_matrix, attn_weights, |
|
num_heads=16, specific_head=0): |
|
commonsense_mask = torch.zeros( |
|
((bsz, num_heads, n_tokens, n_tokens)) |
|
) |
|
attn_weights_helper = attn_weights.reshape((num_heads, bsz, n_tokens, n_tokens)) |
|
zeros = torch.zeros( |
|
((bsz, n_tokens, n_tokens)) |
|
) |
|
head_previous_attention_weights = attn_weights_helper[specific_head] |
|
attn_weights_helper[specific_head] = zeros |
|
attn_weights_helper = attn_weights_helper.reshape((bsz, num_heads, n_tokens, n_tokens)) |
|
if commonsense_matrix is None: |
|
|
|
commonsense_matrix = torch.ones( |
|
((bsz, n_tokens, n_tokens)) |
|
) |
|
commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens)) |
|
commonsense_mask[specific_head] = head_previous_attention_weights * commonsense_matrix |
|
|
|
commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens)).to('cuda') |
|
return attn_weights_helper + commonsense_mask |
|
|
|
def convert_relations_to_binary_mask(input_relations, should_clone=True): |
|
relations_binary_mask=input_relations |
|
if should_clone: |
|
relations_binary_mask = input_relations.clone() |
|
relations_binary_mask[relations_binary_mask > 1] = 1 |
|
return relations_binary_mask |
|
|
|
def relation_binary_2d_to_1d(relations_binary_mask): |
|
relations_binary_mask = relations_binary_mask.sum(dim=1) |
|
relations_binary_mask[relations_binary_mask > 1] = 1 |
|
return relations_binary_mask |
|
|
|
def create_layer_with_commonsense_on_specific_head(relation_binary_mask, bsz, num_heads, specific_head=0): |
|
n_tokens = relation_binary_mask.size()[-1] |
|
relations_mask = torch.zeros( |
|
(bsz, num_heads, n_tokens, n_tokens) |
|
) |
|
layer = relations_mask.reshape((num_heads, bsz, n_tokens, n_tokens)) |
|
layer[specific_head] = relation_binary_mask |
|
layer = layer.reshape((bsz, num_heads, n_tokens, n_tokens)) |
|
return layer |
|
|
|
def update_weights_regarding_relations_on_specific_head(layer_head_mask, attn_weights, relation_inputs, bsz, num_heads, tgt_len, src_len, verbose=True): |
|
|
|
inverse_layer_head_mask = (layer_head_mask.view(num_heads, 1, 1) - 1) * -1 |
|
|
|
|
|
if verbose: |
|
print("==============================") |
|
print('layer_head_mask.shape:', layer_head_mask.shape) |
|
print('inverse_layer_head_mask.shape:', inverse_layer_head_mask.shape) |
|
print('attn_weights.shape:', attn_weights.shape) |
|
print('relation_inputs.shape', relation_inputs.shape) |
|
print("==============================") |
|
|
|
|
|
|
|
intermediate_weights = inverse_layer_head_mask * attn_weights.view(bsz, num_heads, tgt_len, src_len) |
|
relation_inputs = convert_relations_to_binary_mask(relation_inputs, should_clone=False) |
|
relation_weights = layer_head_mask.view(num_heads, 1, 1) * relation_inputs.view(bsz,1,tgt_len, src_len) * attn_weights.view(bsz, num_heads, |
|
tgt_len, src_len) |
|
attn_weights = intermediate_weights + relation_weights |
|
|
|
if verbose: |
|
print('attn_weights_int.shape', attn_weights.shape) |
|
return attn_weights |
|
|
|
""" |
|
def create_commonsense_mask(self, bsz, n_tokens, commonsense_matrix, num_heads=16, specific_head=0): |
|
commonsense_mask = torch.zeros( |
|
((bsz, num_heads, n_tokens, n_tokens)) |
|
) |
|
if commonsense_matrix is None: |
|
commonsense_matrix = torch.zeros( |
|
((bsz, n_tokens, n_tokens)) |
|
) |
|
commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens)) |
|
commonsense_mask[specific_head] = commonsense_matrix |
|
commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens)) |
|
return commonsense_mask |
|
|
|
def commonsense_attention_mask_update(self, bsz, n_tokens, commonsense_matrix, attn_weights, |
|
specific_head=0): |
|
num_heads = self.num_heads |
|
commonsense_mask = torch.zeros( |
|
((bsz, num_heads, n_tokens, n_tokens)) |
|
) |
|
attn_weights_helper = attn_weights.reshape((num_heads, bsz, n_tokens, n_tokens)) |
|
zeros = torch.zeros( |
|
((bsz, n_tokens, n_tokens)) |
|
) |
|
head_previous_attention_weights = attn_weights_helper[specific_head] |
|
attn_weights_helper[specific_head] = zeros |
|
attn_weights_helper = attn_weights_helper.reshape((bsz, num_heads, n_tokens, n_tokens)) |
|
if commonsense_matrix is None: |
|
# ignore is not passed (ones -> neutral since multiplication is used) |
|
commonsense_matrix = torch.ones( |
|
((bsz, n_tokens, n_tokens)) |
|
) |
|
commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens)) |
|
commonsense_mask[specific_head] = head_previous_attention_weights * commonsense_matrix |
|
# TODO Stupid conversion |
|
commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens)).to('cuda') |
|
return attn_weights_helper + commonsense_mask |
|
""" |