import os import shutil import accelerate import torch import glob def restore_checkpoint( checkpoint_dir, accelerator: accelerate.Accelerator, logger=None ): dirs = glob.glob(os.path.join(checkpoint_dir, "*")) dirs.sort() path = dirs[-1] if len(dirs) > 0 else None if path is None: if logger is not None: logger.info("Checkpoint does not exist. Starting a new training run.") init_step = 0 else: if logger is not None: logger.info(f"Resuming from checkpoint {path}") accelerator.load_state(path) init_step = int(os.path.basename(path)) return init_step def save_checkpoint(save_dir, accelerator: accelerate.Accelerator, step=0, total_limit=3): if total_limit > 0: folders = glob.glob(os.path.join(save_dir, "*")) folders.sort() for folder in folders[: len(folders) + 1 - total_limit]: shutil.rmtree(folder) accelerator.save_state(os.path.join(save_dir, f"{step:06d}"))