Spaces:
Runtime error
Runtime error
''' | |
Author: Qiguang Chen | |
Date: 2023-01-11 10:39:26 | |
LastEditors: Qiguang Chen | |
LastEditTime: 2023-01-31 18:22:36 | |
Description: | |
''' | |
from torch import nn | |
from common.utils import HiddenData, OutputData, InputData | |
class BaseDecoder(nn.Module): | |
"""Base class for all decoder module. | |
Notice: t is often only necessary to change this module and its sub-modules | |
""" | |
def __init__(self, intent_classifier=None, slot_classifier=None, interaction=None): | |
super().__init__() | |
self.intent_classifier = intent_classifier | |
self.slot_classifier = slot_classifier | |
self.interaction = interaction | |
def forward(self, hidden: HiddenData): | |
"""forward | |
Args: | |
hidden (HiddenData): encoded data | |
Returns: | |
OutputData: prediction logits | |
""" | |
if self.interaction is not None: | |
hidden = self.interaction(hidden) | |
intent = None | |
slot = None | |
if self.intent_classifier is not None: | |
intent = self.intent_classifier(hidden) | |
if self.slot_classifier is not None: | |
slot = self.slot_classifier(hidden) | |
return OutputData(intent, slot) | |
def decode(self, output: OutputData, target: InputData = None): | |
"""decode output logits | |
Args: | |
output (OutputData): output logits data | |
target (InputData, optional): input data with attention mask. Defaults to None. | |
Returns: | |
List: decoded sequence ids | |
""" | |
intent, slot = None, None | |
if self.intent_classifier is not None: | |
intent = self.intent_classifier.decode(output, target) | |
if self.slot_classifier is not None: | |
slot = self.slot_classifier.decode(output, target) | |
return OutputData(intent, slot) | |
def compute_loss(self, pred: OutputData, target: InputData, compute_intent_loss=True, compute_slot_loss=True): | |
"""compute loss. | |
Notice: can set intent and slot loss weight by adding 'weight' config item in corresponding classifier configuration. | |
Args: | |
pred (OutputData): output logits data | |
target (InputData): input golden data | |
compute_intent_loss (bool, optional): whether to compute intent loss. Defaults to True. | |
compute_slot_loss (bool, optional): whether to compute intent loss. Defaults to True. | |
Returns: | |
Tensor: loss result | |
""" | |
loss = 0 | |
intent_loss = None | |
slot_loss = None | |
if self.intent_classifier is not None: | |
intent_loss = self.intent_classifier.compute_loss(pred, target) if compute_intent_loss else None | |
intent_weight = self.intent_classifier.config.get("weight") | |
intent_weight = intent_weight if intent_weight is not None else 1. | |
loss += intent_loss * intent_weight | |
if self.slot_classifier is not None: | |
slot_loss = self.slot_classifier.compute_loss(pred, target) if compute_slot_loss else None | |
slot_weight = self.slot_classifier.config.get("weight") | |
slot_weight = slot_weight if slot_weight is not None else 1. | |
loss += slot_loss * slot_weight | |
return loss, intent_loss, slot_loss | |
class StackPropagationDecoder(BaseDecoder): | |
def forward(self, hidden: HiddenData): | |
# hidden = self.interaction(hidden) | |
pred_intent = self.intent_classifier(hidden) | |
# embedding = pred_intent.output_embedding if pred_intent.output_embedding is not None else pred_intent.classifier_output | |
# hidden.update_intent_hidden_state(torch.cat([hidden.get_slot_hidden_state(), embedding], dim=-1)) | |
hidden = self.interaction(pred_intent, hidden) | |
pred_slot = self.slot_classifier(hidden) | |
return OutputData(pred_intent, pred_slot) | |
class DCANetDecoder(BaseDecoder): | |
def forward(self, hidden: HiddenData): | |
if self.interaction is not None: | |
hidden = self.interaction(hidden, intent_emb=self.intent_classifier, slot_emb=self.slot_classifier) | |
return OutputData(self.intent_classifier(hidden), self.slot_classifier(hidden)) | |