Spaces:
Sleeping
Sleeping
File size: 14,606 Bytes
41aae2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 |
import logging
import numpy as np
import torch
from collections import defaultdict
from utils.utils import MergeLayer
from modules.memory import Memory
from modules.message_aggregator import get_message_aggregator
from modules.message_function import get_message_function
from modules.memory_updater import get_memory_updater
from modules.embedding_module import get_embedding_module
from model.time_encoding import TimeEncode
class TGN(torch.nn.Module):
def __init__(self, neighbor_finder, node_features, edge_features, device, n_layers=2,
n_heads=2, dropout=0.1, use_memory=False,
memory_update_at_start=True, message_dimension=100,
memory_dimension=500, embedding_module_type="graph_attention",
message_function="mlp",
mean_time_shift_src=0, std_time_shift_src=1, mean_time_shift_dst=0,
std_time_shift_dst=1, n_neighbors=None, aggregator_type="last",
memory_updater_type="gru",
use_destination_embedding_in_message=False,
use_source_embedding_in_message=False,
dyrep=False):
super(TGN, self).__init__()
self.n_layers = n_layers
self.neighbor_finder = neighbor_finder
self.device = device
self.logger = logging.getLogger(__name__)
self.node_raw_features = torch.from_numpy(node_features.astype(np.float32)).to(device)
self.edge_raw_features = torch.from_numpy(edge_features.astype(np.float32)).to(device)
self.n_node_features = self.node_raw_features.shape[1]
self.n_nodes = self.node_raw_features.shape[0]
self.n_edge_features = self.edge_raw_features.shape[1]
self.embedding_dimension = self.n_node_features
self.n_neighbors = n_neighbors
self.embedding_module_type = embedding_module_type
self.use_destination_embedding_in_message = use_destination_embedding_in_message
self.use_source_embedding_in_message = use_source_embedding_in_message
self.dyrep = dyrep
self.use_memory = use_memory
self.time_encoder = TimeEncode(dimension=self.n_node_features)
self.memory = None
self.mean_time_shift_src = mean_time_shift_src
self.std_time_shift_src = std_time_shift_src
self.mean_time_shift_dst = mean_time_shift_dst
self.std_time_shift_dst = std_time_shift_dst
if self.use_memory:
self.memory_dimension = memory_dimension
self.memory_update_at_start = memory_update_at_start
raw_message_dimension = 2 * self.memory_dimension + self.n_edge_features + \
self.time_encoder.dimension
message_dimension = message_dimension if message_function != "identity" else raw_message_dimension
self.memory = Memory(n_nodes=self.n_nodes,
memory_dimension=self.memory_dimension,
input_dimension=message_dimension,
message_dimension=message_dimension,
device=device)
self.message_aggregator = get_message_aggregator(aggregator_type=aggregator_type,
device=device)
self.message_function = get_message_function(module_type=message_function,
raw_message_dimension=raw_message_dimension,
message_dimension=message_dimension)
self.memory_updater = get_memory_updater(module_type=memory_updater_type,
memory=self.memory,
message_dimension=message_dimension,
memory_dimension=self.memory_dimension,
device=device)
self.embedding_module_type = embedding_module_type
self.embedding_module = get_embedding_module(module_type=embedding_module_type,
node_features=self.node_raw_features,
edge_features=self.edge_raw_features,
memory=self.memory,
neighbor_finder=self.neighbor_finder,
time_encoder=self.time_encoder,
n_layers=self.n_layers,
n_node_features=self.n_node_features,
n_edge_features=self.n_edge_features,
n_time_features=self.n_node_features,
embedding_dimension=self.embedding_dimension,
device=self.device,
n_heads=n_heads, dropout=dropout,
use_memory=use_memory,
n_neighbors=self.n_neighbors)
# MLP to compute probability on an edge given two node embeddings
self.affinity_score = MergeLayer(self.n_node_features, self.n_node_features,
self.n_node_features,
1)
def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times,
edge_idxs, n_neighbors=20):
"""
Compute temporal embeddings for sources, destinations, and negatively sampled destinations.
source_nodes [batch_size]: source ids.
:param destination_nodes [batch_size]: destination ids
:param negative_nodes [batch_size]: ids of negative sampled destination
:param edge_times [batch_size]: timestamp of interaction
:param edge_idxs [batch_size]: index of interaction
:param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
layer
:return: Temporal embeddings for sources, destinations and negatives
"""
n_samples = len(source_nodes)
nodes = np.concatenate([source_nodes, destination_nodes, negative_nodes])
positives = np.concatenate([source_nodes, destination_nodes])
timestamps = np.concatenate([edge_times, edge_times, edge_times])
memory = None
time_diffs = None
if self.use_memory:
if self.memory_update_at_start:
# Update memory for all nodes with messages stored in previous batches
memory, last_update = self.get_updated_memory(list(range(self.n_nodes)),
self.memory.messages)
else:
memory = self.memory.get_memory(list(range(self.n_nodes)))
last_update = self.memory.last_update
### Compute differences between the time the memory of a node was last updated,
### and the time for which we want to compute the embedding of a node
source_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
source_nodes].long()
source_time_diffs = (source_time_diffs - self.mean_time_shift_src) / self.std_time_shift_src
destination_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
destination_nodes].long()
destination_time_diffs = (destination_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst
negative_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
negative_nodes].long()
negative_time_diffs = (negative_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst
time_diffs = torch.cat([source_time_diffs, destination_time_diffs, negative_time_diffs],
dim=0)
# Compute the embeddings using the embedding module
node_embedding = self.embedding_module.compute_embedding(memory=memory,
source_nodes=nodes,
timestamps=timestamps,
n_layers=self.n_layers,
n_neighbors=n_neighbors,
time_diffs=time_diffs)
source_node_embedding = node_embedding[:n_samples]
destination_node_embedding = node_embedding[n_samples: 2 * n_samples]
negative_node_embedding = node_embedding[2 * n_samples:]
if self.use_memory:
if self.memory_update_at_start:
# Persist the updates to the memory only for sources and destinations (since now we have
# new messages for them)
self.update_memory(positives, self.memory.messages)
assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
"Something wrong in how the memory was updated"
# Remove messages for the positives since we have already updated the memory using them
self.memory.clear_messages(positives)
unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes,
source_node_embedding,
destination_nodes,
destination_node_embedding,
edge_times, edge_idxs)
unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes,
destination_node_embedding,
source_nodes,
source_node_embedding,
edge_times, edge_idxs)
if self.memory_update_at_start:
self.memory.store_raw_messages(unique_sources, source_id_to_messages)
self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
else:
self.update_memory(unique_sources, source_id_to_messages)
self.update_memory(unique_destinations, destination_id_to_messages)
if self.dyrep:
source_node_embedding = memory[source_nodes]
destination_node_embedding = memory[destination_nodes]
negative_node_embedding = memory[negative_nodes]
return source_node_embedding, destination_node_embedding, negative_node_embedding
def compute_edge_probabilities(self, source_nodes, destination_nodes, negative_nodes, edge_times,
edge_idxs, n_neighbors=20):
"""
Compute probabilities for edges between sources and destination and between sources and
negatives by first computing temporal embeddings using the TGN encoder and then feeding them
into the MLP decoder.
:param destination_nodes [batch_size]: destination ids
:param negative_nodes [batch_size]: ids of negative sampled destination
:param edge_times [batch_size]: timestamp of interaction
:param edge_idxs [batch_size]: index of interaction
:param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
layer
:return: Probabilities for both the positive and negative edges
"""
n_samples = len(source_nodes)
source_node_embedding, destination_node_embedding, negative_node_embedding = self.compute_temporal_embeddings(
source_nodes, destination_nodes, negative_nodes, edge_times, edge_idxs, n_neighbors)
score = self.affinity_score(torch.cat([source_node_embedding, source_node_embedding], dim=0),
torch.cat([destination_node_embedding,
negative_node_embedding])).squeeze(dim=0)
pos_score = score[:n_samples]
neg_score = score[n_samples:]
return pos_score.sigmoid(), neg_score.sigmoid()
def update_memory(self, nodes, messages):
# Aggregate messages for the same nodes
unique_nodes, unique_messages, unique_timestamps = \
self.message_aggregator.aggregate(
nodes,
messages)
if len(unique_nodes) > 0:
unique_messages = self.message_function.compute_message(unique_messages)
# Update the memory with the aggregated messages
self.memory_updater.update_memory(unique_nodes, unique_messages,
timestamps=unique_timestamps)
def get_updated_memory(self, nodes, messages):
# Aggregate messages for the same nodes
unique_nodes, unique_messages, unique_timestamps = \
self.message_aggregator.aggregate(
nodes,
messages)
if len(unique_nodes) > 0:
unique_messages = self.message_function.compute_message(unique_messages)
updated_memory, updated_last_update = self.memory_updater.get_updated_memory(unique_nodes,
unique_messages,
timestamps=unique_timestamps)
return updated_memory, updated_last_update
def get_raw_messages(self, source_nodes, source_node_embedding, destination_nodes,
destination_node_embedding, edge_times, edge_idxs):
edge_times = torch.from_numpy(edge_times).float().to(self.device)
edge_features = self.edge_raw_features[edge_idxs]
source_memory = self.memory.get_memory(source_nodes) if not \
self.use_source_embedding_in_message else source_node_embedding
destination_memory = self.memory.get_memory(destination_nodes) if \
not self.use_destination_embedding_in_message else destination_node_embedding
source_time_delta = edge_times - self.memory.last_update[source_nodes]
source_time_delta_encoding = self.time_encoder(source_time_delta.unsqueeze(dim=1)).view(len(
source_nodes), -1)
source_message = torch.cat([source_memory, destination_memory, edge_features,
source_time_delta_encoding],
dim=1)
messages = defaultdict(list)
unique_sources = np.unique(source_nodes)
for i in range(len(source_nodes)):
messages[source_nodes[i]].append((source_message[i], edge_times[i]))
return unique_sources, messages
def set_neighbor_finder(self, neighbor_finder):
self.neighbor_finder = neighbor_finder
self.embedding_module.neighbor_finder = neighbor_finder
|