File size: 1,646 Bytes
37b9e99
 
 
 
223340a
37b9e99
 
 
223340a
 
37b9e99
 
 
 
 
 
 
 
 
 
 
 
 
 
223340a
 
 
 
 
37b9e99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
'''
Author: Qiguang Chen
Date: 2023-01-11 10:39:26
LastEditors: Qiguang Chen
LastEditTime: 2023-02-18 17:38:30
Description: pretrained encoder model

'''
from transformers import AutoModel, AutoConfig
from common import utils

from common.utils import InputData, HiddenData
from model.encoder.base_encoder import BaseEncoder


class PretrainedEncoder(BaseEncoder):
    def __init__(self, **config):
        """ init pretrained encoder
        
        Args:
            config (dict):
                encoder_name (str): pretrained model name in hugging face.
        """
        super().__init__(**config)
        if self.config.get("_is_check_point_"):
            self.encoder = utils.instantiate(config["pretrained_model"], target="_pretrained_model_target_")
            # print(self.encoder)
        else:
            self.encoder = AutoModel.from_pretrained(config["encoder_name"])

    def forward(self, inputs: InputData):
        output = self.encoder(**inputs.get_inputs())
        hidden = HiddenData(None, output.last_hidden_state)
        if self.config.get("return_with_input"):
            hidden.add_input(inputs)
        if self.config.get("return_sentence_level_hidden"):
            padding_side = self.config.get("padding_side")
            if hasattr(output, "pooler_output"):
                hidden.update_intent_hidden_state(output.pooler_output)
            elif padding_side is not None and padding_side == "left":
                hidden.update_intent_hidden_state(output.last_hidden_state[:, -1])
            else:
                hidden.update_intent_hidden_state(output.last_hidden_state[:, 0])
        return hidden