PeteBleackley commited on
Commit
32df2f1
·
1 Parent(s): ac98be7

Converted GlobalAttentionPoolingHead to use PyTorch

Browse files
qarac/models/layers/GlobalAttentionPoolingHead.py CHANGED
@@ -6,90 +6,59 @@ Created on Tue Sep 5 07:32:55 2023
6
  @author: peter
7
  """
8
 
9
- import keras
10
- import tensorflow
11
 
 
12
 
13
- @tensorflow.function
14
- def dot_prod(vectors):
15
- (x,y) = vectors
16
- return tensorflow.tensordot(x,y,axes=1)
17
 
18
-
19
- class GlobalAttentionPoolingHead(keras.layers.Layer):
20
-
21
- def __init__(self):
22
  """
23
  Creates the layer
24
-
25
- Returns
26
- -------
27
- None.
28
-
29
- """
30
- super(GlobalAttentionPoolingHead,self).__init__()
31
- self.global_projection = None
32
- self.local_projection = None
33
-
34
-
35
- def build(self,input_shape):
36
- """
37
- Initialises layer weights
38
-
39
  Parameters
40
  ----------
41
- input_shape : tuple
42
- Shape of the input layer
43
 
44
  Returns
45
  -------
46
  None.
47
 
48
  """
49
- width = input_shape[-1]
50
- self.global_projection = self.add_weight('global projection',
51
- shape=(width,width),
52
- trainable=True)
53
- self.local_projection = self.add_weight('local projection',
54
- shape=(width,width),
55
- trainable=True)
56
- self.built=True
57
 
58
- @tensorflow.function
59
- def project_local(self,X):
60
- return tensorflow.tensordot(X,
61
- self.local_projection,
62
- axes=1)
63
 
64
- def call(self,X,attention_mask=None,training=None):
65
  """
66
 
67
 
68
  Parameters
69
  ----------
70
- X : tensorflow.Tensor
71
  Base model vectors to apply pooling to.
72
  attention_mask: tensorflow.Tensor, optional
73
  mask for pad values
74
- training : bool, optional
75
- Not used. The default is None.
76
 
77
  Returns
78
  -------
79
- tensorflow.Tensor
80
  The pooled value.
81
 
82
  """
83
- gp = tensorflow.linalg.l2_normalize(tensorflow.tensordot(tensorflow.reduce_sum(X,
84
- axis=1),
85
- self.global_projection,
86
- axes=1),
87
- axis=1)
88
- lp = tensorflow.linalg.l2_normalize(tensorflow.vectorized_map(self.project_local,
89
- X),
90
- axis=2)
91
- attention = tensorflow.vectorized_map(dot_prod,(lp,gp))
92
  if attention_mask is None:
93
- attention_mask = tensorflow.ones_like(attention)
94
- return tensorflow.vectorized_map(dot_prod,
95
- (attention * attention_mask,X))
 
 
 
 
 
 
 
 
 
6
  @author: peter
7
  """
8
 
9
+ import torch
 
10
 
11
+ EPSILON = 1.0e-12
12
 
13
+ class GlobalAttentionPoolingHead(torch.nn.Module):
 
 
 
14
 
15
+ def __init__(self,config):
 
 
 
16
  """
17
  Creates the layer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  Parameters
19
  ----------
20
+ config : transformers.RobertaConfig
21
+ the configuration of the model
22
 
23
  Returns
24
  -------
25
  None.
26
 
27
  """
28
+ size = config.hidden_size
29
+ super(GlobalAttentionPoolingHead,self).__init__()
30
+ self.global_projection = torch.nn.Linear(size,size,bias=False)
31
+ self.local_projection = torch.nn.Linear(size,size,bias=False)
32
+
 
 
 
33
 
 
 
 
 
 
34
 
35
+ def forward(self,X,attention_mask=None):
36
  """
37
 
38
 
39
  Parameters
40
  ----------
41
+ X : torch.Tensor
42
  Base model vectors to apply pooling to.
43
  attention_mask: tensorflow.Tensor, optional
44
  mask for pad values
45
+
 
46
 
47
  Returns
48
  -------
49
+ torch.Tensor
50
  The pooled value.
51
 
52
  """
 
 
 
 
 
 
 
 
 
53
  if attention_mask is None:
54
+ attention_mask = torch.ones_like(X)
55
+ Xa = X*attention_mask
56
+ sigma = torch.sum(Xa,dim=1)
57
+ psigma = self.global_projection(sigma)
58
+ nsigma = torch.max(torch.linalg.vector_norm(psigma,dim=1),EPSILON)
59
+ gp = psigma/nsigma
60
+ loc = self.local_projection(Xa)
61
+ nloc = torch.max(torch.linalg.vector_norm(loc,dim=2),EPSILON)
62
+ lp = loc/nloc
63
+ attention = torch.einsum('ijk,k->ij',lp,gp)
64
+ return torch.einsum('ij,ijk->ik',attention,Xa)