|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List |
|
from random import randint |
|
|
|
|
|
import torch |
|
|
|
|
|
from utils import Head_Mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_layers_head_mask(config, head_mask_type: Head_Mask=Head_Mask.ALL, specific_heads: List[int] = None): |
|
mask_heads = torch.zeros((config.encoder_layers, config.encoder_attention_heads)) |
|
if head_mask_type == Head_Mask.RANDOM: |
|
for i in range(config.encoder_layers): |
|
rand_idx = randint(0, config.encoder_attention_heads-1) |
|
mask_heads[i, rand_idx] = 1 |
|
elif head_mask_type == Head_Mask.NONE: |
|
mask_heads[:, :] = 1 |
|
elif head_mask_type == Head_Mask.ALL: |
|
pass |
|
elif head_mask_type == Head_Mask.SPECIFIC: |
|
if specific_heads: |
|
for layer_i in range(len(mask_heads)): |
|
specific_head = specific_heads[layer_i] - 1 |
|
mask_heads[layer_i][specific_head] = 1 |
|
else: |
|
mask_heads = torch.Tensor([[0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0], |
|
[1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0], |
|
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0], |
|
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1], |
|
[1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], |
|
[0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0], |
|
[0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], |
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1], |
|
[0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1], |
|
[0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1], |
|
[0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1]]) |
|
else: |
|
raise NotImplementedError() |
|
return mask_heads.tolist() |