Spaces:
Runtime error
Runtime error
File size: 4,158 Bytes
37b9e99 223340a 37b9e99 223340a 37b9e99 223340a 37b9e99 223340a 37b9e99 223340a 37b9e99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
'''
Author: Qiguang Chen
Date: 2023-01-11 10:39:26
LastEditors: Qiguang Chen
LastEditTime: 2023-01-31 18:22:36
Description:
'''
from torch import nn
from common.utils import HiddenData, OutputData, InputData
class BaseDecoder(nn.Module):
"""Base class for all decoder module.
Notice: t is often only necessary to change this module and its sub-modules
"""
def __init__(self, intent_classifier=None, slot_classifier=None, interaction=None):
super().__init__()
self.intent_classifier = intent_classifier
self.slot_classifier = slot_classifier
self.interaction = interaction
def forward(self, hidden: HiddenData):
"""forward
Args:
hidden (HiddenData): encoded data
Returns:
OutputData: prediction logits
"""
if self.interaction is not None:
hidden = self.interaction(hidden)
intent = None
slot = None
if self.intent_classifier is not None:
intent = self.intent_classifier(hidden)
if self.slot_classifier is not None:
slot = self.slot_classifier(hidden)
return OutputData(intent, slot)
def decode(self, output: OutputData, target: InputData = None):
"""decode output logits
Args:
output (OutputData): output logits data
target (InputData, optional): input data with attention mask. Defaults to None.
Returns:
List: decoded sequence ids
"""
intent, slot = None, None
if self.intent_classifier is not None:
intent = self.intent_classifier.decode(output, target)
if self.slot_classifier is not None:
slot = self.slot_classifier.decode(output, target)
return OutputData(intent, slot)
def compute_loss(self, pred: OutputData, target: InputData, compute_intent_loss=True, compute_slot_loss=True):
"""compute loss.
Notice: can set intent and slot loss weight by adding 'weight' config item in corresponding classifier configuration.
Args:
pred (OutputData): output logits data
target (InputData): input golden data
compute_intent_loss (bool, optional): whether to compute intent loss. Defaults to True.
compute_slot_loss (bool, optional): whether to compute intent loss. Defaults to True.
Returns:
Tensor: loss result
"""
loss = 0
intent_loss = None
slot_loss = None
if self.intent_classifier is not None:
intent_loss = self.intent_classifier.compute_loss(pred, target) if compute_intent_loss else None
intent_weight = self.intent_classifier.config.get("weight")
intent_weight = intent_weight if intent_weight is not None else 1.
loss += intent_loss * intent_weight
if self.slot_classifier is not None:
slot_loss = self.slot_classifier.compute_loss(pred, target) if compute_slot_loss else None
slot_weight = self.slot_classifier.config.get("weight")
slot_weight = slot_weight if slot_weight is not None else 1.
loss += slot_loss * slot_weight
return loss, intent_loss, slot_loss
class StackPropagationDecoder(BaseDecoder):
def forward(self, hidden: HiddenData):
# hidden = self.interaction(hidden)
pred_intent = self.intent_classifier(hidden)
# embedding = pred_intent.output_embedding if pred_intent.output_embedding is not None else pred_intent.classifier_output
# hidden.update_intent_hidden_state(torch.cat([hidden.get_slot_hidden_state(), embedding], dim=-1))
hidden = self.interaction(pred_intent, hidden)
pred_slot = self.slot_classifier(hidden)
return OutputData(pred_intent, pred_slot)
class DCANetDecoder(BaseDecoder):
def forward(self, hidden: HiddenData):
if self.interaction is not None:
hidden = self.interaction(hidden, intent_emb=self.intent_classifier, slot_emb=self.slot_classifier)
return OutputData(self.intent_classifier(hidden), self.slot_classifier(hidden))
|