OpenSLU / model /decoder /agif_decoder.py
LightChen2333's picture
Upload 34 files
37b9e99
raw
history blame
905 Bytes
from common.utils import HiddenData, OutputData
from model.decoder.base_decoder import BaseDecoder
class AGIFDecoder(BaseDecoder):
def forward(self, hidden: HiddenData, **kwargs):
# hidden = self.interaction(hidden)
pred_intent = self.intent_classifier(hidden)
intent_index = self.intent_classifier.decode(OutputData(pred_intent, None),
return_list=False,
return_sentence_level=True)
interact_args = {"intent_index": intent_index,
"batch_size": pred_intent.classifier_output.shape[0],
"intent_label_num": self.intent_classifier.config["intent_label_num"]}
pred_slot = self.slot_classifier(hidden, internal_interaction=self.interaction, **interact_args)
return OutputData(pred_intent, pred_slot)