from torch import nn import torch class MemoryUpdater(nn.Module): def update_memory(self, unique_node_ids, unique_messages, timestamps): pass class SequenceMemoryUpdater(MemoryUpdater): def __init__(self, memory, message_dimension, memory_dimension, device): super(SequenceMemoryUpdater, self).__init__() self.memory = memory self.layer_norm = torch.nn.LayerNorm(memory_dimension) self.message_dimension = message_dimension self.device = device def update_memory(self, unique_node_ids, unique_messages, timestamps): if len(unique_node_ids) <= 0: return assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \ "update memory to time in the past" memory = self.memory.get_memory(unique_node_ids) with torch.no_grad(): self.memory.last_update[unique_node_ids] = timestamps updated_memory = self.memory_updater(unique_messages, memory) with torch.no_grad(): self.memory.set_memory(unique_node_ids, updated_memory) def get_updated_memory(self, unique_node_ids, unique_messages, timestamps): if len(unique_node_ids) <= 0: return self.memory.memory.data.clone(), self.memory.last_update.data.clone() assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \ "update memory to time in the past" updated_memory = self.memory.memory.data.clone() updated_memory[unique_node_ids] = self.memory_updater(unique_messages, updated_memory[unique_node_ids]) updated_last_update = self.memory.last_update.data.clone() updated_last_update[unique_node_ids] = timestamps return updated_memory, updated_last_update class GRUMemoryUpdater(SequenceMemoryUpdater): def __init__(self, memory, message_dimension, memory_dimension, device): super(GRUMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device) self.memory_updater = nn.GRUCell(input_size=message_dimension, hidden_size=memory_dimension) class RNNMemoryUpdater(SequenceMemoryUpdater): def __init__(self, memory, message_dimension, memory_dimension, device): super(RNNMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device) self.memory_updater = nn.RNNCell(input_size=message_dimension, hidden_size=memory_dimension) def get_memory_updater(module_type, memory, message_dimension, memory_dimension, device): if module_type == "gru": return GRUMemoryUpdater(memory, message_dimension, memory_dimension, device) elif module_type == "rnn": return RNNMemoryUpdater(memory, message_dimension, memory_dimension, device)