MaykaGR commited on
Commit
dcda5cf
·
verified ·
1 Parent(s): 8a90622

Upload 4 files

Browse files
Files changed (4) hide show
  1. execution.py +994 -0
  2. main.py +304 -0
  3. nodes.py +2262 -0
  4. server.py +863 -0
execution.py ADDED
@@ -0,0 +1,994 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import copy
3
+ import logging
4
+ import threading
5
+ import heapq
6
+ import time
7
+ import traceback
8
+ from enum import Enum
9
+ import inspect
10
+ from typing import List, Literal, NamedTuple, Optional
11
+
12
+ import torch
13
+ import nodes
14
+
15
+ import comfy.model_management
16
+ from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
17
+ from comfy_execution.graph_utils import is_link, GraphBuilder
18
+ from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
19
+ from comfy_execution.validation import validate_node_input
20
+
21
+ class ExecutionResult(Enum):
22
+ SUCCESS = 0
23
+ FAILURE = 1
24
+ PENDING = 2
25
+
26
+ class DuplicateNodeError(Exception):
27
+ pass
28
+
29
+ class IsChangedCache:
30
+ def __init__(self, dynprompt, outputs_cache):
31
+ self.dynprompt = dynprompt
32
+ self.outputs_cache = outputs_cache
33
+ self.is_changed = {}
34
+
35
+ def get(self, node_id):
36
+ if node_id in self.is_changed:
37
+ return self.is_changed[node_id]
38
+
39
+ node = self.dynprompt.get_node(node_id)
40
+ class_type = node["class_type"]
41
+ class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
42
+ if not hasattr(class_def, "IS_CHANGED"):
43
+ self.is_changed[node_id] = False
44
+ return self.is_changed[node_id]
45
+
46
+ if "is_changed" in node:
47
+ self.is_changed[node_id] = node["is_changed"]
48
+ return self.is_changed[node_id]
49
+
50
+ # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
51
+ input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
52
+ try:
53
+ is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
54
+ node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
55
+ except Exception as e:
56
+ logging.warning("WARNING: {}".format(e))
57
+ node["is_changed"] = float("NaN")
58
+ finally:
59
+ self.is_changed[node_id] = node["is_changed"]
60
+ return self.is_changed[node_id]
61
+
62
+ class CacheSet:
63
+ def __init__(self, lru_size=None):
64
+ if lru_size is None or lru_size == 0:
65
+ self.init_classic_cache()
66
+ else:
67
+ self.init_lru_cache(lru_size)
68
+ self.all = [self.outputs, self.ui, self.objects]
69
+
70
+ # Useful for those with ample RAM/VRAM -- allows experimenting without
71
+ # blowing away the cache every time
72
+ def init_lru_cache(self, cache_size):
73
+ self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
74
+ self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
75
+ self.objects = HierarchicalCache(CacheKeySetID)
76
+
77
+ # Performs like the old cache -- dump data ASAP
78
+ def init_classic_cache(self):
79
+ self.outputs = HierarchicalCache(CacheKeySetInputSignature)
80
+ self.ui = HierarchicalCache(CacheKeySetInputSignature)
81
+ self.objects = HierarchicalCache(CacheKeySetID)
82
+
83
+ def recursive_debug_dump(self):
84
+ result = {
85
+ "outputs": self.outputs.recursive_debug_dump(),
86
+ "ui": self.ui.recursive_debug_dump(),
87
+ }
88
+ return result
89
+
90
+ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
91
+ valid_inputs = class_def.INPUT_TYPES()
92
+ input_data_all = {}
93
+ missing_keys = {}
94
+ for x in inputs:
95
+ input_data = inputs[x]
96
+ input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
97
+ def mark_missing():
98
+ missing_keys[x] = True
99
+ input_data_all[x] = (None,)
100
+ if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
101
+ input_unique_id = input_data[0]
102
+ output_index = input_data[1]
103
+ if outputs is None:
104
+ mark_missing()
105
+ continue # This might be a lazily-evaluated input
106
+ cached_output = outputs.get(input_unique_id)
107
+ if cached_output is None:
108
+ mark_missing()
109
+ continue
110
+ if output_index >= len(cached_output):
111
+ mark_missing()
112
+ continue
113
+ obj = cached_output[output_index]
114
+ input_data_all[x] = obj
115
+ elif input_category is not None:
116
+ input_data_all[x] = [input_data]
117
+
118
+ if "hidden" in valid_inputs:
119
+ h = valid_inputs["hidden"]
120
+ for x in h:
121
+ if h[x] == "PROMPT":
122
+ input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
123
+ if h[x] == "DYNPROMPT":
124
+ input_data_all[x] = [dynprompt]
125
+ if h[x] == "EXTRA_PNGINFO":
126
+ input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
127
+ if h[x] == "UNIQUE_ID":
128
+ input_data_all[x] = [unique_id]
129
+ return input_data_all, missing_keys
130
+
131
+ map_node_over_list = None #Don't hook this please
132
+
133
+ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
134
+ # check if node wants the lists
135
+ input_is_list = getattr(obj, "INPUT_IS_LIST", False)
136
+
137
+ if len(input_data_all) == 0:
138
+ max_len_input = 0
139
+ else:
140
+ max_len_input = max(len(x) for x in input_data_all.values())
141
+
142
+ # get a slice of inputs, repeat last input when list isn't long enough
143
+ def slice_dict(d, i):
144
+ return {k: v[i if len(v) > i else -1] for k, v in d.items()}
145
+
146
+ results = []
147
+ def process_inputs(inputs, index=None, input_is_list=False):
148
+ if allow_interrupt:
149
+ nodes.before_node_execution()
150
+ execution_block = None
151
+ for k, v in inputs.items():
152
+ if input_is_list:
153
+ for e in v:
154
+ if isinstance(e, ExecutionBlocker):
155
+ v = e
156
+ break
157
+ if isinstance(v, ExecutionBlocker):
158
+ execution_block = execution_block_cb(v) if execution_block_cb else v
159
+ break
160
+ if execution_block is None:
161
+ if pre_execute_cb is not None and index is not None:
162
+ pre_execute_cb(index)
163
+ results.append(getattr(obj, func)(**inputs))
164
+ else:
165
+ results.append(execution_block)
166
+
167
+ if input_is_list:
168
+ process_inputs(input_data_all, 0, input_is_list=input_is_list)
169
+ elif max_len_input == 0:
170
+ process_inputs({})
171
+ else:
172
+ for i in range(max_len_input):
173
+ input_dict = slice_dict(input_data_all, i)
174
+ process_inputs(input_dict, i)
175
+ return results
176
+
177
+ def merge_result_data(results, obj):
178
+ # check which outputs need concatenating
179
+ output = []
180
+ output_is_list = [False] * len(results[0])
181
+ if hasattr(obj, "OUTPUT_IS_LIST"):
182
+ output_is_list = obj.OUTPUT_IS_LIST
183
+
184
+ # merge node execution results
185
+ for i, is_list in zip(range(len(results[0])), output_is_list):
186
+ if is_list:
187
+ value = []
188
+ for o in results:
189
+ if isinstance(o[i], ExecutionBlocker):
190
+ value.append(o[i])
191
+ else:
192
+ value.extend(o[i])
193
+ output.append(value)
194
+ else:
195
+ output.append([o[i] for o in results])
196
+ return output
197
+
198
+ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
199
+ results = []
200
+ uis = []
201
+ subgraph_results = []
202
+ return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
203
+ has_subgraph = False
204
+ for i in range(len(return_values)):
205
+ r = return_values[i]
206
+ if isinstance(r, dict):
207
+ if 'ui' in r:
208
+ uis.append(r['ui'])
209
+ if 'expand' in r:
210
+ # Perform an expansion, but do not append results
211
+ has_subgraph = True
212
+ new_graph = r['expand']
213
+ result = r.get("result", None)
214
+ if isinstance(result, ExecutionBlocker):
215
+ result = tuple([result] * len(obj.RETURN_TYPES))
216
+ subgraph_results.append((new_graph, result))
217
+ elif 'result' in r:
218
+ result = r.get("result", None)
219
+ if isinstance(result, ExecutionBlocker):
220
+ result = tuple([result] * len(obj.RETURN_TYPES))
221
+ results.append(result)
222
+ subgraph_results.append((None, result))
223
+ else:
224
+ if isinstance(r, ExecutionBlocker):
225
+ r = tuple([r] * len(obj.RETURN_TYPES))
226
+ results.append(r)
227
+ subgraph_results.append((None, r))
228
+
229
+ if has_subgraph:
230
+ output = subgraph_results
231
+ elif len(results) > 0:
232
+ output = merge_result_data(results, obj)
233
+ else:
234
+ output = []
235
+ ui = dict()
236
+ if len(uis) > 0:
237
+ ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
238
+ return output, ui, has_subgraph
239
+
240
+ def format_value(x):
241
+ if x is None:
242
+ return None
243
+ elif isinstance(x, (int, float, bool, str)):
244
+ return x
245
+ else:
246
+ return str(x)
247
+
248
+ def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
249
+ unique_id = current_item
250
+ real_node_id = dynprompt.get_real_node_id(unique_id)
251
+ display_node_id = dynprompt.get_display_node_id(unique_id)
252
+ parent_node_id = dynprompt.get_parent_node_id(unique_id)
253
+ inputs = dynprompt.get_node(unique_id)['inputs']
254
+ class_type = dynprompt.get_node(unique_id)['class_type']
255
+ class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
256
+ if caches.outputs.get(unique_id) is not None:
257
+ if server.client_id is not None:
258
+ cached_output = caches.ui.get(unique_id) or {}
259
+ server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
260
+ return (ExecutionResult.SUCCESS, None, None)
261
+
262
+ input_data_all = None
263
+ try:
264
+ if unique_id in pending_subgraph_results:
265
+ cached_results = pending_subgraph_results[unique_id]
266
+ resolved_outputs = []
267
+ for is_subgraph, result in cached_results:
268
+ if not is_subgraph:
269
+ resolved_outputs.append(result)
270
+ else:
271
+ resolved_output = []
272
+ for r in result:
273
+ if is_link(r):
274
+ source_node, source_output = r[0], r[1]
275
+ node_output = caches.outputs.get(source_node)[source_output]
276
+ for o in node_output:
277
+ resolved_output.append(o)
278
+
279
+ else:
280
+ resolved_output.append(r)
281
+ resolved_outputs.append(tuple(resolved_output))
282
+ output_data = merge_result_data(resolved_outputs, class_def)
283
+ output_ui = []
284
+ has_subgraph = False
285
+ else:
286
+ input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
287
+ if server.client_id is not None:
288
+ server.last_node_id = display_node_id
289
+ server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
290
+
291
+ obj = caches.objects.get(unique_id)
292
+ if obj is None:
293
+ obj = class_def()
294
+ caches.objects.set(unique_id, obj)
295
+
296
+ if hasattr(obj, "check_lazy_status"):
297
+ required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
298
+ required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
299
+ required_inputs = [x for x in required_inputs if isinstance(x,str) and (
300
+ x not in input_data_all or x in missing_keys
301
+ )]
302
+ if len(required_inputs) > 0:
303
+ for i in required_inputs:
304
+ execution_list.make_input_strong_link(unique_id, i)
305
+ return (ExecutionResult.PENDING, None, None)
306
+
307
+ def execution_block_cb(block):
308
+ if block.message is not None:
309
+ mes = {
310
+ "prompt_id": prompt_id,
311
+ "node_id": unique_id,
312
+ "node_type": class_type,
313
+ "executed": list(executed),
314
+
315
+ "exception_message": f"Execution Blocked: {block.message}",
316
+ "exception_type": "ExecutionBlocked",
317
+ "traceback": [],
318
+ "current_inputs": [],
319
+ "current_outputs": [],
320
+ }
321
+ server.send_sync("execution_error", mes, server.client_id)
322
+ return ExecutionBlocker(None)
323
+ else:
324
+ return block
325
+ def pre_execute_cb(call_index):
326
+ GraphBuilder.set_default_prefix(unique_id, call_index, 0)
327
+ output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
328
+ if len(output_ui) > 0:
329
+ caches.ui.set(unique_id, {
330
+ "meta": {
331
+ "node_id": unique_id,
332
+ "display_node": display_node_id,
333
+ "parent_node": parent_node_id,
334
+ "real_node_id": real_node_id,
335
+ },
336
+ "output": output_ui
337
+ })
338
+ if server.client_id is not None:
339
+ server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
340
+ if has_subgraph:
341
+ cached_outputs = []
342
+ new_node_ids = []
343
+ new_output_ids = []
344
+ new_output_links = []
345
+ for i in range(len(output_data)):
346
+ new_graph, node_outputs = output_data[i]
347
+ if new_graph is None:
348
+ cached_outputs.append((False, node_outputs))
349
+ else:
350
+ # Check for conflicts
351
+ for node_id in new_graph.keys():
352
+ if dynprompt.has_node(node_id):
353
+ raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.")
354
+ for node_id, node_info in new_graph.items():
355
+ new_node_ids.append(node_id)
356
+ display_id = node_info.get("override_display_id", unique_id)
357
+ dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
358
+ # Figure out if the newly created node is an output node
359
+ class_type = node_info["class_type"]
360
+ class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
361
+ if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
362
+ new_output_ids.append(node_id)
363
+ for i in range(len(node_outputs)):
364
+ if is_link(node_outputs[i]):
365
+ from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
366
+ new_output_links.append((from_node_id, from_socket))
367
+ cached_outputs.append((True, node_outputs))
368
+ new_node_ids = set(new_node_ids)
369
+ for cache in caches.all:
370
+ cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
371
+ for node_id in new_output_ids:
372
+ execution_list.add_node(node_id)
373
+ for link in new_output_links:
374
+ execution_list.add_strong_link(link[0], link[1], unique_id)
375
+ pending_subgraph_results[unique_id] = cached_outputs
376
+ return (ExecutionResult.PENDING, None, None)
377
+ caches.outputs.set(unique_id, output_data)
378
+ except comfy.model_management.InterruptProcessingException as iex:
379
+ logging.info("Processing interrupted")
380
+
381
+ # skip formatting inputs/outputs
382
+ error_details = {
383
+ "node_id": real_node_id,
384
+ }
385
+
386
+ return (ExecutionResult.FAILURE, error_details, iex)
387
+ except Exception as ex:
388
+ typ, _, tb = sys.exc_info()
389
+ exception_type = full_type_name(typ)
390
+ input_data_formatted = {}
391
+ if input_data_all is not None:
392
+ input_data_formatted = {}
393
+ for name, inputs in input_data_all.items():
394
+ input_data_formatted[name] = [format_value(x) for x in inputs]
395
+
396
+ logging.error(f"!!! Exception during processing !!! {ex}")
397
+ logging.error(traceback.format_exc())
398
+
399
+ error_details = {
400
+ "node_id": real_node_id,
401
+ "exception_message": str(ex),
402
+ "exception_type": exception_type,
403
+ "traceback": traceback.format_tb(tb),
404
+ "current_inputs": input_data_formatted
405
+ }
406
+ if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
407
+ logging.error("Got an OOM, unloading all loaded models.")
408
+ comfy.model_management.unload_all_models()
409
+
410
+ return (ExecutionResult.FAILURE, error_details, ex)
411
+
412
+ executed.add(unique_id)
413
+
414
+ return (ExecutionResult.SUCCESS, None, None)
415
+
416
+ class PromptExecutor:
417
+ def __init__(self, server, lru_size=None):
418
+ self.lru_size = lru_size
419
+ self.server = server
420
+ self.reset()
421
+
422
+ def reset(self):
423
+ self.caches = CacheSet(self.lru_size)
424
+ self.status_messages = []
425
+ self.success = True
426
+
427
+ def add_message(self, event, data: dict, broadcast: bool):
428
+ data = {
429
+ **data,
430
+ "timestamp": int(time.time() * 1000),
431
+ }
432
+ self.status_messages.append((event, data))
433
+ if self.server.client_id is not None or broadcast:
434
+ self.server.send_sync(event, data, self.server.client_id)
435
+
436
+ def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
437
+ node_id = error["node_id"]
438
+ class_type = prompt[node_id]["class_type"]
439
+
440
+ # First, send back the status to the frontend depending
441
+ # on the exception type
442
+ if isinstance(ex, comfy.model_management.InterruptProcessingException):
443
+ mes = {
444
+ "prompt_id": prompt_id,
445
+ "node_id": node_id,
446
+ "node_type": class_type,
447
+ "executed": list(executed),
448
+ }
449
+ self.add_message("execution_interrupted", mes, broadcast=True)
450
+ else:
451
+ mes = {
452
+ "prompt_id": prompt_id,
453
+ "node_id": node_id,
454
+ "node_type": class_type,
455
+ "executed": list(executed),
456
+ "exception_message": error["exception_message"],
457
+ "exception_type": error["exception_type"],
458
+ "traceback": error["traceback"],
459
+ "current_inputs": error["current_inputs"],
460
+ "current_outputs": list(current_outputs),
461
+ }
462
+ self.add_message("execution_error", mes, broadcast=False)
463
+
464
+ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
465
+ nodes.interrupt_processing(False)
466
+
467
+ if "client_id" in extra_data:
468
+ self.server.client_id = extra_data["client_id"]
469
+ else:
470
+ self.server.client_id = None
471
+
472
+ self.status_messages = []
473
+ self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
474
+
475
+ with torch.inference_mode():
476
+ dynamic_prompt = DynamicPrompt(prompt)
477
+ is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
478
+ for cache in self.caches.all:
479
+ cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
480
+ cache.clean_unused()
481
+
482
+ cached_nodes = []
483
+ for node_id in prompt:
484
+ if self.caches.outputs.get(node_id) is not None:
485
+ cached_nodes.append(node_id)
486
+
487
+ comfy.model_management.cleanup_models_gc()
488
+ self.add_message("execution_cached",
489
+ { "nodes": cached_nodes, "prompt_id": prompt_id},
490
+ broadcast=False)
491
+ pending_subgraph_results = {}
492
+ executed = set()
493
+ execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
494
+ current_outputs = self.caches.outputs.all_node_ids()
495
+ for node_id in list(execute_outputs):
496
+ execution_list.add_node(node_id)
497
+
498
+ while not execution_list.is_empty():
499
+ node_id, error, ex = execution_list.stage_node_execution()
500
+ if error is not None:
501
+ self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
502
+ break
503
+
504
+ result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
505
+ self.success = result != ExecutionResult.FAILURE
506
+ if result == ExecutionResult.FAILURE:
507
+ self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
508
+ break
509
+ elif result == ExecutionResult.PENDING:
510
+ execution_list.unstage_node_execution()
511
+ else: # result == ExecutionResult.SUCCESS:
512
+ execution_list.complete_node_execution()
513
+ else:
514
+ # Only execute when the while-loop ends without break
515
+ self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
516
+
517
+ ui_outputs = {}
518
+ meta_outputs = {}
519
+ all_node_ids = self.caches.ui.all_node_ids()
520
+ for node_id in all_node_ids:
521
+ ui_info = self.caches.ui.get(node_id)
522
+ if ui_info is not None:
523
+ ui_outputs[node_id] = ui_info["output"]
524
+ meta_outputs[node_id] = ui_info["meta"]
525
+ self.history_result = {
526
+ "outputs": ui_outputs,
527
+ "meta": meta_outputs,
528
+ }
529
+ self.server.last_node_id = None
530
+ if comfy.model_management.DISABLE_SMART_MEMORY:
531
+ comfy.model_management.unload_all_models()
532
+
533
+
534
+ def validate_inputs(prompt, item, validated):
535
+ unique_id = item
536
+ if unique_id in validated:
537
+ return validated[unique_id]
538
+
539
+ inputs = prompt[unique_id]['inputs']
540
+ class_type = prompt[unique_id]['class_type']
541
+ obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
542
+
543
+ class_inputs = obj_class.INPUT_TYPES()
544
+ valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
545
+
546
+ errors = []
547
+ valid = True
548
+
549
+ validate_function_inputs = []
550
+ validate_has_kwargs = False
551
+ if hasattr(obj_class, "VALIDATE_INPUTS"):
552
+ argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS)
553
+ validate_function_inputs = argspec.args
554
+ validate_has_kwargs = argspec.varkw is not None
555
+ received_types = {}
556
+
557
+ for x in valid_inputs:
558
+ type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
559
+ assert extra_info is not None
560
+ if x not in inputs:
561
+ if input_category == "required":
562
+ error = {
563
+ "type": "required_input_missing",
564
+ "message": "Required input is missing",
565
+ "details": f"{x}",
566
+ "extra_info": {
567
+ "input_name": x
568
+ }
569
+ }
570
+ errors.append(error)
571
+ continue
572
+
573
+ val = inputs[x]
574
+ info = (type_input, extra_info)
575
+ if isinstance(val, list):
576
+ if len(val) != 2:
577
+ error = {
578
+ "type": "bad_linked_input",
579
+ "message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
580
+ "details": f"{x}",
581
+ "extra_info": {
582
+ "input_name": x,
583
+ "input_config": info,
584
+ "received_value": val
585
+ }
586
+ }
587
+ errors.append(error)
588
+ continue
589
+
590
+ o_id = val[0]
591
+ o_class_type = prompt[o_id]['class_type']
592
+ r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
593
+ received_type = r[val[1]]
594
+ received_types[x] = received_type
595
+ if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input):
596
+ details = f"{x}, received_type({received_type}) mismatch input_type({type_input})"
597
+ error = {
598
+ "type": "return_type_mismatch",
599
+ "message": "Return type mismatch between linked nodes",
600
+ "details": details,
601
+ "extra_info": {
602
+ "input_name": x,
603
+ "input_config": info,
604
+ "received_type": received_type,
605
+ "linked_node": val
606
+ }
607
+ }
608
+ errors.append(error)
609
+ continue
610
+ try:
611
+ r = validate_inputs(prompt, o_id, validated)
612
+ if r[0] is False:
613
+ # `r` will be set in `validated[o_id]` already
614
+ valid = False
615
+ continue
616
+ except Exception as ex:
617
+ typ, _, tb = sys.exc_info()
618
+ valid = False
619
+ exception_type = full_type_name(typ)
620
+ reasons = [{
621
+ "type": "exception_during_inner_validation",
622
+ "message": "Exception when validating inner node",
623
+ "details": str(ex),
624
+ "extra_info": {
625
+ "input_name": x,
626
+ "input_config": info,
627
+ "exception_message": str(ex),
628
+ "exception_type": exception_type,
629
+ "traceback": traceback.format_tb(tb),
630
+ "linked_node": val
631
+ }
632
+ }]
633
+ validated[o_id] = (False, reasons, o_id)
634
+ continue
635
+ else:
636
+ try:
637
+ if type_input == "INT":
638
+ val = int(val)
639
+ inputs[x] = val
640
+ if type_input == "FLOAT":
641
+ val = float(val)
642
+ inputs[x] = val
643
+ if type_input == "STRING":
644
+ val = str(val)
645
+ inputs[x] = val
646
+ if type_input == "BOOLEAN":
647
+ val = bool(val)
648
+ inputs[x] = val
649
+ except Exception as ex:
650
+ error = {
651
+ "type": "invalid_input_type",
652
+ "message": f"Failed to convert an input value to a {type_input} value",
653
+ "details": f"{x}, {val}, {ex}",
654
+ "extra_info": {
655
+ "input_name": x,
656
+ "input_config": info,
657
+ "received_value": val,
658
+ "exception_message": str(ex)
659
+ }
660
+ }
661
+ errors.append(error)
662
+ continue
663
+
664
+ if x not in validate_function_inputs and not validate_has_kwargs:
665
+ if "min" in extra_info and val < extra_info["min"]:
666
+ error = {
667
+ "type": "value_smaller_than_min",
668
+ "message": "Value {} smaller than min of {}".format(val, extra_info["min"]),
669
+ "details": f"{x}",
670
+ "extra_info": {
671
+ "input_name": x,
672
+ "input_config": info,
673
+ "received_value": val,
674
+ }
675
+ }
676
+ errors.append(error)
677
+ continue
678
+ if "max" in extra_info and val > extra_info["max"]:
679
+ error = {
680
+ "type": "value_bigger_than_max",
681
+ "message": "Value {} bigger than max of {}".format(val, extra_info["max"]),
682
+ "details": f"{x}",
683
+ "extra_info": {
684
+ "input_name": x,
685
+ "input_config": info,
686
+ "received_value": val,
687
+ }
688
+ }
689
+ errors.append(error)
690
+ continue
691
+
692
+ if isinstance(type_input, list):
693
+ if val not in type_input:
694
+ input_config = info
695
+ list_info = ""
696
+
697
+ # Don't send back gigantic lists like if they're lots of
698
+ # scanned model filepaths
699
+ if len(type_input) > 20:
700
+ list_info = f"(list of length {len(type_input)})"
701
+ input_config = None
702
+ else:
703
+ list_info = str(type_input)
704
+
705
+ error = {
706
+ "type": "value_not_in_list",
707
+ "message": "Value not in list",
708
+ "details": f"{x}: '{val}' not in {list_info}",
709
+ "extra_info": {
710
+ "input_name": x,
711
+ "input_config": input_config,
712
+ "received_value": val,
713
+ }
714
+ }
715
+ errors.append(error)
716
+ continue
717
+
718
+ if len(validate_function_inputs) > 0 or validate_has_kwargs:
719
+ input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
720
+ input_filtered = {}
721
+ for x in input_data_all:
722
+ if x in validate_function_inputs or validate_has_kwargs:
723
+ input_filtered[x] = input_data_all[x]
724
+ if 'input_types' in validate_function_inputs:
725
+ input_filtered['input_types'] = [received_types]
726
+
727
+ #ret = obj_class.VALIDATE_INPUTS(**input_filtered)
728
+ ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
729
+ for x in input_filtered:
730
+ for i, r in enumerate(ret):
731
+ if r is not True and not isinstance(r, ExecutionBlocker):
732
+ details = f"{x}"
733
+ if r is not False:
734
+ details += f" - {str(r)}"
735
+
736
+ error = {
737
+ "type": "custom_validation_failed",
738
+ "message": "Custom validation failed for node",
739
+ "details": details,
740
+ "extra_info": {
741
+ "input_name": x,
742
+ }
743
+ }
744
+ errors.append(error)
745
+ continue
746
+
747
+ if len(errors) > 0 or valid is not True:
748
+ ret = (False, errors, unique_id)
749
+ else:
750
+ ret = (True, [], unique_id)
751
+
752
+ validated[unique_id] = ret
753
+ return ret
754
+
755
+ def full_type_name(klass):
756
+ module = klass.__module__
757
+ if module == 'builtins':
758
+ return klass.__qualname__
759
+ return module + '.' + klass.__qualname__
760
+
761
+ def validate_prompt(prompt):
762
+ outputs = set()
763
+ for x in prompt:
764
+ if 'class_type' not in prompt[x]:
765
+ error = {
766
+ "type": "invalid_prompt",
767
+ "message": "Cannot execute because a node is missing the class_type property.",
768
+ "details": f"Node ID '#{x}'",
769
+ "extra_info": {}
770
+ }
771
+ return (False, error, [], [])
772
+
773
+ class_type = prompt[x]['class_type']
774
+ class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
775
+ if class_ is None:
776
+ error = {
777
+ "type": "invalid_prompt",
778
+ "message": f"Cannot execute because node {class_type} does not exist.",
779
+ "details": f"Node ID '#{x}'",
780
+ "extra_info": {}
781
+ }
782
+ return (False, error, [], [])
783
+
784
+ if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
785
+ outputs.add(x)
786
+
787
+ if len(outputs) == 0:
788
+ error = {
789
+ "type": "prompt_no_outputs",
790
+ "message": "Prompt has no outputs",
791
+ "details": "",
792
+ "extra_info": {}
793
+ }
794
+ return (False, error, [], [])
795
+
796
+ good_outputs = set()
797
+ errors = []
798
+ node_errors = {}
799
+ validated = {}
800
+ for o in outputs:
801
+ valid = False
802
+ reasons = []
803
+ try:
804
+ m = validate_inputs(prompt, o, validated)
805
+ valid = m[0]
806
+ reasons = m[1]
807
+ except Exception as ex:
808
+ typ, _, tb = sys.exc_info()
809
+ valid = False
810
+ exception_type = full_type_name(typ)
811
+ reasons = [{
812
+ "type": "exception_during_validation",
813
+ "message": "Exception when validating node",
814
+ "details": str(ex),
815
+ "extra_info": {
816
+ "exception_type": exception_type,
817
+ "traceback": traceback.format_tb(tb)
818
+ }
819
+ }]
820
+ validated[o] = (False, reasons, o)
821
+
822
+ if valid is True:
823
+ good_outputs.add(o)
824
+ else:
825
+ logging.error(f"Failed to validate prompt for output {o}:")
826
+ if len(reasons) > 0:
827
+ logging.error("* (prompt):")
828
+ for reason in reasons:
829
+ logging.error(f" - {reason['message']}: {reason['details']}")
830
+ errors += [(o, reasons)]
831
+ for node_id, result in validated.items():
832
+ valid = result[0]
833
+ reasons = result[1]
834
+ # If a node upstream has errors, the nodes downstream will also
835
+ # be reported as invalid, but there will be no errors attached.
836
+ # So don't return those nodes as having errors in the response.
837
+ if valid is not True and len(reasons) > 0:
838
+ if node_id not in node_errors:
839
+ class_type = prompt[node_id]['class_type']
840
+ node_errors[node_id] = {
841
+ "errors": reasons,
842
+ "dependent_outputs": [],
843
+ "class_type": class_type
844
+ }
845
+ logging.error(f"* {class_type} {node_id}:")
846
+ for reason in reasons:
847
+ logging.error(f" - {reason['message']}: {reason['details']}")
848
+ node_errors[node_id]["dependent_outputs"].append(o)
849
+ logging.error("Output will be ignored")
850
+
851
+ if len(good_outputs) == 0:
852
+ errors_list = []
853
+ for o, errors in errors:
854
+ for error in errors:
855
+ errors_list.append(f"{error['message']}: {error['details']}")
856
+ errors_list = "\n".join(errors_list)
857
+
858
+ error = {
859
+ "type": "prompt_outputs_failed_validation",
860
+ "message": "Prompt outputs failed validation",
861
+ "details": errors_list,
862
+ "extra_info": {}
863
+ }
864
+
865
+ return (False, error, list(good_outputs), node_errors)
866
+
867
+ return (True, None, list(good_outputs), node_errors)
868
+
869
+ MAXIMUM_HISTORY_SIZE = 10000
870
+
871
+ class PromptQueue:
872
+ def __init__(self, server):
873
+ self.server = server
874
+ self.mutex = threading.RLock()
875
+ self.not_empty = threading.Condition(self.mutex)
876
+ self.task_counter = 0
877
+ self.queue = []
878
+ self.currently_running = {}
879
+ self.history = {}
880
+ self.flags = {}
881
+ server.prompt_queue = self
882
+
883
+ def put(self, item):
884
+ with self.mutex:
885
+ heapq.heappush(self.queue, item)
886
+ self.server.queue_updated()
887
+ self.not_empty.notify()
888
+
889
+ def get(self, timeout=None):
890
+ with self.not_empty:
891
+ while len(self.queue) == 0:
892
+ self.not_empty.wait(timeout=timeout)
893
+ if timeout is not None and len(self.queue) == 0:
894
+ return None
895
+ item = heapq.heappop(self.queue)
896
+ i = self.task_counter
897
+ self.currently_running[i] = copy.deepcopy(item)
898
+ self.task_counter += 1
899
+ self.server.queue_updated()
900
+ return (item, i)
901
+
902
+ class ExecutionStatus(NamedTuple):
903
+ status_str: Literal['success', 'error']
904
+ completed: bool
905
+ messages: List[str]
906
+
907
+ def task_done(self, item_id, history_result,
908
+ status: Optional['PromptQueue.ExecutionStatus']):
909
+ with self.mutex:
910
+ prompt = self.currently_running.pop(item_id)
911
+ if len(self.history) > MAXIMUM_HISTORY_SIZE:
912
+ self.history.pop(next(iter(self.history)))
913
+
914
+ status_dict: Optional[dict] = None
915
+ if status is not None:
916
+ status_dict = copy.deepcopy(status._asdict())
917
+
918
+ self.history[prompt[1]] = {
919
+ "prompt": prompt,
920
+ "outputs": {},
921
+ 'status': status_dict,
922
+ }
923
+ self.history[prompt[1]].update(history_result)
924
+ self.server.queue_updated()
925
+
926
+ def get_current_queue(self):
927
+ with self.mutex:
928
+ out = []
929
+ for x in self.currently_running.values():
930
+ out += [x]
931
+ return (out, copy.deepcopy(self.queue))
932
+
933
+ def get_tasks_remaining(self):
934
+ with self.mutex:
935
+ return len(self.queue) + len(self.currently_running)
936
+
937
+ def wipe_queue(self):
938
+ with self.mutex:
939
+ self.queue = []
940
+ self.server.queue_updated()
941
+
942
+ def delete_queue_item(self, function):
943
+ with self.mutex:
944
+ for x in range(len(self.queue)):
945
+ if function(self.queue[x]):
946
+ if len(self.queue) == 1:
947
+ self.wipe_queue()
948
+ else:
949
+ self.queue.pop(x)
950
+ heapq.heapify(self.queue)
951
+ self.server.queue_updated()
952
+ return True
953
+ return False
954
+
955
+ def get_history(self, prompt_id=None, max_items=None, offset=-1):
956
+ with self.mutex:
957
+ if prompt_id is None:
958
+ out = {}
959
+ i = 0
960
+ if offset < 0 and max_items is not None:
961
+ offset = len(self.history) - max_items
962
+ for k in self.history:
963
+ if i >= offset:
964
+ out[k] = self.history[k]
965
+ if max_items is not None and len(out) >= max_items:
966
+ break
967
+ i += 1
968
+ return out
969
+ elif prompt_id in self.history:
970
+ return {prompt_id: copy.deepcopy(self.history[prompt_id])}
971
+ else:
972
+ return {}
973
+
974
+ def wipe_history(self):
975
+ with self.mutex:
976
+ self.history = {}
977
+
978
+ def delete_history_item(self, id_to_delete):
979
+ with self.mutex:
980
+ self.history.pop(id_to_delete, None)
981
+
982
+ def set_flag(self, name, data):
983
+ with self.mutex:
984
+ self.flags[name] = data
985
+ self.not_empty.notify()
986
+
987
+ def get_flags(self, reset=True):
988
+ with self.mutex:
989
+ if reset:
990
+ ret = self.flags
991
+ self.flags = {}
992
+ return ret
993
+ else:
994
+ return self.flags.copy()
main.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import comfy.options
2
+ comfy.options.enable_args_parsing()
3
+
4
+ import os
5
+ import importlib.util
6
+ import folder_paths
7
+ import time
8
+ from comfy.cli_args import args
9
+ from app.logger import setup_logger
10
+ import itertools
11
+ import utils.extra_config
12
+ import logging
13
+
14
+ if __name__ == "__main__":
15
+ #NOTE: These do not do anything on core ComfyUI which should already have no communication with the internet, they are for custom nodes.
16
+ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
17
+ os.environ['DO_NOT_TRACK'] = '1'
18
+
19
+
20
+ setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
21
+
22
+ def apply_custom_paths():
23
+ # extra model paths
24
+ extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
25
+ if os.path.isfile(extra_model_paths_config_path):
26
+ utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
27
+
28
+ if args.extra_model_paths_config:
29
+ for config_path in itertools.chain(*args.extra_model_paths_config):
30
+ utils.extra_config.load_extra_path_config(config_path)
31
+
32
+ # --output-directory, --input-directory, --user-directory
33
+ if args.output_directory:
34
+ output_dir = os.path.abspath(args.output_directory)
35
+ logging.info(f"Setting output directory to: {output_dir}")
36
+ folder_paths.set_output_directory(output_dir)
37
+
38
+ # These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
39
+ folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
40
+ folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
41
+ folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
42
+ folder_paths.add_model_folder_path("diffusion_models",
43
+ os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
44
+ folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
45
+
46
+ if args.input_directory:
47
+ input_dir = os.path.abspath(args.input_directory)
48
+ logging.info(f"Setting input directory to: {input_dir}")
49
+ folder_paths.set_input_directory(input_dir)
50
+
51
+ if args.user_directory:
52
+ user_dir = os.path.abspath(args.user_directory)
53
+ logging.info(f"Setting user directory to: {user_dir}")
54
+ folder_paths.set_user_directory(user_dir)
55
+
56
+
57
+ def execute_prestartup_script():
58
+ def execute_script(script_path):
59
+ module_name = os.path.splitext(script_path)[0]
60
+ try:
61
+ spec = importlib.util.spec_from_file_location(module_name, script_path)
62
+ module = importlib.util.module_from_spec(spec)
63
+ spec.loader.exec_module(module)
64
+ return True
65
+ except Exception as e:
66
+ logging.error(f"Failed to execute startup-script: {script_path} / {e}")
67
+ return False
68
+
69
+ if args.disable_all_custom_nodes:
70
+ return
71
+
72
+ node_paths = folder_paths.get_folder_paths("custom_nodes")
73
+ for custom_node_path in node_paths:
74
+ possible_modules = os.listdir(custom_node_path)
75
+ node_prestartup_times = []
76
+
77
+ for possible_module in possible_modules:
78
+ module_path = os.path.join(custom_node_path, possible_module)
79
+ if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
80
+ continue
81
+
82
+ script_path = os.path.join(module_path, "prestartup_script.py")
83
+ if os.path.exists(script_path):
84
+ time_before = time.perf_counter()
85
+ success = execute_script(script_path)
86
+ node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
87
+ if len(node_prestartup_times) > 0:
88
+ logging.info("\nPrestartup times for custom nodes:")
89
+ for n in sorted(node_prestartup_times):
90
+ if n[2]:
91
+ import_message = ""
92
+ else:
93
+ import_message = " (PRESTARTUP FAILED)"
94
+ logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
95
+ logging.info("")
96
+
97
+ apply_custom_paths()
98
+ execute_prestartup_script()
99
+
100
+
101
+ # Main code
102
+ import asyncio
103
+ import shutil
104
+ import threading
105
+ import gc
106
+
107
+
108
+ if os.name == "nt":
109
+ logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
110
+
111
+ if __name__ == "__main__":
112
+ if args.cuda_device is not None:
113
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
114
+ os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
115
+ logging.info("Set cuda device to: {}".format(args.cuda_device))
116
+
117
+ if args.oneapi_device_selector is not None:
118
+ os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
119
+ logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
120
+
121
+ if args.deterministic:
122
+ if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
123
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
124
+
125
+ import cuda_malloc
126
+
127
+ if args.windows_standalone_build:
128
+ try:
129
+ from fix_torch import fix_pytorch_libomp
130
+ fix_pytorch_libomp()
131
+ except:
132
+ pass
133
+
134
+ import comfy.utils
135
+
136
+ import execution
137
+ import server
138
+ from server import BinaryEventTypes
139
+ import nodes
140
+ import comfy.model_management
141
+ import comfyui_version
142
+
143
+
144
+ def cuda_malloc_warning():
145
+ device = comfy.model_management.get_torch_device()
146
+ device_name = comfy.model_management.get_torch_device_name(device)
147
+ cuda_malloc_warning = False
148
+ if "cudaMallocAsync" in device_name:
149
+ for b in cuda_malloc.blacklist:
150
+ if b in device_name:
151
+ cuda_malloc_warning = True
152
+ if cuda_malloc_warning:
153
+ logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
154
+
155
+
156
+ def prompt_worker(q, server_instance):
157
+ current_time: float = 0.0
158
+ e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
159
+ last_gc_collect = 0
160
+ need_gc = False
161
+ gc_collect_interval = 10.0
162
+
163
+ while True:
164
+ timeout = 1000.0
165
+ if need_gc:
166
+ timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
167
+
168
+ queue_item = q.get(timeout=timeout)
169
+ if queue_item is not None:
170
+ item, item_id = queue_item
171
+ execution_start_time = time.perf_counter()
172
+ prompt_id = item[1]
173
+ server_instance.last_prompt_id = prompt_id
174
+
175
+ e.execute(item[2], prompt_id, item[3], item[4])
176
+ need_gc = True
177
+ q.task_done(item_id,
178
+ e.history_result,
179
+ status=execution.PromptQueue.ExecutionStatus(
180
+ status_str='success' if e.success else 'error',
181
+ completed=e.success,
182
+ messages=e.status_messages))
183
+ if server_instance.client_id is not None:
184
+ server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
185
+
186
+ current_time = time.perf_counter()
187
+ execution_time = current_time - execution_start_time
188
+ logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
189
+
190
+ flags = q.get_flags()
191
+ free_memory = flags.get("free_memory", False)
192
+
193
+ if flags.get("unload_models", free_memory):
194
+ comfy.model_management.unload_all_models()
195
+ need_gc = True
196
+ last_gc_collect = 0
197
+
198
+ if free_memory:
199
+ e.reset()
200
+ need_gc = True
201
+ last_gc_collect = 0
202
+
203
+ if need_gc:
204
+ current_time = time.perf_counter()
205
+ if (current_time - last_gc_collect) > gc_collect_interval:
206
+ gc.collect()
207
+ comfy.model_management.soft_empty_cache()
208
+ last_gc_collect = current_time
209
+ need_gc = False
210
+
211
+
212
+ async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
213
+ addresses = []
214
+ for addr in address.split(","):
215
+ addresses.append((addr, port))
216
+ await asyncio.gather(
217
+ server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
218
+ )
219
+
220
+
221
+ def hijack_progress(server_instance):
222
+ def hook(value, total, preview_image):
223
+ comfy.model_management.throw_exception_if_processing_interrupted()
224
+ progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
225
+
226
+ server_instance.send_sync("progress", progress, server_instance.client_id)
227
+ if preview_image is not None:
228
+ server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
229
+
230
+ comfy.utils.set_progress_bar_global_hook(hook)
231
+
232
+
233
+ def cleanup_temp():
234
+ temp_dir = folder_paths.get_temp_directory()
235
+ if os.path.exists(temp_dir):
236
+ shutil.rmtree(temp_dir, ignore_errors=True)
237
+
238
+
239
+ def start_comfyui(asyncio_loop=None):
240
+ """
241
+ Starts the ComfyUI server using the provided asyncio event loop or creates a new one.
242
+ Returns the event loop, server instance, and a function to start the server asynchronously.
243
+ """
244
+ if args.temp_directory:
245
+ temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
246
+ logging.info(f"Setting temp directory to: {temp_dir}")
247
+ folder_paths.set_temp_directory(temp_dir)
248
+ cleanup_temp()
249
+
250
+ if args.windows_standalone_build:
251
+ try:
252
+ import new_updater
253
+ new_updater.update_windows_updater()
254
+ except:
255
+ pass
256
+
257
+ if not asyncio_loop:
258
+ asyncio_loop = asyncio.new_event_loop()
259
+ asyncio.set_event_loop(asyncio_loop)
260
+ prompt_server = server.PromptServer(asyncio_loop)
261
+ q = execution.PromptQueue(prompt_server)
262
+
263
+ nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
264
+
265
+ cuda_malloc_warning()
266
+
267
+ prompt_server.add_routes()
268
+ hijack_progress(prompt_server)
269
+
270
+ threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
271
+
272
+ if args.quick_test_for_ci:
273
+ exit(0)
274
+
275
+ os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
276
+ call_on_start = None
277
+ if args.auto_launch:
278
+ def startup_server(scheme, address, port):
279
+ import webbrowser
280
+ if os.name == 'nt' and address == '0.0.0.0':
281
+ address = '127.0.0.1'
282
+ if ':' in address:
283
+ address = "[{}]".format(address)
284
+ webbrowser.open(f"{scheme}://{address}:{port}")
285
+ call_on_start = startup_server
286
+
287
+ async def start_all():
288
+ await prompt_server.setup()
289
+ await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
290
+
291
+ # Returning these so that other code can integrate with the ComfyUI loop and server
292
+ return asyncio_loop, prompt_server, start_all
293
+
294
+
295
+ if __name__ == "__main__":
296
+ # Running directly, just start ComfyUI.
297
+ logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
298
+ event_loop, _, start_all_func = start_comfyui()
299
+ try:
300
+ event_loop.run_until_complete(start_all_func())
301
+ except KeyboardInterrupt:
302
+ logging.info("\nStopped server")
303
+
304
+ cleanup_temp()
nodes.py ADDED
@@ -0,0 +1,2262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import torch
3
+
4
+ import os
5
+ import sys
6
+ import json
7
+ import hashlib
8
+ import traceback
9
+ import math
10
+ import time
11
+ import random
12
+ import logging
13
+
14
+ from PIL import Image, ImageOps, ImageSequence
15
+ from PIL.PngImagePlugin import PngInfo
16
+
17
+ import numpy as np
18
+ import safetensors.torch
19
+
20
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
21
+
22
+ import comfy.diffusers_load
23
+ import comfy.samplers
24
+ import comfy.sample
25
+ import comfy.sd
26
+ import comfy.utils
27
+ import comfy.controlnet
28
+ from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
29
+
30
+ import comfy.clip_vision
31
+
32
+ import comfy.model_management
33
+ from comfy.cli_args import args
34
+
35
+ import importlib
36
+
37
+ import folder_paths
38
+ import latent_preview
39
+ import node_helpers
40
+
41
+ def before_node_execution():
42
+ comfy.model_management.throw_exception_if_processing_interrupted()
43
+
44
+ def interrupt_processing(value=True):
45
+ comfy.model_management.interrupt_current_processing(value)
46
+
47
+ MAX_RESOLUTION=16384
48
+
49
+ class CLIPTextEncode(ComfyNodeABC):
50
+ @classmethod
51
+ def INPUT_TYPES(s) -> InputTypeDict:
52
+ return {
53
+ "required": {
54
+ "text": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
55
+ "clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."})
56
+ }
57
+ }
58
+ RETURN_TYPES = (IO.CONDITIONING,)
59
+ OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",)
60
+ FUNCTION = "encode"
61
+
62
+ CATEGORY = "conditioning"
63
+ DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
64
+
65
+ def encode(self, clip, text):
66
+ if clip is None:
67
+ raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
68
+ tokens = clip.tokenize(text)
69
+ return (clip.encode_from_tokens_scheduled(tokens), )
70
+
71
+
72
+ class ConditioningCombine:
73
+ @classmethod
74
+ def INPUT_TYPES(s):
75
+ return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
76
+ RETURN_TYPES = ("CONDITIONING",)
77
+ FUNCTION = "combine"
78
+
79
+ CATEGORY = "conditioning"
80
+
81
+ def combine(self, conditioning_1, conditioning_2):
82
+ return (conditioning_1 + conditioning_2, )
83
+
84
+ class ConditioningAverage :
85
+ @classmethod
86
+ def INPUT_TYPES(s):
87
+ return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ),
88
+ "conditioning_to_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
89
+ }}
90
+ RETURN_TYPES = ("CONDITIONING",)
91
+ FUNCTION = "addWeighted"
92
+
93
+ CATEGORY = "conditioning"
94
+
95
+ def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_strength):
96
+ out = []
97
+
98
+ if len(conditioning_from) > 1:
99
+ logging.warning("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
100
+
101
+ cond_from = conditioning_from[0][0]
102
+ pooled_output_from = conditioning_from[0][1].get("pooled_output", None)
103
+
104
+ for i in range(len(conditioning_to)):
105
+ t1 = conditioning_to[i][0]
106
+ pooled_output_to = conditioning_to[i][1].get("pooled_output", pooled_output_from)
107
+ t0 = cond_from[:,:t1.shape[1]]
108
+ if t0.shape[1] < t1.shape[1]:
109
+ t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1)
110
+
111
+ tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength))
112
+ t_to = conditioning_to[i][1].copy()
113
+ if pooled_output_from is not None and pooled_output_to is not None:
114
+ t_to["pooled_output"] = torch.mul(pooled_output_to, conditioning_to_strength) + torch.mul(pooled_output_from, (1.0 - conditioning_to_strength))
115
+ elif pooled_output_from is not None:
116
+ t_to["pooled_output"] = pooled_output_from
117
+
118
+ n = [tw, t_to]
119
+ out.append(n)
120
+ return (out, )
121
+
122
+ class ConditioningConcat:
123
+ @classmethod
124
+ def INPUT_TYPES(s):
125
+ return {"required": {
126
+ "conditioning_to": ("CONDITIONING",),
127
+ "conditioning_from": ("CONDITIONING",),
128
+ }}
129
+ RETURN_TYPES = ("CONDITIONING",)
130
+ FUNCTION = "concat"
131
+
132
+ CATEGORY = "conditioning"
133
+
134
+ def concat(self, conditioning_to, conditioning_from):
135
+ out = []
136
+
137
+ if len(conditioning_from) > 1:
138
+ logging.warning("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
139
+
140
+ cond_from = conditioning_from[0][0]
141
+
142
+ for i in range(len(conditioning_to)):
143
+ t1 = conditioning_to[i][0]
144
+ tw = torch.cat((t1, cond_from),1)
145
+ n = [tw, conditioning_to[i][1].copy()]
146
+ out.append(n)
147
+
148
+ return (out, )
149
+
150
+ class ConditioningSetArea:
151
+ @classmethod
152
+ def INPUT_TYPES(s):
153
+ return {"required": {"conditioning": ("CONDITIONING", ),
154
+ "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
155
+ "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
156
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
157
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
158
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
159
+ }}
160
+ RETURN_TYPES = ("CONDITIONING",)
161
+ FUNCTION = "append"
162
+
163
+ CATEGORY = "conditioning"
164
+
165
+ def append(self, conditioning, width, height, x, y, strength):
166
+ c = node_helpers.conditioning_set_values(conditioning, {"area": (height // 8, width // 8, y // 8, x // 8),
167
+ "strength": strength,
168
+ "set_area_to_bounds": False})
169
+ return (c, )
170
+
171
+ class ConditioningSetAreaPercentage:
172
+ @classmethod
173
+ def INPUT_TYPES(s):
174
+ return {"required": {"conditioning": ("CONDITIONING", ),
175
+ "width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
176
+ "height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
177
+ "x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
178
+ "y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
179
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
180
+ }}
181
+ RETURN_TYPES = ("CONDITIONING",)
182
+ FUNCTION = "append"
183
+
184
+ CATEGORY = "conditioning"
185
+
186
+ def append(self, conditioning, width, height, x, y, strength):
187
+ c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", height, width, y, x),
188
+ "strength": strength,
189
+ "set_area_to_bounds": False})
190
+ return (c, )
191
+
192
+ class ConditioningSetAreaStrength:
193
+ @classmethod
194
+ def INPUT_TYPES(s):
195
+ return {"required": {"conditioning": ("CONDITIONING", ),
196
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
197
+ }}
198
+ RETURN_TYPES = ("CONDITIONING",)
199
+ FUNCTION = "append"
200
+
201
+ CATEGORY = "conditioning"
202
+
203
+ def append(self, conditioning, strength):
204
+ c = node_helpers.conditioning_set_values(conditioning, {"strength": strength})
205
+ return (c, )
206
+
207
+
208
+ class ConditioningSetMask:
209
+ @classmethod
210
+ def INPUT_TYPES(s):
211
+ return {"required": {"conditioning": ("CONDITIONING", ),
212
+ "mask": ("MASK", ),
213
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
214
+ "set_cond_area": (["default", "mask bounds"],),
215
+ }}
216
+ RETURN_TYPES = ("CONDITIONING",)
217
+ FUNCTION = "append"
218
+
219
+ CATEGORY = "conditioning"
220
+
221
+ def append(self, conditioning, mask, set_cond_area, strength):
222
+ set_area_to_bounds = False
223
+ if set_cond_area != "default":
224
+ set_area_to_bounds = True
225
+ if len(mask.shape) < 3:
226
+ mask = mask.unsqueeze(0)
227
+
228
+ c = node_helpers.conditioning_set_values(conditioning, {"mask": mask,
229
+ "set_area_to_bounds": set_area_to_bounds,
230
+ "mask_strength": strength})
231
+ return (c, )
232
+
233
+ class ConditioningZeroOut:
234
+ @classmethod
235
+ def INPUT_TYPES(s):
236
+ return {"required": {"conditioning": ("CONDITIONING", )}}
237
+ RETURN_TYPES = ("CONDITIONING",)
238
+ FUNCTION = "zero_out"
239
+
240
+ CATEGORY = "advanced/conditioning"
241
+
242
+ def zero_out(self, conditioning):
243
+ c = []
244
+ for t in conditioning:
245
+ d = t[1].copy()
246
+ pooled_output = d.get("pooled_output", None)
247
+ if pooled_output is not None:
248
+ d["pooled_output"] = torch.zeros_like(pooled_output)
249
+ n = [torch.zeros_like(t[0]), d]
250
+ c.append(n)
251
+ return (c, )
252
+
253
+ class ConditioningSetTimestepRange:
254
+ @classmethod
255
+ def INPUT_TYPES(s):
256
+ return {"required": {"conditioning": ("CONDITIONING", ),
257
+ "start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
258
+ "end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
259
+ }}
260
+ RETURN_TYPES = ("CONDITIONING",)
261
+ FUNCTION = "set_range"
262
+
263
+ CATEGORY = "advanced/conditioning"
264
+
265
+ def set_range(self, conditioning, start, end):
266
+ c = node_helpers.conditioning_set_values(conditioning, {"start_percent": start,
267
+ "end_percent": end})
268
+ return (c, )
269
+
270
+ class VAEDecode:
271
+ @classmethod
272
+ def INPUT_TYPES(s):
273
+ return {
274
+ "required": {
275
+ "samples": ("LATENT", {"tooltip": "The latent to be decoded."}),
276
+ "vae": ("VAE", {"tooltip": "The VAE model used for decoding the latent."})
277
+ }
278
+ }
279
+ RETURN_TYPES = ("IMAGE",)
280
+ OUTPUT_TOOLTIPS = ("The decoded image.",)
281
+ FUNCTION = "decode"
282
+
283
+ CATEGORY = "latent"
284
+ DESCRIPTION = "Decodes latent images back into pixel space images."
285
+
286
+ def decode(self, vae, samples):
287
+ images = vae.decode(samples["samples"])
288
+ if len(images.shape) == 5: #Combine batches
289
+ images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
290
+ return (images, )
291
+
292
+ class VAEDecodeTiled:
293
+ @classmethod
294
+ def INPUT_TYPES(s):
295
+ return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
296
+ "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}),
297
+ "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
298
+ "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time."}),
299
+ "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}),
300
+ }}
301
+ RETURN_TYPES = ("IMAGE",)
302
+ FUNCTION = "decode"
303
+
304
+ CATEGORY = "_for_testing"
305
+
306
+ def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
307
+ if tile_size < overlap * 4:
308
+ overlap = tile_size // 4
309
+ if temporal_size < temporal_overlap * 2:
310
+ temporal_overlap = temporal_overlap // 2
311
+ temporal_compression = vae.temporal_compression_decode()
312
+ if temporal_compression is not None:
313
+ temporal_size = max(2, temporal_size // temporal_compression)
314
+ temporal_overlap = max(1, min(temporal_size // 2, temporal_overlap // temporal_compression))
315
+ else:
316
+ temporal_size = None
317
+ temporal_overlap = None
318
+
319
+ compression = vae.spacial_compression_decode()
320
+ images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap)
321
+ if len(images.shape) == 5: #Combine batches
322
+ images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
323
+ return (images, )
324
+
325
+ class VAEEncode:
326
+ @classmethod
327
+ def INPUT_TYPES(s):
328
+ return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
329
+ RETURN_TYPES = ("LATENT",)
330
+ FUNCTION = "encode"
331
+
332
+ CATEGORY = "latent"
333
+
334
+ def encode(self, vae, pixels):
335
+ t = vae.encode(pixels[:,:,:,:3])
336
+ return ({"samples":t}, )
337
+
338
+ class VAEEncodeTiled:
339
+ @classmethod
340
+ def INPUT_TYPES(s):
341
+ return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
342
+ "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
343
+ "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
344
+ "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time."}),
345
+ "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}),
346
+ }}
347
+ RETURN_TYPES = ("LATENT",)
348
+ FUNCTION = "encode"
349
+
350
+ CATEGORY = "_for_testing"
351
+
352
+ def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
353
+ t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
354
+ return ({"samples": t}, )
355
+
356
+ class VAEEncodeForInpaint:
357
+ @classmethod
358
+ def INPUT_TYPES(s):
359
+ return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
360
+ RETURN_TYPES = ("LATENT",)
361
+ FUNCTION = "encode"
362
+
363
+ CATEGORY = "latent/inpaint"
364
+
365
+ def encode(self, vae, pixels, mask, grow_mask_by=6):
366
+ x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio
367
+ y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio
368
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
369
+
370
+ pixels = pixels.clone()
371
+ if pixels.shape[1] != x or pixels.shape[2] != y:
372
+ x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2
373
+ y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2
374
+ pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
375
+ mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
376
+
377
+ #grow mask by a few pixels to keep things seamless in latent space
378
+ if grow_mask_by == 0:
379
+ mask_erosion = mask
380
+ else:
381
+ kernel_tensor = torch.ones((1, 1, grow_mask_by, grow_mask_by))
382
+ padding = math.ceil((grow_mask_by - 1) / 2)
383
+
384
+ mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=padding), 0, 1)
385
+
386
+ m = (1.0 - mask.round()).squeeze(1)
387
+ for i in range(3):
388
+ pixels[:,:,:,i] -= 0.5
389
+ pixels[:,:,:,i] *= m
390
+ pixels[:,:,:,i] += 0.5
391
+ t = vae.encode(pixels)
392
+
393
+ return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
394
+
395
+
396
+ class InpaintModelConditioning:
397
+ @classmethod
398
+ def INPUT_TYPES(s):
399
+ return {"required": {"positive": ("CONDITIONING", ),
400
+ "negative": ("CONDITIONING", ),
401
+ "vae": ("VAE", ),
402
+ "pixels": ("IMAGE", ),
403
+ "mask": ("MASK", ),
404
+ "noise_mask": ("BOOLEAN", {"default": True, "tooltip": "Add a noise mask to the latent so sampling will only happen within the mask. Might improve results or completely break things depending on the model."}),
405
+ }}
406
+
407
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
408
+ RETURN_NAMES = ("positive", "negative", "latent")
409
+ FUNCTION = "encode"
410
+
411
+ CATEGORY = "conditioning/inpaint"
412
+
413
+ def encode(self, positive, negative, pixels, vae, mask, noise_mask=True):
414
+ x = (pixels.shape[1] // 8) * 8
415
+ y = (pixels.shape[2] // 8) * 8
416
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
417
+
418
+ orig_pixels = pixels
419
+ pixels = orig_pixels.clone()
420
+ if pixels.shape[1] != x or pixels.shape[2] != y:
421
+ x_offset = (pixels.shape[1] % 8) // 2
422
+ y_offset = (pixels.shape[2] % 8) // 2
423
+ pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
424
+ mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
425
+
426
+ m = (1.0 - mask.round()).squeeze(1)
427
+ for i in range(3):
428
+ pixels[:,:,:,i] -= 0.5
429
+ pixels[:,:,:,i] *= m
430
+ pixels[:,:,:,i] += 0.5
431
+ concat_latent = vae.encode(pixels)
432
+ orig_latent = vae.encode(orig_pixels)
433
+
434
+ out_latent = {}
435
+
436
+ out_latent["samples"] = orig_latent
437
+ if noise_mask:
438
+ out_latent["noise_mask"] = mask
439
+
440
+ out = []
441
+ for conditioning in [positive, negative]:
442
+ c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent,
443
+ "concat_mask": mask})
444
+ out.append(c)
445
+ return (out[0], out[1], out_latent)
446
+
447
+
448
+ class SaveLatent:
449
+ def __init__(self):
450
+ self.output_dir = folder_paths.get_output_directory()
451
+
452
+ @classmethod
453
+ def INPUT_TYPES(s):
454
+ return {"required": { "samples": ("LATENT", ),
455
+ "filename_prefix": ("STRING", {"default": "latents/ComfyUI"})},
456
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
457
+ }
458
+ RETURN_TYPES = ()
459
+ FUNCTION = "save"
460
+
461
+ OUTPUT_NODE = True
462
+
463
+ CATEGORY = "_for_testing"
464
+
465
+ def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
466
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
467
+
468
+ # support save metadata for latent sharing
469
+ prompt_info = ""
470
+ if prompt is not None:
471
+ prompt_info = json.dumps(prompt)
472
+
473
+ metadata = None
474
+ if not args.disable_metadata:
475
+ metadata = {"prompt": prompt_info}
476
+ if extra_pnginfo is not None:
477
+ for x in extra_pnginfo:
478
+ metadata[x] = json.dumps(extra_pnginfo[x])
479
+
480
+ file = f"{filename}_{counter:05}_.latent"
481
+
482
+ results = list()
483
+ results.append({
484
+ "filename": file,
485
+ "subfolder": subfolder,
486
+ "type": "output"
487
+ })
488
+
489
+ file = os.path.join(full_output_folder, file)
490
+
491
+ output = {}
492
+ output["latent_tensor"] = samples["samples"]
493
+ output["latent_format_version_0"] = torch.tensor([])
494
+
495
+ comfy.utils.save_torch_file(output, file, metadata=metadata)
496
+ return { "ui": { "latents": results } }
497
+
498
+
499
+ class LoadLatent:
500
+ @classmethod
501
+ def INPUT_TYPES(s):
502
+ input_dir = folder_paths.get_input_directory()
503
+ files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
504
+ return {"required": {"latent": [sorted(files), ]}, }
505
+
506
+ CATEGORY = "_for_testing"
507
+
508
+ RETURN_TYPES = ("LATENT", )
509
+ FUNCTION = "load"
510
+
511
+ def load(self, latent):
512
+ latent_path = folder_paths.get_annotated_filepath(latent)
513
+ latent = safetensors.torch.load_file(latent_path, device="cpu")
514
+ multiplier = 1.0
515
+ if "latent_format_version_0" not in latent:
516
+ multiplier = 1.0 / 0.18215
517
+ samples = {"samples": latent["latent_tensor"].float() * multiplier}
518
+ return (samples, )
519
+
520
+ @classmethod
521
+ def IS_CHANGED(s, latent):
522
+ image_path = folder_paths.get_annotated_filepath(latent)
523
+ m = hashlib.sha256()
524
+ with open(image_path, 'rb') as f:
525
+ m.update(f.read())
526
+ return m.digest().hex()
527
+
528
+ @classmethod
529
+ def VALIDATE_INPUTS(s, latent):
530
+ if not folder_paths.exists_annotated_filepath(latent):
531
+ return "Invalid latent file: {}".format(latent)
532
+ return True
533
+
534
+
535
+ class CheckpointLoader:
536
+ @classmethod
537
+ def INPUT_TYPES(s):
538
+ return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ),
539
+ "ckpt_name": (folder_paths.get_filename_list("checkpoints"), )}}
540
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
541
+ FUNCTION = "load_checkpoint"
542
+
543
+ CATEGORY = "advanced/loaders"
544
+ DEPRECATED = True
545
+
546
+ def load_checkpoint(self, config_name, ckpt_name):
547
+ config_path = folder_paths.get_full_path("configs", config_name)
548
+ ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
549
+ return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
550
+
551
+ class CheckpointLoaderSimple:
552
+ @classmethod
553
+ def INPUT_TYPES(s):
554
+ return {
555
+ "required": {
556
+ "ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
557
+ }
558
+ }
559
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
560
+ OUTPUT_TOOLTIPS = ("The model used for denoising latents.",
561
+ "The CLIP model used for encoding text prompts.",
562
+ "The VAE model used for encoding and decoding images to and from latent space.")
563
+ FUNCTION = "load_checkpoint"
564
+
565
+ CATEGORY = "loaders"
566
+ DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
567
+
568
+ def load_checkpoint(self, ckpt_name):
569
+ ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
570
+ out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
571
+ return out[:3]
572
+
573
+ class DiffusersLoader:
574
+ @classmethod
575
+ def INPUT_TYPES(cls):
576
+ paths = []
577
+ for search_path in folder_paths.get_folder_paths("diffusers"):
578
+ if os.path.exists(search_path):
579
+ for root, subdir, files in os.walk(search_path, followlinks=True):
580
+ if "model_index.json" in files:
581
+ paths.append(os.path.relpath(root, start=search_path))
582
+
583
+ return {"required": {"model_path": (paths,), }}
584
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
585
+ FUNCTION = "load_checkpoint"
586
+
587
+ CATEGORY = "advanced/loaders/deprecated"
588
+
589
+ def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
590
+ for search_path in folder_paths.get_folder_paths("diffusers"):
591
+ if os.path.exists(search_path):
592
+ path = os.path.join(search_path, model_path)
593
+ if os.path.exists(path):
594
+ model_path = path
595
+ break
596
+
597
+ return comfy.diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
598
+
599
+
600
+ class unCLIPCheckpointLoader:
601
+ @classmethod
602
+ def INPUT_TYPES(s):
603
+ return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
604
+ }}
605
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION")
606
+ FUNCTION = "load_checkpoint"
607
+
608
+ CATEGORY = "loaders"
609
+
610
+ def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
611
+ ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
612
+ out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
613
+ return out
614
+
615
+ class CLIPSetLastLayer:
616
+ @classmethod
617
+ def INPUT_TYPES(s):
618
+ return {"required": { "clip": ("CLIP", ),
619
+ "stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}),
620
+ }}
621
+ RETURN_TYPES = ("CLIP",)
622
+ FUNCTION = "set_last_layer"
623
+
624
+ CATEGORY = "conditioning"
625
+
626
+ def set_last_layer(self, clip, stop_at_clip_layer):
627
+ clip = clip.clone()
628
+ clip.clip_layer(stop_at_clip_layer)
629
+ return (clip,)
630
+
631
+ class LoraLoader:
632
+ def __init__(self):
633
+ self.loaded_lora = None
634
+
635
+ @classmethod
636
+ def INPUT_TYPES(s):
637
+ return {
638
+ "required": {
639
+ "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
640
+ "clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}),
641
+ "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
642
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
643
+ "strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}),
644
+ }
645
+ }
646
+
647
+ RETURN_TYPES = ("MODEL", "CLIP")
648
+ OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.")
649
+ FUNCTION = "load_lora"
650
+
651
+ CATEGORY = "loaders"
652
+ DESCRIPTION = "LoRAs are used to modify diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together."
653
+
654
+ def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
655
+ if strength_model == 0 and strength_clip == 0:
656
+ return (model, clip)
657
+
658
+ lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
659
+ lora = None
660
+ if self.loaded_lora is not None:
661
+ if self.loaded_lora[0] == lora_path:
662
+ lora = self.loaded_lora[1]
663
+ else:
664
+ self.loaded_lora = None
665
+
666
+ if lora is None:
667
+ lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
668
+ self.loaded_lora = (lora_path, lora)
669
+
670
+ model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
671
+ return (model_lora, clip_lora)
672
+
673
+ class LoraLoaderModelOnly(LoraLoader):
674
+ @classmethod
675
+ def INPUT_TYPES(s):
676
+ return {"required": { "model": ("MODEL",),
677
+ "lora_name": (folder_paths.get_filename_list("loras"), ),
678
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
679
+ }}
680
+ RETURN_TYPES = ("MODEL",)
681
+ FUNCTION = "load_lora_model_only"
682
+
683
+ def load_lora_model_only(self, model, lora_name, strength_model):
684
+ return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
685
+
686
+ class VAELoader:
687
+ @staticmethod
688
+ def vae_list():
689
+ vaes = folder_paths.get_filename_list("vae")
690
+ approx_vaes = folder_paths.get_filename_list("vae_approx")
691
+ sdxl_taesd_enc = False
692
+ sdxl_taesd_dec = False
693
+ sd1_taesd_enc = False
694
+ sd1_taesd_dec = False
695
+ sd3_taesd_enc = False
696
+ sd3_taesd_dec = False
697
+ f1_taesd_enc = False
698
+ f1_taesd_dec = False
699
+
700
+ for v in approx_vaes:
701
+ if v.startswith("taesd_decoder."):
702
+ sd1_taesd_dec = True
703
+ elif v.startswith("taesd_encoder."):
704
+ sd1_taesd_enc = True
705
+ elif v.startswith("taesdxl_decoder."):
706
+ sdxl_taesd_dec = True
707
+ elif v.startswith("taesdxl_encoder."):
708
+ sdxl_taesd_enc = True
709
+ elif v.startswith("taesd3_decoder."):
710
+ sd3_taesd_dec = True
711
+ elif v.startswith("taesd3_encoder."):
712
+ sd3_taesd_enc = True
713
+ elif v.startswith("taef1_encoder."):
714
+ f1_taesd_dec = True
715
+ elif v.startswith("taef1_decoder."):
716
+ f1_taesd_enc = True
717
+ if sd1_taesd_dec and sd1_taesd_enc:
718
+ vaes.append("taesd")
719
+ if sdxl_taesd_dec and sdxl_taesd_enc:
720
+ vaes.append("taesdxl")
721
+ if sd3_taesd_dec and sd3_taesd_enc:
722
+ vaes.append("taesd3")
723
+ if f1_taesd_dec and f1_taesd_enc:
724
+ vaes.append("taef1")
725
+ return vaes
726
+
727
+ @staticmethod
728
+ def load_taesd(name):
729
+ sd = {}
730
+ approx_vaes = folder_paths.get_filename_list("vae_approx")
731
+
732
+ encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
733
+ decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
734
+
735
+ enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
736
+ for k in enc:
737
+ sd["taesd_encoder.{}".format(k)] = enc[k]
738
+
739
+ dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
740
+ for k in dec:
741
+ sd["taesd_decoder.{}".format(k)] = dec[k]
742
+
743
+ if name == "taesd":
744
+ sd["vae_scale"] = torch.tensor(0.18215)
745
+ sd["vae_shift"] = torch.tensor(0.0)
746
+ elif name == "taesdxl":
747
+ sd["vae_scale"] = torch.tensor(0.13025)
748
+ sd["vae_shift"] = torch.tensor(0.0)
749
+ elif name == "taesd3":
750
+ sd["vae_scale"] = torch.tensor(1.5305)
751
+ sd["vae_shift"] = torch.tensor(0.0609)
752
+ elif name == "taef1":
753
+ sd["vae_scale"] = torch.tensor(0.3611)
754
+ sd["vae_shift"] = torch.tensor(0.1159)
755
+ return sd
756
+
757
+ @classmethod
758
+ def INPUT_TYPES(s):
759
+ return {"required": { "vae_name": (s.vae_list(), )}}
760
+ RETURN_TYPES = ("VAE",)
761
+ FUNCTION = "load_vae"
762
+
763
+ CATEGORY = "loaders"
764
+
765
+ #TODO: scale factor?
766
+ def load_vae(self, vae_name):
767
+ if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
768
+ sd = self.load_taesd(vae_name)
769
+ else:
770
+ vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
771
+ sd = comfy.utils.load_torch_file(vae_path)
772
+ vae = comfy.sd.VAE(sd=sd)
773
+ return (vae,)
774
+
775
+ class ControlNetLoader:
776
+ @classmethod
777
+ def INPUT_TYPES(s):
778
+ return {"required": { "control_net_name": (folder_paths.get_filename_list("controlnet"), )}}
779
+
780
+ RETURN_TYPES = ("CONTROL_NET",)
781
+ FUNCTION = "load_controlnet"
782
+
783
+ CATEGORY = "loaders"
784
+
785
+ def load_controlnet(self, control_net_name):
786
+ controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
787
+ controlnet = comfy.controlnet.load_controlnet(controlnet_path)
788
+ return (controlnet,)
789
+
790
+ class DiffControlNetLoader:
791
+ @classmethod
792
+ def INPUT_TYPES(s):
793
+ return {"required": { "model": ("MODEL",),
794
+ "control_net_name": (folder_paths.get_filename_list("controlnet"), )}}
795
+
796
+ RETURN_TYPES = ("CONTROL_NET",)
797
+ FUNCTION = "load_controlnet"
798
+
799
+ CATEGORY = "loaders"
800
+
801
+ def load_controlnet(self, model, control_net_name):
802
+ controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
803
+ controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
804
+ return (controlnet,)
805
+
806
+
807
+ class ControlNetApply:
808
+ @classmethod
809
+ def INPUT_TYPES(s):
810
+ return {"required": {"conditioning": ("CONDITIONING", ),
811
+ "control_net": ("CONTROL_NET", ),
812
+ "image": ("IMAGE", ),
813
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01})
814
+ }}
815
+ RETURN_TYPES = ("CONDITIONING",)
816
+ FUNCTION = "apply_controlnet"
817
+
818
+ DEPRECATED = True
819
+ CATEGORY = "conditioning/controlnet"
820
+
821
+ def apply_controlnet(self, conditioning, control_net, image, strength):
822
+ if strength == 0:
823
+ return (conditioning, )
824
+
825
+ c = []
826
+ control_hint = image.movedim(-1,1)
827
+ for t in conditioning:
828
+ n = [t[0], t[1].copy()]
829
+ c_net = control_net.copy().set_cond_hint(control_hint, strength)
830
+ if 'control' in t[1]:
831
+ c_net.set_previous_controlnet(t[1]['control'])
832
+ n[1]['control'] = c_net
833
+ n[1]['control_apply_to_uncond'] = True
834
+ c.append(n)
835
+ return (c, )
836
+
837
+
838
+ class ControlNetApplyAdvanced:
839
+ @classmethod
840
+ def INPUT_TYPES(s):
841
+ return {"required": {"positive": ("CONDITIONING", ),
842
+ "negative": ("CONDITIONING", ),
843
+ "control_net": ("CONTROL_NET", ),
844
+ "image": ("IMAGE", ),
845
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
846
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
847
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
848
+ },
849
+ "optional": {"vae": ("VAE", ),
850
+ }
851
+ }
852
+
853
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING")
854
+ RETURN_NAMES = ("positive", "negative")
855
+ FUNCTION = "apply_controlnet"
856
+
857
+ CATEGORY = "conditioning/controlnet"
858
+
859
+ def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]):
860
+ if strength == 0:
861
+ return (positive, negative)
862
+
863
+ control_hint = image.movedim(-1,1)
864
+ cnets = {}
865
+
866
+ out = []
867
+ for conditioning in [positive, negative]:
868
+ c = []
869
+ for t in conditioning:
870
+ d = t[1].copy()
871
+
872
+ prev_cnet = d.get('control', None)
873
+ if prev_cnet in cnets:
874
+ c_net = cnets[prev_cnet]
875
+ else:
876
+ c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae=vae, extra_concat=extra_concat)
877
+ c_net.set_previous_controlnet(prev_cnet)
878
+ cnets[prev_cnet] = c_net
879
+
880
+ d['control'] = c_net
881
+ d['control_apply_to_uncond'] = False
882
+ n = [t[0], d]
883
+ c.append(n)
884
+ out.append(c)
885
+ return (out[0], out[1])
886
+
887
+
888
+ class UNETLoader:
889
+ @classmethod
890
+ def INPUT_TYPES(s):
891
+ return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
892
+ "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
893
+ }}
894
+ RETURN_TYPES = ("MODEL",)
895
+ FUNCTION = "load_unet"
896
+
897
+ CATEGORY = "advanced/loaders"
898
+
899
+ def load_unet(self, unet_name, weight_dtype):
900
+ model_options = {}
901
+ if weight_dtype == "fp8_e4m3fn":
902
+ model_options["dtype"] = torch.float8_e4m3fn
903
+ elif weight_dtype == "fp8_e4m3fn_fast":
904
+ model_options["dtype"] = torch.float8_e4m3fn
905
+ model_options["fp8_optimizations"] = True
906
+ elif weight_dtype == "fp8_e5m2":
907
+ model_options["dtype"] = torch.float8_e5m2
908
+
909
+ unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
910
+ model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
911
+ return (model,)
912
+
913
+ class CLIPLoader:
914
+ @classmethod
915
+ def INPUT_TYPES(s):
916
+ return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
917
+ "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos"], ),
918
+ },
919
+ "optional": {
920
+ "device": (["default", "cpu"], {"advanced": True}),
921
+ }}
922
+ RETURN_TYPES = ("CLIP",)
923
+ FUNCTION = "load_clip"
924
+
925
+ CATEGORY = "advanced/loaders"
926
+
927
+ DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5\ncosmos: old t5 xxl"
928
+
929
+ def load_clip(self, clip_name, type="stable_diffusion", device="default"):
930
+ if type == "stable_cascade":
931
+ clip_type = comfy.sd.CLIPType.STABLE_CASCADE
932
+ elif type == "sd3":
933
+ clip_type = comfy.sd.CLIPType.SD3
934
+ elif type == "stable_audio":
935
+ clip_type = comfy.sd.CLIPType.STABLE_AUDIO
936
+ elif type == "mochi":
937
+ clip_type = comfy.sd.CLIPType.MOCHI
938
+ elif type == "ltxv":
939
+ clip_type = comfy.sd.CLIPType.LTXV
940
+ elif type == "pixart":
941
+ clip_type = comfy.sd.CLIPType.PIXART
942
+ elif type == "cosmos":
943
+ clip_type = comfy.sd.CLIPType.COSMOS
944
+ else:
945
+ clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
946
+
947
+ model_options = {}
948
+ if device == "cpu":
949
+ model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
950
+
951
+ clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
952
+ clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
953
+ return (clip,)
954
+
955
+ class DualCLIPLoader:
956
+ @classmethod
957
+ def INPUT_TYPES(s):
958
+ return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
959
+ "clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
960
+ "type": (["sdxl", "sd3", "flux", "hunyuan_video"], ),
961
+ },
962
+ "optional": {
963
+ "device": (["default", "cpu"], {"advanced": True}),
964
+ }}
965
+ RETURN_TYPES = ("CLIP",)
966
+ FUNCTION = "load_clip"
967
+
968
+ CATEGORY = "advanced/loaders"
969
+
970
+ DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
971
+
972
+ def load_clip(self, clip_name1, clip_name2, type, device="default"):
973
+ clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
974
+ clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
975
+ if type == "sdxl":
976
+ clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
977
+ elif type == "sd3":
978
+ clip_type = comfy.sd.CLIPType.SD3
979
+ elif type == "flux":
980
+ clip_type = comfy.sd.CLIPType.FLUX
981
+ elif type == "hunyuan_video":
982
+ clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO
983
+
984
+ model_options = {}
985
+ if device == "cpu":
986
+ model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
987
+
988
+ clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
989
+ return (clip,)
990
+
991
+ class CLIPVisionLoader:
992
+ @classmethod
993
+ def INPUT_TYPES(s):
994
+ return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"), ),
995
+ }}
996
+ RETURN_TYPES = ("CLIP_VISION",)
997
+ FUNCTION = "load_clip"
998
+
999
+ CATEGORY = "loaders"
1000
+
1001
+ def load_clip(self, clip_name):
1002
+ clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
1003
+ clip_vision = comfy.clip_vision.load(clip_path)
1004
+ return (clip_vision,)
1005
+
1006
+ class CLIPVisionEncode:
1007
+ @classmethod
1008
+ def INPUT_TYPES(s):
1009
+ return {"required": { "clip_vision": ("CLIP_VISION",),
1010
+ "image": ("IMAGE",),
1011
+ "crop": (["center", "none"],)
1012
+ }}
1013
+ RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
1014
+ FUNCTION = "encode"
1015
+
1016
+ CATEGORY = "conditioning"
1017
+
1018
+ def encode(self, clip_vision, image, crop):
1019
+ crop_image = True
1020
+ if crop != "center":
1021
+ crop_image = False
1022
+ output = clip_vision.encode_image(image, crop=crop_image)
1023
+ return (output,)
1024
+
1025
+ class StyleModelLoader:
1026
+ @classmethod
1027
+ def INPUT_TYPES(s):
1028
+ return {"required": { "style_model_name": (folder_paths.get_filename_list("style_models"), )}}
1029
+
1030
+ RETURN_TYPES = ("STYLE_MODEL",)
1031
+ FUNCTION = "load_style_model"
1032
+
1033
+ CATEGORY = "loaders"
1034
+
1035
+ def load_style_model(self, style_model_name):
1036
+ style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name)
1037
+ style_model = comfy.sd.load_style_model(style_model_path)
1038
+ return (style_model,)
1039
+
1040
+
1041
+ class StyleModelApply:
1042
+ @classmethod
1043
+ def INPUT_TYPES(s):
1044
+ return {"required": {"conditioning": ("CONDITIONING", ),
1045
+ "style_model": ("STYLE_MODEL", ),
1046
+ "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
1047
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
1048
+ "strength_type": (["multiply", "attn_bias"], ),
1049
+ }}
1050
+ RETURN_TYPES = ("CONDITIONING",)
1051
+ FUNCTION = "apply_stylemodel"
1052
+
1053
+ CATEGORY = "conditioning/style_model"
1054
+
1055
+ def apply_stylemodel(self, conditioning, style_model, clip_vision_output, strength, strength_type):
1056
+ cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
1057
+ if strength_type == "multiply":
1058
+ cond *= strength
1059
+
1060
+ n = cond.shape[1]
1061
+ c_out = []
1062
+ for t in conditioning:
1063
+ (txt, keys) = t
1064
+ keys = keys.copy()
1065
+ if strength_type == "attn_bias" and strength != 1.0:
1066
+ # math.log raises an error if the argument is zero
1067
+ # torch.log returns -inf, which is what we want
1068
+ attn_bias = torch.log(torch.Tensor([strength]))
1069
+ # get the size of the mask image
1070
+ mask_ref_size = keys.get("attention_mask_img_shape", (1, 1))
1071
+ n_ref = mask_ref_size[0] * mask_ref_size[1]
1072
+ n_txt = txt.shape[1]
1073
+ # grab the existing mask
1074
+ mask = keys.get("attention_mask", None)
1075
+ # create a default mask if it doesn't exist
1076
+ if mask is None:
1077
+ mask = torch.zeros((txt.shape[0], n_txt + n_ref, n_txt + n_ref), dtype=torch.float16)
1078
+ # convert the mask dtype, because it might be boolean
1079
+ # we want it to be interpreted as a bias
1080
+ if mask.dtype == torch.bool:
1081
+ # log(True) = log(1) = 0
1082
+ # log(False) = log(0) = -inf
1083
+ mask = torch.log(mask.to(dtype=torch.float16))
1084
+ # now we make the mask bigger to add space for our new tokens
1085
+ new_mask = torch.zeros((txt.shape[0], n_txt + n + n_ref, n_txt + n + n_ref), dtype=torch.float16)
1086
+ # copy over the old mask, in quandrants
1087
+ new_mask[:, :n_txt, :n_txt] = mask[:, :n_txt, :n_txt]
1088
+ new_mask[:, :n_txt, n_txt+n:] = mask[:, :n_txt, n_txt:]
1089
+ new_mask[:, n_txt+n:, :n_txt] = mask[:, n_txt:, :n_txt]
1090
+ new_mask[:, n_txt+n:, n_txt+n:] = mask[:, n_txt:, n_txt:]
1091
+ # now fill in the attention bias to our redux tokens
1092
+ new_mask[:, :n_txt, n_txt:n_txt+n] = attn_bias
1093
+ new_mask[:, n_txt+n:, n_txt:n_txt+n] = attn_bias
1094
+ keys["attention_mask"] = new_mask.to(txt.device)
1095
+ keys["attention_mask_img_shape"] = mask_ref_size
1096
+
1097
+ c_out.append([torch.cat((txt, cond), dim=1), keys])
1098
+
1099
+ return (c_out,)
1100
+
1101
+ class unCLIPConditioning:
1102
+ @classmethod
1103
+ def INPUT_TYPES(s):
1104
+ return {"required": {"conditioning": ("CONDITIONING", ),
1105
+ "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
1106
+ "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
1107
+ "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
1108
+ }}
1109
+ RETURN_TYPES = ("CONDITIONING",)
1110
+ FUNCTION = "apply_adm"
1111
+
1112
+ CATEGORY = "conditioning"
1113
+
1114
+ def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
1115
+ if strength == 0:
1116
+ return (conditioning, )
1117
+
1118
+ c = []
1119
+ for t in conditioning:
1120
+ o = t[1].copy()
1121
+ x = {"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}
1122
+ if "unclip_conditioning" in o:
1123
+ o["unclip_conditioning"] = o["unclip_conditioning"][:] + [x]
1124
+ else:
1125
+ o["unclip_conditioning"] = [x]
1126
+ n = [t[0], o]
1127
+ c.append(n)
1128
+ return (c, )
1129
+
1130
+ class GLIGENLoader:
1131
+ @classmethod
1132
+ def INPUT_TYPES(s):
1133
+ return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"), )}}
1134
+
1135
+ RETURN_TYPES = ("GLIGEN",)
1136
+ FUNCTION = "load_gligen"
1137
+
1138
+ CATEGORY = "loaders"
1139
+
1140
+ def load_gligen(self, gligen_name):
1141
+ gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name)
1142
+ gligen = comfy.sd.load_gligen(gligen_path)
1143
+ return (gligen,)
1144
+
1145
+ class GLIGENTextBoxApply:
1146
+ @classmethod
1147
+ def INPUT_TYPES(s):
1148
+ return {"required": {"conditioning_to": ("CONDITIONING", ),
1149
+ "clip": ("CLIP", ),
1150
+ "gligen_textbox_model": ("GLIGEN", ),
1151
+ "text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
1152
+ "width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
1153
+ "height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
1154
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1155
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1156
+ }}
1157
+ RETURN_TYPES = ("CONDITIONING",)
1158
+ FUNCTION = "append"
1159
+
1160
+ CATEGORY = "conditioning/gligen"
1161
+
1162
+ def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
1163
+ c = []
1164
+ cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled="unprojected")
1165
+ for t in conditioning_to:
1166
+ n = [t[0], t[1].copy()]
1167
+ position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
1168
+ prev = []
1169
+ if "gligen" in n[1]:
1170
+ prev = n[1]['gligen'][2]
1171
+
1172
+ n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params)
1173
+ c.append(n)
1174
+ return (c, )
1175
+
1176
+ class EmptyLatentImage:
1177
+ def __init__(self):
1178
+ self.device = comfy.model_management.intermediate_device()
1179
+
1180
+ @classmethod
1181
+ def INPUT_TYPES(s):
1182
+ return {
1183
+ "required": {
1184
+ "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8, "tooltip": "The width of the latent images in pixels."}),
1185
+ "height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8, "tooltip": "The height of the latent images in pixels."}),
1186
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."})
1187
+ }
1188
+ }
1189
+ RETURN_TYPES = ("LATENT",)
1190
+ OUTPUT_TOOLTIPS = ("The empty latent image batch.",)
1191
+ FUNCTION = "generate"
1192
+
1193
+ CATEGORY = "latent"
1194
+ DESCRIPTION = "Create a new batch of empty latent images to be denoised via sampling."
1195
+
1196
+ def generate(self, width, height, batch_size=1):
1197
+ latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
1198
+ return ({"samples":latent}, )
1199
+
1200
+
1201
+ class LatentFromBatch:
1202
+ @classmethod
1203
+ def INPUT_TYPES(s):
1204
+ return {"required": { "samples": ("LATENT",),
1205
+ "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
1206
+ "length": ("INT", {"default": 1, "min": 1, "max": 64}),
1207
+ }}
1208
+ RETURN_TYPES = ("LATENT",)
1209
+ FUNCTION = "frombatch"
1210
+
1211
+ CATEGORY = "latent/batch"
1212
+
1213
+ def frombatch(self, samples, batch_index, length):
1214
+ s = samples.copy()
1215
+ s_in = samples["samples"]
1216
+ batch_index = min(s_in.shape[0] - 1, batch_index)
1217
+ length = min(s_in.shape[0] - batch_index, length)
1218
+ s["samples"] = s_in[batch_index:batch_index + length].clone()
1219
+ if "noise_mask" in samples:
1220
+ masks = samples["noise_mask"]
1221
+ if masks.shape[0] == 1:
1222
+ s["noise_mask"] = masks.clone()
1223
+ else:
1224
+ if masks.shape[0] < s_in.shape[0]:
1225
+ masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
1226
+ s["noise_mask"] = masks[batch_index:batch_index + length].clone()
1227
+ if "batch_index" not in s:
1228
+ s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
1229
+ else:
1230
+ s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
1231
+ return (s,)
1232
+
1233
+ class RepeatLatentBatch:
1234
+ @classmethod
1235
+ def INPUT_TYPES(s):
1236
+ return {"required": { "samples": ("LATENT",),
1237
+ "amount": ("INT", {"default": 1, "min": 1, "max": 64}),
1238
+ }}
1239
+ RETURN_TYPES = ("LATENT",)
1240
+ FUNCTION = "repeat"
1241
+
1242
+ CATEGORY = "latent/batch"
1243
+
1244
+ def repeat(self, samples, amount):
1245
+ s = samples.copy()
1246
+ s_in = samples["samples"]
1247
+
1248
+ s["samples"] = s_in.repeat((amount, 1,1,1))
1249
+ if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
1250
+ masks = samples["noise_mask"]
1251
+ if masks.shape[0] < s_in.shape[0]:
1252
+ masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
1253
+ s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
1254
+ if "batch_index" in s:
1255
+ offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
1256
+ s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
1257
+ return (s,)
1258
+
1259
+ class LatentUpscale:
1260
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
1261
+ crop_methods = ["disabled", "center"]
1262
+
1263
+ @classmethod
1264
+ def INPUT_TYPES(s):
1265
+ return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
1266
+ "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1267
+ "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1268
+ "crop": (s.crop_methods,)}}
1269
+ RETURN_TYPES = ("LATENT",)
1270
+ FUNCTION = "upscale"
1271
+
1272
+ CATEGORY = "latent"
1273
+
1274
+ def upscale(self, samples, upscale_method, width, height, crop):
1275
+ if width == 0 and height == 0:
1276
+ s = samples
1277
+ else:
1278
+ s = samples.copy()
1279
+
1280
+ if width == 0:
1281
+ height = max(64, height)
1282
+ width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2]))
1283
+ elif height == 0:
1284
+ width = max(64, width)
1285
+ height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1]))
1286
+ else:
1287
+ width = max(64, width)
1288
+ height = max(64, height)
1289
+
1290
+ s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
1291
+ return (s,)
1292
+
1293
+ class LatentUpscaleBy:
1294
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
1295
+
1296
+ @classmethod
1297
+ def INPUT_TYPES(s):
1298
+ return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
1299
+ "scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}),}}
1300
+ RETURN_TYPES = ("LATENT",)
1301
+ FUNCTION = "upscale"
1302
+
1303
+ CATEGORY = "latent"
1304
+
1305
+ def upscale(self, samples, upscale_method, scale_by):
1306
+ s = samples.copy()
1307
+ width = round(samples["samples"].shape[-1] * scale_by)
1308
+ height = round(samples["samples"].shape[-2] * scale_by)
1309
+ s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
1310
+ return (s,)
1311
+
1312
+ class LatentRotate:
1313
+ @classmethod
1314
+ def INPUT_TYPES(s):
1315
+ return {"required": { "samples": ("LATENT",),
1316
+ "rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],),
1317
+ }}
1318
+ RETURN_TYPES = ("LATENT",)
1319
+ FUNCTION = "rotate"
1320
+
1321
+ CATEGORY = "latent/transform"
1322
+
1323
+ def rotate(self, samples, rotation):
1324
+ s = samples.copy()
1325
+ rotate_by = 0
1326
+ if rotation.startswith("90"):
1327
+ rotate_by = 1
1328
+ elif rotation.startswith("180"):
1329
+ rotate_by = 2
1330
+ elif rotation.startswith("270"):
1331
+ rotate_by = 3
1332
+
1333
+ s["samples"] = torch.rot90(samples["samples"], k=rotate_by, dims=[3, 2])
1334
+ return (s,)
1335
+
1336
+ class LatentFlip:
1337
+ @classmethod
1338
+ def INPUT_TYPES(s):
1339
+ return {"required": { "samples": ("LATENT",),
1340
+ "flip_method": (["x-axis: vertically", "y-axis: horizontally"],),
1341
+ }}
1342
+ RETURN_TYPES = ("LATENT",)
1343
+ FUNCTION = "flip"
1344
+
1345
+ CATEGORY = "latent/transform"
1346
+
1347
+ def flip(self, samples, flip_method):
1348
+ s = samples.copy()
1349
+ if flip_method.startswith("x"):
1350
+ s["samples"] = torch.flip(samples["samples"], dims=[2])
1351
+ elif flip_method.startswith("y"):
1352
+ s["samples"] = torch.flip(samples["samples"], dims=[3])
1353
+
1354
+ return (s,)
1355
+
1356
+ class LatentComposite:
1357
+ @classmethod
1358
+ def INPUT_TYPES(s):
1359
+ return {"required": { "samples_to": ("LATENT",),
1360
+ "samples_from": ("LATENT",),
1361
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1362
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1363
+ "feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1364
+ }}
1365
+ RETURN_TYPES = ("LATENT",)
1366
+ FUNCTION = "composite"
1367
+
1368
+ CATEGORY = "latent"
1369
+
1370
+ def composite(self, samples_to, samples_from, x, y, composite_method="normal", feather=0):
1371
+ x = x // 8
1372
+ y = y // 8
1373
+ feather = feather // 8
1374
+ samples_out = samples_to.copy()
1375
+ s = samples_to["samples"].clone()
1376
+ samples_to = samples_to["samples"]
1377
+ samples_from = samples_from["samples"]
1378
+ if feather == 0:
1379
+ s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
1380
+ else:
1381
+ samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
1382
+ mask = torch.ones_like(samples_from)
1383
+ for t in range(feather):
1384
+ if y != 0:
1385
+ mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
1386
+
1387
+ if y + samples_from.shape[2] < samples_to.shape[2]:
1388
+ mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
1389
+ if x != 0:
1390
+ mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
1391
+ if x + samples_from.shape[3] < samples_to.shape[3]:
1392
+ mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
1393
+ rev_mask = torch.ones_like(mask) - mask
1394
+ s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask
1395
+ samples_out["samples"] = s
1396
+ return (samples_out,)
1397
+
1398
+ class LatentBlend:
1399
+ @classmethod
1400
+ def INPUT_TYPES(s):
1401
+ return {"required": {
1402
+ "samples1": ("LATENT",),
1403
+ "samples2": ("LATENT",),
1404
+ "blend_factor": ("FLOAT", {
1405
+ "default": 0.5,
1406
+ "min": 0,
1407
+ "max": 1,
1408
+ "step": 0.01
1409
+ }),
1410
+ }}
1411
+
1412
+ RETURN_TYPES = ("LATENT",)
1413
+ FUNCTION = "blend"
1414
+
1415
+ CATEGORY = "_for_testing"
1416
+
1417
+ def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
1418
+
1419
+ samples_out = samples1.copy()
1420
+ samples1 = samples1["samples"]
1421
+ samples2 = samples2["samples"]
1422
+
1423
+ if samples1.shape != samples2.shape:
1424
+ samples2.permute(0, 3, 1, 2)
1425
+ samples2 = comfy.utils.common_upscale(samples2, samples1.shape[3], samples1.shape[2], 'bicubic', crop='center')
1426
+ samples2.permute(0, 2, 3, 1)
1427
+
1428
+ samples_blended = self.blend_mode(samples1, samples2, blend_mode)
1429
+ samples_blended = samples1 * blend_factor + samples_blended * (1 - blend_factor)
1430
+ samples_out["samples"] = samples_blended
1431
+ return (samples_out,)
1432
+
1433
+ def blend_mode(self, img1, img2, mode):
1434
+ if mode == "normal":
1435
+ return img2
1436
+ else:
1437
+ raise ValueError(f"Unsupported blend mode: {mode}")
1438
+
1439
+ class LatentCrop:
1440
+ @classmethod
1441
+ def INPUT_TYPES(s):
1442
+ return {"required": { "samples": ("LATENT",),
1443
+ "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
1444
+ "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
1445
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1446
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1447
+ }}
1448
+ RETURN_TYPES = ("LATENT",)
1449
+ FUNCTION = "crop"
1450
+
1451
+ CATEGORY = "latent/transform"
1452
+
1453
+ def crop(self, samples, width, height, x, y):
1454
+ s = samples.copy()
1455
+ samples = samples['samples']
1456
+ x = x // 8
1457
+ y = y // 8
1458
+
1459
+ #enfonce minimum size of 64
1460
+ if x > (samples.shape[3] - 8):
1461
+ x = samples.shape[3] - 8
1462
+ if y > (samples.shape[2] - 8):
1463
+ y = samples.shape[2] - 8
1464
+
1465
+ new_height = height // 8
1466
+ new_width = width // 8
1467
+ to_x = new_width + x
1468
+ to_y = new_height + y
1469
+ s['samples'] = samples[:,:,y:to_y, x:to_x]
1470
+ return (s,)
1471
+
1472
+ class SetLatentNoiseMask:
1473
+ @classmethod
1474
+ def INPUT_TYPES(s):
1475
+ return {"required": { "samples": ("LATENT",),
1476
+ "mask": ("MASK",),
1477
+ }}
1478
+ RETURN_TYPES = ("LATENT",)
1479
+ FUNCTION = "set_mask"
1480
+
1481
+ CATEGORY = "latent/inpaint"
1482
+
1483
+ def set_mask(self, samples, mask):
1484
+ s = samples.copy()
1485
+ s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
1486
+ return (s,)
1487
+
1488
+ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
1489
+ latent_image = latent["samples"]
1490
+ latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
1491
+
1492
+ if disable_noise:
1493
+ noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
1494
+ else:
1495
+ batch_inds = latent["batch_index"] if "batch_index" in latent else None
1496
+ noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
1497
+
1498
+ noise_mask = None
1499
+ if "noise_mask" in latent:
1500
+ noise_mask = latent["noise_mask"]
1501
+
1502
+ callback = latent_preview.prepare_callback(model, steps)
1503
+ disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
1504
+ samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
1505
+ denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
1506
+ force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
1507
+ out = latent.copy()
1508
+ out["samples"] = samples
1509
+ return (out, )
1510
+
1511
+ class KSampler:
1512
+ @classmethod
1513
+ def INPUT_TYPES(s):
1514
+ return {
1515
+ "required": {
1516
+ "model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}),
1517
+ "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "tooltip": "The random seed used for creating the noise."}),
1518
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "The number of steps used in the denoising process."}),
1519
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}),
1520
+ "sampler_name": (comfy.samplers.KSampler.SAMPLERS, {"tooltip": "The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."}),
1521
+ "scheduler": (comfy.samplers.KSampler.SCHEDULERS, {"tooltip": "The scheduler controls how noise is gradually removed to form the image."}),
1522
+ "positive": ("CONDITIONING", {"tooltip": "The conditioning describing the attributes you want to include in the image."}),
1523
+ "negative": ("CONDITIONING", {"tooltip": "The conditioning describing the attributes you want to exclude from the image."}),
1524
+ "latent_image": ("LATENT", {"tooltip": "The latent image to denoise."}),
1525
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."}),
1526
+ }
1527
+ }
1528
+
1529
+ RETURN_TYPES = ("LATENT",)
1530
+ OUTPUT_TOOLTIPS = ("The denoised latent.",)
1531
+ FUNCTION = "sample"
1532
+
1533
+ CATEGORY = "sampling"
1534
+ DESCRIPTION = "Uses the provided model, positive and negative conditioning to denoise the latent image."
1535
+
1536
+ def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
1537
+ return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
1538
+
1539
+ class KSamplerAdvanced:
1540
+ @classmethod
1541
+ def INPUT_TYPES(s):
1542
+ return {"required":
1543
+ {"model": ("MODEL",),
1544
+ "add_noise": (["enable", "disable"], ),
1545
+ "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
1546
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
1547
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
1548
+ "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
1549
+ "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
1550
+ "positive": ("CONDITIONING", ),
1551
+ "negative": ("CONDITIONING", ),
1552
+ "latent_image": ("LATENT", ),
1553
+ "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
1554
+ "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
1555
+ "return_with_leftover_noise": (["disable", "enable"], ),
1556
+ }
1557
+ }
1558
+
1559
+ RETURN_TYPES = ("LATENT",)
1560
+ FUNCTION = "sample"
1561
+
1562
+ CATEGORY = "sampling"
1563
+
1564
+ def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0):
1565
+ force_full_denoise = True
1566
+ if return_with_leftover_noise == "enable":
1567
+ force_full_denoise = False
1568
+ disable_noise = False
1569
+ if add_noise == "disable":
1570
+ disable_noise = True
1571
+ return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
1572
+
1573
+ class SaveImage:
1574
+ def __init__(self):
1575
+ self.output_dir = folder_paths.get_output_directory()
1576
+ self.type = "output"
1577
+ self.prefix_append = ""
1578
+ self.compress_level = 4
1579
+
1580
+ @classmethod
1581
+ def INPUT_TYPES(s):
1582
+ return {
1583
+ "required": {
1584
+ "images": ("IMAGE", {"tooltip": "The images to save."}),
1585
+ "filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
1586
+ },
1587
+ "hidden": {
1588
+ "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"
1589
+ },
1590
+ }
1591
+
1592
+ RETURN_TYPES = ()
1593
+ FUNCTION = "save_images"
1594
+
1595
+ OUTPUT_NODE = True
1596
+
1597
+ CATEGORY = "image"
1598
+ DESCRIPTION = "Saves the input images to your ComfyUI output directory."
1599
+
1600
+ def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
1601
+ filename_prefix += self.prefix_append
1602
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
1603
+ results = list()
1604
+ for (batch_number, image) in enumerate(images):
1605
+ i = 255. * image.cpu().numpy()
1606
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
1607
+ metadata = None
1608
+ if not args.disable_metadata:
1609
+ metadata = PngInfo()
1610
+ if prompt is not None:
1611
+ metadata.add_text("prompt", json.dumps(prompt))
1612
+ if extra_pnginfo is not None:
1613
+ for x in extra_pnginfo:
1614
+ metadata.add_text(x, json.dumps(extra_pnginfo[x]))
1615
+
1616
+ filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
1617
+ file = f"{filename_with_batch_num}_{counter:05}_.png"
1618
+ img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
1619
+ results.append({
1620
+ "filename": file,
1621
+ "subfolder": subfolder,
1622
+ "type": self.type
1623
+ })
1624
+ counter += 1
1625
+
1626
+ return { "ui": { "images": results } }
1627
+
1628
+ class PreviewImage(SaveImage):
1629
+ def __init__(self):
1630
+ self.output_dir = folder_paths.get_temp_directory()
1631
+ self.type = "temp"
1632
+ self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
1633
+ self.compress_level = 1
1634
+
1635
+ @classmethod
1636
+ def INPUT_TYPES(s):
1637
+ return {"required":
1638
+ {"images": ("IMAGE", ), },
1639
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
1640
+ }
1641
+
1642
+ class LoadImage:
1643
+ @classmethod
1644
+ def INPUT_TYPES(s):
1645
+ input_dir = folder_paths.get_input_directory()
1646
+ files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
1647
+ return {"required":
1648
+ {"image": (sorted(files), {"image_upload": True})},
1649
+ }
1650
+
1651
+ CATEGORY = "image"
1652
+
1653
+ RETURN_TYPES = ("IMAGE", "MASK")
1654
+ FUNCTION = "load_image"
1655
+ def load_image(self, image):
1656
+ image_path = folder_paths.get_annotated_filepath(image)
1657
+
1658
+ img = node_helpers.pillow(Image.open, image_path)
1659
+
1660
+ output_images = []
1661
+ output_masks = []
1662
+ w, h = None, None
1663
+
1664
+ excluded_formats = ['MPO']
1665
+
1666
+ for i in ImageSequence.Iterator(img):
1667
+ i = node_helpers.pillow(ImageOps.exif_transpose, i)
1668
+
1669
+ if i.mode == 'I':
1670
+ i = i.point(lambda i: i * (1 / 255))
1671
+ image = i.convert("RGB")
1672
+
1673
+ if len(output_images) == 0:
1674
+ w = image.size[0]
1675
+ h = image.size[1]
1676
+
1677
+ if image.size[0] != w or image.size[1] != h:
1678
+ continue
1679
+
1680
+ image = np.array(image).astype(np.float32) / 255.0
1681
+ image = torch.from_numpy(image)[None,]
1682
+ if 'A' in i.getbands():
1683
+ mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
1684
+ mask = 1. - torch.from_numpy(mask)
1685
+ else:
1686
+ mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
1687
+ output_images.append(image)
1688
+ output_masks.append(mask.unsqueeze(0))
1689
+
1690
+ if len(output_images) > 1 and img.format not in excluded_formats:
1691
+ output_image = torch.cat(output_images, dim=0)
1692
+ output_mask = torch.cat(output_masks, dim=0)
1693
+ else:
1694
+ output_image = output_images[0]
1695
+ output_mask = output_masks[0]
1696
+
1697
+ return (output_image, output_mask)
1698
+
1699
+ @classmethod
1700
+ def IS_CHANGED(s, image):
1701
+ image_path = folder_paths.get_annotated_filepath(image)
1702
+ m = hashlib.sha256()
1703
+ with open(image_path, 'rb') as f:
1704
+ m.update(f.read())
1705
+ return m.digest().hex()
1706
+
1707
+ @classmethod
1708
+ def VALIDATE_INPUTS(s, image):
1709
+ if not folder_paths.exists_annotated_filepath(image):
1710
+ return "Invalid image file: {}".format(image)
1711
+
1712
+ return True
1713
+
1714
+ class LoadImageMask:
1715
+ _color_channels = ["alpha", "red", "green", "blue"]
1716
+ @classmethod
1717
+ def INPUT_TYPES(s):
1718
+ input_dir = folder_paths.get_input_directory()
1719
+ files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
1720
+ return {"required":
1721
+ {"image": (sorted(files), {"image_upload": True}),
1722
+ "channel": (s._color_channels, ), }
1723
+ }
1724
+
1725
+ CATEGORY = "mask"
1726
+
1727
+ RETURN_TYPES = ("MASK",)
1728
+ FUNCTION = "load_image"
1729
+ def load_image(self, image, channel):
1730
+ image_path = folder_paths.get_annotated_filepath(image)
1731
+ i = node_helpers.pillow(Image.open, image_path)
1732
+ i = node_helpers.pillow(ImageOps.exif_transpose, i)
1733
+ if i.getbands() != ("R", "G", "B", "A"):
1734
+ if i.mode == 'I':
1735
+ i = i.point(lambda i: i * (1 / 255))
1736
+ i = i.convert("RGBA")
1737
+ mask = None
1738
+ c = channel[0].upper()
1739
+ if c in i.getbands():
1740
+ mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
1741
+ mask = torch.from_numpy(mask)
1742
+ if c == 'A':
1743
+ mask = 1. - mask
1744
+ else:
1745
+ mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
1746
+ return (mask.unsqueeze(0),)
1747
+
1748
+ @classmethod
1749
+ def IS_CHANGED(s, image, channel):
1750
+ image_path = folder_paths.get_annotated_filepath(image)
1751
+ m = hashlib.sha256()
1752
+ with open(image_path, 'rb') as f:
1753
+ m.update(f.read())
1754
+ return m.digest().hex()
1755
+
1756
+ @classmethod
1757
+ def VALIDATE_INPUTS(s, image):
1758
+ if not folder_paths.exists_annotated_filepath(image):
1759
+ return "Invalid image file: {}".format(image)
1760
+
1761
+ return True
1762
+
1763
+ class ImageScale:
1764
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
1765
+ crop_methods = ["disabled", "center"]
1766
+
1767
+ @classmethod
1768
+ def INPUT_TYPES(s):
1769
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
1770
+ "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
1771
+ "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
1772
+ "crop": (s.crop_methods,)}}
1773
+ RETURN_TYPES = ("IMAGE",)
1774
+ FUNCTION = "upscale"
1775
+
1776
+ CATEGORY = "image/upscaling"
1777
+
1778
+ def upscale(self, image, upscale_method, width, height, crop):
1779
+ if width == 0 and height == 0:
1780
+ s = image
1781
+ else:
1782
+ samples = image.movedim(-1,1)
1783
+
1784
+ if width == 0:
1785
+ width = max(1, round(samples.shape[3] * height / samples.shape[2]))
1786
+ elif height == 0:
1787
+ height = max(1, round(samples.shape[2] * width / samples.shape[3]))
1788
+
1789
+ s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop)
1790
+ s = s.movedim(1,-1)
1791
+ return (s,)
1792
+
1793
+ class ImageScaleBy:
1794
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
1795
+
1796
+ @classmethod
1797
+ def INPUT_TYPES(s):
1798
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
1799
+ "scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 8.0, "step": 0.01}),}}
1800
+ RETURN_TYPES = ("IMAGE",)
1801
+ FUNCTION = "upscale"
1802
+
1803
+ CATEGORY = "image/upscaling"
1804
+
1805
+ def upscale(self, image, upscale_method, scale_by):
1806
+ samples = image.movedim(-1,1)
1807
+ width = round(samples.shape[3] * scale_by)
1808
+ height = round(samples.shape[2] * scale_by)
1809
+ s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
1810
+ s = s.movedim(1,-1)
1811
+ return (s,)
1812
+
1813
+ class ImageInvert:
1814
+
1815
+ @classmethod
1816
+ def INPUT_TYPES(s):
1817
+ return {"required": { "image": ("IMAGE",)}}
1818
+
1819
+ RETURN_TYPES = ("IMAGE",)
1820
+ FUNCTION = "invert"
1821
+
1822
+ CATEGORY = "image"
1823
+
1824
+ def invert(self, image):
1825
+ s = 1.0 - image
1826
+ return (s,)
1827
+
1828
+ class ImageBatch:
1829
+
1830
+ @classmethod
1831
+ def INPUT_TYPES(s):
1832
+ return {"required": { "image1": ("IMAGE",), "image2": ("IMAGE",)}}
1833
+
1834
+ RETURN_TYPES = ("IMAGE",)
1835
+ FUNCTION = "batch"
1836
+
1837
+ CATEGORY = "image"
1838
+
1839
+ def batch(self, image1, image2):
1840
+ if image1.shape[1:] != image2.shape[1:]:
1841
+ image2 = comfy.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1)
1842
+ s = torch.cat((image1, image2), dim=0)
1843
+ return (s,)
1844
+
1845
+ class EmptyImage:
1846
+ def __init__(self, device="cpu"):
1847
+ self.device = device
1848
+
1849
+ @classmethod
1850
+ def INPUT_TYPES(s):
1851
+ return {"required": { "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
1852
+ "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
1853
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
1854
+ "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
1855
+ }}
1856
+ RETURN_TYPES = ("IMAGE",)
1857
+ FUNCTION = "generate"
1858
+
1859
+ CATEGORY = "image"
1860
+
1861
+ def generate(self, width, height, batch_size=1, color=0):
1862
+ r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF)
1863
+ g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF)
1864
+ b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF)
1865
+ return (torch.cat((r, g, b), dim=-1), )
1866
+
1867
+ class ImagePadForOutpaint:
1868
+
1869
+ @classmethod
1870
+ def INPUT_TYPES(s):
1871
+ return {
1872
+ "required": {
1873
+ "image": ("IMAGE",),
1874
+ "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1875
+ "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1876
+ "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1877
+ "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1878
+ "feathering": ("INT", {"default": 40, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
1879
+ }
1880
+ }
1881
+
1882
+ RETURN_TYPES = ("IMAGE", "MASK")
1883
+ FUNCTION = "expand_image"
1884
+
1885
+ CATEGORY = "image"
1886
+
1887
+ def expand_image(self, image, left, top, right, bottom, feathering):
1888
+ d1, d2, d3, d4 = image.size()
1889
+
1890
+ new_image = torch.ones(
1891
+ (d1, d2 + top + bottom, d3 + left + right, d4),
1892
+ dtype=torch.float32,
1893
+ ) * 0.5
1894
+
1895
+ new_image[:, top:top + d2, left:left + d3, :] = image
1896
+
1897
+ mask = torch.ones(
1898
+ (d2 + top + bottom, d3 + left + right),
1899
+ dtype=torch.float32,
1900
+ )
1901
+
1902
+ t = torch.zeros(
1903
+ (d2, d3),
1904
+ dtype=torch.float32
1905
+ )
1906
+
1907
+ if feathering > 0 and feathering * 2 < d2 and feathering * 2 < d3:
1908
+
1909
+ for i in range(d2):
1910
+ for j in range(d3):
1911
+ dt = i if top != 0 else d2
1912
+ db = d2 - i if bottom != 0 else d2
1913
+
1914
+ dl = j if left != 0 else d3
1915
+ dr = d3 - j if right != 0 else d3
1916
+
1917
+ d = min(dt, db, dl, dr)
1918
+
1919
+ if d >= feathering:
1920
+ continue
1921
+
1922
+ v = (feathering - d) / feathering
1923
+
1924
+ t[i, j] = v * v
1925
+
1926
+ mask[top:top + d2, left:left + d3] = t
1927
+
1928
+ return (new_image, mask)
1929
+
1930
+
1931
+ NODE_CLASS_MAPPINGS = {
1932
+ "KSampler": KSampler,
1933
+ "CheckpointLoaderSimple": CheckpointLoaderSimple,
1934
+ "CLIPTextEncode": CLIPTextEncode,
1935
+ "CLIPSetLastLayer": CLIPSetLastLayer,
1936
+ "VAEDecode": VAEDecode,
1937
+ "VAEEncode": VAEEncode,
1938
+ "VAEEncodeForInpaint": VAEEncodeForInpaint,
1939
+ "VAELoader": VAELoader,
1940
+ "EmptyLatentImage": EmptyLatentImage,
1941
+ "LatentUpscale": LatentUpscale,
1942
+ "LatentUpscaleBy": LatentUpscaleBy,
1943
+ "LatentFromBatch": LatentFromBatch,
1944
+ "RepeatLatentBatch": RepeatLatentBatch,
1945
+ "SaveImage": SaveImage,
1946
+ "PreviewImage": PreviewImage,
1947
+ "LoadImage": LoadImage,
1948
+ "LoadImageMask": LoadImageMask,
1949
+ "ImageScale": ImageScale,
1950
+ "ImageScaleBy": ImageScaleBy,
1951
+ "ImageInvert": ImageInvert,
1952
+ "ImageBatch": ImageBatch,
1953
+ "ImagePadForOutpaint": ImagePadForOutpaint,
1954
+ "EmptyImage": EmptyImage,
1955
+ "ConditioningAverage": ConditioningAverage ,
1956
+ "ConditioningCombine": ConditioningCombine,
1957
+ "ConditioningConcat": ConditioningConcat,
1958
+ "ConditioningSetArea": ConditioningSetArea,
1959
+ "ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
1960
+ "ConditioningSetAreaStrength": ConditioningSetAreaStrength,
1961
+ "ConditioningSetMask": ConditioningSetMask,
1962
+ "KSamplerAdvanced": KSamplerAdvanced,
1963
+ "SetLatentNoiseMask": SetLatentNoiseMask,
1964
+ "LatentComposite": LatentComposite,
1965
+ "LatentBlend": LatentBlend,
1966
+ "LatentRotate": LatentRotate,
1967
+ "LatentFlip": LatentFlip,
1968
+ "LatentCrop": LatentCrop,
1969
+ "LoraLoader": LoraLoader,
1970
+ "CLIPLoader": CLIPLoader,
1971
+ "UNETLoader": UNETLoader,
1972
+ "DualCLIPLoader": DualCLIPLoader,
1973
+ "CLIPVisionEncode": CLIPVisionEncode,
1974
+ "StyleModelApply": StyleModelApply,
1975
+ "unCLIPConditioning": unCLIPConditioning,
1976
+ "ControlNetApply": ControlNetApply,
1977
+ "ControlNetApplyAdvanced": ControlNetApplyAdvanced,
1978
+ "ControlNetLoader": ControlNetLoader,
1979
+ "DiffControlNetLoader": DiffControlNetLoader,
1980
+ "StyleModelLoader": StyleModelLoader,
1981
+ "CLIPVisionLoader": CLIPVisionLoader,
1982
+ "VAEDecodeTiled": VAEDecodeTiled,
1983
+ "VAEEncodeTiled": VAEEncodeTiled,
1984
+ "unCLIPCheckpointLoader": unCLIPCheckpointLoader,
1985
+ "GLIGENLoader": GLIGENLoader,
1986
+ "GLIGENTextBoxApply": GLIGENTextBoxApply,
1987
+ "InpaintModelConditioning": InpaintModelConditioning,
1988
+
1989
+ "CheckpointLoader": CheckpointLoader,
1990
+ "DiffusersLoader": DiffusersLoader,
1991
+
1992
+ "LoadLatent": LoadLatent,
1993
+ "SaveLatent": SaveLatent,
1994
+
1995
+ "ConditioningZeroOut": ConditioningZeroOut,
1996
+ "ConditioningSetTimestepRange": ConditioningSetTimestepRange,
1997
+ "LoraLoaderModelOnly": LoraLoaderModelOnly,
1998
+ }
1999
+
2000
+ NODE_DISPLAY_NAME_MAPPINGS = {
2001
+ # Sampling
2002
+ "KSampler": "KSampler",
2003
+ "KSamplerAdvanced": "KSampler (Advanced)",
2004
+ # Loaders
2005
+ "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
2006
+ "CheckpointLoaderSimple": "Load Checkpoint",
2007
+ "VAELoader": "Load VAE",
2008
+ "LoraLoader": "Load LoRA",
2009
+ "CLIPLoader": "Load CLIP",
2010
+ "ControlNetLoader": "Load ControlNet Model",
2011
+ "DiffControlNetLoader": "Load ControlNet Model (diff)",
2012
+ "StyleModelLoader": "Load Style Model",
2013
+ "CLIPVisionLoader": "Load CLIP Vision",
2014
+ "UpscaleModelLoader": "Load Upscale Model",
2015
+ "UNETLoader": "Load Diffusion Model",
2016
+ # Conditioning
2017
+ "CLIPVisionEncode": "CLIP Vision Encode",
2018
+ "StyleModelApply": "Apply Style Model",
2019
+ "CLIPTextEncode": "CLIP Text Encode (Prompt)",
2020
+ "CLIPSetLastLayer": "CLIP Set Last Layer",
2021
+ "ConditioningCombine": "Conditioning (Combine)",
2022
+ "ConditioningAverage ": "Conditioning (Average)",
2023
+ "ConditioningConcat": "Conditioning (Concat)",
2024
+ "ConditioningSetArea": "Conditioning (Set Area)",
2025
+ "ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
2026
+ "ConditioningSetMask": "Conditioning (Set Mask)",
2027
+ "ControlNetApply": "Apply ControlNet (OLD)",
2028
+ "ControlNetApplyAdvanced": "Apply ControlNet",
2029
+ # Latent
2030
+ "VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
2031
+ "SetLatentNoiseMask": "Set Latent Noise Mask",
2032
+ "VAEDecode": "VAE Decode",
2033
+ "VAEEncode": "VAE Encode",
2034
+ "LatentRotate": "Rotate Latent",
2035
+ "LatentFlip": "Flip Latent",
2036
+ "LatentCrop": "Crop Latent",
2037
+ "EmptyLatentImage": "Empty Latent Image",
2038
+ "LatentUpscale": "Upscale Latent",
2039
+ "LatentUpscaleBy": "Upscale Latent By",
2040
+ "LatentComposite": "Latent Composite",
2041
+ "LatentBlend": "Latent Blend",
2042
+ "LatentFromBatch" : "Latent From Batch",
2043
+ "RepeatLatentBatch": "Repeat Latent Batch",
2044
+ # Image
2045
+ "SaveImage": "Save Image",
2046
+ "PreviewImage": "Preview Image",
2047
+ "LoadImage": "Load Image",
2048
+ "LoadImageMask": "Load Image (as Mask)",
2049
+ "ImageScale": "Upscale Image",
2050
+ "ImageScaleBy": "Upscale Image By",
2051
+ "ImageUpscaleWithModel": "Upscale Image (using Model)",
2052
+ "ImageInvert": "Invert Image",
2053
+ "ImagePadForOutpaint": "Pad Image for Outpainting",
2054
+ "ImageBatch": "Batch Images",
2055
+ "ImageCrop": "Image Crop",
2056
+ "ImageBlend": "Image Blend",
2057
+ "ImageBlur": "Image Blur",
2058
+ "ImageQuantize": "Image Quantize",
2059
+ "ImageSharpen": "Image Sharpen",
2060
+ "ImageScaleToTotalPixels": "Scale Image to Total Pixels",
2061
+ # _for_testing
2062
+ "VAEDecodeTiled": "VAE Decode (Tiled)",
2063
+ "VAEEncodeTiled": "VAE Encode (Tiled)",
2064
+ }
2065
+
2066
+ EXTENSION_WEB_DIRS = {}
2067
+
2068
+ # Dictionary of successfully loaded module names and associated directories.
2069
+ LOADED_MODULE_DIRS = {}
2070
+
2071
+
2072
+ def get_module_name(module_path: str) -> str:
2073
+ """
2074
+ Returns the module name based on the given module path.
2075
+ Examples:
2076
+ get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node.py") -> "my_custom_node"
2077
+ get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node") -> "my_custom_node"
2078
+ get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/") -> "my_custom_node"
2079
+ get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__.py") -> "my_custom_node"
2080
+ get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__") -> "my_custom_node"
2081
+ get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__/") -> "my_custom_node"
2082
+ get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node.disabled") -> "custom_nodes
2083
+ Args:
2084
+ module_path (str): The path of the module.
2085
+ Returns:
2086
+ str: The module name.
2087
+ """
2088
+ base_path = os.path.basename(module_path)
2089
+ if os.path.isfile(module_path):
2090
+ base_path = os.path.splitext(base_path)[0]
2091
+ return base_path
2092
+
2093
+
2094
+ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
2095
+ module_name = os.path.basename(module_path)
2096
+ if os.path.isfile(module_path):
2097
+ sp = os.path.splitext(module_path)
2098
+ module_name = sp[0]
2099
+ try:
2100
+ logging.debug("Trying to load custom node {}".format(module_path))
2101
+ if os.path.isfile(module_path):
2102
+ module_spec = importlib.util.spec_from_file_location(module_name, module_path)
2103
+ module_dir = os.path.split(module_path)[0]
2104
+ else:
2105
+ module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
2106
+ module_dir = module_path
2107
+
2108
+ module = importlib.util.module_from_spec(module_spec)
2109
+ sys.modules[module_name] = module
2110
+ module_spec.loader.exec_module(module)
2111
+
2112
+ LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
2113
+
2114
+ if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None:
2115
+ web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY")))
2116
+ if os.path.isdir(web_dir):
2117
+ EXTENSION_WEB_DIRS[module_name] = web_dir
2118
+
2119
+ if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
2120
+ for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
2121
+ if name not in ignore:
2122
+ NODE_CLASS_MAPPINGS[name] = node_cls
2123
+ node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
2124
+ if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
2125
+ NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
2126
+ return True
2127
+ else:
2128
+ logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
2129
+ return False
2130
+ except Exception as e:
2131
+ logging.warning(traceback.format_exc())
2132
+ logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
2133
+ return False
2134
+
2135
+ def init_external_custom_nodes():
2136
+ """
2137
+ Initializes the external custom nodes.
2138
+
2139
+ This function loads custom nodes from the specified folder paths and imports them into the application.
2140
+ It measures the import times for each custom node and logs the results.
2141
+
2142
+ Returns:
2143
+ None
2144
+ """
2145
+ base_node_names = set(NODE_CLASS_MAPPINGS.keys())
2146
+ node_paths = folder_paths.get_folder_paths("custom_nodes")
2147
+ node_import_times = []
2148
+ for custom_node_path in node_paths:
2149
+ possible_modules = os.listdir(os.path.realpath(custom_node_path))
2150
+ if "__pycache__" in possible_modules:
2151
+ possible_modules.remove("__pycache__")
2152
+
2153
+ for possible_module in possible_modules:
2154
+ module_path = os.path.join(custom_node_path, possible_module)
2155
+ if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
2156
+ if module_path.endswith(".disabled"): continue
2157
+ time_before = time.perf_counter()
2158
+ success = load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
2159
+ node_import_times.append((time.perf_counter() - time_before, module_path, success))
2160
+
2161
+ if len(node_import_times) > 0:
2162
+ logging.info("\nImport times for custom nodes:")
2163
+ for n in sorted(node_import_times):
2164
+ if n[2]:
2165
+ import_message = ""
2166
+ else:
2167
+ import_message = " (IMPORT FAILED)"
2168
+ logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
2169
+ logging.info("")
2170
+
2171
+ def init_builtin_extra_nodes():
2172
+ """
2173
+ Initializes the built-in extra nodes in ComfyUI.
2174
+
2175
+ This function loads the extra node files located in the "comfy_extras" directory and imports them into ComfyUI.
2176
+ If any of the extra node files fail to import, a warning message is logged.
2177
+
2178
+ Returns:
2179
+ None
2180
+ """
2181
+ extras_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras")
2182
+ extras_files = [
2183
+ "nodes_latent.py",
2184
+ "nodes_hypernetwork.py",
2185
+ "nodes_upscale_model.py",
2186
+ "nodes_post_processing.py",
2187
+ "nodes_mask.py",
2188
+ "nodes_compositing.py",
2189
+ "nodes_rebatch.py",
2190
+ "nodes_model_merging.py",
2191
+ "nodes_tomesd.py",
2192
+ "nodes_clip_sdxl.py",
2193
+ "nodes_canny.py",
2194
+ "nodes_freelunch.py",
2195
+ "nodes_custom_sampler.py",
2196
+ "nodes_hypertile.py",
2197
+ "nodes_model_advanced.py",
2198
+ "nodes_model_downscale.py",
2199
+ "nodes_images.py",
2200
+ "nodes_video_model.py",
2201
+ "nodes_sag.py",
2202
+ "nodes_perpneg.py",
2203
+ "nodes_stable3d.py",
2204
+ "nodes_sdupscale.py",
2205
+ "nodes_photomaker.py",
2206
+ "nodes_pixart.py",
2207
+ "nodes_cond.py",
2208
+ "nodes_morphology.py",
2209
+ "nodes_stable_cascade.py",
2210
+ "nodes_differential_diffusion.py",
2211
+ "nodes_ip2p.py",
2212
+ "nodes_model_merging_model_specific.py",
2213
+ "nodes_pag.py",
2214
+ "nodes_align_your_steps.py",
2215
+ "nodes_attention_multiply.py",
2216
+ "nodes_advanced_samplers.py",
2217
+ "nodes_webcam.py",
2218
+ "nodes_audio.py",
2219
+ "nodes_sd3.py",
2220
+ "nodes_gits.py",
2221
+ "nodes_controlnet.py",
2222
+ "nodes_hunyuan.py",
2223
+ "nodes_flux.py",
2224
+ "nodes_lora_extract.py",
2225
+ "nodes_torch_compile.py",
2226
+ "nodes_mochi.py",
2227
+ "nodes_slg.py",
2228
+ "nodes_mahiro.py",
2229
+ "nodes_lt.py",
2230
+ "nodes_hooks.py",
2231
+ "nodes_load_3d.py",
2232
+ "nodes_cosmos.py",
2233
+ ]
2234
+
2235
+ import_failed = []
2236
+ for node_file in extras_files:
2237
+ if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
2238
+ import_failed.append(node_file)
2239
+
2240
+ return import_failed
2241
+
2242
+
2243
+ def init_extra_nodes(init_custom_nodes=True):
2244
+ import_failed = init_builtin_extra_nodes()
2245
+
2246
+ if init_custom_nodes:
2247
+ init_external_custom_nodes()
2248
+ else:
2249
+ logging.info("Skipping loading of custom nodes")
2250
+
2251
+ if len(import_failed) > 0:
2252
+ logging.warning("WARNING: some comfy_extras/ nodes did not import correctly. This may be because they are missing some dependencies.\n")
2253
+ for node in import_failed:
2254
+ logging.warning("IMPORT FAILED: {}".format(node))
2255
+ logging.warning("\nThis issue might be caused by new missing dependencies added the last time you updated ComfyUI.")
2256
+ if args.windows_standalone_build:
2257
+ logging.warning("Please run the update script: update/update_comfyui.bat")
2258
+ else:
2259
+ logging.warning("Please do a: pip install -r requirements.txt")
2260
+ logging.warning("")
2261
+
2262
+ return import_failed
server.py ADDED
@@ -0,0 +1,863 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import asyncio
4
+ import traceback
5
+
6
+ import nodes
7
+ import folder_paths
8
+ import execution
9
+ import uuid
10
+ import urllib
11
+ import json
12
+ import glob
13
+ import struct
14
+ import ssl
15
+ import socket
16
+ import ipaddress
17
+ from PIL import Image, ImageOps
18
+ from PIL.PngImagePlugin import PngInfo
19
+ from io import BytesIO
20
+
21
+ import aiohttp
22
+ from aiohttp import web
23
+ import logging
24
+
25
+ import mimetypes
26
+ from comfy.cli_args import args
27
+ import comfy.utils
28
+ import comfy.model_management
29
+ import node_helpers
30
+ from comfyui_version import __version__
31
+ from app.frontend_management import FrontendManager
32
+ from app.user_manager import UserManager
33
+ from app.model_manager import ModelFileManager
34
+ from app.custom_node_manager import CustomNodeManager
35
+ from typing import Optional
36
+ from api_server.routes.internal.internal_routes import InternalRoutes
37
+
38
+ class BinaryEventTypes:
39
+ PREVIEW_IMAGE = 1
40
+ UNENCODED_PREVIEW_IMAGE = 2
41
+
42
+ async def send_socket_catch_exception(function, message):
43
+ try:
44
+ await function(message)
45
+ except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
46
+ logging.warning("send error: {}".format(err))
47
+
48
+ @web.middleware
49
+ async def cache_control(request: web.Request, handler):
50
+ response: web.Response = await handler(request)
51
+ if request.path.endswith('.js') or request.path.endswith('.css'):
52
+ response.headers.setdefault('Cache-Control', 'no-cache')
53
+ return response
54
+
55
+
56
+ @web.middleware
57
+ async def compress_body(request: web.Request, handler):
58
+ accept_encoding = request.headers.get("Accept-Encoding", "")
59
+ response: web.Response = await handler(request)
60
+ if args.disable_compres_response_body:
61
+ return response
62
+ if not isinstance(response, web.Response):
63
+ return response
64
+ if response.content_type not in ["application/json", "text/plain"]:
65
+ return response
66
+ if response.body and "gzip" in accept_encoding:
67
+ response.enable_compression()
68
+ return response
69
+
70
+
71
+ def create_cors_middleware(allowed_origin: str):
72
+ @web.middleware
73
+ async def cors_middleware(request: web.Request, handler):
74
+ if request.method == "OPTIONS":
75
+ # Pre-flight request. Reply successfully:
76
+ response = web.Response()
77
+ else:
78
+ response = await handler(request)
79
+
80
+ response.headers['Access-Control-Allow-Origin'] = allowed_origin
81
+ response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
82
+ response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
83
+ response.headers['Access-Control-Allow-Credentials'] = 'true'
84
+ return response
85
+
86
+ return cors_middleware
87
+
88
+ def is_loopback(host):
89
+ if host is None:
90
+ return False
91
+ try:
92
+ if ipaddress.ip_address(host).is_loopback:
93
+ return True
94
+ else:
95
+ return False
96
+ except:
97
+ pass
98
+
99
+ loopback = False
100
+ for family in (socket.AF_INET, socket.AF_INET6):
101
+ try:
102
+ r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
103
+ for family, _, _, _, sockaddr in r:
104
+ if not ipaddress.ip_address(sockaddr[0]).is_loopback:
105
+ return loopback
106
+ else:
107
+ loopback = True
108
+ except socket.gaierror:
109
+ pass
110
+
111
+ return loopback
112
+
113
+
114
+ def create_origin_only_middleware():
115
+ @web.middleware
116
+ async def origin_only_middleware(request: web.Request, handler):
117
+ #this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
118
+ #in that case the Host and Origin hostnames won't match
119
+ #I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
120
+ if 'Host' in request.headers and 'Origin' in request.headers:
121
+ host = request.headers['Host']
122
+ origin = request.headers['Origin']
123
+ host_domain = host.lower()
124
+ parsed = urllib.parse.urlparse(origin)
125
+ origin_domain = parsed.netloc.lower()
126
+ host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
127
+
128
+ #limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
129
+ loopback = is_loopback(host_domain_parsed.hostname)
130
+
131
+ if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
132
+ host_domain = host_domain_parsed.hostname
133
+ if host_domain_parsed.port is None:
134
+ origin_domain = parsed.hostname
135
+
136
+ if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
137
+ if host_domain != origin_domain:
138
+ logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
139
+ return web.Response(status=403)
140
+
141
+ if request.method == "OPTIONS":
142
+ response = web.Response()
143
+ else:
144
+ response = await handler(request)
145
+
146
+ return response
147
+
148
+ return origin_only_middleware
149
+
150
+ class PromptServer():
151
+ def __init__(self, loop):
152
+ PromptServer.instance = self
153
+
154
+ mimetypes.init()
155
+ mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
156
+
157
+ self.user_manager = UserManager()
158
+ self.model_file_manager = ModelFileManager()
159
+ self.custom_node_manager = CustomNodeManager()
160
+ self.internal_routes = InternalRoutes(self)
161
+ self.supports = ["custom_nodes_from_web"]
162
+ self.prompt_queue = None
163
+ self.loop = loop
164
+ self.messages = asyncio.Queue()
165
+ self.client_session:Optional[aiohttp.ClientSession] = None
166
+ self.number = 0
167
+
168
+ middlewares = [cache_control, compress_body]
169
+ if args.enable_cors_header:
170
+ middlewares.append(create_cors_middleware(args.enable_cors_header))
171
+ else:
172
+ middlewares.append(create_origin_only_middleware())
173
+
174
+ max_upload_size = round(args.max_upload_size * 1024 * 1024)
175
+ self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
176
+ self.sockets = dict()
177
+ self.web_root = (
178
+ FrontendManager.init_frontend(args.front_end_version)
179
+ if args.front_end_root is None
180
+ else args.front_end_root
181
+ )
182
+ logging.info(f"[Prompt Server] web root: {self.web_root}")
183
+ routes = web.RouteTableDef()
184
+ self.routes = routes
185
+ self.last_node_id = None
186
+ self.client_id = None
187
+
188
+ self.on_prompt_handlers = []
189
+
190
+ @routes.get('/ws')
191
+ async def websocket_handler(request):
192
+ ws = web.WebSocketResponse()
193
+ await ws.prepare(request)
194
+ sid = request.rel_url.query.get('clientId', '')
195
+ if sid:
196
+ # Reusing existing session, remove old
197
+ self.sockets.pop(sid, None)
198
+ else:
199
+ sid = uuid.uuid4().hex
200
+
201
+ self.sockets[sid] = ws
202
+
203
+ try:
204
+ # Send initial state to the new client
205
+ await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
206
+ # On reconnect if we are the currently executing client send the current node
207
+ if self.client_id == sid and self.last_node_id is not None:
208
+ await self.send("executing", { "node": self.last_node_id }, sid)
209
+
210
+ async for msg in ws:
211
+ if msg.type == aiohttp.WSMsgType.ERROR:
212
+ logging.warning('ws connection closed with exception %s' % ws.exception())
213
+ finally:
214
+ self.sockets.pop(sid, None)
215
+ return ws
216
+
217
+ @routes.get("/")
218
+ async def get_root(request):
219
+ response = web.FileResponse(os.path.join(self.web_root, "index.html"))
220
+ response.headers['Cache-Control'] = 'no-cache'
221
+ response.headers["Pragma"] = "no-cache"
222
+ response.headers["Expires"] = "0"
223
+ return response
224
+
225
+ @routes.get("/embeddings")
226
+ def get_embeddings(self):
227
+ embeddings = folder_paths.get_filename_list("embeddings")
228
+ return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
229
+
230
+ @routes.get("/models")
231
+ def list_model_types(request):
232
+ model_types = list(folder_paths.folder_names_and_paths.keys())
233
+
234
+ return web.json_response(model_types)
235
+
236
+ @routes.get("/models/{folder}")
237
+ async def get_models(request):
238
+ folder = request.match_info.get("folder", None)
239
+ if not folder in folder_paths.folder_names_and_paths:
240
+ return web.Response(status=404)
241
+ files = folder_paths.get_filename_list(folder)
242
+ return web.json_response(files)
243
+
244
+ @routes.get("/extensions")
245
+ async def get_extensions(request):
246
+ files = glob.glob(os.path.join(
247
+ glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
248
+
249
+ extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
250
+
251
+ for name, dir in nodes.EXTENSION_WEB_DIRS.items():
252
+ files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True)
253
+ extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote(
254
+ name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
255
+
256
+ return web.json_response(extensions)
257
+
258
+ def get_dir_by_type(dir_type):
259
+ if dir_type is None:
260
+ dir_type = "input"
261
+
262
+ if dir_type == "input":
263
+ type_dir = folder_paths.get_input_directory()
264
+ elif dir_type == "temp":
265
+ type_dir = folder_paths.get_temp_directory()
266
+ elif dir_type == "output":
267
+ type_dir = folder_paths.get_output_directory()
268
+
269
+ return type_dir, dir_type
270
+
271
+ def compare_image_hash(filepath, image):
272
+ hasher = node_helpers.hasher()
273
+
274
+ # function to compare hashes of two images to see if it already exists, fix to #3465
275
+ if os.path.exists(filepath):
276
+ a = hasher()
277
+ b = hasher()
278
+ with open(filepath, "rb") as f:
279
+ a.update(f.read())
280
+ b.update(image.file.read())
281
+ image.file.seek(0)
282
+ f.close()
283
+ return a.hexdigest() == b.hexdigest()
284
+ return False
285
+
286
+ def image_upload(post, image_save_function=None):
287
+ image = post.get("image")
288
+ overwrite = post.get("overwrite")
289
+ image_is_duplicate = False
290
+
291
+ image_upload_type = post.get("type")
292
+ upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
293
+
294
+ if image and image.file:
295
+ filename = image.filename
296
+ if not filename:
297
+ return web.Response(status=400)
298
+
299
+ subfolder = post.get("subfolder", "")
300
+ full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder))
301
+ filepath = os.path.abspath(os.path.join(full_output_folder, filename))
302
+
303
+ if os.path.commonpath((upload_dir, filepath)) != upload_dir:
304
+ return web.Response(status=400)
305
+
306
+ if not os.path.exists(full_output_folder):
307
+ os.makedirs(full_output_folder)
308
+
309
+ split = os.path.splitext(filename)
310
+
311
+ if overwrite is not None and (overwrite == "true" or overwrite == "1"):
312
+ pass
313
+ else:
314
+ i = 1
315
+ while os.path.exists(filepath):
316
+ if compare_image_hash(filepath, image): #compare hash to prevent saving of duplicates with same name, fix for #3465
317
+ image_is_duplicate = True
318
+ break
319
+ filename = f"{split[0]} ({i}){split[1]}"
320
+ filepath = os.path.join(full_output_folder, filename)
321
+ i += 1
322
+
323
+ if not image_is_duplicate:
324
+ if image_save_function is not None:
325
+ image_save_function(image, post, filepath)
326
+ else:
327
+ with open(filepath, "wb") as f:
328
+ f.write(image.file.read())
329
+
330
+ return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
331
+ else:
332
+ return web.Response(status=400)
333
+
334
+ @routes.post("/upload/image")
335
+ async def upload_image(request):
336
+ post = await request.post()
337
+ return image_upload(post)
338
+
339
+
340
+ @routes.post("/upload/mask")
341
+ async def upload_mask(request):
342
+ post = await request.post()
343
+
344
+ def image_save_function(image, post, filepath):
345
+ original_ref = json.loads(post.get("original_ref"))
346
+ filename, output_dir = folder_paths.annotated_filepath(original_ref['filename'])
347
+
348
+ if not filename:
349
+ return web.Response(status=400)
350
+
351
+ # validation for security: prevent accessing arbitrary path
352
+ if filename[0] == '/' or '..' in filename:
353
+ return web.Response(status=400)
354
+
355
+ if output_dir is None:
356
+ type = original_ref.get("type", "output")
357
+ output_dir = folder_paths.get_directory_by_type(type)
358
+
359
+ if output_dir is None:
360
+ return web.Response(status=400)
361
+
362
+ if original_ref.get("subfolder", "") != "":
363
+ full_output_dir = os.path.join(output_dir, original_ref["subfolder"])
364
+ if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
365
+ return web.Response(status=403)
366
+ output_dir = full_output_dir
367
+
368
+ file = os.path.join(output_dir, filename)
369
+
370
+ if os.path.isfile(file):
371
+ with Image.open(file) as original_pil:
372
+ metadata = PngInfo()
373
+ if hasattr(original_pil,'text'):
374
+ for key in original_pil.text:
375
+ metadata.add_text(key, original_pil.text[key])
376
+ original_pil = original_pil.convert('RGBA')
377
+ mask_pil = Image.open(image.file).convert('RGBA')
378
+
379
+ # alpha copy
380
+ new_alpha = mask_pil.getchannel('A')
381
+ original_pil.putalpha(new_alpha)
382
+ original_pil.save(filepath, compress_level=4, pnginfo=metadata)
383
+
384
+ return image_upload(post, image_save_function)
385
+
386
+ @routes.get("/view")
387
+ async def view_image(request):
388
+ if "filename" in request.rel_url.query:
389
+ filename = request.rel_url.query["filename"]
390
+ filename,output_dir = folder_paths.annotated_filepath(filename)
391
+
392
+ if not filename:
393
+ return web.Response(status=400)
394
+
395
+ # validation for security: prevent accessing arbitrary path
396
+ if filename[0] == '/' or '..' in filename:
397
+ return web.Response(status=400)
398
+
399
+ if output_dir is None:
400
+ type = request.rel_url.query.get("type", "output")
401
+ output_dir = folder_paths.get_directory_by_type(type)
402
+
403
+ if output_dir is None:
404
+ return web.Response(status=400)
405
+
406
+ if "subfolder" in request.rel_url.query:
407
+ full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
408
+ if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
409
+ return web.Response(status=403)
410
+ output_dir = full_output_dir
411
+
412
+ filename = os.path.basename(filename)
413
+ file = os.path.join(output_dir, filename)
414
+
415
+ if os.path.isfile(file):
416
+ if 'preview' in request.rel_url.query:
417
+ with Image.open(file) as img:
418
+ preview_info = request.rel_url.query['preview'].split(';')
419
+ image_format = preview_info[0]
420
+ if image_format not in ['webp', 'jpeg'] or 'a' in request.rel_url.query.get('channel', ''):
421
+ image_format = 'webp'
422
+
423
+ quality = 90
424
+ if preview_info[-1].isdigit():
425
+ quality = int(preview_info[-1])
426
+
427
+ buffer = BytesIO()
428
+ if image_format in ['jpeg'] or request.rel_url.query.get('channel', '') == 'rgb':
429
+ img = img.convert("RGB")
430
+ img.save(buffer, format=image_format, quality=quality)
431
+ buffer.seek(0)
432
+
433
+ return web.Response(body=buffer.read(), content_type=f'image/{image_format}',
434
+ headers={"Content-Disposition": f"filename=\"{filename}\""})
435
+
436
+ if 'channel' not in request.rel_url.query:
437
+ channel = 'rgba'
438
+ else:
439
+ channel = request.rel_url.query["channel"]
440
+
441
+ if channel == 'rgb':
442
+ with Image.open(file) as img:
443
+ if img.mode == "RGBA":
444
+ r, g, b, a = img.split()
445
+ new_img = Image.merge('RGB', (r, g, b))
446
+ else:
447
+ new_img = img.convert("RGB")
448
+
449
+ buffer = BytesIO()
450
+ new_img.save(buffer, format='PNG')
451
+ buffer.seek(0)
452
+
453
+ return web.Response(body=buffer.read(), content_type='image/png',
454
+ headers={"Content-Disposition": f"filename=\"{filename}\""})
455
+
456
+ elif channel == 'a':
457
+ with Image.open(file) as img:
458
+ if img.mode == "RGBA":
459
+ _, _, _, a = img.split()
460
+ else:
461
+ a = Image.new('L', img.size, 255)
462
+
463
+ # alpha img
464
+ alpha_img = Image.new('RGBA', img.size)
465
+ alpha_img.putalpha(a)
466
+ alpha_buffer = BytesIO()
467
+ alpha_img.save(alpha_buffer, format='PNG')
468
+ alpha_buffer.seek(0)
469
+
470
+ return web.Response(body=alpha_buffer.read(), content_type='image/png',
471
+ headers={"Content-Disposition": f"filename=\"{filename}\""})
472
+ else:
473
+ # Get content type from mimetype, defaulting to 'application/octet-stream'
474
+ content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
475
+
476
+ # For security, force certain extensions to download instead of display
477
+ file_extension = os.path.splitext(filename)[1].lower()
478
+ if file_extension in {'.html', '.htm', '.js', '.css'}:
479
+ content_type = 'application/octet-stream' # Forces download
480
+
481
+ return web.FileResponse(
482
+ file,
483
+ headers={
484
+ "Content-Disposition": f"filename=\"{filename}\"",
485
+ "Content-Type": content_type
486
+ }
487
+ )
488
+
489
+ return web.Response(status=404)
490
+
491
+ @routes.get("/view_metadata/{folder_name}")
492
+ async def view_metadata(request):
493
+ folder_name = request.match_info.get("folder_name", None)
494
+ if folder_name is None:
495
+ return web.Response(status=404)
496
+ if not "filename" in request.rel_url.query:
497
+ return web.Response(status=404)
498
+
499
+ filename = request.rel_url.query["filename"]
500
+ if not filename.endswith(".safetensors"):
501
+ return web.Response(status=404)
502
+
503
+ safetensors_path = folder_paths.get_full_path(folder_name, filename)
504
+ if safetensors_path is None:
505
+ return web.Response(status=404)
506
+ out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024)
507
+ if out is None:
508
+ return web.Response(status=404)
509
+ dt = json.loads(out)
510
+ if not "__metadata__" in dt:
511
+ return web.Response(status=404)
512
+ return web.json_response(dt["__metadata__"])
513
+
514
+ @routes.get("/system_stats")
515
+ async def system_stats(request):
516
+ device = comfy.model_management.get_torch_device()
517
+ device_name = comfy.model_management.get_torch_device_name(device)
518
+ cpu_device = comfy.model_management.torch.device("cpu")
519
+ ram_total = comfy.model_management.get_total_memory(cpu_device)
520
+ ram_free = comfy.model_management.get_free_memory(cpu_device)
521
+ vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
522
+ vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
523
+
524
+ system_stats = {
525
+ "system": {
526
+ "os": os.name,
527
+ "ram_total": ram_total,
528
+ "ram_free": ram_free,
529
+ "comfyui_version": __version__,
530
+ "python_version": sys.version,
531
+ "pytorch_version": comfy.model_management.torch_version,
532
+ "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
533
+ "argv": sys.argv
534
+ },
535
+ "devices": [
536
+ {
537
+ "name": device_name,
538
+ "type": device.type,
539
+ "index": device.index,
540
+ "vram_total": vram_total,
541
+ "vram_free": vram_free,
542
+ "torch_vram_total": torch_vram_total,
543
+ "torch_vram_free": torch_vram_free,
544
+ }
545
+ ]
546
+ }
547
+ return web.json_response(system_stats)
548
+
549
+ @routes.get("/prompt")
550
+ async def get_prompt(request):
551
+ return web.json_response(self.get_queue_info())
552
+
553
+ def node_info(node_class):
554
+ obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
555
+ info = {}
556
+ info['input'] = obj_class.INPUT_TYPES()
557
+ info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
558
+ info['output'] = obj_class.RETURN_TYPES
559
+ info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
560
+ info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
561
+ info['name'] = node_class
562
+ info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
563
+ info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else ''
564
+ info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes")
565
+ info['category'] = 'sd'
566
+ if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
567
+ info['output_node'] = True
568
+ else:
569
+ info['output_node'] = False
570
+
571
+ if hasattr(obj_class, 'CATEGORY'):
572
+ info['category'] = obj_class.CATEGORY
573
+
574
+ if hasattr(obj_class, 'OUTPUT_TOOLTIPS'):
575
+ info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS
576
+
577
+ if getattr(obj_class, "DEPRECATED", False):
578
+ info['deprecated'] = True
579
+ if getattr(obj_class, "EXPERIMENTAL", False):
580
+ info['experimental'] = True
581
+ return info
582
+
583
+ @routes.get("/object_info")
584
+ async def get_object_info(request):
585
+ with folder_paths.cache_helper:
586
+ out = {}
587
+ for x in nodes.NODE_CLASS_MAPPINGS:
588
+ try:
589
+ out[x] = node_info(x)
590
+ except Exception:
591
+ logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
592
+ logging.error(traceback.format_exc())
593
+ return web.json_response(out)
594
+
595
+ @routes.get("/object_info/{node_class}")
596
+ async def get_object_info_node(request):
597
+ node_class = request.match_info.get("node_class", None)
598
+ out = {}
599
+ if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
600
+ out[node_class] = node_info(node_class)
601
+ return web.json_response(out)
602
+
603
+ @routes.get("/history")
604
+ async def get_history(request):
605
+ max_items = request.rel_url.query.get("max_items", None)
606
+ if max_items is not None:
607
+ max_items = int(max_items)
608
+ return web.json_response(self.prompt_queue.get_history(max_items=max_items))
609
+
610
+ @routes.get("/history/{prompt_id}")
611
+ async def get_history_prompt_id(request):
612
+ prompt_id = request.match_info.get("prompt_id", None)
613
+ return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))
614
+
615
+ @routes.get("/queue")
616
+ async def get_queue(request):
617
+ queue_info = {}
618
+ current_queue = self.prompt_queue.get_current_queue()
619
+ queue_info['queue_running'] = current_queue[0]
620
+ queue_info['queue_pending'] = current_queue[1]
621
+ return web.json_response(queue_info)
622
+
623
+ @routes.post("/prompt")
624
+ async def post_prompt(request):
625
+ logging.info("got prompt")
626
+ json_data = await request.json()
627
+ json_data = self.trigger_on_prompt(json_data)
628
+
629
+ if "number" in json_data:
630
+ number = float(json_data['number'])
631
+ else:
632
+ number = self.number
633
+ if "front" in json_data:
634
+ if json_data['front']:
635
+ number = -number
636
+
637
+ self.number += 1
638
+
639
+ if "prompt" in json_data:
640
+ prompt = json_data["prompt"]
641
+ valid = execution.validate_prompt(prompt)
642
+ extra_data = {}
643
+ if "extra_data" in json_data:
644
+ extra_data = json_data["extra_data"]
645
+
646
+ if "client_id" in json_data:
647
+ extra_data["client_id"] = json_data["client_id"]
648
+ if valid[0]:
649
+ prompt_id = str(uuid.uuid4())
650
+ outputs_to_execute = valid[2]
651
+ self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
652
+ response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
653
+ return web.json_response(response)
654
+ else:
655
+ logging.warning("invalid prompt: {}".format(valid[1]))
656
+ return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
657
+ else:
658
+ return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
659
+
660
+ @routes.post("/queue")
661
+ async def post_queue(request):
662
+ json_data = await request.json()
663
+ if "clear" in json_data:
664
+ if json_data["clear"]:
665
+ self.prompt_queue.wipe_queue()
666
+ if "delete" in json_data:
667
+ to_delete = json_data['delete']
668
+ for id_to_delete in to_delete:
669
+ delete_func = lambda a: a[1] == id_to_delete
670
+ self.prompt_queue.delete_queue_item(delete_func)
671
+
672
+ return web.Response(status=200)
673
+
674
+ @routes.post("/interrupt")
675
+ async def post_interrupt(request):
676
+ nodes.interrupt_processing()
677
+ return web.Response(status=200)
678
+
679
+ @routes.post("/free")
680
+ async def post_free(request):
681
+ json_data = await request.json()
682
+ unload_models = json_data.get("unload_models", False)
683
+ free_memory = json_data.get("free_memory", False)
684
+ if unload_models:
685
+ self.prompt_queue.set_flag("unload_models", unload_models)
686
+ if free_memory:
687
+ self.prompt_queue.set_flag("free_memory", free_memory)
688
+ return web.Response(status=200)
689
+
690
+ @routes.post("/history")
691
+ async def post_history(request):
692
+ json_data = await request.json()
693
+ if "clear" in json_data:
694
+ if json_data["clear"]:
695
+ self.prompt_queue.wipe_history()
696
+ if "delete" in json_data:
697
+ to_delete = json_data['delete']
698
+ for id_to_delete in to_delete:
699
+ self.prompt_queue.delete_history_item(id_to_delete)
700
+
701
+ return web.Response(status=200)
702
+
703
+ async def setup(self):
704
+ timeout = aiohttp.ClientTimeout(total=None) # no timeout
705
+ self.client_session = aiohttp.ClientSession(timeout=timeout)
706
+
707
+ def add_routes(self):
708
+ self.user_manager.add_routes(self.routes)
709
+ self.model_file_manager.add_routes(self.routes)
710
+ self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
711
+ self.app.add_subapp('/internal', self.internal_routes.get_app())
712
+
713
+ # Prefix every route with /api for easier matching for delegation.
714
+ # This is very useful for frontend dev server, which need to forward
715
+ # everything except serving of static files.
716
+ # Currently both the old endpoints without prefix and new endpoints with
717
+ # prefix are supported.
718
+ api_routes = web.RouteTableDef()
719
+ for route in self.routes:
720
+ # Custom nodes might add extra static routes. Only process non-static
721
+ # routes to add /api prefix.
722
+ if isinstance(route, web.RouteDef):
723
+ api_routes.route(route.method, "/api" + route.path)(route.handler, **route.kwargs)
724
+ self.app.add_routes(api_routes)
725
+ self.app.add_routes(self.routes)
726
+
727
+ # Add routes from web extensions.
728
+ for name, dir in nodes.EXTENSION_WEB_DIRS.items():
729
+ self.app.add_routes([web.static('/extensions/' + name, dir)])
730
+
731
+ self.app.add_routes([
732
+ web.static('/', self.web_root),
733
+ ])
734
+
735
+ def get_queue_info(self):
736
+ prompt_info = {}
737
+ exec_info = {}
738
+ exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
739
+ prompt_info['exec_info'] = exec_info
740
+ return prompt_info
741
+
742
+ async def send(self, event, data, sid=None):
743
+ if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
744
+ await self.send_image(data, sid=sid)
745
+ elif isinstance(data, (bytes, bytearray)):
746
+ await self.send_bytes(event, data, sid)
747
+ else:
748
+ await self.send_json(event, data, sid)
749
+
750
+ def encode_bytes(self, event, data):
751
+ if not isinstance(event, int):
752
+ raise RuntimeError(f"Binary event types must be integers, got {event}")
753
+
754
+ packed = struct.pack(">I", event)
755
+ message = bytearray(packed)
756
+ message.extend(data)
757
+ return message
758
+
759
+ async def send_image(self, image_data, sid=None):
760
+ image_type = image_data[0]
761
+ image = image_data[1]
762
+ max_size = image_data[2]
763
+ if max_size is not None:
764
+ if hasattr(Image, 'Resampling'):
765
+ resampling = Image.Resampling.BILINEAR
766
+ else:
767
+ resampling = Image.ANTIALIAS
768
+
769
+ image = ImageOps.contain(image, (max_size, max_size), resampling)
770
+ type_num = 1
771
+ if image_type == "JPEG":
772
+ type_num = 1
773
+ elif image_type == "PNG":
774
+ type_num = 2
775
+
776
+ bytesIO = BytesIO()
777
+ header = struct.pack(">I", type_num)
778
+ bytesIO.write(header)
779
+ image.save(bytesIO, format=image_type, quality=95, compress_level=1)
780
+ preview_bytes = bytesIO.getvalue()
781
+ await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
782
+
783
+ async def send_bytes(self, event, data, sid=None):
784
+ message = self.encode_bytes(event, data)
785
+
786
+ if sid is None:
787
+ sockets = list(self.sockets.values())
788
+ for ws in sockets:
789
+ await send_socket_catch_exception(ws.send_bytes, message)
790
+ elif sid in self.sockets:
791
+ await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
792
+
793
+ async def send_json(self, event, data, sid=None):
794
+ message = {"type": event, "data": data}
795
+
796
+ if sid is None:
797
+ sockets = list(self.sockets.values())
798
+ for ws in sockets:
799
+ await send_socket_catch_exception(ws.send_json, message)
800
+ elif sid in self.sockets:
801
+ await send_socket_catch_exception(self.sockets[sid].send_json, message)
802
+
803
+ def send_sync(self, event, data, sid=None):
804
+ self.loop.call_soon_threadsafe(
805
+ self.messages.put_nowait, (event, data, sid))
806
+
807
+ def queue_updated(self):
808
+ self.send_sync("status", { "status": self.get_queue_info() })
809
+
810
+ async def publish_loop(self):
811
+ while True:
812
+ msg = await self.messages.get()
813
+ await self.send(*msg)
814
+
815
+ async def start(self, address, port, verbose=True, call_on_start=None):
816
+ await self.start_multi_address([(address, port)], call_on_start=call_on_start)
817
+
818
+ async def start_multi_address(self, addresses, call_on_start=None, verbose=True):
819
+ runner = web.AppRunner(self.app, access_log=None)
820
+ await runner.setup()
821
+ ssl_ctx = None
822
+ scheme = "http"
823
+ if args.tls_keyfile and args.tls_certfile:
824
+ ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE)
825
+ ssl_ctx.load_cert_chain(certfile=args.tls_certfile,
826
+ keyfile=args.tls_keyfile)
827
+ scheme = "https"
828
+
829
+ if verbose:
830
+ logging.info("Starting server\n")
831
+ for addr in addresses:
832
+ address = addr[0]
833
+ port = addr[1]
834
+ site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
835
+ await site.start()
836
+
837
+ if not hasattr(self, 'address'):
838
+ self.address = address #TODO: remove this
839
+ self.port = port
840
+
841
+ if ':' in address:
842
+ address_print = "[{}]".format(address)
843
+ else:
844
+ address_print = address
845
+
846
+ if verbose:
847
+ logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
848
+
849
+ if call_on_start is not None:
850
+ call_on_start(scheme, self.address, self.port)
851
+
852
+ def add_on_prompt_handler(self, handler):
853
+ self.on_prompt_handlers.append(handler)
854
+
855
+ def trigger_on_prompt(self, json_data):
856
+ for handler in self.on_prompt_handlers:
857
+ try:
858
+ json_data = handler(json_data)
859
+ except Exception:
860
+ logging.warning("[ERROR] An error occurred during the on_prompt_handler processing")
861
+ logging.warning(traceback.format_exc())
862
+
863
+ return json_data