tizero / submission.py
ShiyuHuang's picture
Upload folder using huggingface_hub
2322e9b
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
import os
import sys
from pathlib import Path
import numpy as np
import torch
base_dir = Path(__file__).resolve().parent
sys.path.append(str(base_dir))
from openrl_policy import PolicyNetwork
from openrl_utils import openrl_obs_deal, _t2n
from goal_keeper import agent_get_action
class OpenRLAgent():
def __init__(self):
rnn_shape = [1,1,1,512]
self.rnn_hidden_state = [np.zeros(rnn_shape, dtype=np.float32) for _ in range (11)]
self.model = PolicyNetwork()
self.model.load_state_dict(torch.load( os.path.dirname(os.path.abspath(__file__)) + '/actor.pt', map_location=torch.device("cpu")))
self.model.eval()
def get_action(self,raw_obs,idx):
if idx == 0:
re_action = [[0]*19]
re_action_index = agent_get_action(raw_obs)[0]
re_action[0][re_action_index] = 1
return re_action
openrl_obs = openrl_obs_deal(raw_obs)
obs = openrl_obs['obs']
obs = np.concatenate(obs.reshape(1, 1, 330))
rnn_hidden_state = np.concatenate(self.rnn_hidden_state[idx])
avail_actions = np.zeros(20)
avail_actions[:19] = openrl_obs['available_action']
avail_actions = np.concatenate(avail_actions.reshape([1, 1, 20]))
with torch.no_grad():
actions, rnn_hidden_state = self.model(obs, rnn_hidden_state, available_actions=avail_actions, deterministic=True)
if actions[0][0] == 17 and raw_obs["sticky_actions"][8] == 1:
actions[0][0] = 15
self.rnn_hidden_state[idx] = np.array(np.split(_t2n(rnn_hidden_state), 1))
re_action = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
re_action[0][actions[0]] = 1
return re_action
agent = OpenRLAgent()
def my_controller(obs_list, action_space_list, is_act_continuous=False):
idx = obs_list['controlled_player_index'] % 11
del obs_list['controlled_player_index']
action = agent.get_action(obs_list,idx)
return action
def jidi_controller(obs_list=None):
if obs_list is None:
return
#重命名,防止加载错误
re = my_controller(obs_list,None)
assert isinstance(re,list)
assert isinstance(re[0],list)
return re