Spaces:
Running
on
Zero
Running
on
Zero
import os.path | |
from typing import Optional, Dict, Any | |
import lightning.pytorch as pl | |
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint | |
from soupsieve.util import lower | |
class CheckpointHook(ModelCheckpoint): | |
"""Save checkpoint with only the incremental part of the model""" | |
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: | |
self.dirpath = trainer.default_root_dir | |
self.exception_ckpt_path = os.path.join(self.dirpath, "on_exception.pt") | |
pl_module.strict_loading = False | |
def on_save_checkpoint( | |
self, trainer: "pl.Trainer", | |
pl_module: "pl.LightningModule", | |
checkpoint: Dict[str, Any] | |
) -> None: | |
del checkpoint["callbacks"] |