Spaces:
Runtime error
Runtime error
import os | |
import diffusion | |
from pytorch_lightning.callbacks import ( | |
ModelCheckpoint, | |
LearningRateMonitor | |
) | |
class ModelCallback: | |
def __init__( | |
self, | |
root_path: str, | |
ckpt_monitor: str = "val_loss", | |
ckpt_mode: str = "min", | |
): | |
ckpt_path = os.path.join(os.path.join(root_path, "model/")) | |
if not os.path.exists(root_path): | |
os.makedirs(root_path) | |
if not os.path.exists(ckpt_path): | |
os.makedirs(ckpt_path) | |
self.ckpt_callback = ModelCheckpoint( | |
monitor=ckpt_monitor, | |
dirpath=ckpt_path, | |
filename="model", | |
save_top_k=1, | |
mode=ckpt_mode, | |
save_weights_only=True | |
) | |
self.lr_callback = LearningRateMonitor("step") | |
self.ema_callback = diffusion.EMACallback(decay=0.995) | |
def get_callback(self): | |
return [ | |
self.ckpt_callback, self.lr_callback, self.ema_callback | |
] | |