|
import os, sys |
|
|
|
sys.path.insert(0, os.getcwd()) |
|
import argparse |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"base_model", help="The model you want to merge with loha", default="", type=str |
|
) |
|
parser.add_argument( |
|
"lycoris_model", |
|
help="the lyco model you want to merge into sd model", |
|
default="", |
|
type=str, |
|
) |
|
parser.add_argument( |
|
"output_name", help="the output model", default="./out.pt", type=str |
|
) |
|
parser.add_argument( |
|
"--is_v2", |
|
help="Your base model is sd v2 or not", |
|
default=False, |
|
action="store_true", |
|
) |
|
parser.add_argument( |
|
"--is_sdxl", |
|
help="Your base/db model is sdxl or not", |
|
default=False, |
|
action="store_true", |
|
) |
|
parser.add_argument( |
|
"--device", |
|
help="Which device you want to use to merge the weight", |
|
default="cpu", |
|
type=str, |
|
) |
|
parser.add_argument("--dtype", help="dtype to save", default="float", type=str) |
|
parser.add_argument( |
|
"--weight", help="weight for the lyco model to merge", default="1.0", type=float |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
args = ARGS = get_args() |
|
|
|
|
|
from lycoris.utils import merge |
|
from lycoris.kohya.model_utils import ( |
|
load_models_from_stable_diffusion_checkpoint, |
|
save_stable_diffusion_checkpoint, |
|
load_file, |
|
) |
|
from lycoris.kohya.sdxl_model_util import ( |
|
load_models_from_sdxl_checkpoint, |
|
save_stable_diffusion_checkpoint as save_sdxl_checkpoint, |
|
) |
|
|
|
import torch |
|
|
|
|
|
@torch.no_grad() |
|
def main(): |
|
if args.is_sdxl: |
|
base = load_models_from_sdxl_checkpoint( |
|
None, args.base_model, map_location=args.device |
|
) |
|
else: |
|
base = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.base_model) |
|
if ARGS.lycoris_model.rsplit(".", 1)[-1] == "safetensors": |
|
lyco = load_file(ARGS.lycoris_model) |
|
else: |
|
lyco = torch.load(ARGS.lycoris_model) |
|
|
|
dtype_str = ARGS.dtype.replace("fp", "float").replace("bf", "bfloat") |
|
dtype = { |
|
"float": torch.float, |
|
"float16": torch.float16, |
|
"float32": torch.float32, |
|
"float64": torch.float64, |
|
"bfloat": torch.bfloat16, |
|
"bfloat16": torch.bfloat16, |
|
}.get(dtype_str, None) |
|
if dtype is None: |
|
raise ValueError(f'Cannot Find the dtype "{dtype}"') |
|
|
|
if args.is_sdxl: |
|
base_tes = [base[0], base[1]] |
|
base_unet = base[3] |
|
else: |
|
base_tes = [base[0]] |
|
base_unet = base[2] |
|
|
|
merge(base_tes, base_unet, lyco, ARGS.weight, ARGS.device) |
|
|
|
if args.is_sdxl: |
|
save_sdxl_checkpoint( |
|
ARGS.output_name, |
|
base[0].cpu(), |
|
base[1].cpu(), |
|
base[3].cpu(), |
|
0, |
|
0, |
|
None, |
|
base[2], |
|
getattr(base[1], "logit_scale", None), |
|
dtype, |
|
) |
|
else: |
|
save_stable_diffusion_checkpoint( |
|
ARGS.is_v2, |
|
ARGS.output_name, |
|
base[0].cpu(), |
|
base[2].cpu(), |
|
None, |
|
0, |
|
0, |
|
dtype, |
|
base[1], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|