SocialAISchool / models /multiheadedac.py
grg's picture
Cleaned old git history
be5548b
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
import torch_ac
from utils.other import init_params
class MultiHeadedACModel(nn.Module, torch_ac.RecurrentACModel):
def __init__(self, obs_space, action_space, use_memory=False, use_text=False, use_dialogue=False):
super().__init__()
# store config
self.config = locals()
# Decide which components are enabled
self.use_text = use_text
self.use_dialogue = use_dialogue
self.use_memory = use_memory
if self.use_text:
raise ValueError("You should not use text but dialogue. --text is cheating.")
# multi dim
if action_space.shape == ():
raise ValueError("The action space is not multi modal. Use ACModel instead.")
self.n_primitive_actions = action_space.nvec[0] + 1 # for talk
self.talk_action = int(self.n_primitive_actions) - 1
self.n_utterance_actions = action_space.nvec[1:]
self.env_action_space = action_space
self.model_raw_action_space = spaces.MultiDiscrete([self.n_primitive_actions, *self.n_utterance_actions])
# Define image embedding
self.image_conv = nn.Sequential(
nn.Conv2d(3, 16, (2, 2)),
nn.ReLU(),
nn.MaxPool2d((2, 2)),
nn.Conv2d(16, 32, (2, 2)),
nn.ReLU(),
nn.Conv2d(32, 64, (2, 2)),
nn.ReLU()
)
n = obs_space["image"][0]
m = obs_space["image"][1]
self.image_embedding_size = ((n-1)//2-2)*((m-1)//2-2)*64
# Define memory
if self.use_memory:
self.memory_rnn = nn.LSTMCell(self.image_embedding_size, self.semi_memory_size)
if self.use_text or self.use_dialogue:
self.word_embedding_size = 32
self.word_embedding = nn.Embedding(obs_space["text"], self.word_embedding_size)
# Define text embedding
if self.use_text:
self.text_embedding_size = 128
self.text_rnn = nn.GRU(self.word_embedding_size, self.text_embedding_size, batch_first=True)
# Define dialogue embedding
if self.use_dialogue:
self.dialogue_embedding_size = 128
self.dialogue_rnn = nn.GRU(self.word_embedding_size, self.dialogue_embedding_size, batch_first=True)
# Resize image embedding
self.embedding_size = self.semi_memory_size
if self.use_text:
self.embedding_size += self.text_embedding_size
if self.use_dialogue:
self.embedding_size += self.dialogue_embedding_size
# Define actor's model
self.actor = nn.Sequential(
nn.Linear(self.embedding_size, 64),
nn.Tanh(),
nn.Linear(64, self.n_primitive_actions)
)
self.talker = nn.ModuleList([
nn.Sequential(
nn.Linear(self.embedding_size, 64),
nn.Tanh(),
nn.Linear(64, n)
) for n in self.n_utterance_actions])
# Define critic's model
self.critic = nn.Sequential(
nn.Linear(self.embedding_size, 64),
nn.Tanh(),
nn.Linear(64, 1)
)
# Initialize parameters correctly
self.apply(init_params)
@property
def memory_size(self):
return 2*self.semi_memory_size
@property
def semi_memory_size(self):
return self.image_embedding_size
def forward(self, obs, memory):
x = obs.image.transpose(1, 3).transpose(2, 3)
x = self.image_conv(x)
batch_size = x.shape[0]
x = x.reshape(batch_size, -1)
if self.use_memory:
hidden = (memory[:, :self.semi_memory_size], memory[:, self.semi_memory_size:])
hidden = self.memory_rnn(x, hidden)
embedding = hidden[0]
memory = torch.cat(hidden, dim=1)
else:
embedding = x
if self.use_text:
embed_text = self._get_embed_text(obs.text)
embedding = torch.cat((embedding, embed_text), dim=1)
if self.use_dialogue:
if not hasattr(obs, "utterance_history"):
raise ValueError("The environment need's to be updated to 'utterance' and 'utterance_history' keys'")
embed_dial = self._get_embed_dialogue(obs.utterance_history)
embedding = torch.cat((embedding, embed_dial), dim=1)
x = self.actor(embedding)
primtive_actions_dist = Categorical(logits=F.log_softmax(x, dim=1))
x = self.critic(embedding)
value = x.squeeze(1)
utterance_actions_dists = [
Categorical(logits=F.log_softmax(
tal(embedding),
dim=1,
)) for tal in self.talker
]
dist = [primtive_actions_dist] + utterance_actions_dists
return dist, value, memory
def sample_action(self, dist):
return torch.stack([d.sample() for d in dist], dim=1)
def calculate_log_probs(self, dist, action):
return torch.stack([d.log_prob(action[:, i]) for i, d in enumerate(dist)], dim=1)
def calculate_action_masks(self, action):
talk_mask = action[:, 0] == self.talk_action
mask = torch.stack(
(torch.ones_like(talk_mask), talk_mask, talk_mask),
dim=1).detach()
assert action.shape == mask.shape
return mask
def construct_final_action(self, action):
act_mask = action[:, 0] != self.n_primitive_actions - 1
nan_mask = np.array([
np.array([1, np.nan, np.nan]) if t else np.array([np.nan, 1, 1]) for t in act_mask
])
action = nan_mask*action
return action
def _get_embed_text(self, text):
_, hidden = self.text_rnn(self.word_embedding(text))
return hidden[-1]
def _get_embed_dialogue(self, dial):
_, hidden = self.dialogue_rnn(self.word_embedding(dial))
return hidden[-1]
def get_config_dict(self):
del self.config['__class__']
self.config['self'] = str(self.config['self'])
self.config['action_space'] = self.config['action_space'].nvec.tolist()
return self.config