File size: 1,994 Bytes
37b9e99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
'''
Author: Qiguang Chen
Date: 2023-01-11 10:39:26
LastEditors: Qiguang Chen
LastEditTime: 2023-01-26 17:18:22
Description: Root Model Module

'''
from torch import nn

from common.utils import OutputData, InputData
from model.decoder.base_decoder import BaseDecoder
from model.encoder.base_encoder import BaseEncoder


class OpenSLUModel(nn.Module):
    def __init__(self, encoder: BaseEncoder, decoder:BaseDecoder, **config):
        """Create model automatedly

        Args:
            encoder (BaseEncoder): encoder created by config
            decoder (BaseDecoder): decoder created by config
            config (dict): any other args
        """
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.config = config

    def forward(self, inp: InputData) -> OutputData:
        """ model forward

        Args:
            inp (InputData): input ids and other information

        Returns:
            OutputData: pred logits
        """
        return self.decoder(self.encoder(inp))

    def decode(self, output: OutputData, target: InputData=None):
        """ decode output

        Args:
            pred (OutputData): pred logits data
            target (InputData): golden data

        Returns: decoded ids
        """
        return self.decoder.decode(output, target)

    def compute_loss(self, pred: OutputData, target: InputData, compute_intent_loss=True, compute_slot_loss=True):
        """ compute loss

        Args:
            pred (OutputData): pred logits data
            target (InputData): golden data
            compute_intent_loss (bool, optional): whether to compute intent loss. Defaults to True.
            compute_slot_loss (bool, optional): whether to compute slot loss. Defaults to True.

        Returns: loss value
        """
        return self.decoder.compute_loss(pred, target, compute_intent_loss=compute_intent_loss,
                                         compute_slot_loss=compute_slot_loss)