File size: 1,708 Bytes
8f1745b
 
 
 
 
 
 
 
32df2f1
4a7707c
8f1745b
210f1cb
32df2f1
c284c9a
32df2f1
f16a715
 
 
 
32df2f1
 
f16a715
 
 
 
 
 
32df2f1
 
4a7707c
 
798488e
32df2f1
8f1745b
 
32df2f1
f16a715
 
 
 
 
32df2f1
f16a715
b2593fa
 
32df2f1
f16a715
 
 
32df2f1
f16a715
 
 
b2593fa
32df2f1
bc77ce5
 
32df2f1
798488e
 
 
 
 
32df2f1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Sep  5 07:32:55 2023

@author: peter
"""

import torch
import qarac.models.layers.FactorizedMatrixMultiplication


class GlobalAttentionPoolingHead(torch.nn.Module):
    
    def __init__(self,config):
        """
        Creates the layer
        Parameters
        ----------
        config : transformers.RobertaConfig
                 the configuration of the model

        Returns
        -------
        None.

        """
        size = config.hidden_size
        super(GlobalAttentionPoolingHead,self).__init__()
        self.global_projection = qarac.models.layers.FactorizedMatrixMultiplication.FactorizedMatrixMultiplication(size)
        self.local_projection = qarac.models.layers.FactorizedMatrixMultiplication.FactorizedMatrixMultiplication(size)
        self.cosine = torch.nn.CosineSimilarity(dim=2,eps=1.0e-12)
        
    
        
    def forward(self,X,attention_mask=None):
        """
        

        Parameters
        ----------
        X : torch.Tensor
            Base model vectors to apply pooling to.
        attention_mask: tensorflow.Tensor, optional
            mask for pad values
        

        Returns
        -------
        torch.Tensor
            The pooled value.

        """
        if attention_mask is None:
            attention_mask = torch.ones_like(X)
        else:
            attention_mask = attention_mask.unsqueeze(2)
        Xa = X*attention_mask
        sigma = torch.sum(Xa,dim=1,keepdim=True)
        gp = self.global_projection(sigma)
        lp = self.local_projection(Xa)
        
        attention = self.cosine(lp,gp)
        return torch.einsum('ij,ijk->ik',attention,Xa)