Spaces:
Runtime error
Runtime error
File size: 1,646 Bytes
37b9e99 223340a 37b9e99 223340a 37b9e99 223340a 37b9e99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
'''
Author: Qiguang Chen
Date: 2023-01-11 10:39:26
LastEditors: Qiguang Chen
LastEditTime: 2023-02-18 17:38:30
Description: pretrained encoder model
'''
from transformers import AutoModel, AutoConfig
from common import utils
from common.utils import InputData, HiddenData
from model.encoder.base_encoder import BaseEncoder
class PretrainedEncoder(BaseEncoder):
def __init__(self, **config):
""" init pretrained encoder
Args:
config (dict):
encoder_name (str): pretrained model name in hugging face.
"""
super().__init__(**config)
if self.config.get("_is_check_point_"):
self.encoder = utils.instantiate(config["pretrained_model"], target="_pretrained_model_target_")
# print(self.encoder)
else:
self.encoder = AutoModel.from_pretrained(config["encoder_name"])
def forward(self, inputs: InputData):
output = self.encoder(**inputs.get_inputs())
hidden = HiddenData(None, output.last_hidden_state)
if self.config.get("return_with_input"):
hidden.add_input(inputs)
if self.config.get("return_sentence_level_hidden"):
padding_side = self.config.get("padding_side")
if hasattr(output, "pooler_output"):
hidden.update_intent_hidden_state(output.pooler_output)
elif padding_side is not None and padding_side == "left":
hidden.update_intent_hidden_state(output.last_hidden_state[:, -1])
else:
hidden.update_intent_hidden_state(output.last_hidden_state[:, 0])
return hidden
|