HuskyDoge commited on
Commit
7d23b62
1 Parent(s): e7a440c

added gomokubot

Browse files
Gomoku_Bot/HumanVSAI.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5
+ from .board import Board
6
+ from .minmax import vct, cache_hits, minmax
7
+ from .eval import FIVE, FOUR, performance
8
+
9
+
10
+ class Game():
11
+ def __init__(self, firstRole=1):
12
+ self.board = Board(8, firstRole)
13
+ self.steps = []
14
+ self.step = 0
15
+ self.enableVCT = True # 是否开启算杀, 算杀会在某些leaf节点加深搜索, 但是不一定会增加搜索时间
16
+
17
+ def human_input(self):
18
+ x, y = map(int, input('Your move: ').split())
19
+ return x, y
20
+
21
+ def start_play(self, human_first=False):
22
+ if not human_first:
23
+ while not self.board.isGameOver():
24
+ print(self.board.display())
25
+ if self.step % 2 == 1:
26
+ x, y = self.human_input()
27
+ while not self.board.put(x, y):
28
+ x, y = self.human_input()
29
+ else:
30
+ score = minmax(self.board, 1, 4, enableVCT=self.enableVCT)
31
+ print(score)
32
+ x, y = score[1]
33
+ print("move at", x, y)
34
+ self.board.put(x, y)
35
+ self.step += 1
36
+ else:
37
+ while not self.board.isGameOver():
38
+ print(self.board.display())
39
+ if self.step % 2 == 0:
40
+ x, y = self.human_input()
41
+ while not self.board.put(x, y):
42
+ x, y = self.human_input()
43
+ else:
44
+ score = minmax(self.board, -1, 4, enableVCT=self.enableVCT)
45
+ print(score)
46
+ x, y = score[1]
47
+ self.board.put(x, y)
48
+ self.step += 1
49
+ print(self.board.display())
50
+
51
+
52
+ if __name__ == '__main__':
53
+ game = Game()
54
+ game.start_play(True)
Gomoku_Bot/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .board import Board
2
+ from .gomoku_bot import Gomoku_bot
Gomoku_Bot/board.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .zobrist import ZobristCache as Zobrist
2
+ from .cache import Cache
3
+ from .eval import Evaluate, FIVE
4
+ from scipy import signal
5
+ import pickle
6
+ import os
7
+ save_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'train_data/data', 'train_data.pkl')
8
+
9
+ if 'numpy' not in globals():
10
+ import numpy as np
11
+
12
+
13
+ class Board:
14
+ def __init__(self, size=15, firstRole=1):
15
+ self.size = size
16
+ self.board = [[0] * self.size for _ in range(self.size)]
17
+ self.firstRole = firstRole # 1 for black, -1 for white
18
+ self.role = firstRole # 1 for black, -1 for white
19
+ self.history = []
20
+ self.zobrist = Zobrist(self.size)
21
+ self.winnerCache = Cache()
22
+ self.gameoverCache = Cache()
23
+ self.evaluateCache = Cache()
24
+ self.valuableMovesCache = Cache()
25
+ self.evaluateTime = 0
26
+ self.evaluator = Evaluate(self.size)
27
+ self.available = [(i, j) for i in range(self.size) for j in range(self.size)]
28
+ self.patterns = [np.ones((1, 5)), np.ones((5, 1)), np.eye(5), np.fliplr(np.eye(5))]
29
+ self.train_data = {1:[], -1: []}
30
+ if os.path.exists(save_path):
31
+ with open(save_path, 'rb') as f:
32
+ self.train_data = pickle.load(f)
33
+
34
+ def isGameOver(self):
35
+ # Checked
36
+ hash = self.hash()
37
+ if self.gameoverCache.get(hash):
38
+ return self.gameoverCache.get(hash)
39
+ if self.getWinner() != 0:
40
+ self.gameoverCache.put(hash, True)
41
+ # save train data
42
+ # with open(save_path, 'wb') as f:
43
+ # pickle.dump(self.train_data, f)
44
+ return True # Someone has won
45
+ # Game is over when there is no empty space on the board or someone has won
46
+ if len(self.history) == self.size ** 2:
47
+ self.gameoverCache.put(hash, True)
48
+ return True
49
+ else:
50
+ self.gameoverCache.put(hash, False)
51
+ return False
52
+
53
+ def getWinner(self):
54
+ # Checked
55
+ hash = self.hash()
56
+ flag = True
57
+ if self.winnerCache.get(hash):
58
+ return self.winnerCache.get(hash)
59
+ directions = [[1, 0], [0, 1], [1, 1], [1, -1]] # Horizontal, Vertical, Diagonal
60
+ for i in range(self.size):
61
+ for j in range(self.size):
62
+ if self.board[i][j] == 0:
63
+ flag = False
64
+ continue
65
+ for direction in directions:
66
+ count = 0
67
+ while (
68
+ 0 <= i + direction[0] * count < self.size and
69
+ 0 <= j + direction[1] * count < self.size and
70
+ self.board[i + direction[0] * count][j + direction[1] * count] == self.board[i][j]
71
+ ):
72
+ count += 1
73
+ if count >= 5:
74
+ self.winnerCache.put(hash, self.board[i][j])
75
+ return self.board[i][j]
76
+ if flag:
77
+ print("tie!!!")
78
+ return 0
79
+ self.winnerCache.put(hash, 0)
80
+ return 0
81
+
82
+ def getValidMoves(self):
83
+ return self.available
84
+
85
+ def put(self, i, j, role=None):
86
+ # Checked
87
+ if role is None:
88
+ role = self.role
89
+ if not isinstance(i, int) or not isinstance(j, int):
90
+ print("Invalid move: Not Number!", i, j)
91
+ return False
92
+ if self.board[i][j] != 0:
93
+ print("Invalid move!", i, j)
94
+ return False
95
+ self.board[i][j] = role
96
+ self.available.remove((i, j))
97
+ self.history.append({"i": i, "j": j, "role": role})
98
+ self.zobrist.togglePiece(i, j, role)
99
+ self.evaluator.move(i, j, role)
100
+ self.role *= -1 # Switch role
101
+ return True
102
+
103
+ def undo(self):
104
+ # Checked
105
+ if len(self.history) == 0:
106
+ print("No moves to undo!")
107
+ return False
108
+
109
+ lastMove = self.history.pop()
110
+ self.board[lastMove['i']][lastMove['j']] = 0 # Remove the piece from the board
111
+ self.role = lastMove['role'] # Switch back to the previous player
112
+ self.zobrist.togglePiece(lastMove['i'], lastMove['j'], lastMove['role'])
113
+ self.evaluator.undo(lastMove['i'], lastMove['j'])
114
+ self.available.append((lastMove['i'], lastMove['j']))
115
+ return True
116
+
117
+ def position2coordinate(self, position):
118
+ # checked
119
+ row = position // self.size
120
+ col = position % self.size
121
+ return [row, col]
122
+
123
+ def coordinate2position(self, coordinate):
124
+ # Checked
125
+ return coordinate[0] * self.size + coordinate[1]
126
+
127
+ def getValuableMoves(self, role, depth=0, onlyThree=False, onlyFour=False):
128
+ # Checked
129
+ hash = self.hash()
130
+ prev = self.valuableMovesCache.get(hash)
131
+ if prev:
132
+ if (prev["role"] == role and
133
+ prev["depth"] == depth and
134
+ prev["onlyThree"] == onlyThree
135
+ and prev["onlyFour"] == onlyFour):
136
+ return prev["moves"]
137
+
138
+ moves, train_data = self.evaluator.getMoves(role, depth, onlyThree, onlyFour)
139
+ self.train_data[self.role].append(train_data)
140
+ # Handle a special case, if the center point is not occupied, add it by default
141
+
142
+ # 开局的时候随机走一步,增加开局的多样性
143
+ if not onlyThree and not onlyFour:
144
+ center = self.size // 2
145
+ if self.board[center][center] == 0:
146
+ moves.append((center, center))
147
+
148
+ # x_step = np.random.randint(-self.size // 2, self.size // 2)
149
+ # y_step = np.random.randint(-self.size // 2, self.size // 2)
150
+ # x = center + x_step
151
+ # y = center + y_step
152
+ # if 0 <= x < self.size and 0 <= y < self.size and self.board[x][y] == 0:
153
+ # moves.append((x, y))
154
+
155
+ self.valuableMovesCache.put(hash, {
156
+ "role": role,
157
+ "moves": moves,
158
+ "depth": depth,
159
+ "onlyThree": onlyThree,
160
+ "onlyFour": onlyFour
161
+ })
162
+ return moves
163
+
164
+ def display(self, extraPoints=[]):
165
+ # Checked
166
+ extraPosition = [self.coordinate2position(point) for point in extraPoints]
167
+ result = ""
168
+ for i in range(self.size):
169
+ for j in range(self.size):
170
+ position = self.coordinate2position([i, j])
171
+ if position in extraPosition:
172
+ result += "? "
173
+ continue
174
+ value = self.board[i][j]
175
+ if value == 1:
176
+ result += "B " # Black
177
+ elif value == -1:
178
+ result += "W " # White
179
+ else:
180
+ result += "- "
181
+ result += "\n"
182
+ return result
183
+
184
+ def hash(self):
185
+ # Checked
186
+ return self.zobrist.getHash() # Return the hash value of the current board, used for caching
187
+
188
+ def evaluate(self, role):
189
+ # Checked
190
+ hash_key = self.hash()
191
+ prev = self.evaluateCache.get(hash_key)
192
+ if prev:
193
+ if prev["role"] == role:
194
+ return prev["score"]
195
+
196
+ winner = self.getWinner()
197
+ score = 0
198
+ if winner != 0:
199
+ score = FIVE * winner * role
200
+ else:
201
+ score = self.evaluator.evaluate(role)
202
+
203
+ self.evaluateCache.put(hash_key, {"role": role, "score": score})
204
+ return score
205
+
206
+ def reverse(self):
207
+ # Checked
208
+ new_board = Board(self.size, -self.firstRole)
209
+ for move in self.history:
210
+ x, y, role = move['i'], move['j'], move['role']
211
+ new_board.put(x, y, -role)
212
+ return new_board
213
+
214
+ def toString(self):
215
+ # Checked
216
+ return ''.join([''.join(map(str, row)) for row in self.board])
Gomoku_Bot/board_manuls.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wins = [
2
+ # O O O O O
3
+ # X X X X -
4
+ # - - - - -
5
+ # - - - - -
6
+ # - - - - -
7
+ [5, [0, 5, 1, 6, 2, 7, 3, 8, 4], 1], # 横向五
8
+ # O O O O -
9
+ # X X X X X
10
+ # O - - - -
11
+ # - - - - -
12
+ # - - - - -
13
+ [5, [0, 5, 1, 6, 2, 7, 3, 8, 10, 9], -1], # 白子横向五
14
+ # O O - O O
15
+ # X X X X -
16
+ # O - - - -
17
+ # - - - - -
18
+ # - - - - -
19
+ [5, [0, 5, 1, 6, 10, 7, 3, 8, 4], 0], # 有一个空位
20
+ # O O O X O
21
+ # X X X X -
22
+ # O - - - -
23
+ # - - - - -
24
+ # - - - - -
25
+ [5, [0, 5, 1, 6, 2, 7, 10, 8, 4], 0], # 有一个白子
26
+ # O X X X X
27
+ # O - - - -
28
+ # O - - - -
29
+ # O - - - -
30
+ # O - - - -
31
+ [5, [0, 1, 5, 2, 10, 3, 15, 4, 20], 1], # 纵向五
32
+ # O X X X X
33
+ # O - - - -
34
+ # O - - - -
35
+ # - O - - -
36
+ # O - - - -
37
+ [5, [0, 1, 5, 2, 10, 3, 16, 4, 20], 0], # 纵向五有一个空位
38
+ # O X X X X
39
+ # O - - - -
40
+ # O - - - -
41
+ # X O - - -
42
+ # O - - - -
43
+ [5, [0, 1, 5, 2, 10, 3, 16, 4, 20, 15], 0], # 纵向五有一个白子
44
+ # O X X X X
45
+ # - O - - -
46
+ # - - O - -
47
+ # - - - O -
48
+ # - - - - O
49
+ [5, [0, 1, 6, 2, 12, 3, 18, 4, 24], 1], # 斜线五
50
+ # O X X X X
51
+ # - O - - -
52
+ # - - - O -
53
+ # - - - O -
54
+ # - - - - O
55
+ [5, [0, 1, 6, 2, 12, 3, 19, 4, 24], 0], # 斜线五有一个空的
56
+ # O X X X X
57
+ # - O - - -
58
+ # - - X O -
59
+ # - - - O -
60
+ # - - - - O
61
+ [5, [0, 1, 6, 2, 12, 3, 19, 4, 24, 18], 0], # 斜线五有一个白子
62
+ # X X X X O
63
+ # - - - O -
64
+ # - - O - -
65
+ # - O - - -
66
+ # O - - - -
67
+ [5, [4, 0, 8, 1, 12, 2, 16, 3, 20], 1], # 反斜线五
68
+ # X X X X O
69
+ # - - - O -
70
+ # - - O - -
71
+ # O - - - -
72
+ # O - - - -
73
+ [5, [4, 0, 8, 1, 12, 2, 15, 3, 20], 0], # 反斜线五 有一个空位
74
+ # X X X - O
75
+ # - - - O -
76
+ # - - O - -
77
+ # - X - - -
78
+ # O - - - -
79
+ [5, [4, 0, 8, 1, 12, 2, 16, 20], 0], # 反斜线五 有一个空位
80
+ ]
81
+
82
+ # valid moves
83
+ validMoves = [
84
+ # O - -
85
+ # - - -
86
+ # - - O
87
+ [3, [0, 8], [1, 2, 3, 4, 5, 6, 7]],
88
+ # O - - - -
89
+ # - - - - -
90
+ # - - - - -
91
+ # - - - - -
92
+ # - - - - -
93
+ [5, [0], [1, 2, 5, 6, 7, 8, 10, 11, 12, 13, 16, 17, 18]],
94
+ # - - - - -
95
+ # - - - - -
96
+ # - O - - -
97
+ # - - - - -
98
+ # - - - - -
99
+ [5, [11], [0, 1, 2, 3, 5, 6, 7, 8, 10, 12, 13, 15, 16, 17, 18, 20, 21, 22, 23]],
100
+
101
+ # - - - - - - - -
102
+ # - - - - - - - -
103
+ # - - - - - - - -
104
+ # - - O - X - - -
105
+ # - - - - - - - -
106
+ # - - - - - - - -
107
+ # - - - - - - - -
108
+ # - - - - - - - -
109
+ [8, [26, 28], [
110
+ 8, 9, 10, 11, 12, 13, 14,
111
+ 16, 17, 18, 19, 20, 21, 22,
112
+ 24, 25, 27, 29, 30,
113
+ 32, 33, 34, 35, 36, 37, 38,
114
+ 40, 41, 42, 43, 44, 45, 46,
115
+ ],
116
+ ],
117
+ ]
Gomoku_Bot/cache.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cachetools import LRUCache
2
+ from .config import config
3
+
4
+ class Cache:
5
+ def __init__(self, capacity=1000000):
6
+ self.capacity = capacity
7
+ self.cache = LRUCache(maxsize=capacity)
8
+ self.enable_cache = config['enableCache']
9
+
10
+ def get(self, key):
11
+ if not self.enable_cache:
12
+ return None
13
+ return self.cache.get(key, None)
14
+
15
+ def put(self, key, value):
16
+ if not self.enable_cache:
17
+ return
18
+ self.cache[key] = value
19
+
20
+ def has(self, key):
21
+ if not self.enable_cache:
22
+ return False
23
+ return key in self.cache
Gomoku_Bot/config.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ config = {
2
+ "enableCache": True, # Whether to enable caching
3
+ "onlyInLine": False, # Whether to search only on a single line, an optimization option
4
+ "inlineCount": 4, # Number of recent points to consider for being on the same line
5
+ "inLineDistance": 5 # Maximum distance to determine if a point is on the same line
6
+ }
Gomoku_Bot/eval.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import math
3
+ from .shape import shapes, get_shape_fast, is_five, is_four, get_all_shapes_of_point
4
+ from .position import coordinate2Position, isLine, isAllInLine, hasInLine, position2Coordinate
5
+ from .config import config
6
+ from datetime import datetime
7
+ import os
8
+
9
+ dir_path = os.path.dirname(os.path.abspath(__file__))
10
+ import numpy as np
11
+ import torch
12
+ # from .minimax_Net import BoardEvaluationNet as net
13
+
14
+ # mini_max_net = net(board_size=15)
15
+ # mini_max_net.load_state_dict(torch.load(os.path.join(dir_path, 'train_data/model', 'best_loss=609.3356355479785.pth')))
16
+ # mini_max_net.eval()
17
+
18
+
19
+ # Enum to represent different shapes
20
+ class Shapes(Enum):
21
+ FIVE = 0
22
+ BLOCK_FIVE = 1
23
+ FOUR = 2
24
+ FOUR_FOUR = 3
25
+ FOUR_THREE = 4
26
+ THREE_THREE = 5
27
+ BLOCK_FOUR = 6
28
+ THREE = 7
29
+ BLOCK_THREE = 8
30
+ TWO_TWO = 9
31
+ TWO = 10
32
+ NONE = 11
33
+
34
+
35
+ # Constants representing scores for each shape
36
+ FIVE = 10000000
37
+ BLOCK_FIVE = FIVE
38
+ FOUR = 100000
39
+ FOUR_FOUR = FOUR # 双冲四
40
+ FOUR_THREE = FOUR # 冲四活三
41
+ THREE_THREE = FOUR / 2 # 双活三
42
+ BLOCK_FOUR = 1500
43
+ THREE = 1000
44
+ BLOCK_THREE = 150
45
+ TWO_TWO = 200 # 双活二
46
+ TWO = 100
47
+ BLOCK_TWO = 15
48
+ ONE = 10
49
+ BLOCK_ONE = 1
50
+
51
+
52
+ # Function to calculate the real shape score based on the shape
53
+ def getRealShapeScore(shape: Shapes) -> int:
54
+ # Checked
55
+ if shape == shapes['FIVE']:
56
+ return FOUR
57
+ elif shape == shapes['BLOCK_FIVE']:
58
+ return BLOCK_FOUR
59
+ elif shape in [shapes['FOUR'], shapes['FOUR_FOUR'], shapes['FOUR_THREE']]:
60
+ return THREE
61
+ elif shape == shapes['BLOCK_FOUR']:
62
+ return BLOCK_THREE
63
+ elif shape == shapes['THREE']:
64
+ return TWO
65
+ elif shape == shapes['THREE_THREE']:
66
+ return math.floor(THREE_THREE / 10)
67
+ elif shape == shapes['BLOCK_THREE']:
68
+ return BLOCK_TWO
69
+ elif shape == shapes['TWO']:
70
+ return ONE
71
+ elif shape == shapes['TWO_TWO']:
72
+ return math.floor(TWO_TWO / 10)
73
+ else:
74
+ return 0
75
+
76
+
77
+ # List of all directions
78
+ allDirections = [
79
+ [0, 1], # Horizontal
80
+ [1, 0], # Vertical
81
+ [1, 1], # Diagonal \
82
+ [1, -1] # Diagonal /
83
+ ]
84
+
85
+
86
+ # Function to get the index of a direction
87
+ def direction2index(ox: int, oy: int) -> int:
88
+ # Checked
89
+ if ox == 0:
90
+ return 0 # |
91
+ elif oy == 0:
92
+ return 1 # -
93
+ elif ox == oy:
94
+ return 2 # \
95
+ elif ox != oy:
96
+ return 3 # /
97
+
98
+
99
+ # Performance dictionary
100
+ performance = {
101
+ "updateTime": 0,
102
+ "getPointsTime": 0
103
+ }
104
+
105
+
106
+ class Evaluate:
107
+ def __init__(self, size=15):
108
+ # Checked
109
+ self.size = size
110
+ self.board = [[2] * (size + 2) for _ in range(size + 2)]
111
+ for i in range(size + 2):
112
+ for j in range(size + 2):
113
+ if i == 0 or j == 0 or i == self.size + 1 or j == self.size + 1:
114
+ self.board[i][j] = 2
115
+ else:
116
+ self.board[i][j] = 0
117
+ self.blackScores = [[0] * self.size for _ in range(size)]
118
+ self.whiteScores = [[0] * self.size for _ in range(size)]
119
+ self.initPoints()
120
+ self.history = [] # List of [position, role]
121
+
122
+ def move(self, x, y, role):
123
+ # Checked
124
+ # Clear the cache first
125
+ for d in [0, 1, 2, 3]:
126
+ self.shapeCache[role][d][x][y] = 0
127
+ self.shapeCache[-role][d][x][y] = 0
128
+ self.blackScores[x][y] = 0
129
+ self.whiteScores[x][y] = 0
130
+ # Update the board
131
+ self.board[x + 1][y + 1] = role ## Adjust for the added wall
132
+ self.updatePoint(x, y)
133
+ self.history.append([coordinate2Position(x, y, self.size), role])
134
+
135
+ def undo(self, x, y):
136
+ # Checked
137
+ self.board[x + 1][y + 1] = 0
138
+ self.updatePoint(x, y)
139
+ self.history.pop()
140
+
141
+ def initPoints(self):
142
+ # Checked
143
+ # Initialize the cache, avoid calculating the same points multiple times
144
+ self.shapeCache = {}
145
+ for role in [1, -1]:
146
+ self.shapeCache[role] = {}
147
+ for direction in [0, 1, 2, 3]:
148
+ self.shapeCache[role][direction] = [[0] * self.size for _ in range(self.size)]
149
+
150
+ self.pointsCache = {}
151
+ for role in [1, -1]:
152
+ self.pointsCache[role] = {}
153
+ for shape in shapes:
154
+ self.pointsCache[role][shape] = set()
155
+
156
+ def getPointsInLine(self, role):
157
+ # Checked
158
+ pointsInLine = {}
159
+ hasPointsInLine = False
160
+ for key in shapes:
161
+ pointsInLine[shapes[key]] = set()
162
+
163
+ last2Points = [position for position, role in self.history[-config['inlineCount']:]]
164
+ processed = {}
165
+ # 在last2Points中查找是否有点位在一条线上
166
+ for r in [role, -role]:
167
+ for point in last2Points:
168
+ x, y = position2Coordinate(point, self.size)
169
+ for ox, oy in allDirections:
170
+ for sign in [1, -1]:
171
+ for step in range(1, config['inLineDistance'] + 1):
172
+ nx = x + sign * step * ox
173
+ ny = y + sign * step * oy
174
+ position = coordinate2Position(nx, ny, self.size)
175
+ # 检测是否到达边界
176
+ if nx < 0 or nx >= self.size or ny < 0 or ny >= self.size:
177
+ break
178
+ if self.board[nx + 1][ny + 1] != 0:
179
+ continue
180
+ if processed.get(position) == r:
181
+ continue
182
+ processed[position] = r
183
+ for direction in [0, 1, 2, 3]:
184
+ shape = self.shapeCache[r][direction][nx][ny]
185
+ # 到达边界停止,但是注意到达对方棋子不能停止
186
+ if shape:
187
+ pointsInLine[shape].add(coordinate2Position(nx, ny, self.size))
188
+ hasPointsInLine = True
189
+
190
+ if hasPointsInLine:
191
+ return pointsInLine
192
+ return False
193
+
194
+ def getPoints(self, role, depth, vct, vcf):
195
+ first = role if depth % 2 == 0 else -role # 先手
196
+ start = datetime.now()
197
+
198
+ if config['onlyInLine'] and len(self.history) >= config['inlineCount']:
199
+ points_in_line = self.getPointsInLine(role)
200
+ if points_in_line:
201
+ performance['getPointsTime'] += (datetime.now() - start).total_seconds()
202
+ return points_in_line
203
+
204
+ points = {} # 全部点位
205
+
206
+ for key in shapes.keys():
207
+ points[shapes[key]] = set()
208
+
209
+ last_points = [position for position, _ in self.history[-4:]]
210
+
211
+ for r in [role, -role]:
212
+ # 这里是直接遍历了这个棋盘上的所有点位,如果棋盘很大,这里会有性能问题;可以用神经网络来预测
213
+ for i in range(self.size):
214
+ for j in range(self.size):
215
+ four_count = 0
216
+ block_four_count = 0
217
+ three_count = 0
218
+
219
+ for direction in [0, 1, 2, 3]:
220
+ if self.board[i + 1][j + 1] != 0:
221
+ continue
222
+
223
+ shape = self.shapeCache[r][direction][i][j]
224
+
225
+ if not shape:
226
+ continue
227
+
228
+ point = i * self.size + j
229
+
230
+ if vcf:
231
+ if r == first and not is_four(shape) and not is_five(shape):
232
+ continue
233
+ if r == -first and is_five(shape):
234
+ continue
235
+
236
+ if vct:
237
+ if depth % 2 == 0:
238
+ if depth == 0 and r != first:
239
+ continue
240
+ if shape != shapes['THREE'] and not is_four(shape) and not is_five(shape):
241
+ continue
242
+ if shape == shapes['THREE'] and r != first:
243
+ continue
244
+ if depth == 0 and r != first:
245
+ continue
246
+ if depth > 0:
247
+ if shape == shapes['THREE'] and len(
248
+ get_all_shapes_of_point(self.shapeCache, i, j, r)) == 1:
249
+ continue
250
+ if shape == shapes['BLOCK_FOUR'] and len(
251
+ get_all_shapes_of_point(self.shapeCache, i, j, r)) == 1:
252
+ continue
253
+ else:
254
+ if shape != shapes['THREE'] and not is_four(shape) and not is_five(shape):
255
+ continue
256
+ if shape == shapes['THREE'] and r == -first:
257
+ continue
258
+ if depth > 1:
259
+ if shape == shapes['BLOCK_FOUR'] and len(
260
+ get_all_shapes_of_point(self.shapeCache, i, j)) == 1:
261
+ continue
262
+ if shape == shapes['BLOCK_FOUR'] and not hasInLine(point, last_points, self.size):
263
+ continue
264
+
265
+ if vcf:
266
+ if not is_four(shape) and not is_five(shape):
267
+ continue
268
+
269
+ if depth > 2 and (shape == shapes['TWO'] or shape == shapes['TWO_TWO'] or shape == shapes[
270
+ 'BLOCK_THREE']) and not hasInLine(point, last_points, self.size):
271
+ continue
272
+
273
+ points[shape].add(point)
274
+
275
+ if shape == shapes['FOUR']:
276
+ four_count += 1
277
+ elif shape == shapes['BLOCK_FOUR']:
278
+ block_four_count += 1
279
+ elif shape == shapes['THREE']:
280
+ three_count += 1
281
+
282
+ union_shape = None
283
+
284
+ if four_count >= 2:
285
+ union_shape = shapes['FOUR_FOUR']
286
+ elif block_four_count and three_count:
287
+ union_shape = shapes['FOUR_THREE']
288
+ elif three_count >= 2:
289
+ union_shape = shapes['THREE_THREE']
290
+
291
+ if union_shape:
292
+ points[union_shape].add(point)
293
+
294
+ performance['getPointsTime'] += (datetime.now() - start).total_seconds()
295
+
296
+ return points
297
+
298
+ """
299
+ 当一个位置发生变时候,要更新这个位置的四个方向上得分,更新规则是:
300
+ 1. 如果这个位置是空的,那么就重新计算这个位置的得分
301
+ 2. 如果碰到了边界或者对方的棋子,那么就停止计算
302
+ 3. 如果超过2个空位,那么就停止计算
303
+ 4. 要更新自己的和对方的得分
304
+ """
305
+
306
+ def updatePoint(self, x, y):
307
+ # Checked
308
+ start = datetime.now()
309
+ self.updateSinglePoint(x, y, 1)
310
+ self.updateSinglePoint(x, y, -1)
311
+
312
+ for ox, oy in allDirections:
313
+ for sign in [1, -1]: # -1 for negative direction, 1 for positive direction
314
+ for step in range(1, 6):
315
+ reachEdge = False
316
+ for role in [1, -1]:
317
+ nx = x + sign * step * ox + 1 # +1 to adjust for wall
318
+ ny = y + sign * step * oy + 1 # +1 to adjust for wall
319
+ # Stop if wall or opponent's piece is found
320
+ if self.board[nx][ny] == 2:
321
+ reachEdge = True
322
+ break
323
+ elif self.board[nx][ny] == -role: # Change role if opponent's piece is found
324
+ continue
325
+ elif self.board[nx][ny] == 0:
326
+ self.updateSinglePoint(nx - 1, ny - 1, role,
327
+ [sign * ox, sign * oy]) # -1 to adjust back from wall
328
+ if reachEdge:
329
+ break
330
+ performance['updateTime'] += (datetime.now() - start).total_seconds()
331
+
332
+ """
333
+ 计算单个点的得分
334
+ 计算原理是:
335
+ 在当前位置放一个当前角色的棋子,遍历四个方向,生成四个方向上的字符串,用patters来匹配字符串, 匹配到的话,就将对应的得分加到scores上
336
+ 四个方向的字符串生成规则是:向两边都延伸5个位置,如果遇到边界或者对方的棋子,就停止延伸
337
+ 在更新周围棋子时,只有一个方向需要更新,因此可以传入direction参数,只更新一个方向
338
+ """
339
+
340
+ def updateSinglePoint(self, x, y, role, direction=None):
341
+ # Checked
342
+ if self.board[x + 1][y + 1] != 0:
343
+ return # Not an empty spot
344
+
345
+ # Temporarily place the piece
346
+ self.board[x + 1][y + 1] = role
347
+
348
+ directions = []
349
+ if direction:
350
+ directions.append(direction)
351
+ else:
352
+ directions = allDirections
353
+
354
+ shapeCache = self.shapeCache[role]
355
+
356
+ # Clear the cache first
357
+ for ox, oy in directions:
358
+ shapeCache[direction2index(ox, oy)][x][y] = shapes['NONE']
359
+
360
+ score = 0
361
+ blockFourCount = 0
362
+ threeCount = 0
363
+ twoCount = 0
364
+
365
+ # Calculate existing score
366
+ for intDirection in [0, 1, 2, 3]:
367
+ shape = shapeCache[intDirection][x][y]
368
+ if shape > shapes['NONE']:
369
+ score += getRealShapeScore(shape)
370
+ if shape == shapes['BLOCK_FOUR']:
371
+ blockFourCount += 1
372
+ if shape == shapes['THREE']:
373
+ threeCount += 1
374
+ if shape == shapes['TWO']:
375
+ twoCount += 1
376
+
377
+ for ox, oy in directions:
378
+ intDirection = direction2index(ox, oy)
379
+ shape, selfCount = get_shape_fast(self.board, x, y, ox, oy, role)
380
+ if not shape:
381
+ continue
382
+ if shape:
383
+ # Note: Only cache single shapes, do not cache compound shapes like double threes, as they depend on two shapes
384
+ shapeCache[intDirection][x][y] = shape
385
+ if shape == shapes['BLOCK_FOUR']:
386
+ blockFourCount += 1
387
+ if shape == shapes['THREE']:
388
+ threeCount += 1
389
+ if shape == shapes['TWO']:
390
+ twoCount += 1
391
+ if blockFourCount >= 2:
392
+ shape = shapes['FOUR_FOUR']
393
+ elif blockFourCount and threeCount:
394
+ shape = shapes['FOUR_THREE']
395
+ elif threeCount >= 2:
396
+ shape = shapes['THREE_THREE']
397
+ elif twoCount >= 2:
398
+ shape = shapes['TWO_TWO']
399
+ score += getRealShapeScore(shape)
400
+
401
+ self.board[x + 1][y + 1] = 0 # Remove the temporary piece
402
+
403
+ if role == 1:
404
+ self.blackScores[x][y] = score
405
+ else:
406
+ self.whiteScores[x][y] = score
407
+
408
+ return score
409
+
410
+ def evaluate(self, role):
411
+ # Checked
412
+ blackScore = 0
413
+ whiteScore = 0
414
+
415
+ for i in range(len(self.blackScores)):
416
+ for j in range(len(self.blackScores[i])):
417
+ blackScore += self.blackScores[i][j]
418
+
419
+ for i in range(len(self.whiteScores)):
420
+ for j in range(len(self.whiteScores[i])):
421
+ whiteScore += self.whiteScores[i][j]
422
+
423
+ score = blackScore - whiteScore if role == 1 else whiteScore - blackScore
424
+ return score
425
+
426
+ def getMoves(self, role, depth, onThree=False, onlyFour=False, use_net = False):
427
+ # Checked
428
+ train_data = 0
429
+ if use_net and role == 1:
430
+ # value_move_num = 6
431
+ # input = torch.Tensor(np.array(self.board)[1:-1, 1:-1]).unsqueeze(0)
432
+ # scores = mini_max_net(input)
433
+ # flattened_scores = scores.flatten()
434
+ #
435
+ # moves = (flattened_scores.argsort(descending=True)[:value_move_num]).tolist()
436
+ moves = 0
437
+ # print(moves)
438
+ else:
439
+ moves, model_train_maxtrix = self._getMoves(role, depth, onThree, onlyFour)
440
+ train_data = {"state": np.array(self.board)[1:-1, 1:-1], "scores": model_train_maxtrix}
441
+ moves = [(move // self.size, move % self.size) for move in moves]
442
+ # cut the self.board into normal size
443
+ print("moves", moves)
444
+
445
+ return moves, train_data
446
+
447
+ def _getMoves(self, role, depth, only_three=False, only_four=False):
448
+ """
449
+ Get possible moves based on the current game state.
450
+ """
451
+ points = self.getPoints(role, depth, only_three, only_four)
452
+ fives = points[shapes['FIVE']]
453
+ block_fives = points[shapes['BLOCK_FIVE']]
454
+
455
+ # To train the model, we need to get all these points's score and store it to board size matrix
456
+ # Then we can use this matrix to train the model, given a state, we want it to output the score of each point, then we can choose the highest score point
457
+ model_train_matrix = [[0] * self.size for _ in range(self.size)]
458
+
459
+ if fives and len(fives) > 0 or block_fives and len(block_fives) > 0:
460
+ for point in fives:
461
+ x = point // self.size
462
+ y = point % self.size
463
+ model_train_matrix[x][y] = max(FIVE, model_train_matrix[x][y])
464
+ for point in block_fives:
465
+ x = point // self.size
466
+ y = point % self.size
467
+ model_train_matrix[x][y] = max(BLOCK_FIVE, model_train_matrix[x][y])
468
+
469
+ return set(list(fives) + list(block_fives)), model_train_matrix
470
+
471
+ fours = points[shapes['FOUR']]
472
+ block_fours = points[shapes['BLOCK_FOUR']] # Block four is special, consider it in both four and three
473
+ if only_four or (fours and len(fours) > 0):
474
+ for point in fours:
475
+ x = point // self.size
476
+ y = point % self.size
477
+ model_train_matrix[x][y] = max(FOUR, model_train_matrix[x][y])
478
+
479
+ for point in block_fours:
480
+ x = point // self.size
481
+ y = point % self.size
482
+ model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
483
+
484
+ return set(list(fours) + list(block_fours)), model_train_matrix
485
+
486
+ four_fours = points[shapes['FOUR_FOUR']]
487
+ if four_fours and len(four_fours) > 0:
488
+ for point in four_fours:
489
+ x = point // self.size
490
+ y = point % self.size
491
+ model_train_matrix[x][y] = max(FOUR_FOUR, model_train_matrix[x][y])
492
+
493
+ for point in block_fours:
494
+ x = point // self.size
495
+ y = point % self.size
496
+ model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
497
+
498
+ return set(list(four_fours) + list(block_fours)), model_train_matrix
499
+
500
+ # Double threes and active threes
501
+ threes = points[shapes['THREE']]
502
+ four_threes = points[shapes['FOUR_THREE']]
503
+ if four_threes and len(four_threes) > 0:
504
+ for point in four_threes:
505
+ x = point // self.size
506
+ y = point % self.size
507
+ model_train_matrix[x][y] = max(FOUR_THREE, model_train_matrix[x][y])
508
+
509
+ for point in block_fours:
510
+ x = point // self.size
511
+ y = point % self.size
512
+ model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
513
+
514
+ for point in threes:
515
+ x = point // self.size
516
+ y = point % self.size
517
+ model_train_matrix[x][y] = max(THREE, model_train_matrix[x][y])
518
+
519
+ return set(list(four_threes) + list(block_fours) + list(threes)), model_train_matrix
520
+
521
+ three_threes = points[shapes['THREE_THREE']]
522
+ if three_threes and len(three_threes) > 0:
523
+
524
+ for point in three_threes:
525
+ x = point // self.size
526
+ y = point % self.size
527
+ model_train_matrix[x][y] = max(THREE_THREE, model_train_matrix[x][y])
528
+
529
+ for point in block_fours:
530
+ x = point // self.size
531
+ y = point % self.size
532
+ model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
533
+
534
+ for point in threes:
535
+ x = point // self.size
536
+ y = point % self.size
537
+ model_train_matrix[x][y] = max(THREE, model_train_matrix[x][y])
538
+
539
+ return set(list(three_threes) + list(block_fours) + list(threes)), model_train_matrix
540
+
541
+ if only_three:
542
+ for point in threes:
543
+ x = point // self.size
544
+ y = point % self.size
545
+ model_train_matrix[x][y] = max(THREE, model_train_matrix[x][y])
546
+
547
+ for point in block_fours:
548
+ x = point // self.size
549
+ y = point % self.size
550
+ model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
551
+ return set(list(block_fours) + list(threes)), model_train_matrix
552
+
553
+ block_threes = points[shapes['BLOCK_THREE']]
554
+ two_twos = points[shapes['TWO_TWO']]
555
+ twos = points[shapes['TWO']]
556
+
557
+ for point in block_threes:
558
+ x = point // self.size
559
+ y = point % self.size
560
+ model_train_matrix[x][y] = max(BLOCK_THREE, model_train_matrix[x][y])
561
+
562
+ for point in two_twos:
563
+ x = point // self.size
564
+ y = point % self.size
565
+ model_train_matrix[x][y] = max(TWO_TWO, model_train_matrix[x][y])
566
+
567
+ for point in twos:
568
+ x = point // self.size
569
+ y = point % self.size
570
+ model_train_matrix[x][y] = max(TWO, model_train_matrix[x][y])
571
+
572
+ for point in block_fours:
573
+ x = point // self.size
574
+ y = point % self.size
575
+ model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
576
+
577
+ for point in threes:
578
+ x = point // self.size
579
+ y = point % self.size
580
+ model_train_matrix[x][y] = max(THREE, model_train_matrix[x][y])
581
+
582
+ mid = list(block_fours) + list(threes) + list(block_threes) + list(two_twos) + list(twos)
583
+ res = set(mid[:5])
584
+ for i in range(len(model_train_matrix)):
585
+ for j in range(len(model_train_matrix)):
586
+ if (i * len(model_train_matrix) + j) not in res:
587
+ model_train_matrix[i][j] = 0
588
+ return res, model_train_matrix
Gomoku_Bot/gomoku_bot.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .minmax import *
2
+ import time
3
+
4
+
5
+ class Gomoku_bot:
6
+ def __init__(self, board, role, depth=4, enableVCT=True):
7
+ self.board = board
8
+ self.role = role
9
+ self.depth = depth
10
+ self.enableVCT = enableVCT
11
+
12
+ def get_action(self, return_time=True):
13
+ start = time.time()
14
+ score = minmax(self.board, self.role, self.depth, self.enableVCT)
15
+ end = time.time()
16
+ sim_time = end - start
17
+ move = score[1]
18
+ # turn tuple into an int
19
+ move = move[0] * self.board.size + move[1]
20
+ if return_time:
21
+ return move, sim_time
22
+ else:
23
+ return move
Gomoku_Bot/minimax_Net.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import pickle
6
+ import os
7
+
8
+ dir_path = os.path.dirname(os.path.realpath(__file__))
9
+ # from tensorboardX import SummaryWriter
10
+ from tqdm import tqdm
11
+ import datetime
12
+ from torch.utils.data import DataLoader, TensorDataset
13
+
14
+ date = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
15
+
16
+
17
+ class BoardEvaluationNet(nn.Module):
18
+ def __init__(self, board_size):
19
+ super(BoardEvaluationNet, self).__init__()
20
+ self.board_size = board_size
21
+
22
+ self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
23
+ self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
24
+ self.fc1 = nn.Linear(32 * board_size * board_size, 256)
25
+ self.fc2 = nn.Linear(256, board_size * board_size)
26
+
27
+ def forward(self, x):
28
+ x = x.unsqueeze(1) # Add a channel dimension
29
+ x = F.relu(self.conv1(x))
30
+ x = F.relu(self.conv2(x))
31
+ x = x.view(-1, 32 * self.board_size * self.board_size)
32
+ x = F.relu(self.fc1(x))
33
+ x = self.fc2(x)
34
+ return x.view(-1, self.board_size, self.board_size)
35
+
36
+
37
+ def normalize(t):
38
+ return t
39
+
40
+
41
+ if __name__ == "__main__":
42
+
43
+ writer = SummaryWriter(os.path.join(dir_path, 'train_data/log', date), comment='BoardEvaluationNet')
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+
46
+ best = np.Inf
47
+ loss_fn = nn.CrossEntropyLoss()
48
+
49
+ # Example usage
50
+ BS = 15
51
+
52
+ net_for_black = BoardEvaluationNet(BS).to(device)
53
+ net_for_white = BoardEvaluationNet(BS).to(device)
54
+
55
+ net_for_black.load_state_dict(torch.load(os.path.join(dir_path, 'train_data/model', 'best_loss=680.5813717259707.pth')))
56
+
57
+ optimizer = torch.optim.Adam(net_for_black.parameters(), lr=1e-5, betas=(0.9, 0.99),
58
+ eps=1e-8)
59
+
60
+ data_path = os.path.join(dir_path, 'train_data/data', 'train_data.pkl')
61
+ with open(data_path, 'rb') as f:
62
+ datas = pickle.load(f)
63
+
64
+ train_data_for_black = datas[1][:int(len(datas[1]) * 1)]
65
+ test_data_for_black = datas[1][int(len(datas[1]) * 0.8):]
66
+ train_data_for_white = datas[-1]
67
+ epochs = 500
68
+ batch_size = 32
69
+ train_dataset = TensorDataset(torch.stack([torch.tensor(item['state'], dtype=torch.float) for item in train_data_for_black]),
70
+ torch.stack([normalize(torch.tensor(item['scores'], dtype=torch.float)) for item in train_data_for_black]))
71
+ train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
72
+
73
+ for epoch in range(epochs):
74
+ epoch_loss = 0
75
+ print('Epoch:', epoch)
76
+ for i, (states, scores) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
77
+ states = states.to(device)
78
+ scores = scores.to(device)
79
+
80
+ # print(input_tensor.shape)
81
+ infer_start = datetime.datetime.now()
82
+ output_tensor = net_for_black(states)
83
+ infer_end = datetime.datetime.now()
84
+ loss = loss_fn(output_tensor, scores)
85
+ print(loss.item())
86
+ exit(0)
87
+ loss.backward()
88
+ optimizer.step()
89
+ optimizer.zero_grad()
90
+ epoch_loss += loss.item()
91
+
92
+ writer.add_scalar('train/infer_time', (infer_end - infer_start).microseconds,
93
+ i + epoch * len(train_dataloader))
94
+
95
+ epoch_loss /= len(train_dataloader)
96
+ writer.add_scalar('train/epoch_loss', epoch_loss, epoch)
97
+ # test
98
+ with torch.no_grad():
99
+ test_loss = 0
100
+ net_for_black.eval()
101
+ for j, item in tqdm(enumerate(test_data_for_black), total=len(test_data_for_black)):
102
+ scores = normalize(torch.tensor(item['scores'], dtype=torch.float).to(device).unsqueeze(0)) # 将数据类型设为float
103
+ state = item['state']
104
+ input_tensor = torch.tensor(state, dtype=torch.float).to(device).unsqueeze(0) # 将数据类型设为float,并转移到设备上
105
+ output_tensor = net_for_black(input_tensor).to(device)
106
+ loss = loss_fn(output_tensor, scores)
107
+ test_loss += loss.item()
108
+ test_loss /=len(test_data_for_black)
109
+ writer.add_scalar('test/loss', test_loss, epoch)
110
+ if best > test_loss:
111
+ best = test_loss
112
+ model_path = os.path.join(dir_path, 'train_data/model')
113
+ if not os.path.exists(model_path):
114
+ os.makedirs(model_path)
115
+ torch.save(net_for_black.state_dict(),
116
+ os.path.join(model_path, f'best_loss={best}.pth'))
117
+ net_for_black.train()
Gomoku_Bot/minmax.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .cache import Cache
2
+ from .eval import FIVE
3
+ # Checked
4
+
5
+
6
+ MAX = 1000000000
7
+ cache_hits = {
8
+ "search": 0,
9
+ "total": 0,
10
+ "hit": 0
11
+ }
12
+
13
+ onlyThreeThreshold = 6
14
+ cache = Cache()
15
+
16
+
17
+ def factory(onlyThree=False, onlyFour=False):
18
+
19
+ def helper(board, role, depth, cDepth=0, path=(), alpha=-MAX, beta=MAX):
20
+ cache_hits["search"] += 1
21
+ if cDepth >= depth or board.isGameOver():
22
+ return [board.evaluate(role), None, (path)]
23
+ hash = board.hash()
24
+ prev = cache.get(hash)
25
+ if prev and prev["role"] == role:
26
+ if (
27
+ (abs(prev["value"]) >= FIVE or prev["depth"] >= depth - cDepth)
28
+ and prev["onlyThree"] == onlyThree
29
+ and prev["onlyFour"] == onlyFour
30
+ ):
31
+ cache_hits["hit"] += 1
32
+ return [prev["value"], prev["move"], path + prev["path"]]
33
+ value = -MAX
34
+ move = None
35
+ bestPath = path # Copy the current path
36
+ bestDepth = 0
37
+ # points = board.getValuableMoves(role, cDepth, onlyThree or cDepth > onlyThreeThreshold, onlyFour)
38
+ points = board.getValuableMoves(role, cDepth, onlyThree or cDepth > onlyThreeThreshold, onlyFour)
39
+ if cDepth == 0:
40
+ print('points:', points)
41
+ if not len(points):
42
+ return [board.evaluate(role), None, path]
43
+ for d in range(cDepth + 1, depth + 1):
44
+ # 迭代加深过程中只找己方能赢的解,因此只搜索偶数层即可
45
+ if d % 2 != 0:
46
+ continue
47
+ breakAll = False
48
+ for point in points:
49
+ board.put(point[0], point[1], role)
50
+ newPath = tuple(list(path) + [point]) # Add current move to path
51
+ currentValue, currentMove, currentPath = helper(board, -role, d, cDepth + 1, tuple(newPath) , -beta, -alpha)
52
+ currentValue = -currentValue
53
+ board.undo()
54
+ ## 迭代加深的过程中,除了能赢的棋,其他都不要
55
+ ## 原因是:除了必胜的,其他评估不准。比如必输的棋,由于走的步数偏少,也会变成没有输,比如5
56
+ ### 步之后输了,但是1步肯定不会输,这时候1步的分数是不准确的,显然不能选择。
57
+ if currentValue >= FIVE or d == depth:
58
+ # 必输的棋,也要挣扎一下,选择最长的路径
59
+ if (
60
+ currentValue > value
61
+ or (currentValue <= -FIVE and value <= -FIVE and len(currentPath) > bestDepth)
62
+ ):
63
+ value = currentValue
64
+ move = point
65
+ bestPath = currentPath
66
+ bestDepth = len(currentPath)
67
+ alpha = max(alpha, value)
68
+ if alpha >= FIVE:
69
+ breakAll = True
70
+ break
71
+ if alpha >= beta:
72
+ break
73
+ if breakAll:
74
+ break
75
+ if (cDepth < onlyThreeThreshold or onlyThree or onlyFour) and (not prev or prev["depth"] < depth - cDepth):
76
+ cache_hits["total"] += 1
77
+ cache.put(hash, {
78
+ "depth": depth - cDepth,
79
+ "value": value,
80
+ "move": move,
81
+ "role": role,
82
+ "path": bestPath[cDepth:],
83
+ "onlyThree": onlyThree,
84
+ "onlyFour": onlyFour,
85
+ })
86
+ return [value, move, bestPath]
87
+ return helper
88
+
89
+
90
+ _minmax = factory()
91
+ vct = factory(True)
92
+ vcf = factory(False, True)
93
+
94
+
95
+ def minmax(board, role, depth=4, enableVCT=False):
96
+
97
+ if enableVCT:
98
+ vctDepth = depth + 8
99
+ value, move, bestPath = vct(board, role, vctDepth)
100
+ if value >= FIVE:
101
+ return [value, move, bestPath]
102
+ value, move, bestPath = _minmax(board, role, depth)
103
+ '''
104
+ // 假设对方有杀棋,先按自己的思路走,走完之后看对方是不是还有杀棋
105
+ // 如果对方没有了,那么就说明走的是对的
106
+ // 如果对方还是有,那么要对比对方的杀棋路径和自己没有走棋时的长短
107
+ // 如果走了棋之后路径变长了,说明走的是对的
108
+ // 如果走了棋之后,对方杀棋路径长度没变,甚至更短,说明走错了,此时就优先封堵对方
109
+ '''
110
+ board.put(move[0], move[1], role)
111
+ value2, move2, bestPath2 = vct(board.reverse(), role, vctDepth)
112
+ board.undo()
113
+ if value < FIVE and value2 == FIVE and len(bestPath2) > len(bestPath):
114
+ value3, move3, bestPath3 = vct(board.reverse(), role, vctDepth)
115
+ if len(bestPath2) <= len(bestPath3):
116
+ return [value, move2, bestPath2] # value2 是被挡住的,所以这里还是用value
117
+ return [value, move, bestPath]
118
+ else:
119
+ return _minmax(board, role, depth)
Gomoku_Bot/position.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from .config import config
3
+ # Checked
4
+
5
+ # Function to convert position to coordinate
6
+ def position2Coordinate(position: int, size: int) -> List[int]:
7
+ return [position // size, position % size]
8
+
9
+ # Function to convert coordinate to position
10
+ def coordinate2Position(x: int, y: int, size: int) -> int:
11
+ return x * size + y
12
+
13
+ # Check if points a and b are on the same line and the distance is less than maxDistance
14
+ def isLine(a: int, b: int, size: int) -> bool:
15
+ maxDistance = config["inLineDistance"]
16
+ [x1, y1] = position2Coordinate(a, size)
17
+ [x2, y2] = position2Coordinate(b, size)
18
+ return (
19
+ (x1 == x2 and abs(y1 - y2) < maxDistance) or
20
+ (y1 == y2 and abs(x1 - x2) < maxDistance) or
21
+ (abs(x1 - x2) == abs(y1 - y2) and abs(x1 - x2) < maxDistance)
22
+ )
23
+
24
+ # Check if all points in the array are on the same line as point p
25
+ def isAllInLine(p: int, arr: List[int], size: int) -> bool:
26
+ for i in range(len(arr)):
27
+ if not isLine(p, arr[i], size):
28
+ return False
29
+ return True
30
+
31
+ # Check if any point in the array is on the same line as point p
32
+ def hasInLine(p: int, arr: List[int], size: int) -> bool:
33
+ for i in range(len(arr)):
34
+ if isLine(p, arr[i], size):
35
+ return True
36
+ return False
Gomoku_Bot/shape.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import re
3
+
4
+ # Define patterns using regular expressions
5
+ patterns = {
6
+ 'five': re.compile('11111'),
7
+ 'block_five': re.compile('211111|111112'),
8
+ 'four': re.compile('011110'),
9
+ 'block_four': re.compile('10111|11011|11101|211110|211101|211011|210111|011112|101112|110112|111012'),
10
+ 'three': re.compile('011100|011010|010110|001110'),
11
+ 'block_three': re.compile('211100|211010|210110|001112|010112|011012'),
12
+ 'two': re.compile('001100|011000|000110|010100|001010'),
13
+ }
14
+
15
+ # Define shapes with associated scores
16
+ shapes = {
17
+ 'FIVE': 5,
18
+ 'BLOCK_FIVE': 50,
19
+ 'FOUR': 4,
20
+ 'FOUR_FOUR': 44, # Double four
21
+ 'FOUR_THREE': 43, # Four with an open three
22
+ 'THREE_THREE': 33, # Double three
23
+ 'BLOCK_FOUR': 40,
24
+ 'THREE': 3,
25
+ 'BLOCK_THREE': 30,
26
+ 'TWO_TWO': 22, # Double two
27
+ 'TWO': 2,
28
+ 'NONE': 0
29
+ }
30
+
31
+ # Initialize a performance record
32
+ performance = {
33
+ 'five': 0,
34
+ 'block_five': 0,
35
+ 'four': 0,
36
+ 'block_four': 0,
37
+ 'three': 0,
38
+ 'block_three': 0,
39
+ 'two': 0,
40
+ 'none': 0,
41
+ 'total': 0
42
+ }
43
+
44
+ # Function to detect shapes on the board
45
+ def get_shape(board, x, y, offset_x, offset_y, role):
46
+ """
47
+ Detect shape at a given board position.
48
+ :param board: The game board.
49
+ :param x: X-coordinate.
50
+ :param y: Y-coordinate.
51
+ :param offset_x: X-direction offset for scanning.
52
+ :param offset_y: Y-direction offset for scanning.
53
+ :param role: Current player's role.
54
+ :return: A tuple of shape, self count, opponent count, and empty count.
55
+ """
56
+ opponent = -role
57
+ empty_count = 0
58
+ self_count = 1
59
+ opponent_count = 0
60
+ shape = shapes['NONE']
61
+
62
+ # Skip empty nodes
63
+ if (
64
+ board[x + offset_x + 1][y + offset_y + 1] == 0
65
+ and board[x - offset_x + 1][y - offset_y + 1] == 0
66
+ and board[x + 2 * offset_x + 1][y + 2 * offset_y + 1] == 0
67
+ and board[x - 2 * offset_x + 1][y - 2 * offset_y + 1] == 0
68
+ ):
69
+ return [0, self_count, opponent_count, empty_count]
70
+ # Check for 'two' pattern
71
+ for i in range(-3, 4):
72
+ if i == 0:
73
+ continue
74
+ nx, ny = x + i * offset_x, y + i * offset_y
75
+ current_role = board.get((nx, ny))
76
+ if current_role is None:
77
+ continue
78
+ if current_role == 2:
79
+ opponent_count += 1
80
+ elif current_role == role:
81
+ self_count += 1
82
+ elif current_role == 0:
83
+ empty_count += 1
84
+
85
+ if self_count == 2:
86
+ if opponent_count == 0:
87
+ return shapes['TWO'], self_count, opponent_count, empty_count
88
+ else:
89
+ return shapes['NONE'], self_count, opponent_count, empty_count
90
+
91
+ # Reset counts and prepare string for pattern matching
92
+ empty_count, self_count, opponent_count = 0, 1, 0
93
+ result_string = '1'
94
+
95
+ # Build result string for pattern matching
96
+ for i in range(1, 6):
97
+ nx = x + i * offset_x + 1
98
+ ny = y + i * offset_y + 1
99
+ currentRole = board[nx][ny]
100
+ if currentRole == 2:
101
+ result_string += '2'
102
+ elif currentRole == 0:
103
+ result_string += '0'
104
+ else:
105
+ result_string += '1' if currentRole == role else '2'
106
+ if currentRole == 2 or currentRole == opponent:
107
+ opponent_count += 1
108
+ break
109
+ if currentRole == 0:
110
+ empty_count += 1
111
+ if currentRole == role:
112
+ self_count += 1
113
+
114
+ for i in range(1, 6):
115
+ nx = x - i * offset_x + 1
116
+ ny = y - i * offset_y + 1
117
+ currentRole = board[nx][ny]
118
+ if currentRole == 2:
119
+ result_string = '2' + result_string
120
+ elif currentRole == 0:
121
+ result_string = '0' + result_string
122
+ else:
123
+ result_string = '1' if currentRole == role else '2' + result_string
124
+ if currentRole == 2 or currentRole == opponent:
125
+ opponent_count += 1
126
+ break
127
+ if currentRole == 0:
128
+ empty_count += 1
129
+ if currentRole == role:
130
+ self_count += 1
131
+
132
+ # Check patterns and update performance
133
+ for pattern_key, shape_key in [('five', 'FIVE'), ('four', 'FOUR'), ('block_four', 'BLOCK_FOUR'),
134
+ ('three', 'THREE'), ('block_three', 'BLOCK_THREE'), ('two', 'TWO')]:
135
+ if patterns[pattern_key].search(result_string):
136
+ shape = shapes[shape_key]
137
+ performance[pattern_key] += 1
138
+ performance['total'] += 1
139
+ break
140
+ ## 尽量减少多余字符串生成
141
+ if self_count <= 1 or len(result_string) < 5:
142
+ return shape, self_count, opponent_count, empty_count
143
+
144
+ return shape, self_count, opponent_count, empty_count
145
+
146
+ def count_shape(board, x, y, offset_x, offset_y, role):
147
+ opponent = - role
148
+
149
+ inner_empty_count = 0 # Number of empty positions inside the player's stones
150
+ temp_empty_count = 0
151
+ self_count = 0 # Number of the player's stones in the shape
152
+ total_length = 0
153
+
154
+ side_empty_count = 0 # Number of empty positions on the side of the shape
155
+ no_empty_self_count = 0
156
+ one_empty_self_count = 0
157
+
158
+ # Right direction
159
+ for i in range(1, 6):
160
+ nx = x + i * offset_x + 1
161
+ ny = y + i * offset_y + 1
162
+ current_role = board[nx][ny]
163
+ if current_role == 2 or current_role == opponent:
164
+ break
165
+ if current_role == role:
166
+ self_count += 1
167
+ side_empty_count = 0
168
+ if temp_empty_count:
169
+ inner_empty_count += temp_empty_count
170
+ temp_empty_count = 0
171
+ if inner_empty_count == 0:
172
+ no_empty_self_count += 1
173
+ one_empty_self_count += 1
174
+ elif inner_empty_count == 1:
175
+ one_empty_self_count += 1
176
+ total_length += 1
177
+ if current_role == 0:
178
+ temp_empty_count += 1
179
+ side_empty_count += 1
180
+ if side_empty_count >= 2:
181
+ break
182
+
183
+ if not inner_empty_count:
184
+ one_empty_self_count = 0
185
+
186
+ return {
187
+ 'self_count': self_count,
188
+ 'total_length': total_length,
189
+ 'no_empty_self_count': no_empty_self_count,
190
+ 'one_empty_self_count': one_empty_self_count,
191
+ 'inner_empty_count': inner_empty_count,
192
+ 'side_empty_count': side_empty_count
193
+ }
194
+
195
+ # Fast shape detection function
196
+ def get_shape_fast(board, x, y, offsetX, offsetY, role):
197
+ if (
198
+ board[x + offsetX + 1][y + offsetY + 1] == 0
199
+ and board[x - offsetX + 1][y - offsetY + 1] == 0
200
+ and board[x + 2 * offsetX + 1][y + 2 * offsetY + 1] == 0
201
+ and board[x - 2 * offsetX + 1][y - 2 * offsetY + 1] == 0
202
+ ):
203
+ return [shapes['NONE'], 1]
204
+
205
+ selfCount = 1
206
+ totalLength = 1
207
+ shape = shapes['NONE']
208
+
209
+ leftEmpty = 0
210
+ rightEmpty = 0
211
+ noEmptySelfCount = 1
212
+ OneEmptySelfCount = 1
213
+
214
+ left = count_shape(board, x, y, -offsetX, -offsetY, role)
215
+ right = count_shape(board, x, y, offsetX, offsetY, role)
216
+
217
+ selfCount = left['self_count'] + right['self_count'] + 1
218
+ totalLength = left['total_length'] + right['total_length'] + 1
219
+ noEmptySelfCount = left['no_empty_self_count'] + right['no_empty_self_count'] + 1
220
+ OneEmptySelfCount = max(
221
+ left['one_empty_self_count'] + right['no_empty_self_count'],
222
+ left['no_empty_self_count'] + right['one_empty_self_count'],
223
+ ) + 1
224
+ rightEmpty = right['side_empty_count']
225
+ leftEmpty = left['side_empty_count']
226
+
227
+ if totalLength < 5:
228
+ return [shape, selfCount]
229
+
230
+ if noEmptySelfCount >= 5:
231
+ if rightEmpty > 0 and leftEmpty > 0:
232
+ return [shapes['FIVE'], selfCount]
233
+ else:
234
+ return [shapes['BLOCK_FIVE'], selfCount]
235
+
236
+ if noEmptySelfCount == 4:
237
+ if (
238
+ (rightEmpty >= 1 or right['one_empty_self_count'] > right['no_empty_self_count'])
239
+ and (leftEmpty >= 1 or left['one_empty_self_count'] > left['no_empty_self_count'])
240
+ ):
241
+ return [shapes['FOUR'], selfCount]
242
+ elif not (rightEmpty == 0 and leftEmpty == 0):
243
+ return [shapes['BLOCK_FOUR'], selfCount]
244
+
245
+ if OneEmptySelfCount == 4:
246
+ return [shapes['BLOCK_FOUR'], selfCount]
247
+
248
+ if noEmptySelfCount == 3:
249
+ if (rightEmpty >= 2 and leftEmpty >= 1) or (rightEmpty >= 1 and leftEmpty >= 2):
250
+ return [shapes['THREE'], selfCount]
251
+ else:
252
+ return [shapes['BLOCK_THREE'], selfCount]
253
+
254
+ if OneEmptySelfCount == 3:
255
+ if rightEmpty >= 1 and leftEmpty >= 1:
256
+ return [shapes['THREE'], selfCount]
257
+ else:
258
+ return [shapes['BLOCK_THREE'], selfCount]
259
+
260
+ if (noEmptySelfCount == 2 or OneEmptySelfCount == 2) and totalLength > 5:
261
+ shape = shapes['TWO']
262
+
263
+ return [shape, selfCount]
264
+
265
+ # Helper functions to check for specific shapes
266
+ def is_five(shape):
267
+ # Checked
268
+ return shape in [shapes['FIVE'], shapes['BLOCK_FIVE']]
269
+
270
+ def is_four(shape):
271
+ # Checked
272
+ return shape in [shapes['FOUR'], shapes['BLOCK_FOUR']]
273
+
274
+ # Function to get all shapes at a specific point
275
+ def get_all_shapes_of_point(shape_cache, x, y, role = None):
276
+ # Checked
277
+ roles = [role] if role else [1, -1]
278
+ result = []
279
+ for r in roles:
280
+ for d in range(4):
281
+ shape = shape_cache[r][d][x][y]
282
+ if shape > 0:
283
+ result.append(shape)
284
+ return result
285
+
286
+
287
+ if __name__ == "__main__":
288
+ pass
Gomoku_Bot/zobrist.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ # Checked
3
+ class ZobristCache:
4
+ def __init__(self, size):
5
+ self.size = size
6
+ self.zobristTable = self.initializeZobristTable(size)
7
+ self.hash = 0
8
+
9
+ def initializeZobristTable(self, size):
10
+ table = []
11
+ for i in range(size):
12
+ table.append([])
13
+ for j in range(size):
14
+ table[i].append({
15
+ 1: random.getrandbits(64), # black
16
+ -1: random.getrandbits(64) # white
17
+ })
18
+ return table
19
+
20
+ def togglePiece(self, x, y, role):
21
+ self.hash ^= self.zobristTable[x][y][role]
22
+
23
+ def getHash(self):
24
+ return self.hash
25
+
26
+ if __name__ == '__main__':
27
+ # Example usage
28
+ size = 8
29
+ cache = ZobristCache(size)
30
+ x = 3
31
+ y = 4
32
+ role = 1
33
+ cache.togglePiece(x, y, role)
34
+ hash_value = cache.getHash()
35
+ print(hash_value)
ai.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from transformers import GPT2LMHeadModel
4
+
5
+
6
+ def load_model(model_name: str = "snoop2head/Gomoku-GPT2") -> GPT2LMHeadModel:
7
+ gpt2 = GPT2LMHeadModel.from_pretrained(model_name)
8
+ return gpt2
9
+
10
+
11
+ BOS_TOKEN_ID = 401
12
+ PAD_TOKEN_ID = 402
13
+ EOS_TOKEN_ID = 403
14
+
15
+
16
+ def generate_gpt2(model: GPT2LMHeadModel, input_ids: torch.LongTensor,temperature = 0.7) -> list:
17
+ """
18
+ input_ids: [batch_size, seq_len] torch.LongTensor
19
+ output_ids: [seq_len] list
20
+ """
21
+ output_ids = model.generate(
22
+ input_ids,
23
+ max_length=128,
24
+ num_beams=5,
25
+ temperature= temperature,
26
+ pad_token_id=PAD_TOKEN_ID,
27
+ eos_token_id=EOS_TOKEN_ID,
28
+ )
29
+ return output_ids.squeeze().tolist()
30
+
31
+
32
+ def change_to_1d_coordinate(board: np.ndarray, x: int, y: int) -> int:
33
+ """change 2d coordinate to 1d coordinate"""
34
+ return x * board.shape[1] + y
35
+
36
+
37
+ def change_to_2d_coordinate(board: np.ndarray, coordinate: int) -> tuple:
38
+ """change 1d coordinate to 2d coordinate"""
39
+ return (coordinate // board.shape[1], coordinate % board.shape[1])
const.py CHANGED
@@ -14,10 +14,13 @@ _BOARD_SIZE_1D = _BOARD_SIZE * _BOARD_SIZE
14
  _BLANK = 0
15
  _BLACK = 1
16
  _WHITE = 2
 
17
  _PLAYER_SYMBOL = {
18
  _WHITE: "⚪",
19
  _BLANK: "➕",
20
  _BLACK: "⚫",
 
 
21
  }
22
  _PLAYER_COLOR = {
23
  _WHITE: "AI",
 
14
  _BLANK = 0
15
  _BLACK = 1
16
  _WHITE = 2
17
+ _NEW = 3
18
  _PLAYER_SYMBOL = {
19
  _WHITE: "⚪",
20
  _BLANK: "➕",
21
  _BLACK: "⚫",
22
+ _NEW: "🔴",
23
+
24
  }
25
  _PLAYER_COLOR = {
26
  _WHITE: "AI",
pages/AI_VS_AI.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FileName: app.py
3
+ Author: Benhao Huang
4
+ Create Date: 2023/11/19
5
+ Description: this file is used to display our project and add visualization elements to the game, using Streamlit
6
+ """
7
+
8
+ import time
9
+ import pandas as pd
10
+ from copy import deepcopy
11
+ import torch
12
+
13
+ # import torch
14
+ import numpy as np
15
+ import streamlit as st
16
+ from scipy.signal import convolve # this is used to check if any player wins
17
+ from streamlit import session_state
18
+ from streamlit_server_state import server_state, server_state_lock
19
+ from Gomoku_MCTS import MCTSpure, alphazero, Board, PolicyValueNet
20
+ from Gomoku_Bot import Gomoku_bot
21
+ from Gomoku_Bot import Board as Gomoku_bot_board
22
+ import matplotlib.pyplot as plt
23
+
24
+
25
+
26
+ from const import (
27
+ _BLACK, # 1, for human
28
+ _WHITE, # 2 , for AI
29
+ _BLANK,
30
+ _PLAYER_COLOR,
31
+ _PLAYER_SYMBOL,
32
+ _ROOM_COLOR,
33
+ _VERTICAL,
34
+ _NEW,
35
+ _HORIZONTAL,
36
+ _DIAGONAL_UP_LEFT,
37
+ _DIAGONAL_UP_RIGHT,
38
+ _BOARD_SIZE,
39
+ _BOARD_SIZE_1D,
40
+ _AI_AID_INFO
41
+ )
42
+
43
+
44
+ from ai import (
45
+ BOS_TOKEN_ID,
46
+ generate_gpt2,
47
+ load_model,
48
+ )
49
+
50
+ gpt2 = load_model()
51
+
52
+
53
+ # Utils
54
+ class Room:
55
+ def __init__(self, room_id) -> None:
56
+ self.ROOM_ID = room_id
57
+ # self.BOARD = np.zeros(shape=(_BOARD_SIZE, _BOARD_SIZE), dtype=int)
58
+ self.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5, players=[_BLACK, _WHITE])
59
+ self.PLAYER = _BLACK
60
+ self.TURN = self.PLAYER
61
+ self.HISTORY = (0, 0)
62
+ self.WINNER = _BLANK
63
+ self.TIME = time.time()
64
+ self.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
65
+ self.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
66
+ 'AlphaZero': alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, 'Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth').policy_value_fn, c_puct=5, n_playout=100),
67
+ 'Gomoku Bot': Gomoku_bot(self.gomoku_bot_board, -1)}
68
+ self.MCTS = self.MCTS_dict['AlphaZero']
69
+ self.last_mcts = self.MCTS
70
+ self.AID_MCTS = self.MCTS_dict['AlphaZero']
71
+ self.COORDINATE_1D = [BOS_TOKEN_ID]
72
+ self.current_move = -1
73
+ self.simula_time_list = []
74
+
75
+
76
+ def change_turn(cur):
77
+ return cur % 2 + 1
78
+
79
+
80
+ # Initialize the game
81
+ if "ROOM" not in session_state:
82
+ session_state.ROOM = Room("local")
83
+ if "OWNER" not in session_state:
84
+ session_state.OWNER = False
85
+ if "USE_AIAID" not in session_state:
86
+ session_state.USE_AIAID = False
87
+
88
+ # Check server health
89
+ if "ROOMS" not in server_state:
90
+ with server_state_lock["ROOMS"]:
91
+ server_state.ROOMS = {}
92
+
93
+ def handle_oppo_model_selection():
94
+ if st.session_state['selected_oppo_model'] == 'Gomoku Bot':
95
+ session_state.ROOM.last_mcts = session_state.ROOM.MCTS # since use different mechanism, store previous mcts first
96
+ session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict['Gomoku Bot']
97
+ return
98
+ else:
99
+ TreeNode = session_state.ROOM.last_mcts.mcts._root
100
+ new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
101
+ new_mct.mcts._root = deepcopy(TreeNode)
102
+ session_state.ROOM.MCTS = new_mct
103
+ session_state.ROOM.last_mcts = new_mct
104
+ return
105
+
106
+ def handle_aid_model_selection():
107
+ if st.session_state['selected_aid_model'] == 'None':
108
+ session_state.USE_AIAID = False
109
+ return
110
+ session_state.USE_AIAID = True
111
+ TreeNode = session_state.ROOM.MCTS.mcts._root # use the same tree node
112
+ new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_aid_model']]
113
+ new_mct.mcts._root = deepcopy(TreeNode)
114
+ session_state.ROOM.AID_MCTS = new_mct
115
+ return
116
+
117
+ if 'selected_oppo_model' not in st.session_state:
118
+ st.session_state['selected_oppo_model'] = 'AlphaZero' # 默认值
119
+
120
+ if 'selected_aid_model' not in st.session_state:
121
+ st.session_state['selected_aid_model'] = 'AlphaZero' # 默认值
122
+
123
+ # Layout
124
+ TITLE = st.empty()
125
+ Model_Switch = st.empty()
126
+
127
+ TITLE.header("🤖 AI 3603 Gomoku")
128
+ selected_oppo_option = Model_Switch.selectbox('Select Opponent Model', ['Pure MCTS', 'AlphaZero','Gomoku Bot'], index=1, key='oppo_model')
129
+
130
+ if st.session_state['selected_oppo_model'] != selected_oppo_option:
131
+ st.session_state['selected_oppo_model'] = selected_oppo_option
132
+ handle_oppo_model_selection()
133
+
134
+ ROUND_INFO = st.empty()
135
+ st.markdown("<br>", unsafe_allow_html=True)
136
+ BOARD_PLATE = [
137
+ [cell.empty() for cell in st.columns([1 for _ in range(_BOARD_SIZE)])] for _ in range(_BOARD_SIZE)
138
+ ]
139
+ LOG = st.empty()
140
+
141
+ # Sidebar
142
+ SCORE_TAG = st.sidebar.empty()
143
+ SCORE_PLATE = st.sidebar.columns(2)
144
+ # History scores
145
+ SCORE_TAG.subheader("Scores")
146
+
147
+ PLAY_MODE_INFO = st.sidebar.container()
148
+ MULTIPLAYER_TAG = st.sidebar.empty()
149
+ with st.sidebar.container():
150
+ ANOTHER_ROUND = st.empty()
151
+ RESTART = st.empty()
152
+ AIAID = st.empty()
153
+ EXIT = st.empty()
154
+ selected_aid_option = AIAID.selectbox('Select Assistant Model', ['None', 'Pure MCTS', 'AlphaZero'], index=0, key='aid_model')
155
+ if st.session_state['selected_aid_model'] != selected_aid_option:
156
+ st.session_state['selected_aid_model'] = selected_aid_option
157
+ handle_aid_model_selection()
158
+
159
+ GAME_INFO = st.sidebar.container()
160
+ message = st.empty()
161
+ PLAY_MODE_INFO.write("---\n\n**You are Black, AI agent is White.**")
162
+ GAME_INFO.markdown(
163
+ """
164
+ ---
165
+ # <span style="color:black;">Freestyle Gomoku game. 🎲</span>
166
+ - no restrictions 🚫
167
+ - no regrets 😎
168
+ - no regrets 😎
169
+ - swap players after one round is over 🔁
170
+ Powered by an AlphaZero approach with our own improvements! 🚀 For the specific details, please check out our <a href="insert_report_link_here" style="color:blue;">report</a>.
171
+ ##### Adapted and improved by us! 🌟 <a href="https://github.com/Lijiaxin0111/AI_3603_BIGHOME" style="color:blue;">Our Github repo</a>
172
+ """,
173
+ unsafe_allow_html=True,
174
+ )
175
+
176
+
177
+
178
+ def restart() -> None:
179
+ """
180
+ Restart the game.
181
+ """
182
+ session_state.ROOM = Room(session_state.ROOM.ROOM_ID)
183
+ st.session_state['selected_oppo_model'] = 'AlphaZero'
184
+
185
+ RESTART.button(
186
+ "Reset",
187
+ on_click=restart,
188
+ help="Clear the board as well as the scores",
189
+ )
190
+
191
+
192
+ # Draw the board
193
+ def gomoku():
194
+ """
195
+ Draw the board.
196
+ Handle the main logic.
197
+ """
198
+
199
+ # Restart the game
200
+
201
+ # Continue new round
202
+ def another_round() -> None:
203
+ """
204
+ Continue new round.
205
+ """
206
+ session_state.ROOM = deepcopy(session_state.ROOM)
207
+ session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
208
+ session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
209
+ session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
210
+ 'AlphaZero': alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, 'Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth').policy_value_fn, c_puct=5, n_playout=100),
211
+ 'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
212
+ session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
213
+ session_state.ROOM.last_mcts = session_state.ROOM.MCTS
214
+ session_state.ROOM.PLAYER = session_state.ROOM.PLAYER
215
+ session_state.ROOM.TURN = session_state.ROOM.PLAYER
216
+ session_state.ROOM.WINNER = _BLANK # 0
217
+ session_state.ROOM.COORDINATE_1D = [BOS_TOKEN_ID]
218
+
219
+ # Room status sync
220
+ def sync_room() -> bool:
221
+ room_id = session_state.ROOM.ROOM_ID
222
+ if room_id not in server_state.ROOMS.keys():
223
+ session_state.ROOM = Room("local")
224
+ return False
225
+ elif server_state.ROOMS[room_id].TIME == session_state.ROOM.TIME:
226
+ return False
227
+ elif server_state.ROOMS[room_id].TIME < session_state.ROOM.TIME:
228
+ # Only acquire the lock when writing to the server state
229
+ with server_state_lock["ROOMS"]:
230
+ server_rooms = server_state.ROOMS
231
+ server_rooms[room_id] = session_state.ROOM
232
+ server_state.ROOMS = server_rooms
233
+ return True
234
+ else:
235
+ session_state.ROOM = server_state.ROOMS[room_id]
236
+ return True
237
+
238
+ # Check if winner emerge from move
239
+ def check_win() -> int:
240
+ """
241
+ Use convolution to check if any player wins.
242
+ """
243
+ vertical = convolve(
244
+ session_state.ROOM.BOARD.board_map,
245
+ _VERTICAL,
246
+ mode="same",
247
+ )
248
+ horizontal = convolve(
249
+ session_state.ROOM.BOARD.board_map,
250
+ _HORIZONTAL,
251
+ mode="same",
252
+ )
253
+ diagonal_up_left = convolve(
254
+ session_state.ROOM.BOARD.board_map,
255
+ _DIAGONAL_UP_LEFT,
256
+ mode="same",
257
+ )
258
+ diagonal_up_right = convolve(
259
+ session_state.ROOM.BOARD.board_map,
260
+ _DIAGONAL_UP_RIGHT,
261
+ mode="same",
262
+ )
263
+ if (
264
+ np.max(
265
+ [
266
+ np.max(vertical),
267
+ np.max(horizontal),
268
+ np.max(diagonal_up_left),
269
+ np.max(diagonal_up_right),
270
+ ]
271
+ )
272
+ == 5 * _BLACK
273
+ ):
274
+ winner = _BLACK
275
+ elif (
276
+ np.min(
277
+ [
278
+ np.min(vertical),
279
+ np.min(horizontal),
280
+ np.min(diagonal_up_left),
281
+ np.min(diagonal_up_right),
282
+ ]
283
+ )
284
+ == 5 * _WHITE
285
+ ):
286
+ winner = _WHITE
287
+ else:
288
+ winner = _BLANK
289
+ return winner
290
+
291
+ # Triggers the board response on click
292
+ def handle_click(x, y):
293
+ """
294
+ Controls whether to pass on / continue current board / may start new round
295
+ """
296
+ if session_state.ROOM.BOARD.board_map[x][y] != _BLANK:
297
+ pass
298
+ elif (
299
+ session_state.ROOM.ROOM_ID in server_state.ROOMS.keys()
300
+ and _ROOM_COLOR[session_state.OWNER]
301
+ != server_state.ROOMS[session_state.ROOM.ROOM_ID].TURN
302
+ ):
303
+ sync_room()
304
+
305
+ # normal play situation
306
+ elif session_state.ROOM.WINNER == _BLANK:
307
+ # session_state.ROOM = deepcopy(session_state.ROOM)
308
+ # print("View of human player: ", session_state.ROOM.BOARD.board_map)
309
+ move = session_state.ROOM.BOARD.location_to_move((x, y))
310
+ session_state.ROOM.current_move = move
311
+ session_state.ROOM.BOARD.do_move(move)
312
+ # Gomoku Bot BOARD
313
+ session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(move // _BOARD_SIZE, move % _BOARD_SIZE)
314
+ session_state.ROOM.BOARD.board_map[x][y] = session_state.ROOM.TURN
315
+ session_state.ROOM.COORDINATE_1D.append(x * _BOARD_SIZE + y)
316
+
317
+ session_state.ROOM.TURN = change_turn(session_state.ROOM.TURN)
318
+ win, winner = session_state.ROOM.BOARD.game_end()
319
+ if win:
320
+ session_state.ROOM.WINNER = winner
321
+ session_state.ROOM.HISTORY = (
322
+ session_state.ROOM.HISTORY[0]
323
+ + int(session_state.ROOM.WINNER == _WHITE),
324
+ session_state.ROOM.HISTORY[1]
325
+ + int(session_state.ROOM.WINNER == _BLACK),
326
+ )
327
+ session_state.ROOM.TIME = time.time()
328
+
329
+ def forbid_click(x, y):
330
+ # st.warning('This posistion has been occupied!!!!', icon="⚠️")
331
+ st.error("({}, {}) has been occupied!!)".format(x, y), icon="🚨")
332
+
333
+ # Draw board
334
+ def draw_board(response: bool):
335
+ """construct each buttons for all cells of the board"""
336
+ if session_state.USE_AIAID and session_state.ROOM.WINNER == _BLANK and session_state.ROOM.TURN == _BLACK:
337
+ if session_state.USE_AIAID:
338
+ copy_mcts = deepcopy(session_state.ROOM.AID_MCTS.mcts)
339
+ _, acts_aid, probs_aid, simul_mean_time_aid = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
340
+ sorted_acts_probs = sorted(zip(acts_aid, probs_aid), key=lambda x: x[1], reverse=True)
341
+ top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
342
+ top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
343
+ if response and session_state.ROOM.TURN == _BLACK: # human turn
344
+ print("Your turn")
345
+ # construction of clickable buttons
346
+ cur_move = (session_state.ROOM.current_move // _BOARD_SIZE, session_state.ROOM.current_move % _BOARD_SIZE)
347
+ for i, row in enumerate(session_state.ROOM.BOARD.board_map):
348
+ # print("row:", row)
349
+ for j, cell in enumerate(row):
350
+ if (
351
+ i * _BOARD_SIZE + j
352
+ in (session_state.ROOM.COORDINATE_1D)
353
+ ):
354
+ if i == cur_move[0] and j == cur_move[1]:
355
+ BOARD_PLATE[i][j].button(
356
+ _PLAYER_SYMBOL[_NEW],
357
+ key=f"{i}:{j}",
358
+ args=(i, j),
359
+ on_click=handle_click,
360
+ )
361
+ else:
362
+ # disable click for GPT choices
363
+ BOARD_PLATE[i][j].button(
364
+ _PLAYER_SYMBOL[cell],
365
+ key=f"{i}:{j}",
366
+ args=(i, j),
367
+ on_click=forbid_click
368
+ )
369
+ else:
370
+ if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts:
371
+ # enable click for other cells available for human choices
372
+ prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
373
+ BOARD_PLATE[i][j].button(
374
+ _PLAYER_SYMBOL[cell] + f"({round(prob, 2)})",
375
+ key=f"{i}:{j}",
376
+ on_click=handle_click,
377
+ args=(i, j),
378
+ )
379
+ else:
380
+ # enable click for other cells available for human choices
381
+ BOARD_PLATE[i][j].button(
382
+ _PLAYER_SYMBOL[cell],
383
+ key=f"{i}:{j}",
384
+ on_click=handle_click,
385
+ args=(i, j),
386
+ )
387
+
388
+
389
+ elif response and session_state.ROOM.TURN == _WHITE: # AI turn
390
+ message.empty()
391
+ with st.spinner('🔮✨ Waiting for AI response... ⏳🚀'):
392
+ time.sleep(0.1)
393
+ print("AI's turn")
394
+ print("Below are current board under AI's view")
395
+ # print(session_state.ROOM.BOARD.board_map)
396
+ # move = _BOARD_SIZE * _BOARD_SIZE
397
+ # forbid = []
398
+ # step = 0.1
399
+ # tmp = 0.7
400
+ # while move >= _BOARD_SIZE * _BOARD_SIZE or move in session_state.ROOM.COORDINATE_1D:
401
+ #
402
+ # gpt_predictions = generate_gpt2(
403
+ # gpt2,
404
+ # torch.tensor(session_state.ROOM.COORDINATE_1D).unsqueeze(0),
405
+ # tmp
406
+ # )
407
+ # print(gpt_predictions)
408
+ # move = gpt_predictions[len(session_state.ROOM.COORDINATE_1D)]
409
+ # print(move)
410
+ # tmp += step
411
+ # # if move >= _BOARD_SIZE * _BOARD_SIZE:
412
+ # # forbid.append(move)
413
+ # # else:
414
+ # # break
415
+ #
416
+ #
417
+ # gpt_response = move
418
+ # gpt_i, gpt_j = gpt_response // _BOARD_SIZE, gpt_response % _BOARD_SIZE
419
+ # print(gpt_i, gpt_j)
420
+ # # session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
421
+ #
422
+ # simul_time = 0
423
+ if st.session_state['selected_oppo_model'] != 'Gomoku Bot':
424
+ move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
425
+ else:
426
+ move, simul_time = session_state.ROOM.MCTS.get_action(return_time=True)
427
+ session_state.ROOM.simula_time_list.append(simul_time)
428
+ print("AI takes move: ", move)
429
+ session_state.ROOM.current_move = move
430
+ gpt_response = move
431
+ gpt_i, gpt_j = gpt_response // _BOARD_SIZE, gpt_response % _BOARD_SIZE
432
+ print("AI's move is located at ({}, {}) :".format(gpt_i, gpt_j))
433
+ move = session_state.ROOM.BOARD.location_to_move((gpt_i, gpt_j))
434
+ print("Location to move: ", move)
435
+ # print("Location to move: ", move)
436
+ # MCTS BOARD
437
+ session_state.ROOM.BOARD.do_move(move)
438
+ # Gomoku Bot BOARD
439
+ session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(move // _BOARD_SIZE, move % _BOARD_SIZE)
440
+ # session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
441
+ session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
442
+
443
+ if not session_state.ROOM.BOARD.game_end()[0]:
444
+ if session_state.USE_AIAID:
445
+ copy_mcts = deepcopy(session_state.ROOM.AID_MCTS.mcts)
446
+ _, acts_aid, probs_aid, simul_mean_time_aid = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
447
+ sorted_acts_probs = sorted(zip(acts_aid, probs_aid), key=lambda x: x[1], reverse=True)
448
+ top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
449
+ top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
450
+ else:
451
+ top_five_acts = []
452
+ top_five_probs = []
453
+
454
+ # construction of clickable buttons
455
+ for i, row in enumerate(session_state.ROOM.BOARD.board_map):
456
+ # print("row:", row)
457
+ for j, cell in enumerate(row):
458
+ if (
459
+ i * _BOARD_SIZE + j
460
+ in (session_state.ROOM.COORDINATE_1D)
461
+ ):
462
+ if i == gpt_i and j == gpt_j:
463
+ BOARD_PLATE[i][j].button(
464
+ _PLAYER_SYMBOL[_NEW],
465
+ key=f"{i}:{j}",
466
+ args=(i, j),
467
+ on_click=handle_click,
468
+ )
469
+ else:
470
+ # disable click for GPT choices
471
+ BOARD_PLATE[i][j].button(
472
+ _PLAYER_SYMBOL[cell],
473
+ key=f"{i}:{j}",
474
+ args=(i, j),
475
+ on_click=forbid_click
476
+ )
477
+ else:
478
+ if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts and not session_state.ROOM.BOARD.game_end()[0]:
479
+ # enable click for other cells available for human choices
480
+ prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
481
+ BOARD_PLATE[i][j].button(
482
+ _PLAYER_SYMBOL[cell] + f"({round(prob, 2)})",
483
+ key=f"{i}:{j}",
484
+ on_click=handle_click,
485
+ args=(i, j),
486
+ )
487
+ else:
488
+ # enable click for other cells available for human choices
489
+ BOARD_PLATE[i][j].button(
490
+ _PLAYER_SYMBOL[cell],
491
+ key=f"{i}:{j}",
492
+ on_click=handle_click,
493
+ args=(i, j),
494
+ )
495
+
496
+
497
+ message.markdown(
498
+ 'AI agent has calculated its strategy, which takes <span style="color: blue; font-size: 20px;">{:.3e}</span>s per simulation.'.format(
499
+ simul_time),
500
+ unsafe_allow_html=True
501
+ )
502
+ LOG.subheader("Logs")
503
+ # change turn
504
+ session_state.ROOM.TURN = change_turn(session_state.ROOM.TURN)
505
+ # session_state.ROOM.WINNER = check_win()
506
+
507
+ win, winner = session_state.ROOM.BOARD.game_end()
508
+ if win:
509
+ session_state.ROOM.WINNER = winner
510
+
511
+ session_state.ROOM.HISTORY = (
512
+ session_state.ROOM.HISTORY[0]
513
+ + int(session_state.ROOM.WINNER == _WHITE),
514
+ session_state.ROOM.HISTORY[1]
515
+ + int(session_state.ROOM.WINNER == _BLACK),
516
+ )
517
+ session_state.ROOM.TIME = time.time()
518
+
519
+ if not response or session_state.ROOM.WINNER != _BLANK:
520
+ if session_state.ROOM.WINNER != _BLANK:
521
+ print("Game over")
522
+ for i, row in enumerate(session_state.ROOM.BOARD.board_map):
523
+ for j, cell in enumerate(row):
524
+ BOARD_PLATE[i][j].write(
525
+ _PLAYER_SYMBOL[cell],
526
+ # key=f"{i}:{j}",
527
+ )
528
+
529
+ # Game process control
530
+ def game_control():
531
+ if session_state.ROOM.WINNER != _BLANK:
532
+ draw_board(False)
533
+ else:
534
+ draw_board(True)
535
+ if session_state.ROOM.WINNER != _BLANK or 0 not in session_state.ROOM.BOARD.board_map:
536
+ ANOTHER_ROUND.button(
537
+ "Play Next round!",
538
+ on_click=another_round,
539
+ help="Clear board and swap first player",
540
+ )
541
+
542
+ # Infos
543
+ def update_info() -> None:
544
+ # Additional information
545
+ SCORE_PLATE[0].metric("Gomoku-Agent", session_state.ROOM.HISTORY[0])
546
+ SCORE_PLATE[1].metric("Black", session_state.ROOM.HISTORY[1])
547
+ if session_state.ROOM.WINNER != _BLANK:
548
+ st.balloons()
549
+ ROUND_INFO.write(
550
+ f"#### **{_PLAYER_COLOR[session_state.ROOM.WINNER]} WIN!**\n**Click buttons on the left for more plays.**"
551
+ )
552
+
553
+ # elif 0 not in session_state.ROOM.BOARD.board_map:
554
+ # ROUND_INFO.write("#### **Tie**")
555
+ # else:
556
+ # ROUND_INFO.write(
557
+ # f"#### **{_PLAYER_SYMBOL[session_state.ROOM.TURN]} {_PLAYER_COLOR[session_state.ROOM.TURN]}'s turn...**"
558
+ # )
559
+
560
+ # draw the plot for simulation time
561
+ # 创建一个 DataFrame
562
+
563
+ # print(session_state.ROOM.simula_time_list)
564
+ st.markdown("<br>", unsafe_allow_html=True)
565
+ st.markdown("<br>", unsafe_allow_html=True)
566
+ chart_data = pd.DataFrame(session_state.ROOM.simula_time_list, columns=["Simulation Time"])
567
+ st.line_chart(chart_data)
568
+
569
+
570
+ game_control()
571
+ update_info()
572
+
573
+
574
+ if __name__ == "__main__":
575
+ gomoku()
pages/Player_VS_AI.py CHANGED
@@ -8,6 +8,7 @@ Description: this file is used to display our project and add visualization elem
8
  import time
9
  import pandas as pd
10
  from copy import deepcopy
 
11
 
12
  # import torch
13
  import numpy as np
@@ -16,8 +17,12 @@ from scipy.signal import convolve # this is used to check if any player wins
16
  from streamlit import session_state
17
  from streamlit_server_state import server_state, server_state_lock
18
  from Gomoku_MCTS import MCTSpure, alphazero, Board, PolicyValueNet
 
 
19
  import matplotlib.pyplot as plt
20
 
 
 
21
  from const import (
22
  _BLACK, # 1, for human
23
  _WHITE, # 2 , for AI
@@ -26,6 +31,7 @@ from const import (
26
  _PLAYER_SYMBOL,
27
  _ROOM_COLOR,
28
  _VERTICAL,
 
29
  _HORIZONTAL,
30
  _DIAGONAL_UP_LEFT,
31
  _DIAGONAL_UP_RIGHT,
@@ -35,6 +41,13 @@ from const import (
35
  )
36
 
37
 
 
 
 
 
 
 
 
38
 
39
 
40
  # Utils
@@ -48,11 +61,14 @@ class Room:
48
  self.HISTORY = (0, 0)
49
  self.WINNER = _BLANK
50
  self.TIME = time.time()
51
- self.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=10),
52
- 'AlphaZero': alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, 'Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth').policy_value_fn, c_puct=5, n_playout=100)}
 
 
53
  self.MCTS = self.MCTS_dict['AlphaZero']
 
54
  self.AID_MCTS = self.MCTS_dict['AlphaZero']
55
- self.COORDINATE_1D = [_BOARD_SIZE_1D + 1]
56
  self.current_move = -1
57
  self.simula_time_list = []
58
 
@@ -75,10 +91,16 @@ if "ROOMS" not in server_state:
75
  server_state.ROOMS = {}
76
 
77
  def handle_oppo_model_selection():
78
- TreeNode = session_state.ROOM.MCTS.mcts._root
79
- new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
80
- new_mct.mcts._root = deepcopy(TreeNode)
81
- session_state.ROOM.MCTS = new_mct
 
 
 
 
 
 
82
  return
83
 
84
  def handle_aid_model_selection():
@@ -103,7 +125,7 @@ TITLE = st.empty()
103
  Model_Switch = st.empty()
104
 
105
  TITLE.header("🤖 AI 3603 Gomoku")
106
- selected_oppo_option = Model_Switch.selectbox('Select Opponent Model', ['Pure MCTS', 'AlphaZero'], index=1, key='oppo_model')
107
 
108
  if st.session_state['selected_oppo_model'] != selected_oppo_option:
109
  st.session_state['selected_oppo_model'] = selected_oppo_option
@@ -158,7 +180,7 @@ def restart() -> None:
158
  Restart the game.
159
  """
160
  session_state.ROOM = Room(session_state.ROOM.ROOM_ID)
161
-
162
 
163
  RESTART.button(
164
  "Reset",
@@ -183,10 +205,16 @@ def gomoku():
183
  """
184
  session_state.ROOM = deepcopy(session_state.ROOM)
185
  session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
 
 
 
 
 
 
186
  session_state.ROOM.PLAYER = session_state.ROOM.PLAYER
187
  session_state.ROOM.TURN = session_state.ROOM.PLAYER
188
  session_state.ROOM.WINNER = _BLANK # 0
189
- session_state.ROOM.COORDINATE_1D = [_BOARD_SIZE_1D + 1]
190
 
191
  # Room status sync
192
  def sync_room() -> bool:
@@ -281,6 +309,8 @@ def gomoku():
281
  move = session_state.ROOM.BOARD.location_to_move((x, y))
282
  session_state.ROOM.current_move = move
283
  session_state.ROOM.BOARD.do_move(move)
 
 
284
  session_state.ROOM.BOARD.board_map[x][y] = session_state.ROOM.TURN
285
  session_state.ROOM.COORDINATE_1D.append(x * _BOARD_SIZE + y)
286
 
@@ -299,7 +329,6 @@ def gomoku():
299
  def forbid_click(x, y):
300
  # st.warning('This posistion has been occupied!!!!', icon="⚠️")
301
  st.error("({}, {}) has been occupied!!)".format(x, y), icon="🚨")
302
- print("asdas")
303
 
304
  # Draw board
305
  def draw_board(response: bool):
@@ -314,6 +343,7 @@ def gomoku():
314
  if response and session_state.ROOM.TURN == _BLACK: # human turn
315
  print("Your turn")
316
  # construction of clickable buttons
 
317
  for i, row in enumerate(session_state.ROOM.BOARD.board_map):
318
  # print("row:", row)
319
  for j, cell in enumerate(row):
@@ -321,13 +351,21 @@ def gomoku():
321
  i * _BOARD_SIZE + j
322
  in (session_state.ROOM.COORDINATE_1D)
323
  ):
324
- # disable click for GPT choices
325
- BOARD_PLATE[i][j].button(
326
- _PLAYER_SYMBOL[cell],
327
- key=f"{i}:{j}",
328
- args=(i, j),
329
- on_click=forbid_click
330
- )
 
 
 
 
 
 
 
 
331
  else:
332
  if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts:
333
  # enable click for other cells available for human choices
@@ -355,7 +393,37 @@ def gomoku():
355
  print("AI's turn")
356
  print("Below are current board under AI's view")
357
  # print(session_state.ROOM.BOARD.board_map)
358
- move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  session_state.ROOM.simula_time_list.append(simul_time)
360
  print("AI takes move: ", move)
361
  session_state.ROOM.current_move = move
@@ -364,7 +432,11 @@ def gomoku():
364
  print("AI's move is located at ({}, {}) :".format(gpt_i, gpt_j))
365
  move = session_state.ROOM.BOARD.location_to_move((gpt_i, gpt_j))
366
  print("Location to move: ", move)
 
 
367
  session_state.ROOM.BOARD.do_move(move)
 
 
368
  # session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
369
  session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
370
 
@@ -387,13 +459,21 @@ def gomoku():
387
  i * _BOARD_SIZE + j
388
  in (session_state.ROOM.COORDINATE_1D)
389
  ):
390
- # disable click for GPT choices
391
- BOARD_PLATE[i][j].button(
392
- _PLAYER_SYMBOL[cell],
393
- key=f"{i}:{j}",
394
- args=(i, j),
395
- on_click=forbid_click
396
- )
 
 
 
 
 
 
 
 
397
  else:
398
  if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts and not session_state.ROOM.BOARD.game_end()[0]:
399
  # enable click for other cells available for human choices
@@ -480,7 +560,7 @@ def gomoku():
480
  # draw the plot for simulation time
481
  # 创建一个 DataFrame
482
 
483
- print(session_state.ROOM.simula_time_list)
484
  st.markdown("<br>", unsafe_allow_html=True)
485
  st.markdown("<br>", unsafe_allow_html=True)
486
  chart_data = pd.DataFrame(session_state.ROOM.simula_time_list, columns=["Simulation Time"])
 
8
  import time
9
  import pandas as pd
10
  from copy import deepcopy
11
+ import torch
12
 
13
  # import torch
14
  import numpy as np
 
17
  from streamlit import session_state
18
  from streamlit_server_state import server_state, server_state_lock
19
  from Gomoku_MCTS import MCTSpure, alphazero, Board, PolicyValueNet
20
+ from Gomoku_Bot import Gomoku_bot
21
+ from Gomoku_Bot import Board as Gomoku_bot_board
22
  import matplotlib.pyplot as plt
23
 
24
+
25
+
26
  from const import (
27
  _BLACK, # 1, for human
28
  _WHITE, # 2 , for AI
 
31
  _PLAYER_SYMBOL,
32
  _ROOM_COLOR,
33
  _VERTICAL,
34
+ _NEW,
35
  _HORIZONTAL,
36
  _DIAGONAL_UP_LEFT,
37
  _DIAGONAL_UP_RIGHT,
 
41
  )
42
 
43
 
44
+ from ai import (
45
+ BOS_TOKEN_ID,
46
+ generate_gpt2,
47
+ load_model,
48
+ )
49
+
50
+ gpt2 = load_model()
51
 
52
 
53
  # Utils
 
61
  self.HISTORY = (0, 0)
62
  self.WINNER = _BLANK
63
  self.TIME = time.time()
64
+ self.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
65
+ self.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
66
+ 'AlphaZero': alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, 'Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth').policy_value_fn, c_puct=5, n_playout=100),
67
+ 'Gomoku Bot': Gomoku_bot(self.gomoku_bot_board, -1)}
68
  self.MCTS = self.MCTS_dict['AlphaZero']
69
+ self.last_mcts = self.MCTS
70
  self.AID_MCTS = self.MCTS_dict['AlphaZero']
71
+ self.COORDINATE_1D = [BOS_TOKEN_ID]
72
  self.current_move = -1
73
  self.simula_time_list = []
74
 
 
91
  server_state.ROOMS = {}
92
 
93
  def handle_oppo_model_selection():
94
+ if st.session_state['selected_oppo_model'] == 'Gomoku Bot':
95
+ session_state.ROOM.last_mcts = session_state.ROOM.MCTS # since use different mechanism, store previous mcts first
96
+ session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict['Gomoku Bot']
97
+ return
98
+ else:
99
+ TreeNode = session_state.ROOM.last_mcts.mcts._root
100
+ new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
101
+ new_mct.mcts._root = deepcopy(TreeNode)
102
+ session_state.ROOM.MCTS = new_mct
103
+ session_state.ROOM.last_mcts = new_mct
104
  return
105
 
106
  def handle_aid_model_selection():
 
125
  Model_Switch = st.empty()
126
 
127
  TITLE.header("🤖 AI 3603 Gomoku")
128
+ selected_oppo_option = Model_Switch.selectbox('Select Opponent Model', ['Pure MCTS', 'AlphaZero','Gomoku Bot'], index=1, key='oppo_model')
129
 
130
  if st.session_state['selected_oppo_model'] != selected_oppo_option:
131
  st.session_state['selected_oppo_model'] = selected_oppo_option
 
180
  Restart the game.
181
  """
182
  session_state.ROOM = Room(session_state.ROOM.ROOM_ID)
183
+ st.session_state['selected_oppo_model'] = 'AlphaZero'
184
 
185
  RESTART.button(
186
  "Reset",
 
205
  """
206
  session_state.ROOM = deepcopy(session_state.ROOM)
207
  session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
208
+ session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
209
+ session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
210
+ 'AlphaZero': alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, 'Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth').policy_value_fn, c_puct=5, n_playout=100),
211
+ 'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
212
+ session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
213
+ session_state.ROOM.last_mcts = session_state.ROOM.MCTS
214
  session_state.ROOM.PLAYER = session_state.ROOM.PLAYER
215
  session_state.ROOM.TURN = session_state.ROOM.PLAYER
216
  session_state.ROOM.WINNER = _BLANK # 0
217
+ session_state.ROOM.COORDINATE_1D = [BOS_TOKEN_ID]
218
 
219
  # Room status sync
220
  def sync_room() -> bool:
 
309
  move = session_state.ROOM.BOARD.location_to_move((x, y))
310
  session_state.ROOM.current_move = move
311
  session_state.ROOM.BOARD.do_move(move)
312
+ # Gomoku Bot BOARD
313
+ session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(move // _BOARD_SIZE, move % _BOARD_SIZE)
314
  session_state.ROOM.BOARD.board_map[x][y] = session_state.ROOM.TURN
315
  session_state.ROOM.COORDINATE_1D.append(x * _BOARD_SIZE + y)
316
 
 
329
  def forbid_click(x, y):
330
  # st.warning('This posistion has been occupied!!!!', icon="⚠️")
331
  st.error("({}, {}) has been occupied!!)".format(x, y), icon="🚨")
 
332
 
333
  # Draw board
334
  def draw_board(response: bool):
 
343
  if response and session_state.ROOM.TURN == _BLACK: # human turn
344
  print("Your turn")
345
  # construction of clickable buttons
346
+ cur_move = (session_state.ROOM.current_move // _BOARD_SIZE, session_state.ROOM.current_move % _BOARD_SIZE)
347
  for i, row in enumerate(session_state.ROOM.BOARD.board_map):
348
  # print("row:", row)
349
  for j, cell in enumerate(row):
 
351
  i * _BOARD_SIZE + j
352
  in (session_state.ROOM.COORDINATE_1D)
353
  ):
354
+ if i == cur_move[0] and j == cur_move[1]:
355
+ BOARD_PLATE[i][j].button(
356
+ _PLAYER_SYMBOL[_NEW],
357
+ key=f"{i}:{j}",
358
+ args=(i, j),
359
+ on_click=handle_click,
360
+ )
361
+ else:
362
+ # disable click for GPT choices
363
+ BOARD_PLATE[i][j].button(
364
+ _PLAYER_SYMBOL[cell],
365
+ key=f"{i}:{j}",
366
+ args=(i, j),
367
+ on_click=forbid_click
368
+ )
369
  else:
370
  if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts:
371
  # enable click for other cells available for human choices
 
393
  print("AI's turn")
394
  print("Below are current board under AI's view")
395
  # print(session_state.ROOM.BOARD.board_map)
396
+ # move = _BOARD_SIZE * _BOARD_SIZE
397
+ # forbid = []
398
+ # step = 0.1
399
+ # tmp = 0.7
400
+ # while move >= _BOARD_SIZE * _BOARD_SIZE or move in session_state.ROOM.COORDINATE_1D:
401
+ #
402
+ # gpt_predictions = generate_gpt2(
403
+ # gpt2,
404
+ # torch.tensor(session_state.ROOM.COORDINATE_1D).unsqueeze(0),
405
+ # tmp
406
+ # )
407
+ # print(gpt_predictions)
408
+ # move = gpt_predictions[len(session_state.ROOM.COORDINATE_1D)]
409
+ # print(move)
410
+ # tmp += step
411
+ # # if move >= _BOARD_SIZE * _BOARD_SIZE:
412
+ # # forbid.append(move)
413
+ # # else:
414
+ # # break
415
+ #
416
+ #
417
+ # gpt_response = move
418
+ # gpt_i, gpt_j = gpt_response // _BOARD_SIZE, gpt_response % _BOARD_SIZE
419
+ # print(gpt_i, gpt_j)
420
+ # # session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
421
+ #
422
+ # simul_time = 0
423
+ if st.session_state['selected_oppo_model'] != 'Gomoku Bot':
424
+ move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
425
+ else:
426
+ move, simul_time = session_state.ROOM.MCTS.get_action(return_time=True)
427
  session_state.ROOM.simula_time_list.append(simul_time)
428
  print("AI takes move: ", move)
429
  session_state.ROOM.current_move = move
 
432
  print("AI's move is located at ({}, {}) :".format(gpt_i, gpt_j))
433
  move = session_state.ROOM.BOARD.location_to_move((gpt_i, gpt_j))
434
  print("Location to move: ", move)
435
+ # print("Location to move: ", move)
436
+ # MCTS BOARD
437
  session_state.ROOM.BOARD.do_move(move)
438
+ # Gomoku Bot BOARD
439
+ session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(move // _BOARD_SIZE, move % _BOARD_SIZE)
440
  # session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
441
  session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
442
 
 
459
  i * _BOARD_SIZE + j
460
  in (session_state.ROOM.COORDINATE_1D)
461
  ):
462
+ if i == gpt_i and j == gpt_j:
463
+ BOARD_PLATE[i][j].button(
464
+ _PLAYER_SYMBOL[_NEW],
465
+ key=f"{i}:{j}",
466
+ args=(i, j),
467
+ on_click=handle_click,
468
+ )
469
+ else:
470
+ # disable click for GPT choices
471
+ BOARD_PLATE[i][j].button(
472
+ _PLAYER_SYMBOL[cell],
473
+ key=f"{i}:{j}",
474
+ args=(i, j),
475
+ on_click=forbid_click
476
+ )
477
  else:
478
  if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts and not session_state.ROOM.BOARD.game_end()[0]:
479
  # enable click for other cells available for human choices
 
560
  # draw the plot for simulation time
561
  # 创建一个 DataFrame
562
 
563
+ # print(session_state.ROOM.simula_time_list)
564
  st.markdown("<br>", unsafe_allow_html=True)
565
  st.markdown("<br>", unsafe_allow_html=True)
566
  chart_data = pd.DataFrame(session_state.ROOM.simula_time_list, columns=["Simulation Time"])
try.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
2
+
3
+ model_name_or_path = "TheBloke/Llama-2-7b-Chat-GPTQ"
4
+ # To use a different branch, change revision
5
+ # For example: revision="gptq-4bit-64g-actorder_True"
6
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
7
+ device_map="auto",
8
+ trust_remote_code=False,
9
+ revision="main")
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
12
+
13
+ prompt = "Tell me about AI"
14
+ prompt_template=f'''[INST] <<SYS>>
15
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
16
+ <</SYS>>
17
+ {prompt}[/INST]
18
+
19
+ '''
20
+
21
+ print("\n\n*** Generate:")
22
+
23
+ input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
24
+ output = model.generate(inputs=input_ids, temperature=0.7, do_sample=True, top_p=0.95, top_k=40, max_new_tokens=512)
25
+ print(tokenizer.decode(output[0]))
26
+
27
+ # Inference can also be done using transformers' pipeline
28
+
29
+ print("*** Pipeline:")
30
+ pipe = pipeline(
31
+ "text-generation",
32
+ model=model,
33
+ tokenizer=tokenizer,
34
+ max_new_tokens=512,
35
+ do_sample=True,
36
+ temperature=0.7,
37
+ top_p=0.95,
38
+ top_k=40,
39
+ repetition_penalty=1.1
40
+ )
41
+
42
+ print(pipe(prompt_template)[0]['generated_text'])