Spaces:
Sleeping
Sleeping
trial
Browse files- .DS_Store +0 -0
- Gomoku_MCTS/.DS_Store +0 -0
- Gomoku_MCTS/__init__.py +142 -0
- Gomoku_MCTS/__pycache__/__init__.cpython-310.pyc +0 -0
- Gomoku_MCTS/__pycache__/dueling_net.cpython-310.pyc +0 -0
- Gomoku_MCTS/__pycache__/game.cpython-310.pyc +0 -0
- Gomoku_MCTS/__pycache__/mcts_alphaZero.cpython-310.pyc +0 -0
- Gomoku_MCTS/__pycache__/mcts_pure.cpython-310.pyc +0 -0
- Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth +3 -0
- Gomoku_MCTS/config/config.yaml +10 -0
- Gomoku_MCTS/config/options.py +74 -0
- Gomoku_MCTS/config/utils.py +54 -0
- Gomoku_MCTS/dueling_net.py +155 -0
- Gomoku_MCTS/game.py +281 -0
- Gomoku_MCTS/main_worker.py +334 -0
- Gomoku_MCTS/mcts_alphaZero.py +250 -0
- Gomoku_MCTS/mcts_pure.py +246 -0
- Gomoku_MCTS/policy_value_net_pytorch.py +159 -0
- Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183498.LAPTOP-5AN2UHOO +3 -0
- Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183516.LAPTOP-5AN2UHOO +3 -0
- Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183568.LAPTOP-5AN2UHOO +3 -0
- Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183629.LAPTOP-5AN2UHOO +3 -0
- Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183640.LAPTOP-5AN2UHOO +3 -0
- Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183667.LAPTOP-5AN2UHOO +3 -0
- Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183756.LAPTOP-5AN2UHOO +3 -0
- Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183820.LAPTOP-5AN2UHOO +3 -0
- Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700184097.LAPTOP-5AN2UHOO +3 -0
- README.md +3 -3
- app.py +56 -0
- assets/favicon_circle.png +0 -0
- const.py +58 -0
- pages/Player_VS_AI.py +409 -0
- requirements.txt +7 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
Gomoku_MCTS/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
Gomoku_MCTS/__init__.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .mcts_pure import MCTSPlayer as MCTSpure
|
2 |
+
from .mcts_alphaZero import MCTSPlayer as alphazero
|
3 |
+
from .dueling_net import PolicyValueNet
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class Board(object):
|
8 |
+
"""board for the game"""
|
9 |
+
|
10 |
+
def __init__(self, **kwargs):
|
11 |
+
self.last_move = None
|
12 |
+
self.availables = None
|
13 |
+
self.current_player = None
|
14 |
+
self.width = int(kwargs.get('width', 8)) # if no width, default 8
|
15 |
+
self.height = int(kwargs.get('height', 8))
|
16 |
+
self.board_map = np.zeros(shape=(self.width, self.height), dtype=int)
|
17 |
+
# board states stored as a dict,
|
18 |
+
# key: move as location on the board,
|
19 |
+
# value: player as pieces type
|
20 |
+
self.states = {}
|
21 |
+
# need how many pieces in a row to win
|
22 |
+
self.n_in_row = int(kwargs.get('n_in_row', 5))
|
23 |
+
self.players = kwargs.get('players', [1, 2]) # player1 and player2
|
24 |
+
self.init_board(0)
|
25 |
+
|
26 |
+
def init_board(self, start_player=0):
|
27 |
+
if self.width < self.n_in_row or self.height < self.n_in_row:
|
28 |
+
raise Exception('board width and height can not be '
|
29 |
+
'less than {}'.format(self.n_in_row))
|
30 |
+
self.current_player = self.players[start_player] # start player
|
31 |
+
# keep available moves in a list
|
32 |
+
self.availables = list(range(self.width * self.height))
|
33 |
+
self.states = {}
|
34 |
+
self.last_move = -1
|
35 |
+
|
36 |
+
def move_to_location(self, move: int):
|
37 |
+
"""
|
38 |
+
3*3 board's moves like:
|
39 |
+
6 7 8
|
40 |
+
3 4 5
|
41 |
+
0 1 2
|
42 |
+
and move 5's location is (1,2)
|
43 |
+
"""
|
44 |
+
h = move // self.width
|
45 |
+
w = move % self.width
|
46 |
+
return [h, w]
|
47 |
+
|
48 |
+
def location_to_move(self, location):
|
49 |
+
if len(location) != 2:
|
50 |
+
return -1
|
51 |
+
h = location[0]
|
52 |
+
w = location[1]
|
53 |
+
move = h * self.width + w
|
54 |
+
if move not in range(self.width * self.height):
|
55 |
+
return -1
|
56 |
+
return move
|
57 |
+
|
58 |
+
def current_state(self):
|
59 |
+
"""
|
60 |
+
return the board state from the perspective of the current player.
|
61 |
+
state shape: 4*width*height
|
62 |
+
这个状态数组具有四个通道:
|
63 |
+
第一个通道表示当前玩家的棋子位置,第二个通道表示对手的棋子位置,第三个通道表示最后一步移动的位置。
|
64 |
+
第四个通道是一个指示符,用于表示当前轮到哪个玩家(如果棋盘上的总移动次数是偶数,那么这个通道的所有元素都为1,表示是第一个玩家的回合;否则,所有元素都为0,表示是第二个玩家的回合)。
|
65 |
+
每个通道都是一个 width x height 的二维数组,代表着棋盘的布局。对于第一个和第二个通道,如果一个位置上有当前玩家或对手的棋子,那么该位置的值为 1,否则为0。
|
66 |
+
对于第三个通道,只有最后一步移动的位置是1,其余位置都为0。对于第四个通道,如果是第一个玩家的回合,那么所有的位置都是1,否则都是0。
|
67 |
+
最后,状态数组在垂直方向上翻转,以匹配棋盘的实际布局。
|
68 |
+
"""
|
69 |
+
|
70 |
+
square_state = np.zeros((4, self.width, self.height))
|
71 |
+
if self.states:
|
72 |
+
moves, players = np.array(list(zip(*self.states.items())))
|
73 |
+
move_curr = moves[players == self.current_player]
|
74 |
+
move_oppo = moves[players != self.current_player]
|
75 |
+
square_state[0][move_curr // self.width,
|
76 |
+
move_curr % self.height] = 1.0
|
77 |
+
square_state[1][move_oppo // self.width,
|
78 |
+
move_oppo % self.height] = 1.0
|
79 |
+
# indicate the last move location
|
80 |
+
square_state[2][self.last_move // self.width,
|
81 |
+
self.last_move % self.height] = 1.0
|
82 |
+
if len(self.states) % 2 == 0:
|
83 |
+
square_state[3][:, :] = 1.0 # indicate the colour to play
|
84 |
+
return square_state[:, ::-1, :]
|
85 |
+
|
86 |
+
def do_move(self, move):
|
87 |
+
self.states[move] = self.current_player
|
88 |
+
# get (x,y) of this move
|
89 |
+
x, y = self.move_to_location(move)
|
90 |
+
self.board_map[x][y] = self.current_player
|
91 |
+
|
92 |
+
self.availables.remove(move)
|
93 |
+
self.current_player = (
|
94 |
+
self.players[0] if self.current_player == self.players[1]
|
95 |
+
else self.players[1]
|
96 |
+
)
|
97 |
+
self.last_move = move
|
98 |
+
|
99 |
+
def has_a_winner(self):
|
100 |
+
width = self.width
|
101 |
+
height = self.height
|
102 |
+
states = self.states
|
103 |
+
n = self.n_in_row
|
104 |
+
|
105 |
+
moved = list(set(range(width * height)) - set(self.availables))
|
106 |
+
if len(moved) < self.n_in_row * 2 - 1:
|
107 |
+
return False, -1
|
108 |
+
|
109 |
+
for m in moved:
|
110 |
+
h = m // width
|
111 |
+
w = m % width
|
112 |
+
player = states[m]
|
113 |
+
|
114 |
+
if (w in range(width - n + 1) and
|
115 |
+
len(set(states.get(i, -1) for i in range(m, m + n))) == 1):
|
116 |
+
return True, player
|
117 |
+
|
118 |
+
if (h in range(height - n + 1) and
|
119 |
+
len(set(states.get(i, -1) for i in range(m, m + n * width, width))) == 1):
|
120 |
+
return True, player
|
121 |
+
|
122 |
+
if (w in range(width - n + 1) and h in range(height - n + 1) and
|
123 |
+
len(set(states.get(i, -1) for i in range(m, m + n * (width + 1), width + 1))) == 1):
|
124 |
+
return True, player
|
125 |
+
|
126 |
+
if (w in range(n - 1, width) and h in range(height - n + 1) and
|
127 |
+
len(set(states.get(i, -1) for i in range(m, m + n * (width - 1), width - 1))) == 1):
|
128 |
+
return True, player
|
129 |
+
|
130 |
+
return False, -1
|
131 |
+
|
132 |
+
def game_end(self):
|
133 |
+
"""Check whether the game is ended or not"""
|
134 |
+
win, winner = self.has_a_winner()
|
135 |
+
if win:
|
136 |
+
return True, winner
|
137 |
+
elif not len(self.availables):
|
138 |
+
return True, -1
|
139 |
+
return False, -1
|
140 |
+
|
141 |
+
def get_current_player(self):
|
142 |
+
return self.current_player
|
Gomoku_MCTS/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (5.41 kB). View file
|
|
Gomoku_MCTS/__pycache__/dueling_net.cpython-310.pyc
ADDED
Binary file (4.71 kB). View file
|
|
Gomoku_MCTS/__pycache__/game.cpython-310.pyc
ADDED
Binary file (8.97 kB). View file
|
|
Gomoku_MCTS/__pycache__/mcts_alphaZero.cpython-310.pyc
ADDED
Binary file (8.05 kB). View file
|
|
Gomoku_MCTS/__pycache__/mcts_pure.cpython-310.pyc
ADDED
Binary file (8.73 kB). View file
|
|
Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:878aace7c41962e0817fe8298a1f260b3b83e71c24d7d8c3558ccd6c4996d4f8
|
3 |
+
size 481383
|
Gomoku_MCTS/config/config.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ckpt/logger options(dynamic)
|
2 |
+
checkpoint_base: checkpoint
|
3 |
+
visual_base: visualization
|
4 |
+
log_base: log
|
5 |
+
|
6 |
+
# dataset
|
7 |
+
data_base: dataset
|
8 |
+
|
9 |
+
|
10 |
+
|
Gomoku_MCTS/config/options.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
7 |
+
|
8 |
+
# basic settings
|
9 |
+
parser.add_argument('--seed', default=1234, type=int)
|
10 |
+
parser.add_argument('--savepath', type=str, default="blip_uni_cross_mu", help='')
|
11 |
+
|
12 |
+
|
13 |
+
# board settings
|
14 |
+
parser.add_argument("--board_width", type=int,default=9)
|
15 |
+
parser.add_argument("--board_height", type=int,default=9)
|
16 |
+
parser.add_argument("--n_in_row", type=int,default=5,help="the condition of winning")
|
17 |
+
|
18 |
+
|
19 |
+
# device settings
|
20 |
+
parser.add_argument('--config', type=str, default='config/config.yaml', help='Path to the config file.')
|
21 |
+
parser.add_argument('--gpu_num', type=int, default=1)
|
22 |
+
parser.add_argument('--gpu_id', type=str, default='5')
|
23 |
+
|
24 |
+
|
25 |
+
# save options
|
26 |
+
parser.add_argument('--clear_visualizer', dest='clear_visualizer', action='store_true')
|
27 |
+
parser.add_argument('--std_log', dest='std_log', action='store_true')
|
28 |
+
|
29 |
+
|
30 |
+
# mode settings
|
31 |
+
parser.add_argument("--split",type=str,default="train",help="the mode of woker")
|
32 |
+
|
33 |
+
|
34 |
+
# train settings
|
35 |
+
parser.add_argument("--expri",type=str, default="",help="the name of experiment")
|
36 |
+
parser.add_argument("--learn_rate", type=float,default=2e-3)
|
37 |
+
parser.add_argument("--l2_const",type=float,default=1e-4)
|
38 |
+
# ???
|
39 |
+
parser.add_argument("--lr_multiplier", type=float,default= 1.0 ,help="adaptively adjust the learning rate based on KL")
|
40 |
+
parser.add_argument("--buffer_size",type=int,default=10000,help="The size of collection of game data ")
|
41 |
+
parser.add_argument("--batch_size",type=int,default=512)
|
42 |
+
parser.add_argument("--play_batch_size",type=int, default=1,help="The time of selfplaying when collect the data")
|
43 |
+
parser.add_argument("--epochs",type=int,default=5,help="num of train_steps for each update")
|
44 |
+
parser.add_argument("--kl_targ",type=float,default=0.02,help="the target kl distance between the old decision function and the new decision function ")
|
45 |
+
parser.add_argument("--check_freq",type=int,default=50,help='the frequence of the checking the win ratio when training')
|
46 |
+
parser.add_argument("--game_batch_num",type=int,default=1500,help = "the total training times")
|
47 |
+
|
48 |
+
|
49 |
+
# parser.add_argument("--l2_const",type=float,default=1e-4,help=" coef of l2 penalty")
|
50 |
+
parser.add_argument("--distributed",type=bool,default=False)
|
51 |
+
|
52 |
+
# preload_model setting
|
53 |
+
parser.add_argument("--preload_model",type=str, default="")
|
54 |
+
|
55 |
+
|
56 |
+
# Alphazero agent setting
|
57 |
+
parser.add_argument("--temp", type=float,default= 1.0 ,help="the temperature parameter when calculate the decision function getting the next action")
|
58 |
+
parser.add_argument("--n_playout",type=int, default=200, help="num of simulations for each move ")
|
59 |
+
parser.add_argument("--c_puct",type=int, default=5, help= "the balance parameter between exploration and exploitative ")
|
60 |
+
|
61 |
+
# prue_mcts agent setting
|
62 |
+
parser.add_argument("--pure_mcts_playout_num",type=int, default=200)
|
63 |
+
|
64 |
+
# test settings
|
65 |
+
parser.add_argument('--test_ckpt', type=str, default=None, help='ckpt absolute path')
|
66 |
+
|
67 |
+
|
68 |
+
opts = parser.parse_args()
|
69 |
+
|
70 |
+
# additional parameters
|
71 |
+
current_path = os.path.abspath(__file__)
|
72 |
+
grandfather_path = os.path.abspath(os.path.dirname(os.path.dirname(current_path)) + os.path.sep + ".")
|
73 |
+
with open(os.path.join(grandfather_path, opts.config), 'r') as stream:
|
74 |
+
config = yaml.full_load(stream)
|
Gomoku_MCTS/config/utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, shutil
|
2 |
+
import torch
|
3 |
+
from tensorboardX import SummaryWriter
|
4 |
+
from config.options import *
|
5 |
+
import torch.distributed as dist
|
6 |
+
import time
|
7 |
+
|
8 |
+
""" ==================== Save ======================== """
|
9 |
+
|
10 |
+
def make_path():
|
11 |
+
return "{}_{}_bs{}_lr{}".format(opts.expri,opts.savepath,opts.batch_size,opts.learn_rate)
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
def save_model(model,name):
|
17 |
+
save_path = make_path()
|
18 |
+
if not os.path.isdir(os.path.join(config['checkpoint_base'], save_path)):
|
19 |
+
os.makedirs(os.path.join(config['checkpoint_base'], save_path), exist_ok=True)
|
20 |
+
model_name = os.path.join(config['checkpoint_base'], save_path, name)
|
21 |
+
torch.save(model.state_dict(), model_name)
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
""" ==================== Tools ======================== """
|
27 |
+
def is_dist_avail_and_initialized():
|
28 |
+
if not dist.is_available():
|
29 |
+
return False
|
30 |
+
if not dist.is_initialized():
|
31 |
+
return False
|
32 |
+
return True
|
33 |
+
|
34 |
+
def get_rank():
|
35 |
+
if not is_dist_avail_and_initialized():
|
36 |
+
return 0
|
37 |
+
return dist.get_rank()
|
38 |
+
|
39 |
+
|
40 |
+
def makedir(path):
|
41 |
+
if not os.path.exists(path):
|
42 |
+
os.makedirs(path, 0o777)
|
43 |
+
|
44 |
+
|
45 |
+
def visualizer():
|
46 |
+
if get_rank() == 0:
|
47 |
+
# filewriter_path = config['visual_base']+opts.savepath+'/'
|
48 |
+
save_path = make_path()
|
49 |
+
filewriter_path = os.path.join(config['visual_base'], save_path)
|
50 |
+
if opts.clear_visualizer and os.path.exists(filewriter_path): # 删掉以前的summary,以免重合
|
51 |
+
shutil.rmtree(filewriter_path)
|
52 |
+
makedir(filewriter_path)
|
53 |
+
writer = SummaryWriter(filewriter_path, comment='visualizer')
|
54 |
+
return writer
|
Gomoku_MCTS/dueling_net.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.optim as optim
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.autograd import Variable
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
def set_learning_rate(optimizer, lr):
|
9 |
+
"""Sets the learning rate to the given value"""
|
10 |
+
for param_group in optimizer.param_groups:
|
11 |
+
param_group['lr'] = lr
|
12 |
+
|
13 |
+
class DuelingDQNNet(nn.Module):
|
14 |
+
"""Dueling DQN network module"""
|
15 |
+
def __init__(self, board_width, board_height):
|
16 |
+
super(DuelingDQNNet, self).__init__()
|
17 |
+
|
18 |
+
self.board_width = board_width
|
19 |
+
self.board_height = board_height
|
20 |
+
# common layers
|
21 |
+
self.conv1 = nn.Conv2d(4, 32, kernel_size=3, padding=1)
|
22 |
+
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
|
23 |
+
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
|
24 |
+
# advantage layers
|
25 |
+
self.adv_conv1 = nn.Conv2d(128, 4, kernel_size=1)
|
26 |
+
self.adv_fc1 = nn.Linear(4*board_width*board_height,
|
27 |
+
board_width*board_height)
|
28 |
+
# value layers
|
29 |
+
self.val_conv1 = nn.Conv2d(128, 2, kernel_size=1)
|
30 |
+
self.val_fc1 = nn.Linear(2*board_width*board_height, 64)
|
31 |
+
self.val_fc2 = nn.Linear(64, 1)
|
32 |
+
|
33 |
+
def forward(self, state_input):
|
34 |
+
# common layers
|
35 |
+
x = F.relu(self.conv1(state_input))
|
36 |
+
x = F.relu(self.conv2(x))
|
37 |
+
x = F.relu(self.conv3(x))
|
38 |
+
|
39 |
+
# advantage stream
|
40 |
+
adv = F.relu(self.adv_conv1(x))
|
41 |
+
adv = adv.view(-1, 4*self.board_width*self.board_height)
|
42 |
+
adv = self.adv_fc1(adv)
|
43 |
+
|
44 |
+
# value stream
|
45 |
+
val = F.relu(self.val_conv1(x))
|
46 |
+
val = val.view(-1, 2*self.board_width*self.board_height)
|
47 |
+
val = F.relu(self.val_fc1(val))
|
48 |
+
val = self.val_fc2(val)
|
49 |
+
|
50 |
+
q_values = val + adv - adv.mean(dim=1, keepdim=True)
|
51 |
+
|
52 |
+
return F.log_softmax(q_values, dim=1), val
|
53 |
+
|
54 |
+
class PolicyValueNet():
|
55 |
+
"""policy-value network """
|
56 |
+
def __init__(self, board_width, board_height,
|
57 |
+
model_file=None, use_gpu=False):
|
58 |
+
self.use_gpu = use_gpu
|
59 |
+
self.board_width = board_width
|
60 |
+
self.board_height = board_height
|
61 |
+
self.l2_const = 1e-4 # coef of l2 penalty
|
62 |
+
# the policy value net module
|
63 |
+
if self.use_gpu:
|
64 |
+
self.policy_value_net = DuelingDQNNet(board_width, board_height).cuda()
|
65 |
+
else:
|
66 |
+
self.policy_value_net = DuelingDQNNet(board_width, board_height)
|
67 |
+
self.optimizer = optim.Adam(self.policy_value_net.parameters(),
|
68 |
+
weight_decay=self.l2_const)
|
69 |
+
|
70 |
+
if model_file:
|
71 |
+
net_params = torch.load(model_file)
|
72 |
+
self.policy_value_net.load_state_dict(net_params, strict=False)
|
73 |
+
|
74 |
+
def policy_value(self, state_batch):
|
75 |
+
"""
|
76 |
+
input: a batch of states
|
77 |
+
output: a batch of action probabilities and state values
|
78 |
+
"""
|
79 |
+
if self.use_gpu:
|
80 |
+
state_batch = Variable(torch.FloatTensor(state_batch).cuda())
|
81 |
+
log_act_probs, value = self.policy_value_net(state_batch)
|
82 |
+
act_probs = np.exp(log_act_probs.data.cpu().numpy())
|
83 |
+
return act_probs, value.data.cpu().numpy()
|
84 |
+
else:
|
85 |
+
state_batch = Variable(torch.FloatTensor(state_batch))
|
86 |
+
log_act_probs, value = self.policy_value_net(state_batch)
|
87 |
+
act_probs = np.exp(log_act_probs.data.numpy())
|
88 |
+
return act_probs, value.data.numpy()
|
89 |
+
|
90 |
+
def policy_value_fn(self, board):
|
91 |
+
"""
|
92 |
+
input: board
|
93 |
+
output: a list of (action, probability) tuples for each available
|
94 |
+
action and the score of the board state
|
95 |
+
"""
|
96 |
+
legal_positions = board.availables
|
97 |
+
current_state = np.ascontiguousarray(board.current_state().reshape(
|
98 |
+
-1, 4, self.board_width, self.board_height))
|
99 |
+
if self.use_gpu:
|
100 |
+
log_act_probs, value = self.policy_value_net(
|
101 |
+
Variable(torch.from_numpy(current_state)).cuda().float())
|
102 |
+
act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten())
|
103 |
+
else:
|
104 |
+
log_act_probs, value = self.policy_value_net(
|
105 |
+
Variable(torch.from_numpy(current_state)).float())
|
106 |
+
act_probs = np.exp(log_act_probs.data.numpy().flatten())
|
107 |
+
act_probs = zip(legal_positions, act_probs[legal_positions])
|
108 |
+
value = value.data[0][0]
|
109 |
+
return act_probs, value
|
110 |
+
|
111 |
+
def train_step(self, state_batch, mcts_probs, winner_batch, lr):
|
112 |
+
"""perform a training step"""
|
113 |
+
|
114 |
+
# self.use_gpu = True
|
115 |
+
# wrap in Variable
|
116 |
+
if self.use_gpu:
|
117 |
+
state_batch = Variable(torch.FloatTensor(state_batch).cuda())
|
118 |
+
mcts_probs = Variable(torch.FloatTensor(mcts_probs).cuda())
|
119 |
+
winner_batch = Variable(torch.FloatTensor(winner_batch).cuda())
|
120 |
+
else:
|
121 |
+
state_batch = Variable(torch.FloatTensor(state_batch))
|
122 |
+
mcts_probs = Variable(torch.FloatTensor(mcts_probs))
|
123 |
+
winner_batch = Variable(torch.FloatTensor(winner_batch))
|
124 |
+
|
125 |
+
# zero the parameter gradients
|
126 |
+
self.optimizer.zero_grad()
|
127 |
+
# set learning rate
|
128 |
+
set_learning_rate(self.optimizer, lr)
|
129 |
+
|
130 |
+
# forward
|
131 |
+
log_act_probs, value = self.policy_value_net(state_batch)
|
132 |
+
# define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2
|
133 |
+
# Note: the L2 penalty is incorporated in optimizer
|
134 |
+
value_loss = F.mse_loss(value.view(-1), winner_batch)
|
135 |
+
policy_loss = -torch.mean(torch.sum(mcts_probs*log_act_probs, 1))
|
136 |
+
loss = value_loss + policy_loss
|
137 |
+
# backward and optimize
|
138 |
+
loss.backward()
|
139 |
+
self.optimizer.step()
|
140 |
+
# calc policy entropy, for monitoring only
|
141 |
+
entropy = -torch.mean(
|
142 |
+
torch.sum(torch.exp(log_act_probs) * log_act_probs, 1)
|
143 |
+
)
|
144 |
+
# return loss.data[0], entropy.data[0]
|
145 |
+
#for pytorch version >= 0.5 please use the following line instead.
|
146 |
+
return loss.item(), entropy.item()
|
147 |
+
|
148 |
+
def get_policy_param(self):
|
149 |
+
net_params = self.policy_value_net.state_dict()
|
150 |
+
return net_params
|
151 |
+
|
152 |
+
def save_model(self, model_file):
|
153 |
+
""" save model params to file """
|
154 |
+
net_params = self.get_policy_param() # get model params
|
155 |
+
torch.save(net_params, model_file)
|
Gomoku_MCTS/game.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
FileName: game.py
|
3 |
+
Author: Jiaxin Li
|
4 |
+
Create Date: yyyy/mm/dd
|
5 |
+
Description: to be completed
|
6 |
+
Edit History:
|
7 |
+
- 2023/11/18, Sat, Edited by Hbh ([email protected])
|
8 |
+
- added some comments and optimize import and some structures
|
9 |
+
"""
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
from mcts_pure import MCTSPlayer as MCTS_Pure
|
13 |
+
from mcts_pure import Human_Player
|
14 |
+
from collections import defaultdict
|
15 |
+
from typing import Optional
|
16 |
+
|
17 |
+
|
18 |
+
class Board(object):
|
19 |
+
"""board for the game"""
|
20 |
+
|
21 |
+
def __init__(self, **kwargs):
|
22 |
+
self.last_move = None
|
23 |
+
self.availables = None
|
24 |
+
self.current_player = None
|
25 |
+
self.width = int(kwargs.get('width', 8)) # if no width, default 8
|
26 |
+
self.height = int(kwargs.get('height', 8))
|
27 |
+
# board states stored as a dict,
|
28 |
+
# key: move as location on the board,
|
29 |
+
# value: player as pieces type
|
30 |
+
self.states = {}
|
31 |
+
# need how many pieces in a row to win
|
32 |
+
self.n_in_row = int(kwargs.get('n_in_row', 5))
|
33 |
+
self.players = [1, 2] # player1 and player2
|
34 |
+
|
35 |
+
def init_board(self, start_player=0):
|
36 |
+
if self.width < self.n_in_row or self.height < self.n_in_row:
|
37 |
+
raise Exception('board width and height can not be '
|
38 |
+
'less than {}'.format(self.n_in_row))
|
39 |
+
self.current_player = self.players[start_player] # start player
|
40 |
+
# keep available moves in a list
|
41 |
+
self.availables = list(range(self.width * self.height))
|
42 |
+
self.states = {}
|
43 |
+
self.last_move = -1
|
44 |
+
|
45 |
+
def move_to_location(self, move: int):
|
46 |
+
"""
|
47 |
+
3*3 board's moves like:
|
48 |
+
6 7 8
|
49 |
+
3 4 5
|
50 |
+
0 1 2
|
51 |
+
and move 5's location is (1,2)
|
52 |
+
"""
|
53 |
+
h = move // self.width
|
54 |
+
w = move % self.width
|
55 |
+
return [h, w]
|
56 |
+
|
57 |
+
def location_to_move(self, location):
|
58 |
+
if len(location) != 2:
|
59 |
+
return -1
|
60 |
+
h = location[0]
|
61 |
+
w = location[1]
|
62 |
+
move = h * self.width + w
|
63 |
+
if move not in range(self.width * self.height):
|
64 |
+
return -1
|
65 |
+
return move
|
66 |
+
|
67 |
+
def current_state(self):
|
68 |
+
"""
|
69 |
+
return the board state from the perspective of the current player.
|
70 |
+
state shape: 4*width*height
|
71 |
+
这个状态数组具有四个通道:
|
72 |
+
第一个通道表示当前玩家的棋子位置,第二个通道表示对手的棋子位置,第三个通道表示最后一步移动的位置。
|
73 |
+
第四个通道是一个指示符,用于表示当前轮到哪个玩家(如果棋盘上的总移动次数是偶数,那么这个通道的所有元素都为1,表示是第一个玩家的回合;否则,所有元素都为0,表示是第二个玩家的回合)。
|
74 |
+
每个通道都是一个 width x height 的二维数组,代表着棋盘的布局。对于第一个和第二个通道,如果一个位置上有当前玩家或对手的棋子,那么该位置的值为 1,否则为0。
|
75 |
+
对于第三个通道,只有最后一步移动的位置是1,其余位置都为0。对于第四个通道,如果是第一个玩家的回合,那么所有的位置都是1,否则都是0。
|
76 |
+
最后,状态数组在垂直方向上翻转,以匹配棋盘的实际布局。
|
77 |
+
"""
|
78 |
+
|
79 |
+
square_state = np.zeros((4, self.width, self.height))
|
80 |
+
if self.states:
|
81 |
+
moves, players = np.array(list(zip(*self.states.items())))
|
82 |
+
move_curr = moves[players == self.current_player]
|
83 |
+
move_oppo = moves[players != self.current_player]
|
84 |
+
square_state[0][move_curr // self.width,
|
85 |
+
move_curr % self.height] = 1.0
|
86 |
+
square_state[1][move_oppo // self.width,
|
87 |
+
move_oppo % self.height] = 1.0
|
88 |
+
# indicate the last move location
|
89 |
+
square_state[2][self.last_move // self.width,
|
90 |
+
self.last_move % self.height] = 1.0
|
91 |
+
if len(self.states) % 2 == 0:
|
92 |
+
square_state[3][:, :] = 1.0 # indicate the colour to play
|
93 |
+
return square_state[:, ::-1, :]
|
94 |
+
|
95 |
+
def do_move(self, move):
|
96 |
+
self.states[move] = self.current_player
|
97 |
+
self.availables.remove(move)
|
98 |
+
self.current_player = (
|
99 |
+
self.players[0] if self.current_player == self.players[1]
|
100 |
+
else self.players[1]
|
101 |
+
)
|
102 |
+
self.last_move = move
|
103 |
+
|
104 |
+
def has_a_winner(self):
|
105 |
+
width = self.width
|
106 |
+
height = self.height
|
107 |
+
states = self.states
|
108 |
+
n = self.n_in_row
|
109 |
+
|
110 |
+
moved = list(set(range(width * height)) - set(self.availables))
|
111 |
+
if len(moved) < self.n_in_row * 2 - 1:
|
112 |
+
return False, -1
|
113 |
+
|
114 |
+
for m in moved:
|
115 |
+
h = m // width
|
116 |
+
w = m % width
|
117 |
+
player = states[m]
|
118 |
+
|
119 |
+
if (w in range(width - n + 1) and
|
120 |
+
len(set(states.get(i, -1) for i in range(m, m + n))) == 1):
|
121 |
+
return True, player
|
122 |
+
|
123 |
+
if (h in range(height - n + 1) and
|
124 |
+
len(set(states.get(i, -1) for i in range(m, m + n * width, width))) == 1):
|
125 |
+
return True, player
|
126 |
+
|
127 |
+
if (w in range(width - n + 1) and h in range(height - n + 1) and
|
128 |
+
len(set(states.get(i, -1) for i in range(m, m + n * (width + 1), width + 1))) == 1):
|
129 |
+
return True, player
|
130 |
+
|
131 |
+
if (w in range(n - 1, width) and h in range(height - n + 1) and
|
132 |
+
len(set(states.get(i, -1) for i in range(m, m + n * (width - 1), width - 1))) == 1):
|
133 |
+
return True, player
|
134 |
+
|
135 |
+
return False, -1
|
136 |
+
|
137 |
+
def game_end(self):
|
138 |
+
"""Check whether the game is ended or not"""
|
139 |
+
win, winner = self.has_a_winner()
|
140 |
+
if win:
|
141 |
+
return True, winner
|
142 |
+
elif not len(self.availables):
|
143 |
+
return True, -1
|
144 |
+
return False, -1
|
145 |
+
|
146 |
+
def get_current_player(self):
|
147 |
+
return self.current_player
|
148 |
+
|
149 |
+
|
150 |
+
class Game(object):
|
151 |
+
"""game server"""
|
152 |
+
|
153 |
+
def __init__(self, board, **kwargs):
|
154 |
+
self.board = board
|
155 |
+
self.pure_mcts_playout_num = 100 # simulation time
|
156 |
+
|
157 |
+
def graphic(self, board, player1, player2):
|
158 |
+
"""Draw the board and show game info"""
|
159 |
+
width = board.width
|
160 |
+
height = board.height
|
161 |
+
|
162 |
+
print("Player", player1, "with X".rjust(3))
|
163 |
+
print("Player", player2, "with O".rjust(3))
|
164 |
+
print()
|
165 |
+
for x in range(width):
|
166 |
+
print("{0:8}".format(x), end='')
|
167 |
+
print('\r\n')
|
168 |
+
for i in range(height - 1, -1, -1):
|
169 |
+
print("{0:4d}".format(i), end='')
|
170 |
+
for j in range(width):
|
171 |
+
loc = i * width + j
|
172 |
+
p = board.states.get(loc, -1)
|
173 |
+
if p == player1:
|
174 |
+
print('X'.center(8), end='')
|
175 |
+
elif p == player2:
|
176 |
+
print('O'.center(8), end='')
|
177 |
+
else:
|
178 |
+
print('_'.center(8), end='')
|
179 |
+
print('\r\n\r\n')
|
180 |
+
|
181 |
+
def start_play(self, player1, player2, start_player=0, is_shown=1):
|
182 |
+
"""start a game between two players"""
|
183 |
+
if start_player not in (0, 1):
|
184 |
+
raise Exception('start_player should be either 0 (player1 first) '
|
185 |
+
'or 1 (player2 f1irst)')
|
186 |
+
self.board.init_board(start_player)
|
187 |
+
p1, p2 = self.board.players
|
188 |
+
player1.set_player_ind(p1)
|
189 |
+
player2.set_player_ind(p2)
|
190 |
+
players = {p1: player1, p2: player2}
|
191 |
+
if is_shown:
|
192 |
+
self.graphic(self.board, player1.player, player2.player)
|
193 |
+
while True:
|
194 |
+
current_player = self.board.get_current_player()
|
195 |
+
player_in_turn = players[current_player]
|
196 |
+
move = player_in_turn.get_action(self.board)
|
197 |
+
self.board.do_move(move)
|
198 |
+
if is_shown:
|
199 |
+
self.graphic(self.board, player1.player, player2.player)
|
200 |
+
end, winner = self.board.game_end()
|
201 |
+
if end:
|
202 |
+
if is_shown:
|
203 |
+
if winner != -1:
|
204 |
+
print("Game end. Winner is", players[winner])
|
205 |
+
else:
|
206 |
+
print("Game end. Tie")
|
207 |
+
return winner
|
208 |
+
|
209 |
+
def start_self_play(self, player, is_shown=0, temp=1e-3):
|
210 |
+
"""
|
211 |
+
start a self-play game using a MCTS player, reuse the search tree,
|
212 |
+
and store the self-play data: (state, mcts_probs, z) for training
|
213 |
+
"""
|
214 |
+
self.board.init_board()
|
215 |
+
p1, p2 = self.board.players
|
216 |
+
states, mcts_probs, current_players = [], [], []
|
217 |
+
while True:
|
218 |
+
move, move_probs = player.get_action(self.board,
|
219 |
+
temp=temp,
|
220 |
+
return_prob=1)
|
221 |
+
# store the data
|
222 |
+
states.append(self.board.current_state())
|
223 |
+
mcts_probs.append(move_probs)
|
224 |
+
current_players.append(self.board.current_player)
|
225 |
+
# perform a move
|
226 |
+
self.board.do_move(move)
|
227 |
+
if is_shown:
|
228 |
+
self.graphic(self.board, p1, p2)
|
229 |
+
end, winner = self.board.game_end()
|
230 |
+
if end:
|
231 |
+
# winner from the perspective of the current player of each state
|
232 |
+
winners_z = np.zeros(len(current_players))
|
233 |
+
if winner != -1:
|
234 |
+
winners_z[np.array(current_players) == winner] = 1.0
|
235 |
+
winners_z[np.array(current_players) != winner] = -1.0
|
236 |
+
# reset MCTS root node
|
237 |
+
player.reset_player()
|
238 |
+
if is_shown:
|
239 |
+
if winner != -1:
|
240 |
+
print("Game end. Winner is player:", winner)
|
241 |
+
else:
|
242 |
+
print("Game end. Tie")
|
243 |
+
return winner, zip(states, mcts_probs, winners_z)
|
244 |
+
|
245 |
+
# 多了下面这一串测试代码
|
246 |
+
|
247 |
+
def policy_evaluate(self, n_games=10):
|
248 |
+
"""
|
249 |
+
Evaluate the trained policy by playing against the pure MCTS player
|
250 |
+
Note: this is only for monitoring the progress of training
|
251 |
+
"""
|
252 |
+
current_mcts_player = MCTS_Pure(c_puct=5,
|
253 |
+
n_playout=self.pure_mcts_playout_num)
|
254 |
+
|
255 |
+
# pure_mcts_player = MCTS_Pure(c_puct=5,
|
256 |
+
# n_playout=self.pure_mcts_playout_num)
|
257 |
+
|
258 |
+
pure_mcts_player = Human_Player()
|
259 |
+
win_cnt = defaultdict(int)
|
260 |
+
for i in range(n_games):
|
261 |
+
winner = self.start_play(current_mcts_player,
|
262 |
+
pure_mcts_player,
|
263 |
+
start_player=i % 2,
|
264 |
+
is_shown=1)
|
265 |
+
win_cnt[winner] += 1
|
266 |
+
win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games
|
267 |
+
print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
|
268 |
+
self.pure_mcts_playout_num,
|
269 |
+
win_cnt[1], win_cnt[2], win_cnt[-1]))
|
270 |
+
return win_ratio
|
271 |
+
|
272 |
+
|
273 |
+
if __name__ == '__main__':
|
274 |
+
board_width = 8
|
275 |
+
board_height = 8
|
276 |
+
n_in_row = 5
|
277 |
+
board = Board(width=board_width,
|
278 |
+
height=board_height,
|
279 |
+
n_in_row=n_in_row)
|
280 |
+
task = Game(board)
|
281 |
+
task.policy_evaluate(n_games=10)
|
Gomoku_MCTS/main_worker.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
from collections import defaultdict, deque
|
5 |
+
from game import Board, Game
|
6 |
+
from mcts_pure import MCTSPlayer as MCTS_Pure
|
7 |
+
from mcts_alphaZero import MCTSPlayer
|
8 |
+
import torch.optim as optim
|
9 |
+
# from policy_value_net import PolicyValueNet # Theano and Lasagne
|
10 |
+
# from policy_value_net_pytorch import PolicyValueNet # Pytorch
|
11 |
+
from dueling_net import PolicyValueNet
|
12 |
+
# from policy_value_net_tensorflow import PolicyValueNet # Tensorflow
|
13 |
+
# from policy_value_net_keras import PolicyValueNet # Keras
|
14 |
+
# import joblib
|
15 |
+
from torch.autograd import Variable
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
|
19 |
+
from config.options import *
|
20 |
+
import sys
|
21 |
+
from config.utils import *
|
22 |
+
from torch.backends import cudnn
|
23 |
+
|
24 |
+
import torch
|
25 |
+
|
26 |
+
from tqdm import *
|
27 |
+
from torch.utils.tensorboard import SummaryWriter
|
28 |
+
|
29 |
+
from multiprocessing import Pool
|
30 |
+
|
31 |
+
def set_learning_rate(optimizer, lr):
|
32 |
+
"""Sets the learning rate to the given value"""
|
33 |
+
for param_group in optimizer.param_groups:
|
34 |
+
param_group['lr'] = lr
|
35 |
+
|
36 |
+
def std_log():
|
37 |
+
if get_rank() == 0:
|
38 |
+
save_path = make_path()
|
39 |
+
makedir(config['log_base'])
|
40 |
+
sys.stdout = open(os.path.join(config['log_base'], "{}.txt".format(save_path)), "w")
|
41 |
+
|
42 |
+
|
43 |
+
def init_seeds(seed, cuda_deterministic=True):
|
44 |
+
torch.manual_seed(seed)
|
45 |
+
if cuda_deterministic: # slower, more reproducible
|
46 |
+
cudnn.deterministic = True
|
47 |
+
cudnn.benchmark = False
|
48 |
+
else: # faster, less reproducible
|
49 |
+
cudnn.deterministic = False
|
50 |
+
cudnn.benchmark = True
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
class MainWorker():
|
56 |
+
def __init__(self,device):
|
57 |
+
|
58 |
+
#--- init the set of pipeline -------
|
59 |
+
self.board_width = opts.board_width
|
60 |
+
self.board_height = opts.board_height
|
61 |
+
self.n_in_row = opts.n_in_row
|
62 |
+
self.learn_rate = opts.learn_rate
|
63 |
+
self.lr_multiplier = opts.lr_multiplier
|
64 |
+
self.temp = opts.temp
|
65 |
+
self.n_playout = opts.n_playout
|
66 |
+
self.c_puct = opts.c_puct
|
67 |
+
self.buffer_size = opts.buffer_size
|
68 |
+
self.batch_size = opts.batch_size
|
69 |
+
self.play_batch_size = opts.play_batch_size
|
70 |
+
self.epochs = opts.epochs
|
71 |
+
self.kl_targ = opts.kl_targ
|
72 |
+
self.check_freq = opts.check_freq
|
73 |
+
self.game_batch_num = opts.game_batch_num
|
74 |
+
self.pure_mcts_playout_num = opts.pure_mcts_playout_num
|
75 |
+
|
76 |
+
self.device = device
|
77 |
+
self.use_gpu = torch.device("cuda") == self.device
|
78 |
+
|
79 |
+
self.board = Board(width=self.board_width,
|
80 |
+
height=self.board_height,
|
81 |
+
n_in_row=self.n_in_row)
|
82 |
+
self.game = Game(self.board)
|
83 |
+
|
84 |
+
# The data collection of the history of games
|
85 |
+
self.data_buffer = deque(maxlen=self.buffer_size)
|
86 |
+
|
87 |
+
|
88 |
+
# The best win ratio of the training agent
|
89 |
+
self.best_win_ratio = 0.0
|
90 |
+
|
91 |
+
|
92 |
+
if opts.preload_model:
|
93 |
+
# start training from an initial policy-value net
|
94 |
+
self.policy_value_net = PolicyValueNet(self.board_width,
|
95 |
+
self.board_height,
|
96 |
+
model_file=opts.preload_model,
|
97 |
+
use_gpu=(self.device == "cuda"))
|
98 |
+
|
99 |
+
else:
|
100 |
+
# start training from a new policy-value net
|
101 |
+
self.policy_value_net = PolicyValueNet(self.board_width,
|
102 |
+
self.board_height,
|
103 |
+
use_gpu=(self.device == "cuda"))
|
104 |
+
self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
|
105 |
+
c_puct=self.c_puct,
|
106 |
+
n_playout=self.n_playout,
|
107 |
+
is_selfplay=1)
|
108 |
+
|
109 |
+
# The set of optimizer
|
110 |
+
self.optimizer = optim.Adam(self.policy_value_net.policy_value_net.parameters(),
|
111 |
+
weight_decay=opts.l2_const)
|
112 |
+
# set learning rate
|
113 |
+
set_learning_rate(self.optimizer, self.learn_rate*self.lr_multiplier)
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
def get_equi_data(self, play_data):
|
119 |
+
"""augment the data set by rotation and flipping
|
120 |
+
play_data: [(state, mcts_prob, winner_z), ..., ...]
|
121 |
+
"""
|
122 |
+
extend_data = []
|
123 |
+
for state, mcts_porb, winner in play_data:
|
124 |
+
for i in [1, 2, 3, 4]:
|
125 |
+
# rotate counterclockwise
|
126 |
+
equi_state = np.array([np.rot90(s, i) for s in state])
|
127 |
+
equi_mcts_prob = np.rot90(np.flipud(
|
128 |
+
mcts_porb.reshape(self.board_height, self.board_width)), i)
|
129 |
+
extend_data.append((equi_state,
|
130 |
+
np.flipud(equi_mcts_prob).flatten(),
|
131 |
+
winner))
|
132 |
+
# flip horizontally
|
133 |
+
equi_state = np.array([np.fliplr(s) for s in equi_state])
|
134 |
+
equi_mcts_prob = np.fliplr(equi_mcts_prob)
|
135 |
+
extend_data.append((equi_state,
|
136 |
+
np.flipud(equi_mcts_prob).flatten(),
|
137 |
+
winner))
|
138 |
+
return extend_data
|
139 |
+
|
140 |
+
def job(self, i):
|
141 |
+
game = self.game
|
142 |
+
player = self.mcts_player
|
143 |
+
winner, play_data = game.start_self_play(player,
|
144 |
+
temp=self.temp)
|
145 |
+
play_data = list(play_data)[:]
|
146 |
+
play_data = self.get_equi_data(play_data)
|
147 |
+
|
148 |
+
return play_data
|
149 |
+
|
150 |
+
def collect_selfplay_data(self, n_games=1):
|
151 |
+
"""collect self-play data for training"""
|
152 |
+
# print("[STAGE] Collecting self-play data for training")
|
153 |
+
|
154 |
+
# collection_bar = tqdm( range(n_games))
|
155 |
+
collection_bar = range(n_games)
|
156 |
+
with Pool(4) as p:
|
157 |
+
play_data = p.map(self.job, collection_bar, chunksize=1)
|
158 |
+
self.data_buffer.extend(play_data)
|
159 |
+
# print('\n', 'data buffer size:', len(self.data_buffer))
|
160 |
+
|
161 |
+
def policy_update(self):
|
162 |
+
"""update the policy-value net"""
|
163 |
+
mini_batch = random.sample(self.data_buffer, self.batch_size)
|
164 |
+
state_batch = [data[0] for data in mini_batch]
|
165 |
+
mcts_probs_batch = [data[1] for data in mini_batch]
|
166 |
+
winner_batch = [data[2] for data in mini_batch]
|
167 |
+
old_probs, old_v = self.policy_value_net.policy_value(state_batch)
|
168 |
+
|
169 |
+
epoch_bar = tqdm(range(self.epochs))
|
170 |
+
|
171 |
+
for i in epoch_bar:
|
172 |
+
"""perform a training step"""
|
173 |
+
# wrap in Variable
|
174 |
+
if self.use_gpu:
|
175 |
+
state_batch = Variable(torch.FloatTensor(state_batch).cuda())
|
176 |
+
mcts_probs = Variable(torch.FloatTensor(mcts_probs_batch).cuda())
|
177 |
+
winner_batch = Variable(torch.FloatTensor(winner_batch).cuda())
|
178 |
+
else:
|
179 |
+
state_batch = Variable(torch.FloatTensor(state_batch))
|
180 |
+
mcts_probs = Variable(torch.FloatTensor(mcts_probs_batch))
|
181 |
+
winner_batch = Variable(torch.FloatTensor(winner_batch))
|
182 |
+
|
183 |
+
# zero the parameter gradients
|
184 |
+
self.optimizer.zero_grad()
|
185 |
+
|
186 |
+
# forward
|
187 |
+
log_act_probs, value = self.policy_value_net.policy_value_net(state_batch)
|
188 |
+
# define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2
|
189 |
+
# Note: the L2 penalty is incorporated in optimizer
|
190 |
+
value_loss = F.mse_loss(value.view(-1), winner_batch)
|
191 |
+
policy_loss = -torch.mean(torch.sum(mcts_probs*log_act_probs, 1))
|
192 |
+
loss = value_loss + policy_loss
|
193 |
+
# backward and optimize
|
194 |
+
loss.backward()
|
195 |
+
self.optimizer.step()
|
196 |
+
# calc policy entropy, for monitoring only
|
197 |
+
entropy = -torch.mean(
|
198 |
+
torch.sum(torch.exp(log_act_probs) * log_act_probs, 1)
|
199 |
+
)
|
200 |
+
loss = loss.item()
|
201 |
+
entropy = entropy.item()
|
202 |
+
|
203 |
+
new_probs, new_v = self.policy_value_net.policy_value(state_batch)
|
204 |
+
kl = np.mean(np.sum(old_probs * (
|
205 |
+
np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
|
206 |
+
axis=1)
|
207 |
+
)
|
208 |
+
if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly
|
209 |
+
break
|
210 |
+
|
211 |
+
epoch_bar.set_description(f"training epoch {i}")
|
212 |
+
epoch_bar.set_postfix( new_v =new_v, kl = kl)
|
213 |
+
|
214 |
+
# adaptively adjust the learning rate
|
215 |
+
if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
|
216 |
+
self.lr_multiplier /= 1.5
|
217 |
+
elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
|
218 |
+
self.lr_multiplier *= 1.5
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
explained_var_old = (1 -
|
223 |
+
np.var(np.array(winner_batch) - old_v.flatten()) /
|
224 |
+
np.var(np.array(winner_batch)))
|
225 |
+
explained_var_new = (1 -
|
226 |
+
np.var(np.array(winner_batch) - new_v.flatten()) /
|
227 |
+
np.var(np.array(winner_batch)))
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
return kl, loss, entropy,explained_var_old, explained_var_new
|
233 |
+
|
234 |
+
def policy_evaluate(self, n_games=10):
|
235 |
+
"""
|
236 |
+
Evaluate the trained policy by playing against the pure MCTS player
|
237 |
+
Note: this is only for monitoring the progress of training
|
238 |
+
"""
|
239 |
+
current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
|
240 |
+
c_puct=self.c_puct,
|
241 |
+
n_playout=self.n_playout)
|
242 |
+
pure_mcts_player = MCTS_Pure(c_puct=5,
|
243 |
+
n_playout=self.pure_mcts_playout_num)
|
244 |
+
win_cnt = defaultdict(int)
|
245 |
+
for i in range(n_games):
|
246 |
+
|
247 |
+
winner = self.game.start_play(
|
248 |
+
pure_mcts_player,current_mcts_player,
|
249 |
+
start_player=i % 2,
|
250 |
+
is_shown=0)
|
251 |
+
win_cnt[winner] += 1
|
252 |
+
print(f" {i}_th winner:" , winner)
|
253 |
+
win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games
|
254 |
+
print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
|
255 |
+
self.pure_mcts_playout_num,
|
256 |
+
win_cnt[1], win_cnt[2], win_cnt[-1]))
|
257 |
+
return win_ratio
|
258 |
+
|
259 |
+
def run(self):
|
260 |
+
"""run the training pipeline"""
|
261 |
+
try:
|
262 |
+
|
263 |
+
batch_bar = tqdm(range(self.game_batch_num))
|
264 |
+
for i in batch_bar:
|
265 |
+
self.collect_selfplay_data(self.play_batch_size)
|
266 |
+
|
267 |
+
if len(self.data_buffer) > self.batch_size:
|
268 |
+
kl, loss, entropy,explained_var_old, explained_var_new = self.policy_update()
|
269 |
+
|
270 |
+
writer.add_scalar("policy_update/kl", kl ,i )
|
271 |
+
writer.add_scalar("policy_update/loss", loss ,i)
|
272 |
+
writer.add_scalar("policy_update/entropy", entropy ,i)
|
273 |
+
writer.add_scalar("policy_update/explained_var_old", explained_var_old,i)
|
274 |
+
writer.add_scalar("policy_update/explained_var_new ", explained_var_new ,i)
|
275 |
+
|
276 |
+
|
277 |
+
batch_bar.set_description(f"game batch num {i}")
|
278 |
+
|
279 |
+
# check the performance of the current model,
|
280 |
+
# and save the model params
|
281 |
+
if (i+1) % self.check_freq == 0:
|
282 |
+
win_ratio = self.policy_evaluate()
|
283 |
+
|
284 |
+
batch_bar.set_description(f"game batch num {i+1}")
|
285 |
+
writer.add_scalar("evaluate/explained_var_new ", win_ratio ,i)
|
286 |
+
batch_bar.set_postfix(loss= loss, entropy= entropy,win_ratio =win_ratio)
|
287 |
+
|
288 |
+
save_model(self.policy_value_net,"current_policy.model")
|
289 |
+
if win_ratio > self.best_win_ratio:
|
290 |
+
print("New best policy!!!!!!!!")
|
291 |
+
self.best_win_ratio = win_ratio
|
292 |
+
# update the best_policy
|
293 |
+
save_model(self.policy_value_net,"best_policy.model")
|
294 |
+
if (self.best_win_ratio == 1.0 and
|
295 |
+
self.pure_mcts_playout_num < 5000):
|
296 |
+
self.pure_mcts_playout_num += 1000
|
297 |
+
self.best_win_ratio = 0.0
|
298 |
+
except KeyboardInterrupt:
|
299 |
+
print('\n\rquit')
|
300 |
+
|
301 |
+
|
302 |
+
if __name__ == "__main__":
|
303 |
+
print("START train....")
|
304 |
+
|
305 |
+
# ------init set-----------
|
306 |
+
|
307 |
+
if opts.std_log:
|
308 |
+
std_log()
|
309 |
+
writer = visualizer()
|
310 |
+
|
311 |
+
|
312 |
+
if opts.distributed:
|
313 |
+
torch.distributed.init_process_group(backend="nccl")
|
314 |
+
local_rank = torch.distributed.get_rank()
|
315 |
+
torch.cuda.set_device(local_rank)
|
316 |
+
device = torch.device("cuda", local_rank)
|
317 |
+
init_seeds(opts.seed + local_rank)
|
318 |
+
|
319 |
+
else:
|
320 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
321 |
+
init_seeds(opts.seed)
|
322 |
+
|
323 |
+
print("seed: ",opts.seed )
|
324 |
+
print("device:" , device)
|
325 |
+
|
326 |
+
|
327 |
+
if opts.split == "train":
|
328 |
+
training_pipeline = MainWorker(device)
|
329 |
+
training_pipeline.run()
|
330 |
+
|
331 |
+
if get_rank() == 0 and opts.split == "test":
|
332 |
+
training_pipeline = MainWorker(device)
|
333 |
+
training_pipeline.policy_value_net()
|
334 |
+
|
Gomoku_MCTS/mcts_alphaZero.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Monte Carlo Tree Search in AlphaGo Zero style, which uses a policy-value
|
4 |
+
network to guide the tree search and evaluate the leaf nodes
|
5 |
+
|
6 |
+
@author: Junxiao Song
|
7 |
+
"""
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import copy
|
11 |
+
import time
|
12 |
+
from concurrent.futures import ThreadPoolExecutor
|
13 |
+
import threading
|
14 |
+
|
15 |
+
|
16 |
+
def softmax(x):
|
17 |
+
probs = np.exp(x - np.max(x))
|
18 |
+
probs /= np.sum(probs)
|
19 |
+
return probs
|
20 |
+
|
21 |
+
|
22 |
+
class TreeNode(object):
|
23 |
+
"""A node in the MCTS tree.
|
24 |
+
|
25 |
+
Each node keeps track of its own value Q, prior probability P, and
|
26 |
+
its visit-count-adjusted prior score u.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, parent, prior_p):
|
30 |
+
self._parent = parent
|
31 |
+
self._children = {} # a map from action to TreeNode
|
32 |
+
self._n_visits = 0
|
33 |
+
self._Q = 0
|
34 |
+
self._u = 0
|
35 |
+
self._P = prior_p
|
36 |
+
|
37 |
+
def expand(self, action_priors):
|
38 |
+
"""Expand tree by creating new children.
|
39 |
+
action_priors: a list of tuples of actions and their prior probability
|
40 |
+
according to the policy function.
|
41 |
+
"""
|
42 |
+
for action, prob in action_priors:
|
43 |
+
if action not in self._children:
|
44 |
+
self._children[action] = TreeNode(self, prob)
|
45 |
+
|
46 |
+
def select(self, c_puct):
|
47 |
+
"""Select action among children that gives maximum action value Q
|
48 |
+
plus bonus u(P).
|
49 |
+
Return: A tuple of (action, next_node)
|
50 |
+
"""
|
51 |
+
return max(self._children.items(),
|
52 |
+
key=lambda act_node: act_node[1].get_value(c_puct))
|
53 |
+
|
54 |
+
def update(self, leaf_value):
|
55 |
+
"""Update node values from leaf evaluation.
|
56 |
+
leaf_value: the value of subtree evaluation from the current player's
|
57 |
+
perspective.
|
58 |
+
"""
|
59 |
+
# Count visit.
|
60 |
+
self._n_visits += 1
|
61 |
+
# Update Q, a running average of values for all visits.
|
62 |
+
self._Q += 1.0*(leaf_value - self._Q) / self._n_visits
|
63 |
+
|
64 |
+
def update_recursive(self, leaf_value):
|
65 |
+
"""Like a call to update(), but applied recursively for all ancestors.
|
66 |
+
"""
|
67 |
+
# If it is not root, this node's parent should be updated first.
|
68 |
+
if self._parent:
|
69 |
+
self._parent.update_recursive(-leaf_value)
|
70 |
+
self.update(leaf_value)
|
71 |
+
|
72 |
+
def get_value(self, c_puct):
|
73 |
+
"""Calculate and return the value for this node.
|
74 |
+
It is a combination of leaf evaluations Q, and this node's prior
|
75 |
+
adjusted for its visit count, u.
|
76 |
+
c_puct: a number in (0, inf) controlling the relative impact of
|
77 |
+
value Q, and prior probability P, on this node's score.
|
78 |
+
"""
|
79 |
+
self._u = (c_puct * self._P *
|
80 |
+
np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
|
81 |
+
return self._Q + self._u
|
82 |
+
|
83 |
+
def is_leaf(self):
|
84 |
+
"""Check if leaf node (i.e. no nodes below this have been expanded)."""
|
85 |
+
return self._children == {}
|
86 |
+
|
87 |
+
def is_root(self):
|
88 |
+
return self._parent is None
|
89 |
+
|
90 |
+
|
91 |
+
class MCTS(object):
|
92 |
+
"""An implementation of Monte Carlo Tree Search."""
|
93 |
+
|
94 |
+
def __init__(self, policy_value_fn, c_puct=5, n_playout=10000):
|
95 |
+
"""
|
96 |
+
policy_value_fn: a function that takes in a board state and outputs
|
97 |
+
a list of (action, probability) tuples and also a score in [-1, 1]
|
98 |
+
(i.e. the expected value of the end game score from the current
|
99 |
+
player's perspective) for the current player.
|
100 |
+
c_puct: a number in (0, inf) that controls how quickly exploration
|
101 |
+
converges to the maximum-value policy. A higher value means
|
102 |
+
relying on the prior more.
|
103 |
+
"""
|
104 |
+
self._root = TreeNode(None, 1.0)
|
105 |
+
self._policy = policy_value_fn
|
106 |
+
self._c_puct = c_puct
|
107 |
+
self._n_playout = n_playout
|
108 |
+
|
109 |
+
def _playout(self, state, lock=None):
|
110 |
+
"""Run a single playout from the root to the leaf, getting a value at
|
111 |
+
the leaf and propagating it back through its parents.
|
112 |
+
State is modified in-place, so a copy must be provided.
|
113 |
+
"""
|
114 |
+
node = self._root
|
115 |
+
if lock is not None:
|
116 |
+
lock.acquire()
|
117 |
+
while(1):
|
118 |
+
if node.is_leaf():
|
119 |
+
break
|
120 |
+
# Greedily select next move.
|
121 |
+
action, node = node.select(self._c_puct)
|
122 |
+
state.do_move(action)
|
123 |
+
if lock is not None:
|
124 |
+
lock.release()
|
125 |
+
# Evaluate the leaf using a network which outputs a list of
|
126 |
+
# (action, probability) tuples p and also a score v in [-1, 1]
|
127 |
+
# for the current player.
|
128 |
+
action_probs, leaf_value = self._policy(state)
|
129 |
+
# Check for end of game.
|
130 |
+
end, winner = state.game_end()
|
131 |
+
if lock is not None:
|
132 |
+
lock.acquire()
|
133 |
+
if not end:
|
134 |
+
node.expand(action_probs)
|
135 |
+
else:
|
136 |
+
# for end state,return the "true" leaf_value
|
137 |
+
if winner == -1: # tie
|
138 |
+
leaf_value = 0.0
|
139 |
+
else:
|
140 |
+
leaf_value = (
|
141 |
+
1.0 if winner == state.get_current_player() else -1.0
|
142 |
+
)
|
143 |
+
|
144 |
+
# Update value and visit count of nodes in this traversal.
|
145 |
+
node.update_recursive(-leaf_value)
|
146 |
+
if lock is not None:
|
147 |
+
lock.release()
|
148 |
+
|
149 |
+
def get_move_probs(self, state, temp=1e-3):
|
150 |
+
"""Run all playouts sequentially and return the available actions and
|
151 |
+
their corresponding probabilities.
|
152 |
+
state: the current game state
|
153 |
+
temp: temperature parameter in (0, 1] controls the level of exploration
|
154 |
+
"""
|
155 |
+
|
156 |
+
start_time_averge = 0
|
157 |
+
|
158 |
+
### test multi-thread
|
159 |
+
lock = threading.Lock()
|
160 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
161 |
+
for n in range(self._n_playout):
|
162 |
+
start_time = time.time()
|
163 |
+
|
164 |
+
state_copy = copy.deepcopy(state)
|
165 |
+
executor.submit(self._playout, state_copy, lock)
|
166 |
+
start_time_averge += (time.time() - start_time)
|
167 |
+
### end test multi-thread
|
168 |
+
|
169 |
+
# t = time.time()
|
170 |
+
# for n in range(self._n_playout):
|
171 |
+
# start_time = time.time()
|
172 |
+
|
173 |
+
# state_copy = copy.deepcopy(state)
|
174 |
+
# self._playout(state_copy)
|
175 |
+
# start_time_averge += (time.time() - start_time)
|
176 |
+
# print('!!time!!:', time.time() - t)
|
177 |
+
|
178 |
+
# print(f" My MCTS sum_time: {start_time_averge }, total_simulation: {self._n_playout}")
|
179 |
+
|
180 |
+
|
181 |
+
# calc the move probabilities based on visit counts at the root node
|
182 |
+
act_visits = [(act, node._n_visits)
|
183 |
+
for act, node in self._root._children.items()]
|
184 |
+
acts, visits = zip(*act_visits)
|
185 |
+
act_probs = softmax(1.0/temp * np.log(np.array(visits) + 1e-10))
|
186 |
+
|
187 |
+
return acts, act_probs
|
188 |
+
|
189 |
+
def update_with_move(self, last_move):
|
190 |
+
"""Step forward in the tree, keeping everything we already know
|
191 |
+
about the subtree.
|
192 |
+
"""
|
193 |
+
if last_move in self._root._children:
|
194 |
+
self._root = self._root._children[last_move]
|
195 |
+
self._root._parent = None
|
196 |
+
else:
|
197 |
+
self._root = TreeNode(None, 1.0)
|
198 |
+
|
199 |
+
def __str__(self):
|
200 |
+
return "MCTS"
|
201 |
+
|
202 |
+
|
203 |
+
class MCTSPlayer(object):
|
204 |
+
"""AI player based on MCTS"""
|
205 |
+
|
206 |
+
def __init__(self, policy_value_function,
|
207 |
+
c_puct=5, n_playout=2000, is_selfplay=0):
|
208 |
+
self.mcts = MCTS(policy_value_function, c_puct, n_playout)
|
209 |
+
self._is_selfplay = is_selfplay
|
210 |
+
|
211 |
+
def set_player_ind(self, p):
|
212 |
+
self.player = p
|
213 |
+
|
214 |
+
def reset_player(self):
|
215 |
+
self.mcts.update_with_move(-1)
|
216 |
+
|
217 |
+
def get_action(self, board, temp=1e-3, return_prob=0):
|
218 |
+
sensible_moves = board.availables
|
219 |
+
# the pi vector returned by MCTS as in the alphaGo Zero paper
|
220 |
+
move_probs = np.zeros(board.width*board.height)
|
221 |
+
if len(sensible_moves) > 0:
|
222 |
+
acts, probs = self.mcts.get_move_probs(board, temp)
|
223 |
+
move_probs[list(acts)] = probs
|
224 |
+
if self._is_selfplay:
|
225 |
+
# add Dirichlet Noise for exploration (needed for
|
226 |
+
# self-play training)
|
227 |
+
move = np.random.choice(
|
228 |
+
acts,
|
229 |
+
p=0.75*probs + 0.25*np.random.dirichlet(0.3*np.ones(len(probs)))
|
230 |
+
)
|
231 |
+
# update the root node and reuse the search tree
|
232 |
+
self.mcts.update_with_move(move)
|
233 |
+
else:
|
234 |
+
# with the default temp=1e-3, it is almost equivalent
|
235 |
+
# to choosing the move with the highest prob
|
236 |
+
move = np.random.choice(acts, p=probs)
|
237 |
+
# reset the root node
|
238 |
+
self.mcts.update_with_move(-1)
|
239 |
+
# location = board.move_to_location(move)
|
240 |
+
# print("AI move: %d,%d\n" % (location[0], location[1]))
|
241 |
+
|
242 |
+
if return_prob:
|
243 |
+
return move, move_probs
|
244 |
+
else:
|
245 |
+
return move
|
246 |
+
else:
|
247 |
+
print("WARNING: the board is full")
|
248 |
+
|
249 |
+
def __str__(self):
|
250 |
+
return "MCTS {}".format(self.player)
|
Gomoku_MCTS/mcts_pure.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import copy
|
5 |
+
from operator import itemgetter
|
6 |
+
import time
|
7 |
+
|
8 |
+
|
9 |
+
def rollout_policy_fn(board):
|
10 |
+
"""a coarse, fast version of policy_fn used in the rollout phase."""
|
11 |
+
# rollout randomly
|
12 |
+
action_probs = np.random.rand(len(board.availables))
|
13 |
+
return zip(board.availables, action_probs)
|
14 |
+
|
15 |
+
# 决策价值函数
|
16 |
+
def policy_value_fn(board):
|
17 |
+
"""a function that takes in a state and outputs a list of (action, probability)
|
18 |
+
tuples and a score for the state"""
|
19 |
+
# return uniform probabilities and 0 score for pure MCTS
|
20 |
+
action_probs = np.ones(len(board.availables))/len(board.availables)
|
21 |
+
return zip(board.availables, action_probs), 0
|
22 |
+
|
23 |
+
|
24 |
+
class TreeNode(object):
|
25 |
+
"""A node in the MCTS tree. Each node keeps track of its own value Q,
|
26 |
+
prior probability P, and its visit-count-adjusted prior score u.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, parent, prior_p):
|
30 |
+
self._parent = parent
|
31 |
+
self._children = {} # a map from action to TreeNode
|
32 |
+
self._n_visits = 0
|
33 |
+
self._Q = 0
|
34 |
+
self._u = 0
|
35 |
+
self._P = prior_p
|
36 |
+
|
37 |
+
def expand(self, action_priors):
|
38 |
+
"""Expand tree by creating new children.
|
39 |
+
action_priors: a list of tuples of actions and their prior probability
|
40 |
+
according to the policy function.
|
41 |
+
"""
|
42 |
+
for action, prob in action_priors:
|
43 |
+
if action not in self._children:
|
44 |
+
self._children[action] = TreeNode(self, prob)
|
45 |
+
|
46 |
+
def select(self, c_puct):
|
47 |
+
"""Select action among children that gives maximum action value Q
|
48 |
+
plus bonus u(P).
|
49 |
+
Return: A tuple of (action, next_node)
|
50 |
+
"""
|
51 |
+
return max(self._children.items(),
|
52 |
+
key=lambda act_node: act_node[1].get_value(c_puct))
|
53 |
+
|
54 |
+
def update(self, leaf_value):
|
55 |
+
"""Update node values from leaf evaluation.
|
56 |
+
leaf_value: the value of subtree evaluation from the current player's
|
57 |
+
perspective.
|
58 |
+
"""
|
59 |
+
# Count visit.
|
60 |
+
self._n_visits += 1
|
61 |
+
# Update Q, a running average of values for all visits.
|
62 |
+
# print("=====================================")
|
63 |
+
# print("Before, Q: {}, visits: {}, leaf_value: {}".format(self._Q, self._n_visits,leaf_value))
|
64 |
+
self._Q += 1.0*(leaf_value - self._Q) / self._n_visits
|
65 |
+
# print("After, Q: {}, visits: {}, leaf_value: {}".format(self._Q, self._n_visits,leaf_value))
|
66 |
+
|
67 |
+
|
68 |
+
def update_recursive(self, leaf_value):
|
69 |
+
"""Like a call to update(), but applied recursively for all ancestors.
|
70 |
+
"""
|
71 |
+
# If it is not root, this node's parent should be updated first.
|
72 |
+
if self._parent:
|
73 |
+
self._parent.update_recursive(-leaf_value)
|
74 |
+
self.update(leaf_value)
|
75 |
+
|
76 |
+
def get_value(self, c_puct):
|
77 |
+
"""Calculate and return the value for this node.
|
78 |
+
It is a combination of leaf evaluations Q, and this node's prior
|
79 |
+
adjusted for its visit count, u.
|
80 |
+
c_puct: a number in (0, inf) controlling the relative impact of
|
81 |
+
value Q, and prior probability P, on this node's score.
|
82 |
+
"""
|
83 |
+
self._u = (c_puct * self._P *
|
84 |
+
np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
|
85 |
+
return self._Q + self._u
|
86 |
+
|
87 |
+
def is_leaf(self):
|
88 |
+
"""Check if leaf node (i.e. no nodes below this have been expanded).
|
89 |
+
"""
|
90 |
+
return self._children == {}
|
91 |
+
|
92 |
+
def is_root(self):
|
93 |
+
return self._parent is None
|
94 |
+
|
95 |
+
|
96 |
+
class MCTS(object):
|
97 |
+
"""A simple implementation of Monte Carlo Tree Search."""
|
98 |
+
|
99 |
+
def __init__(self, policy_value_fn, c_puct=5, n_playout=2000):
|
100 |
+
"""
|
101 |
+
policy_value_fn: a function that takes in a board state and outputs
|
102 |
+
a list of (action, probability) tuples and also a score in [-1, 1]
|
103 |
+
(i.e. the expected value of the end game score from the current
|
104 |
+
player's perspective) for the current player.
|
105 |
+
c_puct: a number in (0, inf) that controls how quickly exploration
|
106 |
+
converges to the maximum-value policy. A higher value means
|
107 |
+
relying on the prior more. ???
|
108 |
+
"""
|
109 |
+
self._root = TreeNode(None, 1.0)
|
110 |
+
self._policy = policy_value_fn
|
111 |
+
self._c_puct = c_puct
|
112 |
+
self._n_playout = n_playout
|
113 |
+
|
114 |
+
def _playout(self, state):
|
115 |
+
"""Run a single playout from the root to the leaf, getting a value at
|
116 |
+
the leaf and propagating it back through its parents.
|
117 |
+
State is modified in-place, so a copy must be provided.
|
118 |
+
"""
|
119 |
+
node = self._root
|
120 |
+
while(1):
|
121 |
+
if node.is_leaf():
|
122 |
+
|
123 |
+
break
|
124 |
+
# Greedily select next move.
|
125 |
+
action, node = node.select(self._c_puct)
|
126 |
+
state.do_move(action)
|
127 |
+
|
128 |
+
action_probs, _ = self._policy(state)
|
129 |
+
# Check for end of game
|
130 |
+
end, winner = state.game_end()
|
131 |
+
if not end:
|
132 |
+
node.expand(action_probs)
|
133 |
+
# Evaluate the leaf node by random rollout
|
134 |
+
leaf_value = self._evaluate_rollout(state)
|
135 |
+
# Update value and visit count of nodes in this traversal.
|
136 |
+
node.update_recursive(-leaf_value)
|
137 |
+
|
138 |
+
def _evaluate_rollout(self, state, limit=1000):
|
139 |
+
"""Use the rollout policy to play until the end of the game,
|
140 |
+
returning +1 if the current player wins, -1 if the opponent wins,
|
141 |
+
and 0 if it is a tie.
|
142 |
+
"""
|
143 |
+
player = state.get_current_player()
|
144 |
+
for i in range(limit):
|
145 |
+
end, winner = state.game_end()
|
146 |
+
if end:
|
147 |
+
break
|
148 |
+
action_probs = rollout_policy_fn(state)
|
149 |
+
max_action = max(action_probs, key=itemgetter(1))[0]
|
150 |
+
state.do_move(max_action)
|
151 |
+
else:
|
152 |
+
# If no break from the loop, issue a warning.
|
153 |
+
print("WARNING: rollout reached move limit")
|
154 |
+
if winner == -1: # tie
|
155 |
+
return 0
|
156 |
+
else:
|
157 |
+
return 1 if winner == player else -1
|
158 |
+
|
159 |
+
def get_move(self, state):
|
160 |
+
"""Runs all playouts sequentially and returns the most visited action.
|
161 |
+
state: the current game state
|
162 |
+
|
163 |
+
Return: the selected action
|
164 |
+
"""
|
165 |
+
start_time = time.time()
|
166 |
+
# n_playout 探索的次数
|
167 |
+
for n in range(self._n_playout):
|
168 |
+
state_copy = copy.deepcopy(state)
|
169 |
+
self._playout(state_copy)
|
170 |
+
|
171 |
+
need_time = time.time() - start_time
|
172 |
+
|
173 |
+
print(f" PureMCTS sum_time: {need_time / self._n_playout }, total_simulation: {self._n_playout}")
|
174 |
+
|
175 |
+
return max(self._root._children.items(),key=lambda act_node: act_node[1]._n_visits)[0], need_time / self._n_playout
|
176 |
+
|
177 |
+
def update_with_move(self, last_move):
|
178 |
+
"""Step forward in the tree, keeping everything we already know
|
179 |
+
about the subtree.
|
180 |
+
"""
|
181 |
+
if last_move in self._root._children:
|
182 |
+
self._root = self._root._children[last_move]
|
183 |
+
self._root._parent = None
|
184 |
+
else:
|
185 |
+
self._root = TreeNode(None, 1.0)
|
186 |
+
|
187 |
+
def __str__(self):
|
188 |
+
return "MCTS"
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
class MCTSPlayer(object):
|
193 |
+
"""AI player based on MCTS"""
|
194 |
+
def __init__(self, c_puct=5, n_playout=2000):
|
195 |
+
self.mcts = MCTS(policy_value_fn, c_puct, n_playout)
|
196 |
+
|
197 |
+
def set_player_ind(self, p):
|
198 |
+
self.player = p
|
199 |
+
|
200 |
+
def reset_player(self):
|
201 |
+
self.mcts.update_with_move(-1)
|
202 |
+
|
203 |
+
def get_action(self, board):
|
204 |
+
sensible_moves = board.availables
|
205 |
+
if len(sensible_moves) > 0:
|
206 |
+
move, simul_mean_time = self.mcts.get_move(board)
|
207 |
+
self.mcts.update_with_move(-1)
|
208 |
+
print("MCTS move:", move)
|
209 |
+
return move, simul_mean_time
|
210 |
+
else:
|
211 |
+
print("WARNING: the board is full")
|
212 |
+
|
213 |
+
|
214 |
+
def __str__(self):
|
215 |
+
return "MCTS {}".format(self.player)
|
216 |
+
|
217 |
+
|
218 |
+
# 多了下面这一串代码
|
219 |
+
|
220 |
+
class Human_Player(object):
|
221 |
+
def __init__(self):
|
222 |
+
pass
|
223 |
+
|
224 |
+
|
225 |
+
def set_player_ind(self, p):
|
226 |
+
self.player = p
|
227 |
+
|
228 |
+
|
229 |
+
def get_action(self, board):
|
230 |
+
|
231 |
+
|
232 |
+
sensible_moves = board.availables
|
233 |
+
if len(sensible_moves) > 0:
|
234 |
+
# print(sensible_moves)
|
235 |
+
|
236 |
+
move = int(input("Input the move:"))
|
237 |
+
while (move not in sensible_moves ):
|
238 |
+
print(sensible_moves)
|
239 |
+
move = int(input("Input the move again:"))
|
240 |
+
return move
|
241 |
+
else:
|
242 |
+
print("WARNING: the board is full")
|
243 |
+
|
244 |
+
def __str__(self):
|
245 |
+
return "Human {}".format(self.player)
|
246 |
+
|
Gomoku_MCTS/policy_value_net_pytorch.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
An implementation of the policyValueNet in PyTorch
|
4 |
+
Tested in PyTorch 0.2.0 and 0.3.0
|
5 |
+
|
6 |
+
@author: Junxiao Song
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.optim as optim
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch.autograd import Variable
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
class Net(nn.Module):
|
21 |
+
"""policy-value network module"""
|
22 |
+
def __init__(self, board_width, board_height):
|
23 |
+
super(Net, self).__init__()
|
24 |
+
|
25 |
+
self.board_width = board_width
|
26 |
+
self.board_height = board_height
|
27 |
+
# common layers
|
28 |
+
self.conv1 = nn.Conv2d(4, 32, kernel_size=3, padding=1)
|
29 |
+
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
|
30 |
+
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
|
31 |
+
# action policy layers
|
32 |
+
self.act_conv1 = nn.Conv2d(128, 4, kernel_size=1)
|
33 |
+
self.act_fc1 = nn.Linear(4*board_width*board_height,
|
34 |
+
board_width*board_height)
|
35 |
+
# state value layers
|
36 |
+
self.val_conv1 = nn.Conv2d(128, 2, kernel_size=1)
|
37 |
+
self.val_fc1 = nn.Linear(2*board_width*board_height, 64)
|
38 |
+
self.val_fc2 = nn.Linear(64, 1)
|
39 |
+
|
40 |
+
def forward(self, state_input):
|
41 |
+
# common layers
|
42 |
+
x = F.relu(self.conv1(state_input))
|
43 |
+
x = F.relu(self.conv2(x))
|
44 |
+
x = F.relu(self.conv3(x))
|
45 |
+
# action policy layers
|
46 |
+
x_act = F.relu(self.act_conv1(x))
|
47 |
+
x_act = x_act.view(-1, 4*self.board_width*self.board_height)
|
48 |
+
x_act = F.log_softmax(self.act_fc1(x_act))
|
49 |
+
# state value layers
|
50 |
+
x_val = F.relu(self.val_conv1(x))
|
51 |
+
x_val = x_val.view(-1, 2*self.board_width*self.board_height)
|
52 |
+
x_val = F.relu(self.val_fc1(x_val))
|
53 |
+
x_val = F.tanh(self.val_fc2(x_val))
|
54 |
+
return x_act, x_val
|
55 |
+
|
56 |
+
|
57 |
+
class PolicyValueNet():
|
58 |
+
"""policy-value network """
|
59 |
+
def __init__(self, board_width, board_height,
|
60 |
+
model_file=None, use_gpu=False):
|
61 |
+
self.use_gpu = use_gpu
|
62 |
+
self.board_width = board_width
|
63 |
+
self.board_height = board_height
|
64 |
+
|
65 |
+
# the policy value net module
|
66 |
+
if self.use_gpu:
|
67 |
+
self.policy_value_net = Net(board_width, board_height).cuda()
|
68 |
+
else:
|
69 |
+
self.policy_value_net = Net(board_width, board_height)
|
70 |
+
|
71 |
+
if model_file:
|
72 |
+
net_params = torch.load(model_file)
|
73 |
+
self.policy_value_net.load_state_dict(net_params)
|
74 |
+
|
75 |
+
def policy_value(self, state_batch):
|
76 |
+
"""
|
77 |
+
input: a batch of states
|
78 |
+
output: a batch of action probabilities and state values
|
79 |
+
"""
|
80 |
+
if self.use_gpu:
|
81 |
+
state_batch = Variable(torch.FloatTensor(state_batch).cuda())
|
82 |
+
log_act_probs, value = self.policy_value_net(state_batch)
|
83 |
+
act_probs = np.exp(log_act_probs.data.cpu().numpy())
|
84 |
+
return act_probs, value.data.cpu().numpy()
|
85 |
+
else:
|
86 |
+
state_batch = Variable(torch.FloatTensor(state_batch))
|
87 |
+
log_act_probs, value = self.policy_value_net(state_batch)
|
88 |
+
act_probs = np.exp(log_act_probs.data.numpy())
|
89 |
+
return act_probs, value.data.numpy()
|
90 |
+
|
91 |
+
def policy_value_fn(self, board):
|
92 |
+
"""
|
93 |
+
input: board
|
94 |
+
output: a list of (action, probability) tuples for each available
|
95 |
+
action and the score of the board state
|
96 |
+
"""
|
97 |
+
legal_positions = board.availables
|
98 |
+
current_state = np.ascontiguousarray(board.current_state().reshape(
|
99 |
+
-1, 4, self.board_width, self.board_height))
|
100 |
+
if self.use_gpu:
|
101 |
+
log_act_probs, value = self.policy_value_net(
|
102 |
+
Variable(torch.from_numpy(current_state)).cuda().float())
|
103 |
+
act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten())
|
104 |
+
else:
|
105 |
+
log_act_probs, value = self.policy_value_net(
|
106 |
+
Variable(torch.from_numpy(current_state)).float())
|
107 |
+
act_probs = np.exp(log_act_probs.data.numpy().flatten())
|
108 |
+
act_probs = zip(legal_positions, act_probs[legal_positions])
|
109 |
+
value = value.data[0][0]
|
110 |
+
return act_probs, value
|
111 |
+
|
112 |
+
|
113 |
+
# 搬到main_worker
|
114 |
+
|
115 |
+
def train_step(self, state_batch, mcts_probs, winner_batch, lr):
|
116 |
+
"""perform a training step"""
|
117 |
+
|
118 |
+
# self.use_gpu = True
|
119 |
+
# wrap in Variable
|
120 |
+
if self.use_gpu:
|
121 |
+
state_batch = Variable(torch.FloatTensor(state_batch).cuda())
|
122 |
+
mcts_probs = Variable(torch.FloatTensor(mcts_probs).cuda())
|
123 |
+
winner_batch = Variable(torch.FloatTensor(winner_batch).cuda())
|
124 |
+
else:
|
125 |
+
state_batch = Variable(torch.FloatTensor(state_batch))
|
126 |
+
mcts_probs = Variable(torch.FloatTensor(mcts_probs))
|
127 |
+
winner_batch = Variable(torch.FloatTensor(winner_batch))
|
128 |
+
|
129 |
+
# zero the parameter gradients
|
130 |
+
self.optimizer.zero_grad()
|
131 |
+
# set learning rate
|
132 |
+
set_learning_rate(self.optimizer, lr)
|
133 |
+
|
134 |
+
# forward
|
135 |
+
log_act_probs, value = self.policy_value_net(state_batch)
|
136 |
+
# define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2
|
137 |
+
# Note: the L2 penalty is incorporated in optimizer
|
138 |
+
value_loss = F.mse_loss(value.view(-1), winner_batch)
|
139 |
+
policy_loss = -torch.mean(torch.sum(mcts_probs*log_act_probs, 1))
|
140 |
+
loss = value_loss + policy_loss
|
141 |
+
# backward and optimize
|
142 |
+
loss.backward()
|
143 |
+
self.optimizer.step()
|
144 |
+
# calc policy entropy, for monitoring only
|
145 |
+
entropy = -torch.mean(
|
146 |
+
torch.sum(torch.exp(log_act_probs) * log_act_probs, 1)
|
147 |
+
)
|
148 |
+
# return loss.data[0], entropy.data[0]
|
149 |
+
#for pytorch version >= 0.5 please use the following line instead.
|
150 |
+
return loss.item(), entropy.item()
|
151 |
+
|
152 |
+
# def get_policy_param(self):
|
153 |
+
# net_params = self.policy_value_net.state_dict()
|
154 |
+
# return net_params
|
155 |
+
|
156 |
+
# def save_model(self, model_file):
|
157 |
+
# """ save model params to file """
|
158 |
+
# net_params = self.get_policy_param() # get model params
|
159 |
+
# torch.save(net_params, model_file)
|
Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183498.LAPTOP-5AN2UHOO
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f64431c679947fac92ef87e7f3d3b6a75c0cdf82e6fd0383451a98d778b7b21e
|
3 |
+
size 40
|
Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183516.LAPTOP-5AN2UHOO
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aaa7025d5d1daa88dce58231e0fba4d7a04391612c696e4c2e23292ad4169d80
|
3 |
+
size 40
|
Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183568.LAPTOP-5AN2UHOO
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:440fac7c819d3368da1126c35e1a146b4ec3a3e614cb3c6e7e10063f9f0ced3c
|
3 |
+
size 40
|
Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183629.LAPTOP-5AN2UHOO
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e45fb4fd64ac1d0a3ec9f5376d2122b48aa9c0a56e01ccfdc0a4ea0ed22188ed
|
3 |
+
size 40
|
Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183640.LAPTOP-5AN2UHOO
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:59b89d395abdaa5c2b8cb0922f4a465a9f06c59a429697c4d138e58033e6e1a0
|
3 |
+
size 40
|
Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183667.LAPTOP-5AN2UHOO
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e2ffadd87d44bba87f7bcd80fb424959536ff24e7a4e52a67238200c691befac
|
3 |
+
size 40
|
Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183756.LAPTOP-5AN2UHOO
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a91a4478ed17986eaf70acf7a0fb3fe0db11cbcbe8eedf7655bfad9e6a4a9650
|
3 |
+
size 40
|
Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183820.LAPTOP-5AN2UHOO
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2641a889839f3eaf83a9ee90b4bdf0073488a9416044507d321a7bfc8bbad83f
|
3 |
+
size 40
|
Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700184097.LAPTOP-5AN2UHOO
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c33c3e165e65aac3f1f75ddf5a1a4a3fc6e494e5be728a3a455ff453c7a40100
|
3 |
+
size 3726
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: green
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.28.2
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
+
title: Gomoku Zero
|
3 |
+
emoji: 📉
|
4 |
colorFrom: green
|
5 |
+
colorTo: blue
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.28.2
|
8 |
app_file: app.py
|
app.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
# 设置页面配置
|
3 |
+
st.set_page_config(
|
4 |
+
page_title="AI 3603 Gomoku Project",
|
5 |
+
page_icon="👋",
|
6 |
+
layout="wide",
|
7 |
+
initial_sidebar_state="collapsed"
|
8 |
+
)
|
9 |
+
# 大标题
|
10 |
+
st.write('<h1 style="text-align: center; color: black; font-weight: bold;">AI 3603 Gomoku Project 👋</h1>', unsafe_allow_html=True)
|
11 |
+
# 项目参与者
|
12 |
+
st.write('<p style="text-align: center; font-size: 20px;"><a href="https://github.com" style="color: blue; font-weight: normal; margin-right: 20px; text-decoration: none;">Jiaxin Li</a> \
|
13 |
+
<a href="https://github.com" style="color: blue; font-weight: normal; margin-right: 20px; text-decoration: none;">Junzhe Shen</a> \
|
14 |
+
<a href="https://github.com" style="color: blue; font-weight: normal; text-decoration: none;">Benhao Huang</a></p>', unsafe_allow_html=True)
|
15 |
+
# 标签
|
16 |
+
st.markdown("""
|
17 |
+
<div style="text-align: center;">
|
18 |
+
<a href="#" style="background-color: #343a40; color: white; font-size: 15px; padding: 10px 15px; margin: 5px; border-radius: 15px; text-decoration: none;">📄 Report</a>
|
19 |
+
<a href="#" style="background-color: #343a40; color: white; font-size: 15px; padding: 10px 15px; margin: 5px; border-radius: 15px; text-decoration: none;">💻 Code</a>
|
20 |
+
<a href="#" style="background-color: #343a40; color: white; font-size: 15px; padding: 10px 15px; margin: 5px; border-radius: 15px; text-decoration: none;">🌐 Space</a>
|
21 |
+
<a href="#" style="background-color: #343a40; color: white; font-size: 15px; padding: 10px 15px; margin: 5px; border-radius: 15px; text-decoration: none;">📊 PPT</a>
|
22 |
+
</div>
|
23 |
+
</br>
|
24 |
+
</br>
|
25 |
+
""", unsafe_allow_html=True)
|
26 |
+
# 项目介绍
|
27 |
+
st.markdown("""
|
28 |
+
<div style='color: black; font-size:18px'>Gomoku is an abstract strategy board game. Also called <span style='color:red;'>Gobang</span> or <span style='color:red;'>Five in a Row</span>,
|
29 |
+
it is traditionally played with Go pieces (black and white stones)
|
30 |
+
on a Go board. It is straightforward and fun, but also full of strategy and challenge.
|
31 |
+
Our project is aiming to apply Machine Learning techniques to build a powerful Gomoku AI.</div>
|
32 |
+
""",
|
33 |
+
unsafe_allow_html=True)
|
34 |
+
# 创新点和图片展示
|
35 |
+
st.write("<h2 style='text-align: center; color: black; font-weight: bold;'>Innovations We Made 👍</h2>", unsafe_allow_html=True)
|
36 |
+
col1, col2, col3 = st.columns(3)
|
37 |
+
with col1:
|
38 |
+
st.image("assets/favicon_circle.png", width=50) # 替换为你的图片 URL
|
39 |
+
st.caption("Innovation 1")
|
40 |
+
with col2:
|
41 |
+
st.image("assets/favicon_circle.png", width=50) # 替换为你的图片 URL
|
42 |
+
st.caption("Innovation 2")
|
43 |
+
with col3:
|
44 |
+
st.image("assets/favicon_circle.png", width=50) # 替换为你的图片 URL
|
45 |
+
st.caption("Innovation 3")
|
46 |
+
# 代码框架阐述和代码组件
|
47 |
+
st.write("<h2 style='text-align: center; color: black; font-weight: bold;'>Code Structure 🛠️</h2>", unsafe_allow_html=True)
|
48 |
+
st.code("""
|
49 |
+
import os
|
50 |
+
import streamlit as st
|
51 |
+
def main():
|
52 |
+
# your code here
|
53 |
+
if __name__ == "__main__":
|
54 |
+
main()
|
55 |
+
""", language="python")
|
56 |
+
st.markdown("---")
|
assets/favicon_circle.png
ADDED
const.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
_BOARD_SIZE = 8
|
4 |
+
_BOARD_SIZE_1D = _BOARD_SIZE * _BOARD_SIZE
|
5 |
+
_BLANK = 0
|
6 |
+
_BLACK = 1
|
7 |
+
_WHITE = 2
|
8 |
+
_PLAYER_SYMBOL = {
|
9 |
+
_WHITE: "⚪",
|
10 |
+
_BLANK: "➕",
|
11 |
+
_BLACK: "⚫",
|
12 |
+
}
|
13 |
+
_PLAYER_COLOR = {
|
14 |
+
_WHITE: "AI",
|
15 |
+
_BLANK: "Blank",
|
16 |
+
_BLACK: "YOU HUMAN",
|
17 |
+
}
|
18 |
+
_HORIZONTAL = np.array(
|
19 |
+
[
|
20 |
+
[0, 0, 0, 0, 0],
|
21 |
+
[0, 0, 0, 0, 0],
|
22 |
+
[1, 1, 1, 1, 1],
|
23 |
+
[0, 0, 0, 0, 0],
|
24 |
+
[0, 0, 0, 0, 0],
|
25 |
+
]
|
26 |
+
)
|
27 |
+
_VERTICAL = np.array(
|
28 |
+
[
|
29 |
+
[0, 0, 1, 0, 0],
|
30 |
+
[0, 0, 1, 0, 0],
|
31 |
+
[0, 0, 1, 0, 0],
|
32 |
+
[0, 0, 1, 0, 0],
|
33 |
+
[0, 0, 1, 0, 0],
|
34 |
+
]
|
35 |
+
)
|
36 |
+
_DIAGONAL_UP_LEFT = np.array(
|
37 |
+
[
|
38 |
+
[1, 0, 0, 0, 0],
|
39 |
+
[0, 1, 0, 0, 0],
|
40 |
+
[0, 0, 1, 0, 0],
|
41 |
+
[0, 0, 0, 1, 0],
|
42 |
+
[0, 0, 0, 0, 1],
|
43 |
+
]
|
44 |
+
)
|
45 |
+
_DIAGONAL_UP_RIGHT = np.array(
|
46 |
+
[
|
47 |
+
[0, 0, 0, 0, 1],
|
48 |
+
[0, 0, 0, 1, 0],
|
49 |
+
[0, 0, 1, 0, 0],
|
50 |
+
[0, 1, 0, 0, 0],
|
51 |
+
[1, 0, 0, 0, 0],
|
52 |
+
]
|
53 |
+
)
|
54 |
+
|
55 |
+
_ROOM_COLOR = {
|
56 |
+
True: _BLACK,
|
57 |
+
False: _WHITE,
|
58 |
+
}
|
pages/Player_VS_AI.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
FileName: app.py
|
3 |
+
Author: Benhao Huang
|
4 |
+
Create Date: 2023/11/18
|
5 |
+
Description: this file is used to display our project and add visualization elements to the game, using Streamlit
|
6 |
+
"""
|
7 |
+
|
8 |
+
import time
|
9 |
+
import pandas as pd
|
10 |
+
from copy import deepcopy
|
11 |
+
|
12 |
+
# import torch
|
13 |
+
import numpy as np
|
14 |
+
import streamlit as st
|
15 |
+
from scipy.signal import convolve # this is used to check if any player wins
|
16 |
+
from streamlit import session_state
|
17 |
+
from streamlit_server_state import server_state, server_state_lock
|
18 |
+
from Gomoku_MCTS import MCTSpure, alphazero, Board, PolicyValueNet
|
19 |
+
import matplotlib.pyplot as plt
|
20 |
+
|
21 |
+
from const import (
|
22 |
+
_BLACK, # 1, for human
|
23 |
+
_WHITE, # 2 , for AI
|
24 |
+
_BLANK,
|
25 |
+
_PLAYER_COLOR,
|
26 |
+
_PLAYER_SYMBOL,
|
27 |
+
_ROOM_COLOR,
|
28 |
+
_VERTICAL,
|
29 |
+
_HORIZONTAL,
|
30 |
+
_DIAGONAL_UP_LEFT,
|
31 |
+
_DIAGONAL_UP_RIGHT,
|
32 |
+
_BOARD_SIZE,
|
33 |
+
_BOARD_SIZE_1D
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
# Utils
|
38 |
+
class Room:
|
39 |
+
def __init__(self, room_id) -> None:
|
40 |
+
self.ROOM_ID = room_id
|
41 |
+
# self.BOARD = np.zeros(shape=(_BOARD_SIZE, _BOARD_SIZE), dtype=int)
|
42 |
+
self.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5, players=[_BLACK, _WHITE])
|
43 |
+
self.PLAYER = _BLACK
|
44 |
+
self.TURN = self.PLAYER
|
45 |
+
self.HISTORY = (0, 0)
|
46 |
+
self.WINNER = _BLANK
|
47 |
+
self.TIME = time.time()
|
48 |
+
self.MCTS = MCTSpure(c_puct=5, n_playout=10)
|
49 |
+
self.COORDINATE_1D = [_BOARD_SIZE_1D + 1]
|
50 |
+
self.current_move = -1
|
51 |
+
self.simula_time_list = []
|
52 |
+
|
53 |
+
|
54 |
+
def change_turn(cur):
|
55 |
+
return cur % 2 + 1
|
56 |
+
|
57 |
+
|
58 |
+
# Initialize the game
|
59 |
+
if "ROOM" not in session_state:
|
60 |
+
session_state.ROOM = Room("local")
|
61 |
+
if "OWNER" not in session_state:
|
62 |
+
session_state.OWNER = False
|
63 |
+
|
64 |
+
# Check server health
|
65 |
+
if "ROOMS" not in server_state:
|
66 |
+
with server_state_lock["ROOMS"]:
|
67 |
+
server_state.ROOMS = {}
|
68 |
+
|
69 |
+
# # Layout
|
70 |
+
# Main
|
71 |
+
TITLE = st.empty()
|
72 |
+
TITLE.header("🤖 AI 3603 Gomoku")
|
73 |
+
ROUND_INFO = st.empty()
|
74 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
75 |
+
BOARD_PLATE = [
|
76 |
+
[cell.empty() for cell in st.columns([1 for _ in range(_BOARD_SIZE)])] for _ in range(_BOARD_SIZE)
|
77 |
+
]
|
78 |
+
LOG = st.empty()
|
79 |
+
|
80 |
+
# Sidebar
|
81 |
+
SCORE_TAG = st.sidebar.empty()
|
82 |
+
SCORE_PLATE = st.sidebar.columns(2)
|
83 |
+
# History scores
|
84 |
+
SCORE_TAG.subheader("Scores")
|
85 |
+
|
86 |
+
PLAY_MODE_INFO = st.sidebar.container()
|
87 |
+
MULTIPLAYER_TAG = st.sidebar.empty()
|
88 |
+
with st.sidebar.container():
|
89 |
+
ANOTHER_ROUND = st.empty()
|
90 |
+
RESTART = st.empty()
|
91 |
+
EXIT = st.empty()
|
92 |
+
GAME_INFO = st.sidebar.container()
|
93 |
+
message = st.empty()
|
94 |
+
PLAY_MODE_INFO.write("---\n\n**You are Black, AI agent is White.**")
|
95 |
+
GAME_INFO.markdown(
|
96 |
+
"""
|
97 |
+
---
|
98 |
+
# <span style="color:black;">Freestyle Gomoku game. 🎲</span>
|
99 |
+
- no restrictions 🚫
|
100 |
+
- no regrets 😎
|
101 |
+
- swap players after one round is over 🔁
|
102 |
+
Powered by an AlphaZero approach with our own improvements! 🚀 For the specific details, please check out our <a href="insert_report_link_here" style="color:blue;">report</a>.
|
103 |
+
##### Adapted and improved by us! 🌟 <a href="https://github.com/Lijiaxin0111/AI_3603_BIGHOME" style="color:blue;">Our Github repo</a>
|
104 |
+
""",
|
105 |
+
unsafe_allow_html=True,
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
def restart() -> None:
|
110 |
+
"""
|
111 |
+
Restart the game.
|
112 |
+
"""
|
113 |
+
session_state.ROOM = Room(session_state.ROOM.ROOM_ID)
|
114 |
+
|
115 |
+
|
116 |
+
RESTART.button(
|
117 |
+
"Reset",
|
118 |
+
on_click=restart,
|
119 |
+
help="Clear the board as well as the scores",
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
# Draw the board
|
124 |
+
def gomoku():
|
125 |
+
"""
|
126 |
+
Draw the board.
|
127 |
+
Handle the main logic.
|
128 |
+
"""
|
129 |
+
|
130 |
+
# Restart the game
|
131 |
+
|
132 |
+
# Continue new round
|
133 |
+
def another_round() -> None:
|
134 |
+
"""
|
135 |
+
Continue new round.
|
136 |
+
"""
|
137 |
+
session_state.ROOM = deepcopy(session_state.ROOM)
|
138 |
+
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
|
139 |
+
session_state.ROOM.PLAYER = session_state.ROOM.PLAYER % 2 + 1
|
140 |
+
session_state.ROOM.TURN = session_state.ROOM.PLAYER
|
141 |
+
session_state.ROOM.WINNER = _BLANK # 0
|
142 |
+
session_state.ROOM.COORDINATE_1D = [_BOARD_SIZE_1D + 1]
|
143 |
+
|
144 |
+
# Room status sync
|
145 |
+
def sync_room() -> bool:
|
146 |
+
room_id = session_state.ROOM.ROOM_ID
|
147 |
+
if room_id not in server_state.ROOMS.keys():
|
148 |
+
session_state.ROOM = Room("local")
|
149 |
+
return False
|
150 |
+
elif server_state.ROOMS[room_id].TIME == session_state.ROOM.TIME:
|
151 |
+
return False
|
152 |
+
elif server_state.ROOMS[room_id].TIME < session_state.ROOM.TIME:
|
153 |
+
# Only acquire the lock when writing to the server state
|
154 |
+
with server_state_lock["ROOMS"]:
|
155 |
+
server_rooms = server_state.ROOMS
|
156 |
+
server_rooms[room_id] = session_state.ROOM
|
157 |
+
server_state.ROOMS = server_rooms
|
158 |
+
return True
|
159 |
+
else:
|
160 |
+
session_state.ROOM = server_state.ROOMS[room_id]
|
161 |
+
return True
|
162 |
+
|
163 |
+
# Check if winner emerge from move
|
164 |
+
def check_win() -> int:
|
165 |
+
"""
|
166 |
+
Use convolution to check if any player wins.
|
167 |
+
"""
|
168 |
+
vertical = convolve(
|
169 |
+
session_state.ROOM.BOARD.board_map,
|
170 |
+
_VERTICAL,
|
171 |
+
mode="same",
|
172 |
+
)
|
173 |
+
horizontal = convolve(
|
174 |
+
session_state.ROOM.BOARD.board_map,
|
175 |
+
_HORIZONTAL,
|
176 |
+
mode="same",
|
177 |
+
)
|
178 |
+
diagonal_up_left = convolve(
|
179 |
+
session_state.ROOM.BOARD.board_map,
|
180 |
+
_DIAGONAL_UP_LEFT,
|
181 |
+
mode="same",
|
182 |
+
)
|
183 |
+
diagonal_up_right = convolve(
|
184 |
+
session_state.ROOM.BOARD.board_map,
|
185 |
+
_DIAGONAL_UP_RIGHT,
|
186 |
+
mode="same",
|
187 |
+
)
|
188 |
+
if (
|
189 |
+
np.max(
|
190 |
+
[
|
191 |
+
np.max(vertical),
|
192 |
+
np.max(horizontal),
|
193 |
+
np.max(diagonal_up_left),
|
194 |
+
np.max(diagonal_up_right),
|
195 |
+
]
|
196 |
+
)
|
197 |
+
== 5 * _BLACK
|
198 |
+
):
|
199 |
+
winner = _BLACK
|
200 |
+
elif (
|
201 |
+
np.min(
|
202 |
+
[
|
203 |
+
np.min(vertical),
|
204 |
+
np.min(horizontal),
|
205 |
+
np.min(diagonal_up_left),
|
206 |
+
np.min(diagonal_up_right),
|
207 |
+
]
|
208 |
+
)
|
209 |
+
== 5 * _WHITE
|
210 |
+
):
|
211 |
+
winner = _WHITE
|
212 |
+
else:
|
213 |
+
winner = _BLANK
|
214 |
+
return winner
|
215 |
+
|
216 |
+
# Triggers the board response on click
|
217 |
+
def handle_click(x, y):
|
218 |
+
"""
|
219 |
+
Controls whether to pass on / continue current board / may start new round
|
220 |
+
"""
|
221 |
+
if session_state.ROOM.BOARD.board_map[x][y] != _BLANK:
|
222 |
+
pass
|
223 |
+
elif (
|
224 |
+
session_state.ROOM.ROOM_ID in server_state.ROOMS.keys()
|
225 |
+
and _ROOM_COLOR[session_state.OWNER]
|
226 |
+
!= server_state.ROOMS[session_state.ROOM.ROOM_ID].TURN
|
227 |
+
):
|
228 |
+
sync_room()
|
229 |
+
|
230 |
+
# normal play situation
|
231 |
+
elif session_state.ROOM.WINNER == _BLANK:
|
232 |
+
# session_state.ROOM = deepcopy(session_state.ROOM)
|
233 |
+
print("View of human player: ", session_state.ROOM.BOARD.board_map)
|
234 |
+
move = session_state.ROOM.BOARD.location_to_move((x, y))
|
235 |
+
session_state.ROOM.current_move = move
|
236 |
+
session_state.ROOM.BOARD.do_move(move)
|
237 |
+
session_state.ROOM.BOARD.board_map[x][y] = session_state.ROOM.TURN
|
238 |
+
session_state.ROOM.COORDINATE_1D.append(x * _BOARD_SIZE + y)
|
239 |
+
|
240 |
+
session_state.ROOM.TURN = change_turn(session_state.ROOM.TURN)
|
241 |
+
win, winner = session_state.ROOM.BOARD.game_end()
|
242 |
+
if win:
|
243 |
+
session_state.ROOM.WINNER = winner
|
244 |
+
session_state.ROOM.HISTORY = (
|
245 |
+
session_state.ROOM.HISTORY[0]
|
246 |
+
+ int(session_state.ROOM.WINNER == _WHITE),
|
247 |
+
session_state.ROOM.HISTORY[1]
|
248 |
+
+ int(session_state.ROOM.WINNER == _BLACK),
|
249 |
+
)
|
250 |
+
session_state.ROOM.TIME = time.time()
|
251 |
+
|
252 |
+
def forbid_click(x, y):
|
253 |
+
# st.warning('This posistion has been occupied!!!!', icon="⚠️")
|
254 |
+
st.error("({}, {}) has been occupied!!)".format(x, y), icon="🚨")
|
255 |
+
print("asdas")
|
256 |
+
|
257 |
+
# Draw board
|
258 |
+
def draw_board(response: bool):
|
259 |
+
"""construct each buttons for all cells of the board"""
|
260 |
+
|
261 |
+
if response and session_state.ROOM.TURN == _BLACK: # human turn
|
262 |
+
print("Your turn")
|
263 |
+
# construction of clickable buttons
|
264 |
+
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
265 |
+
# print("row:", row)
|
266 |
+
for j, cell in enumerate(row):
|
267 |
+
if (
|
268 |
+
i * _BOARD_SIZE + j
|
269 |
+
in (session_state.ROOM.COORDINATE_1D)
|
270 |
+
):
|
271 |
+
# disable click for GPT choices
|
272 |
+
BOARD_PLATE[i][j].button(
|
273 |
+
_PLAYER_SYMBOL[cell],
|
274 |
+
key=f"{i}:{j}",
|
275 |
+
args=(i, j),
|
276 |
+
on_click=forbid_click
|
277 |
+
)
|
278 |
+
else:
|
279 |
+
# enable click for other cells available for human choices
|
280 |
+
BOARD_PLATE[i][j].button(
|
281 |
+
_PLAYER_SYMBOL[cell],
|
282 |
+
key=f"{i}:{j}",
|
283 |
+
on_click=handle_click,
|
284 |
+
args=(i, j),
|
285 |
+
)
|
286 |
+
|
287 |
+
|
288 |
+
elif response and session_state.ROOM.TURN == _WHITE: # AI turn
|
289 |
+
message.empty()
|
290 |
+
with st.spinner('🔮✨ Waiting for AI response... ⏳🚀'):
|
291 |
+
time.sleep(0.1)
|
292 |
+
print("AI's turn")
|
293 |
+
print("Below are current board under AI's view")
|
294 |
+
print(session_state.ROOM.BOARD.board_map)
|
295 |
+
move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD)
|
296 |
+
session_state.ROOM.simula_time_list.append(simul_time)
|
297 |
+
print("AI takes move: ", move)
|
298 |
+
session_state.ROOM.current_move = move
|
299 |
+
gpt_response = move
|
300 |
+
gpt_i, gpt_j = gpt_response // _BOARD_SIZE, gpt_response % _BOARD_SIZE
|
301 |
+
print("AI's move is located at ({}, {}) :".format(gpt_i, gpt_j))
|
302 |
+
move = session_state.ROOM.BOARD.location_to_move((gpt_i, gpt_j))
|
303 |
+
print("Location to move: ", move)
|
304 |
+
session_state.ROOM.BOARD.do_move(move)
|
305 |
+
# session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
|
306 |
+
session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
|
307 |
+
|
308 |
+
# construction of clickable buttons
|
309 |
+
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
310 |
+
# print("row:", row)
|
311 |
+
for j, cell in enumerate(row):
|
312 |
+
if (
|
313 |
+
i * _BOARD_SIZE + j
|
314 |
+
in (session_state.ROOM.COORDINATE_1D)
|
315 |
+
):
|
316 |
+
# disable click for GPT choices
|
317 |
+
BOARD_PLATE[i][j].button(
|
318 |
+
_PLAYER_SYMBOL[cell],
|
319 |
+
key=f"{i}:{j}",
|
320 |
+
args=(i, j),
|
321 |
+
on_click=forbid_click
|
322 |
+
)
|
323 |
+
else:
|
324 |
+
# enable click for other cells available for human choices
|
325 |
+
BOARD_PLATE[i][j].button(
|
326 |
+
_PLAYER_SYMBOL[cell],
|
327 |
+
key=f"{i}:{j}",
|
328 |
+
on_click=handle_click,
|
329 |
+
args=(i, j),
|
330 |
+
)
|
331 |
+
|
332 |
+
message.markdown(
|
333 |
+
'AI agent has calculated its strategy, which takes <span style="color: blue; font-size: 20px;">{:.3e}</span>s per simulation.'.format(
|
334 |
+
simul_time),
|
335 |
+
unsafe_allow_html=True
|
336 |
+
)
|
337 |
+
LOG.subheader("Logs")
|
338 |
+
# change turn
|
339 |
+
session_state.ROOM.TURN = change_turn(session_state.ROOM.TURN)
|
340 |
+
# session_state.ROOM.WINNER = check_win()
|
341 |
+
|
342 |
+
win, winner = session_state.ROOM.BOARD.game_end()
|
343 |
+
if win:
|
344 |
+
session_state.ROOM.WINNER = winner
|
345 |
+
|
346 |
+
session_state.ROOM.HISTORY = (
|
347 |
+
session_state.ROOM.HISTORY[0]
|
348 |
+
+ int(session_state.ROOM.WINNER == _WHITE),
|
349 |
+
session_state.ROOM.HISTORY[1]
|
350 |
+
+ int(session_state.ROOM.WINNER == _BLACK),
|
351 |
+
)
|
352 |
+
session_state.ROOM.TIME = time.time()
|
353 |
+
|
354 |
+
if not response or session_state.ROOM.WINNER != _BLANK:
|
355 |
+
print("Game over")
|
356 |
+
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
357 |
+
for j, cell in enumerate(row):
|
358 |
+
BOARD_PLATE[i][j].write(
|
359 |
+
_PLAYER_SYMBOL[cell],
|
360 |
+
key=f"{i}:{j}",
|
361 |
+
)
|
362 |
+
|
363 |
+
# Game process control
|
364 |
+
def game_control():
|
365 |
+
if session_state.ROOM.WINNER != _BLANK:
|
366 |
+
draw_board(False)
|
367 |
+
else:
|
368 |
+
draw_board(True)
|
369 |
+
if session_state.ROOM.WINNER != _BLANK or 0 not in session_state.ROOM.BOARD.board_map:
|
370 |
+
ANOTHER_ROUND.button(
|
371 |
+
"Play Next round!",
|
372 |
+
on_click=another_round,
|
373 |
+
help="Clear board and swap first player",
|
374 |
+
)
|
375 |
+
|
376 |
+
# Infos
|
377 |
+
def update_info() -> None:
|
378 |
+
# Additional information
|
379 |
+
SCORE_PLATE[0].metric("Gomoku-Agent", session_state.ROOM.HISTORY[0])
|
380 |
+
SCORE_PLATE[1].metric("Black", session_state.ROOM.HISTORY[1])
|
381 |
+
if session_state.ROOM.WINNER != _BLANK:
|
382 |
+
st.balloons()
|
383 |
+
ROUND_INFO.write(
|
384 |
+
f"#### **{_PLAYER_COLOR[session_state.ROOM.WINNER]} WIN!**\n**Click buttons on the left for more plays.**"
|
385 |
+
)
|
386 |
+
|
387 |
+
# elif 0 not in session_state.ROOM.BOARD.board_map:
|
388 |
+
# ROUND_INFO.write("#### **Tie**")
|
389 |
+
# else:
|
390 |
+
# ROUND_INFO.write(
|
391 |
+
# f"#### **{_PLAYER_SYMBOL[session_state.ROOM.TURN]} {_PLAYER_COLOR[session_state.ROOM.TURN]}'s turn...**"
|
392 |
+
# )
|
393 |
+
|
394 |
+
# draw the plot for simulation time
|
395 |
+
# 创建一个 DataFrame
|
396 |
+
|
397 |
+
print(session_state.ROOM.simula_time_list)
|
398 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
399 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
400 |
+
chart_data = pd.DataFrame(session_state.ROOM.simula_time_list, columns=["Simulation Time"])
|
401 |
+
st.line_chart(chart_data)
|
402 |
+
|
403 |
+
# The main game loop
|
404 |
+
game_control()
|
405 |
+
update_info()
|
406 |
+
|
407 |
+
|
408 |
+
if __name__ == "__main__":
|
409 |
+
gomoku()
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pandas~=2.1.3
|
2 |
+
numpy~=1.26.2
|
3 |
+
streamlit~=1.28.2
|
4 |
+
matplotlib~=3.8.2
|
5 |
+
scipy~=1.11.3
|
6 |
+
torch~=2.1.1
|
7 |
+
streamlit-server-state==0.17.1
|