File size: 1,593 Bytes
f9c0522
 
 
 
 
 
 
 
87535ff
f16a715
f9c0522
858f75e
f9c0522
858f75e
f16a715
 
 
 
 
 
 
 
 
 
 
 
 
858f75e
 
215b416
87535ff
f16a715
f9c0522
37a581e
 
f16a715
 
 
 
 
 
 
 
 
 
 
 
 
 
ca642d2
b5ce6f8
 
1a9032d
 
858f75e
 
37a581e
519dfd1
f9c0522
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Sep  5 10:01:39 2023

@author: peter
"""

import transformers
import qarac.models.layers.GlobalAttentionPoolingHead

class QaracEncoderModel(transformers.PreTrainedModel):
    
    def __init__(self,path):
        """
        Creates the endocer model

        Parameters
        ----------
        base_model : transformers.TFRobertaModel
            The base model

        Returns
        -------
        None.

        """
        config = transformers.PretrainedConfig.from_pretrained(path)
        super(QaracEncoderModel,self).__init__(config)
        self.encoder = transformers.RobertaModel.from_pretrained(path)
        self.head = qarac.models.layers.GlobalAttentionPoolingHead.GlobalAttentionPoolingHead(config)
        
        
    def forward(self,input_ids,
             attention_mask=None):
        """
        Vectorizes a tokenised text

        Parameters
        ----------
        inputs : tensorflow.Tensor
            tokenized text to endode

        Returns
        -------
        tensorflow.Tensor
            Vector representing the document

        """

        if attention_mask is None and 'attention_mask' in input_ids:
            (input_ids,attention_mask) = (input_ids['input_ids'],input_ids['attention_mask'])
        print('input_ids',input_ids.device)
        print('attention_mask',attention_mask.device)
        return self.head(self.encoder(input_ids,
                                      attention_mask).last_hidden_state,
                         attention_mask)