Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
32df2f1
1
Parent(s):
ac98be7
Converted GlobalAttentionPoolingHead to use PyTorch
Browse files
qarac/models/layers/GlobalAttentionPoolingHead.py
CHANGED
@@ -6,90 +6,59 @@ Created on Tue Sep 5 07:32:55 2023
|
|
6 |
@author: peter
|
7 |
"""
|
8 |
|
9 |
-
import
|
10 |
-
import tensorflow
|
11 |
|
|
|
12 |
|
13 |
-
|
14 |
-
def dot_prod(vectors):
|
15 |
-
(x,y) = vectors
|
16 |
-
return tensorflow.tensordot(x,y,axes=1)
|
17 |
|
18 |
-
|
19 |
-
class GlobalAttentionPoolingHead(keras.layers.Layer):
|
20 |
-
|
21 |
-
def __init__(self):
|
22 |
"""
|
23 |
Creates the layer
|
24 |
-
|
25 |
-
Returns
|
26 |
-
-------
|
27 |
-
None.
|
28 |
-
|
29 |
-
"""
|
30 |
-
super(GlobalAttentionPoolingHead,self).__init__()
|
31 |
-
self.global_projection = None
|
32 |
-
self.local_projection = None
|
33 |
-
|
34 |
-
|
35 |
-
def build(self,input_shape):
|
36 |
-
"""
|
37 |
-
Initialises layer weights
|
38 |
-
|
39 |
Parameters
|
40 |
----------
|
41 |
-
|
42 |
-
|
43 |
|
44 |
Returns
|
45 |
-------
|
46 |
None.
|
47 |
|
48 |
"""
|
49 |
-
|
50 |
-
self.
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
shape=(width,width),
|
55 |
-
trainable=True)
|
56 |
-
self.built=True
|
57 |
|
58 |
-
@tensorflow.function
|
59 |
-
def project_local(self,X):
|
60 |
-
return tensorflow.tensordot(X,
|
61 |
-
self.local_projection,
|
62 |
-
axes=1)
|
63 |
|
64 |
-
def
|
65 |
"""
|
66 |
|
67 |
|
68 |
Parameters
|
69 |
----------
|
70 |
-
X :
|
71 |
Base model vectors to apply pooling to.
|
72 |
attention_mask: tensorflow.Tensor, optional
|
73 |
mask for pad values
|
74 |
-
|
75 |
-
Not used. The default is None.
|
76 |
|
77 |
Returns
|
78 |
-------
|
79 |
-
|
80 |
The pooled value.
|
81 |
|
82 |
"""
|
83 |
-
gp = tensorflow.linalg.l2_normalize(tensorflow.tensordot(tensorflow.reduce_sum(X,
|
84 |
-
axis=1),
|
85 |
-
self.global_projection,
|
86 |
-
axes=1),
|
87 |
-
axis=1)
|
88 |
-
lp = tensorflow.linalg.l2_normalize(tensorflow.vectorized_map(self.project_local,
|
89 |
-
X),
|
90 |
-
axis=2)
|
91 |
-
attention = tensorflow.vectorized_map(dot_prod,(lp,gp))
|
92 |
if attention_mask is None:
|
93 |
-
attention_mask =
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
@author: peter
|
7 |
"""
|
8 |
|
9 |
+
import torch
|
|
|
10 |
|
11 |
+
EPSILON = 1.0e-12
|
12 |
|
13 |
+
class GlobalAttentionPoolingHead(torch.nn.Module):
|
|
|
|
|
|
|
14 |
|
15 |
+
def __init__(self,config):
|
|
|
|
|
|
|
16 |
"""
|
17 |
Creates the layer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
Parameters
|
19 |
----------
|
20 |
+
config : transformers.RobertaConfig
|
21 |
+
the configuration of the model
|
22 |
|
23 |
Returns
|
24 |
-------
|
25 |
None.
|
26 |
|
27 |
"""
|
28 |
+
size = config.hidden_size
|
29 |
+
super(GlobalAttentionPoolingHead,self).__init__()
|
30 |
+
self.global_projection = torch.nn.Linear(size,size,bias=False)
|
31 |
+
self.local_projection = torch.nn.Linear(size,size,bias=False)
|
32 |
+
|
|
|
|
|
|
|
33 |
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
+
def forward(self,X,attention_mask=None):
|
36 |
"""
|
37 |
|
38 |
|
39 |
Parameters
|
40 |
----------
|
41 |
+
X : torch.Tensor
|
42 |
Base model vectors to apply pooling to.
|
43 |
attention_mask: tensorflow.Tensor, optional
|
44 |
mask for pad values
|
45 |
+
|
|
|
46 |
|
47 |
Returns
|
48 |
-------
|
49 |
+
torch.Tensor
|
50 |
The pooled value.
|
51 |
|
52 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
if attention_mask is None:
|
54 |
+
attention_mask = torch.ones_like(X)
|
55 |
+
Xa = X*attention_mask
|
56 |
+
sigma = torch.sum(Xa,dim=1)
|
57 |
+
psigma = self.global_projection(sigma)
|
58 |
+
nsigma = torch.max(torch.linalg.vector_norm(psigma,dim=1),EPSILON)
|
59 |
+
gp = psigma/nsigma
|
60 |
+
loc = self.local_projection(Xa)
|
61 |
+
nloc = torch.max(torch.linalg.vector_norm(loc,dim=2),EPSILON)
|
62 |
+
lp = loc/nloc
|
63 |
+
attention = torch.einsum('ijk,k->ij',lp,gp)
|
64 |
+
return torch.einsum('ij,ijk->ik',attention,Xa)
|