RA-BART / model_utils.py
MrVicente's picture
added demo base code
6cf191b
#############################
# Imports
#############################
# Python modules
from typing import List
from random import randint
# Remote modules
import torch
# Local modules
from utils import Head_Mask
#############################
# Constants
#############################
#############################
# Stuff
#############################
def create_layers_head_mask(config, head_mask_type: Head_Mask=Head_Mask.ALL, specific_heads: List[int] = None):
mask_heads = torch.zeros((config.encoder_layers, config.encoder_attention_heads))
if head_mask_type == Head_Mask.RANDOM:
for i in range(config.encoder_layers):
rand_idx = randint(0, config.encoder_attention_heads-1)
mask_heads[i, rand_idx] = 1
elif head_mask_type == Head_Mask.NONE:
mask_heads[:, :] = 1
elif head_mask_type == Head_Mask.ALL:
pass
elif head_mask_type == Head_Mask.SPECIFIC:
if specific_heads:
for layer_i in range(len(mask_heads)):
specific_head = specific_heads[layer_i] - 1
mask_heads[layer_i][specific_head] = 1
else:
mask_heads = torch.Tensor([[0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0],
[1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0],
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
[1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0],
[0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
[0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1],
[0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1],
[0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1]])
else:
raise NotImplementedError()
return mask_heads.tolist()