OpenSLU / model /decoder /gl_gin_decoder.py
LightChen2333's picture
Upload 34 files
37b9e99
raw
history blame
2.05 kB
import torch
import torch.nn.functional as F
from torch import nn
from common.utils import HiddenData, OutputData, InputData
from model.decoder import BaseDecoder
from model.decoder.interaction.gl_gin_interaction import LSTMEncoder
class IntentEncoder(nn.Module):
def __init__(self,input_dim, dropout_rate):
super().__init__()
self.dropout_rate = dropout_rate
self.__intent_lstm = LSTMEncoder(
input_dim,
input_dim,
dropout_rate
)
def forward(self, g_hiddens, seq_lens):
intent_lstm_out = self.__intent_lstm(g_hiddens, seq_lens)
return F.dropout(intent_lstm_out, p=self.dropout_rate, training=self.training)
class GLGINDecoder(BaseDecoder):
def __init__(self, intent_classifier, slot_classifier, interaction=None, **config):
super().__init__(intent_classifier, slot_classifier, interaction)
self.config=config
self.intent_encoder = IntentEncoder(self.intent_classifier.config["input_dim"], self.config["dropout_rate"])
def forward(self, hidden: HiddenData, forced_slot=None, forced_intent=None, differentiable=None):
seq_lens = hidden.inputs.attention_mask.sum(-1)
intent_lstm_out = self.intent_encoder(hidden.slot_hidden, seq_lens)
hidden.update_intent_hidden_state(intent_lstm_out)
pred_intent = self.intent_classifier(hidden)
intent_index = self.intent_classifier.decode(OutputData(pred_intent, None),hidden.inputs,
return_list=False,
return_sentence_level=True)
slot_hidden = self.interaction(
hidden,
pred_intent=pred_intent,
intent_index=intent_index,
)
pred_slot = self.slot_classifier(slot_hidden)
num_intent = self.intent_classifier.config["intent_label_num"]
pred_slot = pred_slot.classifier_output[:, num_intent:]
return OutputData(pred_intent, F.log_softmax(pred_slot, dim=1))