PeteBleackley commited on
Commit
13f1508
·
1 Parent(s): 37a581e

Converted QaracDecoderModel to use PyTorch

Browse files
Files changed (1) hide show
  1. qarac/models/QaracDecoderModel.py +17 -35
qarac/models/QaracDecoderModel.py CHANGED
@@ -6,11 +6,10 @@ Created on Tue Sep 5 10:29:03 2023
6
  @author: peter
7
  """
8
 
9
- import keras
10
- import tensorflow
11
  import transformers
12
 
13
- class QaracDecoderHead(keras.layers.Layer):
14
 
15
  def __init__(self,config,input_embeddings):
16
  """
@@ -27,32 +26,16 @@ class QaracDecoderHead(keras.layers.Layer):
27
 
28
  """
29
  super(QaracDecoderHead,self).__init__()
30
- self.concat = keras.layers.Concatenate(axis=1)
31
- self.layer_0 = transformers.models.roberta.modeling_tf_roberta.TFRobertaLayer(config)
32
- self.layer_1 = transformers.models.roberta.modeling_tf_roberta.TFRobertaLayer(config)
33
- self.head = transformers.models.roberta.modeling_tf_roberta.TFRobertaLMHead(config,
34
- input_embeddings)
35
 
36
- def build(self,input_shape):
37
- """
38
-
39
-
40
- Parameters
41
- ----------
42
- input_shape : tuple
43
- Input shape.
44
-
45
- Returns
46
- -------
47
- None.
48
 
49
- """
50
- self.built = True
51
-
52
 
53
 
54
 
55
- def call(self,
56
  vector,
57
  hidden_states,
58
  attention_mask=None,training=False):
@@ -66,12 +49,13 @@ class QaracDecoderHead(keras.layers.Layer):
66
 
67
  Returns
68
  -------
69
- transformers.modeling_tf_outputs.TFCausalLMOutputWithCrossAttentions
70
  Predicted text
71
 
72
  """
73
- vectors = self.concat([vector, hidden_states])
74
- attentions = attention_mask if attention_mask is None else self.concat([tensorflow.ones((hidden_states.shape(0),
 
75
  1)),
76
  attention_mask])
77
  l0 = self.layer_0(vectors,
@@ -91,7 +75,7 @@ class QaracDecoderHead(keras.layers.Layer):
91
  False,
92
  training)[0])
93
 
94
- class QaracDecoderModel(transformers.TFPreTrainedModel,transformers.generation_tf_utils.TFGenerationMixin):
95
 
96
  def __init__(self,base_model,tokenizer):
97
  """
@@ -112,11 +96,9 @@ class QaracDecoderModel(transformers.TFPreTrainedModel,transformers.generation_t
112
  self.decoder_head = QaracDecoderHead(self.base_model.config,
113
  self.base_model.roberta.get_input_embeddings())
114
  self.tokenizer = tokenizer
115
- self.start=None
116
- self.end=None
117
- self.pad=None
118
 
119
- def call(self,inputs,**kwargs):
120
  """
121
  Predicts text from inputs
122
 
@@ -130,13 +112,13 @@ class QaracDecoderModel(transformers.TFPreTrainedModel,transformers.generation_t
130
 
131
  Returns
132
  -------
133
- transformers.modeling_tf_outputs.TFCausalLMOutputWithCrossAttentions
134
  Predicted text
135
 
136
  """
137
  (v,s) = (kwargs['vector'],inputs) if 'vector' in kwargs else inputs
138
 
139
- return self.decoder_head(tensorflow.expand_dims(v,1),
140
  self.base_model(s).last_hidden_state,
141
  training = kwargs.get('training',False))
142
 
@@ -145,7 +127,7 @@ class QaracDecoderModel(transformers.TFPreTrainedModel,transformers.generation_t
145
  attention_mask=None,
146
  **kwargs):
147
  if attention_mask is None:
148
- attention_mask = tensorflow.ones_like(input_ids)
149
  return {'input_ids':input_ids,
150
  'attention_mask':attention_mask}
151
 
 
6
  @author: peter
7
  """
8
 
9
+ import torch
 
10
  import transformers
11
 
12
+ class QaracDecoderHead(torch.nn.Module):
13
 
14
  def __init__(self,config,input_embeddings):
15
  """
 
26
 
27
  """
28
  super(QaracDecoderHead,self).__init__()
29
+ self.layer_0 = transformers.models.roberta.modeling_roberta.RobertaLayer(config)
30
+ self.layer_1 = transformers.models.roberta.modeling_roberta.RobertaLayer(config)
31
+ self.head = transformers.models.roberta.modeling_roberta.RobertaLMHead(config,
32
+ input_embeddings)
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
 
 
 
35
 
36
 
37
 
38
+ def forward(self,
39
  vector,
40
  hidden_states,
41
  attention_mask=None,training=False):
 
49
 
50
  Returns
51
  -------
52
+ transformers.modeling_outputs.CausalLMOutputWithCrossAttentions
53
  Predicted text
54
 
55
  """
56
+ vectors = torch.cat([vector, hidden_states],
57
+ dim=1)
58
+ attentions = attention_mask if attention_mask is None else torch.cat([torch.ones((hidden_states.shape(0),
59
  1)),
60
  attention_mask])
61
  l0 = self.layer_0(vectors,
 
75
  False,
76
  training)[0])
77
 
78
+ class QaracDecoderModel(transformers.PreTrainedModel,transformers.generation_utils.TFGenerationMixin):
79
 
80
  def __init__(self,base_model,tokenizer):
81
  """
 
96
  self.decoder_head = QaracDecoderHead(self.base_model.config,
97
  self.base_model.roberta.get_input_embeddings())
98
  self.tokenizer = tokenizer
99
+
 
 
100
 
101
+ def forward(self,inputs,**kwargs):
102
  """
103
  Predicts text from inputs
104
 
 
112
 
113
  Returns
114
  -------
115
+ transformers.modeling_outputs.CausalLMOutputWithCrossAttentions
116
  Predicted text
117
 
118
  """
119
  (v,s) = (kwargs['vector'],inputs) if 'vector' in kwargs else inputs
120
 
121
+ return self.decoder_head(torch.unsqueeze(v,1),
122
  self.base_model(s).last_hidden_state,
123
  training = kwargs.get('training',False))
124
 
 
127
  attention_mask=None,
128
  **kwargs):
129
  if attention_mask is None:
130
+ attention_mask = torch.ones_like(input_ids)
131
  return {'input_ids':input_ids,
132
  'attention_mask':attention_mask}
133