QARAC / qarac /models /layers /GlobalAttentionPoolingHead.py
PeteBleackley
Fixed import
4a7707c
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Sep 5 07:32:55 2023
@author: peter
"""
import torch
import qarac.models.layers.FactorizedMatrixMultiplication
class GlobalAttentionPoolingHead(torch.nn.Module):
def __init__(self,config):
"""
Creates the layer
Parameters
----------
config : transformers.RobertaConfig
the configuration of the model
Returns
-------
None.
"""
size = config.hidden_size
super(GlobalAttentionPoolingHead,self).__init__()
self.global_projection = qarac.models.layers.FactorizedMatrixMultiplication.FactorizedMatrixMultiplication(size)
self.local_projection = qarac.models.layers.FactorizedMatrixMultiplication.FactorizedMatrixMultiplication(size)
self.cosine = torch.nn.CosineSimilarity(dim=2,eps=1.0e-12)
def forward(self,X,attention_mask=None):
"""
Parameters
----------
X : torch.Tensor
Base model vectors to apply pooling to.
attention_mask: tensorflow.Tensor, optional
mask for pad values
Returns
-------
torch.Tensor
The pooled value.
"""
if attention_mask is None:
attention_mask = torch.ones_like(X)
else:
attention_mask = attention_mask.unsqueeze(2)
Xa = X*attention_mask
sigma = torch.sum(Xa,dim=1,keepdim=True)
gp = self.global_projection(sigma)
lp = self.local_projection(Xa)
attention = self.cosine(lp,gp)
return torch.einsum('ij,ijk->ik',attention,Xa)