sjz commited on
Commit
2f21cdd
1 Parent(s): 70d6bef

fix AI Aid error when taking turn

Browse files
.vscode/settings.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "editor.suggest.snippetsPreventQuickSuggestions": false,
3
+ "aiXcoder.showTrayIcon": true
4
+ }
Gomoku_MCTS/__init__.py CHANGED
@@ -1,6 +1,7 @@
1
  from .mcts_pure import MCTSPlayer as MCTSpure
2
  from .mcts_alphaZero import MCTSPlayer as alphazero
3
  from .dueling_net import PolicyValueNet
 
4
  import numpy as np
5
 
6
 
 
1
  from .mcts_pure import MCTSPlayer as MCTSpure
2
  from .mcts_alphaZero import MCTSPlayer as alphazero
3
  from .dueling_net import PolicyValueNet
4
+ # from .policy_value_net_pytorch import PolicyValueNet
5
  import numpy as np
6
 
7
 
Gomoku_MCTS/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (5.51 kB). View file
 
Gomoku_MCTS/__pycache__/dueling_net.cpython-38.pyc ADDED
Binary file (4.72 kB). View file
 
Gomoku_MCTS/__pycache__/mcts_alphaZero.cpython-38.pyc ADDED
Binary file (8.09 kB). View file
 
Gomoku_MCTS/__pycache__/mcts_pure.cpython-38.pyc ADDED
Binary file (8.74 kB). View file
 
Gomoku_MCTS/__pycache__/policy_value_net_pytorch.cpython-38.pyc ADDED
Binary file (4.12 kB). View file
 
Gomoku_MCTS/dueling_net.py CHANGED
@@ -52,7 +52,7 @@ class DuelingDQNNet(nn.Module):
52
  return F.log_softmax(q_values, dim=1), val
53
 
54
  class PolicyValueNet():
55
- """policy-value network """
56
  def __init__(self, board_width, board_height,
57
  model_file=None, use_gpu=False):
58
  self.use_gpu = use_gpu
@@ -70,6 +70,7 @@ class PolicyValueNet():
70
  if model_file:
71
  net_params = torch.load(model_file)
72
  self.policy_value_net.load_state_dict(net_params, strict=False)
 
73
 
74
  def policy_value(self, state_batch):
75
  """
 
52
  return F.log_softmax(q_values, dim=1), val
53
 
54
  class PolicyValueNet():
55
+ """dueling policy-value network """
56
  def __init__(self, board_width, board_height,
57
  model_file=None, use_gpu=False):
58
  self.use_gpu = use_gpu
 
70
  if model_file:
71
  net_params = torch.load(model_file)
72
  self.policy_value_net.load_state_dict(net_params, strict=False)
73
+ print('loaded dueling model file')
74
 
75
  def policy_value(self, state_batch):
76
  """
Gomoku_MCTS/policy_value_net_pytorch.py CHANGED
@@ -55,7 +55,7 @@ class Net(nn.Module):
55
 
56
 
57
  class PolicyValueNet():
58
- """policy-value network """
59
  def __init__(self, board_width, board_height,
60
  model_file=None, use_gpu=False):
61
  self.use_gpu = use_gpu
@@ -71,6 +71,7 @@ class PolicyValueNet():
71
  if model_file:
72
  net_params = torch.load(model_file)
73
  self.policy_value_net.load_state_dict(net_params)
 
74
 
75
  def policy_value(self, state_batch):
76
  """
 
55
 
56
 
57
  class PolicyValueNet():
58
+ """alphazero policy-value network """
59
  def __init__(self, board_width, board_height,
60
  model_file=None, use_gpu=False):
61
  self.use_gpu = use_gpu
 
71
  if model_file:
72
  net_params = torch.load(model_file)
73
  self.policy_value_net.load_state_dict(net_params)
74
+ print('loaded model file')
75
 
76
  def policy_value(self, state_batch):
77
  """
__pycache__/const.cpython-38.pyc ADDED
Binary file (959 Bytes). View file
 
pages/Player_VS_AI.py CHANGED
@@ -46,7 +46,7 @@ class Room:
46
  self.WINNER = _BLANK
47
  self.TIME = time.time()
48
  self.MCTS = MCTSpure(c_puct=5, n_playout=10)
49
- self.MCTS = alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE).policy_value_fn, c_puct=5, n_playout=10)
50
  self.COORDINATE_1D = [_BOARD_SIZE_1D + 1]
51
  self.current_move = -1
52
  self.simula_time_list = []
@@ -242,7 +242,7 @@ def gomoku():
242
  # normal play situation
243
  elif session_state.ROOM.WINNER == _BLANK:
244
  # session_state.ROOM = deepcopy(session_state.ROOM)
245
- print("View of human player: ", session_state.ROOM.BOARD.board_map)
246
  move = session_state.ROOM.BOARD.location_to_move((x, y))
247
  session_state.ROOM.current_move = move
248
  session_state.ROOM.BOARD.do_move(move)
@@ -269,7 +269,7 @@ def gomoku():
269
  # Draw board
270
  def draw_board(response: bool):
271
  """construct each buttons for all cells of the board"""
272
- if session_state.USE_AIAID and session_state.ROOM.WINNER == _BLANK:
273
  copy_mcts = deepcopy(session_state.ROOM.MCTS.mcts)
274
  _, acts, probs, simul_mean_time = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
275
  sorted_acts_probs = sorted(zip(acts, probs), key=lambda x: x[1], reverse=True)
@@ -318,7 +318,7 @@ def gomoku():
318
  time.sleep(0.1)
319
  print("AI's turn")
320
  print("Below are current board under AI's view")
321
- print(session_state.ROOM.BOARD.board_map)
322
  move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
323
  session_state.ROOM.simula_time_list.append(simul_time)
324
  print("AI takes move: ", move)
@@ -332,6 +332,12 @@ def gomoku():
332
  # session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
333
  session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
334
 
 
 
 
 
 
 
335
  # construction of clickable buttons
336
  for i, row in enumerate(session_state.ROOM.BOARD.board_map):
337
  # print("row:", row)
@@ -390,7 +396,8 @@ def gomoku():
390
  session_state.ROOM.TIME = time.time()
391
 
392
  if not response or session_state.ROOM.WINNER != _BLANK:
393
- print("Game over")
 
394
  for i, row in enumerate(session_state.ROOM.BOARD.board_map):
395
  for j, cell in enumerate(row):
396
  BOARD_PLATE[i][j].write(
 
46
  self.WINNER = _BLANK
47
  self.TIME = time.time()
48
  self.MCTS = MCTSpure(c_puct=5, n_playout=10)
49
+ self.MCTS = alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, 'Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth').policy_value_fn, c_puct=5, n_playout=100)
50
  self.COORDINATE_1D = [_BOARD_SIZE_1D + 1]
51
  self.current_move = -1
52
  self.simula_time_list = []
 
242
  # normal play situation
243
  elif session_state.ROOM.WINNER == _BLANK:
244
  # session_state.ROOM = deepcopy(session_state.ROOM)
245
+ # print("View of human player: ", session_state.ROOM.BOARD.board_map)
246
  move = session_state.ROOM.BOARD.location_to_move((x, y))
247
  session_state.ROOM.current_move = move
248
  session_state.ROOM.BOARD.do_move(move)
 
269
  # Draw board
270
  def draw_board(response: bool):
271
  """construct each buttons for all cells of the board"""
272
+ if session_state.USE_AIAID and session_state.ROOM.WINNER == _BLANK and session_state.ROOM.TURN == _BLACK:
273
  copy_mcts = deepcopy(session_state.ROOM.MCTS.mcts)
274
  _, acts, probs, simul_mean_time = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
275
  sorted_acts_probs = sorted(zip(acts, probs), key=lambda x: x[1], reverse=True)
 
318
  time.sleep(0.1)
319
  print("AI's turn")
320
  print("Below are current board under AI's view")
321
+ # print(session_state.ROOM.BOARD.board_map)
322
  move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
323
  session_state.ROOM.simula_time_list.append(simul_time)
324
  print("AI takes move: ", move)
 
332
  # session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
333
  session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
334
 
335
+ copy_mcts = deepcopy(session_state.ROOM.MCTS.mcts)
336
+ _, acts, probs, simul_mean_time = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
337
+ sorted_acts_probs = sorted(zip(acts, probs), key=lambda x: x[1], reverse=True)
338
+ top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
339
+ top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
340
+
341
  # construction of clickable buttons
342
  for i, row in enumerate(session_state.ROOM.BOARD.board_map):
343
  # print("row:", row)
 
396
  session_state.ROOM.TIME = time.time()
397
 
398
  if not response or session_state.ROOM.WINNER != _BLANK:
399
+ if session_state.ROOM.WINNER != _BLANK:
400
+ print("Game over")
401
  for i, row in enumerate(session_state.ROOM.BOARD.board_map):
402
  for j, cell in enumerate(row):
403
  BOARD_PLATE[i][j].write(