''' Author: Qiguang Chen Date: 2023-01-11 10:39:26 LastEditors: Qiguang Chen LastEditTime: 2023-02-18 19:33:34 Description: ''' from common.utils import InputData from model.encoder.base_encoder import BaseEncoder, BiEncoder from model.encoder.pretrained_encoder import PretrainedEncoder from model.encoder.non_pretrained_encoder import NonPretrainedEncoder class AutoEncoder(BaseEncoder): def __init__(self, **config): """automatedly load encoder by 'encoder_name' Args: config (dict): encoder_name (str): support ["lstm", "self-attention-lstm", "bi-encoder"] and other pretrained model in hugging face **args (Any): other configuration items corresponding to each module. """ super().__init__() self.config = config if config.get("encoder_name"): encoder_name = config.get("encoder_name").lower() if encoder_name in ["lstm", "self-attention-lstm"]: self.__encoder = NonPretrainedEncoder(**config) elif encoder_name == "bi-encoder": self.__encoder= BiEncoder(self.__init__(**config["intent_encoder"]), self.__init__(**config["intent_encoder"])) else: self.__encoder = PretrainedEncoder(**config) else: raise ValueError("There is no Encoder Name in config.") def forward(self, inputs: InputData): return self.__encoder(inputs)