|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.distributions import Categorical |
|
|
|
import gym |
|
|
|
def check(input): |
|
output = torch.from_numpy(input) if type(input) == np.ndarray else input |
|
return output |
|
|
|
class FcEncoder(nn.Module): |
|
def __init__(self, fc_num, input_size, output_size): |
|
super(FcEncoder, self).__init__() |
|
self.first_mlp = nn.Sequential( |
|
nn.Linear(input_size, output_size), nn.ReLU(), nn.LayerNorm(output_size) |
|
) |
|
self.mlp = nn.Sequential() |
|
for _ in range(fc_num - 1): |
|
self.mlp.append(nn.Sequential( |
|
nn.Linear(output_size, output_size), nn.ReLU(), nn.LayerNorm(output_size) |
|
)) |
|
|
|
def forward(self, x): |
|
output = self.first_mlp(x) |
|
return self.mlp(output) |
|
|
|
def init(module, weight_init, bias_init, gain=1): |
|
weight_init(module.weight.data, gain=gain) |
|
if module.bias is not None: |
|
bias_init(module.bias.data) |
|
return module |
|
|
|
|
|
class FixedCategorical(torch.distributions.Categorical): |
|
def sample(self): |
|
return super().sample().unsqueeze(-1) |
|
|
|
def log_probs(self, actions): |
|
return ( |
|
super() |
|
.log_prob(actions.squeeze(-1)) |
|
.view(actions.size(0), -1) |
|
.sum(-1) |
|
.unsqueeze(-1) |
|
) |
|
|
|
def mode(self): |
|
return self.probs.argmax(dim=-1, keepdim=True) |
|
|
|
class Categorical(nn.Module): |
|
def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01): |
|
super(Categorical, self).__init__() |
|
init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] |
|
def init_(m): |
|
return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain) |
|
|
|
self.linear = init_(nn.Linear(num_inputs, num_outputs)) |
|
|
|
def forward(self, x, available_actions=None): |
|
x = self.linear(x) |
|
if available_actions is not None: |
|
x[available_actions == 0] = -1e10 |
|
return FixedCategorical(logits=x) |
|
|
|
|
|
class AddBias(nn.Module): |
|
def __init__(self, bias): |
|
super(AddBias, self).__init__() |
|
self._bias = nn.Parameter(bias.unsqueeze(1)) |
|
|
|
def forward(self, x): |
|
if x.dim() == 2: |
|
bias = self._bias.t().view(1, -1) |
|
else: |
|
bias = self._bias.t().view(1, -1, 1, 1) |
|
|
|
return x + bias |
|
|
|
class ACTLayer(nn.Module): |
|
def __init__(self, action_space, inputs_dim, use_orthogonal, gain): |
|
super(ACTLayer, self).__init__() |
|
self.multidiscrete_action = False |
|
self.continuous_action = False |
|
self.mixed_action = False |
|
|
|
action_dim = action_space.n |
|
self.action_out = Categorical(inputs_dim, action_dim, use_orthogonal, gain) |
|
|
|
|
|
|
|
def forward(self, x, available_actions=None, deterministic=False): |
|
if self.mixed_action : |
|
actions = [] |
|
action_log_probs = [] |
|
for action_out in self.action_outs: |
|
action_logit = action_out(x) |
|
action = action_logit.mode() if deterministic else action_logit.sample() |
|
action_log_prob = action_logit.log_probs(action) |
|
actions.append(action.float()) |
|
action_log_probs.append(action_log_prob) |
|
|
|
actions = torch.cat(actions, -1) |
|
action_log_probs = torch.sum(torch.cat(action_log_probs, -1), -1, keepdim=True) |
|
|
|
elif self.multidiscrete_action: |
|
actions = [] |
|
action_log_probs = [] |
|
for action_out in self.action_outs: |
|
action_logit = action_out(x) |
|
action = action_logit.mode() if deterministic else action_logit.sample() |
|
action_log_prob = action_logit.log_probs(action) |
|
actions.append(action) |
|
action_log_probs.append(action_log_prob) |
|
|
|
actions = torch.cat(actions, -1) |
|
action_log_probs = torch.cat(action_log_probs, -1) |
|
|
|
elif self.continuous_action: |
|
action_logits = self.action_out(x) |
|
actions = action_logits.mode() if deterministic else action_logits.sample() |
|
action_log_probs = action_logits.log_probs(actions) |
|
|
|
else: |
|
action_logits = self.action_out(x, available_actions) |
|
actions = action_logits.mode() if deterministic else action_logits.sample() |
|
action_log_probs = action_logits.log_probs(actions) |
|
|
|
return actions, action_log_probs |
|
|
|
def get_probs(self, x, available_actions=None): |
|
if self.mixed_action or self.multidiscrete_action: |
|
action_probs = [] |
|
for action_out in self.action_outs: |
|
action_logit = action_out(x) |
|
action_prob = action_logit.probs |
|
action_probs.append(action_prob) |
|
action_probs = torch.cat(action_probs, -1) |
|
elif self.continuous_action: |
|
action_logits = self.action_out(x) |
|
action_probs = action_logits.probs |
|
else: |
|
action_logits = self.action_out(x, available_actions) |
|
action_probs = action_logits.probs |
|
|
|
return action_probs |
|
|
|
def evaluate_actions(self, x, action, available_actions=None, active_masks=None, get_probs=False): |
|
if self.mixed_action: |
|
a, b = action.split((2, 1), -1) |
|
b = b.long() |
|
action = [a, b] |
|
action_log_probs = [] |
|
dist_entropy = [] |
|
for action_out, act in zip(self.action_outs, action): |
|
action_logit = action_out(x) |
|
action_log_probs.append(action_logit.log_probs(act)) |
|
if active_masks is not None: |
|
if len(action_logit.entropy().shape) == len(active_masks.shape): |
|
dist_entropy.append((action_logit.entropy() * active_masks).sum()/active_masks.sum()) |
|
else: |
|
dist_entropy.append((action_logit.entropy() * active_masks.squeeze(-1)).sum()/active_masks.sum()) |
|
else: |
|
dist_entropy.append(action_logit.entropy().mean()) |
|
|
|
action_log_probs = torch.sum(torch.cat(action_log_probs, -1), -1, keepdim=True) |
|
dist_entropy = dist_entropy[0] * 0.0025 + dist_entropy[1] * 0.01 |
|
|
|
elif self.multidiscrete_action: |
|
action = torch.transpose(action, 0, 1) |
|
action_log_probs = [] |
|
dist_entropy = [] |
|
for action_out, act in zip(self.action_outs, action): |
|
action_logit = action_out(x) |
|
action_log_probs.append(action_logit.log_probs(act)) |
|
if active_masks is not None: |
|
dist_entropy.append((action_logit.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum()) |
|
else: |
|
dist_entropy.append(action_logit.entropy().mean()) |
|
|
|
action_log_probs = torch.cat(action_log_probs, -1) |
|
dist_entropy = torch.tensor(dist_entropy).mean() |
|
|
|
elif self.continuous_action: |
|
action_logits = self.action_out(x) |
|
action_log_probs = action_logits.log_probs(action) |
|
act_entropy = action_logits.entropy() |
|
|
|
if active_masks is not None: |
|
dist_entropy = (act_entropy*active_masks).sum()/active_masks.sum() |
|
else: |
|
dist_entropy = act_entropy.mean() |
|
|
|
else: |
|
action_logits = self.action_out(x, available_actions) |
|
action_log_probs = action_logits.log_probs(action) |
|
if active_masks is not None: |
|
dist_entropy = (action_logits.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum() |
|
else: |
|
dist_entropy = action_logits.entropy().mean() |
|
if not get_probs: |
|
return action_log_probs, dist_entropy |
|
else: |
|
return action_log_probs, dist_entropy, action_logits |
|
|
|
class RNNLayer(nn.Module): |
|
def __init__(self, inputs_dim, outputs_dim, recurrent_N, use_orthogonal,rnn_type='gru'): |
|
super(RNNLayer, self).__init__() |
|
self._recurrent_N = recurrent_N |
|
self._use_orthogonal = use_orthogonal |
|
self.rnn_type = rnn_type |
|
if rnn_type == 'gru': |
|
self.rnn = nn.GRU(inputs_dim, outputs_dim, num_layers=self._recurrent_N) |
|
elif rnn_type == 'lstm': |
|
self.rnn = nn.LSTM(inputs_dim, outputs_dim, num_layers=self._recurrent_N) |
|
else: |
|
raise NotImplementedError(f'RNN type {rnn_type} has not been implemented.') |
|
|
|
for name, param in self.rnn.named_parameters(): |
|
if 'bias' in name: |
|
nn.init.constant_(param, 0) |
|
elif 'weight' in name: |
|
if self._use_orthogonal: |
|
nn.init.orthogonal_(param) |
|
else: |
|
nn.init.xavier_uniform_(param) |
|
self.norm = nn.LayerNorm(outputs_dim) |
|
|
|
def rnn_forward(self, x, h): |
|
if self.rnn_type == 'lstm': |
|
h = torch.split(h, h.shape[-1] // 2, dim=-1) |
|
h = (h[0].contiguous(), h[1].contiguous()) |
|
x_, h_ = self.rnn(x, h) |
|
if self.rnn_type == 'lstm': |
|
h_ = torch.cat(h_, -1) |
|
return x_, h_ |
|
|
|
def forward(self, x, hxs, masks): |
|
if x.size(0) == hxs.size(0): |
|
x, hxs = self.rnn_forward(x.unsqueeze(0), (hxs * masks.repeat(1, self._recurrent_N).unsqueeze(-1)).transpose(0, 1).contiguous()) |
|
|
|
x = x.squeeze(0) |
|
hxs = hxs.transpose(0, 1) |
|
else: |
|
|
|
N = hxs.size(0) |
|
T = int(x.size(0) / N) |
|
|
|
|
|
x = x.view(T, N, x.size(1)) |
|
|
|
|
|
masks = masks.view(T, N) |
|
|
|
|
|
|
|
has_zeros = ((masks[1:] == 0.0) |
|
.any(dim=-1) |
|
.nonzero() |
|
.squeeze() |
|
.cpu()) |
|
|
|
|
|
if has_zeros.dim() == 0: |
|
|
|
has_zeros = [has_zeros.item() + 1] |
|
else: |
|
has_zeros = (has_zeros + 1).numpy().tolist() |
|
|
|
|
|
has_zeros = [0] + has_zeros + [T] |
|
|
|
hxs = hxs.transpose(0, 1) |
|
|
|
outputs = [] |
|
for i in range(len(has_zeros) - 1): |
|
|
|
|
|
start_idx = has_zeros[i] |
|
end_idx = has_zeros[i + 1] |
|
temp = (hxs * masks[start_idx].view(1, -1, 1).repeat(self._recurrent_N, 1, 1)).contiguous() |
|
rnn_scores, hxs = self.rnn_forward(x[start_idx:end_idx], temp) |
|
outputs.append(rnn_scores) |
|
|
|
|
|
|
|
x = torch.cat(outputs, dim=0) |
|
|
|
|
|
x = x.reshape(T * N, -1) |
|
hxs = hxs.transpose(0, 1) |
|
|
|
x = self.norm(x) |
|
return x, hxs |
|
|
|
|
|
class InputEncoder(nn.Module): |
|
def __init__(self): |
|
super(InputEncoder, self).__init__() |
|
fc_layer_num = 2 |
|
fc_output_num = 64 |
|
self.active_input_num = 87 |
|
self.ball_owner_input_num = 57 |
|
self.left_input_num = 88 |
|
self.right_input_num = 88 |
|
self.match_state_input_num = 9 |
|
|
|
self.active_encoder = FcEncoder(fc_layer_num, self.active_input_num, fc_output_num) |
|
self.ball_owner_encoder = FcEncoder(fc_layer_num, self.ball_owner_input_num, fc_output_num) |
|
self.left_encoder = FcEncoder(fc_layer_num, self.left_input_num, fc_output_num) |
|
self.right_encoder = FcEncoder(fc_layer_num, self.right_input_num, fc_output_num) |
|
self.match_state_encoder = FcEncoder(fc_layer_num, self.match_state_input_num, self.match_state_input_num) |
|
|
|
def forward(self, x): |
|
active_vec = x[:, :self.active_input_num] |
|
ball_owner_vec = x[:, self.active_input_num : self.active_input_num + self.ball_owner_input_num] |
|
left_vec = x[:, self.active_input_num + self.ball_owner_input_num : self.active_input_num + self.ball_owner_input_num + self.left_input_num] |
|
right_vec = x[:, self.active_input_num + self.ball_owner_input_num + self.left_input_num : \ |
|
self.active_input_num + self.ball_owner_input_num + self.left_input_num + self.right_input_num] |
|
match_state_vec = x[:, self.active_input_num + self.ball_owner_input_num + self.left_input_num + self.right_input_num:] |
|
|
|
active_output = self.active_encoder(active_vec) |
|
ball_owner_output = self.ball_owner_encoder(ball_owner_vec) |
|
left_output = self.left_encoder(left_vec) |
|
right_output = self.right_encoder(right_vec) |
|
match_state_output = self.match_state_encoder(match_state_vec) |
|
|
|
return torch.cat([ |
|
active_output, |
|
ball_owner_output, |
|
left_output, |
|
right_output, |
|
match_state_output |
|
], 1) |
|
|
|
def get_fc(input_size, output_size): |
|
return nn.Sequential(nn.Linear(input_size, output_size), nn.ReLU(), nn.LayerNorm(output_size)) |
|
|
|
class ObsEncoder(nn.Module): |
|
def __init__(self, input_embedding_size, hidden_size, _recurrent_N, _use_orthogonal, rnn_type): |
|
super(ObsEncoder, self).__init__() |
|
self.input_encoder = InputEncoder() |
|
self.input_embedding = get_fc(input_embedding_size, hidden_size) |
|
self.rnn = RNNLayer(hidden_size, hidden_size, _recurrent_N, _use_orthogonal, rnn_type=rnn_type) |
|
self.after_rnn_mlp = get_fc(hidden_size, hidden_size) |
|
|
|
def forward(self, obs, rnn_states, masks): |
|
actor_features = self.input_encoder(obs) |
|
actor_features = self.input_embedding(actor_features) |
|
output, rnn_states = self.rnn(actor_features, rnn_states, masks) |
|
return self.after_rnn_mlp(output), rnn_states |
|
|
|
|
|
class PolicyNetwork(nn.Module): |
|
def __init__(self, device=torch.device("cpu")): |
|
super(PolicyNetwork, self).__init__() |
|
self.tpdv = dict(dtype=torch.float32, device=device) |
|
self.device = device |
|
self.hidden_size = 256 |
|
self._use_policy_active_masks = True |
|
recurrent_N = 1 |
|
use_orthogonal = True |
|
rnn_type = 'lstm' |
|
gain = 0.01 |
|
action_space = gym.spaces.Discrete(20) |
|
self.action_dim = 19 |
|
input_embedding_size = 64 * 4 + 9 |
|
self.active_id_size = 1 |
|
self.id_max = 11 |
|
|
|
self.obs_encoder = ObsEncoder(input_embedding_size, self.hidden_size, recurrent_N, use_orthogonal, rnn_type) |
|
|
|
self.predict_id = get_fc(self.hidden_size + self.action_dim, self.id_max) |
|
self.id_embedding = get_fc(self.id_max, self.id_max) |
|
|
|
self.before_act_wrapper = FcEncoder(2, self.hidden_size + self.id_max, self.hidden_size) |
|
self.act = ACTLayer(action_space, self.hidden_size, use_orthogonal, gain) |
|
|
|
self.to(device) |
|
|
|
|
|
def forward(self, obs, rnn_states, masks=np.concatenate(np.ones((1, 1, 1), dtype=np.float32)), available_actions=None, deterministic=False): |
|
obs = check(obs).to(**self.tpdv) |
|
if available_actions is not None: |
|
available_actions = check(available_actions).to(**self.tpdv) |
|
masks = check(masks).to(**self.tpdv) |
|
rnn_states = check(rnn_states).to(**self.tpdv) |
|
|
|
active_id = obs[:,:self.active_id_size].squeeze(1).long() |
|
id_onehot = torch.eye(self.id_max)[active_id].to(self.device) |
|
obs = obs[:,self.active_id_size:] |
|
|
|
obs_output, rnn_states = self.obs_encoder(obs, rnn_states, masks) |
|
id_output = self.id_embedding(id_onehot) |
|
output = torch.cat([id_output, obs_output], 1) |
|
|
|
output = self.before_act_wrapper(output) |
|
|
|
actions, action_log_probs = self.act(output, available_actions, deterministic) |
|
return actions, rnn_states |
|
|
|
def eval_actions(self, obs, rnn_states, action, masks, available_actions=None, active_masks=None): |
|
obs = check(obs).to(**self.tpdv) |
|
if available_actions is not None: |
|
available_actions = check(available_actions).to(**self.tpdv) |
|
if active_masks is not None: |
|
active_masks = check(active_masks).to(**self.tpdv) |
|
masks = check(masks).to(**self.tpdv) |
|
rnn_states = check(rnn_states).to(**self.tpdv) |
|
action = check(action).to(**self.tpdv) |
|
|
|
id_groundtruth = obs[:,:self.active_id_size].squeeze(1).long() |
|
id_onehot = torch.eye(self.id_max)[id_groundtruth].to(self.device) |
|
obs = obs[:,self.active_id_size:] |
|
|
|
obs_output, rnn_states = self.obs_encoder(obs, rnn_states, masks) |
|
id_output = self.id_embedding(id_onehot) |
|
|
|
action_onehot = torch.eye(self.action_dim)[action.squeeze(1).long()].to(self.device) |
|
|
|
id_prediction = self.predict_id(torch.cat([obs_output, action_onehot], 1)) |
|
output = torch.cat([id_output, obs_output], 1) |
|
|
|
output = self.before_act_wrapper(output) |
|
action_log_probs, dist_entropy = self.act.evaluate_actions(output, action, available_actions, |
|
active_masks=active_masks if self._use_policy_active_masks else None) |
|
values = None |
|
return action_log_probs, dist_entropy, values, id_prediction, id_groundtruth |
|
|
|
|