Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from src.utils import get_batch | |
def estimate_loss(model: nn.Module, eval_iters, block_size, batch_size, device): | |
out = {} | |
model.eval() | |
for split in ["train", "val"]: | |
losses = torch.zeros(eval_iters) | |
for k in range(eval_iters): | |
X, Y = get_batch(split, block_size, batch_size) | |
X, Y = X.to(device), Y.to(device) | |
logits, loss = model(X, Y) | |
losses[k] = loss.item() | |
out[split] = losses.mean() | |
model.train() | |
return out | |
def train( | |
model, | |
optimizer, | |
max_iters, | |
eval_interval, | |
eval_iters, | |
block_size, | |
batch_size, | |
device, | |
): | |
val_loss = None | |
for iter in range(max_iters): | |
if iter % eval_interval == 0: | |
losses = estimate_loss(model, eval_iters, block_size, batch_size, device) | |
print( | |
f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" | |
) | |
if val_loss is not None: | |
if losses["val"] < val_loss: | |
torch.save(model, "checkpoints/model.pth") | |
else: | |
val_loss = losses["val"] | |
xb, yb = get_batch("train", block_size, batch_size) | |
xb, yb = xb.to(device), yb.to(device) | |
logits, loss = model(xb, yb) | |
optimizer.zero_grad(set_to_none=True) | |
loss.backward() | |
optimizer.step() | |