File size: 4,158 Bytes
37b9e99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223340a
37b9e99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223340a
 
 
 
 
 
 
37b9e99
 
 
 
 
 
 
 
 
 
 
223340a
 
 
 
 
 
37b9e99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223340a
 
 
 
 
 
37b9e99
223340a
 
 
 
37b9e99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
'''
Author: Qiguang Chen
Date: 2023-01-11 10:39:26
LastEditors: Qiguang Chen
LastEditTime: 2023-01-31 18:22:36
Description: 

'''
from torch import nn

from common.utils import HiddenData, OutputData, InputData


class BaseDecoder(nn.Module):
    """Base class for all decoder module. 
    
    Notice: t is often only necessary to change this module and its sub-modules
    """
    def __init__(self, intent_classifier=None, slot_classifier=None, interaction=None):
        super().__init__()
        self.intent_classifier = intent_classifier
        self.slot_classifier = slot_classifier
        self.interaction = interaction

    def forward(self, hidden: HiddenData):
        """forward

        Args:
            hidden (HiddenData): encoded data

        Returns:
            OutputData: prediction logits 
        """
        if self.interaction is not None:
            hidden = self.interaction(hidden)
        intent = None
        slot = None
        if self.intent_classifier is not None:
            intent = self.intent_classifier(hidden)
        if self.slot_classifier is not None:
            slot = self.slot_classifier(hidden)
        return OutputData(intent, slot)

    def decode(self, output: OutputData, target: InputData = None):
        """decode output logits

        Args:
            output (OutputData): output logits data
            target (InputData, optional): input data with attention mask. Defaults to None.

        Returns:
            List: decoded sequence ids
        """
        intent, slot = None, None
        if self.intent_classifier is not None:
            intent = self.intent_classifier.decode(output, target)
        if self.slot_classifier is not None:
            slot = self.slot_classifier.decode(output, target)
        return OutputData(intent, slot)

    def compute_loss(self, pred: OutputData, target: InputData, compute_intent_loss=True, compute_slot_loss=True):
        """compute loss.
        Notice: can set intent and slot loss weight by adding 'weight' config item in corresponding classifier configuration.

        Args:
            pred (OutputData): output logits data
            target (InputData): input golden data
            compute_intent_loss (bool, optional): whether to compute intent loss. Defaults to True.
            compute_slot_loss (bool, optional): whether to compute intent loss. Defaults to True.

        Returns:
            Tensor: loss result
        """
        loss = 0
        intent_loss = None
        slot_loss = None
        if self.intent_classifier is not None:
            intent_loss = self.intent_classifier.compute_loss(pred, target) if compute_intent_loss else None
            intent_weight =  self.intent_classifier.config.get("weight")
            intent_weight = intent_weight if intent_weight is not None else 1.
            loss += intent_loss * intent_weight
        if self.slot_classifier is not None:
            slot_loss = self.slot_classifier.compute_loss(pred, target) if compute_slot_loss else None
            slot_weight = self.slot_classifier.config.get("weight")
            slot_weight = slot_weight if slot_weight is not None else 1.
            loss += slot_loss * slot_weight
        return loss, intent_loss, slot_loss


class StackPropagationDecoder(BaseDecoder):

    def forward(self, hidden: HiddenData):
        # hidden = self.interaction(hidden)
        pred_intent = self.intent_classifier(hidden)
        # embedding = pred_intent.output_embedding if pred_intent.output_embedding is not None else pred_intent.classifier_output
        # hidden.update_intent_hidden_state(torch.cat([hidden.get_slot_hidden_state(), embedding], dim=-1))
        hidden = self.interaction(pred_intent, hidden)
        pred_slot = self.slot_classifier(hidden)
        return OutputData(pred_intent, pred_slot)

class DCANetDecoder(BaseDecoder):

    def forward(self, hidden: HiddenData):
        if self.interaction is not None:
            hidden = self.interaction(hidden, intent_emb=self.intent_classifier, slot_emb=self.slot_classifier)
        return OutputData(self.intent_classifier(hidden), self.slot_classifier(hidden))