Spaces:
Sleeping
Sleeping
""" | |
FileName: main_worker.py | |
Author: Jiaxin Li | |
Create Date: 2023/11/21 | |
Description: The implement of Gumbel MCST | |
Edit History: | |
Debug: the dim of output: probs | |
""" | |
import numpy as np | |
import copy | |
import time | |
from config.options import * | |
import sys | |
from config.utils import * | |
def softmax(x): | |
probs = np.exp(x - np.max(x)) | |
probs /= np.sum(probs) | |
return probs | |
def _sigma_mano(y ,Nb): | |
return (50 + Nb) * 1.0 * y | |
class TreeNode(object): | |
"""A node in the MCTS tree. | |
Each node keeps track of its own value Q, prior probability P, and | |
its visit-count-adjusted prior score u. | |
""" | |
def __init__(self, parent, prior_p): | |
self._parent = parent | |
self._children = {} # a map from action to TreeNode | |
self._n_visits = 0 | |
self._Q = 0 | |
self._u = 0 | |
self._v = 0 | |
self._p = prior_p | |
def expand(self, action_priors): | |
"""Expand tree by creating new children. | |
action_priors: a list of tuples of actions and their prior probability | |
according to the policy function. | |
""" | |
for action, prob in action_priors: | |
if action not in self._children: | |
self._children[action] = TreeNode(self, prob) | |
def select(self, v_pi): | |
"""Select action among children that gives maximum | |
(pi'(a) - N(a) \ (1 + \sum_b N(b))) | |
Return: A tuple of (action, next_node) | |
""" | |
# if opts.split == "train": | |
# v_pi = v_pi.detach().numpy() | |
# print(v_pi) | |
max_N_b = np.max(np.array( [act_node[1]._n_visits for act_node in self._children.items()])) | |
if opts.split == "train": | |
pi_ = softmax( np.array( [ act_node[1].get_pi(v_pi,max_N_b) for act_node in self._children.items() ])).reshape(len(list(self._children.items())) ,-1) | |
else: | |
pi_ = softmax( np.array( [ act_node[1].get_pi(v_pi,max_N_b) for act_node in self._children.items() ])).reshape(len(list(self._children.items())) ,-1) | |
# print(pi_.shape) | |
N_a = np.array( [ act_node[1]._n_visits / (1 + self._n_visits) for act_node in self._children.items() ]).reshape(pi_.shape[0],-1) | |
# print(N_a.shape) | |
max_index= np.argmax(pi_ - N_a) | |
# print((pi_ - N_a).shape) | |
return list(self._children.items())[max_index] | |
def update(self, leaf_value): | |
"""Update node values from leaf evaluation. | |
leaf_value: the value of subtree evaluation from the current player's | |
perspective. | |
""" | |
# Count visit. | |
self._n_visits += 1 | |
# Update Q, a running average of values for all visits. | |
if opts.split == "train": | |
self._Q = self._Q + (1.0*(leaf_value - self._Q ) / self._n_visits) | |
else: | |
self._Q += (1.0*(leaf_value - self._Q) / self._n_visits) | |
def update_recursive(self, leaf_value): | |
"""Like a call to update(), but applied recursively for all ancestors. | |
""" | |
# If it is not root, this node's parent should be updated first. | |
if self._parent: | |
self._parent.update_recursive(-leaf_value) | |
self.update(leaf_value) | |
def get_pi(self,v_pi,max_N_b): | |
if self._n_visits == 0: | |
Q_completed = v_pi | |
else: | |
Q_completed = self._Q | |
return self._p + _sigma_mano(Q_completed,max_N_b) | |
def get_value(self, c_puct): | |
"""Calculate and return the value for this node. | |
It is a combination of leaf evaluations Q, and this node's prior | |
adjusted for its visit count, u. | |
c_puct: a number in (0, inf) controlling the relative impact of | |
value Q, and prior probability P, on this node's score. | |
""" | |
self._u = (c_puct * self._P * | |
np.sqrt(self._parent._n_visits) / (1 + self._n_visits)) | |
return self._Q + self._u | |
def is_leaf(self): | |
"""Check if leaf node (i.e. no nodes below this have been expanded).""" | |
return self._children == {} | |
def is_root(self): | |
return self._parent is None | |
class Gumbel_MCTS(object): | |
"""An implementation of Monte Carlo Tree Search.""" | |
def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): | |
""" | |
policy_value_fn: a function that takes in a board state and outputs | |
a list of (action, probability) tuples and also a score in [-1, 1] | |
(i.e. the expected value of the end game score from the current | |
player's perspective) for the current player. | |
c_puct: a number in (0, inf) that controls how quickly exploration | |
converges to the maximum-value policy. A higher value means | |
relying on the prior more. | |
""" | |
self._root = TreeNode(None, 1.0) | |
self._policy = policy_value_fn | |
self._c_puct = c_puct | |
self._n_playout = n_playout | |
def Gumbel_playout(self, child_node, child_state): | |
"""Run a single playout from the child of the root to the leaf, getting a value at | |
the leaf and propagating it back through its parents. | |
State is modified in-place, so a copy must be provided. | |
This mothod of select is a non-root selet. | |
""" | |
node = child_node | |
state = child_state | |
while(1): | |
if node.is_leaf(): | |
break | |
# Greedily select next move. | |
action, node = node.select(node._v) | |
state.do_move(action) | |
# Evaluate the leaf using a network which outputs a list of | |
# (action, probability) tuples p and also a score v in [-1, 1] | |
# for the current player. | |
action_probs, leaf_value = self._policy(state) | |
leaf_value = leaf_value.detach().numpy()[0][0] | |
node._v = leaf_value | |
# Check for end of game. | |
end, winner = state.game_end() | |
if not end: | |
node.expand(action_probs) | |
else: | |
# for end state,return the "true" leaf_value | |
if winner == -1: # tie | |
leaf_value = 0.0 | |
else: | |
leaf_value = ( | |
1.0 if winner == state.get_current_player() else -1.0 | |
) | |
# Update value and visit count of nodes in this traversal. | |
node.update_recursive(-leaf_value) | |
def top_k(self,x, k): | |
# print("x",x.shape) | |
# print("k ", k) | |
return np.argpartition(x, k)[..., -k:] | |
def sample_k(self,logits, k): | |
u = np.random.uniform(size=np.shape(logits)) | |
z = -np.log(-np.log(u)) | |
return self.top_k(logits + z, k),z | |
def get_move_probs(self, state, temp=1e-3,m_action = 16): | |
"""Run all playouts sequentially and return the available actions and | |
their corresponding probabilities. | |
state: the current game state | |
temp: temperature parameter in (0, 1] controls the level of exploration | |
""" | |
# 这里需要修改:1 | |
# logits 暂定为 p | |
start_time = time.time() | |
# 对根节点进行拓展 | |
act_probs, leaf_value = self._policy(state) | |
act_probs = list(act_probs) | |
leaf_value = leaf_value.detach().numpy()[0][0] | |
# print(list(act_probs)) | |
porbs = [prob for act,prob in (act_probs)] | |
self._root.expand(act_probs) | |
n = self._n_playout | |
m = min( m_action,int(len( porbs) / 2)) | |
# 先进行Gumbel 分布采样,不重复的采样前m个动作,对应选择公式 logits + g | |
A_topm ,g = self.sample_k(porbs , m) | |
# 获得state选取每个action后对应的状态,保存到一个列表中 | |
root_childs = list(self._root._children.items()) | |
child_state_m = [] | |
for i in range(m): | |
state_copy = copy.deepcopy(state) | |
action,node = root_childs[A_topm[i]] | |
state_copy.do_move(action) | |
child_state_m.append(state_copy) | |
# 每轮对选择的动作进行的仿真次数 | |
N = int( n /( np.log(m) * m )) | |
# 进行sequential halving with Gumbel | |
while m >= 1: | |
# 对每个选择的动作进行仿真 | |
for i in range(m): | |
action_state = child_state_m[i] | |
action,node = root_childs[A_topm[i]] | |
for j in range(N): | |
action_state_copy = copy.deepcopy(action_state) | |
# 对选择动作进行仿真: 即找到这个子树的叶节点,然后再网络中预测v,然后往上回溯的过程 | |
self.Gumbel_playout(node, action_state_copy) | |
# 每轮不重复采样的动作个数减半 | |
m = m //2 | |
# 不是最后一轮,单轮仿真次数加倍 | |
if(m != 1): | |
n = n - N | |
N *= 2 | |
# 当最后一轮时,只有一个动作,把所有仿真次数用完 | |
else: | |
N = n | |
# 进行新的一轮不重复采样, 采样在之前的动作前一半的动作, 对应公式 g + logits + \sigma( \hat{q} ) | |
# print([action_node[1]._Q for action_node in self._root._children.items() ]) | |
q_hat = np.array([action_node[1]._Q for action_node in self._root._children.items() ]) | |
assert(np.sum(q_hat[A_topm] == 0) == 0 ) | |
A_index = self.top_k( np.array(porbs)[A_topm] + np.array(g)[A_topm] + q_hat[A_topm] , m) | |
A_topm = np.array(A_topm)[A_index] | |
child_state_m = np.array(child_state_m)[A_index] | |
# 最后返回对应的决策函数, 即 pi' = softmax(logits + sigma(completed Q)) | |
max_N_b = np.max(np.array( [act_node[1]._n_visits for act_node in self._root._children.items()] )) | |
final_act_probs= softmax( np.array( [ act_node[1].get_pi(leaf_value, max_N_b) for act_node in self._root._children.items() ])) | |
action = ( np.array( [ act_node[0] for act_node in self._root._children.items() ])) | |
need_time = time.time() - start_time | |
print(f" Gumbel Alphazero sum_time: {need_time }, total_simulation: {self._n_playout}") | |
return np.array(list(self._root._children.items()))[A_topm][0][0], action, final_act_probs , need_time | |
def update_with_move(self, last_move): | |
"""Step forward in the tree, keeping everything we already know | |
about the subtree. | |
""" | |
if last_move in self._root._children: | |
self._root = self._root._children[last_move] | |
self._root._parent = None | |
else: | |
self._root = TreeNode(None, 1.0) | |
def __str__(self): | |
return "MCTS" | |
class Gumbel_MCTSPlayer(object): | |
"""AI player based on MCTS""" | |
def __init__(self, policy_value_function, | |
c_puct=5, n_playout=2000, is_selfplay=0,m_action = 16): | |
self.mcts = Gumbel_MCTS(policy_value_function, c_puct, n_playout) | |
self._is_selfplay = is_selfplay | |
self.m_action = m_action | |
def set_player_ind(self, p): | |
self.player = p | |
def reset_player(self): | |
self.mcts.update_with_move(-1) | |
def get_action(self, board, temp=1e-3, return_prob=0,return_time = False): | |
sensible_moves = board.availables | |
# the pi vector returned by MCTS as in the alphaGo Zero paper | |
move_probs = np.zeros(board.width*board.height) | |
if len(sensible_moves) > 0: | |
# 在搜索树中利用sequential halving with Gumbel 来进行动作选择 并且返回对应的决策函数 | |
move, acts, probs,simul_mean_time = self.mcts.get_move_probs(board, temp,self.m_action) | |
# 重置搜索树 | |
self.mcts.update_with_move(-1) | |
move_probs[list(acts)] = probs | |
if return_time: | |
if return_prob: | |
return move, move_probs,simul_mean_time | |
else: | |
return move,simul_mean_time | |
else: | |
if return_prob: | |
return move, move_probs | |
else: | |
return move | |
else: | |
print("WARNING: the board is full") | |
def __str__(self): | |
return "MCTS {}".format(self.player) | |