MaykaGR commited on
Commit
bebd44f
·
verified ·
1 Parent(s): 91820dd

Upload 220 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. comfy_execution/caching.py +318 -0
  3. comfy_execution/graph.py +270 -0
  4. comfy_execution/graph_utils.py +139 -0
  5. comfy_execution/validation.py +39 -0
  6. comfy_extras/chainner_models/model_loading.py +6 -0
  7. comfy_extras/nodes_advanced_samplers.py +111 -0
  8. comfy_extras/nodes_align_your_steps.py +53 -0
  9. comfy_extras/nodes_attention_multiply.py +120 -0
  10. comfy_extras/nodes_audio.py +251 -0
  11. comfy_extras/nodes_canny.py +25 -0
  12. comfy_extras/nodes_clip_sdxl.py +54 -0
  13. comfy_extras/nodes_compositing.py +214 -0
  14. comfy_extras/nodes_cond.py +25 -0
  15. comfy_extras/nodes_controlnet.py +60 -0
  16. comfy_extras/nodes_cosmos.py +82 -0
  17. comfy_extras/nodes_custom_sampler.py +744 -0
  18. comfy_extras/nodes_differential_diffusion.py +42 -0
  19. comfy_extras/nodes_flux.py +63 -0
  20. comfy_extras/nodes_freelunch.py +113 -0
  21. comfy_extras/nodes_gits.py +369 -0
  22. comfy_extras/nodes_hooks.py +745 -0
  23. comfy_extras/nodes_hunyuan.py +44 -0
  24. comfy_extras/nodes_hypernetwork.py +120 -0
  25. comfy_extras/nodes_hypertile.py +81 -0
  26. comfy_extras/nodes_images.py +195 -0
  27. comfy_extras/nodes_ip2p.py +45 -0
  28. comfy_extras/nodes_latent.py +288 -0
  29. comfy_extras/nodes_load_3d.py +154 -0
  30. comfy_extras/nodes_lora_extract.py +119 -0
  31. comfy_extras/nodes_lt.py +184 -0
  32. comfy_extras/nodes_mahiro.py +41 -0
  33. comfy_extras/nodes_mask.py +382 -0
  34. comfy_extras/nodes_mochi.py +23 -0
  35. comfy_extras/nodes_model_advanced.py +306 -0
  36. comfy_extras/nodes_model_downscale.py +53 -0
  37. comfy_extras/nodes_model_merging.py +371 -0
  38. comfy_extras/nodes_model_merging_model_specific.py +209 -0
  39. comfy_extras/nodes_morphology.py +49 -0
  40. comfy_extras/nodes_pag.py +56 -0
  41. comfy_extras/nodes_perpneg.py +129 -0
  42. comfy_extras/nodes_photomaker.py +188 -0
  43. comfy_extras/nodes_pixart.py +24 -0
  44. comfy_extras/nodes_post_processing.py +279 -0
  45. comfy_extras/nodes_rebatch.py +138 -0
  46. comfy_extras/nodes_sag.py +181 -0
  47. comfy_extras/nodes_sd3.py +138 -0
  48. comfy_extras/nodes_sdupscale.py +46 -0
  49. comfy_extras/nodes_slg.py +84 -0
  50. comfy_extras/nodes_stable3d.py +143 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ web/assets/images/sad_girl.png filter=lfs diff=lfs merge=lfs -text
37
+ web/fonts/materialdesignicons-webfont.woff2 filter=lfs diff=lfs merge=lfs -text
comfy_execution/caching.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from typing import Sequence, Mapping, Dict
3
+ from comfy_execution.graph import DynamicPrompt
4
+
5
+ import nodes
6
+
7
+ from comfy_execution.graph_utils import is_link
8
+
9
+ NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
10
+
11
+
12
+ def include_unique_id_in_input(class_type: str) -> bool:
13
+ if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
14
+ return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
15
+ class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
16
+ NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
17
+ return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
18
+
19
+ class CacheKeySet:
20
+ def __init__(self, dynprompt, node_ids, is_changed_cache):
21
+ self.keys = {}
22
+ self.subcache_keys = {}
23
+
24
+ def add_keys(self, node_ids):
25
+ raise NotImplementedError()
26
+
27
+ def all_node_ids(self):
28
+ return set(self.keys.keys())
29
+
30
+ def get_used_keys(self):
31
+ return self.keys.values()
32
+
33
+ def get_used_subcache_keys(self):
34
+ return self.subcache_keys.values()
35
+
36
+ def get_data_key(self, node_id):
37
+ return self.keys.get(node_id, None)
38
+
39
+ def get_subcache_key(self, node_id):
40
+ return self.subcache_keys.get(node_id, None)
41
+
42
+ class Unhashable:
43
+ def __init__(self):
44
+ self.value = float("NaN")
45
+
46
+ def to_hashable(obj):
47
+ # So that we don't infinitely recurse since frozenset and tuples
48
+ # are Sequences.
49
+ if isinstance(obj, (int, float, str, bool, type(None))):
50
+ return obj
51
+ elif isinstance(obj, Mapping):
52
+ return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
53
+ elif isinstance(obj, Sequence):
54
+ return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
55
+ else:
56
+ # TODO - Support other objects like tensors?
57
+ return Unhashable()
58
+
59
+ class CacheKeySetID(CacheKeySet):
60
+ def __init__(self, dynprompt, node_ids, is_changed_cache):
61
+ super().__init__(dynprompt, node_ids, is_changed_cache)
62
+ self.dynprompt = dynprompt
63
+ self.add_keys(node_ids)
64
+
65
+ def add_keys(self, node_ids):
66
+ for node_id in node_ids:
67
+ if node_id in self.keys:
68
+ continue
69
+ if not self.dynprompt.has_node(node_id):
70
+ continue
71
+ node = self.dynprompt.get_node(node_id)
72
+ self.keys[node_id] = (node_id, node["class_type"])
73
+ self.subcache_keys[node_id] = (node_id, node["class_type"])
74
+
75
+ class CacheKeySetInputSignature(CacheKeySet):
76
+ def __init__(self, dynprompt, node_ids, is_changed_cache):
77
+ super().__init__(dynprompt, node_ids, is_changed_cache)
78
+ self.dynprompt = dynprompt
79
+ self.is_changed_cache = is_changed_cache
80
+ self.add_keys(node_ids)
81
+
82
+ def include_node_id_in_input(self) -> bool:
83
+ return False
84
+
85
+ def add_keys(self, node_ids):
86
+ for node_id in node_ids:
87
+ if node_id in self.keys:
88
+ continue
89
+ if not self.dynprompt.has_node(node_id):
90
+ continue
91
+ node = self.dynprompt.get_node(node_id)
92
+ self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
93
+ self.subcache_keys[node_id] = (node_id, node["class_type"])
94
+
95
+ def get_node_signature(self, dynprompt, node_id):
96
+ signature = []
97
+ ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
98
+ signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
99
+ for ancestor_id in ancestors:
100
+ signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
101
+ return to_hashable(signature)
102
+
103
+ def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
104
+ if not dynprompt.has_node(node_id):
105
+ # This node doesn't exist -- we can't cache it.
106
+ return [float("NaN")]
107
+ node = dynprompt.get_node(node_id)
108
+ class_type = node["class_type"]
109
+ class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
110
+ signature = [class_type, self.is_changed_cache.get(node_id)]
111
+ if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
112
+ signature.append(node_id)
113
+ inputs = node["inputs"]
114
+ for key in sorted(inputs.keys()):
115
+ if is_link(inputs[key]):
116
+ (ancestor_id, ancestor_socket) = inputs[key]
117
+ ancestor_index = ancestor_order_mapping[ancestor_id]
118
+ signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
119
+ else:
120
+ signature.append((key, inputs[key]))
121
+ return signature
122
+
123
+ # This function returns a list of all ancestors of the given node. The order of the list is
124
+ # deterministic based on which specific inputs the ancestor is connected by.
125
+ def get_ordered_ancestry(self, dynprompt, node_id):
126
+ ancestors = []
127
+ order_mapping = {}
128
+ self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
129
+ return ancestors, order_mapping
130
+
131
+ def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
132
+ if not dynprompt.has_node(node_id):
133
+ return
134
+ inputs = dynprompt.get_node(node_id)["inputs"]
135
+ input_keys = sorted(inputs.keys())
136
+ for key in input_keys:
137
+ if is_link(inputs[key]):
138
+ ancestor_id = inputs[key][0]
139
+ if ancestor_id not in order_mapping:
140
+ ancestors.append(ancestor_id)
141
+ order_mapping[ancestor_id] = len(ancestors) - 1
142
+ self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
143
+
144
+ class BasicCache:
145
+ def __init__(self, key_class):
146
+ self.key_class = key_class
147
+ self.initialized = False
148
+ self.dynprompt: DynamicPrompt
149
+ self.cache_key_set: CacheKeySet
150
+ self.cache = {}
151
+ self.subcaches = {}
152
+
153
+ def set_prompt(self, dynprompt, node_ids, is_changed_cache):
154
+ self.dynprompt = dynprompt
155
+ self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
156
+ self.is_changed_cache = is_changed_cache
157
+ self.initialized = True
158
+
159
+ def all_node_ids(self):
160
+ assert self.initialized
161
+ node_ids = self.cache_key_set.all_node_ids()
162
+ for subcache in self.subcaches.values():
163
+ node_ids = node_ids.union(subcache.all_node_ids())
164
+ return node_ids
165
+
166
+ def _clean_cache(self):
167
+ preserve_keys = set(self.cache_key_set.get_used_keys())
168
+ to_remove = []
169
+ for key in self.cache:
170
+ if key not in preserve_keys:
171
+ to_remove.append(key)
172
+ for key in to_remove:
173
+ del self.cache[key]
174
+
175
+ def _clean_subcaches(self):
176
+ preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
177
+
178
+ to_remove = []
179
+ for key in self.subcaches:
180
+ if key not in preserve_subcaches:
181
+ to_remove.append(key)
182
+ for key in to_remove:
183
+ del self.subcaches[key]
184
+
185
+ def clean_unused(self):
186
+ assert self.initialized
187
+ self._clean_cache()
188
+ self._clean_subcaches()
189
+
190
+ def _set_immediate(self, node_id, value):
191
+ assert self.initialized
192
+ cache_key = self.cache_key_set.get_data_key(node_id)
193
+ self.cache[cache_key] = value
194
+
195
+ def _get_immediate(self, node_id):
196
+ if not self.initialized:
197
+ return None
198
+ cache_key = self.cache_key_set.get_data_key(node_id)
199
+ if cache_key in self.cache:
200
+ return self.cache[cache_key]
201
+ else:
202
+ return None
203
+
204
+ def _ensure_subcache(self, node_id, children_ids):
205
+ subcache_key = self.cache_key_set.get_subcache_key(node_id)
206
+ subcache = self.subcaches.get(subcache_key, None)
207
+ if subcache is None:
208
+ subcache = BasicCache(self.key_class)
209
+ self.subcaches[subcache_key] = subcache
210
+ subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
211
+ return subcache
212
+
213
+ def _get_subcache(self, node_id):
214
+ assert self.initialized
215
+ subcache_key = self.cache_key_set.get_subcache_key(node_id)
216
+ if subcache_key in self.subcaches:
217
+ return self.subcaches[subcache_key]
218
+ else:
219
+ return None
220
+
221
+ def recursive_debug_dump(self):
222
+ result = []
223
+ for key in self.cache:
224
+ result.append({"key": key, "value": self.cache[key]})
225
+ for key in self.subcaches:
226
+ result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
227
+ return result
228
+
229
+ class HierarchicalCache(BasicCache):
230
+ def __init__(self, key_class):
231
+ super().__init__(key_class)
232
+
233
+ def _get_cache_for(self, node_id):
234
+ assert self.dynprompt is not None
235
+ parent_id = self.dynprompt.get_parent_node_id(node_id)
236
+ if parent_id is None:
237
+ return self
238
+
239
+ hierarchy = []
240
+ while parent_id is not None:
241
+ hierarchy.append(parent_id)
242
+ parent_id = self.dynprompt.get_parent_node_id(parent_id)
243
+
244
+ cache = self
245
+ for parent_id in reversed(hierarchy):
246
+ cache = cache._get_subcache(parent_id)
247
+ if cache is None:
248
+ return None
249
+ return cache
250
+
251
+ def get(self, node_id):
252
+ cache = self._get_cache_for(node_id)
253
+ if cache is None:
254
+ return None
255
+ return cache._get_immediate(node_id)
256
+
257
+ def set(self, node_id, value):
258
+ cache = self._get_cache_for(node_id)
259
+ assert cache is not None
260
+ cache._set_immediate(node_id, value)
261
+
262
+ def ensure_subcache_for(self, node_id, children_ids):
263
+ cache = self._get_cache_for(node_id)
264
+ assert cache is not None
265
+ return cache._ensure_subcache(node_id, children_ids)
266
+
267
+ class LRUCache(BasicCache):
268
+ def __init__(self, key_class, max_size=100):
269
+ super().__init__(key_class)
270
+ self.max_size = max_size
271
+ self.min_generation = 0
272
+ self.generation = 0
273
+ self.used_generation = {}
274
+ self.children = {}
275
+
276
+ def set_prompt(self, dynprompt, node_ids, is_changed_cache):
277
+ super().set_prompt(dynprompt, node_ids, is_changed_cache)
278
+ self.generation += 1
279
+ for node_id in node_ids:
280
+ self._mark_used(node_id)
281
+
282
+ def clean_unused(self):
283
+ while len(self.cache) > self.max_size and self.min_generation < self.generation:
284
+ self.min_generation += 1
285
+ to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
286
+ for key in to_remove:
287
+ del self.cache[key]
288
+ del self.used_generation[key]
289
+ if key in self.children:
290
+ del self.children[key]
291
+ self._clean_subcaches()
292
+
293
+ def get(self, node_id):
294
+ self._mark_used(node_id)
295
+ return self._get_immediate(node_id)
296
+
297
+ def _mark_used(self, node_id):
298
+ cache_key = self.cache_key_set.get_data_key(node_id)
299
+ if cache_key is not None:
300
+ self.used_generation[cache_key] = self.generation
301
+
302
+ def set(self, node_id, value):
303
+ self._mark_used(node_id)
304
+ return self._set_immediate(node_id, value)
305
+
306
+ def ensure_subcache_for(self, node_id, children_ids):
307
+ # Just uses subcaches for tracking 'live' nodes
308
+ super()._ensure_subcache(node_id, children_ids)
309
+
310
+ self.cache_key_set.add_keys(children_ids)
311
+ self._mark_used(node_id)
312
+ cache_key = self.cache_key_set.get_data_key(node_id)
313
+ self.children[cache_key] = []
314
+ for child_id in children_ids:
315
+ self._mark_used(child_id)
316
+ self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
317
+ return self
318
+
comfy_execution/graph.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nodes
2
+
3
+ from comfy_execution.graph_utils import is_link
4
+
5
+ class DependencyCycleError(Exception):
6
+ pass
7
+
8
+ class NodeInputError(Exception):
9
+ pass
10
+
11
+ class NodeNotFoundError(Exception):
12
+ pass
13
+
14
+ class DynamicPrompt:
15
+ def __init__(self, original_prompt):
16
+ # The original prompt provided by the user
17
+ self.original_prompt = original_prompt
18
+ # Any extra pieces of the graph created during execution
19
+ self.ephemeral_prompt = {}
20
+ self.ephemeral_parents = {}
21
+ self.ephemeral_display = {}
22
+
23
+ def get_node(self, node_id):
24
+ if node_id in self.ephemeral_prompt:
25
+ return self.ephemeral_prompt[node_id]
26
+ if node_id in self.original_prompt:
27
+ return self.original_prompt[node_id]
28
+ raise NodeNotFoundError(f"Node {node_id} not found")
29
+
30
+ def has_node(self, node_id):
31
+ return node_id in self.original_prompt or node_id in self.ephemeral_prompt
32
+
33
+ def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
34
+ self.ephemeral_prompt[node_id] = node_info
35
+ self.ephemeral_parents[node_id] = parent_id
36
+ self.ephemeral_display[node_id] = display_id
37
+
38
+ def get_real_node_id(self, node_id):
39
+ while node_id in self.ephemeral_parents:
40
+ node_id = self.ephemeral_parents[node_id]
41
+ return node_id
42
+
43
+ def get_parent_node_id(self, node_id):
44
+ return self.ephemeral_parents.get(node_id, None)
45
+
46
+ def get_display_node_id(self, node_id):
47
+ while node_id in self.ephemeral_display:
48
+ node_id = self.ephemeral_display[node_id]
49
+ return node_id
50
+
51
+ def all_node_ids(self):
52
+ return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
53
+
54
+ def get_original_prompt(self):
55
+ return self.original_prompt
56
+
57
+ def get_input_info(class_def, input_name, valid_inputs=None):
58
+ valid_inputs = valid_inputs or class_def.INPUT_TYPES()
59
+ input_info = None
60
+ input_category = None
61
+ if "required" in valid_inputs and input_name in valid_inputs["required"]:
62
+ input_category = "required"
63
+ input_info = valid_inputs["required"][input_name]
64
+ elif "optional" in valid_inputs and input_name in valid_inputs["optional"]:
65
+ input_category = "optional"
66
+ input_info = valid_inputs["optional"][input_name]
67
+ elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]:
68
+ input_category = "hidden"
69
+ input_info = valid_inputs["hidden"][input_name]
70
+ if input_info is None:
71
+ return None, None, None
72
+ input_type = input_info[0]
73
+ if len(input_info) > 1:
74
+ extra_info = input_info[1]
75
+ else:
76
+ extra_info = {}
77
+ return input_type, input_category, extra_info
78
+
79
+ class TopologicalSort:
80
+ def __init__(self, dynprompt):
81
+ self.dynprompt = dynprompt
82
+ self.pendingNodes = {}
83
+ self.blockCount = {} # Number of nodes this node is directly blocked by
84
+ self.blocking = {} # Which nodes are blocked by this node
85
+
86
+ def get_input_info(self, unique_id, input_name):
87
+ class_type = self.dynprompt.get_node(unique_id)["class_type"]
88
+ class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
89
+ return get_input_info(class_def, input_name)
90
+
91
+ def make_input_strong_link(self, to_node_id, to_input):
92
+ inputs = self.dynprompt.get_node(to_node_id)["inputs"]
93
+ if to_input not in inputs:
94
+ raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all")
95
+ value = inputs[to_input]
96
+ if not is_link(value):
97
+ raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant")
98
+ from_node_id, from_socket = value
99
+ self.add_strong_link(from_node_id, from_socket, to_node_id)
100
+
101
+ def add_strong_link(self, from_node_id, from_socket, to_node_id):
102
+ if not self.is_cached(from_node_id):
103
+ self.add_node(from_node_id)
104
+ if to_node_id not in self.blocking[from_node_id]:
105
+ self.blocking[from_node_id][to_node_id] = {}
106
+ self.blockCount[to_node_id] += 1
107
+ self.blocking[from_node_id][to_node_id][from_socket] = True
108
+
109
+ def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
110
+ node_ids = [node_unique_id]
111
+ links = []
112
+
113
+ while len(node_ids) > 0:
114
+ unique_id = node_ids.pop()
115
+ if unique_id in self.pendingNodes:
116
+ continue
117
+
118
+ self.pendingNodes[unique_id] = True
119
+ self.blockCount[unique_id] = 0
120
+ self.blocking[unique_id] = {}
121
+
122
+ inputs = self.dynprompt.get_node(unique_id)["inputs"]
123
+ for input_name in inputs:
124
+ value = inputs[input_name]
125
+ if is_link(value):
126
+ from_node_id, from_socket = value
127
+ if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
128
+ continue
129
+ input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
130
+ is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
131
+ if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
132
+ node_ids.append(from_node_id)
133
+ links.append((from_node_id, from_socket, unique_id))
134
+
135
+ for link in links:
136
+ self.add_strong_link(*link)
137
+
138
+ def is_cached(self, node_id):
139
+ return False
140
+
141
+ def get_ready_nodes(self):
142
+ return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
143
+
144
+ def pop_node(self, unique_id):
145
+ del self.pendingNodes[unique_id]
146
+ for blocked_node_id in self.blocking[unique_id]:
147
+ self.blockCount[blocked_node_id] -= 1
148
+ del self.blocking[unique_id]
149
+
150
+ def is_empty(self):
151
+ return len(self.pendingNodes) == 0
152
+
153
+ class ExecutionList(TopologicalSort):
154
+ """
155
+ ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
156
+ it can still be returned to the graph after having further dependencies added.
157
+ """
158
+ def __init__(self, dynprompt, output_cache):
159
+ super().__init__(dynprompt)
160
+ self.output_cache = output_cache
161
+ self.staged_node_id = None
162
+
163
+ def is_cached(self, node_id):
164
+ return self.output_cache.get(node_id) is not None
165
+
166
+ def stage_node_execution(self):
167
+ assert self.staged_node_id is None
168
+ if self.is_empty():
169
+ return None, None, None
170
+ available = self.get_ready_nodes()
171
+ if len(available) == 0:
172
+ cycled_nodes = self.get_nodes_in_cycle()
173
+ # Because cycles composed entirely of static nodes are caught during initial validation,
174
+ # we will 'blame' the first node in the cycle that is not a static node.
175
+ blamed_node = cycled_nodes[0]
176
+ for node_id in cycled_nodes:
177
+ display_node_id = self.dynprompt.get_display_node_id(node_id)
178
+ if display_node_id != node_id:
179
+ blamed_node = display_node_id
180
+ break
181
+ ex = DependencyCycleError("Dependency cycle detected")
182
+ error_details = {
183
+ "node_id": blamed_node,
184
+ "exception_message": str(ex),
185
+ "exception_type": "graph.DependencyCycleError",
186
+ "traceback": [],
187
+ "current_inputs": []
188
+ }
189
+ return None, error_details, ex
190
+
191
+ self.staged_node_id = self.ux_friendly_pick_node(available)
192
+ return self.staged_node_id, None, None
193
+
194
+ def ux_friendly_pick_node(self, node_list):
195
+ # If an output node is available, do that first.
196
+ # Technically this has no effect on the overall length of execution, but it feels better as a user
197
+ # for a PreviewImage to display a result as soon as it can
198
+ # Some other heuristics could probably be used here to improve the UX further.
199
+ def is_output(node_id):
200
+ class_type = self.dynprompt.get_node(node_id)["class_type"]
201
+ class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
202
+ if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
203
+ return True
204
+ return False
205
+
206
+ for node_id in node_list:
207
+ if is_output(node_id):
208
+ return node_id
209
+
210
+ #This should handle the VAEDecode -> preview case
211
+ for node_id in node_list:
212
+ for blocked_node_id in self.blocking[node_id]:
213
+ if is_output(blocked_node_id):
214
+ return node_id
215
+
216
+ #This should handle the VAELoader -> VAEDecode -> preview case
217
+ for node_id in node_list:
218
+ for blocked_node_id in self.blocking[node_id]:
219
+ for blocked_node_id1 in self.blocking[blocked_node_id]:
220
+ if is_output(blocked_node_id1):
221
+ return node_id
222
+
223
+ #TODO: this function should be improved
224
+ return node_list[0]
225
+
226
+ def unstage_node_execution(self):
227
+ assert self.staged_node_id is not None
228
+ self.staged_node_id = None
229
+
230
+ def complete_node_execution(self):
231
+ node_id = self.staged_node_id
232
+ self.pop_node(node_id)
233
+ self.staged_node_id = None
234
+
235
+ def get_nodes_in_cycle(self):
236
+ # We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle.
237
+ # We're skipping some of the performance optimizations from the original TopologicalSort to keep
238
+ # the code simple (and because having a cycle in the first place is a catastrophic error)
239
+ blocked_by = { node_id: {} for node_id in self.pendingNodes }
240
+ for from_node_id in self.blocking:
241
+ for to_node_id in self.blocking[from_node_id]:
242
+ if True in self.blocking[from_node_id][to_node_id].values():
243
+ blocked_by[to_node_id][from_node_id] = True
244
+ to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
245
+ while len(to_remove) > 0:
246
+ for node_id in to_remove:
247
+ for to_node_id in blocked_by:
248
+ if node_id in blocked_by[to_node_id]:
249
+ del blocked_by[to_node_id][node_id]
250
+ del blocked_by[node_id]
251
+ to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
252
+ return list(blocked_by.keys())
253
+
254
+ class ExecutionBlocker:
255
+ """
256
+ Return this from a node and any users will be blocked with the given error message.
257
+ If the message is None, execution will be blocked silently instead.
258
+ Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
259
+ possible, a lazy input will be more efficient and have a better user experience.
260
+ This functionality is useful in two cases:
261
+ 1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
262
+ like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
263
+ lazy evaluation to let it conditionally disable itself.)
264
+ 2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
265
+ (I would recommend not making nodes like this in the future -- instead, make multiple nodes with
266
+ different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
267
+ """
268
+ def __init__(self, message):
269
+ self.message = message
270
+
comfy_execution/graph_utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def is_link(obj):
2
+ if not isinstance(obj, list):
3
+ return False
4
+ if len(obj) != 2:
5
+ return False
6
+ if not isinstance(obj[0], str):
7
+ return False
8
+ if not isinstance(obj[1], int) and not isinstance(obj[1], float):
9
+ return False
10
+ return True
11
+
12
+ # The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
13
+ class GraphBuilder:
14
+ _default_prefix_root = ""
15
+ _default_prefix_call_index = 0
16
+ _default_prefix_graph_index = 0
17
+
18
+ def __init__(self, prefix = None):
19
+ if prefix is None:
20
+ self.prefix = GraphBuilder.alloc_prefix()
21
+ else:
22
+ self.prefix = prefix
23
+ self.nodes = {}
24
+ self.id_gen = 1
25
+
26
+ @classmethod
27
+ def set_default_prefix(cls, prefix_root, call_index, graph_index = 0):
28
+ cls._default_prefix_root = prefix_root
29
+ cls._default_prefix_call_index = call_index
30
+ cls._default_prefix_graph_index = graph_index
31
+
32
+ @classmethod
33
+ def alloc_prefix(cls, root=None, call_index=None, graph_index=None):
34
+ if root is None:
35
+ root = GraphBuilder._default_prefix_root
36
+ if call_index is None:
37
+ call_index = GraphBuilder._default_prefix_call_index
38
+ if graph_index is None:
39
+ graph_index = GraphBuilder._default_prefix_graph_index
40
+ result = f"{root}.{call_index}.{graph_index}."
41
+ GraphBuilder._default_prefix_graph_index += 1
42
+ return result
43
+
44
+ def node(self, class_type, id=None, **kwargs):
45
+ if id is None:
46
+ id = str(self.id_gen)
47
+ self.id_gen += 1
48
+ id = self.prefix + id
49
+ if id in self.nodes:
50
+ return self.nodes[id]
51
+
52
+ node = Node(id, class_type, kwargs)
53
+ self.nodes[id] = node
54
+ return node
55
+
56
+ def lookup_node(self, id):
57
+ id = self.prefix + id
58
+ return self.nodes.get(id)
59
+
60
+ def finalize(self):
61
+ output = {}
62
+ for node_id, node in self.nodes.items():
63
+ output[node_id] = node.serialize()
64
+ return output
65
+
66
+ def replace_node_output(self, node_id, index, new_value):
67
+ node_id = self.prefix + node_id
68
+ to_remove = []
69
+ for node in self.nodes.values():
70
+ for key, value in node.inputs.items():
71
+ if is_link(value) and value[0] == node_id and value[1] == index:
72
+ if new_value is None:
73
+ to_remove.append((node, key))
74
+ else:
75
+ node.inputs[key] = new_value
76
+ for node, key in to_remove:
77
+ del node.inputs[key]
78
+
79
+ def remove_node(self, id):
80
+ id = self.prefix + id
81
+ del self.nodes[id]
82
+
83
+ class Node:
84
+ def __init__(self, id, class_type, inputs):
85
+ self.id = id
86
+ self.class_type = class_type
87
+ self.inputs = inputs
88
+ self.override_display_id = None
89
+
90
+ def out(self, index):
91
+ return [self.id, index]
92
+
93
+ def set_input(self, key, value):
94
+ if value is None:
95
+ if key in self.inputs:
96
+ del self.inputs[key]
97
+ else:
98
+ self.inputs[key] = value
99
+
100
+ def get_input(self, key):
101
+ return self.inputs.get(key)
102
+
103
+ def set_override_display_id(self, override_display_id):
104
+ self.override_display_id = override_display_id
105
+
106
+ def serialize(self):
107
+ serialized = {
108
+ "class_type": self.class_type,
109
+ "inputs": self.inputs
110
+ }
111
+ if self.override_display_id is not None:
112
+ serialized["override_display_id"] = self.override_display_id
113
+ return serialized
114
+
115
+ def add_graph_prefix(graph, outputs, prefix):
116
+ # Change the node IDs and any internal links
117
+ new_graph = {}
118
+ for node_id, node_info in graph.items():
119
+ # Make sure the added nodes have unique IDs
120
+ new_node_id = prefix + node_id
121
+ new_node = { "class_type": node_info["class_type"], "inputs": {} }
122
+ for input_name, input_value in node_info.get("inputs", {}).items():
123
+ if is_link(input_value):
124
+ new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]]
125
+ else:
126
+ new_node["inputs"][input_name] = input_value
127
+ new_graph[new_node_id] = new_node
128
+
129
+ # Change the node IDs in the outputs
130
+ new_outputs = []
131
+ for n in range(len(outputs)):
132
+ output = outputs[n]
133
+ if is_link(output):
134
+ new_outputs.append([prefix + output[0], output[1]])
135
+ else:
136
+ new_outputs.append(output)
137
+
138
+ return new_graph, tuple(new_outputs)
139
+
comfy_execution/validation.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+
4
+ def validate_node_input(
5
+ received_type: str, input_type: str, strict: bool = False
6
+ ) -> bool:
7
+ """
8
+ received_type and input_type are both strings of the form "T1,T2,...".
9
+
10
+ If strict is True, the input_type must contain the received_type.
11
+ For example, if received_type is "STRING" and input_type is "STRING,INT",
12
+ this will return True. But if received_type is "STRING,INT" and input_type is
13
+ "INT", this will return False.
14
+
15
+ If strict is False, the input_type must have overlap with the received_type.
16
+ For example, if received_type is "STRING,BOOLEAN" and input_type is "STRING,INT",
17
+ this will return True.
18
+
19
+ Supports pre-union type extension behaviour of ``__ne__`` overrides.
20
+ """
21
+ # If the types are exactly the same, we can return immediately
22
+ # Use pre-union behaviour: inverse of `__ne__`
23
+ if not received_type != input_type:
24
+ return True
25
+
26
+ # Not equal, and not strings
27
+ if not isinstance(received_type, str) or not isinstance(input_type, str):
28
+ return False
29
+
30
+ # Split the type strings into sets for comparison
31
+ received_types = set(t.strip() for t in received_type.split(","))
32
+ input_types = set(t.strip() for t in input_type.split(","))
33
+
34
+ if strict:
35
+ # In strict mode, all received types must be in the input types
36
+ return received_types.issubset(input_types)
37
+ else:
38
+ # In non-strict mode, there must be at least one type in common
39
+ return len(received_types.intersection(input_types)) > 0
comfy_extras/chainner_models/model_loading.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import logging
2
+ from spandrel import ModelLoader
3
+
4
+ def load_state_dict(state_dict):
5
+ logging.warning("comfy_extras.chainner_models is deprecated and has been replaced by the spandrel library.")
6
+ return ModelLoader().load_from_state_dict(state_dict).eval()
comfy_extras/nodes_advanced_samplers.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import comfy.samplers
2
+ import comfy.utils
3
+ import torch
4
+ import numpy as np
5
+ from tqdm.auto import trange
6
+
7
+
8
+ @torch.no_grad()
9
+ def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable=None, total_upscale=2.0, upscale_method="bislerp", upscale_steps=None):
10
+ extra_args = {} if extra_args is None else extra_args
11
+
12
+ if upscale_steps is None:
13
+ upscale_steps = max(len(sigmas) // 2 + 1, 2)
14
+ else:
15
+ upscale_steps += 1
16
+ upscale_steps = min(upscale_steps, len(sigmas) + 1)
17
+
18
+ upscales = np.linspace(1.0, total_upscale, upscale_steps)[1:]
19
+
20
+ orig_shape = x.size()
21
+ s_in = x.new_ones([x.shape[0]])
22
+ for i in trange(len(sigmas) - 1, disable=disable):
23
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
24
+ if callback is not None:
25
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
26
+
27
+ x = denoised
28
+ if i < len(upscales):
29
+ x = comfy.utils.common_upscale(x, round(orig_shape[-1] * upscales[i]), round(orig_shape[-2] * upscales[i]), upscale_method, "disabled")
30
+
31
+ if sigmas[i + 1] > 0:
32
+ x += sigmas[i + 1] * torch.randn_like(x)
33
+ return x
34
+
35
+
36
+ class SamplerLCMUpscale:
37
+ upscale_methods = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
38
+
39
+ @classmethod
40
+ def INPUT_TYPES(s):
41
+ return {"required":
42
+ {"scale_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.01}),
43
+ "scale_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "step": 1}),
44
+ "upscale_method": (s.upscale_methods,),
45
+ }
46
+ }
47
+ RETURN_TYPES = ("SAMPLER",)
48
+ CATEGORY = "sampling/custom_sampling/samplers"
49
+
50
+ FUNCTION = "get_sampler"
51
+
52
+ def get_sampler(self, scale_ratio, scale_steps, upscale_method):
53
+ if scale_steps < 0:
54
+ scale_steps = None
55
+ sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method})
56
+ return (sampler, )
57
+
58
+ from comfy.k_diffusion.sampling import to_d
59
+ import comfy.model_patcher
60
+
61
+ @torch.no_grad()
62
+ def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
63
+ extra_args = {} if extra_args is None else extra_args
64
+
65
+ temp = [0]
66
+ def post_cfg_function(args):
67
+ temp[0] = args["uncond_denoised"]
68
+ return args["denoised"]
69
+
70
+ model_options = extra_args.get("model_options", {}).copy()
71
+ extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
72
+
73
+ s_in = x.new_ones([x.shape[0]])
74
+ for i in trange(len(sigmas) - 1, disable=disable):
75
+ sigma_hat = sigmas[i]
76
+ denoised = model(x, sigma_hat * s_in, **extra_args)
77
+ d = to_d(x - denoised + temp[0], sigmas[i], denoised)
78
+ if callback is not None:
79
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
80
+ dt = sigmas[i + 1] - sigma_hat
81
+ x = x + d * dt
82
+ return x
83
+
84
+
85
+ class SamplerEulerCFGpp:
86
+ @classmethod
87
+ def INPUT_TYPES(s):
88
+ return {"required":
89
+ {"version": (["regular", "alternative"],),}
90
+ }
91
+ RETURN_TYPES = ("SAMPLER",)
92
+ # CATEGORY = "sampling/custom_sampling/samplers"
93
+ CATEGORY = "_for_testing"
94
+
95
+ FUNCTION = "get_sampler"
96
+
97
+ def get_sampler(self, version):
98
+ if version == "alternative":
99
+ sampler = comfy.samplers.KSAMPLER(sample_euler_pp)
100
+ else:
101
+ sampler = comfy.samplers.ksampler("euler_cfg_pp")
102
+ return (sampler, )
103
+
104
+ NODE_CLASS_MAPPINGS = {
105
+ "SamplerLCMUpscale": SamplerLCMUpscale,
106
+ "SamplerEulerCFGpp": SamplerEulerCFGpp,
107
+ }
108
+
109
+ NODE_DISPLAY_NAME_MAPPINGS = {
110
+ "SamplerEulerCFGpp": "SamplerEulerCFG++",
111
+ }
comfy_extras/nodes_align_your_steps.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
2
+ import numpy as np
3
+ import torch
4
+
5
+ def loglinear_interp(t_steps, num_steps):
6
+ """
7
+ Performs log-linear interpolation of a given array of decreasing numbers.
8
+ """
9
+ xs = np.linspace(0, 1, len(t_steps))
10
+ ys = np.log(t_steps[::-1])
11
+
12
+ new_xs = np.linspace(0, 1, num_steps)
13
+ new_ys = np.interp(new_xs, xs, ys)
14
+
15
+ interped_ys = np.exp(new_ys)[::-1].copy()
16
+ return interped_ys
17
+
18
+ NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.6946151520, 1.8841921177, 1.3943805092, 0.9642583904, 0.6523686016, 0.3977456272, 0.1515232662, 0.0291671582],
19
+ "SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582],
20
+ "SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]}
21
+
22
+ class AlignYourStepsScheduler:
23
+ @classmethod
24
+ def INPUT_TYPES(s):
25
+ return {"required":
26
+ {"model_type": (["SD1", "SDXL", "SVD"], ),
27
+ "steps": ("INT", {"default": 10, "min": 1, "max": 10000}),
28
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
29
+ }
30
+ }
31
+ RETURN_TYPES = ("SIGMAS",)
32
+ CATEGORY = "sampling/custom_sampling/schedulers"
33
+
34
+ FUNCTION = "get_sigmas"
35
+
36
+ def get_sigmas(self, model_type, steps, denoise):
37
+ total_steps = steps
38
+ if denoise < 1.0:
39
+ if denoise <= 0.0:
40
+ return (torch.FloatTensor([]),)
41
+ total_steps = round(steps * denoise)
42
+
43
+ sigmas = NOISE_LEVELS[model_type][:]
44
+ if (steps + 1) != len(sigmas):
45
+ sigmas = loglinear_interp(sigmas, steps + 1)
46
+
47
+ sigmas = sigmas[-(total_steps + 1):]
48
+ sigmas[-1] = 0
49
+ return (torch.FloatTensor(sigmas), )
50
+
51
+ NODE_CLASS_MAPPINGS = {
52
+ "AlignYourStepsScheduler": AlignYourStepsScheduler,
53
+ }
comfy_extras/nodes_attention_multiply.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def attention_multiply(attn, model, q, k, v, out):
3
+ m = model.clone()
4
+ sd = model.model_state_dict()
5
+
6
+ for key in sd:
7
+ if key.endswith("{}.to_q.bias".format(attn)) or key.endswith("{}.to_q.weight".format(attn)):
8
+ m.add_patches({key: (None,)}, 0.0, q)
9
+ if key.endswith("{}.to_k.bias".format(attn)) or key.endswith("{}.to_k.weight".format(attn)):
10
+ m.add_patches({key: (None,)}, 0.0, k)
11
+ if key.endswith("{}.to_v.bias".format(attn)) or key.endswith("{}.to_v.weight".format(attn)):
12
+ m.add_patches({key: (None,)}, 0.0, v)
13
+ if key.endswith("{}.to_out.0.bias".format(attn)) or key.endswith("{}.to_out.0.weight".format(attn)):
14
+ m.add_patches({key: (None,)}, 0.0, out)
15
+
16
+ return m
17
+
18
+
19
+ class UNetSelfAttentionMultiply:
20
+ @classmethod
21
+ def INPUT_TYPES(s):
22
+ return {"required": { "model": ("MODEL",),
23
+ "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
24
+ "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
25
+ "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
26
+ "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
27
+ }}
28
+ RETURN_TYPES = ("MODEL",)
29
+ FUNCTION = "patch"
30
+
31
+ CATEGORY = "_for_testing/attention_experiments"
32
+
33
+ def patch(self, model, q, k, v, out):
34
+ m = attention_multiply("attn1", model, q, k, v, out)
35
+ return (m, )
36
+
37
+ class UNetCrossAttentionMultiply:
38
+ @classmethod
39
+ def INPUT_TYPES(s):
40
+ return {"required": { "model": ("MODEL",),
41
+ "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
42
+ "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
43
+ "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
44
+ "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
45
+ }}
46
+ RETURN_TYPES = ("MODEL",)
47
+ FUNCTION = "patch"
48
+
49
+ CATEGORY = "_for_testing/attention_experiments"
50
+
51
+ def patch(self, model, q, k, v, out):
52
+ m = attention_multiply("attn2", model, q, k, v, out)
53
+ return (m, )
54
+
55
+ class CLIPAttentionMultiply:
56
+ @classmethod
57
+ def INPUT_TYPES(s):
58
+ return {"required": { "clip": ("CLIP",),
59
+ "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
60
+ "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
61
+ "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
62
+ "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
63
+ }}
64
+ RETURN_TYPES = ("CLIP",)
65
+ FUNCTION = "patch"
66
+
67
+ CATEGORY = "_for_testing/attention_experiments"
68
+
69
+ def patch(self, clip, q, k, v, out):
70
+ m = clip.clone()
71
+ sd = m.patcher.model_state_dict()
72
+
73
+ for key in sd:
74
+ if key.endswith("self_attn.q_proj.weight") or key.endswith("self_attn.q_proj.bias"):
75
+ m.add_patches({key: (None,)}, 0.0, q)
76
+ if key.endswith("self_attn.k_proj.weight") or key.endswith("self_attn.k_proj.bias"):
77
+ m.add_patches({key: (None,)}, 0.0, k)
78
+ if key.endswith("self_attn.v_proj.weight") or key.endswith("self_attn.v_proj.bias"):
79
+ m.add_patches({key: (None,)}, 0.0, v)
80
+ if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"):
81
+ m.add_patches({key: (None,)}, 0.0, out)
82
+ return (m, )
83
+
84
+ class UNetTemporalAttentionMultiply:
85
+ @classmethod
86
+ def INPUT_TYPES(s):
87
+ return {"required": { "model": ("MODEL",),
88
+ "self_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
89
+ "self_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
90
+ "cross_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
91
+ "cross_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
92
+ }}
93
+ RETURN_TYPES = ("MODEL",)
94
+ FUNCTION = "patch"
95
+
96
+ CATEGORY = "_for_testing/attention_experiments"
97
+
98
+ def patch(self, model, self_structural, self_temporal, cross_structural, cross_temporal):
99
+ m = model.clone()
100
+ sd = model.model_state_dict()
101
+
102
+ for k in sd:
103
+ if (k.endswith("attn1.to_out.0.bias") or k.endswith("attn1.to_out.0.weight")):
104
+ if '.time_stack.' in k:
105
+ m.add_patches({k: (None,)}, 0.0, self_temporal)
106
+ else:
107
+ m.add_patches({k: (None,)}, 0.0, self_structural)
108
+ elif (k.endswith("attn2.to_out.0.bias") or k.endswith("attn2.to_out.0.weight")):
109
+ if '.time_stack.' in k:
110
+ m.add_patches({k: (None,)}, 0.0, cross_temporal)
111
+ else:
112
+ m.add_patches({k: (None,)}, 0.0, cross_structural)
113
+ return (m, )
114
+
115
+ NODE_CLASS_MAPPINGS = {
116
+ "UNetSelfAttentionMultiply": UNetSelfAttentionMultiply,
117
+ "UNetCrossAttentionMultiply": UNetCrossAttentionMultiply,
118
+ "CLIPAttentionMultiply": CLIPAttentionMultiply,
119
+ "UNetTemporalAttentionMultiply": UNetTemporalAttentionMultiply,
120
+ }
comfy_extras/nodes_audio.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ import torch
3
+ import comfy.model_management
4
+ import folder_paths
5
+ import os
6
+ import io
7
+ import json
8
+ import struct
9
+ import random
10
+ import hashlib
11
+ import node_helpers
12
+ from comfy.cli_args import args
13
+
14
+ class EmptyLatentAudio:
15
+ def __init__(self):
16
+ self.device = comfy.model_management.intermediate_device()
17
+
18
+ @classmethod
19
+ def INPUT_TYPES(s):
20
+ return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
21
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
22
+ }}
23
+ RETURN_TYPES = ("LATENT",)
24
+ FUNCTION = "generate"
25
+
26
+ CATEGORY = "latent/audio"
27
+
28
+ def generate(self, seconds, batch_size):
29
+ length = round((seconds * 44100 / 2048) / 2) * 2
30
+ latent = torch.zeros([batch_size, 64, length], device=self.device)
31
+ return ({"samples":latent, "type": "audio"}, )
32
+
33
+ class ConditioningStableAudio:
34
+ @classmethod
35
+ def INPUT_TYPES(s):
36
+ return {"required": {"positive": ("CONDITIONING", ),
37
+ "negative": ("CONDITIONING", ),
38
+ "seconds_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
39
+ "seconds_total": ("FLOAT", {"default": 47.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
40
+ }}
41
+
42
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING")
43
+ RETURN_NAMES = ("positive", "negative")
44
+
45
+ FUNCTION = "append"
46
+
47
+ CATEGORY = "conditioning"
48
+
49
+ def append(self, positive, negative, seconds_start, seconds_total):
50
+ positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total})
51
+ negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total})
52
+ return (positive, negative)
53
+
54
+ class VAEEncodeAudio:
55
+ @classmethod
56
+ def INPUT_TYPES(s):
57
+ return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}}
58
+ RETURN_TYPES = ("LATENT",)
59
+ FUNCTION = "encode"
60
+
61
+ CATEGORY = "latent/audio"
62
+
63
+ def encode(self, vae, audio):
64
+ sample_rate = audio["sample_rate"]
65
+ if 44100 != sample_rate:
66
+ waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
67
+ else:
68
+ waveform = audio["waveform"]
69
+
70
+ t = vae.encode(waveform.movedim(1, -1))
71
+ return ({"samples":t}, )
72
+
73
+ class VAEDecodeAudio:
74
+ @classmethod
75
+ def INPUT_TYPES(s):
76
+ return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
77
+ RETURN_TYPES = ("AUDIO",)
78
+ FUNCTION = "decode"
79
+
80
+ CATEGORY = "latent/audio"
81
+
82
+ def decode(self, vae, samples):
83
+ audio = vae.decode(samples["samples"]).movedim(-1, 1)
84
+ std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
85
+ std[std < 1.0] = 1.0
86
+ audio /= std
87
+ return ({"waveform": audio, "sample_rate": 44100}, )
88
+
89
+
90
+ def create_vorbis_comment_block(comment_dict, last_block):
91
+ vendor_string = b'ComfyUI'
92
+ vendor_length = len(vendor_string)
93
+
94
+ comments = []
95
+ for key, value in comment_dict.items():
96
+ comment = f"{key}={value}".encode('utf-8')
97
+ comments.append(struct.pack('<I', len(comment)) + comment)
98
+
99
+ user_comment_list_length = len(comments)
100
+ user_comments = b''.join(comments)
101
+
102
+ comment_data = struct.pack('<I', vendor_length) + vendor_string + struct.pack('<I', user_comment_list_length) + user_comments
103
+ if last_block:
104
+ id = b'\x84'
105
+ else:
106
+ id = b'\x04'
107
+ comment_block = id + struct.pack('>I', len(comment_data))[1:] + comment_data
108
+
109
+ return comment_block
110
+
111
+ def insert_or_replace_vorbis_comment(flac_io, comment_dict):
112
+ if len(comment_dict) == 0:
113
+ return flac_io
114
+
115
+ flac_io.seek(4)
116
+
117
+ blocks = []
118
+ last_block = False
119
+
120
+ while not last_block:
121
+ header = flac_io.read(4)
122
+ last_block = (header[0] & 0x80) != 0
123
+ block_type = header[0] & 0x7F
124
+ block_length = struct.unpack('>I', b'\x00' + header[1:])[0]
125
+ block_data = flac_io.read(block_length)
126
+
127
+ if block_type == 4 or block_type == 1:
128
+ pass
129
+ else:
130
+ header = bytes([(header[0] & (~0x80))]) + header[1:]
131
+ blocks.append(header + block_data)
132
+
133
+ blocks.append(create_vorbis_comment_block(comment_dict, last_block=True))
134
+
135
+ new_flac_io = io.BytesIO()
136
+ new_flac_io.write(b'fLaC')
137
+ for block in blocks:
138
+ new_flac_io.write(block)
139
+
140
+ new_flac_io.write(flac_io.read())
141
+ return new_flac_io
142
+
143
+
144
+ class SaveAudio:
145
+ def __init__(self):
146
+ self.output_dir = folder_paths.get_output_directory()
147
+ self.type = "output"
148
+ self.prefix_append = ""
149
+
150
+ @classmethod
151
+ def INPUT_TYPES(s):
152
+ return {"required": { "audio": ("AUDIO", ),
153
+ "filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
154
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
155
+ }
156
+
157
+ RETURN_TYPES = ()
158
+ FUNCTION = "save_audio"
159
+
160
+ OUTPUT_NODE = True
161
+
162
+ CATEGORY = "audio"
163
+
164
+ def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
165
+ filename_prefix += self.prefix_append
166
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
167
+ results = list()
168
+
169
+ metadata = {}
170
+ if not args.disable_metadata:
171
+ if prompt is not None:
172
+ metadata["prompt"] = json.dumps(prompt)
173
+ if extra_pnginfo is not None:
174
+ for x in extra_pnginfo:
175
+ metadata[x] = json.dumps(extra_pnginfo[x])
176
+
177
+ for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
178
+ filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
179
+ file = f"{filename_with_batch_num}_{counter:05}_.flac"
180
+
181
+ buff = io.BytesIO()
182
+ torchaudio.save(buff, waveform, audio["sample_rate"], format="FLAC")
183
+
184
+ buff = insert_or_replace_vorbis_comment(buff, metadata)
185
+
186
+ with open(os.path.join(full_output_folder, file), 'wb') as f:
187
+ f.write(buff.getbuffer())
188
+
189
+ results.append({
190
+ "filename": file,
191
+ "subfolder": subfolder,
192
+ "type": self.type
193
+ })
194
+ counter += 1
195
+
196
+ return { "ui": { "audio": results } }
197
+
198
+ class PreviewAudio(SaveAudio):
199
+ def __init__(self):
200
+ self.output_dir = folder_paths.get_temp_directory()
201
+ self.type = "temp"
202
+ self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
203
+
204
+ @classmethod
205
+ def INPUT_TYPES(s):
206
+ return {"required":
207
+ {"audio": ("AUDIO", ), },
208
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
209
+ }
210
+
211
+ class LoadAudio:
212
+ @classmethod
213
+ def INPUT_TYPES(s):
214
+ input_dir = folder_paths.get_input_directory()
215
+ files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
216
+ return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
217
+
218
+ CATEGORY = "audio"
219
+
220
+ RETURN_TYPES = ("AUDIO", )
221
+ FUNCTION = "load"
222
+
223
+ def load(self, audio):
224
+ audio_path = folder_paths.get_annotated_filepath(audio)
225
+ waveform, sample_rate = torchaudio.load(audio_path)
226
+ audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
227
+ return (audio, )
228
+
229
+ @classmethod
230
+ def IS_CHANGED(s, audio):
231
+ image_path = folder_paths.get_annotated_filepath(audio)
232
+ m = hashlib.sha256()
233
+ with open(image_path, 'rb') as f:
234
+ m.update(f.read())
235
+ return m.digest().hex()
236
+
237
+ @classmethod
238
+ def VALIDATE_INPUTS(s, audio):
239
+ if not folder_paths.exists_annotated_filepath(audio):
240
+ return "Invalid audio file: {}".format(audio)
241
+ return True
242
+
243
+ NODE_CLASS_MAPPINGS = {
244
+ "EmptyLatentAudio": EmptyLatentAudio,
245
+ "VAEEncodeAudio": VAEEncodeAudio,
246
+ "VAEDecodeAudio": VAEDecodeAudio,
247
+ "SaveAudio": SaveAudio,
248
+ "LoadAudio": LoadAudio,
249
+ "PreviewAudio": PreviewAudio,
250
+ "ConditioningStableAudio": ConditioningStableAudio,
251
+ }
comfy_extras/nodes_canny.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from kornia.filters import canny
2
+ import comfy.model_management
3
+
4
+
5
+ class Canny:
6
+ @classmethod
7
+ def INPUT_TYPES(s):
8
+ return {"required": {"image": ("IMAGE",),
9
+ "low_threshold": ("FLOAT", {"default": 0.4, "min": 0.01, "max": 0.99, "step": 0.01}),
10
+ "high_threshold": ("FLOAT", {"default": 0.8, "min": 0.01, "max": 0.99, "step": 0.01})
11
+ }}
12
+
13
+ RETURN_TYPES = ("IMAGE",)
14
+ FUNCTION = "detect_edge"
15
+
16
+ CATEGORY = "image/preprocessors"
17
+
18
+ def detect_edge(self, image, low_threshold, high_threshold):
19
+ output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
20
+ img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
21
+ return (img_out,)
22
+
23
+ NODE_CLASS_MAPPINGS = {
24
+ "Canny": Canny,
25
+ }
comfy_extras/nodes_clip_sdxl.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nodes import MAX_RESOLUTION
2
+
3
+ class CLIPTextEncodeSDXLRefiner:
4
+ @classmethod
5
+ def INPUT_TYPES(s):
6
+ return {"required": {
7
+ "ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
8
+ "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
9
+ "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
10
+ "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
11
+ }}
12
+ RETURN_TYPES = ("CONDITIONING",)
13
+ FUNCTION = "encode"
14
+
15
+ CATEGORY = "advanced/conditioning"
16
+
17
+ def encode(self, clip, ascore, width, height, text):
18
+ tokens = clip.tokenize(text)
19
+ return (clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}), )
20
+
21
+ class CLIPTextEncodeSDXL:
22
+ @classmethod
23
+ def INPUT_TYPES(s):
24
+ return {"required": {
25
+ "clip": ("CLIP", ),
26
+ "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
27
+ "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
28
+ "crop_w": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
29
+ "crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
30
+ "target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
31
+ "target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
32
+ "text_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
33
+ "text_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
34
+ }}
35
+ RETURN_TYPES = ("CONDITIONING",)
36
+ FUNCTION = "encode"
37
+
38
+ CATEGORY = "advanced/conditioning"
39
+
40
+ def encode(self, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l):
41
+ tokens = clip.tokenize(text_g)
42
+ tokens["l"] = clip.tokenize(text_l)["l"]
43
+ if len(tokens["l"]) != len(tokens["g"]):
44
+ empty = clip.tokenize("")
45
+ while len(tokens["l"]) < len(tokens["g"]):
46
+ tokens["l"] += empty["l"]
47
+ while len(tokens["l"]) > len(tokens["g"]):
48
+ tokens["g"] += empty["g"]
49
+ return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}), )
50
+
51
+ NODE_CLASS_MAPPINGS = {
52
+ "CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner,
53
+ "CLIPTextEncodeSDXL": CLIPTextEncodeSDXL,
54
+ }
comfy_extras/nodes_compositing.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import comfy.utils
3
+ from enum import Enum
4
+
5
+ def resize_mask(mask, shape):
6
+ return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
7
+
8
+ class PorterDuffMode(Enum):
9
+ ADD = 0
10
+ CLEAR = 1
11
+ DARKEN = 2
12
+ DST = 3
13
+ DST_ATOP = 4
14
+ DST_IN = 5
15
+ DST_OUT = 6
16
+ DST_OVER = 7
17
+ LIGHTEN = 8
18
+ MULTIPLY = 9
19
+ OVERLAY = 10
20
+ SCREEN = 11
21
+ SRC = 12
22
+ SRC_ATOP = 13
23
+ SRC_IN = 14
24
+ SRC_OUT = 15
25
+ SRC_OVER = 16
26
+ XOR = 17
27
+
28
+
29
+ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode):
30
+ # convert mask to alpha
31
+ src_alpha = 1 - src_alpha
32
+ dst_alpha = 1 - dst_alpha
33
+ # premultiply alpha
34
+ src_image = src_image * src_alpha
35
+ dst_image = dst_image * dst_alpha
36
+
37
+ # composite ops below assume alpha-premultiplied images
38
+ if mode == PorterDuffMode.ADD:
39
+ out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1)
40
+ out_image = torch.clamp(src_image + dst_image, 0, 1)
41
+ elif mode == PorterDuffMode.CLEAR:
42
+ out_alpha = torch.zeros_like(dst_alpha)
43
+ out_image = torch.zeros_like(dst_image)
44
+ elif mode == PorterDuffMode.DARKEN:
45
+ out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
46
+ out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image)
47
+ elif mode == PorterDuffMode.DST:
48
+ out_alpha = dst_alpha
49
+ out_image = dst_image
50
+ elif mode == PorterDuffMode.DST_ATOP:
51
+ out_alpha = src_alpha
52
+ out_image = src_alpha * dst_image + (1 - dst_alpha) * src_image
53
+ elif mode == PorterDuffMode.DST_IN:
54
+ out_alpha = src_alpha * dst_alpha
55
+ out_image = dst_image * src_alpha
56
+ elif mode == PorterDuffMode.DST_OUT:
57
+ out_alpha = (1 - src_alpha) * dst_alpha
58
+ out_image = (1 - src_alpha) * dst_image
59
+ elif mode == PorterDuffMode.DST_OVER:
60
+ out_alpha = dst_alpha + (1 - dst_alpha) * src_alpha
61
+ out_image = dst_image + (1 - dst_alpha) * src_image
62
+ elif mode == PorterDuffMode.LIGHTEN:
63
+ out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
64
+ out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.max(src_image, dst_image)
65
+ elif mode == PorterDuffMode.MULTIPLY:
66
+ out_alpha = src_alpha * dst_alpha
67
+ out_image = src_image * dst_image
68
+ elif mode == PorterDuffMode.OVERLAY:
69
+ out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
70
+ out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image,
71
+ src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image))
72
+ elif mode == PorterDuffMode.SCREEN:
73
+ out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
74
+ out_image = src_image + dst_image - src_image * dst_image
75
+ elif mode == PorterDuffMode.SRC:
76
+ out_alpha = src_alpha
77
+ out_image = src_image
78
+ elif mode == PorterDuffMode.SRC_ATOP:
79
+ out_alpha = dst_alpha
80
+ out_image = dst_alpha * src_image + (1 - src_alpha) * dst_image
81
+ elif mode == PorterDuffMode.SRC_IN:
82
+ out_alpha = src_alpha * dst_alpha
83
+ out_image = src_image * dst_alpha
84
+ elif mode == PorterDuffMode.SRC_OUT:
85
+ out_alpha = (1 - dst_alpha) * src_alpha
86
+ out_image = (1 - dst_alpha) * src_image
87
+ elif mode == PorterDuffMode.SRC_OVER:
88
+ out_alpha = src_alpha + (1 - src_alpha) * dst_alpha
89
+ out_image = src_image + (1 - src_alpha) * dst_image
90
+ elif mode == PorterDuffMode.XOR:
91
+ out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha
92
+ out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image
93
+ else:
94
+ return None, None
95
+
96
+ # back to non-premultiplied alpha
97
+ out_image = torch.where(out_alpha > 1e-5, out_image / out_alpha, torch.zeros_like(out_image))
98
+ out_image = torch.clamp(out_image, 0, 1)
99
+ # convert alpha to mask
100
+ out_alpha = 1 - out_alpha
101
+ return out_image, out_alpha
102
+
103
+
104
+ class PorterDuffImageComposite:
105
+ @classmethod
106
+ def INPUT_TYPES(s):
107
+ return {
108
+ "required": {
109
+ "source": ("IMAGE",),
110
+ "source_alpha": ("MASK",),
111
+ "destination": ("IMAGE",),
112
+ "destination_alpha": ("MASK",),
113
+ "mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
114
+ },
115
+ }
116
+
117
+ RETURN_TYPES = ("IMAGE", "MASK")
118
+ FUNCTION = "composite"
119
+ CATEGORY = "mask/compositing"
120
+
121
+ def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
122
+ batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
123
+ out_images = []
124
+ out_alphas = []
125
+
126
+ for i in range(batch_size):
127
+ src_image = source[i]
128
+ dst_image = destination[i]
129
+
130
+ assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels
131
+
132
+ src_alpha = source_alpha[i].unsqueeze(2)
133
+ dst_alpha = destination_alpha[i].unsqueeze(2)
134
+
135
+ if dst_alpha.shape[:2] != dst_image.shape[:2]:
136
+ upscale_input = dst_alpha.unsqueeze(0).permute(0, 3, 1, 2)
137
+ upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
138
+ dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
139
+ if src_image.shape != dst_image.shape:
140
+ upscale_input = src_image.unsqueeze(0).permute(0, 3, 1, 2)
141
+ upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
142
+ src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0)
143
+ if src_alpha.shape != dst_alpha.shape:
144
+ upscale_input = src_alpha.unsqueeze(0).permute(0, 3, 1, 2)
145
+ upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center')
146
+ src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
147
+
148
+ out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode])
149
+
150
+ out_images.append(out_image)
151
+ out_alphas.append(out_alpha.squeeze(2))
152
+
153
+ result = (torch.stack(out_images), torch.stack(out_alphas))
154
+ return result
155
+
156
+
157
+ class SplitImageWithAlpha:
158
+ @classmethod
159
+ def INPUT_TYPES(s):
160
+ return {
161
+ "required": {
162
+ "image": ("IMAGE",),
163
+ }
164
+ }
165
+
166
+ CATEGORY = "mask/compositing"
167
+ RETURN_TYPES = ("IMAGE", "MASK")
168
+ FUNCTION = "split_image_with_alpha"
169
+
170
+ def split_image_with_alpha(self, image: torch.Tensor):
171
+ out_images = [i[:,:,:3] for i in image]
172
+ out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
173
+ result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
174
+ return result
175
+
176
+
177
+ class JoinImageWithAlpha:
178
+ @classmethod
179
+ def INPUT_TYPES(s):
180
+ return {
181
+ "required": {
182
+ "image": ("IMAGE",),
183
+ "alpha": ("MASK",),
184
+ }
185
+ }
186
+
187
+ CATEGORY = "mask/compositing"
188
+ RETURN_TYPES = ("IMAGE",)
189
+ FUNCTION = "join_image_with_alpha"
190
+
191
+ def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
192
+ batch_size = min(len(image), len(alpha))
193
+ out_images = []
194
+
195
+ alpha = 1.0 - resize_mask(alpha, image.shape[1:])
196
+ for i in range(batch_size):
197
+ out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
198
+
199
+ result = (torch.stack(out_images),)
200
+ return result
201
+
202
+
203
+ NODE_CLASS_MAPPINGS = {
204
+ "PorterDuffImageComposite": PorterDuffImageComposite,
205
+ "SplitImageWithAlpha": SplitImageWithAlpha,
206
+ "JoinImageWithAlpha": JoinImageWithAlpha,
207
+ }
208
+
209
+
210
+ NODE_DISPLAY_NAME_MAPPINGS = {
211
+ "PorterDuffImageComposite": "Porter-Duff Image Composite",
212
+ "SplitImageWithAlpha": "Split Image with Alpha",
213
+ "JoinImageWithAlpha": "Join Image with Alpha",
214
+ }
comfy_extras/nodes_cond.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ class CLIPTextEncodeControlnet:
4
+ @classmethod
5
+ def INPUT_TYPES(s):
6
+ return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True, "dynamicPrompts": True})}}
7
+ RETURN_TYPES = ("CONDITIONING",)
8
+ FUNCTION = "encode"
9
+
10
+ CATEGORY = "_for_testing/conditioning"
11
+
12
+ def encode(self, clip, conditioning, text):
13
+ tokens = clip.tokenize(text)
14
+ cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
15
+ c = []
16
+ for t in conditioning:
17
+ n = [t[0], t[1].copy()]
18
+ n[1]['cross_attn_controlnet'] = cond
19
+ n[1]['pooled_output_controlnet'] = pooled
20
+ c.append(n)
21
+ return (c, )
22
+
23
+ NODE_CLASS_MAPPINGS = {
24
+ "CLIPTextEncodeControlnet": CLIPTextEncodeControlnet
25
+ }
comfy_extras/nodes_controlnet.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
2
+ import nodes
3
+ import comfy.utils
4
+
5
+ class SetUnionControlNetType:
6
+ @classmethod
7
+ def INPUT_TYPES(s):
8
+ return {"required": {"control_net": ("CONTROL_NET", ),
9
+ "type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),)
10
+ }}
11
+
12
+ CATEGORY = "conditioning/controlnet"
13
+ RETURN_TYPES = ("CONTROL_NET",)
14
+
15
+ FUNCTION = "set_controlnet_type"
16
+
17
+ def set_controlnet_type(self, control_net, type):
18
+ control_net = control_net.copy()
19
+ type_number = UNION_CONTROLNET_TYPES.get(type, -1)
20
+ if type_number >= 0:
21
+ control_net.set_extra_arg("control_type", [type_number])
22
+ else:
23
+ control_net.set_extra_arg("control_type", [])
24
+
25
+ return (control_net,)
26
+
27
+ class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
28
+ @classmethod
29
+ def INPUT_TYPES(s):
30
+ return {"required": {"positive": ("CONDITIONING", ),
31
+ "negative": ("CONDITIONING", ),
32
+ "control_net": ("CONTROL_NET", ),
33
+ "vae": ("VAE", ),
34
+ "image": ("IMAGE", ),
35
+ "mask": ("MASK", ),
36
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
37
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
38
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
39
+ }}
40
+
41
+ FUNCTION = "apply_inpaint_controlnet"
42
+
43
+ CATEGORY = "conditioning/controlnet"
44
+
45
+ def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
46
+ extra_concat = []
47
+ if control_net.concat_mask:
48
+ mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
49
+ mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
50
+ image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
51
+ extra_concat = [mask]
52
+
53
+ return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
54
+
55
+
56
+
57
+ NODE_CLASS_MAPPINGS = {
58
+ "SetUnionControlNetType": SetUnionControlNetType,
59
+ "ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
60
+ }
comfy_extras/nodes_cosmos.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nodes
2
+ import torch
3
+ import comfy.model_management
4
+ import comfy.utils
5
+
6
+
7
+ class EmptyCosmosLatentVideo:
8
+ @classmethod
9
+ def INPUT_TYPES(s):
10
+ return {"required": { "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
11
+ "height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
12
+ "length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
13
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
14
+ RETURN_TYPES = ("LATENT",)
15
+ FUNCTION = "generate"
16
+
17
+ CATEGORY = "latent/video"
18
+
19
+ def generate(self, width, height, length, batch_size=1):
20
+ latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
21
+ return ({"samples": latent}, )
22
+
23
+
24
+ def vae_encode_with_padding(vae, image, width, height, length, padding=0):
25
+ pixels = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
26
+ pixel_len = min(pixels.shape[0], length)
27
+ padded_length = min(length, (((pixel_len - 1) // 8) + 1 + padding) * 8 - 7)
28
+ padded_pixels = torch.ones((padded_length, height, width, 3)) * 0.5
29
+ padded_pixels[:pixel_len] = pixels[:pixel_len]
30
+ latent_len = ((pixel_len - 1) // 8) + 1
31
+ latent_temp = vae.encode(padded_pixels)
32
+ return latent_temp[:, :, :latent_len]
33
+
34
+
35
+ class CosmosImageToVideoLatent:
36
+ @classmethod
37
+ def INPUT_TYPES(s):
38
+ return {"required": {"vae": ("VAE", ),
39
+ "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
40
+ "height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
41
+ "length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
42
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
43
+ },
44
+ "optional": {"start_image": ("IMAGE", ),
45
+ "end_image": ("IMAGE", ),
46
+ }}
47
+
48
+
49
+ RETURN_TYPES = ("LATENT",)
50
+ FUNCTION = "encode"
51
+
52
+ CATEGORY = "conditioning/inpaint"
53
+
54
+ def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
55
+ latent = torch.zeros([1, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
56
+ if start_image is None and end_image is None:
57
+ out_latent = {}
58
+ out_latent["samples"] = latent
59
+ return (out_latent,)
60
+
61
+ mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
62
+
63
+ if start_image is not None:
64
+ latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
65
+ latent[:, :, :latent_temp.shape[-3]] = latent_temp
66
+ mask[:, :, :latent_temp.shape[-3]] *= 0.0
67
+
68
+ if end_image is not None:
69
+ latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
70
+ latent[:, :, -latent_temp.shape[-3]:] = latent_temp
71
+ mask[:, :, -latent_temp.shape[-3]:] *= 0.0
72
+
73
+ out_latent = {}
74
+ out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
75
+ out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
76
+ return (out_latent,)
77
+
78
+
79
+ NODE_CLASS_MAPPINGS = {
80
+ "EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
81
+ "CosmosImageToVideoLatent": CosmosImageToVideoLatent,
82
+ }
comfy_extras/nodes_custom_sampler.py ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import comfy.samplers
2
+ import comfy.sample
3
+ from comfy.k_diffusion import sampling as k_diffusion_sampling
4
+ import latent_preview
5
+ import torch
6
+ import comfy.utils
7
+ import node_helpers
8
+
9
+
10
+ class BasicScheduler:
11
+ @classmethod
12
+ def INPUT_TYPES(s):
13
+ return {"required":
14
+ {"model": ("MODEL",),
15
+ "scheduler": (comfy.samplers.SCHEDULER_NAMES, ),
16
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
17
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
18
+ }
19
+ }
20
+ RETURN_TYPES = ("SIGMAS",)
21
+ CATEGORY = "sampling/custom_sampling/schedulers"
22
+
23
+ FUNCTION = "get_sigmas"
24
+
25
+ def get_sigmas(self, model, scheduler, steps, denoise):
26
+ total_steps = steps
27
+ if denoise < 1.0:
28
+ if denoise <= 0.0:
29
+ return (torch.FloatTensor([]),)
30
+ total_steps = int(steps/denoise)
31
+
32
+ sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu()
33
+ sigmas = sigmas[-(steps + 1):]
34
+ return (sigmas, )
35
+
36
+
37
+ class KarrasScheduler:
38
+ @classmethod
39
+ def INPUT_TYPES(s):
40
+ return {"required":
41
+ {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
42
+ "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
43
+ "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
44
+ "rho": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
45
+ }
46
+ }
47
+ RETURN_TYPES = ("SIGMAS",)
48
+ CATEGORY = "sampling/custom_sampling/schedulers"
49
+
50
+ FUNCTION = "get_sigmas"
51
+
52
+ def get_sigmas(self, steps, sigma_max, sigma_min, rho):
53
+ sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
54
+ return (sigmas, )
55
+
56
+ class ExponentialScheduler:
57
+ @classmethod
58
+ def INPUT_TYPES(s):
59
+ return {"required":
60
+ {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
61
+ "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
62
+ "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
63
+ }
64
+ }
65
+ RETURN_TYPES = ("SIGMAS",)
66
+ CATEGORY = "sampling/custom_sampling/schedulers"
67
+
68
+ FUNCTION = "get_sigmas"
69
+
70
+ def get_sigmas(self, steps, sigma_max, sigma_min):
71
+ sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max)
72
+ return (sigmas, )
73
+
74
+ class PolyexponentialScheduler:
75
+ @classmethod
76
+ def INPUT_TYPES(s):
77
+ return {"required":
78
+ {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
79
+ "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
80
+ "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
81
+ "rho": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
82
+ }
83
+ }
84
+ RETURN_TYPES = ("SIGMAS",)
85
+ CATEGORY = "sampling/custom_sampling/schedulers"
86
+
87
+ FUNCTION = "get_sigmas"
88
+
89
+ def get_sigmas(self, steps, sigma_max, sigma_min, rho):
90
+ sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
91
+ return (sigmas, )
92
+
93
+ class LaplaceScheduler:
94
+ @classmethod
95
+ def INPUT_TYPES(s):
96
+ return {"required":
97
+ {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
98
+ "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
99
+ "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
100
+ "mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}),
101
+ "beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}),
102
+ }
103
+ }
104
+ RETURN_TYPES = ("SIGMAS",)
105
+ CATEGORY = "sampling/custom_sampling/schedulers"
106
+
107
+ FUNCTION = "get_sigmas"
108
+
109
+ def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta):
110
+ sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta)
111
+ return (sigmas, )
112
+
113
+
114
+ class SDTurboScheduler:
115
+ @classmethod
116
+ def INPUT_TYPES(s):
117
+ return {"required":
118
+ {"model": ("MODEL",),
119
+ "steps": ("INT", {"default": 1, "min": 1, "max": 10}),
120
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
121
+ }
122
+ }
123
+ RETURN_TYPES = ("SIGMAS",)
124
+ CATEGORY = "sampling/custom_sampling/schedulers"
125
+
126
+ FUNCTION = "get_sigmas"
127
+
128
+ def get_sigmas(self, model, steps, denoise):
129
+ start_step = 10 - int(10 * denoise)
130
+ timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
131
+ sigmas = model.get_model_object("model_sampling").sigma(timesteps)
132
+ sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
133
+ return (sigmas, )
134
+
135
+ class BetaSamplingScheduler:
136
+ @classmethod
137
+ def INPUT_TYPES(s):
138
+ return {"required":
139
+ {"model": ("MODEL",),
140
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
141
+ "alpha": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}),
142
+ "beta": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}),
143
+ }
144
+ }
145
+ RETURN_TYPES = ("SIGMAS",)
146
+ CATEGORY = "sampling/custom_sampling/schedulers"
147
+
148
+ FUNCTION = "get_sigmas"
149
+
150
+ def get_sigmas(self, model, steps, alpha, beta):
151
+ sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta)
152
+ return (sigmas, )
153
+
154
+ class VPScheduler:
155
+ @classmethod
156
+ def INPUT_TYPES(s):
157
+ return {"required":
158
+ {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
159
+ "beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), #TODO: fix default values
160
+ "beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
161
+ "eps_s": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step":0.0001, "round": False}),
162
+ }
163
+ }
164
+ RETURN_TYPES = ("SIGMAS",)
165
+ CATEGORY = "sampling/custom_sampling/schedulers"
166
+
167
+ FUNCTION = "get_sigmas"
168
+
169
+ def get_sigmas(self, steps, beta_d, beta_min, eps_s):
170
+ sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s)
171
+ return (sigmas, )
172
+
173
+ class SplitSigmas:
174
+ @classmethod
175
+ def INPUT_TYPES(s):
176
+ return {"required":
177
+ {"sigmas": ("SIGMAS", ),
178
+ "step": ("INT", {"default": 0, "min": 0, "max": 10000}),
179
+ }
180
+ }
181
+ RETURN_TYPES = ("SIGMAS","SIGMAS")
182
+ RETURN_NAMES = ("high_sigmas", "low_sigmas")
183
+ CATEGORY = "sampling/custom_sampling/sigmas"
184
+
185
+ FUNCTION = "get_sigmas"
186
+
187
+ def get_sigmas(self, sigmas, step):
188
+ sigmas1 = sigmas[:step + 1]
189
+ sigmas2 = sigmas[step:]
190
+ return (sigmas1, sigmas2)
191
+
192
+ class SplitSigmasDenoise:
193
+ @classmethod
194
+ def INPUT_TYPES(s):
195
+ return {"required":
196
+ {"sigmas": ("SIGMAS", ),
197
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
198
+ }
199
+ }
200
+ RETURN_TYPES = ("SIGMAS","SIGMAS")
201
+ RETURN_NAMES = ("high_sigmas", "low_sigmas")
202
+ CATEGORY = "sampling/custom_sampling/sigmas"
203
+
204
+ FUNCTION = "get_sigmas"
205
+
206
+ def get_sigmas(self, sigmas, denoise):
207
+ steps = max(sigmas.shape[-1] - 1, 0)
208
+ total_steps = round(steps * denoise)
209
+ sigmas1 = sigmas[:-(total_steps)]
210
+ sigmas2 = sigmas[-(total_steps + 1):]
211
+ return (sigmas1, sigmas2)
212
+
213
+ class FlipSigmas:
214
+ @classmethod
215
+ def INPUT_TYPES(s):
216
+ return {"required":
217
+ {"sigmas": ("SIGMAS", ),
218
+ }
219
+ }
220
+ RETURN_TYPES = ("SIGMAS",)
221
+ CATEGORY = "sampling/custom_sampling/sigmas"
222
+
223
+ FUNCTION = "get_sigmas"
224
+
225
+ def get_sigmas(self, sigmas):
226
+ if len(sigmas) == 0:
227
+ return (sigmas,)
228
+
229
+ sigmas = sigmas.flip(0)
230
+ if sigmas[0] == 0:
231
+ sigmas[0] = 0.0001
232
+ return (sigmas,)
233
+
234
+ class SetFirstSigma:
235
+ @classmethod
236
+ def INPUT_TYPES(s):
237
+ return {"required":
238
+ {"sigmas": ("SIGMAS", ),
239
+ "sigma": ("FLOAT", {"default": 136.0, "min": 0.0, "max": 20000.0, "step": 0.001, "round": False}),
240
+ }
241
+ }
242
+ RETURN_TYPES = ("SIGMAS",)
243
+ CATEGORY = "sampling/custom_sampling/sigmas"
244
+
245
+ FUNCTION = "set_first_sigma"
246
+
247
+ def set_first_sigma(self, sigmas, sigma):
248
+ sigmas = sigmas.clone()
249
+ sigmas[0] = sigma
250
+ return (sigmas, )
251
+
252
+ class KSamplerSelect:
253
+ @classmethod
254
+ def INPUT_TYPES(s):
255
+ return {"required":
256
+ {"sampler_name": (comfy.samplers.SAMPLER_NAMES, ),
257
+ }
258
+ }
259
+ RETURN_TYPES = ("SAMPLER",)
260
+ CATEGORY = "sampling/custom_sampling/samplers"
261
+
262
+ FUNCTION = "get_sampler"
263
+
264
+ def get_sampler(self, sampler_name):
265
+ sampler = comfy.samplers.sampler_object(sampler_name)
266
+ return (sampler, )
267
+
268
+ class SamplerDPMPP_3M_SDE:
269
+ @classmethod
270
+ def INPUT_TYPES(s):
271
+ return {"required":
272
+ {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
273
+ "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
274
+ "noise_device": (['gpu', 'cpu'], ),
275
+ }
276
+ }
277
+ RETURN_TYPES = ("SAMPLER",)
278
+ CATEGORY = "sampling/custom_sampling/samplers"
279
+
280
+ FUNCTION = "get_sampler"
281
+
282
+ def get_sampler(self, eta, s_noise, noise_device):
283
+ if noise_device == 'cpu':
284
+ sampler_name = "dpmpp_3m_sde"
285
+ else:
286
+ sampler_name = "dpmpp_3m_sde_gpu"
287
+ sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise})
288
+ return (sampler, )
289
+
290
+ class SamplerDPMPP_2M_SDE:
291
+ @classmethod
292
+ def INPUT_TYPES(s):
293
+ return {"required":
294
+ {"solver_type": (['midpoint', 'heun'], ),
295
+ "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
296
+ "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
297
+ "noise_device": (['gpu', 'cpu'], ),
298
+ }
299
+ }
300
+ RETURN_TYPES = ("SAMPLER",)
301
+ CATEGORY = "sampling/custom_sampling/samplers"
302
+
303
+ FUNCTION = "get_sampler"
304
+
305
+ def get_sampler(self, solver_type, eta, s_noise, noise_device):
306
+ if noise_device == 'cpu':
307
+ sampler_name = "dpmpp_2m_sde"
308
+ else:
309
+ sampler_name = "dpmpp_2m_sde_gpu"
310
+ sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
311
+ return (sampler, )
312
+
313
+
314
+ class SamplerDPMPP_SDE:
315
+ @classmethod
316
+ def INPUT_TYPES(s):
317
+ return {"required":
318
+ {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
319
+ "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
320
+ "r": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
321
+ "noise_device": (['gpu', 'cpu'], ),
322
+ }
323
+ }
324
+ RETURN_TYPES = ("SAMPLER",)
325
+ CATEGORY = "sampling/custom_sampling/samplers"
326
+
327
+ FUNCTION = "get_sampler"
328
+
329
+ def get_sampler(self, eta, s_noise, r, noise_device):
330
+ if noise_device == 'cpu':
331
+ sampler_name = "dpmpp_sde"
332
+ else:
333
+ sampler_name = "dpmpp_sde_gpu"
334
+ sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
335
+ return (sampler, )
336
+
337
+ class SamplerDPMPP_2S_Ancestral:
338
+ @classmethod
339
+ def INPUT_TYPES(s):
340
+ return {"required":
341
+ {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
342
+ "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
343
+ }
344
+ }
345
+ RETURN_TYPES = ("SAMPLER",)
346
+ CATEGORY = "sampling/custom_sampling/samplers"
347
+
348
+ FUNCTION = "get_sampler"
349
+
350
+ def get_sampler(self, eta, s_noise):
351
+ sampler = comfy.samplers.ksampler("dpmpp_2s_ancestral", {"eta": eta, "s_noise": s_noise})
352
+ return (sampler, )
353
+
354
+ class SamplerEulerAncestral:
355
+ @classmethod
356
+ def INPUT_TYPES(s):
357
+ return {"required":
358
+ {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
359
+ "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
360
+ }
361
+ }
362
+ RETURN_TYPES = ("SAMPLER",)
363
+ CATEGORY = "sampling/custom_sampling/samplers"
364
+
365
+ FUNCTION = "get_sampler"
366
+
367
+ def get_sampler(self, eta, s_noise):
368
+ sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise})
369
+ return (sampler, )
370
+
371
+ class SamplerEulerAncestralCFGPP:
372
+ @classmethod
373
+ def INPUT_TYPES(s):
374
+ return {
375
+ "required": {
376
+ "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step":0.01, "round": False}),
377
+ "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step":0.01, "round": False}),
378
+ }}
379
+ RETURN_TYPES = ("SAMPLER",)
380
+ CATEGORY = "sampling/custom_sampling/samplers"
381
+
382
+ FUNCTION = "get_sampler"
383
+
384
+ def get_sampler(self, eta, s_noise):
385
+ sampler = comfy.samplers.ksampler(
386
+ "euler_ancestral_cfg_pp",
387
+ {"eta": eta, "s_noise": s_noise})
388
+ return (sampler, )
389
+
390
+ class SamplerLMS:
391
+ @classmethod
392
+ def INPUT_TYPES(s):
393
+ return {"required":
394
+ {"order": ("INT", {"default": 4, "min": 1, "max": 100}),
395
+ }
396
+ }
397
+ RETURN_TYPES = ("SAMPLER",)
398
+ CATEGORY = "sampling/custom_sampling/samplers"
399
+
400
+ FUNCTION = "get_sampler"
401
+
402
+ def get_sampler(self, order):
403
+ sampler = comfy.samplers.ksampler("lms", {"order": order})
404
+ return (sampler, )
405
+
406
+ class SamplerDPMAdaptative:
407
+ @classmethod
408
+ def INPUT_TYPES(s):
409
+ return {"required":
410
+ {"order": ("INT", {"default": 3, "min": 2, "max": 3}),
411
+ "rtol": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
412
+ "atol": ("FLOAT", {"default": 0.0078, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
413
+ "h_init": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
414
+ "pcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
415
+ "icoeff": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
416
+ "dcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
417
+ "accept_safety": ("FLOAT", {"default": 0.81, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
418
+ "eta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
419
+ "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
420
+ }
421
+ }
422
+ RETURN_TYPES = ("SAMPLER",)
423
+ CATEGORY = "sampling/custom_sampling/samplers"
424
+
425
+ FUNCTION = "get_sampler"
426
+
427
+ def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise):
428
+ sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff,
429
+ "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta,
430
+ "s_noise":s_noise })
431
+ return (sampler, )
432
+
433
+ class Noise_EmptyNoise:
434
+ def __init__(self):
435
+ self.seed = 0
436
+
437
+ def generate_noise(self, input_latent):
438
+ latent_image = input_latent["samples"]
439
+ return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
440
+
441
+
442
+ class Noise_RandomNoise:
443
+ def __init__(self, seed):
444
+ self.seed = seed
445
+
446
+ def generate_noise(self, input_latent):
447
+ latent_image = input_latent["samples"]
448
+ batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None
449
+ return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds)
450
+
451
+ class SamplerCustom:
452
+ @classmethod
453
+ def INPUT_TYPES(s):
454
+ return {"required":
455
+ {"model": ("MODEL",),
456
+ "add_noise": ("BOOLEAN", {"default": True}),
457
+ "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
458
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
459
+ "positive": ("CONDITIONING", ),
460
+ "negative": ("CONDITIONING", ),
461
+ "sampler": ("SAMPLER", ),
462
+ "sigmas": ("SIGMAS", ),
463
+ "latent_image": ("LATENT", ),
464
+ }
465
+ }
466
+
467
+ RETURN_TYPES = ("LATENT","LATENT")
468
+ RETURN_NAMES = ("output", "denoised_output")
469
+
470
+ FUNCTION = "sample"
471
+
472
+ CATEGORY = "sampling/custom_sampling"
473
+
474
+ def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image):
475
+ latent = latent_image
476
+ latent_image = latent["samples"]
477
+ latent = latent.copy()
478
+ latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
479
+ latent["samples"] = latent_image
480
+
481
+ if not add_noise:
482
+ noise = Noise_EmptyNoise().generate_noise(latent)
483
+ else:
484
+ noise = Noise_RandomNoise(noise_seed).generate_noise(latent)
485
+
486
+ noise_mask = None
487
+ if "noise_mask" in latent:
488
+ noise_mask = latent["noise_mask"]
489
+
490
+ x0_output = {}
491
+ callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
492
+
493
+ disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
494
+ samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
495
+
496
+ out = latent.copy()
497
+ out["samples"] = samples
498
+ if "x0" in x0_output:
499
+ out_denoised = latent.copy()
500
+ out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu())
501
+ else:
502
+ out_denoised = out
503
+ return (out, out_denoised)
504
+
505
+ class Guider_Basic(comfy.samplers.CFGGuider):
506
+ def set_conds(self, positive):
507
+ self.inner_set_conds({"positive": positive})
508
+
509
+ class BasicGuider:
510
+ @classmethod
511
+ def INPUT_TYPES(s):
512
+ return {"required":
513
+ {"model": ("MODEL",),
514
+ "conditioning": ("CONDITIONING", ),
515
+ }
516
+ }
517
+
518
+ RETURN_TYPES = ("GUIDER",)
519
+
520
+ FUNCTION = "get_guider"
521
+ CATEGORY = "sampling/custom_sampling/guiders"
522
+
523
+ def get_guider(self, model, conditioning):
524
+ guider = Guider_Basic(model)
525
+ guider.set_conds(conditioning)
526
+ return (guider,)
527
+
528
+ class CFGGuider:
529
+ @classmethod
530
+ def INPUT_TYPES(s):
531
+ return {"required":
532
+ {"model": ("MODEL",),
533
+ "positive": ("CONDITIONING", ),
534
+ "negative": ("CONDITIONING", ),
535
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
536
+ }
537
+ }
538
+
539
+ RETURN_TYPES = ("GUIDER",)
540
+
541
+ FUNCTION = "get_guider"
542
+ CATEGORY = "sampling/custom_sampling/guiders"
543
+
544
+ def get_guider(self, model, positive, negative, cfg):
545
+ guider = comfy.samplers.CFGGuider(model)
546
+ guider.set_conds(positive, negative)
547
+ guider.set_cfg(cfg)
548
+ return (guider,)
549
+
550
+ class Guider_DualCFG(comfy.samplers.CFGGuider):
551
+ def set_cfg(self, cfg1, cfg2):
552
+ self.cfg1 = cfg1
553
+ self.cfg2 = cfg2
554
+
555
+ def set_conds(self, positive, middle, negative):
556
+ middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"})
557
+ self.inner_set_conds({"positive": positive, "middle": middle, "negative": negative})
558
+
559
+ def predict_noise(self, x, timestep, model_options={}, seed=None):
560
+ negative_cond = self.conds.get("negative", None)
561
+ middle_cond = self.conds.get("middle", None)
562
+
563
+ out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, self.conds.get("positive", None)], x, timestep, model_options)
564
+ return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
565
+
566
+ class DualCFGGuider:
567
+ @classmethod
568
+ def INPUT_TYPES(s):
569
+ return {"required":
570
+ {"model": ("MODEL",),
571
+ "cond1": ("CONDITIONING", ),
572
+ "cond2": ("CONDITIONING", ),
573
+ "negative": ("CONDITIONING", ),
574
+ "cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
575
+ "cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
576
+ }
577
+ }
578
+
579
+ RETURN_TYPES = ("GUIDER",)
580
+
581
+ FUNCTION = "get_guider"
582
+ CATEGORY = "sampling/custom_sampling/guiders"
583
+
584
+ def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative):
585
+ guider = Guider_DualCFG(model)
586
+ guider.set_conds(cond1, cond2, negative)
587
+ guider.set_cfg(cfg_conds, cfg_cond2_negative)
588
+ return (guider,)
589
+
590
+ class DisableNoise:
591
+ @classmethod
592
+ def INPUT_TYPES(s):
593
+ return {"required":{
594
+ }
595
+ }
596
+
597
+ RETURN_TYPES = ("NOISE",)
598
+ FUNCTION = "get_noise"
599
+ CATEGORY = "sampling/custom_sampling/noise"
600
+
601
+ def get_noise(self):
602
+ return (Noise_EmptyNoise(),)
603
+
604
+
605
+ class RandomNoise(DisableNoise):
606
+ @classmethod
607
+ def INPUT_TYPES(s):
608
+ return {"required":{
609
+ "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
610
+ }
611
+ }
612
+
613
+ def get_noise(self, noise_seed):
614
+ return (Noise_RandomNoise(noise_seed),)
615
+
616
+
617
+ class SamplerCustomAdvanced:
618
+ @classmethod
619
+ def INPUT_TYPES(s):
620
+ return {"required":
621
+ {"noise": ("NOISE", ),
622
+ "guider": ("GUIDER", ),
623
+ "sampler": ("SAMPLER", ),
624
+ "sigmas": ("SIGMAS", ),
625
+ "latent_image": ("LATENT", ),
626
+ }
627
+ }
628
+
629
+ RETURN_TYPES = ("LATENT","LATENT")
630
+ RETURN_NAMES = ("output", "denoised_output")
631
+
632
+ FUNCTION = "sample"
633
+
634
+ CATEGORY = "sampling/custom_sampling"
635
+
636
+ def sample(self, noise, guider, sampler, sigmas, latent_image):
637
+ latent = latent_image
638
+ latent_image = latent["samples"]
639
+ latent = latent.copy()
640
+ latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image)
641
+ latent["samples"] = latent_image
642
+
643
+ noise_mask = None
644
+ if "noise_mask" in latent:
645
+ noise_mask = latent["noise_mask"]
646
+
647
+ x0_output = {}
648
+ callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output)
649
+
650
+ disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
651
+ samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed)
652
+ samples = samples.to(comfy.model_management.intermediate_device())
653
+
654
+ out = latent.copy()
655
+ out["samples"] = samples
656
+ if "x0" in x0_output:
657
+ out_denoised = latent.copy()
658
+ out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
659
+ else:
660
+ out_denoised = out
661
+ return (out, out_denoised)
662
+
663
+ class AddNoise:
664
+ @classmethod
665
+ def INPUT_TYPES(s):
666
+ return {"required":
667
+ {"model": ("MODEL",),
668
+ "noise": ("NOISE", ),
669
+ "sigmas": ("SIGMAS", ),
670
+ "latent_image": ("LATENT", ),
671
+ }
672
+ }
673
+
674
+ RETURN_TYPES = ("LATENT",)
675
+
676
+ FUNCTION = "add_noise"
677
+
678
+ CATEGORY = "_for_testing/custom_sampling/noise"
679
+
680
+ def add_noise(self, model, noise, sigmas, latent_image):
681
+ if len(sigmas) == 0:
682
+ return latent_image
683
+
684
+ latent = latent_image
685
+ latent_image = latent["samples"]
686
+
687
+ noisy = noise.generate_noise(latent)
688
+
689
+ model_sampling = model.get_model_object("model_sampling")
690
+ process_latent_out = model.get_model_object("process_latent_out")
691
+ process_latent_in = model.get_model_object("process_latent_in")
692
+
693
+ if len(sigmas) > 1:
694
+ scale = torch.abs(sigmas[0] - sigmas[-1])
695
+ else:
696
+ scale = sigmas[0]
697
+
698
+ if torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
699
+ latent_image = process_latent_in(latent_image)
700
+ noisy = model_sampling.noise_scaling(scale, noisy, latent_image)
701
+ noisy = process_latent_out(noisy)
702
+ noisy = torch.nan_to_num(noisy, nan=0.0, posinf=0.0, neginf=0.0)
703
+
704
+ out = latent.copy()
705
+ out["samples"] = noisy
706
+ return (out,)
707
+
708
+
709
+ NODE_CLASS_MAPPINGS = {
710
+ "SamplerCustom": SamplerCustom,
711
+ "BasicScheduler": BasicScheduler,
712
+ "KarrasScheduler": KarrasScheduler,
713
+ "ExponentialScheduler": ExponentialScheduler,
714
+ "PolyexponentialScheduler": PolyexponentialScheduler,
715
+ "LaplaceScheduler": LaplaceScheduler,
716
+ "VPScheduler": VPScheduler,
717
+ "BetaSamplingScheduler": BetaSamplingScheduler,
718
+ "SDTurboScheduler": SDTurboScheduler,
719
+ "KSamplerSelect": KSamplerSelect,
720
+ "SamplerEulerAncestral": SamplerEulerAncestral,
721
+ "SamplerEulerAncestralCFGPP": SamplerEulerAncestralCFGPP,
722
+ "SamplerLMS": SamplerLMS,
723
+ "SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE,
724
+ "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
725
+ "SamplerDPMPP_SDE": SamplerDPMPP_SDE,
726
+ "SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
727
+ "SamplerDPMAdaptative": SamplerDPMAdaptative,
728
+ "SplitSigmas": SplitSigmas,
729
+ "SplitSigmasDenoise": SplitSigmasDenoise,
730
+ "FlipSigmas": FlipSigmas,
731
+ "SetFirstSigma": SetFirstSigma,
732
+
733
+ "CFGGuider": CFGGuider,
734
+ "DualCFGGuider": DualCFGGuider,
735
+ "BasicGuider": BasicGuider,
736
+ "RandomNoise": RandomNoise,
737
+ "DisableNoise": DisableNoise,
738
+ "AddNoise": AddNoise,
739
+ "SamplerCustomAdvanced": SamplerCustomAdvanced,
740
+ }
741
+
742
+ NODE_DISPLAY_NAME_MAPPINGS = {
743
+ "SamplerEulerAncestralCFGPP": "SamplerEulerAncestralCFG++",
744
+ }
comfy_extras/nodes_differential_diffusion.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code adapted from https://github.com/exx8/differential-diffusion
2
+
3
+ import torch
4
+
5
+ class DifferentialDiffusion():
6
+ @classmethod
7
+ def INPUT_TYPES(s):
8
+ return {"required": {"model": ("MODEL", ),
9
+ }}
10
+ RETURN_TYPES = ("MODEL",)
11
+ FUNCTION = "apply"
12
+ CATEGORY = "_for_testing"
13
+ INIT = False
14
+
15
+ def apply(self, model):
16
+ model = model.clone()
17
+ model.set_model_denoise_mask_function(self.forward)
18
+ return (model,)
19
+
20
+ def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict):
21
+ model = extra_options["model"]
22
+ step_sigmas = extra_options["sigmas"]
23
+ sigma_to = model.inner_model.model_sampling.sigma_min
24
+ if step_sigmas[-1] > sigma_to:
25
+ sigma_to = step_sigmas[-1]
26
+ sigma_from = step_sigmas[0]
27
+
28
+ ts_from = model.inner_model.model_sampling.timestep(sigma_from)
29
+ ts_to = model.inner_model.model_sampling.timestep(sigma_to)
30
+ current_ts = model.inner_model.model_sampling.timestep(sigma[0])
31
+
32
+ threshold = (current_ts - ts_to) / (ts_from - ts_to)
33
+
34
+ return (denoise_mask >= threshold).to(denoise_mask.dtype)
35
+
36
+
37
+ NODE_CLASS_MAPPINGS = {
38
+ "DifferentialDiffusion": DifferentialDiffusion,
39
+ }
40
+ NODE_DISPLAY_NAME_MAPPINGS = {
41
+ "DifferentialDiffusion": "Differential Diffusion",
42
+ }
comfy_extras/nodes_flux.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import node_helpers
2
+
3
+ class CLIPTextEncodeFlux:
4
+ @classmethod
5
+ def INPUT_TYPES(s):
6
+ return {"required": {
7
+ "clip": ("CLIP", ),
8
+ "clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
9
+ "t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
10
+ "guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
11
+ }}
12
+ RETURN_TYPES = ("CONDITIONING",)
13
+ FUNCTION = "encode"
14
+
15
+ CATEGORY = "advanced/conditioning/flux"
16
+
17
+ def encode(self, clip, clip_l, t5xxl, guidance):
18
+ tokens = clip.tokenize(clip_l)
19
+ tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
20
+
21
+ return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), )
22
+
23
+ class FluxGuidance:
24
+ @classmethod
25
+ def INPUT_TYPES(s):
26
+ return {"required": {
27
+ "conditioning": ("CONDITIONING", ),
28
+ "guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
29
+ }}
30
+
31
+ RETURN_TYPES = ("CONDITIONING",)
32
+ FUNCTION = "append"
33
+
34
+ CATEGORY = "advanced/conditioning/flux"
35
+
36
+ def append(self, conditioning, guidance):
37
+ c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
38
+ return (c, )
39
+
40
+
41
+ class FluxDisableGuidance:
42
+ @classmethod
43
+ def INPUT_TYPES(s):
44
+ return {"required": {
45
+ "conditioning": ("CONDITIONING", ),
46
+ }}
47
+
48
+ RETURN_TYPES = ("CONDITIONING",)
49
+ FUNCTION = "append"
50
+
51
+ CATEGORY = "advanced/conditioning/flux"
52
+ DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models"
53
+
54
+ def append(self, conditioning):
55
+ c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
56
+ return (c, )
57
+
58
+
59
+ NODE_CLASS_MAPPINGS = {
60
+ "CLIPTextEncodeFlux": CLIPTextEncodeFlux,
61
+ "FluxGuidance": FluxGuidance,
62
+ "FluxDisableGuidance": FluxDisableGuidance,
63
+ }
comfy_extras/nodes_freelunch.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License)
2
+
3
+ import torch
4
+ import logging
5
+
6
+ def Fourier_filter(x, threshold, scale):
7
+ # FFT
8
+ x_freq = torch.fft.fftn(x.float(), dim=(-2, -1))
9
+ x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
10
+
11
+ B, C, H, W = x_freq.shape
12
+ mask = torch.ones((B, C, H, W), device=x.device)
13
+
14
+ crow, ccol = H // 2, W //2
15
+ mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
16
+ x_freq = x_freq * mask
17
+
18
+ # IFFT
19
+ x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
20
+ x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real
21
+
22
+ return x_filtered.to(x.dtype)
23
+
24
+
25
+ class FreeU:
26
+ @classmethod
27
+ def INPUT_TYPES(s):
28
+ return {"required": { "model": ("MODEL",),
29
+ "b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}),
30
+ "b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}),
31
+ "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
32
+ "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
33
+ }}
34
+ RETURN_TYPES = ("MODEL",)
35
+ FUNCTION = "patch"
36
+
37
+ CATEGORY = "model_patches/unet"
38
+
39
+ def patch(self, model, b1, b2, s1, s2):
40
+ model_channels = model.model.model_config.unet_config["model_channels"]
41
+ scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
42
+ on_cpu_devices = {}
43
+
44
+ def output_block_patch(h, hsp, transformer_options):
45
+ scale = scale_dict.get(int(h.shape[1]), None)
46
+ if scale is not None:
47
+ h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0]
48
+ if hsp.device not in on_cpu_devices:
49
+ try:
50
+ hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
51
+ except:
52
+ logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
53
+ on_cpu_devices[hsp.device] = True
54
+ hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
55
+ else:
56
+ hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
57
+
58
+ return h, hsp
59
+
60
+ m = model.clone()
61
+ m.set_model_output_block_patch(output_block_patch)
62
+ return (m, )
63
+
64
+ class FreeU_V2:
65
+ @classmethod
66
+ def INPUT_TYPES(s):
67
+ return {"required": { "model": ("MODEL",),
68
+ "b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}),
69
+ "b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}),
70
+ "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
71
+ "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
72
+ }}
73
+ RETURN_TYPES = ("MODEL",)
74
+ FUNCTION = "patch"
75
+
76
+ CATEGORY = "model_patches/unet"
77
+
78
+ def patch(self, model, b1, b2, s1, s2):
79
+ model_channels = model.model.model_config.unet_config["model_channels"]
80
+ scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
81
+ on_cpu_devices = {}
82
+
83
+ def output_block_patch(h, hsp, transformer_options):
84
+ scale = scale_dict.get(int(h.shape[1]), None)
85
+ if scale is not None:
86
+ hidden_mean = h.mean(1).unsqueeze(1)
87
+ B = hidden_mean.shape[0]
88
+ hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
89
+ hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
90
+ hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
91
+
92
+ h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * ((scale[0] - 1 ) * hidden_mean + 1)
93
+
94
+ if hsp.device not in on_cpu_devices:
95
+ try:
96
+ hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
97
+ except:
98
+ logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
99
+ on_cpu_devices[hsp.device] = True
100
+ hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
101
+ else:
102
+ hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
103
+
104
+ return h, hsp
105
+
106
+ m = model.clone()
107
+ m.set_model_output_block_patch(output_block_patch)
108
+ return (m, )
109
+
110
+ NODE_CLASS_MAPPINGS = {
111
+ "FreeU": FreeU,
112
+ "FreeU_V2": FreeU_V2,
113
+ }
comfy_extras/nodes_gits.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from https://github.com/zju-pi/diff-sampler/tree/main/gits-main
2
+ import numpy as np
3
+ import torch
4
+
5
+ def loglinear_interp(t_steps, num_steps):
6
+ """
7
+ Performs log-linear interpolation of a given array of decreasing numbers.
8
+ """
9
+ xs = np.linspace(0, 1, len(t_steps))
10
+ ys = np.log(t_steps[::-1])
11
+
12
+ new_xs = np.linspace(0, 1, num_steps)
13
+ new_ys = np.interp(new_xs, xs, ys)
14
+
15
+ interped_ys = np.exp(new_ys)[::-1].copy()
16
+ return interped_ys
17
+
18
+ NOISE_LEVELS = {
19
+ 0.80: [
20
+ [14.61464119, 7.49001646, 0.02916753],
21
+ [14.61464119, 11.54541874, 6.77309084, 0.02916753],
22
+ [14.61464119, 11.54541874, 7.49001646, 3.07277966, 0.02916753],
23
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 2.05039096, 0.02916753],
24
+ [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 2.05039096, 0.02916753],
25
+ [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
26
+ [14.61464119, 12.96784878, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
27
+ [14.61464119, 13.76078796, 12.2308979, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
28
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
29
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
30
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753],
31
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753],
32
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753],
33
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753],
34
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.1956799, 1.98035145, 0.86115354, 0.02916753],
35
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.1956799, 1.98035145, 0.86115354, 0.02916753],
36
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.07277966, 1.84880662, 0.83188516, 0.02916753],
37
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.88507891, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.07277966, 1.84880662, 0.83188516, 0.02916753],
38
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.88507891, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.75677586, 2.84484982, 1.78698075, 0.803307, 0.02916753],
39
+ ],
40
+ 0.85: [
41
+ [14.61464119, 7.49001646, 0.02916753],
42
+ [14.61464119, 7.49001646, 1.84880662, 0.02916753],
43
+ [14.61464119, 11.54541874, 6.77309084, 1.56271636, 0.02916753],
44
+ [14.61464119, 11.54541874, 7.11996698, 3.07277966, 1.24153244, 0.02916753],
45
+ [14.61464119, 11.54541874, 7.49001646, 5.09240818, 2.84484982, 0.95350921, 0.02916753],
46
+ [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.09240818, 2.84484982, 0.95350921, 0.02916753],
47
+ [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.58536053, 3.1956799, 1.84880662, 0.803307, 0.02916753],
48
+ [14.61464119, 12.96784878, 11.54541874, 8.75849152, 7.49001646, 5.58536053, 3.1956799, 1.84880662, 0.803307, 0.02916753],
49
+ [14.61464119, 12.96784878, 11.54541874, 8.75849152, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
50
+ [14.61464119, 13.76078796, 12.2308979, 10.90732002, 8.75849152, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
51
+ [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
52
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
53
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
54
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.60512662, 2.6383388, 1.56271636, 0.72133851, 0.02916753],
55
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
56
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
57
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
58
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
59
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.88507891, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
60
+ ],
61
+ 0.90: [
62
+ [14.61464119, 6.77309084, 0.02916753],
63
+ [14.61464119, 7.49001646, 1.56271636, 0.02916753],
64
+ [14.61464119, 7.49001646, 3.07277966, 0.95350921, 0.02916753],
65
+ [14.61464119, 7.49001646, 4.86714602, 2.54230714, 0.89115214, 0.02916753],
66
+ [14.61464119, 11.54541874, 7.49001646, 4.86714602, 2.54230714, 0.89115214, 0.02916753],
67
+ [14.61464119, 11.54541874, 7.49001646, 5.09240818, 3.07277966, 1.61558151, 0.69515091, 0.02916753],
68
+ [14.61464119, 12.2308979, 8.75849152, 7.11996698, 4.86714602, 3.07277966, 1.61558151, 0.69515091, 0.02916753],
69
+ [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 2.95596409, 1.61558151, 0.69515091, 0.02916753],
70
+ [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.24153244, 0.57119018, 0.02916753],
71
+ [14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.24153244, 0.57119018, 0.02916753],
72
+ [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.24153244, 0.57119018, 0.02916753],
73
+ [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.75677586, 2.84484982, 1.84880662, 1.08895338, 0.52423614, 0.02916753],
74
+ [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.75677586, 2.84484982, 1.84880662, 1.08895338, 0.52423614, 0.02916753],
75
+ [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.45427561, 3.32507086, 2.45070267, 1.61558151, 0.95350921, 0.45573691, 0.02916753],
76
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.45427561, 3.32507086, 2.45070267, 1.61558151, 0.95350921, 0.45573691, 0.02916753],
77
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.91689563, 3.07277966, 2.27973175, 1.56271636, 0.95350921, 0.45573691, 0.02916753],
78
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.91689563, 3.07277966, 2.27973175, 1.56271636, 0.95350921, 0.45573691, 0.02916753],
79
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.91689563, 3.07277966, 2.27973175, 1.56271636, 0.95350921, 0.45573691, 0.02916753],
80
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.60512662, 2.95596409, 2.19988537, 1.51179266, 0.89115214, 0.43325692, 0.02916753],
81
+ ],
82
+ 0.95: [
83
+ [14.61464119, 6.77309084, 0.02916753],
84
+ [14.61464119, 6.77309084, 1.56271636, 0.02916753],
85
+ [14.61464119, 7.49001646, 2.84484982, 0.89115214, 0.02916753],
86
+ [14.61464119, 7.49001646, 4.86714602, 2.36326075, 0.803307, 0.02916753],
87
+ [14.61464119, 7.49001646, 4.86714602, 2.95596409, 1.56271636, 0.64427125, 0.02916753],
88
+ [14.61464119, 11.54541874, 7.49001646, 4.86714602, 2.95596409, 1.56271636, 0.64427125, 0.02916753],
89
+ [14.61464119, 11.54541874, 7.49001646, 4.86714602, 3.07277966, 1.91321158, 1.08895338, 0.50118381, 0.02916753],
90
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.07277966, 1.91321158, 1.08895338, 0.50118381, 0.02916753],
91
+ [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.07277966, 1.91321158, 1.08895338, 0.50118381, 0.02916753],
92
+ [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.41535246, 0.803307, 0.38853383, 0.02916753],
93
+ [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.46139455, 2.6383388, 1.84880662, 1.24153244, 0.72133851, 0.34370604, 0.02916753],
94
+ [14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.46139455, 2.6383388, 1.84880662, 1.24153244, 0.72133851, 0.34370604, 0.02916753],
95
+ [14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 6.14220476, 4.86714602, 3.75677586, 2.95596409, 2.19988537, 1.56271636, 1.05362725, 0.64427125, 0.32104823, 0.02916753],
96
+ [14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 6.44769001, 5.58536053, 4.65472794, 3.60512662, 2.95596409, 2.19988537, 1.56271636, 1.05362725, 0.64427125, 0.32104823, 0.02916753],
97
+ [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.65472794, 3.60512662, 2.95596409, 2.19988537, 1.56271636, 1.05362725, 0.64427125, 0.32104823, 0.02916753],
98
+ [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.65472794, 3.75677586, 3.07277966, 2.45070267, 1.78698075, 1.24153244, 0.83188516, 0.50118381, 0.22545385, 0.02916753],
99
+ [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.60512662, 2.95596409, 2.36326075, 1.72759056, 1.24153244, 0.83188516, 0.50118381, 0.22545385, 0.02916753],
100
+ [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.60512662, 2.95596409, 2.36326075, 1.72759056, 1.24153244, 0.83188516, 0.50118381, 0.22545385, 0.02916753],
101
+ [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.75677586, 3.07277966, 2.45070267, 1.91321158, 1.46270394, 1.05362725, 0.72133851, 0.43325692, 0.19894916, 0.02916753],
102
+ ],
103
+ 1.00: [
104
+ [14.61464119, 1.56271636, 0.02916753],
105
+ [14.61464119, 6.77309084, 0.95350921, 0.02916753],
106
+ [14.61464119, 6.77309084, 2.36326075, 0.803307, 0.02916753],
107
+ [14.61464119, 7.11996698, 3.07277966, 1.56271636, 0.59516323, 0.02916753],
108
+ [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.41535246, 0.57119018, 0.02916753],
109
+ [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.86115354, 0.38853383, 0.02916753],
110
+ [14.61464119, 11.54541874, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.86115354, 0.38853383, 0.02916753],
111
+ [14.61464119, 11.54541874, 7.49001646, 4.86714602, 3.07277966, 1.98035145, 1.24153244, 0.72133851, 0.34370604, 0.02916753],
112
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.07277966, 1.98035145, 1.24153244, 0.72133851, 0.34370604, 0.02916753],
113
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.27973175, 1.51179266, 0.95350921, 0.54755926, 0.25053367, 0.02916753],
114
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.36326075, 1.61558151, 1.08895338, 0.72133851, 0.41087446, 0.17026083, 0.02916753],
115
+ [14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.36326075, 1.61558151, 1.08895338, 0.72133851, 0.41087446, 0.17026083, 0.02916753],
116
+ [14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.12350607, 1.56271636, 1.08895338, 0.72133851, 0.41087446, 0.17026083, 0.02916753],
117
+ [14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.19988537, 1.61558151, 1.162866, 0.803307, 0.50118381, 0.27464288, 0.09824532, 0.02916753],
118
+ [14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.75677586, 3.07277966, 2.45070267, 1.84880662, 1.36964464, 1.01931262, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
119
+ [14.61464119, 11.54541874, 8.75849152, 7.49001646, 6.14220476, 5.09240818, 4.26497746, 3.46139455, 2.84484982, 2.19988537, 1.67050016, 1.24153244, 0.92192322, 0.64427125, 0.43325692, 0.25053367, 0.09824532, 0.02916753],
120
+ [14.61464119, 11.54541874, 8.75849152, 7.49001646, 6.14220476, 5.09240818, 4.26497746, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.12534678, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
121
+ [14.61464119, 12.2308979, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 5.09240818, 4.26497746, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.12534678, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
122
+ [14.61464119, 12.2308979, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.26497746, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.12534678, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
123
+ ],
124
+ 1.05: [
125
+ [14.61464119, 0.95350921, 0.02916753],
126
+ [14.61464119, 6.77309084, 0.89115214, 0.02916753],
127
+ [14.61464119, 6.77309084, 2.05039096, 0.72133851, 0.02916753],
128
+ [14.61464119, 6.77309084, 2.84484982, 1.28281462, 0.52423614, 0.02916753],
129
+ [14.61464119, 6.77309084, 3.07277966, 1.61558151, 0.803307, 0.34370604, 0.02916753],
130
+ [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.56271636, 0.803307, 0.34370604, 0.02916753],
131
+ [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.95350921, 0.52423614, 0.22545385, 0.02916753],
132
+ [14.61464119, 7.49001646, 4.86714602, 3.07277966, 1.98035145, 1.24153244, 0.74807048, 0.41087446, 0.17026083, 0.02916753],
133
+ [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.27973175, 1.51179266, 0.95350921, 0.59516323, 0.34370604, 0.13792117, 0.02916753],
134
+ [14.61464119, 7.49001646, 5.09240818, 3.46139455, 2.45070267, 1.61558151, 1.08895338, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
135
+ [14.61464119, 11.54541874, 7.49001646, 5.09240818, 3.46139455, 2.45070267, 1.61558151, 1.08895338, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
136
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.36326075, 1.61558151, 1.08895338, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
137
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.45070267, 1.72759056, 1.24153244, 0.86115354, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
138
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.19988537, 1.61558151, 1.162866, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
139
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.19988537, 1.67050016, 1.28281462, 0.95350921, 0.72133851, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
140
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.36326075, 1.84880662, 1.41535246, 1.08895338, 0.83188516, 0.61951244, 0.45573691, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
141
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.20157266, 0.95350921, 0.74807048, 0.57119018, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
142
+ [14.61464119, 11.54541874, 8.30717278, 7.11996698, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.20157266, 0.95350921, 0.74807048, 0.57119018, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
143
+ [14.61464119, 11.54541874, 8.30717278, 7.11996698, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.45070267, 1.98035145, 1.61558151, 1.32549286, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.41087446, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
144
+ ],
145
+ 1.10: [
146
+ [14.61464119, 0.89115214, 0.02916753],
147
+ [14.61464119, 2.36326075, 0.72133851, 0.02916753],
148
+ [14.61464119, 5.85520077, 1.61558151, 0.57119018, 0.02916753],
149
+ [14.61464119, 6.77309084, 2.45070267, 1.08895338, 0.45573691, 0.02916753],
150
+ [14.61464119, 6.77309084, 2.95596409, 1.56271636, 0.803307, 0.34370604, 0.02916753],
151
+ [14.61464119, 6.77309084, 3.07277966, 1.61558151, 0.89115214, 0.4783645, 0.19894916, 0.02916753],
152
+ [14.61464119, 6.77309084, 3.07277966, 1.84880662, 1.08895338, 0.64427125, 0.34370604, 0.13792117, 0.02916753],
153
+ [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.95350921, 0.54755926, 0.27464288, 0.09824532, 0.02916753],
154
+ [14.61464119, 7.49001646, 4.86714602, 2.95596409, 1.91321158, 1.24153244, 0.803307, 0.4783645, 0.25053367, 0.09824532, 0.02916753],
155
+ [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.05039096, 1.41535246, 0.95350921, 0.64427125, 0.41087446, 0.22545385, 0.09824532, 0.02916753],
156
+ [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.27973175, 1.61558151, 1.12534678, 0.803307, 0.54755926, 0.36617002, 0.22545385, 0.09824532, 0.02916753],
157
+ [14.61464119, 7.49001646, 4.86714602, 3.32507086, 2.45070267, 1.72759056, 1.24153244, 0.89115214, 0.64427125, 0.45573691, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
158
+ [14.61464119, 7.49001646, 5.09240818, 3.60512662, 2.84484982, 2.05039096, 1.51179266, 1.08895338, 0.803307, 0.59516323, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
159
+ [14.61464119, 7.49001646, 5.09240818, 3.60512662, 2.84484982, 2.12350607, 1.61558151, 1.24153244, 0.95350921, 0.72133851, 0.54755926, 0.41087446, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
160
+ [14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
161
+ [14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.45070267, 1.91321158, 1.51179266, 1.20157266, 0.95350921, 0.74807048, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
162
+ [14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.43325692, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
163
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.43325692, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
164
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.89115214, 0.72133851, 0.59516323, 0.4783645, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
165
+ ],
166
+ 1.15: [
167
+ [14.61464119, 0.83188516, 0.02916753],
168
+ [14.61464119, 1.84880662, 0.59516323, 0.02916753],
169
+ [14.61464119, 5.85520077, 1.56271636, 0.52423614, 0.02916753],
170
+ [14.61464119, 5.85520077, 1.91321158, 0.83188516, 0.34370604, 0.02916753],
171
+ [14.61464119, 5.85520077, 2.45070267, 1.24153244, 0.59516323, 0.25053367, 0.02916753],
172
+ [14.61464119, 5.85520077, 2.84484982, 1.51179266, 0.803307, 0.41087446, 0.17026083, 0.02916753],
173
+ [14.61464119, 5.85520077, 2.84484982, 1.56271636, 0.89115214, 0.50118381, 0.25053367, 0.09824532, 0.02916753],
174
+ [14.61464119, 6.77309084, 3.07277966, 1.84880662, 1.12534678, 0.72133851, 0.43325692, 0.22545385, 0.09824532, 0.02916753],
175
+ [14.61464119, 6.77309084, 3.07277966, 1.91321158, 1.24153244, 0.803307, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
176
+ [14.61464119, 7.49001646, 4.86714602, 2.95596409, 1.91321158, 1.24153244, 0.803307, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
177
+ [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.05039096, 1.36964464, 0.95350921, 0.69515091, 0.4783645, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
178
+ [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.803307, 0.59516323, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
179
+ [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.803307, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
180
+ [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.19988537, 1.61558151, 1.24153244, 0.95350921, 0.74807048, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
181
+ [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.78698075, 1.32549286, 1.01931262, 0.803307, 0.64427125, 0.50118381, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
182
+ [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.78698075, 1.32549286, 1.01931262, 0.803307, 0.64427125, 0.52423614, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
183
+ [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.12534678, 0.89115214, 0.72133851, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
184
+ [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.12534678, 0.89115214, 0.72133851, 0.59516323, 0.50118381, 0.41087446, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
185
+ [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.12534678, 0.89115214, 0.72133851, 0.59516323, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
186
+ ],
187
+ 1.20: [
188
+ [14.61464119, 0.803307, 0.02916753],
189
+ [14.61464119, 1.56271636, 0.52423614, 0.02916753],
190
+ [14.61464119, 2.36326075, 0.92192322, 0.36617002, 0.02916753],
191
+ [14.61464119, 2.84484982, 1.24153244, 0.59516323, 0.25053367, 0.02916753],
192
+ [14.61464119, 5.85520077, 2.05039096, 0.95350921, 0.45573691, 0.17026083, 0.02916753],
193
+ [14.61464119, 5.85520077, 2.45070267, 1.24153244, 0.64427125, 0.29807833, 0.09824532, 0.02916753],
194
+ [14.61464119, 5.85520077, 2.45070267, 1.36964464, 0.803307, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
195
+ [14.61464119, 5.85520077, 2.84484982, 1.61558151, 0.95350921, 0.59516323, 0.36617002, 0.19894916, 0.09824532, 0.02916753],
196
+ [14.61464119, 5.85520077, 2.84484982, 1.67050016, 1.08895338, 0.74807048, 0.50118381, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
197
+ [14.61464119, 5.85520077, 2.95596409, 1.84880662, 1.24153244, 0.83188516, 0.59516323, 0.41087446, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
198
+ [14.61464119, 5.85520077, 3.07277966, 1.98035145, 1.36964464, 0.95350921, 0.69515091, 0.50118381, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
199
+ [14.61464119, 6.77309084, 3.46139455, 2.36326075, 1.56271636, 1.08895338, 0.803307, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
200
+ [14.61464119, 6.77309084, 3.46139455, 2.45070267, 1.61558151, 1.162866, 0.86115354, 0.64427125, 0.50118381, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
201
+ [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
202
+ [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
203
+ [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.41087446, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
204
+ [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.19988537, 1.61558151, 1.20157266, 0.92192322, 0.72133851, 0.57119018, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
205
+ [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.19988537, 1.61558151, 1.24153244, 0.95350921, 0.74807048, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
206
+ [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.19988537, 1.61558151, 1.24153244, 0.95350921, 0.74807048, 0.59516323, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
207
+ ],
208
+ 1.25: [
209
+ [14.61464119, 0.72133851, 0.02916753],
210
+ [14.61464119, 1.56271636, 0.50118381, 0.02916753],
211
+ [14.61464119, 2.05039096, 0.803307, 0.32104823, 0.02916753],
212
+ [14.61464119, 2.36326075, 0.95350921, 0.43325692, 0.17026083, 0.02916753],
213
+ [14.61464119, 2.84484982, 1.24153244, 0.59516323, 0.27464288, 0.09824532, 0.02916753],
214
+ [14.61464119, 3.07277966, 1.51179266, 0.803307, 0.43325692, 0.22545385, 0.09824532, 0.02916753],
215
+ [14.61464119, 5.85520077, 2.36326075, 1.24153244, 0.72133851, 0.41087446, 0.22545385, 0.09824532, 0.02916753],
216
+ [14.61464119, 5.85520077, 2.45070267, 1.36964464, 0.83188516, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
217
+ [14.61464119, 5.85520077, 2.84484982, 1.61558151, 0.98595673, 0.64427125, 0.43325692, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
218
+ [14.61464119, 5.85520077, 2.84484982, 1.67050016, 1.08895338, 0.74807048, 0.52423614, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
219
+ [14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.803307, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
220
+ [14.61464119, 5.85520077, 2.95596409, 1.84880662, 1.24153244, 0.86115354, 0.64427125, 0.4783645, 0.36617002, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
221
+ [14.61464119, 5.85520077, 2.95596409, 1.84880662, 1.28281462, 0.92192322, 0.69515091, 0.52423614, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
222
+ [14.61464119, 5.85520077, 2.95596409, 1.91321158, 1.32549286, 0.95350921, 0.72133851, 0.54755926, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
223
+ [14.61464119, 5.85520077, 2.95596409, 1.91321158, 1.32549286, 0.95350921, 0.72133851, 0.57119018, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
224
+ [14.61464119, 5.85520077, 2.95596409, 1.91321158, 1.32549286, 0.95350921, 0.74807048, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
225
+ [14.61464119, 5.85520077, 3.07277966, 2.05039096, 1.41535246, 1.05362725, 0.803307, 0.61951244, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
226
+ [14.61464119, 5.85520077, 3.07277966, 2.05039096, 1.41535246, 1.05362725, 0.803307, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
227
+ [14.61464119, 5.85520077, 3.07277966, 2.05039096, 1.46270394, 1.08895338, 0.83188516, 0.66947293, 0.54755926, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
228
+ ],
229
+ 1.30: [
230
+ [14.61464119, 0.72133851, 0.02916753],
231
+ [14.61464119, 1.24153244, 0.43325692, 0.02916753],
232
+ [14.61464119, 1.56271636, 0.59516323, 0.22545385, 0.02916753],
233
+ [14.61464119, 1.84880662, 0.803307, 0.36617002, 0.13792117, 0.02916753],
234
+ [14.61464119, 2.36326075, 1.01931262, 0.52423614, 0.25053367, 0.09824532, 0.02916753],
235
+ [14.61464119, 2.84484982, 1.36964464, 0.74807048, 0.41087446, 0.22545385, 0.09824532, 0.02916753],
236
+ [14.61464119, 3.07277966, 1.56271636, 0.89115214, 0.54755926, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
237
+ [14.61464119, 3.07277966, 1.61558151, 0.95350921, 0.61951244, 0.41087446, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
238
+ [14.61464119, 5.85520077, 2.45070267, 1.36964464, 0.83188516, 0.54755926, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
239
+ [14.61464119, 5.85520077, 2.45070267, 1.41535246, 0.92192322, 0.64427125, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
240
+ [14.61464119, 5.85520077, 2.6383388, 1.56271636, 1.01931262, 0.72133851, 0.50118381, 0.36617002, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
241
+ [14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.05362725, 0.74807048, 0.54755926, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
242
+ [14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.08895338, 0.77538133, 0.57119018, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
243
+ [14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.59516323, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
244
+ [14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
245
+ [14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.83188516, 0.64427125, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
246
+ [14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.83188516, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
247
+ [14.61464119, 5.85520077, 2.84484982, 1.78698075, 1.24153244, 0.92192322, 0.72133851, 0.57119018, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
248
+ [14.61464119, 5.85520077, 2.84484982, 1.78698075, 1.24153244, 0.92192322, 0.72133851, 0.57119018, 0.4783645, 0.41087446, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
249
+ ],
250
+ 1.35: [
251
+ [14.61464119, 0.69515091, 0.02916753],
252
+ [14.61464119, 0.95350921, 0.34370604, 0.02916753],
253
+ [14.61464119, 1.56271636, 0.57119018, 0.19894916, 0.02916753],
254
+ [14.61464119, 1.61558151, 0.69515091, 0.29807833, 0.09824532, 0.02916753],
255
+ [14.61464119, 1.84880662, 0.83188516, 0.43325692, 0.22545385, 0.09824532, 0.02916753],
256
+ [14.61464119, 2.45070267, 1.162866, 0.64427125, 0.36617002, 0.19894916, 0.09824532, 0.02916753],
257
+ [14.61464119, 2.84484982, 1.36964464, 0.803307, 0.50118381, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
258
+ [14.61464119, 2.84484982, 1.41535246, 0.83188516, 0.54755926, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
259
+ [14.61464119, 2.84484982, 1.56271636, 0.95350921, 0.64427125, 0.45573691, 0.32104823, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
260
+ [14.61464119, 2.84484982, 1.56271636, 0.95350921, 0.64427125, 0.45573691, 0.34370604, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
261
+ [14.61464119, 3.07277966, 1.61558151, 1.01931262, 0.72133851, 0.52423614, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
262
+ [14.61464119, 3.07277966, 1.61558151, 1.01931262, 0.72133851, 0.52423614, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
263
+ [14.61464119, 3.07277966, 1.61558151, 1.05362725, 0.74807048, 0.54755926, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
264
+ [14.61464119, 3.07277966, 1.72759056, 1.12534678, 0.803307, 0.59516323, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
265
+ [14.61464119, 3.07277966, 1.72759056, 1.12534678, 0.803307, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
266
+ [14.61464119, 5.85520077, 2.45070267, 1.51179266, 1.01931262, 0.74807048, 0.57119018, 0.45573691, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
267
+ [14.61464119, 5.85520077, 2.6383388, 1.61558151, 1.08895338, 0.803307, 0.61951244, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
268
+ [14.61464119, 5.85520077, 2.6383388, 1.61558151, 1.08895338, 0.803307, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
269
+ [14.61464119, 5.85520077, 2.6383388, 1.61558151, 1.08895338, 0.803307, 0.64427125, 0.52423614, 0.45573691, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
270
+ ],
271
+ 1.40: [
272
+ [14.61464119, 0.59516323, 0.02916753],
273
+ [14.61464119, 0.95350921, 0.34370604, 0.02916753],
274
+ [14.61464119, 1.08895338, 0.43325692, 0.13792117, 0.02916753],
275
+ [14.61464119, 1.56271636, 0.64427125, 0.27464288, 0.09824532, 0.02916753],
276
+ [14.61464119, 1.61558151, 0.803307, 0.43325692, 0.22545385, 0.09824532, 0.02916753],
277
+ [14.61464119, 2.05039096, 0.95350921, 0.54755926, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
278
+ [14.61464119, 2.45070267, 1.24153244, 0.72133851, 0.43325692, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
279
+ [14.61464119, 2.45070267, 1.24153244, 0.74807048, 0.50118381, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
280
+ [14.61464119, 2.45070267, 1.28281462, 0.803307, 0.52423614, 0.36617002, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
281
+ [14.61464119, 2.45070267, 1.28281462, 0.803307, 0.54755926, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
282
+ [14.61464119, 2.84484982, 1.41535246, 0.86115354, 0.59516323, 0.43325692, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
283
+ [14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.64427125, 0.45573691, 0.34370604, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
284
+ [14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.64427125, 0.4783645, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
285
+ [14.61464119, 2.84484982, 1.56271636, 0.98595673, 0.69515091, 0.52423614, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
286
+ [14.61464119, 2.84484982, 1.56271636, 1.01931262, 0.72133851, 0.54755926, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
287
+ [14.61464119, 2.84484982, 1.61558151, 1.05362725, 0.74807048, 0.57119018, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
288
+ [14.61464119, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.61951244, 0.50118381, 0.41087446, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
289
+ [14.61464119, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.61951244, 0.50118381, 0.43325692, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
290
+ [14.61464119, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.64427125, 0.52423614, 0.45573691, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
291
+ ],
292
+ 1.45: [
293
+ [14.61464119, 0.59516323, 0.02916753],
294
+ [14.61464119, 0.803307, 0.25053367, 0.02916753],
295
+ [14.61464119, 0.95350921, 0.34370604, 0.09824532, 0.02916753],
296
+ [14.61464119, 1.24153244, 0.54755926, 0.25053367, 0.09824532, 0.02916753],
297
+ [14.61464119, 1.56271636, 0.72133851, 0.36617002, 0.19894916, 0.09824532, 0.02916753],
298
+ [14.61464119, 1.61558151, 0.803307, 0.45573691, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
299
+ [14.61464119, 1.91321158, 0.95350921, 0.57119018, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
300
+ [14.61464119, 2.19988537, 1.08895338, 0.64427125, 0.41087446, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
301
+ [14.61464119, 2.45070267, 1.24153244, 0.74807048, 0.50118381, 0.34370604, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
302
+ [14.61464119, 2.45070267, 1.24153244, 0.74807048, 0.50118381, 0.36617002, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
303
+ [14.61464119, 2.45070267, 1.28281462, 0.803307, 0.54755926, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
304
+ [14.61464119, 2.45070267, 1.28281462, 0.803307, 0.57119018, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
305
+ [14.61464119, 2.45070267, 1.28281462, 0.83188516, 0.59516323, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
306
+ [14.61464119, 2.45070267, 1.28281462, 0.83188516, 0.59516323, 0.45573691, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
307
+ [14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.69515091, 0.52423614, 0.41087446, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
308
+ [14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.69515091, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
309
+ [14.61464119, 2.84484982, 1.56271636, 0.98595673, 0.72133851, 0.54755926, 0.45573691, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
310
+ [14.61464119, 2.84484982, 1.56271636, 1.01931262, 0.74807048, 0.57119018, 0.4783645, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
311
+ [14.61464119, 2.84484982, 1.56271636, 1.01931262, 0.74807048, 0.59516323, 0.50118381, 0.43325692, 0.38853383, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
312
+ ],
313
+ 1.50: [
314
+ [14.61464119, 0.54755926, 0.02916753],
315
+ [14.61464119, 0.803307, 0.25053367, 0.02916753],
316
+ [14.61464119, 0.86115354, 0.32104823, 0.09824532, 0.02916753],
317
+ [14.61464119, 1.24153244, 0.54755926, 0.25053367, 0.09824532, 0.02916753],
318
+ [14.61464119, 1.56271636, 0.72133851, 0.36617002, 0.19894916, 0.09824532, 0.02916753],
319
+ [14.61464119, 1.61558151, 0.803307, 0.45573691, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
320
+ [14.61464119, 1.61558151, 0.83188516, 0.52423614, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
321
+ [14.61464119, 1.84880662, 0.95350921, 0.59516323, 0.38853383, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
322
+ [14.61464119, 1.84880662, 0.95350921, 0.59516323, 0.41087446, 0.29807833, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
323
+ [14.61464119, 1.84880662, 0.95350921, 0.61951244, 0.43325692, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
324
+ [14.61464119, 2.19988537, 1.12534678, 0.72133851, 0.50118381, 0.36617002, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
325
+ [14.61464119, 2.19988537, 1.12534678, 0.72133851, 0.50118381, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
326
+ [14.61464119, 2.36326075, 1.24153244, 0.803307, 0.57119018, 0.43325692, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
327
+ [14.61464119, 2.36326075, 1.24153244, 0.803307, 0.57119018, 0.43325692, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
328
+ [14.61464119, 2.36326075, 1.24153244, 0.803307, 0.59516323, 0.45573691, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
329
+ [14.61464119, 2.36326075, 1.24153244, 0.803307, 0.59516323, 0.45573691, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
330
+ [14.61464119, 2.45070267, 1.32549286, 0.86115354, 0.64427125, 0.50118381, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
331
+ [14.61464119, 2.45070267, 1.36964464, 0.92192322, 0.69515091, 0.54755926, 0.45573691, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
332
+ [14.61464119, 2.45070267, 1.41535246, 0.95350921, 0.72133851, 0.57119018, 0.4783645, 0.43325692, 0.38853383, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
333
+ ],
334
+ }
335
+
336
+ class GITSScheduler:
337
+ @classmethod
338
+ def INPUT_TYPES(s):
339
+ return {"required":
340
+ {"coeff": ("FLOAT", {"default": 1.20, "min": 0.80, "max": 1.50, "step": 0.05}),
341
+ "steps": ("INT", {"default": 10, "min": 2, "max": 1000}),
342
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
343
+ }
344
+ }
345
+ RETURN_TYPES = ("SIGMAS",)
346
+ CATEGORY = "sampling/custom_sampling/schedulers"
347
+
348
+ FUNCTION = "get_sigmas"
349
+
350
+ def get_sigmas(self, coeff, steps, denoise):
351
+ total_steps = steps
352
+ if denoise < 1.0:
353
+ if denoise <= 0.0:
354
+ return (torch.FloatTensor([]),)
355
+ total_steps = round(steps * denoise)
356
+
357
+ if steps <= 20:
358
+ sigmas = NOISE_LEVELS[round(coeff, 2)][steps-2][:]
359
+ else:
360
+ sigmas = NOISE_LEVELS[round(coeff, 2)][-1][:]
361
+ sigmas = loglinear_interp(sigmas, steps + 1)
362
+
363
+ sigmas = sigmas[-(total_steps + 1):]
364
+ sigmas[-1] = 0
365
+ return (torch.FloatTensor(sigmas), )
366
+
367
+ NODE_CLASS_MAPPINGS = {
368
+ "GITSScheduler": GITSScheduler,
369
+ }
comfy_extras/nodes_hooks.py ADDED
@@ -0,0 +1,745 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, Union
3
+ import logging
4
+ import torch
5
+ from collections.abc import Iterable
6
+
7
+ if TYPE_CHECKING:
8
+ from comfy.sd import CLIP
9
+
10
+ import comfy.hooks
11
+ import comfy.sd
12
+ import comfy.utils
13
+ import folder_paths
14
+
15
+ ###########################################
16
+ # Mask, Combine, and Hook Conditioning
17
+ #------------------------------------------
18
+ class PairConditioningSetProperties:
19
+ NodeId = 'PairConditioningSetProperties'
20
+ NodeName = 'Cond Pair Set Props'
21
+ @classmethod
22
+ def INPUT_TYPES(s):
23
+ return {
24
+ "required": {
25
+ "positive_NEW": ("CONDITIONING", ),
26
+ "negative_NEW": ("CONDITIONING", ),
27
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
28
+ "set_cond_area": (["default", "mask bounds"],),
29
+ },
30
+ "optional": {
31
+ "mask": ("MASK", ),
32
+ "hooks": ("HOOKS",),
33
+ "timesteps": ("TIMESTEPS_RANGE",),
34
+ }
35
+ }
36
+
37
+ EXPERIMENTAL = True
38
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
39
+ RETURN_NAMES = ("positive", "negative")
40
+ CATEGORY = "advanced/hooks/cond pair"
41
+ FUNCTION = "set_properties"
42
+
43
+ def set_properties(self, positive_NEW, negative_NEW,
44
+ strength: float, set_cond_area: str,
45
+ mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
46
+ final_positive, final_negative = comfy.hooks.set_conds_props(conds=[positive_NEW, negative_NEW],
47
+ strength=strength, set_cond_area=set_cond_area,
48
+ mask=mask, hooks=hooks, timesteps_range=timesteps)
49
+ return (final_positive, final_negative)
50
+
51
+ class PairConditioningSetPropertiesAndCombine:
52
+ NodeId = 'PairConditioningSetPropertiesAndCombine'
53
+ NodeName = 'Cond Pair Set Props Combine'
54
+ @classmethod
55
+ def INPUT_TYPES(s):
56
+ return {
57
+ "required": {
58
+ "positive": ("CONDITIONING", ),
59
+ "negative": ("CONDITIONING", ),
60
+ "positive_NEW": ("CONDITIONING", ),
61
+ "negative_NEW": ("CONDITIONING", ),
62
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
63
+ "set_cond_area": (["default", "mask bounds"],),
64
+ },
65
+ "optional": {
66
+ "mask": ("MASK", ),
67
+ "hooks": ("HOOKS",),
68
+ "timesteps": ("TIMESTEPS_RANGE",),
69
+ }
70
+ }
71
+
72
+ EXPERIMENTAL = True
73
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
74
+ RETURN_NAMES = ("positive", "negative")
75
+ CATEGORY = "advanced/hooks/cond pair"
76
+ FUNCTION = "set_properties"
77
+
78
+ def set_properties(self, positive, negative, positive_NEW, negative_NEW,
79
+ strength: float, set_cond_area: str,
80
+ mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
81
+ final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW],
82
+ strength=strength, set_cond_area=set_cond_area,
83
+ mask=mask, hooks=hooks, timesteps_range=timesteps)
84
+ return (final_positive, final_negative)
85
+
86
+ class ConditioningSetProperties:
87
+ NodeId = 'ConditioningSetProperties'
88
+ NodeName = 'Cond Set Props'
89
+ @classmethod
90
+ def INPUT_TYPES(s):
91
+ return {
92
+ "required": {
93
+ "cond_NEW": ("CONDITIONING", ),
94
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
95
+ "set_cond_area": (["default", "mask bounds"],),
96
+ },
97
+ "optional": {
98
+ "mask": ("MASK", ),
99
+ "hooks": ("HOOKS",),
100
+ "timesteps": ("TIMESTEPS_RANGE",),
101
+ }
102
+ }
103
+
104
+ EXPERIMENTAL = True
105
+ RETURN_TYPES = ("CONDITIONING",)
106
+ CATEGORY = "advanced/hooks/cond single"
107
+ FUNCTION = "set_properties"
108
+
109
+ def set_properties(self, cond_NEW,
110
+ strength: float, set_cond_area: str,
111
+ mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
112
+ (final_cond,) = comfy.hooks.set_conds_props(conds=[cond_NEW],
113
+ strength=strength, set_cond_area=set_cond_area,
114
+ mask=mask, hooks=hooks, timesteps_range=timesteps)
115
+ return (final_cond,)
116
+
117
+ class ConditioningSetPropertiesAndCombine:
118
+ NodeId = 'ConditioningSetPropertiesAndCombine'
119
+ NodeName = 'Cond Set Props Combine'
120
+ @classmethod
121
+ def INPUT_TYPES(s):
122
+ return {
123
+ "required": {
124
+ "cond": ("CONDITIONING", ),
125
+ "cond_NEW": ("CONDITIONING", ),
126
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
127
+ "set_cond_area": (["default", "mask bounds"],),
128
+ },
129
+ "optional": {
130
+ "mask": ("MASK", ),
131
+ "hooks": ("HOOKS",),
132
+ "timesteps": ("TIMESTEPS_RANGE",),
133
+ }
134
+ }
135
+
136
+ EXPERIMENTAL = True
137
+ RETURN_TYPES = ("CONDITIONING",)
138
+ CATEGORY = "advanced/hooks/cond single"
139
+ FUNCTION = "set_properties"
140
+
141
+ def set_properties(self, cond, cond_NEW,
142
+ strength: float, set_cond_area: str,
143
+ mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
144
+ (final_cond,) = comfy.hooks.set_conds_props_and_combine(conds=[cond], new_conds=[cond_NEW],
145
+ strength=strength, set_cond_area=set_cond_area,
146
+ mask=mask, hooks=hooks, timesteps_range=timesteps)
147
+ return (final_cond,)
148
+
149
+ class PairConditioningCombine:
150
+ NodeId = 'PairConditioningCombine'
151
+ NodeName = 'Cond Pair Combine'
152
+ @classmethod
153
+ def INPUT_TYPES(s):
154
+ return {
155
+ "required": {
156
+ "positive_A": ("CONDITIONING",),
157
+ "negative_A": ("CONDITIONING",),
158
+ "positive_B": ("CONDITIONING",),
159
+ "negative_B": ("CONDITIONING",),
160
+ },
161
+ }
162
+
163
+ EXPERIMENTAL = True
164
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
165
+ RETURN_NAMES = ("positive", "negative")
166
+ CATEGORY = "advanced/hooks/cond pair"
167
+ FUNCTION = "combine"
168
+
169
+ def combine(self, positive_A, negative_A, positive_B, negative_B):
170
+ final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],)
171
+ return (final_positive, final_negative,)
172
+
173
+ class PairConditioningSetDefaultAndCombine:
174
+ NodeId = 'PairConditioningSetDefaultCombine'
175
+ NodeName = 'Cond Pair Set Default Combine'
176
+ @classmethod
177
+ def INPUT_TYPES(s):
178
+ return {
179
+ "required": {
180
+ "positive": ("CONDITIONING",),
181
+ "negative": ("CONDITIONING",),
182
+ "positive_DEFAULT": ("CONDITIONING",),
183
+ "negative_DEFAULT": ("CONDITIONING",),
184
+ },
185
+ "optional": {
186
+ "hooks": ("HOOKS",),
187
+ }
188
+ }
189
+
190
+ EXPERIMENTAL = True
191
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
192
+ RETURN_NAMES = ("positive", "negative")
193
+ CATEGORY = "advanced/hooks/cond pair"
194
+ FUNCTION = "set_default_and_combine"
195
+
196
+ def set_default_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT,
197
+ hooks: comfy.hooks.HookGroup=None):
198
+ final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT],
199
+ hooks=hooks)
200
+ return (final_positive, final_negative)
201
+
202
+ class ConditioningSetDefaultAndCombine:
203
+ NodeId = 'ConditioningSetDefaultCombine'
204
+ NodeName = 'Cond Set Default Combine'
205
+ @classmethod
206
+ def INPUT_TYPES(s):
207
+ return {
208
+ "required": {
209
+ "cond": ("CONDITIONING",),
210
+ "cond_DEFAULT": ("CONDITIONING",),
211
+ },
212
+ "optional": {
213
+ "hooks": ("HOOKS",),
214
+ }
215
+ }
216
+
217
+ EXPERIMENTAL = True
218
+ RETURN_TYPES = ("CONDITIONING",)
219
+ CATEGORY = "advanced/hooks/cond single"
220
+ FUNCTION = "set_default_and_combine"
221
+
222
+ def set_default_and_combine(self, cond, cond_DEFAULT,
223
+ hooks: comfy.hooks.HookGroup=None):
224
+ (final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT],
225
+ hooks=hooks)
226
+ return (final_conditioning,)
227
+
228
+ class SetClipHooks:
229
+ NodeId = 'SetClipHooks'
230
+ NodeName = 'Set CLIP Hooks'
231
+ @classmethod
232
+ def INPUT_TYPES(s):
233
+ return {
234
+ "required": {
235
+ "clip": ("CLIP",),
236
+ "apply_to_conds": ("BOOLEAN", {"default": True}),
237
+ "schedule_clip": ("BOOLEAN", {"default": False})
238
+ },
239
+ "optional": {
240
+ "hooks": ("HOOKS",)
241
+ }
242
+ }
243
+
244
+ EXPERIMENTAL = True
245
+ RETURN_TYPES = ("CLIP",)
246
+ CATEGORY = "advanced/hooks/clip"
247
+ FUNCTION = "apply_hooks"
248
+
249
+ def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
250
+ if hooks is not None:
251
+ clip = clip.clone()
252
+ if apply_to_conds:
253
+ clip.apply_hooks_to_conds = hooks
254
+ clip.patcher.forced_hooks = hooks.clone()
255
+ clip.use_clip_schedule = schedule_clip
256
+ if not clip.use_clip_schedule:
257
+ clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
258
+ clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip))
259
+ return (clip,)
260
+
261
+ class ConditioningTimestepsRange:
262
+ NodeId = 'ConditioningTimestepsRange'
263
+ NodeName = 'Timesteps Range'
264
+ @classmethod
265
+ def INPUT_TYPES(s):
266
+ return {
267
+ "required": {
268
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
269
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
270
+ },
271
+ }
272
+
273
+ EXPERIMENTAL = True
274
+ RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE")
275
+ RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE")
276
+ CATEGORY = "advanced/hooks"
277
+ FUNCTION = "create_range"
278
+
279
+ def create_range(self, start_percent: float, end_percent: float):
280
+ return ((start_percent, end_percent), (0.0, start_percent), (end_percent, 1.0))
281
+ #------------------------------------------
282
+ ###########################################
283
+
284
+
285
+ ###########################################
286
+ # Create Hooks
287
+ #------------------------------------------
288
+ class CreateHookLora:
289
+ NodeId = 'CreateHookLora'
290
+ NodeName = 'Create Hook LoRA'
291
+ def __init__(self):
292
+ self.loaded_lora = None
293
+
294
+ @classmethod
295
+ def INPUT_TYPES(s):
296
+ return {
297
+ "required": {
298
+ "lora_name": (folder_paths.get_filename_list("loras"), ),
299
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
300
+ "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
301
+ },
302
+ "optional": {
303
+ "prev_hooks": ("HOOKS",)
304
+ }
305
+ }
306
+
307
+ EXPERIMENTAL = True
308
+ RETURN_TYPES = ("HOOKS",)
309
+ CATEGORY = "advanced/hooks/create"
310
+ FUNCTION = "create_hook"
311
+
312
+ def create_hook(self, lora_name: str, strength_model: float, strength_clip: float, prev_hooks: comfy.hooks.HookGroup=None):
313
+ if prev_hooks is None:
314
+ prev_hooks = comfy.hooks.HookGroup()
315
+ prev_hooks.clone()
316
+
317
+ if strength_model == 0 and strength_clip == 0:
318
+ return (prev_hooks,)
319
+
320
+ lora_path = folder_paths.get_full_path("loras", lora_name)
321
+ lora = None
322
+ if self.loaded_lora is not None:
323
+ if self.loaded_lora[0] == lora_path:
324
+ lora = self.loaded_lora[1]
325
+ else:
326
+ temp = self.loaded_lora
327
+ self.loaded_lora = None
328
+ del temp
329
+
330
+ if lora is None:
331
+ lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
332
+ self.loaded_lora = (lora_path, lora)
333
+
334
+ hooks = comfy.hooks.create_hook_lora(lora=lora, strength_model=strength_model, strength_clip=strength_clip)
335
+ return (prev_hooks.clone_and_combine(hooks),)
336
+
337
+ class CreateHookLoraModelOnly(CreateHookLora):
338
+ NodeId = 'CreateHookLoraModelOnly'
339
+ NodeName = 'Create Hook LoRA (MO)'
340
+ @classmethod
341
+ def INPUT_TYPES(s):
342
+ return {
343
+ "required": {
344
+ "lora_name": (folder_paths.get_filename_list("loras"), ),
345
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
346
+ },
347
+ "optional": {
348
+ "prev_hooks": ("HOOKS",)
349
+ }
350
+ }
351
+
352
+ EXPERIMENTAL = True
353
+ RETURN_TYPES = ("HOOKS",)
354
+ CATEGORY = "advanced/hooks/create"
355
+ FUNCTION = "create_hook_model_only"
356
+
357
+ def create_hook_model_only(self, lora_name: str, strength_model: float, prev_hooks: comfy.hooks.HookGroup=None):
358
+ return self.create_hook(lora_name=lora_name, strength_model=strength_model, strength_clip=0, prev_hooks=prev_hooks)
359
+
360
+ class CreateHookModelAsLora:
361
+ NodeId = 'CreateHookModelAsLora'
362
+ NodeName = 'Create Hook Model as LoRA'
363
+
364
+ def __init__(self):
365
+ # when not None, will be in following format:
366
+ # (ckpt_path: str, weights_model: dict, weights_clip: dict)
367
+ self.loaded_weights = None
368
+
369
+ @classmethod
370
+ def INPUT_TYPES(s):
371
+ return {
372
+ "required": {
373
+ "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
374
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
375
+ "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
376
+ },
377
+ "optional": {
378
+ "prev_hooks": ("HOOKS",)
379
+ }
380
+ }
381
+
382
+ EXPERIMENTAL = True
383
+ RETURN_TYPES = ("HOOKS",)
384
+ CATEGORY = "advanced/hooks/create"
385
+ FUNCTION = "create_hook"
386
+
387
+ def create_hook(self, ckpt_name: str, strength_model: float, strength_clip: float,
388
+ prev_hooks: comfy.hooks.HookGroup=None):
389
+ if prev_hooks is None:
390
+ prev_hooks = comfy.hooks.HookGroup()
391
+ prev_hooks.clone()
392
+
393
+ ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
394
+ weights_model = None
395
+ weights_clip = None
396
+ if self.loaded_weights is not None:
397
+ if self.loaded_weights[0] == ckpt_path:
398
+ weights_model = self.loaded_weights[1]
399
+ weights_clip = self.loaded_weights[2]
400
+ else:
401
+ temp = self.loaded_weights
402
+ self.loaded_weights = None
403
+ del temp
404
+
405
+ if weights_model is None:
406
+ out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
407
+ weights_model = comfy.hooks.get_patch_weights_from_model(out[0])
408
+ weights_clip = comfy.hooks.get_patch_weights_from_model(out[1].patcher if out[1] else out[1])
409
+ self.loaded_weights = (ckpt_path, weights_model, weights_clip)
410
+
411
+ hooks = comfy.hooks.create_hook_model_as_lora(weights_model=weights_model, weights_clip=weights_clip,
412
+ strength_model=strength_model, strength_clip=strength_clip)
413
+ return (prev_hooks.clone_and_combine(hooks),)
414
+
415
+ class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora):
416
+ NodeId = 'CreateHookModelAsLoraModelOnly'
417
+ NodeName = 'Create Hook Model as LoRA (MO)'
418
+ @classmethod
419
+ def INPUT_TYPES(s):
420
+ return {
421
+ "required": {
422
+ "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
423
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
424
+ },
425
+ "optional": {
426
+ "prev_hooks": ("HOOKS",)
427
+ }
428
+ }
429
+
430
+ EXPERIMENTAL = True
431
+ RETURN_TYPES = ("HOOKS",)
432
+ CATEGORY = "advanced/hooks/create"
433
+ FUNCTION = "create_hook_model_only"
434
+
435
+ def create_hook_model_only(self, ckpt_name: str, strength_model: float,
436
+ prev_hooks: comfy.hooks.HookGroup=None):
437
+ return self.create_hook(ckpt_name=ckpt_name, strength_model=strength_model, strength_clip=0.0, prev_hooks=prev_hooks)
438
+ #------------------------------------------
439
+ ###########################################
440
+
441
+
442
+ ###########################################
443
+ # Schedule Hooks
444
+ #------------------------------------------
445
+ class SetHookKeyframes:
446
+ NodeId = 'SetHookKeyframes'
447
+ NodeName = 'Set Hook Keyframes'
448
+ @classmethod
449
+ def INPUT_TYPES(s):
450
+ return {
451
+ "required": {
452
+ "hooks": ("HOOKS",),
453
+ },
454
+ "optional": {
455
+ "hook_kf": ("HOOK_KEYFRAMES",),
456
+ }
457
+ }
458
+
459
+ EXPERIMENTAL = True
460
+ RETURN_TYPES = ("HOOKS",)
461
+ CATEGORY = "advanced/hooks/scheduling"
462
+ FUNCTION = "set_hook_keyframes"
463
+
464
+ def set_hook_keyframes(self, hooks: comfy.hooks.HookGroup, hook_kf: comfy.hooks.HookKeyframeGroup=None):
465
+ if hook_kf is not None:
466
+ hooks = hooks.clone()
467
+ hooks.set_keyframes_on_hooks(hook_kf=hook_kf)
468
+ return (hooks,)
469
+
470
+ class CreateHookKeyframe:
471
+ NodeId = 'CreateHookKeyframe'
472
+ NodeName = 'Create Hook Keyframe'
473
+ @classmethod
474
+ def INPUT_TYPES(s):
475
+ return {
476
+ "required": {
477
+ "strength_mult": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
478
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
479
+ },
480
+ "optional": {
481
+ "prev_hook_kf": ("HOOK_KEYFRAMES",),
482
+ }
483
+ }
484
+
485
+ EXPERIMENTAL = True
486
+ RETURN_TYPES = ("HOOK_KEYFRAMES",)
487
+ RETURN_NAMES = ("HOOK_KF",)
488
+ CATEGORY = "advanced/hooks/scheduling"
489
+ FUNCTION = "create_hook_keyframe"
490
+
491
+ def create_hook_keyframe(self, strength_mult: float, start_percent: float, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None):
492
+ if prev_hook_kf is None:
493
+ prev_hook_kf = comfy.hooks.HookKeyframeGroup()
494
+ prev_hook_kf = prev_hook_kf.clone()
495
+ keyframe = comfy.hooks.HookKeyframe(strength=strength_mult, start_percent=start_percent)
496
+ prev_hook_kf.add(keyframe)
497
+ return (prev_hook_kf,)
498
+
499
+ class CreateHookKeyframesInterpolated:
500
+ NodeId = 'CreateHookKeyframesInterpolated'
501
+ NodeName = 'Create Hook Keyframes Interp.'
502
+ @classmethod
503
+ def INPUT_TYPES(s):
504
+ return {
505
+ "required": {
506
+ "strength_start": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
507
+ "strength_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
508
+ "interpolation": (comfy.hooks.InterpolationMethod._LIST, ),
509
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
510
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
511
+ "keyframes_count": ("INT", {"default": 5, "min": 2, "max": 100, "step": 1}),
512
+ "print_keyframes": ("BOOLEAN", {"default": False}),
513
+ },
514
+ "optional": {
515
+ "prev_hook_kf": ("HOOK_KEYFRAMES",),
516
+ },
517
+ }
518
+
519
+ EXPERIMENTAL = True
520
+ RETURN_TYPES = ("HOOK_KEYFRAMES",)
521
+ RETURN_NAMES = ("HOOK_KF",)
522
+ CATEGORY = "advanced/hooks/scheduling"
523
+ FUNCTION = "create_hook_keyframes"
524
+
525
+ def create_hook_keyframes(self, strength_start: float, strength_end: float, interpolation: str,
526
+ start_percent: float, end_percent: float, keyframes_count: int,
527
+ print_keyframes=False, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None):
528
+ if prev_hook_kf is None:
529
+ prev_hook_kf = comfy.hooks.HookKeyframeGroup()
530
+ prev_hook_kf = prev_hook_kf.clone()
531
+ percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=keyframes_count,
532
+ method=comfy.hooks.InterpolationMethod.LINEAR)
533
+ strengths = comfy.hooks.InterpolationMethod.get_weights(num_from=strength_start, num_to=strength_end, length=keyframes_count, method=interpolation)
534
+
535
+ is_first = True
536
+ for percent, strength in zip(percents, strengths):
537
+ guarantee_steps = 0
538
+ if is_first:
539
+ guarantee_steps = 1
540
+ is_first = False
541
+ prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
542
+ if print_keyframes:
543
+ logging.info(f"Hook Keyframe - start_percent:{percent} = {strength}")
544
+ return (prev_hook_kf,)
545
+
546
+ class CreateHookKeyframesFromFloats:
547
+ NodeId = 'CreateHookKeyframesFromFloats'
548
+ NodeName = 'Create Hook Keyframes From Floats'
549
+ @classmethod
550
+ def INPUT_TYPES(s):
551
+ return {
552
+ "required": {
553
+ "floats_strength": ("FLOATS", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}),
554
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
555
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
556
+ "print_keyframes": ("BOOLEAN", {"default": False}),
557
+ },
558
+ "optional": {
559
+ "prev_hook_kf": ("HOOK_KEYFRAMES",),
560
+ }
561
+ }
562
+
563
+ EXPERIMENTAL = True
564
+ RETURN_TYPES = ("HOOK_KEYFRAMES",)
565
+ RETURN_NAMES = ("HOOK_KF",)
566
+ CATEGORY = "advanced/hooks/scheduling"
567
+ FUNCTION = "create_hook_keyframes"
568
+
569
+ def create_hook_keyframes(self, floats_strength: Union[float, list[float]],
570
+ start_percent: float, end_percent: float,
571
+ prev_hook_kf: comfy.hooks.HookKeyframeGroup=None, print_keyframes=False):
572
+ if prev_hook_kf is None:
573
+ prev_hook_kf = comfy.hooks.HookKeyframeGroup()
574
+ prev_hook_kf = prev_hook_kf.clone()
575
+ if type(floats_strength) in (float, int):
576
+ floats_strength = [float(floats_strength)]
577
+ elif isinstance(floats_strength, Iterable):
578
+ pass
579
+ else:
580
+ raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.")
581
+ percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength),
582
+ method=comfy.hooks.InterpolationMethod.LINEAR)
583
+
584
+ is_first = True
585
+ for percent, strength in zip(percents, floats_strength):
586
+ guarantee_steps = 0
587
+ if is_first:
588
+ guarantee_steps = 1
589
+ is_first = False
590
+ prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
591
+ if print_keyframes:
592
+ logging.info(f"Hook Keyframe - start_percent:{percent} = {strength}")
593
+ return (prev_hook_kf,)
594
+ #------------------------------------------
595
+ ###########################################
596
+
597
+
598
+ class SetModelHooksOnCond:
599
+ @classmethod
600
+ def INPUT_TYPES(s):
601
+ return {
602
+ "required": {
603
+ "conditioning": ("CONDITIONING",),
604
+ "hooks": ("HOOKS",),
605
+ },
606
+ }
607
+
608
+ EXPERIMENTAL = True
609
+ RETURN_TYPES = ("CONDITIONING",)
610
+ CATEGORY = "advanced/hooks/manual"
611
+ FUNCTION = "attach_hook"
612
+
613
+ def attach_hook(self, conditioning, hooks: comfy.hooks.HookGroup):
614
+ return (comfy.hooks.set_hooks_for_conditioning(conditioning, hooks),)
615
+
616
+
617
+ ###########################################
618
+ # Combine Hooks
619
+ #------------------------------------------
620
+ class CombineHooks:
621
+ NodeId = 'CombineHooks2'
622
+ NodeName = 'Combine Hooks [2]'
623
+ @classmethod
624
+ def INPUT_TYPES(s):
625
+ return {
626
+ "required": {
627
+ },
628
+ "optional": {
629
+ "hooks_A": ("HOOKS",),
630
+ "hooks_B": ("HOOKS",),
631
+ }
632
+ }
633
+
634
+ EXPERIMENTAL = True
635
+ RETURN_TYPES = ("HOOKS",)
636
+ CATEGORY = "advanced/hooks/combine"
637
+ FUNCTION = "combine_hooks"
638
+
639
+ def combine_hooks(self,
640
+ hooks_A: comfy.hooks.HookGroup=None,
641
+ hooks_B: comfy.hooks.HookGroup=None):
642
+ candidates = [hooks_A, hooks_B]
643
+ return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
644
+
645
+ class CombineHooksFour:
646
+ NodeId = 'CombineHooks4'
647
+ NodeName = 'Combine Hooks [4]'
648
+ @classmethod
649
+ def INPUT_TYPES(s):
650
+ return {
651
+ "required": {
652
+ },
653
+ "optional": {
654
+ "hooks_A": ("HOOKS",),
655
+ "hooks_B": ("HOOKS",),
656
+ "hooks_C": ("HOOKS",),
657
+ "hooks_D": ("HOOKS",),
658
+ }
659
+ }
660
+
661
+ EXPERIMENTAL = True
662
+ RETURN_TYPES = ("HOOKS",)
663
+ CATEGORY = "advanced/hooks/combine"
664
+ FUNCTION = "combine_hooks"
665
+
666
+ def combine_hooks(self,
667
+ hooks_A: comfy.hooks.HookGroup=None,
668
+ hooks_B: comfy.hooks.HookGroup=None,
669
+ hooks_C: comfy.hooks.HookGroup=None,
670
+ hooks_D: comfy.hooks.HookGroup=None):
671
+ candidates = [hooks_A, hooks_B, hooks_C, hooks_D]
672
+ return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
673
+
674
+ class CombineHooksEight:
675
+ NodeId = 'CombineHooks8'
676
+ NodeName = 'Combine Hooks [8]'
677
+ @classmethod
678
+ def INPUT_TYPES(s):
679
+ return {
680
+ "required": {
681
+ },
682
+ "optional": {
683
+ "hooks_A": ("HOOKS",),
684
+ "hooks_B": ("HOOKS",),
685
+ "hooks_C": ("HOOKS",),
686
+ "hooks_D": ("HOOKS",),
687
+ "hooks_E": ("HOOKS",),
688
+ "hooks_F": ("HOOKS",),
689
+ "hooks_G": ("HOOKS",),
690
+ "hooks_H": ("HOOKS",),
691
+ }
692
+ }
693
+
694
+ EXPERIMENTAL = True
695
+ RETURN_TYPES = ("HOOKS",)
696
+ CATEGORY = "advanced/hooks/combine"
697
+ FUNCTION = "combine_hooks"
698
+
699
+ def combine_hooks(self,
700
+ hooks_A: comfy.hooks.HookGroup=None,
701
+ hooks_B: comfy.hooks.HookGroup=None,
702
+ hooks_C: comfy.hooks.HookGroup=None,
703
+ hooks_D: comfy.hooks.HookGroup=None,
704
+ hooks_E: comfy.hooks.HookGroup=None,
705
+ hooks_F: comfy.hooks.HookGroup=None,
706
+ hooks_G: comfy.hooks.HookGroup=None,
707
+ hooks_H: comfy.hooks.HookGroup=None):
708
+ candidates = [hooks_A, hooks_B, hooks_C, hooks_D, hooks_E, hooks_F, hooks_G, hooks_H]
709
+ return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
710
+ #------------------------------------------
711
+ ###########################################
712
+
713
+ node_list = [
714
+ # Create
715
+ CreateHookLora,
716
+ CreateHookLoraModelOnly,
717
+ CreateHookModelAsLora,
718
+ CreateHookModelAsLoraModelOnly,
719
+ # Scheduling
720
+ SetHookKeyframes,
721
+ CreateHookKeyframe,
722
+ CreateHookKeyframesInterpolated,
723
+ CreateHookKeyframesFromFloats,
724
+ # Combine
725
+ CombineHooks,
726
+ CombineHooksFour,
727
+ CombineHooksEight,
728
+ # Attach
729
+ ConditioningSetProperties,
730
+ ConditioningSetPropertiesAndCombine,
731
+ PairConditioningSetProperties,
732
+ PairConditioningSetPropertiesAndCombine,
733
+ ConditioningSetDefaultAndCombine,
734
+ PairConditioningSetDefaultAndCombine,
735
+ PairConditioningCombine,
736
+ SetClipHooks,
737
+ # Other
738
+ ConditioningTimestepsRange,
739
+ ]
740
+ NODE_CLASS_MAPPINGS = {}
741
+ NODE_DISPLAY_NAME_MAPPINGS = {}
742
+
743
+ for node in node_list:
744
+ NODE_CLASS_MAPPINGS[node.NodeId] = node
745
+ NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName
comfy_extras/nodes_hunyuan.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nodes
2
+ import torch
3
+ import comfy.model_management
4
+
5
+
6
+ class CLIPTextEncodeHunyuanDiT:
7
+ @classmethod
8
+ def INPUT_TYPES(s):
9
+ return {"required": {
10
+ "clip": ("CLIP", ),
11
+ "bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
12
+ "mt5xl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
13
+ }}
14
+ RETURN_TYPES = ("CONDITIONING",)
15
+ FUNCTION = "encode"
16
+
17
+ CATEGORY = "advanced/conditioning"
18
+
19
+ def encode(self, clip, bert, mt5xl):
20
+ tokens = clip.tokenize(bert)
21
+ tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
22
+
23
+ return (clip.encode_from_tokens_scheduled(tokens), )
24
+
25
+ class EmptyHunyuanLatentVideo:
26
+ @classmethod
27
+ def INPUT_TYPES(s):
28
+ return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
29
+ "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
30
+ "length": ("INT", {"default": 25, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
31
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
32
+ RETURN_TYPES = ("LATENT",)
33
+ FUNCTION = "generate"
34
+
35
+ CATEGORY = "latent/video"
36
+
37
+ def generate(self, width, height, length, batch_size=1):
38
+ latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
39
+ return ({"samples":latent}, )
40
+
41
+ NODE_CLASS_MAPPINGS = {
42
+ "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
43
+ "EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
44
+ }
comfy_extras/nodes_hypernetwork.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import comfy.utils
2
+ import folder_paths
3
+ import torch
4
+ import logging
5
+
6
+ def load_hypernetwork_patch(path, strength):
7
+ sd = comfy.utils.load_torch_file(path, safe_load=True)
8
+ activation_func = sd.get('activation_func', 'linear')
9
+ is_layer_norm = sd.get('is_layer_norm', False)
10
+ use_dropout = sd.get('use_dropout', False)
11
+ activate_output = sd.get('activate_output', False)
12
+ last_layer_dropout = sd.get('last_layer_dropout', False)
13
+
14
+ valid_activation = {
15
+ "linear": torch.nn.Identity,
16
+ "relu": torch.nn.ReLU,
17
+ "leakyrelu": torch.nn.LeakyReLU,
18
+ "elu": torch.nn.ELU,
19
+ "swish": torch.nn.Hardswish,
20
+ "tanh": torch.nn.Tanh,
21
+ "sigmoid": torch.nn.Sigmoid,
22
+ "softsign": torch.nn.Softsign,
23
+ "mish": torch.nn.Mish,
24
+ }
25
+
26
+ if activation_func not in valid_activation:
27
+ logging.error("Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout))
28
+ return None
29
+
30
+ out = {}
31
+
32
+ for d in sd:
33
+ try:
34
+ dim = int(d)
35
+ except:
36
+ continue
37
+
38
+ output = []
39
+ for index in [0, 1]:
40
+ attn_weights = sd[dim][index]
41
+ keys = attn_weights.keys()
42
+
43
+ linears = filter(lambda a: a.endswith(".weight"), keys)
44
+ linears = list(map(lambda a: a[:-len(".weight")], linears))
45
+ layers = []
46
+
47
+ i = 0
48
+ while i < len(linears):
49
+ lin_name = linears[i]
50
+ last_layer = (i == (len(linears) - 1))
51
+ penultimate_layer = (i == (len(linears) - 2))
52
+
53
+ lin_weight = attn_weights['{}.weight'.format(lin_name)]
54
+ lin_bias = attn_weights['{}.bias'.format(lin_name)]
55
+ layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
56
+ layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
57
+ layers.append(layer)
58
+ if activation_func != "linear":
59
+ if (not last_layer) or (activate_output):
60
+ layers.append(valid_activation[activation_func]())
61
+ if is_layer_norm:
62
+ i += 1
63
+ ln_name = linears[i]
64
+ ln_weight = attn_weights['{}.weight'.format(ln_name)]
65
+ ln_bias = attn_weights['{}.bias'.format(ln_name)]
66
+ ln = torch.nn.LayerNorm(ln_weight.shape[0])
67
+ ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
68
+ layers.append(ln)
69
+ if use_dropout:
70
+ if (not last_layer) and (not penultimate_layer or last_layer_dropout):
71
+ layers.append(torch.nn.Dropout(p=0.3))
72
+ i += 1
73
+
74
+ output.append(torch.nn.Sequential(*layers))
75
+ out[dim] = torch.nn.ModuleList(output)
76
+
77
+ class hypernetwork_patch:
78
+ def __init__(self, hypernet, strength):
79
+ self.hypernet = hypernet
80
+ self.strength = strength
81
+ def __call__(self, q, k, v, extra_options):
82
+ dim = k.shape[-1]
83
+ if dim in self.hypernet:
84
+ hn = self.hypernet[dim]
85
+ k = k + hn[0](k) * self.strength
86
+ v = v + hn[1](v) * self.strength
87
+
88
+ return q, k, v
89
+
90
+ def to(self, device):
91
+ for d in self.hypernet.keys():
92
+ self.hypernet[d] = self.hypernet[d].to(device)
93
+ return self
94
+
95
+ return hypernetwork_patch(out, strength)
96
+
97
+ class HypernetworkLoader:
98
+ @classmethod
99
+ def INPUT_TYPES(s):
100
+ return {"required": { "model": ("MODEL",),
101
+ "hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
102
+ "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
103
+ }}
104
+ RETURN_TYPES = ("MODEL",)
105
+ FUNCTION = "load_hypernetwork"
106
+
107
+ CATEGORY = "loaders"
108
+
109
+ def load_hypernetwork(self, model, hypernetwork_name, strength):
110
+ hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
111
+ model_hypernetwork = model.clone()
112
+ patch = load_hypernetwork_patch(hypernetwork_path, strength)
113
+ if patch is not None:
114
+ model_hypernetwork.set_model_attn1_patch(patch)
115
+ model_hypernetwork.set_model_attn2_patch(patch)
116
+ return (model_hypernetwork,)
117
+
118
+ NODE_CLASS_MAPPINGS = {
119
+ "HypernetworkLoader": HypernetworkLoader
120
+ }
comfy_extras/nodes_hypertile.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Taken from: https://github.com/tfernd/HyperTile/
2
+
3
+ import math
4
+ from einops import rearrange
5
+ # Use torch rng for consistency across generations
6
+ from torch import randint
7
+
8
+ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
9
+ min_value = min(min_value, value)
10
+
11
+ # All big divisors of value (inclusive)
12
+ divisors = [i for i in range(min_value, value + 1) if value % i == 0]
13
+
14
+ ns = [value // i for i in divisors[:max_options]] # has at least 1 element
15
+
16
+ if len(ns) - 1 > 0:
17
+ idx = randint(low=0, high=len(ns) - 1, size=(1,)).item()
18
+ else:
19
+ idx = 0
20
+
21
+ return ns[idx]
22
+
23
+ class HyperTile:
24
+ @classmethod
25
+ def INPUT_TYPES(s):
26
+ return {"required": { "model": ("MODEL",),
27
+ "tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}),
28
+ "swap_size": ("INT", {"default": 2, "min": 1, "max": 128}),
29
+ "max_depth": ("INT", {"default": 0, "min": 0, "max": 10}),
30
+ "scale_depth": ("BOOLEAN", {"default": False}),
31
+ }}
32
+ RETURN_TYPES = ("MODEL",)
33
+ FUNCTION = "patch"
34
+
35
+ CATEGORY = "model_patches/unet"
36
+
37
+ def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
38
+ latent_tile_size = max(32, tile_size) // 8
39
+ self.temp = None
40
+
41
+ def hypertile_in(q, k, v, extra_options):
42
+ model_chans = q.shape[-2]
43
+ orig_shape = extra_options['original_shape']
44
+ apply_to = []
45
+ for i in range(max_depth + 1):
46
+ apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i)))
47
+
48
+ if model_chans in apply_to:
49
+ shape = extra_options["original_shape"]
50
+ aspect_ratio = shape[-1] / shape[-2]
51
+
52
+ hw = q.size(1)
53
+ h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
54
+
55
+ factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1
56
+ nh = random_divisor(h, latent_tile_size * factor, swap_size)
57
+ nw = random_divisor(w, latent_tile_size * factor, swap_size)
58
+
59
+ if nh * nw > 1:
60
+ q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
61
+ self.temp = (nh, nw, h, w)
62
+ return q, k, v
63
+
64
+ return q, k, v
65
+ def hypertile_out(out, extra_options):
66
+ if self.temp is not None:
67
+ nh, nw, h, w = self.temp
68
+ self.temp = None
69
+ out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
70
+ out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
71
+ return out
72
+
73
+
74
+ m = model.clone()
75
+ m.set_model_attn1_patch(hypertile_in)
76
+ m.set_model_attn1_output_patch(hypertile_out)
77
+ return (m, )
78
+
79
+ NODE_CLASS_MAPPINGS = {
80
+ "HyperTile": HyperTile,
81
+ }
comfy_extras/nodes_images.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nodes
2
+ import folder_paths
3
+ from comfy.cli_args import args
4
+
5
+ from PIL import Image
6
+ from PIL.PngImagePlugin import PngInfo
7
+
8
+ import numpy as np
9
+ import json
10
+ import os
11
+
12
+ MAX_RESOLUTION = nodes.MAX_RESOLUTION
13
+
14
+ class ImageCrop:
15
+ @classmethod
16
+ def INPUT_TYPES(s):
17
+ return {"required": { "image": ("IMAGE",),
18
+ "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
19
+ "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
20
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
21
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
22
+ }}
23
+ RETURN_TYPES = ("IMAGE",)
24
+ FUNCTION = "crop"
25
+
26
+ CATEGORY = "image/transform"
27
+
28
+ def crop(self, image, width, height, x, y):
29
+ x = min(x, image.shape[2] - 1)
30
+ y = min(y, image.shape[1] - 1)
31
+ to_x = width + x
32
+ to_y = height + y
33
+ img = image[:,y:to_y, x:to_x, :]
34
+ return (img,)
35
+
36
+ class RepeatImageBatch:
37
+ @classmethod
38
+ def INPUT_TYPES(s):
39
+ return {"required": { "image": ("IMAGE",),
40
+ "amount": ("INT", {"default": 1, "min": 1, "max": 4096}),
41
+ }}
42
+ RETURN_TYPES = ("IMAGE",)
43
+ FUNCTION = "repeat"
44
+
45
+ CATEGORY = "image/batch"
46
+
47
+ def repeat(self, image, amount):
48
+ s = image.repeat((amount, 1,1,1))
49
+ return (s,)
50
+
51
+ class ImageFromBatch:
52
+ @classmethod
53
+ def INPUT_TYPES(s):
54
+ return {"required": { "image": ("IMAGE",),
55
+ "batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}),
56
+ "length": ("INT", {"default": 1, "min": 1, "max": 4096}),
57
+ }}
58
+ RETURN_TYPES = ("IMAGE",)
59
+ FUNCTION = "frombatch"
60
+
61
+ CATEGORY = "image/batch"
62
+
63
+ def frombatch(self, image, batch_index, length):
64
+ s_in = image
65
+ batch_index = min(s_in.shape[0] - 1, batch_index)
66
+ length = min(s_in.shape[0] - batch_index, length)
67
+ s = s_in[batch_index:batch_index + length].clone()
68
+ return (s,)
69
+
70
+ class SaveAnimatedWEBP:
71
+ def __init__(self):
72
+ self.output_dir = folder_paths.get_output_directory()
73
+ self.type = "output"
74
+ self.prefix_append = ""
75
+
76
+ methods = {"default": 4, "fastest": 0, "slowest": 6}
77
+ @classmethod
78
+ def INPUT_TYPES(s):
79
+ return {"required":
80
+ {"images": ("IMAGE", ),
81
+ "filename_prefix": ("STRING", {"default": "ComfyUI"}),
82
+ "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
83
+ "lossless": ("BOOLEAN", {"default": True}),
84
+ "quality": ("INT", {"default": 80, "min": 0, "max": 100}),
85
+ "method": (list(s.methods.keys()),),
86
+ # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
87
+ },
88
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
89
+ }
90
+
91
+ RETURN_TYPES = ()
92
+ FUNCTION = "save_images"
93
+
94
+ OUTPUT_NODE = True
95
+
96
+ CATEGORY = "image/animation"
97
+
98
+ def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
99
+ method = self.methods.get(method)
100
+ filename_prefix += self.prefix_append
101
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
102
+ results = list()
103
+ pil_images = []
104
+ for image in images:
105
+ i = 255. * image.cpu().numpy()
106
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
107
+ pil_images.append(img)
108
+
109
+ metadata = pil_images[0].getexif()
110
+ if not args.disable_metadata:
111
+ if prompt is not None:
112
+ metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
113
+ if extra_pnginfo is not None:
114
+ inital_exif = 0x010f
115
+ for x in extra_pnginfo:
116
+ metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
117
+ inital_exif -= 1
118
+
119
+ if num_frames == 0:
120
+ num_frames = len(pil_images)
121
+
122
+ c = len(pil_images)
123
+ for i in range(0, c, num_frames):
124
+ file = f"{filename}_{counter:05}_.webp"
125
+ pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
126
+ results.append({
127
+ "filename": file,
128
+ "subfolder": subfolder,
129
+ "type": self.type
130
+ })
131
+ counter += 1
132
+
133
+ animated = num_frames != 1
134
+ return { "ui": { "images": results, "animated": (animated,) } }
135
+
136
+ class SaveAnimatedPNG:
137
+ def __init__(self):
138
+ self.output_dir = folder_paths.get_output_directory()
139
+ self.type = "output"
140
+ self.prefix_append = ""
141
+
142
+ @classmethod
143
+ def INPUT_TYPES(s):
144
+ return {"required":
145
+ {"images": ("IMAGE", ),
146
+ "filename_prefix": ("STRING", {"default": "ComfyUI"}),
147
+ "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
148
+ "compress_level": ("INT", {"default": 4, "min": 0, "max": 9})
149
+ },
150
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
151
+ }
152
+
153
+ RETURN_TYPES = ()
154
+ FUNCTION = "save_images"
155
+
156
+ OUTPUT_NODE = True
157
+
158
+ CATEGORY = "image/animation"
159
+
160
+ def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
161
+ filename_prefix += self.prefix_append
162
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
163
+ results = list()
164
+ pil_images = []
165
+ for image in images:
166
+ i = 255. * image.cpu().numpy()
167
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
168
+ pil_images.append(img)
169
+
170
+ metadata = None
171
+ if not args.disable_metadata:
172
+ metadata = PngInfo()
173
+ if prompt is not None:
174
+ metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True)
175
+ if extra_pnginfo is not None:
176
+ for x in extra_pnginfo:
177
+ metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
178
+
179
+ file = f"{filename}_{counter:05}_.png"
180
+ pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
181
+ results.append({
182
+ "filename": file,
183
+ "subfolder": subfolder,
184
+ "type": self.type
185
+ })
186
+
187
+ return { "ui": { "images": results, "animated": (True,)} }
188
+
189
+ NODE_CLASS_MAPPINGS = {
190
+ "ImageCrop": ImageCrop,
191
+ "RepeatImageBatch": RepeatImageBatch,
192
+ "ImageFromBatch": ImageFromBatch,
193
+ "SaveAnimatedWEBP": SaveAnimatedWEBP,
194
+ "SaveAnimatedPNG": SaveAnimatedPNG,
195
+ }
comfy_extras/nodes_ip2p.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class InstructPixToPixConditioning:
4
+ @classmethod
5
+ def INPUT_TYPES(s):
6
+ return {"required": {"positive": ("CONDITIONING", ),
7
+ "negative": ("CONDITIONING", ),
8
+ "vae": ("VAE", ),
9
+ "pixels": ("IMAGE", ),
10
+ }}
11
+
12
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
13
+ RETURN_NAMES = ("positive", "negative", "latent")
14
+ FUNCTION = "encode"
15
+
16
+ CATEGORY = "conditioning/instructpix2pix"
17
+
18
+ def encode(self, positive, negative, pixels, vae):
19
+ x = (pixels.shape[1] // 8) * 8
20
+ y = (pixels.shape[2] // 8) * 8
21
+
22
+ if pixels.shape[1] != x or pixels.shape[2] != y:
23
+ x_offset = (pixels.shape[1] % 8) // 2
24
+ y_offset = (pixels.shape[2] % 8) // 2
25
+ pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
26
+
27
+ concat_latent = vae.encode(pixels)
28
+
29
+ out_latent = {}
30
+ out_latent["samples"] = torch.zeros_like(concat_latent)
31
+
32
+ out = []
33
+ for conditioning in [positive, negative]:
34
+ c = []
35
+ for t in conditioning:
36
+ d = t[1].copy()
37
+ d["concat_latent_image"] = concat_latent
38
+ n = [t[0], d]
39
+ c.append(n)
40
+ out.append(c)
41
+ return (out[0], out[1], out_latent)
42
+
43
+ NODE_CLASS_MAPPINGS = {
44
+ "InstructPixToPixConditioning": InstructPixToPixConditioning,
45
+ }
comfy_extras/nodes_latent.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import comfy.utils
2
+ import comfy_extras.nodes_post_processing
3
+ import torch
4
+
5
+
6
+ def reshape_latent_to(target_shape, latent, repeat_batch=True):
7
+ if latent.shape[1:] != target_shape[1:]:
8
+ latent = comfy.utils.common_upscale(latent, target_shape[-1], target_shape[-2], "bilinear", "center")
9
+ if repeat_batch:
10
+ return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
11
+ else:
12
+ return latent
13
+
14
+
15
+ class LatentAdd:
16
+ @classmethod
17
+ def INPUT_TYPES(s):
18
+ return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
19
+
20
+ RETURN_TYPES = ("LATENT",)
21
+ FUNCTION = "op"
22
+
23
+ CATEGORY = "latent/advanced"
24
+
25
+ def op(self, samples1, samples2):
26
+ samples_out = samples1.copy()
27
+
28
+ s1 = samples1["samples"]
29
+ s2 = samples2["samples"]
30
+
31
+ s2 = reshape_latent_to(s1.shape, s2)
32
+ samples_out["samples"] = s1 + s2
33
+ return (samples_out,)
34
+
35
+ class LatentSubtract:
36
+ @classmethod
37
+ def INPUT_TYPES(s):
38
+ return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
39
+
40
+ RETURN_TYPES = ("LATENT",)
41
+ FUNCTION = "op"
42
+
43
+ CATEGORY = "latent/advanced"
44
+
45
+ def op(self, samples1, samples2):
46
+ samples_out = samples1.copy()
47
+
48
+ s1 = samples1["samples"]
49
+ s2 = samples2["samples"]
50
+
51
+ s2 = reshape_latent_to(s1.shape, s2)
52
+ samples_out["samples"] = s1 - s2
53
+ return (samples_out,)
54
+
55
+ class LatentMultiply:
56
+ @classmethod
57
+ def INPUT_TYPES(s):
58
+ return {"required": { "samples": ("LATENT",),
59
+ "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
60
+ }}
61
+
62
+ RETURN_TYPES = ("LATENT",)
63
+ FUNCTION = "op"
64
+
65
+ CATEGORY = "latent/advanced"
66
+
67
+ def op(self, samples, multiplier):
68
+ samples_out = samples.copy()
69
+
70
+ s1 = samples["samples"]
71
+ samples_out["samples"] = s1 * multiplier
72
+ return (samples_out,)
73
+
74
+ class LatentInterpolate:
75
+ @classmethod
76
+ def INPUT_TYPES(s):
77
+ return {"required": { "samples1": ("LATENT",),
78
+ "samples2": ("LATENT",),
79
+ "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
80
+ }}
81
+
82
+ RETURN_TYPES = ("LATENT",)
83
+ FUNCTION = "op"
84
+
85
+ CATEGORY = "latent/advanced"
86
+
87
+ def op(self, samples1, samples2, ratio):
88
+ samples_out = samples1.copy()
89
+
90
+ s1 = samples1["samples"]
91
+ s2 = samples2["samples"]
92
+
93
+ s2 = reshape_latent_to(s1.shape, s2)
94
+
95
+ m1 = torch.linalg.vector_norm(s1, dim=(1))
96
+ m2 = torch.linalg.vector_norm(s2, dim=(1))
97
+
98
+ s1 = torch.nan_to_num(s1 / m1)
99
+ s2 = torch.nan_to_num(s2 / m2)
100
+
101
+ t = (s1 * ratio + s2 * (1.0 - ratio))
102
+ mt = torch.linalg.vector_norm(t, dim=(1))
103
+ st = torch.nan_to_num(t / mt)
104
+
105
+ samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
106
+ return (samples_out,)
107
+
108
+ class LatentBatch:
109
+ @classmethod
110
+ def INPUT_TYPES(s):
111
+ return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
112
+
113
+ RETURN_TYPES = ("LATENT",)
114
+ FUNCTION = "batch"
115
+
116
+ CATEGORY = "latent/batch"
117
+
118
+ def batch(self, samples1, samples2):
119
+ samples_out = samples1.copy()
120
+ s1 = samples1["samples"]
121
+ s2 = samples2["samples"]
122
+
123
+ s2 = reshape_latent_to(s1.shape, s2, repeat_batch=False)
124
+ s = torch.cat((s1, s2), dim=0)
125
+ samples_out["samples"] = s
126
+ samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
127
+ return (samples_out,)
128
+
129
+ class LatentBatchSeedBehavior:
130
+ @classmethod
131
+ def INPUT_TYPES(s):
132
+ return {"required": { "samples": ("LATENT",),
133
+ "seed_behavior": (["random", "fixed"],{"default": "fixed"}),}}
134
+
135
+ RETURN_TYPES = ("LATENT",)
136
+ FUNCTION = "op"
137
+
138
+ CATEGORY = "latent/advanced"
139
+
140
+ def op(self, samples, seed_behavior):
141
+ samples_out = samples.copy()
142
+ latent = samples["samples"]
143
+ if seed_behavior == "random":
144
+ if 'batch_index' in samples_out:
145
+ samples_out.pop('batch_index')
146
+ elif seed_behavior == "fixed":
147
+ batch_number = samples_out.get("batch_index", [0])[0]
148
+ samples_out["batch_index"] = [batch_number] * latent.shape[0]
149
+
150
+ return (samples_out,)
151
+
152
+ class LatentApplyOperation:
153
+ @classmethod
154
+ def INPUT_TYPES(s):
155
+ return {"required": { "samples": ("LATENT",),
156
+ "operation": ("LATENT_OPERATION",),
157
+ }}
158
+
159
+ RETURN_TYPES = ("LATENT",)
160
+ FUNCTION = "op"
161
+
162
+ CATEGORY = "latent/advanced/operations"
163
+ EXPERIMENTAL = True
164
+
165
+ def op(self, samples, operation):
166
+ samples_out = samples.copy()
167
+
168
+ s1 = samples["samples"]
169
+ samples_out["samples"] = operation(latent=s1)
170
+ return (samples_out,)
171
+
172
+ class LatentApplyOperationCFG:
173
+ @classmethod
174
+ def INPUT_TYPES(s):
175
+ return {"required": { "model": ("MODEL",),
176
+ "operation": ("LATENT_OPERATION",),
177
+ }}
178
+ RETURN_TYPES = ("MODEL",)
179
+ FUNCTION = "patch"
180
+
181
+ CATEGORY = "latent/advanced/operations"
182
+ EXPERIMENTAL = True
183
+
184
+ def patch(self, model, operation):
185
+ m = model.clone()
186
+
187
+ def pre_cfg_function(args):
188
+ conds_out = args["conds_out"]
189
+ if len(conds_out) == 2:
190
+ conds_out[0] = operation(latent=(conds_out[0] - conds_out[1])) + conds_out[1]
191
+ else:
192
+ conds_out[0] = operation(latent=conds_out[0])
193
+ return conds_out
194
+
195
+ m.set_model_sampler_pre_cfg_function(pre_cfg_function)
196
+ return (m, )
197
+
198
+ class LatentOperationTonemapReinhard:
199
+ @classmethod
200
+ def INPUT_TYPES(s):
201
+ return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
202
+ }}
203
+
204
+ RETURN_TYPES = ("LATENT_OPERATION",)
205
+ FUNCTION = "op"
206
+
207
+ CATEGORY = "latent/advanced/operations"
208
+ EXPERIMENTAL = True
209
+
210
+ def op(self, multiplier):
211
+ def tonemap_reinhard(latent, **kwargs):
212
+ latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
213
+ normalized_latent = latent / latent_vector_magnitude
214
+
215
+ mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
216
+ std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
217
+
218
+ top = (std * 5 + mean) * multiplier
219
+
220
+ #reinhard
221
+ latent_vector_magnitude *= (1.0 / top)
222
+ new_magnitude = latent_vector_magnitude / (latent_vector_magnitude + 1.0)
223
+ new_magnitude *= top
224
+
225
+ return normalized_latent * new_magnitude
226
+ return (tonemap_reinhard,)
227
+
228
+ class LatentOperationSharpen:
229
+ @classmethod
230
+ def INPUT_TYPES(s):
231
+ return {"required": {
232
+ "sharpen_radius": ("INT", {
233
+ "default": 9,
234
+ "min": 1,
235
+ "max": 31,
236
+ "step": 1
237
+ }),
238
+ "sigma": ("FLOAT", {
239
+ "default": 1.0,
240
+ "min": 0.1,
241
+ "max": 10.0,
242
+ "step": 0.1
243
+ }),
244
+ "alpha": ("FLOAT", {
245
+ "default": 0.1,
246
+ "min": 0.0,
247
+ "max": 5.0,
248
+ "step": 0.01
249
+ }),
250
+ }}
251
+
252
+ RETURN_TYPES = ("LATENT_OPERATION",)
253
+ FUNCTION = "op"
254
+
255
+ CATEGORY = "latent/advanced/operations"
256
+ EXPERIMENTAL = True
257
+
258
+ def op(self, sharpen_radius, sigma, alpha):
259
+ def sharpen(latent, **kwargs):
260
+ luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
261
+ normalized_latent = latent / luminance
262
+ channels = latent.shape[1]
263
+
264
+ kernel_size = sharpen_radius * 2 + 1
265
+ kernel = comfy_extras.nodes_post_processing.gaussian_kernel(kernel_size, sigma, device=luminance.device)
266
+ center = kernel_size // 2
267
+
268
+ kernel *= alpha * -10
269
+ kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
270
+
271
+ padded_image = torch.nn.functional.pad(normalized_latent, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
272
+ sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
273
+
274
+ return luminance * sharpened
275
+ return (sharpen,)
276
+
277
+ NODE_CLASS_MAPPINGS = {
278
+ "LatentAdd": LatentAdd,
279
+ "LatentSubtract": LatentSubtract,
280
+ "LatentMultiply": LatentMultiply,
281
+ "LatentInterpolate": LatentInterpolate,
282
+ "LatentBatch": LatentBatch,
283
+ "LatentBatchSeedBehavior": LatentBatchSeedBehavior,
284
+ "LatentApplyOperation": LatentApplyOperation,
285
+ "LatentApplyOperationCFG": LatentApplyOperationCFG,
286
+ "LatentOperationTonemapReinhard": LatentOperationTonemapReinhard,
287
+ "LatentOperationSharpen": LatentOperationSharpen,
288
+ }
comfy_extras/nodes_load_3d.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nodes
2
+ import folder_paths
3
+ import os
4
+
5
+ def normalize_path(path):
6
+ return path.replace('\\', '/')
7
+
8
+ class Load3D():
9
+ @classmethod
10
+ def INPUT_TYPES(s):
11
+ input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
12
+
13
+ os.makedirs(input_dir, exist_ok=True)
14
+
15
+ files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.mtl', '.fbx', '.stl'))]
16
+
17
+ return {"required": {
18
+ "model_file": (sorted(files), {"file_upload": True}),
19
+ "image": ("LOAD_3D", {}),
20
+ "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
21
+ "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
22
+ "material": (["original", "normal", "wireframe", "depth"],),
23
+ "light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
24
+ "up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
25
+ "fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
26
+ }}
27
+
28
+ RETURN_TYPES = ("IMAGE", "MASK", "STRING")
29
+ RETURN_NAMES = ("image", "mask", "mesh_path")
30
+
31
+ FUNCTION = "process"
32
+ EXPERIMENTAL = True
33
+
34
+ CATEGORY = "3d"
35
+
36
+ def process(self, model_file, image, **kwargs):
37
+ if isinstance(image, dict):
38
+ image_path = folder_paths.get_annotated_filepath(image['image'])
39
+ mask_path = folder_paths.get_annotated_filepath(image['mask'])
40
+
41
+ load_image_node = nodes.LoadImage()
42
+ output_image, ignore_mask = load_image_node.load_image(image=image_path)
43
+ ignore_image, output_mask = load_image_node.load_image(image=mask_path)
44
+
45
+ return output_image, output_mask, model_file,
46
+ else:
47
+ # to avoid the format is not dict which will happen the FE code is not compatibility to core,
48
+ # we need to this to double-check, it can be removed after merged FE into the core
49
+ image_path = folder_paths.get_annotated_filepath(image)
50
+ load_image_node = nodes.LoadImage()
51
+ output_image, output_mask = load_image_node.load_image(image=image_path)
52
+ return output_image, output_mask, model_file,
53
+
54
+ class Load3DAnimation():
55
+ @classmethod
56
+ def INPUT_TYPES(s):
57
+ input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
58
+
59
+ os.makedirs(input_dir, exist_ok=True)
60
+
61
+ files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.fbx'))]
62
+
63
+ return {"required": {
64
+ "model_file": (sorted(files), {"file_upload": True}),
65
+ "image": ("LOAD_3D_ANIMATION", {}),
66
+ "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
67
+ "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
68
+ "material": (["original", "normal", "wireframe", "depth"],),
69
+ "light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
70
+ "up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
71
+ "fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
72
+ }}
73
+
74
+ RETURN_TYPES = ("IMAGE", "MASK", "STRING")
75
+ RETURN_NAMES = ("image", "mask", "mesh_path")
76
+
77
+ FUNCTION = "process"
78
+ EXPERIMENTAL = True
79
+
80
+ CATEGORY = "3d"
81
+
82
+ def process(self, model_file, image, **kwargs):
83
+ if isinstance(image, dict):
84
+ image_path = folder_paths.get_annotated_filepath(image['image'])
85
+ mask_path = folder_paths.get_annotated_filepath(image['mask'])
86
+
87
+ load_image_node = nodes.LoadImage()
88
+ output_image, ignore_mask = load_image_node.load_image(image=image_path)
89
+ ignore_image, output_mask = load_image_node.load_image(image=mask_path)
90
+
91
+ return output_image, output_mask, model_file,
92
+ else:
93
+ image_path = folder_paths.get_annotated_filepath(image)
94
+ load_image_node = nodes.LoadImage()
95
+ output_image, output_mask = load_image_node.load_image(image=image_path)
96
+ return output_image, output_mask, model_file,
97
+
98
+ class Preview3D():
99
+ @classmethod
100
+ def INPUT_TYPES(s):
101
+ return {"required": {
102
+ "model_file": ("STRING", {"default": "", "multiline": False}),
103
+ "material": (["original", "normal", "wireframe", "depth"],),
104
+ "light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
105
+ "up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
106
+ "fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
107
+ }}
108
+
109
+ OUTPUT_NODE = True
110
+ RETURN_TYPES = ()
111
+
112
+ CATEGORY = "3d"
113
+
114
+ FUNCTION = "process"
115
+ EXPERIMENTAL = True
116
+
117
+ def process(self, model_file, **kwargs):
118
+ return {"ui": {"model_file": [model_file]}, "result": ()}
119
+
120
+ class Preview3DAnimation():
121
+ @classmethod
122
+ def INPUT_TYPES(s):
123
+ return {"required": {
124
+ "model_file": ("STRING", {"default": "", "multiline": False}),
125
+ "material": (["original", "normal", "wireframe", "depth"],),
126
+ "light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
127
+ "up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
128
+ "fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
129
+ }}
130
+
131
+ OUTPUT_NODE = True
132
+ RETURN_TYPES = ()
133
+
134
+ CATEGORY = "3d"
135
+
136
+ FUNCTION = "process"
137
+ EXPERIMENTAL = True
138
+
139
+ def process(self, model_file, **kwargs):
140
+ return {"ui": {"model_file": [model_file]}, "result": ()}
141
+
142
+ NODE_CLASS_MAPPINGS = {
143
+ "Load3D": Load3D,
144
+ "Load3DAnimation": Load3DAnimation,
145
+ "Preview3D": Preview3D,
146
+ "Preview3DAnimation": Preview3DAnimation
147
+ }
148
+
149
+ NODE_DISPLAY_NAME_MAPPINGS = {
150
+ "Load3D": "Load 3D",
151
+ "Load3DAnimation": "Load 3D - Animation",
152
+ "Preview3D": "Preview 3D",
153
+ "Preview3DAnimation": "Preview 3D - Animation"
154
+ }
comfy_extras/nodes_lora_extract.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import comfy.model_management
3
+ import comfy.utils
4
+ import folder_paths
5
+ import os
6
+ import logging
7
+ from enum import Enum
8
+
9
+ CLAMP_QUANTILE = 0.99
10
+
11
+ def extract_lora(diff, rank):
12
+ conv2d = (len(diff.shape) == 4)
13
+ kernel_size = None if not conv2d else diff.size()[2:4]
14
+ conv2d_3x3 = conv2d and kernel_size != (1, 1)
15
+ out_dim, in_dim = diff.size()[0:2]
16
+ rank = min(rank, in_dim, out_dim)
17
+
18
+ if conv2d:
19
+ if conv2d_3x3:
20
+ diff = diff.flatten(start_dim=1)
21
+ else:
22
+ diff = diff.squeeze()
23
+
24
+
25
+ U, S, Vh = torch.linalg.svd(diff.float())
26
+ U = U[:, :rank]
27
+ S = S[:rank]
28
+ U = U @ torch.diag(S)
29
+ Vh = Vh[:rank, :]
30
+
31
+ dist = torch.cat([U.flatten(), Vh.flatten()])
32
+ hi_val = torch.quantile(dist, CLAMP_QUANTILE)
33
+ low_val = -hi_val
34
+
35
+ U = U.clamp(low_val, hi_val)
36
+ Vh = Vh.clamp(low_val, hi_val)
37
+ if conv2d:
38
+ U = U.reshape(out_dim, rank, 1, 1)
39
+ Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
40
+ return (U, Vh)
41
+
42
+ class LORAType(Enum):
43
+ STANDARD = 0
44
+ FULL_DIFF = 1
45
+
46
+ LORA_TYPES = {"standard": LORAType.STANDARD,
47
+ "full_diff": LORAType.FULL_DIFF}
48
+
49
+ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
50
+ comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
51
+ sd = model_diff.model_state_dict(filter_prefix=prefix_model)
52
+
53
+ for k in sd:
54
+ if k.endswith(".weight"):
55
+ weight_diff = sd[k]
56
+ if lora_type == LORAType.STANDARD:
57
+ if weight_diff.ndim < 2:
58
+ if bias_diff:
59
+ output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
60
+ continue
61
+ try:
62
+ out = extract_lora(weight_diff, rank)
63
+ output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
64
+ output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
65
+ except:
66
+ logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
67
+ elif lora_type == LORAType.FULL_DIFF:
68
+ output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
69
+
70
+ elif bias_diff and k.endswith(".bias"):
71
+ output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
72
+ return output_sd
73
+
74
+ class LoraSave:
75
+ def __init__(self):
76
+ self.output_dir = folder_paths.get_output_directory()
77
+
78
+ @classmethod
79
+ def INPUT_TYPES(s):
80
+ return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
81
+ "rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
82
+ "lora_type": (tuple(LORA_TYPES.keys()),),
83
+ "bias_diff": ("BOOLEAN", {"default": True}),
84
+ },
85
+ "optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
86
+ "text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
87
+ }
88
+ RETURN_TYPES = ()
89
+ FUNCTION = "save"
90
+ OUTPUT_NODE = True
91
+
92
+ CATEGORY = "_for_testing"
93
+
94
+ def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
95
+ if model_diff is None and text_encoder_diff is None:
96
+ return {}
97
+
98
+ lora_type = LORA_TYPES.get(lora_type)
99
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
100
+
101
+ output_sd = {}
102
+ if model_diff is not None:
103
+ output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff)
104
+ if text_encoder_diff is not None:
105
+ output_sd = calc_lora_model(text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff)
106
+
107
+ output_checkpoint = f"{filename}_{counter:05}_.safetensors"
108
+ output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
109
+
110
+ comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
111
+ return {}
112
+
113
+ NODE_CLASS_MAPPINGS = {
114
+ "LoraSave": LoraSave
115
+ }
116
+
117
+ NODE_DISPLAY_NAME_MAPPINGS = {
118
+ "LoraSave": "Extract and Save Lora"
119
+ }
comfy_extras/nodes_lt.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nodes
2
+ import node_helpers
3
+ import torch
4
+ import comfy.model_management
5
+ import comfy.model_sampling
6
+ import math
7
+
8
+ class EmptyLTXVLatentVideo:
9
+ @classmethod
10
+ def INPUT_TYPES(s):
11
+ return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
12
+ "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
13
+ "length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
14
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
15
+ RETURN_TYPES = ("LATENT",)
16
+ FUNCTION = "generate"
17
+
18
+ CATEGORY = "latent/video/ltxv"
19
+
20
+ def generate(self, width, height, length, batch_size=1):
21
+ latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
22
+ return ({"samples": latent}, )
23
+
24
+
25
+ class LTXVImgToVideo:
26
+ @classmethod
27
+ def INPUT_TYPES(s):
28
+ return {"required": {"positive": ("CONDITIONING", ),
29
+ "negative": ("CONDITIONING", ),
30
+ "vae": ("VAE",),
31
+ "image": ("IMAGE",),
32
+ "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
33
+ "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
34
+ "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
35
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
36
+ "image_noise_scale": ("FLOAT", {"default": 0.15, "min": 0, "max": 1.0, "step": 0.01, "tooltip": "Amount of noise to apply on conditioning image latent."})
37
+ }}
38
+
39
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
40
+ RETURN_NAMES = ("positive", "negative", "latent")
41
+
42
+ CATEGORY = "conditioning/video_models"
43
+ FUNCTION = "generate"
44
+
45
+ def generate(self, positive, negative, image, vae, width, height, length, batch_size, image_noise_scale):
46
+ pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
47
+ encode_pixels = pixels[:, :, :, :3]
48
+ t = vae.encode(encode_pixels)
49
+ positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
50
+ negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
51
+
52
+ latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
53
+ latent[:, :, :t.shape[2]] = t
54
+ return (positive, negative, {"samples": latent}, )
55
+
56
+
57
+ class LTXVConditioning:
58
+ @classmethod
59
+ def INPUT_TYPES(s):
60
+ return {"required": {"positive": ("CONDITIONING", ),
61
+ "negative": ("CONDITIONING", ),
62
+ "frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
63
+ }}
64
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
65
+ RETURN_NAMES = ("positive", "negative")
66
+ FUNCTION = "append"
67
+
68
+ CATEGORY = "conditioning/video_models"
69
+
70
+ def append(self, positive, negative, frame_rate):
71
+ positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate})
72
+ negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate})
73
+ return (positive, negative)
74
+
75
+
76
+ class ModelSamplingLTXV:
77
+ @classmethod
78
+ def INPUT_TYPES(s):
79
+ return {"required": { "model": ("MODEL",),
80
+ "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
81
+ "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
82
+ },
83
+ "optional": {"latent": ("LATENT",), }
84
+ }
85
+
86
+ RETURN_TYPES = ("MODEL",)
87
+ FUNCTION = "patch"
88
+
89
+ CATEGORY = "advanced/model"
90
+
91
+ def patch(self, model, max_shift, base_shift, latent=None):
92
+ m = model.clone()
93
+
94
+ if latent is None:
95
+ tokens = 4096
96
+ else:
97
+ tokens = math.prod(latent["samples"].shape[2:])
98
+
99
+ x1 = 1024
100
+ x2 = 4096
101
+ mm = (max_shift - base_shift) / (x2 - x1)
102
+ b = base_shift - mm * x1
103
+ shift = (tokens) * mm + b
104
+
105
+ sampling_base = comfy.model_sampling.ModelSamplingFlux
106
+ sampling_type = comfy.model_sampling.CONST
107
+
108
+ class ModelSamplingAdvanced(sampling_base, sampling_type):
109
+ pass
110
+
111
+ model_sampling = ModelSamplingAdvanced(model.model.model_config)
112
+ model_sampling.set_parameters(shift=shift)
113
+ m.add_object_patch("model_sampling", model_sampling)
114
+
115
+ return (m, )
116
+
117
+
118
+ class LTXVScheduler:
119
+ @classmethod
120
+ def INPUT_TYPES(s):
121
+ return {"required":
122
+ {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
123
+ "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
124
+ "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
125
+ "stretch": ("BOOLEAN", {
126
+ "default": True,
127
+ "tooltip": "Stretch the sigmas to be in the range [terminal, 1]."
128
+ }),
129
+ "terminal": (
130
+ "FLOAT",
131
+ {
132
+ "default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01,
133
+ "tooltip": "The terminal value of the sigmas after stretching."
134
+ },
135
+ ),
136
+ },
137
+ "optional": {"latent": ("LATENT",), }
138
+ }
139
+
140
+ RETURN_TYPES = ("SIGMAS",)
141
+ CATEGORY = "sampling/custom_sampling/schedulers"
142
+
143
+ FUNCTION = "get_sigmas"
144
+
145
+ def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
146
+ if latent is None:
147
+ tokens = 4096
148
+ else:
149
+ tokens = math.prod(latent["samples"].shape[2:])
150
+
151
+ sigmas = torch.linspace(1.0, 0.0, steps + 1)
152
+
153
+ x1 = 1024
154
+ x2 = 4096
155
+ mm = (max_shift - base_shift) / (x2 - x1)
156
+ b = base_shift - mm * x1
157
+ sigma_shift = (tokens) * mm + b
158
+
159
+ power = 1
160
+ sigmas = torch.where(
161
+ sigmas != 0,
162
+ math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
163
+ 0,
164
+ )
165
+
166
+ # Stretch sigmas so that its final value matches the given terminal value.
167
+ if stretch:
168
+ non_zero_mask = sigmas != 0
169
+ non_zero_sigmas = sigmas[non_zero_mask]
170
+ one_minus_z = 1.0 - non_zero_sigmas
171
+ scale_factor = one_minus_z[-1] / (1.0 - terminal)
172
+ stretched = 1.0 - (one_minus_z / scale_factor)
173
+ sigmas[non_zero_mask] = stretched
174
+
175
+ return (sigmas,)
176
+
177
+
178
+ NODE_CLASS_MAPPINGS = {
179
+ "EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
180
+ "LTXVImgToVideo": LTXVImgToVideo,
181
+ "ModelSamplingLTXV": ModelSamplingLTXV,
182
+ "LTXVConditioning": LTXVConditioning,
183
+ "LTXVScheduler": LTXVScheduler,
184
+ }
comfy_extras/nodes_mahiro.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ class Mahiro:
5
+ @classmethod
6
+ def INPUT_TYPES(s):
7
+ return {"required": {"model": ("MODEL",),
8
+ }}
9
+ RETURN_TYPES = ("MODEL",)
10
+ RETURN_NAMES = ("patched_model",)
11
+ FUNCTION = "patch"
12
+ CATEGORY = "_for_testing"
13
+ DESCRIPTION = "Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt."
14
+ def patch(self, model):
15
+ m = model.clone()
16
+ def mahiro_normd(args):
17
+ scale: float = args['cond_scale']
18
+ cond_p: torch.Tensor = args['cond_denoised']
19
+ uncond_p: torch.Tensor = args['uncond_denoised']
20
+ #naive leap
21
+ leap = cond_p * scale
22
+ #sim with uncond leap
23
+ u_leap = uncond_p * scale
24
+ cfg = args["denoised"]
25
+ merge = (leap + cfg) / 2
26
+ normu = torch.sqrt(u_leap.abs()) * u_leap.sign()
27
+ normm = torch.sqrt(merge.abs()) * merge.sign()
28
+ sim = F.cosine_similarity(normu, normm).mean()
29
+ simsc = 2 * (sim+1)
30
+ wm = (simsc*cfg + (4-simsc)*leap) / 4
31
+ return wm
32
+ m.set_model_sampler_post_cfg_function(mahiro_normd)
33
+ return (m, )
34
+
35
+ NODE_CLASS_MAPPINGS = {
36
+ "Mahiro": Mahiro
37
+ }
38
+
39
+ NODE_DISPLAY_NAME_MAPPINGS = {
40
+ "Mahiro": "Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
41
+ }
comfy_extras/nodes_mask.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy.ndimage
3
+ import torch
4
+ import comfy.utils
5
+
6
+ from nodes import MAX_RESOLUTION
7
+
8
+ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
9
+ source = source.to(destination.device)
10
+ if resize_source:
11
+ source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
12
+
13
+ source = comfy.utils.repeat_to_batch_size(source, destination.shape[0])
14
+
15
+ x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
16
+ y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))
17
+
18
+ left, top = (x // multiplier, y // multiplier)
19
+ right, bottom = (left + source.shape[3], top + source.shape[2],)
20
+
21
+ if mask is None:
22
+ mask = torch.ones_like(source)
23
+ else:
24
+ mask = mask.to(destination.device, copy=True)
25
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
26
+ mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])
27
+
28
+ # calculate the bounds of the source that will be overlapping the destination
29
+ # this prevents the source trying to overwrite latent pixels that are out of bounds
30
+ # of the destination
31
+ visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),)
32
+
33
+ mask = mask[:, :, :visible_height, :visible_width]
34
+ inverse_mask = torch.ones_like(mask) - mask
35
+
36
+ source_portion = mask * source[:, :, :visible_height, :visible_width]
37
+ destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
38
+
39
+ destination[:, :, top:bottom, left:right] = source_portion + destination_portion
40
+ return destination
41
+
42
+ class LatentCompositeMasked:
43
+ @classmethod
44
+ def INPUT_TYPES(s):
45
+ return {
46
+ "required": {
47
+ "destination": ("LATENT",),
48
+ "source": ("LATENT",),
49
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
50
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
51
+ "resize_source": ("BOOLEAN", {"default": False}),
52
+ },
53
+ "optional": {
54
+ "mask": ("MASK",),
55
+ }
56
+ }
57
+ RETURN_TYPES = ("LATENT",)
58
+ FUNCTION = "composite"
59
+
60
+ CATEGORY = "latent"
61
+
62
+ def composite(self, destination, source, x, y, resize_source, mask = None):
63
+ output = destination.copy()
64
+ destination = destination["samples"].clone()
65
+ source = source["samples"]
66
+ output["samples"] = composite(destination, source, x, y, mask, 8, resize_source)
67
+ return (output,)
68
+
69
+ class ImageCompositeMasked:
70
+ @classmethod
71
+ def INPUT_TYPES(s):
72
+ return {
73
+ "required": {
74
+ "destination": ("IMAGE",),
75
+ "source": ("IMAGE",),
76
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
77
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
78
+ "resize_source": ("BOOLEAN", {"default": False}),
79
+ },
80
+ "optional": {
81
+ "mask": ("MASK",),
82
+ }
83
+ }
84
+ RETURN_TYPES = ("IMAGE",)
85
+ FUNCTION = "composite"
86
+
87
+ CATEGORY = "image"
88
+
89
+ def composite(self, destination, source, x, y, resize_source, mask = None):
90
+ destination = destination.clone().movedim(-1, 1)
91
+ output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
92
+ return (output,)
93
+
94
+ class MaskToImage:
95
+ @classmethod
96
+ def INPUT_TYPES(s):
97
+ return {
98
+ "required": {
99
+ "mask": ("MASK",),
100
+ }
101
+ }
102
+
103
+ CATEGORY = "mask"
104
+
105
+ RETURN_TYPES = ("IMAGE",)
106
+ FUNCTION = "mask_to_image"
107
+
108
+ def mask_to_image(self, mask):
109
+ result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
110
+ return (result,)
111
+
112
+ class ImageToMask:
113
+ @classmethod
114
+ def INPUT_TYPES(s):
115
+ return {
116
+ "required": {
117
+ "image": ("IMAGE",),
118
+ "channel": (["red", "green", "blue", "alpha"],),
119
+ }
120
+ }
121
+
122
+ CATEGORY = "mask"
123
+
124
+ RETURN_TYPES = ("MASK",)
125
+ FUNCTION = "image_to_mask"
126
+
127
+ def image_to_mask(self, image, channel):
128
+ channels = ["red", "green", "blue", "alpha"]
129
+ mask = image[:, :, :, channels.index(channel)]
130
+ return (mask,)
131
+
132
+ class ImageColorToMask:
133
+ @classmethod
134
+ def INPUT_TYPES(s):
135
+ return {
136
+ "required": {
137
+ "image": ("IMAGE",),
138
+ "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
139
+ }
140
+ }
141
+
142
+ CATEGORY = "mask"
143
+
144
+ RETURN_TYPES = ("MASK",)
145
+ FUNCTION = "image_to_mask"
146
+
147
+ def image_to_mask(self, image, color):
148
+ temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
149
+ temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
150
+ mask = torch.where(temp == color, 255, 0).float()
151
+ return (mask,)
152
+
153
+ class SolidMask:
154
+ @classmethod
155
+ def INPUT_TYPES(cls):
156
+ return {
157
+ "required": {
158
+ "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
159
+ "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
160
+ "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
161
+ }
162
+ }
163
+
164
+ CATEGORY = "mask"
165
+
166
+ RETURN_TYPES = ("MASK",)
167
+
168
+ FUNCTION = "solid"
169
+
170
+ def solid(self, value, width, height):
171
+ out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu")
172
+ return (out,)
173
+
174
+ class InvertMask:
175
+ @classmethod
176
+ def INPUT_TYPES(cls):
177
+ return {
178
+ "required": {
179
+ "mask": ("MASK",),
180
+ }
181
+ }
182
+
183
+ CATEGORY = "mask"
184
+
185
+ RETURN_TYPES = ("MASK",)
186
+
187
+ FUNCTION = "invert"
188
+
189
+ def invert(self, mask):
190
+ out = 1.0 - mask
191
+ return (out,)
192
+
193
+ class CropMask:
194
+ @classmethod
195
+ def INPUT_TYPES(cls):
196
+ return {
197
+ "required": {
198
+ "mask": ("MASK",),
199
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
200
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
201
+ "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
202
+ "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
203
+ }
204
+ }
205
+
206
+ CATEGORY = "mask"
207
+
208
+ RETURN_TYPES = ("MASK",)
209
+
210
+ FUNCTION = "crop"
211
+
212
+ def crop(self, mask, x, y, width, height):
213
+ mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
214
+ out = mask[:, y:y + height, x:x + width]
215
+ return (out,)
216
+
217
+ class MaskComposite:
218
+ @classmethod
219
+ def INPUT_TYPES(cls):
220
+ return {
221
+ "required": {
222
+ "destination": ("MASK",),
223
+ "source": ("MASK",),
224
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
225
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
226
+ "operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
227
+ }
228
+ }
229
+
230
+ CATEGORY = "mask"
231
+
232
+ RETURN_TYPES = ("MASK",)
233
+
234
+ FUNCTION = "combine"
235
+
236
+ def combine(self, destination, source, x, y, operation):
237
+ output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
238
+ source = source.reshape((-1, source.shape[-2], source.shape[-1]))
239
+
240
+ left, top = (x, y,)
241
+ right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2]))
242
+ visible_width, visible_height = (right - left, bottom - top,)
243
+
244
+ source_portion = source[:, :visible_height, :visible_width]
245
+ destination_portion = destination[:, top:bottom, left:right]
246
+
247
+ if operation == "multiply":
248
+ output[:, top:bottom, left:right] = destination_portion * source_portion
249
+ elif operation == "add":
250
+ output[:, top:bottom, left:right] = destination_portion + source_portion
251
+ elif operation == "subtract":
252
+ output[:, top:bottom, left:right] = destination_portion - source_portion
253
+ elif operation == "and":
254
+ output[:, top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float()
255
+ elif operation == "or":
256
+ output[:, top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float()
257
+ elif operation == "xor":
258
+ output[:, top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float()
259
+
260
+ output = torch.clamp(output, 0.0, 1.0)
261
+
262
+ return (output,)
263
+
264
+ class FeatherMask:
265
+ @classmethod
266
+ def INPUT_TYPES(cls):
267
+ return {
268
+ "required": {
269
+ "mask": ("MASK",),
270
+ "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
271
+ "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
272
+ "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
273
+ "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
274
+ }
275
+ }
276
+
277
+ CATEGORY = "mask"
278
+
279
+ RETURN_TYPES = ("MASK",)
280
+
281
+ FUNCTION = "feather"
282
+
283
+ def feather(self, mask, left, top, right, bottom):
284
+ output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
285
+
286
+ left = min(left, output.shape[-1])
287
+ right = min(right, output.shape[-1])
288
+ top = min(top, output.shape[-2])
289
+ bottom = min(bottom, output.shape[-2])
290
+
291
+ for x in range(left):
292
+ feather_rate = (x + 1.0) / left
293
+ output[:, :, x] *= feather_rate
294
+
295
+ for x in range(right):
296
+ feather_rate = (x + 1) / right
297
+ output[:, :, -x] *= feather_rate
298
+
299
+ for y in range(top):
300
+ feather_rate = (y + 1) / top
301
+ output[:, y, :] *= feather_rate
302
+
303
+ for y in range(bottom):
304
+ feather_rate = (y + 1) / bottom
305
+ output[:, -y, :] *= feather_rate
306
+
307
+ return (output,)
308
+
309
+ class GrowMask:
310
+ @classmethod
311
+ def INPUT_TYPES(cls):
312
+ return {
313
+ "required": {
314
+ "mask": ("MASK",),
315
+ "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}),
316
+ "tapered_corners": ("BOOLEAN", {"default": True}),
317
+ },
318
+ }
319
+
320
+ CATEGORY = "mask"
321
+
322
+ RETURN_TYPES = ("MASK",)
323
+
324
+ FUNCTION = "expand_mask"
325
+
326
+ def expand_mask(self, mask, expand, tapered_corners):
327
+ c = 0 if tapered_corners else 1
328
+ kernel = np.array([[c, 1, c],
329
+ [1, 1, 1],
330
+ [c, 1, c]])
331
+ mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
332
+ out = []
333
+ for m in mask:
334
+ output = m.numpy()
335
+ for _ in range(abs(expand)):
336
+ if expand < 0:
337
+ output = scipy.ndimage.grey_erosion(output, footprint=kernel)
338
+ else:
339
+ output = scipy.ndimage.grey_dilation(output, footprint=kernel)
340
+ output = torch.from_numpy(output)
341
+ out.append(output)
342
+ return (torch.stack(out, dim=0),)
343
+
344
+ class ThresholdMask:
345
+ @classmethod
346
+ def INPUT_TYPES(s):
347
+ return {
348
+ "required": {
349
+ "mask": ("MASK",),
350
+ "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
351
+ }
352
+ }
353
+
354
+ CATEGORY = "mask"
355
+
356
+ RETURN_TYPES = ("MASK",)
357
+ FUNCTION = "image_to_mask"
358
+
359
+ def image_to_mask(self, mask, value):
360
+ mask = (mask > value).float()
361
+ return (mask,)
362
+
363
+
364
+ NODE_CLASS_MAPPINGS = {
365
+ "LatentCompositeMasked": LatentCompositeMasked,
366
+ "ImageCompositeMasked": ImageCompositeMasked,
367
+ "MaskToImage": MaskToImage,
368
+ "ImageToMask": ImageToMask,
369
+ "ImageColorToMask": ImageColorToMask,
370
+ "SolidMask": SolidMask,
371
+ "InvertMask": InvertMask,
372
+ "CropMask": CropMask,
373
+ "MaskComposite": MaskComposite,
374
+ "FeatherMask": FeatherMask,
375
+ "GrowMask": GrowMask,
376
+ "ThresholdMask": ThresholdMask,
377
+ }
378
+
379
+ NODE_DISPLAY_NAME_MAPPINGS = {
380
+ "ImageToMask": "Convert Image to Mask",
381
+ "MaskToImage": "Convert Mask to Image",
382
+ }
comfy_extras/nodes_mochi.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nodes
2
+ import torch
3
+ import comfy.model_management
4
+
5
+ class EmptyMochiLatentVideo:
6
+ @classmethod
7
+ def INPUT_TYPES(s):
8
+ return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
9
+ "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
10
+ "length": ("INT", {"default": 25, "min": 7, "max": nodes.MAX_RESOLUTION, "step": 6}),
11
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
12
+ RETURN_TYPES = ("LATENT",)
13
+ FUNCTION = "generate"
14
+
15
+ CATEGORY = "latent/video"
16
+
17
+ def generate(self, width, height, length, batch_size=1):
18
+ latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
19
+ return ({"samples":latent}, )
20
+
21
+ NODE_CLASS_MAPPINGS = {
22
+ "EmptyMochiLatentVideo": EmptyMochiLatentVideo,
23
+ }
comfy_extras/nodes_model_advanced.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import comfy.sd
2
+ import comfy.model_sampling
3
+ import comfy.latent_formats
4
+ import nodes
5
+ import torch
6
+
7
+ class LCM(comfy.model_sampling.EPS):
8
+ def calculate_denoised(self, sigma, model_output, model_input):
9
+ timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
10
+ sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
11
+ x0 = model_input - model_output * sigma
12
+
13
+ sigma_data = 0.5
14
+ scaled_timestep = timestep * 10.0 #timestep_scaling
15
+
16
+ c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
17
+ c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
18
+
19
+ return c_out * x0 + c_skip * model_input
20
+
21
+ class X0(comfy.model_sampling.EPS):
22
+ def calculate_denoised(self, sigma, model_output, model_input):
23
+ return model_output
24
+
25
+ class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
26
+ original_timesteps = 50
27
+
28
+ def __init__(self, model_config=None, zsnr=None):
29
+ super().__init__(model_config, zsnr=zsnr)
30
+
31
+ self.skip_steps = self.num_timesteps // self.original_timesteps
32
+
33
+ sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
34
+ for x in range(self.original_timesteps):
35
+ sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps]
36
+
37
+ self.set_sigmas(sigmas_valid)
38
+
39
+ def timestep(self, sigma):
40
+ log_sigma = sigma.log()
41
+ dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
42
+ return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device)
43
+
44
+ def sigma(self, timestep):
45
+ t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
46
+ low_idx = t.floor().long()
47
+ high_idx = t.ceil().long()
48
+ w = t.frac()
49
+ log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
50
+ return log_sigma.exp().to(timestep.device)
51
+
52
+
53
+ class ModelSamplingDiscrete:
54
+ @classmethod
55
+ def INPUT_TYPES(s):
56
+ return {"required": { "model": ("MODEL",),
57
+ "sampling": (["eps", "v_prediction", "lcm", "x0"],),
58
+ "zsnr": ("BOOLEAN", {"default": False}),
59
+ }}
60
+
61
+ RETURN_TYPES = ("MODEL",)
62
+ FUNCTION = "patch"
63
+
64
+ CATEGORY = "advanced/model"
65
+
66
+ def patch(self, model, sampling, zsnr):
67
+ m = model.clone()
68
+
69
+ sampling_base = comfy.model_sampling.ModelSamplingDiscrete
70
+ if sampling == "eps":
71
+ sampling_type = comfy.model_sampling.EPS
72
+ elif sampling == "v_prediction":
73
+ sampling_type = comfy.model_sampling.V_PREDICTION
74
+ elif sampling == "lcm":
75
+ sampling_type = LCM
76
+ sampling_base = ModelSamplingDiscreteDistilled
77
+ elif sampling == "x0":
78
+ sampling_type = X0
79
+
80
+ class ModelSamplingAdvanced(sampling_base, sampling_type):
81
+ pass
82
+
83
+ model_sampling = ModelSamplingAdvanced(model.model.model_config, zsnr=zsnr)
84
+
85
+ m.add_object_patch("model_sampling", model_sampling)
86
+ return (m, )
87
+
88
+ class ModelSamplingStableCascade:
89
+ @classmethod
90
+ def INPUT_TYPES(s):
91
+ return {"required": { "model": ("MODEL",),
92
+ "shift": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step":0.01}),
93
+ }}
94
+
95
+ RETURN_TYPES = ("MODEL",)
96
+ FUNCTION = "patch"
97
+
98
+ CATEGORY = "advanced/model"
99
+
100
+ def patch(self, model, shift):
101
+ m = model.clone()
102
+
103
+ sampling_base = comfy.model_sampling.StableCascadeSampling
104
+ sampling_type = comfy.model_sampling.EPS
105
+
106
+ class ModelSamplingAdvanced(sampling_base, sampling_type):
107
+ pass
108
+
109
+ model_sampling = ModelSamplingAdvanced(model.model.model_config)
110
+ model_sampling.set_parameters(shift)
111
+ m.add_object_patch("model_sampling", model_sampling)
112
+ return (m, )
113
+
114
+ class ModelSamplingSD3:
115
+ @classmethod
116
+ def INPUT_TYPES(s):
117
+ return {"required": { "model": ("MODEL",),
118
+ "shift": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step":0.01}),
119
+ }}
120
+
121
+ RETURN_TYPES = ("MODEL",)
122
+ FUNCTION = "patch"
123
+
124
+ CATEGORY = "advanced/model"
125
+
126
+ def patch(self, model, shift, multiplier=1000):
127
+ m = model.clone()
128
+
129
+ sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow
130
+ sampling_type = comfy.model_sampling.CONST
131
+
132
+ class ModelSamplingAdvanced(sampling_base, sampling_type):
133
+ pass
134
+
135
+ model_sampling = ModelSamplingAdvanced(model.model.model_config)
136
+ model_sampling.set_parameters(shift=shift, multiplier=multiplier)
137
+ m.add_object_patch("model_sampling", model_sampling)
138
+ return (m, )
139
+
140
+ class ModelSamplingAuraFlow(ModelSamplingSD3):
141
+ @classmethod
142
+ def INPUT_TYPES(s):
143
+ return {"required": { "model": ("MODEL",),
144
+ "shift": ("FLOAT", {"default": 1.73, "min": 0.0, "max": 100.0, "step":0.01}),
145
+ }}
146
+
147
+ FUNCTION = "patch_aura"
148
+
149
+ def patch_aura(self, model, shift):
150
+ return self.patch(model, shift, multiplier=1.0)
151
+
152
+ class ModelSamplingFlux:
153
+ @classmethod
154
+ def INPUT_TYPES(s):
155
+ return {"required": { "model": ("MODEL",),
156
+ "max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}),
157
+ "base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01}),
158
+ "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
159
+ "height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
160
+ }}
161
+
162
+ RETURN_TYPES = ("MODEL",)
163
+ FUNCTION = "patch"
164
+
165
+ CATEGORY = "advanced/model"
166
+
167
+ def patch(self, model, max_shift, base_shift, width, height):
168
+ m = model.clone()
169
+
170
+ x1 = 256
171
+ x2 = 4096
172
+ mm = (max_shift - base_shift) / (x2 - x1)
173
+ b = base_shift - mm * x1
174
+ shift = (width * height / (8 * 8 * 2 * 2)) * mm + b
175
+
176
+ sampling_base = comfy.model_sampling.ModelSamplingFlux
177
+ sampling_type = comfy.model_sampling.CONST
178
+
179
+ class ModelSamplingAdvanced(sampling_base, sampling_type):
180
+ pass
181
+
182
+ model_sampling = ModelSamplingAdvanced(model.model.model_config)
183
+ model_sampling.set_parameters(shift=shift)
184
+ m.add_object_patch("model_sampling", model_sampling)
185
+ return (m, )
186
+
187
+
188
+ class ModelSamplingContinuousEDM:
189
+ @classmethod
190
+ def INPUT_TYPES(s):
191
+ return {"required": { "model": ("MODEL",),
192
+ "sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps"],),
193
+ "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
194
+ "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
195
+ }}
196
+
197
+ RETURN_TYPES = ("MODEL",)
198
+ FUNCTION = "patch"
199
+
200
+ CATEGORY = "advanced/model"
201
+
202
+ def patch(self, model, sampling, sigma_max, sigma_min):
203
+ m = model.clone()
204
+
205
+ latent_format = None
206
+ sigma_data = 1.0
207
+ if sampling == "eps":
208
+ sampling_type = comfy.model_sampling.EPS
209
+ elif sampling == "edm":
210
+ sampling_type = comfy.model_sampling.EDM
211
+ sigma_data = 0.5
212
+ elif sampling == "v_prediction":
213
+ sampling_type = comfy.model_sampling.V_PREDICTION
214
+ elif sampling == "edm_playground_v2.5":
215
+ sampling_type = comfy.model_sampling.EDM
216
+ sigma_data = 0.5
217
+ latent_format = comfy.latent_formats.SDXL_Playground_2_5()
218
+
219
+ class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
220
+ pass
221
+
222
+ model_sampling = ModelSamplingAdvanced(model.model.model_config)
223
+ model_sampling.set_parameters(sigma_min, sigma_max, sigma_data)
224
+ m.add_object_patch("model_sampling", model_sampling)
225
+ if latent_format is not None:
226
+ m.add_object_patch("latent_format", latent_format)
227
+ return (m, )
228
+
229
+ class ModelSamplingContinuousV:
230
+ @classmethod
231
+ def INPUT_TYPES(s):
232
+ return {"required": { "model": ("MODEL",),
233
+ "sampling": (["v_prediction"],),
234
+ "sigma_max": ("FLOAT", {"default": 500.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
235
+ "sigma_min": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
236
+ }}
237
+
238
+ RETURN_TYPES = ("MODEL",)
239
+ FUNCTION = "patch"
240
+
241
+ CATEGORY = "advanced/model"
242
+
243
+ def patch(self, model, sampling, sigma_max, sigma_min):
244
+ m = model.clone()
245
+
246
+ sigma_data = 1.0
247
+ if sampling == "v_prediction":
248
+ sampling_type = comfy.model_sampling.V_PREDICTION
249
+
250
+ class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousV, sampling_type):
251
+ pass
252
+
253
+ model_sampling = ModelSamplingAdvanced(model.model.model_config)
254
+ model_sampling.set_parameters(sigma_min, sigma_max, sigma_data)
255
+ m.add_object_patch("model_sampling", model_sampling)
256
+ return (m, )
257
+
258
+ class RescaleCFG:
259
+ @classmethod
260
+ def INPUT_TYPES(s):
261
+ return {"required": { "model": ("MODEL",),
262
+ "multiplier": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}),
263
+ }}
264
+ RETURN_TYPES = ("MODEL",)
265
+ FUNCTION = "patch"
266
+
267
+ CATEGORY = "advanced/model"
268
+
269
+ def patch(self, model, multiplier):
270
+ def rescale_cfg(args):
271
+ cond = args["cond"]
272
+ uncond = args["uncond"]
273
+ cond_scale = args["cond_scale"]
274
+ sigma = args["sigma"]
275
+ sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
276
+ x_orig = args["input"]
277
+
278
+ #rescale cfg has to be done on v-pred model output
279
+ x = x_orig / (sigma * sigma + 1.0)
280
+ cond = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
281
+ uncond = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
282
+
283
+ #rescalecfg
284
+ x_cfg = uncond + cond_scale * (cond - uncond)
285
+ ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True)
286
+ ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True)
287
+
288
+ x_rescaled = x_cfg * (ro_pos / ro_cfg)
289
+ x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg
290
+
291
+ return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5)
292
+
293
+ m = model.clone()
294
+ m.set_model_sampler_cfg_function(rescale_cfg)
295
+ return (m, )
296
+
297
+ NODE_CLASS_MAPPINGS = {
298
+ "ModelSamplingDiscrete": ModelSamplingDiscrete,
299
+ "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
300
+ "ModelSamplingContinuousV": ModelSamplingContinuousV,
301
+ "ModelSamplingStableCascade": ModelSamplingStableCascade,
302
+ "ModelSamplingSD3": ModelSamplingSD3,
303
+ "ModelSamplingAuraFlow": ModelSamplingAuraFlow,
304
+ "ModelSamplingFlux": ModelSamplingFlux,
305
+ "RescaleCFG": RescaleCFG,
306
+ }
comfy_extras/nodes_model_downscale.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import comfy.utils
2
+
3
+ class PatchModelAddDownscale:
4
+ upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
5
+ @classmethod
6
+ def INPUT_TYPES(s):
7
+ return {"required": { "model": ("MODEL",),
8
+ "block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
9
+ "downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
10
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
11
+ "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
12
+ "downscale_after_skip": ("BOOLEAN", {"default": True}),
13
+ "downscale_method": (s.upscale_methods,),
14
+ "upscale_method": (s.upscale_methods,),
15
+ }}
16
+ RETURN_TYPES = ("MODEL",)
17
+ FUNCTION = "patch"
18
+
19
+ CATEGORY = "model_patches/unet"
20
+
21
+ def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
22
+ model_sampling = model.get_model_object("model_sampling")
23
+ sigma_start = model_sampling.percent_to_sigma(start_percent)
24
+ sigma_end = model_sampling.percent_to_sigma(end_percent)
25
+
26
+ def input_block_patch(h, transformer_options):
27
+ if transformer_options["block"][1] == block_number:
28
+ sigma = transformer_options["sigmas"][0].item()
29
+ if sigma <= sigma_start and sigma >= sigma_end:
30
+ h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
31
+ return h
32
+
33
+ def output_block_patch(h, hsp, transformer_options):
34
+ if h.shape[2] != hsp.shape[2]:
35
+ h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
36
+ return h, hsp
37
+
38
+ m = model.clone()
39
+ if downscale_after_skip:
40
+ m.set_model_input_block_patch_after_skip(input_block_patch)
41
+ else:
42
+ m.set_model_input_block_patch(input_block_patch)
43
+ m.set_model_output_block_patch(output_block_patch)
44
+ return (m, )
45
+
46
+ NODE_CLASS_MAPPINGS = {
47
+ "PatchModelAddDownscale": PatchModelAddDownscale,
48
+ }
49
+
50
+ NODE_DISPLAY_NAME_MAPPINGS = {
51
+ # Sampling
52
+ "PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
53
+ }
comfy_extras/nodes_model_merging.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import comfy.sd
2
+ import comfy.utils
3
+ import comfy.model_base
4
+ import comfy.model_management
5
+ import comfy.model_sampling
6
+
7
+ import torch
8
+ import folder_paths
9
+ import json
10
+ import os
11
+
12
+ from comfy.cli_args import args
13
+
14
+ class ModelMergeSimple:
15
+ @classmethod
16
+ def INPUT_TYPES(s):
17
+ return {"required": { "model1": ("MODEL",),
18
+ "model2": ("MODEL",),
19
+ "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
20
+ }}
21
+ RETURN_TYPES = ("MODEL",)
22
+ FUNCTION = "merge"
23
+
24
+ CATEGORY = "advanced/model_merging"
25
+
26
+ def merge(self, model1, model2, ratio):
27
+ m = model1.clone()
28
+ kp = model2.get_key_patches("diffusion_model.")
29
+ for k in kp:
30
+ m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
31
+ return (m, )
32
+
33
+ class ModelSubtract:
34
+ @classmethod
35
+ def INPUT_TYPES(s):
36
+ return {"required": { "model1": ("MODEL",),
37
+ "model2": ("MODEL",),
38
+ "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
39
+ }}
40
+ RETURN_TYPES = ("MODEL",)
41
+ FUNCTION = "merge"
42
+
43
+ CATEGORY = "advanced/model_merging"
44
+
45
+ def merge(self, model1, model2, multiplier):
46
+ m = model1.clone()
47
+ kp = model2.get_key_patches("diffusion_model.")
48
+ for k in kp:
49
+ m.add_patches({k: kp[k]}, - multiplier, multiplier)
50
+ return (m, )
51
+
52
+ class ModelAdd:
53
+ @classmethod
54
+ def INPUT_TYPES(s):
55
+ return {"required": { "model1": ("MODEL",),
56
+ "model2": ("MODEL",),
57
+ }}
58
+ RETURN_TYPES = ("MODEL",)
59
+ FUNCTION = "merge"
60
+
61
+ CATEGORY = "advanced/model_merging"
62
+
63
+ def merge(self, model1, model2):
64
+ m = model1.clone()
65
+ kp = model2.get_key_patches("diffusion_model.")
66
+ for k in kp:
67
+ m.add_patches({k: kp[k]}, 1.0, 1.0)
68
+ return (m, )
69
+
70
+
71
+ class CLIPMergeSimple:
72
+ @classmethod
73
+ def INPUT_TYPES(s):
74
+ return {"required": { "clip1": ("CLIP",),
75
+ "clip2": ("CLIP",),
76
+ "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
77
+ }}
78
+ RETURN_TYPES = ("CLIP",)
79
+ FUNCTION = "merge"
80
+
81
+ CATEGORY = "advanced/model_merging"
82
+
83
+ def merge(self, clip1, clip2, ratio):
84
+ m = clip1.clone()
85
+ kp = clip2.get_key_patches()
86
+ for k in kp:
87
+ if k.endswith(".position_ids") or k.endswith(".logit_scale"):
88
+ continue
89
+ m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
90
+ return (m, )
91
+
92
+
93
+ class CLIPSubtract:
94
+ @classmethod
95
+ def INPUT_TYPES(s):
96
+ return {"required": { "clip1": ("CLIP",),
97
+ "clip2": ("CLIP",),
98
+ "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
99
+ }}
100
+ RETURN_TYPES = ("CLIP",)
101
+ FUNCTION = "merge"
102
+
103
+ CATEGORY = "advanced/model_merging"
104
+
105
+ def merge(self, clip1, clip2, multiplier):
106
+ m = clip1.clone()
107
+ kp = clip2.get_key_patches()
108
+ for k in kp:
109
+ if k.endswith(".position_ids") or k.endswith(".logit_scale"):
110
+ continue
111
+ m.add_patches({k: kp[k]}, - multiplier, multiplier)
112
+ return (m, )
113
+
114
+
115
+ class CLIPAdd:
116
+ @classmethod
117
+ def INPUT_TYPES(s):
118
+ return {"required": { "clip1": ("CLIP",),
119
+ "clip2": ("CLIP",),
120
+ }}
121
+ RETURN_TYPES = ("CLIP",)
122
+ FUNCTION = "merge"
123
+
124
+ CATEGORY = "advanced/model_merging"
125
+
126
+ def merge(self, clip1, clip2):
127
+ m = clip1.clone()
128
+ kp = clip2.get_key_patches()
129
+ for k in kp:
130
+ if k.endswith(".position_ids") or k.endswith(".logit_scale"):
131
+ continue
132
+ m.add_patches({k: kp[k]}, 1.0, 1.0)
133
+ return (m, )
134
+
135
+
136
+ class ModelMergeBlocks:
137
+ @classmethod
138
+ def INPUT_TYPES(s):
139
+ return {"required": { "model1": ("MODEL",),
140
+ "model2": ("MODEL",),
141
+ "input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
142
+ "middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
143
+ "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
144
+ }}
145
+ RETURN_TYPES = ("MODEL",)
146
+ FUNCTION = "merge"
147
+
148
+ CATEGORY = "advanced/model_merging"
149
+
150
+ def merge(self, model1, model2, **kwargs):
151
+ m = model1.clone()
152
+ kp = model2.get_key_patches("diffusion_model.")
153
+ default_ratio = next(iter(kwargs.values()))
154
+
155
+ for k in kp:
156
+ ratio = default_ratio
157
+ k_unet = k[len("diffusion_model."):]
158
+
159
+ last_arg_size = 0
160
+ for arg in kwargs:
161
+ if k_unet.startswith(arg) and last_arg_size < len(arg):
162
+ ratio = kwargs[arg]
163
+ last_arg_size = len(arg)
164
+
165
+ m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
166
+ return (m, )
167
+
168
+ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
169
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir)
170
+ prompt_info = ""
171
+ if prompt is not None:
172
+ prompt_info = json.dumps(prompt)
173
+
174
+ metadata = {}
175
+
176
+ enable_modelspec = True
177
+ if isinstance(model.model, comfy.model_base.SDXL):
178
+ if isinstance(model.model, comfy.model_base.SDXL_instructpix2pix):
179
+ metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-edit"
180
+ else:
181
+ metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
182
+ elif isinstance(model.model, comfy.model_base.SDXLRefiner):
183
+ metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
184
+ elif isinstance(model.model, comfy.model_base.SVD_img2vid):
185
+ metadata["modelspec.architecture"] = "stable-video-diffusion-img2vid-v1"
186
+ elif isinstance(model.model, comfy.model_base.SD3):
187
+ metadata["modelspec.architecture"] = "stable-diffusion-v3-medium" #TODO: other SD3 variants
188
+ else:
189
+ enable_modelspec = False
190
+
191
+ if enable_modelspec:
192
+ metadata["modelspec.sai_model_spec"] = "1.0.0"
193
+ metadata["modelspec.implementation"] = "sgm"
194
+ metadata["modelspec.title"] = "{} {}".format(filename, counter)
195
+
196
+ #TODO:
197
+ # "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
198
+ # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
199
+ # "v2-inpainting"
200
+
201
+ extra_keys = {}
202
+ model_sampling = model.get_model_object("model_sampling")
203
+ if isinstance(model_sampling, comfy.model_sampling.ModelSamplingContinuousEDM):
204
+ if isinstance(model_sampling, comfy.model_sampling.V_PREDICTION):
205
+ extra_keys["edm_vpred.sigma_max"] = torch.tensor(model_sampling.sigma_max).float()
206
+ extra_keys["edm_vpred.sigma_min"] = torch.tensor(model_sampling.sigma_min).float()
207
+
208
+ if model.model.model_type == comfy.model_base.ModelType.EPS:
209
+ metadata["modelspec.predict_key"] = "epsilon"
210
+ elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
211
+ metadata["modelspec.predict_key"] = "v"
212
+
213
+ if not args.disable_metadata:
214
+ metadata["prompt"] = prompt_info
215
+ if extra_pnginfo is not None:
216
+ for x in extra_pnginfo:
217
+ metadata[x] = json.dumps(extra_pnginfo[x])
218
+
219
+ output_checkpoint = f"{filename}_{counter:05}_.safetensors"
220
+ output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
221
+
222
+ comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
223
+
224
+ class CheckpointSave:
225
+ def __init__(self):
226
+ self.output_dir = folder_paths.get_output_directory()
227
+
228
+ @classmethod
229
+ def INPUT_TYPES(s):
230
+ return {"required": { "model": ("MODEL",),
231
+ "clip": ("CLIP",),
232
+ "vae": ("VAE",),
233
+ "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
234
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
235
+ RETURN_TYPES = ()
236
+ FUNCTION = "save"
237
+ OUTPUT_NODE = True
238
+
239
+ CATEGORY = "advanced/model_merging"
240
+
241
+ def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
242
+ save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
243
+ return {}
244
+
245
+ class CLIPSave:
246
+ def __init__(self):
247
+ self.output_dir = folder_paths.get_output_directory()
248
+
249
+ @classmethod
250
+ def INPUT_TYPES(s):
251
+ return {"required": { "clip": ("CLIP",),
252
+ "filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},
253
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
254
+ RETURN_TYPES = ()
255
+ FUNCTION = "save"
256
+ OUTPUT_NODE = True
257
+
258
+ CATEGORY = "advanced/model_merging"
259
+
260
+ def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
261
+ prompt_info = ""
262
+ if prompt is not None:
263
+ prompt_info = json.dumps(prompt)
264
+
265
+ metadata = {}
266
+ if not args.disable_metadata:
267
+ metadata["format"] = "pt"
268
+ metadata["prompt"] = prompt_info
269
+ if extra_pnginfo is not None:
270
+ for x in extra_pnginfo:
271
+ metadata[x] = json.dumps(extra_pnginfo[x])
272
+
273
+ comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
274
+ clip_sd = clip.get_sd()
275
+
276
+ for prefix in ["clip_l.", "clip_g.", ""]:
277
+ k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
278
+ current_clip_sd = {}
279
+ for x in k:
280
+ current_clip_sd[x] = clip_sd.pop(x)
281
+ if len(current_clip_sd) == 0:
282
+ continue
283
+
284
+ p = prefix[:-1]
285
+ replace_prefix = {}
286
+ filename_prefix_ = filename_prefix
287
+ if len(p) > 0:
288
+ filename_prefix_ = "{}_{}".format(filename_prefix_, p)
289
+ replace_prefix[prefix] = ""
290
+ replace_prefix["transformer."] = ""
291
+
292
+ full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir)
293
+
294
+ output_checkpoint = f"{filename}_{counter:05}_.safetensors"
295
+ output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
296
+
297
+ current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
298
+
299
+ comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
300
+ return {}
301
+
302
+ class VAESave:
303
+ def __init__(self):
304
+ self.output_dir = folder_paths.get_output_directory()
305
+
306
+ @classmethod
307
+ def INPUT_TYPES(s):
308
+ return {"required": { "vae": ("VAE",),
309
+ "filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),},
310
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
311
+ RETURN_TYPES = ()
312
+ FUNCTION = "save"
313
+ OUTPUT_NODE = True
314
+
315
+ CATEGORY = "advanced/model_merging"
316
+
317
+ def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None):
318
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
319
+ prompt_info = ""
320
+ if prompt is not None:
321
+ prompt_info = json.dumps(prompt)
322
+
323
+ metadata = {}
324
+ if not args.disable_metadata:
325
+ metadata["prompt"] = prompt_info
326
+ if extra_pnginfo is not None:
327
+ for x in extra_pnginfo:
328
+ metadata[x] = json.dumps(extra_pnginfo[x])
329
+
330
+ output_checkpoint = f"{filename}_{counter:05}_.safetensors"
331
+ output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
332
+
333
+ comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
334
+ return {}
335
+
336
+ class ModelSave:
337
+ def __init__(self):
338
+ self.output_dir = folder_paths.get_output_directory()
339
+
340
+ @classmethod
341
+ def INPUT_TYPES(s):
342
+ return {"required": { "model": ("MODEL",),
343
+ "filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},
344
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
345
+ RETURN_TYPES = ()
346
+ FUNCTION = "save"
347
+ OUTPUT_NODE = True
348
+
349
+ CATEGORY = "advanced/model_merging"
350
+
351
+ def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None):
352
+ save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
353
+ return {}
354
+
355
+ NODE_CLASS_MAPPINGS = {
356
+ "ModelMergeSimple": ModelMergeSimple,
357
+ "ModelMergeBlocks": ModelMergeBlocks,
358
+ "ModelMergeSubtract": ModelSubtract,
359
+ "ModelMergeAdd": ModelAdd,
360
+ "CheckpointSave": CheckpointSave,
361
+ "CLIPMergeSimple": CLIPMergeSimple,
362
+ "CLIPMergeSubtract": CLIPSubtract,
363
+ "CLIPMergeAdd": CLIPAdd,
364
+ "CLIPSave": CLIPSave,
365
+ "VAESave": VAESave,
366
+ "ModelSave": ModelSave,
367
+ }
368
+
369
+ NODE_DISPLAY_NAME_MAPPINGS = {
370
+ "CheckpointSave": "Save Checkpoint",
371
+ }
comfy_extras/nodes_model_merging_model_specific.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import comfy_extras.nodes_model_merging
2
+
3
+ class ModelMergeSD1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
4
+ CATEGORY = "advanced/model_merging/model_specific"
5
+ @classmethod
6
+ def INPUT_TYPES(s):
7
+ arg_dict = { "model1": ("MODEL",),
8
+ "model2": ("MODEL",)}
9
+
10
+ argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
11
+
12
+ arg_dict["time_embed."] = argument
13
+ arg_dict["label_emb."] = argument
14
+
15
+ for i in range(12):
16
+ arg_dict["input_blocks.{}.".format(i)] = argument
17
+
18
+ for i in range(3):
19
+ arg_dict["middle_block.{}.".format(i)] = argument
20
+
21
+ for i in range(12):
22
+ arg_dict["output_blocks.{}.".format(i)] = argument
23
+
24
+ arg_dict["out."] = argument
25
+
26
+ return {"required": arg_dict}
27
+
28
+
29
+ class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks):
30
+ CATEGORY = "advanced/model_merging/model_specific"
31
+
32
+ @classmethod
33
+ def INPUT_TYPES(s):
34
+ arg_dict = { "model1": ("MODEL",),
35
+ "model2": ("MODEL",)}
36
+
37
+ argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
38
+
39
+ arg_dict["time_embed."] = argument
40
+ arg_dict["label_emb."] = argument
41
+
42
+ for i in range(9):
43
+ arg_dict["input_blocks.{}".format(i)] = argument
44
+
45
+ for i in range(3):
46
+ arg_dict["middle_block.{}".format(i)] = argument
47
+
48
+ for i in range(9):
49
+ arg_dict["output_blocks.{}".format(i)] = argument
50
+
51
+ arg_dict["out."] = argument
52
+
53
+ return {"required": arg_dict}
54
+
55
+ class ModelMergeSD3_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
56
+ CATEGORY = "advanced/model_merging/model_specific"
57
+
58
+ @classmethod
59
+ def INPUT_TYPES(s):
60
+ arg_dict = { "model1": ("MODEL",),
61
+ "model2": ("MODEL",)}
62
+
63
+ argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
64
+
65
+ arg_dict["pos_embed."] = argument
66
+ arg_dict["x_embedder."] = argument
67
+ arg_dict["context_embedder."] = argument
68
+ arg_dict["y_embedder."] = argument
69
+ arg_dict["t_embedder."] = argument
70
+
71
+ for i in range(24):
72
+ arg_dict["joint_blocks.{}.".format(i)] = argument
73
+
74
+ arg_dict["final_layer."] = argument
75
+
76
+ return {"required": arg_dict}
77
+
78
+
79
+ class ModelMergeAuraflow(comfy_extras.nodes_model_merging.ModelMergeBlocks):
80
+ CATEGORY = "advanced/model_merging/model_specific"
81
+
82
+ @classmethod
83
+ def INPUT_TYPES(s):
84
+ arg_dict = { "model1": ("MODEL",),
85
+ "model2": ("MODEL",)}
86
+
87
+ argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
88
+
89
+ arg_dict["init_x_linear."] = argument
90
+ arg_dict["positional_encoding"] = argument
91
+ arg_dict["cond_seq_linear."] = argument
92
+ arg_dict["register_tokens"] = argument
93
+ arg_dict["t_embedder."] = argument
94
+
95
+ for i in range(4):
96
+ arg_dict["double_layers.{}.".format(i)] = argument
97
+
98
+ for i in range(32):
99
+ arg_dict["single_layers.{}.".format(i)] = argument
100
+
101
+ arg_dict["modF."] = argument
102
+ arg_dict["final_linear."] = argument
103
+
104
+ return {"required": arg_dict}
105
+
106
+ class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
107
+ CATEGORY = "advanced/model_merging/model_specific"
108
+
109
+ @classmethod
110
+ def INPUT_TYPES(s):
111
+ arg_dict = { "model1": ("MODEL",),
112
+ "model2": ("MODEL",)}
113
+
114
+ argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
115
+
116
+ arg_dict["img_in."] = argument
117
+ arg_dict["time_in."] = argument
118
+ arg_dict["guidance_in"] = argument
119
+ arg_dict["vector_in."] = argument
120
+ arg_dict["txt_in."] = argument
121
+
122
+ for i in range(19):
123
+ arg_dict["double_blocks.{}.".format(i)] = argument
124
+
125
+ for i in range(38):
126
+ arg_dict["single_blocks.{}.".format(i)] = argument
127
+
128
+ arg_dict["final_layer."] = argument
129
+
130
+ return {"required": arg_dict}
131
+
132
+ class ModelMergeSD35_Large(comfy_extras.nodes_model_merging.ModelMergeBlocks):
133
+ CATEGORY = "advanced/model_merging/model_specific"
134
+
135
+ @classmethod
136
+ def INPUT_TYPES(s):
137
+ arg_dict = { "model1": ("MODEL",),
138
+ "model2": ("MODEL",)}
139
+
140
+ argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
141
+
142
+ arg_dict["pos_embed."] = argument
143
+ arg_dict["x_embedder."] = argument
144
+ arg_dict["context_embedder."] = argument
145
+ arg_dict["y_embedder."] = argument
146
+ arg_dict["t_embedder."] = argument
147
+
148
+ for i in range(38):
149
+ arg_dict["joint_blocks.{}.".format(i)] = argument
150
+
151
+ arg_dict["final_layer."] = argument
152
+
153
+ return {"required": arg_dict}
154
+
155
+ class ModelMergeMochiPreview(comfy_extras.nodes_model_merging.ModelMergeBlocks):
156
+ CATEGORY = "advanced/model_merging/model_specific"
157
+
158
+ @classmethod
159
+ def INPUT_TYPES(s):
160
+ arg_dict = { "model1": ("MODEL",),
161
+ "model2": ("MODEL",)}
162
+
163
+ argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
164
+
165
+ arg_dict["pos_frequencies."] = argument
166
+ arg_dict["t_embedder."] = argument
167
+ arg_dict["t5_y_embedder."] = argument
168
+ arg_dict["t5_yproj."] = argument
169
+
170
+ for i in range(48):
171
+ arg_dict["blocks.{}.".format(i)] = argument
172
+
173
+ arg_dict["final_layer."] = argument
174
+
175
+ return {"required": arg_dict}
176
+
177
+ class ModelMergeLTXV(comfy_extras.nodes_model_merging.ModelMergeBlocks):
178
+ CATEGORY = "advanced/model_merging/model_specific"
179
+
180
+ @classmethod
181
+ def INPUT_TYPES(s):
182
+ arg_dict = { "model1": ("MODEL",),
183
+ "model2": ("MODEL",)}
184
+
185
+ argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
186
+
187
+ arg_dict["patchify_proj."] = argument
188
+ arg_dict["adaln_single."] = argument
189
+ arg_dict["caption_projection."] = argument
190
+
191
+ for i in range(28):
192
+ arg_dict["transformer_blocks.{}.".format(i)] = argument
193
+
194
+ arg_dict["scale_shift_table"] = argument
195
+ arg_dict["proj_out."] = argument
196
+
197
+ return {"required": arg_dict}
198
+
199
+ NODE_CLASS_MAPPINGS = {
200
+ "ModelMergeSD1": ModelMergeSD1,
201
+ "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
202
+ "ModelMergeSDXL": ModelMergeSDXL,
203
+ "ModelMergeSD3_2B": ModelMergeSD3_2B,
204
+ "ModelMergeAuraflow": ModelMergeAuraflow,
205
+ "ModelMergeFlux1": ModelMergeFlux1,
206
+ "ModelMergeSD35_Large": ModelMergeSD35_Large,
207
+ "ModelMergeMochiPreview": ModelMergeMochiPreview,
208
+ "ModelMergeLTXV": ModelMergeLTXV,
209
+ }
comfy_extras/nodes_morphology.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import comfy.model_management
3
+
4
+ from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat
5
+
6
+
7
+ class Morphology:
8
+ @classmethod
9
+ def INPUT_TYPES(s):
10
+ return {"required": {"image": ("IMAGE",),
11
+ "operation": (["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],),
12
+ "kernel_size": ("INT", {"default": 3, "min": 3, "max": 999, "step": 1}),
13
+ }}
14
+
15
+ RETURN_TYPES = ("IMAGE",)
16
+ FUNCTION = "process"
17
+
18
+ CATEGORY = "image/postprocessing"
19
+
20
+ def process(self, image, operation, kernel_size):
21
+ device = comfy.model_management.get_torch_device()
22
+ kernel = torch.ones(kernel_size, kernel_size, device=device)
23
+ image_k = image.to(device).movedim(-1, 1)
24
+ if operation == "erode":
25
+ output = erosion(image_k, kernel)
26
+ elif operation == "dilate":
27
+ output = dilation(image_k, kernel)
28
+ elif operation == "open":
29
+ output = opening(image_k, kernel)
30
+ elif operation == "close":
31
+ output = closing(image_k, kernel)
32
+ elif operation == "gradient":
33
+ output = gradient(image_k, kernel)
34
+ elif operation == "top_hat":
35
+ output = top_hat(image_k, kernel)
36
+ elif operation == "bottom_hat":
37
+ output = bottom_hat(image_k, kernel)
38
+ else:
39
+ raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'")
40
+ img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1)
41
+ return (img_out,)
42
+
43
+ NODE_CLASS_MAPPINGS = {
44
+ "Morphology": Morphology,
45
+ }
46
+
47
+ NODE_DISPLAY_NAME_MAPPINGS = {
48
+ "Morphology": "ImageMorphology",
49
+ }
comfy_extras/nodes_pag.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Modified/simplified version of the node from: https://github.com/pamparamm/sd-perturbed-attention
2
+ #If you want the one with more options see the above repo.
3
+
4
+ #My modified one here is more basic but has less chances of breaking with ComfyUI updates.
5
+
6
+ import comfy.model_patcher
7
+ import comfy.samplers
8
+
9
+ class PerturbedAttentionGuidance:
10
+ @classmethod
11
+ def INPUT_TYPES(s):
12
+ return {
13
+ "required": {
14
+ "model": ("MODEL",),
15
+ "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}),
16
+ }
17
+ }
18
+
19
+ RETURN_TYPES = ("MODEL",)
20
+ FUNCTION = "patch"
21
+
22
+ CATEGORY = "model_patches/unet"
23
+
24
+ def patch(self, model, scale):
25
+ unet_block = "middle"
26
+ unet_block_id = 0
27
+ m = model.clone()
28
+
29
+ def perturbed_attention(q, k, v, extra_options, mask=None):
30
+ return v
31
+
32
+ def post_cfg_function(args):
33
+ model = args["model"]
34
+ cond_pred = args["cond_denoised"]
35
+ cond = args["cond"]
36
+ cfg_result = args["denoised"]
37
+ sigma = args["sigma"]
38
+ model_options = args["model_options"].copy()
39
+ x = args["input"]
40
+
41
+ if scale == 0:
42
+ return cfg_result
43
+
44
+ # Replace Self-attention with PAG
45
+ model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, perturbed_attention, "attn1", unet_block, unet_block_id)
46
+ (pag,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
47
+
48
+ return cfg_result + (cond_pred - pag) * scale
49
+
50
+ m.set_model_sampler_post_cfg_function(post_cfg_function)
51
+
52
+ return (m,)
53
+
54
+ NODE_CLASS_MAPPINGS = {
55
+ "PerturbedAttentionGuidance": PerturbedAttentionGuidance,
56
+ }
comfy_extras/nodes_perpneg.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import comfy.model_management
3
+ import comfy.sampler_helpers
4
+ import comfy.samplers
5
+ import comfy.utils
6
+ import node_helpers
7
+
8
+ def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale):
9
+ pos = noise_pred_pos - noise_pred_nocond
10
+ neg = noise_pred_neg - noise_pred_nocond
11
+
12
+ perp = neg - ((torch.mul(neg, pos).sum())/(torch.norm(pos)**2)) * pos
13
+ perp_neg = perp * neg_scale
14
+ cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg)
15
+ return cfg_result
16
+
17
+ #TODO: This node should be removed, it has been replaced with PerpNegGuider
18
+ class PerpNeg:
19
+ @classmethod
20
+ def INPUT_TYPES(s):
21
+ return {"required": {"model": ("MODEL", ),
22
+ "empty_conditioning": ("CONDITIONING", ),
23
+ "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
24
+ }}
25
+ RETURN_TYPES = ("MODEL",)
26
+ FUNCTION = "patch"
27
+
28
+ CATEGORY = "_for_testing"
29
+ DEPRECATED = True
30
+
31
+ def patch(self, model, empty_conditioning, neg_scale):
32
+ m = model.clone()
33
+ nocond = comfy.sampler_helpers.convert_cond(empty_conditioning)
34
+
35
+ def cfg_function(args):
36
+ model = args["model"]
37
+ noise_pred_pos = args["cond_denoised"]
38
+ noise_pred_neg = args["uncond_denoised"]
39
+ cond_scale = args["cond_scale"]
40
+ x = args["input"]
41
+ sigma = args["sigma"]
42
+ model_options = args["model_options"]
43
+ nocond_processed = comfy.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
44
+
45
+ (noise_pred_nocond,) = comfy.samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options)
46
+
47
+ cfg_result = x - perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale)
48
+ return cfg_result
49
+
50
+ m.set_model_sampler_cfg_function(cfg_function)
51
+
52
+ return (m, )
53
+
54
+
55
+ class Guider_PerpNeg(comfy.samplers.CFGGuider):
56
+ def set_conds(self, positive, negative, empty_negative_prompt):
57
+ empty_negative_prompt = node_helpers.conditioning_set_values(empty_negative_prompt, {"prompt_type": "negative"})
58
+ self.inner_set_conds({"positive": positive, "empty_negative_prompt": empty_negative_prompt, "negative": negative})
59
+
60
+ def set_cfg(self, cfg, neg_scale):
61
+ self.cfg = cfg
62
+ self.neg_scale = neg_scale
63
+
64
+ def predict_noise(self, x, timestep, model_options={}, seed=None):
65
+ # in CFGGuider.predict_noise, we call sampling_function(), which uses cfg_function() to compute pos & neg
66
+ # but we'd rather do a single batch of sampling pos, neg, and empty, so we call calc_cond_batch([pos,neg,empty]) directly
67
+
68
+ positive_cond = self.conds.get("positive", None)
69
+ negative_cond = self.conds.get("negative", None)
70
+ empty_cond = self.conds.get("empty_negative_prompt", None)
71
+
72
+ (noise_pred_pos, noise_pred_neg, noise_pred_empty) = \
73
+ comfy.samplers.calc_cond_batch(self.inner_model, [positive_cond, negative_cond, empty_cond], x, timestep, model_options)
74
+ cfg_result = perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_empty, self.neg_scale, self.cfg)
75
+
76
+ # normally this would be done in cfg_function, but we skipped
77
+ # that for efficiency: we can compute the noise predictions in
78
+ # a single call to calc_cond_batch() (rather than two)
79
+ # so we replicate the hook here
80
+ for fn in model_options.get("sampler_post_cfg_function", []):
81
+ args = {
82
+ "denoised": cfg_result,
83
+ "cond": positive_cond,
84
+ "uncond": negative_cond,
85
+ "model": self.inner_model,
86
+ "uncond_denoised": noise_pred_neg,
87
+ "cond_denoised": noise_pred_pos,
88
+ "sigma": timestep,
89
+ "model_options": model_options,
90
+ "input": x,
91
+ # not in the original call in samplers.py:cfg_function, but made available for future hooks
92
+ "empty_cond": empty_cond,
93
+ "empty_cond_denoised": noise_pred_empty,}
94
+ cfg_result = fn(args)
95
+
96
+ return cfg_result
97
+
98
+ class PerpNegGuider:
99
+ @classmethod
100
+ def INPUT_TYPES(s):
101
+ return {"required":
102
+ {"model": ("MODEL",),
103
+ "positive": ("CONDITIONING", ),
104
+ "negative": ("CONDITIONING", ),
105
+ "empty_conditioning": ("CONDITIONING", ),
106
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
107
+ "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
108
+ }
109
+ }
110
+
111
+ RETURN_TYPES = ("GUIDER",)
112
+
113
+ FUNCTION = "get_guider"
114
+ CATEGORY = "_for_testing"
115
+
116
+ def get_guider(self, model, positive, negative, empty_conditioning, cfg, neg_scale):
117
+ guider = Guider_PerpNeg(model)
118
+ guider.set_conds(positive, negative, empty_conditioning)
119
+ guider.set_cfg(cfg, neg_scale)
120
+ return (guider,)
121
+
122
+ NODE_CLASS_MAPPINGS = {
123
+ "PerpNeg": PerpNeg,
124
+ "PerpNegGuider": PerpNegGuider,
125
+ }
126
+
127
+ NODE_DISPLAY_NAME_MAPPINGS = {
128
+ "PerpNeg": "Perp-Neg (DEPRECATED by PerpNegGuider)",
129
+ }
comfy_extras/nodes_photomaker.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import folder_paths
4
+ import comfy.clip_model
5
+ import comfy.clip_vision
6
+ import comfy.ops
7
+
8
+ # code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0
9
+ VISION_CONFIG_DICT = {
10
+ "hidden_size": 1024,
11
+ "image_size": 224,
12
+ "intermediate_size": 4096,
13
+ "num_attention_heads": 16,
14
+ "num_channels": 3,
15
+ "num_hidden_layers": 24,
16
+ "patch_size": 14,
17
+ "projection_dim": 768,
18
+ "hidden_act": "quick_gelu",
19
+ "model_type": "clip_vision_model",
20
+ }
21
+
22
+ class MLP(nn.Module):
23
+ def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True, operations=comfy.ops):
24
+ super().__init__()
25
+ if use_residual:
26
+ assert in_dim == out_dim
27
+ self.layernorm = operations.LayerNorm(in_dim)
28
+ self.fc1 = operations.Linear(in_dim, hidden_dim)
29
+ self.fc2 = operations.Linear(hidden_dim, out_dim)
30
+ self.use_residual = use_residual
31
+ self.act_fn = nn.GELU()
32
+
33
+ def forward(self, x):
34
+ residual = x
35
+ x = self.layernorm(x)
36
+ x = self.fc1(x)
37
+ x = self.act_fn(x)
38
+ x = self.fc2(x)
39
+ if self.use_residual:
40
+ x = x + residual
41
+ return x
42
+
43
+
44
+ class FuseModule(nn.Module):
45
+ def __init__(self, embed_dim, operations):
46
+ super().__init__()
47
+ self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False, operations=operations)
48
+ self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True, operations=operations)
49
+ self.layer_norm = operations.LayerNorm(embed_dim)
50
+
51
+ def fuse_fn(self, prompt_embeds, id_embeds):
52
+ stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
53
+ stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
54
+ stacked_id_embeds = self.mlp2(stacked_id_embeds)
55
+ stacked_id_embeds = self.layer_norm(stacked_id_embeds)
56
+ return stacked_id_embeds
57
+
58
+ def forward(
59
+ self,
60
+ prompt_embeds,
61
+ id_embeds,
62
+ class_tokens_mask,
63
+ ) -> torch.Tensor:
64
+ # id_embeds shape: [b, max_num_inputs, 1, 2048]
65
+ id_embeds = id_embeds.to(prompt_embeds.dtype)
66
+ num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case
67
+ batch_size, max_num_inputs = id_embeds.shape[:2]
68
+ # seq_length: 77
69
+ seq_length = prompt_embeds.shape[1]
70
+ # flat_id_embeds shape: [b*max_num_inputs, 1, 2048]
71
+ flat_id_embeds = id_embeds.view(
72
+ -1, id_embeds.shape[-2], id_embeds.shape[-1]
73
+ )
74
+ # valid_id_mask [b*max_num_inputs]
75
+ valid_id_mask = (
76
+ torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :]
77
+ < num_inputs[:, None]
78
+ )
79
+ valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
80
+
81
+ prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1])
82
+ class_tokens_mask = class_tokens_mask.view(-1)
83
+ valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1])
84
+ # slice out the image token embeddings
85
+ image_token_embeds = prompt_embeds[class_tokens_mask]
86
+ stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
87
+ assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
88
+ prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
89
+ updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
90
+ return updated_prompt_embeds
91
+
92
+ class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection):
93
+ def __init__(self):
94
+ self.load_device = comfy.model_management.text_encoder_device()
95
+ offload_device = comfy.model_management.text_encoder_offload_device()
96
+ dtype = comfy.model_management.text_encoder_dtype(self.load_device)
97
+
98
+ super().__init__(VISION_CONFIG_DICT, dtype, offload_device, comfy.ops.manual_cast)
99
+ self.visual_projection_2 = comfy.ops.manual_cast.Linear(1024, 1280, bias=False)
100
+ self.fuse_module = FuseModule(2048, comfy.ops.manual_cast)
101
+
102
+ def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask):
103
+ b, num_inputs, c, h, w = id_pixel_values.shape
104
+ id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
105
+
106
+ shared_id_embeds = self.vision_model(id_pixel_values)[2]
107
+ id_embeds = self.visual_projection(shared_id_embeds)
108
+ id_embeds_2 = self.visual_projection_2(shared_id_embeds)
109
+
110
+ id_embeds = id_embeds.view(b, num_inputs, 1, -1)
111
+ id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
112
+
113
+ id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
114
+ updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
115
+
116
+ return updated_prompt_embeds
117
+
118
+
119
+ class PhotoMakerLoader:
120
+ @classmethod
121
+ def INPUT_TYPES(s):
122
+ return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}}
123
+
124
+ RETURN_TYPES = ("PHOTOMAKER",)
125
+ FUNCTION = "load_photomaker_model"
126
+
127
+ CATEGORY = "_for_testing/photomaker"
128
+
129
+ def load_photomaker_model(self, photomaker_model_name):
130
+ photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
131
+ photomaker_model = PhotoMakerIDEncoder()
132
+ data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
133
+ if "id_encoder" in data:
134
+ data = data["id_encoder"]
135
+ photomaker_model.load_state_dict(data)
136
+ return (photomaker_model,)
137
+
138
+
139
+ class PhotoMakerEncode:
140
+ @classmethod
141
+ def INPUT_TYPES(s):
142
+ return {"required": { "photomaker": ("PHOTOMAKER",),
143
+ "image": ("IMAGE",),
144
+ "clip": ("CLIP", ),
145
+ "text": ("STRING", {"multiline": True, "dynamicPrompts": True, "default": "photograph of photomaker"}),
146
+ }}
147
+
148
+ RETURN_TYPES = ("CONDITIONING",)
149
+ FUNCTION = "apply_photomaker"
150
+
151
+ CATEGORY = "_for_testing/photomaker"
152
+
153
+ def apply_photomaker(self, photomaker, image, clip, text):
154
+ special_token = "photomaker"
155
+ pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float()
156
+ try:
157
+ index = text.split(" ").index(special_token) + 1
158
+ except ValueError:
159
+ index = -1
160
+ tokens = clip.tokenize(text, return_word_ids=True)
161
+ out_tokens = {}
162
+ for k in tokens:
163
+ out_tokens[k] = []
164
+ for t in tokens[k]:
165
+ f = list(filter(lambda x: x[2] != index, t))
166
+ while len(f) < len(t):
167
+ f.append(t[-1])
168
+ out_tokens[k].append(f)
169
+
170
+ cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True)
171
+
172
+ if index > 0:
173
+ token_index = index - 1
174
+ num_id_images = 1
175
+ class_tokens_mask = [True if token_index <= i < token_index+num_id_images else False for i in range(77)]
176
+ out = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device),
177
+ class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0))
178
+ else:
179
+ out = cond
180
+
181
+ return ([[out, {"pooled_output": pooled}]], )
182
+
183
+
184
+ NODE_CLASS_MAPPINGS = {
185
+ "PhotoMakerLoader": PhotoMakerLoader,
186
+ "PhotoMakerEncode": PhotoMakerEncode,
187
+ }
188
+
comfy_extras/nodes_pixart.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nodes import MAX_RESOLUTION
2
+
3
+ class CLIPTextEncodePixArtAlpha:
4
+ @classmethod
5
+ def INPUT_TYPES(s):
6
+ return {"required": {
7
+ "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
8
+ "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
9
+ # "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
10
+ "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
11
+ }}
12
+
13
+ RETURN_TYPES = ("CONDITIONING",)
14
+ FUNCTION = "encode"
15
+ CATEGORY = "advanced/conditioning"
16
+ DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma."
17
+
18
+ def encode(self, clip, width, height, text):
19
+ tokens = clip.tokenize(text)
20
+ return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),)
21
+
22
+ NODE_CLASS_MAPPINGS = {
23
+ "CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha,
24
+ }
comfy_extras/nodes_post_processing.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from PIL import Image
5
+ import math
6
+
7
+ import comfy.utils
8
+ import comfy.model_management
9
+
10
+
11
+ class Blend:
12
+ def __init__(self):
13
+ pass
14
+
15
+ @classmethod
16
+ def INPUT_TYPES(s):
17
+ return {
18
+ "required": {
19
+ "image1": ("IMAGE",),
20
+ "image2": ("IMAGE",),
21
+ "blend_factor": ("FLOAT", {
22
+ "default": 0.5,
23
+ "min": 0.0,
24
+ "max": 1.0,
25
+ "step": 0.01
26
+ }),
27
+ "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],),
28
+ },
29
+ }
30
+
31
+ RETURN_TYPES = ("IMAGE",)
32
+ FUNCTION = "blend_images"
33
+
34
+ CATEGORY = "image/postprocessing"
35
+
36
+ def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
37
+ image2 = image2.to(image1.device)
38
+ if image1.shape != image2.shape:
39
+ image2 = image2.permute(0, 3, 1, 2)
40
+ image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
41
+ image2 = image2.permute(0, 2, 3, 1)
42
+
43
+ blended_image = self.blend_mode(image1, image2, blend_mode)
44
+ blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
45
+ blended_image = torch.clamp(blended_image, 0, 1)
46
+ return (blended_image,)
47
+
48
+ def blend_mode(self, img1, img2, mode):
49
+ if mode == "normal":
50
+ return img2
51
+ elif mode == "multiply":
52
+ return img1 * img2
53
+ elif mode == "screen":
54
+ return 1 - (1 - img1) * (1 - img2)
55
+ elif mode == "overlay":
56
+ return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
57
+ elif mode == "soft_light":
58
+ return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
59
+ elif mode == "difference":
60
+ return img1 - img2
61
+ else:
62
+ raise ValueError(f"Unsupported blend mode: {mode}")
63
+
64
+ def g(self, x):
65
+ return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
66
+
67
+ def gaussian_kernel(kernel_size: int, sigma: float, device=None):
68
+ x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij")
69
+ d = torch.sqrt(x * x + y * y)
70
+ g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
71
+ return g / g.sum()
72
+
73
+ class Blur:
74
+ def __init__(self):
75
+ pass
76
+
77
+ @classmethod
78
+ def INPUT_TYPES(s):
79
+ return {
80
+ "required": {
81
+ "image": ("IMAGE",),
82
+ "blur_radius": ("INT", {
83
+ "default": 1,
84
+ "min": 1,
85
+ "max": 31,
86
+ "step": 1
87
+ }),
88
+ "sigma": ("FLOAT", {
89
+ "default": 1.0,
90
+ "min": 0.1,
91
+ "max": 10.0,
92
+ "step": 0.1
93
+ }),
94
+ },
95
+ }
96
+
97
+ RETURN_TYPES = ("IMAGE",)
98
+ FUNCTION = "blur"
99
+
100
+ CATEGORY = "image/postprocessing"
101
+
102
+ def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
103
+ if blur_radius == 0:
104
+ return (image,)
105
+
106
+ image = image.to(comfy.model_management.get_torch_device())
107
+ batch_size, height, width, channels = image.shape
108
+
109
+ kernel_size = blur_radius * 2 + 1
110
+ kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1)
111
+
112
+ image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
113
+ padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
114
+ blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
115
+ blurred = blurred.permute(0, 2, 3, 1)
116
+
117
+ return (blurred.to(comfy.model_management.intermediate_device()),)
118
+
119
+ class Quantize:
120
+ def __init__(self):
121
+ pass
122
+
123
+ @classmethod
124
+ def INPUT_TYPES(s):
125
+ return {
126
+ "required": {
127
+ "image": ("IMAGE",),
128
+ "colors": ("INT", {
129
+ "default": 256,
130
+ "min": 1,
131
+ "max": 256,
132
+ "step": 1
133
+ }),
134
+ "dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],),
135
+ },
136
+ }
137
+
138
+ RETURN_TYPES = ("IMAGE",)
139
+ FUNCTION = "quantize"
140
+
141
+ CATEGORY = "image/postprocessing"
142
+
143
+ def bayer(im, pal_im, order):
144
+ def normalized_bayer_matrix(n):
145
+ if n == 0:
146
+ return np.zeros((1,1), "float32")
147
+ else:
148
+ q = 4 ** n
149
+ m = q * normalized_bayer_matrix(n - 1)
150
+ return np.bmat(((m-1.5, m+0.5), (m+1.5, m-0.5))) / q
151
+
152
+ num_colors = len(pal_im.getpalette()) // 3
153
+ spread = 2 * 256 / num_colors
154
+ bayer_n = int(math.log2(order))
155
+ bayer_matrix = torch.from_numpy(spread * normalized_bayer_matrix(bayer_n) + 0.5)
156
+
157
+ result = torch.from_numpy(np.array(im).astype(np.float32))
158
+ tw = math.ceil(result.shape[0] / bayer_matrix.shape[0])
159
+ th = math.ceil(result.shape[1] / bayer_matrix.shape[1])
160
+ tiled_matrix = bayer_matrix.tile(tw, th).unsqueeze(-1)
161
+ result.add_(tiled_matrix[:result.shape[0],:result.shape[1]]).clamp_(0, 255)
162
+ result = result.to(dtype=torch.uint8)
163
+
164
+ im = Image.fromarray(result.cpu().numpy())
165
+ im = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
166
+ return im
167
+
168
+ def quantize(self, image: torch.Tensor, colors: int, dither: str):
169
+ batch_size, height, width, _ = image.shape
170
+ result = torch.zeros_like(image)
171
+
172
+ for b in range(batch_size):
173
+ im = Image.fromarray((image[b] * 255).to(torch.uint8).numpy(), mode='RGB')
174
+
175
+ pal_im = im.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
176
+
177
+ if dither == "none":
178
+ quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
179
+ elif dither == "floyd-steinberg":
180
+ quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.FLOYDSTEINBERG)
181
+ elif dither.startswith("bayer"):
182
+ order = int(dither.split('-')[-1])
183
+ quantized_image = Quantize.bayer(im, pal_im, order)
184
+
185
+ quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
186
+ result[b] = quantized_array
187
+
188
+ return (result,)
189
+
190
+ class Sharpen:
191
+ def __init__(self):
192
+ pass
193
+
194
+ @classmethod
195
+ def INPUT_TYPES(s):
196
+ return {
197
+ "required": {
198
+ "image": ("IMAGE",),
199
+ "sharpen_radius": ("INT", {
200
+ "default": 1,
201
+ "min": 1,
202
+ "max": 31,
203
+ "step": 1
204
+ }),
205
+ "sigma": ("FLOAT", {
206
+ "default": 1.0,
207
+ "min": 0.1,
208
+ "max": 10.0,
209
+ "step": 0.01
210
+ }),
211
+ "alpha": ("FLOAT", {
212
+ "default": 1.0,
213
+ "min": 0.0,
214
+ "max": 5.0,
215
+ "step": 0.01
216
+ }),
217
+ },
218
+ }
219
+
220
+ RETURN_TYPES = ("IMAGE",)
221
+ FUNCTION = "sharpen"
222
+
223
+ CATEGORY = "image/postprocessing"
224
+
225
+ def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float):
226
+ if sharpen_radius == 0:
227
+ return (image,)
228
+
229
+ batch_size, height, width, channels = image.shape
230
+ image = image.to(comfy.model_management.get_torch_device())
231
+
232
+ kernel_size = sharpen_radius * 2 + 1
233
+ kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
234
+ center = kernel_size // 2
235
+ kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
236
+ kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
237
+
238
+ tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
239
+ tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
240
+ sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
241
+ sharpened = sharpened.permute(0, 2, 3, 1)
242
+
243
+ result = torch.clamp(sharpened, 0, 1)
244
+
245
+ return (result.to(comfy.model_management.intermediate_device()),)
246
+
247
+ class ImageScaleToTotalPixels:
248
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
249
+ crop_methods = ["disabled", "center"]
250
+
251
+ @classmethod
252
+ def INPUT_TYPES(s):
253
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
254
+ "megapixels": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 16.0, "step": 0.01}),
255
+ }}
256
+ RETURN_TYPES = ("IMAGE",)
257
+ FUNCTION = "upscale"
258
+
259
+ CATEGORY = "image/upscaling"
260
+
261
+ def upscale(self, image, upscale_method, megapixels):
262
+ samples = image.movedim(-1,1)
263
+ total = int(megapixels * 1024 * 1024)
264
+
265
+ scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
266
+ width = round(samples.shape[3] * scale_by)
267
+ height = round(samples.shape[2] * scale_by)
268
+
269
+ s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
270
+ s = s.movedim(1,-1)
271
+ return (s,)
272
+
273
+ NODE_CLASS_MAPPINGS = {
274
+ "ImageBlend": Blend,
275
+ "ImageBlur": Blur,
276
+ "ImageQuantize": Quantize,
277
+ "ImageSharpen": Sharpen,
278
+ "ImageScaleToTotalPixels": ImageScaleToTotalPixels,
279
+ }
comfy_extras/nodes_rebatch.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class LatentRebatch:
4
+ @classmethod
5
+ def INPUT_TYPES(s):
6
+ return {"required": { "latents": ("LATENT",),
7
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
8
+ }}
9
+ RETURN_TYPES = ("LATENT",)
10
+ INPUT_IS_LIST = True
11
+ OUTPUT_IS_LIST = (True, )
12
+
13
+ FUNCTION = "rebatch"
14
+
15
+ CATEGORY = "latent/batch"
16
+
17
+ @staticmethod
18
+ def get_batch(latents, list_ind, offset):
19
+ '''prepare a batch out of the list of latents'''
20
+ samples = latents[list_ind]['samples']
21
+ shape = samples.shape
22
+ mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
23
+ if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
24
+ torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
25
+ if mask.shape[0] < samples.shape[0]:
26
+ mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
27
+ if 'batch_index' in latents[list_ind]:
28
+ batch_inds = latents[list_ind]['batch_index']
29
+ else:
30
+ batch_inds = [x+offset for x in range(shape[0])]
31
+ return samples, mask, batch_inds
32
+
33
+ @staticmethod
34
+ def get_slices(indexable, num, batch_size):
35
+ '''divides an indexable object into num slices of length batch_size, and a remainder'''
36
+ slices = []
37
+ for i in range(num):
38
+ slices.append(indexable[i*batch_size:(i+1)*batch_size])
39
+ if num * batch_size < len(indexable):
40
+ return slices, indexable[num * batch_size:]
41
+ else:
42
+ return slices, None
43
+
44
+ @staticmethod
45
+ def slice_batch(batch, num, batch_size):
46
+ result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
47
+ return list(zip(*result))
48
+
49
+ @staticmethod
50
+ def cat_batch(batch1, batch2):
51
+ if batch1[0] is None:
52
+ return batch2
53
+ result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
54
+ return result
55
+
56
+ def rebatch(self, latents, batch_size):
57
+ batch_size = batch_size[0]
58
+
59
+ output_list = []
60
+ current_batch = (None, None, None)
61
+ processed = 0
62
+
63
+ for i in range(len(latents)):
64
+ # fetch new entry of list
65
+ #samples, masks, indices = self.get_batch(latents, i)
66
+ next_batch = self.get_batch(latents, i, processed)
67
+ processed += len(next_batch[2])
68
+ # set to current if current is None
69
+ if current_batch[0] is None:
70
+ current_batch = next_batch
71
+ # add previous to list if dimensions do not match
72
+ elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
73
+ sliced, _ = self.slice_batch(current_batch, 1, batch_size)
74
+ output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
75
+ current_batch = next_batch
76
+ # cat if everything checks out
77
+ else:
78
+ current_batch = self.cat_batch(current_batch, next_batch)
79
+
80
+ # add to list if dimensions gone above target batch size
81
+ if current_batch[0].shape[0] > batch_size:
82
+ num = current_batch[0].shape[0] // batch_size
83
+ sliced, remainder = self.slice_batch(current_batch, num, batch_size)
84
+
85
+ for i in range(num):
86
+ output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
87
+
88
+ current_batch = remainder
89
+
90
+ #add remainder
91
+ if current_batch[0] is not None:
92
+ sliced, _ = self.slice_batch(current_batch, 1, batch_size)
93
+ output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
94
+
95
+ #get rid of empty masks
96
+ for s in output_list:
97
+ if s['noise_mask'].mean() == 1.0:
98
+ del s['noise_mask']
99
+
100
+ return (output_list,)
101
+
102
+ class ImageRebatch:
103
+ @classmethod
104
+ def INPUT_TYPES(s):
105
+ return {"required": { "images": ("IMAGE",),
106
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
107
+ }}
108
+ RETURN_TYPES = ("IMAGE",)
109
+ INPUT_IS_LIST = True
110
+ OUTPUT_IS_LIST = (True, )
111
+
112
+ FUNCTION = "rebatch"
113
+
114
+ CATEGORY = "image/batch"
115
+
116
+ def rebatch(self, images, batch_size):
117
+ batch_size = batch_size[0]
118
+
119
+ output_list = []
120
+ all_images = []
121
+ for img in images:
122
+ for i in range(img.shape[0]):
123
+ all_images.append(img[i:i+1])
124
+
125
+ for i in range(0, len(all_images), batch_size):
126
+ output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
127
+
128
+ return (output_list,)
129
+
130
+ NODE_CLASS_MAPPINGS = {
131
+ "RebatchLatents": LatentRebatch,
132
+ "RebatchImages": ImageRebatch,
133
+ }
134
+
135
+ NODE_DISPLAY_NAME_MAPPINGS = {
136
+ "RebatchLatents": "Rebatch Latents",
137
+ "RebatchImages": "Rebatch Images",
138
+ }
comfy_extras/nodes_sag.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import einsum
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ from einops import rearrange, repeat
7
+ from comfy.ldm.modules.attention import optimized_attention
8
+ import comfy.samplers
9
+
10
+ # from comfy/ldm/modules/attention.py
11
+ # but modified to return attention scores as well as output
12
+ def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None):
13
+ b, _, dim_head = q.shape
14
+ dim_head //= heads
15
+ scale = dim_head ** -0.5
16
+
17
+ h = heads
18
+ q, k, v = map(
19
+ lambda t: t.unsqueeze(3)
20
+ .reshape(b, -1, heads, dim_head)
21
+ .permute(0, 2, 1, 3)
22
+ .reshape(b * heads, -1, dim_head)
23
+ .contiguous(),
24
+ (q, k, v),
25
+ )
26
+
27
+ # force cast to fp32 to avoid overflowing
28
+ if attn_precision == torch.float32:
29
+ sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
30
+ else:
31
+ sim = einsum('b i d, b j d -> b i j', q, k) * scale
32
+
33
+ del q, k
34
+
35
+ if mask is not None:
36
+ mask = rearrange(mask, 'b ... -> b (...)')
37
+ max_neg_value = -torch.finfo(sim.dtype).max
38
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
39
+ sim.masked_fill_(~mask, max_neg_value)
40
+
41
+ # attention, what we cannot get enough of
42
+ sim = sim.softmax(dim=-1)
43
+
44
+ out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
45
+ out = (
46
+ out.unsqueeze(0)
47
+ .reshape(b, heads, -1, dim_head)
48
+ .permute(0, 2, 1, 3)
49
+ .reshape(b, -1, heads * dim_head)
50
+ )
51
+ return (out, sim)
52
+
53
+ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
54
+ # reshape and GAP the attention map
55
+ _, hw1, hw2 = attn.shape
56
+ b, _, lh, lw = x0.shape
57
+ attn = attn.reshape(b, -1, hw1, hw2)
58
+ # Global Average Pool
59
+ mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
60
+
61
+ total = mask.shape[-1]
62
+ x = round(math.sqrt((lh / lw) * total))
63
+ xx = None
64
+ for i in range(0, math.floor(math.sqrt(total) / 2)):
65
+ for j in [(x + i), max(1, x - i)]:
66
+ if total % j == 0:
67
+ xx = j
68
+ break
69
+ if xx is not None:
70
+ break
71
+
72
+ x = xx
73
+ y = total // x
74
+
75
+ # Reshape
76
+ mask = (
77
+ mask.reshape(b, x, y)
78
+ .unsqueeze(1)
79
+ .type(attn.dtype)
80
+ )
81
+ # Upsample
82
+ mask = F.interpolate(mask, (lh, lw))
83
+
84
+ blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
85
+ blurred = blurred * mask + x0 * (1 - mask)
86
+ return blurred
87
+
88
+ def gaussian_blur_2d(img, kernel_size, sigma):
89
+ ksize_half = (kernel_size - 1) * 0.5
90
+
91
+ x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
92
+
93
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
94
+
95
+ x_kernel = pdf / pdf.sum()
96
+ x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
97
+
98
+ kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
99
+ kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
100
+
101
+ padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
102
+
103
+ img = F.pad(img, padding, mode="reflect")
104
+ img = F.conv2d(img, kernel2d, groups=img.shape[-3])
105
+ return img
106
+
107
+ class SelfAttentionGuidance:
108
+ @classmethod
109
+ def INPUT_TYPES(s):
110
+ return {"required": { "model": ("MODEL",),
111
+ "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01}),
112
+ "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
113
+ }}
114
+ RETURN_TYPES = ("MODEL",)
115
+ FUNCTION = "patch"
116
+
117
+ CATEGORY = "_for_testing"
118
+
119
+ def patch(self, model, scale, blur_sigma):
120
+ m = model.clone()
121
+
122
+ attn_scores = None
123
+
124
+ # TODO: make this work properly with chunked batches
125
+ # currently, we can only save the attn from one UNet call
126
+ def attn_and_record(q, k, v, extra_options):
127
+ nonlocal attn_scores
128
+ # if uncond, save the attention scores
129
+ heads = extra_options["n_heads"]
130
+ cond_or_uncond = extra_options["cond_or_uncond"]
131
+ b = q.shape[0] // len(cond_or_uncond)
132
+ if 1 in cond_or_uncond:
133
+ uncond_index = cond_or_uncond.index(1)
134
+ # do the entire attention operation, but save the attention scores to attn_scores
135
+ (out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
136
+ # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
137
+ n_slices = heads * b
138
+ attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
139
+ return out
140
+ else:
141
+ return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
142
+
143
+ def post_cfg_function(args):
144
+ nonlocal attn_scores
145
+ uncond_attn = attn_scores
146
+
147
+ sag_scale = scale
148
+ sag_sigma = blur_sigma
149
+ sag_threshold = 1.0
150
+ model = args["model"]
151
+ uncond_pred = args["uncond_denoised"]
152
+ uncond = args["uncond"]
153
+ cfg_result = args["denoised"]
154
+ sigma = args["sigma"]
155
+ model_options = args["model_options"]
156
+ x = args["input"]
157
+ if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding
158
+ return cfg_result
159
+
160
+ # create the adversarially blurred image
161
+ degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
162
+ degraded_noised = degraded + x - uncond_pred
163
+ # call into the UNet
164
+ (sag,) = comfy.samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options)
165
+ return cfg_result + (degraded - sag) * sag_scale
166
+
167
+ m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
168
+
169
+ # from diffusers:
170
+ # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
171
+ m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
172
+
173
+ return (m, )
174
+
175
+ NODE_CLASS_MAPPINGS = {
176
+ "SelfAttentionGuidance": SelfAttentionGuidance,
177
+ }
178
+
179
+ NODE_DISPLAY_NAME_MAPPINGS = {
180
+ "SelfAttentionGuidance": "Self-Attention Guidance",
181
+ }
comfy_extras/nodes_sd3.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import folder_paths
2
+ import comfy.sd
3
+ import comfy.model_management
4
+ import nodes
5
+ import torch
6
+ import comfy_extras.nodes_slg
7
+
8
+
9
+ class TripleCLIPLoader:
10
+ @classmethod
11
+ def INPUT_TYPES(s):
12
+ return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), )
13
+ }}
14
+ RETURN_TYPES = ("CLIP",)
15
+ FUNCTION = "load_clip"
16
+
17
+ CATEGORY = "advanced/loaders"
18
+
19
+ DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5"
20
+
21
+ def load_clip(self, clip_name1, clip_name2, clip_name3):
22
+ clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
23
+ clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
24
+ clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
25
+ clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
26
+ return (clip,)
27
+
28
+
29
+ class EmptySD3LatentImage:
30
+ def __init__(self):
31
+ self.device = comfy.model_management.intermediate_device()
32
+
33
+ @classmethod
34
+ def INPUT_TYPES(s):
35
+ return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
36
+ "height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
37
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
38
+ RETURN_TYPES = ("LATENT",)
39
+ FUNCTION = "generate"
40
+
41
+ CATEGORY = "latent/sd3"
42
+
43
+ def generate(self, width, height, batch_size=1):
44
+ latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
45
+ return ({"samples":latent}, )
46
+
47
+
48
+ class CLIPTextEncodeSD3:
49
+ @classmethod
50
+ def INPUT_TYPES(s):
51
+ return {"required": {
52
+ "clip": ("CLIP", ),
53
+ "clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
54
+ "clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
55
+ "t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
56
+ "empty_padding": (["none", "empty_prompt"], )
57
+ }}
58
+ RETURN_TYPES = ("CONDITIONING",)
59
+ FUNCTION = "encode"
60
+
61
+ CATEGORY = "advanced/conditioning"
62
+
63
+ def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding):
64
+ no_padding = empty_padding == "none"
65
+
66
+ tokens = clip.tokenize(clip_g)
67
+ if len(clip_g) == 0 and no_padding:
68
+ tokens["g"] = []
69
+
70
+ if len(clip_l) == 0 and no_padding:
71
+ tokens["l"] = []
72
+ else:
73
+ tokens["l"] = clip.tokenize(clip_l)["l"]
74
+
75
+ if len(t5xxl) == 0 and no_padding:
76
+ tokens["t5xxl"] = []
77
+ else:
78
+ tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
79
+ if len(tokens["l"]) != len(tokens["g"]):
80
+ empty = clip.tokenize("")
81
+ while len(tokens["l"]) < len(tokens["g"]):
82
+ tokens["l"] += empty["l"]
83
+ while len(tokens["l"]) > len(tokens["g"]):
84
+ tokens["g"] += empty["g"]
85
+ return (clip.encode_from_tokens_scheduled(tokens), )
86
+
87
+
88
+ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
89
+ @classmethod
90
+ def INPUT_TYPES(s):
91
+ return {"required": {"positive": ("CONDITIONING", ),
92
+ "negative": ("CONDITIONING", ),
93
+ "control_net": ("CONTROL_NET", ),
94
+ "vae": ("VAE", ),
95
+ "image": ("IMAGE", ),
96
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
97
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
98
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
99
+ }}
100
+ CATEGORY = "conditioning/controlnet"
101
+ DEPRECATED = True
102
+
103
+
104
+ class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT):
105
+ '''
106
+ Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
107
+ Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
108
+ Experimental implementation by Dango233@StabilityAI.
109
+ '''
110
+ @classmethod
111
+ def INPUT_TYPES(s):
112
+ return {"required": {"model": ("MODEL", ),
113
+ "layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
114
+ "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
115
+ "start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
116
+ "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
117
+ }}
118
+ RETURN_TYPES = ("MODEL",)
119
+ FUNCTION = "skip_guidance_sd3"
120
+
121
+ CATEGORY = "advanced/guidance"
122
+
123
+ def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent):
124
+ return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
125
+
126
+
127
+ NODE_CLASS_MAPPINGS = {
128
+ "TripleCLIPLoader": TripleCLIPLoader,
129
+ "EmptySD3LatentImage": EmptySD3LatentImage,
130
+ "CLIPTextEncodeSD3": CLIPTextEncodeSD3,
131
+ "ControlNetApplySD3": ControlNetApplySD3,
132
+ "SkipLayerGuidanceSD3": SkipLayerGuidanceSD3,
133
+ }
134
+
135
+ NODE_DISPLAY_NAME_MAPPINGS = {
136
+ # Sampling
137
+ "ControlNetApplySD3": "Apply Controlnet with VAE",
138
+ }
comfy_extras/nodes_sdupscale.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import comfy.utils
3
+
4
+ class SD_4XUpscale_Conditioning:
5
+ @classmethod
6
+ def INPUT_TYPES(s):
7
+ return {"required": { "images": ("IMAGE",),
8
+ "positive": ("CONDITIONING",),
9
+ "negative": ("CONDITIONING",),
10
+ "scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}),
11
+ "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
12
+ }}
13
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
14
+ RETURN_NAMES = ("positive", "negative", "latent")
15
+
16
+ FUNCTION = "encode"
17
+
18
+ CATEGORY = "conditioning/upscale_diffusion"
19
+
20
+ def encode(self, images, positive, negative, scale_ratio, noise_augmentation):
21
+ width = max(1, round(images.shape[-2] * scale_ratio))
22
+ height = max(1, round(images.shape[-3] * scale_ratio))
23
+
24
+ pixels = comfy.utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center")
25
+
26
+ out_cp = []
27
+ out_cn = []
28
+
29
+ for t in positive:
30
+ n = [t[0], t[1].copy()]
31
+ n[1]['concat_image'] = pixels
32
+ n[1]['noise_augmentation'] = noise_augmentation
33
+ out_cp.append(n)
34
+
35
+ for t in negative:
36
+ n = [t[0], t[1].copy()]
37
+ n[1]['concat_image'] = pixels
38
+ n[1]['noise_augmentation'] = noise_augmentation
39
+ out_cn.append(n)
40
+
41
+ latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
42
+ return (out_cp, out_cn, {"samples":latent})
43
+
44
+ NODE_CLASS_MAPPINGS = {
45
+ "SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning,
46
+ }
comfy_extras/nodes_slg.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import comfy.model_patcher
2
+ import comfy.samplers
3
+ import re
4
+
5
+
6
+ class SkipLayerGuidanceDiT:
7
+ '''
8
+ Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
9
+ Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
10
+ Original experimental implementation for SD3 by Dango233@StabilityAI.
11
+ '''
12
+ @classmethod
13
+ def INPUT_TYPES(s):
14
+ return {"required": {"model": ("MODEL", ),
15
+ "double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
16
+ "single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
17
+ "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
18
+ "start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
19
+ "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}),
20
+ "rescaling_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}),
21
+ }}
22
+ RETURN_TYPES = ("MODEL",)
23
+ FUNCTION = "skip_guidance"
24
+ EXPERIMENTAL = True
25
+
26
+ DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model."
27
+
28
+ CATEGORY = "advanced/guidance"
29
+
30
+ def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0):
31
+ # check if layer is comma separated integers
32
+ def skip(args, extra_args):
33
+ return args
34
+
35
+ model_sampling = model.get_model_object("model_sampling")
36
+ sigma_start = model_sampling.percent_to_sigma(start_percent)
37
+ sigma_end = model_sampling.percent_to_sigma(end_percent)
38
+
39
+ double_layers = re.findall(r'\d+', double_layers)
40
+ double_layers = [int(i) for i in double_layers]
41
+
42
+ single_layers = re.findall(r'\d+', single_layers)
43
+ single_layers = [int(i) for i in single_layers]
44
+
45
+ if len(double_layers) == 0 and len(single_layers) == 0:
46
+ return (model, )
47
+
48
+ def post_cfg_function(args):
49
+ model = args["model"]
50
+ cond_pred = args["cond_denoised"]
51
+ cond = args["cond"]
52
+ cfg_result = args["denoised"]
53
+ sigma = args["sigma"]
54
+ x = args["input"]
55
+ model_options = args["model_options"].copy()
56
+
57
+ for layer in double_layers:
58
+ model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
59
+
60
+ for layer in single_layers:
61
+ model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "single_block", layer)
62
+
63
+ model_sampling.percent_to_sigma(start_percent)
64
+
65
+ sigma_ = sigma[0].item()
66
+ if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
67
+ (slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
68
+ cfg_result = cfg_result + (cond_pred - slg) * scale
69
+ if rescaling_scale != 0:
70
+ factor = cond_pred.std() / cfg_result.std()
71
+ factor = rescaling_scale * factor + (1 - rescaling_scale)
72
+ cfg_result *= factor
73
+
74
+ return cfg_result
75
+
76
+ m = model.clone()
77
+ m.set_model_sampler_post_cfg_function(post_cfg_function)
78
+
79
+ return (m, )
80
+
81
+
82
+ NODE_CLASS_MAPPINGS = {
83
+ "SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
84
+ }
comfy_extras/nodes_stable3d.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import nodes
3
+ import comfy.utils
4
+
5
+ def camera_embeddings(elevation, azimuth):
6
+ elevation = torch.as_tensor([elevation])
7
+ azimuth = torch.as_tensor([azimuth])
8
+ embeddings = torch.stack(
9
+ [
10
+ torch.deg2rad(
11
+ (90 - elevation) - (90)
12
+ ), # Zero123 polar is 90-elevation
13
+ torch.sin(torch.deg2rad(azimuth)),
14
+ torch.cos(torch.deg2rad(azimuth)),
15
+ torch.deg2rad(
16
+ 90 - torch.full_like(elevation, 0)
17
+ ),
18
+ ], dim=-1).unsqueeze(1)
19
+
20
+ return embeddings
21
+
22
+
23
+ class StableZero123_Conditioning:
24
+ @classmethod
25
+ def INPUT_TYPES(s):
26
+ return {"required": { "clip_vision": ("CLIP_VISION",),
27
+ "init_image": ("IMAGE",),
28
+ "vae": ("VAE",),
29
+ "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
30
+ "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
31
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
32
+ "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
33
+ "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
34
+ }}
35
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
36
+ RETURN_NAMES = ("positive", "negative", "latent")
37
+
38
+ FUNCTION = "encode"
39
+
40
+ CATEGORY = "conditioning/3d_models"
41
+
42
+ def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth):
43
+ output = clip_vision.encode_image(init_image)
44
+ pooled = output.image_embeds.unsqueeze(0)
45
+ pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
46
+ encode_pixels = pixels[:,:,:,:3]
47
+ t = vae.encode(encode_pixels)
48
+ cam_embeds = camera_embeddings(elevation, azimuth)
49
+ cond = torch.cat([pooled, cam_embeds.to(pooled.device).repeat((pooled.shape[0], 1, 1))], dim=-1)
50
+
51
+ positive = [[cond, {"concat_latent_image": t}]]
52
+ negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
53
+ latent = torch.zeros([batch_size, 4, height // 8, width // 8])
54
+ return (positive, negative, {"samples":latent})
55
+
56
+ class StableZero123_Conditioning_Batched:
57
+ @classmethod
58
+ def INPUT_TYPES(s):
59
+ return {"required": { "clip_vision": ("CLIP_VISION",),
60
+ "init_image": ("IMAGE",),
61
+ "vae": ("VAE",),
62
+ "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
63
+ "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
64
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
65
+ "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
66
+ "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
67
+ "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
68
+ "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
69
+ }}
70
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
71
+ RETURN_NAMES = ("positive", "negative", "latent")
72
+
73
+ FUNCTION = "encode"
74
+
75
+ CATEGORY = "conditioning/3d_models"
76
+
77
+ def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment):
78
+ output = clip_vision.encode_image(init_image)
79
+ pooled = output.image_embeds.unsqueeze(0)
80
+ pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
81
+ encode_pixels = pixels[:,:,:,:3]
82
+ t = vae.encode(encode_pixels)
83
+
84
+ cam_embeds = []
85
+ for i in range(batch_size):
86
+ cam_embeds.append(camera_embeddings(elevation, azimuth))
87
+ elevation += elevation_batch_increment
88
+ azimuth += azimuth_batch_increment
89
+
90
+ cam_embeds = torch.cat(cam_embeds, dim=0)
91
+ cond = torch.cat([comfy.utils.repeat_to_batch_size(pooled, batch_size), cam_embeds], dim=-1)
92
+
93
+ positive = [[cond, {"concat_latent_image": t}]]
94
+ negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
95
+ latent = torch.zeros([batch_size, 4, height // 8, width // 8])
96
+ return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
97
+
98
+ class SV3D_Conditioning:
99
+ @classmethod
100
+ def INPUT_TYPES(s):
101
+ return {"required": { "clip_vision": ("CLIP_VISION",),
102
+ "init_image": ("IMAGE",),
103
+ "vae": ("VAE",),
104
+ "width": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
105
+ "height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
106
+ "video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}),
107
+ "elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}),
108
+ }}
109
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
110
+ RETURN_NAMES = ("positive", "negative", "latent")
111
+
112
+ FUNCTION = "encode"
113
+
114
+ CATEGORY = "conditioning/3d_models"
115
+
116
+ def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation):
117
+ output = clip_vision.encode_image(init_image)
118
+ pooled = output.image_embeds.unsqueeze(0)
119
+ pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
120
+ encode_pixels = pixels[:,:,:,:3]
121
+ t = vae.encode(encode_pixels)
122
+
123
+ azimuth = 0
124
+ azimuth_increment = 360 / (max(video_frames, 2) - 1)
125
+
126
+ elevations = []
127
+ azimuths = []
128
+ for i in range(video_frames):
129
+ elevations.append(elevation)
130
+ azimuths.append(azimuth)
131
+ azimuth += azimuth_increment
132
+
133
+ positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]]
134
+ negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]]
135
+ latent = torch.zeros([video_frames, 4, height // 8, width // 8])
136
+ return (positive, negative, {"samples":latent})
137
+
138
+
139
+ NODE_CLASS_MAPPINGS = {
140
+ "StableZero123_Conditioning": StableZero123_Conditioning,
141
+ "StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched,
142
+ "SV3D_Conditioning": SV3D_Conditioning,
143
+ }