Spaces:
Running
on
Zero
Running
on
Zero
Upload 42 files
Browse files- .ci/update_windows/update.py +146 -0
- .ci/update_windows/update_comfyui.bat +8 -0
- .ci/update_windows/update_comfyui_stable.bat +8 -0
- .ci/windows_base_files/README_VERY_IMPORTANT.txt +31 -0
- .ci/windows_base_files/run_cpu.bat +2 -0
- .ci/windows_base_files/run_nvidia_gpu.bat +2 -0
- .ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat +2 -0
- CODEOWNERS +23 -0
- CONTRIBUTING.md +41 -0
- api_server/__init__.py +0 -0
- api_server/routes/__init__.py +0 -0
- api_server/routes/internal/README.md +3 -0
- api_server/routes/internal/__init__.py +0 -0
- api_server/routes/internal/internal_routes.py +75 -0
- api_server/services/__init__.py +0 -0
- api_server/services/file_service.py +13 -0
- api_server/services/terminal_service.py +60 -0
- api_server/utils/file_operations.py +42 -0
- app.py +421 -0
- comfy_execution/caching.py +318 -0
- comfy_execution/graph.py +270 -0
- comfy_execution/graph_utils.py +139 -0
- comfy_execution/validation.py +39 -0
- comfyui_version.py +3 -0
- cuda_malloc.py +90 -0
- extra_model_paths.yaml.example +47 -0
- fix_torch.py +28 -0
- folder_paths.py +385 -0
- latent_preview.py +105 -0
- main.py +301 -0
- new_updater.py +35 -0
- node_helpers.py +37 -0
- notebooks/comfyui_colab.ipynb +322 -0
- output/_output_images_will_be_put_here +0 -0
- pyproject.toml +23 -0
- pytest.ini +9 -0
- requirements.txt +29 -0
- script_examples/basic_api_example.py +119 -0
- script_examples/websockets_api_example.py +166 -0
- script_examples/websockets_api_example_ws_images.py +159 -0
- utils/__init__.py +0 -0
- utils/extra_config.py +33 -0
.ci/update_windows/update.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pygit2
|
2 |
+
from datetime import datetime
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import filecmp
|
7 |
+
|
8 |
+
def pull(repo, remote_name='origin', branch='master'):
|
9 |
+
for remote in repo.remotes:
|
10 |
+
if remote.name == remote_name:
|
11 |
+
remote.fetch()
|
12 |
+
remote_master_id = repo.lookup_reference('refs/remotes/origin/%s' % (branch)).target
|
13 |
+
merge_result, _ = repo.merge_analysis(remote_master_id)
|
14 |
+
# Up to date, do nothing
|
15 |
+
if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE:
|
16 |
+
return
|
17 |
+
# We can just fastforward
|
18 |
+
elif merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD:
|
19 |
+
repo.checkout_tree(repo.get(remote_master_id))
|
20 |
+
try:
|
21 |
+
master_ref = repo.lookup_reference('refs/heads/%s' % (branch))
|
22 |
+
master_ref.set_target(remote_master_id)
|
23 |
+
except KeyError:
|
24 |
+
repo.create_branch(branch, repo.get(remote_master_id))
|
25 |
+
repo.head.set_target(remote_master_id)
|
26 |
+
elif merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL:
|
27 |
+
repo.merge(remote_master_id)
|
28 |
+
|
29 |
+
if repo.index.conflicts is not None:
|
30 |
+
for conflict in repo.index.conflicts:
|
31 |
+
print('Conflicts found in:', conflict[0].path) # noqa: T201
|
32 |
+
raise AssertionError('Conflicts, ahhhhh!!')
|
33 |
+
|
34 |
+
user = repo.default_signature
|
35 |
+
tree = repo.index.write_tree()
|
36 |
+
repo.create_commit('HEAD',
|
37 |
+
user,
|
38 |
+
user,
|
39 |
+
'Merge!',
|
40 |
+
tree,
|
41 |
+
[repo.head.target, remote_master_id])
|
42 |
+
# We need to do this or git CLI will think we are still merging.
|
43 |
+
repo.state_cleanup()
|
44 |
+
else:
|
45 |
+
raise AssertionError('Unknown merge analysis result')
|
46 |
+
|
47 |
+
pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0)
|
48 |
+
repo_path = str(sys.argv[1])
|
49 |
+
repo = pygit2.Repository(repo_path)
|
50 |
+
ident = pygit2.Signature('comfyui', 'comfy@ui')
|
51 |
+
try:
|
52 |
+
print("stashing current changes") # noqa: T201
|
53 |
+
repo.stash(ident)
|
54 |
+
except KeyError:
|
55 |
+
print("nothing to stash") # noqa: T201
|
56 |
+
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
57 |
+
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
58 |
+
try:
|
59 |
+
repo.branches.local.create(backup_branch_name, repo.head.peel())
|
60 |
+
except:
|
61 |
+
pass
|
62 |
+
|
63 |
+
print("checking out master branch") # noqa: T201
|
64 |
+
branch = repo.lookup_branch('master')
|
65 |
+
if branch is None:
|
66 |
+
ref = repo.lookup_reference('refs/remotes/origin/master')
|
67 |
+
repo.checkout(ref)
|
68 |
+
branch = repo.lookup_branch('master')
|
69 |
+
if branch is None:
|
70 |
+
repo.create_branch('master', repo.get(ref.target))
|
71 |
+
else:
|
72 |
+
ref = repo.lookup_reference(branch.name)
|
73 |
+
repo.checkout(ref)
|
74 |
+
|
75 |
+
print("pulling latest changes") # noqa: T201
|
76 |
+
pull(repo)
|
77 |
+
|
78 |
+
if "--stable" in sys.argv:
|
79 |
+
def latest_tag(repo):
|
80 |
+
versions = []
|
81 |
+
for k in repo.references:
|
82 |
+
try:
|
83 |
+
prefix = "refs/tags/v"
|
84 |
+
if k.startswith(prefix):
|
85 |
+
version = list(map(int, k[len(prefix):].split(".")))
|
86 |
+
versions.append((version[0] * 10000000000 + version[1] * 100000 + version[2], k))
|
87 |
+
except:
|
88 |
+
pass
|
89 |
+
versions.sort()
|
90 |
+
if len(versions) > 0:
|
91 |
+
return versions[-1][1]
|
92 |
+
return None
|
93 |
+
latest_tag = latest_tag(repo)
|
94 |
+
if latest_tag is not None:
|
95 |
+
repo.checkout(latest_tag)
|
96 |
+
|
97 |
+
print("Done!") # noqa: T201
|
98 |
+
|
99 |
+
self_update = True
|
100 |
+
if len(sys.argv) > 2:
|
101 |
+
self_update = '--skip_self_update' not in sys.argv
|
102 |
+
|
103 |
+
update_py_path = os.path.realpath(__file__)
|
104 |
+
repo_update_py_path = os.path.join(repo_path, ".ci/update_windows/update.py")
|
105 |
+
|
106 |
+
cur_path = os.path.dirname(update_py_path)
|
107 |
+
|
108 |
+
|
109 |
+
req_path = os.path.join(cur_path, "current_requirements.txt")
|
110 |
+
repo_req_path = os.path.join(repo_path, "requirements.txt")
|
111 |
+
|
112 |
+
|
113 |
+
def files_equal(file1, file2):
|
114 |
+
try:
|
115 |
+
return filecmp.cmp(file1, file2, shallow=False)
|
116 |
+
except:
|
117 |
+
return False
|
118 |
+
|
119 |
+
def file_size(f):
|
120 |
+
try:
|
121 |
+
return os.path.getsize(f)
|
122 |
+
except:
|
123 |
+
return 0
|
124 |
+
|
125 |
+
|
126 |
+
if self_update and not files_equal(update_py_path, repo_update_py_path) and file_size(repo_update_py_path) > 10:
|
127 |
+
shutil.copy(repo_update_py_path, os.path.join(cur_path, "update_new.py"))
|
128 |
+
exit()
|
129 |
+
|
130 |
+
if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path):
|
131 |
+
import subprocess
|
132 |
+
try:
|
133 |
+
subprocess.check_call([sys.executable, '-s', '-m', 'pip', 'install', '-r', repo_req_path])
|
134 |
+
shutil.copy(repo_req_path, req_path)
|
135 |
+
except:
|
136 |
+
pass
|
137 |
+
|
138 |
+
|
139 |
+
stable_update_script = os.path.join(repo_path, ".ci/update_windows/update_comfyui_stable.bat")
|
140 |
+
stable_update_script_to = os.path.join(cur_path, "update_comfyui_stable.bat")
|
141 |
+
|
142 |
+
try:
|
143 |
+
if not file_size(stable_update_script_to) > 10:
|
144 |
+
shutil.copy(stable_update_script, stable_update_script_to)
|
145 |
+
except:
|
146 |
+
pass
|
.ci/update_windows/update_comfyui.bat
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
..\python_embeded\python.exe .\update.py ..\ComfyUI\
|
3 |
+
if exist update_new.py (
|
4 |
+
move /y update_new.py update.py
|
5 |
+
echo Running updater again since it got updated.
|
6 |
+
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update
|
7 |
+
)
|
8 |
+
if "%~1"=="" pause
|
.ci/update_windows/update_comfyui_stable.bat
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --stable
|
3 |
+
if exist update_new.py (
|
4 |
+
move /y update_new.py update.py
|
5 |
+
echo Running updater again since it got updated.
|
6 |
+
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update --stable
|
7 |
+
)
|
8 |
+
if "%~1"=="" pause
|
.ci/windows_base_files/README_VERY_IMPORTANT.txt
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
HOW TO RUN:
|
2 |
+
|
3 |
+
if you have a NVIDIA gpu:
|
4 |
+
|
5 |
+
run_nvidia_gpu.bat
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
To run it in slow CPU mode:
|
10 |
+
|
11 |
+
run_cpu.bat
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
16 |
+
|
17 |
+
You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
|
18 |
+
|
19 |
+
|
20 |
+
RECOMMENDED WAY TO UPDATE:
|
21 |
+
To update the ComfyUI code: update\update_comfyui.bat
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
To update ComfyUI with the python dependencies, note that you should ONLY run this if you have issues with python dependencies.
|
26 |
+
update\update_comfyui_and_python_dependencies.bat
|
27 |
+
|
28 |
+
|
29 |
+
TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI:
|
30 |
+
In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
|
31 |
+
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.
|
.ci/windows_base_files/run_cpu.bat
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.\python_embeded\python.exe -s ComfyUI\main.py --cpu --windows-standalone-build
|
2 |
+
pause
|
.ci/windows_base_files/run_nvidia_gpu.bat
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
2 |
+
pause
|
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
|
2 |
+
pause
|
CODEOWNERS
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Admins
|
2 |
+
* @comfyanonymous
|
3 |
+
|
4 |
+
# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
|
5 |
+
# Inlined the team members for now.
|
6 |
+
|
7 |
+
# Maintainers
|
8 |
+
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
9 |
+
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
10 |
+
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
11 |
+
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
12 |
+
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
13 |
+
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
14 |
+
|
15 |
+
# Python web server
|
16 |
+
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
17 |
+
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
18 |
+
|
19 |
+
# Frontend assets
|
20 |
+
/web/ @huchenlei @webfiltered @pythongosssss @yoland68 @robinjhuang
|
21 |
+
|
22 |
+
# Extra nodes
|
23 |
+
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to ComfyUI
|
2 |
+
|
3 |
+
Welcome, and thank you for your interest in contributing to ComfyUI!
|
4 |
+
|
5 |
+
There are several ways in which you can contribute, beyond writing code. The goal of this document is to provide a high-level overview of how you can get involved.
|
6 |
+
|
7 |
+
## Asking Questions
|
8 |
+
|
9 |
+
Have a question? Instead of opening an issue, please ask on [Discord](https://comfy.org/discord) or [Matrix](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) channels. Our team and the community will help you.
|
10 |
+
|
11 |
+
## Providing Feedback
|
12 |
+
|
13 |
+
Your comments and feedback are welcome, and the development team is available via a handful of different channels.
|
14 |
+
|
15 |
+
See the `#bug-report`, `#feature-request` and `#feedback` channels on Discord.
|
16 |
+
|
17 |
+
## Reporting Issues
|
18 |
+
|
19 |
+
Have you identified a reproducible problem in ComfyUI? Do you have a feature request? We want to hear about it! Here's how you can report your issue as effectively as possible.
|
20 |
+
|
21 |
+
|
22 |
+
### Look For an Existing Issue
|
23 |
+
|
24 |
+
Before you create a new issue, please do a search in [open issues](https://github.com/comfyanonymous/ComfyUI/issues) to see if the issue or feature request has already been filed.
|
25 |
+
|
26 |
+
If you find your issue already exists, make relevant comments and add your [reaction](https://github.com/blog/2119-add-reactions-to-pull-requests-issues-and-comments). Use a reaction in place of a "+1" comment:
|
27 |
+
|
28 |
+
* 👍 - upvote
|
29 |
+
* 👎 - downvote
|
30 |
+
|
31 |
+
If you cannot find an existing issue that describes your bug or feature, create a new issue. We have an issue template in place to organize new issues.
|
32 |
+
|
33 |
+
|
34 |
+
### Creating Pull Requests
|
35 |
+
|
36 |
+
* Please refer to the article on [creating pull requests](https://github.com/comfyanonymous/ComfyUI/wiki/How-to-Contribute-Code) and contributing to this project.
|
37 |
+
|
38 |
+
|
39 |
+
## Thank You
|
40 |
+
|
41 |
+
Your contributions to open source, large or small, make great projects like this possible. Thank you for taking the time to contribute.
|
api_server/__init__.py
ADDED
File without changes
|
api_server/routes/__init__.py
ADDED
File without changes
|
api_server/routes/internal/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# ComfyUI Internal Routes
|
2 |
+
|
3 |
+
All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.
|
api_server/routes/internal/__init__.py
ADDED
File without changes
|
api_server/routes/internal/internal_routes.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from aiohttp import web
|
2 |
+
from typing import Optional
|
3 |
+
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
|
4 |
+
from api_server.services.file_service import FileService
|
5 |
+
from api_server.services.terminal_service import TerminalService
|
6 |
+
import app.logger
|
7 |
+
|
8 |
+
class InternalRoutes:
|
9 |
+
'''
|
10 |
+
The top level web router for internal routes: /internal/*
|
11 |
+
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
|
12 |
+
Check README.md for more information.
|
13 |
+
'''
|
14 |
+
|
15 |
+
def __init__(self, prompt_server):
|
16 |
+
self.routes: web.RouteTableDef = web.RouteTableDef()
|
17 |
+
self._app: Optional[web.Application] = None
|
18 |
+
self.file_service = FileService({
|
19 |
+
"models": models_dir,
|
20 |
+
"user": user_directory,
|
21 |
+
"output": output_directory
|
22 |
+
})
|
23 |
+
self.prompt_server = prompt_server
|
24 |
+
self.terminal_service = TerminalService(prompt_server)
|
25 |
+
|
26 |
+
def setup_routes(self):
|
27 |
+
@self.routes.get('/files')
|
28 |
+
async def list_files(request):
|
29 |
+
directory_key = request.query.get('directory', '')
|
30 |
+
try:
|
31 |
+
file_list = self.file_service.list_files(directory_key)
|
32 |
+
return web.json_response({"files": file_list})
|
33 |
+
except ValueError as e:
|
34 |
+
return web.json_response({"error": str(e)}, status=400)
|
35 |
+
except Exception as e:
|
36 |
+
return web.json_response({"error": str(e)}, status=500)
|
37 |
+
|
38 |
+
@self.routes.get('/logs')
|
39 |
+
async def get_logs(request):
|
40 |
+
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
|
41 |
+
|
42 |
+
@self.routes.get('/logs/raw')
|
43 |
+
async def get_raw_logs(request):
|
44 |
+
self.terminal_service.update_size()
|
45 |
+
return web.json_response({
|
46 |
+
"entries": list(app.logger.get_logs()),
|
47 |
+
"size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
|
48 |
+
})
|
49 |
+
|
50 |
+
@self.routes.patch('/logs/subscribe')
|
51 |
+
async def subscribe_logs(request):
|
52 |
+
json_data = await request.json()
|
53 |
+
client_id = json_data["clientId"]
|
54 |
+
enabled = json_data["enabled"]
|
55 |
+
if enabled:
|
56 |
+
self.terminal_service.subscribe(client_id)
|
57 |
+
else:
|
58 |
+
self.terminal_service.unsubscribe(client_id)
|
59 |
+
|
60 |
+
return web.Response(status=200)
|
61 |
+
|
62 |
+
|
63 |
+
@self.routes.get('/folder_paths')
|
64 |
+
async def get_folder_paths(request):
|
65 |
+
response = {}
|
66 |
+
for key in folder_names_and_paths:
|
67 |
+
response[key] = folder_names_and_paths[key][0]
|
68 |
+
return web.json_response(response)
|
69 |
+
|
70 |
+
def get_app(self):
|
71 |
+
if self._app is None:
|
72 |
+
self._app = web.Application()
|
73 |
+
self.setup_routes()
|
74 |
+
self._app.add_routes(self.routes)
|
75 |
+
return self._app
|
api_server/services/__init__.py
ADDED
File without changes
|
api_server/services/file_service.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional
|
2 |
+
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
|
3 |
+
|
4 |
+
class FileService:
|
5 |
+
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
|
6 |
+
self.allowed_directories: Dict[str, str] = allowed_directories
|
7 |
+
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
|
8 |
+
|
9 |
+
def list_files(self, directory_key: str) -> List[FileSystemItem]:
|
10 |
+
if directory_key not in self.allowed_directories:
|
11 |
+
raise ValueError("Invalid directory key")
|
12 |
+
directory_path: str = self.allowed_directories[directory_key]
|
13 |
+
return self.file_system_ops.walk_directory(directory_path)
|
api_server/services/terminal_service.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from app.logger import on_flush
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
|
6 |
+
class TerminalService:
|
7 |
+
def __init__(self, server):
|
8 |
+
self.server = server
|
9 |
+
self.cols = None
|
10 |
+
self.rows = None
|
11 |
+
self.subscriptions = set()
|
12 |
+
on_flush(self.send_messages)
|
13 |
+
|
14 |
+
def get_terminal_size(self):
|
15 |
+
try:
|
16 |
+
size = os.get_terminal_size()
|
17 |
+
return (size.columns, size.lines)
|
18 |
+
except OSError:
|
19 |
+
try:
|
20 |
+
size = shutil.get_terminal_size()
|
21 |
+
return (size.columns, size.lines)
|
22 |
+
except OSError:
|
23 |
+
return (80, 24) # fallback to 80x24
|
24 |
+
|
25 |
+
def update_size(self):
|
26 |
+
columns, lines = self.get_terminal_size()
|
27 |
+
changed = False
|
28 |
+
|
29 |
+
if columns != self.cols:
|
30 |
+
self.cols = columns
|
31 |
+
changed = True
|
32 |
+
|
33 |
+
if lines != self.rows:
|
34 |
+
self.rows = lines
|
35 |
+
changed = True
|
36 |
+
|
37 |
+
if changed:
|
38 |
+
return {"cols": self.cols, "rows": self.rows}
|
39 |
+
|
40 |
+
return None
|
41 |
+
|
42 |
+
def subscribe(self, client_id):
|
43 |
+
self.subscriptions.add(client_id)
|
44 |
+
|
45 |
+
def unsubscribe(self, client_id):
|
46 |
+
self.subscriptions.discard(client_id)
|
47 |
+
|
48 |
+
def send_messages(self, entries):
|
49 |
+
if not len(entries) or not len(self.subscriptions):
|
50 |
+
return
|
51 |
+
|
52 |
+
new_size = self.update_size()
|
53 |
+
|
54 |
+
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
|
55 |
+
if client_id not in self.server.sockets:
|
56 |
+
# Automatically unsub if the socket has disconnected
|
57 |
+
self.unsubscribe(client_id)
|
58 |
+
continue
|
59 |
+
|
60 |
+
self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)
|
api_server/utils/file_operations.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Union, TypedDict, Literal
|
3 |
+
from typing_extensions import TypeGuard
|
4 |
+
class FileInfo(TypedDict):
|
5 |
+
name: str
|
6 |
+
path: str
|
7 |
+
type: Literal["file"]
|
8 |
+
size: int
|
9 |
+
|
10 |
+
class DirectoryInfo(TypedDict):
|
11 |
+
name: str
|
12 |
+
path: str
|
13 |
+
type: Literal["directory"]
|
14 |
+
|
15 |
+
FileSystemItem = Union[FileInfo, DirectoryInfo]
|
16 |
+
|
17 |
+
def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
|
18 |
+
return item["type"] == "file"
|
19 |
+
|
20 |
+
class FileSystemOperations:
|
21 |
+
@staticmethod
|
22 |
+
def walk_directory(directory: str) -> List[FileSystemItem]:
|
23 |
+
file_list: List[FileSystemItem] = []
|
24 |
+
for root, dirs, files in os.walk(directory):
|
25 |
+
for name in files:
|
26 |
+
file_path = os.path.join(root, name)
|
27 |
+
relative_path = os.path.relpath(file_path, directory)
|
28 |
+
file_list.append({
|
29 |
+
"name": name,
|
30 |
+
"path": relative_path,
|
31 |
+
"type": "file",
|
32 |
+
"size": os.path.getsize(file_path)
|
33 |
+
})
|
34 |
+
for name in dirs:
|
35 |
+
dir_path = os.path.join(root, name)
|
36 |
+
relative_path = os.path.relpath(dir_path, directory)
|
37 |
+
file_list.append({
|
38 |
+
"name": name,
|
39 |
+
"path": relative_path,
|
40 |
+
"type": "directory"
|
41 |
+
})
|
42 |
+
return file_list
|
app.py
ADDED
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#########################
|
2 |
+
# app.py for ZeroGPU #
|
3 |
+
#########################
|
4 |
+
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import random
|
8 |
+
import torch
|
9 |
+
import gradio as gr
|
10 |
+
import spaces # for ZeroGPU usage
|
11 |
+
from typing import Sequence, Mapping, Any, Union
|
12 |
+
|
13 |
+
# 1) Load your token from environment (make sure you set HF_TOKEN in Space settings)
|
14 |
+
token = os.environ["HF_TOKEN"]
|
15 |
+
|
16 |
+
# 2) We'll use huggingface_hub to download each gated model
|
17 |
+
from huggingface_hub import hf_hub_download
|
18 |
+
|
19 |
+
import shutil
|
20 |
+
import pathlib
|
21 |
+
|
22 |
+
# Create the directories we need (mirroring your ComfyUI structure)
|
23 |
+
pathlib.Path("ComfyUI/models/vae").mkdir(parents=True, exist_ok=True)
|
24 |
+
pathlib.Path("ComfyUI/models/clip").mkdir(parents=True, exist_ok=True)
|
25 |
+
pathlib.Path("ComfyUI/models/clip_vision").mkdir(parents=True, exist_ok=True)
|
26 |
+
pathlib.Path("ComfyUI/models/unet").mkdir(parents=True, exist_ok=True)
|
27 |
+
pathlib.Path("ComfyUI/models/loras").mkdir(parents=True, exist_ok=True)
|
28 |
+
pathlib.Path("ComfyUI/models/style_models").mkdir(parents=True, exist_ok=True)
|
29 |
+
|
30 |
+
# Download each gated model into the correct local folder
|
31 |
+
|
32 |
+
hf_hub_download(
|
33 |
+
repo_id="black-forest-labs/FLUX.1-dev",
|
34 |
+
filename="ae.safetensors",
|
35 |
+
local_dir="ComfyUI/models/vae",
|
36 |
+
use_auth_token=token
|
37 |
+
)
|
38 |
+
|
39 |
+
hf_hub_download(
|
40 |
+
repo_id="comfyanonymous/flux_text_encoders",
|
41 |
+
filename="t5xxl_fp16.safetensors",
|
42 |
+
local_dir="ComfyUI/models/clip",
|
43 |
+
use_auth_token=token
|
44 |
+
)
|
45 |
+
|
46 |
+
hf_hub_download(
|
47 |
+
repo_id="comfyanonymous/flux_text_encoders",
|
48 |
+
filename="clip_l.safetensors",
|
49 |
+
local_dir="ComfyUI/models/clip",
|
50 |
+
use_auth_token=token
|
51 |
+
)
|
52 |
+
|
53 |
+
hf_hub_download(
|
54 |
+
repo_id="black-forest-labs/FLUX.1-Fill-dev",
|
55 |
+
filename="flux1-fill-dev.safetensors",
|
56 |
+
local_dir="ComfyUI/models/unet",
|
57 |
+
use_auth_token=token
|
58 |
+
)
|
59 |
+
|
60 |
+
hf_hub_download(
|
61 |
+
repo_id="zhengchong/CatVTON",
|
62 |
+
filename="flux-lora/pytorch_lora_weights.safetensors",
|
63 |
+
local_dir="ComfyUI/models/loras",
|
64 |
+
# rename so it matches your code reference:
|
65 |
+
local_fname="catvton-flux-lora.safetensors",
|
66 |
+
use_auth_token=token
|
67 |
+
)
|
68 |
+
|
69 |
+
hf_hub_download(
|
70 |
+
repo_id="alimama-creative/FLUX.1-Turbo-Alpha",
|
71 |
+
filename="diffusion_pytorch_model.safetensors",
|
72 |
+
local_dir="ComfyUI/models/loras",
|
73 |
+
# rename so it matches your code reference:
|
74 |
+
local_fname="alimama-flux-turbo-alpha.safetensors",
|
75 |
+
use_auth_token=token
|
76 |
+
)
|
77 |
+
|
78 |
+
hf_hub_download(
|
79 |
+
repo_id="Comfy-Org/sigclip_vision_384",
|
80 |
+
filename="sigclip_vision_patch14_384.safetensors",
|
81 |
+
local_dir="ComfyUI/models/clip_vision",
|
82 |
+
use_auth_token=token
|
83 |
+
)
|
84 |
+
|
85 |
+
hf_hub_download(
|
86 |
+
repo_id="black-forest-labs/FLUX.1-Redux-dev",
|
87 |
+
filename="flux1-redux-dev.safetensors",
|
88 |
+
local_dir="ComfyUI/models/style_models",
|
89 |
+
use_auth_token=token
|
90 |
+
)
|
91 |
+
|
92 |
+
#############################
|
93 |
+
# ComfyUI support functions
|
94 |
+
#############################
|
95 |
+
|
96 |
+
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
|
97 |
+
"""Returns the value at the given index of a sequence or mapping."""
|
98 |
+
try:
|
99 |
+
return obj[index]
|
100 |
+
except KeyError:
|
101 |
+
return obj["result"][index]
|
102 |
+
|
103 |
+
def find_path(name: str, path: str = None) -> str:
|
104 |
+
import os
|
105 |
+
if path is None:
|
106 |
+
path = os.getcwd()
|
107 |
+
|
108 |
+
if name in os.listdir(path):
|
109 |
+
path_name = os.path.join(path, name)
|
110 |
+
print(f"{name} found: {path_name}")
|
111 |
+
return path_name
|
112 |
+
|
113 |
+
parent_directory = os.path.dirname(path)
|
114 |
+
if parent_directory == path:
|
115 |
+
return None
|
116 |
+
|
117 |
+
return find_path(name, parent_directory)
|
118 |
+
|
119 |
+
def add_comfyui_directory_to_sys_path() -> None:
|
120 |
+
comfyui_path = find_path("ComfyUI")
|
121 |
+
if comfyui_path is not None and os.path.isdir(comfyui_path):
|
122 |
+
sys.path.append(comfyui_path)
|
123 |
+
print(f"'{comfyui_path}' added to sys.path")
|
124 |
+
|
125 |
+
def add_extra_model_paths() -> None:
|
126 |
+
try:
|
127 |
+
from main import load_extra_path_config
|
128 |
+
except ImportError:
|
129 |
+
print("Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead.")
|
130 |
+
from utils.extra_config import load_extra_path_config
|
131 |
+
|
132 |
+
extra_model_paths = find_path("extra_model_paths.yaml")
|
133 |
+
if extra_model_paths is not None:
|
134 |
+
load_extra_path_config(extra_model_paths)
|
135 |
+
else:
|
136 |
+
print("Could not find the extra_model_paths config file.")
|
137 |
+
|
138 |
+
add_comfyui_directory_to_sys_path()
|
139 |
+
add_extra_model_paths()
|
140 |
+
|
141 |
+
def import_custom_nodes() -> None:
|
142 |
+
import asyncio
|
143 |
+
import execution
|
144 |
+
from nodes import init_extra_nodes
|
145 |
+
import server
|
146 |
+
|
147 |
+
loop = asyncio.new_event_loop()
|
148 |
+
asyncio.set_event_loop(loop)
|
149 |
+
server_instance = server.PromptServer(loop)
|
150 |
+
execution.PromptQueue(server_instance)
|
151 |
+
init_extra_nodes()
|
152 |
+
|
153 |
+
##########################
|
154 |
+
# Import node mappings
|
155 |
+
##########################
|
156 |
+
from nodes import NODE_CLASS_MAPPINGS
|
157 |
+
|
158 |
+
|
159 |
+
#############################################
|
160 |
+
# MAIN PIPELINE with ZeroGPU Decorator
|
161 |
+
#############################################
|
162 |
+
@spaces.GPU(duration=90) # 90s of GPU usage; adjust as needed
|
163 |
+
def generate_images(user_image_path):
|
164 |
+
"""
|
165 |
+
This function runs your node-based pipeline,
|
166 |
+
using user_image_path for loadimage_264 and
|
167 |
+
returning the final saveimage_295 path.
|
168 |
+
"""
|
169 |
+
import_custom_nodes()
|
170 |
+
with torch.inference_mode():
|
171 |
+
loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
|
172 |
+
loadimage_116 = loadimage.load_image(image="assets_black_tshirt.png")
|
173 |
+
|
174 |
+
loadautomasker = NODE_CLASS_MAPPINGS["LoadAutoMasker"]()
|
175 |
+
loadautomasker_120 = loadautomasker.load(catvton_path="zhengchong/CatVTON")
|
176 |
+
|
177 |
+
loadcatvtonpipeline = NODE_CLASS_MAPPINGS["LoadCatVTONPipeline"]()
|
178 |
+
loadcatvtonpipeline_123 = loadcatvtonpipeline.load(
|
179 |
+
sd15_inpaint_path="runwayml/stable-diffusion-inpainting",
|
180 |
+
catvton_path="zhengchong/CatVTON",
|
181 |
+
mixed_precision="bf16",
|
182 |
+
)
|
183 |
+
|
184 |
+
loadimage_264 = loadimage.load_image(
|
185 |
+
image=user_image_path
|
186 |
+
)
|
187 |
+
|
188 |
+
randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
|
189 |
+
randomnoise_273 = randomnoise.get_noise(noise_seed=random.randint(1, 2**64))
|
190 |
+
|
191 |
+
downloadandloadflorence2model = NODE_CLASS_MAPPINGS["DownloadAndLoadFlorence2Model"]()
|
192 |
+
downloadandloadflorence2model_274 = downloadandloadflorence2model.loadmodel(
|
193 |
+
model="gokaygokay/Florence-2-Flux-Large", precision="fp16", attention="sdpa"
|
194 |
+
)
|
195 |
+
|
196 |
+
automasker = NODE_CLASS_MAPPINGS["AutoMasker"]()
|
197 |
+
automasker_119 = automasker.generate(
|
198 |
+
cloth_type="overall",
|
199 |
+
pipe=get_value_at_index(loadautomasker_120, 0),
|
200 |
+
target_image=get_value_at_index(loadimage_264, 0),
|
201 |
+
)
|
202 |
+
|
203 |
+
catvton = NODE_CLASS_MAPPINGS["CatVTON"]()
|
204 |
+
catvton_121 = catvton.generate(
|
205 |
+
seed=random.randint(1, 2**64),
|
206 |
+
steps=50,
|
207 |
+
cfg=2.5,
|
208 |
+
pipe=get_value_at_index(loadcatvtonpipeline_123, 0),
|
209 |
+
target_image=get_value_at_index(loadimage_264, 0),
|
210 |
+
refer_image=get_value_at_index(loadimage_116, 0),
|
211 |
+
mask_image=get_value_at_index(automasker_119, 0),
|
212 |
+
)
|
213 |
+
|
214 |
+
florence2run = NODE_CLASS_MAPPINGS["Florence2Run"]()
|
215 |
+
florence2run_275 = florence2run.encode(
|
216 |
+
text_input="Haircut",
|
217 |
+
task="caption_to_phrase_grounding",
|
218 |
+
fill_mask=True,
|
219 |
+
keep_model_loaded=False,
|
220 |
+
max_new_tokens=1024,
|
221 |
+
num_beams=3,
|
222 |
+
do_sample=True,
|
223 |
+
output_mask_select="",
|
224 |
+
seed=random.randint(1, 2**64),
|
225 |
+
image=get_value_at_index(catvton_121, 0),
|
226 |
+
florence2_model=get_value_at_index(downloadandloadflorence2model_274, 0),
|
227 |
+
)
|
228 |
+
|
229 |
+
downloadandloadsam2model = NODE_CLASS_MAPPINGS["DownloadAndLoadSAM2Model"]()
|
230 |
+
downloadandloadsam2model_277 = downloadandloadsam2model.loadmodel(
|
231 |
+
model="sam2.1_hiera_large.safetensors",
|
232 |
+
segmentor="single_image",
|
233 |
+
device="cuda",
|
234 |
+
precision="fp16",
|
235 |
+
)
|
236 |
+
|
237 |
+
dualcliploadergguf = NODE_CLASS_MAPPINGS["DualCLIPLoaderGGUF"]()
|
238 |
+
dualcliploadergguf_284 = dualcliploadergguf.load_clip(
|
239 |
+
clip_name1="t5xxl_fp16.safetensors",
|
240 |
+
clip_name2="clip_l.safetensors",
|
241 |
+
type="flux",
|
242 |
+
)
|
243 |
+
|
244 |
+
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
|
245 |
+
cliptextencode_279 = cliptextencode.encode(
|
246 |
+
text="Br0k0L8, Broccoli haircut with voluminous, textured curls on top resembling broccoli florets, contrasted by closely shaved tapered sides",
|
247 |
+
clip=get_value_at_index(dualcliploadergguf_284, 0),
|
248 |
+
)
|
249 |
+
|
250 |
+
clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
|
251 |
+
clipvisionloader_281 = clipvisionloader.load_clip(
|
252 |
+
clip_name="sigclip_vision_patch14_384.safetensors"
|
253 |
+
)
|
254 |
+
|
255 |
+
loadimage_289 = loadimage.load_image(image="assets_broc_ref.jpg")
|
256 |
+
|
257 |
+
clipvisionencode = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]()
|
258 |
+
clipvisionencode_282 = clipvisionencode.encode(
|
259 |
+
crop="center",
|
260 |
+
clip_vision=get_value_at_index(clipvisionloader_281, 0),
|
261 |
+
image=get_value_at_index(loadimage_289, 0),
|
262 |
+
)
|
263 |
+
|
264 |
+
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
|
265 |
+
vaeloader_285 = vaeloader.load_vae(vae_name="ae.safetensors")
|
266 |
+
|
267 |
+
stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
|
268 |
+
stylemodelloader_292 = stylemodelloader.load_style_model(
|
269 |
+
style_model_name="flux1-redux-dev.safetensors"
|
270 |
+
)
|
271 |
+
|
272 |
+
stylemodelapply = NODE_CLASS_MAPPINGS["StyleModelApply"]()
|
273 |
+
stylemodelapply_280 = stylemodelapply.apply_stylemodel(
|
274 |
+
strength=1,
|
275 |
+
strength_type="multiply",
|
276 |
+
conditioning=get_value_at_index(cliptextencode_279, 0),
|
277 |
+
style_model=get_value_at_index(stylemodelloader_292, 0),
|
278 |
+
clip_vision_output=get_value_at_index(clipvisionencode_282, 0),
|
279 |
+
)
|
280 |
+
|
281 |
+
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
|
282 |
+
fluxguidance_288 = fluxguidance.append(
|
283 |
+
guidance=30, conditioning=get_value_at_index(stylemodelapply_280, 0)
|
284 |
+
)
|
285 |
+
|
286 |
+
conditioningzeroout = NODE_CLASS_MAPPINGS["ConditioningZeroOut"]()
|
287 |
+
conditioningzeroout_287 = conditioningzeroout.zero_out(
|
288 |
+
conditioning=get_value_at_index(fluxguidance_288, 0)
|
289 |
+
)
|
290 |
+
|
291 |
+
florence2tocoordinates = NODE_CLASS_MAPPINGS["Florence2toCoordinates"]()
|
292 |
+
florence2tocoordinates_276 = florence2tocoordinates.segment(
|
293 |
+
index="", batch=False, data=get_value_at_index(florence2run_275, 3)
|
294 |
+
)
|
295 |
+
|
296 |
+
sam2segmentation = NODE_CLASS_MAPPINGS["Sam2Segmentation"]()
|
297 |
+
sam2segmentation_278 = sam2segmentation.segment(
|
298 |
+
keep_model_loaded=False,
|
299 |
+
individual_objects=False,
|
300 |
+
sam2_model=get_value_at_index(downloadandloadsam2model_277, 0),
|
301 |
+
image=get_value_at_index(florence2run_275, 0),
|
302 |
+
bboxes=get_value_at_index(florence2tocoordinates_276, 1),
|
303 |
+
)
|
304 |
+
|
305 |
+
growmask = NODE_CLASS_MAPPINGS["GrowMask"]()
|
306 |
+
growmask_299 = growmask.expand_mask(
|
307 |
+
expand=35,
|
308 |
+
tapered_corners=True,
|
309 |
+
mask=get_value_at_index(sam2segmentation_278, 0),
|
310 |
+
)
|
311 |
+
|
312 |
+
layermask_segformerb2clothesultra = NODE_CLASS_MAPPINGS["LayerMask: SegformerB2ClothesUltra"]()
|
313 |
+
layermask_segformerb2clothesultra_300 = layermask_segformerb2clothesultra.segformer_ultra(
|
314 |
+
face=True,
|
315 |
+
hair=False,
|
316 |
+
hat=False,
|
317 |
+
sunglass=False,
|
318 |
+
left_arm=False,
|
319 |
+
right_arm=False,
|
320 |
+
left_leg=False,
|
321 |
+
right_leg=False,
|
322 |
+
upper_clothes=True,
|
323 |
+
skirt=False,
|
324 |
+
pants=False,
|
325 |
+
dress=False,
|
326 |
+
belt=False,
|
327 |
+
shoe=False,
|
328 |
+
bag=False,
|
329 |
+
scarf=True,
|
330 |
+
detail_method="VITMatte",
|
331 |
+
detail_erode=12,
|
332 |
+
detail_dilate=6,
|
333 |
+
black_point=0.15,
|
334 |
+
white_point=0.99,
|
335 |
+
process_detail=True,
|
336 |
+
device="cuda",
|
337 |
+
max_megapixels=2,
|
338 |
+
image=get_value_at_index(catvton_121, 0),
|
339 |
+
)
|
340 |
+
|
341 |
+
masks_subtract = NODE_CLASS_MAPPINGS["Masks Subtract"]()
|
342 |
+
masks_subtract_296 = masks_subtract.subtract_masks(
|
343 |
+
masks_a=get_value_at_index(growmask_299, 0),
|
344 |
+
masks_b=get_value_at_index(layermask_segformerb2clothesultra_300, 1),
|
345 |
+
)
|
346 |
+
|
347 |
+
inpaintmodelconditioning = NODE_CLASS_MAPPINGS["InpaintModelConditioning"]()
|
348 |
+
inpaintmodelconditioning_286 = inpaintmodelconditioning.encode(
|
349 |
+
noise_mask=True,
|
350 |
+
positive=get_value_at_index(fluxguidance_288, 0),
|
351 |
+
negative=get_value_at_index(conditioningzeroout_287, 0),
|
352 |
+
vae=get_value_at_index(vaeloader_285, 0),
|
353 |
+
pixels=get_value_at_index(catvton_121, 0),
|
354 |
+
mask=get_value_at_index(masks_subtract_296, 0),
|
355 |
+
)
|
356 |
+
|
357 |
+
unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
|
358 |
+
unetloader_291 = unetloader.load_unet(
|
359 |
+
unet_name="flux1-fill-dev.safetensors", weight_dtype="default"
|
360 |
+
)
|
361 |
+
|
362 |
+
loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
|
363 |
+
loraloadermodelonly_290 = loraloadermodelonly.load_lora_model_only(
|
364 |
+
lora_name="alimama-flux-turbo-alpha.safetensors",
|
365 |
+
strength_model=1,
|
366 |
+
model=get_value_at_index(unetloader_291, 0),
|
367 |
+
)
|
368 |
+
|
369 |
+
ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
|
370 |
+
vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
|
371 |
+
saveimage = NODE_CLASS_MAPPINGS["SaveImage"]()
|
372 |
+
|
373 |
+
# We'll do a single pass
|
374 |
+
for q in range(1):
|
375 |
+
ksampler_283 = ksampler.sample(
|
376 |
+
seed=random.randint(1, 2**64),
|
377 |
+
steps=10,
|
378 |
+
cfg=1,
|
379 |
+
sampler_name="dpmpp_2m",
|
380 |
+
scheduler="sgm_uniform",
|
381 |
+
denoise=1,
|
382 |
+
model=get_value_at_index(loraloadermodelonly_290, 0),
|
383 |
+
positive=get_value_at_index(inpaintmodelconditioning_286, 0),
|
384 |
+
negative=get_value_at_index(inpaintmodelconditioning_286, 1),
|
385 |
+
latent_image=get_value_at_index(inpaintmodelconditioning_286, 2),
|
386 |
+
)
|
387 |
+
|
388 |
+
vaedecode_294 = vaedecode.decode(
|
389 |
+
samples=get_value_at_index(ksampler_283, 0),
|
390 |
+
vae=get_value_at_index(vaeloader_285, 0),
|
391 |
+
)
|
392 |
+
|
393 |
+
saveimage_295 = saveimage.save_images(
|
394 |
+
filename_prefix="The_Broccolator_",
|
395 |
+
images=get_value_at_index(vaedecode_294, 0),
|
396 |
+
)
|
397 |
+
|
398 |
+
# final_output_path is the only one we return
|
399 |
+
final_output_path = f"output/{saveimage_295['ui']['images'][0]['filename']}"
|
400 |
+
return final_output_path
|
401 |
+
|
402 |
+
|
403 |
+
###################################
|
404 |
+
# A simple Gradio interface
|
405 |
+
###################################
|
406 |
+
with gr.Blocks() as demo:
|
407 |
+
gr.Markdown("## The Broccolator 🥦\nUpload an image for `loadimage_264` and see final output.")
|
408 |
+
with gr.Row():
|
409 |
+
with gr.Column():
|
410 |
+
user_input_image = gr.Image(type="filepath", label="Input Image")
|
411 |
+
btn_generate = gr.Button("Generate")
|
412 |
+
with gr.Column():
|
413 |
+
final_image = gr.Image(label="Final output (saveimage_295)")
|
414 |
+
|
415 |
+
btn_generate.click(
|
416 |
+
fn=generate_images,
|
417 |
+
inputs=user_input_image,
|
418 |
+
outputs=final_image
|
419 |
+
)
|
420 |
+
|
421 |
+
demo.launch(debug=True)
|
comfy_execution/caching.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
from typing import Sequence, Mapping, Dict
|
3 |
+
from comfy_execution.graph import DynamicPrompt
|
4 |
+
|
5 |
+
import nodes
|
6 |
+
|
7 |
+
from comfy_execution.graph_utils import is_link
|
8 |
+
|
9 |
+
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
|
10 |
+
|
11 |
+
|
12 |
+
def include_unique_id_in_input(class_type: str) -> bool:
|
13 |
+
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
|
14 |
+
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
15 |
+
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
16 |
+
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
|
17 |
+
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
18 |
+
|
19 |
+
class CacheKeySet:
|
20 |
+
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
21 |
+
self.keys = {}
|
22 |
+
self.subcache_keys = {}
|
23 |
+
|
24 |
+
def add_keys(self, node_ids):
|
25 |
+
raise NotImplementedError()
|
26 |
+
|
27 |
+
def all_node_ids(self):
|
28 |
+
return set(self.keys.keys())
|
29 |
+
|
30 |
+
def get_used_keys(self):
|
31 |
+
return self.keys.values()
|
32 |
+
|
33 |
+
def get_used_subcache_keys(self):
|
34 |
+
return self.subcache_keys.values()
|
35 |
+
|
36 |
+
def get_data_key(self, node_id):
|
37 |
+
return self.keys.get(node_id, None)
|
38 |
+
|
39 |
+
def get_subcache_key(self, node_id):
|
40 |
+
return self.subcache_keys.get(node_id, None)
|
41 |
+
|
42 |
+
class Unhashable:
|
43 |
+
def __init__(self):
|
44 |
+
self.value = float("NaN")
|
45 |
+
|
46 |
+
def to_hashable(obj):
|
47 |
+
# So that we don't infinitely recurse since frozenset and tuples
|
48 |
+
# are Sequences.
|
49 |
+
if isinstance(obj, (int, float, str, bool, type(None))):
|
50 |
+
return obj
|
51 |
+
elif isinstance(obj, Mapping):
|
52 |
+
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
53 |
+
elif isinstance(obj, Sequence):
|
54 |
+
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
|
55 |
+
else:
|
56 |
+
# TODO - Support other objects like tensors?
|
57 |
+
return Unhashable()
|
58 |
+
|
59 |
+
class CacheKeySetID(CacheKeySet):
|
60 |
+
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
61 |
+
super().__init__(dynprompt, node_ids, is_changed_cache)
|
62 |
+
self.dynprompt = dynprompt
|
63 |
+
self.add_keys(node_ids)
|
64 |
+
|
65 |
+
def add_keys(self, node_ids):
|
66 |
+
for node_id in node_ids:
|
67 |
+
if node_id in self.keys:
|
68 |
+
continue
|
69 |
+
if not self.dynprompt.has_node(node_id):
|
70 |
+
continue
|
71 |
+
node = self.dynprompt.get_node(node_id)
|
72 |
+
self.keys[node_id] = (node_id, node["class_type"])
|
73 |
+
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
74 |
+
|
75 |
+
class CacheKeySetInputSignature(CacheKeySet):
|
76 |
+
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
77 |
+
super().__init__(dynprompt, node_ids, is_changed_cache)
|
78 |
+
self.dynprompt = dynprompt
|
79 |
+
self.is_changed_cache = is_changed_cache
|
80 |
+
self.add_keys(node_ids)
|
81 |
+
|
82 |
+
def include_node_id_in_input(self) -> bool:
|
83 |
+
return False
|
84 |
+
|
85 |
+
def add_keys(self, node_ids):
|
86 |
+
for node_id in node_ids:
|
87 |
+
if node_id in self.keys:
|
88 |
+
continue
|
89 |
+
if not self.dynprompt.has_node(node_id):
|
90 |
+
continue
|
91 |
+
node = self.dynprompt.get_node(node_id)
|
92 |
+
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
93 |
+
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
94 |
+
|
95 |
+
def get_node_signature(self, dynprompt, node_id):
|
96 |
+
signature = []
|
97 |
+
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
98 |
+
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
99 |
+
for ancestor_id in ancestors:
|
100 |
+
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
101 |
+
return to_hashable(signature)
|
102 |
+
|
103 |
+
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
104 |
+
if not dynprompt.has_node(node_id):
|
105 |
+
# This node doesn't exist -- we can't cache it.
|
106 |
+
return [float("NaN")]
|
107 |
+
node = dynprompt.get_node(node_id)
|
108 |
+
class_type = node["class_type"]
|
109 |
+
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
110 |
+
signature = [class_type, self.is_changed_cache.get(node_id)]
|
111 |
+
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
112 |
+
signature.append(node_id)
|
113 |
+
inputs = node["inputs"]
|
114 |
+
for key in sorted(inputs.keys()):
|
115 |
+
if is_link(inputs[key]):
|
116 |
+
(ancestor_id, ancestor_socket) = inputs[key]
|
117 |
+
ancestor_index = ancestor_order_mapping[ancestor_id]
|
118 |
+
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
119 |
+
else:
|
120 |
+
signature.append((key, inputs[key]))
|
121 |
+
return signature
|
122 |
+
|
123 |
+
# This function returns a list of all ancestors of the given node. The order of the list is
|
124 |
+
# deterministic based on which specific inputs the ancestor is connected by.
|
125 |
+
def get_ordered_ancestry(self, dynprompt, node_id):
|
126 |
+
ancestors = []
|
127 |
+
order_mapping = {}
|
128 |
+
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
|
129 |
+
return ancestors, order_mapping
|
130 |
+
|
131 |
+
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
132 |
+
if not dynprompt.has_node(node_id):
|
133 |
+
return
|
134 |
+
inputs = dynprompt.get_node(node_id)["inputs"]
|
135 |
+
input_keys = sorted(inputs.keys())
|
136 |
+
for key in input_keys:
|
137 |
+
if is_link(inputs[key]):
|
138 |
+
ancestor_id = inputs[key][0]
|
139 |
+
if ancestor_id not in order_mapping:
|
140 |
+
ancestors.append(ancestor_id)
|
141 |
+
order_mapping[ancestor_id] = len(ancestors) - 1
|
142 |
+
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
143 |
+
|
144 |
+
class BasicCache:
|
145 |
+
def __init__(self, key_class):
|
146 |
+
self.key_class = key_class
|
147 |
+
self.initialized = False
|
148 |
+
self.dynprompt: DynamicPrompt
|
149 |
+
self.cache_key_set: CacheKeySet
|
150 |
+
self.cache = {}
|
151 |
+
self.subcaches = {}
|
152 |
+
|
153 |
+
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
154 |
+
self.dynprompt = dynprompt
|
155 |
+
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
156 |
+
self.is_changed_cache = is_changed_cache
|
157 |
+
self.initialized = True
|
158 |
+
|
159 |
+
def all_node_ids(self):
|
160 |
+
assert self.initialized
|
161 |
+
node_ids = self.cache_key_set.all_node_ids()
|
162 |
+
for subcache in self.subcaches.values():
|
163 |
+
node_ids = node_ids.union(subcache.all_node_ids())
|
164 |
+
return node_ids
|
165 |
+
|
166 |
+
def _clean_cache(self):
|
167 |
+
preserve_keys = set(self.cache_key_set.get_used_keys())
|
168 |
+
to_remove = []
|
169 |
+
for key in self.cache:
|
170 |
+
if key not in preserve_keys:
|
171 |
+
to_remove.append(key)
|
172 |
+
for key in to_remove:
|
173 |
+
del self.cache[key]
|
174 |
+
|
175 |
+
def _clean_subcaches(self):
|
176 |
+
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
|
177 |
+
|
178 |
+
to_remove = []
|
179 |
+
for key in self.subcaches:
|
180 |
+
if key not in preserve_subcaches:
|
181 |
+
to_remove.append(key)
|
182 |
+
for key in to_remove:
|
183 |
+
del self.subcaches[key]
|
184 |
+
|
185 |
+
def clean_unused(self):
|
186 |
+
assert self.initialized
|
187 |
+
self._clean_cache()
|
188 |
+
self._clean_subcaches()
|
189 |
+
|
190 |
+
def _set_immediate(self, node_id, value):
|
191 |
+
assert self.initialized
|
192 |
+
cache_key = self.cache_key_set.get_data_key(node_id)
|
193 |
+
self.cache[cache_key] = value
|
194 |
+
|
195 |
+
def _get_immediate(self, node_id):
|
196 |
+
if not self.initialized:
|
197 |
+
return None
|
198 |
+
cache_key = self.cache_key_set.get_data_key(node_id)
|
199 |
+
if cache_key in self.cache:
|
200 |
+
return self.cache[cache_key]
|
201 |
+
else:
|
202 |
+
return None
|
203 |
+
|
204 |
+
def _ensure_subcache(self, node_id, children_ids):
|
205 |
+
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
206 |
+
subcache = self.subcaches.get(subcache_key, None)
|
207 |
+
if subcache is None:
|
208 |
+
subcache = BasicCache(self.key_class)
|
209 |
+
self.subcaches[subcache_key] = subcache
|
210 |
+
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
211 |
+
return subcache
|
212 |
+
|
213 |
+
def _get_subcache(self, node_id):
|
214 |
+
assert self.initialized
|
215 |
+
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
216 |
+
if subcache_key in self.subcaches:
|
217 |
+
return self.subcaches[subcache_key]
|
218 |
+
else:
|
219 |
+
return None
|
220 |
+
|
221 |
+
def recursive_debug_dump(self):
|
222 |
+
result = []
|
223 |
+
for key in self.cache:
|
224 |
+
result.append({"key": key, "value": self.cache[key]})
|
225 |
+
for key in self.subcaches:
|
226 |
+
result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
|
227 |
+
return result
|
228 |
+
|
229 |
+
class HierarchicalCache(BasicCache):
|
230 |
+
def __init__(self, key_class):
|
231 |
+
super().__init__(key_class)
|
232 |
+
|
233 |
+
def _get_cache_for(self, node_id):
|
234 |
+
assert self.dynprompt is not None
|
235 |
+
parent_id = self.dynprompt.get_parent_node_id(node_id)
|
236 |
+
if parent_id is None:
|
237 |
+
return self
|
238 |
+
|
239 |
+
hierarchy = []
|
240 |
+
while parent_id is not None:
|
241 |
+
hierarchy.append(parent_id)
|
242 |
+
parent_id = self.dynprompt.get_parent_node_id(parent_id)
|
243 |
+
|
244 |
+
cache = self
|
245 |
+
for parent_id in reversed(hierarchy):
|
246 |
+
cache = cache._get_subcache(parent_id)
|
247 |
+
if cache is None:
|
248 |
+
return None
|
249 |
+
return cache
|
250 |
+
|
251 |
+
def get(self, node_id):
|
252 |
+
cache = self._get_cache_for(node_id)
|
253 |
+
if cache is None:
|
254 |
+
return None
|
255 |
+
return cache._get_immediate(node_id)
|
256 |
+
|
257 |
+
def set(self, node_id, value):
|
258 |
+
cache = self._get_cache_for(node_id)
|
259 |
+
assert cache is not None
|
260 |
+
cache._set_immediate(node_id, value)
|
261 |
+
|
262 |
+
def ensure_subcache_for(self, node_id, children_ids):
|
263 |
+
cache = self._get_cache_for(node_id)
|
264 |
+
assert cache is not None
|
265 |
+
return cache._ensure_subcache(node_id, children_ids)
|
266 |
+
|
267 |
+
class LRUCache(BasicCache):
|
268 |
+
def __init__(self, key_class, max_size=100):
|
269 |
+
super().__init__(key_class)
|
270 |
+
self.max_size = max_size
|
271 |
+
self.min_generation = 0
|
272 |
+
self.generation = 0
|
273 |
+
self.used_generation = {}
|
274 |
+
self.children = {}
|
275 |
+
|
276 |
+
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
277 |
+
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
278 |
+
self.generation += 1
|
279 |
+
for node_id in node_ids:
|
280 |
+
self._mark_used(node_id)
|
281 |
+
|
282 |
+
def clean_unused(self):
|
283 |
+
while len(self.cache) > self.max_size and self.min_generation < self.generation:
|
284 |
+
self.min_generation += 1
|
285 |
+
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
|
286 |
+
for key in to_remove:
|
287 |
+
del self.cache[key]
|
288 |
+
del self.used_generation[key]
|
289 |
+
if key in self.children:
|
290 |
+
del self.children[key]
|
291 |
+
self._clean_subcaches()
|
292 |
+
|
293 |
+
def get(self, node_id):
|
294 |
+
self._mark_used(node_id)
|
295 |
+
return self._get_immediate(node_id)
|
296 |
+
|
297 |
+
def _mark_used(self, node_id):
|
298 |
+
cache_key = self.cache_key_set.get_data_key(node_id)
|
299 |
+
if cache_key is not None:
|
300 |
+
self.used_generation[cache_key] = self.generation
|
301 |
+
|
302 |
+
def set(self, node_id, value):
|
303 |
+
self._mark_used(node_id)
|
304 |
+
return self._set_immediate(node_id, value)
|
305 |
+
|
306 |
+
def ensure_subcache_for(self, node_id, children_ids):
|
307 |
+
# Just uses subcaches for tracking 'live' nodes
|
308 |
+
super()._ensure_subcache(node_id, children_ids)
|
309 |
+
|
310 |
+
self.cache_key_set.add_keys(children_ids)
|
311 |
+
self._mark_used(node_id)
|
312 |
+
cache_key = self.cache_key_set.get_data_key(node_id)
|
313 |
+
self.children[cache_key] = []
|
314 |
+
for child_id in children_ids:
|
315 |
+
self._mark_used(child_id)
|
316 |
+
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
317 |
+
return self
|
318 |
+
|
comfy_execution/graph.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nodes
|
2 |
+
|
3 |
+
from comfy_execution.graph_utils import is_link
|
4 |
+
|
5 |
+
class DependencyCycleError(Exception):
|
6 |
+
pass
|
7 |
+
|
8 |
+
class NodeInputError(Exception):
|
9 |
+
pass
|
10 |
+
|
11 |
+
class NodeNotFoundError(Exception):
|
12 |
+
pass
|
13 |
+
|
14 |
+
class DynamicPrompt:
|
15 |
+
def __init__(self, original_prompt):
|
16 |
+
# The original prompt provided by the user
|
17 |
+
self.original_prompt = original_prompt
|
18 |
+
# Any extra pieces of the graph created during execution
|
19 |
+
self.ephemeral_prompt = {}
|
20 |
+
self.ephemeral_parents = {}
|
21 |
+
self.ephemeral_display = {}
|
22 |
+
|
23 |
+
def get_node(self, node_id):
|
24 |
+
if node_id in self.ephemeral_prompt:
|
25 |
+
return self.ephemeral_prompt[node_id]
|
26 |
+
if node_id in self.original_prompt:
|
27 |
+
return self.original_prompt[node_id]
|
28 |
+
raise NodeNotFoundError(f"Node {node_id} not found")
|
29 |
+
|
30 |
+
def has_node(self, node_id):
|
31 |
+
return node_id in self.original_prompt or node_id in self.ephemeral_prompt
|
32 |
+
|
33 |
+
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
|
34 |
+
self.ephemeral_prompt[node_id] = node_info
|
35 |
+
self.ephemeral_parents[node_id] = parent_id
|
36 |
+
self.ephemeral_display[node_id] = display_id
|
37 |
+
|
38 |
+
def get_real_node_id(self, node_id):
|
39 |
+
while node_id in self.ephemeral_parents:
|
40 |
+
node_id = self.ephemeral_parents[node_id]
|
41 |
+
return node_id
|
42 |
+
|
43 |
+
def get_parent_node_id(self, node_id):
|
44 |
+
return self.ephemeral_parents.get(node_id, None)
|
45 |
+
|
46 |
+
def get_display_node_id(self, node_id):
|
47 |
+
while node_id in self.ephemeral_display:
|
48 |
+
node_id = self.ephemeral_display[node_id]
|
49 |
+
return node_id
|
50 |
+
|
51 |
+
def all_node_ids(self):
|
52 |
+
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
53 |
+
|
54 |
+
def get_original_prompt(self):
|
55 |
+
return self.original_prompt
|
56 |
+
|
57 |
+
def get_input_info(class_def, input_name, valid_inputs=None):
|
58 |
+
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
|
59 |
+
input_info = None
|
60 |
+
input_category = None
|
61 |
+
if "required" in valid_inputs and input_name in valid_inputs["required"]:
|
62 |
+
input_category = "required"
|
63 |
+
input_info = valid_inputs["required"][input_name]
|
64 |
+
elif "optional" in valid_inputs and input_name in valid_inputs["optional"]:
|
65 |
+
input_category = "optional"
|
66 |
+
input_info = valid_inputs["optional"][input_name]
|
67 |
+
elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]:
|
68 |
+
input_category = "hidden"
|
69 |
+
input_info = valid_inputs["hidden"][input_name]
|
70 |
+
if input_info is None:
|
71 |
+
return None, None, None
|
72 |
+
input_type = input_info[0]
|
73 |
+
if len(input_info) > 1:
|
74 |
+
extra_info = input_info[1]
|
75 |
+
else:
|
76 |
+
extra_info = {}
|
77 |
+
return input_type, input_category, extra_info
|
78 |
+
|
79 |
+
class TopologicalSort:
|
80 |
+
def __init__(self, dynprompt):
|
81 |
+
self.dynprompt = dynprompt
|
82 |
+
self.pendingNodes = {}
|
83 |
+
self.blockCount = {} # Number of nodes this node is directly blocked by
|
84 |
+
self.blocking = {} # Which nodes are blocked by this node
|
85 |
+
|
86 |
+
def get_input_info(self, unique_id, input_name):
|
87 |
+
class_type = self.dynprompt.get_node(unique_id)["class_type"]
|
88 |
+
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
89 |
+
return get_input_info(class_def, input_name)
|
90 |
+
|
91 |
+
def make_input_strong_link(self, to_node_id, to_input):
|
92 |
+
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
|
93 |
+
if to_input not in inputs:
|
94 |
+
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all")
|
95 |
+
value = inputs[to_input]
|
96 |
+
if not is_link(value):
|
97 |
+
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant")
|
98 |
+
from_node_id, from_socket = value
|
99 |
+
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
100 |
+
|
101 |
+
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
102 |
+
if not self.is_cached(from_node_id):
|
103 |
+
self.add_node(from_node_id)
|
104 |
+
if to_node_id not in self.blocking[from_node_id]:
|
105 |
+
self.blocking[from_node_id][to_node_id] = {}
|
106 |
+
self.blockCount[to_node_id] += 1
|
107 |
+
self.blocking[from_node_id][to_node_id][from_socket] = True
|
108 |
+
|
109 |
+
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
|
110 |
+
node_ids = [node_unique_id]
|
111 |
+
links = []
|
112 |
+
|
113 |
+
while len(node_ids) > 0:
|
114 |
+
unique_id = node_ids.pop()
|
115 |
+
if unique_id in self.pendingNodes:
|
116 |
+
continue
|
117 |
+
|
118 |
+
self.pendingNodes[unique_id] = True
|
119 |
+
self.blockCount[unique_id] = 0
|
120 |
+
self.blocking[unique_id] = {}
|
121 |
+
|
122 |
+
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
123 |
+
for input_name in inputs:
|
124 |
+
value = inputs[input_name]
|
125 |
+
if is_link(value):
|
126 |
+
from_node_id, from_socket = value
|
127 |
+
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
128 |
+
continue
|
129 |
+
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
130 |
+
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
131 |
+
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
132 |
+
node_ids.append(from_node_id)
|
133 |
+
links.append((from_node_id, from_socket, unique_id))
|
134 |
+
|
135 |
+
for link in links:
|
136 |
+
self.add_strong_link(*link)
|
137 |
+
|
138 |
+
def is_cached(self, node_id):
|
139 |
+
return False
|
140 |
+
|
141 |
+
def get_ready_nodes(self):
|
142 |
+
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
143 |
+
|
144 |
+
def pop_node(self, unique_id):
|
145 |
+
del self.pendingNodes[unique_id]
|
146 |
+
for blocked_node_id in self.blocking[unique_id]:
|
147 |
+
self.blockCount[blocked_node_id] -= 1
|
148 |
+
del self.blocking[unique_id]
|
149 |
+
|
150 |
+
def is_empty(self):
|
151 |
+
return len(self.pendingNodes) == 0
|
152 |
+
|
153 |
+
class ExecutionList(TopologicalSort):
|
154 |
+
"""
|
155 |
+
ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
|
156 |
+
it can still be returned to the graph after having further dependencies added.
|
157 |
+
"""
|
158 |
+
def __init__(self, dynprompt, output_cache):
|
159 |
+
super().__init__(dynprompt)
|
160 |
+
self.output_cache = output_cache
|
161 |
+
self.staged_node_id = None
|
162 |
+
|
163 |
+
def is_cached(self, node_id):
|
164 |
+
return self.output_cache.get(node_id) is not None
|
165 |
+
|
166 |
+
def stage_node_execution(self):
|
167 |
+
assert self.staged_node_id is None
|
168 |
+
if self.is_empty():
|
169 |
+
return None, None, None
|
170 |
+
available = self.get_ready_nodes()
|
171 |
+
if len(available) == 0:
|
172 |
+
cycled_nodes = self.get_nodes_in_cycle()
|
173 |
+
# Because cycles composed entirely of static nodes are caught during initial validation,
|
174 |
+
# we will 'blame' the first node in the cycle that is not a static node.
|
175 |
+
blamed_node = cycled_nodes[0]
|
176 |
+
for node_id in cycled_nodes:
|
177 |
+
display_node_id = self.dynprompt.get_display_node_id(node_id)
|
178 |
+
if display_node_id != node_id:
|
179 |
+
blamed_node = display_node_id
|
180 |
+
break
|
181 |
+
ex = DependencyCycleError("Dependency cycle detected")
|
182 |
+
error_details = {
|
183 |
+
"node_id": blamed_node,
|
184 |
+
"exception_message": str(ex),
|
185 |
+
"exception_type": "graph.DependencyCycleError",
|
186 |
+
"traceback": [],
|
187 |
+
"current_inputs": []
|
188 |
+
}
|
189 |
+
return None, error_details, ex
|
190 |
+
|
191 |
+
self.staged_node_id = self.ux_friendly_pick_node(available)
|
192 |
+
return self.staged_node_id, None, None
|
193 |
+
|
194 |
+
def ux_friendly_pick_node(self, node_list):
|
195 |
+
# If an output node is available, do that first.
|
196 |
+
# Technically this has no effect on the overall length of execution, but it feels better as a user
|
197 |
+
# for a PreviewImage to display a result as soon as it can
|
198 |
+
# Some other heuristics could probably be used here to improve the UX further.
|
199 |
+
def is_output(node_id):
|
200 |
+
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
201 |
+
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
202 |
+
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
203 |
+
return True
|
204 |
+
return False
|
205 |
+
|
206 |
+
for node_id in node_list:
|
207 |
+
if is_output(node_id):
|
208 |
+
return node_id
|
209 |
+
|
210 |
+
#This should handle the VAEDecode -> preview case
|
211 |
+
for node_id in node_list:
|
212 |
+
for blocked_node_id in self.blocking[node_id]:
|
213 |
+
if is_output(blocked_node_id):
|
214 |
+
return node_id
|
215 |
+
|
216 |
+
#This should handle the VAELoader -> VAEDecode -> preview case
|
217 |
+
for node_id in node_list:
|
218 |
+
for blocked_node_id in self.blocking[node_id]:
|
219 |
+
for blocked_node_id1 in self.blocking[blocked_node_id]:
|
220 |
+
if is_output(blocked_node_id1):
|
221 |
+
return node_id
|
222 |
+
|
223 |
+
#TODO: this function should be improved
|
224 |
+
return node_list[0]
|
225 |
+
|
226 |
+
def unstage_node_execution(self):
|
227 |
+
assert self.staged_node_id is not None
|
228 |
+
self.staged_node_id = None
|
229 |
+
|
230 |
+
def complete_node_execution(self):
|
231 |
+
node_id = self.staged_node_id
|
232 |
+
self.pop_node(node_id)
|
233 |
+
self.staged_node_id = None
|
234 |
+
|
235 |
+
def get_nodes_in_cycle(self):
|
236 |
+
# We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle.
|
237 |
+
# We're skipping some of the performance optimizations from the original TopologicalSort to keep
|
238 |
+
# the code simple (and because having a cycle in the first place is a catastrophic error)
|
239 |
+
blocked_by = { node_id: {} for node_id in self.pendingNodes }
|
240 |
+
for from_node_id in self.blocking:
|
241 |
+
for to_node_id in self.blocking[from_node_id]:
|
242 |
+
if True in self.blocking[from_node_id][to_node_id].values():
|
243 |
+
blocked_by[to_node_id][from_node_id] = True
|
244 |
+
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
245 |
+
while len(to_remove) > 0:
|
246 |
+
for node_id in to_remove:
|
247 |
+
for to_node_id in blocked_by:
|
248 |
+
if node_id in blocked_by[to_node_id]:
|
249 |
+
del blocked_by[to_node_id][node_id]
|
250 |
+
del blocked_by[node_id]
|
251 |
+
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
252 |
+
return list(blocked_by.keys())
|
253 |
+
|
254 |
+
class ExecutionBlocker:
|
255 |
+
"""
|
256 |
+
Return this from a node and any users will be blocked with the given error message.
|
257 |
+
If the message is None, execution will be blocked silently instead.
|
258 |
+
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
259 |
+
possible, a lazy input will be more efficient and have a better user experience.
|
260 |
+
This functionality is useful in two cases:
|
261 |
+
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
262 |
+
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
263 |
+
lazy evaluation to let it conditionally disable itself.)
|
264 |
+
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
265 |
+
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
266 |
+
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
267 |
+
"""
|
268 |
+
def __init__(self, message):
|
269 |
+
self.message = message
|
270 |
+
|
comfy_execution/graph_utils.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def is_link(obj):
|
2 |
+
if not isinstance(obj, list):
|
3 |
+
return False
|
4 |
+
if len(obj) != 2:
|
5 |
+
return False
|
6 |
+
if not isinstance(obj[0], str):
|
7 |
+
return False
|
8 |
+
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
|
9 |
+
return False
|
10 |
+
return True
|
11 |
+
|
12 |
+
# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
|
13 |
+
class GraphBuilder:
|
14 |
+
_default_prefix_root = ""
|
15 |
+
_default_prefix_call_index = 0
|
16 |
+
_default_prefix_graph_index = 0
|
17 |
+
|
18 |
+
def __init__(self, prefix = None):
|
19 |
+
if prefix is None:
|
20 |
+
self.prefix = GraphBuilder.alloc_prefix()
|
21 |
+
else:
|
22 |
+
self.prefix = prefix
|
23 |
+
self.nodes = {}
|
24 |
+
self.id_gen = 1
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def set_default_prefix(cls, prefix_root, call_index, graph_index = 0):
|
28 |
+
cls._default_prefix_root = prefix_root
|
29 |
+
cls._default_prefix_call_index = call_index
|
30 |
+
cls._default_prefix_graph_index = graph_index
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def alloc_prefix(cls, root=None, call_index=None, graph_index=None):
|
34 |
+
if root is None:
|
35 |
+
root = GraphBuilder._default_prefix_root
|
36 |
+
if call_index is None:
|
37 |
+
call_index = GraphBuilder._default_prefix_call_index
|
38 |
+
if graph_index is None:
|
39 |
+
graph_index = GraphBuilder._default_prefix_graph_index
|
40 |
+
result = f"{root}.{call_index}.{graph_index}."
|
41 |
+
GraphBuilder._default_prefix_graph_index += 1
|
42 |
+
return result
|
43 |
+
|
44 |
+
def node(self, class_type, id=None, **kwargs):
|
45 |
+
if id is None:
|
46 |
+
id = str(self.id_gen)
|
47 |
+
self.id_gen += 1
|
48 |
+
id = self.prefix + id
|
49 |
+
if id in self.nodes:
|
50 |
+
return self.nodes[id]
|
51 |
+
|
52 |
+
node = Node(id, class_type, kwargs)
|
53 |
+
self.nodes[id] = node
|
54 |
+
return node
|
55 |
+
|
56 |
+
def lookup_node(self, id):
|
57 |
+
id = self.prefix + id
|
58 |
+
return self.nodes.get(id)
|
59 |
+
|
60 |
+
def finalize(self):
|
61 |
+
output = {}
|
62 |
+
for node_id, node in self.nodes.items():
|
63 |
+
output[node_id] = node.serialize()
|
64 |
+
return output
|
65 |
+
|
66 |
+
def replace_node_output(self, node_id, index, new_value):
|
67 |
+
node_id = self.prefix + node_id
|
68 |
+
to_remove = []
|
69 |
+
for node in self.nodes.values():
|
70 |
+
for key, value in node.inputs.items():
|
71 |
+
if is_link(value) and value[0] == node_id and value[1] == index:
|
72 |
+
if new_value is None:
|
73 |
+
to_remove.append((node, key))
|
74 |
+
else:
|
75 |
+
node.inputs[key] = new_value
|
76 |
+
for node, key in to_remove:
|
77 |
+
del node.inputs[key]
|
78 |
+
|
79 |
+
def remove_node(self, id):
|
80 |
+
id = self.prefix + id
|
81 |
+
del self.nodes[id]
|
82 |
+
|
83 |
+
class Node:
|
84 |
+
def __init__(self, id, class_type, inputs):
|
85 |
+
self.id = id
|
86 |
+
self.class_type = class_type
|
87 |
+
self.inputs = inputs
|
88 |
+
self.override_display_id = None
|
89 |
+
|
90 |
+
def out(self, index):
|
91 |
+
return [self.id, index]
|
92 |
+
|
93 |
+
def set_input(self, key, value):
|
94 |
+
if value is None:
|
95 |
+
if key in self.inputs:
|
96 |
+
del self.inputs[key]
|
97 |
+
else:
|
98 |
+
self.inputs[key] = value
|
99 |
+
|
100 |
+
def get_input(self, key):
|
101 |
+
return self.inputs.get(key)
|
102 |
+
|
103 |
+
def set_override_display_id(self, override_display_id):
|
104 |
+
self.override_display_id = override_display_id
|
105 |
+
|
106 |
+
def serialize(self):
|
107 |
+
serialized = {
|
108 |
+
"class_type": self.class_type,
|
109 |
+
"inputs": self.inputs
|
110 |
+
}
|
111 |
+
if self.override_display_id is not None:
|
112 |
+
serialized["override_display_id"] = self.override_display_id
|
113 |
+
return serialized
|
114 |
+
|
115 |
+
def add_graph_prefix(graph, outputs, prefix):
|
116 |
+
# Change the node IDs and any internal links
|
117 |
+
new_graph = {}
|
118 |
+
for node_id, node_info in graph.items():
|
119 |
+
# Make sure the added nodes have unique IDs
|
120 |
+
new_node_id = prefix + node_id
|
121 |
+
new_node = { "class_type": node_info["class_type"], "inputs": {} }
|
122 |
+
for input_name, input_value in node_info.get("inputs", {}).items():
|
123 |
+
if is_link(input_value):
|
124 |
+
new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]]
|
125 |
+
else:
|
126 |
+
new_node["inputs"][input_name] = input_value
|
127 |
+
new_graph[new_node_id] = new_node
|
128 |
+
|
129 |
+
# Change the node IDs in the outputs
|
130 |
+
new_outputs = []
|
131 |
+
for n in range(len(outputs)):
|
132 |
+
output = outputs[n]
|
133 |
+
if is_link(output):
|
134 |
+
new_outputs.append([prefix + output[0], output[1]])
|
135 |
+
else:
|
136 |
+
new_outputs.append(output)
|
137 |
+
|
138 |
+
return new_graph, tuple(new_outputs)
|
139 |
+
|
comfy_execution/validation.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
|
4 |
+
def validate_node_input(
|
5 |
+
received_type: str, input_type: str, strict: bool = False
|
6 |
+
) -> bool:
|
7 |
+
"""
|
8 |
+
received_type and input_type are both strings of the form "T1,T2,...".
|
9 |
+
|
10 |
+
If strict is True, the input_type must contain the received_type.
|
11 |
+
For example, if received_type is "STRING" and input_type is "STRING,INT",
|
12 |
+
this will return True. But if received_type is "STRING,INT" and input_type is
|
13 |
+
"INT", this will return False.
|
14 |
+
|
15 |
+
If strict is False, the input_type must have overlap with the received_type.
|
16 |
+
For example, if received_type is "STRING,BOOLEAN" and input_type is "STRING,INT",
|
17 |
+
this will return True.
|
18 |
+
|
19 |
+
Supports pre-union type extension behaviour of ``__ne__`` overrides.
|
20 |
+
"""
|
21 |
+
# If the types are exactly the same, we can return immediately
|
22 |
+
# Use pre-union behaviour: inverse of `__ne__`
|
23 |
+
if not received_type != input_type:
|
24 |
+
return True
|
25 |
+
|
26 |
+
# Not equal, and not strings
|
27 |
+
if not isinstance(received_type, str) or not isinstance(input_type, str):
|
28 |
+
return False
|
29 |
+
|
30 |
+
# Split the type strings into sets for comparison
|
31 |
+
received_types = set(t.strip() for t in received_type.split(","))
|
32 |
+
input_types = set(t.strip() for t in input_type.split(","))
|
33 |
+
|
34 |
+
if strict:
|
35 |
+
# In strict mode, all received types must be in the input types
|
36 |
+
return received_types.issubset(input_types)
|
37 |
+
else:
|
38 |
+
# In non-strict mode, there must be at least one type in common
|
39 |
+
return len(received_types.intersection(input_types)) > 0
|
comfyui_version.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# This file is automatically generated by the build process when version is
|
2 |
+
# updated in pyproject.toml.
|
3 |
+
__version__ = "0.3.12"
|
cuda_malloc.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import importlib.util
|
3 |
+
from comfy.cli_args import args
|
4 |
+
import subprocess
|
5 |
+
|
6 |
+
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
|
7 |
+
def get_gpu_names():
|
8 |
+
if os.name == 'nt':
|
9 |
+
import ctypes
|
10 |
+
|
11 |
+
# Define necessary C structures and types
|
12 |
+
class DISPLAY_DEVICEA(ctypes.Structure):
|
13 |
+
_fields_ = [
|
14 |
+
('cb', ctypes.c_ulong),
|
15 |
+
('DeviceName', ctypes.c_char * 32),
|
16 |
+
('DeviceString', ctypes.c_char * 128),
|
17 |
+
('StateFlags', ctypes.c_ulong),
|
18 |
+
('DeviceID', ctypes.c_char * 128),
|
19 |
+
('DeviceKey', ctypes.c_char * 128)
|
20 |
+
]
|
21 |
+
|
22 |
+
# Load user32.dll
|
23 |
+
user32 = ctypes.windll.user32
|
24 |
+
|
25 |
+
# Call EnumDisplayDevicesA
|
26 |
+
def enum_display_devices():
|
27 |
+
device_info = DISPLAY_DEVICEA()
|
28 |
+
device_info.cb = ctypes.sizeof(device_info)
|
29 |
+
device_index = 0
|
30 |
+
gpu_names = set()
|
31 |
+
|
32 |
+
while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0):
|
33 |
+
device_index += 1
|
34 |
+
gpu_names.add(device_info.DeviceString.decode('utf-8'))
|
35 |
+
return gpu_names
|
36 |
+
return enum_display_devices()
|
37 |
+
else:
|
38 |
+
gpu_names = set()
|
39 |
+
out = subprocess.check_output(['nvidia-smi', '-L'])
|
40 |
+
for l in out.split(b'\n'):
|
41 |
+
if len(l) > 0:
|
42 |
+
gpu_names.add(l.decode('utf-8').split(' (UUID')[0])
|
43 |
+
return gpu_names
|
44 |
+
|
45 |
+
blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M",
|
46 |
+
"GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620",
|
47 |
+
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
|
48 |
+
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000",
|
49 |
+
"GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M",
|
50 |
+
"GeForce GTX 1650", "GeForce GTX 1630", "Tesla M4", "Tesla M6", "Tesla M10", "Tesla M40", "Tesla M60"
|
51 |
+
}
|
52 |
+
|
53 |
+
def cuda_malloc_supported():
|
54 |
+
try:
|
55 |
+
names = get_gpu_names()
|
56 |
+
except:
|
57 |
+
names = set()
|
58 |
+
for x in names:
|
59 |
+
if "NVIDIA" in x:
|
60 |
+
for b in blacklist:
|
61 |
+
if b in x:
|
62 |
+
return False
|
63 |
+
return True
|
64 |
+
|
65 |
+
|
66 |
+
if not args.cuda_malloc:
|
67 |
+
try:
|
68 |
+
version = ""
|
69 |
+
torch_spec = importlib.util.find_spec("torch")
|
70 |
+
for folder in torch_spec.submodule_search_locations:
|
71 |
+
ver_file = os.path.join(folder, "version.py")
|
72 |
+
if os.path.isfile(ver_file):
|
73 |
+
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
|
74 |
+
module = importlib.util.module_from_spec(spec)
|
75 |
+
spec.loader.exec_module(module)
|
76 |
+
version = module.__version__
|
77 |
+
if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
|
78 |
+
args.cuda_malloc = cuda_malloc_supported()
|
79 |
+
except:
|
80 |
+
pass
|
81 |
+
|
82 |
+
|
83 |
+
if args.cuda_malloc and not args.disable_cuda_malloc:
|
84 |
+
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
85 |
+
if env_var is None:
|
86 |
+
env_var = "backend:cudaMallocAsync"
|
87 |
+
else:
|
88 |
+
env_var += ",backend:cudaMallocAsync"
|
89 |
+
|
90 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
|
extra_model_paths.yaml.example
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#Rename this to extra_model_paths.yaml and ComfyUI will load it
|
2 |
+
|
3 |
+
|
4 |
+
#config for a1111 ui
|
5 |
+
#all you have to do is change the base_path to where yours is installed
|
6 |
+
a111:
|
7 |
+
base_path: path/to/stable-diffusion-webui/
|
8 |
+
|
9 |
+
checkpoints: models/Stable-diffusion
|
10 |
+
configs: models/Stable-diffusion
|
11 |
+
vae: models/VAE
|
12 |
+
loras: |
|
13 |
+
models/Lora
|
14 |
+
models/LyCORIS
|
15 |
+
upscale_models: |
|
16 |
+
models/ESRGAN
|
17 |
+
models/RealESRGAN
|
18 |
+
models/SwinIR
|
19 |
+
embeddings: embeddings
|
20 |
+
hypernetworks: models/hypernetworks
|
21 |
+
controlnet: models/ControlNet
|
22 |
+
|
23 |
+
#config for comfyui
|
24 |
+
#your base path should be either an existing comfy install or a central folder where you store all of your models, loras, etc.
|
25 |
+
|
26 |
+
#comfyui:
|
27 |
+
# base_path: path/to/comfyui/
|
28 |
+
# # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads
|
29 |
+
# #is_default: true
|
30 |
+
# checkpoints: models/checkpoints/
|
31 |
+
# clip: models/clip/
|
32 |
+
# clip_vision: models/clip_vision/
|
33 |
+
# configs: models/configs/
|
34 |
+
# controlnet: models/controlnet/
|
35 |
+
# diffusion_models: |
|
36 |
+
# models/diffusion_models
|
37 |
+
# models/unet
|
38 |
+
# embeddings: models/embeddings/
|
39 |
+
# loras: models/loras/
|
40 |
+
# upscale_models: models/upscale_models/
|
41 |
+
# vae: models/vae/
|
42 |
+
|
43 |
+
#other_ui:
|
44 |
+
# base_path: path/to/ui
|
45 |
+
# checkpoints: models/checkpoints
|
46 |
+
# gligen: models/gligen
|
47 |
+
# custom_nodes: path/custom_nodes
|
fix_torch.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.util
|
2 |
+
import shutil
|
3 |
+
import os
|
4 |
+
import ctypes
|
5 |
+
import logging
|
6 |
+
|
7 |
+
|
8 |
+
def fix_pytorch_libomp():
|
9 |
+
"""
|
10 |
+
Fix PyTorch libomp DLL issue on Windows by copying the correct DLL file if needed.
|
11 |
+
"""
|
12 |
+
torch_spec = importlib.util.find_spec("torch")
|
13 |
+
for folder in torch_spec.submodule_search_locations:
|
14 |
+
lib_folder = os.path.join(folder, "lib")
|
15 |
+
test_file = os.path.join(lib_folder, "fbgemm.dll")
|
16 |
+
dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
|
17 |
+
if os.path.exists(dest):
|
18 |
+
break
|
19 |
+
|
20 |
+
with open(test_file, "rb") as f:
|
21 |
+
contents = f.read()
|
22 |
+
if b"libomp140.x86_64.dll" not in contents:
|
23 |
+
break
|
24 |
+
try:
|
25 |
+
ctypes.cdll.LoadLibrary(test_file)
|
26 |
+
except FileNotFoundError:
|
27 |
+
logging.warning("Detected pytorch version with libomp issue, patching.")
|
28 |
+
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)
|
folder_paths.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import mimetypes
|
6 |
+
import logging
|
7 |
+
from typing import Literal
|
8 |
+
from collections.abc import Collection
|
9 |
+
|
10 |
+
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
11 |
+
|
12 |
+
folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {}
|
13 |
+
|
14 |
+
base_path = os.path.dirname(os.path.realpath(__file__))
|
15 |
+
models_dir = os.path.join(base_path, "models")
|
16 |
+
folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions)
|
17 |
+
folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
|
18 |
+
|
19 |
+
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
|
20 |
+
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
|
21 |
+
folder_names_and_paths["text_encoders"] = ([os.path.join(models_dir, "text_encoders"), os.path.join(models_dir, "clip")], supported_pt_extensions)
|
22 |
+
folder_names_and_paths["diffusion_models"] = ([os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], supported_pt_extensions)
|
23 |
+
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
24 |
+
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
25 |
+
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
26 |
+
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
|
27 |
+
folder_names_and_paths["vae_approx"] = ([os.path.join(models_dir, "vae_approx")], supported_pt_extensions)
|
28 |
+
|
29 |
+
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
|
30 |
+
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
|
31 |
+
|
32 |
+
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
|
33 |
+
|
34 |
+
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], set())
|
35 |
+
|
36 |
+
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
|
37 |
+
|
38 |
+
folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions)
|
39 |
+
|
40 |
+
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
|
41 |
+
|
42 |
+
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
43 |
+
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
44 |
+
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
45 |
+
user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")
|
46 |
+
|
47 |
+
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
48 |
+
|
49 |
+
class CacheHelper:
|
50 |
+
"""
|
51 |
+
Helper class for managing file list cache data.
|
52 |
+
"""
|
53 |
+
def __init__(self):
|
54 |
+
self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
55 |
+
self.active = False
|
56 |
+
|
57 |
+
def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
|
58 |
+
if not self.active:
|
59 |
+
return default
|
60 |
+
return self.cache.get(key, default)
|
61 |
+
|
62 |
+
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
|
63 |
+
if self.active:
|
64 |
+
self.cache[key] = value
|
65 |
+
|
66 |
+
def clear(self):
|
67 |
+
self.cache.clear()
|
68 |
+
|
69 |
+
def __enter__(self):
|
70 |
+
self.active = True
|
71 |
+
return self
|
72 |
+
|
73 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
74 |
+
self.active = False
|
75 |
+
self.clear()
|
76 |
+
|
77 |
+
cache_helper = CacheHelper()
|
78 |
+
|
79 |
+
extension_mimetypes_cache = {
|
80 |
+
"webp" : "image",
|
81 |
+
}
|
82 |
+
|
83 |
+
def map_legacy(folder_name: str) -> str:
|
84 |
+
legacy = {"unet": "diffusion_models",
|
85 |
+
"clip": "text_encoders"}
|
86 |
+
return legacy.get(folder_name, folder_name)
|
87 |
+
|
88 |
+
if not os.path.exists(input_directory):
|
89 |
+
try:
|
90 |
+
os.makedirs(input_directory)
|
91 |
+
except:
|
92 |
+
logging.error("Failed to create input directory")
|
93 |
+
|
94 |
+
def set_output_directory(output_dir: str) -> None:
|
95 |
+
global output_directory
|
96 |
+
output_directory = output_dir
|
97 |
+
|
98 |
+
def set_temp_directory(temp_dir: str) -> None:
|
99 |
+
global temp_directory
|
100 |
+
temp_directory = temp_dir
|
101 |
+
|
102 |
+
def set_input_directory(input_dir: str) -> None:
|
103 |
+
global input_directory
|
104 |
+
input_directory = input_dir
|
105 |
+
|
106 |
+
def get_output_directory() -> str:
|
107 |
+
global output_directory
|
108 |
+
return output_directory
|
109 |
+
|
110 |
+
def get_temp_directory() -> str:
|
111 |
+
global temp_directory
|
112 |
+
return temp_directory
|
113 |
+
|
114 |
+
def get_input_directory() -> str:
|
115 |
+
global input_directory
|
116 |
+
return input_directory
|
117 |
+
|
118 |
+
def get_user_directory() -> str:
|
119 |
+
return user_directory
|
120 |
+
|
121 |
+
def set_user_directory(user_dir: str) -> None:
|
122 |
+
global user_directory
|
123 |
+
user_directory = user_dir
|
124 |
+
|
125 |
+
|
126 |
+
#NOTE: used in http server so don't put folders that should not be accessed remotely
|
127 |
+
def get_directory_by_type(type_name: str) -> str | None:
|
128 |
+
if type_name == "output":
|
129 |
+
return get_output_directory()
|
130 |
+
if type_name == "temp":
|
131 |
+
return get_temp_directory()
|
132 |
+
if type_name == "input":
|
133 |
+
return get_input_directory()
|
134 |
+
return None
|
135 |
+
|
136 |
+
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]:
|
137 |
+
"""
|
138 |
+
Example:
|
139 |
+
files = os.listdir(folder_paths.get_input_directory())
|
140 |
+
filter_files_content_types(files, ["image", "audio", "video"])
|
141 |
+
"""
|
142 |
+
global extension_mimetypes_cache
|
143 |
+
result = []
|
144 |
+
for file in files:
|
145 |
+
extension = file.split('.')[-1]
|
146 |
+
if extension not in extension_mimetypes_cache:
|
147 |
+
mime_type, _ = mimetypes.guess_type(file, strict=False)
|
148 |
+
if not mime_type:
|
149 |
+
continue
|
150 |
+
content_type = mime_type.split('/')[0]
|
151 |
+
extension_mimetypes_cache[extension] = content_type
|
152 |
+
else:
|
153 |
+
content_type = extension_mimetypes_cache[extension]
|
154 |
+
|
155 |
+
if content_type in content_types:
|
156 |
+
result.append(file)
|
157 |
+
return result
|
158 |
+
|
159 |
+
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
160 |
+
# otherwise use default_path as base_dir
|
161 |
+
def annotated_filepath(name: str) -> tuple[str, str | None]:
|
162 |
+
if name.endswith("[output]"):
|
163 |
+
base_dir = get_output_directory()
|
164 |
+
name = name[:-9]
|
165 |
+
elif name.endswith("[input]"):
|
166 |
+
base_dir = get_input_directory()
|
167 |
+
name = name[:-8]
|
168 |
+
elif name.endswith("[temp]"):
|
169 |
+
base_dir = get_temp_directory()
|
170 |
+
name = name[:-7]
|
171 |
+
else:
|
172 |
+
return name, None
|
173 |
+
|
174 |
+
return name, base_dir
|
175 |
+
|
176 |
+
|
177 |
+
def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
|
178 |
+
name, base_dir = annotated_filepath(name)
|
179 |
+
|
180 |
+
if base_dir is None:
|
181 |
+
if default_dir is not None:
|
182 |
+
base_dir = default_dir
|
183 |
+
else:
|
184 |
+
base_dir = get_input_directory() # fallback path
|
185 |
+
|
186 |
+
return os.path.join(base_dir, name)
|
187 |
+
|
188 |
+
|
189 |
+
def exists_annotated_filepath(name) -> bool:
|
190 |
+
name, base_dir = annotated_filepath(name)
|
191 |
+
|
192 |
+
if base_dir is None:
|
193 |
+
base_dir = get_input_directory() # fallback path
|
194 |
+
|
195 |
+
filepath = os.path.join(base_dir, name)
|
196 |
+
return os.path.exists(filepath)
|
197 |
+
|
198 |
+
|
199 |
+
def add_model_folder_path(folder_name: str, full_folder_path: str, is_default: bool = False) -> None:
|
200 |
+
global folder_names_and_paths
|
201 |
+
folder_name = map_legacy(folder_name)
|
202 |
+
if folder_name in folder_names_and_paths:
|
203 |
+
paths, _exts = folder_names_and_paths[folder_name]
|
204 |
+
if full_folder_path in paths:
|
205 |
+
if is_default and paths[0] != full_folder_path:
|
206 |
+
# If the path to the folder is not the first in the list, move it to the beginning.
|
207 |
+
paths.remove(full_folder_path)
|
208 |
+
paths.insert(0, full_folder_path)
|
209 |
+
else:
|
210 |
+
if is_default:
|
211 |
+
paths.insert(0, full_folder_path)
|
212 |
+
else:
|
213 |
+
paths.append(full_folder_path)
|
214 |
+
else:
|
215 |
+
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
216 |
+
|
217 |
+
def get_folder_paths(folder_name: str) -> list[str]:
|
218 |
+
folder_name = map_legacy(folder_name)
|
219 |
+
return folder_names_and_paths[folder_name][0][:]
|
220 |
+
|
221 |
+
def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]:
|
222 |
+
if not os.path.isdir(directory):
|
223 |
+
return [], {}
|
224 |
+
|
225 |
+
if excluded_dir_names is None:
|
226 |
+
excluded_dir_names = []
|
227 |
+
|
228 |
+
result = []
|
229 |
+
dirs = {}
|
230 |
+
|
231 |
+
# Attempt to add the initial directory to dirs with error handling
|
232 |
+
try:
|
233 |
+
dirs[directory] = os.path.getmtime(directory)
|
234 |
+
except FileNotFoundError:
|
235 |
+
logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")
|
236 |
+
|
237 |
+
logging.debug("recursive file list on directory {}".format(directory))
|
238 |
+
dirpath: str
|
239 |
+
subdirs: list[str]
|
240 |
+
filenames: list[str]
|
241 |
+
|
242 |
+
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
243 |
+
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
244 |
+
for file_name in filenames:
|
245 |
+
try:
|
246 |
+
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
247 |
+
result.append(relative_path)
|
248 |
+
except:
|
249 |
+
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
250 |
+
continue
|
251 |
+
|
252 |
+
for d in subdirs:
|
253 |
+
path: str = os.path.join(dirpath, d)
|
254 |
+
try:
|
255 |
+
dirs[path] = os.path.getmtime(path)
|
256 |
+
except FileNotFoundError:
|
257 |
+
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
258 |
+
continue
|
259 |
+
logging.debug("found {} files".format(len(result)))
|
260 |
+
return result, dirs
|
261 |
+
|
262 |
+
def filter_files_extensions(files: Collection[str], extensions: Collection[str]) -> list[str]:
|
263 |
+
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))
|
264 |
+
|
265 |
+
|
266 |
+
|
267 |
+
def get_full_path(folder_name: str, filename: str) -> str | None:
|
268 |
+
global folder_names_and_paths
|
269 |
+
folder_name = map_legacy(folder_name)
|
270 |
+
if folder_name not in folder_names_and_paths:
|
271 |
+
return None
|
272 |
+
folders = folder_names_and_paths[folder_name]
|
273 |
+
filename = os.path.relpath(os.path.join("/", filename), "/")
|
274 |
+
for x in folders[0]:
|
275 |
+
full_path = os.path.join(x, filename)
|
276 |
+
if os.path.isfile(full_path):
|
277 |
+
return full_path
|
278 |
+
elif os.path.islink(full_path):
|
279 |
+
logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path))
|
280 |
+
|
281 |
+
return None
|
282 |
+
|
283 |
+
|
284 |
+
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
285 |
+
full_path = get_full_path(folder_name, filename)
|
286 |
+
if full_path is None:
|
287 |
+
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
|
288 |
+
return full_path
|
289 |
+
|
290 |
+
|
291 |
+
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
292 |
+
folder_name = map_legacy(folder_name)
|
293 |
+
global folder_names_and_paths
|
294 |
+
output_list = set()
|
295 |
+
folders = folder_names_and_paths[folder_name]
|
296 |
+
output_folders = {}
|
297 |
+
for x in folders[0]:
|
298 |
+
files, folders_all = recursive_search(x, excluded_dir_names=[".git"])
|
299 |
+
output_list.update(filter_files_extensions(files, folders[1]))
|
300 |
+
output_folders = {**output_folders, **folders_all}
|
301 |
+
|
302 |
+
return sorted(list(output_list)), output_folders, time.perf_counter()
|
303 |
+
|
304 |
+
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
|
305 |
+
strong_cache = cache_helper.get(folder_name)
|
306 |
+
if strong_cache is not None:
|
307 |
+
return strong_cache
|
308 |
+
|
309 |
+
global filename_list_cache
|
310 |
+
global folder_names_and_paths
|
311 |
+
folder_name = map_legacy(folder_name)
|
312 |
+
if folder_name not in filename_list_cache:
|
313 |
+
return None
|
314 |
+
out = filename_list_cache[folder_name]
|
315 |
+
|
316 |
+
for x in out[1]:
|
317 |
+
time_modified = out[1][x]
|
318 |
+
folder = x
|
319 |
+
if os.path.getmtime(folder) != time_modified:
|
320 |
+
return None
|
321 |
+
|
322 |
+
folders = folder_names_and_paths[folder_name]
|
323 |
+
for x in folders[0]:
|
324 |
+
if os.path.isdir(x):
|
325 |
+
if x not in out[1]:
|
326 |
+
return None
|
327 |
+
|
328 |
+
return out
|
329 |
+
|
330 |
+
def get_filename_list(folder_name: str) -> list[str]:
|
331 |
+
folder_name = map_legacy(folder_name)
|
332 |
+
out = cached_filename_list_(folder_name)
|
333 |
+
if out is None:
|
334 |
+
out = get_filename_list_(folder_name)
|
335 |
+
global filename_list_cache
|
336 |
+
filename_list_cache[folder_name] = out
|
337 |
+
cache_helper.set(folder_name, out)
|
338 |
+
return list(out[0])
|
339 |
+
|
340 |
+
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
|
341 |
+
def map_filename(filename: str) -> tuple[int, str]:
|
342 |
+
prefix_len = len(os.path.basename(filename_prefix))
|
343 |
+
prefix = filename[:prefix_len + 1]
|
344 |
+
try:
|
345 |
+
digits = int(filename[prefix_len + 1:].split('_')[0])
|
346 |
+
except:
|
347 |
+
digits = 0
|
348 |
+
return digits, prefix
|
349 |
+
|
350 |
+
def compute_vars(input: str, image_width: int, image_height: int) -> str:
|
351 |
+
input = input.replace("%width%", str(image_width))
|
352 |
+
input = input.replace("%height%", str(image_height))
|
353 |
+
now = time.localtime()
|
354 |
+
input = input.replace("%year%", str(now.tm_year))
|
355 |
+
input = input.replace("%month%", str(now.tm_mon).zfill(2))
|
356 |
+
input = input.replace("%day%", str(now.tm_mday).zfill(2))
|
357 |
+
input = input.replace("%hour%", str(now.tm_hour).zfill(2))
|
358 |
+
input = input.replace("%minute%", str(now.tm_min).zfill(2))
|
359 |
+
input = input.replace("%second%", str(now.tm_sec).zfill(2))
|
360 |
+
return input
|
361 |
+
|
362 |
+
if "%" in filename_prefix:
|
363 |
+
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
364 |
+
|
365 |
+
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
366 |
+
filename = os.path.basename(os.path.normpath(filename_prefix))
|
367 |
+
|
368 |
+
full_output_folder = os.path.join(output_dir, subfolder)
|
369 |
+
|
370 |
+
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
|
371 |
+
err = "**** ERROR: Saving image outside the output folder is not allowed." + \
|
372 |
+
"\n full_output_folder: " + os.path.abspath(full_output_folder) + \
|
373 |
+
"\n output_dir: " + output_dir + \
|
374 |
+
"\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))
|
375 |
+
logging.error(err)
|
376 |
+
raise Exception(err)
|
377 |
+
|
378 |
+
try:
|
379 |
+
counter = max(filter(lambda a: os.path.normcase(a[1][:-1]) == os.path.normcase(filename) and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
|
380 |
+
except ValueError:
|
381 |
+
counter = 1
|
382 |
+
except FileNotFoundError:
|
383 |
+
os.makedirs(full_output_folder, exist_ok=True)
|
384 |
+
counter = 1
|
385 |
+
return full_output_folder, filename, counter, subfolder, filename_prefix
|
latent_preview.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from comfy.cli_args import args, LatentPreviewMethod
|
4 |
+
from comfy.taesd.taesd import TAESD
|
5 |
+
import comfy.model_management
|
6 |
+
import folder_paths
|
7 |
+
import comfy.utils
|
8 |
+
import logging
|
9 |
+
|
10 |
+
MAX_PREVIEW_RESOLUTION = args.preview_size
|
11 |
+
|
12 |
+
def preview_to_image(latent_image):
|
13 |
+
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
14 |
+
.mul(0xFF) # to 0..255
|
15 |
+
).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
|
16 |
+
|
17 |
+
return Image.fromarray(latents_ubyte.numpy())
|
18 |
+
|
19 |
+
class LatentPreviewer:
|
20 |
+
def decode_latent_to_preview(self, x0):
|
21 |
+
pass
|
22 |
+
|
23 |
+
def decode_latent_to_preview_image(self, preview_format, x0):
|
24 |
+
preview_image = self.decode_latent_to_preview(x0)
|
25 |
+
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
|
26 |
+
|
27 |
+
class TAESDPreviewerImpl(LatentPreviewer):
|
28 |
+
def __init__(self, taesd):
|
29 |
+
self.taesd = taesd
|
30 |
+
|
31 |
+
def decode_latent_to_preview(self, x0):
|
32 |
+
x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2)
|
33 |
+
return preview_to_image(x_sample)
|
34 |
+
|
35 |
+
|
36 |
+
class Latent2RGBPreviewer(LatentPreviewer):
|
37 |
+
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
|
38 |
+
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
|
39 |
+
self.latent_rgb_factors_bias = None
|
40 |
+
if latent_rgb_factors_bias is not None:
|
41 |
+
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
|
42 |
+
|
43 |
+
def decode_latent_to_preview(self, x0):
|
44 |
+
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
45 |
+
if self.latent_rgb_factors_bias is not None:
|
46 |
+
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
|
47 |
+
|
48 |
+
if x0.ndim == 5:
|
49 |
+
x0 = x0[0, :, 0]
|
50 |
+
else:
|
51 |
+
x0 = x0[0]
|
52 |
+
|
53 |
+
latent_image = torch.nn.functional.linear(x0.movedim(0, -1), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
|
54 |
+
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
55 |
+
|
56 |
+
return preview_to_image(latent_image)
|
57 |
+
|
58 |
+
|
59 |
+
def get_previewer(device, latent_format):
|
60 |
+
previewer = None
|
61 |
+
method = args.preview_method
|
62 |
+
if method != LatentPreviewMethod.NoPreviews:
|
63 |
+
# TODO previewer methods
|
64 |
+
taesd_decoder_path = None
|
65 |
+
if latent_format.taesd_decoder_name is not None:
|
66 |
+
taesd_decoder_path = next(
|
67 |
+
(fn for fn in folder_paths.get_filename_list("vae_approx")
|
68 |
+
if fn.startswith(latent_format.taesd_decoder_name)),
|
69 |
+
""
|
70 |
+
)
|
71 |
+
taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
|
72 |
+
|
73 |
+
if method == LatentPreviewMethod.Auto:
|
74 |
+
method = LatentPreviewMethod.Latent2RGB
|
75 |
+
|
76 |
+
if method == LatentPreviewMethod.TAESD:
|
77 |
+
if taesd_decoder_path:
|
78 |
+
taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
|
79 |
+
previewer = TAESDPreviewerImpl(taesd)
|
80 |
+
else:
|
81 |
+
logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
82 |
+
|
83 |
+
if previewer is None:
|
84 |
+
if latent_format.latent_rgb_factors is not None:
|
85 |
+
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
|
86 |
+
return previewer
|
87 |
+
|
88 |
+
def prepare_callback(model, steps, x0_output_dict=None):
|
89 |
+
preview_format = "JPEG"
|
90 |
+
if preview_format not in ["JPEG", "PNG"]:
|
91 |
+
preview_format = "JPEG"
|
92 |
+
|
93 |
+
previewer = get_previewer(model.load_device, model.model.latent_format)
|
94 |
+
|
95 |
+
pbar = comfy.utils.ProgressBar(steps)
|
96 |
+
def callback(step, x0, x, total_steps):
|
97 |
+
if x0_output_dict is not None:
|
98 |
+
x0_output_dict["x0"] = x0
|
99 |
+
|
100 |
+
preview_bytes = None
|
101 |
+
if previewer:
|
102 |
+
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
|
103 |
+
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
104 |
+
return callback
|
105 |
+
|
main.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import comfy.options
|
2 |
+
comfy.options.enable_args_parsing()
|
3 |
+
|
4 |
+
import os
|
5 |
+
import importlib.util
|
6 |
+
import folder_paths
|
7 |
+
import time
|
8 |
+
from comfy.cli_args import args
|
9 |
+
from app.logger import setup_logger
|
10 |
+
import itertools
|
11 |
+
import utils.extra_config
|
12 |
+
import logging
|
13 |
+
|
14 |
+
if __name__ == "__main__":
|
15 |
+
#NOTE: These do not do anything on core ComfyUI which should already have no communication with the internet, they are for custom nodes.
|
16 |
+
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
17 |
+
os.environ['DO_NOT_TRACK'] = '1'
|
18 |
+
|
19 |
+
|
20 |
+
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
21 |
+
|
22 |
+
def apply_custom_paths():
|
23 |
+
# extra model paths
|
24 |
+
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
25 |
+
if os.path.isfile(extra_model_paths_config_path):
|
26 |
+
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
|
27 |
+
|
28 |
+
if args.extra_model_paths_config:
|
29 |
+
for config_path in itertools.chain(*args.extra_model_paths_config):
|
30 |
+
utils.extra_config.load_extra_path_config(config_path)
|
31 |
+
|
32 |
+
# --output-directory, --input-directory, --user-directory
|
33 |
+
if args.output_directory:
|
34 |
+
output_dir = os.path.abspath(args.output_directory)
|
35 |
+
logging.info(f"Setting output directory to: {output_dir}")
|
36 |
+
folder_paths.set_output_directory(output_dir)
|
37 |
+
|
38 |
+
# These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
|
39 |
+
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
|
40 |
+
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
41 |
+
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
42 |
+
folder_paths.add_model_folder_path("diffusion_models",
|
43 |
+
os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
44 |
+
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
|
45 |
+
|
46 |
+
if args.input_directory:
|
47 |
+
input_dir = os.path.abspath(args.input_directory)
|
48 |
+
logging.info(f"Setting input directory to: {input_dir}")
|
49 |
+
folder_paths.set_input_directory(input_dir)
|
50 |
+
|
51 |
+
if args.user_directory:
|
52 |
+
user_dir = os.path.abspath(args.user_directory)
|
53 |
+
logging.info(f"Setting user directory to: {user_dir}")
|
54 |
+
folder_paths.set_user_directory(user_dir)
|
55 |
+
|
56 |
+
|
57 |
+
def execute_prestartup_script():
|
58 |
+
def execute_script(script_path):
|
59 |
+
module_name = os.path.splitext(script_path)[0]
|
60 |
+
try:
|
61 |
+
spec = importlib.util.spec_from_file_location(module_name, script_path)
|
62 |
+
module = importlib.util.module_from_spec(spec)
|
63 |
+
spec.loader.exec_module(module)
|
64 |
+
return True
|
65 |
+
except Exception as e:
|
66 |
+
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
|
67 |
+
return False
|
68 |
+
|
69 |
+
if args.disable_all_custom_nodes:
|
70 |
+
return
|
71 |
+
|
72 |
+
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
73 |
+
for custom_node_path in node_paths:
|
74 |
+
possible_modules = os.listdir(custom_node_path)
|
75 |
+
node_prestartup_times = []
|
76 |
+
|
77 |
+
for possible_module in possible_modules:
|
78 |
+
module_path = os.path.join(custom_node_path, possible_module)
|
79 |
+
if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
|
80 |
+
continue
|
81 |
+
|
82 |
+
script_path = os.path.join(module_path, "prestartup_script.py")
|
83 |
+
if os.path.exists(script_path):
|
84 |
+
time_before = time.perf_counter()
|
85 |
+
success = execute_script(script_path)
|
86 |
+
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
|
87 |
+
if len(node_prestartup_times) > 0:
|
88 |
+
logging.info("\nPrestartup times for custom nodes:")
|
89 |
+
for n in sorted(node_prestartup_times):
|
90 |
+
if n[2]:
|
91 |
+
import_message = ""
|
92 |
+
else:
|
93 |
+
import_message = " (PRESTARTUP FAILED)"
|
94 |
+
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
|
95 |
+
logging.info("")
|
96 |
+
|
97 |
+
apply_custom_paths()
|
98 |
+
execute_prestartup_script()
|
99 |
+
|
100 |
+
|
101 |
+
# Main code
|
102 |
+
import asyncio
|
103 |
+
import shutil
|
104 |
+
import threading
|
105 |
+
import gc
|
106 |
+
|
107 |
+
|
108 |
+
if os.name == "nt":
|
109 |
+
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
110 |
+
|
111 |
+
if __name__ == "__main__":
|
112 |
+
if args.cuda_device is not None:
|
113 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
114 |
+
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
115 |
+
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
116 |
+
|
117 |
+
if args.oneapi_device_selector is not None:
|
118 |
+
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
|
119 |
+
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
|
120 |
+
|
121 |
+
if args.deterministic:
|
122 |
+
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
123 |
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
124 |
+
|
125 |
+
import cuda_malloc
|
126 |
+
|
127 |
+
if args.windows_standalone_build:
|
128 |
+
try:
|
129 |
+
from fix_torch import fix_pytorch_libomp
|
130 |
+
fix_pytorch_libomp()
|
131 |
+
except:
|
132 |
+
pass
|
133 |
+
|
134 |
+
import comfy.utils
|
135 |
+
|
136 |
+
import execution
|
137 |
+
import server
|
138 |
+
from server import BinaryEventTypes
|
139 |
+
import nodes
|
140 |
+
import comfy.model_management
|
141 |
+
|
142 |
+
def cuda_malloc_warning():
|
143 |
+
device = comfy.model_management.get_torch_device()
|
144 |
+
device_name = comfy.model_management.get_torch_device_name(device)
|
145 |
+
cuda_malloc_warning = False
|
146 |
+
if "cudaMallocAsync" in device_name:
|
147 |
+
for b in cuda_malloc.blacklist:
|
148 |
+
if b in device_name:
|
149 |
+
cuda_malloc_warning = True
|
150 |
+
if cuda_malloc_warning:
|
151 |
+
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
152 |
+
|
153 |
+
|
154 |
+
def prompt_worker(q, server_instance):
|
155 |
+
current_time: float = 0.0
|
156 |
+
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
|
157 |
+
last_gc_collect = 0
|
158 |
+
need_gc = False
|
159 |
+
gc_collect_interval = 10.0
|
160 |
+
|
161 |
+
while True:
|
162 |
+
timeout = 1000.0
|
163 |
+
if need_gc:
|
164 |
+
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
|
165 |
+
|
166 |
+
queue_item = q.get(timeout=timeout)
|
167 |
+
if queue_item is not None:
|
168 |
+
item, item_id = queue_item
|
169 |
+
execution_start_time = time.perf_counter()
|
170 |
+
prompt_id = item[1]
|
171 |
+
server_instance.last_prompt_id = prompt_id
|
172 |
+
|
173 |
+
e.execute(item[2], prompt_id, item[3], item[4])
|
174 |
+
need_gc = True
|
175 |
+
q.task_done(item_id,
|
176 |
+
e.history_result,
|
177 |
+
status=execution.PromptQueue.ExecutionStatus(
|
178 |
+
status_str='success' if e.success else 'error',
|
179 |
+
completed=e.success,
|
180 |
+
messages=e.status_messages))
|
181 |
+
if server_instance.client_id is not None:
|
182 |
+
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
183 |
+
|
184 |
+
current_time = time.perf_counter()
|
185 |
+
execution_time = current_time - execution_start_time
|
186 |
+
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
187 |
+
|
188 |
+
flags = q.get_flags()
|
189 |
+
free_memory = flags.get("free_memory", False)
|
190 |
+
|
191 |
+
if flags.get("unload_models", free_memory):
|
192 |
+
comfy.model_management.unload_all_models()
|
193 |
+
need_gc = True
|
194 |
+
last_gc_collect = 0
|
195 |
+
|
196 |
+
if free_memory:
|
197 |
+
e.reset()
|
198 |
+
need_gc = True
|
199 |
+
last_gc_collect = 0
|
200 |
+
|
201 |
+
if need_gc:
|
202 |
+
current_time = time.perf_counter()
|
203 |
+
if (current_time - last_gc_collect) > gc_collect_interval:
|
204 |
+
gc.collect()
|
205 |
+
comfy.model_management.soft_empty_cache()
|
206 |
+
last_gc_collect = current_time
|
207 |
+
need_gc = False
|
208 |
+
|
209 |
+
|
210 |
+
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
|
211 |
+
addresses = []
|
212 |
+
for addr in address.split(","):
|
213 |
+
addresses.append((addr, port))
|
214 |
+
await asyncio.gather(
|
215 |
+
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
|
216 |
+
)
|
217 |
+
|
218 |
+
|
219 |
+
def hijack_progress(server_instance):
|
220 |
+
def hook(value, total, preview_image):
|
221 |
+
comfy.model_management.throw_exception_if_processing_interrupted()
|
222 |
+
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
|
223 |
+
|
224 |
+
server_instance.send_sync("progress", progress, server_instance.client_id)
|
225 |
+
if preview_image is not None:
|
226 |
+
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
|
227 |
+
|
228 |
+
comfy.utils.set_progress_bar_global_hook(hook)
|
229 |
+
|
230 |
+
|
231 |
+
def cleanup_temp():
|
232 |
+
temp_dir = folder_paths.get_temp_directory()
|
233 |
+
if os.path.exists(temp_dir):
|
234 |
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
235 |
+
|
236 |
+
|
237 |
+
def start_comfyui(asyncio_loop=None):
|
238 |
+
"""
|
239 |
+
Starts the ComfyUI server using the provided asyncio event loop or creates a new one.
|
240 |
+
Returns the event loop, server instance, and a function to start the server asynchronously.
|
241 |
+
"""
|
242 |
+
if args.temp_directory:
|
243 |
+
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
244 |
+
logging.info(f"Setting temp directory to: {temp_dir}")
|
245 |
+
folder_paths.set_temp_directory(temp_dir)
|
246 |
+
cleanup_temp()
|
247 |
+
|
248 |
+
if args.windows_standalone_build:
|
249 |
+
try:
|
250 |
+
import new_updater
|
251 |
+
new_updater.update_windows_updater()
|
252 |
+
except:
|
253 |
+
pass
|
254 |
+
|
255 |
+
if not asyncio_loop:
|
256 |
+
asyncio_loop = asyncio.new_event_loop()
|
257 |
+
asyncio.set_event_loop(asyncio_loop)
|
258 |
+
prompt_server = server.PromptServer(asyncio_loop)
|
259 |
+
q = execution.PromptQueue(prompt_server)
|
260 |
+
|
261 |
+
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
262 |
+
|
263 |
+
cuda_malloc_warning()
|
264 |
+
|
265 |
+
prompt_server.add_routes()
|
266 |
+
hijack_progress(prompt_server)
|
267 |
+
|
268 |
+
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
|
269 |
+
|
270 |
+
if args.quick_test_for_ci:
|
271 |
+
exit(0)
|
272 |
+
|
273 |
+
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
|
274 |
+
call_on_start = None
|
275 |
+
if args.auto_launch:
|
276 |
+
def startup_server(scheme, address, port):
|
277 |
+
import webbrowser
|
278 |
+
if os.name == 'nt' and address == '0.0.0.0':
|
279 |
+
address = '127.0.0.1'
|
280 |
+
if ':' in address:
|
281 |
+
address = "[{}]".format(address)
|
282 |
+
webbrowser.open(f"{scheme}://{address}:{port}")
|
283 |
+
call_on_start = startup_server
|
284 |
+
|
285 |
+
async def start_all():
|
286 |
+
await prompt_server.setup()
|
287 |
+
await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
|
288 |
+
|
289 |
+
# Returning these so that other code can integrate with the ComfyUI loop and server
|
290 |
+
return asyncio_loop, prompt_server, start_all
|
291 |
+
|
292 |
+
|
293 |
+
if __name__ == "__main__":
|
294 |
+
# Running directly, just start ComfyUI.
|
295 |
+
event_loop, _, start_all_func = start_comfyui()
|
296 |
+
try:
|
297 |
+
event_loop.run_until_complete(start_all_func())
|
298 |
+
except KeyboardInterrupt:
|
299 |
+
logging.info("\nStopped server")
|
300 |
+
|
301 |
+
cleanup_temp()
|
new_updater.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
|
4 |
+
base_path = os.path.dirname(os.path.realpath(__file__))
|
5 |
+
|
6 |
+
|
7 |
+
def update_windows_updater():
|
8 |
+
top_path = os.path.dirname(base_path)
|
9 |
+
updater_path = os.path.join(base_path, ".ci/update_windows/update.py")
|
10 |
+
bat_path = os.path.join(base_path, ".ci/update_windows/update_comfyui.bat")
|
11 |
+
|
12 |
+
dest_updater_path = os.path.join(top_path, "update/update.py")
|
13 |
+
dest_bat_path = os.path.join(top_path, "update/update_comfyui.bat")
|
14 |
+
dest_bat_deps_path = os.path.join(top_path, "update/update_comfyui_and_python_dependencies.bat")
|
15 |
+
|
16 |
+
try:
|
17 |
+
with open(dest_bat_path, 'rb') as f:
|
18 |
+
contents = f.read()
|
19 |
+
except:
|
20 |
+
return
|
21 |
+
|
22 |
+
if not contents.startswith(b"..\\python_embeded\\python.exe .\\update.py"):
|
23 |
+
return
|
24 |
+
|
25 |
+
shutil.copy(updater_path, dest_updater_path)
|
26 |
+
try:
|
27 |
+
with open(dest_bat_deps_path, 'rb') as f:
|
28 |
+
contents = f.read()
|
29 |
+
contents = contents.replace(b'..\\python_embeded\\python.exe .\\update.py ..\\ComfyUI\\', b'call update_comfyui.bat nopause')
|
30 |
+
with open(dest_bat_deps_path, 'wb') as f:
|
31 |
+
f.write(contents)
|
32 |
+
except:
|
33 |
+
pass
|
34 |
+
shutil.copy(bat_path, dest_bat_path)
|
35 |
+
print("Updated the windows standalone package updater.") # noqa: T201
|
node_helpers.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
|
3 |
+
from comfy.cli_args import args
|
4 |
+
|
5 |
+
from PIL import ImageFile, UnidentifiedImageError
|
6 |
+
|
7 |
+
def conditioning_set_values(conditioning, values={}):
|
8 |
+
c = []
|
9 |
+
for t in conditioning:
|
10 |
+
n = [t[0], t[1].copy()]
|
11 |
+
for k in values:
|
12 |
+
n[1][k] = values[k]
|
13 |
+
c.append(n)
|
14 |
+
|
15 |
+
return c
|
16 |
+
|
17 |
+
def pillow(fn, arg):
|
18 |
+
prev_value = None
|
19 |
+
try:
|
20 |
+
x = fn(arg)
|
21 |
+
except (OSError, UnidentifiedImageError, ValueError): #PIL issues #4472 and #2445, also fixes ComfyUI issue #3416
|
22 |
+
prev_value = ImageFile.LOAD_TRUNCATED_IMAGES
|
23 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
24 |
+
x = fn(arg)
|
25 |
+
finally:
|
26 |
+
if prev_value is not None:
|
27 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
|
28 |
+
return x
|
29 |
+
|
30 |
+
def hasher():
|
31 |
+
hashfuncs = {
|
32 |
+
"md5": hashlib.md5,
|
33 |
+
"sha1": hashlib.sha1,
|
34 |
+
"sha256": hashlib.sha256,
|
35 |
+
"sha512": hashlib.sha512
|
36 |
+
}
|
37 |
+
return hashfuncs[args.default_hashing_function]
|
notebooks/comfyui_colab.ipynb
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "aaaaaaaaaa"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"Git clone the repo and install the requirements. (ignore the pip errors about protobuf)"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {
|
16 |
+
"id": "bbbbbbbbbb"
|
17 |
+
},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"#@title Environment Setup\n",
|
21 |
+
"\n",
|
22 |
+
"\n",
|
23 |
+
"OPTIONS = {}\n",
|
24 |
+
"\n",
|
25 |
+
"USE_GOOGLE_DRIVE = False #@param {type:\"boolean\"}\n",
|
26 |
+
"UPDATE_COMFY_UI = True #@param {type:\"boolean\"}\n",
|
27 |
+
"WORKSPACE = 'ComfyUI'\n",
|
28 |
+
"OPTIONS['USE_GOOGLE_DRIVE'] = USE_GOOGLE_DRIVE\n",
|
29 |
+
"OPTIONS['UPDATE_COMFY_UI'] = UPDATE_COMFY_UI\n",
|
30 |
+
"\n",
|
31 |
+
"if OPTIONS['USE_GOOGLE_DRIVE']:\n",
|
32 |
+
" !echo \"Mounting Google Drive...\"\n",
|
33 |
+
" %cd /\n",
|
34 |
+
" \n",
|
35 |
+
" from google.colab import drive\n",
|
36 |
+
" drive.mount('/content/drive')\n",
|
37 |
+
"\n",
|
38 |
+
" WORKSPACE = \"/content/drive/MyDrive/ComfyUI\"\n",
|
39 |
+
" %cd /content/drive/MyDrive\n",
|
40 |
+
"\n",
|
41 |
+
"![ ! -d $WORKSPACE ] && echo -= Initial setup ComfyUI =- && git clone https://github.com/comfyanonymous/ComfyUI\n",
|
42 |
+
"%cd $WORKSPACE\n",
|
43 |
+
"\n",
|
44 |
+
"if OPTIONS['UPDATE_COMFY_UI']:\n",
|
45 |
+
" !echo -= Updating ComfyUI =-\n",
|
46 |
+
" !git pull\n",
|
47 |
+
"\n",
|
48 |
+
"!echo -= Install dependencies =-\n",
|
49 |
+
"!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117"
|
50 |
+
]
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"cell_type": "markdown",
|
54 |
+
"metadata": {
|
55 |
+
"id": "cccccccccc"
|
56 |
+
},
|
57 |
+
"source": [
|
58 |
+
"Download some models/checkpoints/vae or custom comfyui nodes (uncomment the commands for the ones you want)"
|
59 |
+
]
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"cell_type": "code",
|
63 |
+
"execution_count": null,
|
64 |
+
"metadata": {
|
65 |
+
"id": "dddddddddd"
|
66 |
+
},
|
67 |
+
"outputs": [],
|
68 |
+
"source": [
|
69 |
+
"# Checkpoints\n",
|
70 |
+
"\n",
|
71 |
+
"### SDXL\n",
|
72 |
+
"### I recommend these workflow examples: https://comfyanonymous.github.io/ComfyUI_examples/sdxl/\n",
|
73 |
+
"\n",
|
74 |
+
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P ./models/checkpoints/\n",
|
75 |
+
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0.safetensors -P ./models/checkpoints/\n",
|
76 |
+
"\n",
|
77 |
+
"# SDXL ReVision\n",
|
78 |
+
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
|
79 |
+
"\n",
|
80 |
+
"# SD1.5\n",
|
81 |
+
"!wget -c https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors -P ./models/checkpoints/\n",
|
82 |
+
"\n",
|
83 |
+
"# SD2\n",
|
84 |
+
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",
|
85 |
+
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors -P ./models/checkpoints/\n",
|
86 |
+
"\n",
|
87 |
+
"# Some SD1.5 anime style\n",
|
88 |
+
"#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix2/AbyssOrangeMix2_hard.safetensors -P ./models/checkpoints/\n",
|
89 |
+
"#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix3/AOM3A1_orangemixs.safetensors -P ./models/checkpoints/\n",
|
90 |
+
"#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix3/AOM3A3_orangemixs.safetensors -P ./models/checkpoints/\n",
|
91 |
+
"#!wget -c https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/anything-v3-fp16-pruned.safetensors -P ./models/checkpoints/\n",
|
92 |
+
"\n",
|
93 |
+
"# Waifu Diffusion 1.5 (anime style SD2.x 768-v)\n",
|
94 |
+
"#!wget -c https://huggingface.co/waifu-diffusion/wd-1-5-beta3/resolve/main/wd-illusion-fp16.safetensors -P ./models/checkpoints/\n",
|
95 |
+
"\n",
|
96 |
+
"\n",
|
97 |
+
"# unCLIP models\n",
|
98 |
+
"#!wget -c https://huggingface.co/comfyanonymous/illuminatiDiffusionV1_v11_unCLIP/resolve/main/illuminatiDiffusionV1_v11-unclip-h-fp16.safetensors -P ./models/checkpoints/\n",
|
99 |
+
"#!wget -c https://huggingface.co/comfyanonymous/wd-1.5-beta2_unCLIP/resolve/main/wd-1-5-beta2-aesthetic-unclip-h-fp16.safetensors -P ./models/checkpoints/\n",
|
100 |
+
"\n",
|
101 |
+
"\n",
|
102 |
+
"# VAE\n",
|
103 |
+
"!wget -c https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.safetensors -P ./models/vae/\n",
|
104 |
+
"#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/VAEs/orangemix.vae.pt -P ./models/vae/\n",
|
105 |
+
"#!wget -c https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime2.ckpt -P ./models/vae/\n",
|
106 |
+
"\n",
|
107 |
+
"\n",
|
108 |
+
"# Loras\n",
|
109 |
+
"#!wget -c https://civitai.com/api/download/models/10350 -O ./models/loras/theovercomer8sContrastFix_sd21768.safetensors #theovercomer8sContrastFix SD2.x 768-v\n",
|
110 |
+
"#!wget -c https://civitai.com/api/download/models/10638 -O ./models/loras/theovercomer8sContrastFix_sd15.safetensors #theovercomer8sContrastFix SD1.x\n",
|
111 |
+
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_offset_example-lora_1.0.safetensors -P ./models/loras/ #SDXL offset noise lora\n",
|
112 |
+
"\n",
|
113 |
+
"\n",
|
114 |
+
"# T2I-Adapter\n",
|
115 |
+
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth -P ./models/controlnet/\n",
|
116 |
+
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_seg_sd14v1.pth -P ./models/controlnet/\n",
|
117 |
+
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_sketch_sd14v1.pth -P ./models/controlnet/\n",
|
118 |
+
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_keypose_sd14v1.pth -P ./models/controlnet/\n",
|
119 |
+
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_openpose_sd14v1.pth -P ./models/controlnet/\n",
|
120 |
+
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_color_sd14v1.pth -P ./models/controlnet/\n",
|
121 |
+
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_canny_sd14v1.pth -P ./models/controlnet/\n",
|
122 |
+
"\n",
|
123 |
+
"# T2I Styles Model\n",
|
124 |
+
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_style_sd14v1.pth -P ./models/style_models/\n",
|
125 |
+
"\n",
|
126 |
+
"# CLIPVision model (needed for styles model)\n",
|
127 |
+
"#!wget -c https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin -O ./models/clip_vision/clip_vit14.bin\n",
|
128 |
+
"\n",
|
129 |
+
"\n",
|
130 |
+
"# ControlNet\n",
|
131 |
+
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_ip2p_fp16.safetensors -P ./models/controlnet/\n",
|
132 |
+
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_shuffle_fp16.safetensors -P ./models/controlnet/\n",
|
133 |
+
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_canny_fp16.safetensors -P ./models/controlnet/\n",
|
134 |
+
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11f1p_sd15_depth_fp16.safetensors -P ./models/controlnet/\n",
|
135 |
+
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_inpaint_fp16.safetensors -P ./models/controlnet/\n",
|
136 |
+
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_lineart_fp16.safetensors -P ./models/controlnet/\n",
|
137 |
+
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_mlsd_fp16.safetensors -P ./models/controlnet/\n",
|
138 |
+
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_normalbae_fp16.safetensors -P ./models/controlnet/\n",
|
139 |
+
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_openpose_fp16.safetensors -P ./models/controlnet/\n",
|
140 |
+
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_scribble_fp16.safetensors -P ./models/controlnet/\n",
|
141 |
+
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_seg_fp16.safetensors -P ./models/controlnet/\n",
|
142 |
+
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_softedge_fp16.safetensors -P ./models/controlnet/\n",
|
143 |
+
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15s2_lineart_anime_fp16.safetensors -P ./models/controlnet/\n",
|
144 |
+
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11u_sd15_tile_fp16.safetensors -P ./models/controlnet/\n",
|
145 |
+
"\n",
|
146 |
+
"# ControlNet SDXL\n",
|
147 |
+
"#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-canny-rank256.safetensors -P ./models/controlnet/\n",
|
148 |
+
"#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-depth-rank256.safetensors -P ./models/controlnet/\n",
|
149 |
+
"#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-recolor-rank256.safetensors -P ./models/controlnet/\n",
|
150 |
+
"#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-sketch-rank256.safetensors -P ./models/controlnet/\n",
|
151 |
+
"\n",
|
152 |
+
"# Controlnet Preprocessor nodes by Fannovel16\n",
|
153 |
+
"#!cd custom_nodes && git clone https://github.com/Fannovel16/comfy_controlnet_preprocessors; cd comfy_controlnet_preprocessors && python install.py\n",
|
154 |
+
"\n",
|
155 |
+
"\n",
|
156 |
+
"# GLIGEN\n",
|
157 |
+
"#!wget -c https://huggingface.co/comfyanonymous/GLIGEN_pruned_safetensors/resolve/main/gligen_sd14_textbox_pruned_fp16.safetensors -P ./models/gligen/\n",
|
158 |
+
"\n",
|
159 |
+
"\n",
|
160 |
+
"# ESRGAN upscale model\n",
|
161 |
+
"#!wget -c https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./models/upscale_models/\n",
|
162 |
+
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n",
|
163 |
+
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n",
|
164 |
+
"\n",
|
165 |
+
"\n"
|
166 |
+
]
|
167 |
+
},
|
168 |
+
{
|
169 |
+
"cell_type": "markdown",
|
170 |
+
"metadata": {
|
171 |
+
"id": "kkkkkkkkkkkkkkk"
|
172 |
+
},
|
173 |
+
"source": [
|
174 |
+
"### Run ComfyUI with cloudflared (Recommended Way)\n",
|
175 |
+
"\n",
|
176 |
+
"\n"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "code",
|
181 |
+
"execution_count": null,
|
182 |
+
"metadata": {
|
183 |
+
"id": "jjjjjjjjjjjjjj"
|
184 |
+
},
|
185 |
+
"outputs": [],
|
186 |
+
"source": [
|
187 |
+
"!wget https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb\n",
|
188 |
+
"!dpkg -i cloudflared-linux-amd64.deb\n",
|
189 |
+
"\n",
|
190 |
+
"import subprocess\n",
|
191 |
+
"import threading\n",
|
192 |
+
"import time\n",
|
193 |
+
"import socket\n",
|
194 |
+
"import urllib.request\n",
|
195 |
+
"\n",
|
196 |
+
"def iframe_thread(port):\n",
|
197 |
+
" while True:\n",
|
198 |
+
" time.sleep(0.5)\n",
|
199 |
+
" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n",
|
200 |
+
" result = sock.connect_ex(('127.0.0.1', port))\n",
|
201 |
+
" if result == 0:\n",
|
202 |
+
" break\n",
|
203 |
+
" sock.close()\n",
|
204 |
+
" print(\"\\nComfyUI finished loading, trying to launch cloudflared (if it gets stuck here cloudflared is having issues)\\n\")\n",
|
205 |
+
"\n",
|
206 |
+
" p = subprocess.Popen([\"cloudflared\", \"tunnel\", \"--url\", \"http://127.0.0.1:{}\".format(port)], stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
|
207 |
+
" for line in p.stderr:\n",
|
208 |
+
" l = line.decode()\n",
|
209 |
+
" if \"trycloudflare.com \" in l:\n",
|
210 |
+
" print(\"This is the URL to access ComfyUI:\", l[l.find(\"http\"):], end='')\n",
|
211 |
+
" #print(l, end='')\n",
|
212 |
+
"\n",
|
213 |
+
"\n",
|
214 |
+
"threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n",
|
215 |
+
"\n",
|
216 |
+
"!python main.py --dont-print-server"
|
217 |
+
]
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"cell_type": "markdown",
|
221 |
+
"metadata": {
|
222 |
+
"id": "kkkkkkkkkkkkkk"
|
223 |
+
},
|
224 |
+
"source": [
|
225 |
+
"### Run ComfyUI with localtunnel\n",
|
226 |
+
"\n",
|
227 |
+
"\n"
|
228 |
+
]
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"cell_type": "code",
|
232 |
+
"execution_count": null,
|
233 |
+
"metadata": {
|
234 |
+
"id": "jjjjjjjjjjjjj"
|
235 |
+
},
|
236 |
+
"outputs": [],
|
237 |
+
"source": [
|
238 |
+
"!npm install -g localtunnel\n",
|
239 |
+
"\n",
|
240 |
+
"import threading\n",
|
241 |
+
"\n",
|
242 |
+
"def iframe_thread(port):\n",
|
243 |
+
" while True:\n",
|
244 |
+
" time.sleep(0.5)\n",
|
245 |
+
" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n",
|
246 |
+
" result = sock.connect_ex(('127.0.0.1', port))\n",
|
247 |
+
" if result == 0:\n",
|
248 |
+
" break\n",
|
249 |
+
" sock.close()\n",
|
250 |
+
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n",
|
251 |
+
"\n",
|
252 |
+
" print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n",
|
253 |
+
" p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n",
|
254 |
+
" for line in p.stdout:\n",
|
255 |
+
" print(line.decode(), end='')\n",
|
256 |
+
"\n",
|
257 |
+
"\n",
|
258 |
+
"threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n",
|
259 |
+
"\n",
|
260 |
+
"!python main.py --dont-print-server"
|
261 |
+
]
|
262 |
+
},
|
263 |
+
{
|
264 |
+
"cell_type": "markdown",
|
265 |
+
"metadata": {
|
266 |
+
"id": "gggggggggg"
|
267 |
+
},
|
268 |
+
"source": [
|
269 |
+
"### Run ComfyUI with colab iframe (use only in case the previous way with localtunnel doesn't work)\n",
|
270 |
+
"\n",
|
271 |
+
"You should see the ui appear in an iframe. If you get a 403 error, it's your firefox settings or an extension that's messing things up.\n",
|
272 |
+
"\n",
|
273 |
+
"If you want to open it in another window use the link.\n",
|
274 |
+
"\n",
|
275 |
+
"Note that some UI features like live image previews won't work because the colab iframe blocks websockets."
|
276 |
+
]
|
277 |
+
},
|
278 |
+
{
|
279 |
+
"cell_type": "code",
|
280 |
+
"execution_count": null,
|
281 |
+
"metadata": {
|
282 |
+
"id": "hhhhhhhhhh"
|
283 |
+
},
|
284 |
+
"outputs": [],
|
285 |
+
"source": [
|
286 |
+
"import threading\n",
|
287 |
+
"def iframe_thread(port):\n",
|
288 |
+
" while True:\n",
|
289 |
+
" time.sleep(0.5)\n",
|
290 |
+
" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n",
|
291 |
+
" result = sock.connect_ex(('127.0.0.1', port))\n",
|
292 |
+
" if result == 0:\n",
|
293 |
+
" break\n",
|
294 |
+
" sock.close()\n",
|
295 |
+
" from google.colab import output\n",
|
296 |
+
" output.serve_kernel_port_as_iframe(port, height=1024)\n",
|
297 |
+
" print(\"to open it in a window you can open this link here:\")\n",
|
298 |
+
" output.serve_kernel_port_as_window(port)\n",
|
299 |
+
"\n",
|
300 |
+
"threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n",
|
301 |
+
"\n",
|
302 |
+
"!python main.py --dont-print-server"
|
303 |
+
]
|
304 |
+
}
|
305 |
+
],
|
306 |
+
"metadata": {
|
307 |
+
"accelerator": "GPU",
|
308 |
+
"colab": {
|
309 |
+
"provenance": []
|
310 |
+
},
|
311 |
+
"gpuClass": "standard",
|
312 |
+
"kernelspec": {
|
313 |
+
"display_name": "Python 3",
|
314 |
+
"name": "python3"
|
315 |
+
},
|
316 |
+
"language_info": {
|
317 |
+
"name": "python"
|
318 |
+
}
|
319 |
+
},
|
320 |
+
"nbformat": 4,
|
321 |
+
"nbformat_minor": 0
|
322 |
+
}
|
output/_output_images_will_be_put_here
ADDED
File without changes
|
pyproject.toml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "ComfyUI"
|
3 |
+
version = "0.3.12"
|
4 |
+
readme = "README.md"
|
5 |
+
license = { file = "LICENSE" }
|
6 |
+
requires-python = ">=3.9"
|
7 |
+
|
8 |
+
[project.urls]
|
9 |
+
homepage = "https://www.comfy.org/"
|
10 |
+
repository = "https://github.com/comfyanonymous/ComfyUI"
|
11 |
+
documentation = "https://docs.comfy.org/"
|
12 |
+
|
13 |
+
[tool.ruff]
|
14 |
+
lint.select = [
|
15 |
+
"S307", # suspicious-eval-usage
|
16 |
+
"S102", # exec
|
17 |
+
"T", # print-usage
|
18 |
+
"W",
|
19 |
+
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
20 |
+
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
21 |
+
"F",
|
22 |
+
]
|
23 |
+
exclude = ["*.ipynb"]
|
pytest.ini
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[pytest]
|
2 |
+
markers =
|
3 |
+
inference: mark as inference test (deselect with '-m "not inference"')
|
4 |
+
execution: mark as execution test (deselect with '-m "not execution"')
|
5 |
+
testpaths =
|
6 |
+
tests
|
7 |
+
tests-unit
|
8 |
+
addopts = -s
|
9 |
+
pythonpath = .
|
requirements.txt
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchsde
|
3 |
+
torchvision
|
4 |
+
torchaudio
|
5 |
+
numpy>=1.25.0
|
6 |
+
einops
|
7 |
+
transformers>=4.28.1
|
8 |
+
tokenizers>=0.13.3
|
9 |
+
sentencepiece
|
10 |
+
safetensors>=0.4.2
|
11 |
+
aiohttp
|
12 |
+
pyyaml
|
13 |
+
Pillow>=9.5.0
|
14 |
+
scipy
|
15 |
+
tqdm
|
16 |
+
psutil
|
17 |
+
gradio
|
18 |
+
huggingface_hub
|
19 |
+
# Base Detectron2
|
20 |
+
git+https://github.com/facebookresearch/[email protected]
|
21 |
+
# DensePose (part of Detectron2 projects)
|
22 |
+
git+https://github.com/facebookresearch/[email protected]#subdirectory=projects/DensePose
|
23 |
+
|
24 |
+
# Other dependencies, except conflicting ones, should be manually added
|
25 |
+
# If you have a modified requirements file, include its contents here, excluding detectron2-related dependencies.
|
26 |
+
#non essential dependencies:
|
27 |
+
kornia>=0.7.1
|
28 |
+
spandrel
|
29 |
+
soundfile
|
script_examples/basic_api_example.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from urllib import request
|
3 |
+
|
4 |
+
#This is the ComfyUI api prompt format.
|
5 |
+
|
6 |
+
#If you want it for a specific workflow you can "enable dev mode options"
|
7 |
+
#in the settings of the UI (gear beside the "Queue Size: ") this will enable
|
8 |
+
#a button on the UI to save workflows in api format.
|
9 |
+
|
10 |
+
#keep in mind ComfyUI is pre alpha software so this format will change a bit.
|
11 |
+
|
12 |
+
#this is the one for the default workflow
|
13 |
+
prompt_text = """
|
14 |
+
{
|
15 |
+
"3": {
|
16 |
+
"class_type": "KSampler",
|
17 |
+
"inputs": {
|
18 |
+
"cfg": 8,
|
19 |
+
"denoise": 1,
|
20 |
+
"latent_image": [
|
21 |
+
"5",
|
22 |
+
0
|
23 |
+
],
|
24 |
+
"model": [
|
25 |
+
"4",
|
26 |
+
0
|
27 |
+
],
|
28 |
+
"negative": [
|
29 |
+
"7",
|
30 |
+
0
|
31 |
+
],
|
32 |
+
"positive": [
|
33 |
+
"6",
|
34 |
+
0
|
35 |
+
],
|
36 |
+
"sampler_name": "euler",
|
37 |
+
"scheduler": "normal",
|
38 |
+
"seed": 8566257,
|
39 |
+
"steps": 20
|
40 |
+
}
|
41 |
+
},
|
42 |
+
"4": {
|
43 |
+
"class_type": "CheckpointLoaderSimple",
|
44 |
+
"inputs": {
|
45 |
+
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
46 |
+
}
|
47 |
+
},
|
48 |
+
"5": {
|
49 |
+
"class_type": "EmptyLatentImage",
|
50 |
+
"inputs": {
|
51 |
+
"batch_size": 1,
|
52 |
+
"height": 512,
|
53 |
+
"width": 512
|
54 |
+
}
|
55 |
+
},
|
56 |
+
"6": {
|
57 |
+
"class_type": "CLIPTextEncode",
|
58 |
+
"inputs": {
|
59 |
+
"clip": [
|
60 |
+
"4",
|
61 |
+
1
|
62 |
+
],
|
63 |
+
"text": "masterpiece best quality girl"
|
64 |
+
}
|
65 |
+
},
|
66 |
+
"7": {
|
67 |
+
"class_type": "CLIPTextEncode",
|
68 |
+
"inputs": {
|
69 |
+
"clip": [
|
70 |
+
"4",
|
71 |
+
1
|
72 |
+
],
|
73 |
+
"text": "bad hands"
|
74 |
+
}
|
75 |
+
},
|
76 |
+
"8": {
|
77 |
+
"class_type": "VAEDecode",
|
78 |
+
"inputs": {
|
79 |
+
"samples": [
|
80 |
+
"3",
|
81 |
+
0
|
82 |
+
],
|
83 |
+
"vae": [
|
84 |
+
"4",
|
85 |
+
2
|
86 |
+
]
|
87 |
+
}
|
88 |
+
},
|
89 |
+
"9": {
|
90 |
+
"class_type": "SaveImage",
|
91 |
+
"inputs": {
|
92 |
+
"filename_prefix": "ComfyUI",
|
93 |
+
"images": [
|
94 |
+
"8",
|
95 |
+
0
|
96 |
+
]
|
97 |
+
}
|
98 |
+
}
|
99 |
+
}
|
100 |
+
"""
|
101 |
+
|
102 |
+
def queue_prompt(prompt):
|
103 |
+
p = {"prompt": prompt}
|
104 |
+
data = json.dumps(p).encode('utf-8')
|
105 |
+
req = request.Request("http://127.0.0.1:8188/prompt", data=data)
|
106 |
+
request.urlopen(req)
|
107 |
+
|
108 |
+
|
109 |
+
prompt = json.loads(prompt_text)
|
110 |
+
#set the text prompt for our positive CLIPTextEncode
|
111 |
+
prompt["6"]["inputs"]["text"] = "masterpiece best quality man"
|
112 |
+
|
113 |
+
#set the seed for our KSampler node
|
114 |
+
prompt["3"]["inputs"]["seed"] = 5
|
115 |
+
|
116 |
+
|
117 |
+
queue_prompt(prompt)
|
118 |
+
|
119 |
+
|
script_examples/websockets_api_example.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#This is an example that uses the websockets api to know when a prompt execution is done
|
2 |
+
#Once the prompt execution is done it downloads the images using the /history endpoint
|
3 |
+
|
4 |
+
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
5 |
+
import uuid
|
6 |
+
import json
|
7 |
+
import urllib.request
|
8 |
+
import urllib.parse
|
9 |
+
|
10 |
+
server_address = "127.0.0.1:8188"
|
11 |
+
client_id = str(uuid.uuid4())
|
12 |
+
|
13 |
+
def queue_prompt(prompt):
|
14 |
+
p = {"prompt": prompt, "client_id": client_id}
|
15 |
+
data = json.dumps(p).encode('utf-8')
|
16 |
+
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
|
17 |
+
return json.loads(urllib.request.urlopen(req).read())
|
18 |
+
|
19 |
+
def get_image(filename, subfolder, folder_type):
|
20 |
+
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
21 |
+
url_values = urllib.parse.urlencode(data)
|
22 |
+
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
|
23 |
+
return response.read()
|
24 |
+
|
25 |
+
def get_history(prompt_id):
|
26 |
+
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
|
27 |
+
return json.loads(response.read())
|
28 |
+
|
29 |
+
def get_images(ws, prompt):
|
30 |
+
prompt_id = queue_prompt(prompt)['prompt_id']
|
31 |
+
output_images = {}
|
32 |
+
while True:
|
33 |
+
out = ws.recv()
|
34 |
+
if isinstance(out, str):
|
35 |
+
message = json.loads(out)
|
36 |
+
if message['type'] == 'executing':
|
37 |
+
data = message['data']
|
38 |
+
if data['node'] is None and data['prompt_id'] == prompt_id:
|
39 |
+
break #Execution is done
|
40 |
+
else:
|
41 |
+
# If you want to be able to decode the binary stream for latent previews, here is how you can do it:
|
42 |
+
# bytesIO = BytesIO(out[8:])
|
43 |
+
# preview_image = Image.open(bytesIO) # This is your preview in PIL image format, store it in a global
|
44 |
+
continue #previews are binary data
|
45 |
+
|
46 |
+
history = get_history(prompt_id)[prompt_id]
|
47 |
+
for node_id in history['outputs']:
|
48 |
+
node_output = history['outputs'][node_id]
|
49 |
+
images_output = []
|
50 |
+
if 'images' in node_output:
|
51 |
+
for image in node_output['images']:
|
52 |
+
image_data = get_image(image['filename'], image['subfolder'], image['type'])
|
53 |
+
images_output.append(image_data)
|
54 |
+
output_images[node_id] = images_output
|
55 |
+
|
56 |
+
return output_images
|
57 |
+
|
58 |
+
prompt_text = """
|
59 |
+
{
|
60 |
+
"3": {
|
61 |
+
"class_type": "KSampler",
|
62 |
+
"inputs": {
|
63 |
+
"cfg": 8,
|
64 |
+
"denoise": 1,
|
65 |
+
"latent_image": [
|
66 |
+
"5",
|
67 |
+
0
|
68 |
+
],
|
69 |
+
"model": [
|
70 |
+
"4",
|
71 |
+
0
|
72 |
+
],
|
73 |
+
"negative": [
|
74 |
+
"7",
|
75 |
+
0
|
76 |
+
],
|
77 |
+
"positive": [
|
78 |
+
"6",
|
79 |
+
0
|
80 |
+
],
|
81 |
+
"sampler_name": "euler",
|
82 |
+
"scheduler": "normal",
|
83 |
+
"seed": 8566257,
|
84 |
+
"steps": 20
|
85 |
+
}
|
86 |
+
},
|
87 |
+
"4": {
|
88 |
+
"class_type": "CheckpointLoaderSimple",
|
89 |
+
"inputs": {
|
90 |
+
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
91 |
+
}
|
92 |
+
},
|
93 |
+
"5": {
|
94 |
+
"class_type": "EmptyLatentImage",
|
95 |
+
"inputs": {
|
96 |
+
"batch_size": 1,
|
97 |
+
"height": 512,
|
98 |
+
"width": 512
|
99 |
+
}
|
100 |
+
},
|
101 |
+
"6": {
|
102 |
+
"class_type": "CLIPTextEncode",
|
103 |
+
"inputs": {
|
104 |
+
"clip": [
|
105 |
+
"4",
|
106 |
+
1
|
107 |
+
],
|
108 |
+
"text": "masterpiece best quality girl"
|
109 |
+
}
|
110 |
+
},
|
111 |
+
"7": {
|
112 |
+
"class_type": "CLIPTextEncode",
|
113 |
+
"inputs": {
|
114 |
+
"clip": [
|
115 |
+
"4",
|
116 |
+
1
|
117 |
+
],
|
118 |
+
"text": "bad hands"
|
119 |
+
}
|
120 |
+
},
|
121 |
+
"8": {
|
122 |
+
"class_type": "VAEDecode",
|
123 |
+
"inputs": {
|
124 |
+
"samples": [
|
125 |
+
"3",
|
126 |
+
0
|
127 |
+
],
|
128 |
+
"vae": [
|
129 |
+
"4",
|
130 |
+
2
|
131 |
+
]
|
132 |
+
}
|
133 |
+
},
|
134 |
+
"9": {
|
135 |
+
"class_type": "SaveImage",
|
136 |
+
"inputs": {
|
137 |
+
"filename_prefix": "ComfyUI",
|
138 |
+
"images": [
|
139 |
+
"8",
|
140 |
+
0
|
141 |
+
]
|
142 |
+
}
|
143 |
+
}
|
144 |
+
}
|
145 |
+
"""
|
146 |
+
|
147 |
+
prompt = json.loads(prompt_text)
|
148 |
+
#set the text prompt for our positive CLIPTextEncode
|
149 |
+
prompt["6"]["inputs"]["text"] = "masterpiece best quality man"
|
150 |
+
|
151 |
+
#set the seed for our KSampler node
|
152 |
+
prompt["3"]["inputs"]["seed"] = 5
|
153 |
+
|
154 |
+
ws = websocket.WebSocket()
|
155 |
+
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
156 |
+
images = get_images(ws, prompt)
|
157 |
+
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
|
158 |
+
#Commented out code to display the output images:
|
159 |
+
|
160 |
+
# for node_id in images:
|
161 |
+
# for image_data in images[node_id]:
|
162 |
+
# from PIL import Image
|
163 |
+
# import io
|
164 |
+
# image = Image.open(io.BytesIO(image_data))
|
165 |
+
# image.show()
|
166 |
+
|
script_examples/websockets_api_example_ws_images.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#This is an example that uses the websockets api and the SaveImageWebsocket node to get images directly without
|
2 |
+
#them being saved to disk
|
3 |
+
|
4 |
+
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
5 |
+
import uuid
|
6 |
+
import json
|
7 |
+
import urllib.request
|
8 |
+
import urllib.parse
|
9 |
+
|
10 |
+
server_address = "127.0.0.1:8188"
|
11 |
+
client_id = str(uuid.uuid4())
|
12 |
+
|
13 |
+
def queue_prompt(prompt):
|
14 |
+
p = {"prompt": prompt, "client_id": client_id}
|
15 |
+
data = json.dumps(p).encode('utf-8')
|
16 |
+
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
|
17 |
+
return json.loads(urllib.request.urlopen(req).read())
|
18 |
+
|
19 |
+
def get_image(filename, subfolder, folder_type):
|
20 |
+
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
21 |
+
url_values = urllib.parse.urlencode(data)
|
22 |
+
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
|
23 |
+
return response.read()
|
24 |
+
|
25 |
+
def get_history(prompt_id):
|
26 |
+
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
|
27 |
+
return json.loads(response.read())
|
28 |
+
|
29 |
+
def get_images(ws, prompt):
|
30 |
+
prompt_id = queue_prompt(prompt)['prompt_id']
|
31 |
+
output_images = {}
|
32 |
+
current_node = ""
|
33 |
+
while True:
|
34 |
+
out = ws.recv()
|
35 |
+
if isinstance(out, str):
|
36 |
+
message = json.loads(out)
|
37 |
+
if message['type'] == 'executing':
|
38 |
+
data = message['data']
|
39 |
+
if data['prompt_id'] == prompt_id:
|
40 |
+
if data['node'] is None:
|
41 |
+
break #Execution is done
|
42 |
+
else:
|
43 |
+
current_node = data['node']
|
44 |
+
else:
|
45 |
+
if current_node == 'save_image_websocket_node':
|
46 |
+
images_output = output_images.get(current_node, [])
|
47 |
+
images_output.append(out[8:])
|
48 |
+
output_images[current_node] = images_output
|
49 |
+
|
50 |
+
return output_images
|
51 |
+
|
52 |
+
prompt_text = """
|
53 |
+
{
|
54 |
+
"3": {
|
55 |
+
"class_type": "KSampler",
|
56 |
+
"inputs": {
|
57 |
+
"cfg": 8,
|
58 |
+
"denoise": 1,
|
59 |
+
"latent_image": [
|
60 |
+
"5",
|
61 |
+
0
|
62 |
+
],
|
63 |
+
"model": [
|
64 |
+
"4",
|
65 |
+
0
|
66 |
+
],
|
67 |
+
"negative": [
|
68 |
+
"7",
|
69 |
+
0
|
70 |
+
],
|
71 |
+
"positive": [
|
72 |
+
"6",
|
73 |
+
0
|
74 |
+
],
|
75 |
+
"sampler_name": "euler",
|
76 |
+
"scheduler": "normal",
|
77 |
+
"seed": 8566257,
|
78 |
+
"steps": 20
|
79 |
+
}
|
80 |
+
},
|
81 |
+
"4": {
|
82 |
+
"class_type": "CheckpointLoaderSimple",
|
83 |
+
"inputs": {
|
84 |
+
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
85 |
+
}
|
86 |
+
},
|
87 |
+
"5": {
|
88 |
+
"class_type": "EmptyLatentImage",
|
89 |
+
"inputs": {
|
90 |
+
"batch_size": 1,
|
91 |
+
"height": 512,
|
92 |
+
"width": 512
|
93 |
+
}
|
94 |
+
},
|
95 |
+
"6": {
|
96 |
+
"class_type": "CLIPTextEncode",
|
97 |
+
"inputs": {
|
98 |
+
"clip": [
|
99 |
+
"4",
|
100 |
+
1
|
101 |
+
],
|
102 |
+
"text": "masterpiece best quality girl"
|
103 |
+
}
|
104 |
+
},
|
105 |
+
"7": {
|
106 |
+
"class_type": "CLIPTextEncode",
|
107 |
+
"inputs": {
|
108 |
+
"clip": [
|
109 |
+
"4",
|
110 |
+
1
|
111 |
+
],
|
112 |
+
"text": "bad hands"
|
113 |
+
}
|
114 |
+
},
|
115 |
+
"8": {
|
116 |
+
"class_type": "VAEDecode",
|
117 |
+
"inputs": {
|
118 |
+
"samples": [
|
119 |
+
"3",
|
120 |
+
0
|
121 |
+
],
|
122 |
+
"vae": [
|
123 |
+
"4",
|
124 |
+
2
|
125 |
+
]
|
126 |
+
}
|
127 |
+
},
|
128 |
+
"save_image_websocket_node": {
|
129 |
+
"class_type": "SaveImageWebsocket",
|
130 |
+
"inputs": {
|
131 |
+
"images": [
|
132 |
+
"8",
|
133 |
+
0
|
134 |
+
]
|
135 |
+
}
|
136 |
+
}
|
137 |
+
}
|
138 |
+
"""
|
139 |
+
|
140 |
+
prompt = json.loads(prompt_text)
|
141 |
+
#set the text prompt for our positive CLIPTextEncode
|
142 |
+
prompt["6"]["inputs"]["text"] = "masterpiece best quality man"
|
143 |
+
|
144 |
+
#set the seed for our KSampler node
|
145 |
+
prompt["3"]["inputs"]["seed"] = 5
|
146 |
+
|
147 |
+
ws = websocket.WebSocket()
|
148 |
+
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
149 |
+
images = get_images(ws, prompt)
|
150 |
+
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
|
151 |
+
#Commented out code to display the output images:
|
152 |
+
|
153 |
+
# for node_id in images:
|
154 |
+
# for image_data in images[node_id]:
|
155 |
+
# from PIL import Image
|
156 |
+
# import io
|
157 |
+
# image = Image.open(io.BytesIO(image_data))
|
158 |
+
# image.show()
|
159 |
+
|
utils/__init__.py
ADDED
File without changes
|
utils/extra_config.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import folder_paths
|
4 |
+
import logging
|
5 |
+
|
6 |
+
def load_extra_path_config(yaml_path):
|
7 |
+
with open(yaml_path, 'r') as stream:
|
8 |
+
config = yaml.safe_load(stream)
|
9 |
+
yaml_dir = os.path.dirname(os.path.abspath(yaml_path))
|
10 |
+
for c in config:
|
11 |
+
conf = config[c]
|
12 |
+
if conf is None:
|
13 |
+
continue
|
14 |
+
base_path = None
|
15 |
+
if "base_path" in conf:
|
16 |
+
base_path = conf.pop("base_path")
|
17 |
+
base_path = os.path.expandvars(os.path.expanduser(base_path))
|
18 |
+
if not os.path.isabs(base_path):
|
19 |
+
base_path = os.path.abspath(os.path.join(yaml_dir, base_path))
|
20 |
+
is_default = False
|
21 |
+
if "is_default" in conf:
|
22 |
+
is_default = conf.pop("is_default")
|
23 |
+
for x in conf:
|
24 |
+
for y in conf[x].split("\n"):
|
25 |
+
if len(y) == 0:
|
26 |
+
continue
|
27 |
+
full_path = y
|
28 |
+
if base_path:
|
29 |
+
full_path = os.path.join(base_path, full_path)
|
30 |
+
elif not os.path.isabs(full_path):
|
31 |
+
full_path = os.path.abspath(os.path.join(yaml_dir, y))
|
32 |
+
logging.info("Adding extra search path {} {}".format(x, full_path))
|
33 |
+
folder_paths.add_model_folder_path(x, full_path, is_default)
|