osmunphotography commited on
Commit
c723d2f
·
verified ·
1 Parent(s): 4a58f8f

Upload 10 files

Browse files
default_pipeline 2.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import modules.core as core
2
+ import os
3
+ import gc
4
+ import torch
5
+ import numpy as np
6
+ import modules.path
7
+ import modules.virtual_memory as virtual_memory
8
+ import comfy.model_management
9
+
10
+ from comfy.model_base import BaseModel, SDXL, SDXLRefiner
11
+ from modules.settings import default_settings
12
+ from modules.patch import set_comfy_adm_encoding, set_fooocus_adm_encoding, cfg_patched, patched_model_function
13
+ from modules.expansion import FooocusExpansion
14
+
15
+
16
+ xl_base: core.StableDiffusionModel = None
17
+ xl_base_hash = ''
18
+
19
+ xl_refiner: core.StableDiffusionModel = None
20
+ xl_refiner_hash = ''
21
+
22
+ xl_base_patched: core.StableDiffusionModel = None
23
+ xl_base_patched_hash = ''
24
+
25
+ clip_vision: core.StableDiffusionModel = None
26
+ clip_vision_hash = ''
27
+
28
+ controlnet_canny: core.StableDiffusionModel = None
29
+ controlnet_canny_hash = ''
30
+
31
+ controlnet_depth: core.StableDiffusionModel = None
32
+ controlnet_depth_hash = ''
33
+
34
+
35
+ @torch.no_grad()
36
+ @torch.inference_mode()
37
+ def refresh_base_model(name):
38
+ global xl_base, xl_base_hash, xl_base_patched, xl_base_patched_hash
39
+
40
+ filename = os.path.abspath(os.path.realpath(os.path.join(modules.path.modelfile_path, name)))
41
+ model_hash = filename
42
+
43
+ if xl_base_hash == model_hash:
44
+ return
45
+
46
+ if xl_base is not None:
47
+ xl_base = None
48
+
49
+ if xl_base_patched is not None:
50
+ xl_base_patched = None
51
+
52
+ xl_base = core.load_model(filename)
53
+ if not isinstance(xl_base.unet.model, BaseModel):
54
+ print(f'Model not supported: {name}, using default base model instead.')
55
+ xl_base = None
56
+ xl_base_hash = ''
57
+ refresh_base_model(modules.path.default_base_model_name)
58
+ xl_base_hash = model_hash
59
+ xl_base_patched = xl_base
60
+ xl_base_patched_hash = ''
61
+ return
62
+
63
+ if not isinstance(xl_base.unet.model, SDXL):
64
+ print('WARNING: loading non-SDXL base model.')
65
+
66
+ xl_base_hash = model_hash
67
+ xl_base_patched = xl_base
68
+ xl_base_patched_hash = ''
69
+ print(f'Base model loaded: {model_hash}')
70
+ return
71
+
72
+
73
+ def is_base_sdxl():
74
+ assert xl_base is not None
75
+ return isinstance(xl_base.unet.model, SDXL)
76
+
77
+
78
+ @torch.no_grad()
79
+ @torch.inference_mode()
80
+ def refresh_refiner_model(name):
81
+ global xl_refiner, xl_refiner_hash
82
+
83
+ filename = os.path.abspath(os.path.realpath(os.path.join(modules.path.modelfile_path, name)))
84
+ model_hash = filename
85
+
86
+ if xl_refiner_hash == model_hash:
87
+ return
88
+
89
+ if name == 'None':
90
+ xl_refiner = None
91
+ xl_refiner_hash = ''
92
+ print(f'Refiner unloaded.')
93
+ return
94
+
95
+ if xl_refiner is not None:
96
+ xl_refiner = None
97
+
98
+ xl_refiner = core.load_model(filename)
99
+ if not isinstance(xl_refiner.unet.model, SDXLRefiner):
100
+ print('Model not supported. Fooocus only support SDXL refiner as the refiner.')
101
+ xl_refiner = None
102
+ xl_refiner_hash = ''
103
+ print(f'Refiner unloaded.')
104
+ return
105
+
106
+ xl_refiner_hash = model_hash
107
+ print(f'Refiner model loaded: {model_hash}')
108
+
109
+ xl_refiner.vae = None
110
+ return
111
+
112
+
113
+ @torch.no_grad()
114
+ @torch.inference_mode()
115
+ def patch_base(loras, freeu, b1, b2, s1, s2):
116
+ global xl_base, xl_base_patched, xl_base_patched_hash
117
+ if xl_base_patched_hash == str(loras + [freeu, b1, b2, s1, s2]):
118
+ return
119
+
120
+ model = xl_base
121
+ for name, weight in loras:
122
+ if name == 'None':
123
+ continue
124
+
125
+ if os.path.exists(name):
126
+ filename = name
127
+ else:
128
+ filename = os.path.join(modules.path.lorafile_path, name)
129
+
130
+ assert os.path.exists(filename), 'Lora file not found!'
131
+
132
+ model = core.load_sd_lora(model, filename, strength_model=weight, strength_clip=weight)
133
+ if freeu:
134
+ xl_base_patched = core.freeu(model, b1, b2, s1, s2)
135
+ else:
136
+ xl_base_patched = model
137
+ xl_base_patched_hash = str(loras + [freeu, b1, b2, s1, s2])
138
+ print(f'LoRAs loaded: {loras}')
139
+ if freeu:
140
+ print(f'FreeU applied: {[b1, b2, s1, s2]}')
141
+
142
+ return
143
+
144
+
145
+ @torch.no_grad()
146
+ @torch.inference_mode()
147
+ def refresh_clip_vision():
148
+ global clip_vision, clip_vision_hash
149
+ if clip_vision_hash == str(clip_vision):
150
+ return
151
+
152
+ model_name = modules.path.default_clip_vision_name
153
+ filename = os.path.join(modules.path.clip_vision_path, model_name)
154
+ clip_vision = core.load_clip_vision(filename)
155
+
156
+ clip_vision_hash = model_name
157
+ print(f'CLIP Vision model loaded: {clip_vision_hash}')
158
+
159
+ return
160
+
161
+
162
+ @torch.no_grad()
163
+ @torch.inference_mode()
164
+ def refresh_controlnet_canny(name=None):
165
+ global controlnet_canny, controlnet_canny_hash
166
+ if controlnet_canny_hash == str(controlnet_canny):
167
+ return
168
+
169
+ model_name = modules.path.default_controlnet_canny_name if name == None else name
170
+ filename = os.path.join(modules.path.controlnet_path, model_name)
171
+ controlnet_canny = core.load_controlnet(filename)
172
+
173
+ controlnet_canny_hash = model_name
174
+ print(f'ControlNet model loaded: {controlnet_canny_hash}')
175
+
176
+ return
177
+
178
+
179
+
180
+ @torch.no_grad()
181
+ @torch.inference_mode()
182
+ def refresh_controlnet_depth(name=None):
183
+ global controlnet_depth, controlnet_depth_hash
184
+ if controlnet_depth_hash == str(controlnet_depth):
185
+ return
186
+
187
+ model_name = modules.path.default_controlnet_depth_name if name == None else name
188
+ filename = os.path.join(modules.path.controlnet_path, model_name)
189
+ controlnet_depth = core.load_controlnet(filename)
190
+
191
+ controlnet_depth_hash = model_name
192
+ print(f'ControlNet model loaded: {controlnet_depth_hash}')
193
+
194
+ return
195
+
196
+
197
+ @torch.no_grad()
198
+ @torch.inference_mode()
199
+ def set_clip_skips(base_clip_skip, refiner_clip_skip):
200
+ xl_base_patched.clip.clip_layer(base_clip_skip)
201
+ if xl_refiner is not None:
202
+ xl_refiner.clip.clip_layer(refiner_clip_skip)
203
+ return
204
+
205
+
206
+ @torch.no_grad()
207
+ @torch.inference_mode()
208
+ def apply_prompt_strength(base_cond, refiner_cond, prompt_strength=1.0):
209
+ if prompt_strength >= 0 and prompt_strength < 1.0:
210
+ base_cond = core.set_conditioning_strength(base_cond, prompt_strength)
211
+
212
+ if xl_refiner is not None:
213
+ if prompt_strength >= 0 and prompt_strength < 1.0:
214
+ refiner_cond = core.set_conditioning_strength(refiner_cond, prompt_strength)
215
+ else:
216
+ refiner_cond = None
217
+ return base_cond, refiner_cond
218
+
219
+
220
+ @torch.no_grad()
221
+ @torch.inference_mode()
222
+ def apply_revision(base_cond, revision=False, revision_strengths=[], clip_vision_outputs=[]):
223
+ if revision:
224
+ set_comfy_adm_encoding()
225
+ for i in range(len(clip_vision_outputs)):
226
+ if revision_strengths[i % 4] != 0:
227
+ base_cond = core.apply_adm(base_cond, clip_vision_outputs[i % 4], revision_strengths[i % 4], 0)
228
+ else:
229
+ set_fooocus_adm_encoding()
230
+ return base_cond
231
+
232
+
233
+ @torch.no_grad()
234
+ @torch.inference_mode()
235
+ def clip_encode_single(clip, text, verbose=False):
236
+ cached = clip.fcs_cond_cache.get(text, None)
237
+ if cached is not None:
238
+ if verbose:
239
+ print(f'[CLIP Cached] {text}')
240
+ return cached
241
+ tokens = clip.tokenize(text)
242
+ result = clip.encode_from_tokens(tokens, return_pooled=True)
243
+ clip.fcs_cond_cache[text] = result
244
+ if verbose:
245
+ print(f'[CLIP Encoded] {text}')
246
+ return result
247
+
248
+
249
+ @torch.no_grad()
250
+ @torch.inference_mode()
251
+ def clip_encode(sd, texts, pool_top_k=1):
252
+ if sd is None:
253
+ return None
254
+ if sd.clip is None:
255
+ return None
256
+ if not isinstance(texts, list):
257
+ return None
258
+ if len(texts) == 0:
259
+ return None
260
+
261
+ clip = sd.clip
262
+ cond_list = []
263
+ pooled_acc = 0
264
+
265
+ for i, text in enumerate(texts):
266
+ cond, pooled = clip_encode_single(clip, text)
267
+ cond_list.append(cond)
268
+ if i < pool_top_k:
269
+ pooled_acc += pooled
270
+
271
+ return [[torch.cat(cond_list, dim=1), {"pooled_output": pooled_acc}]]
272
+
273
+
274
+ @torch.no_grad()
275
+ @torch.inference_mode()
276
+ def clear_sd_cond_cache(sd):
277
+ if sd is None:
278
+ return None
279
+ if sd.clip is None:
280
+ return None
281
+ sd.clip.fcs_cond_cache = {}
282
+ return
283
+
284
+
285
+ @torch.no_grad()
286
+ @torch.inference_mode()
287
+ def clear_all_caches():
288
+ clear_sd_cond_cache(xl_base_patched)
289
+ clear_sd_cond_cache(xl_refiner)
290
+ gc.collect()
291
+ comfy.model_management.soft_empty_cache()
292
+
293
+
294
+ @torch.no_grad()
295
+ @torch.inference_mode()
296
+ def refresh_everything(refiner_model_name, base_model_name, loras, freeu, b1, b2, s1, s2):
297
+ refresh_refiner_model(refiner_model_name)
298
+ if xl_refiner is not None:
299
+ virtual_memory.try_move_to_virtual_memory(xl_refiner.unet.model)
300
+ virtual_memory.try_move_to_virtual_memory(xl_refiner.clip.cond_stage_model)
301
+
302
+ refresh_base_model(base_model_name)
303
+ virtual_memory.load_from_virtual_memory(xl_base.unet.model)
304
+
305
+ patch_base(loras, freeu, b1, b2, s1, s2)
306
+ clear_all_caches()
307
+ return
308
+
309
+
310
+ refresh_everything(
311
+ refiner_model_name=default_settings['refiner_model'],
312
+ base_model_name=default_settings['base_model'],
313
+ loras=[(default_settings['lora_1_model'], default_settings['lora_1_weight']),
314
+ (default_settings['lora_2_model'], default_settings['lora_2_weight']),
315
+ (default_settings['lora_3_model'], default_settings['lora_3_weight']),
316
+ (default_settings['lora_4_model'], default_settings['lora_4_weight']),
317
+ (default_settings['lora_5_model'], default_settings['lora_5_weight'])],
318
+ freeu=default_settings['freeu'],
319
+ b1=default_settings['freeu_b1'],
320
+ b2=default_settings['freeu_b2'],
321
+ s1=default_settings['freeu_s1'],
322
+ s2=default_settings['freeu_s2']
323
+ )
324
+
325
+ expansion = FooocusExpansion()
326
+
327
+
328
+ @torch.no_grad()
329
+ @torch.inference_mode()
330
+ def patch_all_models():
331
+ assert xl_base is not None
332
+ assert xl_base_patched is not None
333
+
334
+ xl_base.unet.model_options['sampler_cfg_function'] = cfg_patched
335
+ xl_base.unet.model_options['model_function_wrapper'] = patched_model_function
336
+
337
+ xl_base_patched.unet.model_options['sampler_cfg_function'] = cfg_patched
338
+ xl_base_patched.unet.model_options['model_function_wrapper'] = patched_model_function
339
+
340
+ if xl_refiner is not None:
341
+ xl_refiner.unet.model_options['sampler_cfg_function'] = cfg_patched
342
+ xl_refiner.unet.model_options['model_function_wrapper'] = patched_model_function
343
+
344
+ return
345
+
346
+
347
+ @torch.no_grad()
348
+ @torch.inference_mode()
349
+ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, sampler_name, scheduler, cfg, img2img, input_image, start_step,
350
+ control_lora_canny, canny_edge_low, canny_edge_high, canny_start, canny_stop, canny_strength,
351
+ control_lora_depth, depth_start, depth_stop, depth_strength, callback, latent=None, denoise=1.0, tiled=False):
352
+
353
+ patch_all_models()
354
+
355
+ if xl_refiner is not None:
356
+ virtual_memory.try_move_to_virtual_memory(xl_refiner.unet.model)
357
+ virtual_memory.load_from_virtual_memory(xl_base.unet.model)
358
+
359
+ if img2img and input_image != None:
360
+ initial_latent = core.encode_vae(vae=xl_base_patched.vae, pixels=input_image)
361
+ force_full_denoise = False
362
+ elif latent is None:
363
+ initial_latent = core.generate_empty_latent(width=width, height=height, batch_size=1)
364
+ force_full_denoise = True
365
+ else:
366
+ initial_latent = latent
367
+ force_full_denoise = False
368
+
369
+ positive_conditions = positive_cond[0]
370
+ negative_conditions = negative_cond[0]
371
+
372
+ if control_lora_canny and input_image != None:
373
+ edges_image = core.detect_edge(input_image, canny_edge_low, canny_edge_high)
374
+ positive_conditions, negative_conditions = core.apply_controlnet(positive_conditions, negative_conditions,
375
+ controlnet_canny, edges_image, canny_strength, canny_start, canny_stop)
376
+
377
+ if control_lora_depth and input_image != None:
378
+ positive_conditions, negative_conditions = core.apply_controlnet(positive_conditions, negative_conditions,
379
+ controlnet_depth, input_image, depth_strength, depth_start, depth_stop)
380
+
381
+ if xl_refiner is not None and is_base_sdxl():
382
+ positive_conditions_refiner = positive_cond[1]
383
+ negative_conditions_refiner = negative_cond[1]
384
+
385
+ sampled_latent = core.ksampler_with_refiner(
386
+ model=xl_base_patched.unet,
387
+ positive=positive_conditions,
388
+ negative=negative_conditions,
389
+ refiner=xl_refiner.unet,
390
+ refiner_positive=positive_conditions_refiner,
391
+ refiner_negative=negative_conditions_refiner,
392
+ refiner_switch_step=switch,
393
+ latent=initial_latent,
394
+ steps=steps, start_step=start_step, last_step=steps,
395
+ disable_noise=False, force_full_denoise=force_full_denoise, denoise=denoise,
396
+ seed=image_seed,
397
+ sampler_name=sampler_name,
398
+ scheduler=scheduler,
399
+ cfg=cfg,
400
+ callback_function=callback
401
+ )
402
+ else:
403
+ sampled_latent = core.ksampler(
404
+ model=xl_base_patched.unet,
405
+ positive=positive_conditions,
406
+ negative=negative_conditions,
407
+ latent=initial_latent,
408
+ steps=steps, start_step=start_step, last_step=steps,
409
+ disable_noise=False, force_full_denoise=force_full_denoise, denoise=denoise,
410
+ seed=image_seed,
411
+ sampler_name=sampler_name,
412
+ scheduler=scheduler,
413
+ cfg=cfg,
414
+ callback_function=callback
415
+ )
416
+
417
+ decoded_latent = core.decode_vae(vae=xl_base_patched.vae, latent_image=sampled_latent, tiled=tiled)
418
+ images = core.pytorch_to_numpy(decoded_latent)
419
+
420
+ return images
external_hypernetwork.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import ldm_patched.modules.utils
4
+ import ldm_patched.utils.path_utils
5
+ import torch
6
+
7
+ def load_hypernetwork_patch(path, strength):
8
+ sd = ldm_patched.modules.utils.load_torch_file(path, safe_load=True)
9
+ activation_func = sd.get('activation_func', 'linear')
10
+ is_layer_norm = sd.get('is_layer_norm', False)
11
+ use_dropout = sd.get('use_dropout', False)
12
+ activate_output = sd.get('activate_output', False)
13
+ last_layer_dropout = sd.get('last_layer_dropout', False)
14
+
15
+ valid_activation = {
16
+ "linear": torch.nn.Identity,
17
+ "relu": torch.nn.ReLU,
18
+ "leakyrelu": torch.nn.LeakyReLU,
19
+ "elu": torch.nn.ELU,
20
+ "swish": torch.nn.Hardswish,
21
+ "tanh": torch.nn.Tanh,
22
+ "sigmoid": torch.nn.Sigmoid,
23
+ "softsign": torch.nn.Softsign,
24
+ "mish": torch.nn.Mish,
25
+ }
26
+
27
+ if activation_func not in valid_activation:
28
+ print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)
29
+ return None
30
+
31
+ out = {}
32
+
33
+ for d in sd:
34
+ try:
35
+ dim = int(d)
36
+ except:
37
+ continue
38
+
39
+ output = []
40
+ for index in [0, 1]:
41
+ attn_weights = sd[dim][index]
42
+ keys = attn_weights.keys()
43
+
44
+ linears = filter(lambda a: a.endswith(".weight"), keys)
45
+ linears = list(map(lambda a: a[:-len(".weight")], linears))
46
+ layers = []
47
+
48
+ i = 0
49
+ while i < len(linears):
50
+ lin_name = linears[i]
51
+ last_layer = (i == (len(linears) - 1))
52
+ penultimate_layer = (i == (len(linears) - 2))
53
+
54
+ lin_weight = attn_weights['{}.weight'.format(lin_name)]
55
+ lin_bias = attn_weights['{}.bias'.format(lin_name)]
56
+ layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
57
+ layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
58
+ layers.append(layer)
59
+ if activation_func != "linear":
60
+ if (not last_layer) or (activate_output):
61
+ layers.append(valid_activation[activation_func]())
62
+ if is_layer_norm:
63
+ i += 1
64
+ ln_name = linears[i]
65
+ ln_weight = attn_weights['{}.weight'.format(ln_name)]
66
+ ln_bias = attn_weights['{}.bias'.format(ln_name)]
67
+ ln = torch.nn.LayerNorm(ln_weight.shape[0])
68
+ ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
69
+ layers.append(ln)
70
+ if use_dropout:
71
+ if (not last_layer) and (not penultimate_layer or last_layer_dropout):
72
+ layers.append(torch.nn.Dropout(p=0.3))
73
+ i += 1
74
+
75
+ output.append(torch.nn.Sequential(*layers))
76
+ out[dim] = torch.nn.ModuleList(output)
77
+
78
+ class hypernetwork_patch:
79
+ def __init__(self, hypernet, strength):
80
+ self.hypernet = hypernet
81
+ self.strength = strength
82
+ def __call__(self, q, k, v, extra_options):
83
+ dim = k.shape[-1]
84
+ if dim in self.hypernet:
85
+ hn = self.hypernet[dim]
86
+ k = k + hn[0](k) * self.strength
87
+ v = v + hn[1](v) * self.strength
88
+
89
+ return q, k, v
90
+
91
+ def to(self, device):
92
+ for d in self.hypernet.keys():
93
+ self.hypernet[d] = self.hypernet[d].to(device)
94
+ return self
95
+
96
+ return hypernetwork_patch(out, strength)
97
+
98
+ class HypernetworkLoader:
99
+ @classmethod
100
+ def INPUT_TYPES(s):
101
+ return {"required": { "model": ("MODEL",),
102
+ "hypernetwork_name": (ldm_patched.utils.path_utils.get_filename_list("hypernetworks"), ),
103
+ "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
104
+ }}
105
+ RETURN_TYPES = ("MODEL",)
106
+ FUNCTION = "load_hypernetwork"
107
+
108
+ CATEGORY = "loaders"
109
+
110
+ def load_hypernetwork(self, model, hypernetwork_name, strength):
111
+ hypernetwork_path = ldm_patched.utils.path_utils.get_full_path("hypernetworks", hypernetwork_name)
112
+ model_hypernetwork = model.clone()
113
+ patch = load_hypernetwork_patch(hypernetwork_path, strength)
114
+ if patch is not None:
115
+ model_hypernetwork.set_model_attn1_patch(patch)
116
+ model_hypernetwork.set_model_attn2_patch(patch)
117
+ return (model_hypernetwork,)
118
+
119
+ NODE_CLASS_MAPPINGS = {
120
+ "HypernetworkLoader": HypernetworkLoader
121
+ }
external_mask.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import numpy as np
4
+ import scipy.ndimage
5
+ import torch
6
+ import ldm_patched.modules.utils
7
+
8
+ from ldm_patched.contrib.external import MAX_RESOLUTION
9
+
10
+ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
11
+ source = source.to(destination.device)
12
+ if resize_source:
13
+ source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
14
+
15
+ source = ldm_patched.modules.utils.repeat_to_batch_size(source, destination.shape[0])
16
+
17
+ x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
18
+ y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))
19
+
20
+ left, top = (x // multiplier, y // multiplier)
21
+ right, bottom = (left + source.shape[3], top + source.shape[2],)
22
+
23
+ if mask is None:
24
+ mask = torch.ones_like(source)
25
+ else:
26
+ mask = mask.to(destination.device, copy=True)
27
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
28
+ mask = ldm_patched.modules.utils.repeat_to_batch_size(mask, source.shape[0])
29
+
30
+ # calculate the bounds of the source that will be overlapping the destination
31
+ # this prevents the source trying to overwrite latent pixels that are out of bounds
32
+ # of the destination
33
+ visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),)
34
+
35
+ mask = mask[:, :, :visible_height, :visible_width]
36
+ inverse_mask = torch.ones_like(mask) - mask
37
+
38
+ source_portion = mask * source[:, :, :visible_height, :visible_width]
39
+ destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
40
+
41
+ destination[:, :, top:bottom, left:right] = source_portion + destination_portion
42
+ return destination
43
+
44
+ class LatentCompositeMasked:
45
+ @classmethod
46
+ def INPUT_TYPES(s):
47
+ return {
48
+ "required": {
49
+ "destination": ("LATENT",),
50
+ "source": ("LATENT",),
51
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
52
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
53
+ "resize_source": ("BOOLEAN", {"default": False}),
54
+ },
55
+ "optional": {
56
+ "mask": ("MASK",),
57
+ }
58
+ }
59
+ RETURN_TYPES = ("LATENT",)
60
+ FUNCTION = "composite"
61
+
62
+ CATEGORY = "latent"
63
+
64
+ def composite(self, destination, source, x, y, resize_source, mask = None):
65
+ output = destination.copy()
66
+ destination = destination["samples"].clone()
67
+ source = source["samples"]
68
+ output["samples"] = composite(destination, source, x, y, mask, 8, resize_source)
69
+ return (output,)
70
+
71
+ class ImageCompositeMasked:
72
+ @classmethod
73
+ def INPUT_TYPES(s):
74
+ return {
75
+ "required": {
76
+ "destination": ("IMAGE",),
77
+ "source": ("IMAGE",),
78
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
79
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
80
+ "resize_source": ("BOOLEAN", {"default": False}),
81
+ },
82
+ "optional": {
83
+ "mask": ("MASK",),
84
+ }
85
+ }
86
+ RETURN_TYPES = ("IMAGE",)
87
+ FUNCTION = "composite"
88
+
89
+ CATEGORY = "image"
90
+
91
+ def composite(self, destination, source, x, y, resize_source, mask = None):
92
+ destination = destination.clone().movedim(-1, 1)
93
+ output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
94
+ return (output,)
95
+
96
+ class MaskToImage:
97
+ @classmethod
98
+ def INPUT_TYPES(s):
99
+ return {
100
+ "required": {
101
+ "mask": ("MASK",),
102
+ }
103
+ }
104
+
105
+ CATEGORY = "mask"
106
+
107
+ RETURN_TYPES = ("IMAGE",)
108
+ FUNCTION = "mask_to_image"
109
+
110
+ def mask_to_image(self, mask):
111
+ result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
112
+ return (result,)
113
+
114
+ class ImageToMask:
115
+ @classmethod
116
+ def INPUT_TYPES(s):
117
+ return {
118
+ "required": {
119
+ "image": ("IMAGE",),
120
+ "channel": (["red", "green", "blue", "alpha"],),
121
+ }
122
+ }
123
+
124
+ CATEGORY = "mask"
125
+
126
+ RETURN_TYPES = ("MASK",)
127
+ FUNCTION = "image_to_mask"
128
+
129
+ def image_to_mask(self, image, channel):
130
+ channels = ["red", "green", "blue", "alpha"]
131
+ mask = image[:, :, :, channels.index(channel)]
132
+ return (mask,)
133
+
134
+ class ImageColorToMask:
135
+ @classmethod
136
+ def INPUT_TYPES(s):
137
+ return {
138
+ "required": {
139
+ "image": ("IMAGE",),
140
+ "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
141
+ }
142
+ }
143
+
144
+ CATEGORY = "mask"
145
+
146
+ RETURN_TYPES = ("MASK",)
147
+ FUNCTION = "image_to_mask"
148
+
149
+ def image_to_mask(self, image, color):
150
+ temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
151
+ temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
152
+ mask = torch.where(temp == color, 255, 0).float()
153
+ return (mask,)
154
+
155
+ class SolidMask:
156
+ @classmethod
157
+ def INPUT_TYPES(cls):
158
+ return {
159
+ "required": {
160
+ "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
161
+ "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
162
+ "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
163
+ }
164
+ }
165
+
166
+ CATEGORY = "mask"
167
+
168
+ RETURN_TYPES = ("MASK",)
169
+
170
+ FUNCTION = "solid"
171
+
172
+ def solid(self, value, width, height):
173
+ out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu")
174
+ return (out,)
175
+
176
+ class InvertMask:
177
+ @classmethod
178
+ def INPUT_TYPES(cls):
179
+ return {
180
+ "required": {
181
+ "mask": ("MASK",),
182
+ }
183
+ }
184
+
185
+ CATEGORY = "mask"
186
+
187
+ RETURN_TYPES = ("MASK",)
188
+
189
+ FUNCTION = "invert"
190
+
191
+ def invert(self, mask):
192
+ out = 1.0 - mask
193
+ return (out,)
194
+
195
+ class CropMask:
196
+ @classmethod
197
+ def INPUT_TYPES(cls):
198
+ return {
199
+ "required": {
200
+ "mask": ("MASK",),
201
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
202
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
203
+ "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
204
+ "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
205
+ }
206
+ }
207
+
208
+ CATEGORY = "mask"
209
+
210
+ RETURN_TYPES = ("MASK",)
211
+
212
+ FUNCTION = "crop"
213
+
214
+ def crop(self, mask, x, y, width, height):
215
+ mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
216
+ out = mask[:, y:y + height, x:x + width]
217
+ return (out,)
218
+
219
+ class MaskComposite:
220
+ @classmethod
221
+ def INPUT_TYPES(cls):
222
+ return {
223
+ "required": {
224
+ "destination": ("MASK",),
225
+ "source": ("MASK",),
226
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
227
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
228
+ "operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
229
+ }
230
+ }
231
+
232
+ CATEGORY = "mask"
233
+
234
+ RETURN_TYPES = ("MASK",)
235
+
236
+ FUNCTION = "combine"
237
+
238
+ def combine(self, destination, source, x, y, operation):
239
+ output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
240
+ source = source.reshape((-1, source.shape[-2], source.shape[-1]))
241
+
242
+ left, top = (x, y,)
243
+ right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2]))
244
+ visible_width, visible_height = (right - left, bottom - top,)
245
+
246
+ source_portion = source[:, :visible_height, :visible_width]
247
+ destination_portion = destination[:, top:bottom, left:right]
248
+
249
+ if operation == "multiply":
250
+ output[:, top:bottom, left:right] = destination_portion * source_portion
251
+ elif operation == "add":
252
+ output[:, top:bottom, left:right] = destination_portion + source_portion
253
+ elif operation == "subtract":
254
+ output[:, top:bottom, left:right] = destination_portion - source_portion
255
+ elif operation == "and":
256
+ output[:, top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float()
257
+ elif operation == "or":
258
+ output[:, top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float()
259
+ elif operation == "xor":
260
+ output[:, top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float()
261
+
262
+ output = torch.clamp(output, 0.0, 1.0)
263
+
264
+ return (output,)
265
+
266
+ class FeatherMask:
267
+ @classmethod
268
+ def INPUT_TYPES(cls):
269
+ return {
270
+ "required": {
271
+ "mask": ("MASK",),
272
+ "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
273
+ "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
274
+ "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
275
+ "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
276
+ }
277
+ }
278
+
279
+ CATEGORY = "mask"
280
+
281
+ RETURN_TYPES = ("MASK",)
282
+
283
+ FUNCTION = "feather"
284
+
285
+ def feather(self, mask, left, top, right, bottom):
286
+ output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
287
+
288
+ left = min(left, output.shape[-1])
289
+ right = min(right, output.shape[-1])
290
+ top = min(top, output.shape[-2])
291
+ bottom = min(bottom, output.shape[-2])
292
+
293
+ for x in range(left):
294
+ feather_rate = (x + 1.0) / left
295
+ output[:, :, x] *= feather_rate
296
+
297
+ for x in range(right):
298
+ feather_rate = (x + 1) / right
299
+ output[:, :, -x] *= feather_rate
300
+
301
+ for y in range(top):
302
+ feather_rate = (y + 1) / top
303
+ output[:, y, :] *= feather_rate
304
+
305
+ for y in range(bottom):
306
+ feather_rate = (y + 1) / bottom
307
+ output[:, -y, :] *= feather_rate
308
+
309
+ return (output,)
310
+
311
+ class GrowMask:
312
+ @classmethod
313
+ def INPUT_TYPES(cls):
314
+ return {
315
+ "required": {
316
+ "mask": ("MASK",),
317
+ "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}),
318
+ "tapered_corners": ("BOOLEAN", {"default": True}),
319
+ },
320
+ }
321
+
322
+ CATEGORY = "mask"
323
+
324
+ RETURN_TYPES = ("MASK",)
325
+
326
+ FUNCTION = "expand_mask"
327
+
328
+ def expand_mask(self, mask, expand, tapered_corners):
329
+ c = 0 if tapered_corners else 1
330
+ kernel = np.array([[c, 1, c],
331
+ [1, 1, 1],
332
+ [c, 1, c]])
333
+ mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
334
+ out = []
335
+ for m in mask:
336
+ output = m.numpy()
337
+ for _ in range(abs(expand)):
338
+ if expand < 0:
339
+ output = scipy.ndimage.grey_erosion(output, footprint=kernel)
340
+ else:
341
+ output = scipy.ndimage.grey_dilation(output, footprint=kernel)
342
+ output = torch.from_numpy(output)
343
+ out.append(output)
344
+ return (torch.stack(out, dim=0),)
345
+
346
+
347
+
348
+ NODE_CLASS_MAPPINGS = {
349
+ "LatentCompositeMasked": LatentCompositeMasked,
350
+ "ImageCompositeMasked": ImageCompositeMasked,
351
+ "MaskToImage": MaskToImage,
352
+ "ImageToMask": ImageToMask,
353
+ "ImageColorToMask": ImageColorToMask,
354
+ "SolidMask": SolidMask,
355
+ "InvertMask": InvertMask,
356
+ "CropMask": CropMask,
357
+ "MaskComposite": MaskComposite,
358
+ "FeatherMask": FeatherMask,
359
+ "GrowMask": GrowMask,
360
+ }
361
+
362
+ NODE_DISPLAY_NAME_MAPPINGS = {
363
+ "ImageToMask": "Convert Image to Mask",
364
+ "MaskToImage": "Convert Mask to Image",
365
+ }
external_post_processing.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from PIL import Image
7
+ import math
8
+
9
+ import ldm_patched.modules.utils
10
+
11
+
12
+ class Blend:
13
+ def __init__(self):
14
+ pass
15
+
16
+ @classmethod
17
+ def INPUT_TYPES(s):
18
+ return {
19
+ "required": {
20
+ "image1": ("IMAGE",),
21
+ "image2": ("IMAGE",),
22
+ "blend_factor": ("FLOAT", {
23
+ "default": 0.5,
24
+ "min": 0.0,
25
+ "max": 1.0,
26
+ "step": 0.01
27
+ }),
28
+ "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],),
29
+ },
30
+ }
31
+
32
+ RETURN_TYPES = ("IMAGE",)
33
+ FUNCTION = "blend_images"
34
+
35
+ CATEGORY = "image/postprocessing"
36
+
37
+ def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
38
+ image2 = image2.to(image1.device)
39
+ if image1.shape != image2.shape:
40
+ image2 = image2.permute(0, 3, 1, 2)
41
+ image2 = ldm_patched.modules.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
42
+ image2 = image2.permute(0, 2, 3, 1)
43
+
44
+ blended_image = self.blend_mode(image1, image2, blend_mode)
45
+ blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
46
+ blended_image = torch.clamp(blended_image, 0, 1)
47
+ return (blended_image,)
48
+
49
+ def blend_mode(self, img1, img2, mode):
50
+ if mode == "normal":
51
+ return img2
52
+ elif mode == "multiply":
53
+ return img1 * img2
54
+ elif mode == "screen":
55
+ return 1 - (1 - img1) * (1 - img2)
56
+ elif mode == "overlay":
57
+ return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
58
+ elif mode == "soft_light":
59
+ return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
60
+ elif mode == "difference":
61
+ return img1 - img2
62
+ else:
63
+ raise ValueError(f"Unsupported blend mode: {mode}")
64
+
65
+ def g(self, x):
66
+ return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
67
+
68
+ def gaussian_kernel(kernel_size: int, sigma: float, device=None):
69
+ x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij")
70
+ d = torch.sqrt(x * x + y * y)
71
+ g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
72
+ return g / g.sum()
73
+
74
+ class Blur:
75
+ def __init__(self):
76
+ pass
77
+
78
+ @classmethod
79
+ def INPUT_TYPES(s):
80
+ return {
81
+ "required": {
82
+ "image": ("IMAGE",),
83
+ "blur_radius": ("INT", {
84
+ "default": 1,
85
+ "min": 1,
86
+ "max": 31,
87
+ "step": 1
88
+ }),
89
+ "sigma": ("FLOAT", {
90
+ "default": 1.0,
91
+ "min": 0.1,
92
+ "max": 10.0,
93
+ "step": 0.1
94
+ }),
95
+ },
96
+ }
97
+
98
+ RETURN_TYPES = ("IMAGE",)
99
+ FUNCTION = "blur"
100
+
101
+ CATEGORY = "image/postprocessing"
102
+
103
+ def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
104
+ if blur_radius == 0:
105
+ return (image,)
106
+
107
+ batch_size, height, width, channels = image.shape
108
+
109
+ kernel_size = blur_radius * 2 + 1
110
+ kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1)
111
+
112
+ image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
113
+ padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
114
+ blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
115
+ blurred = blurred.permute(0, 2, 3, 1)
116
+
117
+ return (blurred,)
118
+
119
+ class Quantize:
120
+ def __init__(self):
121
+ pass
122
+
123
+ @classmethod
124
+ def INPUT_TYPES(s):
125
+ return {
126
+ "required": {
127
+ "image": ("IMAGE",),
128
+ "colors": ("INT", {
129
+ "default": 256,
130
+ "min": 1,
131
+ "max": 256,
132
+ "step": 1
133
+ }),
134
+ "dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],),
135
+ },
136
+ }
137
+
138
+ RETURN_TYPES = ("IMAGE",)
139
+ FUNCTION = "quantize"
140
+
141
+ CATEGORY = "image/postprocessing"
142
+
143
+ def bayer(im, pal_im, order):
144
+ def normalized_bayer_matrix(n):
145
+ if n == 0:
146
+ return np.zeros((1,1), "float32")
147
+ else:
148
+ q = 4 ** n
149
+ m = q * normalized_bayer_matrix(n - 1)
150
+ return np.bmat(((m-1.5, m+0.5), (m+1.5, m-0.5))) / q
151
+
152
+ num_colors = len(pal_im.getpalette()) // 3
153
+ spread = 2 * 256 / num_colors
154
+ bayer_n = int(math.log2(order))
155
+ bayer_matrix = torch.from_numpy(spread * normalized_bayer_matrix(bayer_n) + 0.5)
156
+
157
+ result = torch.from_numpy(np.array(im).astype(np.float32))
158
+ tw = math.ceil(result.shape[0] / bayer_matrix.shape[0])
159
+ th = math.ceil(result.shape[1] / bayer_matrix.shape[1])
160
+ tiled_matrix = bayer_matrix.tile(tw, th).unsqueeze(-1)
161
+ result.add_(tiled_matrix[:result.shape[0],:result.shape[1]]).clamp_(0, 255)
162
+ result = result.to(dtype=torch.uint8)
163
+
164
+ im = Image.fromarray(result.cpu().numpy())
165
+ im = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
166
+ return im
167
+
168
+ def quantize(self, image: torch.Tensor, colors: int, dither: str):
169
+ batch_size, height, width, _ = image.shape
170
+ result = torch.zeros_like(image)
171
+
172
+ for b in range(batch_size):
173
+ im = Image.fromarray((image[b] * 255).to(torch.uint8).numpy(), mode='RGB')
174
+
175
+ pal_im = im.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
176
+
177
+ if dither == "none":
178
+ quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
179
+ elif dither == "floyd-steinberg":
180
+ quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.FLOYDSTEINBERG)
181
+ elif dither.startswith("bayer"):
182
+ order = int(dither.split('-')[-1])
183
+ quantized_image = Quantize.bayer(im, pal_im, order)
184
+
185
+ quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
186
+ result[b] = quantized_array
187
+
188
+ return (result,)
189
+
190
+ class Sharpen:
191
+ def __init__(self):
192
+ pass
193
+
194
+ @classmethod
195
+ def INPUT_TYPES(s):
196
+ return {
197
+ "required": {
198
+ "image": ("IMAGE",),
199
+ "sharpen_radius": ("INT", {
200
+ "default": 1,
201
+ "min": 1,
202
+ "max": 31,
203
+ "step": 1
204
+ }),
205
+ "sigma": ("FLOAT", {
206
+ "default": 1.0,
207
+ "min": 0.1,
208
+ "max": 10.0,
209
+ "step": 0.1
210
+ }),
211
+ "alpha": ("FLOAT", {
212
+ "default": 1.0,
213
+ "min": 0.0,
214
+ "max": 5.0,
215
+ "step": 0.1
216
+ }),
217
+ },
218
+ }
219
+
220
+ RETURN_TYPES = ("IMAGE",)
221
+ FUNCTION = "sharpen"
222
+
223
+ CATEGORY = "image/postprocessing"
224
+
225
+ def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float):
226
+ if sharpen_radius == 0:
227
+ return (image,)
228
+
229
+ batch_size, height, width, channels = image.shape
230
+
231
+ kernel_size = sharpen_radius * 2 + 1
232
+ kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
233
+ center = kernel_size // 2
234
+ kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
235
+ kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
236
+
237
+ tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
238
+ tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
239
+ sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
240
+ sharpened = sharpened.permute(0, 2, 3, 1)
241
+
242
+ result = torch.clamp(sharpened, 0, 1)
243
+
244
+ return (result,)
245
+
246
+ class ImageScaleToTotalPixels:
247
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
248
+ crop_methods = ["disabled", "center"]
249
+
250
+ @classmethod
251
+ def INPUT_TYPES(s):
252
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
253
+ "megapixels": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 16.0, "step": 0.01}),
254
+ }}
255
+ RETURN_TYPES = ("IMAGE",)
256
+ FUNCTION = "upscale"
257
+
258
+ CATEGORY = "image/upscaling"
259
+
260
+ def upscale(self, image, upscale_method, megapixels):
261
+ samples = image.movedim(-1,1)
262
+ total = int(megapixels * 1024 * 1024)
263
+
264
+ scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
265
+ width = round(samples.shape[3] * scale_by)
266
+ height = round(samples.shape[2] * scale_by)
267
+
268
+ s = ldm_patched.modules.utils.common_upscale(samples, width, height, upscale_method, "disabled")
269
+ s = s.movedim(1,-1)
270
+ return (s,)
271
+
272
+ NODE_CLASS_MAPPINGS = {
273
+ "ImageBlend": Blend,
274
+ "ImageBlur": Blur,
275
+ "ImageQuantize": Quantize,
276
+ "ImageSharpen": Sharpen,
277
+ "ImageScaleToTotalPixels": ImageScaleToTotalPixels,
278
+ }
external_sdupscale.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import torch
4
+ import ldm_patched.contrib.external
5
+ import ldm_patched.modules.utils
6
+
7
+ class SD_4XUpscale_Conditioning:
8
+ @classmethod
9
+ def INPUT_TYPES(s):
10
+ return {"required": { "images": ("IMAGE",),
11
+ "positive": ("CONDITIONING",),
12
+ "negative": ("CONDITIONING",),
13
+ "scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}),
14
+ "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
15
+ }}
16
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
17
+ RETURN_NAMES = ("positive", "negative", "latent")
18
+
19
+ FUNCTION = "encode"
20
+
21
+ CATEGORY = "conditioning/upscale_diffusion"
22
+
23
+ def encode(self, images, positive, negative, scale_ratio, noise_augmentation):
24
+ width = max(1, round(images.shape[-2] * scale_ratio))
25
+ height = max(1, round(images.shape[-3] * scale_ratio))
26
+
27
+ pixels = ldm_patched.modules.utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center")
28
+
29
+ out_cp = []
30
+ out_cn = []
31
+
32
+ for t in positive:
33
+ n = [t[0], t[1].copy()]
34
+ n[1]['concat_image'] = pixels
35
+ n[1]['noise_augmentation'] = noise_augmentation
36
+ out_cp.append(n)
37
+
38
+ for t in negative:
39
+ n = [t[0], t[1].copy()]
40
+ n[1]['concat_image'] = pixels
41
+ n[1]['noise_augmentation'] = noise_augmentation
42
+ out_cn.append(n)
43
+
44
+ latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
45
+ return (out_cp, out_cn, {"samples":latent})
46
+
47
+ NODE_CLASS_MAPPINGS = {
48
+ "SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning,
49
+ }
external_upscale_model.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import os
4
+ from ldm_patched.pfn import model_loading
5
+ from ldm_patched.modules import model_management
6
+ import torch
7
+ import ldm_patched.modules.utils
8
+ import ldm_patched.utils.path_utils
9
+
10
+ class UpscaleModelLoader:
11
+ @classmethod
12
+ def INPUT_TYPES(s):
13
+ return {"required": { "model_name": (ldm_patched.utils.path_utils.get_filename_list("upscale_models"), ),
14
+ }}
15
+ RETURN_TYPES = ("UPSCALE_MODEL",)
16
+ FUNCTION = "load_model"
17
+
18
+ CATEGORY = "loaders"
19
+
20
+ def load_model(self, model_name):
21
+ model_path = ldm_patched.utils.path_utils.get_full_path("upscale_models", model_name)
22
+ sd = ldm_patched.modules.utils.load_torch_file(model_path, safe_load=True)
23
+ if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
24
+ sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"module.":""})
25
+ out = model_loading.load_state_dict(sd).eval()
26
+ return (out, )
27
+
28
+
29
+ class ImageUpscaleWithModel:
30
+ @classmethod
31
+ def INPUT_TYPES(s):
32
+ return {"required": { "upscale_model": ("UPSCALE_MODEL",),
33
+ "image": ("IMAGE",),
34
+ }}
35
+ RETURN_TYPES = ("IMAGE",)
36
+ FUNCTION = "upscale"
37
+
38
+ CATEGORY = "image/upscaling"
39
+
40
+ def upscale(self, upscale_model, image):
41
+ device = model_management.get_torch_device()
42
+ upscale_model.to(device)
43
+ in_img = image.movedim(-1,-3).to(device)
44
+ free_memory = model_management.get_free_memory(device)
45
+
46
+ tile = 512
47
+ overlap = 32
48
+
49
+ oom = True
50
+ while oom:
51
+ try:
52
+ steps = in_img.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
53
+ pbar = ldm_patched.modules.utils.ProgressBar(steps)
54
+ s = ldm_patched.modules.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
55
+ oom = False
56
+ except model_management.OOM_EXCEPTION as e:
57
+ tile //= 2
58
+ if tile < 128:
59
+ raise e
60
+
61
+ upscale_model.cpu()
62
+ s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
63
+ return (s,)
64
+
65
+ NODE_CLASS_MAPPINGS = {
66
+ "UpscaleModelLoader": UpscaleModelLoader,
67
+ "ImageUpscaleWithModel": ImageUpscaleWithModel
68
+ }
html.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ css = '''
2
+ .loader-container {
3
+ display: flex; /* Use flex to align items horizontally */
4
+ align-items: center; /* Center items vertically within the container */
5
+ white-space: nowrap; /* Prevent line breaks within the container */
6
+ }
7
+
8
+ .loader {
9
+ border: 8px solid #f3f3f3; /* Light grey */
10
+ border-top: 8px solid #3498db; /* Blue */
11
+ border-radius: 50%;
12
+ width: 30px;
13
+ height: 30px;
14
+ animation: spin 2s linear infinite;
15
+ }
16
+
17
+ @keyframes spin {
18
+ 0% { transform: rotate(0deg); }
19
+ 100% { transform: rotate(360deg); }
20
+ }
21
+
22
+ /* Style the progress bar */
23
+ progress {
24
+ appearance: none; /* Remove default styling */
25
+ height: 20px; /* Set the height of the progress bar */
26
+ border-radius: 5px; /* Round the corners of the progress bar */
27
+ background-color: #f3f3f3; /* Light grey background */
28
+ width: 100%;
29
+ }
30
+
31
+ /* Style the progress bar container */
32
+ .progress-container {
33
+ margin-left: 20px;
34
+ margin-right: 20px;
35
+ flex-grow: 1; /* Allow the progress container to take up remaining space */
36
+ }
37
+
38
+ /* Set the color of the progress bar fill */
39
+ progress::-webkit-progress-value {
40
+ background-color: #3498db; /* Blue color for the fill */
41
+ }
42
+
43
+ progress::-moz-progress-bar {
44
+ background-color: #3498db; /* Blue color for the fill in Firefox */
45
+ }
46
+
47
+ /* Style the text on the progress bar */
48
+ progress::after {
49
+ content: attr(value '%'); /* Display the progress value followed by '%' */
50
+ position: absolute;
51
+ top: 50%;
52
+ left: 50%;
53
+ transform: translate(-50%, -50%);
54
+ color: white; /* Set text color */
55
+ font-size: 14px; /* Set font size */
56
+ }
57
+
58
+ /* Style other texts */
59
+ .loader-container > span {
60
+ margin-left: 5px; /* Add spacing between the progress bar and the text */
61
+ }
62
+
63
+ .progress-bar > .generating {
64
+ display: none !important;
65
+ }
66
+
67
+ .progress-bar{
68
+ height: 30px !important;
69
+ }
70
+
71
+ .type_row{
72
+ height: 96px !important;
73
+ }
74
+
75
+ .type_small_row{
76
+ height: 40px !important;
77
+ }
78
+
79
+ .scroll-hide{
80
+ resize: none !important;
81
+ }
82
+
83
+ .refresh_button{
84
+ border: none !important;
85
+ background: none !important;
86
+ font-size: none !important;
87
+ box-shadow: none !important;
88
+ }
89
+
90
+ .advanced_check_row{
91
+ width: 250px !important;
92
+ }
93
+
94
+ .min_check{
95
+ min-width: min(1px, 100%) !important;
96
+ }
97
+
98
+ '''
99
+ progress_html = '''
100
+ <div class="loader-container">
101
+ <div class="loader"></div>
102
+ <div class="progress-container">
103
+ <progress value="*number*" max="100"></progress>
104
+ </div>
105
+ <span>*text*</span>
106
+ </div>
107
+ '''
108
+
109
+
110
+ def make_progress_html(number, text):
111
+ return progress_html.replace('*number*', str(number)).replace('*text*', text)
paths-example.json.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "path_checkpoints": "../models/checkpoints/",
3
+ "path_loras": "../models/loras/",
4
+ "path_embeddings": "../models/embeddings/",
5
+ "path_clip_vision": "../models/clip_vision/",
6
+ "path_controlnet": "../models/controlnet/",
7
+ "path_vae_approx": "../models/vae_approx/",
8
+ "path_fooocus_expansion": "../models/prompt_expansion/fooocus_expansion/",
9
+ "path_upscale_models": "../models/upscale_models/",
10
+ "path_inpaint_models": "../models/inpaint/",
11
+ "path_styles": "../sdxl_styles/",
12
+ "path_wildcards": "../wildcards/",
13
+ "path_outputs": "../outputs/"
14
+ }
requirements_versions 2.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torchsde==0.2.5
2
+ einops==0.4.1
3
+ transformers==4.30.2
4
+ safetensors==0.3.1
5
+ accelerate==0.21.0
6
+ aiohttp==3.8.5
7
+ pyyaml==6.0
8
+ Pillow==9.4.0
9
+ scipy==1.9.3
10
+ tqdm==4.64.1
11
+ psutil==5.9.5
12
+ numpy==1.23.5
13
+ pytorch_lightning==1.9.4
14
+ omegaconf==2.2.3
15
+ gradio==3.41.2
16
+ pygit2==1.12.2
17
+ fastapi==0.94.0
settings-example.json.txt ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "advanced_mode": true,
3
+ "image_number": 1,
4
+ "save_metadata_json": true,
5
+ "save_metadata_image": true,
6
+ "output_format": "jpg",
7
+ "seed_random": false,
8
+ "same_seed_for_all": false,
9
+ "seed": 0,
10
+ "styles": ["Default (Slightly Cinematic)"],
11
+ "prompt_expansion": true,
12
+ "prompt": "",
13
+ "negative_prompt": "",
14
+ "performance": "Speed",
15
+ "custom_steps": 24,
16
+ "custom_switch": 0.75,
17
+ "img2img_mode": false,
18
+ "img2img_start_step": 0.06,
19
+ "img2img_denoise": 0.94,
20
+ "img2img_scale": 1.0,
21
+ "control_lora_canny": false,
22
+ "canny_edge_low": 0.2,
23
+ "canny_edge_high": 0.8,
24
+ "canny_start": 0.0,
25
+ "canny_stop": 0.4,
26
+ "canny_strength": 0.8,
27
+ "canny_model": "control-lora-canny-rank128.safetensors",
28
+ "control_lora_depth": false,
29
+ "depth_start": 0.0,
30
+ "depth_stop": 0.4,
31
+ "depth_strength": 0.8,
32
+ "depth_model": "control-lora-depth-rank128.safetensors",
33
+ "keep_input_names": false,
34
+ "revision": false,
35
+ "positive_prompt_strength": 1.0,
36
+ "negative_prompt_strength": 1.0,
37
+ "revision_strength_1": 1.0,
38
+ "revision_strength_2": 1.0,
39
+ "revision_strength_3": 1.0,
40
+ "revision_strength_4": 1.0,
41
+ "resolution": "1152�896 (9:7)",
42
+ "sampler": "dpmpp_2m_sde_gpu",
43
+ "scheduler": "karras",
44
+ "cfg": 7.0,
45
+ "base_clip_skip": -2,
46
+ "refiner_clip_skip": -2,
47
+ "sharpness": 2.0,
48
+ "base_model": "sd_xl_base_1.0_0.9vae.safetensors",
49
+ "refiner_model": "sd_xl_refiner_1.0_0.9vae.safetensors",
50
+ "lora_1_model": "sd_xl_offset_example-lora_1.0.safetensors",
51
+ "lora_1_weight": 0.5,
52
+ "lora_2_model": "None",
53
+ "lora_2_weight": 0.5,
54
+ "lora_3_model": "None",
55
+ "lora_3_weight": 0.5,
56
+ "lora_4_model": "None",
57
+ "lora_4_weight": 0.5,
58
+ "lora_5_model": "None",
59
+ "lora_5_weight": 0.5,
60
+ "freeu": false,
61
+ "freeu_b1": 1.01,
62
+ "freeu_b2": 1.02,
63
+ "freeu_s1": 0.99,
64
+ "freeu_s2": 0.95
65
+ }