Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
import numpy as np | |
import math | |
from model.temporal_attention import TemporalAttentionLayer | |
class EmbeddingModule(nn.Module): | |
def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers, | |
n_node_features, n_edge_features, n_time_features, embedding_dimension, device, | |
dropout): | |
super(EmbeddingModule, self).__init__() | |
self.node_features = node_features | |
self.edge_features = edge_features | |
# self.memory = memory | |
self.neighbor_finder = neighbor_finder | |
self.time_encoder = time_encoder | |
self.n_layers = n_layers | |
self.n_node_features = n_node_features | |
self.n_edge_features = n_edge_features | |
self.n_time_features = n_time_features | |
self.dropout = dropout | |
self.embedding_dimension = embedding_dimension | |
self.device = device | |
def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None, | |
use_time_proj=True): | |
return NotImplemented | |
class IdentityEmbedding(EmbeddingModule): | |
def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None, | |
use_time_proj=True): | |
return memory[source_nodes, :] | |
class TimeEmbedding(EmbeddingModule): | |
def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers, | |
n_node_features, n_edge_features, n_time_features, embedding_dimension, device, | |
n_heads=2, dropout=0.1, use_memory=True, n_neighbors=1): | |
super(TimeEmbedding, self).__init__(node_features, edge_features, memory, | |
neighbor_finder, time_encoder, n_layers, | |
n_node_features, n_edge_features, n_time_features, | |
embedding_dimension, device, dropout) | |
class NormalLinear(nn.Linear): | |
# From Jodie code | |
def reset_parameters(self): | |
stdv = 1. / math.sqrt(self.weight.size(1)) | |
self.weight.data.normal_(0, stdv) | |
if self.bias is not None: | |
self.bias.data.normal_(0, stdv) | |
self.embedding_layer = NormalLinear(1, self.n_node_features) | |
def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None, | |
use_time_proj=True): | |
source_embeddings = memory[source_nodes, :] * (1 + self.embedding_layer(time_diffs.unsqueeze(1))) | |
return source_embeddings | |
class GraphEmbedding(EmbeddingModule): | |
def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers, | |
n_node_features, n_edge_features, n_time_features, embedding_dimension, device, | |
n_heads=2, dropout=0.1, use_memory=True): | |
super(GraphEmbedding, self).__init__(node_features, edge_features, memory, | |
neighbor_finder, time_encoder, n_layers, | |
n_node_features, n_edge_features, n_time_features, | |
embedding_dimension, device, dropout) | |
self.use_memory = use_memory | |
self.device = device | |
def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None, | |
use_time_proj=True): | |
"""Recursive implementation of curr_layers temporal graph attention layers. | |
src_idx_l [batch_size]: users / items input ids. | |
cut_time_l [batch_size]: scalar representing the instant of the time where we want to extract the user / item representation. | |
curr_layers [scalar]: number of temporal convolutional layers to stack. | |
num_neighbors [scalar]: number of temporal neighbor to consider in each convolutional layer. | |
""" | |
assert (n_layers >= 0) | |
source_nodes_torch = torch.from_numpy(source_nodes).long().to(self.device) | |
timestamps_torch = torch.unsqueeze(torch.from_numpy(timestamps).float().to(self.device), dim=1) | |
# query node always has the start time -> time span == 0 | |
source_nodes_time_embedding = self.time_encoder(torch.zeros_like( | |
timestamps_torch)) | |
source_node_features = self.node_features[source_nodes_torch, :] | |
if self.use_memory: | |
source_node_features = memory[source_nodes, :] + source_node_features | |
if n_layers == 0: | |
return source_node_features | |
else: | |
source_node_conv_embeddings = self.compute_embedding(memory, | |
source_nodes, | |
timestamps, | |
n_layers=n_layers - 1, | |
n_neighbors=n_neighbors) | |
neighbors, edge_idxs, edge_times = self.neighbor_finder.get_temporal_neighbor( | |
source_nodes, | |
timestamps, | |
n_neighbors=n_neighbors) | |
neighbors_torch = torch.from_numpy(neighbors).long().to(self.device) | |
edge_idxs = torch.from_numpy(edge_idxs).long().to(self.device) | |
edge_deltas = timestamps[:, np.newaxis] - edge_times | |
edge_deltas_torch = torch.from_numpy(edge_deltas).float().to(self.device) | |
neighbors = neighbors.flatten() | |
neighbor_embeddings = self.compute_embedding(memory, | |
neighbors, | |
np.repeat(timestamps, n_neighbors), | |
n_layers=n_layers - 1, | |
n_neighbors=n_neighbors) | |
effective_n_neighbors = n_neighbors if n_neighbors > 0 else 1 | |
neighbor_embeddings = neighbor_embeddings.view(len(source_nodes), effective_n_neighbors, -1) | |
edge_time_embeddings = self.time_encoder(edge_deltas_torch) | |
edge_features = self.edge_features[edge_idxs, :] | |
mask = neighbors_torch == 0 | |
source_embedding = self.aggregate(n_layers, source_node_conv_embeddings, | |
source_nodes_time_embedding, | |
neighbor_embeddings, | |
edge_time_embeddings, | |
edge_features, | |
mask) | |
return source_embedding | |
def aggregate(self, n_layers, source_node_features, source_nodes_time_embedding, | |
neighbor_embeddings, | |
edge_time_embeddings, edge_features, mask): | |
return NotImplemented | |
class GraphSumEmbedding(GraphEmbedding): | |
def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers, | |
n_node_features, n_edge_features, n_time_features, embedding_dimension, device, | |
n_heads=2, dropout=0.1, use_memory=True): | |
super(GraphSumEmbedding, self).__init__(node_features=node_features, | |
edge_features=edge_features, | |
memory=memory, | |
neighbor_finder=neighbor_finder, | |
time_encoder=time_encoder, n_layers=n_layers, | |
n_node_features=n_node_features, | |
n_edge_features=n_edge_features, | |
n_time_features=n_time_features, | |
embedding_dimension=embedding_dimension, | |
device=device, | |
n_heads=n_heads, dropout=dropout, | |
use_memory=use_memory) | |
self.linear_1 = torch.nn.ModuleList([torch.nn.Linear(embedding_dimension + n_time_features + | |
n_edge_features, embedding_dimension) | |
for _ in range(n_layers)]) | |
self.linear_2 = torch.nn.ModuleList( | |
[torch.nn.Linear(embedding_dimension + n_node_features + n_time_features, | |
embedding_dimension) for _ in range(n_layers)]) | |
def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding, | |
neighbor_embeddings, | |
edge_time_embeddings, edge_features, mask): | |
neighbors_features = torch.cat([neighbor_embeddings, edge_time_embeddings, edge_features], | |
dim=2) | |
neighbor_embeddings = self.linear_1[n_layer - 1](neighbors_features) | |
neighbors_sum = torch.nn.functional.relu(torch.sum(neighbor_embeddings, dim=1)) | |
source_features = torch.cat([source_node_features, | |
source_nodes_time_embedding.squeeze()], dim=1) | |
source_embedding = torch.cat([neighbors_sum, source_features], dim=1) | |
source_embedding = self.linear_2[n_layer - 1](source_embedding) | |
return source_embedding | |
class GraphAttentionEmbedding(GraphEmbedding): | |
def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers, | |
n_node_features, n_edge_features, n_time_features, embedding_dimension, device, | |
n_heads=2, dropout=0.1, use_memory=True): | |
super(GraphAttentionEmbedding, self).__init__(node_features, edge_features, memory, | |
neighbor_finder, time_encoder, n_layers, | |
n_node_features, n_edge_features, | |
n_time_features, | |
embedding_dimension, device, | |
n_heads, dropout, | |
use_memory) | |
self.attention_models = torch.nn.ModuleList([TemporalAttentionLayer( | |
n_node_features=n_node_features, | |
n_neighbors_features=n_node_features, | |
n_edge_features=n_edge_features, | |
time_dim=n_time_features, | |
n_head=n_heads, | |
dropout=dropout, | |
output_dimension=n_node_features) | |
for _ in range(n_layers)]) | |
def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding, | |
neighbor_embeddings, | |
edge_time_embeddings, edge_features, mask): | |
attention_model = self.attention_models[n_layer - 1] | |
source_embedding, _ = attention_model(source_node_features, | |
source_nodes_time_embedding, | |
neighbor_embeddings, | |
edge_time_embeddings, | |
edge_features, | |
mask) | |
return source_embedding | |
def get_embedding_module(module_type, node_features, edge_features, memory, neighbor_finder, | |
time_encoder, n_layers, n_node_features, n_edge_features, n_time_features, | |
embedding_dimension, device, | |
n_heads=2, dropout=0.1, n_neighbors=None, | |
use_memory=True): | |
if module_type == "graph_attention": | |
return GraphAttentionEmbedding(node_features=node_features, | |
edge_features=edge_features, | |
memory=memory, | |
neighbor_finder=neighbor_finder, | |
time_encoder=time_encoder, | |
n_layers=n_layers, | |
n_node_features=n_node_features, | |
n_edge_features=n_edge_features, | |
n_time_features=n_time_features, | |
embedding_dimension=embedding_dimension, | |
device=device, | |
n_heads=n_heads, dropout=dropout, use_memory=use_memory) | |
elif module_type == "graph_sum": | |
return GraphSumEmbedding(node_features=node_features, | |
edge_features=edge_features, | |
memory=memory, | |
neighbor_finder=neighbor_finder, | |
time_encoder=time_encoder, | |
n_layers=n_layers, | |
n_node_features=n_node_features, | |
n_edge_features=n_edge_features, | |
n_time_features=n_time_features, | |
embedding_dimension=embedding_dimension, | |
device=device, | |
n_heads=n_heads, dropout=dropout, use_memory=use_memory) | |
elif module_type == "identity": | |
return IdentityEmbedding(node_features=node_features, | |
edge_features=edge_features, | |
memory=memory, | |
neighbor_finder=neighbor_finder, | |
time_encoder=time_encoder, | |
n_layers=n_layers, | |
n_node_features=n_node_features, | |
n_edge_features=n_edge_features, | |
n_time_features=n_time_features, | |
embedding_dimension=embedding_dimension, | |
device=device, | |
dropout=dropout) | |
elif module_type == "time": | |
return TimeEmbedding(node_features=node_features, | |
edge_features=edge_features, | |
memory=memory, | |
neighbor_finder=neighbor_finder, | |
time_encoder=time_encoder, | |
n_layers=n_layers, | |
n_node_features=n_node_features, | |
n_edge_features=n_edge_features, | |
n_time_features=n_time_features, | |
embedding_dimension=embedding_dimension, | |
device=device, | |
dropout=dropout, | |
n_neighbors=n_neighbors) | |
else: | |
raise ValueError("Embedding Module {} not supported".format(module_type)) | |