File size: 4,670 Bytes
7d23b62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import os

dir_path = os.path.dirname(os.path.realpath(__file__))
# from tensorboardX import SummaryWriter
from tqdm import tqdm
import datetime
from torch.utils.data import DataLoader, TensorDataset

date = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')


class BoardEvaluationNet(nn.Module):
    def __init__(self, board_size):
        super(BoardEvaluationNet, self).__init__()
        self.board_size = board_size

        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32 * board_size * board_size, 256)
        self.fc2 = nn.Linear(256, board_size * board_size)

    def forward(self, x):
        x = x.unsqueeze(1)  # Add a channel dimension
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(-1, 32 * self.board_size * self.board_size)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x.view(-1, self.board_size, self.board_size)


def normalize(t):
    return t


if __name__ == "__main__":

    writer = SummaryWriter(os.path.join(dir_path, 'train_data/log', date), comment='BoardEvaluationNet')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    best = np.Inf
    loss_fn = nn.CrossEntropyLoss()

    # Example usage
    BS = 15

    net_for_black = BoardEvaluationNet(BS).to(device)
    net_for_white = BoardEvaluationNet(BS).to(device)

    net_for_black.load_state_dict(torch.load(os.path.join(dir_path, 'train_data/model', 'best_loss=680.5813717259707.pth')))

    optimizer = torch.optim.Adam(net_for_black.parameters(), lr=1e-5, betas=(0.9, 0.99),
                                 eps=1e-8)

    data_path = os.path.join(dir_path, 'train_data/data', 'train_data.pkl')
    with open(data_path, 'rb') as f:
        datas = pickle.load(f)

    train_data_for_black = datas[1][:int(len(datas[1]) * 1)]
    test_data_for_black = datas[1][int(len(datas[1]) * 0.8):]
    train_data_for_white = datas[-1]
    epochs = 500
    batch_size = 32
    train_dataset = TensorDataset(torch.stack([torch.tensor(item['state'], dtype=torch.float) for item in train_data_for_black]),
                                  torch.stack([normalize(torch.tensor(item['scores'], dtype=torch.float)) for item in train_data_for_black]))
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        epoch_loss = 0
        print('Epoch:', epoch)
        for i, (states, scores) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            states = states.to(device)
            scores = scores.to(device)

            # print(input_tensor.shape)
            infer_start = datetime.datetime.now()
            output_tensor = net_for_black(states)
            infer_end = datetime.datetime.now()
            loss = loss_fn(output_tensor, scores)
            print(loss.item())
            exit(0)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            epoch_loss += loss.item()

            writer.add_scalar('train/infer_time', (infer_end - infer_start).microseconds,
                              i + epoch * len(train_dataloader))

        epoch_loss /= len(train_dataloader)
        writer.add_scalar('train/epoch_loss', epoch_loss, epoch)
        # test
        with torch.no_grad():
            test_loss = 0
            net_for_black.eval()
            for j, item in tqdm(enumerate(test_data_for_black), total=len(test_data_for_black)):
                scores = normalize(torch.tensor(item['scores'], dtype=torch.float).to(device).unsqueeze(0))  # 将数据类型设为float
                state = item['state']
                input_tensor = torch.tensor(state, dtype=torch.float).to(device).unsqueeze(0)  # 将数据类型设为float,并转移到设备上
                output_tensor = net_for_black(input_tensor).to(device)
                loss = loss_fn(output_tensor, scores)
                test_loss += loss.item()
            test_loss /=len(test_data_for_black)
            writer.add_scalar('test/loss', test_loss, epoch)
            if best > test_loss:
                best = test_loss
                model_path = os.path.join(dir_path, 'train_data/model')
                if not os.path.exists(model_path):
                    os.makedirs(model_path)
                torch.save(net_for_black.state_dict(),
                           os.path.join(model_path, f'best_loss={best}.pth'))
        net_for_black.train()