NagisaNao commited on
Commit
41ef973
·
verified ·
1 Parent(s): 75f4564

Upload mergers.py

Browse files
sagemaker/fixing/extensions/supermerger/scripts/mergers/mergers.py ADDED
@@ -0,0 +1,1448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import os
3
+ import gc
4
+ import hashlib
5
+ import numpy as np
6
+ import os.path
7
+ import re
8
+ import torch
9
+ import tqdm
10
+ import datetime
11
+
12
+ import csv
13
+ import json
14
+ import launch
15
+ import torch.nn as nn
16
+ import scipy.ndimage
17
+ from copy import deepcopy
18
+ from PIL import Image, ImageFont, ImageDraw
19
+ from tqdm import tqdm
20
+ from functools import partial
21
+ from torch import Tensor, lerp
22
+ from torch.nn.functional import cosine_similarity, relu, softplus
23
+ from modules import shared, processing, sd_models, sd_vae, images, sd_samplers, scripts,devices, extras
24
+ from modules.ui import plaintext_to_html
25
+ from modules.shared import opts
26
+ from modules.processing import create_infotext,Processed
27
+ from modules.sd_models import load_model,unload_model_weights
28
+ from modules.generation_parameters_copypaste import create_override_settings_dict
29
+ from scripts.mergers.model_util import filenamecutter,savemodel
30
+ from math import ceil
31
+ import sys
32
+ from multiprocessing import cpu_count
33
+ from threading import Lock
34
+ from concurrent.futures import ThreadPoolExecutor, as_completed
35
+ from scripts.mergers.bcolors import bcolors
36
+ import collections
37
+
38
+ try:
39
+ ui_version = int(launch.git_tag().split("-",1)[0].replace("v","").replace(".",""))
40
+ except:
41
+ ui_version = 100
42
+
43
+ orig_cache = 0
44
+
45
+ modelcache = collections.OrderedDict()
46
+
47
+ from inspect import currentframe
48
+
49
+ SELFKEYS = ["to_out","proj_out","norm"]
50
+
51
+ module_path = os.path.dirname(os.path.abspath(sys.modules[__name__].__file__))
52
+ scriptpath = os.path.dirname(module_path)
53
+
54
+ def tryit(func):
55
+ try:
56
+ func()
57
+ except:
58
+ pass
59
+
60
+ stopmerge = False
61
+
62
+ def freezemtime():
63
+ global stopmerge
64
+ stopmerge = True
65
+
66
+ mergedmodel=[]
67
+ FINETUNEX = ["IN","OUT","OUT2","CONT","BRI","COL1","COL2","COL3"]
68
+ TYPESEG = ["none","alpha","beta (if Triple or Twice is not selected,Twice automatically enable)","alpha and beta","seed",
69
+ "mbw alpha","mbw beta","mbw alpha and beta", "model_A","model_B","model_C","pinpoint blocks (alpha or beta must be selected for another axis)",
70
+ "include blocks", "exclude blocks","add include", "add exclude","elemental","add elemental","pinpoint element","effective elemental checker","adjust","pinpoint adjust (IN,OUT,OUT2,CONT,BRI,COL1,COL2,COL3)",
71
+ "calcmode","prompt","random"]
72
+ TYPES = ["none","alpha","beta","alpha and beta","seed", "mbw alpha ","mbw beta","mbw alpha and beta",
73
+ "model_A","model_B","model_C","pinpoint blocks","include blocks","exclude blocks","add include", "add exclude","elemental","add elemental","pinpoint element",
74
+ "effective","adjust","pinpoint adjust","calcmode","prompt","random"]
75
+ MODES=["Weight" ,"Add" ,"Triple","Twice"]
76
+ SAVEMODES=["save model", "overwrite"]
77
+ EXCLUDE_CHOICES = ["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11",
78
+ "M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11",
79
+ "Adjust","VAE"]
80
+ CHCKPOINT_DICT_SKIP_ON_MERGE = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
81
+
82
+ #type[0:aplha,1:beta,2:seed,3:mbw,4:model_A,5:model_B,6:model_C]
83
+ #msettings=[0 weights_a,1 weights_b,2 model_a,3 model_b,4 model_c,5 base_alpha,6 base_beta,7 mode,8 useblocks,9 custom_name,10 save_sets,11 id_sets,12 wpresets]
84
+ #id sets "image", "PNG info","XY grid"
85
+
86
+ hear = False
87
+ hearm = False
88
+ NON4 = [None]*4
89
+
90
+ informer = sd_models.get_closet_checkpoint_match
91
+
92
+ #msettings=[weights_a,weights_b,model_a,model_b,model_c,device,base_alpha,base_beta,mode,loranames,useblocks,custom_name,save_sets,id_sets,wpresets,deep]
93
+
94
+ def smergegen(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode,
95
+ calcmode,useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,bake_in_vae,opt_value,inex,ex_blocks,ex_elems,
96
+ esettings,
97
+ s_prompt,s_nprompt,s_steps,s_sampler,s_cfg,s_seed,s_w,s_h,s_batch_size,
98
+ genoptions,s_hrupscaler,s_hr2ndsteps,s_denois_str,s_hr_scale,
99
+ lmode,lsets,llimits_u,llimits_l,lseed,lserial,lcustom,lround,
100
+ currentmodel,imggen,
101
+ *txt2imgparams):
102
+
103
+ lucks = {"on":False, "mode":lmode,"set":lsets,"upp":llimits_u,"low":llimits_l,"seed":lseed,"num":lserial,"cust":lcustom,"round":int(lround)}
104
+ deepprint = "print change" in esettings
105
+
106
+ cachedealer(True)
107
+
108
+ result,currentmodel,modelid,theta_0,metadata = smerge(
109
+ weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode,calcmode,
110
+ useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,bake_in_vae,opt_value,inex,ex_blocks,ex_elems,deepprint,lucks
111
+ )
112
+
113
+ if "ERROR" in result or "STOPPED" in result:
114
+ return result,"not loaded",*NON4
115
+
116
+ checkpoint_info = sd_models.get_closet_checkpoint_match(model_a)
117
+
118
+ if ui_version >= 150: checkpoint_info = fake_checkpoint_info(checkpoint_info,metadata,currentmodel)
119
+
120
+ save = True if SAVEMODES[0] in save_sets else False
121
+
122
+ result = savemodel(theta_0,currentmodel,custom_name,save_sets,metadata) if save else "Merged model loaded:"+currentmodel
123
+
124
+ sd_models.model_data.__init__()
125
+ load_model(checkpoint_info, already_loaded_state_dict=theta_0)
126
+
127
+ cachedealer(False)
128
+
129
+ del theta_0
130
+ devices.torch_gc()
131
+
132
+ debug = "debug" in save_sets
133
+
134
+ if ("copy config" in save_sets) and ("(" not in result): extras.create_config(result.replace("Merged model saved in ",""), 0, informer(model_a), informer(model_b), informer(model_b))
135
+
136
+ if imggen :
137
+ images = simggen(s_prompt,s_nprompt,s_steps,s_sampler,s_cfg,s_seed,s_w,s_h,s_batch_size,
138
+ genoptions,s_hrupscaler,s_hr2ndsteps,s_denois_str,s_hr_scale,
139
+ currentmodel,id_sets,modelid,
140
+ *txt2imgparams,debug = debug)
141
+
142
+ return result,currentmodel,*images[:4]
143
+ else:
144
+ return result,currentmodel
145
+
146
+ def checkpointer_infomer(name):
147
+ return sd_models.get_closet_checkpoint_match(name)
148
+
149
+ # XXX hack. fake checkpoint_info
150
+ def fake_checkpoint_info(checkpoint_info,metadata,currentmodel):
151
+ from modules import cache
152
+ dump_cache = cache.dump_cache
153
+ c_cache = cache.cache
154
+
155
+ checkpoint_info = deepcopy(checkpoint_info)
156
+ # change model name etc.
157
+ sha256 = hashlib.sha256(json.dumps(metadata).encode("utf-8")).hexdigest()
158
+ checkpoint_info.sha256 = sha256
159
+ checkpoint_info.name_for_extra = currentmodel
160
+
161
+ checkpoint_info.name = checkpoint_info.name_for_extra + ".safetensors"
162
+ checkpoint_info.model_name = checkpoint_info.name_for_extra.replace("/", "_").replace("\\", "_")
163
+ checkpoint_info.title = f"{checkpoint_info.name} [{sha256[0:10]}]"
164
+ checkpoint_info.metadata = metadata
165
+
166
+ # for sd-webui v1.5.x
167
+ sd_models.checkpoints_list[checkpoint_info.title] = checkpoint_info
168
+
169
+ # force to set a new sha256 hash
170
+ if c_cache is not None:
171
+ hashes = c_cache("hashes")
172
+ hashes[f"checkpoint/{checkpoint_info.name}"] = {
173
+ "mtime": os.path.getmtime(checkpoint_info.filename),
174
+ "sha256": sha256,
175
+ }
176
+ # save cache
177
+ dump_cache()
178
+
179
+ # set ids for a fake checkpoint info
180
+ checkpoint_info.ids = [checkpoint_info.model_name, checkpoint_info.name, checkpoint_info.name_for_extra]
181
+ return checkpoint_info
182
+
183
+ NUM_INPUT_BLOCKS = 12
184
+ NUM_MID_BLOCK = 1
185
+ NUM_OUTPUT_BLOCKS = 12
186
+ NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS
187
+ BLOCKID=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
188
+ BLOCKIDXLL=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","VAE"]
189
+ BLOCKIDXL=['BASE', 'IN0', 'IN1', 'IN2', 'IN3', 'IN4', 'IN5', 'IN6', 'IN7', 'IN8', 'M', 'OUT0', 'OUT1', 'OUT2', 'OUT3', 'OUT4', 'OUT5', 'OUT6', 'OUT7', 'OUT8', 'VAE']
190
+
191
+ RANDMAP = [0,50,100] #alpha,beta,elements
192
+
193
+ statistics = {"sum":{},"mean":{},"max":{},"min":{}}
194
+
195
+ ################################################
196
+ ##### Main Merging Code
197
+
198
+ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode,calcmode,
199
+ useblocks,custom_name,save_sets,id_sets,wpresets,deep,fine,bake_in_vae,opt_value,inex,ex_blocks,ex_elems,deepprint,lucks,main = [False,False,False]):
200
+
201
+ caster("merge start",hearm)
202
+ global hear,mergedmodel,stopmerge,statistics
203
+ stopmerge = False
204
+
205
+ debug = "debug" in save_sets
206
+ uselerp = "use old calc method" not in save_sets
207
+ device = "cuda" if "use cuda" in save_sets else "cpu"
208
+
209
+ unload_model_weights(sd_models.model_data.sd_model)
210
+
211
+ # for from file
212
+ if type(useblocks) is str:
213
+ useblocks = True if useblocks =="True" else False
214
+ if type(base_alpha) == str:base_alpha = float(base_alpha)
215
+ if type(base_beta) == str:base_beta = float(base_beta)
216
+
217
+ #random
218
+ if lucks != {}:
219
+ if lucks["seed"] == -1: lucks["ceed"] = str(random.randrange(4294967294))
220
+ else: lucks["ceed"] = lucks["seed"]
221
+ else: lucks["ceed"] = 0
222
+ np.random.seed(int(lucks["ceed"]))
223
+ randomer = np.random.rand(2500)
224
+
225
+ cachetarget =[]
226
+ for model,num in zip([model_a,model_b,model_c],main):
227
+ if model != "" and num:
228
+ cachetarget.append(model)
229
+
230
+ weights_a,deep = randdealer(weights_a,randomer,0,lucks,deep)
231
+ weights_b,_ = randdealer(weights_b,randomer,1,lucks,None)
232
+
233
+ weights_a_orig = weights_a
234
+ weights_b_orig = weights_b
235
+
236
+ # preset to weights
237
+ if wpresets != False and useblocks:
238
+ weights_a = wpreseter(weights_a,wpresets)
239
+ weights_b = wpreseter(weights_b,wpresets)
240
+
241
+ # mode select booleans
242
+ usebeta = MODES[2] in mode or MODES[3] in mode or "tensor" in calcmode
243
+ metadata = {"format": "pt"}
244
+
245
+ if (calcmode == "trainDifference" or calcmode == "extract") and "Add" not in mode:
246
+ print(f"{bcolors.WARNING}Mode changed to add difference{bcolors.ENDC}")
247
+ mode = "Add"
248
+ if model_c == "" or model_c is None:
249
+ #fallback to avoid crash
250
+ model_c = model_a
251
+ print(f"{bcolors.WARNING}Substituting empty model_c with model_a{bcolors.ENDC}")
252
+
253
+ if not useblocks:
254
+ weights_a = weights_b = ""
255
+ #for save log and save current model
256
+ mergedmodel =[weights_a,weights_b,
257
+ hashfromname(model_a),hashfromname(model_b),hashfromname(model_c),
258
+ base_alpha,base_beta,mode,useblocks,custom_name,save_sets,id_sets,deep,calcmode,lucks["ceed"],fine,opt_value,inex,ex_blocks,ex_elems].copy()
259
+
260
+ model_a = namefromhash(model_a)
261
+ model_b = namefromhash(model_b)
262
+ model_c = namefromhash(model_c)
263
+
264
+ caster(mergedmodel,False)
265
+
266
+ #elementals
267
+ if len(deep) > 0:
268
+ deep = deep.replace("\n",",")
269
+ deep = deep.replace(calcmode+",","")
270
+ deep = deep.split(",")
271
+
272
+ #format check
273
+ if model_a =="" or model_b =="" or ((not MODES[0] in mode) and model_c=="") :
274
+ return "ERROR: Necessary model is not selected",*NON4
275
+
276
+ #for MBW text to list
277
+ if useblocks:
278
+ weights_a_t=weights_a.split(',',1)
279
+ weights_b_t=weights_b.split(',',1)
280
+ base_alpha = float(weights_a_t[0])
281
+ weights_a = [float(w) for w in weights_a_t[1].split(',')]
282
+ caster(f"from {weights_a_t}, alpha = {base_alpha},weights_a ={weights_a}",hearm)
283
+ if not (len(weights_a) == 25 or len(weights_a) == 19):return f"ERROR: weights alpha value must be 20 or 26.",*NON4
284
+ if usebeta:
285
+ base_beta = float(weights_b_t[0])
286
+ weights_b = [float(w) for w in weights_b_t[1].split(',')]
287
+ caster(f"from {weights_b_t}, beta = {base_beta},weights_a ={weights_b}",hearm)
288
+ if not(len(weights_b) == 25 or len(weights_b) == 19): return f"ERROR: weights beta value must be 20 or 26.",*NON4
289
+
290
+ caster("model load start",hearm)
291
+ printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks['ceed'],fine,inex,ex_blocks,ex_elems)
292
+
293
+ theta_1=load_model_weights_m(model_b,2,cachetarget,device).copy()
294
+
295
+ isxl = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in theta_1.keys()
296
+
297
+ #adjust
298
+ if fine.rstrip(",0") != "":
299
+ fine = fineman(fine,isxl)
300
+ else:
301
+ fine = ""
302
+
303
+ if isxl and useblocks:
304
+ if len(weights_a) == 25:
305
+ weights_a = weighttoxl(weights_a)
306
+ print(f"weight converted for XL{weights_a}")
307
+ if usebeta:
308
+ if len(weights_b) == 25:
309
+ weights_b = weighttoxl(weights_b)
310
+ print(f"weight converted for XL{weights_b}")
311
+ if len(weights_a) == 19: weights_a = weights_a + [0]
312
+ if len(weights_b) == 19: weights_b = weights_b + [0]
313
+
314
+ if MODES[1] in mode:#Add
315
+ if stopmerge: return "STOPPED", *NON4
316
+ if calcmode == "trainDifference" or calcmode == "extract":
317
+ theta_2 = load_model_weights_m(model_c,3,cachetarget,device).copy()
318
+ else:
319
+ theta_2 = load_model_weights_m(model_c,3,cachetarget,device).copy()
320
+ for key in tqdm(theta_1.keys()):
321
+ if 'model' in key:
322
+ if key in theta_2:
323
+ t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
324
+ theta_1[key] = theta_1[key]- t2
325
+ else:
326
+ theta_1[key] = torch.zeros_like(theta_1[key])
327
+ del theta_2
328
+
329
+ if stopmerge: return "STOPPED", *NON4
330
+
331
+ if "tensor" in calcmode or "self" in calcmode:
332
+ theta_t = load_model_weights_m(model_a,1,cachetarget,device).copy()
333
+ theta_0 ={}
334
+ for key in theta_t:
335
+ theta_0[key] = theta_t[key].clone()
336
+ del theta_t
337
+ else:
338
+ theta_0=load_model_weights_m(model_a,1,cachetarget,device).copy()
339
+
340
+ if MODES[2] in mode or MODES[3] in mode:#Tripe or Twice
341
+ theta_2 = load_model_weights_m(model_c,3,cachetarget,device).copy()
342
+ else:
343
+ if not (calcmode == "trainDifference" or calcmode == "extract"):
344
+ theta_2 = {}
345
+
346
+ alpha = base_alpha
347
+ beta = base_beta
348
+
349
+ ex_elems = ex_elems.split(",")
350
+
351
+ keyratio = []
352
+ key_and_alpha = {}
353
+
354
+ ##### Stage 0/2 in Cosine
355
+ if "cosine" in calcmode:
356
+ sim, sims = precosine("A" in calcmode,theta_0,theta_1)
357
+
358
+ ##### Stage 1/2
359
+
360
+ for num, key in enumerate(tqdm(theta_0.keys(), desc="Stage 1/2") if not False else theta_0.keys()):
361
+ if stopmerge: return "STOPPED", *NON4
362
+ if not ("model" in key and key in theta_1): continue
363
+ if not ("weight" in key or "bias" in key): continue
364
+ if calcmode == "trainDifference" or calcmode == "extract":
365
+ if key not in theta_2:
366
+ continue
367
+ else:
368
+ if usebeta and (not key in theta_2) and (not theta_2 == {}) :
369
+ continue
370
+
371
+ weight_index = -1
372
+ current_alpha = alpha
373
+ current_beta = beta
374
+
375
+ a = list(theta_0[key].shape)
376
+ b = list(theta_1[key].shape)
377
+
378
+ # this enables merging an inpainting model (A) with another one (B);
379
+ # where normal model would have 4 channels, for latenst space, inpainting model would
380
+ # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
381
+ if a != b and a[0:1] + a[2:] == b[0:1] + b[2:]:
382
+ if a[1] == 4 and b[1] == 9:
383
+ raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
384
+ if a[1] == 4 and b[1] == 8:
385
+ raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
386
+
387
+ if a[1] == 8 and b[1] == 4:#If we have an Instruct-Pix2Pix model...
388
+ result_is_instruct_pix2pix_model = True
389
+ else:
390
+ assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
391
+ result_is_inpainting_model = True
392
+
393
+ block,blocks26 = blockfromkey(key,isxl)
394
+ if block == "Not Merge": continue
395
+ if inex != "Off" and (ex_blocks or (ex_elems != [""])) and excluder(blocks26,inex,ex_blocks,ex_elems,key): continue
396
+ weight_index = BLOCKIDXLL.index(blocks26) if isxl else BLOCKID.index(blocks26)
397
+
398
+ if useblocks:
399
+ if weight_index > 0:
400
+ current_alpha = weights_a[weight_index - 1]
401
+ if usebeta:
402
+ current_beta = weights_b[weight_index - 1]
403
+
404
+ if len(deep) > 0:
405
+ current_alpha = elementals(key,weight_index,deep,randomer,num,lucks,deepprint,current_alpha)
406
+
407
+ keyratio.append([key,current_alpha, current_beta])
408
+ #keyratio.append([key,current_alpha, current_beta,list(theta_0[key].shape),torch.sum(theta_0[key]).item(), torch.mean(theta_0[key]).item(), torch.max(theta_0[key]).item(), torch.min(theta_0[key]).item()])
409
+
410
+ if calcmode == "normal":
411
+ if a != b and a[0:1] + a[2:] == b[0:1] + b[2:]:
412
+ # Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
413
+ theta_0_a = theta_0[key][:, 0:4, :, :]
414
+ else:
415
+ theta_0_a = theta_0[key]
416
+
417
+ if MODES[1] in mode:#Add
418
+ caster(f"{num}, {block}, {model_a}+{current_alpha}+*({model_b}-{model_c}),{key}",hear)
419
+ theta_0_a = theta_0_a + current_alpha * theta_1[key]
420
+
421
+ elif MODES[2] in mode:#Triple
422
+ caster(f"{num}, {block}, {model_a}+{1-current_alpha-current_beta}+{model_b}*{current_alpha}+ {model_c}*{current_beta}",hear)
423
+ #
424
+ if uselerp and current_alpha + current_beta != 0:
425
+ theta_0_a =lerp(theta_0_a.to(torch.float32),lerp(theta_1[key].to(torch.float32),theta_2[key].to(torch.float32),current_beta/(current_alpha + current_beta)),current_alpha + current_beta).to(theta_0_a.dtype)
426
+ else:
427
+ theta_0_a = (1 - current_alpha-current_beta) * theta_0_a + current_alpha * theta_1[key]+current_beta * theta_2[key]
428
+
429
+ elif MODES[3] in mode:#Twice
430
+ caster(f"{num}, {block}, {key},{model_a} + {1-current_alpha} + {model_b}*{current_alpha}",hear)
431
+ caster(f"{num}, {block}, {key}({model_a}+{model_b}) +{1-current_beta}+{model_c}*{current_beta}",hear)
432
+ if uselerp:
433
+ theta_0_a = torch.lerp(torch.lerp(theta_0_a.to(torch.float32), theta_1[key].to(torch.float32), current_alpha), theta_2[key].to(torch.float32), current_beta).to(theta_0_a.dtype)
434
+ else:
435
+ theta_0_a = (1 - current_alpha) * theta_0_a + current_alpha * theta_1[key]
436
+ theta_0_a = (1 - current_beta) * theta_0_a + current_beta * theta_2[key]
437
+
438
+ else:#Weight
439
+ if current_alpha == 1:
440
+ caster(f"{num}, {block}, {key} alpha = 1,{model_a}={model_b}",hear)
441
+ theta_0_a = theta_1[key]
442
+ elif current_alpha !=0:
443
+ caster(f"{num}, {block}, {key}, {model_a}*{1-current_alpha}+{model_b}*{current_alpha}",hear)
444
+ if uselerp:
445
+ theta_0_a = torch.lerp(theta_0_a.to(torch.float32), theta_1[key].to(torch.float32), current_alpha).to(theta_0_a.dtype)
446
+ else:
447
+ theta_0_a = (1 - current_alpha) * theta_0_a + current_alpha * theta_1[key]
448
+
449
+ if a != b and a[0:1] + a[2:] == b[0:1] + b[2:]:
450
+ theta_0[key][:, 0:4, :, :] = theta_0_a
451
+ else:
452
+ theta_0[key] = theta_0_a
453
+
454
+ del theta_0_a, a, b
455
+
456
+ elif "cosine" in calcmode:
457
+ if "first_stage_model" in key: continue
458
+ cosine(calcmode,key,sim,sims,current_alpha,theta_0,theta_1,num,block,uselerp)
459
+
460
+ elif calcmode == "trainDifference":
461
+ if torch.allclose(theta_1[key].float(), theta_2[key].float(), rtol=0, atol=0):
462
+ theta_2[key] = theta_0[key]
463
+ continue
464
+ traindiff(key,current_alpha,theta_0,theta_1,theta_2)
465
+
466
+ elif calcmode == "smoothAdd":
467
+ caster(f"{num}, {block}, model A[{key}] + {current_alpha} + * (model B - model C)[{key}]", hear)
468
+ # Apply median filter to the weight differences
469
+ filtered_diff = scipy.ndimage.median_filter(theta_1[key].to(torch.float32).cpu().numpy(), size=3)
470
+ # Apply Gaussian filter to the filtered differences
471
+ filtered_diff = scipy.ndimage.gaussian_filter(filtered_diff, sigma=1)
472
+ theta_1[key] = torch.tensor(filtered_diff)
473
+ # Add the filtered differences to the original weights
474
+ theta_0[key] = theta_0[key] + current_alpha * theta_1[key]
475
+
476
+ elif calcmode == "smoothAdd MT":
477
+ key_and_alpha[key] = current_alpha
478
+
479
+ elif "tensor" in calcmode:
480
+ dim = theta_0[key].dim()
481
+ if dim == 0 : continue
482
+ tensormerge("2" not in calcmode,key,dim,theta_0,theta_1,current_alpha,current_beta)
483
+
484
+ elif "extract" == calcmode:
485
+ theta_0[key] = extract_super(theta_0[key],theta_1[key],theta_2[key],current_alpha,current_beta,opt_value)
486
+
487
+ elif calcmode == "self":
488
+ if any(selfkey in key for selfkey in SELFKEYS):continue
489
+ if current_alpha == 0: continue
490
+ theta_0[key] = (theta_0[key].clone()) * current_alpha
491
+
492
+ elif calcmode == "plus random":
493
+ if any(selfkey in key for selfkey in SELFKEYS):continue
494
+ if current_alpha == 0: continue
495
+ theta_0[key] += torch.randn_like(theta_0[key].clone()) * current_alpha
496
+
497
+ ##### Adjust
498
+ if any(item in key for item in FINETUNES) and fine:
499
+ index = FINETUNES.index(key)
500
+ if 5 > index :
501
+ theta_0[key] =theta_0[key]* fine[index]
502
+ else :theta_0[key] =theta_0[key] + torch.tensor(fine[5]).to(theta_0[key].device)
503
+
504
+
505
+ if calcmode == "smoothAdd MT":
506
+ # setting threads to higher than 8 doesn't significantly affect the time for merging
507
+ threads = cpu_count()
508
+ tasks_per_thread = 8
509
+
510
+ theta_0, theta_1, stopped = multithread_smoothadd(key_and_alpha, theta_0, theta_1, threads, tasks_per_thread, hear)
511
+ if stopped:
512
+ return "STOPPED", *NON4
513
+
514
+ currentmodel = makemodelname(weights_a,weights_b,model_a, model_b,model_c, base_alpha,base_beta,useblocks,mode,calcmode)
515
+
516
+ for key in tqdm(theta_1.keys(), desc="Stage 2/2"):
517
+ if key in CHCKPOINT_DICT_SKIP_ON_MERGE:
518
+ continue
519
+ if "model" in key and key not in theta_0:
520
+ theta_0.update({key:theta_1[key]})
521
+
522
+ del theta_1
523
+ if calcmode == "trainDifference" or calcmode == "extract":
524
+ del theta_2
525
+
526
+ ##### BakeVAE
527
+ bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
528
+ if bake_in_vae_filename is not None:
529
+ print(f"Baking in VAE from {bake_in_vae_filename}")
530
+ vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
531
+
532
+ for key in vae_dict.keys():
533
+ theta_0_key = 'first_stage_model.' + key
534
+ if theta_0_key in theta_0:
535
+ theta_0[theta_0_key] = vae_dict[key]
536
+
537
+ del vae_dict
538
+
539
+ modelid = rwmergelog(currentmodel,mergedmodel)
540
+ if "save E-list" in lucks["set"]: saveekeys(keyratio,modelid)
541
+
542
+ caster(mergedmodel,False)
543
+ if "Reset CLIP ids" in save_sets: resetclip(theta_0)
544
+
545
+ if True: # always set metadata. savemodel() will check save_sets later
546
+ merge_recipe = {
547
+ "type": "sd-webui-supermerger",
548
+ "weights_alpha": weights_a if useblocks else None,
549
+ "weights_beta": weights_b if useblocks else None,
550
+ "weights_alpha_orig": weights_a_orig if useblocks else None,
551
+ "weights_beta_orig": weights_b_orig if useblocks else None,
552
+ "model_a": longhashfromname(model_a),
553
+ "model_b": longhashfromname(model_b),
554
+ "model_c": longhashfromname(model_c),
555
+ "base_alpha": base_alpha,
556
+ "base_beta": base_beta,
557
+ "mode": mode,
558
+ "mbw": useblocks,
559
+ "elemental_merge": deep,
560
+ "calcmode" : calcmode,
561
+ f"{inex}":ex_blocks + ex_elems
562
+ }
563
+ metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
564
+ metadata["sd_merge_models"] = {}
565
+
566
+ def add_model_metadata(checkpoint_name):
567
+ checkpoint_info = sd_models.get_closet_checkpoint_match(checkpoint_name)
568
+ checkpoint_info.calculate_shorthash()
569
+ metadata["sd_merge_models"][checkpoint_info.sha256] = {
570
+ "name": checkpoint_name,
571
+ "legacy_hash": checkpoint_info.hash
572
+ }
573
+
574
+ #metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))
575
+
576
+ if model_a:
577
+ add_model_metadata(model_a)
578
+ if model_b:
579
+ add_model_metadata(model_b)
580
+ if model_c:
581
+ add_model_metadata(model_c)
582
+
583
+ metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
584
+
585
+ return "",currentmodel,modelid,theta_0,metadata
586
+
587
+ ################################################
588
+ ##### cosineA/B
589
+ def precosine(calcmode,theta_0,theta_1):
590
+ if calcmode: #favors modelA's structure with details from B
591
+ if stopmerge: return "STOPPED", *NON4
592
+ sim = torch.nn.CosineSimilarity(dim=0)
593
+ sims = np.array([], dtype=np.float64)
594
+ for key in (tqdm(theta_0.keys(), desc="Stage 0/2")):
595
+ # skip VAE model parameters to get better results
596
+ if "first_stage_model" in key: continue
597
+ if "model" in key and key in theta_1:
598
+ theta_0_norm = nn.functional.normalize(theta_0[key].to(torch.float32), p=2, dim=0)
599
+ theta_1_norm = nn.functional.normalize(theta_1[key].to(torch.float32), p=2, dim=0)
600
+ simab = sim(theta_0_norm, theta_1_norm)
601
+ sims = np.append(sims,simab.cpu().numpy())
602
+ sims = sims[~np.isnan(sims)]
603
+ sims = np.delete(sims, np.where(sims<np.percentile(sims, 1 ,method = 'midpoint')))
604
+ sims = np.delete(sims, np.where(sims>np.percentile(sims, 99 ,method = 'midpoint')))
605
+ else: #favors modelB's structure with details from A
606
+ if stopmerge: return "STOPPED", *NON4
607
+ sim = torch.nn.CosineSimilarity(dim=0)
608
+ sims = np.array([], dtype=np.float64)
609
+ for key in (tqdm(theta_0.keys(), desc="Stage 0/2")):
610
+ # skip VAE model parameters to get better results
611
+ if "first_stage_model" in key: continue
612
+ if "model" in key and key in theta_1:
613
+ simab = sim(theta_0[key].to(torch.float32), theta_1[key].to(torch.float32))
614
+ dot_product = torch.dot(theta_0[key].view(-1).to(torch.float32), theta_1[key].view(-1).to(torch.float32))
615
+ magnitude_similarity = dot_product / (torch.norm(theta_0[key].to(torch.float32)) * torch.norm(theta_1[key].to(torch.float32)))
616
+ combined_similarity = (simab + magnitude_similarity) / 2.0
617
+ sims = np.append(sims, combined_similarity.cpu().numpy())
618
+ sims = sims[~np.isnan(sims)]
619
+ sims = np.delete(sims, np.where(sims < np.percentile(sims, 1, method='midpoint')))
620
+ sims = np.delete(sims, np.where(sims > np.percentile(sims, 99, method='midpoint')))
621
+ return sim, sims
622
+
623
+ def cosine(mode,key,sim,sims,current_alpha,theta_0,theta_1,num,block,uselerp):
624
+ if "A" in mode: #favors modelA's structure with details from B
625
+ # skip VAE model parameters to get better results
626
+ if "model" in key and key in theta_0:
627
+ # Normalize the vectors before merging
628
+ theta_0_norm = nn.functional.normalize(theta_0[key].to(torch.float32), p=2, dim=0)
629
+ theta_1_norm = nn.functional.normalize(theta_1[key].to(torch.float32), p=2, dim=0)
630
+ simab = sim(theta_0_norm, theta_1_norm)
631
+ dot_product = torch.dot(theta_0_norm.view(-1), theta_1_norm.view(-1))
632
+ magnitude_similarity = dot_product / (torch.norm(theta_0_norm) * torch.norm(theta_1_norm))
633
+ combined_similarity = (simab + magnitude_similarity) / 2.0
634
+ k = (combined_similarity - sims.min()) / (sims.max() - sims.min())
635
+ k = k - abs(current_alpha)
636
+ k = k.clip(min=0,max=1.0)
637
+ caster(f"{num}, {block}, model A[{key}] {1-k} + (model B)[{key}]*{k}",hear)
638
+ if uselerp:
639
+ theta_0[key] = lerp(theta_1[key].to(torch.float32), theta_0[key].to(torch.float32),k).to(theta_0[key].dtype)
640
+ else:
641
+ theta_0[key] = theta_1[key] * (1 - k) + theta_0[key] * k
642
+
643
+ else: #favors modelB's structure with details from A
644
+ # skip VAE model parameters to get better results
645
+ if "model" in key and key in theta_0:
646
+ simab = sim(theta_0[key].to(torch.float32), theta_1[key].to(torch.float32))
647
+ dot_product = torch.dot(theta_0[key].view(-1).to(torch.float32), theta_1[key].view(-1).to(torch.float32))
648
+ magnitude_similarity = dot_product / (torch.norm(theta_0[key].to(torch.float32)) * torch.norm(theta_1[key].to(torch.float32)))
649
+ combined_similarity = (simab + magnitude_similarity) / 2.0
650
+ k = (combined_similarity - sims.min()) / (sims.max() - sims.min())
651
+ k = k - current_alpha
652
+ k = k.clip(min=0,max=1.0)
653
+ caster(f"{num}, {block}, model A[{key}] *{1-k} + (model B)[{key}]*{k}",hear)
654
+ if uselerp:
655
+ theta_0[key] = lerp(theta_1[key].to(torch.float32), theta_0[key].to(torch.float32),k).to(theta_0[key].dtype)
656
+ else:
657
+ theta_0[key] = theta_1[key] * (1 - k) + theta_0[key] * k
658
+
659
+ ################################################
660
+ ##### Traindiff
661
+ def traindiff(key,current_alpha,theta_0,theta_1,theta_2):
662
+ # Check if theta_1[key] is equal to theta_2[key]
663
+ diff_AB = theta_1[key].float() - theta_2[key].float()
664
+
665
+ distance_A0 = torch.abs(theta_1[key].float() - theta_2[key].float())
666
+ distance_A1 = torch.abs(theta_1[key].float() - theta_0[key].float())
667
+
668
+ sum_distances = distance_A0 + distance_A1
669
+
670
+ scale = torch.where(sum_distances != 0, distance_A1 / sum_distances, torch.tensor(0.).float())
671
+ sign_scale = torch.sign(theta_1[key].float() - theta_2[key].float())
672
+ scale = sign_scale * torch.abs(scale)
673
+
674
+ new_diff = scale * torch.abs(diff_AB)
675
+ theta_0[key] = theta_0[key] + (new_diff * (current_alpha*1.8))
676
+
677
+ ################################################
678
+ ##### Extract
679
+ # fix for python 3.9.16
680
+ from typing import Union, Optional
681
+ def extract_super(base: Optional[Tensor], a: Tensor, b: Tensor, alpha: float, beta: float, gamma: float) -> Tensor:
682
+ # def extract_super(base: Tensor | None, a: Tensor, b: Tensor, alpha: float, beta: float, gamma: float) -> Tensor:
683
+ assert base is None or base.shape == a.shape
684
+ assert a.shape == b.shape
685
+ assert 0 <= alpha <= 1
686
+ assert 0 <= beta <= 1
687
+ assert 0 <= gamma
688
+ dtype = base.dtype if base is not None else a.dtype
689
+ base = base.float() if base is not None else 0
690
+ a = a.float() - base
691
+ b = b.float() - base
692
+ c = cosine_similarity(a, b, -1).clamp(-1, 1).unsqueeze(-1)
693
+ d = ((c + 1) / 2) ** gamma
694
+ result = base + lerp(a, b, alpha) * lerp(d, 1 - d, beta)
695
+ return result.to(dtype)
696
+
697
+ def extract(a: Tensor, b: Tensor, p: float, smoothness: float) -> Tensor:
698
+ assert a.shape == b.shape
699
+ assert 0 <= p <= 1
700
+ assert 0 <= smoothness <= 1
701
+
702
+ r = relu if smoothness == 0 else partial(softplus, beta=1 / smoothness)
703
+ c = r(cosine_similarity(a, b, dim=-1)).unsqueeze(dim=-1).repeat_interleave(b.shape[-1], -1)
704
+ m = torch.lerp(c, torch.ones_like(c) - c, p)
705
+ return a * m
706
+
707
+ ################################################
708
+ ##### Tensor Merge
709
+ def tensormerge(mode,key,dim, theta_0,theta_1,current_alpha,current_beta):
710
+ if mode:
711
+ if current_alpha+current_beta <= 1 :
712
+ talphas = int(theta_0[key].shape[0]*(current_beta))
713
+ talphae = int(theta_0[key].shape[0]*(current_alpha+current_beta))
714
+ if dim == 1:
715
+ theta_0[key][talphas:talphae] = theta_1[key][talphas:talphae].clone()
716
+
717
+ elif dim == 2:
718
+ theta_0[key][talphas:talphae,:] = theta_1[key][talphas:talphae,:].clone()
719
+
720
+ elif dim == 3:
721
+ theta_0[key][talphas:talphae,:,:] = theta_1[key][talphas:talphae,:,:].clone()
722
+
723
+ elif dim == 4:
724
+ theta_0[key][talphas:talphae,:,:,:] = theta_1[key][talphas:talphae,:,:,:].clone()
725
+
726
+ else:
727
+ talphas = int(theta_0[key].shape[0]*(current_alpha+current_beta-1))
728
+ talphae = int(theta_0[key].shape[0]*(current_beta))
729
+ theta_t = theta_1[key].clone()
730
+ if dim == 1:
731
+ theta_t[talphas:talphae] = theta_0[key][talphas:talphae].clone()
732
+
733
+ elif dim == 2:
734
+ theta_t[talphas:talphae,:] = theta_0[key][talphas:talphae,:].clone()
735
+
736
+ elif dim == 3:
737
+ theta_t[talphas:talphae,:,:] = theta_0[key][talphas:talphae,:,:].clone()
738
+
739
+ elif dim == 4:
740
+ theta_t[talphas:talphae,:,:,:] = theta_0[key][talphas:talphae,:,:,:].clone()
741
+ theta_0[key] = theta_t
742
+
743
+ else:
744
+ if current_alpha+current_beta <= 1 :
745
+ talphas = int(theta_0[key].shape[0]*(current_beta))
746
+ talphae = int(theta_0[key].shape[0]*(current_alpha+current_beta))
747
+ if dim > 1:
748
+ if theta_0[key].shape[1] > 100:
749
+ talphas = int(theta_0[key].shape[1]*(current_beta))
750
+ talphae = int(theta_0[key].shape[1]*(current_alpha+current_beta))
751
+ if dim == 1:
752
+ theta_0[key][talphas:talphae] = theta_1[key][talphas:talphae].clone()
753
+
754
+ elif dim == 2:
755
+ theta_0[key][:,talphas:talphae] = theta_1[key][:,talphas:talphae].clone()
756
+
757
+ elif dim == 3:
758
+ theta_0[key][:,talphas:talphae,:] = theta_1[key][:,talphas:talphae,:].clone()
759
+
760
+ elif dim == 4:
761
+ theta_0[key][:,talphas:talphae,:,:] = theta_1[key][:,talphas:talphae,:,:].clone()
762
+
763
+ else:
764
+ talphas = int(theta_0[key].shape[0]*(current_alpha+current_beta-1))
765
+ talphae = int(theta_0[key].shape[0]*(current_beta))
766
+ theta_t = theta_1[key].clone()
767
+ if dim > 1:
768
+ if theta_0[key].shape[1] > 100:
769
+ talphas = int(theta_0[key].shape[1]*(current_alpha+current_beta-1))
770
+ talphae = int(theta_0[key].shape[1]*(current_beta))
771
+ if dim == 1:
772
+ theta_t[talphas:talphae] = theta_0[key][talphas:talphae].clone()
773
+
774
+ elif dim == 2:
775
+ theta_t[:,talphas:talphae] = theta_0[key][:,talphas:talphae].clone()
776
+
777
+ elif dim == 3:
778
+ theta_t[:,talphas:talphae,:] = theta_0[key][:,talphas:talphae,:].clone()
779
+
780
+ elif dim == 4:
781
+ theta_t[:,talphas:talphae,:,:] = theta_0[key][:,talphas:talphae,:,:].clone()
782
+ theta_0[key] = theta_t
783
+
784
+ ################################################
785
+ ##### Multi Thread SmoothAdd
786
+
787
+ def multithread_smoothadd(key_and_alpha, theta_0, theta_1, threads, tasks_per_thread, hear):
788
+ lock_theta_0 = Lock()
789
+ lock_theta_1 = Lock()
790
+ lock_progress = Lock()
791
+
792
+ def thread_callback(keys):
793
+ nonlocal theta_0, theta_1
794
+ if stopmerge:
795
+ return False
796
+
797
+ for key in keys:
798
+ caster(f"model A[{key}] + {key_and_alpha[key]} + * (model B - model C)[{key}]", hear)
799
+ filtered_diff = scipy.ndimage.median_filter(theta_1[key].to(torch.float32).cpu().numpy(), size=3)
800
+ filtered_diff = scipy.ndimage.gaussian_filter(filtered_diff, sigma=1)
801
+ with lock_theta_1:
802
+ theta_1[key] = torch.tensor(filtered_diff)
803
+ with lock_theta_0:
804
+ theta_0[key] = theta_0[key] + key_and_alpha[key] * theta_1[key]
805
+
806
+ with lock_progress:
807
+ progress.update(len(keys))
808
+
809
+ return True
810
+
811
+ def extract_and_remove(input_list, count):
812
+ extracted = input_list[:count]
813
+ del input_list[:count]
814
+
815
+ return extracted
816
+
817
+ keys = list(key_and_alpha.keys())
818
+
819
+ total_threads = ceil(len(keys) / int(tasks_per_thread))
820
+ print(f"max threads = {threads}, total threads = {total_threads}, tasks per thread = {tasks_per_thread}")
821
+
822
+ progress = tqdm(key_and_alpha.keys(), desc="smoothAdd MT")
823
+
824
+ futures = []
825
+ with ThreadPoolExecutor(max_workers=threads) as executor:
826
+ futures = [executor.submit(thread_callback, extract_and_remove(keys, int(tasks_per_thread))) for i in range(total_threads)]
827
+ for future in as_completed(futures):
828
+ if not future.result():
829
+ executor.shutdown()
830
+ return theta_0, theta_1, True
831
+ del progress
832
+
833
+ return theta_0, theta_1, False
834
+
835
+ ################################################
836
+ ##### Elementals
837
+ def elementals(key,weight_index,deep,randomer,num,lucks,deepprint,current_alpha):
838
+ skey = key + BLOCKID[weight_index]
839
+ for d in deep:
840
+ if d.count(":") != 2 :continue
841
+ dbs,dws,dr = d.split(":")[0],d.split(":")[1],d.split(":")[2]
842
+ dbs = blocker(dbs,BLOCKID)
843
+ dbs,dws = dbs.split(" "), dws.split(" ")
844
+ dbn,dbs = (True,dbs[1:]) if dbs[0] == "NOT" else (False,dbs)
845
+ dwn,dws = (True,dws[1:]) if dws[0] == "NOT" else (False,dws)
846
+ flag = dbn
847
+ for db in dbs:
848
+ if db in skey:
849
+ flag = not dbn
850
+ if flag:flag = dwn
851
+ else:continue
852
+ for dw in dws:
853
+ if dw in skey:
854
+ flag = not dwn
855
+ if flag:
856
+ dr = eratiodealer(dr,randomer,weight_index,num,lucks)
857
+ if deepprint :print(" ", dbs,dws,key,dr)
858
+ current_alpha = dr
859
+ return current_alpha
860
+
861
+ def forkforker(filename,device):
862
+ try:
863
+ return sd_models.read_state_dict(filename,map_location = device)
864
+ except:
865
+ return sd_models.read_state_dict(filename)
866
+
867
+ ################################################
868
+ ##### Load Model
869
+
870
+ def load_model_weights_m(model,abc,cachetarget,device):
871
+ checkpoint_info = sd_models.get_closet_checkpoint_match(model)
872
+ sd_model_name = checkpoint_info.model_name
873
+
874
+ if checkpoint_info in modelcache:
875
+ print(f"Loading weights [{sd_model_name}] from cache")
876
+ return {k: v.to(device) for k, v in modelcache[checkpoint_info].items()}
877
+ else:
878
+ print(f"Loading weights [{sd_model_name}] from file")
879
+ state_dict = forkforker(checkpoint_info.filename,device)
880
+ if orig_cache >= abc:
881
+ modelcache[checkpoint_info] = state_dict
882
+ modelcache[checkpoint_info] = {k: v.to("cpu") for k, v in modelcache[checkpoint_info].items()}
883
+ dontdelete = []
884
+ for model in cachetarget:
885
+ dontdelete.append(sd_models.get_closet_checkpoint_match(model))
886
+ while len(modelcache) > orig_cache:
887
+ for key in modelcache.keys():
888
+ if key in dontdelete:continue
889
+ modelcache.pop(key)
890
+ break
891
+ return state_dict
892
+
893
+ def makemodelname(weights_a,weights_b,model_a, model_b,model_c, alpha,beta,useblocks,mode,calc):
894
+ model_a=filenamecutter(model_a)
895
+ model_b=filenamecutter(model_b)
896
+ model_c=filenamecutter(model_c)
897
+
898
+ if type(alpha) == str:alpha = float(alpha)
899
+ if type(beta)== str:beta = float(beta)
900
+
901
+ if useblocks:
902
+ if MODES[1] in mode:#add
903
+ currentmodel =f"{model_a} + ({model_b} - {model_c}) x alpha ({str(round(alpha,3))},{','.join(str(s) for s in weights_a)})"
904
+ elif MODES[2] in mode:#triple
905
+ currentmodel =f"{model_a} x (1-alpha-beta) + {model_b} x alpha + {model_c} x beta (alpha = {str(round(alpha,3))},{','.join(str(s) for s in weights_a)},beta = {beta},{','.join(str(s) for s in weights_b)})"
906
+ elif MODES[3] in mode:#twice
907
+ currentmodel =f"({model_a} x (1-alpha) + {model_b} x alpha)x(1-beta)+ {model_c} x beta ({str(round(alpha,3))},{','.join(str(s) for s in weights_a)})_({str(round(beta,3))},{','.join(str(s) for s in weights_b)})"
908
+ else:
909
+ currentmodel =f"{model_a} x (1-alpha) + {model_b} x alpha ({str(round(alpha,3))},{','.join(str(s) for s in weights_a)})"
910
+ else:
911
+ if MODES[1] in mode:#add
912
+ currentmodel =f"{model_a} + ({model_b} - {model_c}) x {str(round(alpha,3))}"
913
+ elif MODES[2] in mode:#triple
914
+ currentmodel =f"{model_a} x {str(round(1-alpha-beta,3))} + {model_b} x {str(round(alpha,3))} + {model_c} x {str(round(beta,3))}"
915
+ elif MODES[3] in mode:#twice
916
+ currentmodel =f"({model_a} x {str(round(1-alpha,3))} +{model_b} x {str(round(alpha,3))}) x {str(round(1-beta,3))} + {model_c} x {str(round(beta,3))}"
917
+ else:
918
+ currentmodel =f"{model_a} x {str(round(1-alpha,3))} + {model_b} x {str(round(alpha,3))}"
919
+ if calc != "normal":
920
+ currentmodel = currentmodel + "_" + calc
921
+ if calc == "tensor":
922
+ currentmodel = currentmodel + f"_beta_{beta}"
923
+ return currentmodel
924
+
925
+ path_root = scripts.basedir()
926
+
927
+
928
+ ################################################
929
+ ##### Logging
930
+
931
+ def rwmergelog(mergedname = "",settings= [],id = 0):
932
+ # for compatible
933
+ mode_info = {
934
+ "Weight sum": "Weight sum:A*(1-alpha)+B*alpha",
935
+ "Add difference": "Add difference:A+(B-C)*alpha",
936
+ "Triple sum": "Triple sum:A*(1-alpha-beta)+B*alpha+C*beta",
937
+ "sum Twice": "sum Twice:(A*(1-alpha)+B*alpha)*(1-beta)+C*beta",
938
+ }
939
+ setting = settings.copy()
940
+ if len(setting) > 7 and setting[7] in mode_info:
941
+ setting[7] = mode_info[setting[7]] # fix mode entry for compatible
942
+ filepath = os.path.join(path_root, "mergehistory.csv")
943
+ is_file = os.path.isfile(filepath)
944
+ if not is_file:
945
+ with open(filepath, 'a') as f:
946
+ #msettings=[0 weights_a,1 weights_b,2 model_a,3 model_b,4 model_c,5 base_alpha,6 base_beta,7 mode,8 useblocks,9 custom_name,10 save_sets,11 id_sets, 12 deep 13 calcmode]
947
+ f.writelines('"ID","time","name","weights alpha","weights beta","model A","model B","model C","alpha","beta","mode","use MBW","plus lora","custum name","save setting","use ID"\n')
948
+ with open(filepath, 'r+') as f:
949
+ reader = csv.reader(f)
950
+ mlist = [raw for raw in reader]
951
+ if mergedname != "":
952
+ mergeid = len(mlist)
953
+ setting.insert(0,mergedname)
954
+ for i,x in enumerate(setting):
955
+ if "," in str(x) or "\n" in str(x):setting[i] = f'"{str(setting[i])}"'
956
+ text = ",".join(map(str, setting))
957
+ text=str(mergeid)+","+datetime.datetime.now().strftime('%Y.%m.%d %H.%M.%S.%f')[:-7]+"," + text + "\n"
958
+ f.writelines(text)
959
+ return mergeid
960
+ try:
961
+ out = mlist[int(id)]
962
+ except:
963
+ out = "ERROR: OUT of ID index"
964
+ return out
965
+
966
+ def saveekeys(keyratio,modelid):
967
+ import csv
968
+ path_root = scripts.basedir()
969
+ dir_path = os.path.join(path_root,"extensions","sd-webui-supermerger","scripts", "data")
970
+
971
+ if not os.path.exists(dir_path):
972
+ os.makedirs(dir_path)
973
+
974
+ filepath = os.path.join(dir_path,f"{modelid}.csv")
975
+
976
+ with open(filepath, 'w', newline='') as csvfile:
977
+ writer = csv.writer(csvfile)
978
+ writer.writerows(keyratio)
979
+
980
+ def savestatics(modelid):
981
+ for key in statistics.keys():
982
+ result = [[tkey] + list(statistics[key][tkey]) for tkey in statistics[key].keys()]
983
+ saveekeys(result,f"{modelid}_{key}")
984
+
985
+ def get_font(fontsize):
986
+ fontpath = os.path.join(scriptpath, "Roboto-Regular.ttf")
987
+ try:
988
+ return ImageFont.truetype(opts.font or fontpath, fontsize)
989
+ except Exception:
990
+ return ImageFont.truetype(fontpath, fontsize)
991
+
992
+ def draw_origin(grid, text,width,height,width_one):
993
+ grid_d= Image.new("RGB", (grid.width,grid.height), "white")
994
+ grid_d.paste(grid,(0,0))
995
+
996
+ d= ImageDraw.Draw(grid_d)
997
+ color_active = (0, 0, 0)
998
+ fontsize = (width+height)//25
999
+ fnt = get_font(fontsize)
1000
+
1001
+ if grid.width != width_one:
1002
+ while d.multiline_textsize(text, font=fnt)[0] > width_one*0.75 and fontsize > 0:
1003
+ fontsize -=1
1004
+ fnt = get_font(fontsize)
1005
+ d.multiline_text((0,0), text, font=fnt, fill=color_active,align="center")
1006
+ return grid_d
1007
+
1008
+ def wpreseter(w,presets):
1009
+ if "," not in w and w != "":
1010
+ presets=presets.splitlines()
1011
+ wdict={}
1012
+ for l in presets:
1013
+ if ":" in l :
1014
+ key = l.split(":",1)[0]
1015
+ wdict[key.strip()]=l.split(":",1)[1]
1016
+ if "\t" in l:
1017
+ key = l.split("\t",1)[0]
1018
+ wdict[key.strip()]=l.split("\t",1)[1]
1019
+ if w.strip() in wdict:
1020
+ name = w
1021
+ w = wdict[w.strip()]
1022
+ print(f"weights {name} imported from presets : {w}")
1023
+ return w
1024
+
1025
+ def fullpathfromname(name):
1026
+ if hash == "" or hash ==[]: return ""
1027
+ checkpoint_info = sd_models.get_closet_checkpoint_match(name)
1028
+ return checkpoint_info.filename
1029
+
1030
+ def namefromhash(hash):
1031
+ if hash == "" or hash ==[]: return ""
1032
+ checkpoint_info = sd_models.get_closet_checkpoint_match(hash)
1033
+ return checkpoint_info.model_name
1034
+
1035
+ def hashfromname(name):
1036
+ from modules import sd_models
1037
+ if name == "" or name ==[]: return ""
1038
+ checkpoint_info = sd_models.get_closet_checkpoint_match(name)
1039
+ if checkpoint_info.shorthash is not None:
1040
+ return checkpoint_info.shorthash
1041
+ return checkpoint_info.calculate_shorthash()
1042
+
1043
+ def longhashfromname(name):
1044
+ from modules import sd_models
1045
+ if name == "" or name ==[]: return ""
1046
+ checkpoint_info = sd_models.get_closet_checkpoint_match(name)
1047
+ if checkpoint_info.sha256 is not None:
1048
+ return checkpoint_info.sha256
1049
+ checkpoint_info.calculate_shorthash()
1050
+ return checkpoint_info.sha256
1051
+
1052
+
1053
+ ################################################
1054
+ ##### Random
1055
+
1056
+ RANCHA = ["R","U","X"]
1057
+
1058
+ def randdealer(w:str,randomer,ab,lucks,deep):
1059
+ up,low = lucks["upp"],lucks["low"]
1060
+ up,low = (up.split(","),low.split(","))
1061
+ out = []
1062
+ outd = {"R":[],"U":[],"X":[]}
1063
+ add = RANDMAP[ab]
1064
+ for i, r in enumerate (w.split(",")):
1065
+ if r.strip() =="R":
1066
+ out.append(str(round(randomer[i+add],lucks["round"])))
1067
+ elif r.strip() == "U":
1068
+ out.append(str(round(-2 * randomer[i+add] + 1.5,lucks["round"])))
1069
+ elif r.strip() == "X":
1070
+ out.append(str(round((float(low[i])-float(up[i]))* randomer[i+add] + float(up[i]),lucks["round"])))
1071
+ elif "E" in r:
1072
+ key = r.strip().replace("E","")
1073
+ outd[key].append(BLOCKID[i])
1074
+ out.append("0")
1075
+ else:
1076
+ out.append(r)
1077
+ for key in outd.keys():
1078
+ if outd[key] != []:
1079
+ deep = deep + f",{' '.join(outd[key])}::{key}" if deep else f"{' '.join(outd[key])}::{key}"
1080
+ return ",".join(out), deep
1081
+
1082
+ def eratiodealer(dr,randomer,block,num,lucks):
1083
+ if any(element in dr for element in RANCHA):
1084
+ up,low = lucks["upp"],lucks["low"]
1085
+ up,low = (up.split(","),low.split(","))
1086
+ add = RANDMAP[2]
1087
+ if dr.strip() =="R":
1088
+ return round(randomer[num+add],lucks["round"])
1089
+ elif dr.strip() == "U":
1090
+ return round(-2 * randomer[num+add] + 1,lucks["round"])
1091
+ elif dr.strip() == "X":
1092
+ return round((float(low[block])-float(up[block]))* randomer[num+add] + float(up[block]),lucks["round"])
1093
+ else:
1094
+ return float(dr)
1095
+
1096
+
1097
+ ################################################
1098
+ ##### Generate Image
1099
+
1100
+ def simggen(s_prompt,s_nprompt,s_steps,s_sampler,s_cfg,s_seed,s_w,s_h,s_batch_size,
1101
+ genoptions,s_hrupscaler,s_hr2ndsteps,s_denois_str,s_hr_scale,
1102
+ mergeinfo,id_sets,modelid,
1103
+ *txt2imgparams,
1104
+ debug = False
1105
+ ):
1106
+ shared.state.begin()
1107
+ from scripts.mergers.components import paramsnames
1108
+ if debug: print(paramsnames)
1109
+
1110
+ #[None, 'Prompt', 'Negative prompt', 'Styles', 'Sampling steps', 'Sampling method', 'Batch count', 'Batch size', 'CFG Scale',
1111
+ # 'Height', 'Width', 'Hires. fix', 'Denoising strength', 'Upscale by', 'Upscaler', 'Hires steps', 'Resize width to', 'Resize height to',
1112
+ # 'Hires checkpoint', 'Hires sampling method', 'Hires prompt', 'Hires negative prompt', 'Override settings', 'Script', 'Refiner',
1113
+ # 'Checkpoint', 'Switch at', 'Seed', 'Extra', 'Variation seed', 'Variation strength', 'Resize seed from width', 'Resize seed from height', '', 'Active', 'Active', 'X Types', 'X Values', 'Y Types', 'Y Values']
1114
+
1115
+ def g(wanted,wantedv=None):
1116
+ if wanted in paramsnames:return txt2imgparams[paramsnames.index(wanted)]
1117
+ elif wantedv and wantedv in paramsnames:return txt2imgparams[paramsnames.index(wantedv)]
1118
+ else:return None
1119
+
1120
+ sampler_index = g("Sampling method")
1121
+ if type(sampler_index) is str:
1122
+ sampler_name = sampler_index
1123
+ else:
1124
+ sampler_name = sd_samplers.samplers[sampler_index].name
1125
+
1126
+ hr_sampler_index = g("Hires sampling method")
1127
+ if hr_sampler_index is None: hr_sampler_index = 0
1128
+ if type(sampler_index) is str:
1129
+ hr_sampler_name = hr_sampler_index
1130
+ else:
1131
+ hr_sampler_name = "Use same sampler" if hr_sampler_index == 0 else sd_samplers.samplers[hr_sampler_index+1].name
1132
+
1133
+ p = processing.StableDiffusionProcessingTxt2Img(
1134
+ sd_model=shared.sd_model,
1135
+ outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
1136
+ outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
1137
+ prompt=g("Prompt"),
1138
+ styles=g("Styles"),
1139
+ negative_prompt=g('Negative prompt'),
1140
+ seed=g("Seed"),
1141
+ subseed=g("Variation seed"),
1142
+ subseed_strength=g("Variation strength"),
1143
+ seed_resize_from_h=g("Resize seed from height"),
1144
+ seed_resize_from_w=g("Resize seed from width"),
1145
+ seed_enable_extras=g("Extra"),
1146
+ sampler_name=sampler_name,
1147
+ batch_size=g("Batch size"),
1148
+ n_iter=g("Batch count"),
1149
+ steps=g("Sampling steps"),
1150
+ cfg_scale=g("CFG Scale"),
1151
+ width=g("Width"),
1152
+ height=g("Height"),
1153
+ restore_faces=g("Restore faces","Face restore"),
1154
+ tiling=g("Tiling"),
1155
+ enable_hr=g("Hires. fix","Second pass"),
1156
+ hr_scale=g("Upscale by"),
1157
+ hr_upscaler=g("Upscaler"),
1158
+ hr_second_pass_steps=g("Hires steps","Secondary steps"),
1159
+ hr_resize_x=g("Resize width to"),
1160
+ hr_resize_y=g("Resize height to"),
1161
+ override_settings=create_override_settings_dict(g("Override settings")),
1162
+ do_not_save_grid=True,
1163
+ do_not_save_samples=True,
1164
+ do_not_reload_embeddings=True,
1165
+ )
1166
+ p.hr_checkpoint_name=None if g("Hires checkpoint") == 'Use same checkpoint' else g("Hires checkpoint")
1167
+ p.hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name
1168
+
1169
+ if s_sampler is None: s_sampler = 0
1170
+
1171
+ if s_batch_size != 1 :p.batch_size = int(s_batch_size)
1172
+ if s_prompt: p.prompt = s_prompt
1173
+ if s_nprompt: p.negative_prompt = s_nprompt
1174
+ if s_steps: p.steps = s_steps
1175
+ if s_sampler: p.sampler_name = sampler_name
1176
+ if s_cfg: p.cfg_scale = s_cfg
1177
+ if s_seed: p.seed = s_seed
1178
+ if s_w: p.width = s_w
1179
+ if s_h: p.height = s_h
1180
+
1181
+ if not p.cfg_scale: p.cfg_scale = 7
1182
+
1183
+ p.scripts = scripts.scripts_txt2img
1184
+ p.script_args = txt2imgparams[paramsnames.index("Override settings")+1:]
1185
+
1186
+ p.denoising_strength=g("Denoising strength") if p.enable_hr else None
1187
+
1188
+ p.hr_prompt=g("Hires prompt","Secondary Prompt")
1189
+ p.hr_negative_prompt=g("Hires negative prompt","Secondary negative prompt")
1190
+
1191
+ if "Hires. fix" in genoptions:
1192
+ p.enable_hr = True
1193
+ if s_hrupscaler: p.hr_upscaler = s_hrupscaler
1194
+ if s_hr2ndsteps:p.hr_second_pass_steps = s_hr2ndsteps
1195
+ if s_denois_str:p.denoising_strength = s_denois_str
1196
+ if s_hr_scale:p.hr_scale = s_hr_scale
1197
+
1198
+ if "Restore faces" in genoptions:
1199
+ p.restore_faces = True
1200
+
1201
+ if "Tiling" in genoptions:
1202
+ p.tiling = True
1203
+
1204
+ p.cached_c = [None,None]
1205
+ p.cached_uc = [None,None]
1206
+
1207
+ p.cached_hr_c = [None, None]
1208
+ p.cached_hr_uc = [None, None]
1209
+
1210
+ if type(p.prompt) == list:
1211
+ p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
1212
+ else:
1213
+ p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
1214
+
1215
+ if type(p.negative_prompt) == list:
1216
+ p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
1217
+ else:
1218
+ p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
1219
+
1220
+ processed:Processed = processing.process_images(p)
1221
+ if "image" in id_sets:
1222
+ for i, image in enumerate(processed.images):
1223
+ processed.images[i] = draw_origin(image, str(modelid),p.width,p.height,p.width)
1224
+
1225
+ if "PNG info" in id_sets:mergeinfo = mergeinfo + " ID " + str(modelid)
1226
+
1227
+ infotext = create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds)
1228
+ if infotext.count("Steps: ")>1:
1229
+ infotext = infotext[:infotext.rindex("Steps")]
1230
+
1231
+ infotexts = infotext.split(",")
1232
+ for i,x in enumerate(infotexts):
1233
+ if "Model:"in x:
1234
+ infotexts[i] = " Model: "+mergeinfo.replace(","," ")
1235
+ infotext= ",".join(infotexts)
1236
+
1237
+ for i, image in enumerate(processed.images):
1238
+ images.save_image(image, opts.outdir_txt2img_samples, "",p.seed, p.prompt,shared.opts.samples_format, p=p,info=infotext)
1239
+
1240
+ if s_batch_size > 1:
1241
+ grid = images.image_grid(processed.images, p.batch_size)
1242
+ processed.images.insert(0, grid)
1243
+ images.save_image(grid, opts.outdir_txt2img_grids, "grid", p.seed, p.prompt, opts.grid_format, info=infotext, short_filename=not opts.grid_extended_filename, p=p, grid=True)
1244
+ shared.state.end()
1245
+ return processed.images,infotext,plaintext_to_html(processed.info), plaintext_to_html(processed.comments),p
1246
+
1247
+
1248
+ ################################################
1249
+ ##### Block Ids
1250
+
1251
+ def blocker(blocks,blockids):
1252
+ blocks = blocks.split(" ")
1253
+ output = ""
1254
+ for w in blocks:
1255
+ flagger=[False]*len(blockids)
1256
+ changer = True
1257
+ if "-" in w:
1258
+ wt = [wt.strip() for wt in w.split('-')]
1259
+ if blockids.index(wt[1]) > blockids.index(wt[0]):
1260
+ flagger[blockids.index(wt[0]):blockids.index(wt[1])+1] = [changer]*(blockids.index(wt[1])-blockids.index(wt[0])+1)
1261
+ else:
1262
+ flagger[blockids.index(wt[1]):blockids.index(wt[0])+1] = [changer]*(blockids.index(wt[0])-blockids.index(wt[1])+1)
1263
+ else:
1264
+ output = output + " " + w if output else w
1265
+ for i in range(len(blockids)):
1266
+ if flagger[i]: output = output + " " + blockids[i] if output else blockids[i]
1267
+ return output
1268
+
1269
+
1270
+ def blockfromkey(key,isxl):
1271
+ if not isxl:
1272
+ re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12
1273
+ re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1
1274
+ re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12
1275
+
1276
+ weight_index = -1
1277
+
1278
+ NUM_INPUT_BLOCKS = 12
1279
+ NUM_MID_BLOCK = 1
1280
+ NUM_OUTPUT_BLOCKS = 12
1281
+ NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS
1282
+
1283
+ if 'time_embed' in key:
1284
+ weight_index = -2 # before input blocks
1285
+ elif '.out.' in key:
1286
+ weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks
1287
+ else:
1288
+ m = re_inp.search(key)
1289
+ if m:
1290
+ inp_idx = int(m.groups()[0])
1291
+ weight_index = inp_idx
1292
+ else:
1293
+ m = re_mid.search(key)
1294
+ if m:
1295
+ weight_index = NUM_INPUT_BLOCKS
1296
+ else:
1297
+ m = re_out.search(key)
1298
+ if m:
1299
+ out_idx = int(m.groups()[0])
1300
+ weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx
1301
+ return BLOCKID[weight_index+1] ,BLOCKID[weight_index+1]
1302
+
1303
+ else:
1304
+ if not ("weight" in key or "bias" in key):return "Not Merge","Not Merge"
1305
+ if "label_emb" in key or "time_embed" in key: return "Not Merge","Not Merge"
1306
+ if "conditioner.embedders" in key : return "BASE","BASE"
1307
+ if "first_stage_model" in key : return "VAE","BASE"
1308
+ if "model.diffusion_model" in key:
1309
+ if "model.diffusion_model.out." in key: return "OUT8","OUT08"
1310
+ block = re.findall(r'input|mid|output', key)
1311
+ block = block[0].upper().replace("PUT","") if block else ""
1312
+ nums = re.sub(r"\D", "", key)[:1 if "MID" in block else 2] + ("0" if "MID" in block else "")
1313
+ add = re.findall(r"transformer_blocks\.(\d+)\.",key)[0] if "transformer" in key else ""
1314
+ return block + nums + add, block + "0" + nums[0] if "MID" not in block else "M00"
1315
+
1316
+ return "Not Merge", "Not Merge"
1317
+
1318
+ ################################################
1319
+ ##### Adjust
1320
+
1321
+ def fineman(fine,isxl):
1322
+ if fine.find(",") != -1:
1323
+ tmp = [t.strip() for t in fine.split(",")]
1324
+ fines = [0.0]*8
1325
+ for i,f in enumerate(tmp[0:8]):
1326
+ try:
1327
+ f = float(f)
1328
+ fines[i] = f
1329
+ except Exception:
1330
+ pass
1331
+
1332
+ fine = fines
1333
+ else:
1334
+ return None
1335
+
1336
+ fine = [
1337
+ 1 - fine[0] * 0.01,
1338
+ 1+ fine[0] * 0.02,
1339
+ 1 - fine[1] * 0.01,
1340
+ 1+ fine[1] * 0.02,
1341
+ 1 - fine[2] * 0.01,
1342
+ [fine[3]*0.02] + colorcalc(fine[4:8],isxl)
1343
+ ]
1344
+ return fine
1345
+
1346
+ def colorcalc(cols,isxl):
1347
+ colors = COLSXL if isxl else COLS
1348
+ outs = [[y * cols[i] * 0.02 for y in x] for i,x in enumerate(colors)]
1349
+ return [sum(x) for x in zip(*outs)]
1350
+
1351
+ COLS = [[-1,1/3,2/3],[1,1,0],[0,-1,-1],[1,0,1]]
1352
+ COLSXL = [[0,0,1],[1,0,0],[-1,-1,0],[-1,1,0]]
1353
+
1354
+ def weighttoxl(weight):
1355
+ weight = weight[:9] + weight[12:22] +[0]
1356
+ return weight
1357
+
1358
+ FINETUNES = [
1359
+ "model.diffusion_model.input_blocks.0.0.weight",
1360
+ "model.diffusion_model.input_blocks.0.0.bias",
1361
+ "model.diffusion_model.out.0.weight",
1362
+ "model.diffusion_model.out.0.bias",
1363
+ "model.diffusion_model.out.2.weight",
1364
+ "model.diffusion_model.out.2.bias",
1365
+ ]
1366
+
1367
+ ################################################
1368
+ ##### Include/Exclude
1369
+ def excluder(block:str,inex:bool,ex_blocks:list,ex_elems:list, key:str):
1370
+ if ex_blocks == [] and ex_elems == [""]:
1371
+ return False
1372
+ out = True if inex == "Include" else False
1373
+ if block in ex_blocks:out = not out
1374
+ if "Adjust" in ex_blocks and key in FINETUNES:out = not out
1375
+ for ke in ex_elems:
1376
+ if ke != "" and ke in key:out = not out
1377
+ if "VAE" in ex_blocks and "first_stage_model"in key:out = not out
1378
+ if "print" in ex_blocks and (out ^ (inex == "Include")):
1379
+ print("Include" if inex else "Exclude",block,ex_blocks,ex_elems,key)
1380
+ return out
1381
+
1382
+ ################################################
1383
+ ##### Reset Broken CliP IDs
1384
+
1385
+ def resetclip(theta):
1386
+ idkey = "cond_stage_model.transformer.text_model.embeddings.position_ids"
1387
+ broken = []
1388
+ if idkey in theta.keys():
1389
+ correct = torch.Tensor([list(range(77))]).to(torch.int64)
1390
+ current = theta[idkey].to(torch.int64)
1391
+
1392
+ broken = correct.ne(current)
1393
+ broken = [i for i in range(77) if broken[0][i]]
1394
+
1395
+ if broken != []: print("Clip IDs broken and fixed: ",broken)
1396
+
1397
+ theta[idkey] = correct
1398
+
1399
+
1400
+ ################################################
1401
+ ##### cache
1402
+ def cachedealer(start):
1403
+ if start:
1404
+ global orig_cache
1405
+ orig_cache = shared.opts.sd_checkpoint_cache
1406
+ shared.opts.sd_checkpoint_cache = 0
1407
+ else:
1408
+ shared.opts.sd_checkpoint_cache = orig_cache
1409
+
1410
+ def clearcache():
1411
+ global modelcache
1412
+ del modelcache
1413
+ modelcache = {}
1414
+ gc.collect()
1415
+ devices.torch_gc()
1416
+
1417
+ def getcachelist():
1418
+ output = []
1419
+ for key in modelcache.keys():
1420
+ if hasattr(key, "model_name"):
1421
+ output.append(key.model_name)
1422
+ return ",".join(output)
1423
+
1424
+ ################################################
1425
+ ##### print
1426
+
1427
+ def printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks,fine,inex,ex_blocks,ex_elems):
1428
+ print(f" model A \t: {model_a}")
1429
+ print(f" model B \t: {model_b}")
1430
+ print(f" model C \t: {model_c}")
1431
+ print(f" alpha,beta\t: {base_alpha,base_beta}")
1432
+ print(f" weights_alpha\t: {weights_a}")
1433
+ print(f" weights_beta\t: {weights_b}")
1434
+ print(f" mode\t\t: {mode}")
1435
+ print(f" MBW \t\t: {useblocks}")
1436
+ print(f" CalcMode \t: {calcmode}")
1437
+ print(f" Elemental \t: {deep}")
1438
+ print(f" Weights Seed\t: {lucks}")
1439
+ print(f" {inex} \t: {ex_blocks,ex_elems}")
1440
+ print(f" Adjust \t: {fine}")
1441
+
1442
+ def caster(news,hear):
1443
+ if hear: print(news)
1444
+
1445
+ def casterr(*args,hear=hear):
1446
+ if hear:
1447
+ names = {id(v): k for k, v in currentframe().f_back.f_locals.items()}
1448
+ print('\n'.join([names.get(id(arg), '???') + ' = ' + repr(arg) for arg in args]))