ResNet50_replicate / checkpoint.py
Ubuntu
Added checkpoint and early stopping
41b8141
raw
history blame
787 Bytes
import torch
def save_checkpoint(model, optimizer, epoch, loss, checkpoint_path="checkpoint.pth"):
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}
torch.save(checkpoint, checkpoint_path)
print(f"Checkpoint saved at epoch {epoch}")
def load_checkpoint(model, optimizer, checkpoint_path="checkpoint.pth"):
checkpoint = torch.load(checkpoint_path, weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print(f"Checkpoint loaded, resuming from epoch {epoch}")
return model, optimizer, loss