Spaces:
Build error
Build error
File size: 1,708 Bytes
8f1745b 32df2f1 4a7707c 8f1745b 210f1cb 32df2f1 c284c9a 32df2f1 f16a715 32df2f1 f16a715 32df2f1 4a7707c 798488e 32df2f1 8f1745b 32df2f1 f16a715 32df2f1 f16a715 b2593fa 32df2f1 f16a715 32df2f1 f16a715 b2593fa 32df2f1 bc77ce5 32df2f1 798488e 32df2f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Sep 5 07:32:55 2023
@author: peter
"""
import torch
import qarac.models.layers.FactorizedMatrixMultiplication
class GlobalAttentionPoolingHead(torch.nn.Module):
def __init__(self,config):
"""
Creates the layer
Parameters
----------
config : transformers.RobertaConfig
the configuration of the model
Returns
-------
None.
"""
size = config.hidden_size
super(GlobalAttentionPoolingHead,self).__init__()
self.global_projection = qarac.models.layers.FactorizedMatrixMultiplication.FactorizedMatrixMultiplication(size)
self.local_projection = qarac.models.layers.FactorizedMatrixMultiplication.FactorizedMatrixMultiplication(size)
self.cosine = torch.nn.CosineSimilarity(dim=2,eps=1.0e-12)
def forward(self,X,attention_mask=None):
"""
Parameters
----------
X : torch.Tensor
Base model vectors to apply pooling to.
attention_mask: tensorflow.Tensor, optional
mask for pad values
Returns
-------
torch.Tensor
The pooled value.
"""
if attention_mask is None:
attention_mask = torch.ones_like(X)
else:
attention_mask = attention_mask.unsqueeze(2)
Xa = X*attention_mask
sigma = torch.sum(Xa,dim=1,keepdim=True)
gp = self.global_projection(sigma)
lp = self.local_projection(Xa)
attention = self.cosine(lp,gp)
return torch.einsum('ij,ijk->ik',attention,Xa) |