File size: 7,412 Bytes
0163a2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import os
import torch
import safetensors.torch
import threading
from modules import shared, sd_hijack, sd_models
from modules.sd_models import load_model
import json
try:
from modules import sd_models_xl
xl = True
except:
xl = False
def prune_model(model, isxl=False):
keys = list(model.keys())
base_prefix = "conditioner." if isxl else "cond_stage_model."
for k in keys:
if "diffusion_model." not in k and "first_stage_model." not in k and base_prefix not in k:
model.pop(k, None)
return model
def to_half(sd):
for key in sd.keys():
if 'model' in key and sd[key].dtype in {torch.float32, torch.float64, torch.bfloat16}:
sd[key] = sd[key].half()
return sd
def savemodel(state_dict,currentmodel,fname,savesets,metadata={}):
other_dict = {}
if state_dict is None:
if shared.sd_model and shared.sd_model.sd_checkpoint_info:
metadata = shared.sd_model.sd_checkpoint_info.metadata.copy()
else:
return "Current model is not a valid merged model"
checkpoint_info = shared.sd_model.sd_checkpoint_info
# check if current merged model is a fake checkpoint_info
if checkpoint_info is not None:
filename = checkpoint_info.filename
name = os.path.basename(filename)
info = sd_models.get_closet_checkpoint_match(name)
if info == checkpoint_info:
# this is a valid checkpoint_info
# no need to save
return "Current model is not a merged model or you've already saved model"
# prepare metadata
save_metadata = "save metadata" in savesets
if save_metadata:
metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
else:
metadata = {"format": "pt"}
if shared.sd_model is not None:
print("load from shared.sd_model..")
# restore textencoder
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
for name,module in shared.sd_model.named_modules():
if hasattr(module,"network_weights_backup"):
module = network_restore_weights_from_backup(module)
state_dict = shared.sd_model.state_dict()
for key in list(state_dict.keys()):
if key in POPKEYS:
other_dict[key] = state_dict[key]
del state_dict[key]
sd_hijack.model_hijack.hijack(shared.sd_model)
else:
return "No current loaded model found"
# name_for_extra was set with the currentmodel
currentmodel = checkpoint_info.name_for_extra
if "fp16" in savesets:
pre = ".fp16"
else:pre = ""
ext = ".safetensors" if "safetensors" in savesets else ".ckpt"
# is it a inpainting or instruct-pix2pix2 model?
if "model.diffusion_model.input_blocks.0.0.weight" in state_dict.keys():
shape = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape
if shape[1] == 9:
pre += "-inpainting"
if shape[1] == 8:
pre += "-instruct-pix2pix"
if not fname or fname == "":
fname = currentmodel.replace(" ","").replace(",","_").replace("(","_").replace(")","_")+pre+ext
if fname[0]=="_":fname = fname[1:]
else:
fname = fname if ext in fname else fname +pre+ext
fname = os.path.join(sd_models.model_path, fname)
fname = fname.replace("ProgramFiles_x86_","Program Files (x86)")
if len(fname) > 255:
fname.replace(ext,"")
fname=fname[:240]+ext
# check if output file already exists
if os.path.isfile(fname) and not "overwrite" in savesets:
_err_msg = f"Output file ({fname}) existed and was not saved]"
print(_err_msg)
return _err_msg
print("Saving...")
isxl = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in state_dict
if isxl:
# prune share memory tensors, "cond_stage_model." prefixed base tensors are share memory with "conditioner." prefixed tensors
for i, key in enumerate(state_dict.keys()):
if "cond_stage_model." in key:
del state_dict[key]
if "fp16" in savesets:
state_dict = to_half(state_dict)
if "prune" in savesets:
state_dict = prune_model(state_dict, isxl)
# for safetensors contiguous error
print("Check contiguous...")
for key in state_dict.keys():
v = state_dict[key]
v = v.contiguous()
state_dict[key] = v
try:
if ext == ".safetensors":
safetensors.torch.save_file(state_dict, fname, metadata=metadata)
else:
torch.save(state_dict, fname)
except Exception as e:
print(f"ERROR: Couldn't saved:{fname},ERROR is {e}")
return f"ERROR: Couldn't saved:{fname},ERROR is {e}"
print("Done!")
if other_dict:
for key in other_dict.keys():
state_dict[key] = other_dict[key]
del other_dict
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return "Merged model saved in "+fname
def filenamecutter(name,model_a = False):
if name =="" or name ==[]: return
checkpoint_info = sd_models.get_closet_checkpoint_match(name)
name= os.path.splitext(checkpoint_info.filename)[0]
if not model_a:
name = os.path.basename(name)
return name
from typing import Union
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
if weights_backup is None and bias_backup is None:
return self
with torch.no_grad():
if weights_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight = torch.nn.Parameter(weights_backup[0].detach().requires_grad_(self.in_proj_weight.requires_grad))
self.out_proj.weight = torch.nn.Parameter(weights_backup[1].detach().requires_grad_(self.out_proj.weight.requires_grad))
else:
self.weight = torch.nn.Parameter(weights_backup.detach().requires_grad_(self.weight.requires_grad))
if bias_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
self.out_proj.bias = torch.nn.Parameter(bias_backup.detach().requires_grad_(self.out_proj.bias.requires_grad))
else:
self.bias = torch.nn.Parameter(bias_backup.detach().requires_grad_(self.bias.requires_grad))
else:
if isinstance(self, torch.nn.MultiheadAttention):
self.out_proj.bias = None
else:
self.bias = None
return self
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
self.network_current_names = ()
self.network_weights_backup = None
self.network_bias_backup = None
POPKEYS=[
"betas",
"alphas_cumprod",
"alphas_cumprod_prev",
"sqrt_alphas_cumprod",
"sqrt_one_minus_alphas_cumprod",
"log_one_minus_alphas_cumprod",
"sqrt_recip_alphas_cumprod",
"sqrt_recipm1_alphas_cumprod",
"posterior_variance",
"posterior_log_variance_clipped",
"posterior_mean_coef1",
"posterior_mean_coef2",
] |