Spaces:
Runtime error
Runtime error
File size: 1,459 Bytes
37b9e99 223340a 37b9e99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
'''
Author: Qiguang Chen
Date: 2023-01-11 10:39:26
LastEditors: Qiguang Chen
LastEditTime: 2023-02-18 19:33:34
Description:
'''
from common.utils import InputData
from model.encoder.base_encoder import BaseEncoder, BiEncoder
from model.encoder.pretrained_encoder import PretrainedEncoder
from model.encoder.non_pretrained_encoder import NonPretrainedEncoder
class AutoEncoder(BaseEncoder):
def __init__(self, **config):
"""automatedly load encoder by 'encoder_name'
Args:
config (dict):
encoder_name (str): support ["lstm", "self-attention-lstm", "bi-encoder"] and other pretrained model in hugging face
**args (Any): other configuration items corresponding to each module.
"""
super().__init__()
self.config = config
if config.get("encoder_name"):
encoder_name = config.get("encoder_name").lower()
if encoder_name in ["lstm", "self-attention-lstm"]:
self.__encoder = NonPretrainedEncoder(**config)
elif encoder_name == "bi-encoder":
self.__encoder= BiEncoder(self.__init__(**config["intent_encoder"]), self.__init__(**config["intent_encoder"]))
else:
self.__encoder = PretrainedEncoder(**config)
else:
raise ValueError("There is no Encoder Name in config.")
def forward(self, inputs: InputData):
return self.__encoder(inputs) |