svjack's picture
Upload folder using huggingface_hub
bce3e7c verified
import argparse
import logging
import os
import shutil
import accelerate
import torch
from utils import huggingface_utils
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# checkpointファイル名
EPOCH_STATE_NAME = "{}-{:06d}-state"
EPOCH_FILE_NAME = "{}-{:06d}"
EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}"
LAST_STATE_NAME = "{}-state"
STEP_STATE_NAME = "{}-step{:08d}-state"
STEP_FILE_NAME = "{}-step{:08d}"
STEP_DIFFUSERS_DIR_NAME = "{}-step{:08d}"
def get_sanitized_config_or_none(args: argparse.Namespace):
# if `--log_config` is enabled, return args for logging. if not, return None.
# when `--log_config is enabled, filter out sensitive values from args
# if wandb is not enabled, the log is not exposed to the public, but it is fine to filter out sensitive values to be safe
if not args.log_config:
return None
sensitive_args = ["wandb_api_key", "huggingface_token"]
sensitive_path_args = [
"dit",
"vae",
"text_encoder1",
"text_encoder2",
"base_weights",
"network_weights",
"output_dir",
"logging_dir",
]
filtered_args = {}
for k, v in vars(args).items():
# filter out sensitive values and convert to string if necessary
if k not in sensitive_args + sensitive_path_args:
# Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`.
if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int):
filtered_args[k] = v
# accelerate does not support lists
elif isinstance(v, list):
filtered_args[k] = f"{v}"
# accelerate does not support objects
elif isinstance(v, object):
filtered_args[k] = f"{v}"
return filtered_args
class LossRecorder:
def __init__(self):
self.loss_list: list[float] = []
self.loss_total: float = 0.0
def add(self, *, epoch: int, step: int, loss: float) -> None:
if epoch == 0:
self.loss_list.append(loss)
else:
while len(self.loss_list) <= step:
self.loss_list.append(0.0)
self.loss_total -= self.loss_list[step]
self.loss_list[step] = loss
self.loss_total += loss
@property
def moving_average(self) -> float:
return self.loss_total / len(self.loss_list)
def get_epoch_ckpt_name(model_name, epoch_no: int):
return EPOCH_FILE_NAME.format(model_name, epoch_no) + ".safetensors"
def get_step_ckpt_name(model_name, step_no: int):
return STEP_FILE_NAME.format(model_name, step_no) + ".safetensors"
def get_last_ckpt_name(model_name):
return model_name + ".safetensors"
def get_remove_epoch_no(args: argparse.Namespace, epoch_no: int):
if args.save_last_n_epochs is None:
return None
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
if remove_epoch_no < 0:
return None
return remove_epoch_no
def get_remove_step_no(args: argparse.Namespace, step_no: int):
if args.save_last_n_steps is None:
return None
# calculate the step number to remove from the last_n_steps and save_every_n_steps
# e.g. if save_every_n_steps=10, save_last_n_steps=30, at step 50, keep 30 steps and remove step 10
remove_step_no = step_no - args.save_last_n_steps - 1
remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
if remove_step_no < 0:
return None
return remove_step_no
def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator: accelerate.Accelerator, epoch_no: int):
model_name = args.output_name
logger.info("")
logger.info(f"saving state at epoch {epoch_no}")
os.makedirs(args.output_dir, exist_ok=True)
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
accelerator.save_state(state_dir)
if args.save_state_to_huggingface:
logger.info("uploading state to huggingface.")
huggingface_utils.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
if last_n_epochs is not None:
remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
if os.path.exists(state_dir_old):
logger.info(f"removing old state: {state_dir_old}")
shutil.rmtree(state_dir_old)
def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator: accelerate.Accelerator, step_no: int):
model_name = args.output_name
logger.info("")
logger.info(f"saving state at step {step_no}")
os.makedirs(args.output_dir, exist_ok=True)
state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no))
accelerator.save_state(state_dir)
if args.save_state_to_huggingface:
logger.info("uploading state to huggingface.")
huggingface_utils.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no))
last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps
if last_n_steps is not None:
# last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する
remove_step_no = step_no - last_n_steps - 1
remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
if remove_step_no > 0:
state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no))
if os.path.exists(state_dir_old):
logger.info(f"removing old state: {state_dir_old}")
shutil.rmtree(state_dir_old)
def save_state_on_train_end(args: argparse.Namespace, accelerator: accelerate.Accelerator):
model_name = args.output_name
logger.info("")
logger.info("saving last state.")
os.makedirs(args.output_dir, exist_ok=True)
state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))
accelerator.save_state(state_dir)
if args.save_state_to_huggingface:
logger.info("uploading last state to huggingface.")
huggingface_utils.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name))