''' Author: Qiguang Chen Date: 2023-01-11 10:39:26 LastEditors: Qiguang Chen LastEditTime: 2023-01-26 17:25:17 Description: Base encoder and bi encoder ''' from torch import nn from common.utils import InputData class BaseEncoder(nn.Module): """Base class for all encoder module """ def __init__(self, **config): super().__init__() self.config = config NotImplementedError("no implement") def forward(self, inputs: InputData): self.encoder(inputs.input_ids) class BiEncoder(nn.Module): """Bi Encoder for encode intent and slot separately """ def __init__(self, intent_encoder: BaseEncoder, slot_encoder: BaseEncoder, **config): super().__init__() self.intent_encoder = intent_encoder self.slot_encoder = slot_encoder def forward(self, inputs: InputData): hidden_slot = self.slot_encoder(inputs) hidden_intent = self.intent_encoder(inputs) if not self.intent_encoder.config["return_sentence_level_hidden"]: hidden_slot.update_intent_hidden_state(hidden_intent.get_slot_hidden_state()) else: hidden_slot.update_intent_hidden_state(hidden_intent.get_intent_hidden_state()) return hidden_slot