DDT / src /callbacks /model_checkpoint.py
wangshuai6
init space
9e426da
raw
history blame contribute delete
774 Bytes
import os.path
from typing import Optional, Dict, Any
import lightning.pytorch as pl
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from soupsieve.util import lower
class CheckpointHook(ModelCheckpoint):
"""Save checkpoint with only the incremental part of the model"""
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
self.dirpath = trainer.default_root_dir
self.exception_ckpt_path = os.path.join(self.dirpath, "on_exception.pt")
pl_module.strict_loading = False
def on_save_checkpoint(
self, trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
checkpoint: Dict[str, Any]
) -> None:
del checkpoint["callbacks"]