ashu316's picture
Upload 14 files
41aae2b verified
import logging
import numpy as np
import torch
from collections import defaultdict
from utils.utils import MergeLayer
from modules.memory import Memory
from modules.message_aggregator import get_message_aggregator
from modules.message_function import get_message_function
from modules.memory_updater import get_memory_updater
from modules.embedding_module import get_embedding_module
from model.time_encoding import TimeEncode
class TGN(torch.nn.Module):
def __init__(self, neighbor_finder, node_features, edge_features, device, n_layers=2,
n_heads=2, dropout=0.1, use_memory=False,
memory_update_at_start=True, message_dimension=100,
memory_dimension=500, embedding_module_type="graph_attention",
message_function="mlp",
mean_time_shift_src=0, std_time_shift_src=1, mean_time_shift_dst=0,
std_time_shift_dst=1, n_neighbors=None, aggregator_type="last",
memory_updater_type="gru",
use_destination_embedding_in_message=False,
use_source_embedding_in_message=False,
dyrep=False):
super(TGN, self).__init__()
self.n_layers = n_layers
self.neighbor_finder = neighbor_finder
self.device = device
self.logger = logging.getLogger(__name__)
self.node_raw_features = torch.from_numpy(node_features.astype(np.float32)).to(device)
self.edge_raw_features = torch.from_numpy(edge_features.astype(np.float32)).to(device)
self.n_node_features = self.node_raw_features.shape[1]
self.n_nodes = self.node_raw_features.shape[0]
self.n_edge_features = self.edge_raw_features.shape[1]
self.embedding_dimension = self.n_node_features
self.n_neighbors = n_neighbors
self.embedding_module_type = embedding_module_type
self.use_destination_embedding_in_message = use_destination_embedding_in_message
self.use_source_embedding_in_message = use_source_embedding_in_message
self.dyrep = dyrep
self.use_memory = use_memory
self.time_encoder = TimeEncode(dimension=self.n_node_features)
self.memory = None
self.mean_time_shift_src = mean_time_shift_src
self.std_time_shift_src = std_time_shift_src
self.mean_time_shift_dst = mean_time_shift_dst
self.std_time_shift_dst = std_time_shift_dst
if self.use_memory:
self.memory_dimension = memory_dimension
self.memory_update_at_start = memory_update_at_start
raw_message_dimension = 2 * self.memory_dimension + self.n_edge_features + \
self.time_encoder.dimension
message_dimension = message_dimension if message_function != "identity" else raw_message_dimension
self.memory = Memory(n_nodes=self.n_nodes,
memory_dimension=self.memory_dimension,
input_dimension=message_dimension,
message_dimension=message_dimension,
device=device)
self.message_aggregator = get_message_aggregator(aggregator_type=aggregator_type,
device=device)
self.message_function = get_message_function(module_type=message_function,
raw_message_dimension=raw_message_dimension,
message_dimension=message_dimension)
self.memory_updater = get_memory_updater(module_type=memory_updater_type,
memory=self.memory,
message_dimension=message_dimension,
memory_dimension=self.memory_dimension,
device=device)
self.embedding_module_type = embedding_module_type
self.embedding_module = get_embedding_module(module_type=embedding_module_type,
node_features=self.node_raw_features,
edge_features=self.edge_raw_features,
memory=self.memory,
neighbor_finder=self.neighbor_finder,
time_encoder=self.time_encoder,
n_layers=self.n_layers,
n_node_features=self.n_node_features,
n_edge_features=self.n_edge_features,
n_time_features=self.n_node_features,
embedding_dimension=self.embedding_dimension,
device=self.device,
n_heads=n_heads, dropout=dropout,
use_memory=use_memory,
n_neighbors=self.n_neighbors)
# MLP to compute probability on an edge given two node embeddings
self.affinity_score = MergeLayer(self.n_node_features, self.n_node_features,
self.n_node_features,
1)
def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times,
edge_idxs, n_neighbors=20):
"""
Compute temporal embeddings for sources, destinations, and negatively sampled destinations.
source_nodes [batch_size]: source ids.
:param destination_nodes [batch_size]: destination ids
:param negative_nodes [batch_size]: ids of negative sampled destination
:param edge_times [batch_size]: timestamp of interaction
:param edge_idxs [batch_size]: index of interaction
:param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
layer
:return: Temporal embeddings for sources, destinations and negatives
"""
n_samples = len(source_nodes)
nodes = np.concatenate([source_nodes, destination_nodes, negative_nodes])
positives = np.concatenate([source_nodes, destination_nodes])
timestamps = np.concatenate([edge_times, edge_times, edge_times])
memory = None
time_diffs = None
if self.use_memory:
if self.memory_update_at_start:
# Update memory for all nodes with messages stored in previous batches
memory, last_update = self.get_updated_memory(list(range(self.n_nodes)),
self.memory.messages)
else:
memory = self.memory.get_memory(list(range(self.n_nodes)))
last_update = self.memory.last_update
### Compute differences between the time the memory of a node was last updated,
### and the time for which we want to compute the embedding of a node
source_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
source_nodes].long()
source_time_diffs = (source_time_diffs - self.mean_time_shift_src) / self.std_time_shift_src
destination_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
destination_nodes].long()
destination_time_diffs = (destination_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst
negative_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
negative_nodes].long()
negative_time_diffs = (negative_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst
time_diffs = torch.cat([source_time_diffs, destination_time_diffs, negative_time_diffs],
dim=0)
# Compute the embeddings using the embedding module
node_embedding = self.embedding_module.compute_embedding(memory=memory,
source_nodes=nodes,
timestamps=timestamps,
n_layers=self.n_layers,
n_neighbors=n_neighbors,
time_diffs=time_diffs)
source_node_embedding = node_embedding[:n_samples]
destination_node_embedding = node_embedding[n_samples: 2 * n_samples]
negative_node_embedding = node_embedding[2 * n_samples:]
if self.use_memory:
if self.memory_update_at_start:
# Persist the updates to the memory only for sources and destinations (since now we have
# new messages for them)
self.update_memory(positives, self.memory.messages)
assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
"Something wrong in how the memory was updated"
# Remove messages for the positives since we have already updated the memory using them
self.memory.clear_messages(positives)
unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes,
source_node_embedding,
destination_nodes,
destination_node_embedding,
edge_times, edge_idxs)
unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes,
destination_node_embedding,
source_nodes,
source_node_embedding,
edge_times, edge_idxs)
if self.memory_update_at_start:
self.memory.store_raw_messages(unique_sources, source_id_to_messages)
self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
else:
self.update_memory(unique_sources, source_id_to_messages)
self.update_memory(unique_destinations, destination_id_to_messages)
if self.dyrep:
source_node_embedding = memory[source_nodes]
destination_node_embedding = memory[destination_nodes]
negative_node_embedding = memory[negative_nodes]
return source_node_embedding, destination_node_embedding, negative_node_embedding
def compute_edge_probabilities(self, source_nodes, destination_nodes, negative_nodes, edge_times,
edge_idxs, n_neighbors=20):
"""
Compute probabilities for edges between sources and destination and between sources and
negatives by first computing temporal embeddings using the TGN encoder and then feeding them
into the MLP decoder.
:param destination_nodes [batch_size]: destination ids
:param negative_nodes [batch_size]: ids of negative sampled destination
:param edge_times [batch_size]: timestamp of interaction
:param edge_idxs [batch_size]: index of interaction
:param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
layer
:return: Probabilities for both the positive and negative edges
"""
n_samples = len(source_nodes)
source_node_embedding, destination_node_embedding, negative_node_embedding = self.compute_temporal_embeddings(
source_nodes, destination_nodes, negative_nodes, edge_times, edge_idxs, n_neighbors)
score = self.affinity_score(torch.cat([source_node_embedding, source_node_embedding], dim=0),
torch.cat([destination_node_embedding,
negative_node_embedding])).squeeze(dim=0)
pos_score = score[:n_samples]
neg_score = score[n_samples:]
return pos_score.sigmoid(), neg_score.sigmoid()
def update_memory(self, nodes, messages):
# Aggregate messages for the same nodes
unique_nodes, unique_messages, unique_timestamps = \
self.message_aggregator.aggregate(
nodes,
messages)
if len(unique_nodes) > 0:
unique_messages = self.message_function.compute_message(unique_messages)
# Update the memory with the aggregated messages
self.memory_updater.update_memory(unique_nodes, unique_messages,
timestamps=unique_timestamps)
def get_updated_memory(self, nodes, messages):
# Aggregate messages for the same nodes
unique_nodes, unique_messages, unique_timestamps = \
self.message_aggregator.aggregate(
nodes,
messages)
if len(unique_nodes) > 0:
unique_messages = self.message_function.compute_message(unique_messages)
updated_memory, updated_last_update = self.memory_updater.get_updated_memory(unique_nodes,
unique_messages,
timestamps=unique_timestamps)
return updated_memory, updated_last_update
def get_raw_messages(self, source_nodes, source_node_embedding, destination_nodes,
destination_node_embedding, edge_times, edge_idxs):
edge_times = torch.from_numpy(edge_times).float().to(self.device)
edge_features = self.edge_raw_features[edge_idxs]
source_memory = self.memory.get_memory(source_nodes) if not \
self.use_source_embedding_in_message else source_node_embedding
destination_memory = self.memory.get_memory(destination_nodes) if \
not self.use_destination_embedding_in_message else destination_node_embedding
source_time_delta = edge_times - self.memory.last_update[source_nodes]
source_time_delta_encoding = self.time_encoder(source_time_delta.unsqueeze(dim=1)).view(len(
source_nodes), -1)
source_message = torch.cat([source_memory, destination_memory, edge_features,
source_time_delta_encoding],
dim=1)
messages = defaultdict(list)
unique_sources = np.unique(source_nodes)
for i in range(len(source_nodes)):
messages[source_nodes[i]].append((source_message[i], edge_times[i]))
return unique_sources, messages
def set_neighbor_finder(self, neighbor_finder):
self.neighbor_finder = neighbor_finder
self.embedding_module.neighbor_finder = neighbor_finder