Spaces:
Sleeping
Sleeping
sjz
commited on
Commit
•
2f21cdd
1
Parent(s):
70d6bef
fix AI Aid error when taking turn
Browse files- .vscode/settings.json +4 -0
- Gomoku_MCTS/__init__.py +1 -0
- Gomoku_MCTS/__pycache__/__init__.cpython-38.pyc +0 -0
- Gomoku_MCTS/__pycache__/dueling_net.cpython-38.pyc +0 -0
- Gomoku_MCTS/__pycache__/mcts_alphaZero.cpython-38.pyc +0 -0
- Gomoku_MCTS/__pycache__/mcts_pure.cpython-38.pyc +0 -0
- Gomoku_MCTS/__pycache__/policy_value_net_pytorch.cpython-38.pyc +0 -0
- Gomoku_MCTS/dueling_net.py +2 -1
- Gomoku_MCTS/policy_value_net_pytorch.py +2 -1
- __pycache__/const.cpython-38.pyc +0 -0
- pages/Player_VS_AI.py +12 -5
.vscode/settings.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"editor.suggest.snippetsPreventQuickSuggestions": false,
|
3 |
+
"aiXcoder.showTrayIcon": true
|
4 |
+
}
|
Gomoku_MCTS/__init__.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from .mcts_pure import MCTSPlayer as MCTSpure
|
2 |
from .mcts_alphaZero import MCTSPlayer as alphazero
|
3 |
from .dueling_net import PolicyValueNet
|
|
|
4 |
import numpy as np
|
5 |
|
6 |
|
|
|
1 |
from .mcts_pure import MCTSPlayer as MCTSpure
|
2 |
from .mcts_alphaZero import MCTSPlayer as alphazero
|
3 |
from .dueling_net import PolicyValueNet
|
4 |
+
# from .policy_value_net_pytorch import PolicyValueNet
|
5 |
import numpy as np
|
6 |
|
7 |
|
Gomoku_MCTS/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (5.51 kB). View file
|
|
Gomoku_MCTS/__pycache__/dueling_net.cpython-38.pyc
ADDED
Binary file (4.72 kB). View file
|
|
Gomoku_MCTS/__pycache__/mcts_alphaZero.cpython-38.pyc
ADDED
Binary file (8.09 kB). View file
|
|
Gomoku_MCTS/__pycache__/mcts_pure.cpython-38.pyc
ADDED
Binary file (8.74 kB). View file
|
|
Gomoku_MCTS/__pycache__/policy_value_net_pytorch.cpython-38.pyc
ADDED
Binary file (4.12 kB). View file
|
|
Gomoku_MCTS/dueling_net.py
CHANGED
@@ -52,7 +52,7 @@ class DuelingDQNNet(nn.Module):
|
|
52 |
return F.log_softmax(q_values, dim=1), val
|
53 |
|
54 |
class PolicyValueNet():
|
55 |
-
"""policy-value network """
|
56 |
def __init__(self, board_width, board_height,
|
57 |
model_file=None, use_gpu=False):
|
58 |
self.use_gpu = use_gpu
|
@@ -70,6 +70,7 @@ class PolicyValueNet():
|
|
70 |
if model_file:
|
71 |
net_params = torch.load(model_file)
|
72 |
self.policy_value_net.load_state_dict(net_params, strict=False)
|
|
|
73 |
|
74 |
def policy_value(self, state_batch):
|
75 |
"""
|
|
|
52 |
return F.log_softmax(q_values, dim=1), val
|
53 |
|
54 |
class PolicyValueNet():
|
55 |
+
"""dueling policy-value network """
|
56 |
def __init__(self, board_width, board_height,
|
57 |
model_file=None, use_gpu=False):
|
58 |
self.use_gpu = use_gpu
|
|
|
70 |
if model_file:
|
71 |
net_params = torch.load(model_file)
|
72 |
self.policy_value_net.load_state_dict(net_params, strict=False)
|
73 |
+
print('loaded dueling model file')
|
74 |
|
75 |
def policy_value(self, state_batch):
|
76 |
"""
|
Gomoku_MCTS/policy_value_net_pytorch.py
CHANGED
@@ -55,7 +55,7 @@ class Net(nn.Module):
|
|
55 |
|
56 |
|
57 |
class PolicyValueNet():
|
58 |
-
"""policy-value network """
|
59 |
def __init__(self, board_width, board_height,
|
60 |
model_file=None, use_gpu=False):
|
61 |
self.use_gpu = use_gpu
|
@@ -71,6 +71,7 @@ class PolicyValueNet():
|
|
71 |
if model_file:
|
72 |
net_params = torch.load(model_file)
|
73 |
self.policy_value_net.load_state_dict(net_params)
|
|
|
74 |
|
75 |
def policy_value(self, state_batch):
|
76 |
"""
|
|
|
55 |
|
56 |
|
57 |
class PolicyValueNet():
|
58 |
+
"""alphazero policy-value network """
|
59 |
def __init__(self, board_width, board_height,
|
60 |
model_file=None, use_gpu=False):
|
61 |
self.use_gpu = use_gpu
|
|
|
71 |
if model_file:
|
72 |
net_params = torch.load(model_file)
|
73 |
self.policy_value_net.load_state_dict(net_params)
|
74 |
+
print('loaded model file')
|
75 |
|
76 |
def policy_value(self, state_batch):
|
77 |
"""
|
__pycache__/const.cpython-38.pyc
ADDED
Binary file (959 Bytes). View file
|
|
pages/Player_VS_AI.py
CHANGED
@@ -46,7 +46,7 @@ class Room:
|
|
46 |
self.WINNER = _BLANK
|
47 |
self.TIME = time.time()
|
48 |
self.MCTS = MCTSpure(c_puct=5, n_playout=10)
|
49 |
-
self.MCTS = alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE).policy_value_fn, c_puct=5, n_playout=
|
50 |
self.COORDINATE_1D = [_BOARD_SIZE_1D + 1]
|
51 |
self.current_move = -1
|
52 |
self.simula_time_list = []
|
@@ -242,7 +242,7 @@ def gomoku():
|
|
242 |
# normal play situation
|
243 |
elif session_state.ROOM.WINNER == _BLANK:
|
244 |
# session_state.ROOM = deepcopy(session_state.ROOM)
|
245 |
-
print("View of human player: ", session_state.ROOM.BOARD.board_map)
|
246 |
move = session_state.ROOM.BOARD.location_to_move((x, y))
|
247 |
session_state.ROOM.current_move = move
|
248 |
session_state.ROOM.BOARD.do_move(move)
|
@@ -269,7 +269,7 @@ def gomoku():
|
|
269 |
# Draw board
|
270 |
def draw_board(response: bool):
|
271 |
"""construct each buttons for all cells of the board"""
|
272 |
-
if session_state.USE_AIAID and session_state.ROOM.WINNER == _BLANK:
|
273 |
copy_mcts = deepcopy(session_state.ROOM.MCTS.mcts)
|
274 |
_, acts, probs, simul_mean_time = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
|
275 |
sorted_acts_probs = sorted(zip(acts, probs), key=lambda x: x[1], reverse=True)
|
@@ -318,7 +318,7 @@ def gomoku():
|
|
318 |
time.sleep(0.1)
|
319 |
print("AI's turn")
|
320 |
print("Below are current board under AI's view")
|
321 |
-
print(session_state.ROOM.BOARD.board_map)
|
322 |
move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
|
323 |
session_state.ROOM.simula_time_list.append(simul_time)
|
324 |
print("AI takes move: ", move)
|
@@ -332,6 +332,12 @@ def gomoku():
|
|
332 |
# session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
|
333 |
session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
|
334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
# construction of clickable buttons
|
336 |
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
337 |
# print("row:", row)
|
@@ -390,7 +396,8 @@ def gomoku():
|
|
390 |
session_state.ROOM.TIME = time.time()
|
391 |
|
392 |
if not response or session_state.ROOM.WINNER != _BLANK:
|
393 |
-
|
|
|
394 |
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
395 |
for j, cell in enumerate(row):
|
396 |
BOARD_PLATE[i][j].write(
|
|
|
46 |
self.WINNER = _BLANK
|
47 |
self.TIME = time.time()
|
48 |
self.MCTS = MCTSpure(c_puct=5, n_playout=10)
|
49 |
+
self.MCTS = alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, 'Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth').policy_value_fn, c_puct=5, n_playout=100)
|
50 |
self.COORDINATE_1D = [_BOARD_SIZE_1D + 1]
|
51 |
self.current_move = -1
|
52 |
self.simula_time_list = []
|
|
|
242 |
# normal play situation
|
243 |
elif session_state.ROOM.WINNER == _BLANK:
|
244 |
# session_state.ROOM = deepcopy(session_state.ROOM)
|
245 |
+
# print("View of human player: ", session_state.ROOM.BOARD.board_map)
|
246 |
move = session_state.ROOM.BOARD.location_to_move((x, y))
|
247 |
session_state.ROOM.current_move = move
|
248 |
session_state.ROOM.BOARD.do_move(move)
|
|
|
269 |
# Draw board
|
270 |
def draw_board(response: bool):
|
271 |
"""construct each buttons for all cells of the board"""
|
272 |
+
if session_state.USE_AIAID and session_state.ROOM.WINNER == _BLANK and session_state.ROOM.TURN == _BLACK:
|
273 |
copy_mcts = deepcopy(session_state.ROOM.MCTS.mcts)
|
274 |
_, acts, probs, simul_mean_time = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
|
275 |
sorted_acts_probs = sorted(zip(acts, probs), key=lambda x: x[1], reverse=True)
|
|
|
318 |
time.sleep(0.1)
|
319 |
print("AI's turn")
|
320 |
print("Below are current board under AI's view")
|
321 |
+
# print(session_state.ROOM.BOARD.board_map)
|
322 |
move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
|
323 |
session_state.ROOM.simula_time_list.append(simul_time)
|
324 |
print("AI takes move: ", move)
|
|
|
332 |
# session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
|
333 |
session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
|
334 |
|
335 |
+
copy_mcts = deepcopy(session_state.ROOM.MCTS.mcts)
|
336 |
+
_, acts, probs, simul_mean_time = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
|
337 |
+
sorted_acts_probs = sorted(zip(acts, probs), key=lambda x: x[1], reverse=True)
|
338 |
+
top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
|
339 |
+
top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
|
340 |
+
|
341 |
# construction of clickable buttons
|
342 |
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
343 |
# print("row:", row)
|
|
|
396 |
session_state.ROOM.TIME = time.time()
|
397 |
|
398 |
if not response or session_state.ROOM.WINNER != _BLANK:
|
399 |
+
if session_state.ROOM.WINNER != _BLANK:
|
400 |
+
print("Game over")
|
401 |
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
402 |
for j, cell in enumerate(row):
|
403 |
BOARD_PLATE[i][j].write(
|