Spaces:
Sleeping
Sleeping
sjz
commited on
Commit
·
9e90422
1
Parent(s):
ae94556
fix player vs AI bugs
Browse files- const.py +5 -3
- pages/test.py +592 -0
const.py
CHANGED
@@ -14,6 +14,8 @@ _BOARD_SIZE_1D = _BOARD_SIZE * _BOARD_SIZE
|
|
14 |
_BLANK = 0
|
15 |
_BLACK = 1
|
16 |
_WHITE = 2
|
|
|
|
|
17 |
_NEW = 3
|
18 |
_PLAYER_SYMBOL1 = {
|
19 |
_WHITE: "⚪",
|
@@ -31,10 +33,10 @@ _PLAYER_SYMBOL2 = {
|
|
31 |
|
32 |
|
33 |
|
34 |
-
|
35 |
-
|
36 |
_BLANK: "Blank",
|
37 |
-
|
38 |
}
|
39 |
_PLAYER_COLOR_AI_VS_AI = {
|
40 |
_WHITE: "WHITE",
|
|
|
14 |
_BLANK = 0
|
15 |
_BLACK = 1
|
16 |
_WHITE = 2
|
17 |
+
_HUMAN = 4
|
18 |
+
_AI = 5
|
19 |
_NEW = 3
|
20 |
_PLAYER_SYMBOL1 = {
|
21 |
_WHITE: "⚪",
|
|
|
33 |
|
34 |
|
35 |
|
36 |
+
_PLAYER_NAME = {
|
37 |
+
_AI: "AI",
|
38 |
_BLANK: "Blank",
|
39 |
+
_HUMAN: "YOU HUMAN",
|
40 |
}
|
41 |
_PLAYER_COLOR_AI_VS_AI = {
|
42 |
_WHITE: "WHITE",
|
pages/test.py
ADDED
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
FileName: app.py
|
3 |
+
Author: Benhao Huang
|
4 |
+
Create Date: 2023/11/19
|
5 |
+
Description: this file is used to display our project and add visualization elements to the game, using Streamlit
|
6 |
+
"""
|
7 |
+
|
8 |
+
import time
|
9 |
+
import pandas as pd
|
10 |
+
from copy import deepcopy
|
11 |
+
import numpy as np
|
12 |
+
import streamlit as st
|
13 |
+
from scipy.signal import convolve # this is used to check if any player wins
|
14 |
+
from streamlit import session_state
|
15 |
+
from streamlit_server_state import server_state, server_state_lock
|
16 |
+
from Gomoku_MCTS import MCTSpure, alphazero, Board, PolicyValueNet_old, PolicyValueNet_new, duel_PolicyValueNet, \
|
17 |
+
Gumbel_MCTSPlayer
|
18 |
+
from Gomoku_Bot import Gomoku_bot
|
19 |
+
from Gomoku_Bot import Board as Gomoku_bot_board
|
20 |
+
|
21 |
+
import matplotlib.pyplot as plt
|
22 |
+
|
23 |
+
from const import (
|
24 |
+
_BLACK, # 1
|
25 |
+
_WHITE, # 2
|
26 |
+
_HUMAN,
|
27 |
+
_AI,
|
28 |
+
_BLANK,
|
29 |
+
_PLAYER_NAME,
|
30 |
+
_PLAYER_SYMBOL1,
|
31 |
+
_PLAYER_SYMBOL2,
|
32 |
+
_ROOM_COLOR,
|
33 |
+
_VERTICAL,
|
34 |
+
_NEW,
|
35 |
+
_HORIZONTAL,
|
36 |
+
_DIAGONAL_UP_LEFT,
|
37 |
+
_DIAGONAL_UP_RIGHT,
|
38 |
+
_BOARD_SIZE,
|
39 |
+
_MODEL_PATH
|
40 |
+
)
|
41 |
+
|
42 |
+
_PLAYER_SYMBOL = [0, _PLAYER_SYMBOL1, _PLAYER_SYMBOL2]
|
43 |
+
|
44 |
+
|
45 |
+
if "FirstPlayer" not in session_state:
|
46 |
+
session_state.FirstPlayer = _HUMAN
|
47 |
+
session_state.Players = [ _BLACK,_WHITE]
|
48 |
+
session_state.Symbols = _PLAYER_SYMBOL1
|
49 |
+
|
50 |
+
# Utils
|
51 |
+
class Room:
|
52 |
+
def __init__(self, room_id) -> None:
|
53 |
+
self.ROOM_ID = room_id
|
54 |
+
self.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5, players=session_state.Players)
|
55 |
+
self.TURN = _BLACK
|
56 |
+
self.CURR_PLAYER = session_state.FirstPlayer
|
57 |
+
self.HISTORY = (0, 0)
|
58 |
+
self.WINNER = _BLANK
|
59 |
+
self.TIME = time.time()
|
60 |
+
self.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
61 |
+
self.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
62 |
+
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
63 |
+
_MODEL_PATH["AlphaZero"]).policy_value_fn,
|
64 |
+
c_puct=5, n_playout=100),
|
65 |
+
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
|
66 |
+
_MODEL_PATH["duel"]).policy_value_fn,
|
67 |
+
c_puct=5, n_playout=100),
|
68 |
+
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
69 |
+
_MODEL_PATH["Gumbel AlphaZero"]).policy_value_fn,
|
70 |
+
c_puct=5, n_playout=100, m_action=8),
|
71 |
+
'Gomoku Bot': Gomoku_bot(self.gomoku_bot_board, -1)}
|
72 |
+
self.MCTS = self.MCTS_dict['AlphaZero']
|
73 |
+
self.last_mcts = self.MCTS
|
74 |
+
self.AID_MCTS = self.MCTS_dict['AlphaZero']
|
75 |
+
self.COORDINATE_1D = []
|
76 |
+
self.current_move = -1
|
77 |
+
self.ai_simula_time_list = []
|
78 |
+
self.human_simula_time_list = []
|
79 |
+
|
80 |
+
|
81 |
+
def change_turn(cur):
|
82 |
+
if cur in [_HUMAN, _AI]:
|
83 |
+
return _HUMAN if cur == _AI else _AI
|
84 |
+
return cur % 2 + 1
|
85 |
+
|
86 |
+
|
87 |
+
# Initialize the game
|
88 |
+
if "ROOM" not in session_state:
|
89 |
+
session_state.ROOM = Room("local")
|
90 |
+
if "OWNER" not in session_state:
|
91 |
+
session_state.OWNER = False
|
92 |
+
if "USE_AIAID" not in session_state:
|
93 |
+
session_state.USE_AIAID = False
|
94 |
+
|
95 |
+
# Check server health
|
96 |
+
if "ROOMS" not in server_state:
|
97 |
+
with server_state_lock["ROOMS"]:
|
98 |
+
server_state.ROOMS = {}
|
99 |
+
|
100 |
+
|
101 |
+
def handle_oppo_model_selection():
|
102 |
+
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
103 |
+
session_state.ROOM.MCTS = new_mct
|
104 |
+
session_state.ROOM.last_mcts = new_mct
|
105 |
+
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
|
106 |
+
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
107 |
+
session_state.ROOM.CURR_PLAYER = session_state.FirstPlayer
|
108 |
+
session_state.ROOM.TURN = _BLACK
|
109 |
+
session_state.ROOM.WINNER = _BLANK # 0
|
110 |
+
session_state.ROOM.ai_simula_time_list = []
|
111 |
+
session_state.ROOM.human_simula_time_list = []
|
112 |
+
session_state.ROOM.COORDINATE_1D = []
|
113 |
+
return
|
114 |
+
|
115 |
+
|
116 |
+
def handle_aid_model_selection():
|
117 |
+
if st.session_state['selected_aid_model'] == 'None':
|
118 |
+
session_state.USE_AIAID = False
|
119 |
+
return
|
120 |
+
session_state.USE_AIAID = True
|
121 |
+
TreeNode = session_state.ROOM.MCTS.mcts._root # use the same tree node
|
122 |
+
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_aid_model']]
|
123 |
+
new_mct.mcts._root = deepcopy(TreeNode)
|
124 |
+
session_state.ROOM.AID_MCTS = new_mct
|
125 |
+
return
|
126 |
+
|
127 |
+
|
128 |
+
if 'selected_oppo_model' not in st.session_state:
|
129 |
+
st.session_state['selected_oppo_model'] = 'AlphaZero' # 默认值
|
130 |
+
|
131 |
+
if 'selected_aid_model' not in st.session_state:
|
132 |
+
st.session_state['selected_aid_model'] = 'AlphaZero' # 默认值
|
133 |
+
|
134 |
+
# Layout
|
135 |
+
TITLE = st.empty()
|
136 |
+
Model_Switch = st.empty()
|
137 |
+
|
138 |
+
TITLE.header("🤖 AI 3603 Gomoku")
|
139 |
+
selected_oppo_option = Model_Switch.selectbox('Select Opponent Model',
|
140 |
+
['Pure MCTS', 'AlphaZero', 'Gomoku Bot', 'duel', 'Gumbel AlphaZero'],
|
141 |
+
index=1, key='oppo_model')
|
142 |
+
|
143 |
+
if st.session_state['selected_oppo_model'] != selected_oppo_option:
|
144 |
+
st.session_state['selected_oppo_model'] = selected_oppo_option
|
145 |
+
handle_oppo_model_selection()
|
146 |
+
|
147 |
+
ROUND_INFO = st.empty()
|
148 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
149 |
+
BOARD_PLATE = [
|
150 |
+
[cell.empty() for cell in st.columns([1 for _ in range(_BOARD_SIZE)])] for _ in range(_BOARD_SIZE)
|
151 |
+
]
|
152 |
+
LOG = st.empty()
|
153 |
+
|
154 |
+
# Sidebar
|
155 |
+
SCORE_TAG = st.sidebar.empty()
|
156 |
+
SCORE_PLATE = st.sidebar.columns(2)
|
157 |
+
# History scores
|
158 |
+
SCORE_TAG.subheader("Scores")
|
159 |
+
|
160 |
+
PLAY_MODE_INFO = st.sidebar.container()
|
161 |
+
MULTIPLAYER_TAG = st.sidebar.empty()
|
162 |
+
with st.sidebar.container():
|
163 |
+
ANOTHER_ROUND = st.empty()
|
164 |
+
RESTART = st.empty()
|
165 |
+
GIVEIN = st.empty()
|
166 |
+
CHANGE_PLAYER = st.empty()
|
167 |
+
AIAID = st.empty()
|
168 |
+
EXIT = st.empty()
|
169 |
+
selected_aid_option = AIAID.selectbox('Select Assistant Model', ['None', 'Pure MCTS', 'AlphaZero'], index=0,
|
170 |
+
key='aid_model')
|
171 |
+
if st.session_state['selected_aid_model'] != selected_aid_option:
|
172 |
+
st.session_state['selected_aid_model'] = selected_aid_option
|
173 |
+
handle_aid_model_selection()
|
174 |
+
|
175 |
+
GAME_INFO = st.sidebar.container()
|
176 |
+
message = st.empty()
|
177 |
+
PLAY_MODE_INFO.write("---\n\n**You are Black, AI agent is White.**")
|
178 |
+
GAME_INFO.markdown(
|
179 |
+
"""
|
180 |
+
---
|
181 |
+
# <span style="color:black;">Freestyle Gomoku game. 🎲</span>
|
182 |
+
- no restrictions 🚫
|
183 |
+
- no regrets 😎
|
184 |
+
- no regrets 😎
|
185 |
+
- swap players after one round is over 🔁
|
186 |
+
Powered by an AlphaZero approach with our own improvements! 🚀 For the specific details, please check out our <a href="insert_report_link_here" style="color:blue;">report</a>.
|
187 |
+
##### Adapted and improved by us! 🌟 <a href="https://github.com/Lijiaxin0111/AI_3603_BIGHOME" style="color:blue;">Our Github repo</a>
|
188 |
+
""",
|
189 |
+
unsafe_allow_html=True,
|
190 |
+
)
|
191 |
+
|
192 |
+
|
193 |
+
def restart() -> None:
|
194 |
+
"""
|
195 |
+
Restart the game.
|
196 |
+
"""
|
197 |
+
session_state.ROOM = Room(session_state.ROOM.ROOM_ID)
|
198 |
+
st.session_state['selected_oppo_model'] = 'AlphaZero'
|
199 |
+
|
200 |
+
def givein() -> None:
|
201 |
+
"""
|
202 |
+
Give in to AI.
|
203 |
+
"""
|
204 |
+
session_state.ROOM = deepcopy(session_state.ROOM)
|
205 |
+
session_state.ROOM.WINNER = _AI
|
206 |
+
# add 1 score to AI
|
207 |
+
session_state.ROOM.HISTORY = (
|
208 |
+
session_state.ROOM.HISTORY[0]
|
209 |
+
+ int(session_state.ROOM.WINNER == _AI),
|
210 |
+
session_state.ROOM.HISTORY[1]
|
211 |
+
+ int(session_state.ROOM.WINNER == _HUMAN),
|
212 |
+
)
|
213 |
+
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
|
214 |
+
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
215 |
+
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
216 |
+
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
217 |
+
_MODEL_PATH["AlphaZero"]).policy_value_fn,
|
218 |
+
c_puct=5, n_playout=100),
|
219 |
+
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
|
220 |
+
_MODEL_PATH["duel"]).policy_value_fn,
|
221 |
+
c_puct=5, n_playout=100),
|
222 |
+
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
223 |
+
_MODEL_PATH[
|
224 |
+
"Gumbel AlphaZero"]).policy_value_fn,
|
225 |
+
c_puct=5, n_playout=100, m_action=8),
|
226 |
+
'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
|
227 |
+
|
228 |
+
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
229 |
+
session_state.ROOM.last_mcts = session_state.ROOM.MCTS
|
230 |
+
session_state.ROOM.TURN = _BLACK
|
231 |
+
session_state.ROOM.CURR_PLAYER = session_state.FirstPlayer
|
232 |
+
session_state.ROOM.WINNER = _BLANK # 0
|
233 |
+
session_state.ROOM.ai_simula_time_list = []
|
234 |
+
session_state.ROOM.human_simula_time_list = []
|
235 |
+
session_state.ROOM.COORDINATE_1D = []
|
236 |
+
|
237 |
+
def swap_players() -> None:
|
238 |
+
session_state.update(
|
239 |
+
FirstPlayer=change_turn(session_state.FirstPlayer),
|
240 |
+
)
|
241 |
+
"""
|
242 |
+
session_state.FirstPlayer = _HUMAN
|
243 |
+
session_state.Players = [ _BLACK,_WHITE]
|
244 |
+
session_state.Symbols = _PLAYER_SYMBOL1
|
245 |
+
"""
|
246 |
+
|
247 |
+
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5, players=session_state.Players)
|
248 |
+
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
249 |
+
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
250 |
+
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
251 |
+
_MODEL_PATH["AlphaZero"]).policy_value_fn,
|
252 |
+
c_puct=5, n_playout=100),
|
253 |
+
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
|
254 |
+
_MODEL_PATH["duel"]).policy_value_fn,
|
255 |
+
c_puct=5, n_playout=100),
|
256 |
+
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
257 |
+
_MODEL_PATH[
|
258 |
+
"Gumbel AlphaZero"]).policy_value_fn,
|
259 |
+
c_puct=5, n_playout=100, m_action=8),
|
260 |
+
'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
|
261 |
+
|
262 |
+
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
263 |
+
session_state.ROOM.last_mcts = session_state.ROOM.MCTS
|
264 |
+
session_state.ROOM.TURN = _BLACK
|
265 |
+
session_state.ROOM.CURR_PLAYER = session_state.FirstPlayer
|
266 |
+
session_state.ROOM.WINNER = _BLANK # 0
|
267 |
+
session_state.ROOM.ai_simula_time_list = []
|
268 |
+
session_state.ROOM.human_simula_time_list = []
|
269 |
+
session_state.ROOM.COORDINATE_1D = []
|
270 |
+
|
271 |
+
RESTART.button(
|
272 |
+
"Reset",
|
273 |
+
on_click=restart,
|
274 |
+
help="Clear the board as well as the scores",
|
275 |
+
)
|
276 |
+
|
277 |
+
GIVEIN.button(
|
278 |
+
"Give in",
|
279 |
+
on_click = givein,
|
280 |
+
help="Give in to AI",
|
281 |
+
)
|
282 |
+
|
283 |
+
CHANGE_PLAYER.button(
|
284 |
+
"Swap players",
|
285 |
+
on_click=swap_players,
|
286 |
+
help="Swap players",
|
287 |
+
)
|
288 |
+
|
289 |
+
|
290 |
+
# Draw the board
|
291 |
+
def gomoku():
|
292 |
+
"""
|
293 |
+
Draw the board.
|
294 |
+
Handle the main logic.
|
295 |
+
"""
|
296 |
+
|
297 |
+
# Restart the game
|
298 |
+
|
299 |
+
# Continue new round
|
300 |
+
def another_round() -> None:
|
301 |
+
"""
|
302 |
+
Continue new round.
|
303 |
+
"""
|
304 |
+
session_state.ROOM = deepcopy(session_state.ROOM)
|
305 |
+
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
|
306 |
+
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
307 |
+
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
308 |
+
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
309 |
+
_MODEL_PATH["AlphaZero"]).policy_value_fn,
|
310 |
+
c_puct=5, n_playout=100),
|
311 |
+
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
|
312 |
+
_MODEL_PATH["duel"]).policy_value_fn,
|
313 |
+
c_puct=5, n_playout=100),
|
314 |
+
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
315 |
+
_MODEL_PATH["Gumbel AlphaZero"]).policy_value_fn,
|
316 |
+
c_puct=5, n_playout=100, m_action=8),
|
317 |
+
'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
|
318 |
+
|
319 |
+
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
320 |
+
session_state.ROOM.last_mcts = session_state.ROOM.MCTS
|
321 |
+
session_state.ROOM.TURN = _BLACK
|
322 |
+
session_state.ROOM.CURR_PLAYER = session_state.FirstPlayer
|
323 |
+
session_state.ROOM.WINNER = _BLANK # 0
|
324 |
+
session_state.ROOM.ai_simula_time_list = []
|
325 |
+
session_state.ROOM.human_simula_time_list = []
|
326 |
+
session_state.ROOM.COORDINATE_1D = []
|
327 |
+
|
328 |
+
# Room status sync
|
329 |
+
def sync_room() -> bool:
|
330 |
+
room_id = session_state.ROOM.ROOM_ID
|
331 |
+
if room_id not in server_state.ROOMS.keys():
|
332 |
+
session_state.ROOM = Room("local")
|
333 |
+
return False
|
334 |
+
elif server_state.ROOMS[room_id].TIME == session_state.ROOM.TIME:
|
335 |
+
return False
|
336 |
+
elif server_state.ROOMS[room_id].TIME < session_state.ROOM.TIME:
|
337 |
+
# Only acquire the lock when writing to the server state
|
338 |
+
with server_state_lock["ROOMS"]:
|
339 |
+
server_rooms = server_state.ROOMS
|
340 |
+
server_rooms[room_id] = session_state.ROOM
|
341 |
+
server_state.ROOMS = server_rooms
|
342 |
+
return True
|
343 |
+
else:
|
344 |
+
session_state.ROOM = server_state.ROOMS[room_id]
|
345 |
+
return True
|
346 |
+
|
347 |
+
# Triggers the board response on click
|
348 |
+
def handle_click(x, y):
|
349 |
+
"""
|
350 |
+
Controls whether to pass on / continue current board / may start new round
|
351 |
+
"""
|
352 |
+
if session_state.ROOM.BOARD.board_map[x][y] != _BLANK:
|
353 |
+
pass
|
354 |
+
elif (
|
355 |
+
session_state.ROOM.ROOM_ID in server_state.ROOMS.keys()
|
356 |
+
and _ROOM_COLOR[session_state.OWNER]
|
357 |
+
!= server_state.ROOMS[session_state.ROOM.ROOM_ID].TURN
|
358 |
+
):
|
359 |
+
sync_room()
|
360 |
+
|
361 |
+
# normal play situation
|
362 |
+
elif session_state.ROOM.WINNER == _BLANK:
|
363 |
+
move = session_state.ROOM.BOARD.location_to_move((x, y))
|
364 |
+
session_state.ROOM.current_move = move
|
365 |
+
session_state.ROOM.BOARD.do_move(move)
|
366 |
+
# Gomoku Bot BOARD
|
367 |
+
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(_BOARD_SIZE - move // _BOARD_SIZE - 1,
|
368 |
+
move % _BOARD_SIZE) # # this move starts from left up corner (0,0), however, the move in the game starts from left bottom corner (0,0)
|
369 |
+
session_state.ROOM.BOARD.board_map[x][y] = session_state.ROOM.TURN
|
370 |
+
session_state.ROOM.COORDINATE_1D.append(x * _BOARD_SIZE + y)
|
371 |
+
|
372 |
+
session_state.ROOM.TURN = change_turn(session_state.ROOM.TURN)
|
373 |
+
session_state.ROOM.CURR_PLAYER = change_turn(session_state.ROOM.CURR_PLAYER)
|
374 |
+
win, winner = session_state.ROOM.BOARD.game_end()
|
375 |
+
if win:
|
376 |
+
session_state.ROOM.WINNER = session_state.ROOM.CURR_PLAYER
|
377 |
+
session_state.ROOM.HISTORY = (
|
378 |
+
session_state.ROOM.HISTORY[0]
|
379 |
+
+ int(session_state.ROOM.WINNER == _AI),
|
380 |
+
session_state.ROOM.HISTORY[1]
|
381 |
+
+ int(session_state.ROOM.WINNER == _HUMAN),
|
382 |
+
)
|
383 |
+
session_state.ROOM.TIME = time.time()
|
384 |
+
|
385 |
+
def forbid_click(x, y):
|
386 |
+
# st.warning('This posistion has been occupied!!!!', icon="⚠️")
|
387 |
+
st.error("({}, {}) has been occupied!!)".format(x, y), icon="🚨")
|
388 |
+
|
389 |
+
# Draw board
|
390 |
+
def draw_board(response: bool):
|
391 |
+
"""construct each buttons for all cells of the board"""
|
392 |
+
if session_state.USE_AIAID and session_state.ROOM.WINNER == _BLANK and session_state.ROOM.CURR_PLAYER == _HUMAN:
|
393 |
+
if session_state.USE_AIAID:
|
394 |
+
copy_mcts = deepcopy(session_state.ROOM.AID_MCTS.mcts)
|
395 |
+
_, acts_aid, probs_aid, simul_mean_time_aid = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
|
396 |
+
sorted_acts_probs = sorted(zip(acts_aid, probs_aid), key=lambda x: x[1], reverse=True)
|
397 |
+
top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
|
398 |
+
top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
|
399 |
+
if response and session_state.ROOM.CURR_PLAYER == _HUMAN: # human turn
|
400 |
+
start_time = time.time()
|
401 |
+
print("Your turn")
|
402 |
+
# construction of clickable buttons
|
403 |
+
cur_move = (session_state.ROOM.current_move // _BOARD_SIZE, session_state.ROOM.current_move % _BOARD_SIZE)
|
404 |
+
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
405 |
+
for j, cell in enumerate(row):
|
406 |
+
if (
|
407 |
+
i * _BOARD_SIZE + j
|
408 |
+
in (session_state.ROOM.COORDINATE_1D)
|
409 |
+
):
|
410 |
+
if i == cur_move[0] and j == cur_move[1]:
|
411 |
+
BOARD_PLATE[i][j].button(
|
412 |
+
session_state.Symbols[_NEW],
|
413 |
+
key=f"{i}:{j}",
|
414 |
+
args=(i, j),
|
415 |
+
on_click=forbid_click,
|
416 |
+
)
|
417 |
+
else:
|
418 |
+
# disable click for GPT choices
|
419 |
+
BOARD_PLATE[i][j].button(
|
420 |
+
session_state.Symbols[cell],
|
421 |
+
key=f"{i}:{j}",
|
422 |
+
args=(i, j),
|
423 |
+
on_click=forbid_click
|
424 |
+
)
|
425 |
+
else:
|
426 |
+
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts:
|
427 |
+
# enable click for other cells available for human choices
|
428 |
+
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
|
429 |
+
BOARD_PLATE[i][j].button(
|
430 |
+
session_state.Symbols[cell] + f"{round(prob, 2)}",
|
431 |
+
key=f"{i}:{j}",
|
432 |
+
on_click=handle_click,
|
433 |
+
args=(i, j),
|
434 |
+
)
|
435 |
+
else:
|
436 |
+
# enable click for other cells available for human choices
|
437 |
+
BOARD_PLATE[i][j].button(
|
438 |
+
session_state.Symbols[cell],
|
439 |
+
key=f"{i}:{j}",
|
440 |
+
on_click=handle_click,
|
441 |
+
args=(i, j),
|
442 |
+
)
|
443 |
+
end_time = time.time()
|
444 |
+
print("Time used for human move: ", end_time - start_time)
|
445 |
+
|
446 |
+
elif response and session_state.ROOM.CURR_PLAYER == _AI: # AI turn
|
447 |
+
message.empty()
|
448 |
+
with st.spinner('🔮✨ Waiting for AI response... ⏳🚀'):
|
449 |
+
time.sleep(0.05)
|
450 |
+
print("AI's turn")
|
451 |
+
|
452 |
+
if st.session_state['selected_oppo_model'] != 'Gomoku Bot':
|
453 |
+
move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
|
454 |
+
else:
|
455 |
+
move, simul_time = session_state.ROOM.MCTS.get_action(return_time=True)
|
456 |
+
session_state.ROOM.ai_simula_time_list.append(simul_time)
|
457 |
+
print("AI takes move: ", move)
|
458 |
+
session_state.ROOM.current_move = move
|
459 |
+
gpt_response = move
|
460 |
+
gpt_i, gpt_j = gpt_response // _BOARD_SIZE, gpt_response % _BOARD_SIZE
|
461 |
+
print("AI's move is located at ({}, {}) :".format(gpt_i, gpt_j))
|
462 |
+
move = session_state.ROOM.BOARD.location_to_move((gpt_i, gpt_j))
|
463 |
+
print("Location to move: ", move)
|
464 |
+
# MCTS BOARD
|
465 |
+
session_state.ROOM.BOARD.do_move(move)
|
466 |
+
# Gomoku Bot BOARD
|
467 |
+
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(_BOARD_SIZE - 1 - move // _BOARD_SIZE,
|
468 |
+
move % _BOARD_SIZE)
|
469 |
+
session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
|
470 |
+
|
471 |
+
if not session_state.ROOM.BOARD.game_end()[0]:
|
472 |
+
if session_state.USE_AIAID:
|
473 |
+
copy_mcts = deepcopy(session_state.ROOM.AID_MCTS.mcts)
|
474 |
+
_, acts_aid, probs_aid, simul_mean_time_aid = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
|
475 |
+
sorted_acts_probs = sorted(zip(acts_aid, probs_aid), key=lambda x: x[1], reverse=True)
|
476 |
+
top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
|
477 |
+
top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
|
478 |
+
else:
|
479 |
+
top_five_acts = []
|
480 |
+
top_five_probs = []
|
481 |
+
|
482 |
+
# construction of clickable buttons
|
483 |
+
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
484 |
+
# print("row:", row)
|
485 |
+
for j, cell in enumerate(row):
|
486 |
+
if (
|
487 |
+
i * _BOARD_SIZE + j
|
488 |
+
in (session_state.ROOM.COORDINATE_1D)
|
489 |
+
):
|
490 |
+
if i == gpt_i and j == gpt_j:
|
491 |
+
BOARD_PLATE[i][j].button(
|
492 |
+
session_state.Symbols[_NEW],
|
493 |
+
key=f"{i}:{j}",
|
494 |
+
args=(i, j),
|
495 |
+
on_click=handle_click,
|
496 |
+
)
|
497 |
+
else:
|
498 |
+
# disable click for GPT choices
|
499 |
+
BOARD_PLATE[i][j].button(
|
500 |
+
session_state.Symbols[cell],
|
501 |
+
key=f"{i}:{j}",
|
502 |
+
args=(i, j),
|
503 |
+
on_click=forbid_click
|
504 |
+
)
|
505 |
+
else:
|
506 |
+
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts and not \
|
507 |
+
session_state.ROOM.BOARD.game_end()[0]:
|
508 |
+
# enable click for other cells available for human choices
|
509 |
+
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
|
510 |
+
BOARD_PLATE[i][j].button(
|
511 |
+
session_state.Symbols[cell] + f"{round(prob, 2)}",
|
512 |
+
key=f"{i}:{j}",
|
513 |
+
on_click=handle_click,
|
514 |
+
args=(i, j),
|
515 |
+
)
|
516 |
+
else:
|
517 |
+
# enable click for other cells available for human choices
|
518 |
+
BOARD_PLATE[i][j].button(
|
519 |
+
session_state.Symbols[cell],
|
520 |
+
key=f"{i}:{j}",
|
521 |
+
on_click=handle_click,
|
522 |
+
args=(i, j),
|
523 |
+
)
|
524 |
+
|
525 |
+
message.markdown(
|
526 |
+
'AI agent has calculated its strategy, which takes <span style="color: blue; font-size: 20px;">{:.3e}</span>s per simulation.'.format(
|
527 |
+
simul_time),
|
528 |
+
unsafe_allow_html=True
|
529 |
+
)
|
530 |
+
LOG.subheader("Logs")
|
531 |
+
# change turn
|
532 |
+
session_state.ROOM.TURN = change_turn(session_state.ROOM.TURN)
|
533 |
+
session_state.ROOM.CURR_PLAYER = change_turn(session_state.ROOM.CURR_PLAYER)
|
534 |
+
|
535 |
+
win, winner = session_state.ROOM.BOARD.game_end()
|
536 |
+
if win:
|
537 |
+
session_state.ROOM.WINNER = session_state.ROOM.CURR_PLAYER
|
538 |
+
|
539 |
+
session_state.ROOM.HISTORY = (
|
540 |
+
session_state.ROOM.HISTORY[0]
|
541 |
+
+ int(session_state.ROOM.WINNER == _AI),
|
542 |
+
session_state.ROOM.HISTORY[1]
|
543 |
+
+ int(session_state.ROOM.WINNER == _HUMAN),
|
544 |
+
)
|
545 |
+
session_state.ROOM.TIME = time.time()
|
546 |
+
|
547 |
+
if not response or session_state.ROOM.WINNER != _BLANK:
|
548 |
+
if session_state.ROOM.WINNER != _BLANK:
|
549 |
+
print("Game over")
|
550 |
+
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
551 |
+
for j, cell in enumerate(row):
|
552 |
+
BOARD_PLATE[i][j].write(
|
553 |
+
session_state.Symbols[cell],
|
554 |
+
# key=f"{i}:{j}",
|
555 |
+
)
|
556 |
+
|
557 |
+
# Game process control
|
558 |
+
def game_control():
|
559 |
+
if session_state.ROOM.WINNER != _BLANK:
|
560 |
+
draw_board(False)
|
561 |
+
else:
|
562 |
+
draw_board(True)
|
563 |
+
if session_state.ROOM.WINNER != _BLANK or 0 not in session_state.ROOM.BOARD.board_map:
|
564 |
+
GIVEIN.empty()
|
565 |
+
ANOTHER_ROUND.button(
|
566 |
+
"Play Next round!",
|
567 |
+
on_click=another_round,
|
568 |
+
help="Clear board and swap first player",
|
569 |
+
)
|
570 |
+
|
571 |
+
# Infos
|
572 |
+
def update_info() -> None:
|
573 |
+
# Additional information
|
574 |
+
SCORE_PLATE[0].metric("Gomoku-Agent", session_state.ROOM.HISTORY[0])
|
575 |
+
SCORE_PLATE[1].metric("You", session_state.ROOM.HISTORY[1])
|
576 |
+
if session_state.ROOM.WINNER != _BLANK:
|
577 |
+
st.balloons()
|
578 |
+
ROUND_INFO.write(
|
579 |
+
f"#### **{_PLAYER_NAME[session_state.ROOM.WINNER]} WIN!**\n**Click buttons on the left for more plays.**"
|
580 |
+
)
|
581 |
+
|
582 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
583 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
584 |
+
chart_data = pd.DataFrame(session_state.ROOM.ai_simula_time_list, columns=["Simulation Time"])
|
585 |
+
st.line_chart(chart_data)
|
586 |
+
|
587 |
+
game_control()
|
588 |
+
update_info()
|
589 |
+
|
590 |
+
|
591 |
+
if __name__ == "__main__":
|
592 |
+
gomoku()
|