ashu316's picture
Upload 14 files
41aae2b verified
import torch
from torch import nn
from collections import defaultdict
from copy import deepcopy
class Memory(nn.Module):
def __init__(self, n_nodes, memory_dimension, input_dimension, message_dimension=None,
device="cpu", combination_method='sum'):
super(Memory, self).__init__()
self.n_nodes = n_nodes
self.memory_dimension = memory_dimension
self.input_dimension = input_dimension
self.message_dimension = message_dimension
self.device = device
self.combination_method = combination_method
self.__init_memory__()
def __init_memory__(self):
"""
Initializes the memory to all zeros. It should be called at the start of each epoch.
"""
# Treat memory as parameter so that it is saved and loaded together with the model
self.memory = nn.Parameter(torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device),
requires_grad=False)
self.last_update = nn.Parameter(torch.zeros(self.n_nodes).to(self.device),
requires_grad=False)
self.messages = defaultdict(list)
def store_raw_messages(self, nodes, node_id_to_messages):
for node in nodes:
self.messages[node].extend(node_id_to_messages[node])
def get_memory(self, node_idxs):
return self.memory[node_idxs, :]
def set_memory(self, node_idxs, values):
self.memory[node_idxs, :] = values
def get_last_update(self, node_idxs):
return self.last_update[node_idxs]
def backup_memory(self):
messages_clone = {}
for k, v in self.messages.items():
messages_clone[k] = [(x[0].clone(), x[1].clone()) for x in v]
return self.memory.data.clone(), self.last_update.data.clone(), messages_clone
def restore_memory(self, memory_backup):
self.memory.data, self.last_update.data = memory_backup[0].clone(), memory_backup[1].clone()
self.messages = defaultdict(list)
for k, v in memory_backup[2].items():
self.messages[k] = [(x[0].clone(), x[1].clone()) for x in v]
def detach_memory(self):
self.memory.detach_()
# Detach all stored messages
for k, v in self.messages.items():
new_node_messages = []
for message in v:
new_node_messages.append((message[0].detach(), message[1]))
self.messages[k] = new_node_messages
def clear_messages(self, nodes):
for node in nodes:
self.messages[node] = []