Spaces:
Runtime error
Runtime error
File size: 5,388 Bytes
37b9e99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.decoder.interaction.base_interaction import BaseInteraction
class GraphAttentionLayer(nn.Module):
"""
Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
"""
def __init__(self, in_features, out_features, dropout, alpha, concat=True):
super(GraphAttentionLayer, self).__init__()
self.dropout = dropout
self.in_features = in_features
self.out_features = out_features
self.alpha = alpha
self.concat = concat
self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
nn.init.xavier_uniform_(self.W.data, gain=1.414)
self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1)))
nn.init.xavier_uniform_(self.a.data, gain=1.414)
self.leakyrelu = nn.LeakyReLU(self.alpha)
def forward(self, input, adj):
h = torch.matmul(input, self.W)
B, N = h.size()[0], h.size()[1]
a_input = torch.cat([h.repeat(1, 1, N).view(B, N * N, -1), h.repeat(1, N, 1)], dim=2).view(B, N, -1,
2 * self.out_features)
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(3))
zero_vec = -9e15 * torch.ones_like(e)
attention = torch.where(adj > 0, e, zero_vec)
attention = F.softmax(attention, dim=2)
attention = F.dropout(attention, self.dropout, training=self.training)
h_prime = torch.matmul(attention, h)
if self.concat:
return F.elu(h_prime)
else:
return h_prime
class GAT(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads, nlayers=2):
"""Dense version of GAT."""
super(GAT, self).__init__()
self.dropout = dropout
self.nlayers = nlayers
self.nheads = nheads
self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in
range(nheads)]
for i, attention in enumerate(self.attentions):
self.add_module('attention_{}'.format(i), attention)
if self.nlayers > 2:
for i in range(self.nlayers - 2):
for j in range(self.nheads):
self.add_module('attention_{}_{}'.format(i + 1, j),
GraphAttentionLayer(nhid * nheads, nhid, dropout=dropout, alpha=alpha, concat=True))
self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)
def forward(self, x, adj):
x = F.dropout(x, self.dropout, training=self.training)
input = x
x = torch.cat([att(x, adj) for att in self.attentions], dim=2)
if self.nlayers > 2:
for i in range(self.nlayers - 2):
temp = []
x = F.dropout(x, self.dropout, training=self.training)
cur_input = x
for j in range(self.nheads):
temp.append(self.__getattr__('attention_{}_{}'.format(i + 1, j))(x, adj))
x = torch.cat(temp, dim=2) + cur_input
x = F.dropout(x, self.dropout, training=self.training)
x = F.elu(self.out_att(x, adj))
return x + input
def normalize_adj(mx):
"""
Row-normalize matrix D^{-1}A
torch.diag_embed: https://github.com/pytorch/pytorch/pull/12447
"""
mx = mx.float()
rowsum = mx.sum(2)
r_inv = torch.pow(rowsum, -1)
r_inv[torch.isinf(r_inv)] = 0.
r_mat_inv = torch.diag_embed(r_inv, 0)
mx = r_mat_inv.matmul(mx)
return mx
class AGIFInteraction(BaseInteraction):
def __init__(self, **config):
super().__init__(**config)
self.intent_embedding = nn.Parameter(
torch.FloatTensor(self.config["intent_label_num"], self.config["intent_embedding_dim"])) # 191, 32
nn.init.normal_(self.intent_embedding.data)
self.adj = None
self.graph = GAT(
config["output_dim"],
config["hidden_dim"],
config["output_dim"],
config["dropout_rate"],
config["alpha"],
config["num_heads"],
config["num_layers"])
def generate_adj_gat(self, index, batch, intent_label_num):
intent_idx_ = [[torch.tensor(0)] for i in range(batch)]
for item in index:
intent_idx_[item[0]].append(item[1] + 1)
intent_idx = intent_idx_
self.adj = torch.cat([torch.eye(intent_label_num + 1).unsqueeze(0) for i in range(batch)])
for i in range(batch):
for j in intent_idx[i]:
self.adj[i, j, intent_idx[i]] = 1.
if self.config["row_normalized"]:
self.adj = normalize_adj(self.adj)
self.adj = self.adj.to(self.intent_embedding.device)
def forward(self, encode_hidden, **interaction_args):
if self.adj is None or interaction_args["sent_id"] == 0:
self.generate_adj_gat(interaction_args["intent_index"], interaction_args["batch_size"], interaction_args["intent_label_num"])
lstm_out = torch.cat((encode_hidden,
self.intent_embedding.unsqueeze(0).repeat(encode_hidden.shape[0], 1, 1)), dim=1)
return self.graph(lstm_out, self.adj[interaction_args["sent_id"]])
|