NagisaNao commited on
Commit
5763524
·
verified ·
1 Parent(s): a639f03

remove dir

Browse files
sagemaker/fixing/extensions/supermerger/scripts/mergers/mergers.py DELETED
@@ -1,1448 +0,0 @@
1
- import random
2
- import os
3
- import gc
4
- import hashlib
5
- import numpy as np
6
- import os.path
7
- import re
8
- import torch
9
- import tqdm
10
- import datetime
11
-
12
- import csv
13
- import json
14
- import launch
15
- import torch.nn as nn
16
- import scipy.ndimage
17
- from copy import deepcopy
18
- from PIL import Image, ImageFont, ImageDraw
19
- from tqdm import tqdm
20
- from functools import partial
21
- from torch import Tensor, lerp
22
- from torch.nn.functional import cosine_similarity, relu, softplus
23
- from modules import shared, processing, sd_models, sd_vae, images, sd_samplers, scripts,devices, extras
24
- from modules.ui import plaintext_to_html
25
- from modules.shared import opts
26
- from modules.processing import create_infotext,Processed
27
- from modules.sd_models import load_model,unload_model_weights
28
- from modules.generation_parameters_copypaste import create_override_settings_dict
29
- from scripts.mergers.model_util import filenamecutter,savemodel
30
- from math import ceil
31
- import sys
32
- from multiprocessing import cpu_count
33
- from threading import Lock
34
- from concurrent.futures import ThreadPoolExecutor, as_completed
35
- from scripts.mergers.bcolors import bcolors
36
- import collections
37
-
38
- try:
39
- ui_version = int(launch.git_tag().split("-",1)[0].replace("v","").replace(".",""))
40
- except:
41
- ui_version = 100
42
-
43
- orig_cache = 0
44
-
45
- modelcache = collections.OrderedDict()
46
-
47
- from inspect import currentframe
48
-
49
- SELFKEYS = ["to_out","proj_out","norm"]
50
-
51
- module_path = os.path.dirname(os.path.abspath(sys.modules[__name__].__file__))
52
- scriptpath = os.path.dirname(module_path)
53
-
54
- def tryit(func):
55
- try:
56
- func()
57
- except:
58
- pass
59
-
60
- stopmerge = False
61
-
62
- def freezemtime():
63
- global stopmerge
64
- stopmerge = True
65
-
66
- mergedmodel=[]
67
- FINETUNEX = ["IN","OUT","OUT2","CONT","BRI","COL1","COL2","COL3"]
68
- TYPESEG = ["none","alpha","beta (if Triple or Twice is not selected,Twice automatically enable)","alpha and beta","seed",
69
- "mbw alpha","mbw beta","mbw alpha and beta", "model_A","model_B","model_C","pinpoint blocks (alpha or beta must be selected for another axis)",
70
- "include blocks", "exclude blocks","add include", "add exclude","elemental","add elemental","pinpoint element","effective elemental checker","adjust","pinpoint adjust (IN,OUT,OUT2,CONT,BRI,COL1,COL2,COL3)",
71
- "calcmode","prompt","random"]
72
- TYPES = ["none","alpha","beta","alpha and beta","seed", "mbw alpha ","mbw beta","mbw alpha and beta",
73
- "model_A","model_B","model_C","pinpoint blocks","include blocks","exclude blocks","add include", "add exclude","elemental","add elemental","pinpoint element",
74
- "effective","adjust","pinpoint adjust","calcmode","prompt","random"]
75
- MODES=["Weight" ,"Add" ,"Triple","Twice"]
76
- SAVEMODES=["save model", "overwrite"]
77
- EXCLUDE_CHOICES = ["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11",
78
- "M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11",
79
- "Adjust","VAE"]
80
- CHCKPOINT_DICT_SKIP_ON_MERGE = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
81
-
82
- #type[0:aplha,1:beta,2:seed,3:mbw,4:model_A,5:model_B,6:model_C]
83
- #msettings=[0 weights_a,1 weights_b,2 model_a,3 model_b,4 model_c,5 base_alpha,6 base_beta,7 mode,8 useblocks,9 custom_name,10 save_sets,11 id_sets,12 wpresets]
84
- #id sets "image", "PNG info","XY grid"
85
-
86
- hear = False
87
- hearm = False
88
- NON4 = [None]*4
89
-
90
- informer = sd_models.get_closet_checkpoint_match
91
-
92
- #msettings=[weights_a,weights_b,model_a,model_b,model_c,device,base_alpha,base_beta,mode,loranames,useblocks,custom_name,save_sets,id_sets,wpresets,deep]
93
-
94
- def smergegen(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode,
95
- calcmode,useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,bake_in_vae,opt_value,inex,ex_blocks,ex_elems,
96
- esettings,
97
- s_prompt,s_nprompt,s_steps,s_sampler,s_cfg,s_seed,s_w,s_h,s_batch_size,
98
- genoptions,s_hrupscaler,s_hr2ndsteps,s_denois_str,s_hr_scale,
99
- lmode,lsets,llimits_u,llimits_l,lseed,lserial,lcustom,lround,
100
- currentmodel,imggen,
101
- *txt2imgparams):
102
-
103
- lucks = {"on":False, "mode":lmode,"set":lsets,"upp":llimits_u,"low":llimits_l,"seed":lseed,"num":lserial,"cust":lcustom,"round":int(lround)}
104
- deepprint = "print change" in esettings
105
-
106
- cachedealer(True)
107
-
108
- result,currentmodel,modelid,theta_0,metadata = smerge(
109
- weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode,calcmode,
110
- useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,bake_in_vae,opt_value,inex,ex_blocks,ex_elems,deepprint,lucks
111
- )
112
-
113
- if "ERROR" in result or "STOPPED" in result:
114
- return result,"not loaded",*NON4
115
-
116
- checkpoint_info = sd_models.get_closet_checkpoint_match(model_a)
117
-
118
- if ui_version >= 150: checkpoint_info = fake_checkpoint_info(checkpoint_info,metadata,currentmodel)
119
-
120
- save = True if SAVEMODES[0] in save_sets else False
121
-
122
- result = savemodel(theta_0,currentmodel,custom_name,save_sets,metadata) if save else "Merged model loaded:"+currentmodel
123
-
124
- sd_models.model_data.__init__()
125
- load_model(checkpoint_info, already_loaded_state_dict=theta_0)
126
-
127
- cachedealer(False)
128
-
129
- del theta_0
130
- devices.torch_gc()
131
-
132
- debug = "debug" in save_sets
133
-
134
- if ("copy config" in save_sets) and ("(" not in result): extras.create_config(result.replace("Merged model saved in ",""), 0, informer(model_a), informer(model_b), informer(model_b))
135
-
136
- if imggen :
137
- images = simggen(s_prompt,s_nprompt,s_steps,s_sampler,s_cfg,s_seed,s_w,s_h,s_batch_size,
138
- genoptions,s_hrupscaler,s_hr2ndsteps,s_denois_str,s_hr_scale,
139
- currentmodel,id_sets,modelid,
140
- *txt2imgparams,debug = debug)
141
-
142
- return result,currentmodel,*images[:4]
143
- else:
144
- return result,currentmodel
145
-
146
- def checkpointer_infomer(name):
147
- return sd_models.get_closet_checkpoint_match(name)
148
-
149
- # XXX hack. fake checkpoint_info
150
- def fake_checkpoint_info(checkpoint_info,metadata,currentmodel):
151
- from modules import cache
152
- dump_cache = cache.dump_cache
153
- c_cache = cache.cache
154
-
155
- checkpoint_info = deepcopy(checkpoint_info)
156
- # change model name etc.
157
- sha256 = hashlib.sha256(json.dumps(metadata).encode("utf-8")).hexdigest()
158
- checkpoint_info.sha256 = sha256
159
- checkpoint_info.name_for_extra = currentmodel
160
-
161
- checkpoint_info.name = checkpoint_info.name_for_extra + ".safetensors"
162
- checkpoint_info.model_name = checkpoint_info.name_for_extra.replace("/", "_").replace("\\", "_")
163
- checkpoint_info.title = f"{checkpoint_info.name} [{sha256[0:10]}]"
164
- checkpoint_info.metadata = metadata
165
-
166
- # for sd-webui v1.5.x
167
- sd_models.checkpoints_list[checkpoint_info.title] = checkpoint_info
168
-
169
- # force to set a new sha256 hash
170
- if c_cache is not None:
171
- hashes = c_cache("hashes")
172
- hashes[f"checkpoint/{checkpoint_info.name}"] = {
173
- "mtime": os.path.getmtime(checkpoint_info.filename),
174
- "sha256": sha256,
175
- }
176
- # save cache
177
- dump_cache()
178
-
179
- # set ids for a fake checkpoint info
180
- checkpoint_info.ids = [checkpoint_info.model_name, checkpoint_info.name, checkpoint_info.name_for_extra]
181
- return checkpoint_info
182
-
183
- NUM_INPUT_BLOCKS = 12
184
- NUM_MID_BLOCK = 1
185
- NUM_OUTPUT_BLOCKS = 12
186
- NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS
187
- BLOCKID=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
188
- BLOCKIDXLL=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","VAE"]
189
- BLOCKIDXL=['BASE', 'IN0', 'IN1', 'IN2', 'IN3', 'IN4', 'IN5', 'IN6', 'IN7', 'IN8', 'M', 'OUT0', 'OUT1', 'OUT2', 'OUT3', 'OUT4', 'OUT5', 'OUT6', 'OUT7', 'OUT8', 'VAE']
190
-
191
- RANDMAP = [0,50,100] #alpha,beta,elements
192
-
193
- statistics = {"sum":{},"mean":{},"max":{},"min":{}}
194
-
195
- ################################################
196
- ##### Main Merging Code
197
-
198
- def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode,calcmode,
199
- useblocks,custom_name,save_sets,id_sets,wpresets,deep,fine,bake_in_vae,opt_value,inex,ex_blocks,ex_elems,deepprint,lucks,main = [False,False,False]):
200
-
201
- caster("merge start",hearm)
202
- global hear,mergedmodel,stopmerge,statistics
203
- stopmerge = False
204
-
205
- debug = "debug" in save_sets
206
- uselerp = "use old calc method" not in save_sets
207
- device = "cuda" if "use cuda" in save_sets else "cpu"
208
-
209
- unload_model_weights(sd_models.model_data.sd_model)
210
-
211
- # for from file
212
- if type(useblocks) is str:
213
- useblocks = True if useblocks =="True" else False
214
- if type(base_alpha) == str:base_alpha = float(base_alpha)
215
- if type(base_beta) == str:base_beta = float(base_beta)
216
-
217
- #random
218
- if lucks != {}:
219
- if lucks["seed"] == -1: lucks["ceed"] = str(random.randrange(4294967294))
220
- else: lucks["ceed"] = lucks["seed"]
221
- else: lucks["ceed"] = 0
222
- np.random.seed(int(lucks["ceed"]))
223
- randomer = np.random.rand(2500)
224
-
225
- cachetarget =[]
226
- for model,num in zip([model_a,model_b,model_c],main):
227
- if model != "" and num:
228
- cachetarget.append(model)
229
-
230
- weights_a,deep = randdealer(weights_a,randomer,0,lucks,deep)
231
- weights_b,_ = randdealer(weights_b,randomer,1,lucks,None)
232
-
233
- weights_a_orig = weights_a
234
- weights_b_orig = weights_b
235
-
236
- # preset to weights
237
- if wpresets != False and useblocks:
238
- weights_a = wpreseter(weights_a,wpresets)
239
- weights_b = wpreseter(weights_b,wpresets)
240
-
241
- # mode select booleans
242
- usebeta = MODES[2] in mode or MODES[3] in mode or "tensor" in calcmode
243
- metadata = {"format": "pt"}
244
-
245
- if (calcmode == "trainDifference" or calcmode == "extract") and "Add" not in mode:
246
- print(f"{bcolors.WARNING}Mode changed to add difference{bcolors.ENDC}")
247
- mode = "Add"
248
- if model_c == "" or model_c is None:
249
- #fallback to avoid crash
250
- model_c = model_a
251
- print(f"{bcolors.WARNING}Substituting empty model_c with model_a{bcolors.ENDC}")
252
-
253
- if not useblocks:
254
- weights_a = weights_b = ""
255
- #for save log and save current model
256
- mergedmodel =[weights_a,weights_b,
257
- hashfromname(model_a),hashfromname(model_b),hashfromname(model_c),
258
- base_alpha,base_beta,mode,useblocks,custom_name,save_sets,id_sets,deep,calcmode,lucks["ceed"],fine,opt_value,inex,ex_blocks,ex_elems].copy()
259
-
260
- model_a = namefromhash(model_a)
261
- model_b = namefromhash(model_b)
262
- model_c = namefromhash(model_c)
263
-
264
- caster(mergedmodel,False)
265
-
266
- #elementals
267
- if len(deep) > 0:
268
- deep = deep.replace("\n",",")
269
- deep = deep.replace(calcmode+",","")
270
- deep = deep.split(",")
271
-
272
- #format check
273
- if model_a =="" or model_b =="" or ((not MODES[0] in mode) and model_c=="") :
274
- return "ERROR: Necessary model is not selected",*NON4
275
-
276
- #for MBW text to list
277
- if useblocks:
278
- weights_a_t=weights_a.split(',',1)
279
- weights_b_t=weights_b.split(',',1)
280
- base_alpha = float(weights_a_t[0])
281
- weights_a = [float(w) for w in weights_a_t[1].split(',')]
282
- caster(f"from {weights_a_t}, alpha = {base_alpha},weights_a ={weights_a}",hearm)
283
- if not (len(weights_a) == 25 or len(weights_a) == 19):return f"ERROR: weights alpha value must be 20 or 26.",*NON4
284
- if usebeta:
285
- base_beta = float(weights_b_t[0])
286
- weights_b = [float(w) for w in weights_b_t[1].split(',')]
287
- caster(f"from {weights_b_t}, beta = {base_beta},weights_a ={weights_b}",hearm)
288
- if not(len(weights_b) == 25 or len(weights_b) == 19): return f"ERROR: weights beta value must be 20 or 26.",*NON4
289
-
290
- caster("model load start",hearm)
291
- printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks['ceed'],fine,inex,ex_blocks,ex_elems)
292
-
293
- theta_1=load_model_weights_m(model_b,2,cachetarget,device).copy()
294
-
295
- isxl = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in theta_1.keys()
296
-
297
- #adjust
298
- if fine.rstrip(",0") != "":
299
- fine = fineman(fine,isxl)
300
- else:
301
- fine = ""
302
-
303
- if isxl and useblocks:
304
- if len(weights_a) == 25:
305
- weights_a = weighttoxl(weights_a)
306
- print(f"weight converted for XL{weights_a}")
307
- if usebeta:
308
- if len(weights_b) == 25:
309
- weights_b = weighttoxl(weights_b)
310
- print(f"weight converted for XL{weights_b}")
311
- if len(weights_a) == 19: weights_a = weights_a + [0]
312
- if len(weights_b) == 19: weights_b = weights_b + [0]
313
-
314
- if MODES[1] in mode:#Add
315
- if stopmerge: return "STOPPED", *NON4
316
- if calcmode == "trainDifference" or calcmode == "extract":
317
- theta_2 = load_model_weights_m(model_c,3,cachetarget,device).copy()
318
- else:
319
- theta_2 = load_model_weights_m(model_c,3,cachetarget,device).copy()
320
- for key in tqdm(theta_1.keys()):
321
- if 'model' in key:
322
- if key in theta_2:
323
- t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
324
- theta_1[key] = theta_1[key]- t2
325
- else:
326
- theta_1[key] = torch.zeros_like(theta_1[key])
327
- del theta_2
328
-
329
- if stopmerge: return "STOPPED", *NON4
330
-
331
- if "tensor" in calcmode or "self" in calcmode:
332
- theta_t = load_model_weights_m(model_a,1,cachetarget,device).copy()
333
- theta_0 ={}
334
- for key in theta_t:
335
- theta_0[key] = theta_t[key].clone()
336
- del theta_t
337
- else:
338
- theta_0=load_model_weights_m(model_a,1,cachetarget,device).copy()
339
-
340
- if MODES[2] in mode or MODES[3] in mode:#Tripe or Twice
341
- theta_2 = load_model_weights_m(model_c,3,cachetarget,device).copy()
342
- else:
343
- if not (calcmode == "trainDifference" or calcmode == "extract"):
344
- theta_2 = {}
345
-
346
- alpha = base_alpha
347
- beta = base_beta
348
-
349
- ex_elems = ex_elems.split(",")
350
-
351
- keyratio = []
352
- key_and_alpha = {}
353
-
354
- ##### Stage 0/2 in Cosine
355
- if "cosine" in calcmode:
356
- sim, sims = precosine("A" in calcmode,theta_0,theta_1)
357
-
358
- ##### Stage 1/2
359
-
360
- for num, key in enumerate(tqdm(theta_0.keys(), desc="Stage 1/2") if not False else theta_0.keys()):
361
- if stopmerge: return "STOPPED", *NON4
362
- if not ("model" in key and key in theta_1): continue
363
- if not ("weight" in key or "bias" in key): continue
364
- if calcmode == "trainDifference" or calcmode == "extract":
365
- if key not in theta_2:
366
- continue
367
- else:
368
- if usebeta and (not key in theta_2) and (not theta_2 == {}) :
369
- continue
370
-
371
- weight_index = -1
372
- current_alpha = alpha
373
- current_beta = beta
374
-
375
- a = list(theta_0[key].shape)
376
- b = list(theta_1[key].shape)
377
-
378
- # this enables merging an inpainting model (A) with another one (B);
379
- # where normal model would have 4 channels, for latenst space, inpainting model would
380
- # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
381
- if a != b and a[0:1] + a[2:] == b[0:1] + b[2:]:
382
- if a[1] == 4 and b[1] == 9:
383
- raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
384
- if a[1] == 4 and b[1] == 8:
385
- raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
386
-
387
- if a[1] == 8 and b[1] == 4:#If we have an Instruct-Pix2Pix model...
388
- result_is_instruct_pix2pix_model = True
389
- else:
390
- assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
391
- result_is_inpainting_model = True
392
-
393
- block,blocks26 = blockfromkey(key,isxl)
394
- if block == "Not Merge": continue
395
- if inex != "Off" and (ex_blocks or (ex_elems != [""])) and excluder(blocks26,inex,ex_blocks,ex_elems,key): continue
396
- weight_index = BLOCKIDXLL.index(blocks26) if isxl else BLOCKID.index(blocks26)
397
-
398
- if useblocks:
399
- if weight_index > 0:
400
- current_alpha = weights_a[weight_index - 1]
401
- if usebeta:
402
- current_beta = weights_b[weight_index - 1]
403
-
404
- if len(deep) > 0:
405
- current_alpha = elementals(key,weight_index,deep,randomer,num,lucks,deepprint,current_alpha)
406
-
407
- keyratio.append([key,current_alpha, current_beta])
408
- #keyratio.append([key,current_alpha, current_beta,list(theta_0[key].shape),torch.sum(theta_0[key]).item(), torch.mean(theta_0[key]).item(), torch.max(theta_0[key]).item(), torch.min(theta_0[key]).item()])
409
-
410
- if calcmode == "normal":
411
- if a != b and a[0:1] + a[2:] == b[0:1] + b[2:]:
412
- # Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
413
- theta_0_a = theta_0[key][:, 0:4, :, :]
414
- else:
415
- theta_0_a = theta_0[key]
416
-
417
- if MODES[1] in mode:#Add
418
- caster(f"{num}, {block}, {model_a}+{current_alpha}+*({model_b}-{model_c}),{key}",hear)
419
- theta_0_a = theta_0_a + current_alpha * theta_1[key]
420
-
421
- elif MODES[2] in mode:#Triple
422
- caster(f"{num}, {block}, {model_a}+{1-current_alpha-current_beta}+{model_b}*{current_alpha}+ {model_c}*{current_beta}",hear)
423
- #
424
- if uselerp and current_alpha + current_beta != 0:
425
- theta_0_a =lerp(theta_0_a.to(torch.float32),lerp(theta_1[key].to(torch.float32),theta_2[key].to(torch.float32),current_beta/(current_alpha + current_beta)),current_alpha + current_beta).to(theta_0_a.dtype)
426
- else:
427
- theta_0_a = (1 - current_alpha-current_beta) * theta_0_a + current_alpha * theta_1[key]+current_beta * theta_2[key]
428
-
429
- elif MODES[3] in mode:#Twice
430
- caster(f"{num}, {block}, {key},{model_a} + {1-current_alpha} + {model_b}*{current_alpha}",hear)
431
- caster(f"{num}, {block}, {key}({model_a}+{model_b}) +{1-current_beta}+{model_c}*{current_beta}",hear)
432
- if uselerp:
433
- theta_0_a = torch.lerp(torch.lerp(theta_0_a.to(torch.float32), theta_1[key].to(torch.float32), current_alpha), theta_2[key].to(torch.float32), current_beta).to(theta_0_a.dtype)
434
- else:
435
- theta_0_a = (1 - current_alpha) * theta_0_a + current_alpha * theta_1[key]
436
- theta_0_a = (1 - current_beta) * theta_0_a + current_beta * theta_2[key]
437
-
438
- else:#Weight
439
- if current_alpha == 1:
440
- caster(f"{num}, {block}, {key} alpha = 1,{model_a}={model_b}",hear)
441
- theta_0_a = theta_1[key]
442
- elif current_alpha !=0:
443
- caster(f"{num}, {block}, {key}, {model_a}*{1-current_alpha}+{model_b}*{current_alpha}",hear)
444
- if uselerp:
445
- theta_0_a = torch.lerp(theta_0_a.to(torch.float32), theta_1[key].to(torch.float32), current_alpha).to(theta_0_a.dtype)
446
- else:
447
- theta_0_a = (1 - current_alpha) * theta_0_a + current_alpha * theta_1[key]
448
-
449
- if a != b and a[0:1] + a[2:] == b[0:1] + b[2:]:
450
- theta_0[key][:, 0:4, :, :] = theta_0_a
451
- else:
452
- theta_0[key] = theta_0_a
453
-
454
- del theta_0_a, a, b
455
-
456
- elif "cosine" in calcmode:
457
- if "first_stage_model" in key: continue
458
- cosine(calcmode,key,sim,sims,current_alpha,theta_0,theta_1,num,block,uselerp)
459
-
460
- elif calcmode == "trainDifference":
461
- if torch.allclose(theta_1[key].float(), theta_2[key].float(), rtol=0, atol=0):
462
- theta_2[key] = theta_0[key]
463
- continue
464
- traindiff(key,current_alpha,theta_0,theta_1,theta_2)
465
-
466
- elif calcmode == "smoothAdd":
467
- caster(f"{num}, {block}, model A[{key}] + {current_alpha} + * (model B - model C)[{key}]", hear)
468
- # Apply median filter to the weight differences
469
- filtered_diff = scipy.ndimage.median_filter(theta_1[key].to(torch.float32).cpu().numpy(), size=3)
470
- # Apply Gaussian filter to the filtered differences
471
- filtered_diff = scipy.ndimage.gaussian_filter(filtered_diff, sigma=1)
472
- theta_1[key] = torch.tensor(filtered_diff)
473
- # Add the filtered differences to the original weights
474
- theta_0[key] = theta_0[key] + current_alpha * theta_1[key]
475
-
476
- elif calcmode == "smoothAdd MT":
477
- key_and_alpha[key] = current_alpha
478
-
479
- elif "tensor" in calcmode:
480
- dim = theta_0[key].dim()
481
- if dim == 0 : continue
482
- tensormerge("2" not in calcmode,key,dim,theta_0,theta_1,current_alpha,current_beta)
483
-
484
- elif "extract" == calcmode:
485
- theta_0[key] = extract_super(theta_0[key],theta_1[key],theta_2[key],current_alpha,current_beta,opt_value)
486
-
487
- elif calcmode == "self":
488
- if any(selfkey in key for selfkey in SELFKEYS):continue
489
- if current_alpha == 0: continue
490
- theta_0[key] = (theta_0[key].clone()) * current_alpha
491
-
492
- elif calcmode == "plus random":
493
- if any(selfkey in key for selfkey in SELFKEYS):continue
494
- if current_alpha == 0: continue
495
- theta_0[key] += torch.randn_like(theta_0[key].clone()) * current_alpha
496
-
497
- ##### Adjust
498
- if any(item in key for item in FINETUNES) and fine:
499
- index = FINETUNES.index(key)
500
- if 5 > index :
501
- theta_0[key] =theta_0[key]* fine[index]
502
- else :theta_0[key] =theta_0[key] + torch.tensor(fine[5]).to(theta_0[key].device)
503
-
504
-
505
- if calcmode == "smoothAdd MT":
506
- # setting threads to higher than 8 doesn't significantly affect the time for merging
507
- threads = cpu_count()
508
- tasks_per_thread = 8
509
-
510
- theta_0, theta_1, stopped = multithread_smoothadd(key_and_alpha, theta_0, theta_1, threads, tasks_per_thread, hear)
511
- if stopped:
512
- return "STOPPED", *NON4
513
-
514
- currentmodel = makemodelname(weights_a,weights_b,model_a, model_b,model_c, base_alpha,base_beta,useblocks,mode,calcmode)
515
-
516
- for key in tqdm(theta_1.keys(), desc="Stage 2/2"):
517
- if key in CHCKPOINT_DICT_SKIP_ON_MERGE:
518
- continue
519
- if "model" in key and key not in theta_0:
520
- theta_0.update({key:theta_1[key]})
521
-
522
- del theta_1
523
- if calcmode == "trainDifference" or calcmode == "extract":
524
- del theta_2
525
-
526
- ##### BakeVAE
527
- bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
528
- if bake_in_vae_filename is not None:
529
- print(f"Baking in VAE from {bake_in_vae_filename}")
530
- vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
531
-
532
- for key in vae_dict.keys():
533
- theta_0_key = 'first_stage_model.' + key
534
- if theta_0_key in theta_0:
535
- theta_0[theta_0_key] = vae_dict[key]
536
-
537
- del vae_dict
538
-
539
- modelid = rwmergelog(currentmodel,mergedmodel)
540
- if "save E-list" in lucks["set"]: saveekeys(keyratio,modelid)
541
-
542
- caster(mergedmodel,False)
543
- if "Reset CLIP ids" in save_sets: resetclip(theta_0)
544
-
545
- if True: # always set metadata. savemodel() will check save_sets later
546
- merge_recipe = {
547
- "type": "sd-webui-supermerger",
548
- "weights_alpha": weights_a if useblocks else None,
549
- "weights_beta": weights_b if useblocks else None,
550
- "weights_alpha_orig": weights_a_orig if useblocks else None,
551
- "weights_beta_orig": weights_b_orig if useblocks else None,
552
- "model_a": longhashfromname(model_a),
553
- "model_b": longhashfromname(model_b),
554
- "model_c": longhashfromname(model_c),
555
- "base_alpha": base_alpha,
556
- "base_beta": base_beta,
557
- "mode": mode,
558
- "mbw": useblocks,
559
- "elemental_merge": deep,
560
- "calcmode" : calcmode,
561
- f"{inex}":ex_blocks + ex_elems
562
- }
563
- metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
564
- metadata["sd_merge_models"] = {}
565
-
566
- def add_model_metadata(checkpoint_name):
567
- checkpoint_info = sd_models.get_closet_checkpoint_match(checkpoint_name)
568
- checkpoint_info.calculate_shorthash()
569
- metadata["sd_merge_models"][checkpoint_info.sha256] = {
570
- "name": checkpoint_name,
571
- "legacy_hash": checkpoint_info.hash
572
- }
573
-
574
- #metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))
575
-
576
- if model_a:
577
- add_model_metadata(model_a)
578
- if model_b:
579
- add_model_metadata(model_b)
580
- if model_c:
581
- add_model_metadata(model_c)
582
-
583
- metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
584
-
585
- return "",currentmodel,modelid,theta_0,metadata
586
-
587
- ################################################
588
- ##### cosineA/B
589
- def precosine(calcmode,theta_0,theta_1):
590
- if calcmode: #favors modelA's structure with details from B
591
- if stopmerge: return "STOPPED", *NON4
592
- sim = torch.nn.CosineSimilarity(dim=0)
593
- sims = np.array([], dtype=np.float64)
594
- for key in (tqdm(theta_0.keys(), desc="Stage 0/2")):
595
- # skip VAE model parameters to get better results
596
- if "first_stage_model" in key: continue
597
- if "model" in key and key in theta_1:
598
- theta_0_norm = nn.functional.normalize(theta_0[key].to(torch.float32), p=2, dim=0)
599
- theta_1_norm = nn.functional.normalize(theta_1[key].to(torch.float32), p=2, dim=0)
600
- simab = sim(theta_0_norm, theta_1_norm)
601
- sims = np.append(sims,simab.cpu().numpy())
602
- sims = sims[~np.isnan(sims)]
603
- sims = np.delete(sims, np.where(sims<np.percentile(sims, 1 ,method = 'midpoint')))
604
- sims = np.delete(sims, np.where(sims>np.percentile(sims, 99 ,method = 'midpoint')))
605
- else: #favors modelB's structure with details from A
606
- if stopmerge: return "STOPPED", *NON4
607
- sim = torch.nn.CosineSimilarity(dim=0)
608
- sims = np.array([], dtype=np.float64)
609
- for key in (tqdm(theta_0.keys(), desc="Stage 0/2")):
610
- # skip VAE model parameters to get better results
611
- if "first_stage_model" in key: continue
612
- if "model" in key and key in theta_1:
613
- simab = sim(theta_0[key].to(torch.float32), theta_1[key].to(torch.float32))
614
- dot_product = torch.dot(theta_0[key].view(-1).to(torch.float32), theta_1[key].view(-1).to(torch.float32))
615
- magnitude_similarity = dot_product / (torch.norm(theta_0[key].to(torch.float32)) * torch.norm(theta_1[key].to(torch.float32)))
616
- combined_similarity = (simab + magnitude_similarity) / 2.0
617
- sims = np.append(sims, combined_similarity.cpu().numpy())
618
- sims = sims[~np.isnan(sims)]
619
- sims = np.delete(sims, np.where(sims < np.percentile(sims, 1, method='midpoint')))
620
- sims = np.delete(sims, np.where(sims > np.percentile(sims, 99, method='midpoint')))
621
- return sim, sims
622
-
623
- def cosine(mode,key,sim,sims,current_alpha,theta_0,theta_1,num,block,uselerp):
624
- if "A" in mode: #favors modelA's structure with details from B
625
- # skip VAE model parameters to get better results
626
- if "model" in key and key in theta_0:
627
- # Normalize the vectors before merging
628
- theta_0_norm = nn.functional.normalize(theta_0[key].to(torch.float32), p=2, dim=0)
629
- theta_1_norm = nn.functional.normalize(theta_1[key].to(torch.float32), p=2, dim=0)
630
- simab = sim(theta_0_norm, theta_1_norm)
631
- dot_product = torch.dot(theta_0_norm.view(-1), theta_1_norm.view(-1))
632
- magnitude_similarity = dot_product / (torch.norm(theta_0_norm) * torch.norm(theta_1_norm))
633
- combined_similarity = (simab + magnitude_similarity) / 2.0
634
- k = (combined_similarity - sims.min()) / (sims.max() - sims.min())
635
- k = k - abs(current_alpha)
636
- k = k.clip(min=0,max=1.0)
637
- caster(f"{num}, {block}, model A[{key}] {1-k} + (model B)[{key}]*{k}",hear)
638
- if uselerp:
639
- theta_0[key] = lerp(theta_1[key].to(torch.float32), theta_0[key].to(torch.float32),k).to(theta_0[key].dtype)
640
- else:
641
- theta_0[key] = theta_1[key] * (1 - k) + theta_0[key] * k
642
-
643
- else: #favors modelB's structure with details from A
644
- # skip VAE model parameters to get better results
645
- if "model" in key and key in theta_0:
646
- simab = sim(theta_0[key].to(torch.float32), theta_1[key].to(torch.float32))
647
- dot_product = torch.dot(theta_0[key].view(-1).to(torch.float32), theta_1[key].view(-1).to(torch.float32))
648
- magnitude_similarity = dot_product / (torch.norm(theta_0[key].to(torch.float32)) * torch.norm(theta_1[key].to(torch.float32)))
649
- combined_similarity = (simab + magnitude_similarity) / 2.0
650
- k = (combined_similarity - sims.min()) / (sims.max() - sims.min())
651
- k = k - current_alpha
652
- k = k.clip(min=0,max=1.0)
653
- caster(f"{num}, {block}, model A[{key}] *{1-k} + (model B)[{key}]*{k}",hear)
654
- if uselerp:
655
- theta_0[key] = lerp(theta_1[key].to(torch.float32), theta_0[key].to(torch.float32),k).to(theta_0[key].dtype)
656
- else:
657
- theta_0[key] = theta_1[key] * (1 - k) + theta_0[key] * k
658
-
659
- ################################################
660
- ##### Traindiff
661
- def traindiff(key,current_alpha,theta_0,theta_1,theta_2):
662
- # Check if theta_1[key] is equal to theta_2[key]
663
- diff_AB = theta_1[key].float() - theta_2[key].float()
664
-
665
- distance_A0 = torch.abs(theta_1[key].float() - theta_2[key].float())
666
- distance_A1 = torch.abs(theta_1[key].float() - theta_0[key].float())
667
-
668
- sum_distances = distance_A0 + distance_A1
669
-
670
- scale = torch.where(sum_distances != 0, distance_A1 / sum_distances, torch.tensor(0.).float())
671
- sign_scale = torch.sign(theta_1[key].float() - theta_2[key].float())
672
- scale = sign_scale * torch.abs(scale)
673
-
674
- new_diff = scale * torch.abs(diff_AB)
675
- theta_0[key] = theta_0[key] + (new_diff * (current_alpha*1.8))
676
-
677
- ################################################
678
- ##### Extract
679
- # fix for python 3.9.16
680
- from typing import Union, Optional
681
- def extract_super(base: Optional[Tensor], a: Tensor, b: Tensor, alpha: float, beta: float, gamma: float) -> Tensor:
682
- # def extract_super(base: Tensor | None, a: Tensor, b: Tensor, alpha: float, beta: float, gamma: float) -> Tensor:
683
- assert base is None or base.shape == a.shape
684
- assert a.shape == b.shape
685
- assert 0 <= alpha <= 1
686
- assert 0 <= beta <= 1
687
- assert 0 <= gamma
688
- dtype = base.dtype if base is not None else a.dtype
689
- base = base.float() if base is not None else 0
690
- a = a.float() - base
691
- b = b.float() - base
692
- c = cosine_similarity(a, b, -1).clamp(-1, 1).unsqueeze(-1)
693
- d = ((c + 1) / 2) ** gamma
694
- result = base + lerp(a, b, alpha) * lerp(d, 1 - d, beta)
695
- return result.to(dtype)
696
-
697
- def extract(a: Tensor, b: Tensor, p: float, smoothness: float) -> Tensor:
698
- assert a.shape == b.shape
699
- assert 0 <= p <= 1
700
- assert 0 <= smoothness <= 1
701
-
702
- r = relu if smoothness == 0 else partial(softplus, beta=1 / smoothness)
703
- c = r(cosine_similarity(a, b, dim=-1)).unsqueeze(dim=-1).repeat_interleave(b.shape[-1], -1)
704
- m = torch.lerp(c, torch.ones_like(c) - c, p)
705
- return a * m
706
-
707
- ################################################
708
- ##### Tensor Merge
709
- def tensormerge(mode,key,dim, theta_0,theta_1,current_alpha,current_beta):
710
- if mode:
711
- if current_alpha+current_beta <= 1 :
712
- talphas = int(theta_0[key].shape[0]*(current_beta))
713
- talphae = int(theta_0[key].shape[0]*(current_alpha+current_beta))
714
- if dim == 1:
715
- theta_0[key][talphas:talphae] = theta_1[key][talphas:talphae].clone()
716
-
717
- elif dim == 2:
718
- theta_0[key][talphas:talphae,:] = theta_1[key][talphas:talphae,:].clone()
719
-
720
- elif dim == 3:
721
- theta_0[key][talphas:talphae,:,:] = theta_1[key][talphas:talphae,:,:].clone()
722
-
723
- elif dim == 4:
724
- theta_0[key][talphas:talphae,:,:,:] = theta_1[key][talphas:talphae,:,:,:].clone()
725
-
726
- else:
727
- talphas = int(theta_0[key].shape[0]*(current_alpha+current_beta-1))
728
- talphae = int(theta_0[key].shape[0]*(current_beta))
729
- theta_t = theta_1[key].clone()
730
- if dim == 1:
731
- theta_t[talphas:talphae] = theta_0[key][talphas:talphae].clone()
732
-
733
- elif dim == 2:
734
- theta_t[talphas:talphae,:] = theta_0[key][talphas:talphae,:].clone()
735
-
736
- elif dim == 3:
737
- theta_t[talphas:talphae,:,:] = theta_0[key][talphas:talphae,:,:].clone()
738
-
739
- elif dim == 4:
740
- theta_t[talphas:talphae,:,:,:] = theta_0[key][talphas:talphae,:,:,:].clone()
741
- theta_0[key] = theta_t
742
-
743
- else:
744
- if current_alpha+current_beta <= 1 :
745
- talphas = int(theta_0[key].shape[0]*(current_beta))
746
- talphae = int(theta_0[key].shape[0]*(current_alpha+current_beta))
747
- if dim > 1:
748
- if theta_0[key].shape[1] > 100:
749
- talphas = int(theta_0[key].shape[1]*(current_beta))
750
- talphae = int(theta_0[key].shape[1]*(current_alpha+current_beta))
751
- if dim == 1:
752
- theta_0[key][talphas:talphae] = theta_1[key][talphas:talphae].clone()
753
-
754
- elif dim == 2:
755
- theta_0[key][:,talphas:talphae] = theta_1[key][:,talphas:talphae].clone()
756
-
757
- elif dim == 3:
758
- theta_0[key][:,talphas:talphae,:] = theta_1[key][:,talphas:talphae,:].clone()
759
-
760
- elif dim == 4:
761
- theta_0[key][:,talphas:talphae,:,:] = theta_1[key][:,talphas:talphae,:,:].clone()
762
-
763
- else:
764
- talphas = int(theta_0[key].shape[0]*(current_alpha+current_beta-1))
765
- talphae = int(theta_0[key].shape[0]*(current_beta))
766
- theta_t = theta_1[key].clone()
767
- if dim > 1:
768
- if theta_0[key].shape[1] > 100:
769
- talphas = int(theta_0[key].shape[1]*(current_alpha+current_beta-1))
770
- talphae = int(theta_0[key].shape[1]*(current_beta))
771
- if dim == 1:
772
- theta_t[talphas:talphae] = theta_0[key][talphas:talphae].clone()
773
-
774
- elif dim == 2:
775
- theta_t[:,talphas:talphae] = theta_0[key][:,talphas:talphae].clone()
776
-
777
- elif dim == 3:
778
- theta_t[:,talphas:talphae,:] = theta_0[key][:,talphas:talphae,:].clone()
779
-
780
- elif dim == 4:
781
- theta_t[:,talphas:talphae,:,:] = theta_0[key][:,talphas:talphae,:,:].clone()
782
- theta_0[key] = theta_t
783
-
784
- ################################################
785
- ##### Multi Thread SmoothAdd
786
-
787
- def multithread_smoothadd(key_and_alpha, theta_0, theta_1, threads, tasks_per_thread, hear):
788
- lock_theta_0 = Lock()
789
- lock_theta_1 = Lock()
790
- lock_progress = Lock()
791
-
792
- def thread_callback(keys):
793
- nonlocal theta_0, theta_1
794
- if stopmerge:
795
- return False
796
-
797
- for key in keys:
798
- caster(f"model A[{key}] + {key_and_alpha[key]} + * (model B - model C)[{key}]", hear)
799
- filtered_diff = scipy.ndimage.median_filter(theta_1[key].to(torch.float32).cpu().numpy(), size=3)
800
- filtered_diff = scipy.ndimage.gaussian_filter(filtered_diff, sigma=1)
801
- with lock_theta_1:
802
- theta_1[key] = torch.tensor(filtered_diff)
803
- with lock_theta_0:
804
- theta_0[key] = theta_0[key] + key_and_alpha[key] * theta_1[key]
805
-
806
- with lock_progress:
807
- progress.update(len(keys))
808
-
809
- return True
810
-
811
- def extract_and_remove(input_list, count):
812
- extracted = input_list[:count]
813
- del input_list[:count]
814
-
815
- return extracted
816
-
817
- keys = list(key_and_alpha.keys())
818
-
819
- total_threads = ceil(len(keys) / int(tasks_per_thread))
820
- print(f"max threads = {threads}, total threads = {total_threads}, tasks per thread = {tasks_per_thread}")
821
-
822
- progress = tqdm(key_and_alpha.keys(), desc="smoothAdd MT")
823
-
824
- futures = []
825
- with ThreadPoolExecutor(max_workers=threads) as executor:
826
- futures = [executor.submit(thread_callback, extract_and_remove(keys, int(tasks_per_thread))) for i in range(total_threads)]
827
- for future in as_completed(futures):
828
- if not future.result():
829
- executor.shutdown()
830
- return theta_0, theta_1, True
831
- del progress
832
-
833
- return theta_0, theta_1, False
834
-
835
- ################################################
836
- ##### Elementals
837
- def elementals(key,weight_index,deep,randomer,num,lucks,deepprint,current_alpha):
838
- skey = key + BLOCKID[weight_index]
839
- for d in deep:
840
- if d.count(":") != 2 :continue
841
- dbs,dws,dr = d.split(":")[0],d.split(":")[1],d.split(":")[2]
842
- dbs = blocker(dbs,BLOCKID)
843
- dbs,dws = dbs.split(" "), dws.split(" ")
844
- dbn,dbs = (True,dbs[1:]) if dbs[0] == "NOT" else (False,dbs)
845
- dwn,dws = (True,dws[1:]) if dws[0] == "NOT" else (False,dws)
846
- flag = dbn
847
- for db in dbs:
848
- if db in skey:
849
- flag = not dbn
850
- if flag:flag = dwn
851
- else:continue
852
- for dw in dws:
853
- if dw in skey:
854
- flag = not dwn
855
- if flag:
856
- dr = eratiodealer(dr,randomer,weight_index,num,lucks)
857
- if deepprint :print(" ", dbs,dws,key,dr)
858
- current_alpha = dr
859
- return current_alpha
860
-
861
- def forkforker(filename,device):
862
- try:
863
- return sd_models.read_state_dict(filename,map_location = device)
864
- except:
865
- return sd_models.read_state_dict(filename)
866
-
867
- ################################################
868
- ##### Load Model
869
-
870
- def load_model_weights_m(model,abc,cachetarget,device):
871
- checkpoint_info = sd_models.get_closet_checkpoint_match(model)
872
- sd_model_name = checkpoint_info.model_name
873
-
874
- if checkpoint_info in modelcache:
875
- print(f"Loading weights [{sd_model_name}] from cache")
876
- return {k: v.to(device) for k, v in modelcache[checkpoint_info].items()}
877
- else:
878
- print(f"Loading weights [{sd_model_name}] from file")
879
- state_dict = forkforker(checkpoint_info.filename,device)
880
- if orig_cache >= abc:
881
- modelcache[checkpoint_info] = state_dict
882
- modelcache[checkpoint_info] = {k: v.to("cpu") for k, v in modelcache[checkpoint_info].items()}
883
- dontdelete = []
884
- for model in cachetarget:
885
- dontdelete.append(sd_models.get_closet_checkpoint_match(model))
886
- while len(modelcache) > orig_cache:
887
- for key in modelcache.keys():
888
- if key in dontdelete:continue
889
- modelcache.pop(key)
890
- break
891
- return state_dict
892
-
893
- def makemodelname(weights_a,weights_b,model_a, model_b,model_c, alpha,beta,useblocks,mode,calc):
894
- model_a=filenamecutter(model_a)
895
- model_b=filenamecutter(model_b)
896
- model_c=filenamecutter(model_c)
897
-
898
- if type(alpha) == str:alpha = float(alpha)
899
- if type(beta)== str:beta = float(beta)
900
-
901
- if useblocks:
902
- if MODES[1] in mode:#add
903
- currentmodel =f"{model_a} + ({model_b} - {model_c}) x alpha ({str(round(alpha,3))},{','.join(str(s) for s in weights_a)})"
904
- elif MODES[2] in mode:#triple
905
- currentmodel =f"{model_a} x (1-alpha-beta) + {model_b} x alpha + {model_c} x beta (alpha = {str(round(alpha,3))},{','.join(str(s) for s in weights_a)},beta = {beta},{','.join(str(s) for s in weights_b)})"
906
- elif MODES[3] in mode:#twice
907
- currentmodel =f"({model_a} x (1-alpha) + {model_b} x alpha)x(1-beta)+ {model_c} x beta ({str(round(alpha,3))},{','.join(str(s) for s in weights_a)})_({str(round(beta,3))},{','.join(str(s) for s in weights_b)})"
908
- else:
909
- currentmodel =f"{model_a} x (1-alpha) + {model_b} x alpha ({str(round(alpha,3))},{','.join(str(s) for s in weights_a)})"
910
- else:
911
- if MODES[1] in mode:#add
912
- currentmodel =f"{model_a} + ({model_b} - {model_c}) x {str(round(alpha,3))}"
913
- elif MODES[2] in mode:#triple
914
- currentmodel =f"{model_a} x {str(round(1-alpha-beta,3))} + {model_b} x {str(round(alpha,3))} + {model_c} x {str(round(beta,3))}"
915
- elif MODES[3] in mode:#twice
916
- currentmodel =f"({model_a} x {str(round(1-alpha,3))} +{model_b} x {str(round(alpha,3))}) x {str(round(1-beta,3))} + {model_c} x {str(round(beta,3))}"
917
- else:
918
- currentmodel =f"{model_a} x {str(round(1-alpha,3))} + {model_b} x {str(round(alpha,3))}"
919
- if calc != "normal":
920
- currentmodel = currentmodel + "_" + calc
921
- if calc == "tensor":
922
- currentmodel = currentmodel + f"_beta_{beta}"
923
- return currentmodel
924
-
925
- path_root = scripts.basedir()
926
-
927
-
928
- ################################################
929
- ##### Logging
930
-
931
- def rwmergelog(mergedname = "",settings= [],id = 0):
932
- # for compatible
933
- mode_info = {
934
- "Weight sum": "Weight sum:A*(1-alpha)+B*alpha",
935
- "Add difference": "Add difference:A+(B-C)*alpha",
936
- "Triple sum": "Triple sum:A*(1-alpha-beta)+B*alpha+C*beta",
937
- "sum Twice": "sum Twice:(A*(1-alpha)+B*alpha)*(1-beta)+C*beta",
938
- }
939
- setting = settings.copy()
940
- if len(setting) > 7 and setting[7] in mode_info:
941
- setting[7] = mode_info[setting[7]] # fix mode entry for compatible
942
- filepath = os.path.join(path_root, "mergehistory.csv")
943
- is_file = os.path.isfile(filepath)
944
- if not is_file:
945
- with open(filepath, 'a') as f:
946
- #msettings=[0 weights_a,1 weights_b,2 model_a,3 model_b,4 model_c,5 base_alpha,6 base_beta,7 mode,8 useblocks,9 custom_name,10 save_sets,11 id_sets, 12 deep 13 calcmode]
947
- f.writelines('"ID","time","name","weights alpha","weights beta","model A","model B","model C","alpha","beta","mode","use MBW","plus lora","custum name","save setting","use ID"\n')
948
- with open(filepath, 'r+') as f:
949
- reader = csv.reader(f)
950
- mlist = [raw for raw in reader]
951
- if mergedname != "":
952
- mergeid = len(mlist)
953
- setting.insert(0,mergedname)
954
- for i,x in enumerate(setting):
955
- if "," in str(x) or "\n" in str(x):setting[i] = f'"{str(setting[i])}"'
956
- text = ",".join(map(str, setting))
957
- text=str(mergeid)+","+datetime.datetime.now().strftime('%Y.%m.%d %H.%M.%S.%f')[:-7]+"," + text + "\n"
958
- f.writelines(text)
959
- return mergeid
960
- try:
961
- out = mlist[int(id)]
962
- except:
963
- out = "ERROR: OUT of ID index"
964
- return out
965
-
966
- def saveekeys(keyratio,modelid):
967
- import csv
968
- path_root = scripts.basedir()
969
- dir_path = os.path.join(path_root,"extensions","sd-webui-supermerger","scripts", "data")
970
-
971
- if not os.path.exists(dir_path):
972
- os.makedirs(dir_path)
973
-
974
- filepath = os.path.join(dir_path,f"{modelid}.csv")
975
-
976
- with open(filepath, 'w', newline='') as csvfile:
977
- writer = csv.writer(csvfile)
978
- writer.writerows(keyratio)
979
-
980
- def savestatics(modelid):
981
- for key in statistics.keys():
982
- result = [[tkey] + list(statistics[key][tkey]) for tkey in statistics[key].keys()]
983
- saveekeys(result,f"{modelid}_{key}")
984
-
985
- def get_font(fontsize):
986
- fontpath = os.path.join(scriptpath, "Roboto-Regular.ttf")
987
- try:
988
- return ImageFont.truetype(opts.font or fontpath, fontsize)
989
- except Exception:
990
- return ImageFont.truetype(fontpath, fontsize)
991
-
992
- def draw_origin(grid, text,width,height,width_one):
993
- grid_d= Image.new("RGB", (grid.width,grid.height), "white")
994
- grid_d.paste(grid,(0,0))
995
-
996
- d= ImageDraw.Draw(grid_d)
997
- color_active = (0, 0, 0)
998
- fontsize = (width+height)//25
999
- fnt = get_font(fontsize)
1000
-
1001
- if grid.width != width_one:
1002
- while d.multiline_textsize(text, font=fnt)[0] > width_one*0.75 and fontsize > 0:
1003
- fontsize -=1
1004
- fnt = get_font(fontsize)
1005
- d.multiline_text((0,0), text, font=fnt, fill=color_active,align="center")
1006
- return grid_d
1007
-
1008
- def wpreseter(w,presets):
1009
- if "," not in w and w != "":
1010
- presets=presets.splitlines()
1011
- wdict={}
1012
- for l in presets:
1013
- if ":" in l :
1014
- key = l.split(":",1)[0]
1015
- wdict[key.strip()]=l.split(":",1)[1]
1016
- if "\t" in l:
1017
- key = l.split("\t",1)[0]
1018
- wdict[key.strip()]=l.split("\t",1)[1]
1019
- if w.strip() in wdict:
1020
- name = w
1021
- w = wdict[w.strip()]
1022
- print(f"weights {name} imported from presets : {w}")
1023
- return w
1024
-
1025
- def fullpathfromname(name):
1026
- if hash == "" or hash ==[]: return ""
1027
- checkpoint_info = sd_models.get_closet_checkpoint_match(name)
1028
- return checkpoint_info.filename
1029
-
1030
- def namefromhash(hash):
1031
- if hash == "" or hash ==[]: return ""
1032
- checkpoint_info = sd_models.get_closet_checkpoint_match(hash)
1033
- return checkpoint_info.model_name
1034
-
1035
- def hashfromname(name):
1036
- from modules import sd_models
1037
- if name == "" or name ==[]: return ""
1038
- checkpoint_info = sd_models.get_closet_checkpoint_match(name)
1039
- if checkpoint_info.shorthash is not None:
1040
- return checkpoint_info.shorthash
1041
- return checkpoint_info.calculate_shorthash()
1042
-
1043
- def longhashfromname(name):
1044
- from modules import sd_models
1045
- if name == "" or name ==[]: return ""
1046
- checkpoint_info = sd_models.get_closet_checkpoint_match(name)
1047
- if checkpoint_info.sha256 is not None:
1048
- return checkpoint_info.sha256
1049
- checkpoint_info.calculate_shorthash()
1050
- return checkpoint_info.sha256
1051
-
1052
-
1053
- ################################################
1054
- ##### Random
1055
-
1056
- RANCHA = ["R","U","X"]
1057
-
1058
- def randdealer(w:str,randomer,ab,lucks,deep):
1059
- up,low = lucks["upp"],lucks["low"]
1060
- up,low = (up.split(","),low.split(","))
1061
- out = []
1062
- outd = {"R":[],"U":[],"X":[]}
1063
- add = RANDMAP[ab]
1064
- for i, r in enumerate (w.split(",")):
1065
- if r.strip() =="R":
1066
- out.append(str(round(randomer[i+add],lucks["round"])))
1067
- elif r.strip() == "U":
1068
- out.append(str(round(-2 * randomer[i+add] + 1.5,lucks["round"])))
1069
- elif r.strip() == "X":
1070
- out.append(str(round((float(low[i])-float(up[i]))* randomer[i+add] + float(up[i]),lucks["round"])))
1071
- elif "E" in r:
1072
- key = r.strip().replace("E","")
1073
- outd[key].append(BLOCKID[i])
1074
- out.append("0")
1075
- else:
1076
- out.append(r)
1077
- for key in outd.keys():
1078
- if outd[key] != []:
1079
- deep = deep + f",{' '.join(outd[key])}::{key}" if deep else f"{' '.join(outd[key])}::{key}"
1080
- return ",".join(out), deep
1081
-
1082
- def eratiodealer(dr,randomer,block,num,lucks):
1083
- if any(element in dr for element in RANCHA):
1084
- up,low = lucks["upp"],lucks["low"]
1085
- up,low = (up.split(","),low.split(","))
1086
- add = RANDMAP[2]
1087
- if dr.strip() =="R":
1088
- return round(randomer[num+add],lucks["round"])
1089
- elif dr.strip() == "U":
1090
- return round(-2 * randomer[num+add] + 1,lucks["round"])
1091
- elif dr.strip() == "X":
1092
- return round((float(low[block])-float(up[block]))* randomer[num+add] + float(up[block]),lucks["round"])
1093
- else:
1094
- return float(dr)
1095
-
1096
-
1097
- ################################################
1098
- ##### Generate Image
1099
-
1100
- def simggen(s_prompt,s_nprompt,s_steps,s_sampler,s_cfg,s_seed,s_w,s_h,s_batch_size,
1101
- genoptions,s_hrupscaler,s_hr2ndsteps,s_denois_str,s_hr_scale,
1102
- mergeinfo,id_sets,modelid,
1103
- *txt2imgparams,
1104
- debug = False
1105
- ):
1106
- shared.state.begin()
1107
- from scripts.mergers.components import paramsnames
1108
- if debug: print(paramsnames)
1109
-
1110
- #[None, 'Prompt', 'Negative prompt', 'Styles', 'Sampling steps', 'Sampling method', 'Batch count', 'Batch size', 'CFG Scale',
1111
- # 'Height', 'Width', 'Hires. fix', 'Denoising strength', 'Upscale by', 'Upscaler', 'Hires steps', 'Resize width to', 'Resize height to',
1112
- # 'Hires checkpoint', 'Hires sampling method', 'Hires prompt', 'Hires negative prompt', 'Override settings', 'Script', 'Refiner',
1113
- # 'Checkpoint', 'Switch at', 'Seed', 'Extra', 'Variation seed', 'Variation strength', 'Resize seed from width', 'Resize seed from height', '', 'Active', 'Active', 'X Types', 'X Values', 'Y Types', 'Y Values']
1114
-
1115
- def g(wanted,wantedv=None):
1116
- if wanted in paramsnames:return txt2imgparams[paramsnames.index(wanted)]
1117
- elif wantedv and wantedv in paramsnames:return txt2imgparams[paramsnames.index(wantedv)]
1118
- else:return None
1119
-
1120
- sampler_index = g("Sampling method")
1121
- if type(sampler_index) is str:
1122
- sampler_name = sampler_index
1123
- else:
1124
- sampler_name = sd_samplers.samplers[sampler_index].name
1125
-
1126
- hr_sampler_index = g("Hires sampling method")
1127
- if hr_sampler_index is None: hr_sampler_index = 0
1128
- if type(sampler_index) is str:
1129
- hr_sampler_name = hr_sampler_index
1130
- else:
1131
- hr_sampler_name = "Use same sampler" if hr_sampler_index == 0 else sd_samplers.samplers[hr_sampler_index+1].name
1132
-
1133
- p = processing.StableDiffusionProcessingTxt2Img(
1134
- sd_model=shared.sd_model,
1135
- outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
1136
- outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
1137
- prompt=g("Prompt"),
1138
- styles=g("Styles"),
1139
- negative_prompt=g('Negative prompt'),
1140
- seed=g("Seed"),
1141
- subseed=g("Variation seed"),
1142
- subseed_strength=g("Variation strength"),
1143
- seed_resize_from_h=g("Resize seed from height"),
1144
- seed_resize_from_w=g("Resize seed from width"),
1145
- seed_enable_extras=g("Extra"),
1146
- sampler_name=sampler_name,
1147
- batch_size=g("Batch size"),
1148
- n_iter=g("Batch count"),
1149
- steps=g("Sampling steps"),
1150
- cfg_scale=g("CFG Scale"),
1151
- width=g("Width"),
1152
- height=g("Height"),
1153
- restore_faces=g("Restore faces","Face restore"),
1154
- tiling=g("Tiling"),
1155
- enable_hr=g("Hires. fix","Second pass"),
1156
- hr_scale=g("Upscale by"),
1157
- hr_upscaler=g("Upscaler"),
1158
- hr_second_pass_steps=g("Hires steps","Secondary steps"),
1159
- hr_resize_x=g("Resize width to"),
1160
- hr_resize_y=g("Resize height to"),
1161
- override_settings=create_override_settings_dict(g("Override settings")),
1162
- do_not_save_grid=True,
1163
- do_not_save_samples=True,
1164
- do_not_reload_embeddings=True,
1165
- )
1166
- p.hr_checkpoint_name=None if g("Hires checkpoint") == 'Use same checkpoint' else g("Hires checkpoint")
1167
- p.hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name
1168
-
1169
- if s_sampler is None: s_sampler = 0
1170
-
1171
- if s_batch_size != 1 :p.batch_size = int(s_batch_size)
1172
- if s_prompt: p.prompt = s_prompt
1173
- if s_nprompt: p.negative_prompt = s_nprompt
1174
- if s_steps: p.steps = s_steps
1175
- if s_sampler: p.sampler_name = sampler_name
1176
- if s_cfg: p.cfg_scale = s_cfg
1177
- if s_seed: p.seed = s_seed
1178
- if s_w: p.width = s_w
1179
- if s_h: p.height = s_h
1180
-
1181
- if not p.cfg_scale: p.cfg_scale = 7
1182
-
1183
- p.scripts = scripts.scripts_txt2img
1184
- p.script_args = txt2imgparams[paramsnames.index("Override settings")+1:]
1185
-
1186
- p.denoising_strength=g("Denoising strength") if p.enable_hr else None
1187
-
1188
- p.hr_prompt=g("Hires prompt","Secondary Prompt")
1189
- p.hr_negative_prompt=g("Hires negative prompt","Secondary negative prompt")
1190
-
1191
- if "Hires. fix" in genoptions:
1192
- p.enable_hr = True
1193
- if s_hrupscaler: p.hr_upscaler = s_hrupscaler
1194
- if s_hr2ndsteps:p.hr_second_pass_steps = s_hr2ndsteps
1195
- if s_denois_str:p.denoising_strength = s_denois_str
1196
- if s_hr_scale:p.hr_scale = s_hr_scale
1197
-
1198
- if "Restore faces" in genoptions:
1199
- p.restore_faces = True
1200
-
1201
- if "Tiling" in genoptions:
1202
- p.tiling = True
1203
-
1204
- p.cached_c = [None,None]
1205
- p.cached_uc = [None,None]
1206
-
1207
- p.cached_hr_c = [None, None]
1208
- p.cached_hr_uc = [None, None]
1209
-
1210
- if type(p.prompt) == list:
1211
- p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
1212
- else:
1213
- p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
1214
-
1215
- if type(p.negative_prompt) == list:
1216
- p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
1217
- else:
1218
- p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
1219
-
1220
- processed:Processed = processing.process_images(p)
1221
- if "image" in id_sets:
1222
- for i, image in enumerate(processed.images):
1223
- processed.images[i] = draw_origin(image, str(modelid),p.width,p.height,p.width)
1224
-
1225
- if "PNG info" in id_sets:mergeinfo = mergeinfo + " ID " + str(modelid)
1226
-
1227
- infotext = create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds)
1228
- if infotext.count("Steps: ")>1:
1229
- infotext = infotext[:infotext.rindex("Steps")]
1230
-
1231
- infotexts = infotext.split(",")
1232
- for i,x in enumerate(infotexts):
1233
- if "Model:"in x:
1234
- infotexts[i] = " Model: "+mergeinfo.replace(","," ")
1235
- infotext= ",".join(infotexts)
1236
-
1237
- for i, image in enumerate(processed.images):
1238
- images.save_image(image, opts.outdir_txt2img_samples, "",p.seed, p.prompt,shared.opts.samples_format, p=p,info=infotext)
1239
-
1240
- if s_batch_size > 1:
1241
- grid = images.image_grid(processed.images, p.batch_size)
1242
- processed.images.insert(0, grid)
1243
- images.save_image(grid, opts.outdir_txt2img_grids, "grid", p.seed, p.prompt, opts.grid_format, info=infotext, short_filename=not opts.grid_extended_filename, p=p, grid=True)
1244
- shared.state.end()
1245
- return processed.images,infotext,plaintext_to_html(processed.info), plaintext_to_html(processed.comments),p
1246
-
1247
-
1248
- ################################################
1249
- ##### Block Ids
1250
-
1251
- def blocker(blocks,blockids):
1252
- blocks = blocks.split(" ")
1253
- output = ""
1254
- for w in blocks:
1255
- flagger=[False]*len(blockids)
1256
- changer = True
1257
- if "-" in w:
1258
- wt = [wt.strip() for wt in w.split('-')]
1259
- if blockids.index(wt[1]) > blockids.index(wt[0]):
1260
- flagger[blockids.index(wt[0]):blockids.index(wt[1])+1] = [changer]*(blockids.index(wt[1])-blockids.index(wt[0])+1)
1261
- else:
1262
- flagger[blockids.index(wt[1]):blockids.index(wt[0])+1] = [changer]*(blockids.index(wt[0])-blockids.index(wt[1])+1)
1263
- else:
1264
- output = output + " " + w if output else w
1265
- for i in range(len(blockids)):
1266
- if flagger[i]: output = output + " " + blockids[i] if output else blockids[i]
1267
- return output
1268
-
1269
-
1270
- def blockfromkey(key,isxl):
1271
- if not isxl:
1272
- re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12
1273
- re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1
1274
- re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12
1275
-
1276
- weight_index = -1
1277
-
1278
- NUM_INPUT_BLOCKS = 12
1279
- NUM_MID_BLOCK = 1
1280
- NUM_OUTPUT_BLOCKS = 12
1281
- NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS
1282
-
1283
- if 'time_embed' in key:
1284
- weight_index = -2 # before input blocks
1285
- elif '.out.' in key:
1286
- weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks
1287
- else:
1288
- m = re_inp.search(key)
1289
- if m:
1290
- inp_idx = int(m.groups()[0])
1291
- weight_index = inp_idx
1292
- else:
1293
- m = re_mid.search(key)
1294
- if m:
1295
- weight_index = NUM_INPUT_BLOCKS
1296
- else:
1297
- m = re_out.search(key)
1298
- if m:
1299
- out_idx = int(m.groups()[0])
1300
- weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx
1301
- return BLOCKID[weight_index+1] ,BLOCKID[weight_index+1]
1302
-
1303
- else:
1304
- if not ("weight" in key or "bias" in key):return "Not Merge","Not Merge"
1305
- if "label_emb" in key or "time_embed" in key: return "Not Merge","Not Merge"
1306
- if "conditioner.embedders" in key : return "BASE","BASE"
1307
- if "first_stage_model" in key : return "VAE","BASE"
1308
- if "model.diffusion_model" in key:
1309
- if "model.diffusion_model.out." in key: return "OUT8","OUT08"
1310
- block = re.findall(r'input|mid|output', key)
1311
- block = block[0].upper().replace("PUT","") if block else ""
1312
- nums = re.sub(r"\D", "", key)[:1 if "MID" in block else 2] + ("0" if "MID" in block else "")
1313
- add = re.findall(r"transformer_blocks\.(\d+)\.",key)[0] if "transformer" in key else ""
1314
- return block + nums + add, block + "0" + nums[0] if "MID" not in block else "M00"
1315
-
1316
- return "Not Merge", "Not Merge"
1317
-
1318
- ################################################
1319
- ##### Adjust
1320
-
1321
- def fineman(fine,isxl):
1322
- if fine.find(",") != -1:
1323
- tmp = [t.strip() for t in fine.split(",")]
1324
- fines = [0.0]*8
1325
- for i,f in enumerate(tmp[0:8]):
1326
- try:
1327
- f = float(f)
1328
- fines[i] = f
1329
- except Exception:
1330
- pass
1331
-
1332
- fine = fines
1333
- else:
1334
- return None
1335
-
1336
- fine = [
1337
- 1 - fine[0] * 0.01,
1338
- 1+ fine[0] * 0.02,
1339
- 1 - fine[1] * 0.01,
1340
- 1+ fine[1] * 0.02,
1341
- 1 - fine[2] * 0.01,
1342
- [fine[3]*0.02] + colorcalc(fine[4:8],isxl)
1343
- ]
1344
- return fine
1345
-
1346
- def colorcalc(cols,isxl):
1347
- colors = COLSXL if isxl else COLS
1348
- outs = [[y * cols[i] * 0.02 for y in x] for i,x in enumerate(colors)]
1349
- return [sum(x) for x in zip(*outs)]
1350
-
1351
- COLS = [[-1,1/3,2/3],[1,1,0],[0,-1,-1],[1,0,1]]
1352
- COLSXL = [[0,0,1],[1,0,0],[-1,-1,0],[-1,1,0]]
1353
-
1354
- def weighttoxl(weight):
1355
- weight = weight[:9] + weight[12:22] +[0]
1356
- return weight
1357
-
1358
- FINETUNES = [
1359
- "model.diffusion_model.input_blocks.0.0.weight",
1360
- "model.diffusion_model.input_blocks.0.0.bias",
1361
- "model.diffusion_model.out.0.weight",
1362
- "model.diffusion_model.out.0.bias",
1363
- "model.diffusion_model.out.2.weight",
1364
- "model.diffusion_model.out.2.bias",
1365
- ]
1366
-
1367
- ################################################
1368
- ##### Include/Exclude
1369
- def excluder(block:str,inex:bool,ex_blocks:list,ex_elems:list, key:str):
1370
- if ex_blocks == [] and ex_elems == [""]:
1371
- return False
1372
- out = True if inex == "Include" else False
1373
- if block in ex_blocks:out = not out
1374
- if "Adjust" in ex_blocks and key in FINETUNES:out = not out
1375
- for ke in ex_elems:
1376
- if ke != "" and ke in key:out = not out
1377
- if "VAE" in ex_blocks and "first_stage_model"in key:out = not out
1378
- if "print" in ex_blocks and (out ^ (inex == "Include")):
1379
- print("Include" if inex else "Exclude",block,ex_blocks,ex_elems,key)
1380
- return out
1381
-
1382
- ################################################
1383
- ##### Reset Broken CliP IDs
1384
-
1385
- def resetclip(theta):
1386
- idkey = "cond_stage_model.transformer.text_model.embeddings.position_ids"
1387
- broken = []
1388
- if idkey in theta.keys():
1389
- correct = torch.Tensor([list(range(77))]).to(torch.int64)
1390
- current = theta[idkey].to(torch.int64)
1391
-
1392
- broken = correct.ne(current)
1393
- broken = [i for i in range(77) if broken[0][i]]
1394
-
1395
- if broken != []: print("Clip IDs broken and fixed: ",broken)
1396
-
1397
- theta[idkey] = correct
1398
-
1399
-
1400
- ################################################
1401
- ##### cache
1402
- def cachedealer(start):
1403
- if start:
1404
- global orig_cache
1405
- orig_cache = shared.opts.sd_checkpoint_cache
1406
- shared.opts.sd_checkpoint_cache = 0
1407
- else:
1408
- shared.opts.sd_checkpoint_cache = orig_cache
1409
-
1410
- def clearcache():
1411
- global modelcache
1412
- del modelcache
1413
- modelcache = {}
1414
- gc.collect()
1415
- devices.torch_gc()
1416
-
1417
- def getcachelist():
1418
- output = []
1419
- for key in modelcache.keys():
1420
- if hasattr(key, "model_name"):
1421
- output.append(key.model_name)
1422
- return ",".join(output)
1423
-
1424
- ################################################
1425
- ##### print
1426
-
1427
- def printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks,fine,inex,ex_blocks,ex_elems):
1428
- print(f" model A \t: {model_a}")
1429
- print(f" model B \t: {model_b}")
1430
- print(f" model C \t: {model_c}")
1431
- print(f" alpha,beta\t: {base_alpha,base_beta}")
1432
- print(f" weights_alpha\t: {weights_a}")
1433
- print(f" weights_beta\t: {weights_b}")
1434
- print(f" mode\t\t: {mode}")
1435
- print(f" MBW \t\t: {useblocks}")
1436
- print(f" CalcMode \t: {calcmode}")
1437
- print(f" Elemental \t: {deep}")
1438
- print(f" Weights Seed\t: {lucks}")
1439
- print(f" {inex} \t: {ex_blocks,ex_elems}")
1440
- print(f" Adjust \t: {fine}")
1441
-
1442
- def caster(news,hear):
1443
- if hear: print(news)
1444
-
1445
- def casterr(*args,hear=hear):
1446
- if hear:
1447
- names = {id(v): k for k, v in currentframe().f_back.f_locals.items()}
1448
- print('\n'.join([names.get(id(arg), '???') + ' = ' + repr(arg) for arg in args]))