Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
13f1508
1
Parent(s):
37a581e
Converted QaracDecoderModel to use PyTorch
Browse files
qarac/models/QaracDecoderModel.py
CHANGED
@@ -6,11 +6,10 @@ Created on Tue Sep 5 10:29:03 2023
|
|
6 |
@author: peter
|
7 |
"""
|
8 |
|
9 |
-
import
|
10 |
-
import tensorflow
|
11 |
import transformers
|
12 |
|
13 |
-
class QaracDecoderHead(
|
14 |
|
15 |
def __init__(self,config,input_embeddings):
|
16 |
"""
|
@@ -27,32 +26,16 @@ class QaracDecoderHead(keras.layers.Layer):
|
|
27 |
|
28 |
"""
|
29 |
super(QaracDecoderHead,self).__init__()
|
30 |
-
self.
|
31 |
-
self.
|
32 |
-
self.
|
33 |
-
|
34 |
-
input_embeddings)
|
35 |
|
36 |
-
def build(self,input_shape):
|
37 |
-
"""
|
38 |
-
|
39 |
-
|
40 |
-
Parameters
|
41 |
-
----------
|
42 |
-
input_shape : tuple
|
43 |
-
Input shape.
|
44 |
-
|
45 |
-
Returns
|
46 |
-
-------
|
47 |
-
None.
|
48 |
|
49 |
-
"""
|
50 |
-
self.built = True
|
51 |
-
|
52 |
|
53 |
|
54 |
|
55 |
-
def
|
56 |
vector,
|
57 |
hidden_states,
|
58 |
attention_mask=None,training=False):
|
@@ -66,12 +49,13 @@ class QaracDecoderHead(keras.layers.Layer):
|
|
66 |
|
67 |
Returns
|
68 |
-------
|
69 |
-
transformers.
|
70 |
Predicted text
|
71 |
|
72 |
"""
|
73 |
-
vectors =
|
74 |
-
|
|
|
75 |
1)),
|
76 |
attention_mask])
|
77 |
l0 = self.layer_0(vectors,
|
@@ -91,7 +75,7 @@ class QaracDecoderHead(keras.layers.Layer):
|
|
91 |
False,
|
92 |
training)[0])
|
93 |
|
94 |
-
class QaracDecoderModel(transformers.
|
95 |
|
96 |
def __init__(self,base_model,tokenizer):
|
97 |
"""
|
@@ -112,11 +96,9 @@ class QaracDecoderModel(transformers.TFPreTrainedModel,transformers.generation_t
|
|
112 |
self.decoder_head = QaracDecoderHead(self.base_model.config,
|
113 |
self.base_model.roberta.get_input_embeddings())
|
114 |
self.tokenizer = tokenizer
|
115 |
-
|
116 |
-
self.end=None
|
117 |
-
self.pad=None
|
118 |
|
119 |
-
def
|
120 |
"""
|
121 |
Predicts text from inputs
|
122 |
|
@@ -130,13 +112,13 @@ class QaracDecoderModel(transformers.TFPreTrainedModel,transformers.generation_t
|
|
130 |
|
131 |
Returns
|
132 |
-------
|
133 |
-
transformers.
|
134 |
Predicted text
|
135 |
|
136 |
"""
|
137 |
(v,s) = (kwargs['vector'],inputs) if 'vector' in kwargs else inputs
|
138 |
|
139 |
-
return self.decoder_head(
|
140 |
self.base_model(s).last_hidden_state,
|
141 |
training = kwargs.get('training',False))
|
142 |
|
@@ -145,7 +127,7 @@ class QaracDecoderModel(transformers.TFPreTrainedModel,transformers.generation_t
|
|
145 |
attention_mask=None,
|
146 |
**kwargs):
|
147 |
if attention_mask is None:
|
148 |
-
attention_mask =
|
149 |
return {'input_ids':input_ids,
|
150 |
'attention_mask':attention_mask}
|
151 |
|
|
|
6 |
@author: peter
|
7 |
"""
|
8 |
|
9 |
+
import torch
|
|
|
10 |
import transformers
|
11 |
|
12 |
+
class QaracDecoderHead(torch.nn.Module):
|
13 |
|
14 |
def __init__(self,config,input_embeddings):
|
15 |
"""
|
|
|
26 |
|
27 |
"""
|
28 |
super(QaracDecoderHead,self).__init__()
|
29 |
+
self.layer_0 = transformers.models.roberta.modeling_roberta.RobertaLayer(config)
|
30 |
+
self.layer_1 = transformers.models.roberta.modeling_roberta.RobertaLayer(config)
|
31 |
+
self.head = transformers.models.roberta.modeling_roberta.RobertaLMHead(config,
|
32 |
+
input_embeddings)
|
|
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
|
|
|
|
|
|
35 |
|
36 |
|
37 |
|
38 |
+
def forward(self,
|
39 |
vector,
|
40 |
hidden_states,
|
41 |
attention_mask=None,training=False):
|
|
|
49 |
|
50 |
Returns
|
51 |
-------
|
52 |
+
transformers.modeling_outputs.CausalLMOutputWithCrossAttentions
|
53 |
Predicted text
|
54 |
|
55 |
"""
|
56 |
+
vectors = torch.cat([vector, hidden_states],
|
57 |
+
dim=1)
|
58 |
+
attentions = attention_mask if attention_mask is None else torch.cat([torch.ones((hidden_states.shape(0),
|
59 |
1)),
|
60 |
attention_mask])
|
61 |
l0 = self.layer_0(vectors,
|
|
|
75 |
False,
|
76 |
training)[0])
|
77 |
|
78 |
+
class QaracDecoderModel(transformers.PreTrainedModel,transformers.generation_utils.TFGenerationMixin):
|
79 |
|
80 |
def __init__(self,base_model,tokenizer):
|
81 |
"""
|
|
|
96 |
self.decoder_head = QaracDecoderHead(self.base_model.config,
|
97 |
self.base_model.roberta.get_input_embeddings())
|
98 |
self.tokenizer = tokenizer
|
99 |
+
|
|
|
|
|
100 |
|
101 |
+
def forward(self,inputs,**kwargs):
|
102 |
"""
|
103 |
Predicts text from inputs
|
104 |
|
|
|
112 |
|
113 |
Returns
|
114 |
-------
|
115 |
+
transformers.modeling_outputs.CausalLMOutputWithCrossAttentions
|
116 |
Predicted text
|
117 |
|
118 |
"""
|
119 |
(v,s) = (kwargs['vector'],inputs) if 'vector' in kwargs else inputs
|
120 |
|
121 |
+
return self.decoder_head(torch.unsqueeze(v,1),
|
122 |
self.base_model(s).last_hidden_state,
|
123 |
training = kwargs.get('training',False))
|
124 |
|
|
|
127 |
attention_mask=None,
|
128 |
**kwargs):
|
129 |
if attention_mask is None:
|
130 |
+
attention_mask = torch.ones_like(input_ids)
|
131 |
return {'input_ids':input_ids,
|
132 |
'attention_mask':attention_mask}
|
133 |
|