Spaces:
Running
Running
""" | |
FileName: app.py | |
Author: Benhao Huang | |
Create Date: 2023/11/19 | |
Description: this file is used to display our project and add visualization elements to the game, using Streamlit | |
""" | |
import time | |
import pandas as pd | |
from copy import deepcopy | |
import numpy as np | |
import streamlit as st | |
from scipy.signal import convolve # this is used to check if any player wins | |
from streamlit import session_state | |
from streamlit_server_state import server_state, server_state_lock | |
from Gomoku_MCTS import MCTSpure, alphazero, Board, PolicyValueNet_old, PolicyValueNet_new, duel_PolicyValueNet, \ | |
Gumbel_MCTSPlayer | |
from Gomoku_Bot import Gomoku_bot | |
from Gomoku_Bot import Board as Gomoku_bot_board | |
import matplotlib.pyplot as plt | |
from const import ( | |
_BLACK, # 1, for human | |
_WHITE, # 2 , for AI | |
_BLANK, | |
_PLAYER_COLOR, | |
_PLAYER_SYMBOL1, | |
_PLAYER_SYMBOL2, | |
_ROOM_COLOR, | |
_VERTICAL, | |
_NEW, | |
_HORIZONTAL, | |
_DIAGONAL_UP_LEFT, | |
_DIAGONAL_UP_RIGHT, | |
_BOARD_SIZE, | |
_MODEL_PATH | |
) | |
_PLAYER_SYMBOL = [0, _PLAYER_SYMBOL1, _PLAYER_SYMBOL2] | |
# ''' | |
# from ai import ( | |
# BOS_TOKEN_ID, | |
# generate_gpt2, | |
# load_model, | |
# ) | |
# | |
# gpt2 = load_model() | |
# | |
# ''' | |
if "FirstPlayer" not in session_state: | |
session_state.FirstPlayer = _BLACK | |
session_state.Player = [[], [ _BLACK,_WHITE], [_WHITE,_BLACK]][session_state.FirstPlayer] | |
session_state.Symbol = _PLAYER_SYMBOL[session_state.FirstPlayer] | |
# Utils | |
class Room: | |
def __init__(self, room_id) -> None: | |
self.ROOM_ID = room_id | |
# self.BOARD = np.zeros(shape=(_BOARD_SIZE, _BOARD_SIZE), dtype=int) | |
self.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5, players=session_state.Player) | |
self.PLAYER = session_state.FirstPlayer | |
self.TURN = self.PLAYER | |
self.HISTORY = (0, 0) | |
self.WINNER = _BLANK | |
self.TIME = time.time() | |
self.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1) | |
self.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000), | |
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE, | |
_MODEL_PATH["AlphaZero"]).policy_value_fn, | |
c_puct=5, n_playout=100), | |
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, | |
_MODEL_PATH["duel"]).policy_value_fn, | |
c_puct=5, n_playout=100), | |
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE, | |
_MODEL_PATH["Gumbel AlphaZero"]).policy_value_fn, | |
c_puct=5, n_playout=100, m_action=8), | |
'Gomoku Bot': Gomoku_bot(self.gomoku_bot_board, -1)} | |
self.MCTS = self.MCTS_dict['AlphaZero'] | |
self.last_mcts = self.MCTS | |
self.AID_MCTS = self.MCTS_dict['AlphaZero'] | |
self.COORDINATE_1D = [] | |
self.current_move = -1 | |
self.ai_simula_time_list = [] | |
self.human_simula_time_list = [] | |
def change_turn(cur): | |
return cur % 2 + 1 | |
# Initialize the game | |
if "ROOM" not in session_state: | |
session_state.ROOM = Room("local") | |
if "OWNER" not in session_state: | |
session_state.OWNER = False | |
if "USE_AIAID" not in session_state: | |
session_state.USE_AIAID = False | |
# Check server health | |
if "ROOMS" not in server_state: | |
with server_state_lock["ROOMS"]: | |
server_state.ROOMS = {} | |
def handle_oppo_model_selection(): | |
if st.session_state['selected_oppo_model'] == 'Gomoku Bot': | |
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict['Gomoku Bot'] | |
return | |
else: | |
TreeNode = session_state.ROOM.last_mcts.mcts._root | |
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']] | |
new_mct.mcts._root = deepcopy(TreeNode) | |
session_state.ROOM.MCTS = new_mct | |
session_state.ROOM.last_mcts = new_mct | |
return | |
def handle_aid_model_selection(): | |
if st.session_state['selected_aid_model'] == 'None': | |
session_state.USE_AIAID = False | |
return | |
session_state.USE_AIAID = True | |
TreeNode = session_state.ROOM.MCTS.mcts._root # use the same tree node | |
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_aid_model']] | |
new_mct.mcts._root = deepcopy(TreeNode) | |
session_state.ROOM.AID_MCTS = new_mct | |
return | |
if 'selected_oppo_model' not in st.session_state: | |
st.session_state['selected_oppo_model'] = 'AlphaZero' # 默认值 | |
if 'selected_aid_model' not in st.session_state: | |
st.session_state['selected_aid_model'] = 'AlphaZero' # 默认值 | |
# Layout | |
TITLE = st.empty() | |
Model_Switch = st.empty() | |
TITLE.header("🤖 AI 3603 Gomoku") | |
selected_oppo_option = Model_Switch.selectbox('Select Opponent Model', | |
['Pure MCTS', 'AlphaZero', 'Gomoku Bot', 'duel', 'Gumbel AlphaZero'], | |
index=1, key='oppo_model') | |
if st.session_state['selected_oppo_model'] != selected_oppo_option: | |
st.session_state['selected_oppo_model'] = selected_oppo_option | |
handle_oppo_model_selection() | |
ROUND_INFO = st.empty() | |
st.markdown("<br>", unsafe_allow_html=True) | |
BOARD_PLATE = [ | |
[cell.empty() for cell in st.columns([1 for _ in range(_BOARD_SIZE)])] for _ in range(_BOARD_SIZE) | |
] | |
LOG = st.empty() | |
# Sidebar | |
SCORE_TAG = st.sidebar.empty() | |
SCORE_PLATE = st.sidebar.columns(2) | |
# History scores | |
SCORE_TAG.subheader("Scores") | |
PLAY_MODE_INFO = st.sidebar.container() | |
MULTIPLAYER_TAG = st.sidebar.empty() | |
with st.sidebar.container(): | |
ANOTHER_ROUND = st.empty() | |
RESTART = st.empty() | |
GIVEIN = st.empty() | |
CHANGE_PLAYER = st.empty() | |
AIAID = st.empty() | |
EXIT = st.empty() | |
selected_aid_option = AIAID.selectbox('Select Assistant Model', ['None', 'Pure MCTS', 'AlphaZero'], index=0, | |
key='aid_model') | |
if st.session_state['selected_aid_model'] != selected_aid_option: | |
st.session_state['selected_aid_model'] = selected_aid_option | |
handle_aid_model_selection() | |
GAME_INFO = st.sidebar.container() | |
message = st.empty() | |
PLAY_MODE_INFO.write("---\n\n**You are Black, AI agent is White.**") | |
GAME_INFO.markdown( | |
""" | |
--- | |
# <span style="color:black;">Freestyle Gomoku game. 🎲</span> | |
- FixedModel means you are not allowed to change model during a game | |
- LeaderBoard is still in development | |
- no restrictions 🚫 | |
- no regrets 😎 | |
Powered by an AlphaZero approach with our own improvements! 🚀 For the specific details, please check out our <a href="insert_report_link_here" style="color:blue;">report</a>. | |
##### Adapted and improved by us! 🌟 <a href="https://github.com/Lijiaxin0111/AI_3603_BIGHOME" style="color:blue;">Our Github repo</a> | |
""", | |
unsafe_allow_html=True, | |
) | |
def restart() -> None: | |
""" | |
Restart the game. | |
""" | |
session_state.ROOM = Room(session_state.ROOM.ROOM_ID) | |
st.session_state['selected_oppo_model'] = 'AlphaZero' | |
def givein() -> None: | |
""" | |
Give in to AI. | |
""" | |
session_state.ROOM = deepcopy(session_state.ROOM) | |
session_state.ROOM.WINNER = _WHITE | |
# add 1 score to AI | |
session_state.ROOM.HISTORY = ( | |
session_state.ROOM.HISTORY[0] | |
+ int(session_state.ROOM.WINNER == _WHITE), | |
session_state.ROOM.HISTORY[1] | |
+ int(session_state.ROOM.WINNER == _BLACK), | |
) | |
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5) | |
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1) | |
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000), | |
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE, | |
_MODEL_PATH["AlphaZero"]).policy_value_fn, | |
c_puct=5, n_playout=100), | |
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, | |
_MODEL_PATH["duel"]).policy_value_fn, | |
c_puct=5, n_playout=100), | |
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE, | |
_MODEL_PATH[ | |
"Gumbel AlphaZero"]).policy_value_fn, | |
c_puct=5, n_playout=100, m_action=8), | |
'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)} | |
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']] | |
session_state.ROOM.last_mcts = session_state.ROOM.MCTS | |
session_state.ROOM.PLAYER = session_state.ROOM.PLAYER | |
session_state.ROOM.TURN = session_state.ROOM.PLAYER | |
session_state.ROOM.WINNER = _BLANK # 0 | |
session_state.ROOM.ai_simula_time_list = [] | |
session_state.ROOM.human_simula_time_list = [] | |
session_state.ROOM.COORDINATE_1D = [] | |
def swap_players() -> None: | |
session_state.update( | |
FirstPlayer=change_turn(session_state.FirstPlayer), | |
) | |
session_state.update( | |
Player=[[], [_BLACK, _WHITE], [_WHITE, _BLACK]][session_state.FirstPlayer], | |
Symbol=_PLAYER_SYMBOL[session_state.FirstPlayer] | |
) | |
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5, players=session_state.Player) | |
session_state.ROOM.PLAYER = session_state.FirstPlayer | |
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1) | |
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000), | |
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE, | |
_MODEL_PATH["AlphaZero"]).policy_value_fn, | |
c_puct=5, n_playout=100), | |
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, | |
_MODEL_PATH["duel"]).policy_value_fn, | |
c_puct=5, n_playout=100), | |
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE, | |
_MODEL_PATH[ | |
"Gumbel AlphaZero"]).policy_value_fn, | |
c_puct=5, n_playout=100, m_action=8), | |
'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)} | |
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']] | |
session_state.ROOM.last_mcts = session_state.ROOM.MCTS | |
session_state.ROOM.PLAYER = session_state.ROOM.PLAYER | |
session_state.ROOM.TURN = session_state.ROOM.PLAYER | |
session_state.ROOM.WINNER = _BLANK # 0 | |
session_state.ROOM.ai_simula_time_list = [] | |
session_state.ROOM.human_simula_time_list = [] | |
session_state.ROOM.COORDINATE_1D = [] | |
RESTART.button( | |
"Reset", | |
on_click=restart, | |
help="Clear the board as well as the scores", | |
) | |
GIVEIN.button( | |
"Give in", | |
on_click = givein, | |
help="Give in to AI", | |
) | |
CHANGE_PLAYER.button( | |
"Swap players", | |
on_click=swap_players, | |
help="Swap players", | |
) | |
# Draw the board | |
def gomoku(): | |
""" | |
Draw the board. | |
Handle the main logic. | |
""" | |
# Restart the game | |
# Continue new round | |
def another_round() -> None: | |
""" | |
Continue new round. | |
""" | |
session_state.ROOM = deepcopy(session_state.ROOM) | |
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5) | |
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1) | |
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000), | |
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE, | |
_MODEL_PATH["AlphaZero"]).policy_value_fn, | |
c_puct=5, n_playout=100), | |
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, | |
_MODEL_PATH["duel"]).policy_value_fn, | |
c_puct=5, n_playout=100), | |
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE, | |
_MODEL_PATH["Gumbel AlphaZero"]).policy_value_fn, | |
c_puct=5, n_playout=100, m_action=8), | |
'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)} | |
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']] | |
session_state.ROOM.last_mcts = session_state.ROOM.MCTS | |
session_state.ROOM.PLAYER = session_state.ROOM.PLAYER | |
session_state.ROOM.TURN = session_state.ROOM.PLAYER | |
session_state.ROOM.WINNER = _BLANK # 0 | |
session_state.ROOM.ai_simula_time_list = [] | |
session_state.ROOM.human_simula_time_list = [] | |
session_state.ROOM.COORDINATE_1D = [] | |
# Room status sync | |
def sync_room() -> bool: | |
room_id = session_state.ROOM.ROOM_ID | |
if room_id not in server_state.ROOMS.keys(): | |
session_state.ROOM = Room("local") | |
return False | |
elif server_state.ROOMS[room_id].TIME == session_state.ROOM.TIME: | |
return False | |
elif server_state.ROOMS[room_id].TIME < session_state.ROOM.TIME: | |
# Only acquire the lock when writing to the server state | |
with server_state_lock["ROOMS"]: | |
server_rooms = server_state.ROOMS | |
server_rooms[room_id] = session_state.ROOM | |
server_state.ROOMS = server_rooms | |
return True | |
else: | |
session_state.ROOM = server_state.ROOMS[room_id] | |
return True | |
# Check if winner emerge from move | |
def check_win() -> int: | |
""" | |
Use convolution to check if any player wins. | |
""" | |
vertical = convolve( | |
session_state.ROOM.BOARD.board_map, | |
_VERTICAL, | |
mode="same", | |
) | |
horizontal = convolve( | |
session_state.ROOM.BOARD.board_map, | |
_HORIZONTAL, | |
mode="same", | |
) | |
diagonal_up_left = convolve( | |
session_state.ROOM.BOARD.board_map, | |
_DIAGONAL_UP_LEFT, | |
mode="same", | |
) | |
diagonal_up_right = convolve( | |
session_state.ROOM.BOARD.board_map, | |
_DIAGONAL_UP_RIGHT, | |
mode="same", | |
) | |
if ( | |
np.max( | |
[ | |
np.max(vertical), | |
np.max(horizontal), | |
np.max(diagonal_up_left), | |
np.max(diagonal_up_right), | |
] | |
) | |
== 5 * _BLACK | |
): | |
winner = _BLACK | |
elif ( | |
np.min( | |
[ | |
np.min(vertical), | |
np.min(horizontal), | |
np.min(diagonal_up_left), | |
np.min(diagonal_up_right), | |
] | |
) | |
== 5 * _WHITE | |
): | |
winner = _WHITE | |
else: | |
winner = _BLANK | |
return winner | |
# Triggers the board response on click | |
def handle_click(x, y): | |
""" | |
Controls whether to pass on / continue current board / may start new round | |
""" | |
if session_state.ROOM.BOARD.board_map[x][y] != _BLANK: | |
pass | |
elif ( | |
session_state.ROOM.ROOM_ID in server_state.ROOMS.keys() | |
and _ROOM_COLOR[session_state.OWNER] | |
!= server_state.ROOMS[session_state.ROOM.ROOM_ID].TURN | |
): | |
sync_room() | |
# normal play situation | |
elif session_state.ROOM.WINNER == _BLANK: | |
# session_state.ROOM = deepcopy(session_state.ROOM) | |
# print("View of human player: ", session_state.ROOM.BOARD.board_map) | |
move = session_state.ROOM.BOARD.location_to_move((x, y)) | |
session_state.ROOM.current_move = move | |
session_state.ROOM.BOARD.do_move(move) | |
# Gomoku Bot BOARD | |
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(_BOARD_SIZE - move // _BOARD_SIZE - 1, | |
move % _BOARD_SIZE) # # this move starts from left up corner (0,0), however, the move in the game starts from left bottom corner (0,0) | |
session_state.ROOM.BOARD.board_map[x][y] = session_state.ROOM.TURN | |
session_state.ROOM.COORDINATE_1D.append(x * _BOARD_SIZE + y) | |
session_state.ROOM.TURN = change_turn(session_state.ROOM.TURN) | |
win, winner = session_state.ROOM.BOARD.game_end() | |
if win: | |
session_state.ROOM.WINNER = winner | |
session_state.ROOM.HISTORY = ( | |
session_state.ROOM.HISTORY[0] | |
+ int(session_state.ROOM.WINNER == _WHITE), | |
session_state.ROOM.HISTORY[1] | |
+ int(session_state.ROOM.WINNER == _BLACK), | |
) | |
session_state.ROOM.TIME = time.time() | |
def forbid_click(x, y): | |
# st.warning('This posistion has been occupied!!!!', icon="⚠️") | |
st.error("({}, {}) has been occupied!!)".format(x, y), icon="🚨") | |
# Draw board | |
def draw_board(response: bool): | |
"""construct each buttons for all cells of the board""" | |
if session_state.USE_AIAID and session_state.ROOM.WINNER == _BLANK and session_state.ROOM.TURN == _BLACK: | |
if session_state.USE_AIAID: | |
copy_mcts = deepcopy(session_state.ROOM.AID_MCTS.mcts) | |
_, acts_aid, probs_aid, simul_mean_time_aid = copy_mcts.get_move_probs(session_state.ROOM.BOARD) | |
sorted_acts_probs = sorted(zip(acts_aid, probs_aid), key=lambda x: x[1], reverse=True) | |
top_five_acts = [act for act, prob in sorted_acts_probs[:5]] | |
top_five_probs = [prob for act, prob in sorted_acts_probs[:5]] | |
if response and session_state.ROOM.TURN == _BLACK: # human turn | |
start_time = time.time() | |
print("Your turn") | |
# construction of clickable buttons | |
cur_move = (session_state.ROOM.current_move // _BOARD_SIZE, session_state.ROOM.current_move % _BOARD_SIZE) | |
for i, row in enumerate(session_state.ROOM.BOARD.board_map): | |
# print("row:", row) | |
for j, cell in enumerate(row): | |
if ( | |
i * _BOARD_SIZE + j | |
in (session_state.ROOM.COORDINATE_1D) | |
): | |
if i == cur_move[0] and j == cur_move[1]: | |
BOARD_PLATE[i][j].button( | |
session_state.Symbol[_NEW], | |
key=f"{i}:{j}", | |
args=(i, j), | |
on_click=forbid_click, | |
) | |
else: | |
# disable click for GPT choices | |
BOARD_PLATE[i][j].button( | |
session_state.Symbol[cell], | |
key=f"{i}:{j}", | |
args=(i, j), | |
on_click=forbid_click | |
) | |
else: | |
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts: | |
# enable click for other cells available for human choices | |
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)] | |
BOARD_PLATE[i][j].button( | |
session_state.Symbol[cell] + f"{round(prob, 2)}", | |
key=f"{i}:{j}", | |
on_click=handle_click, | |
args=(i, j), | |
) | |
else: | |
# enable click for other cells available for human choices | |
BOARD_PLATE[i][j].button( | |
session_state.Symbol[cell], | |
key=f"{i}:{j}", | |
on_click=handle_click, | |
args=(i, j), | |
) | |
end_time = time.time() | |
print("Time used for human move: ", end_time - start_time) | |
elif response and session_state.ROOM.TURN == _WHITE: # AI turn | |
message.empty() | |
with st.spinner('🔮✨ Waiting for AI response... ⏳🚀'): | |
time.sleep(0.1) | |
print("AI's turn") | |
print("Below are current board under AI's view") | |
# print(session_state.ROOM.BOARD.board_map) | |
# move = _BOARD_SIZE * _BOARD_SIZE | |
# forbid = [] | |
# step = 0.1 | |
# tmp = 0.7 | |
# while move >= _BOARD_SIZE * _BOARD_SIZE or move in session_state.ROOM.COORDINATE_1D: | |
# | |
# gpt_predictions = generate_gpt2( | |
# gpt2, | |
# torch.tensor(session_state.ROOM.COORDINATE_1D).unsqueeze(0), | |
# tmp | |
# ) | |
# print(gpt_predictions) | |
# move = gpt_predictions[len(session_state.ROOM.COORDINATE_1D)] | |
# print(move) | |
# tmp += step | |
# # if move >= _BOARD_SIZE * _BOARD_SIZE: | |
# # forbid.append(move) | |
# # else: | |
# # break | |
# | |
# | |
# gpt_response = move | |
# gpt_i, gpt_j = gpt_response // _BOARD_SIZE, gpt_response % _BOARD_SIZE | |
# print(gpt_i, gpt_j) | |
# # session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN | |
# | |
# simul_time = 0 | |
if st.session_state['selected_oppo_model'] != 'Gomoku Bot': | |
move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True) | |
else: | |
move, simul_time = session_state.ROOM.MCTS.get_action(return_time=True) | |
session_state.ROOM.ai_simula_time_list.append(simul_time) | |
print("AI takes move: ", move) | |
session_state.ROOM.current_move = move | |
gpt_response = move | |
gpt_i, gpt_j = gpt_response // _BOARD_SIZE, gpt_response % _BOARD_SIZE | |
print("AI's move is located at ({}, {}) :".format(gpt_i, gpt_j)) | |
move = session_state.ROOM.BOARD.location_to_move((gpt_i, gpt_j)) | |
print("Location to move: ", move) | |
# print("Location to move: ", move) | |
# MCTS BOARD | |
session_state.ROOM.BOARD.do_move(move) | |
# Gomoku Bot BOARD | |
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(_BOARD_SIZE - 1 - move // _BOARD_SIZE, | |
move % _BOARD_SIZE) | |
# session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN | |
session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j) | |
if not session_state.ROOM.BOARD.game_end()[0]: | |
if session_state.USE_AIAID: | |
copy_mcts = deepcopy(session_state.ROOM.AID_MCTS.mcts) | |
_, acts_aid, probs_aid, simul_mean_time_aid = copy_mcts.get_move_probs(session_state.ROOM.BOARD) | |
sorted_acts_probs = sorted(zip(acts_aid, probs_aid), key=lambda x: x[1], reverse=True) | |
top_five_acts = [act for act, prob in sorted_acts_probs[:5]] | |
top_five_probs = [prob for act, prob in sorted_acts_probs[:5]] | |
else: | |
top_five_acts = [] | |
top_five_probs = [] | |
# construction of clickable buttons | |
for i, row in enumerate(session_state.ROOM.BOARD.board_map): | |
# print("row:", row) | |
for j, cell in enumerate(row): | |
if ( | |
i * _BOARD_SIZE + j | |
in (session_state.ROOM.COORDINATE_1D) | |
): | |
if i == gpt_i and j == gpt_j: | |
BOARD_PLATE[i][j].button( | |
session_state.Symbol[_NEW], | |
key=f"{i}:{j}", | |
args=(i, j), | |
on_click=handle_click, | |
) | |
else: | |
# disable click for GPT choices | |
BOARD_PLATE[i][j].button( | |
session_state.Symbol[cell], | |
key=f"{i}:{j}", | |
args=(i, j), | |
on_click=forbid_click | |
) | |
else: | |
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts and not \ | |
session_state.ROOM.BOARD.game_end()[0]: | |
# enable click for other cells available for human choices | |
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)] | |
BOARD_PLATE[i][j].button( | |
session_state.Symbol[cell] + f"{round(prob, 2)}", | |
key=f"{i}:{j}", | |
on_click=handle_click, | |
args=(i, j), | |
) | |
else: | |
# enable click for other cells available for human choices | |
BOARD_PLATE[i][j].button( | |
session_state.Symbol[cell], | |
key=f"{i}:{j}", | |
on_click=handle_click, | |
args=(i, j), | |
) | |
message.markdown( | |
'AI agent has calculated its strategy, which takes <span style="color: blue; font-size: 20px;">{:.3e}</span>s per simulation.'.format( | |
simul_time), | |
unsafe_allow_html=True | |
) | |
LOG.subheader("Logs") | |
# change turn | |
session_state.ROOM.TURN = change_turn(session_state.ROOM.TURN) | |
# session_state.ROOM.WINNER = check_win() | |
win, winner = session_state.ROOM.BOARD.game_end() | |
if win: | |
session_state.ROOM.WINNER = winner | |
session_state.ROOM.HISTORY = ( | |
session_state.ROOM.HISTORY[0] | |
+ int(session_state.ROOM.WINNER == _WHITE), | |
session_state.ROOM.HISTORY[1] | |
+ int(session_state.ROOM.WINNER == _BLACK), | |
) | |
session_state.ROOM.TIME = time.time() | |
if not response or session_state.ROOM.WINNER != _BLANK: | |
if session_state.ROOM.WINNER != _BLANK: | |
print("Game over") | |
for i, row in enumerate(session_state.ROOM.BOARD.board_map): | |
for j, cell in enumerate(row): | |
BOARD_PLATE[i][j].write( | |
session_state.Symbol[cell], | |
# key=f"{i}:{j}", | |
) | |
# Game process control | |
def game_control(): | |
if session_state.ROOM.WINNER != _BLANK: | |
draw_board(False) | |
else: | |
draw_board(True) | |
if session_state.ROOM.WINNER != _BLANK or 0 not in session_state.ROOM.BOARD.board_map: | |
GIVEIN.empty() | |
ANOTHER_ROUND.button( | |
"Play Next round!", | |
on_click=another_round, | |
help="Clear board and swap first player", | |
) | |
# Infos | |
def update_info() -> None: | |
# Additional information | |
SCORE_PLATE[0].metric("Gomoku-Agent", session_state.ROOM.HISTORY[0]) | |
SCORE_PLATE[1].metric("Human", session_state.ROOM.HISTORY[1]) | |
if session_state.ROOM.WINNER != _BLANK: | |
st.balloons() | |
ROUND_INFO.write( | |
f"#### **{_PLAYER_COLOR[session_state.ROOM.WINNER]} WIN!**\n**Click buttons on the left for more plays.**" | |
) | |
# elif 0 not in session_state.ROOM.BOARD.board_map: | |
# ROUND_INFO.write("#### **Tie**") | |
# else: | |
# ROUND_INFO.write( | |
# f"#### **{_PLAYER_SYMBOL[session_state.ROOM.TURN]} {_PLAYER_COLOR[session_state.ROOM.TURN]}'s turn...**" | |
# ) | |
# draw the plot for simulation time | |
# 创建一个 DataFrame | |
# print(session_state.ROOM.ai_simula_time_list) | |
st.markdown("<br>", unsafe_allow_html=True) | |
st.markdown("<br>", unsafe_allow_html=True) | |
chart_data = pd.DataFrame(session_state.ROOM.ai_simula_time_list, columns=["Simulation Time"]) | |
st.line_chart(chart_data) | |
game_control() | |
update_info() | |
if __name__ == "__main__": | |
gomoku() | |