|
import os |
|
import torch |
|
import safetensors.torch |
|
import threading |
|
from modules import shared, sd_hijack, sd_models |
|
from modules.sd_models import load_model |
|
import json |
|
|
|
try: |
|
from modules import sd_models_xl |
|
xl = True |
|
except: |
|
xl = False |
|
|
|
def prune_model(model, isxl=False): |
|
keys = list(model.keys()) |
|
base_prefix = "conditioner." if isxl else "cond_stage_model." |
|
for k in keys: |
|
if "diffusion_model." not in k and "first_stage_model." not in k and base_prefix not in k: |
|
model.pop(k, None) |
|
return model |
|
|
|
def to_half(sd): |
|
for key in sd.keys(): |
|
if 'model' in key and sd[key].dtype in {torch.float32, torch.float64, torch.bfloat16}: |
|
sd[key] = sd[key].half() |
|
return sd |
|
|
|
def savemodel(state_dict,currentmodel,fname,savesets,metadata={}): |
|
other_dict = {} |
|
if state_dict is None: |
|
if shared.sd_model and shared.sd_model.sd_checkpoint_info: |
|
metadata = shared.sd_model.sd_checkpoint_info.metadata.copy() |
|
else: |
|
return "Current model is not a valid merged model" |
|
|
|
checkpoint_info = shared.sd_model.sd_checkpoint_info |
|
|
|
if checkpoint_info is not None: |
|
filename = checkpoint_info.filename |
|
name = os.path.basename(filename) |
|
info = sd_models.get_closet_checkpoint_match(name) |
|
if info == checkpoint_info: |
|
|
|
|
|
return "Current model is not a merged model or you've already saved model" |
|
|
|
|
|
save_metadata = "save metadata" in savesets |
|
if save_metadata: |
|
metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"]) |
|
else: |
|
metadata = {"format": "pt"} |
|
|
|
if shared.sd_model is not None: |
|
print("load from shared.sd_model..") |
|
|
|
|
|
sd_hijack.model_hijack.undo_hijack(shared.sd_model) |
|
|
|
for name,module in shared.sd_model.named_modules(): |
|
if hasattr(module,"network_weights_backup"): |
|
module = network_restore_weights_from_backup(module) |
|
|
|
state_dict = shared.sd_model.state_dict() |
|
for key in list(state_dict.keys()): |
|
if key in POPKEYS: |
|
other_dict[key] = state_dict[key] |
|
del state_dict[key] |
|
|
|
sd_hijack.model_hijack.hijack(shared.sd_model) |
|
else: |
|
return "No current loaded model found" |
|
|
|
|
|
currentmodel = checkpoint_info.name_for_extra |
|
|
|
if "fp16" in savesets: |
|
pre = ".fp16" |
|
else:pre = "" |
|
ext = ".safetensors" if "safetensors" in savesets else ".ckpt" |
|
|
|
|
|
if "model.diffusion_model.input_blocks.0.0.weight" in state_dict.keys(): |
|
shape = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape |
|
if shape[1] == 9: |
|
pre += "-inpainting" |
|
if shape[1] == 8: |
|
pre += "-instruct-pix2pix" |
|
|
|
if not fname or fname == "": |
|
fname = currentmodel.replace(" ","").replace(",","_").replace("(","_").replace(")","_")+pre+ext |
|
if fname[0]=="_":fname = fname[1:] |
|
else: |
|
fname = fname if ext in fname else fname +pre+ext |
|
|
|
fname = os.path.join(sd_models.model_path, fname) |
|
fname = fname.replace("ProgramFiles_x86_","Program Files (x86)") |
|
|
|
if len(fname) > 255: |
|
fname.replace(ext,"") |
|
fname=fname[:240]+ext |
|
|
|
|
|
if os.path.isfile(fname) and not "overwrite" in savesets: |
|
_err_msg = f"Output file ({fname}) existed and was not saved]" |
|
print(_err_msg) |
|
return _err_msg |
|
|
|
print("Saving...") |
|
isxl = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in state_dict |
|
if isxl: |
|
|
|
for i, key in enumerate(state_dict.keys()): |
|
if "cond_stage_model." in key: |
|
del state_dict[key] |
|
|
|
if "fp16" in savesets: |
|
state_dict = to_half(state_dict) |
|
if "prune" in savesets: |
|
state_dict = prune_model(state_dict, isxl) |
|
|
|
|
|
print("Check contiguous...") |
|
for key in state_dict.keys(): |
|
v = state_dict[key] |
|
v = v.contiguous() |
|
state_dict[key] = v |
|
|
|
try: |
|
if ext == ".safetensors": |
|
safetensors.torch.save_file(state_dict, fname, metadata=metadata) |
|
else: |
|
torch.save(state_dict, fname) |
|
except Exception as e: |
|
print(f"ERROR: Couldn't saved:{fname},ERROR is {e}") |
|
return f"ERROR: Couldn't saved:{fname},ERROR is {e}" |
|
print("Done!") |
|
if other_dict: |
|
for key in other_dict.keys(): |
|
state_dict[key] = other_dict[key] |
|
del other_dict |
|
load_model(checkpoint_info, already_loaded_state_dict=state_dict) |
|
return "Merged model saved in "+fname |
|
|
|
def filenamecutter(name,model_a = False): |
|
if name =="" or name ==[]: return |
|
checkpoint_info = sd_models.get_closet_checkpoint_match(name) |
|
name= os.path.splitext(checkpoint_info.filename)[0] |
|
|
|
if not model_a: |
|
name = os.path.basename(name) |
|
return name |
|
|
|
from typing import Union |
|
|
|
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): |
|
weights_backup = getattr(self, "network_weights_backup", None) |
|
bias_backup = getattr(self, "network_bias_backup", None) |
|
|
|
if weights_backup is None and bias_backup is None: |
|
return self |
|
|
|
with torch.no_grad(): |
|
if weights_backup is not None: |
|
if isinstance(self, torch.nn.MultiheadAttention): |
|
self.in_proj_weight = torch.nn.Parameter(weights_backup[0].detach().requires_grad_(self.in_proj_weight.requires_grad)) |
|
self.out_proj.weight = torch.nn.Parameter(weights_backup[1].detach().requires_grad_(self.out_proj.weight.requires_grad)) |
|
else: |
|
self.weight = torch.nn.Parameter(weights_backup.detach().requires_grad_(self.weight.requires_grad)) |
|
|
|
if bias_backup is not None: |
|
if isinstance(self, torch.nn.MultiheadAttention): |
|
self.out_proj.bias = torch.nn.Parameter(bias_backup.detach().requires_grad_(self.out_proj.bias.requires_grad)) |
|
else: |
|
self.bias = torch.nn.Parameter(bias_backup.detach().requires_grad_(self.bias.requires_grad)) |
|
else: |
|
if isinstance(self, torch.nn.MultiheadAttention): |
|
self.out_proj.bias = None |
|
else: |
|
self.bias = None |
|
return self |
|
|
|
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): |
|
self.network_current_names = () |
|
self.network_weights_backup = None |
|
self.network_bias_backup = None |
|
|
|
POPKEYS=[ |
|
"betas", |
|
"alphas_cumprod", |
|
"alphas_cumprod_prev", |
|
"sqrt_alphas_cumprod", |
|
"sqrt_one_minus_alphas_cumprod", |
|
"log_one_minus_alphas_cumprod", |
|
"sqrt_recip_alphas_cumprod", |
|
"sqrt_recipm1_alphas_cumprod", |
|
"posterior_variance", |
|
"posterior_log_variance_clipped", |
|
"posterior_mean_coef1", |
|
"posterior_mean_coef2", |
|
] |