PeteBleackley commited on
Commit
f9c0522
·
1 Parent(s): 8f1745b

Encoder, Decoder and Trainer models (assuming RoBERTa base models)

Browse files
qarac/models/QaracDecoderModel.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Tue Sep 5 10:29:03 2023
5
+
6
+ @author: peter
7
+ """
8
+
9
+ import keras
10
+ import transformers
11
+
12
+ class QaracDecoderHead(keras.layers.Layer):
13
+
14
+ def __init__(self,config):
15
+ super(QaracDecoderHead,self).__init__()
16
+ self.concat = keras.layers.Concatenate(axis=1)
17
+ self.layer_0 = transformers.TFRobertaLayer(config)
18
+ self.layer_1 = transformers.TFRobertalayer(config)
19
+ self.head = transformers.TFRobertaLMHead(config)
20
+
21
+ def call(self,inputs):
22
+ vectors = self.concat(inputs)
23
+ l0 = self.layer_0(vectors)
24
+ return self.head(self.layer1(l0.last_hidden_state[:,1:]))
25
+
26
+ class QaracDecoderModel(transformers.TFPretrainedModel):
27
+
28
+ def __init__(self,base_model):
29
+ super(QaracDecoderModel,self).__init__()
30
+ self.base_model = base_model
31
+ self.decoder_head = QaracDecoderHead(self.base_model.config)
32
+
33
+ def call(self,inputs):
34
+ (v,s) = inputs
35
+ return self.decoder_head((v,self.base_model(s)))
36
+
37
+
qarac/models/QaracEncoderModel.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Tue Sep 5 10:01:39 2023
5
+
6
+ @author: peter
7
+ """
8
+
9
+ import transformers
10
+ import qarac.layers.GlobalAttentionPoolingHead
11
+
12
+ class QaracEncoderModel(transformers.TFPretrainedModel):
13
+
14
+ def __init__(self,base_model):
15
+ super(QaracEncoderModel,self).__init__()
16
+ self.base_model = base_model
17
+ self.head = qarac.layers.GlobalAttentionPoolingHead.GlobalAttentionPoolingHead()
18
+
19
+ def call(self,inputs):
20
+ return self.head(self.base_model(inputs).last_hidden_state)
21
+
22
+
23
+
qarac/models/QaracTrainerModel.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Tue Sep 5 15:30:06 2023
5
+
6
+ @author: peter
7
+ """
8
+
9
+ import keras
10
+ import QaracEncoderModel
11
+ import QaracDecoderModel
12
+
13
+ class QuaracTrainerModel(keras.Model):
14
+
15
+ def __init__(self,base_encoder_model,base_decoder_model):
16
+
17
+ self.question_encoder = QaracEncoderModel.QaracEncoderModel(base_encoder_model)
18
+ self.answer_encoder = QaracEncoderModel.QaracEncoderModel(base_encoder_model)
19
+ self.decoder = QaracDecoderModel.QaracDecoderModel(base_decoder_model)
20
+ self.consistency = keras.layers.Dot(axes=1,normalize=True)
21
+
22
+ def call(self,inputs,training=None):
23
+ results = {}
24
+ results['encode_decode'] = self.decoder((self.answer_encoder(inputs['all_text']),
25
+ inputs['offset_text']))
26
+ results['question_answering'] = self.question_encoder(inputs['question']) - self.answer_encoder(inputs['answer'])
27
+ results['reasoning'] = self.decoder((self.answer_encoder(inputs['proposition0'])
28
+ +self.answer_encoder(inputs['proposition1']),
29
+ self.inputs['conclusion_offset']))
30
+ results['consistency'] = self.consistency((self.answer_encoder(inputs['statement0']),
31
+ self.answer_encoder(inputs['statement1'])))
32
+ return results