File size: 4,147 Bytes
f9c0522
 
 
 
 
 
 
 
 
5b7a8ed
f9c0522
13f1508
f9c0522
a5b7b8e
f16a715
 
 
 
 
 
 
 
 
 
 
 
 
8172944
13f1508
 
a5b7b8e
f9c0522
f16a715
b2593fa
 
 
13f1508
b2593fa
 
684c1d8
f16a715
 
 
 
 
 
 
 
 
 
13f1508
f16a715
 
 
13f1508
 
 
30efe84
 
b2593fa
684c1d8
0941a89
684c1d8
f9c0522
f0ad7f1
5b7a8ed
f9c0522
f0ad7f1
f16a715
 
 
 
 
 
 
 
 
 
 
 
 
858f75e
 
 
a5b7b8e
f16a715
13f1508
f9c0522
13f1508
f16a715
684c1d8
f16a715
 
 
 
63b2c6a
f16a715
 
 
 
 
 
13f1508
f16a715
 
 
 
69cf4c5
13f1508
69cf4c5
 
 
f16a715
6ebe943
 
 
 
 
13f1508
6ebe943
 
f16a715
f9c0522
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Sep  5 10:29:03 2023

@author: peter
"""

import transformers
import torch

class QaracDecoderHead(torch.nn.Module):
    
    def __init__(self,config):
        """
        Creates the Decoder head

        Parameters
        ----------
        config : transformers.RobertaConfig
            Config for the RobertaModel that this head will be attached to.

        Returns
        -------
        None.

        """
        super(QaracDecoderHead,self).__init__()
        self.layer_0 = transformers.models.roberta.modeling_roberta.RobertaLayer(config)
        self.layer_1 = transformers.models.roberta.modeling_roberta.RobertaLayer(config)
        self.head = transformers.models.roberta.modeling_roberta.RobertaLMHead(config)
        

        
        
        
    def forward(self,
             vector,
             hidden_states,
             attention_mask=None):
        """
        Predicts text fron vector and hidden states of base model

        Parameters
        ----------
        inputs : tuple of tensorflow.Tensors
            Vector to be decoded and last hidden states of base model

        Returns
        -------
        transformers.modeling_outputs.CausalLMOutputWithCrossAttentions
            Predicted text

        """
        vectors = torch.cat([vector, hidden_states],
                            dim=1)
        attentions = attention_mask if attention_mask is None else torch.cat([torch.ones((hidden_states.shape(0),
                                                                                                 1)),
                                                                                attention_mask])
        l0 = self.layer_0(vectors,
                          attentions)
        return self.head(self.layer_1(l0[0][:,1:],
                                      attention_mask)[0])

class QaracDecoderModel(transformers.RobertaModel,
                        transformers.generation_utils.GenerationMixin):
    
    def __init__(self,model_path,config,tokenizer):
        """
        Creates decoder model from base model

        Parameters
        ----------
        base_model : transformers.TFRobertaModel
            The base model

        Returns
        -------
        None.

        """
        super(QaracDecoderModel,self).__init__(config)
        self.decoder_base = transformers.RobertaModel.from_pretrained(model_path,
                                                                      config=config)
        self.decoder_head = QaracDecoderHead(self.config)
        self.tokenizer = tokenizer

        
    def forward(self,inputs,**kwargs):
        """
        Predicts text from inputsBleakley

        Parameters
        ----------
        inputs : tuple of Tensorflow.Tensors OR tensorflow.Tensor
            Vector to be converted to text and seed text OR tokenized seed text
        kwargs : optional keyword arguments
            vector : tensorflow.Tensor vector to be decoded. May be supplied 
                     via a keyword argument when this is invoked by .generate

        Returns
        -------
        transformers.modeling_outputs.CausalLMOutputWithCrossAttentions
            Predicted text

        """
        (v,s) = (kwargs['vector'],inputs) if 'vector' in kwargs else inputs
        (seed,attention_mask) = (s['input_ids'],s['attention_mask']) if 'attention_mask' in s else (s,None)
        return self.decoder_head(torch.unsqueeze(v,1),
                                  self.decoder_base(seed,
                                                    attention_mask=attention_mask,
                                                    use_cache='vector' in kwargs).last_hidden_state)
    
    def prepare_inputs_for_generation(self, 
                                      input_ids, 
                                      attention_mask=None,
                                      **kwargs):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        return {'input_ids':input_ids,
                'attention_mask':attention_mask}