Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from collections import defaultdict | |
from copy import deepcopy | |
class Memory(nn.Module): | |
def __init__(self, n_nodes, memory_dimension, input_dimension, message_dimension=None, | |
device="cpu", combination_method='sum'): | |
super(Memory, self).__init__() | |
self.n_nodes = n_nodes | |
self.memory_dimension = memory_dimension | |
self.input_dimension = input_dimension | |
self.message_dimension = message_dimension | |
self.device = device | |
self.combination_method = combination_method | |
self.__init_memory__() | |
def __init_memory__(self): | |
""" | |
Initializes the memory to all zeros. It should be called at the start of each epoch. | |
""" | |
# Treat memory as parameter so that it is saved and loaded together with the model | |
self.memory = nn.Parameter(torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device), | |
requires_grad=False) | |
self.last_update = nn.Parameter(torch.zeros(self.n_nodes).to(self.device), | |
requires_grad=False) | |
self.messages = defaultdict(list) | |
def store_raw_messages(self, nodes, node_id_to_messages): | |
for node in nodes: | |
self.messages[node].extend(node_id_to_messages[node]) | |
def get_memory(self, node_idxs): | |
return self.memory[node_idxs, :] | |
def set_memory(self, node_idxs, values): | |
self.memory[node_idxs, :] = values | |
def get_last_update(self, node_idxs): | |
return self.last_update[node_idxs] | |
def backup_memory(self): | |
messages_clone = {} | |
for k, v in self.messages.items(): | |
messages_clone[k] = [(x[0].clone(), x[1].clone()) for x in v] | |
return self.memory.data.clone(), self.last_update.data.clone(), messages_clone | |
def restore_memory(self, memory_backup): | |
self.memory.data, self.last_update.data = memory_backup[0].clone(), memory_backup[1].clone() | |
self.messages = defaultdict(list) | |
for k, v in memory_backup[2].items(): | |
self.messages[k] = [(x[0].clone(), x[1].clone()) for x in v] | |
def detach_memory(self): | |
self.memory.detach_() | |
# Detach all stored messages | |
for k, v in self.messages.items(): | |
new_node_messages = [] | |
for message in v: | |
new_node_messages.append((message[0].detach(), message[1])) | |
self.messages[k] = new_node_messages | |
def clear_messages(self, nodes): | |
for node in nodes: | |
self.messages[node] = [] | |