Spaces:
Build error
Build error
JosephCatrambone
commited on
Commit
·
6b8f356
1
Parent(s):
1cf8f2a
Changing a few parameters and training for much longer. Should have better outputs now.
Browse files- .gitattributes +0 -33
- main.py +42 -6
- model.pth +2 -2
.gitattributes
CHANGED
@@ -1,34 +1 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
CHANGED
@@ -10,16 +10,39 @@ import data
|
|
10 |
from model import ChessModel
|
11 |
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def train():
|
|
|
|
|
|
|
|
|
|
|
14 |
device_string = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
device = torch.device(device_string)
|
16 |
-
model = ChessModel(
|
17 |
-
opt = torch.optim.Adam(model.parameters())
|
18 |
reconstruction_loss_fn = nn.CrossEntropyLoss().to(torch.float32).to(device)
|
19 |
popularity_loss_fn = nn.L1Loss().to(torch.float32).to(device)
|
20 |
evaluation_loss_fn = nn.L1Loss().to(torch.float32).to(device)
|
21 |
-
data_loader = DataLoader(data.LichessPuzzleDataset(cap_data=
|
22 |
-
|
|
|
23 |
|
24 |
for epoch in range(num_epochs):
|
25 |
model.train()
|
@@ -38,7 +61,8 @@ def train():
|
|
38 |
reconstruction_loss = reconstruction_loss_fn(predicted_board_vec, board_vec)
|
39 |
popularity_loss = popularity_loss_fn(predicted_popularity, popularity)
|
40 |
evaluation_loss = evaluation_loss_fn(predicted_evaluation, evaluation)
|
41 |
-
total_loss = reconstruction_loss + popularity_loss + evaluation_loss
|
|
|
42 |
|
43 |
opt.zero_grad()
|
44 |
total_loss.backward()
|
@@ -54,7 +78,19 @@ def train():
|
|
54 |
print(f"Average evaluation loss: {total_evaluation_loss/num_batches}")
|
55 |
print(f"Average batch loss: {total_batch_loss/num_batches}")
|
56 |
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
|
60 |
def infer(fen):
|
|
|
10 |
from model import ChessModel
|
11 |
|
12 |
|
13 |
+
# Experiment parameters:
|
14 |
+
RUN_CONFIGURATION = {
|
15 |
+
"learning_rate": 0.0004,
|
16 |
+
"dataset_cap": 100000,
|
17 |
+
"epochs": 1000,
|
18 |
+
"latent_size": 256,
|
19 |
+
}
|
20 |
+
|
21 |
+
# Logging:
|
22 |
+
wandb = None
|
23 |
+
try:
|
24 |
+
import wandb
|
25 |
+
wandb.init("assembly_ai_hackathon_2022", config=RUN_CONFIGURATION)
|
26 |
+
except ImportError:
|
27 |
+
print("Weights and Biases not found in packages.")
|
28 |
+
|
29 |
+
|
30 |
def train():
|
31 |
+
learning_rate = RUN_CONFIGURATION["learning_rate"]
|
32 |
+
latent_size = RUN_CONFIGURATION["latent_size"]
|
33 |
+
data_cap = RUN_CONFIGURATION["dataset_cap"]
|
34 |
+
num_epochs = RUN_CONFIGURATION["epochs"]
|
35 |
+
|
36 |
device_string = "cuda" if torch.cuda.is_available() else "cpu"
|
37 |
device = torch.device(device_string)
|
38 |
+
model = ChessModel(latent_size).to(torch.float32).to(device)
|
39 |
+
opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
40 |
reconstruction_loss_fn = nn.CrossEntropyLoss().to(torch.float32).to(device)
|
41 |
popularity_loss_fn = nn.L1Loss().to(torch.float32).to(device)
|
42 |
evaluation_loss_fn = nn.L1Loss().to(torch.float32).to(device)
|
43 |
+
data_loader = DataLoader(data.LichessPuzzleDataset(cap_data=data_cap), batch_size=64, num_workers=1) # 1 to avoid threading madness.
|
44 |
+
save_every_nth_epoch = 50
|
45 |
+
upload_logs_every_nth_epoch = 1
|
46 |
|
47 |
for epoch in range(num_epochs):
|
48 |
model.train()
|
|
|
61 |
reconstruction_loss = reconstruction_loss_fn(predicted_board_vec, board_vec)
|
62 |
popularity_loss = popularity_loss_fn(predicted_popularity, popularity)
|
63 |
evaluation_loss = evaluation_loss_fn(predicted_evaluation, evaluation)
|
64 |
+
#total_loss = reconstruction_loss + popularity_loss + evaluation_loss
|
65 |
+
total_loss = popularity_loss
|
66 |
|
67 |
opt.zero_grad()
|
68 |
total_loss.backward()
|
|
|
78 |
print(f"Average evaluation loss: {total_evaluation_loss/num_batches}")
|
79 |
print(f"Average batch loss: {total_batch_loss/num_batches}")
|
80 |
|
81 |
+
if save_every_nth_epoch > 0 and (epoch % save_every_nth_epoch) == 0:
|
82 |
+
torch.save(model, f"checkpoints/epoch_{epoch}.pth")
|
83 |
+
|
84 |
+
if wandb:
|
85 |
+
wandb.log(
|
86 |
+
# For now, just log popularity.
|
87 |
+
{"popularity_loss": total_popularity_loss},
|
88 |
+
commit=(epoch+1) % upload_logs_every_nth_epoch == 0
|
89 |
+
)
|
90 |
+
|
91 |
+
torch.save(model, "checkpoints/final.pth")
|
92 |
+
if wandb:
|
93 |
+
wandb.finish()
|
94 |
|
95 |
|
96 |
def infer(fen):
|
model.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ed7dc6d33fb3ac545f78b7be413b0bebd565fcca89e6662ed617a2640d99715b
|
3 |
+
size 12118255
|