plaidam commited on
Commit
1f785e4
·
verified ·
1 Parent(s): be40477

Upload 44 files

Browse files
Files changed (44) hide show
  1. LICENSE +674 -0
  2. app/__init__.py +0 -0
  3. app/app_settings.py +59 -0
  4. app/custom_node_manager.py +34 -0
  5. app/frontend_management.py +204 -0
  6. app/logger.py +84 -0
  7. app/model_manager.py +184 -0
  8. app/user_manager.py +330 -0
  9. execution.py +994 -0
  10. nodes.py +2258 -0
  11. server.py +847 -0
  12. tests-unit/README.md +8 -0
  13. tests-unit/app_test/__init__.py +0 -0
  14. tests-unit/app_test/custom_node_manager_test.py +40 -0
  15. tests-unit/app_test/frontend_manager_test.py +130 -0
  16. tests-unit/app_test/model_manager_test.py +62 -0
  17. tests-unit/comfy_test/folder_path_test.py +98 -0
  18. tests-unit/execution_test/validate_node_input_test.py +119 -0
  19. tests-unit/folder_paths_test/__init__.py +0 -0
  20. tests-unit/folder_paths_test/filter_by_content_types_test.py +52 -0
  21. tests-unit/prompt_server_test/__init__.py +0 -0
  22. tests-unit/prompt_server_test/user_manager_test.py +231 -0
  23. tests-unit/requirements.txt +3 -0
  24. tests-unit/server/routes/internal_routes_test.py +115 -0
  25. tests-unit/server/services/file_service_test.py +54 -0
  26. tests-unit/server/utils/file_operations_test.py +42 -0
  27. tests-unit/utils/extra_config_test.py +303 -0
  28. tests/README.md +29 -0
  29. tests/__init__.py +0 -0
  30. tests/compare/conftest.py +41 -0
  31. tests/compare/test_quality.py +195 -0
  32. tests/conftest.py +36 -0
  33. tests/inference/__init__.py +0 -0
  34. tests/inference/extra_model_paths.yaml +4 -0
  35. tests/inference/graphs/default_graph_sdxl1_0.json +144 -0
  36. tests/inference/test_execution.py +524 -0
  37. tests/inference/test_inference.py +237 -0
  38. tests/inference/testing_nodes/testing-pack/__init__.py +23 -0
  39. tests/inference/testing_nodes/testing-pack/conditions.py +194 -0
  40. tests/inference/testing_nodes/testing-pack/flow_control.py +173 -0
  41. tests/inference/testing_nodes/testing-pack/specific_tests.py +362 -0
  42. tests/inference/testing_nodes/testing-pack/stubs.py +129 -0
  43. tests/inference/testing_nodes/testing-pack/tools.py +53 -0
  44. tests/inference/testing_nodes/testing-pack/util.py +364 -0
LICENSE ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 3, 29 June 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU General Public License is a free, copyleft license for
11
+ software and other kinds of works.
12
+
13
+ The licenses for most software and other practical works are designed
14
+ to take away your freedom to share and change the works. By contrast,
15
+ the GNU General Public License is intended to guarantee your freedom to
16
+ share and change all versions of a program--to make sure it remains free
17
+ software for all its users. We, the Free Software Foundation, use the
18
+ GNU General Public License for most of our software; it applies also to
19
+ any other work released this way by its authors. You can apply it to
20
+ your programs, too.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ To protect your rights, we need to prevent others from denying you
30
+ these rights or asking you to surrender the rights. Therefore, you have
31
+ certain responsibilities if you distribute copies of the software, or if
32
+ you modify it: responsibilities to respect the freedom of others.
33
+
34
+ For example, if you distribute copies of such a program, whether
35
+ gratis or for a fee, you must pass on to the recipients the same
36
+ freedoms that you received. You must make sure that they, too, receive
37
+ or can get the source code. And you must show them these terms so they
38
+ know their rights.
39
+
40
+ Developers that use the GNU GPL protect your rights with two steps:
41
+ (1) assert copyright on the software, and (2) offer you this License
42
+ giving you legal permission to copy, distribute and/or modify it.
43
+
44
+ For the developers' and authors' protection, the GPL clearly explains
45
+ that there is no warranty for this free software. For both users' and
46
+ authors' sake, the GPL requires that modified versions be marked as
47
+ changed, so that their problems will not be attributed erroneously to
48
+ authors of previous versions.
49
+
50
+ Some devices are designed to deny users access to install or run
51
+ modified versions of the software inside them, although the manufacturer
52
+ can do so. This is fundamentally incompatible with the aim of
53
+ protecting users' freedom to change the software. The systematic
54
+ pattern of such abuse occurs in the area of products for individuals to
55
+ use, which is precisely where it is most unacceptable. Therefore, we
56
+ have designed this version of the GPL to prohibit the practice for those
57
+ products. If such problems arise substantially in other domains, we
58
+ stand ready to extend this provision to those domains in future versions
59
+ of the GPL, as needed to protect the freedom of users.
60
+
61
+ Finally, every program is threatened constantly by software patents.
62
+ States should not allow patents to restrict development and use of
63
+ software on general-purpose computers, but in those that do, we wish to
64
+ avoid the special danger that patents applied to a free program could
65
+ make it effectively proprietary. To prevent this, the GPL assures that
66
+ patents cannot be used to render the program non-free.
67
+
68
+ The precise terms and conditions for copying, distribution and
69
+ modification follow.
70
+
71
+ TERMS AND CONDITIONS
72
+
73
+ 0. Definitions.
74
+
75
+ "This License" refers to version 3 of the GNU General Public License.
76
+
77
+ "Copyright" also means copyright-like laws that apply to other kinds of
78
+ works, such as semiconductor masks.
79
+
80
+ "The Program" refers to any copyrightable work licensed under this
81
+ License. Each licensee is addressed as "you". "Licensees" and
82
+ "recipients" may be individuals or organizations.
83
+
84
+ To "modify" a work means to copy from or adapt all or part of the work
85
+ in a fashion requiring copyright permission, other than the making of an
86
+ exact copy. The resulting work is called a "modified version" of the
87
+ earlier work or a work "based on" the earlier work.
88
+
89
+ A "covered work" means either the unmodified Program or a work based
90
+ on the Program.
91
+
92
+ To "propagate" a work means to do anything with it that, without
93
+ permission, would make you directly or secondarily liable for
94
+ infringement under applicable copyright law, except executing it on a
95
+ computer or modifying a private copy. Propagation includes copying,
96
+ distribution (with or without modification), making available to the
97
+ public, and in some countries other activities as well.
98
+
99
+ To "convey" a work means any kind of propagation that enables other
100
+ parties to make or receive copies. Mere interaction with a user through
101
+ a computer network, with no transfer of a copy, is not conveying.
102
+
103
+ An interactive user interface displays "Appropriate Legal Notices"
104
+ to the extent that it includes a convenient and prominently visible
105
+ feature that (1) displays an appropriate copyright notice, and (2)
106
+ tells the user that there is no warranty for the work (except to the
107
+ extent that warranties are provided), that licensees may convey the
108
+ work under this License, and how to view a copy of this License. If
109
+ the interface presents a list of user commands or options, such as a
110
+ menu, a prominent item in the list meets this criterion.
111
+
112
+ 1. Source Code.
113
+
114
+ The "source code" for a work means the preferred form of the work
115
+ for making modifications to it. "Object code" means any non-source
116
+ form of a work.
117
+
118
+ A "Standard Interface" means an interface that either is an official
119
+ standard defined by a recognized standards body, or, in the case of
120
+ interfaces specified for a particular programming language, one that
121
+ is widely used among developers working in that language.
122
+
123
+ The "System Libraries" of an executable work include anything, other
124
+ than the work as a whole, that (a) is included in the normal form of
125
+ packaging a Major Component, but which is not part of that Major
126
+ Component, and (b) serves only to enable use of the work with that
127
+ Major Component, or to implement a Standard Interface for which an
128
+ implementation is available to the public in source code form. A
129
+ "Major Component", in this context, means a major essential component
130
+ (kernel, window system, and so on) of the specific operating system
131
+ (if any) on which the executable work runs, or a compiler used to
132
+ produce the work, or an object code interpreter used to run it.
133
+
134
+ The "Corresponding Source" for a work in object code form means all
135
+ the source code needed to generate, install, and (for an executable
136
+ work) run the object code and to modify the work, including scripts to
137
+ control those activities. However, it does not include the work's
138
+ System Libraries, or general-purpose tools or generally available free
139
+ programs which are used unmodified in performing those activities but
140
+ which are not part of the work. For example, Corresponding Source
141
+ includes interface definition files associated with source files for
142
+ the work, and the source code for shared libraries and dynamically
143
+ linked subprograms that the work is specifically designed to require,
144
+ such as by intimate data communication or control flow between those
145
+ subprograms and other parts of the work.
146
+
147
+ The Corresponding Source need not include anything that users
148
+ can regenerate automatically from other parts of the Corresponding
149
+ Source.
150
+
151
+ The Corresponding Source for a work in source code form is that
152
+ same work.
153
+
154
+ 2. Basic Permissions.
155
+
156
+ All rights granted under this License are granted for the term of
157
+ copyright on the Program, and are irrevocable provided the stated
158
+ conditions are met. This License explicitly affirms your unlimited
159
+ permission to run the unmodified Program. The output from running a
160
+ covered work is covered by this License only if the output, given its
161
+ content, constitutes a covered work. This License acknowledges your
162
+ rights of fair use or other equivalent, as provided by copyright law.
163
+
164
+ You may make, run and propagate covered works that you do not
165
+ convey, without conditions so long as your license otherwise remains
166
+ in force. You may convey covered works to others for the sole purpose
167
+ of having them make modifications exclusively for you, or provide you
168
+ with facilities for running those works, provided that you comply with
169
+ the terms of this License in conveying all material for which you do
170
+ not control copyright. Those thus making or running the covered works
171
+ for you must do so exclusively on your behalf, under your direction
172
+ and control, on terms that prohibit them from making any copies of
173
+ your copyrighted material outside their relationship with you.
174
+
175
+ Conveying under any other circumstances is permitted solely under
176
+ the conditions stated below. Sublicensing is not allowed; section 10
177
+ makes it unnecessary.
178
+
179
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180
+
181
+ No covered work shall be deemed part of an effective technological
182
+ measure under any applicable law fulfilling obligations under article
183
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184
+ similar laws prohibiting or restricting circumvention of such
185
+ measures.
186
+
187
+ When you convey a covered work, you waive any legal power to forbid
188
+ circumvention of technological measures to the extent such circumvention
189
+ is effected by exercising rights under this License with respect to
190
+ the covered work, and you disclaim any intention to limit operation or
191
+ modification of the work as a means of enforcing, against the work's
192
+ users, your or third parties' legal rights to forbid circumvention of
193
+ technological measures.
194
+
195
+ 4. Conveying Verbatim Copies.
196
+
197
+ You may convey verbatim copies of the Program's source code as you
198
+ receive it, in any medium, provided that you conspicuously and
199
+ appropriately publish on each copy an appropriate copyright notice;
200
+ keep intact all notices stating that this License and any
201
+ non-permissive terms added in accord with section 7 apply to the code;
202
+ keep intact all notices of the absence of any warranty; and give all
203
+ recipients a copy of this License along with the Program.
204
+
205
+ You may charge any price or no price for each copy that you convey,
206
+ and you may offer support or warranty protection for a fee.
207
+
208
+ 5. Conveying Modified Source Versions.
209
+
210
+ You may convey a work based on the Program, or the modifications to
211
+ produce it from the Program, in the form of source code under the
212
+ terms of section 4, provided that you also meet all of these conditions:
213
+
214
+ a) The work must carry prominent notices stating that you modified
215
+ it, and giving a relevant date.
216
+
217
+ b) The work must carry prominent notices stating that it is
218
+ released under this License and any conditions added under section
219
+ 7. This requirement modifies the requirement in section 4 to
220
+ "keep intact all notices".
221
+
222
+ c) You must license the entire work, as a whole, under this
223
+ License to anyone who comes into possession of a copy. This
224
+ License will therefore apply, along with any applicable section 7
225
+ additional terms, to the whole of the work, and all its parts,
226
+ regardless of how they are packaged. This License gives no
227
+ permission to license the work in any other way, but it does not
228
+ invalidate such permission if you have separately received it.
229
+
230
+ d) If the work has interactive user interfaces, each must display
231
+ Appropriate Legal Notices; however, if the Program has interactive
232
+ interfaces that do not display Appropriate Legal Notices, your
233
+ work need not make them do so.
234
+
235
+ A compilation of a covered work with other separate and independent
236
+ works, which are not by their nature extensions of the covered work,
237
+ and which are not combined with it such as to form a larger program,
238
+ in or on a volume of a storage or distribution medium, is called an
239
+ "aggregate" if the compilation and its resulting copyright are not
240
+ used to limit the access or legal rights of the compilation's users
241
+ beyond what the individual works permit. Inclusion of a covered work
242
+ in an aggregate does not cause this License to apply to the other
243
+ parts of the aggregate.
244
+
245
+ 6. Conveying Non-Source Forms.
246
+
247
+ You may convey a covered work in object code form under the terms
248
+ of sections 4 and 5, provided that you also convey the
249
+ machine-readable Corresponding Source under the terms of this License,
250
+ in one of these ways:
251
+
252
+ a) Convey the object code in, or embodied in, a physical product
253
+ (including a physical distribution medium), accompanied by the
254
+ Corresponding Source fixed on a durable physical medium
255
+ customarily used for software interchange.
256
+
257
+ b) Convey the object code in, or embodied in, a physical product
258
+ (including a physical distribution medium), accompanied by a
259
+ written offer, valid for at least three years and valid for as
260
+ long as you offer spare parts or customer support for that product
261
+ model, to give anyone who possesses the object code either (1) a
262
+ copy of the Corresponding Source for all the software in the
263
+ product that is covered by this License, on a durable physical
264
+ medium customarily used for software interchange, for a price no
265
+ more than your reasonable cost of physically performing this
266
+ conveying of source, or (2) access to copy the
267
+ Corresponding Source from a network server at no charge.
268
+
269
+ c) Convey individual copies of the object code with a copy of the
270
+ written offer to provide the Corresponding Source. This
271
+ alternative is allowed only occasionally and noncommercially, and
272
+ only if you received the object code with such an offer, in accord
273
+ with subsection 6b.
274
+
275
+ d) Convey the object code by offering access from a designated
276
+ place (gratis or for a charge), and offer equivalent access to the
277
+ Corresponding Source in the same way through the same place at no
278
+ further charge. You need not require recipients to copy the
279
+ Corresponding Source along with the object code. If the place to
280
+ copy the object code is a network server, the Corresponding Source
281
+ may be on a different server (operated by you or a third party)
282
+ that supports equivalent copying facilities, provided you maintain
283
+ clear directions next to the object code saying where to find the
284
+ Corresponding Source. Regardless of what server hosts the
285
+ Corresponding Source, you remain obligated to ensure that it is
286
+ available for as long as needed to satisfy these requirements.
287
+
288
+ e) Convey the object code using peer-to-peer transmission, provided
289
+ you inform other peers where the object code and Corresponding
290
+ Source of the work are being offered to the general public at no
291
+ charge under subsection 6d.
292
+
293
+ A separable portion of the object code, whose source code is excluded
294
+ from the Corresponding Source as a System Library, need not be
295
+ included in conveying the object code work.
296
+
297
+ A "User Product" is either (1) a "consumer product", which means any
298
+ tangible personal property which is normally used for personal, family,
299
+ or household purposes, or (2) anything designed or sold for incorporation
300
+ into a dwelling. In determining whether a product is a consumer product,
301
+ doubtful cases shall be resolved in favor of coverage. For a particular
302
+ product received by a particular user, "normally used" refers to a
303
+ typical or common use of that class of product, regardless of the status
304
+ of the particular user or of the way in which the particular user
305
+ actually uses, or expects or is expected to use, the product. A product
306
+ is a consumer product regardless of whether the product has substantial
307
+ commercial, industrial or non-consumer uses, unless such uses represent
308
+ the only significant mode of use of the product.
309
+
310
+ "Installation Information" for a User Product means any methods,
311
+ procedures, authorization keys, or other information required to install
312
+ and execute modified versions of a covered work in that User Product from
313
+ a modified version of its Corresponding Source. The information must
314
+ suffice to ensure that the continued functioning of the modified object
315
+ code is in no case prevented or interfered with solely because
316
+ modification has been made.
317
+
318
+ If you convey an object code work under this section in, or with, or
319
+ specifically for use in, a User Product, and the conveying occurs as
320
+ part of a transaction in which the right of possession and use of the
321
+ User Product is transferred to the recipient in perpetuity or for a
322
+ fixed term (regardless of how the transaction is characterized), the
323
+ Corresponding Source conveyed under this section must be accompanied
324
+ by the Installation Information. But this requirement does not apply
325
+ if neither you nor any third party retains the ability to install
326
+ modified object code on the User Product (for example, the work has
327
+ been installed in ROM).
328
+
329
+ The requirement to provide Installation Information does not include a
330
+ requirement to continue to provide support service, warranty, or updates
331
+ for a work that has been modified or installed by the recipient, or for
332
+ the User Product in which it has been modified or installed. Access to a
333
+ network may be denied when the modification itself materially and
334
+ adversely affects the operation of the network or violates the rules and
335
+ protocols for communication across the network.
336
+
337
+ Corresponding Source conveyed, and Installation Information provided,
338
+ in accord with this section must be in a format that is publicly
339
+ documented (and with an implementation available to the public in
340
+ source code form), and must require no special password or key for
341
+ unpacking, reading or copying.
342
+
343
+ 7. Additional Terms.
344
+
345
+ "Additional permissions" are terms that supplement the terms of this
346
+ License by making exceptions from one or more of its conditions.
347
+ Additional permissions that are applicable to the entire Program shall
348
+ be treated as though they were included in this License, to the extent
349
+ that they are valid under applicable law. If additional permissions
350
+ apply only to part of the Program, that part may be used separately
351
+ under those permissions, but the entire Program remains governed by
352
+ this License without regard to the additional permissions.
353
+
354
+ When you convey a copy of a covered work, you may at your option
355
+ remove any additional permissions from that copy, or from any part of
356
+ it. (Additional permissions may be written to require their own
357
+ removal in certain cases when you modify the work.) You may place
358
+ additional permissions on material, added by you to a covered work,
359
+ for which you have or can give appropriate copyright permission.
360
+
361
+ Notwithstanding any other provision of this License, for material you
362
+ add to a covered work, you may (if authorized by the copyright holders of
363
+ that material) supplement the terms of this License with terms:
364
+
365
+ a) Disclaiming warranty or limiting liability differently from the
366
+ terms of sections 15 and 16 of this License; or
367
+
368
+ b) Requiring preservation of specified reasonable legal notices or
369
+ author attributions in that material or in the Appropriate Legal
370
+ Notices displayed by works containing it; or
371
+
372
+ c) Prohibiting misrepresentation of the origin of that material, or
373
+ requiring that modified versions of such material be marked in
374
+ reasonable ways as different from the original version; or
375
+
376
+ d) Limiting the use for publicity purposes of names of licensors or
377
+ authors of the material; or
378
+
379
+ e) Declining to grant rights under trademark law for use of some
380
+ trade names, trademarks, or service marks; or
381
+
382
+ f) Requiring indemnification of licensors and authors of that
383
+ material by anyone who conveys the material (or modified versions of
384
+ it) with contractual assumptions of liability to the recipient, for
385
+ any liability that these contractual assumptions directly impose on
386
+ those licensors and authors.
387
+
388
+ All other non-permissive additional terms are considered "further
389
+ restrictions" within the meaning of section 10. If the Program as you
390
+ received it, or any part of it, contains a notice stating that it is
391
+ governed by this License along with a term that is a further
392
+ restriction, you may remove that term. If a license document contains
393
+ a further restriction but permits relicensing or conveying under this
394
+ License, you may add to a covered work material governed by the terms
395
+ of that license document, provided that the further restriction does
396
+ not survive such relicensing or conveying.
397
+
398
+ If you add terms to a covered work in accord with this section, you
399
+ must place, in the relevant source files, a statement of the
400
+ additional terms that apply to those files, or a notice indicating
401
+ where to find the applicable terms.
402
+
403
+ Additional terms, permissive or non-permissive, may be stated in the
404
+ form of a separately written license, or stated as exceptions;
405
+ the above requirements apply either way.
406
+
407
+ 8. Termination.
408
+
409
+ You may not propagate or modify a covered work except as expressly
410
+ provided under this License. Any attempt otherwise to propagate or
411
+ modify it is void, and will automatically terminate your rights under
412
+ this License (including any patent licenses granted under the third
413
+ paragraph of section 11).
414
+
415
+ However, if you cease all violation of this License, then your
416
+ license from a particular copyright holder is reinstated (a)
417
+ provisionally, unless and until the copyright holder explicitly and
418
+ finally terminates your license, and (b) permanently, if the copyright
419
+ holder fails to notify you of the violation by some reasonable means
420
+ prior to 60 days after the cessation.
421
+
422
+ Moreover, your license from a particular copyright holder is
423
+ reinstated permanently if the copyright holder notifies you of the
424
+ violation by some reasonable means, this is the first time you have
425
+ received notice of violation of this License (for any work) from that
426
+ copyright holder, and you cure the violation prior to 30 days after
427
+ your receipt of the notice.
428
+
429
+ Termination of your rights under this section does not terminate the
430
+ licenses of parties who have received copies or rights from you under
431
+ this License. If your rights have been terminated and not permanently
432
+ reinstated, you do not qualify to receive new licenses for the same
433
+ material under section 10.
434
+
435
+ 9. Acceptance Not Required for Having Copies.
436
+
437
+ You are not required to accept this License in order to receive or
438
+ run a copy of the Program. Ancillary propagation of a covered work
439
+ occurring solely as a consequence of using peer-to-peer transmission
440
+ to receive a copy likewise does not require acceptance. However,
441
+ nothing other than this License grants you permission to propagate or
442
+ modify any covered work. These actions infringe copyright if you do
443
+ not accept this License. Therefore, by modifying or propagating a
444
+ covered work, you indicate your acceptance of this License to do so.
445
+
446
+ 10. Automatic Licensing of Downstream Recipients.
447
+
448
+ Each time you convey a covered work, the recipient automatically
449
+ receives a license from the original licensors, to run, modify and
450
+ propagate that work, subject to this License. You are not responsible
451
+ for enforcing compliance by third parties with this License.
452
+
453
+ An "entity transaction" is a transaction transferring control of an
454
+ organization, or substantially all assets of one, or subdividing an
455
+ organization, or merging organizations. If propagation of a covered
456
+ work results from an entity transaction, each party to that
457
+ transaction who receives a copy of the work also receives whatever
458
+ licenses to the work the party's predecessor in interest had or could
459
+ give under the previous paragraph, plus a right to possession of the
460
+ Corresponding Source of the work from the predecessor in interest, if
461
+ the predecessor has it or can get it with reasonable efforts.
462
+
463
+ You may not impose any further restrictions on the exercise of the
464
+ rights granted or affirmed under this License. For example, you may
465
+ not impose a license fee, royalty, or other charge for exercise of
466
+ rights granted under this License, and you may not initiate litigation
467
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
468
+ any patent claim is infringed by making, using, selling, offering for
469
+ sale, or importing the Program or any portion of it.
470
+
471
+ 11. Patents.
472
+
473
+ A "contributor" is a copyright holder who authorizes use under this
474
+ License of the Program or a work on which the Program is based. The
475
+ work thus licensed is called the contributor's "contributor version".
476
+
477
+ A contributor's "essential patent claims" are all patent claims
478
+ owned or controlled by the contributor, whether already acquired or
479
+ hereafter acquired, that would be infringed by some manner, permitted
480
+ by this License, of making, using, or selling its contributor version,
481
+ but do not include claims that would be infringed only as a
482
+ consequence of further modification of the contributor version. For
483
+ purposes of this definition, "control" includes the right to grant
484
+ patent sublicenses in a manner consistent with the requirements of
485
+ this License.
486
+
487
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
488
+ patent license under the contributor's essential patent claims, to
489
+ make, use, sell, offer for sale, import and otherwise run, modify and
490
+ propagate the contents of its contributor version.
491
+
492
+ In the following three paragraphs, a "patent license" is any express
493
+ agreement or commitment, however denominated, not to enforce a patent
494
+ (such as an express permission to practice a patent or covenant not to
495
+ sue for patent infringement). To "grant" such a patent license to a
496
+ party means to make such an agreement or commitment not to enforce a
497
+ patent against the party.
498
+
499
+ If you convey a covered work, knowingly relying on a patent license,
500
+ and the Corresponding Source of the work is not available for anyone
501
+ to copy, free of charge and under the terms of this License, through a
502
+ publicly available network server or other readily accessible means,
503
+ then you must either (1) cause the Corresponding Source to be so
504
+ available, or (2) arrange to deprive yourself of the benefit of the
505
+ patent license for this particular work, or (3) arrange, in a manner
506
+ consistent with the requirements of this License, to extend the patent
507
+ license to downstream recipients. "Knowingly relying" means you have
508
+ actual knowledge that, but for the patent license, your conveying the
509
+ covered work in a country, or your recipient's use of the covered work
510
+ in a country, would infringe one or more identifiable patents in that
511
+ country that you have reason to believe are valid.
512
+
513
+ If, pursuant to or in connection with a single transaction or
514
+ arrangement, you convey, or propagate by procuring conveyance of, a
515
+ covered work, and grant a patent license to some of the parties
516
+ receiving the covered work authorizing them to use, propagate, modify
517
+ or convey a specific copy of the covered work, then the patent license
518
+ you grant is automatically extended to all recipients of the covered
519
+ work and works based on it.
520
+
521
+ A patent license is "discriminatory" if it does not include within
522
+ the scope of its coverage, prohibits the exercise of, or is
523
+ conditioned on the non-exercise of one or more of the rights that are
524
+ specifically granted under this License. You may not convey a covered
525
+ work if you are a party to an arrangement with a third party that is
526
+ in the business of distributing software, under which you make payment
527
+ to the third party based on the extent of your activity of conveying
528
+ the work, and under which the third party grants, to any of the
529
+ parties who would receive the covered work from you, a discriminatory
530
+ patent license (a) in connection with copies of the covered work
531
+ conveyed by you (or copies made from those copies), or (b) primarily
532
+ for and in connection with specific products or compilations that
533
+ contain the covered work, unless you entered into that arrangement,
534
+ or that patent license was granted, prior to 28 March 2007.
535
+
536
+ Nothing in this License shall be construed as excluding or limiting
537
+ any implied license or other defenses to infringement that may
538
+ otherwise be available to you under applicable patent law.
539
+
540
+ 12. No Surrender of Others' Freedom.
541
+
542
+ If conditions are imposed on you (whether by court order, agreement or
543
+ otherwise) that contradict the conditions of this License, they do not
544
+ excuse you from the conditions of this License. If you cannot convey a
545
+ covered work so as to satisfy simultaneously your obligations under this
546
+ License and any other pertinent obligations, then as a consequence you may
547
+ not convey it at all. For example, if you agree to terms that obligate you
548
+ to collect a royalty for further conveying from those to whom you convey
549
+ the Program, the only way you could satisfy both those terms and this
550
+ License would be to refrain entirely from conveying the Program.
551
+
552
+ 13. Use with the GNU Affero General Public License.
553
+
554
+ Notwithstanding any other provision of this License, you have
555
+ permission to link or combine any covered work with a work licensed
556
+ under version 3 of the GNU Affero General Public License into a single
557
+ combined work, and to convey the resulting work. The terms of this
558
+ License will continue to apply to the part which is the covered work,
559
+ but the special requirements of the GNU Affero General Public License,
560
+ section 13, concerning interaction through a network will apply to the
561
+ combination as such.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU General Public License from time to time. Such new versions will
567
+ be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU General Public License for more details.
646
+
647
+ You should have received a copy of the GNU General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If the program does terminal interaction, make it output a short
653
+ notice like this when it starts in an interactive mode:
654
+
655
+ <program> Copyright (C) <year> <name of author>
656
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657
+ This is free software, and you are welcome to redistribute it
658
+ under certain conditions; type `show c' for details.
659
+
660
+ The hypothetical commands `show w' and `show c' should show the appropriate
661
+ parts of the General Public License. Of course, your program's commands
662
+ might be different; for a GUI interface, you would use an "about box".
663
+
664
+ You should also get your employer (if you work as a programmer) or school,
665
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
666
+ For more information on this, and how to apply and follow the GNU GPL, see
667
+ <https://www.gnu.org/licenses/>.
668
+
669
+ The GNU General Public License does not permit incorporating your program
670
+ into proprietary programs. If your program is a subroutine library, you
671
+ may consider it more useful to permit linking proprietary applications with
672
+ the library. If this is what you want to do, use the GNU Lesser General
673
+ Public License instead of this License. But first, please read
674
+ <https://www.gnu.org/licenses/why-not-lgpl.html>.
app/__init__.py ADDED
File without changes
app/app_settings.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from aiohttp import web
4
+ import logging
5
+
6
+
7
+ class AppSettings():
8
+ def __init__(self, user_manager):
9
+ self.user_manager = user_manager
10
+
11
+ def get_settings(self, request):
12
+ file = self.user_manager.get_request_user_filepath(
13
+ request, "comfy.settings.json")
14
+ if os.path.isfile(file):
15
+ try:
16
+ with open(file) as f:
17
+ return json.load(f)
18
+ except:
19
+ logging.error(f"The user settings file is corrupted: {file}")
20
+ return {}
21
+ else:
22
+ return {}
23
+
24
+ def save_settings(self, request, settings):
25
+ file = self.user_manager.get_request_user_filepath(
26
+ request, "comfy.settings.json")
27
+ with open(file, "w") as f:
28
+ f.write(json.dumps(settings, indent=4))
29
+
30
+ def add_routes(self, routes):
31
+ @routes.get("/settings")
32
+ async def get_settings(request):
33
+ return web.json_response(self.get_settings(request))
34
+
35
+ @routes.get("/settings/{id}")
36
+ async def get_setting(request):
37
+ value = None
38
+ settings = self.get_settings(request)
39
+ setting_id = request.match_info.get("id", None)
40
+ if setting_id and setting_id in settings:
41
+ value = settings[setting_id]
42
+ return web.json_response(value)
43
+
44
+ @routes.post("/settings")
45
+ async def post_settings(request):
46
+ settings = self.get_settings(request)
47
+ new_settings = await request.json()
48
+ self.save_settings(request, {**settings, **new_settings})
49
+ return web.Response(status=200)
50
+
51
+ @routes.post("/settings/{id}")
52
+ async def post_setting(request):
53
+ setting_id = request.match_info.get("id", None)
54
+ if not setting_id:
55
+ return web.Response(status=400)
56
+ settings = self.get_settings(request)
57
+ settings[setting_id] = await request.json()
58
+ self.save_settings(request, settings)
59
+ return web.Response(status=200)
app/custom_node_manager.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import folder_paths
5
+ import glob
6
+ from aiohttp import web
7
+
8
+ class CustomNodeManager:
9
+ """
10
+ Placeholder to refactor the custom node management features from ComfyUI-Manager.
11
+ Currently it only contains the custom workflow templates feature.
12
+ """
13
+ def add_routes(self, routes, webapp, loadedModules):
14
+
15
+ @routes.get("/workflow_templates")
16
+ async def get_workflow_templates(request):
17
+ """Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
18
+ files = [
19
+ file
20
+ for folder in folder_paths.get_folder_paths("custom_nodes")
21
+ for file in glob.glob(os.path.join(folder, '*/example_workflows/*.json'))
22
+ ]
23
+ workflow_templates_dict = {} # custom_nodes folder name -> example workflow names
24
+ for file in files:
25
+ custom_nodes_name = os.path.basename(os.path.dirname(os.path.dirname(file)))
26
+ workflow_name = os.path.splitext(os.path.basename(file))[0]
27
+ workflow_templates_dict.setdefault(custom_nodes_name, []).append(workflow_name)
28
+ return web.json_response(workflow_templates_dict)
29
+
30
+ # Serve workflow templates from custom nodes.
31
+ for module_name, module_dir in loadedModules:
32
+ workflows_dir = os.path.join(module_dir, 'example_workflows')
33
+ if os.path.exists(workflows_dir):
34
+ webapp.add_routes([web.static('/api/workflow_templates/' + module_name, workflows_dir)])
app/frontend_management.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import argparse
3
+ import logging
4
+ import os
5
+ import re
6
+ import tempfile
7
+ import zipfile
8
+ from dataclasses import dataclass
9
+ from functools import cached_property
10
+ from pathlib import Path
11
+ from typing import TypedDict, Optional
12
+
13
+ import requests
14
+ from typing_extensions import NotRequired
15
+ from comfy.cli_args import DEFAULT_VERSION_STRING
16
+
17
+
18
+ REQUEST_TIMEOUT = 10 # seconds
19
+
20
+
21
+ class Asset(TypedDict):
22
+ url: str
23
+
24
+
25
+ class Release(TypedDict):
26
+ id: int
27
+ tag_name: str
28
+ name: str
29
+ prerelease: bool
30
+ created_at: str
31
+ published_at: str
32
+ body: str
33
+ assets: NotRequired[list[Asset]]
34
+
35
+
36
+ @dataclass
37
+ class FrontEndProvider:
38
+ owner: str
39
+ repo: str
40
+
41
+ @property
42
+ def folder_name(self) -> str:
43
+ return f"{self.owner}_{self.repo}"
44
+
45
+ @property
46
+ def release_url(self) -> str:
47
+ return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"
48
+
49
+ @cached_property
50
+ def all_releases(self) -> list[Release]:
51
+ releases = []
52
+ api_url = self.release_url
53
+ while api_url:
54
+ response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
55
+ response.raise_for_status() # Raises an HTTPError if the response was an error
56
+ releases.extend(response.json())
57
+ # GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
58
+ if "next" in response.links:
59
+ api_url = response.links["next"]["url"]
60
+ else:
61
+ api_url = None
62
+ return releases
63
+
64
+ @cached_property
65
+ def latest_release(self) -> Release:
66
+ latest_release_url = f"{self.release_url}/latest"
67
+ response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
68
+ response.raise_for_status() # Raises an HTTPError if the response was an error
69
+ return response.json()
70
+
71
+ def get_release(self, version: str) -> Release:
72
+ if version == "latest":
73
+ return self.latest_release
74
+ else:
75
+ for release in self.all_releases:
76
+ if release["tag_name"] in [version, f"v{version}"]:
77
+ return release
78
+ raise ValueError(f"Version {version} not found in releases")
79
+
80
+
81
+ def download_release_asset_zip(release: Release, destination_path: str) -> None:
82
+ """Download dist.zip from github release."""
83
+ asset_url = None
84
+ for asset in release.get("assets", []):
85
+ if asset["name"] == "dist.zip":
86
+ asset_url = asset["url"]
87
+ break
88
+
89
+ if not asset_url:
90
+ raise ValueError("dist.zip not found in the release assets")
91
+
92
+ # Use a temporary file to download the zip content
93
+ with tempfile.TemporaryFile() as tmp_file:
94
+ headers = {"Accept": "application/octet-stream"}
95
+ response = requests.get(
96
+ asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
97
+ )
98
+ response.raise_for_status() # Ensure we got a successful response
99
+
100
+ # Write the content to the temporary file
101
+ tmp_file.write(response.content)
102
+
103
+ # Go back to the beginning of the temporary file
104
+ tmp_file.seek(0)
105
+
106
+ # Extract the zip file content to the destination path
107
+ with zipfile.ZipFile(tmp_file, "r") as zip_ref:
108
+ zip_ref.extractall(destination_path)
109
+
110
+
111
+ class FrontendManager:
112
+ DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
113
+ CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
114
+
115
+ @classmethod
116
+ def parse_version_string(cls, value: str) -> tuple[str, str, str]:
117
+ """
118
+ Args:
119
+ value (str): The version string to parse.
120
+
121
+ Returns:
122
+ tuple[str, str]: A tuple containing provider name and version.
123
+
124
+ Raises:
125
+ argparse.ArgumentTypeError: If the version string is invalid.
126
+ """
127
+ VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
128
+ match_result = re.match(VERSION_PATTERN, value)
129
+ if match_result is None:
130
+ raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
131
+
132
+ return match_result.group(1), match_result.group(2), match_result.group(3)
133
+
134
+ @classmethod
135
+ def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
136
+ """
137
+ Initializes the frontend for the specified version.
138
+
139
+ Args:
140
+ version_string (str): The version string.
141
+ provider (FrontEndProvider, optional): The provider to use. Defaults to None.
142
+
143
+ Returns:
144
+ str: The path to the initialized frontend.
145
+
146
+ Raises:
147
+ Exception: If there is an error during the initialization process.
148
+ main error source might be request timeout or invalid URL.
149
+ """
150
+ if version_string == DEFAULT_VERSION_STRING:
151
+ return cls.DEFAULT_FRONTEND_PATH
152
+
153
+ repo_owner, repo_name, version = cls.parse_version_string(version_string)
154
+
155
+ if version.startswith("v"):
156
+ expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
157
+ if os.path.exists(expected_path):
158
+ logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
159
+ return expected_path
160
+
161
+ logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
162
+
163
+ provider = provider or FrontEndProvider(repo_owner, repo_name)
164
+ release = provider.get_release(version)
165
+
166
+ semantic_version = release["tag_name"].lstrip("v")
167
+ web_root = str(
168
+ Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
169
+ )
170
+ if not os.path.exists(web_root):
171
+ try:
172
+ os.makedirs(web_root, exist_ok=True)
173
+ logging.info(
174
+ "Downloading frontend(%s) version(%s) to (%s)",
175
+ provider.folder_name,
176
+ semantic_version,
177
+ web_root,
178
+ )
179
+ logging.debug(release)
180
+ download_release_asset_zip(release, destination_path=web_root)
181
+ finally:
182
+ # Clean up the directory if it is empty, i.e. the download failed
183
+ if not os.listdir(web_root):
184
+ os.rmdir(web_root)
185
+
186
+ return web_root
187
+
188
+ @classmethod
189
+ def init_frontend(cls, version_string: str) -> str:
190
+ """
191
+ Initializes the frontend with the specified version string.
192
+
193
+ Args:
194
+ version_string (str): The version string to initialize the frontend with.
195
+
196
+ Returns:
197
+ str: The path of the initialized frontend.
198
+ """
199
+ try:
200
+ return cls.init_frontend_unsafe(version_string)
201
+ except Exception as e:
202
+ logging.error("Failed to initialize frontend: %s", e)
203
+ logging.info("Falling back to the default frontend.")
204
+ return cls.DEFAULT_FRONTEND_PATH
app/logger.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ from datetime import datetime
3
+ import io
4
+ import logging
5
+ import sys
6
+ import threading
7
+
8
+ logs = None
9
+ stdout_interceptor = None
10
+ stderr_interceptor = None
11
+
12
+
13
+ class LogInterceptor(io.TextIOWrapper):
14
+ def __init__(self, stream, *args, **kwargs):
15
+ buffer = stream.buffer
16
+ encoding = stream.encoding
17
+ super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
18
+ self._lock = threading.Lock()
19
+ self._flush_callbacks = []
20
+ self._logs_since_flush = []
21
+
22
+ def write(self, data):
23
+ entry = {"t": datetime.now().isoformat(), "m": data}
24
+ with self._lock:
25
+ self._logs_since_flush.append(entry)
26
+
27
+ # Simple handling for cr to overwrite the last output if it isnt a full line
28
+ # else logs just get full of progress messages
29
+ if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
30
+ logs.pop()
31
+ logs.append(entry)
32
+ super().write(data)
33
+
34
+ def flush(self):
35
+ super().flush()
36
+ for cb in self._flush_callbacks:
37
+ cb(self._logs_since_flush)
38
+ self._logs_since_flush = []
39
+
40
+ def on_flush(self, callback):
41
+ self._flush_callbacks.append(callback)
42
+
43
+
44
+ def get_logs():
45
+ return logs
46
+
47
+
48
+ def on_flush(callback):
49
+ if stdout_interceptor is not None:
50
+ stdout_interceptor.on_flush(callback)
51
+ if stderr_interceptor is not None:
52
+ stderr_interceptor.on_flush(callback)
53
+
54
+ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False):
55
+ global logs
56
+ if logs:
57
+ return
58
+
59
+ # Override output streams and log to buffer
60
+ logs = deque(maxlen=capacity)
61
+
62
+ global stdout_interceptor
63
+ global stderr_interceptor
64
+ stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
65
+ stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)
66
+
67
+ # Setup default global logger
68
+ logger = logging.getLogger()
69
+ logger.setLevel(log_level)
70
+
71
+ stream_handler = logging.StreamHandler()
72
+ stream_handler.setFormatter(logging.Formatter("%(message)s"))
73
+
74
+ if use_stdout:
75
+ # Only errors and critical to stderr
76
+ stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR)
77
+
78
+ # Lesser to stdout
79
+ stdout_handler = logging.StreamHandler(sys.stdout)
80
+ stdout_handler.setFormatter(logging.Formatter("%(message)s"))
81
+ stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
82
+ logger.addHandler(stdout_handler)
83
+
84
+ logger.addHandler(stream_handler)
app/model_manager.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import base64
5
+ import json
6
+ import time
7
+ import logging
8
+ import folder_paths
9
+ import glob
10
+ import comfy.utils
11
+ from aiohttp import web
12
+ from PIL import Image
13
+ from io import BytesIO
14
+ from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
15
+
16
+
17
+ class ModelFileManager:
18
+ def __init__(self) -> None:
19
+ self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
20
+
21
+ def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
22
+ return self.cache.get(key, default)
23
+
24
+ def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
25
+ self.cache[key] = value
26
+
27
+ def clear_cache(self):
28
+ self.cache.clear()
29
+
30
+ def add_routes(self, routes):
31
+ # NOTE: This is an experiment to replace `/models`
32
+ @routes.get("/experiment/models")
33
+ async def get_model_folders(request):
34
+ model_types = list(folder_paths.folder_names_and_paths.keys())
35
+ folder_black_list = ["configs", "custom_nodes"]
36
+ output_folders: list[dict] = []
37
+ for folder in model_types:
38
+ if folder in folder_black_list:
39
+ continue
40
+ output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
41
+ return web.json_response(output_folders)
42
+
43
+ # NOTE: This is an experiment to replace `/models/{folder}`
44
+ @routes.get("/experiment/models/{folder}")
45
+ async def get_all_models(request):
46
+ folder = request.match_info.get("folder", None)
47
+ if not folder in folder_paths.folder_names_and_paths:
48
+ return web.Response(status=404)
49
+ files = self.get_model_file_list(folder)
50
+ return web.json_response(files)
51
+
52
+ @routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
53
+ async def get_model_preview(request):
54
+ folder_name = request.match_info.get("folder", None)
55
+ path_index = int(request.match_info.get("path_index", None))
56
+ filename = request.match_info.get("filename", None)
57
+
58
+ if not folder_name in folder_paths.folder_names_and_paths:
59
+ return web.Response(status=404)
60
+
61
+ folders = folder_paths.folder_names_and_paths[folder_name]
62
+ folder = folders[0][path_index]
63
+ full_filename = os.path.join(folder, filename)
64
+
65
+ previews = self.get_model_previews(full_filename)
66
+ default_preview = previews[0] if len(previews) > 0 else None
67
+ if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
68
+ return web.Response(status=404)
69
+
70
+ try:
71
+ with Image.open(default_preview) as img:
72
+ img_bytes = BytesIO()
73
+ img.save(img_bytes, format="WEBP")
74
+ img_bytes.seek(0)
75
+ return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
76
+ except:
77
+ return web.Response(status=404)
78
+
79
+ def get_model_file_list(self, folder_name: str):
80
+ folder_name = map_legacy(folder_name)
81
+ folders = folder_paths.folder_names_and_paths[folder_name]
82
+ output_list: list[dict] = []
83
+
84
+ for index, folder in enumerate(folders[0]):
85
+ if not os.path.isdir(folder):
86
+ continue
87
+ out = self.cache_model_file_list_(folder)
88
+ if out is None:
89
+ out = self.recursive_search_models_(folder, index)
90
+ self.set_cache(folder, out)
91
+ output_list.extend(out[0])
92
+
93
+ return output_list
94
+
95
+ def cache_model_file_list_(self, folder: str):
96
+ model_file_list_cache = self.get_cache(folder)
97
+
98
+ if model_file_list_cache is None:
99
+ return None
100
+ if not os.path.isdir(folder):
101
+ return None
102
+ if os.path.getmtime(folder) != model_file_list_cache[1]:
103
+ return None
104
+ for x in model_file_list_cache[1]:
105
+ time_modified = model_file_list_cache[1][x]
106
+ folder = x
107
+ if os.path.getmtime(folder) != time_modified:
108
+ return None
109
+
110
+ return model_file_list_cache
111
+
112
+ def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
113
+ if not os.path.isdir(directory):
114
+ return [], {}, time.perf_counter()
115
+
116
+ excluded_dir_names = [".git"]
117
+ # TODO use settings
118
+ include_hidden_files = False
119
+
120
+ result: list[str] = []
121
+ dirs: dict[str, float] = {}
122
+
123
+ for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
124
+ subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
125
+ if not include_hidden_files:
126
+ subdirs[:] = [d for d in subdirs if not d.startswith(".")]
127
+ filenames = [f for f in filenames if not f.startswith(".")]
128
+
129
+ filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
130
+
131
+ for file_name in filenames:
132
+ try:
133
+ relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
134
+ result.append(relative_path)
135
+ except:
136
+ logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
137
+ continue
138
+
139
+ for d in subdirs:
140
+ path: str = os.path.join(dirpath, d)
141
+ try:
142
+ dirs[path] = os.path.getmtime(path)
143
+ except FileNotFoundError:
144
+ logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
145
+ continue
146
+
147
+ return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
148
+
149
+ def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
150
+ dirname = os.path.dirname(filepath)
151
+
152
+ if not os.path.exists(dirname):
153
+ return []
154
+
155
+ basename = os.path.splitext(filepath)[0]
156
+ match_files = glob.glob(f"{basename}.*", recursive=False)
157
+ image_files = filter_files_content_types(match_files, "image")
158
+ safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
159
+ safetensors_metadata = {}
160
+
161
+ result: list[str | BytesIO] = []
162
+
163
+ for filename in image_files:
164
+ _basename = os.path.splitext(filename)[0]
165
+ if _basename == basename:
166
+ result.append(filename)
167
+ if _basename == f"{basename}.preview":
168
+ result.append(filename)
169
+
170
+ if safetensors_file:
171
+ safetensors_filepath = os.path.join(dirname, safetensors_file)
172
+ header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
173
+ if header:
174
+ safetensors_metadata = json.loads(header)
175
+ safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
176
+ if safetensors_images:
177
+ safetensors_images = json.loads(safetensors_images)
178
+ for image in safetensors_images:
179
+ result.append(BytesIO(base64.b64decode(image)))
180
+
181
+ return result
182
+
183
+ def __exit__(self, exc_type, exc_value, traceback):
184
+ self.clear_cache()
app/user_manager.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import json
3
+ import os
4
+ import re
5
+ import uuid
6
+ import glob
7
+ import shutil
8
+ import logging
9
+ from aiohttp import web
10
+ from urllib import parse
11
+ from comfy.cli_args import args
12
+ import folder_paths
13
+ from .app_settings import AppSettings
14
+ from typing import TypedDict
15
+
16
+ default_user = "default"
17
+
18
+
19
+ class FileInfo(TypedDict):
20
+ path: str
21
+ size: int
22
+ modified: int
23
+
24
+
25
+ def get_file_info(path: str, relative_to: str) -> FileInfo:
26
+ return {
27
+ "path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
28
+ "size": os.path.getsize(path),
29
+ "modified": os.path.getmtime(path)
30
+ }
31
+
32
+
33
+ class UserManager():
34
+ def __init__(self):
35
+ user_directory = folder_paths.get_user_directory()
36
+
37
+ self.settings = AppSettings(self)
38
+ if not os.path.exists(user_directory):
39
+ os.makedirs(user_directory, exist_ok=True)
40
+ if not args.multi_user:
41
+ logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
42
+ logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
43
+
44
+ if args.multi_user:
45
+ if os.path.isfile(self.get_users_file()):
46
+ with open(self.get_users_file()) as f:
47
+ self.users = json.load(f)
48
+ else:
49
+ self.users = {}
50
+ else:
51
+ self.users = {"default": "default"}
52
+
53
+ def get_users_file(self):
54
+ return os.path.join(folder_paths.get_user_directory(), "users.json")
55
+
56
+ def get_request_user_id(self, request):
57
+ user = "default"
58
+ if args.multi_user and "comfy-user" in request.headers:
59
+ user = request.headers["comfy-user"]
60
+
61
+ if user not in self.users:
62
+ raise KeyError("Unknown user: " + user)
63
+
64
+ return user
65
+
66
+ def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
67
+ user_directory = folder_paths.get_user_directory()
68
+
69
+ if type == "userdata":
70
+ root_dir = user_directory
71
+ else:
72
+ raise KeyError("Unknown filepath type:" + type)
73
+
74
+ user = self.get_request_user_id(request)
75
+ path = user_root = os.path.abspath(os.path.join(root_dir, user))
76
+
77
+ # prevent leaving /{type}
78
+ if os.path.commonpath((root_dir, user_root)) != root_dir:
79
+ return None
80
+
81
+ if file is not None:
82
+ # Check if filename is url encoded
83
+ if "%" in file:
84
+ file = parse.unquote(file)
85
+
86
+ # prevent leaving /{type}/{user}
87
+ path = os.path.abspath(os.path.join(user_root, file))
88
+ if os.path.commonpath((user_root, path)) != user_root:
89
+ return None
90
+
91
+ parent = os.path.split(path)[0]
92
+
93
+ if create_dir and not os.path.exists(parent):
94
+ os.makedirs(parent, exist_ok=True)
95
+
96
+ return path
97
+
98
+ def add_user(self, name):
99
+ name = name.strip()
100
+ if not name:
101
+ raise ValueError("username not provided")
102
+ user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
103
+ user_id = user_id + "_" + str(uuid.uuid4())
104
+
105
+ self.users[user_id] = name
106
+
107
+ with open(self.get_users_file(), "w") as f:
108
+ json.dump(self.users, f)
109
+
110
+ return user_id
111
+
112
+ def add_routes(self, routes):
113
+ self.settings.add_routes(routes)
114
+
115
+ @routes.get("/users")
116
+ async def get_users(request):
117
+ if args.multi_user:
118
+ return web.json_response({"storage": "server", "users": self.users})
119
+ else:
120
+ user_dir = self.get_request_user_filepath(request, None, create_dir=False)
121
+ return web.json_response({
122
+ "storage": "server",
123
+ "migrated": os.path.exists(user_dir)
124
+ })
125
+
126
+ @routes.post("/users")
127
+ async def post_users(request):
128
+ body = await request.json()
129
+ username = body["username"]
130
+ if username in self.users.values():
131
+ return web.json_response({"error": "Duplicate username."}, status=400)
132
+
133
+ user_id = self.add_user(username)
134
+ return web.json_response(user_id)
135
+
136
+ @routes.get("/userdata")
137
+ async def listuserdata(request):
138
+ """
139
+ List user data files in a specified directory.
140
+
141
+ This endpoint allows listing files in a user's data directory, with options for recursion,
142
+ full file information, and path splitting.
143
+
144
+ Query Parameters:
145
+ - dir (required): The directory to list files from.
146
+ - recurse (optional): If "true", recursively list files in subdirectories.
147
+ - full_info (optional): If "true", return detailed file information (path, size, modified time).
148
+ - split (optional): If "true", split file paths into components (only applies when full_info is false).
149
+
150
+ Returns:
151
+ - 400: If 'dir' parameter is missing.
152
+ - 403: If the requested path is not allowed.
153
+ - 404: If the requested directory does not exist.
154
+ - 200: JSON response with the list of files or file information.
155
+
156
+ The response format depends on the query parameters:
157
+ - Default: List of relative file paths.
158
+ - full_info=true: List of dictionaries with file details.
159
+ - split=true (and full_info=false): List of lists, each containing path components.
160
+ """
161
+ directory = request.rel_url.query.get('dir', '')
162
+ if not directory:
163
+ return web.Response(status=400, text="Directory not provided")
164
+
165
+ path = self.get_request_user_filepath(request, directory)
166
+ if not path:
167
+ return web.Response(status=403, text="Invalid directory")
168
+
169
+ if not os.path.exists(path):
170
+ return web.Response(status=404, text="Directory not found")
171
+
172
+ recurse = request.rel_url.query.get('recurse', '').lower() == "true"
173
+ full_info = request.rel_url.query.get('full_info', '').lower() == "true"
174
+ split_path = request.rel_url.query.get('split', '').lower() == "true"
175
+
176
+ # Use different patterns based on whether we're recursing or not
177
+ if recurse:
178
+ pattern = os.path.join(glob.escape(path), '**', '*')
179
+ else:
180
+ pattern = os.path.join(glob.escape(path), '*')
181
+
182
+ def process_full_path(full_path: str) -> FileInfo | str | list[str]:
183
+ if full_info:
184
+ return get_file_info(full_path, path)
185
+
186
+ rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
187
+ if split_path:
188
+ return [rel_path] + rel_path.split('/')
189
+
190
+ return rel_path
191
+
192
+ results = [
193
+ process_full_path(full_path)
194
+ for full_path in glob.glob(pattern, recursive=recurse)
195
+ if os.path.isfile(full_path)
196
+ ]
197
+
198
+ return web.json_response(results)
199
+
200
+ def get_user_data_path(request, check_exists = False, param = "file"):
201
+ file = request.match_info.get(param, None)
202
+ if not file:
203
+ return web.Response(status=400)
204
+
205
+ path = self.get_request_user_filepath(request, file)
206
+ if not path:
207
+ return web.Response(status=403)
208
+
209
+ if check_exists and not os.path.exists(path):
210
+ return web.Response(status=404)
211
+
212
+ return path
213
+
214
+ @routes.get("/userdata/{file}")
215
+ async def getuserdata(request):
216
+ path = get_user_data_path(request, check_exists=True)
217
+ if not isinstance(path, str):
218
+ return path
219
+
220
+ return web.FileResponse(path)
221
+
222
+ @routes.post("/userdata/{file}")
223
+ async def post_userdata(request):
224
+ """
225
+ Upload or update a user data file.
226
+
227
+ This endpoint handles file uploads to a user's data directory, with options for
228
+ controlling overwrite behavior and response format.
229
+
230
+ Query Parameters:
231
+ - overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
232
+ - full_info (optional): If "true", returns detailed file information (path, size, modified time).
233
+ If "false", returns only the relative file path.
234
+
235
+ Path Parameters:
236
+ - file: The target file path (URL encoded if necessary).
237
+
238
+ Returns:
239
+ - 400: If 'file' parameter is missing.
240
+ - 403: If the requested path is not allowed.
241
+ - 409: If overwrite=false and the file already exists.
242
+ - 200: JSON response with either:
243
+ - Full file information (if full_info=true)
244
+ - Relative file path (if full_info=false)
245
+
246
+ The request body should contain the raw file content to be written.
247
+ """
248
+ path = get_user_data_path(request)
249
+ if not isinstance(path, str):
250
+ return path
251
+
252
+ overwrite = request.query.get("overwrite", 'true') != "false"
253
+ full_info = request.query.get('full_info', 'false').lower() == "true"
254
+
255
+ if not overwrite and os.path.exists(path):
256
+ return web.Response(status=409, text="File already exists")
257
+
258
+ body = await request.read()
259
+
260
+ with open(path, "wb") as f:
261
+ f.write(body)
262
+
263
+ user_path = self.get_request_user_filepath(request, None)
264
+ if full_info:
265
+ resp = get_file_info(path, user_path)
266
+ else:
267
+ resp = os.path.relpath(path, user_path)
268
+
269
+ return web.json_response(resp)
270
+
271
+ @routes.delete("/userdata/{file}")
272
+ async def delete_userdata(request):
273
+ path = get_user_data_path(request, check_exists=True)
274
+ if not isinstance(path, str):
275
+ return path
276
+
277
+ os.remove(path)
278
+
279
+ return web.Response(status=204)
280
+
281
+ @routes.post("/userdata/{file}/move/{dest}")
282
+ async def move_userdata(request):
283
+ """
284
+ Move or rename a user data file.
285
+
286
+ This endpoint handles moving or renaming files within a user's data directory, with options for
287
+ controlling overwrite behavior and response format.
288
+
289
+ Path Parameters:
290
+ - file: The source file path (URL encoded if necessary)
291
+ - dest: The destination file path (URL encoded if necessary)
292
+
293
+ Query Parameters:
294
+ - overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
295
+ - full_info (optional): If "true", returns detailed file information (path, size, modified time).
296
+ If "false", returns only the relative file path.
297
+
298
+ Returns:
299
+ - 400: If either 'file' or 'dest' parameter is missing
300
+ - 403: If either requested path is not allowed
301
+ - 404: If the source file does not exist
302
+ - 409: If overwrite=false and the destination file already exists
303
+ - 200: JSON response with either:
304
+ - Full file information (if full_info=true)
305
+ - Relative file path (if full_info=false)
306
+ """
307
+ source = get_user_data_path(request, check_exists=True)
308
+ if not isinstance(source, str):
309
+ return source
310
+
311
+ dest = get_user_data_path(request, check_exists=False, param="dest")
312
+ if not isinstance(source, str):
313
+ return dest
314
+
315
+ overwrite = request.query.get("overwrite", 'true') != "false"
316
+ full_info = request.query.get('full_info', 'false').lower() == "true"
317
+
318
+ if not overwrite and os.path.exists(dest):
319
+ return web.Response(status=409, text="File already exists")
320
+
321
+ logging.info(f"moving '{source}' -> '{dest}'")
322
+ shutil.move(source, dest)
323
+
324
+ user_path = self.get_request_user_filepath(request, None)
325
+ if full_info:
326
+ resp = get_file_info(dest, user_path)
327
+ else:
328
+ resp = os.path.relpath(dest, user_path)
329
+
330
+ return web.json_response(resp)
execution.py ADDED
@@ -0,0 +1,994 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import copy
3
+ import logging
4
+ import threading
5
+ import heapq
6
+ import time
7
+ import traceback
8
+ from enum import Enum
9
+ import inspect
10
+ from typing import List, Literal, NamedTuple, Optional
11
+
12
+ import torch
13
+ import nodes
14
+
15
+ import comfy.model_management
16
+ from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
17
+ from comfy_execution.graph_utils import is_link, GraphBuilder
18
+ from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
19
+ from comfy_execution.validation import validate_node_input
20
+
21
+ class ExecutionResult(Enum):
22
+ SUCCESS = 0
23
+ FAILURE = 1
24
+ PENDING = 2
25
+
26
+ class DuplicateNodeError(Exception):
27
+ pass
28
+
29
+ class IsChangedCache:
30
+ def __init__(self, dynprompt, outputs_cache):
31
+ self.dynprompt = dynprompt
32
+ self.outputs_cache = outputs_cache
33
+ self.is_changed = {}
34
+
35
+ def get(self, node_id):
36
+ if node_id in self.is_changed:
37
+ return self.is_changed[node_id]
38
+
39
+ node = self.dynprompt.get_node(node_id)
40
+ class_type = node["class_type"]
41
+ class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
42
+ if not hasattr(class_def, "IS_CHANGED"):
43
+ self.is_changed[node_id] = False
44
+ return self.is_changed[node_id]
45
+
46
+ if "is_changed" in node:
47
+ self.is_changed[node_id] = node["is_changed"]
48
+ return self.is_changed[node_id]
49
+
50
+ # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
51
+ input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
52
+ try:
53
+ is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
54
+ node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
55
+ except Exception as e:
56
+ logging.warning("WARNING: {}".format(e))
57
+ node["is_changed"] = float("NaN")
58
+ finally:
59
+ self.is_changed[node_id] = node["is_changed"]
60
+ return self.is_changed[node_id]
61
+
62
+ class CacheSet:
63
+ def __init__(self, lru_size=None):
64
+ if lru_size is None or lru_size == 0:
65
+ self.init_classic_cache()
66
+ else:
67
+ self.init_lru_cache(lru_size)
68
+ self.all = [self.outputs, self.ui, self.objects]
69
+
70
+ # Useful for those with ample RAM/VRAM -- allows experimenting without
71
+ # blowing away the cache every time
72
+ def init_lru_cache(self, cache_size):
73
+ self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
74
+ self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
75
+ self.objects = HierarchicalCache(CacheKeySetID)
76
+
77
+ # Performs like the old cache -- dump data ASAP
78
+ def init_classic_cache(self):
79
+ self.outputs = HierarchicalCache(CacheKeySetInputSignature)
80
+ self.ui = HierarchicalCache(CacheKeySetInputSignature)
81
+ self.objects = HierarchicalCache(CacheKeySetID)
82
+
83
+ def recursive_debug_dump(self):
84
+ result = {
85
+ "outputs": self.outputs.recursive_debug_dump(),
86
+ "ui": self.ui.recursive_debug_dump(),
87
+ }
88
+ return result
89
+
90
+ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
91
+ valid_inputs = class_def.INPUT_TYPES()
92
+ input_data_all = {}
93
+ missing_keys = {}
94
+ for x in inputs:
95
+ input_data = inputs[x]
96
+ input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
97
+ def mark_missing():
98
+ missing_keys[x] = True
99
+ input_data_all[x] = (None,)
100
+ if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
101
+ input_unique_id = input_data[0]
102
+ output_index = input_data[1]
103
+ if outputs is None:
104
+ mark_missing()
105
+ continue # This might be a lazily-evaluated input
106
+ cached_output = outputs.get(input_unique_id)
107
+ if cached_output is None:
108
+ mark_missing()
109
+ continue
110
+ if output_index >= len(cached_output):
111
+ mark_missing()
112
+ continue
113
+ obj = cached_output[output_index]
114
+ input_data_all[x] = obj
115
+ elif input_category is not None:
116
+ input_data_all[x] = [input_data]
117
+
118
+ if "hidden" in valid_inputs:
119
+ h = valid_inputs["hidden"]
120
+ for x in h:
121
+ if h[x] == "PROMPT":
122
+ input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
123
+ if h[x] == "DYNPROMPT":
124
+ input_data_all[x] = [dynprompt]
125
+ if h[x] == "EXTRA_PNGINFO":
126
+ input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
127
+ if h[x] == "UNIQUE_ID":
128
+ input_data_all[x] = [unique_id]
129
+ return input_data_all, missing_keys
130
+
131
+ map_node_over_list = None #Don't hook this please
132
+
133
+ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
134
+ # check if node wants the lists
135
+ input_is_list = getattr(obj, "INPUT_IS_LIST", False)
136
+
137
+ if len(input_data_all) == 0:
138
+ max_len_input = 0
139
+ else:
140
+ max_len_input = max(len(x) for x in input_data_all.values())
141
+
142
+ # get a slice of inputs, repeat last input when list isn't long enough
143
+ def slice_dict(d, i):
144
+ return {k: v[i if len(v) > i else -1] for k, v in d.items()}
145
+
146
+ results = []
147
+ def process_inputs(inputs, index=None, input_is_list=False):
148
+ if allow_interrupt:
149
+ nodes.before_node_execution()
150
+ execution_block = None
151
+ for k, v in inputs.items():
152
+ if input_is_list:
153
+ for e in v:
154
+ if isinstance(e, ExecutionBlocker):
155
+ v = e
156
+ break
157
+ if isinstance(v, ExecutionBlocker):
158
+ execution_block = execution_block_cb(v) if execution_block_cb else v
159
+ break
160
+ if execution_block is None:
161
+ if pre_execute_cb is not None and index is not None:
162
+ pre_execute_cb(index)
163
+ results.append(getattr(obj, func)(**inputs))
164
+ else:
165
+ results.append(execution_block)
166
+
167
+ if input_is_list:
168
+ process_inputs(input_data_all, 0, input_is_list=input_is_list)
169
+ elif max_len_input == 0:
170
+ process_inputs({})
171
+ else:
172
+ for i in range(max_len_input):
173
+ input_dict = slice_dict(input_data_all, i)
174
+ process_inputs(input_dict, i)
175
+ return results
176
+
177
+ def merge_result_data(results, obj):
178
+ # check which outputs need concatenating
179
+ output = []
180
+ output_is_list = [False] * len(results[0])
181
+ if hasattr(obj, "OUTPUT_IS_LIST"):
182
+ output_is_list = obj.OUTPUT_IS_LIST
183
+
184
+ # merge node execution results
185
+ for i, is_list in zip(range(len(results[0])), output_is_list):
186
+ if is_list:
187
+ value = []
188
+ for o in results:
189
+ if isinstance(o[i], ExecutionBlocker):
190
+ value.append(o[i])
191
+ else:
192
+ value.extend(o[i])
193
+ output.append(value)
194
+ else:
195
+ output.append([o[i] for o in results])
196
+ return output
197
+
198
+ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
199
+ results = []
200
+ uis = []
201
+ subgraph_results = []
202
+ return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
203
+ has_subgraph = False
204
+ for i in range(len(return_values)):
205
+ r = return_values[i]
206
+ if isinstance(r, dict):
207
+ if 'ui' in r:
208
+ uis.append(r['ui'])
209
+ if 'expand' in r:
210
+ # Perform an expansion, but do not append results
211
+ has_subgraph = True
212
+ new_graph = r['expand']
213
+ result = r.get("result", None)
214
+ if isinstance(result, ExecutionBlocker):
215
+ result = tuple([result] * len(obj.RETURN_TYPES))
216
+ subgraph_results.append((new_graph, result))
217
+ elif 'result' in r:
218
+ result = r.get("result", None)
219
+ if isinstance(result, ExecutionBlocker):
220
+ result = tuple([result] * len(obj.RETURN_TYPES))
221
+ results.append(result)
222
+ subgraph_results.append((None, result))
223
+ else:
224
+ if isinstance(r, ExecutionBlocker):
225
+ r = tuple([r] * len(obj.RETURN_TYPES))
226
+ results.append(r)
227
+ subgraph_results.append((None, r))
228
+
229
+ if has_subgraph:
230
+ output = subgraph_results
231
+ elif len(results) > 0:
232
+ output = merge_result_data(results, obj)
233
+ else:
234
+ output = []
235
+ ui = dict()
236
+ if len(uis) > 0:
237
+ ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
238
+ return output, ui, has_subgraph
239
+
240
+ def format_value(x):
241
+ if x is None:
242
+ return None
243
+ elif isinstance(x, (int, float, bool, str)):
244
+ return x
245
+ else:
246
+ return str(x)
247
+
248
+ def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
249
+ unique_id = current_item
250
+ real_node_id = dynprompt.get_real_node_id(unique_id)
251
+ display_node_id = dynprompt.get_display_node_id(unique_id)
252
+ parent_node_id = dynprompt.get_parent_node_id(unique_id)
253
+ inputs = dynprompt.get_node(unique_id)['inputs']
254
+ class_type = dynprompt.get_node(unique_id)['class_type']
255
+ class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
256
+ if caches.outputs.get(unique_id) is not None:
257
+ if server.client_id is not None:
258
+ cached_output = caches.ui.get(unique_id) or {}
259
+ server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
260
+ return (ExecutionResult.SUCCESS, None, None)
261
+
262
+ input_data_all = None
263
+ try:
264
+ if unique_id in pending_subgraph_results:
265
+ cached_results = pending_subgraph_results[unique_id]
266
+ resolved_outputs = []
267
+ for is_subgraph, result in cached_results:
268
+ if not is_subgraph:
269
+ resolved_outputs.append(result)
270
+ else:
271
+ resolved_output = []
272
+ for r in result:
273
+ if is_link(r):
274
+ source_node, source_output = r[0], r[1]
275
+ node_output = caches.outputs.get(source_node)[source_output]
276
+ for o in node_output:
277
+ resolved_output.append(o)
278
+
279
+ else:
280
+ resolved_output.append(r)
281
+ resolved_outputs.append(tuple(resolved_output))
282
+ output_data = merge_result_data(resolved_outputs, class_def)
283
+ output_ui = []
284
+ has_subgraph = False
285
+ else:
286
+ input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
287
+ if server.client_id is not None:
288
+ server.last_node_id = display_node_id
289
+ server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
290
+
291
+ obj = caches.objects.get(unique_id)
292
+ if obj is None:
293
+ obj = class_def()
294
+ caches.objects.set(unique_id, obj)
295
+
296
+ if hasattr(obj, "check_lazy_status"):
297
+ required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
298
+ required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
299
+ required_inputs = [x for x in required_inputs if isinstance(x,str) and (
300
+ x not in input_data_all or x in missing_keys
301
+ )]
302
+ if len(required_inputs) > 0:
303
+ for i in required_inputs:
304
+ execution_list.make_input_strong_link(unique_id, i)
305
+ return (ExecutionResult.PENDING, None, None)
306
+
307
+ def execution_block_cb(block):
308
+ if block.message is not None:
309
+ mes = {
310
+ "prompt_id": prompt_id,
311
+ "node_id": unique_id,
312
+ "node_type": class_type,
313
+ "executed": list(executed),
314
+
315
+ "exception_message": f"Execution Blocked: {block.message}",
316
+ "exception_type": "ExecutionBlocked",
317
+ "traceback": [],
318
+ "current_inputs": [],
319
+ "current_outputs": [],
320
+ }
321
+ server.send_sync("execution_error", mes, server.client_id)
322
+ return ExecutionBlocker(None)
323
+ else:
324
+ return block
325
+ def pre_execute_cb(call_index):
326
+ GraphBuilder.set_default_prefix(unique_id, call_index, 0)
327
+ output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
328
+ if len(output_ui) > 0:
329
+ caches.ui.set(unique_id, {
330
+ "meta": {
331
+ "node_id": unique_id,
332
+ "display_node": display_node_id,
333
+ "parent_node": parent_node_id,
334
+ "real_node_id": real_node_id,
335
+ },
336
+ "output": output_ui
337
+ })
338
+ if server.client_id is not None:
339
+ server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
340
+ if has_subgraph:
341
+ cached_outputs = []
342
+ new_node_ids = []
343
+ new_output_ids = []
344
+ new_output_links = []
345
+ for i in range(len(output_data)):
346
+ new_graph, node_outputs = output_data[i]
347
+ if new_graph is None:
348
+ cached_outputs.append((False, node_outputs))
349
+ else:
350
+ # Check for conflicts
351
+ for node_id in new_graph.keys():
352
+ if dynprompt.has_node(node_id):
353
+ raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.")
354
+ for node_id, node_info in new_graph.items():
355
+ new_node_ids.append(node_id)
356
+ display_id = node_info.get("override_display_id", unique_id)
357
+ dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
358
+ # Figure out if the newly created node is an output node
359
+ class_type = node_info["class_type"]
360
+ class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
361
+ if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
362
+ new_output_ids.append(node_id)
363
+ for i in range(len(node_outputs)):
364
+ if is_link(node_outputs[i]):
365
+ from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
366
+ new_output_links.append((from_node_id, from_socket))
367
+ cached_outputs.append((True, node_outputs))
368
+ new_node_ids = set(new_node_ids)
369
+ for cache in caches.all:
370
+ cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
371
+ for node_id in new_output_ids:
372
+ execution_list.add_node(node_id)
373
+ for link in new_output_links:
374
+ execution_list.add_strong_link(link[0], link[1], unique_id)
375
+ pending_subgraph_results[unique_id] = cached_outputs
376
+ return (ExecutionResult.PENDING, None, None)
377
+ caches.outputs.set(unique_id, output_data)
378
+ except comfy.model_management.InterruptProcessingException as iex:
379
+ logging.info("Processing interrupted")
380
+
381
+ # skip formatting inputs/outputs
382
+ error_details = {
383
+ "node_id": real_node_id,
384
+ }
385
+
386
+ return (ExecutionResult.FAILURE, error_details, iex)
387
+ except Exception as ex:
388
+ typ, _, tb = sys.exc_info()
389
+ exception_type = full_type_name(typ)
390
+ input_data_formatted = {}
391
+ if input_data_all is not None:
392
+ input_data_formatted = {}
393
+ for name, inputs in input_data_all.items():
394
+ input_data_formatted[name] = [format_value(x) for x in inputs]
395
+
396
+ logging.error(f"!!! Exception during processing !!! {ex}")
397
+ logging.error(traceback.format_exc())
398
+
399
+ error_details = {
400
+ "node_id": real_node_id,
401
+ "exception_message": str(ex),
402
+ "exception_type": exception_type,
403
+ "traceback": traceback.format_tb(tb),
404
+ "current_inputs": input_data_formatted
405
+ }
406
+ if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
407
+ logging.error("Got an OOM, unloading all loaded models.")
408
+ comfy.model_management.unload_all_models()
409
+
410
+ return (ExecutionResult.FAILURE, error_details, ex)
411
+
412
+ executed.add(unique_id)
413
+
414
+ return (ExecutionResult.SUCCESS, None, None)
415
+
416
+ class PromptExecutor:
417
+ def __init__(self, server, lru_size=None):
418
+ self.lru_size = lru_size
419
+ self.server = server
420
+ self.reset()
421
+
422
+ def reset(self):
423
+ self.caches = CacheSet(self.lru_size)
424
+ self.status_messages = []
425
+ self.success = True
426
+
427
+ def add_message(self, event, data: dict, broadcast: bool):
428
+ data = {
429
+ **data,
430
+ "timestamp": int(time.time() * 1000),
431
+ }
432
+ self.status_messages.append((event, data))
433
+ if self.server.client_id is not None or broadcast:
434
+ self.server.send_sync(event, data, self.server.client_id)
435
+
436
+ def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
437
+ node_id = error["node_id"]
438
+ class_type = prompt[node_id]["class_type"]
439
+
440
+ # First, send back the status to the frontend depending
441
+ # on the exception type
442
+ if isinstance(ex, comfy.model_management.InterruptProcessingException):
443
+ mes = {
444
+ "prompt_id": prompt_id,
445
+ "node_id": node_id,
446
+ "node_type": class_type,
447
+ "executed": list(executed),
448
+ }
449
+ self.add_message("execution_interrupted", mes, broadcast=True)
450
+ else:
451
+ mes = {
452
+ "prompt_id": prompt_id,
453
+ "node_id": node_id,
454
+ "node_type": class_type,
455
+ "executed": list(executed),
456
+ "exception_message": error["exception_message"],
457
+ "exception_type": error["exception_type"],
458
+ "traceback": error["traceback"],
459
+ "current_inputs": error["current_inputs"],
460
+ "current_outputs": list(current_outputs),
461
+ }
462
+ self.add_message("execution_error", mes, broadcast=False)
463
+
464
+ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
465
+ nodes.interrupt_processing(False)
466
+
467
+ if "client_id" in extra_data:
468
+ self.server.client_id = extra_data["client_id"]
469
+ else:
470
+ self.server.client_id = None
471
+
472
+ self.status_messages = []
473
+ self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
474
+
475
+ with torch.inference_mode():
476
+ dynamic_prompt = DynamicPrompt(prompt)
477
+ is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
478
+ for cache in self.caches.all:
479
+ cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
480
+ cache.clean_unused()
481
+
482
+ cached_nodes = []
483
+ for node_id in prompt:
484
+ if self.caches.outputs.get(node_id) is not None:
485
+ cached_nodes.append(node_id)
486
+
487
+ comfy.model_management.cleanup_models_gc()
488
+ self.add_message("execution_cached",
489
+ { "nodes": cached_nodes, "prompt_id": prompt_id},
490
+ broadcast=False)
491
+ pending_subgraph_results = {}
492
+ executed = set()
493
+ execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
494
+ current_outputs = self.caches.outputs.all_node_ids()
495
+ for node_id in list(execute_outputs):
496
+ execution_list.add_node(node_id)
497
+
498
+ while not execution_list.is_empty():
499
+ node_id, error, ex = execution_list.stage_node_execution()
500
+ if error is not None:
501
+ self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
502
+ break
503
+
504
+ result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
505
+ self.success = result != ExecutionResult.FAILURE
506
+ if result == ExecutionResult.FAILURE:
507
+ self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
508
+ break
509
+ elif result == ExecutionResult.PENDING:
510
+ execution_list.unstage_node_execution()
511
+ else: # result == ExecutionResult.SUCCESS:
512
+ execution_list.complete_node_execution()
513
+ else:
514
+ # Only execute when the while-loop ends without break
515
+ self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
516
+
517
+ ui_outputs = {}
518
+ meta_outputs = {}
519
+ all_node_ids = self.caches.ui.all_node_ids()
520
+ for node_id in all_node_ids:
521
+ ui_info = self.caches.ui.get(node_id)
522
+ if ui_info is not None:
523
+ ui_outputs[node_id] = ui_info["output"]
524
+ meta_outputs[node_id] = ui_info["meta"]
525
+ self.history_result = {
526
+ "outputs": ui_outputs,
527
+ "meta": meta_outputs,
528
+ }
529
+ self.server.last_node_id = None
530
+ if comfy.model_management.DISABLE_SMART_MEMORY:
531
+ comfy.model_management.unload_all_models()
532
+
533
+
534
+ def validate_inputs(prompt, item, validated):
535
+ unique_id = item
536
+ if unique_id in validated:
537
+ return validated[unique_id]
538
+
539
+ inputs = prompt[unique_id]['inputs']
540
+ class_type = prompt[unique_id]['class_type']
541
+ obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
542
+
543
+ class_inputs = obj_class.INPUT_TYPES()
544
+ valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
545
+
546
+ errors = []
547
+ valid = True
548
+
549
+ validate_function_inputs = []
550
+ validate_has_kwargs = False
551
+ if hasattr(obj_class, "VALIDATE_INPUTS"):
552
+ argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS)
553
+ validate_function_inputs = argspec.args
554
+ validate_has_kwargs = argspec.varkw is not None
555
+ received_types = {}
556
+
557
+ for x in valid_inputs:
558
+ type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
559
+ assert extra_info is not None
560
+ if x not in inputs:
561
+ if input_category == "required":
562
+ error = {
563
+ "type": "required_input_missing",
564
+ "message": "Required input is missing",
565
+ "details": f"{x}",
566
+ "extra_info": {
567
+ "input_name": x
568
+ }
569
+ }
570
+ errors.append(error)
571
+ continue
572
+
573
+ val = inputs[x]
574
+ info = (type_input, extra_info)
575
+ if isinstance(val, list):
576
+ if len(val) != 2:
577
+ error = {
578
+ "type": "bad_linked_input",
579
+ "message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
580
+ "details": f"{x}",
581
+ "extra_info": {
582
+ "input_name": x,
583
+ "input_config": info,
584
+ "received_value": val
585
+ }
586
+ }
587
+ errors.append(error)
588
+ continue
589
+
590
+ o_id = val[0]
591
+ o_class_type = prompt[o_id]['class_type']
592
+ r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
593
+ received_type = r[val[1]]
594
+ received_types[x] = received_type
595
+ if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input):
596
+ details = f"{x}, received_type({received_type}) mismatch input_type({type_input})"
597
+ error = {
598
+ "type": "return_type_mismatch",
599
+ "message": "Return type mismatch between linked nodes",
600
+ "details": details,
601
+ "extra_info": {
602
+ "input_name": x,
603
+ "input_config": info,
604
+ "received_type": received_type,
605
+ "linked_node": val
606
+ }
607
+ }
608
+ errors.append(error)
609
+ continue
610
+ try:
611
+ r = validate_inputs(prompt, o_id, validated)
612
+ if r[0] is False:
613
+ # `r` will be set in `validated[o_id]` already
614
+ valid = False
615
+ continue
616
+ except Exception as ex:
617
+ typ, _, tb = sys.exc_info()
618
+ valid = False
619
+ exception_type = full_type_name(typ)
620
+ reasons = [{
621
+ "type": "exception_during_inner_validation",
622
+ "message": "Exception when validating inner node",
623
+ "details": str(ex),
624
+ "extra_info": {
625
+ "input_name": x,
626
+ "input_config": info,
627
+ "exception_message": str(ex),
628
+ "exception_type": exception_type,
629
+ "traceback": traceback.format_tb(tb),
630
+ "linked_node": val
631
+ }
632
+ }]
633
+ validated[o_id] = (False, reasons, o_id)
634
+ continue
635
+ else:
636
+ try:
637
+ if type_input == "INT":
638
+ val = int(val)
639
+ inputs[x] = val
640
+ if type_input == "FLOAT":
641
+ val = float(val)
642
+ inputs[x] = val
643
+ if type_input == "STRING":
644
+ val = str(val)
645
+ inputs[x] = val
646
+ if type_input == "BOOLEAN":
647
+ val = bool(val)
648
+ inputs[x] = val
649
+ except Exception as ex:
650
+ error = {
651
+ "type": "invalid_input_type",
652
+ "message": f"Failed to convert an input value to a {type_input} value",
653
+ "details": f"{x}, {val}, {ex}",
654
+ "extra_info": {
655
+ "input_name": x,
656
+ "input_config": info,
657
+ "received_value": val,
658
+ "exception_message": str(ex)
659
+ }
660
+ }
661
+ errors.append(error)
662
+ continue
663
+
664
+ if x not in validate_function_inputs and not validate_has_kwargs:
665
+ if "min" in extra_info and val < extra_info["min"]:
666
+ error = {
667
+ "type": "value_smaller_than_min",
668
+ "message": "Value {} smaller than min of {}".format(val, extra_info["min"]),
669
+ "details": f"{x}",
670
+ "extra_info": {
671
+ "input_name": x,
672
+ "input_config": info,
673
+ "received_value": val,
674
+ }
675
+ }
676
+ errors.append(error)
677
+ continue
678
+ if "max" in extra_info and val > extra_info["max"]:
679
+ error = {
680
+ "type": "value_bigger_than_max",
681
+ "message": "Value {} bigger than max of {}".format(val, extra_info["max"]),
682
+ "details": f"{x}",
683
+ "extra_info": {
684
+ "input_name": x,
685
+ "input_config": info,
686
+ "received_value": val,
687
+ }
688
+ }
689
+ errors.append(error)
690
+ continue
691
+
692
+ if isinstance(type_input, list):
693
+ if val not in type_input:
694
+ input_config = info
695
+ list_info = ""
696
+
697
+ # Don't send back gigantic lists like if they're lots of
698
+ # scanned model filepaths
699
+ if len(type_input) > 20:
700
+ list_info = f"(list of length {len(type_input)})"
701
+ input_config = None
702
+ else:
703
+ list_info = str(type_input)
704
+
705
+ error = {
706
+ "type": "value_not_in_list",
707
+ "message": "Value not in list",
708
+ "details": f"{x}: '{val}' not in {list_info}",
709
+ "extra_info": {
710
+ "input_name": x,
711
+ "input_config": input_config,
712
+ "received_value": val,
713
+ }
714
+ }
715
+ errors.append(error)
716
+ continue
717
+
718
+ if len(validate_function_inputs) > 0 or validate_has_kwargs:
719
+ input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
720
+ input_filtered = {}
721
+ for x in input_data_all:
722
+ if x in validate_function_inputs or validate_has_kwargs:
723
+ input_filtered[x] = input_data_all[x]
724
+ if 'input_types' in validate_function_inputs:
725
+ input_filtered['input_types'] = [received_types]
726
+
727
+ #ret = obj_class.VALIDATE_INPUTS(**input_filtered)
728
+ ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
729
+ for x in input_filtered:
730
+ for i, r in enumerate(ret):
731
+ if r is not True and not isinstance(r, ExecutionBlocker):
732
+ details = f"{x}"
733
+ if r is not False:
734
+ details += f" - {str(r)}"
735
+
736
+ error = {
737
+ "type": "custom_validation_failed",
738
+ "message": "Custom validation failed for node",
739
+ "details": details,
740
+ "extra_info": {
741
+ "input_name": x,
742
+ }
743
+ }
744
+ errors.append(error)
745
+ continue
746
+
747
+ if len(errors) > 0 or valid is not True:
748
+ ret = (False, errors, unique_id)
749
+ else:
750
+ ret = (True, [], unique_id)
751
+
752
+ validated[unique_id] = ret
753
+ return ret
754
+
755
+ def full_type_name(klass):
756
+ module = klass.__module__
757
+ if module == 'builtins':
758
+ return klass.__qualname__
759
+ return module + '.' + klass.__qualname__
760
+
761
+ def validate_prompt(prompt):
762
+ outputs = set()
763
+ for x in prompt:
764
+ if 'class_type' not in prompt[x]:
765
+ error = {
766
+ "type": "invalid_prompt",
767
+ "message": "Cannot execute because a node is missing the class_type property.",
768
+ "details": f"Node ID '#{x}'",
769
+ "extra_info": {}
770
+ }
771
+ return (False, error, [], [])
772
+
773
+ class_type = prompt[x]['class_type']
774
+ class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
775
+ if class_ is None:
776
+ error = {
777
+ "type": "invalid_prompt",
778
+ "message": f"Cannot execute because node {class_type} does not exist.",
779
+ "details": f"Node ID '#{x}'",
780
+ "extra_info": {}
781
+ }
782
+ return (False, error, [], [])
783
+
784
+ if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
785
+ outputs.add(x)
786
+
787
+ if len(outputs) == 0:
788
+ error = {
789
+ "type": "prompt_no_outputs",
790
+ "message": "Prompt has no outputs",
791
+ "details": "",
792
+ "extra_info": {}
793
+ }
794
+ return (False, error, [], [])
795
+
796
+ good_outputs = set()
797
+ errors = []
798
+ node_errors = {}
799
+ validated = {}
800
+ for o in outputs:
801
+ valid = False
802
+ reasons = []
803
+ try:
804
+ m = validate_inputs(prompt, o, validated)
805
+ valid = m[0]
806
+ reasons = m[1]
807
+ except Exception as ex:
808
+ typ, _, tb = sys.exc_info()
809
+ valid = False
810
+ exception_type = full_type_name(typ)
811
+ reasons = [{
812
+ "type": "exception_during_validation",
813
+ "message": "Exception when validating node",
814
+ "details": str(ex),
815
+ "extra_info": {
816
+ "exception_type": exception_type,
817
+ "traceback": traceback.format_tb(tb)
818
+ }
819
+ }]
820
+ validated[o] = (False, reasons, o)
821
+
822
+ if valid is True:
823
+ good_outputs.add(o)
824
+ else:
825
+ logging.error(f"Failed to validate prompt for output {o}:")
826
+ if len(reasons) > 0:
827
+ logging.error("* (prompt):")
828
+ for reason in reasons:
829
+ logging.error(f" - {reason['message']}: {reason['details']}")
830
+ errors += [(o, reasons)]
831
+ for node_id, result in validated.items():
832
+ valid = result[0]
833
+ reasons = result[1]
834
+ # If a node upstream has errors, the nodes downstream will also
835
+ # be reported as invalid, but there will be no errors attached.
836
+ # So don't return those nodes as having errors in the response.
837
+ if valid is not True and len(reasons) > 0:
838
+ if node_id not in node_errors:
839
+ class_type = prompt[node_id]['class_type']
840
+ node_errors[node_id] = {
841
+ "errors": reasons,
842
+ "dependent_outputs": [],
843
+ "class_type": class_type
844
+ }
845
+ logging.error(f"* {class_type} {node_id}:")
846
+ for reason in reasons:
847
+ logging.error(f" - {reason['message']}: {reason['details']}")
848
+ node_errors[node_id]["dependent_outputs"].append(o)
849
+ logging.error("Output will be ignored")
850
+
851
+ if len(good_outputs) == 0:
852
+ errors_list = []
853
+ for o, errors in errors:
854
+ for error in errors:
855
+ errors_list.append(f"{error['message']}: {error['details']}")
856
+ errors_list = "\n".join(errors_list)
857
+
858
+ error = {
859
+ "type": "prompt_outputs_failed_validation",
860
+ "message": "Prompt outputs failed validation",
861
+ "details": errors_list,
862
+ "extra_info": {}
863
+ }
864
+
865
+ return (False, error, list(good_outputs), node_errors)
866
+
867
+ return (True, None, list(good_outputs), node_errors)
868
+
869
+ MAXIMUM_HISTORY_SIZE = 10000
870
+
871
+ class PromptQueue:
872
+ def __init__(self, server):
873
+ self.server = server
874
+ self.mutex = threading.RLock()
875
+ self.not_empty = threading.Condition(self.mutex)
876
+ self.task_counter = 0
877
+ self.queue = []
878
+ self.currently_running = {}
879
+ self.history = {}
880
+ self.flags = {}
881
+ server.prompt_queue = self
882
+
883
+ def put(self, item):
884
+ with self.mutex:
885
+ heapq.heappush(self.queue, item)
886
+ self.server.queue_updated()
887
+ self.not_empty.notify()
888
+
889
+ def get(self, timeout=None):
890
+ with self.not_empty:
891
+ while len(self.queue) == 0:
892
+ self.not_empty.wait(timeout=timeout)
893
+ if timeout is not None and len(self.queue) == 0:
894
+ return None
895
+ item = heapq.heappop(self.queue)
896
+ i = self.task_counter
897
+ self.currently_running[i] = copy.deepcopy(item)
898
+ self.task_counter += 1
899
+ self.server.queue_updated()
900
+ return (item, i)
901
+
902
+ class ExecutionStatus(NamedTuple):
903
+ status_str: Literal['success', 'error']
904
+ completed: bool
905
+ messages: List[str]
906
+
907
+ def task_done(self, item_id, history_result,
908
+ status: Optional['PromptQueue.ExecutionStatus']):
909
+ with self.mutex:
910
+ prompt = self.currently_running.pop(item_id)
911
+ if len(self.history) > MAXIMUM_HISTORY_SIZE:
912
+ self.history.pop(next(iter(self.history)))
913
+
914
+ status_dict: Optional[dict] = None
915
+ if status is not None:
916
+ status_dict = copy.deepcopy(status._asdict())
917
+
918
+ self.history[prompt[1]] = {
919
+ "prompt": prompt,
920
+ "outputs": {},
921
+ 'status': status_dict,
922
+ }
923
+ self.history[prompt[1]].update(history_result)
924
+ self.server.queue_updated()
925
+
926
+ def get_current_queue(self):
927
+ with self.mutex:
928
+ out = []
929
+ for x in self.currently_running.values():
930
+ out += [x]
931
+ return (out, copy.deepcopy(self.queue))
932
+
933
+ def get_tasks_remaining(self):
934
+ with self.mutex:
935
+ return len(self.queue) + len(self.currently_running)
936
+
937
+ def wipe_queue(self):
938
+ with self.mutex:
939
+ self.queue = []
940
+ self.server.queue_updated()
941
+
942
+ def delete_queue_item(self, function):
943
+ with self.mutex:
944
+ for x in range(len(self.queue)):
945
+ if function(self.queue[x]):
946
+ if len(self.queue) == 1:
947
+ self.wipe_queue()
948
+ else:
949
+ self.queue.pop(x)
950
+ heapq.heapify(self.queue)
951
+ self.server.queue_updated()
952
+ return True
953
+ return False
954
+
955
+ def get_history(self, prompt_id=None, max_items=None, offset=-1):
956
+ with self.mutex:
957
+ if prompt_id is None:
958
+ out = {}
959
+ i = 0
960
+ if offset < 0 and max_items is not None:
961
+ offset = len(self.history) - max_items
962
+ for k in self.history:
963
+ if i >= offset:
964
+ out[k] = self.history[k]
965
+ if max_items is not None and len(out) >= max_items:
966
+ break
967
+ i += 1
968
+ return out
969
+ elif prompt_id in self.history:
970
+ return {prompt_id: copy.deepcopy(self.history[prompt_id])}
971
+ else:
972
+ return {}
973
+
974
+ def wipe_history(self):
975
+ with self.mutex:
976
+ self.history = {}
977
+
978
+ def delete_history_item(self, id_to_delete):
979
+ with self.mutex:
980
+ self.history.pop(id_to_delete, None)
981
+
982
+ def set_flag(self, name, data):
983
+ with self.mutex:
984
+ self.flags[name] = data
985
+ self.not_empty.notify()
986
+
987
+ def get_flags(self, reset=True):
988
+ with self.mutex:
989
+ if reset:
990
+ ret = self.flags
991
+ self.flags = {}
992
+ return ret
993
+ else:
994
+ return self.flags.copy()
nodes.py ADDED
@@ -0,0 +1,2258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import torch
3
+
4
+ import os
5
+ import sys
6
+ import json
7
+ import hashlib
8
+ import traceback
9
+ import math
10
+ import time
11
+ import random
12
+ import logging
13
+
14
+ from PIL import Image, ImageOps, ImageSequence
15
+ from PIL.PngImagePlugin import PngInfo
16
+
17
+ import numpy as np
18
+ import safetensors.torch
19
+
20
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
21
+
22
+ import comfy.diffusers_load
23
+ import comfy.samplers
24
+ import comfy.sample
25
+ import comfy.sd
26
+ import comfy.utils
27
+ import comfy.controlnet
28
+ from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
29
+
30
+ import comfy.clip_vision
31
+
32
+ import comfy.model_management
33
+ from comfy.cli_args import args
34
+
35
+ import importlib
36
+
37
+ import folder_paths
38
+ import latent_preview
39
+ import node_helpers
40
+
41
+ def before_node_execution():
42
+ comfy.model_management.throw_exception_if_processing_interrupted()
43
+
44
+ def interrupt_processing(value=True):
45
+ comfy.model_management.interrupt_current_processing(value)
46
+
47
+ MAX_RESOLUTION=16384
48
+
49
+ class CLIPTextEncode(ComfyNodeABC):
50
+ @classmethod
51
+ def INPUT_TYPES(s) -> InputTypeDict:
52
+ return {
53
+ "required": {
54
+ "text": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
55
+ "clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."})
56
+ }
57
+ }
58
+ RETURN_TYPES = (IO.CONDITIONING,)
59
+ OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",)
60
+ FUNCTION = "encode"
61
+
62
+ CATEGORY = "conditioning"
63
+ DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
64
+
65
+ def encode(self, clip, text):
66
+ tokens = clip.tokenize(text)
67
+ return (clip.encode_from_tokens_scheduled(tokens), )
68
+
69
+
70
+ class ConditioningCombine:
71
+ @classmethod
72
+ def INPUT_TYPES(s):
73
+ return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
74
+ RETURN_TYPES = ("CONDITIONING",)
75
+ FUNCTION = "combine"
76
+
77
+ CATEGORY = "conditioning"
78
+
79
+ def combine(self, conditioning_1, conditioning_2):
80
+ return (conditioning_1 + conditioning_2, )
81
+
82
+ class ConditioningAverage :
83
+ @classmethod
84
+ def INPUT_TYPES(s):
85
+ return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ),
86
+ "conditioning_to_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
87
+ }}
88
+ RETURN_TYPES = ("CONDITIONING",)
89
+ FUNCTION = "addWeighted"
90
+
91
+ CATEGORY = "conditioning"
92
+
93
+ def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_strength):
94
+ out = []
95
+
96
+ if len(conditioning_from) > 1:
97
+ logging.warning("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
98
+
99
+ cond_from = conditioning_from[0][0]
100
+ pooled_output_from = conditioning_from[0][1].get("pooled_output", None)
101
+
102
+ for i in range(len(conditioning_to)):
103
+ t1 = conditioning_to[i][0]
104
+ pooled_output_to = conditioning_to[i][1].get("pooled_output", pooled_output_from)
105
+ t0 = cond_from[:,:t1.shape[1]]
106
+ if t0.shape[1] < t1.shape[1]:
107
+ t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1)
108
+
109
+ tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength))
110
+ t_to = conditioning_to[i][1].copy()
111
+ if pooled_output_from is not None and pooled_output_to is not None:
112
+ t_to["pooled_output"] = torch.mul(pooled_output_to, conditioning_to_strength) + torch.mul(pooled_output_from, (1.0 - conditioning_to_strength))
113
+ elif pooled_output_from is not None:
114
+ t_to["pooled_output"] = pooled_output_from
115
+
116
+ n = [tw, t_to]
117
+ out.append(n)
118
+ return (out, )
119
+
120
+ class ConditioningConcat:
121
+ @classmethod
122
+ def INPUT_TYPES(s):
123
+ return {"required": {
124
+ "conditioning_to": ("CONDITIONING",),
125
+ "conditioning_from": ("CONDITIONING",),
126
+ }}
127
+ RETURN_TYPES = ("CONDITIONING",)
128
+ FUNCTION = "concat"
129
+
130
+ CATEGORY = "conditioning"
131
+
132
+ def concat(self, conditioning_to, conditioning_from):
133
+ out = []
134
+
135
+ if len(conditioning_from) > 1:
136
+ logging.warning("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
137
+
138
+ cond_from = conditioning_from[0][0]
139
+
140
+ for i in range(len(conditioning_to)):
141
+ t1 = conditioning_to[i][0]
142
+ tw = torch.cat((t1, cond_from),1)
143
+ n = [tw, conditioning_to[i][1].copy()]
144
+ out.append(n)
145
+
146
+ return (out, )
147
+
148
+ class ConditioningSetArea:
149
+ @classmethod
150
+ def INPUT_TYPES(s):
151
+ return {"required": {"conditioning": ("CONDITIONING", ),
152
+ "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
153
+ "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
154
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
155
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
156
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
157
+ }}
158
+ RETURN_TYPES = ("CONDITIONING",)
159
+ FUNCTION = "append"
160
+
161
+ CATEGORY = "conditioning"
162
+
163
+ def append(self, conditioning, width, height, x, y, strength):
164
+ c = node_helpers.conditioning_set_values(conditioning, {"area": (height // 8, width // 8, y // 8, x // 8),
165
+ "strength": strength,
166
+ "set_area_to_bounds": False})
167
+ return (c, )
168
+
169
+ class ConditioningSetAreaPercentage:
170
+ @classmethod
171
+ def INPUT_TYPES(s):
172
+ return {"required": {"conditioning": ("CONDITIONING", ),
173
+ "width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
174
+ "height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
175
+ "x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
176
+ "y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
177
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
178
+ }}
179
+ RETURN_TYPES = ("CONDITIONING",)
180
+ FUNCTION = "append"
181
+
182
+ CATEGORY = "conditioning"
183
+
184
+ def append(self, conditioning, width, height, x, y, strength):
185
+ c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", height, width, y, x),
186
+ "strength": strength,
187
+ "set_area_to_bounds": False})
188
+ return (c, )
189
+
190
+ class ConditioningSetAreaStrength:
191
+ @classmethod
192
+ def INPUT_TYPES(s):
193
+ return {"required": {"conditioning": ("CONDITIONING", ),
194
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
195
+ }}
196
+ RETURN_TYPES = ("CONDITIONING",)
197
+ FUNCTION = "append"
198
+
199
+ CATEGORY = "conditioning"
200
+
201
+ def append(self, conditioning, strength):
202
+ c = node_helpers.conditioning_set_values(conditioning, {"strength": strength})
203
+ return (c, )
204
+
205
+
206
+ class ConditioningSetMask:
207
+ @classmethod
208
+ def INPUT_TYPES(s):
209
+ return {"required": {"conditioning": ("CONDITIONING", ),
210
+ "mask": ("MASK", ),
211
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
212
+ "set_cond_area": (["default", "mask bounds"],),
213
+ }}
214
+ RETURN_TYPES = ("CONDITIONING",)
215
+ FUNCTION = "append"
216
+
217
+ CATEGORY = "conditioning"
218
+
219
+ def append(self, conditioning, mask, set_cond_area, strength):
220
+ set_area_to_bounds = False
221
+ if set_cond_area != "default":
222
+ set_area_to_bounds = True
223
+ if len(mask.shape) < 3:
224
+ mask = mask.unsqueeze(0)
225
+
226
+ c = node_helpers.conditioning_set_values(conditioning, {"mask": mask,
227
+ "set_area_to_bounds": set_area_to_bounds,
228
+ "mask_strength": strength})
229
+ return (c, )
230
+
231
+ class ConditioningZeroOut:
232
+ @classmethod
233
+ def INPUT_TYPES(s):
234
+ return {"required": {"conditioning": ("CONDITIONING", )}}
235
+ RETURN_TYPES = ("CONDITIONING",)
236
+ FUNCTION = "zero_out"
237
+
238
+ CATEGORY = "advanced/conditioning"
239
+
240
+ def zero_out(self, conditioning):
241
+ c = []
242
+ for t in conditioning:
243
+ d = t[1].copy()
244
+ pooled_output = d.get("pooled_output", None)
245
+ if pooled_output is not None:
246
+ d["pooled_output"] = torch.zeros_like(pooled_output)
247
+ n = [torch.zeros_like(t[0]), d]
248
+ c.append(n)
249
+ return (c, )
250
+
251
+ class ConditioningSetTimestepRange:
252
+ @classmethod
253
+ def INPUT_TYPES(s):
254
+ return {"required": {"conditioning": ("CONDITIONING", ),
255
+ "start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
256
+ "end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
257
+ }}
258
+ RETURN_TYPES = ("CONDITIONING",)
259
+ FUNCTION = "set_range"
260
+
261
+ CATEGORY = "advanced/conditioning"
262
+
263
+ def set_range(self, conditioning, start, end):
264
+ c = node_helpers.conditioning_set_values(conditioning, {"start_percent": start,
265
+ "end_percent": end})
266
+ return (c, )
267
+
268
+ class VAEDecode:
269
+ @classmethod
270
+ def INPUT_TYPES(s):
271
+ return {
272
+ "required": {
273
+ "samples": ("LATENT", {"tooltip": "The latent to be decoded."}),
274
+ "vae": ("VAE", {"tooltip": "The VAE model used for decoding the latent."})
275
+ }
276
+ }
277
+ RETURN_TYPES = ("IMAGE",)
278
+ OUTPUT_TOOLTIPS = ("The decoded image.",)
279
+ FUNCTION = "decode"
280
+
281
+ CATEGORY = "latent"
282
+ DESCRIPTION = "Decodes latent images back into pixel space images."
283
+
284
+ def decode(self, vae, samples):
285
+ images = vae.decode(samples["samples"])
286
+ if len(images.shape) == 5: #Combine batches
287
+ images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
288
+ return (images, )
289
+
290
+ class VAEDecodeTiled:
291
+ @classmethod
292
+ def INPUT_TYPES(s):
293
+ return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
294
+ "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}),
295
+ "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
296
+ "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time."}),
297
+ "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}),
298
+ }}
299
+ RETURN_TYPES = ("IMAGE",)
300
+ FUNCTION = "decode"
301
+
302
+ CATEGORY = "_for_testing"
303
+
304
+ def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
305
+ if tile_size < overlap * 4:
306
+ overlap = tile_size // 4
307
+ if temporal_size < temporal_overlap * 2:
308
+ temporal_overlap = temporal_overlap // 2
309
+ temporal_compression = vae.temporal_compression_decode()
310
+ if temporal_compression is not None:
311
+ temporal_size = max(2, temporal_size // temporal_compression)
312
+ temporal_overlap = max(1, min(temporal_size // 2, temporal_overlap // temporal_compression))
313
+ else:
314
+ temporal_size = None
315
+ temporal_overlap = None
316
+
317
+ compression = vae.spacial_compression_decode()
318
+ images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap)
319
+ if len(images.shape) == 5: #Combine batches
320
+ images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
321
+ return (images, )
322
+
323
+ class VAEEncode:
324
+ @classmethod
325
+ def INPUT_TYPES(s):
326
+ return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
327
+ RETURN_TYPES = ("LATENT",)
328
+ FUNCTION = "encode"
329
+
330
+ CATEGORY = "latent"
331
+
332
+ def encode(self, vae, pixels):
333
+ t = vae.encode(pixels[:,:,:,:3])
334
+ return ({"samples":t}, )
335
+
336
+ class VAEEncodeTiled:
337
+ @classmethod
338
+ def INPUT_TYPES(s):
339
+ return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
340
+ "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
341
+ "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
342
+ "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time."}),
343
+ "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}),
344
+ }}
345
+ RETURN_TYPES = ("LATENT",)
346
+ FUNCTION = "encode"
347
+
348
+ CATEGORY = "_for_testing"
349
+
350
+ def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
351
+ t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
352
+ return ({"samples": t}, )
353
+
354
+ class VAEEncodeForInpaint:
355
+ @classmethod
356
+ def INPUT_TYPES(s):
357
+ return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
358
+ RETURN_TYPES = ("LATENT",)
359
+ FUNCTION = "encode"
360
+
361
+ CATEGORY = "latent/inpaint"
362
+
363
+ def encode(self, vae, pixels, mask, grow_mask_by=6):
364
+ x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio
365
+ y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio
366
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
367
+
368
+ pixels = pixels.clone()
369
+ if pixels.shape[1] != x or pixels.shape[2] != y:
370
+ x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2
371
+ y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2
372
+ pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
373
+ mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
374
+
375
+ #grow mask by a few pixels to keep things seamless in latent space
376
+ if grow_mask_by == 0:
377
+ mask_erosion = mask
378
+ else:
379
+ kernel_tensor = torch.ones((1, 1, grow_mask_by, grow_mask_by))
380
+ padding = math.ceil((grow_mask_by - 1) / 2)
381
+
382
+ mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=padding), 0, 1)
383
+
384
+ m = (1.0 - mask.round()).squeeze(1)
385
+ for i in range(3):
386
+ pixels[:,:,:,i] -= 0.5
387
+ pixels[:,:,:,i] *= m
388
+ pixels[:,:,:,i] += 0.5
389
+ t = vae.encode(pixels)
390
+
391
+ return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
392
+
393
+
394
+ class InpaintModelConditioning:
395
+ @classmethod
396
+ def INPUT_TYPES(s):
397
+ return {"required": {"positive": ("CONDITIONING", ),
398
+ "negative": ("CONDITIONING", ),
399
+ "vae": ("VAE", ),
400
+ "pixels": ("IMAGE", ),
401
+ "mask": ("MASK", ),
402
+ "noise_mask": ("BOOLEAN", {"default": True, "tooltip": "Add a noise mask to the latent so sampling will only happen within the mask. Might improve results or completely break things depending on the model."}),
403
+ }}
404
+
405
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
406
+ RETURN_NAMES = ("positive", "negative", "latent")
407
+ FUNCTION = "encode"
408
+
409
+ CATEGORY = "conditioning/inpaint"
410
+
411
+ def encode(self, positive, negative, pixels, vae, mask, noise_mask=True):
412
+ x = (pixels.shape[1] // 8) * 8
413
+ y = (pixels.shape[2] // 8) * 8
414
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
415
+
416
+ orig_pixels = pixels
417
+ pixels = orig_pixels.clone()
418
+ if pixels.shape[1] != x or pixels.shape[2] != y:
419
+ x_offset = (pixels.shape[1] % 8) // 2
420
+ y_offset = (pixels.shape[2] % 8) // 2
421
+ pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
422
+ mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
423
+
424
+ m = (1.0 - mask.round()).squeeze(1)
425
+ for i in range(3):
426
+ pixels[:,:,:,i] -= 0.5
427
+ pixels[:,:,:,i] *= m
428
+ pixels[:,:,:,i] += 0.5
429
+ concat_latent = vae.encode(pixels)
430
+ orig_latent = vae.encode(orig_pixels)
431
+
432
+ out_latent = {}
433
+
434
+ out_latent["samples"] = orig_latent
435
+ if noise_mask:
436
+ out_latent["noise_mask"] = mask
437
+
438
+ out = []
439
+ for conditioning in [positive, negative]:
440
+ c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent,
441
+ "concat_mask": mask})
442
+ out.append(c)
443
+ return (out[0], out[1], out_latent)
444
+
445
+
446
+ class SaveLatent:
447
+ def __init__(self):
448
+ self.output_dir = folder_paths.get_output_directory()
449
+
450
+ @classmethod
451
+ def INPUT_TYPES(s):
452
+ return {"required": { "samples": ("LATENT", ),
453
+ "filename_prefix": ("STRING", {"default": "latents/ComfyUI"})},
454
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
455
+ }
456
+ RETURN_TYPES = ()
457
+ FUNCTION = "save"
458
+
459
+ OUTPUT_NODE = True
460
+
461
+ CATEGORY = "_for_testing"
462
+
463
+ def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
464
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
465
+
466
+ # support save metadata for latent sharing
467
+ prompt_info = ""
468
+ if prompt is not None:
469
+ prompt_info = json.dumps(prompt)
470
+
471
+ metadata = None
472
+ if not args.disable_metadata:
473
+ metadata = {"prompt": prompt_info}
474
+ if extra_pnginfo is not None:
475
+ for x in extra_pnginfo:
476
+ metadata[x] = json.dumps(extra_pnginfo[x])
477
+
478
+ file = f"{filename}_{counter:05}_.latent"
479
+
480
+ results = list()
481
+ results.append({
482
+ "filename": file,
483
+ "subfolder": subfolder,
484
+ "type": "output"
485
+ })
486
+
487
+ file = os.path.join(full_output_folder, file)
488
+
489
+ output = {}
490
+ output["latent_tensor"] = samples["samples"]
491
+ output["latent_format_version_0"] = torch.tensor([])
492
+
493
+ comfy.utils.save_torch_file(output, file, metadata=metadata)
494
+ return { "ui": { "latents": results } }
495
+
496
+
497
+ class LoadLatent:
498
+ @classmethod
499
+ def INPUT_TYPES(s):
500
+ input_dir = folder_paths.get_input_directory()
501
+ files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
502
+ return {"required": {"latent": [sorted(files), ]}, }
503
+
504
+ CATEGORY = "_for_testing"
505
+
506
+ RETURN_TYPES = ("LATENT", )
507
+ FUNCTION = "load"
508
+
509
+ def load(self, latent):
510
+ latent_path = folder_paths.get_annotated_filepath(latent)
511
+ latent = safetensors.torch.load_file(latent_path, device="cpu")
512
+ multiplier = 1.0
513
+ if "latent_format_version_0" not in latent:
514
+ multiplier = 1.0 / 0.18215
515
+ samples = {"samples": latent["latent_tensor"].float() * multiplier}
516
+ return (samples, )
517
+
518
+ @classmethod
519
+ def IS_CHANGED(s, latent):
520
+ image_path = folder_paths.get_annotated_filepath(latent)
521
+ m = hashlib.sha256()
522
+ with open(image_path, 'rb') as f:
523
+ m.update(f.read())
524
+ return m.digest().hex()
525
+
526
+ @classmethod
527
+ def VALIDATE_INPUTS(s, latent):
528
+ if not folder_paths.exists_annotated_filepath(latent):
529
+ return "Invalid latent file: {}".format(latent)
530
+ return True
531
+
532
+
533
+ class CheckpointLoader:
534
+ @classmethod
535
+ def INPUT_TYPES(s):
536
+ return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ),
537
+ "ckpt_name": (folder_paths.get_filename_list("checkpoints"), )}}
538
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
539
+ FUNCTION = "load_checkpoint"
540
+
541
+ CATEGORY = "advanced/loaders"
542
+ DEPRECATED = True
543
+
544
+ def load_checkpoint(self, config_name, ckpt_name):
545
+ config_path = folder_paths.get_full_path("configs", config_name)
546
+ ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
547
+ return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
548
+
549
+ class CheckpointLoaderSimple:
550
+ @classmethod
551
+ def INPUT_TYPES(s):
552
+ return {
553
+ "required": {
554
+ "ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
555
+ }
556
+ }
557
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
558
+ OUTPUT_TOOLTIPS = ("The model used for denoising latents.",
559
+ "The CLIP model used for encoding text prompts.",
560
+ "The VAE model used for encoding and decoding images to and from latent space.")
561
+ FUNCTION = "load_checkpoint"
562
+
563
+ CATEGORY = "loaders"
564
+ DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
565
+
566
+ def load_checkpoint(self, ckpt_name):
567
+ ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
568
+ out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
569
+ return out[:3]
570
+
571
+ class DiffusersLoader:
572
+ @classmethod
573
+ def INPUT_TYPES(cls):
574
+ paths = []
575
+ for search_path in folder_paths.get_folder_paths("diffusers"):
576
+ if os.path.exists(search_path):
577
+ for root, subdir, files in os.walk(search_path, followlinks=True):
578
+ if "model_index.json" in files:
579
+ paths.append(os.path.relpath(root, start=search_path))
580
+
581
+ return {"required": {"model_path": (paths,), }}
582
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
583
+ FUNCTION = "load_checkpoint"
584
+
585
+ CATEGORY = "advanced/loaders/deprecated"
586
+
587
+ def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
588
+ for search_path in folder_paths.get_folder_paths("diffusers"):
589
+ if os.path.exists(search_path):
590
+ path = os.path.join(search_path, model_path)
591
+ if os.path.exists(path):
592
+ model_path = path
593
+ break
594
+
595
+ return comfy.diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
596
+
597
+
598
+ class unCLIPCheckpointLoader:
599
+ @classmethod
600
+ def INPUT_TYPES(s):
601
+ return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
602
+ }}
603
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION")
604
+ FUNCTION = "load_checkpoint"
605
+
606
+ CATEGORY = "loaders"
607
+
608
+ def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
609
+ ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
610
+ out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
611
+ return out
612
+
613
+ class CLIPSetLastLayer:
614
+ @classmethod
615
+ def INPUT_TYPES(s):
616
+ return {"required": { "clip": ("CLIP", ),
617
+ "stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}),
618
+ }}
619
+ RETURN_TYPES = ("CLIP",)
620
+ FUNCTION = "set_last_layer"
621
+
622
+ CATEGORY = "conditioning"
623
+
624
+ def set_last_layer(self, clip, stop_at_clip_layer):
625
+ clip = clip.clone()
626
+ clip.clip_layer(stop_at_clip_layer)
627
+ return (clip,)
628
+
629
+ class LoraLoader:
630
+ def __init__(self):
631
+ self.loaded_lora = None
632
+
633
+ @classmethod
634
+ def INPUT_TYPES(s):
635
+ return {
636
+ "required": {
637
+ "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
638
+ "clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}),
639
+ "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
640
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
641
+ "strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}),
642
+ }
643
+ }
644
+
645
+ RETURN_TYPES = ("MODEL", "CLIP")
646
+ OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.")
647
+ FUNCTION = "load_lora"
648
+
649
+ CATEGORY = "loaders"
650
+ DESCRIPTION = "LoRAs are used to modify diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together."
651
+
652
+ def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
653
+ if strength_model == 0 and strength_clip == 0:
654
+ return (model, clip)
655
+
656
+ lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
657
+ lora = None
658
+ if self.loaded_lora is not None:
659
+ if self.loaded_lora[0] == lora_path:
660
+ lora = self.loaded_lora[1]
661
+ else:
662
+ self.loaded_lora = None
663
+
664
+ if lora is None:
665
+ lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
666
+ self.loaded_lora = (lora_path, lora)
667
+
668
+ model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
669
+ return (model_lora, clip_lora)
670
+
671
+ class LoraLoaderModelOnly(LoraLoader):
672
+ @classmethod
673
+ def INPUT_TYPES(s):
674
+ return {"required": { "model": ("MODEL",),
675
+ "lora_name": (folder_paths.get_filename_list("loras"), ),
676
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
677
+ }}
678
+ RETURN_TYPES = ("MODEL",)
679
+ FUNCTION = "load_lora_model_only"
680
+
681
+ def load_lora_model_only(self, model, lora_name, strength_model):
682
+ return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
683
+
684
+ class VAELoader:
685
+ @staticmethod
686
+ def vae_list():
687
+ vaes = folder_paths.get_filename_list("vae")
688
+ approx_vaes = folder_paths.get_filename_list("vae_approx")
689
+ sdxl_taesd_enc = False
690
+ sdxl_taesd_dec = False
691
+ sd1_taesd_enc = False
692
+ sd1_taesd_dec = False
693
+ sd3_taesd_enc = False
694
+ sd3_taesd_dec = False
695
+ f1_taesd_enc = False
696
+ f1_taesd_dec = False
697
+
698
+ for v in approx_vaes:
699
+ if v.startswith("taesd_decoder."):
700
+ sd1_taesd_dec = True
701
+ elif v.startswith("taesd_encoder."):
702
+ sd1_taesd_enc = True
703
+ elif v.startswith("taesdxl_decoder."):
704
+ sdxl_taesd_dec = True
705
+ elif v.startswith("taesdxl_encoder."):
706
+ sdxl_taesd_enc = True
707
+ elif v.startswith("taesd3_decoder."):
708
+ sd3_taesd_dec = True
709
+ elif v.startswith("taesd3_encoder."):
710
+ sd3_taesd_enc = True
711
+ elif v.startswith("taef1_encoder."):
712
+ f1_taesd_dec = True
713
+ elif v.startswith("taef1_decoder."):
714
+ f1_taesd_enc = True
715
+ if sd1_taesd_dec and sd1_taesd_enc:
716
+ vaes.append("taesd")
717
+ if sdxl_taesd_dec and sdxl_taesd_enc:
718
+ vaes.append("taesdxl")
719
+ if sd3_taesd_dec and sd3_taesd_enc:
720
+ vaes.append("taesd3")
721
+ if f1_taesd_dec and f1_taesd_enc:
722
+ vaes.append("taef1")
723
+ return vaes
724
+
725
+ @staticmethod
726
+ def load_taesd(name):
727
+ sd = {}
728
+ approx_vaes = folder_paths.get_filename_list("vae_approx")
729
+
730
+ encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
731
+ decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
732
+
733
+ enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
734
+ for k in enc:
735
+ sd["taesd_encoder.{}".format(k)] = enc[k]
736
+
737
+ dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
738
+ for k in dec:
739
+ sd["taesd_decoder.{}".format(k)] = dec[k]
740
+
741
+ if name == "taesd":
742
+ sd["vae_scale"] = torch.tensor(0.18215)
743
+ sd["vae_shift"] = torch.tensor(0.0)
744
+ elif name == "taesdxl":
745
+ sd["vae_scale"] = torch.tensor(0.13025)
746
+ sd["vae_shift"] = torch.tensor(0.0)
747
+ elif name == "taesd3":
748
+ sd["vae_scale"] = torch.tensor(1.5305)
749
+ sd["vae_shift"] = torch.tensor(0.0609)
750
+ elif name == "taef1":
751
+ sd["vae_scale"] = torch.tensor(0.3611)
752
+ sd["vae_shift"] = torch.tensor(0.1159)
753
+ return sd
754
+
755
+ @classmethod
756
+ def INPUT_TYPES(s):
757
+ return {"required": { "vae_name": (s.vae_list(), )}}
758
+ RETURN_TYPES = ("VAE",)
759
+ FUNCTION = "load_vae"
760
+
761
+ CATEGORY = "loaders"
762
+
763
+ #TODO: scale factor?
764
+ def load_vae(self, vae_name):
765
+ if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
766
+ sd = self.load_taesd(vae_name)
767
+ else:
768
+ vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
769
+ sd = comfy.utils.load_torch_file(vae_path)
770
+ vae = comfy.sd.VAE(sd=sd)
771
+ return (vae,)
772
+
773
+ class ControlNetLoader:
774
+ @classmethod
775
+ def INPUT_TYPES(s):
776
+ return {"required": { "control_net_name": (folder_paths.get_filename_list("controlnet"), )}}
777
+
778
+ RETURN_TYPES = ("CONTROL_NET",)
779
+ FUNCTION = "load_controlnet"
780
+
781
+ CATEGORY = "loaders"
782
+
783
+ def load_controlnet(self, control_net_name):
784
+ controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
785
+ controlnet = comfy.controlnet.load_controlnet(controlnet_path)
786
+ return (controlnet,)
787
+
788
+ class DiffControlNetLoader:
789
+ @classmethod
790
+ def INPUT_TYPES(s):
791
+ return {"required": { "model": ("MODEL",),
792
+ "control_net_name": (folder_paths.get_filename_list("controlnet"), )}}
793
+
794
+ RETURN_TYPES = ("CONTROL_NET",)
795
+ FUNCTION = "load_controlnet"
796
+
797
+ CATEGORY = "loaders"
798
+
799
+ def load_controlnet(self, model, control_net_name):
800
+ controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
801
+ controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
802
+ return (controlnet,)
803
+
804
+
805
+ class ControlNetApply:
806
+ @classmethod
807
+ def INPUT_TYPES(s):
808
+ return {"required": {"conditioning": ("CONDITIONING", ),
809
+ "control_net": ("CONTROL_NET", ),
810
+ "image": ("IMAGE", ),
811
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01})
812
+ }}
813
+ RETURN_TYPES = ("CONDITIONING",)
814
+ FUNCTION = "apply_controlnet"
815
+
816
+ DEPRECATED = True
817
+ CATEGORY = "conditioning/controlnet"
818
+
819
+ def apply_controlnet(self, conditioning, control_net, image, strength):
820
+ if strength == 0:
821
+ return (conditioning, )
822
+
823
+ c = []
824
+ control_hint = image.movedim(-1,1)
825
+ for t in conditioning:
826
+ n = [t[0], t[1].copy()]
827
+ c_net = control_net.copy().set_cond_hint(control_hint, strength)
828
+ if 'control' in t[1]:
829
+ c_net.set_previous_controlnet(t[1]['control'])
830
+ n[1]['control'] = c_net
831
+ n[1]['control_apply_to_uncond'] = True
832
+ c.append(n)
833
+ return (c, )
834
+
835
+
836
+ class ControlNetApplyAdvanced:
837
+ @classmethod
838
+ def INPUT_TYPES(s):
839
+ return {"required": {"positive": ("CONDITIONING", ),
840
+ "negative": ("CONDITIONING", ),
841
+ "control_net": ("CONTROL_NET", ),
842
+ "image": ("IMAGE", ),
843
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
844
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
845
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
846
+ },
847
+ "optional": {"vae": ("VAE", ),
848
+ }
849
+ }
850
+
851
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING")
852
+ RETURN_NAMES = ("positive", "negative")
853
+ FUNCTION = "apply_controlnet"
854
+
855
+ CATEGORY = "conditioning/controlnet"
856
+
857
+ def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]):
858
+ if strength == 0:
859
+ return (positive, negative)
860
+
861
+ control_hint = image.movedim(-1,1)
862
+ cnets = {}
863
+
864
+ out = []
865
+ for conditioning in [positive, negative]:
866
+ c = []
867
+ for t in conditioning:
868
+ d = t[1].copy()
869
+
870
+ prev_cnet = d.get('control', None)
871
+ if prev_cnet in cnets:
872
+ c_net = cnets[prev_cnet]
873
+ else:
874
+ c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae=vae, extra_concat=extra_concat)
875
+ c_net.set_previous_controlnet(prev_cnet)
876
+ cnets[prev_cnet] = c_net
877
+
878
+ d['control'] = c_net
879
+ d['control_apply_to_uncond'] = False
880
+ n = [t[0], d]
881
+ c.append(n)
882
+ out.append(c)
883
+ return (out[0], out[1])
884
+
885
+
886
+ class UNETLoader:
887
+ @classmethod
888
+ def INPUT_TYPES(s):
889
+ return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
890
+ "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
891
+ }}
892
+ RETURN_TYPES = ("MODEL",)
893
+ FUNCTION = "load_unet"
894
+
895
+ CATEGORY = "advanced/loaders"
896
+
897
+ def load_unet(self, unet_name, weight_dtype):
898
+ model_options = {}
899
+ if weight_dtype == "fp8_e4m3fn":
900
+ model_options["dtype"] = torch.float8_e4m3fn
901
+ elif weight_dtype == "fp8_e4m3fn_fast":
902
+ model_options["dtype"] = torch.float8_e4m3fn
903
+ model_options["fp8_optimizations"] = True
904
+ elif weight_dtype == "fp8_e5m2":
905
+ model_options["dtype"] = torch.float8_e5m2
906
+
907
+ unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
908
+ model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
909
+ return (model,)
910
+
911
+ class CLIPLoader:
912
+ @classmethod
913
+ def INPUT_TYPES(s):
914
+ return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
915
+ "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos"], ),
916
+ },
917
+ "optional": {
918
+ "device": (["default", "cpu"], {"advanced": True}),
919
+ }}
920
+ RETURN_TYPES = ("CLIP",)
921
+ FUNCTION = "load_clip"
922
+
923
+ CATEGORY = "advanced/loaders"
924
+
925
+ DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5\ncosmos: old t5 xxl"
926
+
927
+ def load_clip(self, clip_name, type="stable_diffusion", device="default"):
928
+ if type == "stable_cascade":
929
+ clip_type = comfy.sd.CLIPType.STABLE_CASCADE
930
+ elif type == "sd3":
931
+ clip_type = comfy.sd.CLIPType.SD3
932
+ elif type == "stable_audio":
933
+ clip_type = comfy.sd.CLIPType.STABLE_AUDIO
934
+ elif type == "mochi":
935
+ clip_type = comfy.sd.CLIPType.MOCHI
936
+ elif type == "ltxv":
937
+ clip_type = comfy.sd.CLIPType.LTXV
938
+ elif type == "pixart":
939
+ clip_type = comfy.sd.CLIPType.PIXART
940
+ else:
941
+ clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
942
+
943
+ model_options = {}
944
+ if device == "cpu":
945
+ model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
946
+
947
+ clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
948
+ clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
949
+ return (clip,)
950
+
951
+ class DualCLIPLoader:
952
+ @classmethod
953
+ def INPUT_TYPES(s):
954
+ return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
955
+ "clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
956
+ "type": (["sdxl", "sd3", "flux", "hunyuan_video"], ),
957
+ },
958
+ "optional": {
959
+ "device": (["default", "cpu"], {"advanced": True}),
960
+ }}
961
+ RETURN_TYPES = ("CLIP",)
962
+ FUNCTION = "load_clip"
963
+
964
+ CATEGORY = "advanced/loaders"
965
+
966
+ DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
967
+
968
+ def load_clip(self, clip_name1, clip_name2, type, device="default"):
969
+ clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
970
+ clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
971
+ if type == "sdxl":
972
+ clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
973
+ elif type == "sd3":
974
+ clip_type = comfy.sd.CLIPType.SD3
975
+ elif type == "flux":
976
+ clip_type = comfy.sd.CLIPType.FLUX
977
+ elif type == "hunyuan_video":
978
+ clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO
979
+
980
+ model_options = {}
981
+ if device == "cpu":
982
+ model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
983
+
984
+ clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
985
+ return (clip,)
986
+
987
+ class CLIPVisionLoader:
988
+ @classmethod
989
+ def INPUT_TYPES(s):
990
+ return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"), ),
991
+ }}
992
+ RETURN_TYPES = ("CLIP_VISION",)
993
+ FUNCTION = "load_clip"
994
+
995
+ CATEGORY = "loaders"
996
+
997
+ def load_clip(self, clip_name):
998
+ clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
999
+ clip_vision = comfy.clip_vision.load(clip_path)
1000
+ return (clip_vision,)
1001
+
1002
+ class CLIPVisionEncode:
1003
+ @classmethod
1004
+ def INPUT_TYPES(s):
1005
+ return {"required": { "clip_vision": ("CLIP_VISION",),
1006
+ "image": ("IMAGE",),
1007
+ "crop": (["center", "none"],)
1008
+ }}
1009
+ RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
1010
+ FUNCTION = "encode"
1011
+
1012
+ CATEGORY = "conditioning"
1013
+
1014
+ def encode(self, clip_vision, image, crop):
1015
+ crop_image = True
1016
+ if crop != "center":
1017
+ crop_image = False
1018
+ output = clip_vision.encode_image(image, crop=crop_image)
1019
+ return (output,)
1020
+
1021
+ class StyleModelLoader:
1022
+ @classmethod
1023
+ def INPUT_TYPES(s):
1024
+ return {"required": { "style_model_name": (folder_paths.get_filename_list("style_models"), )}}
1025
+
1026
+ RETURN_TYPES = ("STYLE_MODEL",)
1027
+ FUNCTION = "load_style_model"
1028
+
1029
+ CATEGORY = "loaders"
1030
+
1031
+ def load_style_model(self, style_model_name):
1032
+ style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name)
1033
+ style_model = comfy.sd.load_style_model(style_model_path)
1034
+ return (style_model,)
1035
+
1036
+
1037
+ class StyleModelApply:
1038
+ @classmethod
1039
+ def INPUT_TYPES(s):
1040
+ return {"required": {"conditioning": ("CONDITIONING", ),
1041
+ "style_model": ("STYLE_MODEL", ),
1042
+ "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
1043
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
1044
+ "strength_type": (["multiply", "attn_bias"], ),
1045
+ }}
1046
+ RETURN_TYPES = ("CONDITIONING",)
1047
+ FUNCTION = "apply_stylemodel"
1048
+
1049
+ CATEGORY = "conditioning/style_model"
1050
+
1051
+ def apply_stylemodel(self, conditioning, style_model, clip_vision_output, strength, strength_type):
1052
+ cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
1053
+ if strength_type == "multiply":
1054
+ cond *= strength
1055
+
1056
+ n = cond.shape[1]
1057
+ c_out = []
1058
+ for t in conditioning:
1059
+ (txt, keys) = t
1060
+ keys = keys.copy()
1061
+ if strength_type == "attn_bias" and strength != 1.0:
1062
+ # math.log raises an error if the argument is zero
1063
+ # torch.log returns -inf, which is what we want
1064
+ attn_bias = torch.log(torch.Tensor([strength]))
1065
+ # get the size of the mask image
1066
+ mask_ref_size = keys.get("attention_mask_img_shape", (1, 1))
1067
+ n_ref = mask_ref_size[0] * mask_ref_size[1]
1068
+ n_txt = txt.shape[1]
1069
+ # grab the existing mask
1070
+ mask = keys.get("attention_mask", None)
1071
+ # create a default mask if it doesn't exist
1072
+ if mask is None:
1073
+ mask = torch.zeros((txt.shape[0], n_txt + n_ref, n_txt + n_ref), dtype=torch.float16)
1074
+ # convert the mask dtype, because it might be boolean
1075
+ # we want it to be interpreted as a bias
1076
+ if mask.dtype == torch.bool:
1077
+ # log(True) = log(1) = 0
1078
+ # log(False) = log(0) = -inf
1079
+ mask = torch.log(mask.to(dtype=torch.float16))
1080
+ # now we make the mask bigger to add space for our new tokens
1081
+ new_mask = torch.zeros((txt.shape[0], n_txt + n + n_ref, n_txt + n + n_ref), dtype=torch.float16)
1082
+ # copy over the old mask, in quandrants
1083
+ new_mask[:, :n_txt, :n_txt] = mask[:, :n_txt, :n_txt]
1084
+ new_mask[:, :n_txt, n_txt+n:] = mask[:, :n_txt, n_txt:]
1085
+ new_mask[:, n_txt+n:, :n_txt] = mask[:, n_txt:, :n_txt]
1086
+ new_mask[:, n_txt+n:, n_txt+n:] = mask[:, n_txt:, n_txt:]
1087
+ # now fill in the attention bias to our redux tokens
1088
+ new_mask[:, :n_txt, n_txt:n_txt+n] = attn_bias
1089
+ new_mask[:, n_txt+n:, n_txt:n_txt+n] = attn_bias
1090
+ keys["attention_mask"] = new_mask.to(txt.device)
1091
+ keys["attention_mask_img_shape"] = mask_ref_size
1092
+
1093
+ c_out.append([torch.cat((txt, cond), dim=1), keys])
1094
+
1095
+ return (c_out,)
1096
+
1097
+ class unCLIPConditioning:
1098
+ @classmethod
1099
+ def INPUT_TYPES(s):
1100
+ return {"required": {"conditioning": ("CONDITIONING", ),
1101
+ "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
1102
+ "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
1103
+ "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
1104
+ }}
1105
+ RETURN_TYPES = ("CONDITIONING",)
1106
+ FUNCTION = "apply_adm"
1107
+
1108
+ CATEGORY = "conditioning"
1109
+
1110
+ def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
1111
+ if strength == 0:
1112
+ return (conditioning, )
1113
+
1114
+ c = []
1115
+ for t in conditioning:
1116
+ o = t[1].copy()
1117
+ x = {"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}
1118
+ if "unclip_conditioning" in o:
1119
+ o["unclip_conditioning"] = o["unclip_conditioning"][:] + [x]
1120
+ else:
1121
+ o["unclip_conditioning"] = [x]
1122
+ n = [t[0], o]
1123
+ c.append(n)
1124
+ return (c, )
1125
+
1126
+ class GLIGENLoader:
1127
+ @classmethod
1128
+ def INPUT_TYPES(s):
1129
+ return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"), )}}
1130
+
1131
+ RETURN_TYPES = ("GLIGEN",)
1132
+ FUNCTION = "load_gligen"
1133
+
1134
+ CATEGORY = "loaders"
1135
+
1136
+ def load_gligen(self, gligen_name):
1137
+ gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name)
1138
+ gligen = comfy.sd.load_gligen(gligen_path)
1139
+ return (gligen,)
1140
+
1141
+ class GLIGENTextBoxApply:
1142
+ @classmethod
1143
+ def INPUT_TYPES(s):
1144
+ return {"required": {"conditioning_to": ("CONDITIONING", ),
1145
+ "clip": ("CLIP", ),
1146
+ "gligen_textbox_model": ("GLIGEN", ),
1147
+ "text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
1148
+ "width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
1149
+ "height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
1150
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1151
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1152
+ }}
1153
+ RETURN_TYPES = ("CONDITIONING",)
1154
+ FUNCTION = "append"
1155
+
1156
+ CATEGORY = "conditioning/gligen"
1157
+
1158
+ def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
1159
+ c = []
1160
+ cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled="unprojected")
1161
+ for t in conditioning_to:
1162
+ n = [t[0], t[1].copy()]
1163
+ position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
1164
+ prev = []
1165
+ if "gligen" in n[1]:
1166
+ prev = n[1]['gligen'][2]
1167
+
1168
+ n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params)
1169
+ c.append(n)
1170
+ return (c, )
1171
+
1172
+ class EmptyLatentImage:
1173
+ def __init__(self):
1174
+ self.device = comfy.model_management.intermediate_device()
1175
+
1176
+ @classmethod
1177
+ def INPUT_TYPES(s):
1178
+ return {
1179
+ "required": {
1180
+ "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8, "tooltip": "The width of the latent images in pixels."}),
1181
+ "height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8, "tooltip": "The height of the latent images in pixels."}),
1182
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."})
1183
+ }
1184
+ }
1185
+ RETURN_TYPES = ("LATENT",)
1186
+ OUTPUT_TOOLTIPS = ("The empty latent image batch.",)
1187
+ FUNCTION = "generate"
1188
+
1189
+ CATEGORY = "latent"
1190
+ DESCRIPTION = "Create a new batch of empty latent images to be denoised via sampling."
1191
+
1192
+ def generate(self, width, height, batch_size=1):
1193
+ latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
1194
+ return ({"samples":latent}, )
1195
+
1196
+
1197
+ class LatentFromBatch:
1198
+ @classmethod
1199
+ def INPUT_TYPES(s):
1200
+ return {"required": { "samples": ("LATENT",),
1201
+ "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
1202
+ "length": ("INT", {"default": 1, "min": 1, "max": 64}),
1203
+ }}
1204
+ RETURN_TYPES = ("LATENT",)
1205
+ FUNCTION = "frombatch"
1206
+
1207
+ CATEGORY = "latent/batch"
1208
+
1209
+ def frombatch(self, samples, batch_index, length):
1210
+ s = samples.copy()
1211
+ s_in = samples["samples"]
1212
+ batch_index = min(s_in.shape[0] - 1, batch_index)
1213
+ length = min(s_in.shape[0] - batch_index, length)
1214
+ s["samples"] = s_in[batch_index:batch_index + length].clone()
1215
+ if "noise_mask" in samples:
1216
+ masks = samples["noise_mask"]
1217
+ if masks.shape[0] == 1:
1218
+ s["noise_mask"] = masks.clone()
1219
+ else:
1220
+ if masks.shape[0] < s_in.shape[0]:
1221
+ masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
1222
+ s["noise_mask"] = masks[batch_index:batch_index + length].clone()
1223
+ if "batch_index" not in s:
1224
+ s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
1225
+ else:
1226
+ s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
1227
+ return (s,)
1228
+
1229
+ class RepeatLatentBatch:
1230
+ @classmethod
1231
+ def INPUT_TYPES(s):
1232
+ return {"required": { "samples": ("LATENT",),
1233
+ "amount": ("INT", {"default": 1, "min": 1, "max": 64}),
1234
+ }}
1235
+ RETURN_TYPES = ("LATENT",)
1236
+ FUNCTION = "repeat"
1237
+
1238
+ CATEGORY = "latent/batch"
1239
+
1240
+ def repeat(self, samples, amount):
1241
+ s = samples.copy()
1242
+ s_in = samples["samples"]
1243
+
1244
+ s["samples"] = s_in.repeat((amount, 1,1,1))
1245
+ if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
1246
+ masks = samples["noise_mask"]
1247
+ if masks.shape[0] < s_in.shape[0]:
1248
+ masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
1249
+ s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
1250
+ if "batch_index" in s:
1251
+ offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
1252
+ s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
1253
+ return (s,)
1254
+
1255
+ class LatentUpscale:
1256
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
1257
+ crop_methods = ["disabled", "center"]
1258
+
1259
+ @classmethod
1260
+ def INPUT_TYPES(s):
1261
+ return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
1262
+ "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1263
+ "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1264
+ "crop": (s.crop_methods,)}}
1265
+ RETURN_TYPES = ("LATENT",)
1266
+ FUNCTION = "upscale"
1267
+
1268
+ CATEGORY = "latent"
1269
+
1270
+ def upscale(self, samples, upscale_method, width, height, crop):
1271
+ if width == 0 and height == 0:
1272
+ s = samples
1273
+ else:
1274
+ s = samples.copy()
1275
+
1276
+ if width == 0:
1277
+ height = max(64, height)
1278
+ width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2]))
1279
+ elif height == 0:
1280
+ width = max(64, width)
1281
+ height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1]))
1282
+ else:
1283
+ width = max(64, width)
1284
+ height = max(64, height)
1285
+
1286
+ s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
1287
+ return (s,)
1288
+
1289
+ class LatentUpscaleBy:
1290
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
1291
+
1292
+ @classmethod
1293
+ def INPUT_TYPES(s):
1294
+ return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
1295
+ "scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}),}}
1296
+ RETURN_TYPES = ("LATENT",)
1297
+ FUNCTION = "upscale"
1298
+
1299
+ CATEGORY = "latent"
1300
+
1301
+ def upscale(self, samples, upscale_method, scale_by):
1302
+ s = samples.copy()
1303
+ width = round(samples["samples"].shape[-1] * scale_by)
1304
+ height = round(samples["samples"].shape[-2] * scale_by)
1305
+ s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
1306
+ return (s,)
1307
+
1308
+ class LatentRotate:
1309
+ @classmethod
1310
+ def INPUT_TYPES(s):
1311
+ return {"required": { "samples": ("LATENT",),
1312
+ "rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],),
1313
+ }}
1314
+ RETURN_TYPES = ("LATENT",)
1315
+ FUNCTION = "rotate"
1316
+
1317
+ CATEGORY = "latent/transform"
1318
+
1319
+ def rotate(self, samples, rotation):
1320
+ s = samples.copy()
1321
+ rotate_by = 0
1322
+ if rotation.startswith("90"):
1323
+ rotate_by = 1
1324
+ elif rotation.startswith("180"):
1325
+ rotate_by = 2
1326
+ elif rotation.startswith("270"):
1327
+ rotate_by = 3
1328
+
1329
+ s["samples"] = torch.rot90(samples["samples"], k=rotate_by, dims=[3, 2])
1330
+ return (s,)
1331
+
1332
+ class LatentFlip:
1333
+ @classmethod
1334
+ def INPUT_TYPES(s):
1335
+ return {"required": { "samples": ("LATENT",),
1336
+ "flip_method": (["x-axis: vertically", "y-axis: horizontally"],),
1337
+ }}
1338
+ RETURN_TYPES = ("LATENT",)
1339
+ FUNCTION = "flip"
1340
+
1341
+ CATEGORY = "latent/transform"
1342
+
1343
+ def flip(self, samples, flip_method):
1344
+ s = samples.copy()
1345
+ if flip_method.startswith("x"):
1346
+ s["samples"] = torch.flip(samples["samples"], dims=[2])
1347
+ elif flip_method.startswith("y"):
1348
+ s["samples"] = torch.flip(samples["samples"], dims=[3])
1349
+
1350
+ return (s,)
1351
+
1352
+ class LatentComposite:
1353
+ @classmethod
1354
+ def INPUT_TYPES(s):
1355
+ return {"required": { "samples_to": ("LATENT",),
1356
+ "samples_from": ("LATENT",),
1357
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1358
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1359
+ "feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1360
+ }}
1361
+ RETURN_TYPES = ("LATENT",)
1362
+ FUNCTION = "composite"
1363
+
1364
+ CATEGORY = "latent"
1365
+
1366
+ def composite(self, samples_to, samples_from, x, y, composite_method="normal", feather=0):
1367
+ x = x // 8
1368
+ y = y // 8
1369
+ feather = feather // 8
1370
+ samples_out = samples_to.copy()
1371
+ s = samples_to["samples"].clone()
1372
+ samples_to = samples_to["samples"]
1373
+ samples_from = samples_from["samples"]
1374
+ if feather == 0:
1375
+ s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
1376
+ else:
1377
+ samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
1378
+ mask = torch.ones_like(samples_from)
1379
+ for t in range(feather):
1380
+ if y != 0:
1381
+ mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
1382
+
1383
+ if y + samples_from.shape[2] < samples_to.shape[2]:
1384
+ mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
1385
+ if x != 0:
1386
+ mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
1387
+ if x + samples_from.shape[3] < samples_to.shape[3]:
1388
+ mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
1389
+ rev_mask = torch.ones_like(mask) - mask
1390
+ s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask
1391
+ samples_out["samples"] = s
1392
+ return (samples_out,)
1393
+
1394
+ class LatentBlend:
1395
+ @classmethod
1396
+ def INPUT_TYPES(s):
1397
+ return {"required": {
1398
+ "samples1": ("LATENT",),
1399
+ "samples2": ("LATENT",),
1400
+ "blend_factor": ("FLOAT", {
1401
+ "default": 0.5,
1402
+ "min": 0,
1403
+ "max": 1,
1404
+ "step": 0.01
1405
+ }),
1406
+ }}
1407
+
1408
+ RETURN_TYPES = ("LATENT",)
1409
+ FUNCTION = "blend"
1410
+
1411
+ CATEGORY = "_for_testing"
1412
+
1413
+ def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
1414
+
1415
+ samples_out = samples1.copy()
1416
+ samples1 = samples1["samples"]
1417
+ samples2 = samples2["samples"]
1418
+
1419
+ if samples1.shape != samples2.shape:
1420
+ samples2.permute(0, 3, 1, 2)
1421
+ samples2 = comfy.utils.common_upscale(samples2, samples1.shape[3], samples1.shape[2], 'bicubic', crop='center')
1422
+ samples2.permute(0, 2, 3, 1)
1423
+
1424
+ samples_blended = self.blend_mode(samples1, samples2, blend_mode)
1425
+ samples_blended = samples1 * blend_factor + samples_blended * (1 - blend_factor)
1426
+ samples_out["samples"] = samples_blended
1427
+ return (samples_out,)
1428
+
1429
+ def blend_mode(self, img1, img2, mode):
1430
+ if mode == "normal":
1431
+ return img2
1432
+ else:
1433
+ raise ValueError(f"Unsupported blend mode: {mode}")
1434
+
1435
+ class LatentCrop:
1436
+ @classmethod
1437
+ def INPUT_TYPES(s):
1438
+ return {"required": { "samples": ("LATENT",),
1439
+ "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
1440
+ "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
1441
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1442
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1443
+ }}
1444
+ RETURN_TYPES = ("LATENT",)
1445
+ FUNCTION = "crop"
1446
+
1447
+ CATEGORY = "latent/transform"
1448
+
1449
+ def crop(self, samples, width, height, x, y):
1450
+ s = samples.copy()
1451
+ samples = samples['samples']
1452
+ x = x // 8
1453
+ y = y // 8
1454
+
1455
+ #enfonce minimum size of 64
1456
+ if x > (samples.shape[3] - 8):
1457
+ x = samples.shape[3] - 8
1458
+ if y > (samples.shape[2] - 8):
1459
+ y = samples.shape[2] - 8
1460
+
1461
+ new_height = height // 8
1462
+ new_width = width // 8
1463
+ to_x = new_width + x
1464
+ to_y = new_height + y
1465
+ s['samples'] = samples[:,:,y:to_y, x:to_x]
1466
+ return (s,)
1467
+
1468
+ class SetLatentNoiseMask:
1469
+ @classmethod
1470
+ def INPUT_TYPES(s):
1471
+ return {"required": { "samples": ("LATENT",),
1472
+ "mask": ("MASK",),
1473
+ }}
1474
+ RETURN_TYPES = ("LATENT",)
1475
+ FUNCTION = "set_mask"
1476
+
1477
+ CATEGORY = "latent/inpaint"
1478
+
1479
+ def set_mask(self, samples, mask):
1480
+ s = samples.copy()
1481
+ s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
1482
+ return (s,)
1483
+
1484
+ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
1485
+ latent_image = latent["samples"]
1486
+ latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
1487
+
1488
+ if disable_noise:
1489
+ noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
1490
+ else:
1491
+ batch_inds = latent["batch_index"] if "batch_index" in latent else None
1492
+ noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
1493
+
1494
+ noise_mask = None
1495
+ if "noise_mask" in latent:
1496
+ noise_mask = latent["noise_mask"]
1497
+
1498
+ callback = latent_preview.prepare_callback(model, steps)
1499
+ disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
1500
+ samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
1501
+ denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
1502
+ force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
1503
+ out = latent.copy()
1504
+ out["samples"] = samples
1505
+ return (out, )
1506
+
1507
+ class KSampler:
1508
+ @classmethod
1509
+ def INPUT_TYPES(s):
1510
+ return {
1511
+ "required": {
1512
+ "model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}),
1513
+ "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "tooltip": "The random seed used for creating the noise."}),
1514
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "The number of steps used in the denoising process."}),
1515
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}),
1516
+ "sampler_name": (comfy.samplers.KSampler.SAMPLERS, {"tooltip": "The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."}),
1517
+ "scheduler": (comfy.samplers.KSampler.SCHEDULERS, {"tooltip": "The scheduler controls how noise is gradually removed to form the image."}),
1518
+ "positive": ("CONDITIONING", {"tooltip": "The conditioning describing the attributes you want to include in the image."}),
1519
+ "negative": ("CONDITIONING", {"tooltip": "The conditioning describing the attributes you want to exclude from the image."}),
1520
+ "latent_image": ("LATENT", {"tooltip": "The latent image to denoise."}),
1521
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."}),
1522
+ }
1523
+ }
1524
+
1525
+ RETURN_TYPES = ("LATENT",)
1526
+ OUTPUT_TOOLTIPS = ("The denoised latent.",)
1527
+ FUNCTION = "sample"
1528
+
1529
+ CATEGORY = "sampling"
1530
+ DESCRIPTION = "Uses the provided model, positive and negative conditioning to denoise the latent image."
1531
+
1532
+ def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
1533
+ return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
1534
+
1535
+ class KSamplerAdvanced:
1536
+ @classmethod
1537
+ def INPUT_TYPES(s):
1538
+ return {"required":
1539
+ {"model": ("MODEL",),
1540
+ "add_noise": (["enable", "disable"], ),
1541
+ "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
1542
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
1543
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
1544
+ "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
1545
+ "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
1546
+ "positive": ("CONDITIONING", ),
1547
+ "negative": ("CONDITIONING", ),
1548
+ "latent_image": ("LATENT", ),
1549
+ "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
1550
+ "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
1551
+ "return_with_leftover_noise": (["disable", "enable"], ),
1552
+ }
1553
+ }
1554
+
1555
+ RETURN_TYPES = ("LATENT",)
1556
+ FUNCTION = "sample"
1557
+
1558
+ CATEGORY = "sampling"
1559
+
1560
+ def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0):
1561
+ force_full_denoise = True
1562
+ if return_with_leftover_noise == "enable":
1563
+ force_full_denoise = False
1564
+ disable_noise = False
1565
+ if add_noise == "disable":
1566
+ disable_noise = True
1567
+ return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
1568
+
1569
+ class SaveImage:
1570
+ def __init__(self):
1571
+ self.output_dir = folder_paths.get_output_directory()
1572
+ self.type = "output"
1573
+ self.prefix_append = ""
1574
+ self.compress_level = 4
1575
+
1576
+ @classmethod
1577
+ def INPUT_TYPES(s):
1578
+ return {
1579
+ "required": {
1580
+ "images": ("IMAGE", {"tooltip": "The images to save."}),
1581
+ "filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
1582
+ },
1583
+ "hidden": {
1584
+ "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"
1585
+ },
1586
+ }
1587
+
1588
+ RETURN_TYPES = ()
1589
+ FUNCTION = "save_images"
1590
+
1591
+ OUTPUT_NODE = True
1592
+
1593
+ CATEGORY = "image"
1594
+ DESCRIPTION = "Saves the input images to your ComfyUI output directory."
1595
+
1596
+ def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
1597
+ filename_prefix += self.prefix_append
1598
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
1599
+ results = list()
1600
+ for (batch_number, image) in enumerate(images):
1601
+ i = 255. * image.cpu().numpy()
1602
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
1603
+ metadata = None
1604
+ if not args.disable_metadata:
1605
+ metadata = PngInfo()
1606
+ if prompt is not None:
1607
+ metadata.add_text("prompt", json.dumps(prompt))
1608
+ if extra_pnginfo is not None:
1609
+ for x in extra_pnginfo:
1610
+ metadata.add_text(x, json.dumps(extra_pnginfo[x]))
1611
+
1612
+ filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
1613
+ file = f"{filename_with_batch_num}_{counter:05}_.png"
1614
+ img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
1615
+ results.append({
1616
+ "filename": file,
1617
+ "subfolder": subfolder,
1618
+ "type": self.type
1619
+ })
1620
+ counter += 1
1621
+
1622
+ return { "ui": { "images": results } }
1623
+
1624
+ class PreviewImage(SaveImage):
1625
+ def __init__(self):
1626
+ self.output_dir = folder_paths.get_temp_directory()
1627
+ self.type = "temp"
1628
+ self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
1629
+ self.compress_level = 1
1630
+
1631
+ @classmethod
1632
+ def INPUT_TYPES(s):
1633
+ return {"required":
1634
+ {"images": ("IMAGE", ), },
1635
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
1636
+ }
1637
+
1638
+ class LoadImage:
1639
+ @classmethod
1640
+ def INPUT_TYPES(s):
1641
+ input_dir = folder_paths.get_input_directory()
1642
+ files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
1643
+ return {"required":
1644
+ {"image": (sorted(files), {"image_upload": True})},
1645
+ }
1646
+
1647
+ CATEGORY = "image"
1648
+
1649
+ RETURN_TYPES = ("IMAGE", "MASK")
1650
+ FUNCTION = "load_image"
1651
+ def load_image(self, image):
1652
+ image_path = folder_paths.get_annotated_filepath(image)
1653
+
1654
+ img = node_helpers.pillow(Image.open, image_path)
1655
+
1656
+ output_images = []
1657
+ output_masks = []
1658
+ w, h = None, None
1659
+
1660
+ excluded_formats = ['MPO']
1661
+
1662
+ for i in ImageSequence.Iterator(img):
1663
+ i = node_helpers.pillow(ImageOps.exif_transpose, i)
1664
+
1665
+ if i.mode == 'I':
1666
+ i = i.point(lambda i: i * (1 / 255))
1667
+ image = i.convert("RGB")
1668
+
1669
+ if len(output_images) == 0:
1670
+ w = image.size[0]
1671
+ h = image.size[1]
1672
+
1673
+ if image.size[0] != w or image.size[1] != h:
1674
+ continue
1675
+
1676
+ image = np.array(image).astype(np.float32) / 255.0
1677
+ image = torch.from_numpy(image)[None,]
1678
+ if 'A' in i.getbands():
1679
+ mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
1680
+ mask = 1. - torch.from_numpy(mask)
1681
+ else:
1682
+ mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
1683
+ output_images.append(image)
1684
+ output_masks.append(mask.unsqueeze(0))
1685
+
1686
+ if len(output_images) > 1 and img.format not in excluded_formats:
1687
+ output_image = torch.cat(output_images, dim=0)
1688
+ output_mask = torch.cat(output_masks, dim=0)
1689
+ else:
1690
+ output_image = output_images[0]
1691
+ output_mask = output_masks[0]
1692
+
1693
+ return (output_image, output_mask)
1694
+
1695
+ @classmethod
1696
+ def IS_CHANGED(s, image):
1697
+ image_path = folder_paths.get_annotated_filepath(image)
1698
+ m = hashlib.sha256()
1699
+ with open(image_path, 'rb') as f:
1700
+ m.update(f.read())
1701
+ return m.digest().hex()
1702
+
1703
+ @classmethod
1704
+ def VALIDATE_INPUTS(s, image):
1705
+ if not folder_paths.exists_annotated_filepath(image):
1706
+ return "Invalid image file: {}".format(image)
1707
+
1708
+ return True
1709
+
1710
+ class LoadImageMask:
1711
+ _color_channels = ["alpha", "red", "green", "blue"]
1712
+ @classmethod
1713
+ def INPUT_TYPES(s):
1714
+ input_dir = folder_paths.get_input_directory()
1715
+ files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
1716
+ return {"required":
1717
+ {"image": (sorted(files), {"image_upload": True}),
1718
+ "channel": (s._color_channels, ), }
1719
+ }
1720
+
1721
+ CATEGORY = "mask"
1722
+
1723
+ RETURN_TYPES = ("MASK",)
1724
+ FUNCTION = "load_image"
1725
+ def load_image(self, image, channel):
1726
+ image_path = folder_paths.get_annotated_filepath(image)
1727
+ i = node_helpers.pillow(Image.open, image_path)
1728
+ i = node_helpers.pillow(ImageOps.exif_transpose, i)
1729
+ if i.getbands() != ("R", "G", "B", "A"):
1730
+ if i.mode == 'I':
1731
+ i = i.point(lambda i: i * (1 / 255))
1732
+ i = i.convert("RGBA")
1733
+ mask = None
1734
+ c = channel[0].upper()
1735
+ if c in i.getbands():
1736
+ mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
1737
+ mask = torch.from_numpy(mask)
1738
+ if c == 'A':
1739
+ mask = 1. - mask
1740
+ else:
1741
+ mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
1742
+ return (mask.unsqueeze(0),)
1743
+
1744
+ @classmethod
1745
+ def IS_CHANGED(s, image, channel):
1746
+ image_path = folder_paths.get_annotated_filepath(image)
1747
+ m = hashlib.sha256()
1748
+ with open(image_path, 'rb') as f:
1749
+ m.update(f.read())
1750
+ return m.digest().hex()
1751
+
1752
+ @classmethod
1753
+ def VALIDATE_INPUTS(s, image):
1754
+ if not folder_paths.exists_annotated_filepath(image):
1755
+ return "Invalid image file: {}".format(image)
1756
+
1757
+ return True
1758
+
1759
+ class ImageScale:
1760
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
1761
+ crop_methods = ["disabled", "center"]
1762
+
1763
+ @classmethod
1764
+ def INPUT_TYPES(s):
1765
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
1766
+ "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
1767
+ "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
1768
+ "crop": (s.crop_methods,)}}
1769
+ RETURN_TYPES = ("IMAGE",)
1770
+ FUNCTION = "upscale"
1771
+
1772
+ CATEGORY = "image/upscaling"
1773
+
1774
+ def upscale(self, image, upscale_method, width, height, crop):
1775
+ if width == 0 and height == 0:
1776
+ s = image
1777
+ else:
1778
+ samples = image.movedim(-1,1)
1779
+
1780
+ if width == 0:
1781
+ width = max(1, round(samples.shape[3] * height / samples.shape[2]))
1782
+ elif height == 0:
1783
+ height = max(1, round(samples.shape[2] * width / samples.shape[3]))
1784
+
1785
+ s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop)
1786
+ s = s.movedim(1,-1)
1787
+ return (s,)
1788
+
1789
+ class ImageScaleBy:
1790
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
1791
+
1792
+ @classmethod
1793
+ def INPUT_TYPES(s):
1794
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
1795
+ "scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 8.0, "step": 0.01}),}}
1796
+ RETURN_TYPES = ("IMAGE",)
1797
+ FUNCTION = "upscale"
1798
+
1799
+ CATEGORY = "image/upscaling"
1800
+
1801
+ def upscale(self, image, upscale_method, scale_by):
1802
+ samples = image.movedim(-1,1)
1803
+ width = round(samples.shape[3] * scale_by)
1804
+ height = round(samples.shape[2] * scale_by)
1805
+ s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
1806
+ s = s.movedim(1,-1)
1807
+ return (s,)
1808
+
1809
+ class ImageInvert:
1810
+
1811
+ @classmethod
1812
+ def INPUT_TYPES(s):
1813
+ return {"required": { "image": ("IMAGE",)}}
1814
+
1815
+ RETURN_TYPES = ("IMAGE",)
1816
+ FUNCTION = "invert"
1817
+
1818
+ CATEGORY = "image"
1819
+
1820
+ def invert(self, image):
1821
+ s = 1.0 - image
1822
+ return (s,)
1823
+
1824
+ class ImageBatch:
1825
+
1826
+ @classmethod
1827
+ def INPUT_TYPES(s):
1828
+ return {"required": { "image1": ("IMAGE",), "image2": ("IMAGE",)}}
1829
+
1830
+ RETURN_TYPES = ("IMAGE",)
1831
+ FUNCTION = "batch"
1832
+
1833
+ CATEGORY = "image"
1834
+
1835
+ def batch(self, image1, image2):
1836
+ if image1.shape[1:] != image2.shape[1:]:
1837
+ image2 = comfy.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1)
1838
+ s = torch.cat((image1, image2), dim=0)
1839
+ return (s,)
1840
+
1841
+ class EmptyImage:
1842
+ def __init__(self, device="cpu"):
1843
+ self.device = device
1844
+
1845
+ @classmethod
1846
+ def INPUT_TYPES(s):
1847
+ return {"required": { "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
1848
+ "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
1849
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
1850
+ "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
1851
+ }}
1852
+ RETURN_TYPES = ("IMAGE",)
1853
+ FUNCTION = "generate"
1854
+
1855
+ CATEGORY = "image"
1856
+
1857
+ def generate(self, width, height, batch_size=1, color=0):
1858
+ r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF)
1859
+ g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF)
1860
+ b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF)
1861
+ return (torch.cat((r, g, b), dim=-1), )
1862
+
1863
+ class ImagePadForOutpaint:
1864
+
1865
+ @classmethod
1866
+ def INPUT_TYPES(s):
1867
+ return {
1868
+ "required": {
1869
+ "image": ("IMAGE",),
1870
+ "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1871
+ "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1872
+ "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1873
+ "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1874
+ "feathering": ("INT", {"default": 40, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
1875
+ }
1876
+ }
1877
+
1878
+ RETURN_TYPES = ("IMAGE", "MASK")
1879
+ FUNCTION = "expand_image"
1880
+
1881
+ CATEGORY = "image"
1882
+
1883
+ def expand_image(self, image, left, top, right, bottom, feathering):
1884
+ d1, d2, d3, d4 = image.size()
1885
+
1886
+ new_image = torch.ones(
1887
+ (d1, d2 + top + bottom, d3 + left + right, d4),
1888
+ dtype=torch.float32,
1889
+ ) * 0.5
1890
+
1891
+ new_image[:, top:top + d2, left:left + d3, :] = image
1892
+
1893
+ mask = torch.ones(
1894
+ (d2 + top + bottom, d3 + left + right),
1895
+ dtype=torch.float32,
1896
+ )
1897
+
1898
+ t = torch.zeros(
1899
+ (d2, d3),
1900
+ dtype=torch.float32
1901
+ )
1902
+
1903
+ if feathering > 0 and feathering * 2 < d2 and feathering * 2 < d3:
1904
+
1905
+ for i in range(d2):
1906
+ for j in range(d3):
1907
+ dt = i if top != 0 else d2
1908
+ db = d2 - i if bottom != 0 else d2
1909
+
1910
+ dl = j if left != 0 else d3
1911
+ dr = d3 - j if right != 0 else d3
1912
+
1913
+ d = min(dt, db, dl, dr)
1914
+
1915
+ if d >= feathering:
1916
+ continue
1917
+
1918
+ v = (feathering - d) / feathering
1919
+
1920
+ t[i, j] = v * v
1921
+
1922
+ mask[top:top + d2, left:left + d3] = t
1923
+
1924
+ return (new_image, mask)
1925
+
1926
+
1927
+ NODE_CLASS_MAPPINGS = {
1928
+ "KSampler": KSampler,
1929
+ "CheckpointLoaderSimple": CheckpointLoaderSimple,
1930
+ "CLIPTextEncode": CLIPTextEncode,
1931
+ "CLIPSetLastLayer": CLIPSetLastLayer,
1932
+ "VAEDecode": VAEDecode,
1933
+ "VAEEncode": VAEEncode,
1934
+ "VAEEncodeForInpaint": VAEEncodeForInpaint,
1935
+ "VAELoader": VAELoader,
1936
+ "EmptyLatentImage": EmptyLatentImage,
1937
+ "LatentUpscale": LatentUpscale,
1938
+ "LatentUpscaleBy": LatentUpscaleBy,
1939
+ "LatentFromBatch": LatentFromBatch,
1940
+ "RepeatLatentBatch": RepeatLatentBatch,
1941
+ "SaveImage": SaveImage,
1942
+ "PreviewImage": PreviewImage,
1943
+ "LoadImage": LoadImage,
1944
+ "LoadImageMask": LoadImageMask,
1945
+ "ImageScale": ImageScale,
1946
+ "ImageScaleBy": ImageScaleBy,
1947
+ "ImageInvert": ImageInvert,
1948
+ "ImageBatch": ImageBatch,
1949
+ "ImagePadForOutpaint": ImagePadForOutpaint,
1950
+ "EmptyImage": EmptyImage,
1951
+ "ConditioningAverage": ConditioningAverage ,
1952
+ "ConditioningCombine": ConditioningCombine,
1953
+ "ConditioningConcat": ConditioningConcat,
1954
+ "ConditioningSetArea": ConditioningSetArea,
1955
+ "ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
1956
+ "ConditioningSetAreaStrength": ConditioningSetAreaStrength,
1957
+ "ConditioningSetMask": ConditioningSetMask,
1958
+ "KSamplerAdvanced": KSamplerAdvanced,
1959
+ "SetLatentNoiseMask": SetLatentNoiseMask,
1960
+ "LatentComposite": LatentComposite,
1961
+ "LatentBlend": LatentBlend,
1962
+ "LatentRotate": LatentRotate,
1963
+ "LatentFlip": LatentFlip,
1964
+ "LatentCrop": LatentCrop,
1965
+ "LoraLoader": LoraLoader,
1966
+ "CLIPLoader": CLIPLoader,
1967
+ "UNETLoader": UNETLoader,
1968
+ "DualCLIPLoader": DualCLIPLoader,
1969
+ "CLIPVisionEncode": CLIPVisionEncode,
1970
+ "StyleModelApply": StyleModelApply,
1971
+ "unCLIPConditioning": unCLIPConditioning,
1972
+ "ControlNetApply": ControlNetApply,
1973
+ "ControlNetApplyAdvanced": ControlNetApplyAdvanced,
1974
+ "ControlNetLoader": ControlNetLoader,
1975
+ "DiffControlNetLoader": DiffControlNetLoader,
1976
+ "StyleModelLoader": StyleModelLoader,
1977
+ "CLIPVisionLoader": CLIPVisionLoader,
1978
+ "VAEDecodeTiled": VAEDecodeTiled,
1979
+ "VAEEncodeTiled": VAEEncodeTiled,
1980
+ "unCLIPCheckpointLoader": unCLIPCheckpointLoader,
1981
+ "GLIGENLoader": GLIGENLoader,
1982
+ "GLIGENTextBoxApply": GLIGENTextBoxApply,
1983
+ "InpaintModelConditioning": InpaintModelConditioning,
1984
+
1985
+ "CheckpointLoader": CheckpointLoader,
1986
+ "DiffusersLoader": DiffusersLoader,
1987
+
1988
+ "LoadLatent": LoadLatent,
1989
+ "SaveLatent": SaveLatent,
1990
+
1991
+ "ConditioningZeroOut": ConditioningZeroOut,
1992
+ "ConditioningSetTimestepRange": ConditioningSetTimestepRange,
1993
+ "LoraLoaderModelOnly": LoraLoaderModelOnly,
1994
+ }
1995
+
1996
+ NODE_DISPLAY_NAME_MAPPINGS = {
1997
+ # Sampling
1998
+ "KSampler": "KSampler",
1999
+ "KSamplerAdvanced": "KSampler (Advanced)",
2000
+ # Loaders
2001
+ "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
2002
+ "CheckpointLoaderSimple": "Load Checkpoint",
2003
+ "VAELoader": "Load VAE",
2004
+ "LoraLoader": "Load LoRA",
2005
+ "CLIPLoader": "Load CLIP",
2006
+ "ControlNetLoader": "Load ControlNet Model",
2007
+ "DiffControlNetLoader": "Load ControlNet Model (diff)",
2008
+ "StyleModelLoader": "Load Style Model",
2009
+ "CLIPVisionLoader": "Load CLIP Vision",
2010
+ "UpscaleModelLoader": "Load Upscale Model",
2011
+ "UNETLoader": "Load Diffusion Model",
2012
+ # Conditioning
2013
+ "CLIPVisionEncode": "CLIP Vision Encode",
2014
+ "StyleModelApply": "Apply Style Model",
2015
+ "CLIPTextEncode": "CLIP Text Encode (Prompt)",
2016
+ "CLIPSetLastLayer": "CLIP Set Last Layer",
2017
+ "ConditioningCombine": "Conditioning (Combine)",
2018
+ "ConditioningAverage ": "Conditioning (Average)",
2019
+ "ConditioningConcat": "Conditioning (Concat)",
2020
+ "ConditioningSetArea": "Conditioning (Set Area)",
2021
+ "ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
2022
+ "ConditioningSetMask": "Conditioning (Set Mask)",
2023
+ "ControlNetApply": "Apply ControlNet (OLD)",
2024
+ "ControlNetApplyAdvanced": "Apply ControlNet",
2025
+ # Latent
2026
+ "VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
2027
+ "SetLatentNoiseMask": "Set Latent Noise Mask",
2028
+ "VAEDecode": "VAE Decode",
2029
+ "VAEEncode": "VAE Encode",
2030
+ "LatentRotate": "Rotate Latent",
2031
+ "LatentFlip": "Flip Latent",
2032
+ "LatentCrop": "Crop Latent",
2033
+ "EmptyLatentImage": "Empty Latent Image",
2034
+ "LatentUpscale": "Upscale Latent",
2035
+ "LatentUpscaleBy": "Upscale Latent By",
2036
+ "LatentComposite": "Latent Composite",
2037
+ "LatentBlend": "Latent Blend",
2038
+ "LatentFromBatch" : "Latent From Batch",
2039
+ "RepeatLatentBatch": "Repeat Latent Batch",
2040
+ # Image
2041
+ "SaveImage": "Save Image",
2042
+ "PreviewImage": "Preview Image",
2043
+ "LoadImage": "Load Image",
2044
+ "LoadImageMask": "Load Image (as Mask)",
2045
+ "ImageScale": "Upscale Image",
2046
+ "ImageScaleBy": "Upscale Image By",
2047
+ "ImageUpscaleWithModel": "Upscale Image (using Model)",
2048
+ "ImageInvert": "Invert Image",
2049
+ "ImagePadForOutpaint": "Pad Image for Outpainting",
2050
+ "ImageBatch": "Batch Images",
2051
+ "ImageCrop": "Image Crop",
2052
+ "ImageBlend": "Image Blend",
2053
+ "ImageBlur": "Image Blur",
2054
+ "ImageQuantize": "Image Quantize",
2055
+ "ImageSharpen": "Image Sharpen",
2056
+ "ImageScaleToTotalPixels": "Scale Image to Total Pixels",
2057
+ # _for_testing
2058
+ "VAEDecodeTiled": "VAE Decode (Tiled)",
2059
+ "VAEEncodeTiled": "VAE Encode (Tiled)",
2060
+ }
2061
+
2062
+ EXTENSION_WEB_DIRS = {}
2063
+
2064
+ # Dictionary of successfully loaded module names and associated directories.
2065
+ LOADED_MODULE_DIRS = {}
2066
+
2067
+
2068
+ def get_module_name(module_path: str) -> str:
2069
+ """
2070
+ Returns the module name based on the given module path.
2071
+ Examples:
2072
+ get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node.py") -> "my_custom_node"
2073
+ get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node") -> "my_custom_node"
2074
+ get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/") -> "my_custom_node"
2075
+ get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__.py") -> "my_custom_node"
2076
+ get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__") -> "my_custom_node"
2077
+ get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__/") -> "my_custom_node"
2078
+ get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node.disabled") -> "custom_nodes
2079
+ Args:
2080
+ module_path (str): The path of the module.
2081
+ Returns:
2082
+ str: The module name.
2083
+ """
2084
+ base_path = os.path.basename(module_path)
2085
+ if os.path.isfile(module_path):
2086
+ base_path = os.path.splitext(base_path)[0]
2087
+ return base_path
2088
+
2089
+
2090
+ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
2091
+ module_name = os.path.basename(module_path)
2092
+ if os.path.isfile(module_path):
2093
+ sp = os.path.splitext(module_path)
2094
+ module_name = sp[0]
2095
+ try:
2096
+ logging.debug("Trying to load custom node {}".format(module_path))
2097
+ if os.path.isfile(module_path):
2098
+ module_spec = importlib.util.spec_from_file_location(module_name, module_path)
2099
+ module_dir = os.path.split(module_path)[0]
2100
+ else:
2101
+ module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
2102
+ module_dir = module_path
2103
+
2104
+ module = importlib.util.module_from_spec(module_spec)
2105
+ sys.modules[module_name] = module
2106
+ module_spec.loader.exec_module(module)
2107
+
2108
+ LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
2109
+
2110
+ if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None:
2111
+ web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY")))
2112
+ if os.path.isdir(web_dir):
2113
+ EXTENSION_WEB_DIRS[module_name] = web_dir
2114
+
2115
+ if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
2116
+ for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
2117
+ if name not in ignore:
2118
+ NODE_CLASS_MAPPINGS[name] = node_cls
2119
+ node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
2120
+ if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
2121
+ NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
2122
+ return True
2123
+ else:
2124
+ logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
2125
+ return False
2126
+ except Exception as e:
2127
+ logging.warning(traceback.format_exc())
2128
+ logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
2129
+ return False
2130
+
2131
+ def init_external_custom_nodes():
2132
+ """
2133
+ Initializes the external custom nodes.
2134
+
2135
+ This function loads custom nodes from the specified folder paths and imports them into the application.
2136
+ It measures the import times for each custom node and logs the results.
2137
+
2138
+ Returns:
2139
+ None
2140
+ """
2141
+ base_node_names = set(NODE_CLASS_MAPPINGS.keys())
2142
+ node_paths = folder_paths.get_folder_paths("custom_nodes")
2143
+ node_import_times = []
2144
+ for custom_node_path in node_paths:
2145
+ possible_modules = os.listdir(os.path.realpath(custom_node_path))
2146
+ if "__pycache__" in possible_modules:
2147
+ possible_modules.remove("__pycache__")
2148
+
2149
+ for possible_module in possible_modules:
2150
+ module_path = os.path.join(custom_node_path, possible_module)
2151
+ if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
2152
+ if module_path.endswith(".disabled"): continue
2153
+ time_before = time.perf_counter()
2154
+ success = load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
2155
+ node_import_times.append((time.perf_counter() - time_before, module_path, success))
2156
+
2157
+ if len(node_import_times) > 0:
2158
+ logging.info("\nImport times for custom nodes:")
2159
+ for n in sorted(node_import_times):
2160
+ if n[2]:
2161
+ import_message = ""
2162
+ else:
2163
+ import_message = " (IMPORT FAILED)"
2164
+ logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
2165
+ logging.info("")
2166
+
2167
+ def init_builtin_extra_nodes():
2168
+ """
2169
+ Initializes the built-in extra nodes in ComfyUI.
2170
+
2171
+ This function loads the extra node files located in the "comfy_extras" directory and imports them into ComfyUI.
2172
+ If any of the extra node files fail to import, a warning message is logged.
2173
+
2174
+ Returns:
2175
+ None
2176
+ """
2177
+ extras_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras")
2178
+ extras_files = [
2179
+ "nodes_latent.py",
2180
+ "nodes_hypernetwork.py",
2181
+ "nodes_upscale_model.py",
2182
+ "nodes_post_processing.py",
2183
+ "nodes_mask.py",
2184
+ "nodes_compositing.py",
2185
+ "nodes_rebatch.py",
2186
+ "nodes_model_merging.py",
2187
+ "nodes_tomesd.py",
2188
+ "nodes_clip_sdxl.py",
2189
+ "nodes_canny.py",
2190
+ "nodes_freelunch.py",
2191
+ "nodes_custom_sampler.py",
2192
+ "nodes_hypertile.py",
2193
+ "nodes_model_advanced.py",
2194
+ "nodes_model_downscale.py",
2195
+ "nodes_images.py",
2196
+ "nodes_video_model.py",
2197
+ "nodes_sag.py",
2198
+ "nodes_perpneg.py",
2199
+ "nodes_stable3d.py",
2200
+ "nodes_sdupscale.py",
2201
+ "nodes_photomaker.py",
2202
+ "nodes_pixart.py",
2203
+ "nodes_cond.py",
2204
+ "nodes_morphology.py",
2205
+ "nodes_stable_cascade.py",
2206
+ "nodes_differential_diffusion.py",
2207
+ "nodes_ip2p.py",
2208
+ "nodes_model_merging_model_specific.py",
2209
+ "nodes_pag.py",
2210
+ "nodes_align_your_steps.py",
2211
+ "nodes_attention_multiply.py",
2212
+ "nodes_advanced_samplers.py",
2213
+ "nodes_webcam.py",
2214
+ "nodes_audio.py",
2215
+ "nodes_sd3.py",
2216
+ "nodes_gits.py",
2217
+ "nodes_controlnet.py",
2218
+ "nodes_hunyuan.py",
2219
+ "nodes_flux.py",
2220
+ "nodes_lora_extract.py",
2221
+ "nodes_torch_compile.py",
2222
+ "nodes_mochi.py",
2223
+ "nodes_slg.py",
2224
+ "nodes_mahiro.py",
2225
+ "nodes_lt.py",
2226
+ "nodes_hooks.py",
2227
+ "nodes_load_3d.py",
2228
+ "nodes_cosmos.py",
2229
+ ]
2230
+
2231
+ import_failed = []
2232
+ for node_file in extras_files:
2233
+ if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
2234
+ import_failed.append(node_file)
2235
+
2236
+ return import_failed
2237
+
2238
+
2239
+ def init_extra_nodes(init_custom_nodes=True):
2240
+ import_failed = init_builtin_extra_nodes()
2241
+
2242
+ if init_custom_nodes:
2243
+ init_external_custom_nodes()
2244
+ else:
2245
+ logging.info("Skipping loading of custom nodes")
2246
+
2247
+ if len(import_failed) > 0:
2248
+ logging.warning("WARNING: some comfy_extras/ nodes did not import correctly. This may be because they are missing some dependencies.\n")
2249
+ for node in import_failed:
2250
+ logging.warning("IMPORT FAILED: {}".format(node))
2251
+ logging.warning("\nThis issue might be caused by new missing dependencies added the last time you updated ComfyUI.")
2252
+ if args.windows_standalone_build:
2253
+ logging.warning("Please run the update script: update/update_comfyui.bat")
2254
+ else:
2255
+ logging.warning("Please do a: pip install -r requirements.txt")
2256
+ logging.warning("")
2257
+
2258
+ return import_failed
server.py ADDED
@@ -0,0 +1,847 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import asyncio
4
+ import traceback
5
+
6
+ import nodes
7
+ import folder_paths
8
+ import execution
9
+ import uuid
10
+ import urllib
11
+ import json
12
+ import glob
13
+ import struct
14
+ import ssl
15
+ import socket
16
+ import ipaddress
17
+ from PIL import Image, ImageOps
18
+ from PIL.PngImagePlugin import PngInfo
19
+ from io import BytesIO
20
+
21
+ import aiohttp
22
+ from aiohttp import web
23
+ import logging
24
+
25
+ import mimetypes
26
+ from comfy.cli_args import args
27
+ import comfy.utils
28
+ import comfy.model_management
29
+ import node_helpers
30
+ from comfyui_version import __version__
31
+ from app.frontend_management import FrontendManager
32
+ from app.user_manager import UserManager
33
+ from app.model_manager import ModelFileManager
34
+ from app.custom_node_manager import CustomNodeManager
35
+ from typing import Optional
36
+ from api_server.routes.internal.internal_routes import InternalRoutes
37
+
38
+ class BinaryEventTypes:
39
+ PREVIEW_IMAGE = 1
40
+ UNENCODED_PREVIEW_IMAGE = 2
41
+
42
+ async def send_socket_catch_exception(function, message):
43
+ try:
44
+ await function(message)
45
+ except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
46
+ logging.warning("send error: {}".format(err))
47
+
48
+ @web.middleware
49
+ async def cache_control(request: web.Request, handler):
50
+ response: web.Response = await handler(request)
51
+ if request.path.endswith('.js') or request.path.endswith('.css'):
52
+ response.headers.setdefault('Cache-Control', 'no-cache')
53
+ return response
54
+
55
+ def create_cors_middleware(allowed_origin: str):
56
+ @web.middleware
57
+ async def cors_middleware(request: web.Request, handler):
58
+ if request.method == "OPTIONS":
59
+ # Pre-flight request. Reply successfully:
60
+ response = web.Response()
61
+ else:
62
+ response = await handler(request)
63
+
64
+ response.headers['Access-Control-Allow-Origin'] = allowed_origin
65
+ response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
66
+ response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
67
+ response.headers['Access-Control-Allow-Credentials'] = 'true'
68
+ return response
69
+
70
+ return cors_middleware
71
+
72
+ def is_loopback(host):
73
+ if host is None:
74
+ return False
75
+ try:
76
+ if ipaddress.ip_address(host).is_loopback:
77
+ return True
78
+ else:
79
+ return False
80
+ except:
81
+ pass
82
+
83
+ loopback = False
84
+ for family in (socket.AF_INET, socket.AF_INET6):
85
+ try:
86
+ r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
87
+ for family, _, _, _, sockaddr in r:
88
+ if not ipaddress.ip_address(sockaddr[0]).is_loopback:
89
+ return loopback
90
+ else:
91
+ loopback = True
92
+ except socket.gaierror:
93
+ pass
94
+
95
+ return loopback
96
+
97
+
98
+ def create_origin_only_middleware():
99
+ @web.middleware
100
+ async def origin_only_middleware(request: web.Request, handler):
101
+ #this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
102
+ #in that case the Host and Origin hostnames won't match
103
+ #I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
104
+ if 'Host' in request.headers and 'Origin' in request.headers:
105
+ host = request.headers['Host']
106
+ origin = request.headers['Origin']
107
+ host_domain = host.lower()
108
+ parsed = urllib.parse.urlparse(origin)
109
+ origin_domain = parsed.netloc.lower()
110
+ host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
111
+
112
+ #limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
113
+ loopback = is_loopback(host_domain_parsed.hostname)
114
+
115
+ if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
116
+ host_domain = host_domain_parsed.hostname
117
+ if host_domain_parsed.port is None:
118
+ origin_domain = parsed.hostname
119
+
120
+ if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
121
+ if host_domain != origin_domain:
122
+ logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
123
+ return web.Response(status=403)
124
+
125
+ if request.method == "OPTIONS":
126
+ response = web.Response()
127
+ else:
128
+ response = await handler(request)
129
+
130
+ return response
131
+
132
+ return origin_only_middleware
133
+
134
+ class PromptServer():
135
+ def __init__(self, loop):
136
+ PromptServer.instance = self
137
+
138
+ mimetypes.init()
139
+ mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
140
+
141
+ self.user_manager = UserManager()
142
+ self.model_file_manager = ModelFileManager()
143
+ self.custom_node_manager = CustomNodeManager()
144
+ self.internal_routes = InternalRoutes(self)
145
+ self.supports = ["custom_nodes_from_web"]
146
+ self.prompt_queue = None
147
+ self.loop = loop
148
+ self.messages = asyncio.Queue()
149
+ self.client_session:Optional[aiohttp.ClientSession] = None
150
+ self.number = 0
151
+
152
+ middlewares = [cache_control]
153
+ if args.enable_cors_header:
154
+ middlewares.append(create_cors_middleware(args.enable_cors_header))
155
+ else:
156
+ middlewares.append(create_origin_only_middleware())
157
+
158
+ max_upload_size = round(args.max_upload_size * 1024 * 1024)
159
+ self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
160
+ self.sockets = dict()
161
+ self.web_root = (
162
+ FrontendManager.init_frontend(args.front_end_version)
163
+ if args.front_end_root is None
164
+ else args.front_end_root
165
+ )
166
+ logging.info(f"[Prompt Server] web root: {self.web_root}")
167
+ routes = web.RouteTableDef()
168
+ self.routes = routes
169
+ self.last_node_id = None
170
+ self.client_id = None
171
+
172
+ self.on_prompt_handlers = []
173
+
174
+ @routes.get('/ws')
175
+ async def websocket_handler(request):
176
+ ws = web.WebSocketResponse()
177
+ await ws.prepare(request)
178
+ sid = request.rel_url.query.get('clientId', '')
179
+ if sid:
180
+ # Reusing existing session, remove old
181
+ self.sockets.pop(sid, None)
182
+ else:
183
+ sid = uuid.uuid4().hex
184
+
185
+ self.sockets[sid] = ws
186
+
187
+ try:
188
+ # Send initial state to the new client
189
+ await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
190
+ # On reconnect if we are the currently executing client send the current node
191
+ if self.client_id == sid and self.last_node_id is not None:
192
+ await self.send("executing", { "node": self.last_node_id }, sid)
193
+
194
+ async for msg in ws:
195
+ if msg.type == aiohttp.WSMsgType.ERROR:
196
+ logging.warning('ws connection closed with exception %s' % ws.exception())
197
+ finally:
198
+ self.sockets.pop(sid, None)
199
+ return ws
200
+
201
+ @routes.get("/")
202
+ async def get_root(request):
203
+ response = web.FileResponse(os.path.join(self.web_root, "index.html"))
204
+ response.headers['Cache-Control'] = 'no-cache'
205
+ response.headers["Pragma"] = "no-cache"
206
+ response.headers["Expires"] = "0"
207
+ return response
208
+
209
+ @routes.get("/embeddings")
210
+ def get_embeddings(self):
211
+ embeddings = folder_paths.get_filename_list("embeddings")
212
+ return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
213
+
214
+ @routes.get("/models")
215
+ def list_model_types(request):
216
+ model_types = list(folder_paths.folder_names_and_paths.keys())
217
+
218
+ return web.json_response(model_types)
219
+
220
+ @routes.get("/models/{folder}")
221
+ async def get_models(request):
222
+ folder = request.match_info.get("folder", None)
223
+ if not folder in folder_paths.folder_names_and_paths:
224
+ return web.Response(status=404)
225
+ files = folder_paths.get_filename_list(folder)
226
+ return web.json_response(files)
227
+
228
+ @routes.get("/extensions")
229
+ async def get_extensions(request):
230
+ files = glob.glob(os.path.join(
231
+ glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
232
+
233
+ extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
234
+
235
+ for name, dir in nodes.EXTENSION_WEB_DIRS.items():
236
+ files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True)
237
+ extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote(
238
+ name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
239
+
240
+ return web.json_response(extensions)
241
+
242
+ def get_dir_by_type(dir_type):
243
+ if dir_type is None:
244
+ dir_type = "input"
245
+
246
+ if dir_type == "input":
247
+ type_dir = folder_paths.get_input_directory()
248
+ elif dir_type == "temp":
249
+ type_dir = folder_paths.get_temp_directory()
250
+ elif dir_type == "output":
251
+ type_dir = folder_paths.get_output_directory()
252
+
253
+ return type_dir, dir_type
254
+
255
+ def compare_image_hash(filepath, image):
256
+ hasher = node_helpers.hasher()
257
+
258
+ # function to compare hashes of two images to see if it already exists, fix to #3465
259
+ if os.path.exists(filepath):
260
+ a = hasher()
261
+ b = hasher()
262
+ with open(filepath, "rb") as f:
263
+ a.update(f.read())
264
+ b.update(image.file.read())
265
+ image.file.seek(0)
266
+ f.close()
267
+ return a.hexdigest() == b.hexdigest()
268
+ return False
269
+
270
+ def image_upload(post, image_save_function=None):
271
+ image = post.get("image")
272
+ overwrite = post.get("overwrite")
273
+ image_is_duplicate = False
274
+
275
+ image_upload_type = post.get("type")
276
+ upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
277
+
278
+ if image and image.file:
279
+ filename = image.filename
280
+ if not filename:
281
+ return web.Response(status=400)
282
+
283
+ subfolder = post.get("subfolder", "")
284
+ full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder))
285
+ filepath = os.path.abspath(os.path.join(full_output_folder, filename))
286
+
287
+ if os.path.commonpath((upload_dir, filepath)) != upload_dir:
288
+ return web.Response(status=400)
289
+
290
+ if not os.path.exists(full_output_folder):
291
+ os.makedirs(full_output_folder)
292
+
293
+ split = os.path.splitext(filename)
294
+
295
+ if overwrite is not None and (overwrite == "true" or overwrite == "1"):
296
+ pass
297
+ else:
298
+ i = 1
299
+ while os.path.exists(filepath):
300
+ if compare_image_hash(filepath, image): #compare hash to prevent saving of duplicates with same name, fix for #3465
301
+ image_is_duplicate = True
302
+ break
303
+ filename = f"{split[0]} ({i}){split[1]}"
304
+ filepath = os.path.join(full_output_folder, filename)
305
+ i += 1
306
+
307
+ if not image_is_duplicate:
308
+ if image_save_function is not None:
309
+ image_save_function(image, post, filepath)
310
+ else:
311
+ with open(filepath, "wb") as f:
312
+ f.write(image.file.read())
313
+
314
+ return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
315
+ else:
316
+ return web.Response(status=400)
317
+
318
+ @routes.post("/upload/image")
319
+ async def upload_image(request):
320
+ post = await request.post()
321
+ return image_upload(post)
322
+
323
+
324
+ @routes.post("/upload/mask")
325
+ async def upload_mask(request):
326
+ post = await request.post()
327
+
328
+ def image_save_function(image, post, filepath):
329
+ original_ref = json.loads(post.get("original_ref"))
330
+ filename, output_dir = folder_paths.annotated_filepath(original_ref['filename'])
331
+
332
+ if not filename:
333
+ return web.Response(status=400)
334
+
335
+ # validation for security: prevent accessing arbitrary path
336
+ if filename[0] == '/' or '..' in filename:
337
+ return web.Response(status=400)
338
+
339
+ if output_dir is None:
340
+ type = original_ref.get("type", "output")
341
+ output_dir = folder_paths.get_directory_by_type(type)
342
+
343
+ if output_dir is None:
344
+ return web.Response(status=400)
345
+
346
+ if original_ref.get("subfolder", "") != "":
347
+ full_output_dir = os.path.join(output_dir, original_ref["subfolder"])
348
+ if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
349
+ return web.Response(status=403)
350
+ output_dir = full_output_dir
351
+
352
+ file = os.path.join(output_dir, filename)
353
+
354
+ if os.path.isfile(file):
355
+ with Image.open(file) as original_pil:
356
+ metadata = PngInfo()
357
+ if hasattr(original_pil,'text'):
358
+ for key in original_pil.text:
359
+ metadata.add_text(key, original_pil.text[key])
360
+ original_pil = original_pil.convert('RGBA')
361
+ mask_pil = Image.open(image.file).convert('RGBA')
362
+
363
+ # alpha copy
364
+ new_alpha = mask_pil.getchannel('A')
365
+ original_pil.putalpha(new_alpha)
366
+ original_pil.save(filepath, compress_level=4, pnginfo=metadata)
367
+
368
+ return image_upload(post, image_save_function)
369
+
370
+ @routes.get("/view")
371
+ async def view_image(request):
372
+ if "filename" in request.rel_url.query:
373
+ filename = request.rel_url.query["filename"]
374
+ filename,output_dir = folder_paths.annotated_filepath(filename)
375
+
376
+ if not filename:
377
+ return web.Response(status=400)
378
+
379
+ # validation for security: prevent accessing arbitrary path
380
+ if filename[0] == '/' or '..' in filename:
381
+ return web.Response(status=400)
382
+
383
+ if output_dir is None:
384
+ type = request.rel_url.query.get("type", "output")
385
+ output_dir = folder_paths.get_directory_by_type(type)
386
+
387
+ if output_dir is None:
388
+ return web.Response(status=400)
389
+
390
+ if "subfolder" in request.rel_url.query:
391
+ full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
392
+ if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
393
+ return web.Response(status=403)
394
+ output_dir = full_output_dir
395
+
396
+ filename = os.path.basename(filename)
397
+ file = os.path.join(output_dir, filename)
398
+
399
+ if os.path.isfile(file):
400
+ if 'preview' in request.rel_url.query:
401
+ with Image.open(file) as img:
402
+ preview_info = request.rel_url.query['preview'].split(';')
403
+ image_format = preview_info[0]
404
+ if image_format not in ['webp', 'jpeg'] or 'a' in request.rel_url.query.get('channel', ''):
405
+ image_format = 'webp'
406
+
407
+ quality = 90
408
+ if preview_info[-1].isdigit():
409
+ quality = int(preview_info[-1])
410
+
411
+ buffer = BytesIO()
412
+ if image_format in ['jpeg'] or request.rel_url.query.get('channel', '') == 'rgb':
413
+ img = img.convert("RGB")
414
+ img.save(buffer, format=image_format, quality=quality)
415
+ buffer.seek(0)
416
+
417
+ return web.Response(body=buffer.read(), content_type=f'image/{image_format}',
418
+ headers={"Content-Disposition": f"filename=\"{filename}\""})
419
+
420
+ if 'channel' not in request.rel_url.query:
421
+ channel = 'rgba'
422
+ else:
423
+ channel = request.rel_url.query["channel"]
424
+
425
+ if channel == 'rgb':
426
+ with Image.open(file) as img:
427
+ if img.mode == "RGBA":
428
+ r, g, b, a = img.split()
429
+ new_img = Image.merge('RGB', (r, g, b))
430
+ else:
431
+ new_img = img.convert("RGB")
432
+
433
+ buffer = BytesIO()
434
+ new_img.save(buffer, format='PNG')
435
+ buffer.seek(0)
436
+
437
+ return web.Response(body=buffer.read(), content_type='image/png',
438
+ headers={"Content-Disposition": f"filename=\"{filename}\""})
439
+
440
+ elif channel == 'a':
441
+ with Image.open(file) as img:
442
+ if img.mode == "RGBA":
443
+ _, _, _, a = img.split()
444
+ else:
445
+ a = Image.new('L', img.size, 255)
446
+
447
+ # alpha img
448
+ alpha_img = Image.new('RGBA', img.size)
449
+ alpha_img.putalpha(a)
450
+ alpha_buffer = BytesIO()
451
+ alpha_img.save(alpha_buffer, format='PNG')
452
+ alpha_buffer.seek(0)
453
+
454
+ return web.Response(body=alpha_buffer.read(), content_type='image/png',
455
+ headers={"Content-Disposition": f"filename=\"{filename}\""})
456
+ else:
457
+ # Get content type from mimetype, defaulting to 'application/octet-stream'
458
+ content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
459
+
460
+ # For security, force certain extensions to download instead of display
461
+ file_extension = os.path.splitext(filename)[1].lower()
462
+ if file_extension in {'.html', '.htm', '.js', '.css'}:
463
+ content_type = 'application/octet-stream' # Forces download
464
+
465
+ return web.FileResponse(
466
+ file,
467
+ headers={
468
+ "Content-Disposition": f"filename=\"{filename}\"",
469
+ "Content-Type": content_type
470
+ }
471
+ )
472
+
473
+ return web.Response(status=404)
474
+
475
+ @routes.get("/view_metadata/{folder_name}")
476
+ async def view_metadata(request):
477
+ folder_name = request.match_info.get("folder_name", None)
478
+ if folder_name is None:
479
+ return web.Response(status=404)
480
+ if not "filename" in request.rel_url.query:
481
+ return web.Response(status=404)
482
+
483
+ filename = request.rel_url.query["filename"]
484
+ if not filename.endswith(".safetensors"):
485
+ return web.Response(status=404)
486
+
487
+ safetensors_path = folder_paths.get_full_path(folder_name, filename)
488
+ if safetensors_path is None:
489
+ return web.Response(status=404)
490
+ out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024)
491
+ if out is None:
492
+ return web.Response(status=404)
493
+ dt = json.loads(out)
494
+ if not "__metadata__" in dt:
495
+ return web.Response(status=404)
496
+ return web.json_response(dt["__metadata__"])
497
+
498
+ @routes.get("/system_stats")
499
+ async def system_stats(request):
500
+ device = comfy.model_management.get_torch_device()
501
+ device_name = comfy.model_management.get_torch_device_name(device)
502
+ cpu_device = comfy.model_management.torch.device("cpu")
503
+ ram_total = comfy.model_management.get_total_memory(cpu_device)
504
+ ram_free = comfy.model_management.get_free_memory(cpu_device)
505
+ vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
506
+ vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
507
+
508
+ system_stats = {
509
+ "system": {
510
+ "os": os.name,
511
+ "ram_total": ram_total,
512
+ "ram_free": ram_free,
513
+ "comfyui_version": __version__,
514
+ "python_version": sys.version,
515
+ "pytorch_version": comfy.model_management.torch_version,
516
+ "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
517
+ "argv": sys.argv
518
+ },
519
+ "devices": [
520
+ {
521
+ "name": device_name,
522
+ "type": device.type,
523
+ "index": device.index,
524
+ "vram_total": vram_total,
525
+ "vram_free": vram_free,
526
+ "torch_vram_total": torch_vram_total,
527
+ "torch_vram_free": torch_vram_free,
528
+ }
529
+ ]
530
+ }
531
+ return web.json_response(system_stats)
532
+
533
+ @routes.get("/prompt")
534
+ async def get_prompt(request):
535
+ return web.json_response(self.get_queue_info())
536
+
537
+ def node_info(node_class):
538
+ obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
539
+ info = {}
540
+ info['input'] = obj_class.INPUT_TYPES()
541
+ info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
542
+ info['output'] = obj_class.RETURN_TYPES
543
+ info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
544
+ info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
545
+ info['name'] = node_class
546
+ info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
547
+ info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else ''
548
+ info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes")
549
+ info['category'] = 'sd'
550
+ if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
551
+ info['output_node'] = True
552
+ else:
553
+ info['output_node'] = False
554
+
555
+ if hasattr(obj_class, 'CATEGORY'):
556
+ info['category'] = obj_class.CATEGORY
557
+
558
+ if hasattr(obj_class, 'OUTPUT_TOOLTIPS'):
559
+ info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS
560
+
561
+ if getattr(obj_class, "DEPRECATED", False):
562
+ info['deprecated'] = True
563
+ if getattr(obj_class, "EXPERIMENTAL", False):
564
+ info['experimental'] = True
565
+ return info
566
+
567
+ @routes.get("/object_info")
568
+ async def get_object_info(request):
569
+ with folder_paths.cache_helper:
570
+ out = {}
571
+ for x in nodes.NODE_CLASS_MAPPINGS:
572
+ try:
573
+ out[x] = node_info(x)
574
+ except Exception:
575
+ logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
576
+ logging.error(traceback.format_exc())
577
+ return web.json_response(out)
578
+
579
+ @routes.get("/object_info/{node_class}")
580
+ async def get_object_info_node(request):
581
+ node_class = request.match_info.get("node_class", None)
582
+ out = {}
583
+ if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
584
+ out[node_class] = node_info(node_class)
585
+ return web.json_response(out)
586
+
587
+ @routes.get("/history")
588
+ async def get_history(request):
589
+ max_items = request.rel_url.query.get("max_items", None)
590
+ if max_items is not None:
591
+ max_items = int(max_items)
592
+ return web.json_response(self.prompt_queue.get_history(max_items=max_items))
593
+
594
+ @routes.get("/history/{prompt_id}")
595
+ async def get_history_prompt_id(request):
596
+ prompt_id = request.match_info.get("prompt_id", None)
597
+ return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))
598
+
599
+ @routes.get("/queue")
600
+ async def get_queue(request):
601
+ queue_info = {}
602
+ current_queue = self.prompt_queue.get_current_queue()
603
+ queue_info['queue_running'] = current_queue[0]
604
+ queue_info['queue_pending'] = current_queue[1]
605
+ return web.json_response(queue_info)
606
+
607
+ @routes.post("/prompt")
608
+ async def post_prompt(request):
609
+ logging.info("got prompt")
610
+ json_data = await request.json()
611
+ json_data = self.trigger_on_prompt(json_data)
612
+
613
+ if "number" in json_data:
614
+ number = float(json_data['number'])
615
+ else:
616
+ number = self.number
617
+ if "front" in json_data:
618
+ if json_data['front']:
619
+ number = -number
620
+
621
+ self.number += 1
622
+
623
+ if "prompt" in json_data:
624
+ prompt = json_data["prompt"]
625
+ valid = execution.validate_prompt(prompt)
626
+ extra_data = {}
627
+ if "extra_data" in json_data:
628
+ extra_data = json_data["extra_data"]
629
+
630
+ if "client_id" in json_data:
631
+ extra_data["client_id"] = json_data["client_id"]
632
+ if valid[0]:
633
+ prompt_id = str(uuid.uuid4())
634
+ outputs_to_execute = valid[2]
635
+ self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
636
+ response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
637
+ return web.json_response(response)
638
+ else:
639
+ logging.warning("invalid prompt: {}".format(valid[1]))
640
+ return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
641
+ else:
642
+ return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
643
+
644
+ @routes.post("/queue")
645
+ async def post_queue(request):
646
+ json_data = await request.json()
647
+ if "clear" in json_data:
648
+ if json_data["clear"]:
649
+ self.prompt_queue.wipe_queue()
650
+ if "delete" in json_data:
651
+ to_delete = json_data['delete']
652
+ for id_to_delete in to_delete:
653
+ delete_func = lambda a: a[1] == id_to_delete
654
+ self.prompt_queue.delete_queue_item(delete_func)
655
+
656
+ return web.Response(status=200)
657
+
658
+ @routes.post("/interrupt")
659
+ async def post_interrupt(request):
660
+ nodes.interrupt_processing()
661
+ return web.Response(status=200)
662
+
663
+ @routes.post("/free")
664
+ async def post_free(request):
665
+ json_data = await request.json()
666
+ unload_models = json_data.get("unload_models", False)
667
+ free_memory = json_data.get("free_memory", False)
668
+ if unload_models:
669
+ self.prompt_queue.set_flag("unload_models", unload_models)
670
+ if free_memory:
671
+ self.prompt_queue.set_flag("free_memory", free_memory)
672
+ return web.Response(status=200)
673
+
674
+ @routes.post("/history")
675
+ async def post_history(request):
676
+ json_data = await request.json()
677
+ if "clear" in json_data:
678
+ if json_data["clear"]:
679
+ self.prompt_queue.wipe_history()
680
+ if "delete" in json_data:
681
+ to_delete = json_data['delete']
682
+ for id_to_delete in to_delete:
683
+ self.prompt_queue.delete_history_item(id_to_delete)
684
+
685
+ return web.Response(status=200)
686
+
687
+ async def setup(self):
688
+ timeout = aiohttp.ClientTimeout(total=None) # no timeout
689
+ self.client_session = aiohttp.ClientSession(timeout=timeout)
690
+
691
+ def add_routes(self):
692
+ self.user_manager.add_routes(self.routes)
693
+ self.model_file_manager.add_routes(self.routes)
694
+ self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
695
+ self.app.add_subapp('/internal', self.internal_routes.get_app())
696
+
697
+ # Prefix every route with /api for easier matching for delegation.
698
+ # This is very useful for frontend dev server, which need to forward
699
+ # everything except serving of static files.
700
+ # Currently both the old endpoints without prefix and new endpoints with
701
+ # prefix are supported.
702
+ api_routes = web.RouteTableDef()
703
+ for route in self.routes:
704
+ # Custom nodes might add extra static routes. Only process non-static
705
+ # routes to add /api prefix.
706
+ if isinstance(route, web.RouteDef):
707
+ api_routes.route(route.method, "/api" + route.path)(route.handler, **route.kwargs)
708
+ self.app.add_routes(api_routes)
709
+ self.app.add_routes(self.routes)
710
+
711
+ # Add routes from web extensions.
712
+ for name, dir in nodes.EXTENSION_WEB_DIRS.items():
713
+ self.app.add_routes([web.static('/extensions/' + name, dir)])
714
+
715
+ self.app.add_routes([
716
+ web.static('/', self.web_root),
717
+ ])
718
+
719
+ def get_queue_info(self):
720
+ prompt_info = {}
721
+ exec_info = {}
722
+ exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
723
+ prompt_info['exec_info'] = exec_info
724
+ return prompt_info
725
+
726
+ async def send(self, event, data, sid=None):
727
+ if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
728
+ await self.send_image(data, sid=sid)
729
+ elif isinstance(data, (bytes, bytearray)):
730
+ await self.send_bytes(event, data, sid)
731
+ else:
732
+ await self.send_json(event, data, sid)
733
+
734
+ def encode_bytes(self, event, data):
735
+ if not isinstance(event, int):
736
+ raise RuntimeError(f"Binary event types must be integers, got {event}")
737
+
738
+ packed = struct.pack(">I", event)
739
+ message = bytearray(packed)
740
+ message.extend(data)
741
+ return message
742
+
743
+ async def send_image(self, image_data, sid=None):
744
+ image_type = image_data[0]
745
+ image = image_data[1]
746
+ max_size = image_data[2]
747
+ if max_size is not None:
748
+ if hasattr(Image, 'Resampling'):
749
+ resampling = Image.Resampling.BILINEAR
750
+ else:
751
+ resampling = Image.ANTIALIAS
752
+
753
+ image = ImageOps.contain(image, (max_size, max_size), resampling)
754
+ type_num = 1
755
+ if image_type == "JPEG":
756
+ type_num = 1
757
+ elif image_type == "PNG":
758
+ type_num = 2
759
+
760
+ bytesIO = BytesIO()
761
+ header = struct.pack(">I", type_num)
762
+ bytesIO.write(header)
763
+ image.save(bytesIO, format=image_type, quality=95, compress_level=1)
764
+ preview_bytes = bytesIO.getvalue()
765
+ await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
766
+
767
+ async def send_bytes(self, event, data, sid=None):
768
+ message = self.encode_bytes(event, data)
769
+
770
+ if sid is None:
771
+ sockets = list(self.sockets.values())
772
+ for ws in sockets:
773
+ await send_socket_catch_exception(ws.send_bytes, message)
774
+ elif sid in self.sockets:
775
+ await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
776
+
777
+ async def send_json(self, event, data, sid=None):
778
+ message = {"type": event, "data": data}
779
+
780
+ if sid is None:
781
+ sockets = list(self.sockets.values())
782
+ for ws in sockets:
783
+ await send_socket_catch_exception(ws.send_json, message)
784
+ elif sid in self.sockets:
785
+ await send_socket_catch_exception(self.sockets[sid].send_json, message)
786
+
787
+ def send_sync(self, event, data, sid=None):
788
+ self.loop.call_soon_threadsafe(
789
+ self.messages.put_nowait, (event, data, sid))
790
+
791
+ def queue_updated(self):
792
+ self.send_sync("status", { "status": self.get_queue_info() })
793
+
794
+ async def publish_loop(self):
795
+ while True:
796
+ msg = await self.messages.get()
797
+ await self.send(*msg)
798
+
799
+ async def start(self, address, port, verbose=True, call_on_start=None):
800
+ await self.start_multi_address([(address, port)], call_on_start=call_on_start)
801
+
802
+ async def start_multi_address(self, addresses, call_on_start=None, verbose=True):
803
+ runner = web.AppRunner(self.app, access_log=None)
804
+ await runner.setup()
805
+ ssl_ctx = None
806
+ scheme = "http"
807
+ if args.tls_keyfile and args.tls_certfile:
808
+ ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE)
809
+ ssl_ctx.load_cert_chain(certfile=args.tls_certfile,
810
+ keyfile=args.tls_keyfile)
811
+ scheme = "https"
812
+
813
+ if verbose:
814
+ logging.info("Starting server\n")
815
+ for addr in addresses:
816
+ address = addr[0]
817
+ port = addr[1]
818
+ site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
819
+ await site.start()
820
+
821
+ if not hasattr(self, 'address'):
822
+ self.address = address #TODO: remove this
823
+ self.port = port
824
+
825
+ if ':' in address:
826
+ address_print = "[{}]".format(address)
827
+ else:
828
+ address_print = address
829
+
830
+ if verbose:
831
+ logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
832
+
833
+ if call_on_start is not None:
834
+ call_on_start(scheme, self.address, self.port)
835
+
836
+ def add_on_prompt_handler(self, handler):
837
+ self.on_prompt_handlers.append(handler)
838
+
839
+ def trigger_on_prompt(self, json_data):
840
+ for handler in self.on_prompt_handlers:
841
+ try:
842
+ json_data = handler(json_data)
843
+ except Exception:
844
+ logging.warning("[ERROR] An error occurred during the on_prompt_handler processing")
845
+ logging.warning(traceback.format_exc())
846
+
847
+ return json_data
tests-unit/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Pytest Unit Tests
2
+
3
+ ## Install test dependencies
4
+
5
+ `pip install -r tests-unit/requirements.txt`
6
+
7
+ ## Run tests
8
+ `pytest tests-unit/`
tests-unit/app_test/__init__.py ADDED
File without changes
tests-unit/app_test/custom_node_manager_test.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from aiohttp import web
3
+ from unittest.mock import patch
4
+ from app.custom_node_manager import CustomNodeManager
5
+
6
+ pytestmark = (
7
+ pytest.mark.asyncio
8
+ ) # This applies the asyncio mark to all test functions in the module
9
+
10
+ @pytest.fixture
11
+ def custom_node_manager():
12
+ return CustomNodeManager()
13
+
14
+ @pytest.fixture
15
+ def app(custom_node_manager):
16
+ app = web.Application()
17
+ routes = web.RouteTableDef()
18
+ custom_node_manager.add_routes(routes, app, [("ComfyUI-TestExtension1", "ComfyUI-TestExtension1")])
19
+ app.add_routes(routes)
20
+ return app
21
+
22
+ async def test_get_workflow_templates(aiohttp_client, app, tmp_path):
23
+ client = await aiohttp_client(app)
24
+ # Setup temporary custom nodes file structure with 1 workflow file
25
+ custom_nodes_dir = tmp_path / "custom_nodes"
26
+ example_workflows_dir = custom_nodes_dir / "ComfyUI-TestExtension1" / "example_workflows"
27
+ example_workflows_dir.mkdir(parents=True)
28
+ template_file = example_workflows_dir / "workflow1.json"
29
+ template_file.write_text('')
30
+
31
+ with patch('folder_paths.folder_names_and_paths', {
32
+ 'custom_nodes': ([str(custom_nodes_dir)], None)
33
+ }):
34
+ response = await client.get('/workflow_templates')
35
+ assert response.status == 200
36
+ workflows_dict = await response.json()
37
+ assert isinstance(workflows_dict, dict)
38
+ assert "ComfyUI-TestExtension1" in workflows_dict
39
+ assert isinstance(workflows_dict["ComfyUI-TestExtension1"], list)
40
+ assert workflows_dict["ComfyUI-TestExtension1"][0] == "workflow1"
tests-unit/app_test/frontend_manager_test.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pytest
3
+ from requests.exceptions import HTTPError
4
+ from unittest.mock import patch
5
+
6
+ from app.frontend_management import (
7
+ FrontendManager,
8
+ FrontEndProvider,
9
+ Release,
10
+ )
11
+ from comfy.cli_args import DEFAULT_VERSION_STRING
12
+
13
+
14
+ @pytest.fixture
15
+ def mock_releases():
16
+ return [
17
+ Release(
18
+ id=1,
19
+ tag_name="1.0.0",
20
+ name="Release 1.0.0",
21
+ prerelease=False,
22
+ created_at="2022-01-01T00:00:00Z",
23
+ published_at="2022-01-01T00:00:00Z",
24
+ body="Release notes for 1.0.0",
25
+ assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
26
+ ),
27
+ Release(
28
+ id=2,
29
+ tag_name="2.0.0",
30
+ name="Release 2.0.0",
31
+ prerelease=False,
32
+ created_at="2022-02-01T00:00:00Z",
33
+ published_at="2022-02-01T00:00:00Z",
34
+ body="Release notes for 2.0.0",
35
+ assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
36
+ ),
37
+ ]
38
+
39
+
40
+ @pytest.fixture
41
+ def mock_provider(mock_releases):
42
+ provider = FrontEndProvider(
43
+ owner="test-owner",
44
+ repo="test-repo",
45
+ )
46
+ provider.all_releases = mock_releases
47
+ provider.latest_release = mock_releases[1]
48
+ FrontendManager.PROVIDERS = [provider]
49
+ return provider
50
+
51
+
52
+ def test_get_release(mock_provider, mock_releases):
53
+ version = "1.0.0"
54
+ release = mock_provider.get_release(version)
55
+ assert release == mock_releases[0]
56
+
57
+
58
+ def test_get_release_latest(mock_provider, mock_releases):
59
+ version = "latest"
60
+ release = mock_provider.get_release(version)
61
+ assert release == mock_releases[1]
62
+
63
+
64
+ def test_get_release_invalid_version(mock_provider):
65
+ version = "invalid"
66
+ with pytest.raises(ValueError):
67
+ mock_provider.get_release(version)
68
+
69
+
70
+ def test_init_frontend_default():
71
+ version_string = DEFAULT_VERSION_STRING
72
+ frontend_path = FrontendManager.init_frontend(version_string)
73
+ assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH
74
+
75
+
76
+ def test_init_frontend_invalid_version():
77
+ version_string = "test-owner/[email protected]"
78
+ with pytest.raises(HTTPError):
79
+ FrontendManager.init_frontend_unsafe(version_string)
80
+
81
+
82
+ def test_init_frontend_invalid_provider():
83
+ version_string = "invalid/invalid@latest"
84
+ with pytest.raises(HTTPError):
85
+ FrontendManager.init_frontend_unsafe(version_string)
86
+
87
+ @pytest.fixture
88
+ def mock_os_functions():
89
+ with patch('app.frontend_management.os.makedirs') as mock_makedirs, \
90
+ patch('app.frontend_management.os.listdir') as mock_listdir, \
91
+ patch('app.frontend_management.os.rmdir') as mock_rmdir:
92
+ mock_listdir.return_value = [] # Simulate empty directory
93
+ yield mock_makedirs, mock_listdir, mock_rmdir
94
+
95
+ @pytest.fixture
96
+ def mock_download():
97
+ with patch('app.frontend_management.download_release_asset_zip') as mock:
98
+ mock.side_effect = Exception("Download failed") # Simulate download failure
99
+ yield mock
100
+
101
+ def test_finally_block(mock_os_functions, mock_download, mock_provider):
102
+ # Arrange
103
+ mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
104
+ version_string = 'test-owner/[email protected]'
105
+
106
+ # Act & Assert
107
+ with pytest.raises(Exception):
108
+ FrontendManager.init_frontend_unsafe(version_string, mock_provider)
109
+
110
+ # Assert
111
+ mock_makedirs.assert_called_once()
112
+ mock_download.assert_called_once()
113
+ mock_listdir.assert_called_once()
114
+ mock_rmdir.assert_called_once()
115
+
116
+
117
+ def test_parse_version_string():
118
+ version_string = "owner/[email protected]"
119
+ repo_owner, repo_name, version = FrontendManager.parse_version_string(
120
+ version_string
121
+ )
122
+ assert repo_owner == "owner"
123
+ assert repo_name == "repo"
124
+ assert version == "1.0.0"
125
+
126
+
127
+ def test_parse_version_string_invalid():
128
+ version_string = "invalid"
129
+ with pytest.raises(argparse.ArgumentTypeError):
130
+ FrontendManager.parse_version_string(version_string)
tests-unit/app_test/model_manager_test.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import base64
3
+ import json
4
+ import struct
5
+ from io import BytesIO
6
+ from PIL import Image
7
+ from aiohttp import web
8
+ from unittest.mock import patch
9
+ from app.model_manager import ModelFileManager
10
+
11
+ pytestmark = (
12
+ pytest.mark.asyncio
13
+ ) # This applies the asyncio mark to all test functions in the module
14
+
15
+ @pytest.fixture
16
+ def model_manager():
17
+ return ModelFileManager()
18
+
19
+ @pytest.fixture
20
+ def app(model_manager):
21
+ app = web.Application()
22
+ routes = web.RouteTableDef()
23
+ model_manager.add_routes(routes)
24
+ app.add_routes(routes)
25
+ return app
26
+
27
+ async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
28
+ img = Image.new('RGB', (100, 100), 'white')
29
+ img_byte_arr = BytesIO()
30
+ img.save(img_byte_arr, format='PNG')
31
+ img_byte_arr.seek(0)
32
+ img_b64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
33
+
34
+ safetensors_file = tmp_path / "test_model.safetensors"
35
+ header_bytes = json.dumps({
36
+ "__metadata__": {
37
+ "ssmd_cover_images": json.dumps([img_b64])
38
+ }
39
+ }).encode('utf-8')
40
+ length_bytes = struct.pack('<Q', len(header_bytes))
41
+ with open(safetensors_file, 'wb') as f:
42
+ f.write(length_bytes)
43
+ f.write(header_bytes)
44
+
45
+ with patch('folder_paths.folder_names_and_paths', {
46
+ 'test_folder': ([str(tmp_path)], None)
47
+ }):
48
+ client = await aiohttp_client(app)
49
+ response = await client.get('/experiment/models/preview/test_folder/0/test_model.safetensors')
50
+
51
+ # Verify response
52
+ assert response.status == 200
53
+ assert response.content_type == 'image/webp'
54
+
55
+ # Verify the response contains valid image data
56
+ img_bytes = BytesIO(await response.read())
57
+ img = Image.open(img_bytes)
58
+ assert img.format
59
+ assert img.format.lower() == 'webp'
60
+
61
+ # Clean up
62
+ img.close()
tests-unit/comfy_test/folder_path_test.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 🗻 This file is created through the spirit of Mount Fuji at its peak
2
+ # TODO(yoland): clean up this after I get back down
3
+ import pytest
4
+ import os
5
+ import tempfile
6
+ from unittest.mock import patch
7
+
8
+ import folder_paths
9
+
10
+ @pytest.fixture()
11
+ def clear_folder_paths():
12
+ # Clear the global dictionary before each test to ensure isolation
13
+ original = folder_paths.folder_names_and_paths.copy()
14
+ folder_paths.folder_names_and_paths.clear()
15
+ yield
16
+ folder_paths.folder_names_and_paths = original
17
+
18
+ @pytest.fixture
19
+ def temp_dir():
20
+ with tempfile.TemporaryDirectory() as tmpdirname:
21
+ yield tmpdirname
22
+
23
+
24
+ def test_get_directory_by_type():
25
+ test_dir = "/test/dir"
26
+ folder_paths.set_output_directory(test_dir)
27
+ assert folder_paths.get_directory_by_type("output") == test_dir
28
+ assert folder_paths.get_directory_by_type("invalid") is None
29
+
30
+ def test_annotated_filepath():
31
+ assert folder_paths.annotated_filepath("test.txt") == ("test.txt", None)
32
+ assert folder_paths.annotated_filepath("test.txt [output]") == ("test.txt", folder_paths.get_output_directory())
33
+ assert folder_paths.annotated_filepath("test.txt [input]") == ("test.txt", folder_paths.get_input_directory())
34
+ assert folder_paths.annotated_filepath("test.txt [temp]") == ("test.txt", folder_paths.get_temp_directory())
35
+
36
+ def test_get_annotated_filepath():
37
+ default_dir = "/default/dir"
38
+ assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
39
+ assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
40
+
41
+ def test_add_model_folder_path_append(clear_folder_paths):
42
+ folder_paths.add_model_folder_path("test_folder", "/default/path", is_default=True)
43
+ folder_paths.add_model_folder_path("test_folder", "/test/path", is_default=False)
44
+ assert folder_paths.get_folder_paths("test_folder") == ["/default/path", "/test/path"]
45
+
46
+
47
+ def test_add_model_folder_path_insert(clear_folder_paths):
48
+ folder_paths.add_model_folder_path("test_folder", "/test/path", is_default=False)
49
+ folder_paths.add_model_folder_path("test_folder", "/default/path", is_default=True)
50
+ assert folder_paths.get_folder_paths("test_folder") == ["/default/path", "/test/path"]
51
+
52
+
53
+ def test_add_model_folder_path_re_add_existing_default(clear_folder_paths):
54
+ folder_paths.add_model_folder_path("test_folder", "/test/path", is_default=False)
55
+ folder_paths.add_model_folder_path("test_folder", "/old_default/path", is_default=True)
56
+ assert folder_paths.get_folder_paths("test_folder") == ["/old_default/path", "/test/path"]
57
+ folder_paths.add_model_folder_path("test_folder", "/test/path", is_default=True)
58
+ assert folder_paths.get_folder_paths("test_folder") == ["/test/path", "/old_default/path"]
59
+
60
+
61
+ def test_add_model_folder_path_re_add_existing_non_default(clear_folder_paths):
62
+ folder_paths.add_model_folder_path("test_folder", "/test/path", is_default=False)
63
+ folder_paths.add_model_folder_path("test_folder", "/default/path", is_default=True)
64
+ assert folder_paths.get_folder_paths("test_folder") == ["/default/path", "/test/path"]
65
+ folder_paths.add_model_folder_path("test_folder", "/test/path", is_default=False)
66
+ assert folder_paths.get_folder_paths("test_folder") == ["/default/path", "/test/path"]
67
+
68
+
69
+ def test_recursive_search(temp_dir):
70
+ os.makedirs(os.path.join(temp_dir, "subdir"))
71
+ open(os.path.join(temp_dir, "file1.txt"), "w").close()
72
+ open(os.path.join(temp_dir, "subdir", "file2.txt"), "w").close()
73
+
74
+ files, dirs = folder_paths.recursive_search(temp_dir)
75
+ assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")}
76
+ assert len(dirs) == 2 # temp_dir and subdir
77
+
78
+ def test_filter_files_extensions():
79
+ files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"]
80
+ assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"]
81
+ assert folder_paths.filter_files_extensions(files, [".jpg", ".png"]) == ["file2.jpg", "file3.png"]
82
+ assert folder_paths.filter_files_extensions(files, []) == files
83
+
84
+ @patch("folder_paths.recursive_search")
85
+ @patch("folder_paths.folder_names_and_paths")
86
+ def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
87
+ mock_folder_names_and_paths.__getitem__.return_value = (["/test/path"], {".txt"})
88
+ mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {})
89
+ assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
90
+
91
+ def test_get_save_image_path(temp_dir):
92
+ with patch("folder_paths.output_directory", temp_dir):
93
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100)
94
+ assert os.path.samefile(full_output_folder, temp_dir)
95
+ assert filename == "test"
96
+ assert counter == 1
97
+ assert subfolder == ""
98
+ assert filename_prefix == "test"
tests-unit/execution_test/validate_node_input_test.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from comfy_execution.validation import validate_node_input
3
+
4
+
5
+ def test_exact_match():
6
+ """Test cases where types match exactly"""
7
+ assert validate_node_input("STRING", "STRING")
8
+ assert validate_node_input("STRING,INT", "STRING,INT")
9
+ assert validate_node_input("INT,STRING", "STRING,INT") # Order shouldn't matter
10
+
11
+
12
+ def test_strict_mode():
13
+ """Test strict mode validation"""
14
+ # Should pass - received type is subset of input type
15
+ assert validate_node_input("STRING", "STRING,INT", strict=True)
16
+ assert validate_node_input("INT", "STRING,INT", strict=True)
17
+ assert validate_node_input("STRING,INT", "STRING,INT,BOOLEAN", strict=True)
18
+
19
+ # Should fail - received type is not subset of input type
20
+ assert not validate_node_input("STRING,INT", "STRING", strict=True)
21
+ assert not validate_node_input("STRING,BOOLEAN", "STRING", strict=True)
22
+ assert not validate_node_input("INT,BOOLEAN", "STRING,INT", strict=True)
23
+
24
+
25
+ def test_non_strict_mode():
26
+ """Test non-strict mode validation (default behavior)"""
27
+ # Should pass - types have overlap
28
+ assert validate_node_input("STRING,BOOLEAN", "STRING,INT")
29
+ assert validate_node_input("STRING,INT", "INT,BOOLEAN")
30
+ assert validate_node_input("STRING", "STRING,INT")
31
+
32
+ # Should fail - no overlap in types
33
+ assert not validate_node_input("BOOLEAN", "STRING,INT")
34
+ assert not validate_node_input("FLOAT", "STRING,INT")
35
+ assert not validate_node_input("FLOAT,BOOLEAN", "STRING,INT")
36
+
37
+
38
+ def test_whitespace_handling():
39
+ """Test that whitespace is handled correctly"""
40
+ assert validate_node_input("STRING, INT", "STRING,INT")
41
+ assert validate_node_input("STRING,INT", "STRING, INT")
42
+ assert validate_node_input(" STRING , INT ", "STRING,INT")
43
+ assert validate_node_input("STRING,INT", " STRING , INT ")
44
+
45
+
46
+ def test_empty_strings():
47
+ """Test behavior with empty strings"""
48
+ assert validate_node_input("", "")
49
+ assert not validate_node_input("STRING", "")
50
+ assert not validate_node_input("", "STRING")
51
+
52
+
53
+ def test_single_vs_multiple():
54
+ """Test single type against multiple types"""
55
+ assert validate_node_input("STRING", "STRING,INT,BOOLEAN")
56
+ assert validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=False)
57
+ assert not validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=True)
58
+
59
+
60
+ def test_non_string():
61
+ """Test non-string types"""
62
+ obj1 = object()
63
+ obj2 = object()
64
+ assert validate_node_input(obj1, obj1)
65
+ assert not validate_node_input(obj1, obj2)
66
+
67
+
68
+ class NotEqualsOverrideTest(str):
69
+ """Test class for ``__ne__`` override."""
70
+
71
+ def __ne__(self, value: object) -> bool:
72
+ if self == "*" or value == "*":
73
+ return False
74
+ if self == "LONGER_THAN_2":
75
+ return not len(value) > 2
76
+ raise TypeError("This is a class for unit tests only.")
77
+
78
+
79
+ def test_ne_override():
80
+ """Test ``__ne__`` any override"""
81
+ any = NotEqualsOverrideTest("*")
82
+ invalid_type = "INVALID_TYPE"
83
+ obj = object()
84
+ assert validate_node_input(any, any)
85
+ assert validate_node_input(any, invalid_type)
86
+ assert validate_node_input(any, obj)
87
+ assert validate_node_input(any, {})
88
+ assert validate_node_input(any, [])
89
+ assert validate_node_input(any, [1, 2, 3])
90
+
91
+
92
+ def test_ne_custom_override():
93
+ """Test ``__ne__`` custom override"""
94
+ special = NotEqualsOverrideTest("LONGER_THAN_2")
95
+
96
+ assert validate_node_input(special, special)
97
+ assert validate_node_input(special, "*")
98
+ assert validate_node_input(special, "INVALID_TYPE")
99
+ assert validate_node_input(special, [1, 2, 3])
100
+
101
+ # Should fail
102
+ assert not validate_node_input(special, [1, 2])
103
+ assert not validate_node_input(special, "TY")
104
+
105
+
106
+ @pytest.mark.parametrize(
107
+ "received,input_type,strict,expected",
108
+ [
109
+ ("STRING", "STRING", False, True),
110
+ ("STRING,INT", "STRING,INT", False, True),
111
+ ("STRING", "STRING,INT", True, True),
112
+ ("STRING,INT", "STRING", True, False),
113
+ ("BOOLEAN", "STRING,INT", False, False),
114
+ ("STRING,BOOLEAN", "STRING,INT", False, True),
115
+ ],
116
+ )
117
+ def test_parametrized_cases(received, input_type, strict, expected):
118
+ """Parametrized test cases for various scenarios"""
119
+ assert validate_node_input(received, input_type, strict) == expected
tests-unit/folder_paths_test/__init__.py ADDED
File without changes
tests-unit/folder_paths_test/filter_by_content_types_test.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import os
3
+ import tempfile
4
+ from folder_paths import filter_files_content_types
5
+
6
+ @pytest.fixture(scope="module")
7
+ def file_extensions():
8
+ return {
9
+ 'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
10
+ 'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
11
+ 'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
12
+ }
13
+
14
+
15
+ @pytest.fixture(scope="module")
16
+ def mock_dir(file_extensions):
17
+ with tempfile.TemporaryDirectory() as directory:
18
+ for content_type, extensions in file_extensions.items():
19
+ for extension in extensions:
20
+ with open(f"{directory}/sample_{content_type}.{extension}", "w") as f:
21
+ f.write(f"Sample {content_type} file in {extension} format")
22
+ yield directory
23
+
24
+
25
+ def test_categorizes_all_correctly(mock_dir, file_extensions):
26
+ files = os.listdir(mock_dir)
27
+ for content_type, extensions in file_extensions.items():
28
+ filtered_files = filter_files_content_types(files, [content_type])
29
+ for extension in extensions:
30
+ assert f"sample_{content_type}.{extension}" in filtered_files
31
+
32
+
33
+ def test_categorizes_all_uniquely(mock_dir, file_extensions):
34
+ files = os.listdir(mock_dir)
35
+ for content_type, extensions in file_extensions.items():
36
+ filtered_files = filter_files_content_types(files, [content_type])
37
+ assert len(filtered_files) == len(extensions)
38
+
39
+
40
+ def test_handles_bad_extensions():
41
+ files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"]
42
+ assert filter_files_content_types(files, ["image", "audio", "video"]) == []
43
+
44
+
45
+ def test_handles_no_extension():
46
+ files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"]
47
+ assert filter_files_content_types(files, ["image", "audio", "video"]) == []
48
+
49
+
50
+ def test_handles_no_files():
51
+ files = []
52
+ assert filter_files_content_types(files, ["image", "audio", "video"]) == []
tests-unit/prompt_server_test/__init__.py ADDED
File without changes
tests-unit/prompt_server_test/user_manager_test.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import os
3
+ from aiohttp import web
4
+ from app.user_manager import UserManager
5
+ from unittest.mock import patch
6
+
7
+ pytestmark = (
8
+ pytest.mark.asyncio
9
+ ) # This applies the asyncio mark to all test functions in the module
10
+
11
+
12
+ @pytest.fixture
13
+ def user_manager(tmp_path):
14
+ um = UserManager()
15
+ um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
16
+ tmp_path, file
17
+ ) if file else tmp_path
18
+ return um
19
+
20
+
21
+ @pytest.fixture
22
+ def app(user_manager):
23
+ app = web.Application()
24
+ routes = web.RouteTableDef()
25
+ user_manager.add_routes(routes)
26
+ app.add_routes(routes)
27
+ return app
28
+
29
+
30
+ async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path):
31
+ client = await aiohttp_client(app)
32
+ resp = await client.get("/userdata?dir=test_dir")
33
+ assert resp.status == 404
34
+
35
+
36
+ async def test_listuserdata_with_files(aiohttp_client, app, tmp_path):
37
+ os.makedirs(tmp_path / "test_dir")
38
+ with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
39
+ f.write("test content")
40
+
41
+ client = await aiohttp_client(app)
42
+ resp = await client.get("/userdata?dir=test_dir")
43
+ assert resp.status == 200
44
+ assert await resp.json() == ["file1.txt"]
45
+
46
+
47
+ async def test_listuserdata_recursive(aiohttp_client, app, tmp_path):
48
+ os.makedirs(tmp_path / "test_dir" / "subdir")
49
+ with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
50
+ f.write("test content")
51
+ with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f:
52
+ f.write("test content")
53
+
54
+ client = await aiohttp_client(app)
55
+ resp = await client.get("/userdata?dir=test_dir&recurse=true")
56
+ assert resp.status == 200
57
+ assert set(await resp.json()) == {"file1.txt", "subdir/file2.txt"}
58
+
59
+
60
+ async def test_listuserdata_full_info(aiohttp_client, app, tmp_path):
61
+ os.makedirs(tmp_path / "test_dir")
62
+ with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
63
+ f.write("test content")
64
+
65
+ client = await aiohttp_client(app)
66
+ resp = await client.get("/userdata?dir=test_dir&full_info=true")
67
+ assert resp.status == 200
68
+ result = await resp.json()
69
+ assert len(result) == 1
70
+ assert result[0]["path"] == "file1.txt"
71
+ assert "size" in result[0]
72
+ assert "modified" in result[0]
73
+
74
+
75
+ async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
76
+ os.makedirs(tmp_path / "test_dir" / "subdir")
77
+ with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
78
+ f.write("test content")
79
+
80
+ client = await aiohttp_client(app)
81
+ resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
82
+ assert resp.status == 200
83
+ assert await resp.json() == [["subdir/file1.txt", "subdir", "file1.txt"]]
84
+
85
+
86
+ async def test_listuserdata_invalid_directory(aiohttp_client, app):
87
+ client = await aiohttp_client(app)
88
+ resp = await client.get("/userdata?dir=")
89
+ assert resp.status == 400
90
+
91
+
92
+ async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
93
+ os_sep = "\\"
94
+ with patch("os.sep", os_sep):
95
+ with patch("os.path.sep", os_sep):
96
+ os.makedirs(tmp_path / "test_dir" / "subdir")
97
+ with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
98
+ f.write("test content")
99
+
100
+ client = await aiohttp_client(app)
101
+ resp = await client.get("/userdata?dir=test_dir&recurse=true")
102
+ assert resp.status == 200
103
+ result = await resp.json()
104
+ assert len(result) == 1
105
+ assert "/" in result[0] # Ensure forward slash is used
106
+ assert "\\" not in result[0] # Ensure backslash is not present
107
+ assert result[0] == "subdir/file1.txt"
108
+
109
+ # Test with full_info
110
+ resp = await client.get(
111
+ "/userdata?dir=test_dir&recurse=true&full_info=true"
112
+ )
113
+ assert resp.status == 200
114
+ result = await resp.json()
115
+ assert len(result) == 1
116
+ assert "/" in result[0]["path"] # Ensure forward slash is used
117
+ assert "\\" not in result[0]["path"] # Ensure backslash is not present
118
+ assert result[0]["path"] == "subdir/file1.txt"
119
+
120
+
121
+ async def test_post_userdata_new_file(aiohttp_client, app, tmp_path):
122
+ client = await aiohttp_client(app)
123
+ content = b"test content"
124
+ resp = await client.post("/userdata/test.txt", data=content)
125
+
126
+ assert resp.status == 200
127
+ assert await resp.text() == '"test.txt"'
128
+
129
+ # Verify file was created with correct content
130
+ with open(tmp_path / "test.txt", "rb") as f:
131
+ assert f.read() == content
132
+
133
+
134
+ async def test_post_userdata_overwrite_existing(aiohttp_client, app, tmp_path):
135
+ # Create initial file
136
+ with open(tmp_path / "test.txt", "w") as f:
137
+ f.write("initial content")
138
+
139
+ client = await aiohttp_client(app)
140
+ new_content = b"updated content"
141
+ resp = await client.post("/userdata/test.txt", data=new_content)
142
+
143
+ assert resp.status == 200
144
+ assert await resp.text() == '"test.txt"'
145
+
146
+ # Verify file was overwritten
147
+ with open(tmp_path / "test.txt", "rb") as f:
148
+ assert f.read() == new_content
149
+
150
+
151
+ async def test_post_userdata_no_overwrite(aiohttp_client, app, tmp_path):
152
+ # Create initial file
153
+ with open(tmp_path / "test.txt", "w") as f:
154
+ f.write("initial content")
155
+
156
+ client = await aiohttp_client(app)
157
+ resp = await client.post("/userdata/test.txt?overwrite=false", data=b"new content")
158
+
159
+ assert resp.status == 409
160
+
161
+ # Verify original content unchanged
162
+ with open(tmp_path / "test.txt", "r") as f:
163
+ assert f.read() == "initial content"
164
+
165
+
166
+ async def test_post_userdata_full_info(aiohttp_client, app, tmp_path):
167
+ client = await aiohttp_client(app)
168
+ content = b"test content"
169
+ resp = await client.post("/userdata/test.txt?full_info=true", data=content)
170
+
171
+ assert resp.status == 200
172
+ result = await resp.json()
173
+ assert result["path"] == "test.txt"
174
+ assert result["size"] == len(content)
175
+ assert "modified" in result
176
+
177
+
178
+ async def test_move_userdata(aiohttp_client, app, tmp_path):
179
+ # Create initial file
180
+ with open(tmp_path / "source.txt", "w") as f:
181
+ f.write("test content")
182
+
183
+ client = await aiohttp_client(app)
184
+ resp = await client.post("/userdata/source.txt/move/dest.txt")
185
+
186
+ assert resp.status == 200
187
+ assert await resp.text() == '"dest.txt"'
188
+
189
+ # Verify file was moved
190
+ assert not os.path.exists(tmp_path / "source.txt")
191
+ with open(tmp_path / "dest.txt", "r") as f:
192
+ assert f.read() == "test content"
193
+
194
+
195
+ async def test_move_userdata_no_overwrite(aiohttp_client, app, tmp_path):
196
+ # Create source and destination files
197
+ with open(tmp_path / "source.txt", "w") as f:
198
+ f.write("source content")
199
+ with open(tmp_path / "dest.txt", "w") as f:
200
+ f.write("destination content")
201
+
202
+ client = await aiohttp_client(app)
203
+ resp = await client.post("/userdata/source.txt/move/dest.txt?overwrite=false")
204
+
205
+ assert resp.status == 409
206
+
207
+ # Verify files remain unchanged
208
+ with open(tmp_path / "source.txt", "r") as f:
209
+ assert f.read() == "source content"
210
+ with open(tmp_path / "dest.txt", "r") as f:
211
+ assert f.read() == "destination content"
212
+
213
+
214
+ async def test_move_userdata_full_info(aiohttp_client, app, tmp_path):
215
+ # Create initial file
216
+ with open(tmp_path / "source.txt", "w") as f:
217
+ f.write("test content")
218
+
219
+ client = await aiohttp_client(app)
220
+ resp = await client.post("/userdata/source.txt/move/dest.txt?full_info=true")
221
+
222
+ assert resp.status == 200
223
+ result = await resp.json()
224
+ assert result["path"] == "dest.txt"
225
+ assert result["size"] == len("test content")
226
+ assert "modified" in result
227
+
228
+ # Verify file was moved
229
+ assert not os.path.exists(tmp_path / "source.txt")
230
+ with open(tmp_path / "dest.txt", "r") as f:
231
+ assert f.read() == "test content"
tests-unit/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pytest>=7.8.0
2
+ pytest-aiohttp
3
+ pytest-asyncio
tests-unit/server/routes/internal_routes_test.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from aiohttp import web
3
+ from unittest.mock import MagicMock, patch
4
+ from api_server.routes.internal.internal_routes import InternalRoutes
5
+ from api_server.services.file_service import FileService
6
+ from folder_paths import models_dir, user_directory, output_directory
7
+
8
+
9
+ @pytest.fixture
10
+ def internal_routes():
11
+ return InternalRoutes(None)
12
+
13
+ @pytest.fixture
14
+ def aiohttp_client_factory(aiohttp_client, internal_routes):
15
+ async def _get_client():
16
+ app = internal_routes.get_app()
17
+ return await aiohttp_client(app)
18
+ return _get_client
19
+
20
+ @pytest.mark.asyncio
21
+ async def test_list_files_valid_directory(aiohttp_client_factory, internal_routes):
22
+ mock_file_list = [
23
+ {"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
24
+ {"name": "dir1", "path": "dir1", "type": "directory"}
25
+ ]
26
+ internal_routes.file_service.list_files = MagicMock(return_value=mock_file_list)
27
+ client = await aiohttp_client_factory()
28
+ resp = await client.get('/files?directory=models')
29
+ assert resp.status == 200
30
+ data = await resp.json()
31
+ assert 'files' in data
32
+ assert len(data['files']) == 2
33
+ assert data['files'] == mock_file_list
34
+
35
+ # Check other valid directories
36
+ resp = await client.get('/files?directory=user')
37
+ assert resp.status == 200
38
+ resp = await client.get('/files?directory=output')
39
+ assert resp.status == 200
40
+
41
+ @pytest.mark.asyncio
42
+ async def test_list_files_invalid_directory(aiohttp_client_factory, internal_routes):
43
+ internal_routes.file_service.list_files = MagicMock(side_effect=ValueError("Invalid directory key"))
44
+ client = await aiohttp_client_factory()
45
+ resp = await client.get('/files?directory=invalid')
46
+ assert resp.status == 400
47
+ data = await resp.json()
48
+ assert 'error' in data
49
+ assert data['error'] == "Invalid directory key"
50
+
51
+ @pytest.mark.asyncio
52
+ async def test_list_files_exception(aiohttp_client_factory, internal_routes):
53
+ internal_routes.file_service.list_files = MagicMock(side_effect=Exception("Unexpected error"))
54
+ client = await aiohttp_client_factory()
55
+ resp = await client.get('/files?directory=models')
56
+ assert resp.status == 500
57
+ data = await resp.json()
58
+ assert 'error' in data
59
+ assert data['error'] == "Unexpected error"
60
+
61
+ @pytest.mark.asyncio
62
+ async def test_list_files_no_directory_param(aiohttp_client_factory, internal_routes):
63
+ mock_file_list = []
64
+ internal_routes.file_service.list_files = MagicMock(return_value=mock_file_list)
65
+ client = await aiohttp_client_factory()
66
+ resp = await client.get('/files')
67
+ assert resp.status == 200
68
+ data = await resp.json()
69
+ assert 'files' in data
70
+ assert len(data['files']) == 0
71
+
72
+ def test_setup_routes(internal_routes):
73
+ internal_routes.setup_routes()
74
+ routes = internal_routes.routes
75
+ assert any(route.method == 'GET' and str(route.path) == '/files' for route in routes)
76
+
77
+ def test_get_app(internal_routes):
78
+ app = internal_routes.get_app()
79
+ assert isinstance(app, web.Application)
80
+ assert internal_routes._app is not None
81
+
82
+ def test_get_app_reuse(internal_routes):
83
+ app1 = internal_routes.get_app()
84
+ app2 = internal_routes.get_app()
85
+ assert app1 is app2
86
+
87
+ @pytest.mark.asyncio
88
+ async def test_routes_added_to_app(aiohttp_client_factory, internal_routes):
89
+ client = await aiohttp_client_factory()
90
+ try:
91
+ resp = await client.get('/files')
92
+ print(f"Response received: status {resp.status}") # noqa: T201
93
+ except Exception as e:
94
+ print(f"Exception occurred during GET request: {e}") # noqa: T201
95
+ raise
96
+
97
+ assert resp.status != 404, "Route /files does not exist"
98
+
99
+ @pytest.mark.asyncio
100
+ async def test_file_service_initialization():
101
+ with patch('api_server.routes.internal.internal_routes.FileService') as MockFileService:
102
+ # Create a mock instance
103
+ mock_file_service_instance = MagicMock(spec=FileService)
104
+ MockFileService.return_value = mock_file_service_instance
105
+ internal_routes = InternalRoutes(None)
106
+
107
+ # Check if FileService was initialized with the correct parameters
108
+ MockFileService.assert_called_once_with({
109
+ "models": models_dir,
110
+ "user": user_directory,
111
+ "output": output_directory
112
+ })
113
+
114
+ # Verify that the file_service attribute of InternalRoutes is set
115
+ assert internal_routes.file_service == mock_file_service_instance
tests-unit/server/services/file_service_test.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from unittest.mock import MagicMock
3
+ from api_server.services.file_service import FileService
4
+
5
+ @pytest.fixture
6
+ def mock_file_system_ops():
7
+ return MagicMock()
8
+
9
+ @pytest.fixture
10
+ def file_service(mock_file_system_ops):
11
+ allowed_directories = {
12
+ "models": "/path/to/models",
13
+ "user": "/path/to/user",
14
+ "output": "/path/to/output"
15
+ }
16
+ return FileService(allowed_directories, file_system_ops=mock_file_system_ops)
17
+
18
+ def test_list_files_valid_directory(file_service, mock_file_system_ops):
19
+ mock_file_system_ops.walk_directory.return_value = [
20
+ {"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
21
+ {"name": "dir1", "path": "dir1", "type": "directory"}
22
+ ]
23
+
24
+ result = file_service.list_files("models")
25
+
26
+ assert len(result) == 2
27
+ assert result[0]["name"] == "file1.txt"
28
+ assert result[1]["name"] == "dir1"
29
+ mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
30
+
31
+ def test_list_files_invalid_directory(file_service):
32
+ # Does not support walking directories outside of the allowed directories
33
+ with pytest.raises(ValueError, match="Invalid directory key"):
34
+ file_service.list_files("invalid_key")
35
+
36
+ def test_list_files_empty_directory(file_service, mock_file_system_ops):
37
+ mock_file_system_ops.walk_directory.return_value = []
38
+
39
+ result = file_service.list_files("models")
40
+
41
+ assert len(result) == 0
42
+ mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
43
+
44
+ @pytest.mark.parametrize("directory_key", ["models", "user", "output"])
45
+ def test_list_files_all_allowed_directories(file_service, mock_file_system_ops, directory_key):
46
+ mock_file_system_ops.walk_directory.return_value = [
47
+ {"name": f"file_{directory_key}.txt", "path": f"file_{directory_key}.txt", "type": "file", "size": 100}
48
+ ]
49
+
50
+ result = file_service.list_files(directory_key)
51
+
52
+ assert len(result) == 1
53
+ assert result[0]["name"] == f"file_{directory_key}.txt"
54
+ mock_file_system_ops.walk_directory.assert_called_once_with(f"/path/to/{directory_key}")
tests-unit/server/utils/file_operations_test.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from typing import List
3
+ from api_server.utils.file_operations import FileSystemOperations, FileSystemItem, is_file_info
4
+
5
+ @pytest.fixture
6
+ def temp_directory(tmp_path):
7
+ # Create a temporary directory structure
8
+ dir1 = tmp_path / "dir1"
9
+ dir2 = tmp_path / "dir2"
10
+ dir1.mkdir()
11
+ dir2.mkdir()
12
+ (dir1 / "file1.txt").write_text("content1")
13
+ (dir2 / "file2.txt").write_text("content2")
14
+ (tmp_path / "file3.txt").write_text("content3")
15
+ return tmp_path
16
+
17
+ def test_walk_directory(temp_directory):
18
+ result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory))
19
+
20
+ assert len(result) == 5 # 2 directories and 3 files
21
+
22
+ files = [item for item in result if item['type'] == 'file']
23
+ dirs = [item for item in result if item['type'] == 'directory']
24
+
25
+ assert len(files) == 3
26
+ assert len(dirs) == 2
27
+
28
+ file_names = {file['name'] for file in files}
29
+ assert file_names == {'file1.txt', 'file2.txt', 'file3.txt'}
30
+
31
+ dir_names = {dir['name'] for dir in dirs}
32
+ assert dir_names == {'dir1', 'dir2'}
33
+
34
+ def test_walk_directory_empty(tmp_path):
35
+ result = FileSystemOperations.walk_directory(str(tmp_path))
36
+ assert len(result) == 0
37
+
38
+ def test_walk_directory_file_size(temp_directory):
39
+ result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory))
40
+ files = [item for item in result if is_file_info(item)]
41
+ for file in files:
42
+ assert file['size'] > 0 # Assuming all files have some content
tests-unit/utils/extra_config_test.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import yaml
3
+ import os
4
+ import sys
5
+ from unittest.mock import Mock, patch, mock_open
6
+
7
+ from utils.extra_config import load_extra_path_config
8
+ import folder_paths
9
+
10
+
11
+ @pytest.fixture()
12
+ def clear_folder_paths():
13
+ # Clear the global dictionary before each test to ensure isolation
14
+ original = folder_paths.folder_names_and_paths.copy()
15
+ folder_paths.folder_names_and_paths.clear()
16
+ yield
17
+ folder_paths.folder_names_and_paths = original
18
+
19
+
20
+ @pytest.fixture
21
+ def mock_yaml_content():
22
+ return {
23
+ 'test_config': {
24
+ 'base_path': '~/App/',
25
+ 'checkpoints': 'subfolder1',
26
+ }
27
+ }
28
+
29
+
30
+ @pytest.fixture
31
+ def mock_expanded_home():
32
+ return '/home/user'
33
+
34
+
35
+ @pytest.fixture
36
+ def yaml_config_with_appdata():
37
+ return """
38
+ test_config:
39
+ base_path: '%APPDATA%/ComfyUI'
40
+ checkpoints: 'models/checkpoints'
41
+ """
42
+
43
+
44
+ @pytest.fixture
45
+ def mock_yaml_content_appdata(yaml_config_with_appdata):
46
+ return yaml.safe_load(yaml_config_with_appdata)
47
+
48
+
49
+ @pytest.fixture
50
+ def mock_expandvars_appdata():
51
+ mock = Mock()
52
+
53
+ def expandvars(path):
54
+ if '%APPDATA%' in path:
55
+ if sys.platform == 'win32':
56
+ return path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
57
+ else:
58
+ return path.replace('%APPDATA%', '/Users/TestUser/AppData/Roaming')
59
+ return path
60
+
61
+ mock.side_effect = expandvars
62
+ return mock
63
+
64
+
65
+ @pytest.fixture
66
+ def mock_add_model_folder_path():
67
+ return Mock()
68
+
69
+
70
+ @pytest.fixture
71
+ def mock_expanduser(mock_expanded_home):
72
+ def _expanduser(path):
73
+ if path.startswith('~/'):
74
+ return os.path.join(mock_expanded_home, path[2:])
75
+ return path
76
+ return _expanduser
77
+
78
+
79
+ @pytest.fixture
80
+ def mock_yaml_safe_load(mock_yaml_content):
81
+ return Mock(return_value=mock_yaml_content)
82
+
83
+
84
+ @patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
85
+ def test_load_extra_model_paths_expands_userpath(
86
+ mock_file,
87
+ monkeypatch,
88
+ mock_add_model_folder_path,
89
+ mock_expanduser,
90
+ mock_yaml_safe_load,
91
+ mock_expanded_home
92
+ ):
93
+ # Attach mocks used by load_extra_path_config
94
+ monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
95
+ monkeypatch.setattr(os.path, 'expanduser', mock_expanduser)
96
+ monkeypatch.setattr(yaml, 'safe_load', mock_yaml_safe_load)
97
+
98
+ dummy_yaml_file_name = 'dummy_path.yaml'
99
+ load_extra_path_config(dummy_yaml_file_name)
100
+
101
+ expected_calls = [
102
+ ('checkpoints', os.path.join(mock_expanded_home, 'App', 'subfolder1'), False),
103
+ ]
104
+
105
+ assert mock_add_model_folder_path.call_count == len(expected_calls)
106
+
107
+ # Check if add_model_folder_path was called with the correct arguments
108
+ for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
109
+ assert actual_call.args[0] == expected_call[0]
110
+ assert os.path.normpath(actual_call.args[1]) == os.path.normpath(expected_call[1]) # Normalize and check the path to check on multiple OS.
111
+ assert actual_call.args[2] == expected_call[2]
112
+
113
+ # Check if yaml.safe_load was called
114
+ mock_yaml_safe_load.assert_called_once()
115
+
116
+ # Check if open was called with the correct file path
117
+ mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
118
+
119
+
120
+ @patch('builtins.open', new_callable=mock_open)
121
+ def test_load_extra_model_paths_expands_appdata(
122
+ mock_file,
123
+ monkeypatch,
124
+ mock_add_model_folder_path,
125
+ mock_expandvars_appdata,
126
+ yaml_config_with_appdata,
127
+ mock_yaml_content_appdata
128
+ ):
129
+ # Set the mock_file to return yaml with appdata as a variable
130
+ mock_file.return_value.read.return_value = yaml_config_with_appdata
131
+
132
+ # Attach mocks
133
+ monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
134
+ monkeypatch.setattr(os.path, 'expandvars', mock_expandvars_appdata)
135
+ monkeypatch.setattr(yaml, 'safe_load', Mock(return_value=mock_yaml_content_appdata))
136
+
137
+ # Mock expanduser to do nothing (since we're not testing it here)
138
+ monkeypatch.setattr(os.path, 'expanduser', lambda x: x)
139
+
140
+ dummy_yaml_file_name = 'dummy_path.yaml'
141
+ load_extra_path_config(dummy_yaml_file_name)
142
+
143
+ if sys.platform == "win32":
144
+ expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI'
145
+ else:
146
+ expected_base_path = '/Users/TestUser/AppData/Roaming/ComfyUI'
147
+ expected_calls = [
148
+ ('checkpoints', os.path.join(expected_base_path, 'models/checkpoints'), False),
149
+ ]
150
+
151
+ assert mock_add_model_folder_path.call_count == len(expected_calls)
152
+
153
+ # Check the base path variable was expanded
154
+ for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
155
+ assert actual_call.args == expected_call
156
+
157
+ # Verify that expandvars was called
158
+ assert mock_expandvars_appdata.called
159
+
160
+
161
+ @patch("builtins.open", new_callable=mock_open, read_data="dummy yaml content")
162
+ @patch("yaml.safe_load")
163
+ def test_load_extra_path_config_relative_base_path(
164
+ mock_yaml_load, _mock_file, clear_folder_paths, monkeypatch, tmp_path
165
+ ):
166
+ """
167
+ Test that when 'base_path' is a relative path in the YAML, it is joined to the YAML file directory, and then
168
+ the items in the config are correctly converted to absolute paths.
169
+ """
170
+ sub_folder = "./my_rel_base"
171
+ config_data = {
172
+ "some_model_folder": {
173
+ "base_path": sub_folder,
174
+ "is_default": True,
175
+ "checkpoints": "checkpoints",
176
+ "some_key": "some_value"
177
+ }
178
+ }
179
+ mock_yaml_load.return_value = config_data
180
+
181
+ dummy_yaml_name = "dummy_file.yaml"
182
+
183
+ def fake_abspath(path):
184
+ if path == dummy_yaml_name:
185
+ # If it's the YAML path, treat it like it lives in tmp_path
186
+ return os.path.join(str(tmp_path), dummy_yaml_name)
187
+ return os.path.join(str(tmp_path), path) # Otherwise, do a normal join relative to tmp_path
188
+
189
+ def fake_dirname(path):
190
+ # We expect path to be the result of fake_abspath(dummy_yaml_name)
191
+ if path.endswith(dummy_yaml_name):
192
+ return str(tmp_path)
193
+ return os.path.dirname(path)
194
+
195
+ monkeypatch.setattr(os.path, "abspath", fake_abspath)
196
+ monkeypatch.setattr(os.path, "dirname", fake_dirname)
197
+
198
+ load_extra_path_config(dummy_yaml_name)
199
+
200
+ expected_checkpoints = os.path.abspath(os.path.join(str(tmp_path), sub_folder, "checkpoints"))
201
+ expected_some_value = os.path.abspath(os.path.join(str(tmp_path), sub_folder, "some_value"))
202
+
203
+ actual_paths = folder_paths.folder_names_and_paths["checkpoints"][0]
204
+ assert len(actual_paths) == 1, "Should have one path added for 'checkpoints'."
205
+ assert actual_paths[0] == expected_checkpoints
206
+
207
+ actual_paths = folder_paths.folder_names_and_paths["some_key"][0]
208
+ assert len(actual_paths) == 1, "Should have one path added for 'some_key'."
209
+ assert actual_paths[0] == expected_some_value
210
+
211
+
212
+ @patch("builtins.open", new_callable=mock_open, read_data="dummy yaml content")
213
+ @patch("yaml.safe_load")
214
+ def test_load_extra_path_config_absolute_base_path(
215
+ mock_yaml_load, _mock_file, clear_folder_paths, monkeypatch, tmp_path
216
+ ):
217
+ """
218
+ Test that when 'base_path' is an absolute path, each subdirectory is joined with that absolute path,
219
+ rather than being relative to the YAML's directory.
220
+ """
221
+ abs_base = os.path.join(str(tmp_path), "abs_base")
222
+ config_data = {
223
+ "some_absolute_folder": {
224
+ "base_path": abs_base, # <-- absolute
225
+ "is_default": True,
226
+ "loras": "loras_folder",
227
+ "embeddings": "embeddings_folder"
228
+ }
229
+ }
230
+ mock_yaml_load.return_value = config_data
231
+
232
+ dummy_yaml_name = "dummy_abs.yaml"
233
+
234
+ def fake_abspath(path):
235
+ if path == dummy_yaml_name:
236
+ # If it's the YAML path, treat it like it is in tmp_path
237
+ return os.path.join(str(tmp_path), dummy_yaml_name)
238
+ return path # For absolute base, we just return path directly
239
+
240
+ def fake_dirname(path):
241
+ return str(tmp_path) if path.endswith(dummy_yaml_name) else os.path.dirname(path)
242
+
243
+ monkeypatch.setattr(os.path, "abspath", fake_abspath)
244
+ monkeypatch.setattr(os.path, "dirname", fake_dirname)
245
+
246
+ load_extra_path_config(dummy_yaml_name)
247
+
248
+ # Expect the final paths to be <abs_base>/loras_folder and <abs_base>/embeddings_folder
249
+ expected_loras = os.path.join(abs_base, "loras_folder")
250
+ expected_embeddings = os.path.join(abs_base, "embeddings_folder")
251
+
252
+ actual_loras = folder_paths.folder_names_and_paths["loras"][0]
253
+ assert len(actual_loras) == 1, "Should have one path for 'loras'."
254
+ assert actual_loras[0] == os.path.abspath(expected_loras)
255
+
256
+ actual_embeddings = folder_paths.folder_names_and_paths["embeddings"][0]
257
+ assert len(actual_embeddings) == 1, "Should have one path for 'embeddings'."
258
+ assert actual_embeddings[0] == os.path.abspath(expected_embeddings)
259
+
260
+
261
+ @patch("builtins.open", new_callable=mock_open, read_data="dummy yaml content")
262
+ @patch("yaml.safe_load")
263
+ def test_load_extra_path_config_no_base_path(
264
+ mock_yaml_load, _mock_file, clear_folder_paths, monkeypatch, tmp_path
265
+ ):
266
+ """
267
+ Test that if 'base_path' is not present, each path is joined
268
+ with the directory of the YAML file (unless it's already absolute).
269
+ """
270
+ config_data = {
271
+ "some_folder_without_base": {
272
+ "is_default": True,
273
+ "text_encoders": "clip",
274
+ "diffusion_models": "unet"
275
+ }
276
+ }
277
+ mock_yaml_load.return_value = config_data
278
+
279
+ dummy_yaml_name = "dummy_no_base.yaml"
280
+
281
+ def fake_abspath(path):
282
+ if path == dummy_yaml_name:
283
+ return os.path.join(str(tmp_path), dummy_yaml_name)
284
+ return os.path.join(str(tmp_path), path)
285
+
286
+ def fake_dirname(path):
287
+ return str(tmp_path) if path.endswith(dummy_yaml_name) else os.path.dirname(path)
288
+
289
+ monkeypatch.setattr(os.path, "abspath", fake_abspath)
290
+ monkeypatch.setattr(os.path, "dirname", fake_dirname)
291
+
292
+ load_extra_path_config(dummy_yaml_name)
293
+
294
+ expected_clip = os.path.join(str(tmp_path), "clip")
295
+ expected_unet = os.path.join(str(tmp_path), "unet")
296
+
297
+ actual_text_encoders = folder_paths.folder_names_and_paths["text_encoders"][0]
298
+ assert len(actual_text_encoders) == 1, "Should have one path for 'text_encoders'."
299
+ assert actual_text_encoders[0] == os.path.abspath(expected_clip)
300
+
301
+ actual_diffusion = folder_paths.folder_names_and_paths["diffusion_models"][0]
302
+ assert len(actual_diffusion) == 1, "Should have one path for 'diffusion_models'."
303
+ assert actual_diffusion[0] == os.path.abspath(expected_unet)
tests/README.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Automated Testing
2
+
3
+ ## Running tests locally
4
+
5
+ Additional requirements for running tests:
6
+ ```
7
+ pip install pytest
8
+ pip install websocket-client==1.6.1
9
+ opencv-python==4.6.0.66
10
+ scikit-image==0.21.0
11
+ ```
12
+ Run inference tests:
13
+ ```
14
+ pytest tests/inference
15
+ ```
16
+
17
+ ## Quality regression test
18
+ Compares images in 2 directories to ensure they are the same
19
+
20
+ 1) Run an inference test to save a directory of "ground truth" images
21
+ ```
22
+ pytest tests/inference --output_dir tests/inference/baseline
23
+ ```
24
+ 2) Make code edits
25
+
26
+ 3) Run inference and quality comparison tests
27
+ ```
28
+ pytest
29
+ ```
tests/__init__.py ADDED
File without changes
tests/compare/conftest.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytest
3
+
4
+ # Command line arguments for pytest
5
+ def pytest_addoption(parser):
6
+ parser.addoption('--baseline_dir', action="store", default='tests/inference/baseline', help='Directory for ground-truth images')
7
+ parser.addoption('--test_dir', action="store", default='tests/inference/samples', help='Directory for images to test')
8
+ parser.addoption('--metrics_file', action="store", default='tests/metrics.md', help='Output file for metrics')
9
+ parser.addoption('--img_output_dir', action="store", default='tests/compare/samples', help='Output directory for diff metric images')
10
+
11
+ # This initializes args at the beginning of the test session
12
+ @pytest.fixture(scope="session", autouse=True)
13
+ def args_pytest(pytestconfig):
14
+ args = {}
15
+ args['baseline_dir'] = pytestconfig.getoption('baseline_dir')
16
+ args['test_dir'] = pytestconfig.getoption('test_dir')
17
+ args['metrics_file'] = pytestconfig.getoption('metrics_file')
18
+ args['img_output_dir'] = pytestconfig.getoption('img_output_dir')
19
+
20
+ # Initialize metrics file
21
+ with open(args['metrics_file'], 'a') as f:
22
+ # if file is empty, write header
23
+ if os.stat(args['metrics_file']).st_size == 0:
24
+ f.write("| date | run | file | status | value | \n")
25
+ f.write("| --- | --- | --- | --- | --- | \n")
26
+
27
+ return args
28
+
29
+
30
+ def gather_file_basenames(directory: str):
31
+ files = []
32
+ for file in os.listdir(directory):
33
+ if file.endswith(".png"):
34
+ files.append(file)
35
+ return files
36
+
37
+ # Creates the list of baseline file names to use as a fixture
38
+ def pytest_generate_tests(metafunc):
39
+ if "baseline_fname" in metafunc.fixturenames:
40
+ baseline_fnames = gather_file_basenames(metafunc.config.getoption("baseline_dir"))
41
+ metafunc.parametrize("baseline_fname", baseline_fnames)
tests/compare/test_quality.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import numpy as np
3
+ import os
4
+ from PIL import Image
5
+ import pytest
6
+ from pytest import fixture
7
+ from typing import Tuple, List
8
+
9
+ from cv2 import imread, cvtColor, COLOR_BGR2RGB
10
+ from skimage.metrics import structural_similarity as ssim
11
+
12
+
13
+ """
14
+ This test suite compares images in 2 directories by file name
15
+ The directories are specified by the command line arguments --baseline_dir and --test_dir
16
+
17
+ """
18
+ # ssim: Structural Similarity Index
19
+ # Returns a tuple of (ssim, diff_image)
20
+ def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
21
+ score, diff = ssim(img0, img1, channel_axis=-1, full=True)
22
+ # rescale the difference image to 0-255 range
23
+ diff = (diff * 255).astype("uint8")
24
+ return score, diff
25
+
26
+ # Metrics must return a tuple of (score, diff_image)
27
+ METRICS = {"ssim": ssim_score}
28
+ METRICS_PASS_THRESHOLD = {"ssim": 0.95}
29
+
30
+
31
+ class TestCompareImageMetrics:
32
+ @fixture(scope="class")
33
+ def test_file_names(self, args_pytest):
34
+ test_dir = args_pytest['test_dir']
35
+ fnames = self.gather_file_basenames(test_dir)
36
+ yield fnames
37
+ del fnames
38
+
39
+ @fixture(scope="class", autouse=True)
40
+ def teardown(self, args_pytest):
41
+ yield
42
+ # Runs after all tests are complete
43
+ # Aggregate output files into a grid of images
44
+ baseline_dir = args_pytest['baseline_dir']
45
+ test_dir = args_pytest['test_dir']
46
+ img_output_dir = args_pytest['img_output_dir']
47
+ metrics_file = args_pytest['metrics_file']
48
+
49
+ grid_dir = os.path.join(img_output_dir, "grid")
50
+ os.makedirs(grid_dir, exist_ok=True)
51
+
52
+ for metric_dir in METRICS.keys():
53
+ metric_path = os.path.join(img_output_dir, metric_dir)
54
+ for file in os.listdir(metric_path):
55
+ if file.endswith(".png"):
56
+ score = self.lookup_score_from_fname(file, metrics_file)
57
+ image_file_list = []
58
+ image_file_list.append([
59
+ os.path.join(baseline_dir, file),
60
+ os.path.join(test_dir, file),
61
+ os.path.join(metric_path, file)
62
+ ])
63
+ # Create grid
64
+ image_list = [[Image.open(file) for file in files] for files in image_file_list]
65
+ grid = self.image_grid(image_list)
66
+ grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}"))
67
+
68
+ # Tests run for each baseline file name
69
+ @fixture()
70
+ def fname(self, baseline_fname):
71
+ yield baseline_fname
72
+ del baseline_fname
73
+
74
+ def test_directories_not_empty(self, args_pytest):
75
+ baseline_dir = args_pytest['baseline_dir']
76
+ test_dir = args_pytest['test_dir']
77
+ assert len(os.listdir(baseline_dir)) != 0, f"Baseline directory {baseline_dir} is empty"
78
+ assert len(os.listdir(test_dir)) != 0, f"Test directory {test_dir} is empty"
79
+
80
+ def test_dir_has_all_matching_metadata(self, fname, test_file_names, args_pytest):
81
+ # Check that all files in baseline_dir have a file in test_dir with matching metadata
82
+ baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname)
83
+ file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names]
84
+ file_match = self.find_file_match(baseline_file_path, file_paths)
85
+ assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}"
86
+
87
+ # For a baseline image file, finds the corresponding file name in test_dir and
88
+ # compares the images using the metrics in METRICS
89
+ @pytest.mark.parametrize("metric", METRICS.keys())
90
+ def test_pipeline_compare(
91
+ self,
92
+ args_pytest,
93
+ fname,
94
+ test_file_names,
95
+ metric,
96
+ ):
97
+ baseline_dir = args_pytest['baseline_dir']
98
+ test_dir = args_pytest['test_dir']
99
+ metrics_output_file = args_pytest['metrics_file']
100
+ img_output_dir = args_pytest['img_output_dir']
101
+
102
+ baseline_file_path = os.path.join(baseline_dir, fname)
103
+
104
+ # Find file match
105
+ file_paths = [os.path.join(test_dir, f) for f in test_file_names]
106
+ test_file = self.find_file_match(baseline_file_path, file_paths)
107
+
108
+ # Run metrics
109
+ sample_baseline = self.read_img(baseline_file_path)
110
+ sample_secondary = self.read_img(test_file)
111
+
112
+ score, metric_img = METRICS[metric](sample_baseline, sample_secondary)
113
+ metric_status = score > METRICS_PASS_THRESHOLD[metric]
114
+
115
+ # Save metric values
116
+ with open(metrics_output_file, 'a') as f:
117
+ run_info = os.path.splitext(fname)[0]
118
+ metric_status_str = "PASS ✅" if metric_status else "FAIL ❌"
119
+ date_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
120
+ f.write(f"| {date_str} | {run_info} | {metric} | {metric_status_str} | {score} | \n")
121
+
122
+ # Save metric image
123
+ metric_img_dir = os.path.join(img_output_dir, metric)
124
+ os.makedirs(metric_img_dir, exist_ok=True)
125
+ output_filename = f'{fname}'
126
+ Image.fromarray(metric_img).save(os.path.join(metric_img_dir, output_filename))
127
+
128
+ assert score > METRICS_PASS_THRESHOLD[metric]
129
+
130
+ def read_img(self, filename: str) -> np.ndarray:
131
+ cvImg = imread(filename)
132
+ cvImg = cvtColor(cvImg, COLOR_BGR2RGB)
133
+ return cvImg
134
+
135
+ def image_grid(self, img_list: list[list[Image.Image]]):
136
+ # imgs is a 2D list of images
137
+ # Assumes the input images are a rectangular grid of equal sized images
138
+ rows = len(img_list)
139
+ cols = len(img_list[0])
140
+
141
+ w, h = img_list[0][0].size
142
+ grid = Image.new('RGB', size=(cols*w, rows*h))
143
+
144
+ for i, row in enumerate(img_list):
145
+ for j, img in enumerate(row):
146
+ grid.paste(img, box=(j*w, i*h))
147
+ return grid
148
+
149
+ def lookup_score_from_fname(self,
150
+ fname: str,
151
+ metrics_output_file: str
152
+ ) -> float:
153
+ fname_basestr = os.path.splitext(fname)[0]
154
+ with open(metrics_output_file, 'r') as f:
155
+ for line in f:
156
+ if fname_basestr in line:
157
+ score = float(line.split('|')[5])
158
+ return score
159
+ raise ValueError(f"Could not find score for {fname} in {metrics_output_file}")
160
+
161
+ def gather_file_basenames(self, directory: str):
162
+ files = []
163
+ for file in os.listdir(directory):
164
+ if file.endswith(".png"):
165
+ files.append(file)
166
+ return files
167
+
168
+ def read_file_prompt(self, fname:str) -> str:
169
+ # Read prompt from image file metadata
170
+ img = Image.open(fname)
171
+ img.load()
172
+ return img.info['prompt']
173
+
174
+ def find_file_match(self, baseline_file: str, file_paths: List[str]):
175
+ # Find a file in file_paths with matching metadata to baseline_file
176
+ baseline_prompt = self.read_file_prompt(baseline_file)
177
+
178
+ # Do not match empty prompts
179
+ if baseline_prompt is None or baseline_prompt == "":
180
+ return None
181
+
182
+ # Find file match
183
+ # Reorder test_file_names so that the file with matching name is first
184
+ # This is an optimization because matching file names are more likely
185
+ # to have matching metadata if they were generated with the same script
186
+ basename = os.path.basename(baseline_file)
187
+ file_path_basenames = [os.path.basename(f) for f in file_paths]
188
+ if basename in file_path_basenames:
189
+ match_index = file_path_basenames.index(basename)
190
+ file_paths.insert(0, file_paths.pop(match_index))
191
+
192
+ for f in file_paths:
193
+ test_file_prompt = self.read_file_prompt(f)
194
+ if baseline_prompt == test_file_prompt:
195
+ return f
tests/conftest.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytest
3
+
4
+ # Command line arguments for pytest
5
+ def pytest_addoption(parser):
6
+ parser.addoption('--output_dir', action="store", default='tests/inference/samples', help='Output directory for generated images')
7
+ parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
8
+ parser.addoption("--port", type=int, default=8188, help="Set the listen port.")
9
+
10
+ # This initializes args at the beginning of the test session
11
+ @pytest.fixture(scope="session", autouse=True)
12
+ def args_pytest(pytestconfig):
13
+ args = {}
14
+ args['output_dir'] = pytestconfig.getoption('output_dir')
15
+ args['listen'] = pytestconfig.getoption('listen')
16
+ args['port'] = pytestconfig.getoption('port')
17
+
18
+ os.makedirs(args['output_dir'], exist_ok=True)
19
+
20
+ return args
21
+
22
+ def pytest_collection_modifyitems(items):
23
+ # Modifies items so tests run in the correct order
24
+
25
+ LAST_TESTS = ['test_quality']
26
+
27
+ # Move the last items to the end
28
+ last_items = []
29
+ for test_name in LAST_TESTS:
30
+ for item in items.copy():
31
+ print(item.module.__name__, item) # noqa: T201
32
+ if item.module.__name__ == test_name:
33
+ last_items.append(item)
34
+ items.remove(item)
35
+
36
+ items.extend(last_items)
tests/inference/__init__.py ADDED
File without changes
tests/inference/extra_model_paths.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Config for testing nodes
2
+ testing:
3
+ custom_nodes: tests/inference/testing_nodes
4
+
tests/inference/graphs/default_graph_sdxl1_0.json ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "4": {
3
+ "inputs": {
4
+ "ckpt_name": "sd_xl_base_1.0.safetensors"
5
+ },
6
+ "class_type": "CheckpointLoaderSimple"
7
+ },
8
+ "5": {
9
+ "inputs": {
10
+ "width": 1024,
11
+ "height": 1024,
12
+ "batch_size": 1
13
+ },
14
+ "class_type": "EmptyLatentImage"
15
+ },
16
+ "6": {
17
+ "inputs": {
18
+ "text": "a photo of a cat",
19
+ "clip": [
20
+ "4",
21
+ 1
22
+ ]
23
+ },
24
+ "class_type": "CLIPTextEncode"
25
+ },
26
+ "10": {
27
+ "inputs": {
28
+ "add_noise": "enable",
29
+ "noise_seed": 42,
30
+ "steps": 20,
31
+ "cfg": 7.5,
32
+ "sampler_name": "euler",
33
+ "scheduler": "normal",
34
+ "start_at_step": 0,
35
+ "end_at_step": 32,
36
+ "return_with_leftover_noise": "enable",
37
+ "model": [
38
+ "4",
39
+ 0
40
+ ],
41
+ "positive": [
42
+ "6",
43
+ 0
44
+ ],
45
+ "negative": [
46
+ "15",
47
+ 0
48
+ ],
49
+ "latent_image": [
50
+ "5",
51
+ 0
52
+ ]
53
+ },
54
+ "class_type": "KSamplerAdvanced"
55
+ },
56
+ "12": {
57
+ "inputs": {
58
+ "samples": [
59
+ "14",
60
+ 0
61
+ ],
62
+ "vae": [
63
+ "4",
64
+ 2
65
+ ]
66
+ },
67
+ "class_type": "VAEDecode"
68
+ },
69
+ "13": {
70
+ "inputs": {
71
+ "filename_prefix": "test_inference",
72
+ "images": [
73
+ "12",
74
+ 0
75
+ ]
76
+ },
77
+ "class_type": "SaveImage"
78
+ },
79
+ "14": {
80
+ "inputs": {
81
+ "add_noise": "disable",
82
+ "noise_seed": 42,
83
+ "steps": 20,
84
+ "cfg": 7.5,
85
+ "sampler_name": "euler",
86
+ "scheduler": "normal",
87
+ "start_at_step": 32,
88
+ "end_at_step": 10000,
89
+ "return_with_leftover_noise": "disable",
90
+ "model": [
91
+ "16",
92
+ 0
93
+ ],
94
+ "positive": [
95
+ "17",
96
+ 0
97
+ ],
98
+ "negative": [
99
+ "20",
100
+ 0
101
+ ],
102
+ "latent_image": [
103
+ "10",
104
+ 0
105
+ ]
106
+ },
107
+ "class_type": "KSamplerAdvanced"
108
+ },
109
+ "15": {
110
+ "inputs": {
111
+ "conditioning": [
112
+ "6",
113
+ 0
114
+ ]
115
+ },
116
+ "class_type": "ConditioningZeroOut"
117
+ },
118
+ "16": {
119
+ "inputs": {
120
+ "ckpt_name": "sd_xl_refiner_1.0.safetensors"
121
+ },
122
+ "class_type": "CheckpointLoaderSimple"
123
+ },
124
+ "17": {
125
+ "inputs": {
126
+ "text": "a photo of a cat",
127
+ "clip": [
128
+ "16",
129
+ 1
130
+ ]
131
+ },
132
+ "class_type": "CLIPTextEncode"
133
+ },
134
+ "20": {
135
+ "inputs": {
136
+ "text": "",
137
+ "clip": [
138
+ "16",
139
+ 1
140
+ ]
141
+ },
142
+ "class_type": "CLIPTextEncode"
143
+ }
144
+ }
tests/inference/test_execution.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import numpy
3
+ from PIL import Image
4
+ import pytest
5
+ from pytest import fixture
6
+ import time
7
+ import torch
8
+ from typing import Union, Dict
9
+ import json
10
+ import subprocess
11
+ import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
12
+ import uuid
13
+ import urllib.request
14
+ import urllib.parse
15
+ import urllib.error
16
+ from comfy_execution.graph_utils import GraphBuilder, Node
17
+
18
+ class RunResult:
19
+ def __init__(self, prompt_id: str):
20
+ self.outputs: Dict[str,Dict] = {}
21
+ self.runs: Dict[str,bool] = {}
22
+ self.prompt_id: str = prompt_id
23
+
24
+ def get_output(self, node: Node):
25
+ return self.outputs.get(node.id, None)
26
+
27
+ def did_run(self, node: Node):
28
+ return self.runs.get(node.id, False)
29
+
30
+ def get_images(self, node: Node):
31
+ output = self.get_output(node)
32
+ if output is None:
33
+ return []
34
+ return output.get('image_objects', [])
35
+
36
+ def get_prompt_id(self):
37
+ return self.prompt_id
38
+
39
+ class ComfyClient:
40
+ def __init__(self):
41
+ self.test_name = ""
42
+
43
+ def connect(self,
44
+ listen:str = '127.0.0.1',
45
+ port:Union[str,int] = 8188,
46
+ client_id: str = str(uuid.uuid4())
47
+ ):
48
+ self.client_id = client_id
49
+ self.server_address = f"{listen}:{port}"
50
+ ws = websocket.WebSocket()
51
+ ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
52
+ self.ws = ws
53
+
54
+ def queue_prompt(self, prompt):
55
+ p = {"prompt": prompt, "client_id": self.client_id}
56
+ data = json.dumps(p).encode('utf-8')
57
+ req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
58
+ return json.loads(urllib.request.urlopen(req).read())
59
+
60
+ def get_image(self, filename, subfolder, folder_type):
61
+ data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
62
+ url_values = urllib.parse.urlencode(data)
63
+ with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response:
64
+ return response.read()
65
+
66
+ def get_history(self, prompt_id):
67
+ with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response:
68
+ return json.loads(response.read())
69
+
70
+ def set_test_name(self, name):
71
+ self.test_name = name
72
+
73
+ def run(self, graph):
74
+ prompt = graph.finalize()
75
+ for node in graph.nodes.values():
76
+ if node.class_type == 'SaveImage':
77
+ node.inputs['filename_prefix'] = self.test_name
78
+
79
+ prompt_id = self.queue_prompt(prompt)['prompt_id']
80
+ result = RunResult(prompt_id)
81
+ while True:
82
+ out = self.ws.recv()
83
+ if isinstance(out, str):
84
+ message = json.loads(out)
85
+ if message['type'] == 'executing':
86
+ data = message['data']
87
+ if data['prompt_id'] != prompt_id:
88
+ continue
89
+ if data['node'] is None:
90
+ break
91
+ result.runs[data['node']] = True
92
+ elif message['type'] == 'execution_error':
93
+ raise Exception(message['data'])
94
+ elif message['type'] == 'execution_cached':
95
+ pass # Probably want to store this off for testing
96
+
97
+ history = self.get_history(prompt_id)[prompt_id]
98
+ for node_id in history['outputs']:
99
+ node_output = history['outputs'][node_id]
100
+ result.outputs[node_id] = node_output
101
+ images_output = []
102
+ if 'images' in node_output:
103
+ for image in node_output['images']:
104
+ image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
105
+ image_obj = Image.open(BytesIO(image_data))
106
+ images_output.append(image_obj)
107
+ node_output['image_objects'] = images_output
108
+
109
+ return result
110
+
111
+ #
112
+ # Loop through these variables
113
+ #
114
+ @pytest.mark.execution
115
+ class TestExecution:
116
+ #
117
+ # Initialize server and client
118
+ #
119
+ @fixture(scope="class", autouse=True, params=[
120
+ # (use_lru, lru_size)
121
+ (False, 0),
122
+ (True, 0),
123
+ (True, 100),
124
+ ])
125
+ def _server(self, args_pytest, request):
126
+ # Start server
127
+ pargs = [
128
+ 'python','main.py',
129
+ '--output-directory', args_pytest["output_dir"],
130
+ '--listen', args_pytest["listen"],
131
+ '--port', str(args_pytest["port"]),
132
+ '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
133
+ ]
134
+ use_lru, lru_size = request.param
135
+ if use_lru:
136
+ pargs += ['--cache-lru', str(lru_size)]
137
+ print("Running server with args:", pargs) # noqa: T201
138
+ p = subprocess.Popen(pargs)
139
+ yield
140
+ p.kill()
141
+ torch.cuda.empty_cache()
142
+
143
+ def start_client(self, listen:str, port:int):
144
+ # Start client
145
+ comfy_client = ComfyClient()
146
+ # Connect to server (with retries)
147
+ n_tries = 5
148
+ for i in range(n_tries):
149
+ time.sleep(4)
150
+ try:
151
+ comfy_client.connect(listen=listen, port=port)
152
+ except ConnectionRefusedError as e:
153
+ print(e) # noqa: T201
154
+ print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201
155
+ else:
156
+ break
157
+ return comfy_client
158
+
159
+ @fixture(scope="class", autouse=True)
160
+ def shared_client(self, args_pytest, _server):
161
+ client = self.start_client(args_pytest["listen"], args_pytest["port"])
162
+ yield client
163
+ del client
164
+ torch.cuda.empty_cache()
165
+
166
+ @fixture
167
+ def client(self, shared_client, request):
168
+ shared_client.set_test_name(f"execution[{request.node.name}]")
169
+ yield shared_client
170
+
171
+ @fixture
172
+ def builder(self, request):
173
+ yield GraphBuilder(prefix=request.node.name)
174
+
175
+ def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder):
176
+ g = builder
177
+ input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
178
+ input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
179
+ mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1)
180
+
181
+ lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
182
+ output = g.node("SaveImage", images=lazy_mix.out(0))
183
+ result = client.run(g)
184
+
185
+ result_image = result.get_images(output)[0]
186
+ assert numpy.array(result_image).any() == 0, "Image should be black"
187
+ assert result.did_run(input1)
188
+ assert not result.did_run(input2)
189
+ assert result.did_run(mask)
190
+ assert result.did_run(lazy_mix)
191
+
192
+ def test_full_cache(self, client: ComfyClient, builder: GraphBuilder):
193
+ g = builder
194
+ input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
195
+ input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
196
+ mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
197
+
198
+ lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
199
+ g.node("SaveImage", images=lazy_mix.out(0))
200
+
201
+ client.run(g)
202
+ result2 = client.run(g)
203
+ for node_id, node in g.nodes.items():
204
+ assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
205
+
206
+ def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder):
207
+ g = builder
208
+ input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
209
+ input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
210
+ mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
211
+
212
+ lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
213
+ g.node("SaveImage", images=lazy_mix.out(0))
214
+
215
+ client.run(g)
216
+ mask.inputs['value'] = 0.4
217
+ result2 = client.run(g)
218
+ assert not result2.did_run(input1), "Input1 should have been cached"
219
+ assert not result2.did_run(input2), "Input2 should have been cached"
220
+
221
+ def test_error(self, client: ComfyClient, builder: GraphBuilder):
222
+ g = builder
223
+ input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
224
+ # Different size of the two images
225
+ input2 = g.node("StubImage", content="NOISE", height=256, width=256, batch_size=1)
226
+ mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
227
+
228
+ lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
229
+ g.node("SaveImage", images=lazy_mix.out(0))
230
+
231
+ try:
232
+ client.run(g)
233
+ assert False, "Should have raised an error"
234
+ except Exception as e:
235
+ assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
236
+
237
+ @pytest.mark.parametrize("test_value, expect_error", [
238
+ (5, True),
239
+ ("foo", True),
240
+ (5.0, False),
241
+ ])
242
+ def test_validation_error_literal(self, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
243
+ g = builder
244
+ validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0)
245
+ g.node("SaveImage", images=validation1.out(0))
246
+
247
+ if expect_error:
248
+ with pytest.raises(urllib.error.HTTPError):
249
+ client.run(g)
250
+ else:
251
+ client.run(g)
252
+
253
+ @pytest.mark.parametrize("test_type, test_value", [
254
+ ("StubInt", 5),
255
+ ("StubFloat", 5.0)
256
+ ])
257
+ def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
258
+ g = builder
259
+ stub = g.node(test_type, value=test_value)
260
+ validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0)
261
+ g.node("SaveImage", images=validation1.out(0))
262
+
263
+ with pytest.raises(urllib.error.HTTPError):
264
+ client.run(g)
265
+
266
+ @pytest.mark.parametrize("test_type, test_value, expect_error", [
267
+ ("StubInt", 5, True),
268
+ ("StubFloat", 5.0, False)
269
+ ])
270
+ def test_validation_error_edge2(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
271
+ g = builder
272
+ stub = g.node(test_type, value=test_value)
273
+ validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0)
274
+ g.node("SaveImage", images=validation2.out(0))
275
+
276
+ if expect_error:
277
+ with pytest.raises(urllib.error.HTTPError):
278
+ client.run(g)
279
+ else:
280
+ client.run(g)
281
+
282
+ @pytest.mark.parametrize("test_type, test_value, expect_error", [
283
+ ("StubInt", 5, True),
284
+ ("StubFloat", 5.0, False)
285
+ ])
286
+ def test_validation_error_edge3(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
287
+ g = builder
288
+ stub = g.node(test_type, value=test_value)
289
+ validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0)
290
+ g.node("SaveImage", images=validation3.out(0))
291
+
292
+ if expect_error:
293
+ with pytest.raises(urllib.error.HTTPError):
294
+ client.run(g)
295
+ else:
296
+ client.run(g)
297
+
298
+ @pytest.mark.parametrize("test_type, test_value, expect_error", [
299
+ ("StubInt", 5, True),
300
+ ("StubFloat", 5.0, False)
301
+ ])
302
+ def test_validation_error_edge4(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
303
+ g = builder
304
+ stub = g.node(test_type, value=test_value)
305
+ validation4 = g.node("TestCustomValidation4", input1=stub.out(0), input2=3.0)
306
+ g.node("SaveImage", images=validation4.out(0))
307
+
308
+ if expect_error:
309
+ with pytest.raises(urllib.error.HTTPError):
310
+ client.run(g)
311
+ else:
312
+ client.run(g)
313
+
314
+ @pytest.mark.parametrize("test_value1, test_value2, expect_error", [
315
+ (0.0, 0.5, False),
316
+ (0.0, 5.0, False),
317
+ (0.0, 7.0, True)
318
+ ])
319
+ def test_validation_error_kwargs(self, test_value1, test_value2, expect_error, client: ComfyClient, builder: GraphBuilder):
320
+ g = builder
321
+ validation5 = g.node("TestCustomValidation5", input1=test_value1, input2=test_value2)
322
+ g.node("SaveImage", images=validation5.out(0))
323
+
324
+ if expect_error:
325
+ with pytest.raises(urllib.error.HTTPError):
326
+ client.run(g)
327
+ else:
328
+ client.run(g)
329
+
330
+ def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
331
+ g = builder
332
+ input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
333
+ input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
334
+ mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
335
+
336
+ lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), mask=mask.out(0))
337
+ lazy_mix2 = g.node("TestLazyMixImages", image1=lazy_mix1.out(0), image2=input2.out(0), mask=mask.out(0))
338
+ g.node("SaveImage", images=lazy_mix2.out(0))
339
+
340
+ # When the cycle exists on initial submission, it should raise a validation error
341
+ with pytest.raises(urllib.error.HTTPError):
342
+ client.run(g)
343
+
344
+ def test_dynamic_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
345
+ g = builder
346
+ input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
347
+ input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
348
+ generator = g.node("TestDynamicDependencyCycle", input1=input1.out(0), input2=input2.out(0))
349
+ g.node("SaveImage", images=generator.out(0))
350
+
351
+ # When the cycle is in a graph that is generated dynamically, it should raise a runtime error
352
+ try:
353
+ client.run(g)
354
+ assert False, "Should have raised an error"
355
+ except Exception as e:
356
+ assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
357
+ assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
358
+
359
+ def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder):
360
+ g = builder
361
+ input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
362
+ input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
363
+ input3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
364
+ mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
365
+ mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
366
+ mix2 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input3.out(0), mask=mask.out(0))
367
+ # We have multiple outputs. The first is invalid, but the second is valid
368
+ g.node("SaveImage", images=mix1.out(0))
369
+ g.node("SaveImage", images=mix2.out(0))
370
+ g.remove_node("removeme")
371
+
372
+ client.run(g)
373
+
374
+ # Add back in the missing node to make sure the error doesn't break the server
375
+ input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
376
+ client.run(g)
377
+
378
+ def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
379
+ g = builder
380
+ # Creating the nodes in this specific order previously caused a bug
381
+ save = g.node("SaveImage")
382
+ is_changed = g.node("TestCustomIsChanged", should_change=False)
383
+ input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
384
+
385
+ save.set_input('images', is_changed.out(0))
386
+ is_changed.set_input('image', input1.out(0))
387
+
388
+ result1 = client.run(g)
389
+ result2 = client.run(g)
390
+ is_changed.set_input('should_change', True)
391
+ result3 = client.run(g)
392
+ result4 = client.run(g)
393
+ assert result1.did_run(is_changed), "is_changed should have been run"
394
+ assert not result2.did_run(is_changed), "is_changed should have been cached"
395
+ assert result3.did_run(is_changed), "is_changed should have been re-run"
396
+ assert result4.did_run(is_changed), "is_changed should not have been cached"
397
+
398
+ def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder):
399
+ g = builder
400
+ input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
401
+ input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
402
+ input3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
403
+ input4 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
404
+ average = g.node("TestVariadicAverage", input1=input1.out(0), input2=input2.out(0), input3=input3.out(0), input4=input4.out(0))
405
+ output = g.node("SaveImage", images=average.out(0))
406
+
407
+ result = client.run(g)
408
+ result_image = result.get_images(output)[0]
409
+ expected = 255 // 4
410
+ assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
411
+
412
+ def test_for_loop(self, client: ComfyClient, builder: GraphBuilder):
413
+ g = builder
414
+ iterations = 4
415
+ input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
416
+ input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
417
+ is_changed = g.node("TestCustomIsChanged", should_change=True, image=input2.out(0))
418
+ for_open = g.node("TestForLoopOpen", remaining=iterations, initial_value1=is_changed.out(0))
419
+ average = g.node("TestVariadicAverage", input1=input1.out(0), input2=for_open.out(2))
420
+ for_close = g.node("TestForLoopClose", flow_control=for_open.out(0), initial_value1=average.out(0))
421
+ output = g.node("SaveImage", images=for_close.out(0))
422
+
423
+ for iterations in range(1, 5):
424
+ for_open.set_input('remaining', iterations)
425
+ result = client.run(g)
426
+ result_image = result.get_images(output)[0]
427
+ expected = 255 // (2 ** iterations)
428
+ assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
429
+ assert result.did_run(is_changed)
430
+
431
+ def test_mixed_expansion_returns(self, client: ComfyClient, builder: GraphBuilder):
432
+ g = builder
433
+ val_list = g.node("TestMakeListNode", value1=0.1, value2=0.2, value3=0.3)
434
+ mixed = g.node("TestMixedExpansionReturns", input1=val_list.out(0))
435
+ output_dynamic = g.node("SaveImage", images=mixed.out(0))
436
+ output_literal = g.node("SaveImage", images=mixed.out(1))
437
+
438
+ result = client.run(g)
439
+ images_dynamic = result.get_images(output_dynamic)
440
+ assert len(images_dynamic) == 3, "Should have 2 images"
441
+ assert numpy.array(images_dynamic[0]).min() == 25 and numpy.array(images_dynamic[0]).max() == 25, "First image should be 0.1"
442
+ assert numpy.array(images_dynamic[1]).min() == 51 and numpy.array(images_dynamic[1]).max() == 51, "Second image should be 0.2"
443
+ assert numpy.array(images_dynamic[2]).min() == 76 and numpy.array(images_dynamic[2]).max() == 76, "Third image should be 0.3"
444
+
445
+ images_literal = result.get_images(output_literal)
446
+ assert len(images_literal) == 3, "Should have 2 images"
447
+ for i in range(3):
448
+ assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white"
449
+
450
+ def test_mixed_lazy_results(self, client: ComfyClient, builder: GraphBuilder):
451
+ g = builder
452
+ val_list = g.node("TestMakeListNode", value1=0.0, value2=0.5, value3=1.0)
453
+ mask = g.node("StubMask", value=val_list.out(0), height=512, width=512, batch_size=1)
454
+ input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
455
+ input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
456
+ mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
457
+ rebatch = g.node("RebatchImages", images=mix.out(0), batch_size=3)
458
+ output = g.node("SaveImage", images=rebatch.out(0))
459
+
460
+ result = client.run(g)
461
+ images = result.get_images(output)
462
+ assert len(images) == 3, "Should have 3 image"
463
+ assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be 0.0"
464
+ assert numpy.array(images[1]).min() == 127 and numpy.array(images[1]).max() == 127, "Second image should be 0.5"
465
+ assert numpy.array(images[2]).min() == 255 and numpy.array(images[2]).max() == 255, "Third image should be 1.0"
466
+
467
+ def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder):
468
+ g = builder
469
+ input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
470
+
471
+ output1 = g.node("SaveImage", images=input1.out(0))
472
+ output2 = g.node("SaveImage", images=input1.out(0))
473
+
474
+ result = client.run(g)
475
+ images1 = result.get_images(output1)
476
+ images2 = result.get_images(output2)
477
+ assert len(images1) == 1, "Should have 1 image"
478
+ assert len(images2) == 1, "Should have 1 image"
479
+
480
+
481
+ # This tests that only constant outputs are used in the call to `IS_CHANGED`
482
+ def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder):
483
+ g = builder
484
+ input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
485
+ test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
486
+
487
+ output = g.node("PreviewImage", images=test_node.out(0))
488
+
489
+ result = client.run(g)
490
+ images = result.get_images(output)
491
+ assert len(images) == 1, "Should have 1 image"
492
+ assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
493
+
494
+ result = client.run(g)
495
+ images = result.get_images(output)
496
+ assert len(images) == 1, "Should have 1 image"
497
+ assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
498
+ assert not result.did_run(test_node), "The execution should have been cached"
499
+
500
+ # This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
501
+ # as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
502
+ # only that one entry in the list is blocked.
503
+ def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder):
504
+ g = builder
505
+ image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
506
+ image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
507
+ image3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
508
+ image_list = g.node("TestMakeListNode", value1=image1.out(0), value2=image2.out(0), value3=image3.out(0))
509
+ int1 = g.node("StubInt", value=1)
510
+ int2 = g.node("StubInt", value=2)
511
+ int3 = g.node("StubInt", value=3)
512
+ int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
513
+ compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
514
+ blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
515
+
516
+ list_output = g.node("TestMakeListNode", value1=blocker.out(0))
517
+ output = g.node("PreviewImage", images=list_output.out(0))
518
+
519
+ result = client.run(g)
520
+ assert result.did_run(output), "The execution should have run"
521
+ images = result.get_images(output)
522
+ assert len(images) == 2, "Should have 2 images"
523
+ assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
524
+ assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"
tests/inference/test_inference.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from io import BytesIO
3
+ import numpy
4
+ import os
5
+ from PIL import Image
6
+ import pytest
7
+ from pytest import fixture
8
+ import time
9
+ import torch
10
+ from typing import Union
11
+ import json
12
+ import subprocess
13
+ import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
14
+ import uuid
15
+ import urllib.request
16
+ import urllib.parse
17
+
18
+
19
+ from comfy.samplers import KSampler
20
+
21
+ """
22
+ These tests generate and save images through a range of parameters
23
+ """
24
+
25
+ class ComfyGraph:
26
+ def __init__(self,
27
+ graph: dict,
28
+ sampler_nodes: list[str],
29
+ ):
30
+ self.graph = graph
31
+ self.sampler_nodes = sampler_nodes
32
+
33
+ def set_prompt(self, prompt, negative_prompt=None):
34
+ # Sets the prompt for the sampler nodes (eg. base and refiner)
35
+ for node in self.sampler_nodes:
36
+ prompt_node = self.graph[node]['inputs']['positive'][0]
37
+ self.graph[prompt_node]['inputs']['text'] = prompt
38
+ if negative_prompt:
39
+ negative_prompt_node = self.graph[node]['inputs']['negative'][0]
40
+ self.graph[negative_prompt_node]['inputs']['text'] = negative_prompt
41
+
42
+ def set_sampler_name(self, sampler_name:str, ):
43
+ # sets the sampler name for the sampler nodes (eg. base and refiner)
44
+ for node in self.sampler_nodes:
45
+ self.graph[node]['inputs']['sampler_name'] = sampler_name
46
+
47
+ def set_scheduler(self, scheduler:str):
48
+ # sets the sampler name for the sampler nodes (eg. base and refiner)
49
+ for node in self.sampler_nodes:
50
+ self.graph[node]['inputs']['scheduler'] = scheduler
51
+
52
+ def set_filename_prefix(self, prefix:str):
53
+ # sets the filename prefix for the save nodes
54
+ for node in self.graph:
55
+ if self.graph[node]['class_type'] == 'SaveImage':
56
+ self.graph[node]['inputs']['filename_prefix'] = prefix
57
+
58
+
59
+ class ComfyClient:
60
+ # From examples/websockets_api_example.py
61
+
62
+ def connect(self,
63
+ listen:str = '127.0.0.1',
64
+ port:Union[str,int] = 8188,
65
+ client_id: str = str(uuid.uuid4())
66
+ ):
67
+ self.client_id = client_id
68
+ self.server_address = f"{listen}:{port}"
69
+ ws = websocket.WebSocket()
70
+ ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
71
+ self.ws = ws
72
+
73
+ def queue_prompt(self, prompt):
74
+ p = {"prompt": prompt, "client_id": self.client_id}
75
+ data = json.dumps(p).encode('utf-8')
76
+ req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
77
+ return json.loads(urllib.request.urlopen(req).read())
78
+
79
+ def get_image(self, filename, subfolder, folder_type):
80
+ data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
81
+ url_values = urllib.parse.urlencode(data)
82
+ with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response:
83
+ return response.read()
84
+
85
+ def get_history(self, prompt_id):
86
+ with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response:
87
+ return json.loads(response.read())
88
+
89
+ def get_images(self, graph, save=True):
90
+ prompt = graph
91
+ if not save:
92
+ # Replace save nodes with preview nodes
93
+ prompt_str = json.dumps(prompt)
94
+ prompt_str = prompt_str.replace('SaveImage', 'PreviewImage')
95
+ prompt = json.loads(prompt_str)
96
+
97
+ prompt_id = self.queue_prompt(prompt)['prompt_id']
98
+ output_images = {}
99
+ while True:
100
+ out = self.ws.recv()
101
+ if isinstance(out, str):
102
+ message = json.loads(out)
103
+ if message['type'] == 'executing':
104
+ data = message['data']
105
+ if data['node'] is None and data['prompt_id'] == prompt_id:
106
+ break #Execution is done
107
+ else:
108
+ continue #previews are binary data
109
+
110
+ history = self.get_history(prompt_id)[prompt_id]
111
+ for node_id in history['outputs']:
112
+ node_output = history['outputs'][node_id]
113
+ images_output = []
114
+ if 'images' in node_output:
115
+ for image in node_output['images']:
116
+ image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
117
+ images_output.append(image_data)
118
+ output_images[node_id] = images_output
119
+
120
+ return output_images
121
+
122
+ #
123
+ # Initialize graphs
124
+ #
125
+ default_graph_file = 'tests/inference/graphs/default_graph_sdxl1_0.json'
126
+ with open(default_graph_file, 'r') as file:
127
+ default_graph = json.loads(file.read())
128
+ DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10','14'])
129
+ DEFAULT_COMFY_GRAPH_ID = os.path.splitext(os.path.basename(default_graph_file))[0]
130
+
131
+ #
132
+ # Loop through these variables
133
+ #
134
+ comfy_graph_list = [DEFAULT_COMFY_GRAPH]
135
+ comfy_graph_ids = [DEFAULT_COMFY_GRAPH_ID]
136
+ prompt_list = [
137
+ 'a painting of a cat',
138
+ ]
139
+
140
+ sampler_list = KSampler.SAMPLERS
141
+ scheduler_list = KSampler.SCHEDULERS
142
+
143
+ @pytest.mark.inference
144
+ @pytest.mark.parametrize("sampler", sampler_list)
145
+ @pytest.mark.parametrize("scheduler", scheduler_list)
146
+ @pytest.mark.parametrize("prompt", prompt_list)
147
+ class TestInference:
148
+ #
149
+ # Initialize server and client
150
+ #
151
+ @fixture(scope="class", autouse=True)
152
+ def _server(self, args_pytest):
153
+ # Start server
154
+ p = subprocess.Popen([
155
+ 'python','main.py',
156
+ '--output-directory', args_pytest["output_dir"],
157
+ '--listen', args_pytest["listen"],
158
+ '--port', str(args_pytest["port"]),
159
+ ])
160
+ yield
161
+ p.kill()
162
+ torch.cuda.empty_cache()
163
+
164
+ def start_client(self, listen:str, port:int):
165
+ # Start client
166
+ comfy_client = ComfyClient()
167
+ # Connect to server (with retries)
168
+ n_tries = 5
169
+ for i in range(n_tries):
170
+ time.sleep(4)
171
+ try:
172
+ comfy_client.connect(listen=listen, port=port)
173
+ except ConnectionRefusedError as e:
174
+ print(e) # noqa: T201
175
+ print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201
176
+ else:
177
+ break
178
+ return comfy_client
179
+
180
+ #
181
+ # Client and graph fixtures with server warmup
182
+ #
183
+ # Returns a "_client_graph", which is client-graph pair corresponding to an initialized server
184
+ # The "graph" is the default graph
185
+ @fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True)
186
+ def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph):
187
+ comfy_graph = request.param
188
+
189
+ # Start client
190
+ comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"])
191
+
192
+ # Warm up pipeline
193
+ comfy_client.get_images(graph=comfy_graph.graph, save=False)
194
+
195
+ yield comfy_client, comfy_graph
196
+ del comfy_client
197
+ del comfy_graph
198
+ torch.cuda.empty_cache()
199
+
200
+ @fixture
201
+ def client(self, _client_graph):
202
+ client = _client_graph[0]
203
+ yield client
204
+
205
+ @fixture
206
+ def comfy_graph(self, _client_graph):
207
+ # avoid mutating the graph
208
+ graph = deepcopy(_client_graph[1])
209
+ yield graph
210
+
211
+ def test_comfy(
212
+ self,
213
+ client,
214
+ comfy_graph,
215
+ sampler,
216
+ scheduler,
217
+ prompt,
218
+ request
219
+ ):
220
+ test_info = request.node.name
221
+ comfy_graph.set_filename_prefix(test_info)
222
+ # Settings for comfy graph
223
+ comfy_graph.set_sampler_name(sampler)
224
+ comfy_graph.set_scheduler(scheduler)
225
+ comfy_graph.set_prompt(prompt)
226
+
227
+ # Generate
228
+ images = client.get_images(comfy_graph.graph)
229
+
230
+ assert len(images) != 0, "No images generated"
231
+ # assert all images are not blank
232
+ for images_output in images.values():
233
+ for image_data in images_output:
234
+ pil_image = Image.open(BytesIO(image_data))
235
+ assert numpy.array(pil_image).any() != 0, "Image is blank"
236
+
237
+
tests/inference/testing_nodes/testing-pack/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .specific_tests import TEST_NODE_CLASS_MAPPINGS, TEST_NODE_DISPLAY_NAME_MAPPINGS
2
+ from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS
3
+ from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS
4
+ from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
5
+ from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS
6
+
7
+ # NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
8
+ # NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
9
+
10
+ NODE_CLASS_MAPPINGS = {}
11
+ NODE_CLASS_MAPPINGS.update(TEST_NODE_CLASS_MAPPINGS)
12
+ NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS)
13
+ NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
14
+ NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
15
+ NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS)
16
+
17
+ NODE_DISPLAY_NAME_MAPPINGS = {}
18
+ NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS)
19
+ NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS)
20
+ NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
21
+ NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
22
+ NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS)
23
+
tests/inference/testing_nodes/testing-pack/conditions.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+
4
+ class TestIntConditions:
5
+ def __init__(self):
6
+ pass
7
+
8
+ @classmethod
9
+ def INPUT_TYPES(cls):
10
+ return {
11
+ "required": {
12
+ "a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
13
+ "b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
14
+ "operation": (["==", "!=", "<", ">", "<=", ">="],),
15
+ },
16
+ }
17
+
18
+ RETURN_TYPES = ("BOOLEAN",)
19
+ FUNCTION = "int_condition"
20
+
21
+ CATEGORY = "Testing/Logic"
22
+
23
+ def int_condition(self, a, b, operation):
24
+ if operation == "==":
25
+ return (a == b,)
26
+ elif operation == "!=":
27
+ return (a != b,)
28
+ elif operation == "<":
29
+ return (a < b,)
30
+ elif operation == ">":
31
+ return (a > b,)
32
+ elif operation == "<=":
33
+ return (a <= b,)
34
+ elif operation == ">=":
35
+ return (a >= b,)
36
+
37
+
38
+ class TestFloatConditions:
39
+ def __init__(self):
40
+ pass
41
+
42
+ @classmethod
43
+ def INPUT_TYPES(cls):
44
+ return {
45
+ "required": {
46
+ "a": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}),
47
+ "b": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}),
48
+ "operation": (["==", "!=", "<", ">", "<=", ">="],),
49
+ },
50
+ }
51
+
52
+ RETURN_TYPES = ("BOOLEAN",)
53
+ FUNCTION = "float_condition"
54
+
55
+ CATEGORY = "Testing/Logic"
56
+
57
+ def float_condition(self, a, b, operation):
58
+ if operation == "==":
59
+ return (a == b,)
60
+ elif operation == "!=":
61
+ return (a != b,)
62
+ elif operation == "<":
63
+ return (a < b,)
64
+ elif operation == ">":
65
+ return (a > b,)
66
+ elif operation == "<=":
67
+ return (a <= b,)
68
+ elif operation == ">=":
69
+ return (a >= b,)
70
+
71
+ class TestStringConditions:
72
+ def __init__(self):
73
+ pass
74
+
75
+ @classmethod
76
+ def INPUT_TYPES(cls):
77
+ return {
78
+ "required": {
79
+ "a": ("STRING", {"multiline": False}),
80
+ "b": ("STRING", {"multiline": False}),
81
+ "operation": (["a == b", "a != b", "a IN b", "a MATCH REGEX(b)", "a BEGINSWITH b", "a ENDSWITH b"],),
82
+ "case_sensitive": ("BOOLEAN", {"default": True}),
83
+ },
84
+ }
85
+
86
+ RETURN_TYPES = ("BOOLEAN",)
87
+ FUNCTION = "string_condition"
88
+
89
+ CATEGORY = "Testing/Logic"
90
+
91
+ def string_condition(self, a, b, operation, case_sensitive):
92
+ if not case_sensitive:
93
+ a = a.lower()
94
+ b = b.lower()
95
+
96
+ if operation == "a == b":
97
+ return (a == b,)
98
+ elif operation == "a != b":
99
+ return (a != b,)
100
+ elif operation == "a IN b":
101
+ return (a in b,)
102
+ elif operation == "a MATCH REGEX(b)":
103
+ try:
104
+ return (re.match(b, a) is not None,)
105
+ except:
106
+ return (False,)
107
+ elif operation == "a BEGINSWITH b":
108
+ return (a.startswith(b),)
109
+ elif operation == "a ENDSWITH b":
110
+ return (a.endswith(b),)
111
+
112
+ class TestToBoolNode:
113
+ def __init__(self):
114
+ pass
115
+
116
+ @classmethod
117
+ def INPUT_TYPES(cls):
118
+ return {
119
+ "required": {
120
+ "value": ("*",),
121
+ },
122
+ "optional": {
123
+ "invert": ("BOOLEAN", {"default": False}),
124
+ },
125
+ }
126
+
127
+ RETURN_TYPES = ("BOOLEAN",)
128
+ FUNCTION = "to_bool"
129
+
130
+ CATEGORY = "Testing/Logic"
131
+
132
+ def to_bool(self, value, invert = False):
133
+ if isinstance(value, torch.Tensor):
134
+ if value.max().item() == 0 and value.min().item() == 0:
135
+ result = False
136
+ else:
137
+ result = True
138
+ else:
139
+ try:
140
+ result = bool(value)
141
+ except:
142
+ # Can't convert it? Well then it's something or other. I dunno, I'm not a Python programmer.
143
+ result = True
144
+
145
+ if invert:
146
+ result = not result
147
+
148
+ return (result,)
149
+
150
+ class TestBoolOperationNode:
151
+ def __init__(self):
152
+ pass
153
+
154
+ @classmethod
155
+ def INPUT_TYPES(cls):
156
+ return {
157
+ "required": {
158
+ "a": ("BOOLEAN",),
159
+ "b": ("BOOLEAN",),
160
+ "op": (["a AND b", "a OR b", "a XOR b", "NOT a"],),
161
+ },
162
+ }
163
+
164
+ RETURN_TYPES = ("BOOLEAN",)
165
+ FUNCTION = "bool_operation"
166
+
167
+ CATEGORY = "Testing/Logic"
168
+
169
+ def bool_operation(self, a, b, op):
170
+ if op == "a AND b":
171
+ return (a and b,)
172
+ elif op == "a OR b":
173
+ return (a or b,)
174
+ elif op == "a XOR b":
175
+ return (a ^ b,)
176
+ elif op == "NOT a":
177
+ return (not a,)
178
+
179
+
180
+ CONDITION_NODE_CLASS_MAPPINGS = {
181
+ "TestIntConditions": TestIntConditions,
182
+ "TestFloatConditions": TestFloatConditions,
183
+ "TestStringConditions": TestStringConditions,
184
+ "TestToBoolNode": TestToBoolNode,
185
+ "TestBoolOperationNode": TestBoolOperationNode,
186
+ }
187
+
188
+ CONDITION_NODE_DISPLAY_NAME_MAPPINGS = {
189
+ "TestIntConditions": "Int Condition",
190
+ "TestFloatConditions": "Float Condition",
191
+ "TestStringConditions": "String Condition",
192
+ "TestToBoolNode": "To Bool",
193
+ "TestBoolOperationNode": "Bool Operation",
194
+ }
tests/inference/testing_nodes/testing-pack/flow_control.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from comfy_execution.graph_utils import GraphBuilder, is_link
2
+ from comfy_execution.graph import ExecutionBlocker
3
+ from .tools import VariantSupport
4
+
5
+ NUM_FLOW_SOCKETS = 5
6
+ @VariantSupport()
7
+ class TestWhileLoopOpen:
8
+ def __init__(self):
9
+ pass
10
+
11
+ @classmethod
12
+ def INPUT_TYPES(cls):
13
+ inputs = {
14
+ "required": {
15
+ "condition": ("BOOLEAN", {"default": True}),
16
+ },
17
+ "optional": {
18
+ },
19
+ }
20
+ for i in range(NUM_FLOW_SOCKETS):
21
+ inputs["optional"][f"initial_value{i}"] = ("*",)
22
+ return inputs
23
+
24
+ RETURN_TYPES = tuple(["FLOW_CONTROL"] + ["*"] * NUM_FLOW_SOCKETS)
25
+ RETURN_NAMES = tuple(["FLOW_CONTROL"] + [f"value{i}" for i in range(NUM_FLOW_SOCKETS)])
26
+ FUNCTION = "while_loop_open"
27
+
28
+ CATEGORY = "Testing/Flow"
29
+
30
+ def while_loop_open(self, condition, **kwargs):
31
+ values = []
32
+ for i in range(NUM_FLOW_SOCKETS):
33
+ values.append(kwargs.get(f"initial_value{i}", None))
34
+ return tuple(["stub"] + values)
35
+
36
+ @VariantSupport()
37
+ class TestWhileLoopClose:
38
+ def __init__(self):
39
+ pass
40
+
41
+ @classmethod
42
+ def INPUT_TYPES(cls):
43
+ inputs = {
44
+ "required": {
45
+ "flow_control": ("FLOW_CONTROL", {"rawLink": True}),
46
+ "condition": ("BOOLEAN", {"forceInput": True}),
47
+ },
48
+ "optional": {
49
+ },
50
+ "hidden": {
51
+ "dynprompt": "DYNPROMPT",
52
+ "unique_id": "UNIQUE_ID",
53
+ }
54
+ }
55
+ for i in range(NUM_FLOW_SOCKETS):
56
+ inputs["optional"][f"initial_value{i}"] = ("*",)
57
+ return inputs
58
+
59
+ RETURN_TYPES = tuple(["*"] * NUM_FLOW_SOCKETS)
60
+ RETURN_NAMES = tuple([f"value{i}" for i in range(NUM_FLOW_SOCKETS)])
61
+ FUNCTION = "while_loop_close"
62
+
63
+ CATEGORY = "Testing/Flow"
64
+
65
+ def explore_dependencies(self, node_id, dynprompt, upstream):
66
+ node_info = dynprompt.get_node(node_id)
67
+ if "inputs" not in node_info:
68
+ return
69
+ for k, v in node_info["inputs"].items():
70
+ if is_link(v):
71
+ parent_id = v[0]
72
+ if parent_id not in upstream:
73
+ upstream[parent_id] = []
74
+ self.explore_dependencies(parent_id, dynprompt, upstream)
75
+ upstream[parent_id].append(node_id)
76
+
77
+ def collect_contained(self, node_id, upstream, contained):
78
+ if node_id not in upstream:
79
+ return
80
+ for child_id in upstream[node_id]:
81
+ if child_id not in contained:
82
+ contained[child_id] = True
83
+ self.collect_contained(child_id, upstream, contained)
84
+
85
+
86
+ def while_loop_close(self, flow_control, condition, dynprompt=None, unique_id=None, **kwargs):
87
+ assert dynprompt is not None
88
+ if not condition:
89
+ # We're done with the loop
90
+ values = []
91
+ for i in range(NUM_FLOW_SOCKETS):
92
+ values.append(kwargs.get(f"initial_value{i}", None))
93
+ return tuple(values)
94
+
95
+ # We want to loop
96
+ upstream = {}
97
+ # Get the list of all nodes between the open and close nodes
98
+ self.explore_dependencies(unique_id, dynprompt, upstream)
99
+
100
+ contained = {}
101
+ open_node = flow_control[0]
102
+ self.collect_contained(open_node, upstream, contained)
103
+ contained[unique_id] = True
104
+ contained[open_node] = True
105
+
106
+ # We'll use the default prefix, but to avoid having node names grow exponentially in size,
107
+ # we'll use "Recurse" for the name of the recursively-generated copy of this node.
108
+ graph = GraphBuilder()
109
+ for node_id in contained:
110
+ original_node = dynprompt.get_node(node_id)
111
+ node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id)
112
+ node.set_override_display_id(node_id)
113
+ for node_id in contained:
114
+ original_node = dynprompt.get_node(node_id)
115
+ node = graph.lookup_node("Recurse" if node_id == unique_id else node_id)
116
+ assert node is not None
117
+ for k, v in original_node["inputs"].items():
118
+ if is_link(v) and v[0] in contained:
119
+ parent = graph.lookup_node(v[0])
120
+ assert parent is not None
121
+ node.set_input(k, parent.out(v[1]))
122
+ else:
123
+ node.set_input(k, v)
124
+ new_open = graph.lookup_node(open_node)
125
+ assert new_open is not None
126
+ for i in range(NUM_FLOW_SOCKETS):
127
+ key = f"initial_value{i}"
128
+ new_open.set_input(key, kwargs.get(key, None))
129
+ my_clone = graph.lookup_node("Recurse")
130
+ assert my_clone is not None
131
+ result = map(lambda x: my_clone.out(x), range(NUM_FLOW_SOCKETS))
132
+ return {
133
+ "result": tuple(result),
134
+ "expand": graph.finalize(),
135
+ }
136
+
137
+ @VariantSupport()
138
+ class TestExecutionBlockerNode:
139
+ def __init__(self):
140
+ pass
141
+
142
+ @classmethod
143
+ def INPUT_TYPES(cls):
144
+ inputs = {
145
+ "required": {
146
+ "input": ("*",),
147
+ "block": ("BOOLEAN",),
148
+ "verbose": ("BOOLEAN", {"default": False}),
149
+ },
150
+ }
151
+ return inputs
152
+
153
+ RETURN_TYPES = ("*",)
154
+ RETURN_NAMES = ("output",)
155
+ FUNCTION = "execution_blocker"
156
+
157
+ CATEGORY = "Testing/Flow"
158
+
159
+ def execution_blocker(self, input, block, verbose):
160
+ if block:
161
+ return (ExecutionBlocker("Blocked Execution" if verbose else None),)
162
+ return (input,)
163
+
164
+ FLOW_CONTROL_NODE_CLASS_MAPPINGS = {
165
+ "TestWhileLoopOpen": TestWhileLoopOpen,
166
+ "TestWhileLoopClose": TestWhileLoopClose,
167
+ "TestExecutionBlocker": TestExecutionBlockerNode,
168
+ }
169
+ FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS = {
170
+ "TestWhileLoopOpen": "While Loop Open",
171
+ "TestWhileLoopClose": "While Loop Close",
172
+ "TestExecutionBlocker": "Execution Blocker",
173
+ }
tests/inference/testing_nodes/testing-pack/specific_tests.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .tools import VariantSupport
3
+ from comfy_execution.graph_utils import GraphBuilder
4
+
5
+ class TestLazyMixImages:
6
+ @classmethod
7
+ def INPUT_TYPES(cls):
8
+ return {
9
+ "required": {
10
+ "image1": ("IMAGE",{"lazy": True}),
11
+ "image2": ("IMAGE",{"lazy": True}),
12
+ "mask": ("MASK",),
13
+ },
14
+ }
15
+
16
+ RETURN_TYPES = ("IMAGE",)
17
+ FUNCTION = "mix"
18
+
19
+ CATEGORY = "Testing/Nodes"
20
+
21
+ def check_lazy_status(self, mask, image1, image2):
22
+ mask_min = mask.min()
23
+ mask_max = mask.max()
24
+ needed = []
25
+ if image1 is None and (mask_min != 1.0 or mask_max != 1.0):
26
+ needed.append("image1")
27
+ if image2 is None and (mask_min != 0.0 or mask_max != 0.0):
28
+ needed.append("image2")
29
+ return needed
30
+
31
+ # Not trying to handle different batch sizes here just to keep the demo simple
32
+ def mix(self, mask, image1, image2):
33
+ mask_min = mask.min()
34
+ mask_max = mask.max()
35
+ if mask_min == 0.0 and mask_max == 0.0:
36
+ return (image1,)
37
+ elif mask_min == 1.0 and mask_max == 1.0:
38
+ return (image2,)
39
+
40
+ if len(mask.shape) == 2:
41
+ mask = mask.unsqueeze(0)
42
+ if len(mask.shape) == 3:
43
+ mask = mask.unsqueeze(3)
44
+ if mask.shape[3] < image1.shape[3]:
45
+ mask = mask.repeat(1, 1, 1, image1.shape[3])
46
+
47
+ result = image1 * (1. - mask) + image2 * mask,
48
+ return (result[0],)
49
+
50
+ class TestVariadicAverage:
51
+ @classmethod
52
+ def INPUT_TYPES(cls):
53
+ return {
54
+ "required": {
55
+ "input1": ("IMAGE",),
56
+ },
57
+ }
58
+
59
+ RETURN_TYPES = ("IMAGE",)
60
+ FUNCTION = "variadic_average"
61
+
62
+ CATEGORY = "Testing/Nodes"
63
+
64
+ def variadic_average(self, input1, **kwargs):
65
+ inputs = [input1]
66
+ while 'input' + str(len(inputs) + 1) in kwargs:
67
+ inputs.append(kwargs['input' + str(len(inputs) + 1)])
68
+ return (torch.stack(inputs).mean(dim=0),)
69
+
70
+
71
+ class TestCustomIsChanged:
72
+ @classmethod
73
+ def INPUT_TYPES(cls):
74
+ return {
75
+ "required": {
76
+ "image": ("IMAGE",),
77
+ },
78
+ "optional": {
79
+ "should_change": ("BOOL", {"default": False}),
80
+ },
81
+ }
82
+
83
+ RETURN_TYPES = ("IMAGE",)
84
+ FUNCTION = "custom_is_changed"
85
+
86
+ CATEGORY = "Testing/Nodes"
87
+
88
+ def custom_is_changed(self, image, should_change=False):
89
+ return (image,)
90
+
91
+ @classmethod
92
+ def IS_CHANGED(cls, should_change=False, *args, **kwargs):
93
+ if should_change:
94
+ return float("NaN")
95
+ else:
96
+ return False
97
+
98
+ class TestIsChangedWithConstants:
99
+ @classmethod
100
+ def INPUT_TYPES(cls):
101
+ return {
102
+ "required": {
103
+ "image": ("IMAGE",),
104
+ "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
105
+ },
106
+ }
107
+
108
+ RETURN_TYPES = ("IMAGE",)
109
+ FUNCTION = "custom_is_changed"
110
+
111
+ CATEGORY = "Testing/Nodes"
112
+
113
+ def custom_is_changed(self, image, value):
114
+ return (image * value,)
115
+
116
+ @classmethod
117
+ def IS_CHANGED(cls, image, value):
118
+ if image is None:
119
+ return value
120
+ else:
121
+ return image.mean().item() * value
122
+
123
+ class TestCustomValidation1:
124
+ @classmethod
125
+ def INPUT_TYPES(cls):
126
+ return {
127
+ "required": {
128
+ "input1": ("IMAGE,FLOAT",),
129
+ "input2": ("IMAGE,FLOAT",),
130
+ },
131
+ }
132
+
133
+ RETURN_TYPES = ("IMAGE",)
134
+ FUNCTION = "custom_validation1"
135
+
136
+ CATEGORY = "Testing/Nodes"
137
+
138
+ def custom_validation1(self, input1, input2):
139
+ if isinstance(input1, float) and isinstance(input2, float):
140
+ result = torch.ones([1, 512, 512, 3]) * input1 * input2
141
+ else:
142
+ result = input1 * input2
143
+ return (result,)
144
+
145
+ @classmethod
146
+ def VALIDATE_INPUTS(cls, input1=None, input2=None):
147
+ if input1 is not None:
148
+ if not isinstance(input1, (torch.Tensor, float)):
149
+ return f"Invalid type of input1: {type(input1)}"
150
+ if input2 is not None:
151
+ if not isinstance(input2, (torch.Tensor, float)):
152
+ return f"Invalid type of input2: {type(input2)}"
153
+
154
+ return True
155
+
156
+ class TestCustomValidation2:
157
+ @classmethod
158
+ def INPUT_TYPES(cls):
159
+ return {
160
+ "required": {
161
+ "input1": ("IMAGE,FLOAT",),
162
+ "input2": ("IMAGE,FLOAT",),
163
+ },
164
+ }
165
+
166
+ RETURN_TYPES = ("IMAGE",)
167
+ FUNCTION = "custom_validation2"
168
+
169
+ CATEGORY = "Testing/Nodes"
170
+
171
+ def custom_validation2(self, input1, input2):
172
+ if isinstance(input1, float) and isinstance(input2, float):
173
+ result = torch.ones([1, 512, 512, 3]) * input1 * input2
174
+ else:
175
+ result = input1 * input2
176
+ return (result,)
177
+
178
+ @classmethod
179
+ def VALIDATE_INPUTS(cls, input_types, input1=None, input2=None):
180
+ if input1 is not None:
181
+ if not isinstance(input1, (torch.Tensor, float)):
182
+ return f"Invalid type of input1: {type(input1)}"
183
+ if input2 is not None:
184
+ if not isinstance(input2, (torch.Tensor, float)):
185
+ return f"Invalid type of input2: {type(input2)}"
186
+
187
+ if 'input1' in input_types:
188
+ if input_types['input1'] not in ["IMAGE", "FLOAT"]:
189
+ return f"Invalid type of input1: {input_types['input1']}"
190
+ if 'input2' in input_types:
191
+ if input_types['input2'] not in ["IMAGE", "FLOAT"]:
192
+ return f"Invalid type of input2: {input_types['input2']}"
193
+
194
+ return True
195
+
196
+ @VariantSupport()
197
+ class TestCustomValidation3:
198
+ @classmethod
199
+ def INPUT_TYPES(cls):
200
+ return {
201
+ "required": {
202
+ "input1": ("IMAGE,FLOAT",),
203
+ "input2": ("IMAGE,FLOAT",),
204
+ },
205
+ }
206
+
207
+ RETURN_TYPES = ("IMAGE",)
208
+ FUNCTION = "custom_validation3"
209
+
210
+ CATEGORY = "Testing/Nodes"
211
+
212
+ def custom_validation3(self, input1, input2):
213
+ if isinstance(input1, float) and isinstance(input2, float):
214
+ result = torch.ones([1, 512, 512, 3]) * input1 * input2
215
+ else:
216
+ result = input1 * input2
217
+ return (result,)
218
+
219
+ class TestCustomValidation4:
220
+ @classmethod
221
+ def INPUT_TYPES(cls):
222
+ return {
223
+ "required": {
224
+ "input1": ("FLOAT",),
225
+ "input2": ("FLOAT",),
226
+ },
227
+ }
228
+
229
+ RETURN_TYPES = ("IMAGE",)
230
+ FUNCTION = "custom_validation4"
231
+
232
+ CATEGORY = "Testing/Nodes"
233
+
234
+ def custom_validation4(self, input1, input2):
235
+ result = torch.ones([1, 512, 512, 3]) * input1 * input2
236
+ return (result,)
237
+
238
+ @classmethod
239
+ def VALIDATE_INPUTS(cls, input1, input2):
240
+ if input1 is not None:
241
+ if not isinstance(input1, float):
242
+ return f"Invalid type of input1: {type(input1)}"
243
+ if input2 is not None:
244
+ if not isinstance(input2, float):
245
+ return f"Invalid type of input2: {type(input2)}"
246
+
247
+ return True
248
+
249
+ class TestCustomValidation5:
250
+ @classmethod
251
+ def INPUT_TYPES(cls):
252
+ return {
253
+ "required": {
254
+ "input1": ("FLOAT", {"min": 0.0, "max": 1.0}),
255
+ "input2": ("FLOAT", {"min": 0.0, "max": 1.0}),
256
+ },
257
+ }
258
+
259
+ RETURN_TYPES = ("IMAGE",)
260
+ FUNCTION = "custom_validation5"
261
+
262
+ CATEGORY = "Testing/Nodes"
263
+
264
+ def custom_validation5(self, input1, input2):
265
+ value = input1 * input2
266
+ return (torch.ones([1, 512, 512, 3]) * value,)
267
+
268
+ @classmethod
269
+ def VALIDATE_INPUTS(cls, **kwargs):
270
+ if kwargs['input2'] == 7.0:
271
+ return "7s are not allowed. I've never liked 7s."
272
+ return True
273
+
274
+ class TestDynamicDependencyCycle:
275
+ @classmethod
276
+ def INPUT_TYPES(cls):
277
+ return {
278
+ "required": {
279
+ "input1": ("IMAGE",),
280
+ "input2": ("IMAGE",),
281
+ },
282
+ }
283
+
284
+ RETURN_TYPES = ("IMAGE",)
285
+ FUNCTION = "dynamic_dependency_cycle"
286
+
287
+ CATEGORY = "Testing/Nodes"
288
+
289
+ def dynamic_dependency_cycle(self, input1, input2):
290
+ g = GraphBuilder()
291
+ mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
292
+ mix1 = g.node("TestLazyMixImages", image1=input1, mask=mask.out(0))
293
+ mix2 = g.node("TestLazyMixImages", image1=mix1.out(0), image2=input2, mask=mask.out(0))
294
+
295
+ # Create the cyle
296
+ mix1.set_input("image2", mix2.out(0))
297
+
298
+ return {
299
+ "result": (mix2.out(0),),
300
+ "expand": g.finalize(),
301
+ }
302
+
303
+ class TestMixedExpansionReturns:
304
+ @classmethod
305
+ def INPUT_TYPES(cls):
306
+ return {
307
+ "required": {
308
+ "input1": ("FLOAT",),
309
+ },
310
+ }
311
+
312
+ RETURN_TYPES = ("IMAGE","IMAGE")
313
+ FUNCTION = "mixed_expansion_returns"
314
+
315
+ CATEGORY = "Testing/Nodes"
316
+
317
+ def mixed_expansion_returns(self, input1):
318
+ white_image = torch.ones([1, 512, 512, 3])
319
+ if input1 <= 0.1:
320
+ return (torch.ones([1, 512, 512, 3]) * 0.1, white_image)
321
+ elif input1 <= 0.2:
322
+ return {
323
+ "result": (torch.ones([1, 512, 512, 3]) * 0.2, white_image),
324
+ }
325
+ else:
326
+ g = GraphBuilder()
327
+ mask = g.node("StubMask", value=0.3, height=512, width=512, batch_size=1)
328
+ black = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
329
+ white = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
330
+ mix = g.node("TestLazyMixImages", image1=black.out(0), image2=white.out(0), mask=mask.out(0))
331
+ return {
332
+ "result": (mix.out(0), white_image),
333
+ "expand": g.finalize(),
334
+ }
335
+
336
+ TEST_NODE_CLASS_MAPPINGS = {
337
+ "TestLazyMixImages": TestLazyMixImages,
338
+ "TestVariadicAverage": TestVariadicAverage,
339
+ "TestCustomIsChanged": TestCustomIsChanged,
340
+ "TestIsChangedWithConstants": TestIsChangedWithConstants,
341
+ "TestCustomValidation1": TestCustomValidation1,
342
+ "TestCustomValidation2": TestCustomValidation2,
343
+ "TestCustomValidation3": TestCustomValidation3,
344
+ "TestCustomValidation4": TestCustomValidation4,
345
+ "TestCustomValidation5": TestCustomValidation5,
346
+ "TestDynamicDependencyCycle": TestDynamicDependencyCycle,
347
+ "TestMixedExpansionReturns": TestMixedExpansionReturns,
348
+ }
349
+
350
+ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
351
+ "TestLazyMixImages": "Lazy Mix Images",
352
+ "TestVariadicAverage": "Variadic Average",
353
+ "TestCustomIsChanged": "Custom IsChanged",
354
+ "TestIsChangedWithConstants": "IsChanged With Constants",
355
+ "TestCustomValidation1": "Custom Validation 1",
356
+ "TestCustomValidation2": "Custom Validation 2",
357
+ "TestCustomValidation3": "Custom Validation 3",
358
+ "TestCustomValidation4": "Custom Validation 4",
359
+ "TestCustomValidation5": "Custom Validation 5",
360
+ "TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
361
+ "TestMixedExpansionReturns": "Mixed Expansion Returns",
362
+ }
tests/inference/testing_nodes/testing-pack/stubs.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class StubImage:
4
+ def __init__(self):
5
+ pass
6
+
7
+ @classmethod
8
+ def INPUT_TYPES(cls):
9
+ return {
10
+ "required": {
11
+ "content": (['WHITE', 'BLACK', 'NOISE'],),
12
+ "height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
13
+ "width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
14
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
15
+ },
16
+ }
17
+
18
+ RETURN_TYPES = ("IMAGE",)
19
+ FUNCTION = "stub_image"
20
+
21
+ CATEGORY = "Testing/Stub Nodes"
22
+
23
+ def stub_image(self, content, height, width, batch_size):
24
+ if content == "WHITE":
25
+ return (torch.ones(batch_size, height, width, 3),)
26
+ elif content == "BLACK":
27
+ return (torch.zeros(batch_size, height, width, 3),)
28
+ elif content == "NOISE":
29
+ return (torch.rand(batch_size, height, width, 3),)
30
+
31
+ class StubConstantImage:
32
+ def __init__(self):
33
+ pass
34
+ @classmethod
35
+ def INPUT_TYPES(cls):
36
+ return {
37
+ "required": {
38
+ "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
39
+ "height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
40
+ "width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
41
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
42
+ },
43
+ }
44
+
45
+ RETURN_TYPES = ("IMAGE",)
46
+ FUNCTION = "stub_constant_image"
47
+
48
+ CATEGORY = "Testing/Stub Nodes"
49
+
50
+ def stub_constant_image(self, value, height, width, batch_size):
51
+ return (torch.ones(batch_size, height, width, 3) * value,)
52
+
53
+ class StubMask:
54
+ def __init__(self):
55
+ pass
56
+
57
+ @classmethod
58
+ def INPUT_TYPES(cls):
59
+ return {
60
+ "required": {
61
+ "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
62
+ "height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
63
+ "width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
64
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
65
+ },
66
+ }
67
+
68
+ RETURN_TYPES = ("MASK",)
69
+ FUNCTION = "stub_mask"
70
+
71
+ CATEGORY = "Testing/Stub Nodes"
72
+
73
+ def stub_mask(self, value, height, width, batch_size):
74
+ return (torch.ones(batch_size, height, width) * value,)
75
+
76
+ class StubInt:
77
+ def __init__(self):
78
+ pass
79
+
80
+ @classmethod
81
+ def INPUT_TYPES(cls):
82
+ return {
83
+ "required": {
84
+ "value": ("INT", {"default": 0, "min": -0xffffffff, "max": 0xffffffff, "step": 1}),
85
+ },
86
+ }
87
+
88
+ RETURN_TYPES = ("INT",)
89
+ FUNCTION = "stub_int"
90
+
91
+ CATEGORY = "Testing/Stub Nodes"
92
+
93
+ def stub_int(self, value):
94
+ return (value,)
95
+
96
+ class StubFloat:
97
+ def __init__(self):
98
+ pass
99
+
100
+ @classmethod
101
+ def INPUT_TYPES(cls):
102
+ return {
103
+ "required": {
104
+ "value": ("FLOAT", {"default": 0.0, "min": -1.0e38, "max": 1.0e38, "step": 0.01}),
105
+ },
106
+ }
107
+
108
+ RETURN_TYPES = ("FLOAT",)
109
+ FUNCTION = "stub_float"
110
+
111
+ CATEGORY = "Testing/Stub Nodes"
112
+
113
+ def stub_float(self, value):
114
+ return (value,)
115
+
116
+ TEST_STUB_NODE_CLASS_MAPPINGS = {
117
+ "StubImage": StubImage,
118
+ "StubConstantImage": StubConstantImage,
119
+ "StubMask": StubMask,
120
+ "StubInt": StubInt,
121
+ "StubFloat": StubFloat,
122
+ }
123
+ TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
124
+ "StubImage": "Stub Image",
125
+ "StubConstantImage": "Stub Constant Image",
126
+ "StubMask": "Stub Mask",
127
+ "StubInt": "Stub Int",
128
+ "StubFloat": "Stub Float",
129
+ }
tests/inference/testing_nodes/testing-pack/tools.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def MakeSmartType(t):
3
+ if isinstance(t, str):
4
+ return SmartType(t)
5
+ return t
6
+
7
+ class SmartType(str):
8
+ def __ne__(self, other):
9
+ if self == "*" or other == "*":
10
+ return False
11
+ selfset = set(self.split(','))
12
+ otherset = set(other.split(','))
13
+ return not selfset.issubset(otherset)
14
+
15
+ def VariantSupport():
16
+ def decorator(cls):
17
+ if hasattr(cls, "INPUT_TYPES"):
18
+ old_input_types = getattr(cls, "INPUT_TYPES")
19
+ def new_input_types(*args, **kwargs):
20
+ types = old_input_types(*args, **kwargs)
21
+ for category in ["required", "optional"]:
22
+ if category not in types:
23
+ continue
24
+ for key, value in types[category].items():
25
+ if isinstance(value, tuple):
26
+ types[category][key] = (MakeSmartType(value[0]),) + value[1:]
27
+ return types
28
+ setattr(cls, "INPUT_TYPES", new_input_types)
29
+ if hasattr(cls, "RETURN_TYPES"):
30
+ old_return_types = cls.RETURN_TYPES
31
+ setattr(cls, "RETURN_TYPES", tuple(MakeSmartType(x) for x in old_return_types))
32
+ if hasattr(cls, "VALIDATE_INPUTS"):
33
+ # Reflection is used to determine what the function signature is, so we can't just change the function signature
34
+ raise NotImplementedError("VariantSupport does not support VALIDATE_INPUTS yet")
35
+ else:
36
+ def validate_inputs(input_types):
37
+ inputs = cls.INPUT_TYPES()
38
+ for key, value in input_types.items():
39
+ if isinstance(value, SmartType):
40
+ continue
41
+ if "required" in inputs and key in inputs["required"]:
42
+ expected_type = inputs["required"][key][0]
43
+ elif "optional" in inputs and key in inputs["optional"]:
44
+ expected_type = inputs["optional"][key][0]
45
+ else:
46
+ expected_type = None
47
+ if expected_type is not None and MakeSmartType(value) != expected_type:
48
+ return f"Invalid type of {key}: {value} (expected {expected_type})"
49
+ return True
50
+ setattr(cls, "VALIDATE_INPUTS", validate_inputs)
51
+ return cls
52
+ return decorator
53
+
tests/inference/testing_nodes/testing-pack/util.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from comfy_execution.graph_utils import GraphBuilder
2
+ from .tools import VariantSupport
3
+
4
+ @VariantSupport()
5
+ class TestAccumulateNode:
6
+ def __init__(self):
7
+ pass
8
+
9
+ @classmethod
10
+ def INPUT_TYPES(cls):
11
+ return {
12
+ "required": {
13
+ "to_add": ("*",),
14
+ },
15
+ "optional": {
16
+ "accumulation": ("ACCUMULATION",),
17
+ },
18
+ }
19
+
20
+ RETURN_TYPES = ("ACCUMULATION",)
21
+ FUNCTION = "accumulate"
22
+
23
+ CATEGORY = "Testing/Lists"
24
+
25
+ def accumulate(self, to_add, accumulation = None):
26
+ if accumulation is None:
27
+ value = [to_add]
28
+ else:
29
+ value = accumulation["accum"] + [to_add]
30
+ return ({"accum": value},)
31
+
32
+ @VariantSupport()
33
+ class TestAccumulationHeadNode:
34
+ def __init__(self):
35
+ pass
36
+
37
+ @classmethod
38
+ def INPUT_TYPES(cls):
39
+ return {
40
+ "required": {
41
+ "accumulation": ("ACCUMULATION",),
42
+ },
43
+ }
44
+
45
+ RETURN_TYPES = ("ACCUMULATION", "*",)
46
+ FUNCTION = "accumulation_head"
47
+
48
+ CATEGORY = "Testing/Lists"
49
+
50
+ def accumulation_head(self, accumulation):
51
+ accum = accumulation["accum"]
52
+ if len(accum) == 0:
53
+ return (accumulation, None)
54
+ else:
55
+ return ({"accum": accum[1:]}, accum[0])
56
+
57
+ class TestAccumulationTailNode:
58
+ def __init__(self):
59
+ pass
60
+
61
+ @classmethod
62
+ def INPUT_TYPES(cls):
63
+ return {
64
+ "required": {
65
+ "accumulation": ("ACCUMULATION",),
66
+ },
67
+ }
68
+
69
+ RETURN_TYPES = ("ACCUMULATION", "*",)
70
+ FUNCTION = "accumulation_tail"
71
+
72
+ CATEGORY = "Testing/Lists"
73
+
74
+ def accumulation_tail(self, accumulation):
75
+ accum = accumulation["accum"]
76
+ if len(accum) == 0:
77
+ return (None, accumulation)
78
+ else:
79
+ return ({"accum": accum[:-1]}, accum[-1])
80
+
81
+ @VariantSupport()
82
+ class TestAccumulationToListNode:
83
+ def __init__(self):
84
+ pass
85
+
86
+ @classmethod
87
+ def INPUT_TYPES(cls):
88
+ return {
89
+ "required": {
90
+ "accumulation": ("ACCUMULATION",),
91
+ },
92
+ }
93
+
94
+ RETURN_TYPES = ("*",)
95
+ OUTPUT_IS_LIST = (True,)
96
+
97
+ FUNCTION = "accumulation_to_list"
98
+
99
+ CATEGORY = "Testing/Lists"
100
+
101
+ def accumulation_to_list(self, accumulation):
102
+ return (accumulation["accum"],)
103
+
104
+ @VariantSupport()
105
+ class TestListToAccumulationNode:
106
+ def __init__(self):
107
+ pass
108
+
109
+ @classmethod
110
+ def INPUT_TYPES(cls):
111
+ return {
112
+ "required": {
113
+ "list": ("*",),
114
+ },
115
+ }
116
+
117
+ RETURN_TYPES = ("ACCUMULATION",)
118
+ INPUT_IS_LIST = (True,)
119
+
120
+ FUNCTION = "list_to_accumulation"
121
+
122
+ CATEGORY = "Testing/Lists"
123
+
124
+ def list_to_accumulation(self, list):
125
+ return ({"accum": list},)
126
+
127
+ @VariantSupport()
128
+ class TestAccumulationGetLengthNode:
129
+ def __init__(self):
130
+ pass
131
+
132
+ @classmethod
133
+ def INPUT_TYPES(cls):
134
+ return {
135
+ "required": {
136
+ "accumulation": ("ACCUMULATION",),
137
+ },
138
+ }
139
+
140
+ RETURN_TYPES = ("INT",)
141
+
142
+ FUNCTION = "accumlength"
143
+
144
+ CATEGORY = "Testing/Lists"
145
+
146
+ def accumlength(self, accumulation):
147
+ return (len(accumulation['accum']),)
148
+
149
+ @VariantSupport()
150
+ class TestAccumulationGetItemNode:
151
+ def __init__(self):
152
+ pass
153
+
154
+ @classmethod
155
+ def INPUT_TYPES(cls):
156
+ return {
157
+ "required": {
158
+ "accumulation": ("ACCUMULATION",),
159
+ "index": ("INT", {"default":0, "step":1})
160
+ },
161
+ }
162
+
163
+ RETURN_TYPES = ("*",)
164
+
165
+ FUNCTION = "get_item"
166
+
167
+ CATEGORY = "Testing/Lists"
168
+
169
+ def get_item(self, accumulation, index):
170
+ return (accumulation['accum'][index],)
171
+
172
+ @VariantSupport()
173
+ class TestAccumulationSetItemNode:
174
+ def __init__(self):
175
+ pass
176
+
177
+ @classmethod
178
+ def INPUT_TYPES(cls):
179
+ return {
180
+ "required": {
181
+ "accumulation": ("ACCUMULATION",),
182
+ "index": ("INT", {"default":0, "step":1}),
183
+ "value": ("*",),
184
+ },
185
+ }
186
+
187
+ RETURN_TYPES = ("ACCUMULATION",)
188
+
189
+ FUNCTION = "set_item"
190
+
191
+ CATEGORY = "Testing/Lists"
192
+
193
+ def set_item(self, accumulation, index, value):
194
+ new_accum = accumulation['accum'][:]
195
+ new_accum[index] = value
196
+ return ({"accum": new_accum},)
197
+
198
+ class TestIntMathOperation:
199
+ def __init__(self):
200
+ pass
201
+
202
+ @classmethod
203
+ def INPUT_TYPES(cls):
204
+ return {
205
+ "required": {
206
+ "a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
207
+ "b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
208
+ "operation": (["add", "subtract", "multiply", "divide", "modulo", "power"],),
209
+ },
210
+ }
211
+
212
+ RETURN_TYPES = ("INT",)
213
+ FUNCTION = "int_math_operation"
214
+
215
+ CATEGORY = "Testing/Logic"
216
+
217
+ def int_math_operation(self, a, b, operation):
218
+ if operation == "add":
219
+ return (a + b,)
220
+ elif operation == "subtract":
221
+ return (a - b,)
222
+ elif operation == "multiply":
223
+ return (a * b,)
224
+ elif operation == "divide":
225
+ return (a // b,)
226
+ elif operation == "modulo":
227
+ return (a % b,)
228
+ elif operation == "power":
229
+ return (a ** b,)
230
+
231
+
232
+ from .flow_control import NUM_FLOW_SOCKETS
233
+ @VariantSupport()
234
+ class TestForLoopOpen:
235
+ def __init__(self):
236
+ pass
237
+
238
+ @classmethod
239
+ def INPUT_TYPES(cls):
240
+ return {
241
+ "required": {
242
+ "remaining": ("INT", {"default": 1, "min": 0, "max": 100000, "step": 1}),
243
+ },
244
+ "optional": {
245
+ f"initial_value{i}": ("*",) for i in range(1, NUM_FLOW_SOCKETS)
246
+ },
247
+ "hidden": {
248
+ "initial_value0": ("*",)
249
+ }
250
+ }
251
+
252
+ RETURN_TYPES = tuple(["FLOW_CONTROL", "INT",] + ["*"] * (NUM_FLOW_SOCKETS-1))
253
+ RETURN_NAMES = tuple(["flow_control", "remaining"] + [f"value{i}" for i in range(1, NUM_FLOW_SOCKETS)])
254
+ FUNCTION = "for_loop_open"
255
+
256
+ CATEGORY = "Testing/Flow"
257
+
258
+ def for_loop_open(self, remaining, **kwargs):
259
+ graph = GraphBuilder()
260
+ if "initial_value0" in kwargs:
261
+ remaining = kwargs["initial_value0"]
262
+ graph.node("TestWhileLoopOpen", condition=remaining, initial_value0=remaining, **{(f"initial_value{i}"): kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)})
263
+ outputs = [kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)]
264
+ return {
265
+ "result": tuple(["stub", remaining] + outputs),
266
+ "expand": graph.finalize(),
267
+ }
268
+
269
+ @VariantSupport()
270
+ class TestForLoopClose:
271
+ def __init__(self):
272
+ pass
273
+
274
+ @classmethod
275
+ def INPUT_TYPES(cls):
276
+ return {
277
+ "required": {
278
+ "flow_control": ("FLOW_CONTROL", {"rawLink": True}),
279
+ },
280
+ "optional": {
281
+ f"initial_value{i}": ("*",{"rawLink": True}) for i in range(1, NUM_FLOW_SOCKETS)
282
+ },
283
+ }
284
+
285
+ RETURN_TYPES = tuple(["*"] * (NUM_FLOW_SOCKETS-1))
286
+ RETURN_NAMES = tuple([f"value{i}" for i in range(1, NUM_FLOW_SOCKETS)])
287
+ FUNCTION = "for_loop_close"
288
+
289
+ CATEGORY = "Testing/Flow"
290
+
291
+ def for_loop_close(self, flow_control, **kwargs):
292
+ graph = GraphBuilder()
293
+ while_open = flow_control[0]
294
+ sub = graph.node("TestIntMathOperation", operation="subtract", a=[while_open,1], b=1)
295
+ cond = graph.node("TestToBoolNode", value=sub.out(0))
296
+ input_values = {f"initial_value{i}": kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)}
297
+ while_close = graph.node("TestWhileLoopClose",
298
+ flow_control=flow_control,
299
+ condition=cond.out(0),
300
+ initial_value0=sub.out(0),
301
+ **input_values)
302
+ return {
303
+ "result": tuple([while_close.out(i) for i in range(1, NUM_FLOW_SOCKETS)]),
304
+ "expand": graph.finalize(),
305
+ }
306
+
307
+ NUM_LIST_SOCKETS = 10
308
+ @VariantSupport()
309
+ class TestMakeListNode:
310
+ def __init__(self):
311
+ pass
312
+
313
+ @classmethod
314
+ def INPUT_TYPES(cls):
315
+ return {
316
+ "required": {
317
+ "value1": ("*",),
318
+ },
319
+ "optional": {
320
+ f"value{i}": ("*",) for i in range(1, NUM_LIST_SOCKETS)
321
+ },
322
+ }
323
+
324
+ RETURN_TYPES = ("*",)
325
+ FUNCTION = "make_list"
326
+ OUTPUT_IS_LIST = (True,)
327
+
328
+ CATEGORY = "Testing/Lists"
329
+
330
+ def make_list(self, **kwargs):
331
+ result = []
332
+ for i in range(NUM_LIST_SOCKETS):
333
+ if f"value{i}" in kwargs:
334
+ result.append(kwargs[f"value{i}"])
335
+ return (result,)
336
+
337
+ UTILITY_NODE_CLASS_MAPPINGS = {
338
+ "TestAccumulateNode": TestAccumulateNode,
339
+ "TestAccumulationHeadNode": TestAccumulationHeadNode,
340
+ "TestAccumulationTailNode": TestAccumulationTailNode,
341
+ "TestAccumulationToListNode": TestAccumulationToListNode,
342
+ "TestListToAccumulationNode": TestListToAccumulationNode,
343
+ "TestAccumulationGetLengthNode": TestAccumulationGetLengthNode,
344
+ "TestAccumulationGetItemNode": TestAccumulationGetItemNode,
345
+ "TestAccumulationSetItemNode": TestAccumulationSetItemNode,
346
+ "TestForLoopOpen": TestForLoopOpen,
347
+ "TestForLoopClose": TestForLoopClose,
348
+ "TestIntMathOperation": TestIntMathOperation,
349
+ "TestMakeListNode": TestMakeListNode,
350
+ }
351
+ UTILITY_NODE_DISPLAY_NAME_MAPPINGS = {
352
+ "TestAccumulateNode": "Accumulate",
353
+ "TestAccumulationHeadNode": "Accumulation Head",
354
+ "TestAccumulationTailNode": "Accumulation Tail",
355
+ "TestAccumulationToListNode": "Accumulation to List",
356
+ "TestListToAccumulationNode": "List to Accumulation",
357
+ "TestAccumulationGetLengthNode": "Accumulation Get Length",
358
+ "TestAccumulationGetItemNode": "Accumulation Get Item",
359
+ "TestAccumulationSetItemNode": "Accumulation Set Item",
360
+ "TestForLoopOpen": "For Loop Open",
361
+ "TestForLoopClose": "For Loop Close",
362
+ "TestIntMathOperation": "Int Math Operation",
363
+ "TestMakeListNode": "Make List",
364
+ }