Spaces:
Build error
Build error
File size: 1,718 Bytes
02f6666 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import torch
from torch import nn
from data import BOARD_VECTOR_SIZE
class ChessModel(nn.Module):
def __init__(self, embedding_dims):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(BOARD_VECTOR_SIZE, 512),
nn.SiLU(),
nn.Linear(512, 1024),
nn.SiLU(),
nn.Linear(1024, 1024),
nn.SiLU(),
nn.Linear(1024, embedding_dims),
nn.SiLU(),
)
self.popularity_head = nn.Sequential(
nn.Linear(embedding_dims, 512),
nn.SiLU(),
nn.Linear(512, 1),
nn.Tanh(),
)
# Since it will take too long for this to evaluate before the jam is over, just noop it.
#self.evaluation_head = nn.Sequential(
# nn.Linear(embedding_dims, 512),
# nn.SiLU(),
# nn.Linear(512, 1),
# nn.Tanh(),
#)
self.evaluation_head = nn.Sequential(
nn.Linear(embedding_dims, 1),
)
self.reconstruction_head = nn.Sequential(
nn.Linear(embedding_dims, 512),
nn.SiLU(),
nn.Linear(512, BOARD_VECTOR_SIZE),
nn.Sigmoid(),
)
def forward(self, x):
"""Return the embedding, popularity, evaluation, and reconstruction."""
# Outputs have three heads: one for the board reconstruction, one for the popularity, and one for the eval.
embedding = self.encoder(x)
popularity = self.popularity_head(embedding)
evaluation = self.evaluation_head(embedding)
reconstruction = self.reconstruction_head(embedding)
return embedding, popularity, evaluation, reconstruction
|