import torch | |
def save_checkpoint(model, optimizer, epoch, loss, checkpoint_path="checkpoint.pth"): | |
checkpoint = { | |
'epoch': epoch, | |
'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
'loss': loss | |
} | |
torch.save(checkpoint, checkpoint_path) | |
print(f"Checkpoint saved at epoch {epoch}") | |
def load_checkpoint(model, optimizer, checkpoint_path="checkpoint.pth"): | |
checkpoint = torch.load(checkpoint_path, weights_only=True) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
epoch = checkpoint['epoch'] | |
loss = checkpoint['loss'] | |
print(f"Checkpoint loaded, resuming from epoch {epoch}") | |
return model, optimizer, loss | |