import os import diffusion from pytorch_lightning.callbacks import ( ModelCheckpoint, LearningRateMonitor ) class ModelCallback: def __init__( self, root_path: str, ckpt_monitor: str = "val_loss", ckpt_mode: str = "min", ): ckpt_path = os.path.join(os.path.join(root_path, "model/")) if not os.path.exists(root_path): os.makedirs(root_path) if not os.path.exists(ckpt_path): os.makedirs(ckpt_path) self.ckpt_callback = ModelCheckpoint( monitor=ckpt_monitor, dirpath=ckpt_path, filename="model", save_top_k=1, mode=ckpt_mode, save_weights_only=True ) self.lr_callback = LearningRateMonitor("step") self.ema_callback = diffusion.EMACallback(decay=0.995) def get_callback(self): return [ self.ckpt_callback, self.lr_callback, self.ema_callback ]