Spaces:
Runtime error
Runtime error
import math | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torch.nn import LayerNorm | |
from common.utils import HiddenData | |
from model.decoder.interaction import BaseInteraction | |
class DCANetInteraction(BaseInteraction): | |
def __init__(self, **config): | |
super().__init__(**config) | |
self.I_S_Emb = Label_Attention() | |
self.T_block1 = I_S_Block(self.config["input_dim"], self.config["attention_dropout"], self.config["num_attention_heads"]) | |
self.T_block2 = I_S_Block(self.config["input_dim"], self.config["attention_dropout"], self.config["num_attention_heads"]) | |
def forward(self, encode_hidden: HiddenData, **kwargs): | |
mask = encode_hidden.inputs.attention_mask | |
H = encode_hidden.slot_hidden | |
H_I, H_S = self.I_S_Emb(H, H, kwargs["intent_emb"], kwargs["slot_emb"]) | |
H_I, H_S = self.T_block1(H_I + H, H_S + H, mask) | |
H_I_1, H_S_1 = self.I_S_Emb(H_I, H_S, kwargs["intent_emb"], kwargs["slot_emb"]) | |
H_I, H_S = self.T_block2(H_I + H_I_1, H_S + H_S_1, mask) | |
encode_hidden.update_intent_hidden_state(F.max_pool1d((H_I + H).transpose(1, 2), H_I.size(1)).squeeze(2)) | |
encode_hidden.update_slot_hidden_state(H_S + H) | |
return encode_hidden | |
class Label_Attention(nn.Module): | |
def __init__(self): | |
super(Label_Attention, self).__init__() | |
def forward(self, input_intent, input_slot, intent_emb, slot_emb): | |
self.W_intent_emb = intent_emb.intent_classifier.weight | |
self.W_slot_emb = slot_emb.slot_classifier.weight | |
intent_score = torch.matmul(input_intent, self.W_intent_emb.t()) | |
slot_score = torch.matmul(input_slot, self.W_slot_emb.t()) | |
intent_probs = nn.Softmax(dim=-1)(intent_score) | |
slot_probs = nn.Softmax(dim=-1)(slot_score) | |
intent_res = torch.matmul(intent_probs, self.W_intent_emb) | |
slot_res = torch.matmul(slot_probs, self.W_slot_emb) | |
return intent_res, slot_res | |
class I_S_Block(nn.Module): | |
def __init__(self, hidden_size, attention_dropout, num_attention_heads): | |
super(I_S_Block, self).__init__() | |
self.I_S_Attention = I_S_SelfAttention(hidden_size, 2 * hidden_size, hidden_size, attention_dropout, num_attention_heads) | |
self.I_Out = SelfOutput(hidden_size, attention_dropout) | |
self.S_Out = SelfOutput(hidden_size, attention_dropout) | |
self.I_S_Feed_forward = Intermediate_I_S(hidden_size, hidden_size, attention_dropout) | |
def forward(self, H_intent_input, H_slot_input, mask): | |
H_slot, H_intent = self.I_S_Attention(H_intent_input, H_slot_input, mask) | |
H_slot = self.S_Out(H_slot, H_slot_input) | |
H_intent = self.I_Out(H_intent, H_intent_input) | |
H_intent, H_slot = self.I_S_Feed_forward(H_intent, H_slot) | |
return H_intent, H_slot | |
class I_S_SelfAttention(nn.Module): | |
def __init__(self, input_size, hidden_size, out_size, attention_dropout, num_attention_heads): | |
super(I_S_SelfAttention, self).__init__() | |
self.num_attention_heads = num_attention_heads | |
self.attention_head_size = int(hidden_size / self.num_attention_heads) | |
self.all_head_size = self.num_attention_heads * self.attention_head_size | |
self.out_size = out_size | |
self.query = nn.Linear(input_size, self.all_head_size) | |
self.query_slot = nn.Linear(input_size, self.all_head_size) | |
self.key = nn.Linear(input_size, self.all_head_size) | |
self.key_slot = nn.Linear(input_size, self.all_head_size) | |
self.value = nn.Linear(input_size, self.out_size) | |
self.value_slot = nn.Linear(input_size, self.out_size) | |
self.dropout = nn.Dropout(attention_dropout) | |
def transpose_for_scores(self, x): | |
last_dim = int(x.size()[-1] / self.num_attention_heads) | |
new_x_shape = x.size()[:-1] + (self.num_attention_heads, last_dim) | |
x = x.view(*new_x_shape) | |
return x.permute(0, 2, 1, 3) | |
def forward(self, intent, slot, mask): | |
extended_attention_mask = mask.unsqueeze(1).unsqueeze(2) | |
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | |
attention_mask = (1.0 - extended_attention_mask) * -10000.0 | |
mixed_query_layer = self.query(intent) | |
mixed_key_layer = self.key(slot) | |
mixed_value_layer = self.value(slot) | |
mixed_query_layer_slot = self.query_slot(slot) | |
mixed_key_layer_slot = self.key_slot(intent) | |
mixed_value_layer_slot = self.value_slot(intent) | |
query_layer = self.transpose_for_scores(mixed_query_layer) | |
query_layer_slot = self.transpose_for_scores(mixed_query_layer_slot) | |
key_layer = self.transpose_for_scores(mixed_key_layer) | |
key_layer_slot = self.transpose_for_scores(mixed_key_layer_slot) | |
value_layer = self.transpose_for_scores(mixed_value_layer) | |
value_layer_slot = self.transpose_for_scores(mixed_value_layer_slot) | |
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |
attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |
# attention_scores_slot = torch.matmul(query_slot, key_slot.transpose(1,0)) | |
attention_scores_slot = torch.matmul(query_layer_slot, key_layer_slot.transpose(-1, -2)) | |
attention_scores_slot = attention_scores_slot / math.sqrt(self.attention_head_size) | |
attention_scores_intent = attention_scores + attention_mask | |
attention_scores_slot = attention_scores_slot + attention_mask | |
# Normalize the attention scores to probabilities. | |
attention_probs_slot = nn.Softmax(dim=-1)(attention_scores_slot) | |
attention_probs_intent = nn.Softmax(dim=-1)(attention_scores_intent) | |
attention_probs_slot = self.dropout(attention_probs_slot) | |
attention_probs_intent = self.dropout(attention_probs_intent) | |
context_layer_slot = torch.matmul(attention_probs_slot, value_layer_slot) | |
context_layer_intent = torch.matmul(attention_probs_intent, value_layer) | |
context_layer = context_layer_slot.permute(0, 2, 1, 3).contiguous() | |
context_layer_intent = context_layer_intent.permute(0, 2, 1, 3).contiguous() | |
new_context_layer_shape = context_layer.size()[:-2] + (self.out_size,) | |
new_context_layer_shape_intent = context_layer_intent.size()[:-2] + (self.out_size,) | |
context_layer = context_layer.view(*new_context_layer_shape) | |
context_layer_intent = context_layer_intent.view(*new_context_layer_shape_intent) | |
return context_layer, context_layer_intent | |
class SelfOutput(nn.Module): | |
def __init__(self, hidden_size, hidden_dropout_prob): | |
super(SelfOutput, self).__init__() | |
self.dense = nn.Linear(hidden_size, hidden_size) | |
self.LayerNorm = LayerNorm(hidden_size, eps=1e-12) | |
self.dropout = nn.Dropout(hidden_dropout_prob) | |
def forward(self, hidden_states, input_tensor): | |
hidden_states = self.dense(hidden_states) | |
hidden_states = self.dropout(hidden_states) | |
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |
return hidden_states | |
class Intermediate_I_S(nn.Module): | |
def __init__(self, intermediate_size, hidden_size, attention_dropout): | |
super(Intermediate_I_S, self).__init__() | |
self.dense_in = nn.Linear(hidden_size * 6, intermediate_size) | |
self.intermediate_act_fn = nn.ReLU() | |
self.dense_out = nn.Linear(intermediate_size, hidden_size) | |
self.LayerNorm_I = LayerNorm(hidden_size, eps=1e-12) | |
self.LayerNorm_S = LayerNorm(hidden_size, eps=1e-12) | |
self.dropout = nn.Dropout(attention_dropout) | |
def forward(self, hidden_states_I, hidden_states_S): | |
hidden_states_in = torch.cat([hidden_states_I, hidden_states_S], dim=2) | |
batch_size, max_length, hidden_size = hidden_states_in.size() | |
h_pad = torch.zeros(batch_size, 1, hidden_size).to(hidden_states_I.device) | |
h_left = torch.cat([h_pad, hidden_states_in[:, :max_length - 1, :]], dim=1) | |
h_right = torch.cat([hidden_states_in[:, 1:, :], h_pad], dim=1) | |
hidden_states_in = torch.cat([hidden_states_in, h_left, h_right], dim=2) | |
hidden_states = self.dense_in(hidden_states_in) | |
hidden_states = self.intermediate_act_fn(hidden_states) | |
hidden_states = self.dense_out(hidden_states) | |
hidden_states = self.dropout(hidden_states) | |
hidden_states_I_NEW = self.LayerNorm_I(hidden_states + hidden_states_I) | |
hidden_states_S_NEW = self.LayerNorm_S(hidden_states + hidden_states_S) | |
return hidden_states_I_NEW, hidden_states_S_NEW | |