from common.utils import HiddenData, OutputData from model.decoder.base_decoder import BaseDecoder class AGIFDecoder(BaseDecoder): def forward(self, hidden: HiddenData, **kwargs): # hidden = self.interaction(hidden) pred_intent = self.intent_classifier(hidden) intent_index = self.intent_classifier.decode(OutputData(pred_intent, None), return_list=False, return_sentence_level=True) interact_args = {"intent_index": intent_index, "batch_size": pred_intent.classifier_output.shape[0], "intent_label_num": self.intent_classifier.config["intent_label_num"]} pred_slot = self.slot_classifier(hidden, internal_interaction=self.interaction, **interact_args) return OutputData(pred_intent, pred_slot)