Spaces:
Sleeping
Sleeping
added gomokubot
Browse files- Gomoku_Bot/HumanVSAI.py +54 -0
- Gomoku_Bot/__init__.py +2 -0
- Gomoku_Bot/board.py +216 -0
- Gomoku_Bot/board_manuls.py +117 -0
- Gomoku_Bot/cache.py +23 -0
- Gomoku_Bot/config.py +6 -0
- Gomoku_Bot/eval.py +588 -0
- Gomoku_Bot/gomoku_bot.py +23 -0
- Gomoku_Bot/minimax_Net.py +117 -0
- Gomoku_Bot/minmax.py +119 -0
- Gomoku_Bot/position.py +36 -0
- Gomoku_Bot/shape.py +288 -0
- Gomoku_Bot/zobrist.py +35 -0
- ai.py +39 -0
- const.py +3 -0
- pages/AI_VS_AI.py +575 -0
- pages/Player_VS_AI.py +107 -27
- try.py +42 -0
Gomoku_Bot/HumanVSAI.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
5 |
+
from .board import Board
|
6 |
+
from .minmax import vct, cache_hits, minmax
|
7 |
+
from .eval import FIVE, FOUR, performance
|
8 |
+
|
9 |
+
|
10 |
+
class Game():
|
11 |
+
def __init__(self, firstRole=1):
|
12 |
+
self.board = Board(8, firstRole)
|
13 |
+
self.steps = []
|
14 |
+
self.step = 0
|
15 |
+
self.enableVCT = True # 是否开启算杀, 算杀会在某些leaf节点加深搜索, 但是不一定会增加搜索时间
|
16 |
+
|
17 |
+
def human_input(self):
|
18 |
+
x, y = map(int, input('Your move: ').split())
|
19 |
+
return x, y
|
20 |
+
|
21 |
+
def start_play(self, human_first=False):
|
22 |
+
if not human_first:
|
23 |
+
while not self.board.isGameOver():
|
24 |
+
print(self.board.display())
|
25 |
+
if self.step % 2 == 1:
|
26 |
+
x, y = self.human_input()
|
27 |
+
while not self.board.put(x, y):
|
28 |
+
x, y = self.human_input()
|
29 |
+
else:
|
30 |
+
score = minmax(self.board, 1, 4, enableVCT=self.enableVCT)
|
31 |
+
print(score)
|
32 |
+
x, y = score[1]
|
33 |
+
print("move at", x, y)
|
34 |
+
self.board.put(x, y)
|
35 |
+
self.step += 1
|
36 |
+
else:
|
37 |
+
while not self.board.isGameOver():
|
38 |
+
print(self.board.display())
|
39 |
+
if self.step % 2 == 0:
|
40 |
+
x, y = self.human_input()
|
41 |
+
while not self.board.put(x, y):
|
42 |
+
x, y = self.human_input()
|
43 |
+
else:
|
44 |
+
score = minmax(self.board, -1, 4, enableVCT=self.enableVCT)
|
45 |
+
print(score)
|
46 |
+
x, y = score[1]
|
47 |
+
self.board.put(x, y)
|
48 |
+
self.step += 1
|
49 |
+
print(self.board.display())
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
game = Game()
|
54 |
+
game.start_play(True)
|
Gomoku_Bot/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .board import Board
|
2 |
+
from .gomoku_bot import Gomoku_bot
|
Gomoku_Bot/board.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .zobrist import ZobristCache as Zobrist
|
2 |
+
from .cache import Cache
|
3 |
+
from .eval import Evaluate, FIVE
|
4 |
+
from scipy import signal
|
5 |
+
import pickle
|
6 |
+
import os
|
7 |
+
save_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'train_data/data', 'train_data.pkl')
|
8 |
+
|
9 |
+
if 'numpy' not in globals():
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
|
13 |
+
class Board:
|
14 |
+
def __init__(self, size=15, firstRole=1):
|
15 |
+
self.size = size
|
16 |
+
self.board = [[0] * self.size for _ in range(self.size)]
|
17 |
+
self.firstRole = firstRole # 1 for black, -1 for white
|
18 |
+
self.role = firstRole # 1 for black, -1 for white
|
19 |
+
self.history = []
|
20 |
+
self.zobrist = Zobrist(self.size)
|
21 |
+
self.winnerCache = Cache()
|
22 |
+
self.gameoverCache = Cache()
|
23 |
+
self.evaluateCache = Cache()
|
24 |
+
self.valuableMovesCache = Cache()
|
25 |
+
self.evaluateTime = 0
|
26 |
+
self.evaluator = Evaluate(self.size)
|
27 |
+
self.available = [(i, j) for i in range(self.size) for j in range(self.size)]
|
28 |
+
self.patterns = [np.ones((1, 5)), np.ones((5, 1)), np.eye(5), np.fliplr(np.eye(5))]
|
29 |
+
self.train_data = {1:[], -1: []}
|
30 |
+
if os.path.exists(save_path):
|
31 |
+
with open(save_path, 'rb') as f:
|
32 |
+
self.train_data = pickle.load(f)
|
33 |
+
|
34 |
+
def isGameOver(self):
|
35 |
+
# Checked
|
36 |
+
hash = self.hash()
|
37 |
+
if self.gameoverCache.get(hash):
|
38 |
+
return self.gameoverCache.get(hash)
|
39 |
+
if self.getWinner() != 0:
|
40 |
+
self.gameoverCache.put(hash, True)
|
41 |
+
# save train data
|
42 |
+
# with open(save_path, 'wb') as f:
|
43 |
+
# pickle.dump(self.train_data, f)
|
44 |
+
return True # Someone has won
|
45 |
+
# Game is over when there is no empty space on the board or someone has won
|
46 |
+
if len(self.history) == self.size ** 2:
|
47 |
+
self.gameoverCache.put(hash, True)
|
48 |
+
return True
|
49 |
+
else:
|
50 |
+
self.gameoverCache.put(hash, False)
|
51 |
+
return False
|
52 |
+
|
53 |
+
def getWinner(self):
|
54 |
+
# Checked
|
55 |
+
hash = self.hash()
|
56 |
+
flag = True
|
57 |
+
if self.winnerCache.get(hash):
|
58 |
+
return self.winnerCache.get(hash)
|
59 |
+
directions = [[1, 0], [0, 1], [1, 1], [1, -1]] # Horizontal, Vertical, Diagonal
|
60 |
+
for i in range(self.size):
|
61 |
+
for j in range(self.size):
|
62 |
+
if self.board[i][j] == 0:
|
63 |
+
flag = False
|
64 |
+
continue
|
65 |
+
for direction in directions:
|
66 |
+
count = 0
|
67 |
+
while (
|
68 |
+
0 <= i + direction[0] * count < self.size and
|
69 |
+
0 <= j + direction[1] * count < self.size and
|
70 |
+
self.board[i + direction[0] * count][j + direction[1] * count] == self.board[i][j]
|
71 |
+
):
|
72 |
+
count += 1
|
73 |
+
if count >= 5:
|
74 |
+
self.winnerCache.put(hash, self.board[i][j])
|
75 |
+
return self.board[i][j]
|
76 |
+
if flag:
|
77 |
+
print("tie!!!")
|
78 |
+
return 0
|
79 |
+
self.winnerCache.put(hash, 0)
|
80 |
+
return 0
|
81 |
+
|
82 |
+
def getValidMoves(self):
|
83 |
+
return self.available
|
84 |
+
|
85 |
+
def put(self, i, j, role=None):
|
86 |
+
# Checked
|
87 |
+
if role is None:
|
88 |
+
role = self.role
|
89 |
+
if not isinstance(i, int) or not isinstance(j, int):
|
90 |
+
print("Invalid move: Not Number!", i, j)
|
91 |
+
return False
|
92 |
+
if self.board[i][j] != 0:
|
93 |
+
print("Invalid move!", i, j)
|
94 |
+
return False
|
95 |
+
self.board[i][j] = role
|
96 |
+
self.available.remove((i, j))
|
97 |
+
self.history.append({"i": i, "j": j, "role": role})
|
98 |
+
self.zobrist.togglePiece(i, j, role)
|
99 |
+
self.evaluator.move(i, j, role)
|
100 |
+
self.role *= -1 # Switch role
|
101 |
+
return True
|
102 |
+
|
103 |
+
def undo(self):
|
104 |
+
# Checked
|
105 |
+
if len(self.history) == 0:
|
106 |
+
print("No moves to undo!")
|
107 |
+
return False
|
108 |
+
|
109 |
+
lastMove = self.history.pop()
|
110 |
+
self.board[lastMove['i']][lastMove['j']] = 0 # Remove the piece from the board
|
111 |
+
self.role = lastMove['role'] # Switch back to the previous player
|
112 |
+
self.zobrist.togglePiece(lastMove['i'], lastMove['j'], lastMove['role'])
|
113 |
+
self.evaluator.undo(lastMove['i'], lastMove['j'])
|
114 |
+
self.available.append((lastMove['i'], lastMove['j']))
|
115 |
+
return True
|
116 |
+
|
117 |
+
def position2coordinate(self, position):
|
118 |
+
# checked
|
119 |
+
row = position // self.size
|
120 |
+
col = position % self.size
|
121 |
+
return [row, col]
|
122 |
+
|
123 |
+
def coordinate2position(self, coordinate):
|
124 |
+
# Checked
|
125 |
+
return coordinate[0] * self.size + coordinate[1]
|
126 |
+
|
127 |
+
def getValuableMoves(self, role, depth=0, onlyThree=False, onlyFour=False):
|
128 |
+
# Checked
|
129 |
+
hash = self.hash()
|
130 |
+
prev = self.valuableMovesCache.get(hash)
|
131 |
+
if prev:
|
132 |
+
if (prev["role"] == role and
|
133 |
+
prev["depth"] == depth and
|
134 |
+
prev["onlyThree"] == onlyThree
|
135 |
+
and prev["onlyFour"] == onlyFour):
|
136 |
+
return prev["moves"]
|
137 |
+
|
138 |
+
moves, train_data = self.evaluator.getMoves(role, depth, onlyThree, onlyFour)
|
139 |
+
self.train_data[self.role].append(train_data)
|
140 |
+
# Handle a special case, if the center point is not occupied, add it by default
|
141 |
+
|
142 |
+
# 开局的时候随机走一步,增加开局的多样性
|
143 |
+
if not onlyThree and not onlyFour:
|
144 |
+
center = self.size // 2
|
145 |
+
if self.board[center][center] == 0:
|
146 |
+
moves.append((center, center))
|
147 |
+
|
148 |
+
# x_step = np.random.randint(-self.size // 2, self.size // 2)
|
149 |
+
# y_step = np.random.randint(-self.size // 2, self.size // 2)
|
150 |
+
# x = center + x_step
|
151 |
+
# y = center + y_step
|
152 |
+
# if 0 <= x < self.size and 0 <= y < self.size and self.board[x][y] == 0:
|
153 |
+
# moves.append((x, y))
|
154 |
+
|
155 |
+
self.valuableMovesCache.put(hash, {
|
156 |
+
"role": role,
|
157 |
+
"moves": moves,
|
158 |
+
"depth": depth,
|
159 |
+
"onlyThree": onlyThree,
|
160 |
+
"onlyFour": onlyFour
|
161 |
+
})
|
162 |
+
return moves
|
163 |
+
|
164 |
+
def display(self, extraPoints=[]):
|
165 |
+
# Checked
|
166 |
+
extraPosition = [self.coordinate2position(point) for point in extraPoints]
|
167 |
+
result = ""
|
168 |
+
for i in range(self.size):
|
169 |
+
for j in range(self.size):
|
170 |
+
position = self.coordinate2position([i, j])
|
171 |
+
if position in extraPosition:
|
172 |
+
result += "? "
|
173 |
+
continue
|
174 |
+
value = self.board[i][j]
|
175 |
+
if value == 1:
|
176 |
+
result += "B " # Black
|
177 |
+
elif value == -1:
|
178 |
+
result += "W " # White
|
179 |
+
else:
|
180 |
+
result += "- "
|
181 |
+
result += "\n"
|
182 |
+
return result
|
183 |
+
|
184 |
+
def hash(self):
|
185 |
+
# Checked
|
186 |
+
return self.zobrist.getHash() # Return the hash value of the current board, used for caching
|
187 |
+
|
188 |
+
def evaluate(self, role):
|
189 |
+
# Checked
|
190 |
+
hash_key = self.hash()
|
191 |
+
prev = self.evaluateCache.get(hash_key)
|
192 |
+
if prev:
|
193 |
+
if prev["role"] == role:
|
194 |
+
return prev["score"]
|
195 |
+
|
196 |
+
winner = self.getWinner()
|
197 |
+
score = 0
|
198 |
+
if winner != 0:
|
199 |
+
score = FIVE * winner * role
|
200 |
+
else:
|
201 |
+
score = self.evaluator.evaluate(role)
|
202 |
+
|
203 |
+
self.evaluateCache.put(hash_key, {"role": role, "score": score})
|
204 |
+
return score
|
205 |
+
|
206 |
+
def reverse(self):
|
207 |
+
# Checked
|
208 |
+
new_board = Board(self.size, -self.firstRole)
|
209 |
+
for move in self.history:
|
210 |
+
x, y, role = move['i'], move['j'], move['role']
|
211 |
+
new_board.put(x, y, -role)
|
212 |
+
return new_board
|
213 |
+
|
214 |
+
def toString(self):
|
215 |
+
# Checked
|
216 |
+
return ''.join([''.join(map(str, row)) for row in self.board])
|
Gomoku_Bot/board_manuls.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
wins = [
|
2 |
+
# O O O O O
|
3 |
+
# X X X X -
|
4 |
+
# - - - - -
|
5 |
+
# - - - - -
|
6 |
+
# - - - - -
|
7 |
+
[5, [0, 5, 1, 6, 2, 7, 3, 8, 4], 1], # 横向五
|
8 |
+
# O O O O -
|
9 |
+
# X X X X X
|
10 |
+
# O - - - -
|
11 |
+
# - - - - -
|
12 |
+
# - - - - -
|
13 |
+
[5, [0, 5, 1, 6, 2, 7, 3, 8, 10, 9], -1], # 白子横向五
|
14 |
+
# O O - O O
|
15 |
+
# X X X X -
|
16 |
+
# O - - - -
|
17 |
+
# - - - - -
|
18 |
+
# - - - - -
|
19 |
+
[5, [0, 5, 1, 6, 10, 7, 3, 8, 4], 0], # 有一个空位
|
20 |
+
# O O O X O
|
21 |
+
# X X X X -
|
22 |
+
# O - - - -
|
23 |
+
# - - - - -
|
24 |
+
# - - - - -
|
25 |
+
[5, [0, 5, 1, 6, 2, 7, 10, 8, 4], 0], # 有一个白子
|
26 |
+
# O X X X X
|
27 |
+
# O - - - -
|
28 |
+
# O - - - -
|
29 |
+
# O - - - -
|
30 |
+
# O - - - -
|
31 |
+
[5, [0, 1, 5, 2, 10, 3, 15, 4, 20], 1], # 纵向五
|
32 |
+
# O X X X X
|
33 |
+
# O - - - -
|
34 |
+
# O - - - -
|
35 |
+
# - O - - -
|
36 |
+
# O - - - -
|
37 |
+
[5, [0, 1, 5, 2, 10, 3, 16, 4, 20], 0], # 纵向五有一个空位
|
38 |
+
# O X X X X
|
39 |
+
# O - - - -
|
40 |
+
# O - - - -
|
41 |
+
# X O - - -
|
42 |
+
# O - - - -
|
43 |
+
[5, [0, 1, 5, 2, 10, 3, 16, 4, 20, 15], 0], # 纵向五有一个白子
|
44 |
+
# O X X X X
|
45 |
+
# - O - - -
|
46 |
+
# - - O - -
|
47 |
+
# - - - O -
|
48 |
+
# - - - - O
|
49 |
+
[5, [0, 1, 6, 2, 12, 3, 18, 4, 24], 1], # 斜线五
|
50 |
+
# O X X X X
|
51 |
+
# - O - - -
|
52 |
+
# - - - O -
|
53 |
+
# - - - O -
|
54 |
+
# - - - - O
|
55 |
+
[5, [0, 1, 6, 2, 12, 3, 19, 4, 24], 0], # 斜线五有一个空的
|
56 |
+
# O X X X X
|
57 |
+
# - O - - -
|
58 |
+
# - - X O -
|
59 |
+
# - - - O -
|
60 |
+
# - - - - O
|
61 |
+
[5, [0, 1, 6, 2, 12, 3, 19, 4, 24, 18], 0], # 斜线五有一个白子
|
62 |
+
# X X X X O
|
63 |
+
# - - - O -
|
64 |
+
# - - O - -
|
65 |
+
# - O - - -
|
66 |
+
# O - - - -
|
67 |
+
[5, [4, 0, 8, 1, 12, 2, 16, 3, 20], 1], # 反斜线五
|
68 |
+
# X X X X O
|
69 |
+
# - - - O -
|
70 |
+
# - - O - -
|
71 |
+
# O - - - -
|
72 |
+
# O - - - -
|
73 |
+
[5, [4, 0, 8, 1, 12, 2, 15, 3, 20], 0], # 反斜线五 有一个空位
|
74 |
+
# X X X - O
|
75 |
+
# - - - O -
|
76 |
+
# - - O - -
|
77 |
+
# - X - - -
|
78 |
+
# O - - - -
|
79 |
+
[5, [4, 0, 8, 1, 12, 2, 16, 20], 0], # 反斜线五 有一个空位
|
80 |
+
]
|
81 |
+
|
82 |
+
# valid moves
|
83 |
+
validMoves = [
|
84 |
+
# O - -
|
85 |
+
# - - -
|
86 |
+
# - - O
|
87 |
+
[3, [0, 8], [1, 2, 3, 4, 5, 6, 7]],
|
88 |
+
# O - - - -
|
89 |
+
# - - - - -
|
90 |
+
# - - - - -
|
91 |
+
# - - - - -
|
92 |
+
# - - - - -
|
93 |
+
[5, [0], [1, 2, 5, 6, 7, 8, 10, 11, 12, 13, 16, 17, 18]],
|
94 |
+
# - - - - -
|
95 |
+
# - - - - -
|
96 |
+
# - O - - -
|
97 |
+
# - - - - -
|
98 |
+
# - - - - -
|
99 |
+
[5, [11], [0, 1, 2, 3, 5, 6, 7, 8, 10, 12, 13, 15, 16, 17, 18, 20, 21, 22, 23]],
|
100 |
+
|
101 |
+
# - - - - - - - -
|
102 |
+
# - - - - - - - -
|
103 |
+
# - - - - - - - -
|
104 |
+
# - - O - X - - -
|
105 |
+
# - - - - - - - -
|
106 |
+
# - - - - - - - -
|
107 |
+
# - - - - - - - -
|
108 |
+
# - - - - - - - -
|
109 |
+
[8, [26, 28], [
|
110 |
+
8, 9, 10, 11, 12, 13, 14,
|
111 |
+
16, 17, 18, 19, 20, 21, 22,
|
112 |
+
24, 25, 27, 29, 30,
|
113 |
+
32, 33, 34, 35, 36, 37, 38,
|
114 |
+
40, 41, 42, 43, 44, 45, 46,
|
115 |
+
],
|
116 |
+
],
|
117 |
+
]
|
Gomoku_Bot/cache.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from cachetools import LRUCache
|
2 |
+
from .config import config
|
3 |
+
|
4 |
+
class Cache:
|
5 |
+
def __init__(self, capacity=1000000):
|
6 |
+
self.capacity = capacity
|
7 |
+
self.cache = LRUCache(maxsize=capacity)
|
8 |
+
self.enable_cache = config['enableCache']
|
9 |
+
|
10 |
+
def get(self, key):
|
11 |
+
if not self.enable_cache:
|
12 |
+
return None
|
13 |
+
return self.cache.get(key, None)
|
14 |
+
|
15 |
+
def put(self, key, value):
|
16 |
+
if not self.enable_cache:
|
17 |
+
return
|
18 |
+
self.cache[key] = value
|
19 |
+
|
20 |
+
def has(self, key):
|
21 |
+
if not self.enable_cache:
|
22 |
+
return False
|
23 |
+
return key in self.cache
|
Gomoku_Bot/config.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
config = {
|
2 |
+
"enableCache": True, # Whether to enable caching
|
3 |
+
"onlyInLine": False, # Whether to search only on a single line, an optimization option
|
4 |
+
"inlineCount": 4, # Number of recent points to consider for being on the same line
|
5 |
+
"inLineDistance": 5 # Maximum distance to determine if a point is on the same line
|
6 |
+
}
|
Gomoku_Bot/eval.py
ADDED
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
import math
|
3 |
+
from .shape import shapes, get_shape_fast, is_five, is_four, get_all_shapes_of_point
|
4 |
+
from .position import coordinate2Position, isLine, isAllInLine, hasInLine, position2Coordinate
|
5 |
+
from .config import config
|
6 |
+
from datetime import datetime
|
7 |
+
import os
|
8 |
+
|
9 |
+
dir_path = os.path.dirname(os.path.abspath(__file__))
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
# from .minimax_Net import BoardEvaluationNet as net
|
13 |
+
|
14 |
+
# mini_max_net = net(board_size=15)
|
15 |
+
# mini_max_net.load_state_dict(torch.load(os.path.join(dir_path, 'train_data/model', 'best_loss=609.3356355479785.pth')))
|
16 |
+
# mini_max_net.eval()
|
17 |
+
|
18 |
+
|
19 |
+
# Enum to represent different shapes
|
20 |
+
class Shapes(Enum):
|
21 |
+
FIVE = 0
|
22 |
+
BLOCK_FIVE = 1
|
23 |
+
FOUR = 2
|
24 |
+
FOUR_FOUR = 3
|
25 |
+
FOUR_THREE = 4
|
26 |
+
THREE_THREE = 5
|
27 |
+
BLOCK_FOUR = 6
|
28 |
+
THREE = 7
|
29 |
+
BLOCK_THREE = 8
|
30 |
+
TWO_TWO = 9
|
31 |
+
TWO = 10
|
32 |
+
NONE = 11
|
33 |
+
|
34 |
+
|
35 |
+
# Constants representing scores for each shape
|
36 |
+
FIVE = 10000000
|
37 |
+
BLOCK_FIVE = FIVE
|
38 |
+
FOUR = 100000
|
39 |
+
FOUR_FOUR = FOUR # 双冲四
|
40 |
+
FOUR_THREE = FOUR # 冲四活三
|
41 |
+
THREE_THREE = FOUR / 2 # 双活三
|
42 |
+
BLOCK_FOUR = 1500
|
43 |
+
THREE = 1000
|
44 |
+
BLOCK_THREE = 150
|
45 |
+
TWO_TWO = 200 # 双活二
|
46 |
+
TWO = 100
|
47 |
+
BLOCK_TWO = 15
|
48 |
+
ONE = 10
|
49 |
+
BLOCK_ONE = 1
|
50 |
+
|
51 |
+
|
52 |
+
# Function to calculate the real shape score based on the shape
|
53 |
+
def getRealShapeScore(shape: Shapes) -> int:
|
54 |
+
# Checked
|
55 |
+
if shape == shapes['FIVE']:
|
56 |
+
return FOUR
|
57 |
+
elif shape == shapes['BLOCK_FIVE']:
|
58 |
+
return BLOCK_FOUR
|
59 |
+
elif shape in [shapes['FOUR'], shapes['FOUR_FOUR'], shapes['FOUR_THREE']]:
|
60 |
+
return THREE
|
61 |
+
elif shape == shapes['BLOCK_FOUR']:
|
62 |
+
return BLOCK_THREE
|
63 |
+
elif shape == shapes['THREE']:
|
64 |
+
return TWO
|
65 |
+
elif shape == shapes['THREE_THREE']:
|
66 |
+
return math.floor(THREE_THREE / 10)
|
67 |
+
elif shape == shapes['BLOCK_THREE']:
|
68 |
+
return BLOCK_TWO
|
69 |
+
elif shape == shapes['TWO']:
|
70 |
+
return ONE
|
71 |
+
elif shape == shapes['TWO_TWO']:
|
72 |
+
return math.floor(TWO_TWO / 10)
|
73 |
+
else:
|
74 |
+
return 0
|
75 |
+
|
76 |
+
|
77 |
+
# List of all directions
|
78 |
+
allDirections = [
|
79 |
+
[0, 1], # Horizontal
|
80 |
+
[1, 0], # Vertical
|
81 |
+
[1, 1], # Diagonal \
|
82 |
+
[1, -1] # Diagonal /
|
83 |
+
]
|
84 |
+
|
85 |
+
|
86 |
+
# Function to get the index of a direction
|
87 |
+
def direction2index(ox: int, oy: int) -> int:
|
88 |
+
# Checked
|
89 |
+
if ox == 0:
|
90 |
+
return 0 # |
|
91 |
+
elif oy == 0:
|
92 |
+
return 1 # -
|
93 |
+
elif ox == oy:
|
94 |
+
return 2 # \
|
95 |
+
elif ox != oy:
|
96 |
+
return 3 # /
|
97 |
+
|
98 |
+
|
99 |
+
# Performance dictionary
|
100 |
+
performance = {
|
101 |
+
"updateTime": 0,
|
102 |
+
"getPointsTime": 0
|
103 |
+
}
|
104 |
+
|
105 |
+
|
106 |
+
class Evaluate:
|
107 |
+
def __init__(self, size=15):
|
108 |
+
# Checked
|
109 |
+
self.size = size
|
110 |
+
self.board = [[2] * (size + 2) for _ in range(size + 2)]
|
111 |
+
for i in range(size + 2):
|
112 |
+
for j in range(size + 2):
|
113 |
+
if i == 0 or j == 0 or i == self.size + 1 or j == self.size + 1:
|
114 |
+
self.board[i][j] = 2
|
115 |
+
else:
|
116 |
+
self.board[i][j] = 0
|
117 |
+
self.blackScores = [[0] * self.size for _ in range(size)]
|
118 |
+
self.whiteScores = [[0] * self.size for _ in range(size)]
|
119 |
+
self.initPoints()
|
120 |
+
self.history = [] # List of [position, role]
|
121 |
+
|
122 |
+
def move(self, x, y, role):
|
123 |
+
# Checked
|
124 |
+
# Clear the cache first
|
125 |
+
for d in [0, 1, 2, 3]:
|
126 |
+
self.shapeCache[role][d][x][y] = 0
|
127 |
+
self.shapeCache[-role][d][x][y] = 0
|
128 |
+
self.blackScores[x][y] = 0
|
129 |
+
self.whiteScores[x][y] = 0
|
130 |
+
# Update the board
|
131 |
+
self.board[x + 1][y + 1] = role ## Adjust for the added wall
|
132 |
+
self.updatePoint(x, y)
|
133 |
+
self.history.append([coordinate2Position(x, y, self.size), role])
|
134 |
+
|
135 |
+
def undo(self, x, y):
|
136 |
+
# Checked
|
137 |
+
self.board[x + 1][y + 1] = 0
|
138 |
+
self.updatePoint(x, y)
|
139 |
+
self.history.pop()
|
140 |
+
|
141 |
+
def initPoints(self):
|
142 |
+
# Checked
|
143 |
+
# Initialize the cache, avoid calculating the same points multiple times
|
144 |
+
self.shapeCache = {}
|
145 |
+
for role in [1, -1]:
|
146 |
+
self.shapeCache[role] = {}
|
147 |
+
for direction in [0, 1, 2, 3]:
|
148 |
+
self.shapeCache[role][direction] = [[0] * self.size for _ in range(self.size)]
|
149 |
+
|
150 |
+
self.pointsCache = {}
|
151 |
+
for role in [1, -1]:
|
152 |
+
self.pointsCache[role] = {}
|
153 |
+
for shape in shapes:
|
154 |
+
self.pointsCache[role][shape] = set()
|
155 |
+
|
156 |
+
def getPointsInLine(self, role):
|
157 |
+
# Checked
|
158 |
+
pointsInLine = {}
|
159 |
+
hasPointsInLine = False
|
160 |
+
for key in shapes:
|
161 |
+
pointsInLine[shapes[key]] = set()
|
162 |
+
|
163 |
+
last2Points = [position for position, role in self.history[-config['inlineCount']:]]
|
164 |
+
processed = {}
|
165 |
+
# 在last2Points中查找是否有点位在一条线上
|
166 |
+
for r in [role, -role]:
|
167 |
+
for point in last2Points:
|
168 |
+
x, y = position2Coordinate(point, self.size)
|
169 |
+
for ox, oy in allDirections:
|
170 |
+
for sign in [1, -1]:
|
171 |
+
for step in range(1, config['inLineDistance'] + 1):
|
172 |
+
nx = x + sign * step * ox
|
173 |
+
ny = y + sign * step * oy
|
174 |
+
position = coordinate2Position(nx, ny, self.size)
|
175 |
+
# 检测是否到达边界
|
176 |
+
if nx < 0 or nx >= self.size or ny < 0 or ny >= self.size:
|
177 |
+
break
|
178 |
+
if self.board[nx + 1][ny + 1] != 0:
|
179 |
+
continue
|
180 |
+
if processed.get(position) == r:
|
181 |
+
continue
|
182 |
+
processed[position] = r
|
183 |
+
for direction in [0, 1, 2, 3]:
|
184 |
+
shape = self.shapeCache[r][direction][nx][ny]
|
185 |
+
# 到达边界停止,但是注意到达对方棋子不能停止
|
186 |
+
if shape:
|
187 |
+
pointsInLine[shape].add(coordinate2Position(nx, ny, self.size))
|
188 |
+
hasPointsInLine = True
|
189 |
+
|
190 |
+
if hasPointsInLine:
|
191 |
+
return pointsInLine
|
192 |
+
return False
|
193 |
+
|
194 |
+
def getPoints(self, role, depth, vct, vcf):
|
195 |
+
first = role if depth % 2 == 0 else -role # 先手
|
196 |
+
start = datetime.now()
|
197 |
+
|
198 |
+
if config['onlyInLine'] and len(self.history) >= config['inlineCount']:
|
199 |
+
points_in_line = self.getPointsInLine(role)
|
200 |
+
if points_in_line:
|
201 |
+
performance['getPointsTime'] += (datetime.now() - start).total_seconds()
|
202 |
+
return points_in_line
|
203 |
+
|
204 |
+
points = {} # 全部点位
|
205 |
+
|
206 |
+
for key in shapes.keys():
|
207 |
+
points[shapes[key]] = set()
|
208 |
+
|
209 |
+
last_points = [position for position, _ in self.history[-4:]]
|
210 |
+
|
211 |
+
for r in [role, -role]:
|
212 |
+
# 这里是直接遍历了这个棋盘上的所有点位,如果棋盘很大,这里会有性能问题;可以用神经网络来预测
|
213 |
+
for i in range(self.size):
|
214 |
+
for j in range(self.size):
|
215 |
+
four_count = 0
|
216 |
+
block_four_count = 0
|
217 |
+
three_count = 0
|
218 |
+
|
219 |
+
for direction in [0, 1, 2, 3]:
|
220 |
+
if self.board[i + 1][j + 1] != 0:
|
221 |
+
continue
|
222 |
+
|
223 |
+
shape = self.shapeCache[r][direction][i][j]
|
224 |
+
|
225 |
+
if not shape:
|
226 |
+
continue
|
227 |
+
|
228 |
+
point = i * self.size + j
|
229 |
+
|
230 |
+
if vcf:
|
231 |
+
if r == first and not is_four(shape) and not is_five(shape):
|
232 |
+
continue
|
233 |
+
if r == -first and is_five(shape):
|
234 |
+
continue
|
235 |
+
|
236 |
+
if vct:
|
237 |
+
if depth % 2 == 0:
|
238 |
+
if depth == 0 and r != first:
|
239 |
+
continue
|
240 |
+
if shape != shapes['THREE'] and not is_four(shape) and not is_five(shape):
|
241 |
+
continue
|
242 |
+
if shape == shapes['THREE'] and r != first:
|
243 |
+
continue
|
244 |
+
if depth == 0 and r != first:
|
245 |
+
continue
|
246 |
+
if depth > 0:
|
247 |
+
if shape == shapes['THREE'] and len(
|
248 |
+
get_all_shapes_of_point(self.shapeCache, i, j, r)) == 1:
|
249 |
+
continue
|
250 |
+
if shape == shapes['BLOCK_FOUR'] and len(
|
251 |
+
get_all_shapes_of_point(self.shapeCache, i, j, r)) == 1:
|
252 |
+
continue
|
253 |
+
else:
|
254 |
+
if shape != shapes['THREE'] and not is_four(shape) and not is_five(shape):
|
255 |
+
continue
|
256 |
+
if shape == shapes['THREE'] and r == -first:
|
257 |
+
continue
|
258 |
+
if depth > 1:
|
259 |
+
if shape == shapes['BLOCK_FOUR'] and len(
|
260 |
+
get_all_shapes_of_point(self.shapeCache, i, j)) == 1:
|
261 |
+
continue
|
262 |
+
if shape == shapes['BLOCK_FOUR'] and not hasInLine(point, last_points, self.size):
|
263 |
+
continue
|
264 |
+
|
265 |
+
if vcf:
|
266 |
+
if not is_four(shape) and not is_five(shape):
|
267 |
+
continue
|
268 |
+
|
269 |
+
if depth > 2 and (shape == shapes['TWO'] or shape == shapes['TWO_TWO'] or shape == shapes[
|
270 |
+
'BLOCK_THREE']) and not hasInLine(point, last_points, self.size):
|
271 |
+
continue
|
272 |
+
|
273 |
+
points[shape].add(point)
|
274 |
+
|
275 |
+
if shape == shapes['FOUR']:
|
276 |
+
four_count += 1
|
277 |
+
elif shape == shapes['BLOCK_FOUR']:
|
278 |
+
block_four_count += 1
|
279 |
+
elif shape == shapes['THREE']:
|
280 |
+
three_count += 1
|
281 |
+
|
282 |
+
union_shape = None
|
283 |
+
|
284 |
+
if four_count >= 2:
|
285 |
+
union_shape = shapes['FOUR_FOUR']
|
286 |
+
elif block_four_count and three_count:
|
287 |
+
union_shape = shapes['FOUR_THREE']
|
288 |
+
elif three_count >= 2:
|
289 |
+
union_shape = shapes['THREE_THREE']
|
290 |
+
|
291 |
+
if union_shape:
|
292 |
+
points[union_shape].add(point)
|
293 |
+
|
294 |
+
performance['getPointsTime'] += (datetime.now() - start).total_seconds()
|
295 |
+
|
296 |
+
return points
|
297 |
+
|
298 |
+
"""
|
299 |
+
当一个位置发生变时候,要更新这个位置的四个方向上得分,更新规则是:
|
300 |
+
1. 如果这个位置是空的,那么就重新计算这个位置的得分
|
301 |
+
2. 如果碰到了边界或者对方的棋子,那么就停止计算
|
302 |
+
3. 如果超过2个空位,那么就停止计算
|
303 |
+
4. 要更新自己的和对方的得分
|
304 |
+
"""
|
305 |
+
|
306 |
+
def updatePoint(self, x, y):
|
307 |
+
# Checked
|
308 |
+
start = datetime.now()
|
309 |
+
self.updateSinglePoint(x, y, 1)
|
310 |
+
self.updateSinglePoint(x, y, -1)
|
311 |
+
|
312 |
+
for ox, oy in allDirections:
|
313 |
+
for sign in [1, -1]: # -1 for negative direction, 1 for positive direction
|
314 |
+
for step in range(1, 6):
|
315 |
+
reachEdge = False
|
316 |
+
for role in [1, -1]:
|
317 |
+
nx = x + sign * step * ox + 1 # +1 to adjust for wall
|
318 |
+
ny = y + sign * step * oy + 1 # +1 to adjust for wall
|
319 |
+
# Stop if wall or opponent's piece is found
|
320 |
+
if self.board[nx][ny] == 2:
|
321 |
+
reachEdge = True
|
322 |
+
break
|
323 |
+
elif self.board[nx][ny] == -role: # Change role if opponent's piece is found
|
324 |
+
continue
|
325 |
+
elif self.board[nx][ny] == 0:
|
326 |
+
self.updateSinglePoint(nx - 1, ny - 1, role,
|
327 |
+
[sign * ox, sign * oy]) # -1 to adjust back from wall
|
328 |
+
if reachEdge:
|
329 |
+
break
|
330 |
+
performance['updateTime'] += (datetime.now() - start).total_seconds()
|
331 |
+
|
332 |
+
"""
|
333 |
+
计算单个点的得分
|
334 |
+
计算原理是:
|
335 |
+
在当前位置放一个当前角色的棋子,遍历四个方向,生成四个方向上的字符串,用patters来匹配字符串, 匹配到的话,就将对应的得分加到scores上
|
336 |
+
四个方向的字符串生成规则是:向两边都延伸5个位置,如果遇到边界或者对方的棋子,就停止延伸
|
337 |
+
在更新周围棋子时,只有一个方向需要更新,因此可以传入direction参数,只更新一个方向
|
338 |
+
"""
|
339 |
+
|
340 |
+
def updateSinglePoint(self, x, y, role, direction=None):
|
341 |
+
# Checked
|
342 |
+
if self.board[x + 1][y + 1] != 0:
|
343 |
+
return # Not an empty spot
|
344 |
+
|
345 |
+
# Temporarily place the piece
|
346 |
+
self.board[x + 1][y + 1] = role
|
347 |
+
|
348 |
+
directions = []
|
349 |
+
if direction:
|
350 |
+
directions.append(direction)
|
351 |
+
else:
|
352 |
+
directions = allDirections
|
353 |
+
|
354 |
+
shapeCache = self.shapeCache[role]
|
355 |
+
|
356 |
+
# Clear the cache first
|
357 |
+
for ox, oy in directions:
|
358 |
+
shapeCache[direction2index(ox, oy)][x][y] = shapes['NONE']
|
359 |
+
|
360 |
+
score = 0
|
361 |
+
blockFourCount = 0
|
362 |
+
threeCount = 0
|
363 |
+
twoCount = 0
|
364 |
+
|
365 |
+
# Calculate existing score
|
366 |
+
for intDirection in [0, 1, 2, 3]:
|
367 |
+
shape = shapeCache[intDirection][x][y]
|
368 |
+
if shape > shapes['NONE']:
|
369 |
+
score += getRealShapeScore(shape)
|
370 |
+
if shape == shapes['BLOCK_FOUR']:
|
371 |
+
blockFourCount += 1
|
372 |
+
if shape == shapes['THREE']:
|
373 |
+
threeCount += 1
|
374 |
+
if shape == shapes['TWO']:
|
375 |
+
twoCount += 1
|
376 |
+
|
377 |
+
for ox, oy in directions:
|
378 |
+
intDirection = direction2index(ox, oy)
|
379 |
+
shape, selfCount = get_shape_fast(self.board, x, y, ox, oy, role)
|
380 |
+
if not shape:
|
381 |
+
continue
|
382 |
+
if shape:
|
383 |
+
# Note: Only cache single shapes, do not cache compound shapes like double threes, as they depend on two shapes
|
384 |
+
shapeCache[intDirection][x][y] = shape
|
385 |
+
if shape == shapes['BLOCK_FOUR']:
|
386 |
+
blockFourCount += 1
|
387 |
+
if shape == shapes['THREE']:
|
388 |
+
threeCount += 1
|
389 |
+
if shape == shapes['TWO']:
|
390 |
+
twoCount += 1
|
391 |
+
if blockFourCount >= 2:
|
392 |
+
shape = shapes['FOUR_FOUR']
|
393 |
+
elif blockFourCount and threeCount:
|
394 |
+
shape = shapes['FOUR_THREE']
|
395 |
+
elif threeCount >= 2:
|
396 |
+
shape = shapes['THREE_THREE']
|
397 |
+
elif twoCount >= 2:
|
398 |
+
shape = shapes['TWO_TWO']
|
399 |
+
score += getRealShapeScore(shape)
|
400 |
+
|
401 |
+
self.board[x + 1][y + 1] = 0 # Remove the temporary piece
|
402 |
+
|
403 |
+
if role == 1:
|
404 |
+
self.blackScores[x][y] = score
|
405 |
+
else:
|
406 |
+
self.whiteScores[x][y] = score
|
407 |
+
|
408 |
+
return score
|
409 |
+
|
410 |
+
def evaluate(self, role):
|
411 |
+
# Checked
|
412 |
+
blackScore = 0
|
413 |
+
whiteScore = 0
|
414 |
+
|
415 |
+
for i in range(len(self.blackScores)):
|
416 |
+
for j in range(len(self.blackScores[i])):
|
417 |
+
blackScore += self.blackScores[i][j]
|
418 |
+
|
419 |
+
for i in range(len(self.whiteScores)):
|
420 |
+
for j in range(len(self.whiteScores[i])):
|
421 |
+
whiteScore += self.whiteScores[i][j]
|
422 |
+
|
423 |
+
score = blackScore - whiteScore if role == 1 else whiteScore - blackScore
|
424 |
+
return score
|
425 |
+
|
426 |
+
def getMoves(self, role, depth, onThree=False, onlyFour=False, use_net = False):
|
427 |
+
# Checked
|
428 |
+
train_data = 0
|
429 |
+
if use_net and role == 1:
|
430 |
+
# value_move_num = 6
|
431 |
+
# input = torch.Tensor(np.array(self.board)[1:-1, 1:-1]).unsqueeze(0)
|
432 |
+
# scores = mini_max_net(input)
|
433 |
+
# flattened_scores = scores.flatten()
|
434 |
+
#
|
435 |
+
# moves = (flattened_scores.argsort(descending=True)[:value_move_num]).tolist()
|
436 |
+
moves = 0
|
437 |
+
# print(moves)
|
438 |
+
else:
|
439 |
+
moves, model_train_maxtrix = self._getMoves(role, depth, onThree, onlyFour)
|
440 |
+
train_data = {"state": np.array(self.board)[1:-1, 1:-1], "scores": model_train_maxtrix}
|
441 |
+
moves = [(move // self.size, move % self.size) for move in moves]
|
442 |
+
# cut the self.board into normal size
|
443 |
+
print("moves", moves)
|
444 |
+
|
445 |
+
return moves, train_data
|
446 |
+
|
447 |
+
def _getMoves(self, role, depth, only_three=False, only_four=False):
|
448 |
+
"""
|
449 |
+
Get possible moves based on the current game state.
|
450 |
+
"""
|
451 |
+
points = self.getPoints(role, depth, only_three, only_four)
|
452 |
+
fives = points[shapes['FIVE']]
|
453 |
+
block_fives = points[shapes['BLOCK_FIVE']]
|
454 |
+
|
455 |
+
# To train the model, we need to get all these points's score and store it to board size matrix
|
456 |
+
# Then we can use this matrix to train the model, given a state, we want it to output the score of each point, then we can choose the highest score point
|
457 |
+
model_train_matrix = [[0] * self.size for _ in range(self.size)]
|
458 |
+
|
459 |
+
if fives and len(fives) > 0 or block_fives and len(block_fives) > 0:
|
460 |
+
for point in fives:
|
461 |
+
x = point // self.size
|
462 |
+
y = point % self.size
|
463 |
+
model_train_matrix[x][y] = max(FIVE, model_train_matrix[x][y])
|
464 |
+
for point in block_fives:
|
465 |
+
x = point // self.size
|
466 |
+
y = point % self.size
|
467 |
+
model_train_matrix[x][y] = max(BLOCK_FIVE, model_train_matrix[x][y])
|
468 |
+
|
469 |
+
return set(list(fives) + list(block_fives)), model_train_matrix
|
470 |
+
|
471 |
+
fours = points[shapes['FOUR']]
|
472 |
+
block_fours = points[shapes['BLOCK_FOUR']] # Block four is special, consider it in both four and three
|
473 |
+
if only_four or (fours and len(fours) > 0):
|
474 |
+
for point in fours:
|
475 |
+
x = point // self.size
|
476 |
+
y = point % self.size
|
477 |
+
model_train_matrix[x][y] = max(FOUR, model_train_matrix[x][y])
|
478 |
+
|
479 |
+
for point in block_fours:
|
480 |
+
x = point // self.size
|
481 |
+
y = point % self.size
|
482 |
+
model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
|
483 |
+
|
484 |
+
return set(list(fours) + list(block_fours)), model_train_matrix
|
485 |
+
|
486 |
+
four_fours = points[shapes['FOUR_FOUR']]
|
487 |
+
if four_fours and len(four_fours) > 0:
|
488 |
+
for point in four_fours:
|
489 |
+
x = point // self.size
|
490 |
+
y = point % self.size
|
491 |
+
model_train_matrix[x][y] = max(FOUR_FOUR, model_train_matrix[x][y])
|
492 |
+
|
493 |
+
for point in block_fours:
|
494 |
+
x = point // self.size
|
495 |
+
y = point % self.size
|
496 |
+
model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
|
497 |
+
|
498 |
+
return set(list(four_fours) + list(block_fours)), model_train_matrix
|
499 |
+
|
500 |
+
# Double threes and active threes
|
501 |
+
threes = points[shapes['THREE']]
|
502 |
+
four_threes = points[shapes['FOUR_THREE']]
|
503 |
+
if four_threes and len(four_threes) > 0:
|
504 |
+
for point in four_threes:
|
505 |
+
x = point // self.size
|
506 |
+
y = point % self.size
|
507 |
+
model_train_matrix[x][y] = max(FOUR_THREE, model_train_matrix[x][y])
|
508 |
+
|
509 |
+
for point in block_fours:
|
510 |
+
x = point // self.size
|
511 |
+
y = point % self.size
|
512 |
+
model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
|
513 |
+
|
514 |
+
for point in threes:
|
515 |
+
x = point // self.size
|
516 |
+
y = point % self.size
|
517 |
+
model_train_matrix[x][y] = max(THREE, model_train_matrix[x][y])
|
518 |
+
|
519 |
+
return set(list(four_threes) + list(block_fours) + list(threes)), model_train_matrix
|
520 |
+
|
521 |
+
three_threes = points[shapes['THREE_THREE']]
|
522 |
+
if three_threes and len(three_threes) > 0:
|
523 |
+
|
524 |
+
for point in three_threes:
|
525 |
+
x = point // self.size
|
526 |
+
y = point % self.size
|
527 |
+
model_train_matrix[x][y] = max(THREE_THREE, model_train_matrix[x][y])
|
528 |
+
|
529 |
+
for point in block_fours:
|
530 |
+
x = point // self.size
|
531 |
+
y = point % self.size
|
532 |
+
model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
|
533 |
+
|
534 |
+
for point in threes:
|
535 |
+
x = point // self.size
|
536 |
+
y = point % self.size
|
537 |
+
model_train_matrix[x][y] = max(THREE, model_train_matrix[x][y])
|
538 |
+
|
539 |
+
return set(list(three_threes) + list(block_fours) + list(threes)), model_train_matrix
|
540 |
+
|
541 |
+
if only_three:
|
542 |
+
for point in threes:
|
543 |
+
x = point // self.size
|
544 |
+
y = point % self.size
|
545 |
+
model_train_matrix[x][y] = max(THREE, model_train_matrix[x][y])
|
546 |
+
|
547 |
+
for point in block_fours:
|
548 |
+
x = point // self.size
|
549 |
+
y = point % self.size
|
550 |
+
model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
|
551 |
+
return set(list(block_fours) + list(threes)), model_train_matrix
|
552 |
+
|
553 |
+
block_threes = points[shapes['BLOCK_THREE']]
|
554 |
+
two_twos = points[shapes['TWO_TWO']]
|
555 |
+
twos = points[shapes['TWO']]
|
556 |
+
|
557 |
+
for point in block_threes:
|
558 |
+
x = point // self.size
|
559 |
+
y = point % self.size
|
560 |
+
model_train_matrix[x][y] = max(BLOCK_THREE, model_train_matrix[x][y])
|
561 |
+
|
562 |
+
for point in two_twos:
|
563 |
+
x = point // self.size
|
564 |
+
y = point % self.size
|
565 |
+
model_train_matrix[x][y] = max(TWO_TWO, model_train_matrix[x][y])
|
566 |
+
|
567 |
+
for point in twos:
|
568 |
+
x = point // self.size
|
569 |
+
y = point % self.size
|
570 |
+
model_train_matrix[x][y] = max(TWO, model_train_matrix[x][y])
|
571 |
+
|
572 |
+
for point in block_fours:
|
573 |
+
x = point // self.size
|
574 |
+
y = point % self.size
|
575 |
+
model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
|
576 |
+
|
577 |
+
for point in threes:
|
578 |
+
x = point // self.size
|
579 |
+
y = point % self.size
|
580 |
+
model_train_matrix[x][y] = max(THREE, model_train_matrix[x][y])
|
581 |
+
|
582 |
+
mid = list(block_fours) + list(threes) + list(block_threes) + list(two_twos) + list(twos)
|
583 |
+
res = set(mid[:5])
|
584 |
+
for i in range(len(model_train_matrix)):
|
585 |
+
for j in range(len(model_train_matrix)):
|
586 |
+
if (i * len(model_train_matrix) + j) not in res:
|
587 |
+
model_train_matrix[i][j] = 0
|
588 |
+
return res, model_train_matrix
|
Gomoku_Bot/gomoku_bot.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .minmax import *
|
2 |
+
import time
|
3 |
+
|
4 |
+
|
5 |
+
class Gomoku_bot:
|
6 |
+
def __init__(self, board, role, depth=4, enableVCT=True):
|
7 |
+
self.board = board
|
8 |
+
self.role = role
|
9 |
+
self.depth = depth
|
10 |
+
self.enableVCT = enableVCT
|
11 |
+
|
12 |
+
def get_action(self, return_time=True):
|
13 |
+
start = time.time()
|
14 |
+
score = minmax(self.board, self.role, self.depth, self.enableVCT)
|
15 |
+
end = time.time()
|
16 |
+
sim_time = end - start
|
17 |
+
move = score[1]
|
18 |
+
# turn tuple into an int
|
19 |
+
move = move[0] * self.board.size + move[1]
|
20 |
+
if return_time:
|
21 |
+
return move, sim_time
|
22 |
+
else:
|
23 |
+
return move
|
Gomoku_Bot/minimax_Net.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import pickle
|
6 |
+
import os
|
7 |
+
|
8 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
9 |
+
# from tensorboardX import SummaryWriter
|
10 |
+
from tqdm import tqdm
|
11 |
+
import datetime
|
12 |
+
from torch.utils.data import DataLoader, TensorDataset
|
13 |
+
|
14 |
+
date = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
15 |
+
|
16 |
+
|
17 |
+
class BoardEvaluationNet(nn.Module):
|
18 |
+
def __init__(self, board_size):
|
19 |
+
super(BoardEvaluationNet, self).__init__()
|
20 |
+
self.board_size = board_size
|
21 |
+
|
22 |
+
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
|
23 |
+
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
|
24 |
+
self.fc1 = nn.Linear(32 * board_size * board_size, 256)
|
25 |
+
self.fc2 = nn.Linear(256, board_size * board_size)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
x = x.unsqueeze(1) # Add a channel dimension
|
29 |
+
x = F.relu(self.conv1(x))
|
30 |
+
x = F.relu(self.conv2(x))
|
31 |
+
x = x.view(-1, 32 * self.board_size * self.board_size)
|
32 |
+
x = F.relu(self.fc1(x))
|
33 |
+
x = self.fc2(x)
|
34 |
+
return x.view(-1, self.board_size, self.board_size)
|
35 |
+
|
36 |
+
|
37 |
+
def normalize(t):
|
38 |
+
return t
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == "__main__":
|
42 |
+
|
43 |
+
writer = SummaryWriter(os.path.join(dir_path, 'train_data/log', date), comment='BoardEvaluationNet')
|
44 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
45 |
+
|
46 |
+
best = np.Inf
|
47 |
+
loss_fn = nn.CrossEntropyLoss()
|
48 |
+
|
49 |
+
# Example usage
|
50 |
+
BS = 15
|
51 |
+
|
52 |
+
net_for_black = BoardEvaluationNet(BS).to(device)
|
53 |
+
net_for_white = BoardEvaluationNet(BS).to(device)
|
54 |
+
|
55 |
+
net_for_black.load_state_dict(torch.load(os.path.join(dir_path, 'train_data/model', 'best_loss=680.5813717259707.pth')))
|
56 |
+
|
57 |
+
optimizer = torch.optim.Adam(net_for_black.parameters(), lr=1e-5, betas=(0.9, 0.99),
|
58 |
+
eps=1e-8)
|
59 |
+
|
60 |
+
data_path = os.path.join(dir_path, 'train_data/data', 'train_data.pkl')
|
61 |
+
with open(data_path, 'rb') as f:
|
62 |
+
datas = pickle.load(f)
|
63 |
+
|
64 |
+
train_data_for_black = datas[1][:int(len(datas[1]) * 1)]
|
65 |
+
test_data_for_black = datas[1][int(len(datas[1]) * 0.8):]
|
66 |
+
train_data_for_white = datas[-1]
|
67 |
+
epochs = 500
|
68 |
+
batch_size = 32
|
69 |
+
train_dataset = TensorDataset(torch.stack([torch.tensor(item['state'], dtype=torch.float) for item in train_data_for_black]),
|
70 |
+
torch.stack([normalize(torch.tensor(item['scores'], dtype=torch.float)) for item in train_data_for_black]))
|
71 |
+
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
72 |
+
|
73 |
+
for epoch in range(epochs):
|
74 |
+
epoch_loss = 0
|
75 |
+
print('Epoch:', epoch)
|
76 |
+
for i, (states, scores) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
|
77 |
+
states = states.to(device)
|
78 |
+
scores = scores.to(device)
|
79 |
+
|
80 |
+
# print(input_tensor.shape)
|
81 |
+
infer_start = datetime.datetime.now()
|
82 |
+
output_tensor = net_for_black(states)
|
83 |
+
infer_end = datetime.datetime.now()
|
84 |
+
loss = loss_fn(output_tensor, scores)
|
85 |
+
print(loss.item())
|
86 |
+
exit(0)
|
87 |
+
loss.backward()
|
88 |
+
optimizer.step()
|
89 |
+
optimizer.zero_grad()
|
90 |
+
epoch_loss += loss.item()
|
91 |
+
|
92 |
+
writer.add_scalar('train/infer_time', (infer_end - infer_start).microseconds,
|
93 |
+
i + epoch * len(train_dataloader))
|
94 |
+
|
95 |
+
epoch_loss /= len(train_dataloader)
|
96 |
+
writer.add_scalar('train/epoch_loss', epoch_loss, epoch)
|
97 |
+
# test
|
98 |
+
with torch.no_grad():
|
99 |
+
test_loss = 0
|
100 |
+
net_for_black.eval()
|
101 |
+
for j, item in tqdm(enumerate(test_data_for_black), total=len(test_data_for_black)):
|
102 |
+
scores = normalize(torch.tensor(item['scores'], dtype=torch.float).to(device).unsqueeze(0)) # 将数据类型设为float
|
103 |
+
state = item['state']
|
104 |
+
input_tensor = torch.tensor(state, dtype=torch.float).to(device).unsqueeze(0) # 将数据类型设为float,并转移到设备上
|
105 |
+
output_tensor = net_for_black(input_tensor).to(device)
|
106 |
+
loss = loss_fn(output_tensor, scores)
|
107 |
+
test_loss += loss.item()
|
108 |
+
test_loss /=len(test_data_for_black)
|
109 |
+
writer.add_scalar('test/loss', test_loss, epoch)
|
110 |
+
if best > test_loss:
|
111 |
+
best = test_loss
|
112 |
+
model_path = os.path.join(dir_path, 'train_data/model')
|
113 |
+
if not os.path.exists(model_path):
|
114 |
+
os.makedirs(model_path)
|
115 |
+
torch.save(net_for_black.state_dict(),
|
116 |
+
os.path.join(model_path, f'best_loss={best}.pth'))
|
117 |
+
net_for_black.train()
|
Gomoku_Bot/minmax.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .cache import Cache
|
2 |
+
from .eval import FIVE
|
3 |
+
# Checked
|
4 |
+
|
5 |
+
|
6 |
+
MAX = 1000000000
|
7 |
+
cache_hits = {
|
8 |
+
"search": 0,
|
9 |
+
"total": 0,
|
10 |
+
"hit": 0
|
11 |
+
}
|
12 |
+
|
13 |
+
onlyThreeThreshold = 6
|
14 |
+
cache = Cache()
|
15 |
+
|
16 |
+
|
17 |
+
def factory(onlyThree=False, onlyFour=False):
|
18 |
+
|
19 |
+
def helper(board, role, depth, cDepth=0, path=(), alpha=-MAX, beta=MAX):
|
20 |
+
cache_hits["search"] += 1
|
21 |
+
if cDepth >= depth or board.isGameOver():
|
22 |
+
return [board.evaluate(role), None, (path)]
|
23 |
+
hash = board.hash()
|
24 |
+
prev = cache.get(hash)
|
25 |
+
if prev and prev["role"] == role:
|
26 |
+
if (
|
27 |
+
(abs(prev["value"]) >= FIVE or prev["depth"] >= depth - cDepth)
|
28 |
+
and prev["onlyThree"] == onlyThree
|
29 |
+
and prev["onlyFour"] == onlyFour
|
30 |
+
):
|
31 |
+
cache_hits["hit"] += 1
|
32 |
+
return [prev["value"], prev["move"], path + prev["path"]]
|
33 |
+
value = -MAX
|
34 |
+
move = None
|
35 |
+
bestPath = path # Copy the current path
|
36 |
+
bestDepth = 0
|
37 |
+
# points = board.getValuableMoves(role, cDepth, onlyThree or cDepth > onlyThreeThreshold, onlyFour)
|
38 |
+
points = board.getValuableMoves(role, cDepth, onlyThree or cDepth > onlyThreeThreshold, onlyFour)
|
39 |
+
if cDepth == 0:
|
40 |
+
print('points:', points)
|
41 |
+
if not len(points):
|
42 |
+
return [board.evaluate(role), None, path]
|
43 |
+
for d in range(cDepth + 1, depth + 1):
|
44 |
+
# 迭代加深过程中只找己方能赢的解,因此只搜索偶数层即可
|
45 |
+
if d % 2 != 0:
|
46 |
+
continue
|
47 |
+
breakAll = False
|
48 |
+
for point in points:
|
49 |
+
board.put(point[0], point[1], role)
|
50 |
+
newPath = tuple(list(path) + [point]) # Add current move to path
|
51 |
+
currentValue, currentMove, currentPath = helper(board, -role, d, cDepth + 1, tuple(newPath) , -beta, -alpha)
|
52 |
+
currentValue = -currentValue
|
53 |
+
board.undo()
|
54 |
+
## 迭代加深的过程中,除了能赢的棋,其他都不要
|
55 |
+
## 原因是:除了必胜的,其他评估不准。比如必输的棋,由于走的步数偏少,也会变成没有输,比如5
|
56 |
+
### 步之后输了,但是1步肯定不会输,这时候1步的分数是不准确的,显然不能选择。
|
57 |
+
if currentValue >= FIVE or d == depth:
|
58 |
+
# 必输的棋,也要挣扎一下,选择最长的路径
|
59 |
+
if (
|
60 |
+
currentValue > value
|
61 |
+
or (currentValue <= -FIVE and value <= -FIVE and len(currentPath) > bestDepth)
|
62 |
+
):
|
63 |
+
value = currentValue
|
64 |
+
move = point
|
65 |
+
bestPath = currentPath
|
66 |
+
bestDepth = len(currentPath)
|
67 |
+
alpha = max(alpha, value)
|
68 |
+
if alpha >= FIVE:
|
69 |
+
breakAll = True
|
70 |
+
break
|
71 |
+
if alpha >= beta:
|
72 |
+
break
|
73 |
+
if breakAll:
|
74 |
+
break
|
75 |
+
if (cDepth < onlyThreeThreshold or onlyThree or onlyFour) and (not prev or prev["depth"] < depth - cDepth):
|
76 |
+
cache_hits["total"] += 1
|
77 |
+
cache.put(hash, {
|
78 |
+
"depth": depth - cDepth,
|
79 |
+
"value": value,
|
80 |
+
"move": move,
|
81 |
+
"role": role,
|
82 |
+
"path": bestPath[cDepth:],
|
83 |
+
"onlyThree": onlyThree,
|
84 |
+
"onlyFour": onlyFour,
|
85 |
+
})
|
86 |
+
return [value, move, bestPath]
|
87 |
+
return helper
|
88 |
+
|
89 |
+
|
90 |
+
_minmax = factory()
|
91 |
+
vct = factory(True)
|
92 |
+
vcf = factory(False, True)
|
93 |
+
|
94 |
+
|
95 |
+
def minmax(board, role, depth=4, enableVCT=False):
|
96 |
+
|
97 |
+
if enableVCT:
|
98 |
+
vctDepth = depth + 8
|
99 |
+
value, move, bestPath = vct(board, role, vctDepth)
|
100 |
+
if value >= FIVE:
|
101 |
+
return [value, move, bestPath]
|
102 |
+
value, move, bestPath = _minmax(board, role, depth)
|
103 |
+
'''
|
104 |
+
// 假设对方有杀棋,先按自己的思路走,走完之后看对方是不是还有杀棋
|
105 |
+
// 如果对方没有了,那么就说明走的是对的
|
106 |
+
// 如果对方还是有,那么要对比对方的杀棋路径和自己没有走棋时的长短
|
107 |
+
// 如果走了棋之后路径变长了,说明走的是对的
|
108 |
+
// 如果走了棋之后,对方杀棋路径长度没变,甚至更短,说明走错了,此时就优先封堵对方
|
109 |
+
'''
|
110 |
+
board.put(move[0], move[1], role)
|
111 |
+
value2, move2, bestPath2 = vct(board.reverse(), role, vctDepth)
|
112 |
+
board.undo()
|
113 |
+
if value < FIVE and value2 == FIVE and len(bestPath2) > len(bestPath):
|
114 |
+
value3, move3, bestPath3 = vct(board.reverse(), role, vctDepth)
|
115 |
+
if len(bestPath2) <= len(bestPath3):
|
116 |
+
return [value, move2, bestPath2] # value2 是被挡住的,所以这里还是用value
|
117 |
+
return [value, move, bestPath]
|
118 |
+
else:
|
119 |
+
return _minmax(board, role, depth)
|
Gomoku_Bot/position.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from .config import config
|
3 |
+
# Checked
|
4 |
+
|
5 |
+
# Function to convert position to coordinate
|
6 |
+
def position2Coordinate(position: int, size: int) -> List[int]:
|
7 |
+
return [position // size, position % size]
|
8 |
+
|
9 |
+
# Function to convert coordinate to position
|
10 |
+
def coordinate2Position(x: int, y: int, size: int) -> int:
|
11 |
+
return x * size + y
|
12 |
+
|
13 |
+
# Check if points a and b are on the same line and the distance is less than maxDistance
|
14 |
+
def isLine(a: int, b: int, size: int) -> bool:
|
15 |
+
maxDistance = config["inLineDistance"]
|
16 |
+
[x1, y1] = position2Coordinate(a, size)
|
17 |
+
[x2, y2] = position2Coordinate(b, size)
|
18 |
+
return (
|
19 |
+
(x1 == x2 and abs(y1 - y2) < maxDistance) or
|
20 |
+
(y1 == y2 and abs(x1 - x2) < maxDistance) or
|
21 |
+
(abs(x1 - x2) == abs(y1 - y2) and abs(x1 - x2) < maxDistance)
|
22 |
+
)
|
23 |
+
|
24 |
+
# Check if all points in the array are on the same line as point p
|
25 |
+
def isAllInLine(p: int, arr: List[int], size: int) -> bool:
|
26 |
+
for i in range(len(arr)):
|
27 |
+
if not isLine(p, arr[i], size):
|
28 |
+
return False
|
29 |
+
return True
|
30 |
+
|
31 |
+
# Check if any point in the array is on the same line as point p
|
32 |
+
def hasInLine(p: int, arr: List[int], size: int) -> bool:
|
33 |
+
for i in range(len(arr)):
|
34 |
+
if isLine(p, arr[i], size):
|
35 |
+
return True
|
36 |
+
return False
|
Gomoku_Bot/shape.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import re
|
3 |
+
|
4 |
+
# Define patterns using regular expressions
|
5 |
+
patterns = {
|
6 |
+
'five': re.compile('11111'),
|
7 |
+
'block_five': re.compile('211111|111112'),
|
8 |
+
'four': re.compile('011110'),
|
9 |
+
'block_four': re.compile('10111|11011|11101|211110|211101|211011|210111|011112|101112|110112|111012'),
|
10 |
+
'three': re.compile('011100|011010|010110|001110'),
|
11 |
+
'block_three': re.compile('211100|211010|210110|001112|010112|011012'),
|
12 |
+
'two': re.compile('001100|011000|000110|010100|001010'),
|
13 |
+
}
|
14 |
+
|
15 |
+
# Define shapes with associated scores
|
16 |
+
shapes = {
|
17 |
+
'FIVE': 5,
|
18 |
+
'BLOCK_FIVE': 50,
|
19 |
+
'FOUR': 4,
|
20 |
+
'FOUR_FOUR': 44, # Double four
|
21 |
+
'FOUR_THREE': 43, # Four with an open three
|
22 |
+
'THREE_THREE': 33, # Double three
|
23 |
+
'BLOCK_FOUR': 40,
|
24 |
+
'THREE': 3,
|
25 |
+
'BLOCK_THREE': 30,
|
26 |
+
'TWO_TWO': 22, # Double two
|
27 |
+
'TWO': 2,
|
28 |
+
'NONE': 0
|
29 |
+
}
|
30 |
+
|
31 |
+
# Initialize a performance record
|
32 |
+
performance = {
|
33 |
+
'five': 0,
|
34 |
+
'block_five': 0,
|
35 |
+
'four': 0,
|
36 |
+
'block_four': 0,
|
37 |
+
'three': 0,
|
38 |
+
'block_three': 0,
|
39 |
+
'two': 0,
|
40 |
+
'none': 0,
|
41 |
+
'total': 0
|
42 |
+
}
|
43 |
+
|
44 |
+
# Function to detect shapes on the board
|
45 |
+
def get_shape(board, x, y, offset_x, offset_y, role):
|
46 |
+
"""
|
47 |
+
Detect shape at a given board position.
|
48 |
+
:param board: The game board.
|
49 |
+
:param x: X-coordinate.
|
50 |
+
:param y: Y-coordinate.
|
51 |
+
:param offset_x: X-direction offset for scanning.
|
52 |
+
:param offset_y: Y-direction offset for scanning.
|
53 |
+
:param role: Current player's role.
|
54 |
+
:return: A tuple of shape, self count, opponent count, and empty count.
|
55 |
+
"""
|
56 |
+
opponent = -role
|
57 |
+
empty_count = 0
|
58 |
+
self_count = 1
|
59 |
+
opponent_count = 0
|
60 |
+
shape = shapes['NONE']
|
61 |
+
|
62 |
+
# Skip empty nodes
|
63 |
+
if (
|
64 |
+
board[x + offset_x + 1][y + offset_y + 1] == 0
|
65 |
+
and board[x - offset_x + 1][y - offset_y + 1] == 0
|
66 |
+
and board[x + 2 * offset_x + 1][y + 2 * offset_y + 1] == 0
|
67 |
+
and board[x - 2 * offset_x + 1][y - 2 * offset_y + 1] == 0
|
68 |
+
):
|
69 |
+
return [0, self_count, opponent_count, empty_count]
|
70 |
+
# Check for 'two' pattern
|
71 |
+
for i in range(-3, 4):
|
72 |
+
if i == 0:
|
73 |
+
continue
|
74 |
+
nx, ny = x + i * offset_x, y + i * offset_y
|
75 |
+
current_role = board.get((nx, ny))
|
76 |
+
if current_role is None:
|
77 |
+
continue
|
78 |
+
if current_role == 2:
|
79 |
+
opponent_count += 1
|
80 |
+
elif current_role == role:
|
81 |
+
self_count += 1
|
82 |
+
elif current_role == 0:
|
83 |
+
empty_count += 1
|
84 |
+
|
85 |
+
if self_count == 2:
|
86 |
+
if opponent_count == 0:
|
87 |
+
return shapes['TWO'], self_count, opponent_count, empty_count
|
88 |
+
else:
|
89 |
+
return shapes['NONE'], self_count, opponent_count, empty_count
|
90 |
+
|
91 |
+
# Reset counts and prepare string for pattern matching
|
92 |
+
empty_count, self_count, opponent_count = 0, 1, 0
|
93 |
+
result_string = '1'
|
94 |
+
|
95 |
+
# Build result string for pattern matching
|
96 |
+
for i in range(1, 6):
|
97 |
+
nx = x + i * offset_x + 1
|
98 |
+
ny = y + i * offset_y + 1
|
99 |
+
currentRole = board[nx][ny]
|
100 |
+
if currentRole == 2:
|
101 |
+
result_string += '2'
|
102 |
+
elif currentRole == 0:
|
103 |
+
result_string += '0'
|
104 |
+
else:
|
105 |
+
result_string += '1' if currentRole == role else '2'
|
106 |
+
if currentRole == 2 or currentRole == opponent:
|
107 |
+
opponent_count += 1
|
108 |
+
break
|
109 |
+
if currentRole == 0:
|
110 |
+
empty_count += 1
|
111 |
+
if currentRole == role:
|
112 |
+
self_count += 1
|
113 |
+
|
114 |
+
for i in range(1, 6):
|
115 |
+
nx = x - i * offset_x + 1
|
116 |
+
ny = y - i * offset_y + 1
|
117 |
+
currentRole = board[nx][ny]
|
118 |
+
if currentRole == 2:
|
119 |
+
result_string = '2' + result_string
|
120 |
+
elif currentRole == 0:
|
121 |
+
result_string = '0' + result_string
|
122 |
+
else:
|
123 |
+
result_string = '1' if currentRole == role else '2' + result_string
|
124 |
+
if currentRole == 2 or currentRole == opponent:
|
125 |
+
opponent_count += 1
|
126 |
+
break
|
127 |
+
if currentRole == 0:
|
128 |
+
empty_count += 1
|
129 |
+
if currentRole == role:
|
130 |
+
self_count += 1
|
131 |
+
|
132 |
+
# Check patterns and update performance
|
133 |
+
for pattern_key, shape_key in [('five', 'FIVE'), ('four', 'FOUR'), ('block_four', 'BLOCK_FOUR'),
|
134 |
+
('three', 'THREE'), ('block_three', 'BLOCK_THREE'), ('two', 'TWO')]:
|
135 |
+
if patterns[pattern_key].search(result_string):
|
136 |
+
shape = shapes[shape_key]
|
137 |
+
performance[pattern_key] += 1
|
138 |
+
performance['total'] += 1
|
139 |
+
break
|
140 |
+
## 尽量减少多余字符串生成
|
141 |
+
if self_count <= 1 or len(result_string) < 5:
|
142 |
+
return shape, self_count, opponent_count, empty_count
|
143 |
+
|
144 |
+
return shape, self_count, opponent_count, empty_count
|
145 |
+
|
146 |
+
def count_shape(board, x, y, offset_x, offset_y, role):
|
147 |
+
opponent = - role
|
148 |
+
|
149 |
+
inner_empty_count = 0 # Number of empty positions inside the player's stones
|
150 |
+
temp_empty_count = 0
|
151 |
+
self_count = 0 # Number of the player's stones in the shape
|
152 |
+
total_length = 0
|
153 |
+
|
154 |
+
side_empty_count = 0 # Number of empty positions on the side of the shape
|
155 |
+
no_empty_self_count = 0
|
156 |
+
one_empty_self_count = 0
|
157 |
+
|
158 |
+
# Right direction
|
159 |
+
for i in range(1, 6):
|
160 |
+
nx = x + i * offset_x + 1
|
161 |
+
ny = y + i * offset_y + 1
|
162 |
+
current_role = board[nx][ny]
|
163 |
+
if current_role == 2 or current_role == opponent:
|
164 |
+
break
|
165 |
+
if current_role == role:
|
166 |
+
self_count += 1
|
167 |
+
side_empty_count = 0
|
168 |
+
if temp_empty_count:
|
169 |
+
inner_empty_count += temp_empty_count
|
170 |
+
temp_empty_count = 0
|
171 |
+
if inner_empty_count == 0:
|
172 |
+
no_empty_self_count += 1
|
173 |
+
one_empty_self_count += 1
|
174 |
+
elif inner_empty_count == 1:
|
175 |
+
one_empty_self_count += 1
|
176 |
+
total_length += 1
|
177 |
+
if current_role == 0:
|
178 |
+
temp_empty_count += 1
|
179 |
+
side_empty_count += 1
|
180 |
+
if side_empty_count >= 2:
|
181 |
+
break
|
182 |
+
|
183 |
+
if not inner_empty_count:
|
184 |
+
one_empty_self_count = 0
|
185 |
+
|
186 |
+
return {
|
187 |
+
'self_count': self_count,
|
188 |
+
'total_length': total_length,
|
189 |
+
'no_empty_self_count': no_empty_self_count,
|
190 |
+
'one_empty_self_count': one_empty_self_count,
|
191 |
+
'inner_empty_count': inner_empty_count,
|
192 |
+
'side_empty_count': side_empty_count
|
193 |
+
}
|
194 |
+
|
195 |
+
# Fast shape detection function
|
196 |
+
def get_shape_fast(board, x, y, offsetX, offsetY, role):
|
197 |
+
if (
|
198 |
+
board[x + offsetX + 1][y + offsetY + 1] == 0
|
199 |
+
and board[x - offsetX + 1][y - offsetY + 1] == 0
|
200 |
+
and board[x + 2 * offsetX + 1][y + 2 * offsetY + 1] == 0
|
201 |
+
and board[x - 2 * offsetX + 1][y - 2 * offsetY + 1] == 0
|
202 |
+
):
|
203 |
+
return [shapes['NONE'], 1]
|
204 |
+
|
205 |
+
selfCount = 1
|
206 |
+
totalLength = 1
|
207 |
+
shape = shapes['NONE']
|
208 |
+
|
209 |
+
leftEmpty = 0
|
210 |
+
rightEmpty = 0
|
211 |
+
noEmptySelfCount = 1
|
212 |
+
OneEmptySelfCount = 1
|
213 |
+
|
214 |
+
left = count_shape(board, x, y, -offsetX, -offsetY, role)
|
215 |
+
right = count_shape(board, x, y, offsetX, offsetY, role)
|
216 |
+
|
217 |
+
selfCount = left['self_count'] + right['self_count'] + 1
|
218 |
+
totalLength = left['total_length'] + right['total_length'] + 1
|
219 |
+
noEmptySelfCount = left['no_empty_self_count'] + right['no_empty_self_count'] + 1
|
220 |
+
OneEmptySelfCount = max(
|
221 |
+
left['one_empty_self_count'] + right['no_empty_self_count'],
|
222 |
+
left['no_empty_self_count'] + right['one_empty_self_count'],
|
223 |
+
) + 1
|
224 |
+
rightEmpty = right['side_empty_count']
|
225 |
+
leftEmpty = left['side_empty_count']
|
226 |
+
|
227 |
+
if totalLength < 5:
|
228 |
+
return [shape, selfCount]
|
229 |
+
|
230 |
+
if noEmptySelfCount >= 5:
|
231 |
+
if rightEmpty > 0 and leftEmpty > 0:
|
232 |
+
return [shapes['FIVE'], selfCount]
|
233 |
+
else:
|
234 |
+
return [shapes['BLOCK_FIVE'], selfCount]
|
235 |
+
|
236 |
+
if noEmptySelfCount == 4:
|
237 |
+
if (
|
238 |
+
(rightEmpty >= 1 or right['one_empty_self_count'] > right['no_empty_self_count'])
|
239 |
+
and (leftEmpty >= 1 or left['one_empty_self_count'] > left['no_empty_self_count'])
|
240 |
+
):
|
241 |
+
return [shapes['FOUR'], selfCount]
|
242 |
+
elif not (rightEmpty == 0 and leftEmpty == 0):
|
243 |
+
return [shapes['BLOCK_FOUR'], selfCount]
|
244 |
+
|
245 |
+
if OneEmptySelfCount == 4:
|
246 |
+
return [shapes['BLOCK_FOUR'], selfCount]
|
247 |
+
|
248 |
+
if noEmptySelfCount == 3:
|
249 |
+
if (rightEmpty >= 2 and leftEmpty >= 1) or (rightEmpty >= 1 and leftEmpty >= 2):
|
250 |
+
return [shapes['THREE'], selfCount]
|
251 |
+
else:
|
252 |
+
return [shapes['BLOCK_THREE'], selfCount]
|
253 |
+
|
254 |
+
if OneEmptySelfCount == 3:
|
255 |
+
if rightEmpty >= 1 and leftEmpty >= 1:
|
256 |
+
return [shapes['THREE'], selfCount]
|
257 |
+
else:
|
258 |
+
return [shapes['BLOCK_THREE'], selfCount]
|
259 |
+
|
260 |
+
if (noEmptySelfCount == 2 or OneEmptySelfCount == 2) and totalLength > 5:
|
261 |
+
shape = shapes['TWO']
|
262 |
+
|
263 |
+
return [shape, selfCount]
|
264 |
+
|
265 |
+
# Helper functions to check for specific shapes
|
266 |
+
def is_five(shape):
|
267 |
+
# Checked
|
268 |
+
return shape in [shapes['FIVE'], shapes['BLOCK_FIVE']]
|
269 |
+
|
270 |
+
def is_four(shape):
|
271 |
+
# Checked
|
272 |
+
return shape in [shapes['FOUR'], shapes['BLOCK_FOUR']]
|
273 |
+
|
274 |
+
# Function to get all shapes at a specific point
|
275 |
+
def get_all_shapes_of_point(shape_cache, x, y, role = None):
|
276 |
+
# Checked
|
277 |
+
roles = [role] if role else [1, -1]
|
278 |
+
result = []
|
279 |
+
for r in roles:
|
280 |
+
for d in range(4):
|
281 |
+
shape = shape_cache[r][d][x][y]
|
282 |
+
if shape > 0:
|
283 |
+
result.append(shape)
|
284 |
+
return result
|
285 |
+
|
286 |
+
|
287 |
+
if __name__ == "__main__":
|
288 |
+
pass
|
Gomoku_Bot/zobrist.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
# Checked
|
3 |
+
class ZobristCache:
|
4 |
+
def __init__(self, size):
|
5 |
+
self.size = size
|
6 |
+
self.zobristTable = self.initializeZobristTable(size)
|
7 |
+
self.hash = 0
|
8 |
+
|
9 |
+
def initializeZobristTable(self, size):
|
10 |
+
table = []
|
11 |
+
for i in range(size):
|
12 |
+
table.append([])
|
13 |
+
for j in range(size):
|
14 |
+
table[i].append({
|
15 |
+
1: random.getrandbits(64), # black
|
16 |
+
-1: random.getrandbits(64) # white
|
17 |
+
})
|
18 |
+
return table
|
19 |
+
|
20 |
+
def togglePiece(self, x, y, role):
|
21 |
+
self.hash ^= self.zobristTable[x][y][role]
|
22 |
+
|
23 |
+
def getHash(self):
|
24 |
+
return self.hash
|
25 |
+
|
26 |
+
if __name__ == '__main__':
|
27 |
+
# Example usage
|
28 |
+
size = 8
|
29 |
+
cache = ZobristCache(size)
|
30 |
+
x = 3
|
31 |
+
y = 4
|
32 |
+
role = 1
|
33 |
+
cache.togglePiece(x, y, role)
|
34 |
+
hash_value = cache.getHash()
|
35 |
+
print(hash_value)
|
ai.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from transformers import GPT2LMHeadModel
|
4 |
+
|
5 |
+
|
6 |
+
def load_model(model_name: str = "snoop2head/Gomoku-GPT2") -> GPT2LMHeadModel:
|
7 |
+
gpt2 = GPT2LMHeadModel.from_pretrained(model_name)
|
8 |
+
return gpt2
|
9 |
+
|
10 |
+
|
11 |
+
BOS_TOKEN_ID = 401
|
12 |
+
PAD_TOKEN_ID = 402
|
13 |
+
EOS_TOKEN_ID = 403
|
14 |
+
|
15 |
+
|
16 |
+
def generate_gpt2(model: GPT2LMHeadModel, input_ids: torch.LongTensor,temperature = 0.7) -> list:
|
17 |
+
"""
|
18 |
+
input_ids: [batch_size, seq_len] torch.LongTensor
|
19 |
+
output_ids: [seq_len] list
|
20 |
+
"""
|
21 |
+
output_ids = model.generate(
|
22 |
+
input_ids,
|
23 |
+
max_length=128,
|
24 |
+
num_beams=5,
|
25 |
+
temperature= temperature,
|
26 |
+
pad_token_id=PAD_TOKEN_ID,
|
27 |
+
eos_token_id=EOS_TOKEN_ID,
|
28 |
+
)
|
29 |
+
return output_ids.squeeze().tolist()
|
30 |
+
|
31 |
+
|
32 |
+
def change_to_1d_coordinate(board: np.ndarray, x: int, y: int) -> int:
|
33 |
+
"""change 2d coordinate to 1d coordinate"""
|
34 |
+
return x * board.shape[1] + y
|
35 |
+
|
36 |
+
|
37 |
+
def change_to_2d_coordinate(board: np.ndarray, coordinate: int) -> tuple:
|
38 |
+
"""change 1d coordinate to 2d coordinate"""
|
39 |
+
return (coordinate // board.shape[1], coordinate % board.shape[1])
|
const.py
CHANGED
@@ -14,10 +14,13 @@ _BOARD_SIZE_1D = _BOARD_SIZE * _BOARD_SIZE
|
|
14 |
_BLANK = 0
|
15 |
_BLACK = 1
|
16 |
_WHITE = 2
|
|
|
17 |
_PLAYER_SYMBOL = {
|
18 |
_WHITE: "⚪",
|
19 |
_BLANK: "➕",
|
20 |
_BLACK: "⚫",
|
|
|
|
|
21 |
}
|
22 |
_PLAYER_COLOR = {
|
23 |
_WHITE: "AI",
|
|
|
14 |
_BLANK = 0
|
15 |
_BLACK = 1
|
16 |
_WHITE = 2
|
17 |
+
_NEW = 3
|
18 |
_PLAYER_SYMBOL = {
|
19 |
_WHITE: "⚪",
|
20 |
_BLANK: "➕",
|
21 |
_BLACK: "⚫",
|
22 |
+
_NEW: "🔴",
|
23 |
+
|
24 |
}
|
25 |
_PLAYER_COLOR = {
|
26 |
_WHITE: "AI",
|
pages/AI_VS_AI.py
ADDED
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
FileName: app.py
|
3 |
+
Author: Benhao Huang
|
4 |
+
Create Date: 2023/11/19
|
5 |
+
Description: this file is used to display our project and add visualization elements to the game, using Streamlit
|
6 |
+
"""
|
7 |
+
|
8 |
+
import time
|
9 |
+
import pandas as pd
|
10 |
+
from copy import deepcopy
|
11 |
+
import torch
|
12 |
+
|
13 |
+
# import torch
|
14 |
+
import numpy as np
|
15 |
+
import streamlit as st
|
16 |
+
from scipy.signal import convolve # this is used to check if any player wins
|
17 |
+
from streamlit import session_state
|
18 |
+
from streamlit_server_state import server_state, server_state_lock
|
19 |
+
from Gomoku_MCTS import MCTSpure, alphazero, Board, PolicyValueNet
|
20 |
+
from Gomoku_Bot import Gomoku_bot
|
21 |
+
from Gomoku_Bot import Board as Gomoku_bot_board
|
22 |
+
import matplotlib.pyplot as plt
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
from const import (
|
27 |
+
_BLACK, # 1, for human
|
28 |
+
_WHITE, # 2 , for AI
|
29 |
+
_BLANK,
|
30 |
+
_PLAYER_COLOR,
|
31 |
+
_PLAYER_SYMBOL,
|
32 |
+
_ROOM_COLOR,
|
33 |
+
_VERTICAL,
|
34 |
+
_NEW,
|
35 |
+
_HORIZONTAL,
|
36 |
+
_DIAGONAL_UP_LEFT,
|
37 |
+
_DIAGONAL_UP_RIGHT,
|
38 |
+
_BOARD_SIZE,
|
39 |
+
_BOARD_SIZE_1D,
|
40 |
+
_AI_AID_INFO
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
from ai import (
|
45 |
+
BOS_TOKEN_ID,
|
46 |
+
generate_gpt2,
|
47 |
+
load_model,
|
48 |
+
)
|
49 |
+
|
50 |
+
gpt2 = load_model()
|
51 |
+
|
52 |
+
|
53 |
+
# Utils
|
54 |
+
class Room:
|
55 |
+
def __init__(self, room_id) -> None:
|
56 |
+
self.ROOM_ID = room_id
|
57 |
+
# self.BOARD = np.zeros(shape=(_BOARD_SIZE, _BOARD_SIZE), dtype=int)
|
58 |
+
self.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5, players=[_BLACK, _WHITE])
|
59 |
+
self.PLAYER = _BLACK
|
60 |
+
self.TURN = self.PLAYER
|
61 |
+
self.HISTORY = (0, 0)
|
62 |
+
self.WINNER = _BLANK
|
63 |
+
self.TIME = time.time()
|
64 |
+
self.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
65 |
+
self.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
66 |
+
'AlphaZero': alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, 'Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth').policy_value_fn, c_puct=5, n_playout=100),
|
67 |
+
'Gomoku Bot': Gomoku_bot(self.gomoku_bot_board, -1)}
|
68 |
+
self.MCTS = self.MCTS_dict['AlphaZero']
|
69 |
+
self.last_mcts = self.MCTS
|
70 |
+
self.AID_MCTS = self.MCTS_dict['AlphaZero']
|
71 |
+
self.COORDINATE_1D = [BOS_TOKEN_ID]
|
72 |
+
self.current_move = -1
|
73 |
+
self.simula_time_list = []
|
74 |
+
|
75 |
+
|
76 |
+
def change_turn(cur):
|
77 |
+
return cur % 2 + 1
|
78 |
+
|
79 |
+
|
80 |
+
# Initialize the game
|
81 |
+
if "ROOM" not in session_state:
|
82 |
+
session_state.ROOM = Room("local")
|
83 |
+
if "OWNER" not in session_state:
|
84 |
+
session_state.OWNER = False
|
85 |
+
if "USE_AIAID" not in session_state:
|
86 |
+
session_state.USE_AIAID = False
|
87 |
+
|
88 |
+
# Check server health
|
89 |
+
if "ROOMS" not in server_state:
|
90 |
+
with server_state_lock["ROOMS"]:
|
91 |
+
server_state.ROOMS = {}
|
92 |
+
|
93 |
+
def handle_oppo_model_selection():
|
94 |
+
if st.session_state['selected_oppo_model'] == 'Gomoku Bot':
|
95 |
+
session_state.ROOM.last_mcts = session_state.ROOM.MCTS # since use different mechanism, store previous mcts first
|
96 |
+
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict['Gomoku Bot']
|
97 |
+
return
|
98 |
+
else:
|
99 |
+
TreeNode = session_state.ROOM.last_mcts.mcts._root
|
100 |
+
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
101 |
+
new_mct.mcts._root = deepcopy(TreeNode)
|
102 |
+
session_state.ROOM.MCTS = new_mct
|
103 |
+
session_state.ROOM.last_mcts = new_mct
|
104 |
+
return
|
105 |
+
|
106 |
+
def handle_aid_model_selection():
|
107 |
+
if st.session_state['selected_aid_model'] == 'None':
|
108 |
+
session_state.USE_AIAID = False
|
109 |
+
return
|
110 |
+
session_state.USE_AIAID = True
|
111 |
+
TreeNode = session_state.ROOM.MCTS.mcts._root # use the same tree node
|
112 |
+
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_aid_model']]
|
113 |
+
new_mct.mcts._root = deepcopy(TreeNode)
|
114 |
+
session_state.ROOM.AID_MCTS = new_mct
|
115 |
+
return
|
116 |
+
|
117 |
+
if 'selected_oppo_model' not in st.session_state:
|
118 |
+
st.session_state['selected_oppo_model'] = 'AlphaZero' # 默认值
|
119 |
+
|
120 |
+
if 'selected_aid_model' not in st.session_state:
|
121 |
+
st.session_state['selected_aid_model'] = 'AlphaZero' # 默认值
|
122 |
+
|
123 |
+
# Layout
|
124 |
+
TITLE = st.empty()
|
125 |
+
Model_Switch = st.empty()
|
126 |
+
|
127 |
+
TITLE.header("🤖 AI 3603 Gomoku")
|
128 |
+
selected_oppo_option = Model_Switch.selectbox('Select Opponent Model', ['Pure MCTS', 'AlphaZero','Gomoku Bot'], index=1, key='oppo_model')
|
129 |
+
|
130 |
+
if st.session_state['selected_oppo_model'] != selected_oppo_option:
|
131 |
+
st.session_state['selected_oppo_model'] = selected_oppo_option
|
132 |
+
handle_oppo_model_selection()
|
133 |
+
|
134 |
+
ROUND_INFO = st.empty()
|
135 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
136 |
+
BOARD_PLATE = [
|
137 |
+
[cell.empty() for cell in st.columns([1 for _ in range(_BOARD_SIZE)])] for _ in range(_BOARD_SIZE)
|
138 |
+
]
|
139 |
+
LOG = st.empty()
|
140 |
+
|
141 |
+
# Sidebar
|
142 |
+
SCORE_TAG = st.sidebar.empty()
|
143 |
+
SCORE_PLATE = st.sidebar.columns(2)
|
144 |
+
# History scores
|
145 |
+
SCORE_TAG.subheader("Scores")
|
146 |
+
|
147 |
+
PLAY_MODE_INFO = st.sidebar.container()
|
148 |
+
MULTIPLAYER_TAG = st.sidebar.empty()
|
149 |
+
with st.sidebar.container():
|
150 |
+
ANOTHER_ROUND = st.empty()
|
151 |
+
RESTART = st.empty()
|
152 |
+
AIAID = st.empty()
|
153 |
+
EXIT = st.empty()
|
154 |
+
selected_aid_option = AIAID.selectbox('Select Assistant Model', ['None', 'Pure MCTS', 'AlphaZero'], index=0, key='aid_model')
|
155 |
+
if st.session_state['selected_aid_model'] != selected_aid_option:
|
156 |
+
st.session_state['selected_aid_model'] = selected_aid_option
|
157 |
+
handle_aid_model_selection()
|
158 |
+
|
159 |
+
GAME_INFO = st.sidebar.container()
|
160 |
+
message = st.empty()
|
161 |
+
PLAY_MODE_INFO.write("---\n\n**You are Black, AI agent is White.**")
|
162 |
+
GAME_INFO.markdown(
|
163 |
+
"""
|
164 |
+
---
|
165 |
+
# <span style="color:black;">Freestyle Gomoku game. 🎲</span>
|
166 |
+
- no restrictions 🚫
|
167 |
+
- no regrets 😎
|
168 |
+
- no regrets 😎
|
169 |
+
- swap players after one round is over 🔁
|
170 |
+
Powered by an AlphaZero approach with our own improvements! 🚀 For the specific details, please check out our <a href="insert_report_link_here" style="color:blue;">report</a>.
|
171 |
+
##### Adapted and improved by us! 🌟 <a href="https://github.com/Lijiaxin0111/AI_3603_BIGHOME" style="color:blue;">Our Github repo</a>
|
172 |
+
""",
|
173 |
+
unsafe_allow_html=True,
|
174 |
+
)
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
def restart() -> None:
|
179 |
+
"""
|
180 |
+
Restart the game.
|
181 |
+
"""
|
182 |
+
session_state.ROOM = Room(session_state.ROOM.ROOM_ID)
|
183 |
+
st.session_state['selected_oppo_model'] = 'AlphaZero'
|
184 |
+
|
185 |
+
RESTART.button(
|
186 |
+
"Reset",
|
187 |
+
on_click=restart,
|
188 |
+
help="Clear the board as well as the scores",
|
189 |
+
)
|
190 |
+
|
191 |
+
|
192 |
+
# Draw the board
|
193 |
+
def gomoku():
|
194 |
+
"""
|
195 |
+
Draw the board.
|
196 |
+
Handle the main logic.
|
197 |
+
"""
|
198 |
+
|
199 |
+
# Restart the game
|
200 |
+
|
201 |
+
# Continue new round
|
202 |
+
def another_round() -> None:
|
203 |
+
"""
|
204 |
+
Continue new round.
|
205 |
+
"""
|
206 |
+
session_state.ROOM = deepcopy(session_state.ROOM)
|
207 |
+
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
|
208 |
+
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
209 |
+
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
210 |
+
'AlphaZero': alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, 'Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth').policy_value_fn, c_puct=5, n_playout=100),
|
211 |
+
'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
|
212 |
+
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
213 |
+
session_state.ROOM.last_mcts = session_state.ROOM.MCTS
|
214 |
+
session_state.ROOM.PLAYER = session_state.ROOM.PLAYER
|
215 |
+
session_state.ROOM.TURN = session_state.ROOM.PLAYER
|
216 |
+
session_state.ROOM.WINNER = _BLANK # 0
|
217 |
+
session_state.ROOM.COORDINATE_1D = [BOS_TOKEN_ID]
|
218 |
+
|
219 |
+
# Room status sync
|
220 |
+
def sync_room() -> bool:
|
221 |
+
room_id = session_state.ROOM.ROOM_ID
|
222 |
+
if room_id not in server_state.ROOMS.keys():
|
223 |
+
session_state.ROOM = Room("local")
|
224 |
+
return False
|
225 |
+
elif server_state.ROOMS[room_id].TIME == session_state.ROOM.TIME:
|
226 |
+
return False
|
227 |
+
elif server_state.ROOMS[room_id].TIME < session_state.ROOM.TIME:
|
228 |
+
# Only acquire the lock when writing to the server state
|
229 |
+
with server_state_lock["ROOMS"]:
|
230 |
+
server_rooms = server_state.ROOMS
|
231 |
+
server_rooms[room_id] = session_state.ROOM
|
232 |
+
server_state.ROOMS = server_rooms
|
233 |
+
return True
|
234 |
+
else:
|
235 |
+
session_state.ROOM = server_state.ROOMS[room_id]
|
236 |
+
return True
|
237 |
+
|
238 |
+
# Check if winner emerge from move
|
239 |
+
def check_win() -> int:
|
240 |
+
"""
|
241 |
+
Use convolution to check if any player wins.
|
242 |
+
"""
|
243 |
+
vertical = convolve(
|
244 |
+
session_state.ROOM.BOARD.board_map,
|
245 |
+
_VERTICAL,
|
246 |
+
mode="same",
|
247 |
+
)
|
248 |
+
horizontal = convolve(
|
249 |
+
session_state.ROOM.BOARD.board_map,
|
250 |
+
_HORIZONTAL,
|
251 |
+
mode="same",
|
252 |
+
)
|
253 |
+
diagonal_up_left = convolve(
|
254 |
+
session_state.ROOM.BOARD.board_map,
|
255 |
+
_DIAGONAL_UP_LEFT,
|
256 |
+
mode="same",
|
257 |
+
)
|
258 |
+
diagonal_up_right = convolve(
|
259 |
+
session_state.ROOM.BOARD.board_map,
|
260 |
+
_DIAGONAL_UP_RIGHT,
|
261 |
+
mode="same",
|
262 |
+
)
|
263 |
+
if (
|
264 |
+
np.max(
|
265 |
+
[
|
266 |
+
np.max(vertical),
|
267 |
+
np.max(horizontal),
|
268 |
+
np.max(diagonal_up_left),
|
269 |
+
np.max(diagonal_up_right),
|
270 |
+
]
|
271 |
+
)
|
272 |
+
== 5 * _BLACK
|
273 |
+
):
|
274 |
+
winner = _BLACK
|
275 |
+
elif (
|
276 |
+
np.min(
|
277 |
+
[
|
278 |
+
np.min(vertical),
|
279 |
+
np.min(horizontal),
|
280 |
+
np.min(diagonal_up_left),
|
281 |
+
np.min(diagonal_up_right),
|
282 |
+
]
|
283 |
+
)
|
284 |
+
== 5 * _WHITE
|
285 |
+
):
|
286 |
+
winner = _WHITE
|
287 |
+
else:
|
288 |
+
winner = _BLANK
|
289 |
+
return winner
|
290 |
+
|
291 |
+
# Triggers the board response on click
|
292 |
+
def handle_click(x, y):
|
293 |
+
"""
|
294 |
+
Controls whether to pass on / continue current board / may start new round
|
295 |
+
"""
|
296 |
+
if session_state.ROOM.BOARD.board_map[x][y] != _BLANK:
|
297 |
+
pass
|
298 |
+
elif (
|
299 |
+
session_state.ROOM.ROOM_ID in server_state.ROOMS.keys()
|
300 |
+
and _ROOM_COLOR[session_state.OWNER]
|
301 |
+
!= server_state.ROOMS[session_state.ROOM.ROOM_ID].TURN
|
302 |
+
):
|
303 |
+
sync_room()
|
304 |
+
|
305 |
+
# normal play situation
|
306 |
+
elif session_state.ROOM.WINNER == _BLANK:
|
307 |
+
# session_state.ROOM = deepcopy(session_state.ROOM)
|
308 |
+
# print("View of human player: ", session_state.ROOM.BOARD.board_map)
|
309 |
+
move = session_state.ROOM.BOARD.location_to_move((x, y))
|
310 |
+
session_state.ROOM.current_move = move
|
311 |
+
session_state.ROOM.BOARD.do_move(move)
|
312 |
+
# Gomoku Bot BOARD
|
313 |
+
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(move // _BOARD_SIZE, move % _BOARD_SIZE)
|
314 |
+
session_state.ROOM.BOARD.board_map[x][y] = session_state.ROOM.TURN
|
315 |
+
session_state.ROOM.COORDINATE_1D.append(x * _BOARD_SIZE + y)
|
316 |
+
|
317 |
+
session_state.ROOM.TURN = change_turn(session_state.ROOM.TURN)
|
318 |
+
win, winner = session_state.ROOM.BOARD.game_end()
|
319 |
+
if win:
|
320 |
+
session_state.ROOM.WINNER = winner
|
321 |
+
session_state.ROOM.HISTORY = (
|
322 |
+
session_state.ROOM.HISTORY[0]
|
323 |
+
+ int(session_state.ROOM.WINNER == _WHITE),
|
324 |
+
session_state.ROOM.HISTORY[1]
|
325 |
+
+ int(session_state.ROOM.WINNER == _BLACK),
|
326 |
+
)
|
327 |
+
session_state.ROOM.TIME = time.time()
|
328 |
+
|
329 |
+
def forbid_click(x, y):
|
330 |
+
# st.warning('This posistion has been occupied!!!!', icon="⚠️")
|
331 |
+
st.error("({}, {}) has been occupied!!)".format(x, y), icon="🚨")
|
332 |
+
|
333 |
+
# Draw board
|
334 |
+
def draw_board(response: bool):
|
335 |
+
"""construct each buttons for all cells of the board"""
|
336 |
+
if session_state.USE_AIAID and session_state.ROOM.WINNER == _BLANK and session_state.ROOM.TURN == _BLACK:
|
337 |
+
if session_state.USE_AIAID:
|
338 |
+
copy_mcts = deepcopy(session_state.ROOM.AID_MCTS.mcts)
|
339 |
+
_, acts_aid, probs_aid, simul_mean_time_aid = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
|
340 |
+
sorted_acts_probs = sorted(zip(acts_aid, probs_aid), key=lambda x: x[1], reverse=True)
|
341 |
+
top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
|
342 |
+
top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
|
343 |
+
if response and session_state.ROOM.TURN == _BLACK: # human turn
|
344 |
+
print("Your turn")
|
345 |
+
# construction of clickable buttons
|
346 |
+
cur_move = (session_state.ROOM.current_move // _BOARD_SIZE, session_state.ROOM.current_move % _BOARD_SIZE)
|
347 |
+
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
348 |
+
# print("row:", row)
|
349 |
+
for j, cell in enumerate(row):
|
350 |
+
if (
|
351 |
+
i * _BOARD_SIZE + j
|
352 |
+
in (session_state.ROOM.COORDINATE_1D)
|
353 |
+
):
|
354 |
+
if i == cur_move[0] and j == cur_move[1]:
|
355 |
+
BOARD_PLATE[i][j].button(
|
356 |
+
_PLAYER_SYMBOL[_NEW],
|
357 |
+
key=f"{i}:{j}",
|
358 |
+
args=(i, j),
|
359 |
+
on_click=handle_click,
|
360 |
+
)
|
361 |
+
else:
|
362 |
+
# disable click for GPT choices
|
363 |
+
BOARD_PLATE[i][j].button(
|
364 |
+
_PLAYER_SYMBOL[cell],
|
365 |
+
key=f"{i}:{j}",
|
366 |
+
args=(i, j),
|
367 |
+
on_click=forbid_click
|
368 |
+
)
|
369 |
+
else:
|
370 |
+
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts:
|
371 |
+
# enable click for other cells available for human choices
|
372 |
+
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
|
373 |
+
BOARD_PLATE[i][j].button(
|
374 |
+
_PLAYER_SYMBOL[cell] + f"({round(prob, 2)})",
|
375 |
+
key=f"{i}:{j}",
|
376 |
+
on_click=handle_click,
|
377 |
+
args=(i, j),
|
378 |
+
)
|
379 |
+
else:
|
380 |
+
# enable click for other cells available for human choices
|
381 |
+
BOARD_PLATE[i][j].button(
|
382 |
+
_PLAYER_SYMBOL[cell],
|
383 |
+
key=f"{i}:{j}",
|
384 |
+
on_click=handle_click,
|
385 |
+
args=(i, j),
|
386 |
+
)
|
387 |
+
|
388 |
+
|
389 |
+
elif response and session_state.ROOM.TURN == _WHITE: # AI turn
|
390 |
+
message.empty()
|
391 |
+
with st.spinner('🔮✨ Waiting for AI response... ⏳🚀'):
|
392 |
+
time.sleep(0.1)
|
393 |
+
print("AI's turn")
|
394 |
+
print("Below are current board under AI's view")
|
395 |
+
# print(session_state.ROOM.BOARD.board_map)
|
396 |
+
# move = _BOARD_SIZE * _BOARD_SIZE
|
397 |
+
# forbid = []
|
398 |
+
# step = 0.1
|
399 |
+
# tmp = 0.7
|
400 |
+
# while move >= _BOARD_SIZE * _BOARD_SIZE or move in session_state.ROOM.COORDINATE_1D:
|
401 |
+
#
|
402 |
+
# gpt_predictions = generate_gpt2(
|
403 |
+
# gpt2,
|
404 |
+
# torch.tensor(session_state.ROOM.COORDINATE_1D).unsqueeze(0),
|
405 |
+
# tmp
|
406 |
+
# )
|
407 |
+
# print(gpt_predictions)
|
408 |
+
# move = gpt_predictions[len(session_state.ROOM.COORDINATE_1D)]
|
409 |
+
# print(move)
|
410 |
+
# tmp += step
|
411 |
+
# # if move >= _BOARD_SIZE * _BOARD_SIZE:
|
412 |
+
# # forbid.append(move)
|
413 |
+
# # else:
|
414 |
+
# # break
|
415 |
+
#
|
416 |
+
#
|
417 |
+
# gpt_response = move
|
418 |
+
# gpt_i, gpt_j = gpt_response // _BOARD_SIZE, gpt_response % _BOARD_SIZE
|
419 |
+
# print(gpt_i, gpt_j)
|
420 |
+
# # session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
|
421 |
+
#
|
422 |
+
# simul_time = 0
|
423 |
+
if st.session_state['selected_oppo_model'] != 'Gomoku Bot':
|
424 |
+
move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
|
425 |
+
else:
|
426 |
+
move, simul_time = session_state.ROOM.MCTS.get_action(return_time=True)
|
427 |
+
session_state.ROOM.simula_time_list.append(simul_time)
|
428 |
+
print("AI takes move: ", move)
|
429 |
+
session_state.ROOM.current_move = move
|
430 |
+
gpt_response = move
|
431 |
+
gpt_i, gpt_j = gpt_response // _BOARD_SIZE, gpt_response % _BOARD_SIZE
|
432 |
+
print("AI's move is located at ({}, {}) :".format(gpt_i, gpt_j))
|
433 |
+
move = session_state.ROOM.BOARD.location_to_move((gpt_i, gpt_j))
|
434 |
+
print("Location to move: ", move)
|
435 |
+
# print("Location to move: ", move)
|
436 |
+
# MCTS BOARD
|
437 |
+
session_state.ROOM.BOARD.do_move(move)
|
438 |
+
# Gomoku Bot BOARD
|
439 |
+
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(move // _BOARD_SIZE, move % _BOARD_SIZE)
|
440 |
+
# session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
|
441 |
+
session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
|
442 |
+
|
443 |
+
if not session_state.ROOM.BOARD.game_end()[0]:
|
444 |
+
if session_state.USE_AIAID:
|
445 |
+
copy_mcts = deepcopy(session_state.ROOM.AID_MCTS.mcts)
|
446 |
+
_, acts_aid, probs_aid, simul_mean_time_aid = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
|
447 |
+
sorted_acts_probs = sorted(zip(acts_aid, probs_aid), key=lambda x: x[1], reverse=True)
|
448 |
+
top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
|
449 |
+
top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
|
450 |
+
else:
|
451 |
+
top_five_acts = []
|
452 |
+
top_five_probs = []
|
453 |
+
|
454 |
+
# construction of clickable buttons
|
455 |
+
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
456 |
+
# print("row:", row)
|
457 |
+
for j, cell in enumerate(row):
|
458 |
+
if (
|
459 |
+
i * _BOARD_SIZE + j
|
460 |
+
in (session_state.ROOM.COORDINATE_1D)
|
461 |
+
):
|
462 |
+
if i == gpt_i and j == gpt_j:
|
463 |
+
BOARD_PLATE[i][j].button(
|
464 |
+
_PLAYER_SYMBOL[_NEW],
|
465 |
+
key=f"{i}:{j}",
|
466 |
+
args=(i, j),
|
467 |
+
on_click=handle_click,
|
468 |
+
)
|
469 |
+
else:
|
470 |
+
# disable click for GPT choices
|
471 |
+
BOARD_PLATE[i][j].button(
|
472 |
+
_PLAYER_SYMBOL[cell],
|
473 |
+
key=f"{i}:{j}",
|
474 |
+
args=(i, j),
|
475 |
+
on_click=forbid_click
|
476 |
+
)
|
477 |
+
else:
|
478 |
+
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts and not session_state.ROOM.BOARD.game_end()[0]:
|
479 |
+
# enable click for other cells available for human choices
|
480 |
+
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
|
481 |
+
BOARD_PLATE[i][j].button(
|
482 |
+
_PLAYER_SYMBOL[cell] + f"({round(prob, 2)})",
|
483 |
+
key=f"{i}:{j}",
|
484 |
+
on_click=handle_click,
|
485 |
+
args=(i, j),
|
486 |
+
)
|
487 |
+
else:
|
488 |
+
# enable click for other cells available for human choices
|
489 |
+
BOARD_PLATE[i][j].button(
|
490 |
+
_PLAYER_SYMBOL[cell],
|
491 |
+
key=f"{i}:{j}",
|
492 |
+
on_click=handle_click,
|
493 |
+
args=(i, j),
|
494 |
+
)
|
495 |
+
|
496 |
+
|
497 |
+
message.markdown(
|
498 |
+
'AI agent has calculated its strategy, which takes <span style="color: blue; font-size: 20px;">{:.3e}</span>s per simulation.'.format(
|
499 |
+
simul_time),
|
500 |
+
unsafe_allow_html=True
|
501 |
+
)
|
502 |
+
LOG.subheader("Logs")
|
503 |
+
# change turn
|
504 |
+
session_state.ROOM.TURN = change_turn(session_state.ROOM.TURN)
|
505 |
+
# session_state.ROOM.WINNER = check_win()
|
506 |
+
|
507 |
+
win, winner = session_state.ROOM.BOARD.game_end()
|
508 |
+
if win:
|
509 |
+
session_state.ROOM.WINNER = winner
|
510 |
+
|
511 |
+
session_state.ROOM.HISTORY = (
|
512 |
+
session_state.ROOM.HISTORY[0]
|
513 |
+
+ int(session_state.ROOM.WINNER == _WHITE),
|
514 |
+
session_state.ROOM.HISTORY[1]
|
515 |
+
+ int(session_state.ROOM.WINNER == _BLACK),
|
516 |
+
)
|
517 |
+
session_state.ROOM.TIME = time.time()
|
518 |
+
|
519 |
+
if not response or session_state.ROOM.WINNER != _BLANK:
|
520 |
+
if session_state.ROOM.WINNER != _BLANK:
|
521 |
+
print("Game over")
|
522 |
+
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
523 |
+
for j, cell in enumerate(row):
|
524 |
+
BOARD_PLATE[i][j].write(
|
525 |
+
_PLAYER_SYMBOL[cell],
|
526 |
+
# key=f"{i}:{j}",
|
527 |
+
)
|
528 |
+
|
529 |
+
# Game process control
|
530 |
+
def game_control():
|
531 |
+
if session_state.ROOM.WINNER != _BLANK:
|
532 |
+
draw_board(False)
|
533 |
+
else:
|
534 |
+
draw_board(True)
|
535 |
+
if session_state.ROOM.WINNER != _BLANK or 0 not in session_state.ROOM.BOARD.board_map:
|
536 |
+
ANOTHER_ROUND.button(
|
537 |
+
"Play Next round!",
|
538 |
+
on_click=another_round,
|
539 |
+
help="Clear board and swap first player",
|
540 |
+
)
|
541 |
+
|
542 |
+
# Infos
|
543 |
+
def update_info() -> None:
|
544 |
+
# Additional information
|
545 |
+
SCORE_PLATE[0].metric("Gomoku-Agent", session_state.ROOM.HISTORY[0])
|
546 |
+
SCORE_PLATE[1].metric("Black", session_state.ROOM.HISTORY[1])
|
547 |
+
if session_state.ROOM.WINNER != _BLANK:
|
548 |
+
st.balloons()
|
549 |
+
ROUND_INFO.write(
|
550 |
+
f"#### **{_PLAYER_COLOR[session_state.ROOM.WINNER]} WIN!**\n**Click buttons on the left for more plays.**"
|
551 |
+
)
|
552 |
+
|
553 |
+
# elif 0 not in session_state.ROOM.BOARD.board_map:
|
554 |
+
# ROUND_INFO.write("#### **Tie**")
|
555 |
+
# else:
|
556 |
+
# ROUND_INFO.write(
|
557 |
+
# f"#### **{_PLAYER_SYMBOL[session_state.ROOM.TURN]} {_PLAYER_COLOR[session_state.ROOM.TURN]}'s turn...**"
|
558 |
+
# )
|
559 |
+
|
560 |
+
# draw the plot for simulation time
|
561 |
+
# 创建一个 DataFrame
|
562 |
+
|
563 |
+
# print(session_state.ROOM.simula_time_list)
|
564 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
565 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
566 |
+
chart_data = pd.DataFrame(session_state.ROOM.simula_time_list, columns=["Simulation Time"])
|
567 |
+
st.line_chart(chart_data)
|
568 |
+
|
569 |
+
|
570 |
+
game_control()
|
571 |
+
update_info()
|
572 |
+
|
573 |
+
|
574 |
+
if __name__ == "__main__":
|
575 |
+
gomoku()
|
pages/Player_VS_AI.py
CHANGED
@@ -8,6 +8,7 @@ Description: this file is used to display our project and add visualization elem
|
|
8 |
import time
|
9 |
import pandas as pd
|
10 |
from copy import deepcopy
|
|
|
11 |
|
12 |
# import torch
|
13 |
import numpy as np
|
@@ -16,8 +17,12 @@ from scipy.signal import convolve # this is used to check if any player wins
|
|
16 |
from streamlit import session_state
|
17 |
from streamlit_server_state import server_state, server_state_lock
|
18 |
from Gomoku_MCTS import MCTSpure, alphazero, Board, PolicyValueNet
|
|
|
|
|
19 |
import matplotlib.pyplot as plt
|
20 |
|
|
|
|
|
21 |
from const import (
|
22 |
_BLACK, # 1, for human
|
23 |
_WHITE, # 2 , for AI
|
@@ -26,6 +31,7 @@ from const import (
|
|
26 |
_PLAYER_SYMBOL,
|
27 |
_ROOM_COLOR,
|
28 |
_VERTICAL,
|
|
|
29 |
_HORIZONTAL,
|
30 |
_DIAGONAL_UP_LEFT,
|
31 |
_DIAGONAL_UP_RIGHT,
|
@@ -35,6 +41,13 @@ from const import (
|
|
35 |
)
|
36 |
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
|
40 |
# Utils
|
@@ -48,11 +61,14 @@ class Room:
|
|
48 |
self.HISTORY = (0, 0)
|
49 |
self.WINNER = _BLANK
|
50 |
self.TIME = time.time()
|
51 |
-
self.
|
52 |
-
|
|
|
|
|
53 |
self.MCTS = self.MCTS_dict['AlphaZero']
|
|
|
54 |
self.AID_MCTS = self.MCTS_dict['AlphaZero']
|
55 |
-
self.COORDINATE_1D = [
|
56 |
self.current_move = -1
|
57 |
self.simula_time_list = []
|
58 |
|
@@ -75,10 +91,16 @@ if "ROOMS" not in server_state:
|
|
75 |
server_state.ROOMS = {}
|
76 |
|
77 |
def handle_oppo_model_selection():
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
return
|
83 |
|
84 |
def handle_aid_model_selection():
|
@@ -103,7 +125,7 @@ TITLE = st.empty()
|
|
103 |
Model_Switch = st.empty()
|
104 |
|
105 |
TITLE.header("🤖 AI 3603 Gomoku")
|
106 |
-
selected_oppo_option = Model_Switch.selectbox('Select Opponent Model', ['Pure MCTS', 'AlphaZero'], index=1, key='oppo_model')
|
107 |
|
108 |
if st.session_state['selected_oppo_model'] != selected_oppo_option:
|
109 |
st.session_state['selected_oppo_model'] = selected_oppo_option
|
@@ -158,7 +180,7 @@ def restart() -> None:
|
|
158 |
Restart the game.
|
159 |
"""
|
160 |
session_state.ROOM = Room(session_state.ROOM.ROOM_ID)
|
161 |
-
|
162 |
|
163 |
RESTART.button(
|
164 |
"Reset",
|
@@ -183,10 +205,16 @@ def gomoku():
|
|
183 |
"""
|
184 |
session_state.ROOM = deepcopy(session_state.ROOM)
|
185 |
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
session_state.ROOM.PLAYER = session_state.ROOM.PLAYER
|
187 |
session_state.ROOM.TURN = session_state.ROOM.PLAYER
|
188 |
session_state.ROOM.WINNER = _BLANK # 0
|
189 |
-
session_state.ROOM.COORDINATE_1D = [
|
190 |
|
191 |
# Room status sync
|
192 |
def sync_room() -> bool:
|
@@ -281,6 +309,8 @@ def gomoku():
|
|
281 |
move = session_state.ROOM.BOARD.location_to_move((x, y))
|
282 |
session_state.ROOM.current_move = move
|
283 |
session_state.ROOM.BOARD.do_move(move)
|
|
|
|
|
284 |
session_state.ROOM.BOARD.board_map[x][y] = session_state.ROOM.TURN
|
285 |
session_state.ROOM.COORDINATE_1D.append(x * _BOARD_SIZE + y)
|
286 |
|
@@ -299,7 +329,6 @@ def gomoku():
|
|
299 |
def forbid_click(x, y):
|
300 |
# st.warning('This posistion has been occupied!!!!', icon="⚠️")
|
301 |
st.error("({}, {}) has been occupied!!)".format(x, y), icon="🚨")
|
302 |
-
print("asdas")
|
303 |
|
304 |
# Draw board
|
305 |
def draw_board(response: bool):
|
@@ -314,6 +343,7 @@ def gomoku():
|
|
314 |
if response and session_state.ROOM.TURN == _BLACK: # human turn
|
315 |
print("Your turn")
|
316 |
# construction of clickable buttons
|
|
|
317 |
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
318 |
# print("row:", row)
|
319 |
for j, cell in enumerate(row):
|
@@ -321,13 +351,21 @@ def gomoku():
|
|
321 |
i * _BOARD_SIZE + j
|
322 |
in (session_state.ROOM.COORDINATE_1D)
|
323 |
):
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
else:
|
332 |
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts:
|
333 |
# enable click for other cells available for human choices
|
@@ -355,7 +393,37 @@ def gomoku():
|
|
355 |
print("AI's turn")
|
356 |
print("Below are current board under AI's view")
|
357 |
# print(session_state.ROOM.BOARD.board_map)
|
358 |
-
move
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
session_state.ROOM.simula_time_list.append(simul_time)
|
360 |
print("AI takes move: ", move)
|
361 |
session_state.ROOM.current_move = move
|
@@ -364,7 +432,11 @@ def gomoku():
|
|
364 |
print("AI's move is located at ({}, {}) :".format(gpt_i, gpt_j))
|
365 |
move = session_state.ROOM.BOARD.location_to_move((gpt_i, gpt_j))
|
366 |
print("Location to move: ", move)
|
|
|
|
|
367 |
session_state.ROOM.BOARD.do_move(move)
|
|
|
|
|
368 |
# session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
|
369 |
session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
|
370 |
|
@@ -387,13 +459,21 @@ def gomoku():
|
|
387 |
i * _BOARD_SIZE + j
|
388 |
in (session_state.ROOM.COORDINATE_1D)
|
389 |
):
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
else:
|
398 |
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts and not session_state.ROOM.BOARD.game_end()[0]:
|
399 |
# enable click for other cells available for human choices
|
@@ -480,7 +560,7 @@ def gomoku():
|
|
480 |
# draw the plot for simulation time
|
481 |
# 创建一个 DataFrame
|
482 |
|
483 |
-
print(session_state.ROOM.simula_time_list)
|
484 |
st.markdown("<br>", unsafe_allow_html=True)
|
485 |
st.markdown("<br>", unsafe_allow_html=True)
|
486 |
chart_data = pd.DataFrame(session_state.ROOM.simula_time_list, columns=["Simulation Time"])
|
|
|
8 |
import time
|
9 |
import pandas as pd
|
10 |
from copy import deepcopy
|
11 |
+
import torch
|
12 |
|
13 |
# import torch
|
14 |
import numpy as np
|
|
|
17 |
from streamlit import session_state
|
18 |
from streamlit_server_state import server_state, server_state_lock
|
19 |
from Gomoku_MCTS import MCTSpure, alphazero, Board, PolicyValueNet
|
20 |
+
from Gomoku_Bot import Gomoku_bot
|
21 |
+
from Gomoku_Bot import Board as Gomoku_bot_board
|
22 |
import matplotlib.pyplot as plt
|
23 |
|
24 |
+
|
25 |
+
|
26 |
from const import (
|
27 |
_BLACK, # 1, for human
|
28 |
_WHITE, # 2 , for AI
|
|
|
31 |
_PLAYER_SYMBOL,
|
32 |
_ROOM_COLOR,
|
33 |
_VERTICAL,
|
34 |
+
_NEW,
|
35 |
_HORIZONTAL,
|
36 |
_DIAGONAL_UP_LEFT,
|
37 |
_DIAGONAL_UP_RIGHT,
|
|
|
41 |
)
|
42 |
|
43 |
|
44 |
+
from ai import (
|
45 |
+
BOS_TOKEN_ID,
|
46 |
+
generate_gpt2,
|
47 |
+
load_model,
|
48 |
+
)
|
49 |
+
|
50 |
+
gpt2 = load_model()
|
51 |
|
52 |
|
53 |
# Utils
|
|
|
61 |
self.HISTORY = (0, 0)
|
62 |
self.WINNER = _BLANK
|
63 |
self.TIME = time.time()
|
64 |
+
self.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
65 |
+
self.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
66 |
+
'AlphaZero': alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, 'Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth').policy_value_fn, c_puct=5, n_playout=100),
|
67 |
+
'Gomoku Bot': Gomoku_bot(self.gomoku_bot_board, -1)}
|
68 |
self.MCTS = self.MCTS_dict['AlphaZero']
|
69 |
+
self.last_mcts = self.MCTS
|
70 |
self.AID_MCTS = self.MCTS_dict['AlphaZero']
|
71 |
+
self.COORDINATE_1D = [BOS_TOKEN_ID]
|
72 |
self.current_move = -1
|
73 |
self.simula_time_list = []
|
74 |
|
|
|
91 |
server_state.ROOMS = {}
|
92 |
|
93 |
def handle_oppo_model_selection():
|
94 |
+
if st.session_state['selected_oppo_model'] == 'Gomoku Bot':
|
95 |
+
session_state.ROOM.last_mcts = session_state.ROOM.MCTS # since use different mechanism, store previous mcts first
|
96 |
+
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict['Gomoku Bot']
|
97 |
+
return
|
98 |
+
else:
|
99 |
+
TreeNode = session_state.ROOM.last_mcts.mcts._root
|
100 |
+
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
101 |
+
new_mct.mcts._root = deepcopy(TreeNode)
|
102 |
+
session_state.ROOM.MCTS = new_mct
|
103 |
+
session_state.ROOM.last_mcts = new_mct
|
104 |
return
|
105 |
|
106 |
def handle_aid_model_selection():
|
|
|
125 |
Model_Switch = st.empty()
|
126 |
|
127 |
TITLE.header("🤖 AI 3603 Gomoku")
|
128 |
+
selected_oppo_option = Model_Switch.selectbox('Select Opponent Model', ['Pure MCTS', 'AlphaZero','Gomoku Bot'], index=1, key='oppo_model')
|
129 |
|
130 |
if st.session_state['selected_oppo_model'] != selected_oppo_option:
|
131 |
st.session_state['selected_oppo_model'] = selected_oppo_option
|
|
|
180 |
Restart the game.
|
181 |
"""
|
182 |
session_state.ROOM = Room(session_state.ROOM.ROOM_ID)
|
183 |
+
st.session_state['selected_oppo_model'] = 'AlphaZero'
|
184 |
|
185 |
RESTART.button(
|
186 |
"Reset",
|
|
|
205 |
"""
|
206 |
session_state.ROOM = deepcopy(session_state.ROOM)
|
207 |
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
|
208 |
+
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
209 |
+
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
210 |
+
'AlphaZero': alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, 'Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth').policy_value_fn, c_puct=5, n_playout=100),
|
211 |
+
'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
|
212 |
+
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
213 |
+
session_state.ROOM.last_mcts = session_state.ROOM.MCTS
|
214 |
session_state.ROOM.PLAYER = session_state.ROOM.PLAYER
|
215 |
session_state.ROOM.TURN = session_state.ROOM.PLAYER
|
216 |
session_state.ROOM.WINNER = _BLANK # 0
|
217 |
+
session_state.ROOM.COORDINATE_1D = [BOS_TOKEN_ID]
|
218 |
|
219 |
# Room status sync
|
220 |
def sync_room() -> bool:
|
|
|
309 |
move = session_state.ROOM.BOARD.location_to_move((x, y))
|
310 |
session_state.ROOM.current_move = move
|
311 |
session_state.ROOM.BOARD.do_move(move)
|
312 |
+
# Gomoku Bot BOARD
|
313 |
+
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(move // _BOARD_SIZE, move % _BOARD_SIZE)
|
314 |
session_state.ROOM.BOARD.board_map[x][y] = session_state.ROOM.TURN
|
315 |
session_state.ROOM.COORDINATE_1D.append(x * _BOARD_SIZE + y)
|
316 |
|
|
|
329 |
def forbid_click(x, y):
|
330 |
# st.warning('This posistion has been occupied!!!!', icon="⚠️")
|
331 |
st.error("({}, {}) has been occupied!!)".format(x, y), icon="🚨")
|
|
|
332 |
|
333 |
# Draw board
|
334 |
def draw_board(response: bool):
|
|
|
343 |
if response and session_state.ROOM.TURN == _BLACK: # human turn
|
344 |
print("Your turn")
|
345 |
# construction of clickable buttons
|
346 |
+
cur_move = (session_state.ROOM.current_move // _BOARD_SIZE, session_state.ROOM.current_move % _BOARD_SIZE)
|
347 |
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
348 |
# print("row:", row)
|
349 |
for j, cell in enumerate(row):
|
|
|
351 |
i * _BOARD_SIZE + j
|
352 |
in (session_state.ROOM.COORDINATE_1D)
|
353 |
):
|
354 |
+
if i == cur_move[0] and j == cur_move[1]:
|
355 |
+
BOARD_PLATE[i][j].button(
|
356 |
+
_PLAYER_SYMBOL[_NEW],
|
357 |
+
key=f"{i}:{j}",
|
358 |
+
args=(i, j),
|
359 |
+
on_click=handle_click,
|
360 |
+
)
|
361 |
+
else:
|
362 |
+
# disable click for GPT choices
|
363 |
+
BOARD_PLATE[i][j].button(
|
364 |
+
_PLAYER_SYMBOL[cell],
|
365 |
+
key=f"{i}:{j}",
|
366 |
+
args=(i, j),
|
367 |
+
on_click=forbid_click
|
368 |
+
)
|
369 |
else:
|
370 |
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts:
|
371 |
# enable click for other cells available for human choices
|
|
|
393 |
print("AI's turn")
|
394 |
print("Below are current board under AI's view")
|
395 |
# print(session_state.ROOM.BOARD.board_map)
|
396 |
+
# move = _BOARD_SIZE * _BOARD_SIZE
|
397 |
+
# forbid = []
|
398 |
+
# step = 0.1
|
399 |
+
# tmp = 0.7
|
400 |
+
# while move >= _BOARD_SIZE * _BOARD_SIZE or move in session_state.ROOM.COORDINATE_1D:
|
401 |
+
#
|
402 |
+
# gpt_predictions = generate_gpt2(
|
403 |
+
# gpt2,
|
404 |
+
# torch.tensor(session_state.ROOM.COORDINATE_1D).unsqueeze(0),
|
405 |
+
# tmp
|
406 |
+
# )
|
407 |
+
# print(gpt_predictions)
|
408 |
+
# move = gpt_predictions[len(session_state.ROOM.COORDINATE_1D)]
|
409 |
+
# print(move)
|
410 |
+
# tmp += step
|
411 |
+
# # if move >= _BOARD_SIZE * _BOARD_SIZE:
|
412 |
+
# # forbid.append(move)
|
413 |
+
# # else:
|
414 |
+
# # break
|
415 |
+
#
|
416 |
+
#
|
417 |
+
# gpt_response = move
|
418 |
+
# gpt_i, gpt_j = gpt_response // _BOARD_SIZE, gpt_response % _BOARD_SIZE
|
419 |
+
# print(gpt_i, gpt_j)
|
420 |
+
# # session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
|
421 |
+
#
|
422 |
+
# simul_time = 0
|
423 |
+
if st.session_state['selected_oppo_model'] != 'Gomoku Bot':
|
424 |
+
move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
|
425 |
+
else:
|
426 |
+
move, simul_time = session_state.ROOM.MCTS.get_action(return_time=True)
|
427 |
session_state.ROOM.simula_time_list.append(simul_time)
|
428 |
print("AI takes move: ", move)
|
429 |
session_state.ROOM.current_move = move
|
|
|
432 |
print("AI's move is located at ({}, {}) :".format(gpt_i, gpt_j))
|
433 |
move = session_state.ROOM.BOARD.location_to_move((gpt_i, gpt_j))
|
434 |
print("Location to move: ", move)
|
435 |
+
# print("Location to move: ", move)
|
436 |
+
# MCTS BOARD
|
437 |
session_state.ROOM.BOARD.do_move(move)
|
438 |
+
# Gomoku Bot BOARD
|
439 |
+
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(move // _BOARD_SIZE, move % _BOARD_SIZE)
|
440 |
# session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
|
441 |
session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
|
442 |
|
|
|
459 |
i * _BOARD_SIZE + j
|
460 |
in (session_state.ROOM.COORDINATE_1D)
|
461 |
):
|
462 |
+
if i == gpt_i and j == gpt_j:
|
463 |
+
BOARD_PLATE[i][j].button(
|
464 |
+
_PLAYER_SYMBOL[_NEW],
|
465 |
+
key=f"{i}:{j}",
|
466 |
+
args=(i, j),
|
467 |
+
on_click=handle_click,
|
468 |
+
)
|
469 |
+
else:
|
470 |
+
# disable click for GPT choices
|
471 |
+
BOARD_PLATE[i][j].button(
|
472 |
+
_PLAYER_SYMBOL[cell],
|
473 |
+
key=f"{i}:{j}",
|
474 |
+
args=(i, j),
|
475 |
+
on_click=forbid_click
|
476 |
+
)
|
477 |
else:
|
478 |
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts and not session_state.ROOM.BOARD.game_end()[0]:
|
479 |
# enable click for other cells available for human choices
|
|
|
560 |
# draw the plot for simulation time
|
561 |
# 创建一个 DataFrame
|
562 |
|
563 |
+
# print(session_state.ROOM.simula_time_list)
|
564 |
st.markdown("<br>", unsafe_allow_html=True)
|
565 |
st.markdown("<br>", unsafe_allow_html=True)
|
566 |
chart_data = pd.DataFrame(session_state.ROOM.simula_time_list, columns=["Simulation Time"])
|
try.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
2 |
+
|
3 |
+
model_name_or_path = "TheBloke/Llama-2-7b-Chat-GPTQ"
|
4 |
+
# To use a different branch, change revision
|
5 |
+
# For example: revision="gptq-4bit-64g-actorder_True"
|
6 |
+
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
|
7 |
+
device_map="auto",
|
8 |
+
trust_remote_code=False,
|
9 |
+
revision="main")
|
10 |
+
|
11 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
|
12 |
+
|
13 |
+
prompt = "Tell me about AI"
|
14 |
+
prompt_template=f'''[INST] <<SYS>>
|
15 |
+
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
|
16 |
+
<</SYS>>
|
17 |
+
{prompt}[/INST]
|
18 |
+
|
19 |
+
'''
|
20 |
+
|
21 |
+
print("\n\n*** Generate:")
|
22 |
+
|
23 |
+
input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
|
24 |
+
output = model.generate(inputs=input_ids, temperature=0.7, do_sample=True, top_p=0.95, top_k=40, max_new_tokens=512)
|
25 |
+
print(tokenizer.decode(output[0]))
|
26 |
+
|
27 |
+
# Inference can also be done using transformers' pipeline
|
28 |
+
|
29 |
+
print("*** Pipeline:")
|
30 |
+
pipe = pipeline(
|
31 |
+
"text-generation",
|
32 |
+
model=model,
|
33 |
+
tokenizer=tokenizer,
|
34 |
+
max_new_tokens=512,
|
35 |
+
do_sample=True,
|
36 |
+
temperature=0.7,
|
37 |
+
top_p=0.95,
|
38 |
+
top_k=40,
|
39 |
+
repetition_penalty=1.1
|
40 |
+
)
|
41 |
+
|
42 |
+
print(pipe(prompt_template)[0]['generated_text'])
|