# -*- coding: utf-8 -*- import numpy as np import copy from operator import itemgetter import time def rollout_policy_fn(board): """a coarse, fast version of policy_fn used in the rollout phase.""" # rollout randomly action_probs = np.random.rand(len(board.availables)) return zip(board.availables, action_probs) # 决策价值函数 def policy_value_fn(board): """a function that takes in a state and outputs a list of (action, probability) tuples and a score for the state""" # return uniform probabilities and 0 score for pure MCTS action_probs = np.ones(len(board.availables))/len(board.availables) return zip(board.availables, action_probs), 0 class TreeNode(object): """A node in the MCTS tree. Each node keeps track of its own value Q, prior probability P, and its visit-count-adjusted prior score u. """ def __init__(self, parent, prior_p): self._parent = parent self._children = {} # a map from action to TreeNode self._n_visits = 0 self._Q = 0 self._u = 0 self._P = prior_p def expand(self, action_priors): """Expand tree by creating new children. action_priors: a list of tuples of actions and their prior probability according to the policy function. """ for action, prob in action_priors: if action not in self._children: self._children[action] = TreeNode(self, prob) def select(self, c_puct): """Select action among children that gives maximum action value Q plus bonus u(P). Return: A tuple of (action, next_node) """ return max(self._children.items(), key=lambda act_node: act_node[1].get_value(c_puct)) def update(self, leaf_value): """Update node values from leaf evaluation. leaf_value: the value of subtree evaluation from the current player's perspective. """ # Count visit. self._n_visits += 1 # Update Q, a running average of values for all visits. # print("=====================================") # print("Before, Q: {}, visits: {}, leaf_value: {}".format(self._Q, self._n_visits,leaf_value)) self._Q += 1.0*(leaf_value - self._Q) / self._n_visits # print("After, Q: {}, visits: {}, leaf_value: {}".format(self._Q, self._n_visits,leaf_value)) def update_recursive(self, leaf_value): """Like a call to update(), but applied recursively for all ancestors. """ # If it is not root, this node's parent should be updated first. if self._parent: self._parent.update_recursive(-leaf_value) self.update(leaf_value) def get_value(self, c_puct): """Calculate and return the value for this node. It is a combination of leaf evaluations Q, and this node's prior adjusted for its visit count, u. c_puct: a number in (0, inf) controlling the relative impact of value Q, and prior probability P, on this node's score. """ self._u = (c_puct * self._P * np.sqrt(self._parent._n_visits) / (1 + self._n_visits)) return self._Q + self._u def is_leaf(self): """Check if leaf node (i.e. no nodes below this have been expanded). """ return self._children == {} def is_root(self): return self._parent is None class MCTS(object): """A simple implementation of Monte Carlo Tree Search.""" def __init__(self, policy_value_fn, c_puct=5, n_playout=2000): """ policy_value_fn: a function that takes in a board state and outputs a list of (action, probability) tuples and also a score in [-1, 1] (i.e. the expected value of the end game score from the current player's perspective) for the current player. c_puct: a number in (0, inf) that controls how quickly exploration converges to the maximum-value policy. A higher value means relying on the prior more. ??? """ self._root = TreeNode(None, 1.0) self._policy = policy_value_fn self._c_puct = c_puct self._n_playout = n_playout def _playout(self, state): """Run a single playout from the root to the leaf, getting a value at the leaf and propagating it back through its parents. State is modified in-place, so a copy must be provided. """ node = self._root while(1): if node.is_leaf(): break # Greedily select next move. action, node = node.select(self._c_puct) state.do_move(action) action_probs, _ = self._policy(state) # Check for end of game end, winner = state.game_end() if not end: node.expand(action_probs) # Evaluate the leaf node by random rollout leaf_value = self._evaluate_rollout(state) # Update value and visit count of nodes in this traversal. node.update_recursive(-leaf_value) def _evaluate_rollout(self, state, limit=1000): """Use the rollout policy to play until the end of the game, returning +1 if the current player wins, -1 if the opponent wins, and 0 if it is a tie. """ player = state.get_current_player() for i in range(limit): end, winner = state.game_end() if end: break action_probs = rollout_policy_fn(state) max_action = max(action_probs, key=itemgetter(1))[0] state.do_move(max_action) else: # If no break from the loop, issue a warning. print("WARNING: rollout reached move limit") if winner == -1: # tie return 0 else: return 1 if winner == player else -1 def get_move(self, state): """Runs all playouts sequentially and returns the most visited action. state: the current game state Return: the selected action """ start_time = time.time() # n_playout 探索的次数 for n in range(self._n_playout): state_copy = copy.deepcopy(state) self._playout(state_copy) need_time = time.time() - start_time print(f" PureMCTS sum_time: {need_time / self._n_playout }, total_simulation: {self._n_playout}") return max(self._root._children.items(),key=lambda act_node: act_node[1]._n_visits)[0], need_time / self._n_playout def update_with_move(self, last_move): """Step forward in the tree, keeping everything we already know about the subtree. """ if last_move in self._root._children: self._root = self._root._children[last_move] self._root._parent = None else: self._root = TreeNode(None, 1.0) def __str__(self): return "MCTS" class MCTSPlayer(object): """AI player based on MCTS""" def __init__(self, c_puct=5, n_playout=2000): self.mcts = MCTS(policy_value_fn, c_puct, n_playout) def set_player_ind(self, p): self.player = p def reset_player(self): self.mcts.update_with_move(-1) def get_action(self, board): sensible_moves = board.availables if len(sensible_moves) > 0: move, simul_mean_time = self.mcts.get_move(board) self.mcts.update_with_move(-1) print("MCTS move:", move) return move, simul_mean_time else: print("WARNING: the board is full") def __str__(self): return "MCTS {}".format(self.player) # 多了下面这一串代码 class Human_Player(object): def __init__(self): pass def set_player_ind(self, p): self.player = p def get_action(self, board): sensible_moves = board.availables if len(sensible_moves) > 0: # print(sensible_moves) move = int(input("Input the move:")) while (move not in sensible_moves ): print(sensible_moves) move = int(input("Input the move again:")) return move else: print("WARNING: the board is full") def __str__(self): return "Human {}".format(self.player)