File size: 2,699 Bytes
37b9e99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import math

import einops
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import LayerNorm

from common.utils import HiddenData
from model.decoder.interaction import BaseInteraction


class SlotGatedInteraction(BaseInteraction):
    def __init__(self, **config):
        super().__init__(**config)
        self.intent_linear = nn.Linear(self.config["input_dim"],1, bias=False)
        self.slot_linear1 = nn.Linear(self.config["input_dim"],1, bias=False)
        self.slot_linear2 = nn.Linear(self.config["input_dim"],1, bias=False)
        self.remove_slot_attn = self.config["remove_slot_attn"]
        self.slot_gate = SlotGate(**config)

    def forward(self, encode_hidden: HiddenData, **kwargs):
        input_hidden = encode_hidden.get_slot_hidden_state()

        seq_lens = encode_hidden.inputs.attention_mask.sum(-1)
        output_list = []
        for index, slen in enumerate(seq_lens):
            output_list.append(input_hidden[index, slen - 1, :].unsqueeze(0))
        intent_input = torch.cat(output_list, dim=0)
        e_I = torch.tanh(self.intent_linear(intent_input)).squeeze(1)
        alpha_I = einops.repeat(e_I, 'b -> b h', h=intent_input.shape[-1])
        c_I = alpha_I * intent_input
        intent_hidden = intent_input+c_I
        if not self.remove_slot_attn:
            # slot attention
            h_k = einops.repeat(self.slot_linear1(input_hidden), 'b l h -> b l c h', c=input_hidden.shape[1])
            h_i = einops.repeat(self.slot_linear2(input_hidden), 'b l h -> b l c h', c=input_hidden.shape[1]).transpose(1,2)
            e_S = torch.tanh(h_k + h_i)
            alpha_S = torch.softmax(e_S, dim=2).squeeze(3)
            alpha_S = einops.repeat(alpha_S, 'b l1 l2 -> b l1 l2 h', h=input_hidden.shape[-1])
            map_input_hidden = einops.repeat(input_hidden, 'b l h -> b l c h', c=input_hidden.shape[1])
            c_S = torch.sum(alpha_S * map_input_hidden, dim=2)
        else:
            c_S = input_hidden
        slot_hidden = input_hidden + c_S * self.slot_gate(c_S,c_I)
        encode_hidden.update_intent_hidden_state(intent_hidden)
        encode_hidden.update_slot_hidden_state(slot_hidden)
        return encode_hidden

class SlotGate(nn.Module):
    def __init__(self, **config):
        super().__init__()
        self.linear = nn.Linear(config["input_dim"], config["output_dim"],bias=False)
        self.v = nn.Parameter(torch.rand(size=[1]))

    def forward(self, slot_context, intent_context):
        intent_gate = self.linear(intent_context)
        intent_gate = einops.repeat(intent_gate, 'b h -> b l h', l=slot_context.shape[1])
        return self.v * torch.tanh(slot_context + intent_gate)