"""
FileName: mcts_Gumbel_Alphazero.py
Author: Jiaxin Li
Create Date: 2023/11/21
Description: The implement of  Gumbel MCST 
Edit History:
Debug: the dim of output: probs
"""

import numpy as np
import copy
import time

from .config.options import *
import sys
from .config.utils import *


def softmax(x):
    probs = np.exp(x - np.max(x))
    probs /= np.sum(probs)
    return probs


def _sigma_mano(y, Nb):
    return (50 + Nb) * 1.0 * y


class TreeNode(object):
    """A node in the MCTS tree.

    Each node keeps track of its own value Q, prior probability P, and
    its visit-count-adjusted prior score u.
    """

    def __init__(self, parent, prior_p):
        self._parent = parent
        self._children = {}  # a map from action to TreeNode
        self._n_visits = 0
        self._Q = 0
        self._u = 0
        self._v = 0
        self._p = prior_p

    def expand(self, action_priors):
        """Expand tree by creating new children.
        action_priors: a list of tuples of actions and their prior probability
            according to the policy function.
        """
        for action, prob in action_priors:
            if action not in self._children:
                self._children[action] = TreeNode(self, prob)

    def select(self, v_pi):
        """Select action among children that gives maximum 
        (pi'(a) - N(a) \ (1 + \sum_b N(b)))
        Return: A tuple of (action, next_node)
        """
        # if opts.split == "train":
        #     v_pi = v_pi.detach().numpy()
        # print(v_pi)

        max_N_b = np.max(np.array([act_node[1]._n_visits for act_node in self._children.items()]))

        if opts.split == "train":
            pi_ = softmax(np.array([act_node[1].get_pi(v_pi, max_N_b) for act_node in self._children.items()])).reshape(
                len(list(self._children.items())), -1)
        else:
            pi_ = softmax(np.array([act_node[1].get_pi(v_pi, max_N_b) for act_node in self._children.items()])).reshape(
                len(list(self._children.items())), -1)
        # print(pi_.shape)

        N_a = np.array([act_node[1]._n_visits / (1 + self._n_visits) for act_node in self._children.items()]).reshape(
            pi_.shape[0], -1)
        # print(N_a.shape)    

        max_index = np.argmax(pi_ - N_a)
        # print((pi_ - N_a).shape)

        return list(self._children.items())[max_index]

    def update(self, leaf_value):
        """Update node values from leaf evaluation.
        leaf_value: the value of subtree evaluation from the current player's
            perspective.
        """
        # Count visit.
        self._n_visits += 1
        # Update Q, a running average of values for all visits.
        if opts.split == "train":
            self._Q = self._Q + (1.0 * (leaf_value - self._Q) / self._n_visits)
        else:
            self._Q += (1.0 * (leaf_value - self._Q) / self._n_visits)

    def update_recursive(self, leaf_value):
        """Like a call to update(), but applied recursively for all ancestors.
        """
        # If it is not root, this node's parent should be updated first.
        if self._parent:
            self._parent.update_recursive(-leaf_value)
        self.update(leaf_value)

    def get_pi(self, v_pi, max_N_b):
        if self._n_visits == 0:
            Q_completed = v_pi
        else:
            Q_completed = self._Q

        return self._p + _sigma_mano(Q_completed, max_N_b)

    def get_value(self, c_puct):
        """Calculate and return the value for this node.
        It is a combination of leaf evaluations Q, and this node's prior
        adjusted for its visit count, u.
        c_puct: a number in (0, inf) controlling the relative impact of
            value Q, and prior probability P, on this node's score.
        """
        self._u = (c_puct * self._P *
                   np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
        return self._Q + self._u

    def is_leaf(self):
        """Check if leaf node (i.e. no nodes below this have been expanded)."""
        return self._children == {}

    def is_root(self):
        return self._parent is None


class Gumbel_MCTS(object):
    """An implementation of Monte Carlo Tree Search."""

    def __init__(self, policy_value_fn, c_puct=5, n_playout=10000):
        """
        policy_value_fn: a function that takes in a board state and outputs
            a list of (action, probability) tuples and also a score in [-1, 1]
            (i.e. the expected value of the end game score from the current
            player's perspective) for the current player.
        c_puct: a number in (0, inf) that controls how quickly exploration
            converges to the maximum-value policy. A higher value means
            relying on the prior more.
        """
        self._root = TreeNode(None, 1.0)
        self._policy = policy_value_fn
        self._c_puct = c_puct
        self._n_playout = n_playout

    def Gumbel_playout(self, child_node, child_state):
        """Run a single playout from the child of the root to the leaf, getting a value at
        the leaf and propagating it back through its parents.
        State is modified in-place, so a copy must be provided.
        This mothod of select is a non-root selet.
        """
        node = child_node
        state = child_state

        while (1):
            if node.is_leaf():
                break
            # Greedily select next move.

            action, node = node.select(node._v)

            state.do_move(action)

        # Evaluate the leaf using a network which outputs a list of
        # (action, probability) tuples p and also a score v in [-1, 1]
        # for the current player.

        action_probs, leaf_value = self._policy(state)

        # leaf_value = leaf_value.detach().numpy()[0][0]
        leaf_value = leaf_value.detach().numpy()

        node._v = leaf_value

        # Check for end of game.
        end, winner = state.game_end()
        if not end:
            node.expand(action_probs)
        else:
            # for end state,return the "true" leaf_value
            if winner == -1:  # tie
                leaf_value = 0.0
            else:
                leaf_value = (
                    1.0 if winner == state.get_current_player() else -1.0
                )

        # Update value and visit count of nodes in this traversal.
        node.update_recursive(-leaf_value)

    def top_k(self, x, k):
        # print("x",x.shape)
        # print("k ", k)

        return np.argpartition(x, k)[..., -k:]

    def sample_k(self, logits, k):
        u = np.random.uniform(size=np.shape(logits))
        z = -np.log(-np.log(u))
        return self.top_k(logits + z, k), z

    def get_move_probs(self, state, temp=1e-3, m_action=16):
        """Run all playouts sequentially and return the available actions and
        their corresponding probabilities.
        state: the current game state
        temp: temperature parameter in (0, 1] controls the level of exploration
        """
        # 这里需要修改:1
        # logits 暂定为 p

        start_time = time.time()
        # 对根节点进行拓展

        act_probs, leaf_value = self._policy(state)

        act_probs = list(act_probs)

        # leaf_value = leaf_value.detach().numpy()[0][0]
        leaf_value = leaf_value.detach().numpy()
        # print(list(act_probs))
        porbs = [prob for act, prob in (act_probs)]

        self._root.expand(act_probs)

        n = self._n_playout
        m = min(m_action, int(len(porbs) / 2))

        # 先进行Gumbel 分布采样,不重复的采样前m个动作,对应选择公式 logits + g
        A_topm, g = self.sample_k(porbs, m)

        # 获得state选取每个action后对应的状态,保存到一个列表中
        root_childs = list(self._root._children.items())

        child_state_m = []
        for i in range(m):
            state_copy = copy.deepcopy(state)
            action, node = root_childs[A_topm[i]]
            state_copy.do_move(action)
            child_state_m.append(state_copy)
        print(porbs)

        print("depend on:", np.array(porbs)[A_topm])
        print(f"A_topm_{m}", A_topm)

        print("m ", m)

        if m > 1:
            # 每轮对选择的动作进行的仿真次数
            N = int(n / (np.log(m) * m))
        else:
            N = n

        # 进行sequential halving with Gumbel 
        while m >= 1:

            # 对每个选择的动作进行仿真
            for i in range(m):
                action_state = child_state_m[i]

                action, node = root_childs[A_topm[i]]

                for j in range(N):
                    action_state_copy = copy.deepcopy(action_state)

                    # 对选择动作进行仿真: 即找到这个子树的叶节点,然后再网络中预测v,然后往上回溯的过程
                    self.Gumbel_playout(node, action_state_copy)

            # 每轮不重复采样的动作个数减半
            m = m // 2

            # 不是最后一轮,单轮仿真次数加倍
            if (m != 1):
                n = n - N
                N *= 2
            # 当最后一轮时,只有一个动作,把所有仿真次数用完
            else:
                N = n

            # 进行新的一轮不重复采样, 采样在之前的动作前一半的动作, 对应公式 g + logits + \sigma( \hat{q} )
            # print([action_node[1]._Q for action_node in self._root._children.items()  ])

            q_hat = np.array([action_node[1]._Q for action_node in self._root._children.items()])
            assert (np.sum(q_hat[A_topm] == 0) == 0)

            print("depend on:", np.array(porbs)[A_topm] + np.array(g)[A_topm] + q_hat[A_topm])
            print(f"A_topm_{m}", A_topm)

            A_index = self.top_k(np.array(porbs)[A_topm] + np.array(g)[A_topm] + q_hat[A_topm], m)
            A_topm = np.array(A_topm)[A_index]
            child_state_m = np.array(child_state_m)[A_index]

        # 最后返回对应的决策函数, 即 pi' = softmax(logits + sigma(completed Q))

        max_N_b = np.max(np.array([act_node[1]._n_visits for act_node in self._root._children.items()]))

        final_act_probs = softmax(
            np.array([act_node[1].get_pi(leaf_value, max_N_b) for act_node in self._root._children.items()]))

        action = (np.array([act_node[0] for act_node in self._root._children.items()]))
        print("final_act_prbs", final_act_probs)
        print("move :", action)
        print("final_action", np.array(list(self._root._children.items()))[A_topm][0][0])
        print("argmax_prob", np.argmax(final_act_probs))
        need_time = time.time() - start_time
        print(f" Gumbel Alphazero sum_time: {need_time}, total_simulation: {self._n_playout}")

        return np.array(list(self._root._children.items()))[A_topm][0][0], action, final_act_probs, need_time

    def update_with_move(self, last_move):
        """Step forward in the tree, keeping everything we already know
        about the subtree.
        """
        if last_move in self._root._children:
            self._root = self._root._children[last_move]
            self._root._parent = None
        else:
            self._root = TreeNode(None, 1.0)

    def __str__(self):
        return "MCTS"


class Gumbel_MCTSPlayer(object):
    """AI player based on MCTS"""

    def __init__(self, policy_value_function,
                 c_puct=5, n_playout=2000, is_selfplay=0, m_action=16):
        self.mcts = Gumbel_MCTS(policy_value_function, c_puct, n_playout)
        self._is_selfplay = is_selfplay
        self.m_action = m_action

    def set_player_ind(self, p):
        self.player = p

    def reset_player(self):
        self.mcts.update_with_move(-1)

    def get_action(self, board, temp=1e-3, return_prob=0, return_time=False):
        sensible_moves = board.availables
        # the pi vector returned by MCTS as in the alphaGo Zero paper
        move_probs = np.zeros(board.width * board.height)

        if len(sensible_moves) > 0:

            # 在搜索树中利用sequential halving with Gumbel 来进行动作选择 并且返回对应的决策函数
            move, acts, probs, simul_mean_time = self.mcts.get_move_probs(board, temp, self.m_action)

            # 重置搜索树
            self.mcts.update_with_move(-1)

            move_probs[list(acts)] = probs
            move_probs = np.zeros(move_probs.shape[0])
            move_probs[move] = 1

            print("final prob:", move_probs)
            print("arg_max:", np.argmax(move_probs))
            print("max", np.max(move_probs))
            print("move", move)

            # 他通过训练能够使得最后move_probs 有一个位置趋近于1,即得到一个策略
            # 关键是他的策略,和MCTS得到move不一致,怀疑是分布策略计算的问题

            if return_time:

                if return_prob:

                    return move, move_probs, simul_mean_time
                else:
                    return move, simul_mean_time
            else:

                if return_prob:

                    return move, move_probs
                else:
                    return move
        else:
            print("WARNING: the board is full")

    def __str__(self):
        return "MCTS {}".format(self.player)