File size: 1,249 Bytes
37b9e99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
'''
Author: Qiguang Chen
Date: 2023-01-11 10:39:26
LastEditors: Qiguang Chen
LastEditTime: 2023-01-26 17:25:17
Description: Base encoder and bi encoder

'''
from torch import nn

from common.utils import InputData


class BaseEncoder(nn.Module):
    """Base class for all encoder module
    """
    def __init__(self, **config):
        super().__init__()
        self.config = config
        NotImplementedError("no implement")

    def forward(self, inputs: InputData):
        self.encoder(inputs.input_ids)


class BiEncoder(nn.Module):
    """Bi Encoder for encode intent and slot separately
    """
    def __init__(self, intent_encoder: BaseEncoder, slot_encoder: BaseEncoder, **config):
        super().__init__()
        self.intent_encoder = intent_encoder
        self.slot_encoder = slot_encoder

    def forward(self, inputs: InputData):
        hidden_slot = self.slot_encoder(inputs)
        hidden_intent = self.intent_encoder(inputs)
        if not self.intent_encoder.config["return_sentence_level_hidden"]:
            hidden_slot.update_intent_hidden_state(hidden_intent.get_slot_hidden_state())
        else:
            hidden_slot.update_intent_hidden_state(hidden_intent.get_intent_hidden_state())
        return hidden_slot