ehristoforu's picture
Upload folder using huggingface_hub
0163a2c verified
import os
import torch
import safetensors.torch
import threading
from modules import shared, sd_hijack, sd_models
from modules.sd_models import load_model
import json
try:
from modules import sd_models_xl
xl = True
except:
xl = False
def prune_model(model, isxl=False):
keys = list(model.keys())
base_prefix = "conditioner." if isxl else "cond_stage_model."
for k in keys:
if "diffusion_model." not in k and "first_stage_model." not in k and base_prefix not in k:
model.pop(k, None)
return model
def to_half(sd):
for key in sd.keys():
if 'model' in key and sd[key].dtype in {torch.float32, torch.float64, torch.bfloat16}:
sd[key] = sd[key].half()
return sd
def savemodel(state_dict,currentmodel,fname,savesets,metadata={}):
other_dict = {}
if state_dict is None:
if shared.sd_model and shared.sd_model.sd_checkpoint_info:
metadata = shared.sd_model.sd_checkpoint_info.metadata.copy()
else:
return "Current model is not a valid merged model"
checkpoint_info = shared.sd_model.sd_checkpoint_info
# check if current merged model is a fake checkpoint_info
if checkpoint_info is not None:
filename = checkpoint_info.filename
name = os.path.basename(filename)
info = sd_models.get_closet_checkpoint_match(name)
if info == checkpoint_info:
# this is a valid checkpoint_info
# no need to save
return "Current model is not a merged model or you've already saved model"
# prepare metadata
save_metadata = "save metadata" in savesets
if save_metadata:
metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
else:
metadata = {"format": "pt"}
if shared.sd_model is not None:
print("load from shared.sd_model..")
# restore textencoder
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
for name,module in shared.sd_model.named_modules():
if hasattr(module,"network_weights_backup"):
module = network_restore_weights_from_backup(module)
state_dict = shared.sd_model.state_dict()
for key in list(state_dict.keys()):
if key in POPKEYS:
other_dict[key] = state_dict[key]
del state_dict[key]
sd_hijack.model_hijack.hijack(shared.sd_model)
else:
return "No current loaded model found"
# name_for_extra was set with the currentmodel
currentmodel = checkpoint_info.name_for_extra
if "fp16" in savesets:
pre = ".fp16"
else:pre = ""
ext = ".safetensors" if "safetensors" in savesets else ".ckpt"
# is it a inpainting or instruct-pix2pix2 model?
if "model.diffusion_model.input_blocks.0.0.weight" in state_dict.keys():
shape = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape
if shape[1] == 9:
pre += "-inpainting"
if shape[1] == 8:
pre += "-instruct-pix2pix"
if not fname or fname == "":
fname = currentmodel.replace(" ","").replace(",","_").replace("(","_").replace(")","_")+pre+ext
if fname[0]=="_":fname = fname[1:]
else:
fname = fname if ext in fname else fname +pre+ext
fname = os.path.join(sd_models.model_path, fname)
fname = fname.replace("ProgramFiles_x86_","Program Files (x86)")
if len(fname) > 255:
fname.replace(ext,"")
fname=fname[:240]+ext
# check if output file already exists
if os.path.isfile(fname) and not "overwrite" in savesets:
_err_msg = f"Output file ({fname}) existed and was not saved]"
print(_err_msg)
return _err_msg
print("Saving...")
isxl = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in state_dict
if isxl:
# prune share memory tensors, "cond_stage_model." prefixed base tensors are share memory with "conditioner." prefixed tensors
for i, key in enumerate(state_dict.keys()):
if "cond_stage_model." in key:
del state_dict[key]
if "fp16" in savesets:
state_dict = to_half(state_dict)
if "prune" in savesets:
state_dict = prune_model(state_dict, isxl)
# for safetensors contiguous error
print("Check contiguous...")
for key in state_dict.keys():
v = state_dict[key]
v = v.contiguous()
state_dict[key] = v
try:
if ext == ".safetensors":
safetensors.torch.save_file(state_dict, fname, metadata=metadata)
else:
torch.save(state_dict, fname)
except Exception as e:
print(f"ERROR: Couldn't saved:{fname},ERROR is {e}")
return f"ERROR: Couldn't saved:{fname},ERROR is {e}"
print("Done!")
if other_dict:
for key in other_dict.keys():
state_dict[key] = other_dict[key]
del other_dict
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return "Merged model saved in "+fname
def filenamecutter(name,model_a = False):
if name =="" or name ==[]: return
checkpoint_info = sd_models.get_closet_checkpoint_match(name)
name= os.path.splitext(checkpoint_info.filename)[0]
if not model_a:
name = os.path.basename(name)
return name
from typing import Union
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
if weights_backup is None and bias_backup is None:
return self
with torch.no_grad():
if weights_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight = torch.nn.Parameter(weights_backup[0].detach().requires_grad_(self.in_proj_weight.requires_grad))
self.out_proj.weight = torch.nn.Parameter(weights_backup[1].detach().requires_grad_(self.out_proj.weight.requires_grad))
else:
self.weight = torch.nn.Parameter(weights_backup.detach().requires_grad_(self.weight.requires_grad))
if bias_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
self.out_proj.bias = torch.nn.Parameter(bias_backup.detach().requires_grad_(self.out_proj.bias.requires_grad))
else:
self.bias = torch.nn.Parameter(bias_backup.detach().requires_grad_(self.bias.requires_grad))
else:
if isinstance(self, torch.nn.MultiheadAttention):
self.out_proj.bias = None
else:
self.bias = None
return self
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
self.network_current_names = ()
self.network_weights_backup = None
self.network_bias_backup = None
POPKEYS=[
"betas",
"alphas_cumprod",
"alphas_cumprod_prev",
"sqrt_alphas_cumprod",
"sqrt_one_minus_alphas_cumprod",
"log_one_minus_alphas_cumprod",
"sqrt_recip_alphas_cumprod",
"sqrt_recipm1_alphas_cumprod",
"posterior_variance",
"posterior_log_variance_clipped",
"posterior_mean_coef1",
"posterior_mean_coef2",
]