tgn-playground / modules /embedding_module.py
ashu316's picture
Upload 14 files
41aae2b verified
import torch
from torch import nn
import numpy as np
import math
from model.temporal_attention import TemporalAttentionLayer
class EmbeddingModule(nn.Module):
def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
dropout):
super(EmbeddingModule, self).__init__()
self.node_features = node_features
self.edge_features = edge_features
# self.memory = memory
self.neighbor_finder = neighbor_finder
self.time_encoder = time_encoder
self.n_layers = n_layers
self.n_node_features = n_node_features
self.n_edge_features = n_edge_features
self.n_time_features = n_time_features
self.dropout = dropout
self.embedding_dimension = embedding_dimension
self.device = device
def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
use_time_proj=True):
return NotImplemented
class IdentityEmbedding(EmbeddingModule):
def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
use_time_proj=True):
return memory[source_nodes, :]
class TimeEmbedding(EmbeddingModule):
def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
n_heads=2, dropout=0.1, use_memory=True, n_neighbors=1):
super(TimeEmbedding, self).__init__(node_features, edge_features, memory,
neighbor_finder, time_encoder, n_layers,
n_node_features, n_edge_features, n_time_features,
embedding_dimension, device, dropout)
class NormalLinear(nn.Linear):
# From Jodie code
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.normal_(0, stdv)
if self.bias is not None:
self.bias.data.normal_(0, stdv)
self.embedding_layer = NormalLinear(1, self.n_node_features)
def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
use_time_proj=True):
source_embeddings = memory[source_nodes, :] * (1 + self.embedding_layer(time_diffs.unsqueeze(1)))
return source_embeddings
class GraphEmbedding(EmbeddingModule):
def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
n_heads=2, dropout=0.1, use_memory=True):
super(GraphEmbedding, self).__init__(node_features, edge_features, memory,
neighbor_finder, time_encoder, n_layers,
n_node_features, n_edge_features, n_time_features,
embedding_dimension, device, dropout)
self.use_memory = use_memory
self.device = device
def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
use_time_proj=True):
"""Recursive implementation of curr_layers temporal graph attention layers.
src_idx_l [batch_size]: users / items input ids.
cut_time_l [batch_size]: scalar representing the instant of the time where we want to extract the user / item representation.
curr_layers [scalar]: number of temporal convolutional layers to stack.
num_neighbors [scalar]: number of temporal neighbor to consider in each convolutional layer.
"""
assert (n_layers >= 0)
source_nodes_torch = torch.from_numpy(source_nodes).long().to(self.device)
timestamps_torch = torch.unsqueeze(torch.from_numpy(timestamps).float().to(self.device), dim=1)
# query node always has the start time -> time span == 0
source_nodes_time_embedding = self.time_encoder(torch.zeros_like(
timestamps_torch))
source_node_features = self.node_features[source_nodes_torch, :]
if self.use_memory:
source_node_features = memory[source_nodes, :] + source_node_features
if n_layers == 0:
return source_node_features
else:
source_node_conv_embeddings = self.compute_embedding(memory,
source_nodes,
timestamps,
n_layers=n_layers - 1,
n_neighbors=n_neighbors)
neighbors, edge_idxs, edge_times = self.neighbor_finder.get_temporal_neighbor(
source_nodes,
timestamps,
n_neighbors=n_neighbors)
neighbors_torch = torch.from_numpy(neighbors).long().to(self.device)
edge_idxs = torch.from_numpy(edge_idxs).long().to(self.device)
edge_deltas = timestamps[:, np.newaxis] - edge_times
edge_deltas_torch = torch.from_numpy(edge_deltas).float().to(self.device)
neighbors = neighbors.flatten()
neighbor_embeddings = self.compute_embedding(memory,
neighbors,
np.repeat(timestamps, n_neighbors),
n_layers=n_layers - 1,
n_neighbors=n_neighbors)
effective_n_neighbors = n_neighbors if n_neighbors > 0 else 1
neighbor_embeddings = neighbor_embeddings.view(len(source_nodes), effective_n_neighbors, -1)
edge_time_embeddings = self.time_encoder(edge_deltas_torch)
edge_features = self.edge_features[edge_idxs, :]
mask = neighbors_torch == 0
source_embedding = self.aggregate(n_layers, source_node_conv_embeddings,
source_nodes_time_embedding,
neighbor_embeddings,
edge_time_embeddings,
edge_features,
mask)
return source_embedding
def aggregate(self, n_layers, source_node_features, source_nodes_time_embedding,
neighbor_embeddings,
edge_time_embeddings, edge_features, mask):
return NotImplemented
class GraphSumEmbedding(GraphEmbedding):
def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
n_heads=2, dropout=0.1, use_memory=True):
super(GraphSumEmbedding, self).__init__(node_features=node_features,
edge_features=edge_features,
memory=memory,
neighbor_finder=neighbor_finder,
time_encoder=time_encoder, n_layers=n_layers,
n_node_features=n_node_features,
n_edge_features=n_edge_features,
n_time_features=n_time_features,
embedding_dimension=embedding_dimension,
device=device,
n_heads=n_heads, dropout=dropout,
use_memory=use_memory)
self.linear_1 = torch.nn.ModuleList([torch.nn.Linear(embedding_dimension + n_time_features +
n_edge_features, embedding_dimension)
for _ in range(n_layers)])
self.linear_2 = torch.nn.ModuleList(
[torch.nn.Linear(embedding_dimension + n_node_features + n_time_features,
embedding_dimension) for _ in range(n_layers)])
def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding,
neighbor_embeddings,
edge_time_embeddings, edge_features, mask):
neighbors_features = torch.cat([neighbor_embeddings, edge_time_embeddings, edge_features],
dim=2)
neighbor_embeddings = self.linear_1[n_layer - 1](neighbors_features)
neighbors_sum = torch.nn.functional.relu(torch.sum(neighbor_embeddings, dim=1))
source_features = torch.cat([source_node_features,
source_nodes_time_embedding.squeeze()], dim=1)
source_embedding = torch.cat([neighbors_sum, source_features], dim=1)
source_embedding = self.linear_2[n_layer - 1](source_embedding)
return source_embedding
class GraphAttentionEmbedding(GraphEmbedding):
def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
n_heads=2, dropout=0.1, use_memory=True):
super(GraphAttentionEmbedding, self).__init__(node_features, edge_features, memory,
neighbor_finder, time_encoder, n_layers,
n_node_features, n_edge_features,
n_time_features,
embedding_dimension, device,
n_heads, dropout,
use_memory)
self.attention_models = torch.nn.ModuleList([TemporalAttentionLayer(
n_node_features=n_node_features,
n_neighbors_features=n_node_features,
n_edge_features=n_edge_features,
time_dim=n_time_features,
n_head=n_heads,
dropout=dropout,
output_dimension=n_node_features)
for _ in range(n_layers)])
def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding,
neighbor_embeddings,
edge_time_embeddings, edge_features, mask):
attention_model = self.attention_models[n_layer - 1]
source_embedding, _ = attention_model(source_node_features,
source_nodes_time_embedding,
neighbor_embeddings,
edge_time_embeddings,
edge_features,
mask)
return source_embedding
def get_embedding_module(module_type, node_features, edge_features, memory, neighbor_finder,
time_encoder, n_layers, n_node_features, n_edge_features, n_time_features,
embedding_dimension, device,
n_heads=2, dropout=0.1, n_neighbors=None,
use_memory=True):
if module_type == "graph_attention":
return GraphAttentionEmbedding(node_features=node_features,
edge_features=edge_features,
memory=memory,
neighbor_finder=neighbor_finder,
time_encoder=time_encoder,
n_layers=n_layers,
n_node_features=n_node_features,
n_edge_features=n_edge_features,
n_time_features=n_time_features,
embedding_dimension=embedding_dimension,
device=device,
n_heads=n_heads, dropout=dropout, use_memory=use_memory)
elif module_type == "graph_sum":
return GraphSumEmbedding(node_features=node_features,
edge_features=edge_features,
memory=memory,
neighbor_finder=neighbor_finder,
time_encoder=time_encoder,
n_layers=n_layers,
n_node_features=n_node_features,
n_edge_features=n_edge_features,
n_time_features=n_time_features,
embedding_dimension=embedding_dimension,
device=device,
n_heads=n_heads, dropout=dropout, use_memory=use_memory)
elif module_type == "identity":
return IdentityEmbedding(node_features=node_features,
edge_features=edge_features,
memory=memory,
neighbor_finder=neighbor_finder,
time_encoder=time_encoder,
n_layers=n_layers,
n_node_features=n_node_features,
n_edge_features=n_edge_features,
n_time_features=n_time_features,
embedding_dimension=embedding_dimension,
device=device,
dropout=dropout)
elif module_type == "time":
return TimeEmbedding(node_features=node_features,
edge_features=edge_features,
memory=memory,
neighbor_finder=neighbor_finder,
time_encoder=time_encoder,
n_layers=n_layers,
n_node_features=n_node_features,
n_edge_features=n_edge_features,
n_time_features=n_time_features,
embedding_dimension=embedding_dimension,
device=device,
dropout=dropout,
n_neighbors=n_neighbors)
else:
raise ValueError("Embedding Module {} not supported".format(module_type))