import torch def save_checkpoint(model, optimizer, epoch, loss, checkpoint_path="checkpoint.pth"): checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss } torch.save(checkpoint, checkpoint_path) print(f"Checkpoint saved at epoch {epoch}") def load_checkpoint(model, optimizer, checkpoint_path="checkpoint.pth"): checkpoint = torch.load(checkpoint_path, weights_only=True) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] print(f"Checkpoint loaded, resuming from epoch {epoch}") return model, optimizer, loss