dinhdat1110's picture
Upload folder using huggingface_hub
9457143 verified
raw
history blame contribute delete
981 Bytes
import os
import diffusion
from pytorch_lightning.callbacks import (
ModelCheckpoint,
LearningRateMonitor
)
class ModelCallback:
def __init__(
self,
root_path: str,
ckpt_monitor: str = "val_loss",
ckpt_mode: str = "min",
):
ckpt_path = os.path.join(os.path.join(root_path, "model/"))
if not os.path.exists(root_path):
os.makedirs(root_path)
if not os.path.exists(ckpt_path):
os.makedirs(ckpt_path)
self.ckpt_callback = ModelCheckpoint(
monitor=ckpt_monitor,
dirpath=ckpt_path,
filename="model",
save_top_k=1,
mode=ckpt_mode,
save_weights_only=True
)
self.lr_callback = LearningRateMonitor("step")
self.ema_callback = diffusion.EMACallback(decay=0.995)
def get_callback(self):
return [
self.ckpt_callback, self.lr_callback, self.ema_callback
]