|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""""" |
|
import os |
|
import sys |
|
from pathlib import Path |
|
import numpy as np |
|
import torch |
|
|
|
base_dir = Path(__file__).resolve().parent |
|
sys.path.append(str(base_dir)) |
|
|
|
from openrl_policy import PolicyNetwork |
|
from openrl_utils import openrl_obs_deal, _t2n |
|
from goal_keeper import agent_get_action |
|
|
|
class OpenRLAgent(): |
|
def __init__(self): |
|
rnn_shape = [1,1,1,512] |
|
self.rnn_hidden_state = [np.zeros(rnn_shape, dtype=np.float32) for _ in range (11)] |
|
self.model = PolicyNetwork() |
|
self.model.load_state_dict(torch.load( os.path.dirname(os.path.abspath(__file__)) + '/actor.pt', map_location=torch.device("cpu"))) |
|
self.model.eval() |
|
|
|
def get_action(self,raw_obs,idx): |
|
if idx == 0: |
|
re_action = [[0]*19] |
|
re_action_index = agent_get_action(raw_obs)[0] |
|
re_action[0][re_action_index] = 1 |
|
return re_action |
|
|
|
openrl_obs = openrl_obs_deal(raw_obs) |
|
|
|
obs = openrl_obs['obs'] |
|
obs = np.concatenate(obs.reshape(1, 1, 330)) |
|
rnn_hidden_state = np.concatenate(self.rnn_hidden_state[idx]) |
|
avail_actions = np.zeros(20) |
|
avail_actions[:19] = openrl_obs['available_action'] |
|
avail_actions = np.concatenate(avail_actions.reshape([1, 1, 20])) |
|
with torch.no_grad(): |
|
actions, rnn_hidden_state = self.model(obs, rnn_hidden_state, available_actions=avail_actions, deterministic=True) |
|
if actions[0][0] == 17 and raw_obs["sticky_actions"][8] == 1: |
|
actions[0][0] = 15 |
|
self.rnn_hidden_state[idx] = np.array(np.split(_t2n(rnn_hidden_state), 1)) |
|
|
|
re_action = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] |
|
re_action[0][actions[0]] = 1 |
|
|
|
return re_action |
|
|
|
agent = OpenRLAgent() |
|
|
|
def my_controller(obs_list, action_space_list, is_act_continuous=False): |
|
idx = obs_list['controlled_player_index'] % 11 |
|
del obs_list['controlled_player_index'] |
|
action = agent.get_action(obs_list,idx) |
|
return action |
|
|
|
def jidi_controller(obs_list=None): |
|
if obs_list is None: |
|
return |
|
|
|
re = my_controller(obs_list,None) |
|
assert isinstance(re,list) |
|
assert isinstance(re[0],list) |
|
return re |