plaidam commited on
Commit
b0b4e57
·
verified ·
1 Parent(s): 1f785e4

Upload 42 files

Browse files
Files changed (42) hide show
  1. .ci/update_windows/update.py +146 -0
  2. .ci/update_windows/update_comfyui.bat +8 -0
  3. .ci/update_windows/update_comfyui_stable.bat +8 -0
  4. .ci/windows_base_files/README_VERY_IMPORTANT.txt +31 -0
  5. .ci/windows_base_files/run_cpu.bat +2 -0
  6. .ci/windows_base_files/run_nvidia_gpu.bat +2 -0
  7. .ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat +2 -0
  8. CODEOWNERS +23 -0
  9. CONTRIBUTING.md +41 -0
  10. api_server/__init__.py +0 -0
  11. api_server/routes/__init__.py +0 -0
  12. api_server/routes/internal/README.md +3 -0
  13. api_server/routes/internal/__init__.py +0 -0
  14. api_server/routes/internal/internal_routes.py +75 -0
  15. api_server/services/__init__.py +0 -0
  16. api_server/services/file_service.py +13 -0
  17. api_server/services/terminal_service.py +60 -0
  18. api_server/utils/file_operations.py +42 -0
  19. app.py +421 -0
  20. comfy_execution/caching.py +318 -0
  21. comfy_execution/graph.py +270 -0
  22. comfy_execution/graph_utils.py +139 -0
  23. comfy_execution/validation.py +39 -0
  24. comfyui_version.py +3 -0
  25. cuda_malloc.py +90 -0
  26. extra_model_paths.yaml.example +47 -0
  27. fix_torch.py +28 -0
  28. folder_paths.py +385 -0
  29. latent_preview.py +105 -0
  30. main.py +301 -0
  31. new_updater.py +35 -0
  32. node_helpers.py +37 -0
  33. notebooks/comfyui_colab.ipynb +322 -0
  34. output/_output_images_will_be_put_here +0 -0
  35. pyproject.toml +23 -0
  36. pytest.ini +9 -0
  37. requirements.txt +29 -0
  38. script_examples/basic_api_example.py +119 -0
  39. script_examples/websockets_api_example.py +166 -0
  40. script_examples/websockets_api_example_ws_images.py +159 -0
  41. utils/__init__.py +0 -0
  42. utils/extra_config.py +33 -0
.ci/update_windows/update.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pygit2
2
+ from datetime import datetime
3
+ import sys
4
+ import os
5
+ import shutil
6
+ import filecmp
7
+
8
+ def pull(repo, remote_name='origin', branch='master'):
9
+ for remote in repo.remotes:
10
+ if remote.name == remote_name:
11
+ remote.fetch()
12
+ remote_master_id = repo.lookup_reference('refs/remotes/origin/%s' % (branch)).target
13
+ merge_result, _ = repo.merge_analysis(remote_master_id)
14
+ # Up to date, do nothing
15
+ if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE:
16
+ return
17
+ # We can just fastforward
18
+ elif merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD:
19
+ repo.checkout_tree(repo.get(remote_master_id))
20
+ try:
21
+ master_ref = repo.lookup_reference('refs/heads/%s' % (branch))
22
+ master_ref.set_target(remote_master_id)
23
+ except KeyError:
24
+ repo.create_branch(branch, repo.get(remote_master_id))
25
+ repo.head.set_target(remote_master_id)
26
+ elif merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL:
27
+ repo.merge(remote_master_id)
28
+
29
+ if repo.index.conflicts is not None:
30
+ for conflict in repo.index.conflicts:
31
+ print('Conflicts found in:', conflict[0].path) # noqa: T201
32
+ raise AssertionError('Conflicts, ahhhhh!!')
33
+
34
+ user = repo.default_signature
35
+ tree = repo.index.write_tree()
36
+ repo.create_commit('HEAD',
37
+ user,
38
+ user,
39
+ 'Merge!',
40
+ tree,
41
+ [repo.head.target, remote_master_id])
42
+ # We need to do this or git CLI will think we are still merging.
43
+ repo.state_cleanup()
44
+ else:
45
+ raise AssertionError('Unknown merge analysis result')
46
+
47
+ pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0)
48
+ repo_path = str(sys.argv[1])
49
+ repo = pygit2.Repository(repo_path)
50
+ ident = pygit2.Signature('comfyui', 'comfy@ui')
51
+ try:
52
+ print("stashing current changes") # noqa: T201
53
+ repo.stash(ident)
54
+ except KeyError:
55
+ print("nothing to stash") # noqa: T201
56
+ backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
57
+ print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
58
+ try:
59
+ repo.branches.local.create(backup_branch_name, repo.head.peel())
60
+ except:
61
+ pass
62
+
63
+ print("checking out master branch") # noqa: T201
64
+ branch = repo.lookup_branch('master')
65
+ if branch is None:
66
+ ref = repo.lookup_reference('refs/remotes/origin/master')
67
+ repo.checkout(ref)
68
+ branch = repo.lookup_branch('master')
69
+ if branch is None:
70
+ repo.create_branch('master', repo.get(ref.target))
71
+ else:
72
+ ref = repo.lookup_reference(branch.name)
73
+ repo.checkout(ref)
74
+
75
+ print("pulling latest changes") # noqa: T201
76
+ pull(repo)
77
+
78
+ if "--stable" in sys.argv:
79
+ def latest_tag(repo):
80
+ versions = []
81
+ for k in repo.references:
82
+ try:
83
+ prefix = "refs/tags/v"
84
+ if k.startswith(prefix):
85
+ version = list(map(int, k[len(prefix):].split(".")))
86
+ versions.append((version[0] * 10000000000 + version[1] * 100000 + version[2], k))
87
+ except:
88
+ pass
89
+ versions.sort()
90
+ if len(versions) > 0:
91
+ return versions[-1][1]
92
+ return None
93
+ latest_tag = latest_tag(repo)
94
+ if latest_tag is not None:
95
+ repo.checkout(latest_tag)
96
+
97
+ print("Done!") # noqa: T201
98
+
99
+ self_update = True
100
+ if len(sys.argv) > 2:
101
+ self_update = '--skip_self_update' not in sys.argv
102
+
103
+ update_py_path = os.path.realpath(__file__)
104
+ repo_update_py_path = os.path.join(repo_path, ".ci/update_windows/update.py")
105
+
106
+ cur_path = os.path.dirname(update_py_path)
107
+
108
+
109
+ req_path = os.path.join(cur_path, "current_requirements.txt")
110
+ repo_req_path = os.path.join(repo_path, "requirements.txt")
111
+
112
+
113
+ def files_equal(file1, file2):
114
+ try:
115
+ return filecmp.cmp(file1, file2, shallow=False)
116
+ except:
117
+ return False
118
+
119
+ def file_size(f):
120
+ try:
121
+ return os.path.getsize(f)
122
+ except:
123
+ return 0
124
+
125
+
126
+ if self_update and not files_equal(update_py_path, repo_update_py_path) and file_size(repo_update_py_path) > 10:
127
+ shutil.copy(repo_update_py_path, os.path.join(cur_path, "update_new.py"))
128
+ exit()
129
+
130
+ if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path):
131
+ import subprocess
132
+ try:
133
+ subprocess.check_call([sys.executable, '-s', '-m', 'pip', 'install', '-r', repo_req_path])
134
+ shutil.copy(repo_req_path, req_path)
135
+ except:
136
+ pass
137
+
138
+
139
+ stable_update_script = os.path.join(repo_path, ".ci/update_windows/update_comfyui_stable.bat")
140
+ stable_update_script_to = os.path.join(cur_path, "update_comfyui_stable.bat")
141
+
142
+ try:
143
+ if not file_size(stable_update_script_to) > 10:
144
+ shutil.copy(stable_update_script, stable_update_script_to)
145
+ except:
146
+ pass
.ci/update_windows/update_comfyui.bat ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ ..\python_embeded\python.exe .\update.py ..\ComfyUI\
3
+ if exist update_new.py (
4
+ move /y update_new.py update.py
5
+ echo Running updater again since it got updated.
6
+ ..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update
7
+ )
8
+ if "%~1"=="" pause
.ci/update_windows/update_comfyui_stable.bat ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ ..\python_embeded\python.exe .\update.py ..\ComfyUI\ --stable
3
+ if exist update_new.py (
4
+ move /y update_new.py update.py
5
+ echo Running updater again since it got updated.
6
+ ..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update --stable
7
+ )
8
+ if "%~1"=="" pause
.ci/windows_base_files/README_VERY_IMPORTANT.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ HOW TO RUN:
2
+
3
+ if you have a NVIDIA gpu:
4
+
5
+ run_nvidia_gpu.bat
6
+
7
+
8
+
9
+ To run it in slow CPU mode:
10
+
11
+ run_cpu.bat
12
+
13
+
14
+
15
+ IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
16
+
17
+ You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
18
+
19
+
20
+ RECOMMENDED WAY TO UPDATE:
21
+ To update the ComfyUI code: update\update_comfyui.bat
22
+
23
+
24
+
25
+ To update ComfyUI with the python dependencies, note that you should ONLY run this if you have issues with python dependencies.
26
+ update\update_comfyui_and_python_dependencies.bat
27
+
28
+
29
+ TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI:
30
+ In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
31
+ Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.
.ci/windows_base_files/run_cpu.bat ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .\python_embeded\python.exe -s ComfyUI\main.py --cpu --windows-standalone-build
2
+ pause
.ci/windows_base_files/run_nvidia_gpu.bat ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
2
+ pause
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
2
+ pause
CODEOWNERS ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Admins
2
+ * @comfyanonymous
3
+
4
+ # Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
5
+ # Inlined the team members for now.
6
+
7
+ # Maintainers
8
+ *.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
9
+ /tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
10
+ /tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
11
+ /notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
12
+ /script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
13
+ /.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
14
+
15
+ # Python web server
16
+ /api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
17
+ /app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
18
+
19
+ # Frontend assets
20
+ /web/ @huchenlei @webfiltered @pythongosssss @yoland68 @robinjhuang
21
+
22
+ # Extra nodes
23
+ /comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink
CONTRIBUTING.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to ComfyUI
2
+
3
+ Welcome, and thank you for your interest in contributing to ComfyUI!
4
+
5
+ There are several ways in which you can contribute, beyond writing code. The goal of this document is to provide a high-level overview of how you can get involved.
6
+
7
+ ## Asking Questions
8
+
9
+ Have a question? Instead of opening an issue, please ask on [Discord](https://comfy.org/discord) or [Matrix](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) channels. Our team and the community will help you.
10
+
11
+ ## Providing Feedback
12
+
13
+ Your comments and feedback are welcome, and the development team is available via a handful of different channels.
14
+
15
+ See the `#bug-report`, `#feature-request` and `#feedback` channels on Discord.
16
+
17
+ ## Reporting Issues
18
+
19
+ Have you identified a reproducible problem in ComfyUI? Do you have a feature request? We want to hear about it! Here's how you can report your issue as effectively as possible.
20
+
21
+
22
+ ### Look For an Existing Issue
23
+
24
+ Before you create a new issue, please do a search in [open issues](https://github.com/comfyanonymous/ComfyUI/issues) to see if the issue or feature request has already been filed.
25
+
26
+ If you find your issue already exists, make relevant comments and add your [reaction](https://github.com/blog/2119-add-reactions-to-pull-requests-issues-and-comments). Use a reaction in place of a "+1" comment:
27
+
28
+ * 👍 - upvote
29
+ * 👎 - downvote
30
+
31
+ If you cannot find an existing issue that describes your bug or feature, create a new issue. We have an issue template in place to organize new issues.
32
+
33
+
34
+ ### Creating Pull Requests
35
+
36
+ * Please refer to the article on [creating pull requests](https://github.com/comfyanonymous/ComfyUI/wiki/How-to-Contribute-Code) and contributing to this project.
37
+
38
+
39
+ ## Thank You
40
+
41
+ Your contributions to open source, large or small, make great projects like this possible. Thank you for taking the time to contribute.
api_server/__init__.py ADDED
File without changes
api_server/routes/__init__.py ADDED
File without changes
api_server/routes/internal/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # ComfyUI Internal Routes
2
+
3
+ All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.
api_server/routes/internal/__init__.py ADDED
File without changes
api_server/routes/internal/internal_routes.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from aiohttp import web
2
+ from typing import Optional
3
+ from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
4
+ from api_server.services.file_service import FileService
5
+ from api_server.services.terminal_service import TerminalService
6
+ import app.logger
7
+
8
+ class InternalRoutes:
9
+ '''
10
+ The top level web router for internal routes: /internal/*
11
+ The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
12
+ Check README.md for more information.
13
+ '''
14
+
15
+ def __init__(self, prompt_server):
16
+ self.routes: web.RouteTableDef = web.RouteTableDef()
17
+ self._app: Optional[web.Application] = None
18
+ self.file_service = FileService({
19
+ "models": models_dir,
20
+ "user": user_directory,
21
+ "output": output_directory
22
+ })
23
+ self.prompt_server = prompt_server
24
+ self.terminal_service = TerminalService(prompt_server)
25
+
26
+ def setup_routes(self):
27
+ @self.routes.get('/files')
28
+ async def list_files(request):
29
+ directory_key = request.query.get('directory', '')
30
+ try:
31
+ file_list = self.file_service.list_files(directory_key)
32
+ return web.json_response({"files": file_list})
33
+ except ValueError as e:
34
+ return web.json_response({"error": str(e)}, status=400)
35
+ except Exception as e:
36
+ return web.json_response({"error": str(e)}, status=500)
37
+
38
+ @self.routes.get('/logs')
39
+ async def get_logs(request):
40
+ return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
41
+
42
+ @self.routes.get('/logs/raw')
43
+ async def get_raw_logs(request):
44
+ self.terminal_service.update_size()
45
+ return web.json_response({
46
+ "entries": list(app.logger.get_logs()),
47
+ "size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
48
+ })
49
+
50
+ @self.routes.patch('/logs/subscribe')
51
+ async def subscribe_logs(request):
52
+ json_data = await request.json()
53
+ client_id = json_data["clientId"]
54
+ enabled = json_data["enabled"]
55
+ if enabled:
56
+ self.terminal_service.subscribe(client_id)
57
+ else:
58
+ self.terminal_service.unsubscribe(client_id)
59
+
60
+ return web.Response(status=200)
61
+
62
+
63
+ @self.routes.get('/folder_paths')
64
+ async def get_folder_paths(request):
65
+ response = {}
66
+ for key in folder_names_and_paths:
67
+ response[key] = folder_names_and_paths[key][0]
68
+ return web.json_response(response)
69
+
70
+ def get_app(self):
71
+ if self._app is None:
72
+ self._app = web.Application()
73
+ self.setup_routes()
74
+ self._app.add_routes(self.routes)
75
+ return self._app
api_server/services/__init__.py ADDED
File without changes
api_server/services/file_service.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+ from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
3
+
4
+ class FileService:
5
+ def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
6
+ self.allowed_directories: Dict[str, str] = allowed_directories
7
+ self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
8
+
9
+ def list_files(self, directory_key: str) -> List[FileSystemItem]:
10
+ if directory_key not in self.allowed_directories:
11
+ raise ValueError("Invalid directory key")
12
+ directory_path: str = self.allowed_directories[directory_key]
13
+ return self.file_system_ops.walk_directory(directory_path)
api_server/services/terminal_service.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.logger import on_flush
2
+ import os
3
+ import shutil
4
+
5
+
6
+ class TerminalService:
7
+ def __init__(self, server):
8
+ self.server = server
9
+ self.cols = None
10
+ self.rows = None
11
+ self.subscriptions = set()
12
+ on_flush(self.send_messages)
13
+
14
+ def get_terminal_size(self):
15
+ try:
16
+ size = os.get_terminal_size()
17
+ return (size.columns, size.lines)
18
+ except OSError:
19
+ try:
20
+ size = shutil.get_terminal_size()
21
+ return (size.columns, size.lines)
22
+ except OSError:
23
+ return (80, 24) # fallback to 80x24
24
+
25
+ def update_size(self):
26
+ columns, lines = self.get_terminal_size()
27
+ changed = False
28
+
29
+ if columns != self.cols:
30
+ self.cols = columns
31
+ changed = True
32
+
33
+ if lines != self.rows:
34
+ self.rows = lines
35
+ changed = True
36
+
37
+ if changed:
38
+ return {"cols": self.cols, "rows": self.rows}
39
+
40
+ return None
41
+
42
+ def subscribe(self, client_id):
43
+ self.subscriptions.add(client_id)
44
+
45
+ def unsubscribe(self, client_id):
46
+ self.subscriptions.discard(client_id)
47
+
48
+ def send_messages(self, entries):
49
+ if not len(entries) or not len(self.subscriptions):
50
+ return
51
+
52
+ new_size = self.update_size()
53
+
54
+ for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
55
+ if client_id not in self.server.sockets:
56
+ # Automatically unsub if the socket has disconnected
57
+ self.unsubscribe(client_id)
58
+ continue
59
+
60
+ self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)
api_server/utils/file_operations.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Union, TypedDict, Literal
3
+ from typing_extensions import TypeGuard
4
+ class FileInfo(TypedDict):
5
+ name: str
6
+ path: str
7
+ type: Literal["file"]
8
+ size: int
9
+
10
+ class DirectoryInfo(TypedDict):
11
+ name: str
12
+ path: str
13
+ type: Literal["directory"]
14
+
15
+ FileSystemItem = Union[FileInfo, DirectoryInfo]
16
+
17
+ def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
18
+ return item["type"] == "file"
19
+
20
+ class FileSystemOperations:
21
+ @staticmethod
22
+ def walk_directory(directory: str) -> List[FileSystemItem]:
23
+ file_list: List[FileSystemItem] = []
24
+ for root, dirs, files in os.walk(directory):
25
+ for name in files:
26
+ file_path = os.path.join(root, name)
27
+ relative_path = os.path.relpath(file_path, directory)
28
+ file_list.append({
29
+ "name": name,
30
+ "path": relative_path,
31
+ "type": "file",
32
+ "size": os.path.getsize(file_path)
33
+ })
34
+ for name in dirs:
35
+ dir_path = os.path.join(root, name)
36
+ relative_path = os.path.relpath(dir_path, directory)
37
+ file_list.append({
38
+ "name": name,
39
+ "path": relative_path,
40
+ "type": "directory"
41
+ })
42
+ return file_list
app.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #########################
2
+ # app.py for ZeroGPU #
3
+ #########################
4
+
5
+ import os
6
+ import sys
7
+ import random
8
+ import torch
9
+ import gradio as gr
10
+ import spaces # for ZeroGPU usage
11
+ from typing import Sequence, Mapping, Any, Union
12
+
13
+ # 1) Load your token from environment (make sure you set HF_TOKEN in Space settings)
14
+ token = os.environ["HF_TOKEN"]
15
+
16
+ # 2) We'll use huggingface_hub to download each gated model
17
+ from huggingface_hub import hf_hub_download
18
+
19
+ import shutil
20
+ import pathlib
21
+
22
+ # Create the directories we need (mirroring your ComfyUI structure)
23
+ pathlib.Path("ComfyUI/models/vae").mkdir(parents=True, exist_ok=True)
24
+ pathlib.Path("ComfyUI/models/clip").mkdir(parents=True, exist_ok=True)
25
+ pathlib.Path("ComfyUI/models/clip_vision").mkdir(parents=True, exist_ok=True)
26
+ pathlib.Path("ComfyUI/models/unet").mkdir(parents=True, exist_ok=True)
27
+ pathlib.Path("ComfyUI/models/loras").mkdir(parents=True, exist_ok=True)
28
+ pathlib.Path("ComfyUI/models/style_models").mkdir(parents=True, exist_ok=True)
29
+
30
+ # Download each gated model into the correct local folder
31
+
32
+ hf_hub_download(
33
+ repo_id="black-forest-labs/FLUX.1-dev",
34
+ filename="ae.safetensors",
35
+ local_dir="ComfyUI/models/vae",
36
+ use_auth_token=token
37
+ )
38
+
39
+ hf_hub_download(
40
+ repo_id="comfyanonymous/flux_text_encoders",
41
+ filename="t5xxl_fp16.safetensors",
42
+ local_dir="ComfyUI/models/clip",
43
+ use_auth_token=token
44
+ )
45
+
46
+ hf_hub_download(
47
+ repo_id="comfyanonymous/flux_text_encoders",
48
+ filename="clip_l.safetensors",
49
+ local_dir="ComfyUI/models/clip",
50
+ use_auth_token=token
51
+ )
52
+
53
+ hf_hub_download(
54
+ repo_id="black-forest-labs/FLUX.1-Fill-dev",
55
+ filename="flux1-fill-dev.safetensors",
56
+ local_dir="ComfyUI/models/unet",
57
+ use_auth_token=token
58
+ )
59
+
60
+ hf_hub_download(
61
+ repo_id="zhengchong/CatVTON",
62
+ filename="flux-lora/pytorch_lora_weights.safetensors",
63
+ local_dir="ComfyUI/models/loras",
64
+ # rename so it matches your code reference:
65
+ local_fname="catvton-flux-lora.safetensors",
66
+ use_auth_token=token
67
+ )
68
+
69
+ hf_hub_download(
70
+ repo_id="alimama-creative/FLUX.1-Turbo-Alpha",
71
+ filename="diffusion_pytorch_model.safetensors",
72
+ local_dir="ComfyUI/models/loras",
73
+ # rename so it matches your code reference:
74
+ local_fname="alimama-flux-turbo-alpha.safetensors",
75
+ use_auth_token=token
76
+ )
77
+
78
+ hf_hub_download(
79
+ repo_id="Comfy-Org/sigclip_vision_384",
80
+ filename="sigclip_vision_patch14_384.safetensors",
81
+ local_dir="ComfyUI/models/clip_vision",
82
+ use_auth_token=token
83
+ )
84
+
85
+ hf_hub_download(
86
+ repo_id="black-forest-labs/FLUX.1-Redux-dev",
87
+ filename="flux1-redux-dev.safetensors",
88
+ local_dir="ComfyUI/models/style_models",
89
+ use_auth_token=token
90
+ )
91
+
92
+ #############################
93
+ # ComfyUI support functions
94
+ #############################
95
+
96
+ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
97
+ """Returns the value at the given index of a sequence or mapping."""
98
+ try:
99
+ return obj[index]
100
+ except KeyError:
101
+ return obj["result"][index]
102
+
103
+ def find_path(name: str, path: str = None) -> str:
104
+ import os
105
+ if path is None:
106
+ path = os.getcwd()
107
+
108
+ if name in os.listdir(path):
109
+ path_name = os.path.join(path, name)
110
+ print(f"{name} found: {path_name}")
111
+ return path_name
112
+
113
+ parent_directory = os.path.dirname(path)
114
+ if parent_directory == path:
115
+ return None
116
+
117
+ return find_path(name, parent_directory)
118
+
119
+ def add_comfyui_directory_to_sys_path() -> None:
120
+ comfyui_path = find_path("ComfyUI")
121
+ if comfyui_path is not None and os.path.isdir(comfyui_path):
122
+ sys.path.append(comfyui_path)
123
+ print(f"'{comfyui_path}' added to sys.path")
124
+
125
+ def add_extra_model_paths() -> None:
126
+ try:
127
+ from main import load_extra_path_config
128
+ except ImportError:
129
+ print("Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead.")
130
+ from utils.extra_config import load_extra_path_config
131
+
132
+ extra_model_paths = find_path("extra_model_paths.yaml")
133
+ if extra_model_paths is not None:
134
+ load_extra_path_config(extra_model_paths)
135
+ else:
136
+ print("Could not find the extra_model_paths config file.")
137
+
138
+ add_comfyui_directory_to_sys_path()
139
+ add_extra_model_paths()
140
+
141
+ def import_custom_nodes() -> None:
142
+ import asyncio
143
+ import execution
144
+ from nodes import init_extra_nodes
145
+ import server
146
+
147
+ loop = asyncio.new_event_loop()
148
+ asyncio.set_event_loop(loop)
149
+ server_instance = server.PromptServer(loop)
150
+ execution.PromptQueue(server_instance)
151
+ init_extra_nodes()
152
+
153
+ ##########################
154
+ # Import node mappings
155
+ ##########################
156
+ from nodes import NODE_CLASS_MAPPINGS
157
+
158
+
159
+ #############################################
160
+ # MAIN PIPELINE with ZeroGPU Decorator
161
+ #############################################
162
+ @spaces.GPU(duration=90) # 90s of GPU usage; adjust as needed
163
+ def generate_images(user_image_path):
164
+ """
165
+ This function runs your node-based pipeline,
166
+ using user_image_path for loadimage_264 and
167
+ returning the final saveimage_295 path.
168
+ """
169
+ import_custom_nodes()
170
+ with torch.inference_mode():
171
+ loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
172
+ loadimage_116 = loadimage.load_image(image="assets_black_tshirt.png")
173
+
174
+ loadautomasker = NODE_CLASS_MAPPINGS["LoadAutoMasker"]()
175
+ loadautomasker_120 = loadautomasker.load(catvton_path="zhengchong/CatVTON")
176
+
177
+ loadcatvtonpipeline = NODE_CLASS_MAPPINGS["LoadCatVTONPipeline"]()
178
+ loadcatvtonpipeline_123 = loadcatvtonpipeline.load(
179
+ sd15_inpaint_path="runwayml/stable-diffusion-inpainting",
180
+ catvton_path="zhengchong/CatVTON",
181
+ mixed_precision="bf16",
182
+ )
183
+
184
+ loadimage_264 = loadimage.load_image(
185
+ image=user_image_path
186
+ )
187
+
188
+ randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
189
+ randomnoise_273 = randomnoise.get_noise(noise_seed=random.randint(1, 2**64))
190
+
191
+ downloadandloadflorence2model = NODE_CLASS_MAPPINGS["DownloadAndLoadFlorence2Model"]()
192
+ downloadandloadflorence2model_274 = downloadandloadflorence2model.loadmodel(
193
+ model="gokaygokay/Florence-2-Flux-Large", precision="fp16", attention="sdpa"
194
+ )
195
+
196
+ automasker = NODE_CLASS_MAPPINGS["AutoMasker"]()
197
+ automasker_119 = automasker.generate(
198
+ cloth_type="overall",
199
+ pipe=get_value_at_index(loadautomasker_120, 0),
200
+ target_image=get_value_at_index(loadimage_264, 0),
201
+ )
202
+
203
+ catvton = NODE_CLASS_MAPPINGS["CatVTON"]()
204
+ catvton_121 = catvton.generate(
205
+ seed=random.randint(1, 2**64),
206
+ steps=50,
207
+ cfg=2.5,
208
+ pipe=get_value_at_index(loadcatvtonpipeline_123, 0),
209
+ target_image=get_value_at_index(loadimage_264, 0),
210
+ refer_image=get_value_at_index(loadimage_116, 0),
211
+ mask_image=get_value_at_index(automasker_119, 0),
212
+ )
213
+
214
+ florence2run = NODE_CLASS_MAPPINGS["Florence2Run"]()
215
+ florence2run_275 = florence2run.encode(
216
+ text_input="Haircut",
217
+ task="caption_to_phrase_grounding",
218
+ fill_mask=True,
219
+ keep_model_loaded=False,
220
+ max_new_tokens=1024,
221
+ num_beams=3,
222
+ do_sample=True,
223
+ output_mask_select="",
224
+ seed=random.randint(1, 2**64),
225
+ image=get_value_at_index(catvton_121, 0),
226
+ florence2_model=get_value_at_index(downloadandloadflorence2model_274, 0),
227
+ )
228
+
229
+ downloadandloadsam2model = NODE_CLASS_MAPPINGS["DownloadAndLoadSAM2Model"]()
230
+ downloadandloadsam2model_277 = downloadandloadsam2model.loadmodel(
231
+ model="sam2.1_hiera_large.safetensors",
232
+ segmentor="single_image",
233
+ device="cuda",
234
+ precision="fp16",
235
+ )
236
+
237
+ dualcliploadergguf = NODE_CLASS_MAPPINGS["DualCLIPLoaderGGUF"]()
238
+ dualcliploadergguf_284 = dualcliploadergguf.load_clip(
239
+ clip_name1="t5xxl_fp16.safetensors",
240
+ clip_name2="clip_l.safetensors",
241
+ type="flux",
242
+ )
243
+
244
+ cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
245
+ cliptextencode_279 = cliptextencode.encode(
246
+ text="Br0k0L8, Broccoli haircut with voluminous, textured curls on top resembling broccoli florets, contrasted by closely shaved tapered sides",
247
+ clip=get_value_at_index(dualcliploadergguf_284, 0),
248
+ )
249
+
250
+ clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
251
+ clipvisionloader_281 = clipvisionloader.load_clip(
252
+ clip_name="sigclip_vision_patch14_384.safetensors"
253
+ )
254
+
255
+ loadimage_289 = loadimage.load_image(image="assets_broc_ref.jpg")
256
+
257
+ clipvisionencode = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]()
258
+ clipvisionencode_282 = clipvisionencode.encode(
259
+ crop="center",
260
+ clip_vision=get_value_at_index(clipvisionloader_281, 0),
261
+ image=get_value_at_index(loadimage_289, 0),
262
+ )
263
+
264
+ vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
265
+ vaeloader_285 = vaeloader.load_vae(vae_name="ae.safetensors")
266
+
267
+ stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
268
+ stylemodelloader_292 = stylemodelloader.load_style_model(
269
+ style_model_name="flux1-redux-dev.safetensors"
270
+ )
271
+
272
+ stylemodelapply = NODE_CLASS_MAPPINGS["StyleModelApply"]()
273
+ stylemodelapply_280 = stylemodelapply.apply_stylemodel(
274
+ strength=1,
275
+ strength_type="multiply",
276
+ conditioning=get_value_at_index(cliptextencode_279, 0),
277
+ style_model=get_value_at_index(stylemodelloader_292, 0),
278
+ clip_vision_output=get_value_at_index(clipvisionencode_282, 0),
279
+ )
280
+
281
+ fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
282
+ fluxguidance_288 = fluxguidance.append(
283
+ guidance=30, conditioning=get_value_at_index(stylemodelapply_280, 0)
284
+ )
285
+
286
+ conditioningzeroout = NODE_CLASS_MAPPINGS["ConditioningZeroOut"]()
287
+ conditioningzeroout_287 = conditioningzeroout.zero_out(
288
+ conditioning=get_value_at_index(fluxguidance_288, 0)
289
+ )
290
+
291
+ florence2tocoordinates = NODE_CLASS_MAPPINGS["Florence2toCoordinates"]()
292
+ florence2tocoordinates_276 = florence2tocoordinates.segment(
293
+ index="", batch=False, data=get_value_at_index(florence2run_275, 3)
294
+ )
295
+
296
+ sam2segmentation = NODE_CLASS_MAPPINGS["Sam2Segmentation"]()
297
+ sam2segmentation_278 = sam2segmentation.segment(
298
+ keep_model_loaded=False,
299
+ individual_objects=False,
300
+ sam2_model=get_value_at_index(downloadandloadsam2model_277, 0),
301
+ image=get_value_at_index(florence2run_275, 0),
302
+ bboxes=get_value_at_index(florence2tocoordinates_276, 1),
303
+ )
304
+
305
+ growmask = NODE_CLASS_MAPPINGS["GrowMask"]()
306
+ growmask_299 = growmask.expand_mask(
307
+ expand=35,
308
+ tapered_corners=True,
309
+ mask=get_value_at_index(sam2segmentation_278, 0),
310
+ )
311
+
312
+ layermask_segformerb2clothesultra = NODE_CLASS_MAPPINGS["LayerMask: SegformerB2ClothesUltra"]()
313
+ layermask_segformerb2clothesultra_300 = layermask_segformerb2clothesultra.segformer_ultra(
314
+ face=True,
315
+ hair=False,
316
+ hat=False,
317
+ sunglass=False,
318
+ left_arm=False,
319
+ right_arm=False,
320
+ left_leg=False,
321
+ right_leg=False,
322
+ upper_clothes=True,
323
+ skirt=False,
324
+ pants=False,
325
+ dress=False,
326
+ belt=False,
327
+ shoe=False,
328
+ bag=False,
329
+ scarf=True,
330
+ detail_method="VITMatte",
331
+ detail_erode=12,
332
+ detail_dilate=6,
333
+ black_point=0.15,
334
+ white_point=0.99,
335
+ process_detail=True,
336
+ device="cuda",
337
+ max_megapixels=2,
338
+ image=get_value_at_index(catvton_121, 0),
339
+ )
340
+
341
+ masks_subtract = NODE_CLASS_MAPPINGS["Masks Subtract"]()
342
+ masks_subtract_296 = masks_subtract.subtract_masks(
343
+ masks_a=get_value_at_index(growmask_299, 0),
344
+ masks_b=get_value_at_index(layermask_segformerb2clothesultra_300, 1),
345
+ )
346
+
347
+ inpaintmodelconditioning = NODE_CLASS_MAPPINGS["InpaintModelConditioning"]()
348
+ inpaintmodelconditioning_286 = inpaintmodelconditioning.encode(
349
+ noise_mask=True,
350
+ positive=get_value_at_index(fluxguidance_288, 0),
351
+ negative=get_value_at_index(conditioningzeroout_287, 0),
352
+ vae=get_value_at_index(vaeloader_285, 0),
353
+ pixels=get_value_at_index(catvton_121, 0),
354
+ mask=get_value_at_index(masks_subtract_296, 0),
355
+ )
356
+
357
+ unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
358
+ unetloader_291 = unetloader.load_unet(
359
+ unet_name="flux1-fill-dev.safetensors", weight_dtype="default"
360
+ )
361
+
362
+ loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
363
+ loraloadermodelonly_290 = loraloadermodelonly.load_lora_model_only(
364
+ lora_name="alimama-flux-turbo-alpha.safetensors",
365
+ strength_model=1,
366
+ model=get_value_at_index(unetloader_291, 0),
367
+ )
368
+
369
+ ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
370
+ vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
371
+ saveimage = NODE_CLASS_MAPPINGS["SaveImage"]()
372
+
373
+ # We'll do a single pass
374
+ for q in range(1):
375
+ ksampler_283 = ksampler.sample(
376
+ seed=random.randint(1, 2**64),
377
+ steps=10,
378
+ cfg=1,
379
+ sampler_name="dpmpp_2m",
380
+ scheduler="sgm_uniform",
381
+ denoise=1,
382
+ model=get_value_at_index(loraloadermodelonly_290, 0),
383
+ positive=get_value_at_index(inpaintmodelconditioning_286, 0),
384
+ negative=get_value_at_index(inpaintmodelconditioning_286, 1),
385
+ latent_image=get_value_at_index(inpaintmodelconditioning_286, 2),
386
+ )
387
+
388
+ vaedecode_294 = vaedecode.decode(
389
+ samples=get_value_at_index(ksampler_283, 0),
390
+ vae=get_value_at_index(vaeloader_285, 0),
391
+ )
392
+
393
+ saveimage_295 = saveimage.save_images(
394
+ filename_prefix="The_Broccolator_",
395
+ images=get_value_at_index(vaedecode_294, 0),
396
+ )
397
+
398
+ # final_output_path is the only one we return
399
+ final_output_path = f"output/{saveimage_295['ui']['images'][0]['filename']}"
400
+ return final_output_path
401
+
402
+
403
+ ###################################
404
+ # A simple Gradio interface
405
+ ###################################
406
+ with gr.Blocks() as demo:
407
+ gr.Markdown("## The Broccolator 🥦\nUpload an image for `loadimage_264` and see final output.")
408
+ with gr.Row():
409
+ with gr.Column():
410
+ user_input_image = gr.Image(type="filepath", label="Input Image")
411
+ btn_generate = gr.Button("Generate")
412
+ with gr.Column():
413
+ final_image = gr.Image(label="Final output (saveimage_295)")
414
+
415
+ btn_generate.click(
416
+ fn=generate_images,
417
+ inputs=user_input_image,
418
+ outputs=final_image
419
+ )
420
+
421
+ demo.launch(debug=True)
comfy_execution/caching.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from typing import Sequence, Mapping, Dict
3
+ from comfy_execution.graph import DynamicPrompt
4
+
5
+ import nodes
6
+
7
+ from comfy_execution.graph_utils import is_link
8
+
9
+ NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
10
+
11
+
12
+ def include_unique_id_in_input(class_type: str) -> bool:
13
+ if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
14
+ return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
15
+ class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
16
+ NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
17
+ return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
18
+
19
+ class CacheKeySet:
20
+ def __init__(self, dynprompt, node_ids, is_changed_cache):
21
+ self.keys = {}
22
+ self.subcache_keys = {}
23
+
24
+ def add_keys(self, node_ids):
25
+ raise NotImplementedError()
26
+
27
+ def all_node_ids(self):
28
+ return set(self.keys.keys())
29
+
30
+ def get_used_keys(self):
31
+ return self.keys.values()
32
+
33
+ def get_used_subcache_keys(self):
34
+ return self.subcache_keys.values()
35
+
36
+ def get_data_key(self, node_id):
37
+ return self.keys.get(node_id, None)
38
+
39
+ def get_subcache_key(self, node_id):
40
+ return self.subcache_keys.get(node_id, None)
41
+
42
+ class Unhashable:
43
+ def __init__(self):
44
+ self.value = float("NaN")
45
+
46
+ def to_hashable(obj):
47
+ # So that we don't infinitely recurse since frozenset and tuples
48
+ # are Sequences.
49
+ if isinstance(obj, (int, float, str, bool, type(None))):
50
+ return obj
51
+ elif isinstance(obj, Mapping):
52
+ return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
53
+ elif isinstance(obj, Sequence):
54
+ return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
55
+ else:
56
+ # TODO - Support other objects like tensors?
57
+ return Unhashable()
58
+
59
+ class CacheKeySetID(CacheKeySet):
60
+ def __init__(self, dynprompt, node_ids, is_changed_cache):
61
+ super().__init__(dynprompt, node_ids, is_changed_cache)
62
+ self.dynprompt = dynprompt
63
+ self.add_keys(node_ids)
64
+
65
+ def add_keys(self, node_ids):
66
+ for node_id in node_ids:
67
+ if node_id in self.keys:
68
+ continue
69
+ if not self.dynprompt.has_node(node_id):
70
+ continue
71
+ node = self.dynprompt.get_node(node_id)
72
+ self.keys[node_id] = (node_id, node["class_type"])
73
+ self.subcache_keys[node_id] = (node_id, node["class_type"])
74
+
75
+ class CacheKeySetInputSignature(CacheKeySet):
76
+ def __init__(self, dynprompt, node_ids, is_changed_cache):
77
+ super().__init__(dynprompt, node_ids, is_changed_cache)
78
+ self.dynprompt = dynprompt
79
+ self.is_changed_cache = is_changed_cache
80
+ self.add_keys(node_ids)
81
+
82
+ def include_node_id_in_input(self) -> bool:
83
+ return False
84
+
85
+ def add_keys(self, node_ids):
86
+ for node_id in node_ids:
87
+ if node_id in self.keys:
88
+ continue
89
+ if not self.dynprompt.has_node(node_id):
90
+ continue
91
+ node = self.dynprompt.get_node(node_id)
92
+ self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
93
+ self.subcache_keys[node_id] = (node_id, node["class_type"])
94
+
95
+ def get_node_signature(self, dynprompt, node_id):
96
+ signature = []
97
+ ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
98
+ signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
99
+ for ancestor_id in ancestors:
100
+ signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
101
+ return to_hashable(signature)
102
+
103
+ def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
104
+ if not dynprompt.has_node(node_id):
105
+ # This node doesn't exist -- we can't cache it.
106
+ return [float("NaN")]
107
+ node = dynprompt.get_node(node_id)
108
+ class_type = node["class_type"]
109
+ class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
110
+ signature = [class_type, self.is_changed_cache.get(node_id)]
111
+ if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
112
+ signature.append(node_id)
113
+ inputs = node["inputs"]
114
+ for key in sorted(inputs.keys()):
115
+ if is_link(inputs[key]):
116
+ (ancestor_id, ancestor_socket) = inputs[key]
117
+ ancestor_index = ancestor_order_mapping[ancestor_id]
118
+ signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
119
+ else:
120
+ signature.append((key, inputs[key]))
121
+ return signature
122
+
123
+ # This function returns a list of all ancestors of the given node. The order of the list is
124
+ # deterministic based on which specific inputs the ancestor is connected by.
125
+ def get_ordered_ancestry(self, dynprompt, node_id):
126
+ ancestors = []
127
+ order_mapping = {}
128
+ self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
129
+ return ancestors, order_mapping
130
+
131
+ def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
132
+ if not dynprompt.has_node(node_id):
133
+ return
134
+ inputs = dynprompt.get_node(node_id)["inputs"]
135
+ input_keys = sorted(inputs.keys())
136
+ for key in input_keys:
137
+ if is_link(inputs[key]):
138
+ ancestor_id = inputs[key][0]
139
+ if ancestor_id not in order_mapping:
140
+ ancestors.append(ancestor_id)
141
+ order_mapping[ancestor_id] = len(ancestors) - 1
142
+ self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
143
+
144
+ class BasicCache:
145
+ def __init__(self, key_class):
146
+ self.key_class = key_class
147
+ self.initialized = False
148
+ self.dynprompt: DynamicPrompt
149
+ self.cache_key_set: CacheKeySet
150
+ self.cache = {}
151
+ self.subcaches = {}
152
+
153
+ def set_prompt(self, dynprompt, node_ids, is_changed_cache):
154
+ self.dynprompt = dynprompt
155
+ self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
156
+ self.is_changed_cache = is_changed_cache
157
+ self.initialized = True
158
+
159
+ def all_node_ids(self):
160
+ assert self.initialized
161
+ node_ids = self.cache_key_set.all_node_ids()
162
+ for subcache in self.subcaches.values():
163
+ node_ids = node_ids.union(subcache.all_node_ids())
164
+ return node_ids
165
+
166
+ def _clean_cache(self):
167
+ preserve_keys = set(self.cache_key_set.get_used_keys())
168
+ to_remove = []
169
+ for key in self.cache:
170
+ if key not in preserve_keys:
171
+ to_remove.append(key)
172
+ for key in to_remove:
173
+ del self.cache[key]
174
+
175
+ def _clean_subcaches(self):
176
+ preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
177
+
178
+ to_remove = []
179
+ for key in self.subcaches:
180
+ if key not in preserve_subcaches:
181
+ to_remove.append(key)
182
+ for key in to_remove:
183
+ del self.subcaches[key]
184
+
185
+ def clean_unused(self):
186
+ assert self.initialized
187
+ self._clean_cache()
188
+ self._clean_subcaches()
189
+
190
+ def _set_immediate(self, node_id, value):
191
+ assert self.initialized
192
+ cache_key = self.cache_key_set.get_data_key(node_id)
193
+ self.cache[cache_key] = value
194
+
195
+ def _get_immediate(self, node_id):
196
+ if not self.initialized:
197
+ return None
198
+ cache_key = self.cache_key_set.get_data_key(node_id)
199
+ if cache_key in self.cache:
200
+ return self.cache[cache_key]
201
+ else:
202
+ return None
203
+
204
+ def _ensure_subcache(self, node_id, children_ids):
205
+ subcache_key = self.cache_key_set.get_subcache_key(node_id)
206
+ subcache = self.subcaches.get(subcache_key, None)
207
+ if subcache is None:
208
+ subcache = BasicCache(self.key_class)
209
+ self.subcaches[subcache_key] = subcache
210
+ subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
211
+ return subcache
212
+
213
+ def _get_subcache(self, node_id):
214
+ assert self.initialized
215
+ subcache_key = self.cache_key_set.get_subcache_key(node_id)
216
+ if subcache_key in self.subcaches:
217
+ return self.subcaches[subcache_key]
218
+ else:
219
+ return None
220
+
221
+ def recursive_debug_dump(self):
222
+ result = []
223
+ for key in self.cache:
224
+ result.append({"key": key, "value": self.cache[key]})
225
+ for key in self.subcaches:
226
+ result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
227
+ return result
228
+
229
+ class HierarchicalCache(BasicCache):
230
+ def __init__(self, key_class):
231
+ super().__init__(key_class)
232
+
233
+ def _get_cache_for(self, node_id):
234
+ assert self.dynprompt is not None
235
+ parent_id = self.dynprompt.get_parent_node_id(node_id)
236
+ if parent_id is None:
237
+ return self
238
+
239
+ hierarchy = []
240
+ while parent_id is not None:
241
+ hierarchy.append(parent_id)
242
+ parent_id = self.dynprompt.get_parent_node_id(parent_id)
243
+
244
+ cache = self
245
+ for parent_id in reversed(hierarchy):
246
+ cache = cache._get_subcache(parent_id)
247
+ if cache is None:
248
+ return None
249
+ return cache
250
+
251
+ def get(self, node_id):
252
+ cache = self._get_cache_for(node_id)
253
+ if cache is None:
254
+ return None
255
+ return cache._get_immediate(node_id)
256
+
257
+ def set(self, node_id, value):
258
+ cache = self._get_cache_for(node_id)
259
+ assert cache is not None
260
+ cache._set_immediate(node_id, value)
261
+
262
+ def ensure_subcache_for(self, node_id, children_ids):
263
+ cache = self._get_cache_for(node_id)
264
+ assert cache is not None
265
+ return cache._ensure_subcache(node_id, children_ids)
266
+
267
+ class LRUCache(BasicCache):
268
+ def __init__(self, key_class, max_size=100):
269
+ super().__init__(key_class)
270
+ self.max_size = max_size
271
+ self.min_generation = 0
272
+ self.generation = 0
273
+ self.used_generation = {}
274
+ self.children = {}
275
+
276
+ def set_prompt(self, dynprompt, node_ids, is_changed_cache):
277
+ super().set_prompt(dynprompt, node_ids, is_changed_cache)
278
+ self.generation += 1
279
+ for node_id in node_ids:
280
+ self._mark_used(node_id)
281
+
282
+ def clean_unused(self):
283
+ while len(self.cache) > self.max_size and self.min_generation < self.generation:
284
+ self.min_generation += 1
285
+ to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
286
+ for key in to_remove:
287
+ del self.cache[key]
288
+ del self.used_generation[key]
289
+ if key in self.children:
290
+ del self.children[key]
291
+ self._clean_subcaches()
292
+
293
+ def get(self, node_id):
294
+ self._mark_used(node_id)
295
+ return self._get_immediate(node_id)
296
+
297
+ def _mark_used(self, node_id):
298
+ cache_key = self.cache_key_set.get_data_key(node_id)
299
+ if cache_key is not None:
300
+ self.used_generation[cache_key] = self.generation
301
+
302
+ def set(self, node_id, value):
303
+ self._mark_used(node_id)
304
+ return self._set_immediate(node_id, value)
305
+
306
+ def ensure_subcache_for(self, node_id, children_ids):
307
+ # Just uses subcaches for tracking 'live' nodes
308
+ super()._ensure_subcache(node_id, children_ids)
309
+
310
+ self.cache_key_set.add_keys(children_ids)
311
+ self._mark_used(node_id)
312
+ cache_key = self.cache_key_set.get_data_key(node_id)
313
+ self.children[cache_key] = []
314
+ for child_id in children_ids:
315
+ self._mark_used(child_id)
316
+ self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
317
+ return self
318
+
comfy_execution/graph.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nodes
2
+
3
+ from comfy_execution.graph_utils import is_link
4
+
5
+ class DependencyCycleError(Exception):
6
+ pass
7
+
8
+ class NodeInputError(Exception):
9
+ pass
10
+
11
+ class NodeNotFoundError(Exception):
12
+ pass
13
+
14
+ class DynamicPrompt:
15
+ def __init__(self, original_prompt):
16
+ # The original prompt provided by the user
17
+ self.original_prompt = original_prompt
18
+ # Any extra pieces of the graph created during execution
19
+ self.ephemeral_prompt = {}
20
+ self.ephemeral_parents = {}
21
+ self.ephemeral_display = {}
22
+
23
+ def get_node(self, node_id):
24
+ if node_id in self.ephemeral_prompt:
25
+ return self.ephemeral_prompt[node_id]
26
+ if node_id in self.original_prompt:
27
+ return self.original_prompt[node_id]
28
+ raise NodeNotFoundError(f"Node {node_id} not found")
29
+
30
+ def has_node(self, node_id):
31
+ return node_id in self.original_prompt or node_id in self.ephemeral_prompt
32
+
33
+ def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
34
+ self.ephemeral_prompt[node_id] = node_info
35
+ self.ephemeral_parents[node_id] = parent_id
36
+ self.ephemeral_display[node_id] = display_id
37
+
38
+ def get_real_node_id(self, node_id):
39
+ while node_id in self.ephemeral_parents:
40
+ node_id = self.ephemeral_parents[node_id]
41
+ return node_id
42
+
43
+ def get_parent_node_id(self, node_id):
44
+ return self.ephemeral_parents.get(node_id, None)
45
+
46
+ def get_display_node_id(self, node_id):
47
+ while node_id in self.ephemeral_display:
48
+ node_id = self.ephemeral_display[node_id]
49
+ return node_id
50
+
51
+ def all_node_ids(self):
52
+ return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
53
+
54
+ def get_original_prompt(self):
55
+ return self.original_prompt
56
+
57
+ def get_input_info(class_def, input_name, valid_inputs=None):
58
+ valid_inputs = valid_inputs or class_def.INPUT_TYPES()
59
+ input_info = None
60
+ input_category = None
61
+ if "required" in valid_inputs and input_name in valid_inputs["required"]:
62
+ input_category = "required"
63
+ input_info = valid_inputs["required"][input_name]
64
+ elif "optional" in valid_inputs and input_name in valid_inputs["optional"]:
65
+ input_category = "optional"
66
+ input_info = valid_inputs["optional"][input_name]
67
+ elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]:
68
+ input_category = "hidden"
69
+ input_info = valid_inputs["hidden"][input_name]
70
+ if input_info is None:
71
+ return None, None, None
72
+ input_type = input_info[0]
73
+ if len(input_info) > 1:
74
+ extra_info = input_info[1]
75
+ else:
76
+ extra_info = {}
77
+ return input_type, input_category, extra_info
78
+
79
+ class TopologicalSort:
80
+ def __init__(self, dynprompt):
81
+ self.dynprompt = dynprompt
82
+ self.pendingNodes = {}
83
+ self.blockCount = {} # Number of nodes this node is directly blocked by
84
+ self.blocking = {} # Which nodes are blocked by this node
85
+
86
+ def get_input_info(self, unique_id, input_name):
87
+ class_type = self.dynprompt.get_node(unique_id)["class_type"]
88
+ class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
89
+ return get_input_info(class_def, input_name)
90
+
91
+ def make_input_strong_link(self, to_node_id, to_input):
92
+ inputs = self.dynprompt.get_node(to_node_id)["inputs"]
93
+ if to_input not in inputs:
94
+ raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all")
95
+ value = inputs[to_input]
96
+ if not is_link(value):
97
+ raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant")
98
+ from_node_id, from_socket = value
99
+ self.add_strong_link(from_node_id, from_socket, to_node_id)
100
+
101
+ def add_strong_link(self, from_node_id, from_socket, to_node_id):
102
+ if not self.is_cached(from_node_id):
103
+ self.add_node(from_node_id)
104
+ if to_node_id not in self.blocking[from_node_id]:
105
+ self.blocking[from_node_id][to_node_id] = {}
106
+ self.blockCount[to_node_id] += 1
107
+ self.blocking[from_node_id][to_node_id][from_socket] = True
108
+
109
+ def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
110
+ node_ids = [node_unique_id]
111
+ links = []
112
+
113
+ while len(node_ids) > 0:
114
+ unique_id = node_ids.pop()
115
+ if unique_id in self.pendingNodes:
116
+ continue
117
+
118
+ self.pendingNodes[unique_id] = True
119
+ self.blockCount[unique_id] = 0
120
+ self.blocking[unique_id] = {}
121
+
122
+ inputs = self.dynprompt.get_node(unique_id)["inputs"]
123
+ for input_name in inputs:
124
+ value = inputs[input_name]
125
+ if is_link(value):
126
+ from_node_id, from_socket = value
127
+ if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
128
+ continue
129
+ input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
130
+ is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
131
+ if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
132
+ node_ids.append(from_node_id)
133
+ links.append((from_node_id, from_socket, unique_id))
134
+
135
+ for link in links:
136
+ self.add_strong_link(*link)
137
+
138
+ def is_cached(self, node_id):
139
+ return False
140
+
141
+ def get_ready_nodes(self):
142
+ return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
143
+
144
+ def pop_node(self, unique_id):
145
+ del self.pendingNodes[unique_id]
146
+ for blocked_node_id in self.blocking[unique_id]:
147
+ self.blockCount[blocked_node_id] -= 1
148
+ del self.blocking[unique_id]
149
+
150
+ def is_empty(self):
151
+ return len(self.pendingNodes) == 0
152
+
153
+ class ExecutionList(TopologicalSort):
154
+ """
155
+ ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
156
+ it can still be returned to the graph after having further dependencies added.
157
+ """
158
+ def __init__(self, dynprompt, output_cache):
159
+ super().__init__(dynprompt)
160
+ self.output_cache = output_cache
161
+ self.staged_node_id = None
162
+
163
+ def is_cached(self, node_id):
164
+ return self.output_cache.get(node_id) is not None
165
+
166
+ def stage_node_execution(self):
167
+ assert self.staged_node_id is None
168
+ if self.is_empty():
169
+ return None, None, None
170
+ available = self.get_ready_nodes()
171
+ if len(available) == 0:
172
+ cycled_nodes = self.get_nodes_in_cycle()
173
+ # Because cycles composed entirely of static nodes are caught during initial validation,
174
+ # we will 'blame' the first node in the cycle that is not a static node.
175
+ blamed_node = cycled_nodes[0]
176
+ for node_id in cycled_nodes:
177
+ display_node_id = self.dynprompt.get_display_node_id(node_id)
178
+ if display_node_id != node_id:
179
+ blamed_node = display_node_id
180
+ break
181
+ ex = DependencyCycleError("Dependency cycle detected")
182
+ error_details = {
183
+ "node_id": blamed_node,
184
+ "exception_message": str(ex),
185
+ "exception_type": "graph.DependencyCycleError",
186
+ "traceback": [],
187
+ "current_inputs": []
188
+ }
189
+ return None, error_details, ex
190
+
191
+ self.staged_node_id = self.ux_friendly_pick_node(available)
192
+ return self.staged_node_id, None, None
193
+
194
+ def ux_friendly_pick_node(self, node_list):
195
+ # If an output node is available, do that first.
196
+ # Technically this has no effect on the overall length of execution, but it feels better as a user
197
+ # for a PreviewImage to display a result as soon as it can
198
+ # Some other heuristics could probably be used here to improve the UX further.
199
+ def is_output(node_id):
200
+ class_type = self.dynprompt.get_node(node_id)["class_type"]
201
+ class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
202
+ if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
203
+ return True
204
+ return False
205
+
206
+ for node_id in node_list:
207
+ if is_output(node_id):
208
+ return node_id
209
+
210
+ #This should handle the VAEDecode -> preview case
211
+ for node_id in node_list:
212
+ for blocked_node_id in self.blocking[node_id]:
213
+ if is_output(blocked_node_id):
214
+ return node_id
215
+
216
+ #This should handle the VAELoader -> VAEDecode -> preview case
217
+ for node_id in node_list:
218
+ for blocked_node_id in self.blocking[node_id]:
219
+ for blocked_node_id1 in self.blocking[blocked_node_id]:
220
+ if is_output(blocked_node_id1):
221
+ return node_id
222
+
223
+ #TODO: this function should be improved
224
+ return node_list[0]
225
+
226
+ def unstage_node_execution(self):
227
+ assert self.staged_node_id is not None
228
+ self.staged_node_id = None
229
+
230
+ def complete_node_execution(self):
231
+ node_id = self.staged_node_id
232
+ self.pop_node(node_id)
233
+ self.staged_node_id = None
234
+
235
+ def get_nodes_in_cycle(self):
236
+ # We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle.
237
+ # We're skipping some of the performance optimizations from the original TopologicalSort to keep
238
+ # the code simple (and because having a cycle in the first place is a catastrophic error)
239
+ blocked_by = { node_id: {} for node_id in self.pendingNodes }
240
+ for from_node_id in self.blocking:
241
+ for to_node_id in self.blocking[from_node_id]:
242
+ if True in self.blocking[from_node_id][to_node_id].values():
243
+ blocked_by[to_node_id][from_node_id] = True
244
+ to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
245
+ while len(to_remove) > 0:
246
+ for node_id in to_remove:
247
+ for to_node_id in blocked_by:
248
+ if node_id in blocked_by[to_node_id]:
249
+ del blocked_by[to_node_id][node_id]
250
+ del blocked_by[node_id]
251
+ to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
252
+ return list(blocked_by.keys())
253
+
254
+ class ExecutionBlocker:
255
+ """
256
+ Return this from a node and any users will be blocked with the given error message.
257
+ If the message is None, execution will be blocked silently instead.
258
+ Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
259
+ possible, a lazy input will be more efficient and have a better user experience.
260
+ This functionality is useful in two cases:
261
+ 1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
262
+ like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
263
+ lazy evaluation to let it conditionally disable itself.)
264
+ 2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
265
+ (I would recommend not making nodes like this in the future -- instead, make multiple nodes with
266
+ different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
267
+ """
268
+ def __init__(self, message):
269
+ self.message = message
270
+
comfy_execution/graph_utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def is_link(obj):
2
+ if not isinstance(obj, list):
3
+ return False
4
+ if len(obj) != 2:
5
+ return False
6
+ if not isinstance(obj[0], str):
7
+ return False
8
+ if not isinstance(obj[1], int) and not isinstance(obj[1], float):
9
+ return False
10
+ return True
11
+
12
+ # The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
13
+ class GraphBuilder:
14
+ _default_prefix_root = ""
15
+ _default_prefix_call_index = 0
16
+ _default_prefix_graph_index = 0
17
+
18
+ def __init__(self, prefix = None):
19
+ if prefix is None:
20
+ self.prefix = GraphBuilder.alloc_prefix()
21
+ else:
22
+ self.prefix = prefix
23
+ self.nodes = {}
24
+ self.id_gen = 1
25
+
26
+ @classmethod
27
+ def set_default_prefix(cls, prefix_root, call_index, graph_index = 0):
28
+ cls._default_prefix_root = prefix_root
29
+ cls._default_prefix_call_index = call_index
30
+ cls._default_prefix_graph_index = graph_index
31
+
32
+ @classmethod
33
+ def alloc_prefix(cls, root=None, call_index=None, graph_index=None):
34
+ if root is None:
35
+ root = GraphBuilder._default_prefix_root
36
+ if call_index is None:
37
+ call_index = GraphBuilder._default_prefix_call_index
38
+ if graph_index is None:
39
+ graph_index = GraphBuilder._default_prefix_graph_index
40
+ result = f"{root}.{call_index}.{graph_index}."
41
+ GraphBuilder._default_prefix_graph_index += 1
42
+ return result
43
+
44
+ def node(self, class_type, id=None, **kwargs):
45
+ if id is None:
46
+ id = str(self.id_gen)
47
+ self.id_gen += 1
48
+ id = self.prefix + id
49
+ if id in self.nodes:
50
+ return self.nodes[id]
51
+
52
+ node = Node(id, class_type, kwargs)
53
+ self.nodes[id] = node
54
+ return node
55
+
56
+ def lookup_node(self, id):
57
+ id = self.prefix + id
58
+ return self.nodes.get(id)
59
+
60
+ def finalize(self):
61
+ output = {}
62
+ for node_id, node in self.nodes.items():
63
+ output[node_id] = node.serialize()
64
+ return output
65
+
66
+ def replace_node_output(self, node_id, index, new_value):
67
+ node_id = self.prefix + node_id
68
+ to_remove = []
69
+ for node in self.nodes.values():
70
+ for key, value in node.inputs.items():
71
+ if is_link(value) and value[0] == node_id and value[1] == index:
72
+ if new_value is None:
73
+ to_remove.append((node, key))
74
+ else:
75
+ node.inputs[key] = new_value
76
+ for node, key in to_remove:
77
+ del node.inputs[key]
78
+
79
+ def remove_node(self, id):
80
+ id = self.prefix + id
81
+ del self.nodes[id]
82
+
83
+ class Node:
84
+ def __init__(self, id, class_type, inputs):
85
+ self.id = id
86
+ self.class_type = class_type
87
+ self.inputs = inputs
88
+ self.override_display_id = None
89
+
90
+ def out(self, index):
91
+ return [self.id, index]
92
+
93
+ def set_input(self, key, value):
94
+ if value is None:
95
+ if key in self.inputs:
96
+ del self.inputs[key]
97
+ else:
98
+ self.inputs[key] = value
99
+
100
+ def get_input(self, key):
101
+ return self.inputs.get(key)
102
+
103
+ def set_override_display_id(self, override_display_id):
104
+ self.override_display_id = override_display_id
105
+
106
+ def serialize(self):
107
+ serialized = {
108
+ "class_type": self.class_type,
109
+ "inputs": self.inputs
110
+ }
111
+ if self.override_display_id is not None:
112
+ serialized["override_display_id"] = self.override_display_id
113
+ return serialized
114
+
115
+ def add_graph_prefix(graph, outputs, prefix):
116
+ # Change the node IDs and any internal links
117
+ new_graph = {}
118
+ for node_id, node_info in graph.items():
119
+ # Make sure the added nodes have unique IDs
120
+ new_node_id = prefix + node_id
121
+ new_node = { "class_type": node_info["class_type"], "inputs": {} }
122
+ for input_name, input_value in node_info.get("inputs", {}).items():
123
+ if is_link(input_value):
124
+ new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]]
125
+ else:
126
+ new_node["inputs"][input_name] = input_value
127
+ new_graph[new_node_id] = new_node
128
+
129
+ # Change the node IDs in the outputs
130
+ new_outputs = []
131
+ for n in range(len(outputs)):
132
+ output = outputs[n]
133
+ if is_link(output):
134
+ new_outputs.append([prefix + output[0], output[1]])
135
+ else:
136
+ new_outputs.append(output)
137
+
138
+ return new_graph, tuple(new_outputs)
139
+
comfy_execution/validation.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+
4
+ def validate_node_input(
5
+ received_type: str, input_type: str, strict: bool = False
6
+ ) -> bool:
7
+ """
8
+ received_type and input_type are both strings of the form "T1,T2,...".
9
+
10
+ If strict is True, the input_type must contain the received_type.
11
+ For example, if received_type is "STRING" and input_type is "STRING,INT",
12
+ this will return True. But if received_type is "STRING,INT" and input_type is
13
+ "INT", this will return False.
14
+
15
+ If strict is False, the input_type must have overlap with the received_type.
16
+ For example, if received_type is "STRING,BOOLEAN" and input_type is "STRING,INT",
17
+ this will return True.
18
+
19
+ Supports pre-union type extension behaviour of ``__ne__`` overrides.
20
+ """
21
+ # If the types are exactly the same, we can return immediately
22
+ # Use pre-union behaviour: inverse of `__ne__`
23
+ if not received_type != input_type:
24
+ return True
25
+
26
+ # Not equal, and not strings
27
+ if not isinstance(received_type, str) or not isinstance(input_type, str):
28
+ return False
29
+
30
+ # Split the type strings into sets for comparison
31
+ received_types = set(t.strip() for t in received_type.split(","))
32
+ input_types = set(t.strip() for t in input_type.split(","))
33
+
34
+ if strict:
35
+ # In strict mode, all received types must be in the input types
36
+ return received_types.issubset(input_types)
37
+ else:
38
+ # In non-strict mode, there must be at least one type in common
39
+ return len(received_types.intersection(input_types)) > 0
comfyui_version.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # This file is automatically generated by the build process when version is
2
+ # updated in pyproject.toml.
3
+ __version__ = "0.3.12"
cuda_malloc.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib.util
3
+ from comfy.cli_args import args
4
+ import subprocess
5
+
6
+ #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
7
+ def get_gpu_names():
8
+ if os.name == 'nt':
9
+ import ctypes
10
+
11
+ # Define necessary C structures and types
12
+ class DISPLAY_DEVICEA(ctypes.Structure):
13
+ _fields_ = [
14
+ ('cb', ctypes.c_ulong),
15
+ ('DeviceName', ctypes.c_char * 32),
16
+ ('DeviceString', ctypes.c_char * 128),
17
+ ('StateFlags', ctypes.c_ulong),
18
+ ('DeviceID', ctypes.c_char * 128),
19
+ ('DeviceKey', ctypes.c_char * 128)
20
+ ]
21
+
22
+ # Load user32.dll
23
+ user32 = ctypes.windll.user32
24
+
25
+ # Call EnumDisplayDevicesA
26
+ def enum_display_devices():
27
+ device_info = DISPLAY_DEVICEA()
28
+ device_info.cb = ctypes.sizeof(device_info)
29
+ device_index = 0
30
+ gpu_names = set()
31
+
32
+ while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0):
33
+ device_index += 1
34
+ gpu_names.add(device_info.DeviceString.decode('utf-8'))
35
+ return gpu_names
36
+ return enum_display_devices()
37
+ else:
38
+ gpu_names = set()
39
+ out = subprocess.check_output(['nvidia-smi', '-L'])
40
+ for l in out.split(b'\n'):
41
+ if len(l) > 0:
42
+ gpu_names.add(l.decode('utf-8').split(' (UUID')[0])
43
+ return gpu_names
44
+
45
+ blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M",
46
+ "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620",
47
+ "Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
48
+ "Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000",
49
+ "GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M",
50
+ "GeForce GTX 1650", "GeForce GTX 1630", "Tesla M4", "Tesla M6", "Tesla M10", "Tesla M40", "Tesla M60"
51
+ }
52
+
53
+ def cuda_malloc_supported():
54
+ try:
55
+ names = get_gpu_names()
56
+ except:
57
+ names = set()
58
+ for x in names:
59
+ if "NVIDIA" in x:
60
+ for b in blacklist:
61
+ if b in x:
62
+ return False
63
+ return True
64
+
65
+
66
+ if not args.cuda_malloc:
67
+ try:
68
+ version = ""
69
+ torch_spec = importlib.util.find_spec("torch")
70
+ for folder in torch_spec.submodule_search_locations:
71
+ ver_file = os.path.join(folder, "version.py")
72
+ if os.path.isfile(ver_file):
73
+ spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
74
+ module = importlib.util.module_from_spec(spec)
75
+ spec.loader.exec_module(module)
76
+ version = module.__version__
77
+ if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
78
+ args.cuda_malloc = cuda_malloc_supported()
79
+ except:
80
+ pass
81
+
82
+
83
+ if args.cuda_malloc and not args.disable_cuda_malloc:
84
+ env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
85
+ if env_var is None:
86
+ env_var = "backend:cudaMallocAsync"
87
+ else:
88
+ env_var += ",backend:cudaMallocAsync"
89
+
90
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
extra_model_paths.yaml.example ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Rename this to extra_model_paths.yaml and ComfyUI will load it
2
+
3
+
4
+ #config for a1111 ui
5
+ #all you have to do is change the base_path to where yours is installed
6
+ a111:
7
+ base_path: path/to/stable-diffusion-webui/
8
+
9
+ checkpoints: models/Stable-diffusion
10
+ configs: models/Stable-diffusion
11
+ vae: models/VAE
12
+ loras: |
13
+ models/Lora
14
+ models/LyCORIS
15
+ upscale_models: |
16
+ models/ESRGAN
17
+ models/RealESRGAN
18
+ models/SwinIR
19
+ embeddings: embeddings
20
+ hypernetworks: models/hypernetworks
21
+ controlnet: models/ControlNet
22
+
23
+ #config for comfyui
24
+ #your base path should be either an existing comfy install or a central folder where you store all of your models, loras, etc.
25
+
26
+ #comfyui:
27
+ # base_path: path/to/comfyui/
28
+ # # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads
29
+ # #is_default: true
30
+ # checkpoints: models/checkpoints/
31
+ # clip: models/clip/
32
+ # clip_vision: models/clip_vision/
33
+ # configs: models/configs/
34
+ # controlnet: models/controlnet/
35
+ # diffusion_models: |
36
+ # models/diffusion_models
37
+ # models/unet
38
+ # embeddings: models/embeddings/
39
+ # loras: models/loras/
40
+ # upscale_models: models/upscale_models/
41
+ # vae: models/vae/
42
+
43
+ #other_ui:
44
+ # base_path: path/to/ui
45
+ # checkpoints: models/checkpoints
46
+ # gligen: models/gligen
47
+ # custom_nodes: path/custom_nodes
fix_torch.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+ import shutil
3
+ import os
4
+ import ctypes
5
+ import logging
6
+
7
+
8
+ def fix_pytorch_libomp():
9
+ """
10
+ Fix PyTorch libomp DLL issue on Windows by copying the correct DLL file if needed.
11
+ """
12
+ torch_spec = importlib.util.find_spec("torch")
13
+ for folder in torch_spec.submodule_search_locations:
14
+ lib_folder = os.path.join(folder, "lib")
15
+ test_file = os.path.join(lib_folder, "fbgemm.dll")
16
+ dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
17
+ if os.path.exists(dest):
18
+ break
19
+
20
+ with open(test_file, "rb") as f:
21
+ contents = f.read()
22
+ if b"libomp140.x86_64.dll" not in contents:
23
+ break
24
+ try:
25
+ ctypes.cdll.LoadLibrary(test_file)
26
+ except FileNotFoundError:
27
+ logging.warning("Detected pytorch version with libomp issue, patching.")
28
+ shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)
folder_paths.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import time
5
+ import mimetypes
6
+ import logging
7
+ from typing import Literal
8
+ from collections.abc import Collection
9
+
10
+ supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
11
+
12
+ folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {}
13
+
14
+ base_path = os.path.dirname(os.path.realpath(__file__))
15
+ models_dir = os.path.join(base_path, "models")
16
+ folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions)
17
+ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
18
+
19
+ folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
20
+ folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
21
+ folder_names_and_paths["text_encoders"] = ([os.path.join(models_dir, "text_encoders"), os.path.join(models_dir, "clip")], supported_pt_extensions)
22
+ folder_names_and_paths["diffusion_models"] = ([os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], supported_pt_extensions)
23
+ folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
24
+ folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
25
+ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
26
+ folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
27
+ folder_names_and_paths["vae_approx"] = ([os.path.join(models_dir, "vae_approx")], supported_pt_extensions)
28
+
29
+ folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
30
+ folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
31
+
32
+ folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
33
+
34
+ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], set())
35
+
36
+ folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
37
+
38
+ folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions)
39
+
40
+ folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
41
+
42
+ output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
43
+ temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
44
+ input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
45
+ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")
46
+
47
+ filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
48
+
49
+ class CacheHelper:
50
+ """
51
+ Helper class for managing file list cache data.
52
+ """
53
+ def __init__(self):
54
+ self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
55
+ self.active = False
56
+
57
+ def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
58
+ if not self.active:
59
+ return default
60
+ return self.cache.get(key, default)
61
+
62
+ def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
63
+ if self.active:
64
+ self.cache[key] = value
65
+
66
+ def clear(self):
67
+ self.cache.clear()
68
+
69
+ def __enter__(self):
70
+ self.active = True
71
+ return self
72
+
73
+ def __exit__(self, exc_type, exc_value, traceback):
74
+ self.active = False
75
+ self.clear()
76
+
77
+ cache_helper = CacheHelper()
78
+
79
+ extension_mimetypes_cache = {
80
+ "webp" : "image",
81
+ }
82
+
83
+ def map_legacy(folder_name: str) -> str:
84
+ legacy = {"unet": "diffusion_models",
85
+ "clip": "text_encoders"}
86
+ return legacy.get(folder_name, folder_name)
87
+
88
+ if not os.path.exists(input_directory):
89
+ try:
90
+ os.makedirs(input_directory)
91
+ except:
92
+ logging.error("Failed to create input directory")
93
+
94
+ def set_output_directory(output_dir: str) -> None:
95
+ global output_directory
96
+ output_directory = output_dir
97
+
98
+ def set_temp_directory(temp_dir: str) -> None:
99
+ global temp_directory
100
+ temp_directory = temp_dir
101
+
102
+ def set_input_directory(input_dir: str) -> None:
103
+ global input_directory
104
+ input_directory = input_dir
105
+
106
+ def get_output_directory() -> str:
107
+ global output_directory
108
+ return output_directory
109
+
110
+ def get_temp_directory() -> str:
111
+ global temp_directory
112
+ return temp_directory
113
+
114
+ def get_input_directory() -> str:
115
+ global input_directory
116
+ return input_directory
117
+
118
+ def get_user_directory() -> str:
119
+ return user_directory
120
+
121
+ def set_user_directory(user_dir: str) -> None:
122
+ global user_directory
123
+ user_directory = user_dir
124
+
125
+
126
+ #NOTE: used in http server so don't put folders that should not be accessed remotely
127
+ def get_directory_by_type(type_name: str) -> str | None:
128
+ if type_name == "output":
129
+ return get_output_directory()
130
+ if type_name == "temp":
131
+ return get_temp_directory()
132
+ if type_name == "input":
133
+ return get_input_directory()
134
+ return None
135
+
136
+ def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]:
137
+ """
138
+ Example:
139
+ files = os.listdir(folder_paths.get_input_directory())
140
+ filter_files_content_types(files, ["image", "audio", "video"])
141
+ """
142
+ global extension_mimetypes_cache
143
+ result = []
144
+ for file in files:
145
+ extension = file.split('.')[-1]
146
+ if extension not in extension_mimetypes_cache:
147
+ mime_type, _ = mimetypes.guess_type(file, strict=False)
148
+ if not mime_type:
149
+ continue
150
+ content_type = mime_type.split('/')[0]
151
+ extension_mimetypes_cache[extension] = content_type
152
+ else:
153
+ content_type = extension_mimetypes_cache[extension]
154
+
155
+ if content_type in content_types:
156
+ result.append(file)
157
+ return result
158
+
159
+ # determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
160
+ # otherwise use default_path as base_dir
161
+ def annotated_filepath(name: str) -> tuple[str, str | None]:
162
+ if name.endswith("[output]"):
163
+ base_dir = get_output_directory()
164
+ name = name[:-9]
165
+ elif name.endswith("[input]"):
166
+ base_dir = get_input_directory()
167
+ name = name[:-8]
168
+ elif name.endswith("[temp]"):
169
+ base_dir = get_temp_directory()
170
+ name = name[:-7]
171
+ else:
172
+ return name, None
173
+
174
+ return name, base_dir
175
+
176
+
177
+ def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
178
+ name, base_dir = annotated_filepath(name)
179
+
180
+ if base_dir is None:
181
+ if default_dir is not None:
182
+ base_dir = default_dir
183
+ else:
184
+ base_dir = get_input_directory() # fallback path
185
+
186
+ return os.path.join(base_dir, name)
187
+
188
+
189
+ def exists_annotated_filepath(name) -> bool:
190
+ name, base_dir = annotated_filepath(name)
191
+
192
+ if base_dir is None:
193
+ base_dir = get_input_directory() # fallback path
194
+
195
+ filepath = os.path.join(base_dir, name)
196
+ return os.path.exists(filepath)
197
+
198
+
199
+ def add_model_folder_path(folder_name: str, full_folder_path: str, is_default: bool = False) -> None:
200
+ global folder_names_and_paths
201
+ folder_name = map_legacy(folder_name)
202
+ if folder_name in folder_names_and_paths:
203
+ paths, _exts = folder_names_and_paths[folder_name]
204
+ if full_folder_path in paths:
205
+ if is_default and paths[0] != full_folder_path:
206
+ # If the path to the folder is not the first in the list, move it to the beginning.
207
+ paths.remove(full_folder_path)
208
+ paths.insert(0, full_folder_path)
209
+ else:
210
+ if is_default:
211
+ paths.insert(0, full_folder_path)
212
+ else:
213
+ paths.append(full_folder_path)
214
+ else:
215
+ folder_names_and_paths[folder_name] = ([full_folder_path], set())
216
+
217
+ def get_folder_paths(folder_name: str) -> list[str]:
218
+ folder_name = map_legacy(folder_name)
219
+ return folder_names_and_paths[folder_name][0][:]
220
+
221
+ def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]:
222
+ if not os.path.isdir(directory):
223
+ return [], {}
224
+
225
+ if excluded_dir_names is None:
226
+ excluded_dir_names = []
227
+
228
+ result = []
229
+ dirs = {}
230
+
231
+ # Attempt to add the initial directory to dirs with error handling
232
+ try:
233
+ dirs[directory] = os.path.getmtime(directory)
234
+ except FileNotFoundError:
235
+ logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")
236
+
237
+ logging.debug("recursive file list on directory {}".format(directory))
238
+ dirpath: str
239
+ subdirs: list[str]
240
+ filenames: list[str]
241
+
242
+ for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
243
+ subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
244
+ for file_name in filenames:
245
+ try:
246
+ relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
247
+ result.append(relative_path)
248
+ except:
249
+ logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
250
+ continue
251
+
252
+ for d in subdirs:
253
+ path: str = os.path.join(dirpath, d)
254
+ try:
255
+ dirs[path] = os.path.getmtime(path)
256
+ except FileNotFoundError:
257
+ logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
258
+ continue
259
+ logging.debug("found {} files".format(len(result)))
260
+ return result, dirs
261
+
262
+ def filter_files_extensions(files: Collection[str], extensions: Collection[str]) -> list[str]:
263
+ return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))
264
+
265
+
266
+
267
+ def get_full_path(folder_name: str, filename: str) -> str | None:
268
+ global folder_names_and_paths
269
+ folder_name = map_legacy(folder_name)
270
+ if folder_name not in folder_names_and_paths:
271
+ return None
272
+ folders = folder_names_and_paths[folder_name]
273
+ filename = os.path.relpath(os.path.join("/", filename), "/")
274
+ for x in folders[0]:
275
+ full_path = os.path.join(x, filename)
276
+ if os.path.isfile(full_path):
277
+ return full_path
278
+ elif os.path.islink(full_path):
279
+ logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path))
280
+
281
+ return None
282
+
283
+
284
+ def get_full_path_or_raise(folder_name: str, filename: str) -> str:
285
+ full_path = get_full_path(folder_name, filename)
286
+ if full_path is None:
287
+ raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
288
+ return full_path
289
+
290
+
291
+ def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
292
+ folder_name = map_legacy(folder_name)
293
+ global folder_names_and_paths
294
+ output_list = set()
295
+ folders = folder_names_and_paths[folder_name]
296
+ output_folders = {}
297
+ for x in folders[0]:
298
+ files, folders_all = recursive_search(x, excluded_dir_names=[".git"])
299
+ output_list.update(filter_files_extensions(files, folders[1]))
300
+ output_folders = {**output_folders, **folders_all}
301
+
302
+ return sorted(list(output_list)), output_folders, time.perf_counter()
303
+
304
+ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
305
+ strong_cache = cache_helper.get(folder_name)
306
+ if strong_cache is not None:
307
+ return strong_cache
308
+
309
+ global filename_list_cache
310
+ global folder_names_and_paths
311
+ folder_name = map_legacy(folder_name)
312
+ if folder_name not in filename_list_cache:
313
+ return None
314
+ out = filename_list_cache[folder_name]
315
+
316
+ for x in out[1]:
317
+ time_modified = out[1][x]
318
+ folder = x
319
+ if os.path.getmtime(folder) != time_modified:
320
+ return None
321
+
322
+ folders = folder_names_and_paths[folder_name]
323
+ for x in folders[0]:
324
+ if os.path.isdir(x):
325
+ if x not in out[1]:
326
+ return None
327
+
328
+ return out
329
+
330
+ def get_filename_list(folder_name: str) -> list[str]:
331
+ folder_name = map_legacy(folder_name)
332
+ out = cached_filename_list_(folder_name)
333
+ if out is None:
334
+ out = get_filename_list_(folder_name)
335
+ global filename_list_cache
336
+ filename_list_cache[folder_name] = out
337
+ cache_helper.set(folder_name, out)
338
+ return list(out[0])
339
+
340
+ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
341
+ def map_filename(filename: str) -> tuple[int, str]:
342
+ prefix_len = len(os.path.basename(filename_prefix))
343
+ prefix = filename[:prefix_len + 1]
344
+ try:
345
+ digits = int(filename[prefix_len + 1:].split('_')[0])
346
+ except:
347
+ digits = 0
348
+ return digits, prefix
349
+
350
+ def compute_vars(input: str, image_width: int, image_height: int) -> str:
351
+ input = input.replace("%width%", str(image_width))
352
+ input = input.replace("%height%", str(image_height))
353
+ now = time.localtime()
354
+ input = input.replace("%year%", str(now.tm_year))
355
+ input = input.replace("%month%", str(now.tm_mon).zfill(2))
356
+ input = input.replace("%day%", str(now.tm_mday).zfill(2))
357
+ input = input.replace("%hour%", str(now.tm_hour).zfill(2))
358
+ input = input.replace("%minute%", str(now.tm_min).zfill(2))
359
+ input = input.replace("%second%", str(now.tm_sec).zfill(2))
360
+ return input
361
+
362
+ if "%" in filename_prefix:
363
+ filename_prefix = compute_vars(filename_prefix, image_width, image_height)
364
+
365
+ subfolder = os.path.dirname(os.path.normpath(filename_prefix))
366
+ filename = os.path.basename(os.path.normpath(filename_prefix))
367
+
368
+ full_output_folder = os.path.join(output_dir, subfolder)
369
+
370
+ if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
371
+ err = "**** ERROR: Saving image outside the output folder is not allowed." + \
372
+ "\n full_output_folder: " + os.path.abspath(full_output_folder) + \
373
+ "\n output_dir: " + output_dir + \
374
+ "\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))
375
+ logging.error(err)
376
+ raise Exception(err)
377
+
378
+ try:
379
+ counter = max(filter(lambda a: os.path.normcase(a[1][:-1]) == os.path.normcase(filename) and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
380
+ except ValueError:
381
+ counter = 1
382
+ except FileNotFoundError:
383
+ os.makedirs(full_output_folder, exist_ok=True)
384
+ counter = 1
385
+ return full_output_folder, filename, counter, subfolder, filename_prefix
latent_preview.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from comfy.cli_args import args, LatentPreviewMethod
4
+ from comfy.taesd.taesd import TAESD
5
+ import comfy.model_management
6
+ import folder_paths
7
+ import comfy.utils
8
+ import logging
9
+
10
+ MAX_PREVIEW_RESOLUTION = args.preview_size
11
+
12
+ def preview_to_image(latent_image):
13
+ latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
14
+ .mul(0xFF) # to 0..255
15
+ ).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
16
+
17
+ return Image.fromarray(latents_ubyte.numpy())
18
+
19
+ class LatentPreviewer:
20
+ def decode_latent_to_preview(self, x0):
21
+ pass
22
+
23
+ def decode_latent_to_preview_image(self, preview_format, x0):
24
+ preview_image = self.decode_latent_to_preview(x0)
25
+ return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
26
+
27
+ class TAESDPreviewerImpl(LatentPreviewer):
28
+ def __init__(self, taesd):
29
+ self.taesd = taesd
30
+
31
+ def decode_latent_to_preview(self, x0):
32
+ x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2)
33
+ return preview_to_image(x_sample)
34
+
35
+
36
+ class Latent2RGBPreviewer(LatentPreviewer):
37
+ def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
38
+ self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
39
+ self.latent_rgb_factors_bias = None
40
+ if latent_rgb_factors_bias is not None:
41
+ self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
42
+
43
+ def decode_latent_to_preview(self, x0):
44
+ self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
45
+ if self.latent_rgb_factors_bias is not None:
46
+ self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
47
+
48
+ if x0.ndim == 5:
49
+ x0 = x0[0, :, 0]
50
+ else:
51
+ x0 = x0[0]
52
+
53
+ latent_image = torch.nn.functional.linear(x0.movedim(0, -1), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
54
+ # latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
55
+
56
+ return preview_to_image(latent_image)
57
+
58
+
59
+ def get_previewer(device, latent_format):
60
+ previewer = None
61
+ method = args.preview_method
62
+ if method != LatentPreviewMethod.NoPreviews:
63
+ # TODO previewer methods
64
+ taesd_decoder_path = None
65
+ if latent_format.taesd_decoder_name is not None:
66
+ taesd_decoder_path = next(
67
+ (fn for fn in folder_paths.get_filename_list("vae_approx")
68
+ if fn.startswith(latent_format.taesd_decoder_name)),
69
+ ""
70
+ )
71
+ taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
72
+
73
+ if method == LatentPreviewMethod.Auto:
74
+ method = LatentPreviewMethod.Latent2RGB
75
+
76
+ if method == LatentPreviewMethod.TAESD:
77
+ if taesd_decoder_path:
78
+ taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
79
+ previewer = TAESDPreviewerImpl(taesd)
80
+ else:
81
+ logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
82
+
83
+ if previewer is None:
84
+ if latent_format.latent_rgb_factors is not None:
85
+ previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
86
+ return previewer
87
+
88
+ def prepare_callback(model, steps, x0_output_dict=None):
89
+ preview_format = "JPEG"
90
+ if preview_format not in ["JPEG", "PNG"]:
91
+ preview_format = "JPEG"
92
+
93
+ previewer = get_previewer(model.load_device, model.model.latent_format)
94
+
95
+ pbar = comfy.utils.ProgressBar(steps)
96
+ def callback(step, x0, x, total_steps):
97
+ if x0_output_dict is not None:
98
+ x0_output_dict["x0"] = x0
99
+
100
+ preview_bytes = None
101
+ if previewer:
102
+ preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
103
+ pbar.update_absolute(step + 1, total_steps, preview_bytes)
104
+ return callback
105
+
main.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import comfy.options
2
+ comfy.options.enable_args_parsing()
3
+
4
+ import os
5
+ import importlib.util
6
+ import folder_paths
7
+ import time
8
+ from comfy.cli_args import args
9
+ from app.logger import setup_logger
10
+ import itertools
11
+ import utils.extra_config
12
+ import logging
13
+
14
+ if __name__ == "__main__":
15
+ #NOTE: These do not do anything on core ComfyUI which should already have no communication with the internet, they are for custom nodes.
16
+ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
17
+ os.environ['DO_NOT_TRACK'] = '1'
18
+
19
+
20
+ setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
21
+
22
+ def apply_custom_paths():
23
+ # extra model paths
24
+ extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
25
+ if os.path.isfile(extra_model_paths_config_path):
26
+ utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
27
+
28
+ if args.extra_model_paths_config:
29
+ for config_path in itertools.chain(*args.extra_model_paths_config):
30
+ utils.extra_config.load_extra_path_config(config_path)
31
+
32
+ # --output-directory, --input-directory, --user-directory
33
+ if args.output_directory:
34
+ output_dir = os.path.abspath(args.output_directory)
35
+ logging.info(f"Setting output directory to: {output_dir}")
36
+ folder_paths.set_output_directory(output_dir)
37
+
38
+ # These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
39
+ folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
40
+ folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
41
+ folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
42
+ folder_paths.add_model_folder_path("diffusion_models",
43
+ os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
44
+ folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
45
+
46
+ if args.input_directory:
47
+ input_dir = os.path.abspath(args.input_directory)
48
+ logging.info(f"Setting input directory to: {input_dir}")
49
+ folder_paths.set_input_directory(input_dir)
50
+
51
+ if args.user_directory:
52
+ user_dir = os.path.abspath(args.user_directory)
53
+ logging.info(f"Setting user directory to: {user_dir}")
54
+ folder_paths.set_user_directory(user_dir)
55
+
56
+
57
+ def execute_prestartup_script():
58
+ def execute_script(script_path):
59
+ module_name = os.path.splitext(script_path)[0]
60
+ try:
61
+ spec = importlib.util.spec_from_file_location(module_name, script_path)
62
+ module = importlib.util.module_from_spec(spec)
63
+ spec.loader.exec_module(module)
64
+ return True
65
+ except Exception as e:
66
+ logging.error(f"Failed to execute startup-script: {script_path} / {e}")
67
+ return False
68
+
69
+ if args.disable_all_custom_nodes:
70
+ return
71
+
72
+ node_paths = folder_paths.get_folder_paths("custom_nodes")
73
+ for custom_node_path in node_paths:
74
+ possible_modules = os.listdir(custom_node_path)
75
+ node_prestartup_times = []
76
+
77
+ for possible_module in possible_modules:
78
+ module_path = os.path.join(custom_node_path, possible_module)
79
+ if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
80
+ continue
81
+
82
+ script_path = os.path.join(module_path, "prestartup_script.py")
83
+ if os.path.exists(script_path):
84
+ time_before = time.perf_counter()
85
+ success = execute_script(script_path)
86
+ node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
87
+ if len(node_prestartup_times) > 0:
88
+ logging.info("\nPrestartup times for custom nodes:")
89
+ for n in sorted(node_prestartup_times):
90
+ if n[2]:
91
+ import_message = ""
92
+ else:
93
+ import_message = " (PRESTARTUP FAILED)"
94
+ logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
95
+ logging.info("")
96
+
97
+ apply_custom_paths()
98
+ execute_prestartup_script()
99
+
100
+
101
+ # Main code
102
+ import asyncio
103
+ import shutil
104
+ import threading
105
+ import gc
106
+
107
+
108
+ if os.name == "nt":
109
+ logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
110
+
111
+ if __name__ == "__main__":
112
+ if args.cuda_device is not None:
113
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
114
+ os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
115
+ logging.info("Set cuda device to: {}".format(args.cuda_device))
116
+
117
+ if args.oneapi_device_selector is not None:
118
+ os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
119
+ logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
120
+
121
+ if args.deterministic:
122
+ if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
123
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
124
+
125
+ import cuda_malloc
126
+
127
+ if args.windows_standalone_build:
128
+ try:
129
+ from fix_torch import fix_pytorch_libomp
130
+ fix_pytorch_libomp()
131
+ except:
132
+ pass
133
+
134
+ import comfy.utils
135
+
136
+ import execution
137
+ import server
138
+ from server import BinaryEventTypes
139
+ import nodes
140
+ import comfy.model_management
141
+
142
+ def cuda_malloc_warning():
143
+ device = comfy.model_management.get_torch_device()
144
+ device_name = comfy.model_management.get_torch_device_name(device)
145
+ cuda_malloc_warning = False
146
+ if "cudaMallocAsync" in device_name:
147
+ for b in cuda_malloc.blacklist:
148
+ if b in device_name:
149
+ cuda_malloc_warning = True
150
+ if cuda_malloc_warning:
151
+ logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
152
+
153
+
154
+ def prompt_worker(q, server_instance):
155
+ current_time: float = 0.0
156
+ e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
157
+ last_gc_collect = 0
158
+ need_gc = False
159
+ gc_collect_interval = 10.0
160
+
161
+ while True:
162
+ timeout = 1000.0
163
+ if need_gc:
164
+ timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
165
+
166
+ queue_item = q.get(timeout=timeout)
167
+ if queue_item is not None:
168
+ item, item_id = queue_item
169
+ execution_start_time = time.perf_counter()
170
+ prompt_id = item[1]
171
+ server_instance.last_prompt_id = prompt_id
172
+
173
+ e.execute(item[2], prompt_id, item[3], item[4])
174
+ need_gc = True
175
+ q.task_done(item_id,
176
+ e.history_result,
177
+ status=execution.PromptQueue.ExecutionStatus(
178
+ status_str='success' if e.success else 'error',
179
+ completed=e.success,
180
+ messages=e.status_messages))
181
+ if server_instance.client_id is not None:
182
+ server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
183
+
184
+ current_time = time.perf_counter()
185
+ execution_time = current_time - execution_start_time
186
+ logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
187
+
188
+ flags = q.get_flags()
189
+ free_memory = flags.get("free_memory", False)
190
+
191
+ if flags.get("unload_models", free_memory):
192
+ comfy.model_management.unload_all_models()
193
+ need_gc = True
194
+ last_gc_collect = 0
195
+
196
+ if free_memory:
197
+ e.reset()
198
+ need_gc = True
199
+ last_gc_collect = 0
200
+
201
+ if need_gc:
202
+ current_time = time.perf_counter()
203
+ if (current_time - last_gc_collect) > gc_collect_interval:
204
+ gc.collect()
205
+ comfy.model_management.soft_empty_cache()
206
+ last_gc_collect = current_time
207
+ need_gc = False
208
+
209
+
210
+ async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
211
+ addresses = []
212
+ for addr in address.split(","):
213
+ addresses.append((addr, port))
214
+ await asyncio.gather(
215
+ server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
216
+ )
217
+
218
+
219
+ def hijack_progress(server_instance):
220
+ def hook(value, total, preview_image):
221
+ comfy.model_management.throw_exception_if_processing_interrupted()
222
+ progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
223
+
224
+ server_instance.send_sync("progress", progress, server_instance.client_id)
225
+ if preview_image is not None:
226
+ server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
227
+
228
+ comfy.utils.set_progress_bar_global_hook(hook)
229
+
230
+
231
+ def cleanup_temp():
232
+ temp_dir = folder_paths.get_temp_directory()
233
+ if os.path.exists(temp_dir):
234
+ shutil.rmtree(temp_dir, ignore_errors=True)
235
+
236
+
237
+ def start_comfyui(asyncio_loop=None):
238
+ """
239
+ Starts the ComfyUI server using the provided asyncio event loop or creates a new one.
240
+ Returns the event loop, server instance, and a function to start the server asynchronously.
241
+ """
242
+ if args.temp_directory:
243
+ temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
244
+ logging.info(f"Setting temp directory to: {temp_dir}")
245
+ folder_paths.set_temp_directory(temp_dir)
246
+ cleanup_temp()
247
+
248
+ if args.windows_standalone_build:
249
+ try:
250
+ import new_updater
251
+ new_updater.update_windows_updater()
252
+ except:
253
+ pass
254
+
255
+ if not asyncio_loop:
256
+ asyncio_loop = asyncio.new_event_loop()
257
+ asyncio.set_event_loop(asyncio_loop)
258
+ prompt_server = server.PromptServer(asyncio_loop)
259
+ q = execution.PromptQueue(prompt_server)
260
+
261
+ nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
262
+
263
+ cuda_malloc_warning()
264
+
265
+ prompt_server.add_routes()
266
+ hijack_progress(prompt_server)
267
+
268
+ threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
269
+
270
+ if args.quick_test_for_ci:
271
+ exit(0)
272
+
273
+ os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
274
+ call_on_start = None
275
+ if args.auto_launch:
276
+ def startup_server(scheme, address, port):
277
+ import webbrowser
278
+ if os.name == 'nt' and address == '0.0.0.0':
279
+ address = '127.0.0.1'
280
+ if ':' in address:
281
+ address = "[{}]".format(address)
282
+ webbrowser.open(f"{scheme}://{address}:{port}")
283
+ call_on_start = startup_server
284
+
285
+ async def start_all():
286
+ await prompt_server.setup()
287
+ await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
288
+
289
+ # Returning these so that other code can integrate with the ComfyUI loop and server
290
+ return asyncio_loop, prompt_server, start_all
291
+
292
+
293
+ if __name__ == "__main__":
294
+ # Running directly, just start ComfyUI.
295
+ event_loop, _, start_all_func = start_comfyui()
296
+ try:
297
+ event_loop.run_until_complete(start_all_func())
298
+ except KeyboardInterrupt:
299
+ logging.info("\nStopped server")
300
+
301
+ cleanup_temp()
new_updater.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ base_path = os.path.dirname(os.path.realpath(__file__))
5
+
6
+
7
+ def update_windows_updater():
8
+ top_path = os.path.dirname(base_path)
9
+ updater_path = os.path.join(base_path, ".ci/update_windows/update.py")
10
+ bat_path = os.path.join(base_path, ".ci/update_windows/update_comfyui.bat")
11
+
12
+ dest_updater_path = os.path.join(top_path, "update/update.py")
13
+ dest_bat_path = os.path.join(top_path, "update/update_comfyui.bat")
14
+ dest_bat_deps_path = os.path.join(top_path, "update/update_comfyui_and_python_dependencies.bat")
15
+
16
+ try:
17
+ with open(dest_bat_path, 'rb') as f:
18
+ contents = f.read()
19
+ except:
20
+ return
21
+
22
+ if not contents.startswith(b"..\\python_embeded\\python.exe .\\update.py"):
23
+ return
24
+
25
+ shutil.copy(updater_path, dest_updater_path)
26
+ try:
27
+ with open(dest_bat_deps_path, 'rb') as f:
28
+ contents = f.read()
29
+ contents = contents.replace(b'..\\python_embeded\\python.exe .\\update.py ..\\ComfyUI\\', b'call update_comfyui.bat nopause')
30
+ with open(dest_bat_deps_path, 'wb') as f:
31
+ f.write(contents)
32
+ except:
33
+ pass
34
+ shutil.copy(bat_path, dest_bat_path)
35
+ print("Updated the windows standalone package updater.") # noqa: T201
node_helpers.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+
3
+ from comfy.cli_args import args
4
+
5
+ from PIL import ImageFile, UnidentifiedImageError
6
+
7
+ def conditioning_set_values(conditioning, values={}):
8
+ c = []
9
+ for t in conditioning:
10
+ n = [t[0], t[1].copy()]
11
+ for k in values:
12
+ n[1][k] = values[k]
13
+ c.append(n)
14
+
15
+ return c
16
+
17
+ def pillow(fn, arg):
18
+ prev_value = None
19
+ try:
20
+ x = fn(arg)
21
+ except (OSError, UnidentifiedImageError, ValueError): #PIL issues #4472 and #2445, also fixes ComfyUI issue #3416
22
+ prev_value = ImageFile.LOAD_TRUNCATED_IMAGES
23
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
24
+ x = fn(arg)
25
+ finally:
26
+ if prev_value is not None:
27
+ ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
28
+ return x
29
+
30
+ def hasher():
31
+ hashfuncs = {
32
+ "md5": hashlib.md5,
33
+ "sha1": hashlib.sha1,
34
+ "sha256": hashlib.sha256,
35
+ "sha512": hashlib.sha512
36
+ }
37
+ return hashfuncs[args.default_hashing_function]
notebooks/comfyui_colab.ipynb ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "aaaaaaaaaa"
7
+ },
8
+ "source": [
9
+ "Git clone the repo and install the requirements. (ignore the pip errors about protobuf)"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "id": "bbbbbbbbbb"
17
+ },
18
+ "outputs": [],
19
+ "source": [
20
+ "#@title Environment Setup\n",
21
+ "\n",
22
+ "\n",
23
+ "OPTIONS = {}\n",
24
+ "\n",
25
+ "USE_GOOGLE_DRIVE = False #@param {type:\"boolean\"}\n",
26
+ "UPDATE_COMFY_UI = True #@param {type:\"boolean\"}\n",
27
+ "WORKSPACE = 'ComfyUI'\n",
28
+ "OPTIONS['USE_GOOGLE_DRIVE'] = USE_GOOGLE_DRIVE\n",
29
+ "OPTIONS['UPDATE_COMFY_UI'] = UPDATE_COMFY_UI\n",
30
+ "\n",
31
+ "if OPTIONS['USE_GOOGLE_DRIVE']:\n",
32
+ " !echo \"Mounting Google Drive...\"\n",
33
+ " %cd /\n",
34
+ " \n",
35
+ " from google.colab import drive\n",
36
+ " drive.mount('/content/drive')\n",
37
+ "\n",
38
+ " WORKSPACE = \"/content/drive/MyDrive/ComfyUI\"\n",
39
+ " %cd /content/drive/MyDrive\n",
40
+ "\n",
41
+ "![ ! -d $WORKSPACE ] && echo -= Initial setup ComfyUI =- && git clone https://github.com/comfyanonymous/ComfyUI\n",
42
+ "%cd $WORKSPACE\n",
43
+ "\n",
44
+ "if OPTIONS['UPDATE_COMFY_UI']:\n",
45
+ " !echo -= Updating ComfyUI =-\n",
46
+ " !git pull\n",
47
+ "\n",
48
+ "!echo -= Install dependencies =-\n",
49
+ "!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "markdown",
54
+ "metadata": {
55
+ "id": "cccccccccc"
56
+ },
57
+ "source": [
58
+ "Download some models/checkpoints/vae or custom comfyui nodes (uncomment the commands for the ones you want)"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "metadata": {
65
+ "id": "dddddddddd"
66
+ },
67
+ "outputs": [],
68
+ "source": [
69
+ "# Checkpoints\n",
70
+ "\n",
71
+ "### SDXL\n",
72
+ "### I recommend these workflow examples: https://comfyanonymous.github.io/ComfyUI_examples/sdxl/\n",
73
+ "\n",
74
+ "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P ./models/checkpoints/\n",
75
+ "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0.safetensors -P ./models/checkpoints/\n",
76
+ "\n",
77
+ "# SDXL ReVision\n",
78
+ "#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
79
+ "\n",
80
+ "# SD1.5\n",
81
+ "!wget -c https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors -P ./models/checkpoints/\n",
82
+ "\n",
83
+ "# SD2\n",
84
+ "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",
85
+ "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors -P ./models/checkpoints/\n",
86
+ "\n",
87
+ "# Some SD1.5 anime style\n",
88
+ "#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix2/AbyssOrangeMix2_hard.safetensors -P ./models/checkpoints/\n",
89
+ "#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix3/AOM3A1_orangemixs.safetensors -P ./models/checkpoints/\n",
90
+ "#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix3/AOM3A3_orangemixs.safetensors -P ./models/checkpoints/\n",
91
+ "#!wget -c https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/anything-v3-fp16-pruned.safetensors -P ./models/checkpoints/\n",
92
+ "\n",
93
+ "# Waifu Diffusion 1.5 (anime style SD2.x 768-v)\n",
94
+ "#!wget -c https://huggingface.co/waifu-diffusion/wd-1-5-beta3/resolve/main/wd-illusion-fp16.safetensors -P ./models/checkpoints/\n",
95
+ "\n",
96
+ "\n",
97
+ "# unCLIP models\n",
98
+ "#!wget -c https://huggingface.co/comfyanonymous/illuminatiDiffusionV1_v11_unCLIP/resolve/main/illuminatiDiffusionV1_v11-unclip-h-fp16.safetensors -P ./models/checkpoints/\n",
99
+ "#!wget -c https://huggingface.co/comfyanonymous/wd-1.5-beta2_unCLIP/resolve/main/wd-1-5-beta2-aesthetic-unclip-h-fp16.safetensors -P ./models/checkpoints/\n",
100
+ "\n",
101
+ "\n",
102
+ "# VAE\n",
103
+ "!wget -c https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.safetensors -P ./models/vae/\n",
104
+ "#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/VAEs/orangemix.vae.pt -P ./models/vae/\n",
105
+ "#!wget -c https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime2.ckpt -P ./models/vae/\n",
106
+ "\n",
107
+ "\n",
108
+ "# Loras\n",
109
+ "#!wget -c https://civitai.com/api/download/models/10350 -O ./models/loras/theovercomer8sContrastFix_sd21768.safetensors #theovercomer8sContrastFix SD2.x 768-v\n",
110
+ "#!wget -c https://civitai.com/api/download/models/10638 -O ./models/loras/theovercomer8sContrastFix_sd15.safetensors #theovercomer8sContrastFix SD1.x\n",
111
+ "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_offset_example-lora_1.0.safetensors -P ./models/loras/ #SDXL offset noise lora\n",
112
+ "\n",
113
+ "\n",
114
+ "# T2I-Adapter\n",
115
+ "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth -P ./models/controlnet/\n",
116
+ "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_seg_sd14v1.pth -P ./models/controlnet/\n",
117
+ "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_sketch_sd14v1.pth -P ./models/controlnet/\n",
118
+ "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_keypose_sd14v1.pth -P ./models/controlnet/\n",
119
+ "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_openpose_sd14v1.pth -P ./models/controlnet/\n",
120
+ "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_color_sd14v1.pth -P ./models/controlnet/\n",
121
+ "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_canny_sd14v1.pth -P ./models/controlnet/\n",
122
+ "\n",
123
+ "# T2I Styles Model\n",
124
+ "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_style_sd14v1.pth -P ./models/style_models/\n",
125
+ "\n",
126
+ "# CLIPVision model (needed for styles model)\n",
127
+ "#!wget -c https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin -O ./models/clip_vision/clip_vit14.bin\n",
128
+ "\n",
129
+ "\n",
130
+ "# ControlNet\n",
131
+ "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_ip2p_fp16.safetensors -P ./models/controlnet/\n",
132
+ "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_shuffle_fp16.safetensors -P ./models/controlnet/\n",
133
+ "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_canny_fp16.safetensors -P ./models/controlnet/\n",
134
+ "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11f1p_sd15_depth_fp16.safetensors -P ./models/controlnet/\n",
135
+ "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_inpaint_fp16.safetensors -P ./models/controlnet/\n",
136
+ "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_lineart_fp16.safetensors -P ./models/controlnet/\n",
137
+ "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_mlsd_fp16.safetensors -P ./models/controlnet/\n",
138
+ "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_normalbae_fp16.safetensors -P ./models/controlnet/\n",
139
+ "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_openpose_fp16.safetensors -P ./models/controlnet/\n",
140
+ "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_scribble_fp16.safetensors -P ./models/controlnet/\n",
141
+ "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_seg_fp16.safetensors -P ./models/controlnet/\n",
142
+ "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_softedge_fp16.safetensors -P ./models/controlnet/\n",
143
+ "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15s2_lineart_anime_fp16.safetensors -P ./models/controlnet/\n",
144
+ "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11u_sd15_tile_fp16.safetensors -P ./models/controlnet/\n",
145
+ "\n",
146
+ "# ControlNet SDXL\n",
147
+ "#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-canny-rank256.safetensors -P ./models/controlnet/\n",
148
+ "#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-depth-rank256.safetensors -P ./models/controlnet/\n",
149
+ "#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-recolor-rank256.safetensors -P ./models/controlnet/\n",
150
+ "#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-sketch-rank256.safetensors -P ./models/controlnet/\n",
151
+ "\n",
152
+ "# Controlnet Preprocessor nodes by Fannovel16\n",
153
+ "#!cd custom_nodes && git clone https://github.com/Fannovel16/comfy_controlnet_preprocessors; cd comfy_controlnet_preprocessors && python install.py\n",
154
+ "\n",
155
+ "\n",
156
+ "# GLIGEN\n",
157
+ "#!wget -c https://huggingface.co/comfyanonymous/GLIGEN_pruned_safetensors/resolve/main/gligen_sd14_textbox_pruned_fp16.safetensors -P ./models/gligen/\n",
158
+ "\n",
159
+ "\n",
160
+ "# ESRGAN upscale model\n",
161
+ "#!wget -c https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./models/upscale_models/\n",
162
+ "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n",
163
+ "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n",
164
+ "\n",
165
+ "\n"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "markdown",
170
+ "metadata": {
171
+ "id": "kkkkkkkkkkkkkkk"
172
+ },
173
+ "source": [
174
+ "### Run ComfyUI with cloudflared (Recommended Way)\n",
175
+ "\n",
176
+ "\n"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "metadata": {
183
+ "id": "jjjjjjjjjjjjjj"
184
+ },
185
+ "outputs": [],
186
+ "source": [
187
+ "!wget https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb\n",
188
+ "!dpkg -i cloudflared-linux-amd64.deb\n",
189
+ "\n",
190
+ "import subprocess\n",
191
+ "import threading\n",
192
+ "import time\n",
193
+ "import socket\n",
194
+ "import urllib.request\n",
195
+ "\n",
196
+ "def iframe_thread(port):\n",
197
+ " while True:\n",
198
+ " time.sleep(0.5)\n",
199
+ " sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n",
200
+ " result = sock.connect_ex(('127.0.0.1', port))\n",
201
+ " if result == 0:\n",
202
+ " break\n",
203
+ " sock.close()\n",
204
+ " print(\"\\nComfyUI finished loading, trying to launch cloudflared (if it gets stuck here cloudflared is having issues)\\n\")\n",
205
+ "\n",
206
+ " p = subprocess.Popen([\"cloudflared\", \"tunnel\", \"--url\", \"http://127.0.0.1:{}\".format(port)], stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
207
+ " for line in p.stderr:\n",
208
+ " l = line.decode()\n",
209
+ " if \"trycloudflare.com \" in l:\n",
210
+ " print(\"This is the URL to access ComfyUI:\", l[l.find(\"http\"):], end='')\n",
211
+ " #print(l, end='')\n",
212
+ "\n",
213
+ "\n",
214
+ "threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n",
215
+ "\n",
216
+ "!python main.py --dont-print-server"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "markdown",
221
+ "metadata": {
222
+ "id": "kkkkkkkkkkkkkk"
223
+ },
224
+ "source": [
225
+ "### Run ComfyUI with localtunnel\n",
226
+ "\n",
227
+ "\n"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": null,
233
+ "metadata": {
234
+ "id": "jjjjjjjjjjjjj"
235
+ },
236
+ "outputs": [],
237
+ "source": [
238
+ "!npm install -g localtunnel\n",
239
+ "\n",
240
+ "import threading\n",
241
+ "\n",
242
+ "def iframe_thread(port):\n",
243
+ " while True:\n",
244
+ " time.sleep(0.5)\n",
245
+ " sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n",
246
+ " result = sock.connect_ex(('127.0.0.1', port))\n",
247
+ " if result == 0:\n",
248
+ " break\n",
249
+ " sock.close()\n",
250
+ " print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n",
251
+ "\n",
252
+ " print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n",
253
+ " p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n",
254
+ " for line in p.stdout:\n",
255
+ " print(line.decode(), end='')\n",
256
+ "\n",
257
+ "\n",
258
+ "threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n",
259
+ "\n",
260
+ "!python main.py --dont-print-server"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "markdown",
265
+ "metadata": {
266
+ "id": "gggggggggg"
267
+ },
268
+ "source": [
269
+ "### Run ComfyUI with colab iframe (use only in case the previous way with localtunnel doesn't work)\n",
270
+ "\n",
271
+ "You should see the ui appear in an iframe. If you get a 403 error, it's your firefox settings or an extension that's messing things up.\n",
272
+ "\n",
273
+ "If you want to open it in another window use the link.\n",
274
+ "\n",
275
+ "Note that some UI features like live image previews won't work because the colab iframe blocks websockets."
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": null,
281
+ "metadata": {
282
+ "id": "hhhhhhhhhh"
283
+ },
284
+ "outputs": [],
285
+ "source": [
286
+ "import threading\n",
287
+ "def iframe_thread(port):\n",
288
+ " while True:\n",
289
+ " time.sleep(0.5)\n",
290
+ " sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n",
291
+ " result = sock.connect_ex(('127.0.0.1', port))\n",
292
+ " if result == 0:\n",
293
+ " break\n",
294
+ " sock.close()\n",
295
+ " from google.colab import output\n",
296
+ " output.serve_kernel_port_as_iframe(port, height=1024)\n",
297
+ " print(\"to open it in a window you can open this link here:\")\n",
298
+ " output.serve_kernel_port_as_window(port)\n",
299
+ "\n",
300
+ "threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n",
301
+ "\n",
302
+ "!python main.py --dont-print-server"
303
+ ]
304
+ }
305
+ ],
306
+ "metadata": {
307
+ "accelerator": "GPU",
308
+ "colab": {
309
+ "provenance": []
310
+ },
311
+ "gpuClass": "standard",
312
+ "kernelspec": {
313
+ "display_name": "Python 3",
314
+ "name": "python3"
315
+ },
316
+ "language_info": {
317
+ "name": "python"
318
+ }
319
+ },
320
+ "nbformat": 4,
321
+ "nbformat_minor": 0
322
+ }
output/_output_images_will_be_put_here ADDED
File without changes
pyproject.toml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "ComfyUI"
3
+ version = "0.3.12"
4
+ readme = "README.md"
5
+ license = { file = "LICENSE" }
6
+ requires-python = ">=3.9"
7
+
8
+ [project.urls]
9
+ homepage = "https://www.comfy.org/"
10
+ repository = "https://github.com/comfyanonymous/ComfyUI"
11
+ documentation = "https://docs.comfy.org/"
12
+
13
+ [tool.ruff]
14
+ lint.select = [
15
+ "S307", # suspicious-eval-usage
16
+ "S102", # exec
17
+ "T", # print-usage
18
+ "W",
19
+ # The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
20
+ # See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
21
+ "F",
22
+ ]
23
+ exclude = ["*.ipynb"]
pytest.ini ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [pytest]
2
+ markers =
3
+ inference: mark as inference test (deselect with '-m "not inference"')
4
+ execution: mark as execution test (deselect with '-m "not execution"')
5
+ testpaths =
6
+ tests
7
+ tests-unit
8
+ addopts = -s
9
+ pythonpath = .
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchsde
3
+ torchvision
4
+ torchaudio
5
+ numpy>=1.25.0
6
+ einops
7
+ transformers>=4.28.1
8
+ tokenizers>=0.13.3
9
+ sentencepiece
10
+ safetensors>=0.4.2
11
+ aiohttp
12
+ pyyaml
13
+ Pillow>=9.5.0
14
+ scipy
15
+ tqdm
16
+ psutil
17
+ gradio
18
+ huggingface_hub
19
+ # Base Detectron2
20
+ git+https://github.com/facebookresearch/[email protected]
21
+ # DensePose (part of Detectron2 projects)
22
+ git+https://github.com/facebookresearch/[email protected]#subdirectory=projects/DensePose
23
+
24
+ # Other dependencies, except conflicting ones, should be manually added
25
+ # If you have a modified requirements file, include its contents here, excluding detectron2-related dependencies.
26
+ #non essential dependencies:
27
+ kornia>=0.7.1
28
+ spandrel
29
+ soundfile
script_examples/basic_api_example.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from urllib import request
3
+
4
+ #This is the ComfyUI api prompt format.
5
+
6
+ #If you want it for a specific workflow you can "enable dev mode options"
7
+ #in the settings of the UI (gear beside the "Queue Size: ") this will enable
8
+ #a button on the UI to save workflows in api format.
9
+
10
+ #keep in mind ComfyUI is pre alpha software so this format will change a bit.
11
+
12
+ #this is the one for the default workflow
13
+ prompt_text = """
14
+ {
15
+ "3": {
16
+ "class_type": "KSampler",
17
+ "inputs": {
18
+ "cfg": 8,
19
+ "denoise": 1,
20
+ "latent_image": [
21
+ "5",
22
+ 0
23
+ ],
24
+ "model": [
25
+ "4",
26
+ 0
27
+ ],
28
+ "negative": [
29
+ "7",
30
+ 0
31
+ ],
32
+ "positive": [
33
+ "6",
34
+ 0
35
+ ],
36
+ "sampler_name": "euler",
37
+ "scheduler": "normal",
38
+ "seed": 8566257,
39
+ "steps": 20
40
+ }
41
+ },
42
+ "4": {
43
+ "class_type": "CheckpointLoaderSimple",
44
+ "inputs": {
45
+ "ckpt_name": "v1-5-pruned-emaonly.safetensors"
46
+ }
47
+ },
48
+ "5": {
49
+ "class_type": "EmptyLatentImage",
50
+ "inputs": {
51
+ "batch_size": 1,
52
+ "height": 512,
53
+ "width": 512
54
+ }
55
+ },
56
+ "6": {
57
+ "class_type": "CLIPTextEncode",
58
+ "inputs": {
59
+ "clip": [
60
+ "4",
61
+ 1
62
+ ],
63
+ "text": "masterpiece best quality girl"
64
+ }
65
+ },
66
+ "7": {
67
+ "class_type": "CLIPTextEncode",
68
+ "inputs": {
69
+ "clip": [
70
+ "4",
71
+ 1
72
+ ],
73
+ "text": "bad hands"
74
+ }
75
+ },
76
+ "8": {
77
+ "class_type": "VAEDecode",
78
+ "inputs": {
79
+ "samples": [
80
+ "3",
81
+ 0
82
+ ],
83
+ "vae": [
84
+ "4",
85
+ 2
86
+ ]
87
+ }
88
+ },
89
+ "9": {
90
+ "class_type": "SaveImage",
91
+ "inputs": {
92
+ "filename_prefix": "ComfyUI",
93
+ "images": [
94
+ "8",
95
+ 0
96
+ ]
97
+ }
98
+ }
99
+ }
100
+ """
101
+
102
+ def queue_prompt(prompt):
103
+ p = {"prompt": prompt}
104
+ data = json.dumps(p).encode('utf-8')
105
+ req = request.Request("http://127.0.0.1:8188/prompt", data=data)
106
+ request.urlopen(req)
107
+
108
+
109
+ prompt = json.loads(prompt_text)
110
+ #set the text prompt for our positive CLIPTextEncode
111
+ prompt["6"]["inputs"]["text"] = "masterpiece best quality man"
112
+
113
+ #set the seed for our KSampler node
114
+ prompt["3"]["inputs"]["seed"] = 5
115
+
116
+
117
+ queue_prompt(prompt)
118
+
119
+
script_examples/websockets_api_example.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #This is an example that uses the websockets api to know when a prompt execution is done
2
+ #Once the prompt execution is done it downloads the images using the /history endpoint
3
+
4
+ import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
5
+ import uuid
6
+ import json
7
+ import urllib.request
8
+ import urllib.parse
9
+
10
+ server_address = "127.0.0.1:8188"
11
+ client_id = str(uuid.uuid4())
12
+
13
+ def queue_prompt(prompt):
14
+ p = {"prompt": prompt, "client_id": client_id}
15
+ data = json.dumps(p).encode('utf-8')
16
+ req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
17
+ return json.loads(urllib.request.urlopen(req).read())
18
+
19
+ def get_image(filename, subfolder, folder_type):
20
+ data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
21
+ url_values = urllib.parse.urlencode(data)
22
+ with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
23
+ return response.read()
24
+
25
+ def get_history(prompt_id):
26
+ with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
27
+ return json.loads(response.read())
28
+
29
+ def get_images(ws, prompt):
30
+ prompt_id = queue_prompt(prompt)['prompt_id']
31
+ output_images = {}
32
+ while True:
33
+ out = ws.recv()
34
+ if isinstance(out, str):
35
+ message = json.loads(out)
36
+ if message['type'] == 'executing':
37
+ data = message['data']
38
+ if data['node'] is None and data['prompt_id'] == prompt_id:
39
+ break #Execution is done
40
+ else:
41
+ # If you want to be able to decode the binary stream for latent previews, here is how you can do it:
42
+ # bytesIO = BytesIO(out[8:])
43
+ # preview_image = Image.open(bytesIO) # This is your preview in PIL image format, store it in a global
44
+ continue #previews are binary data
45
+
46
+ history = get_history(prompt_id)[prompt_id]
47
+ for node_id in history['outputs']:
48
+ node_output = history['outputs'][node_id]
49
+ images_output = []
50
+ if 'images' in node_output:
51
+ for image in node_output['images']:
52
+ image_data = get_image(image['filename'], image['subfolder'], image['type'])
53
+ images_output.append(image_data)
54
+ output_images[node_id] = images_output
55
+
56
+ return output_images
57
+
58
+ prompt_text = """
59
+ {
60
+ "3": {
61
+ "class_type": "KSampler",
62
+ "inputs": {
63
+ "cfg": 8,
64
+ "denoise": 1,
65
+ "latent_image": [
66
+ "5",
67
+ 0
68
+ ],
69
+ "model": [
70
+ "4",
71
+ 0
72
+ ],
73
+ "negative": [
74
+ "7",
75
+ 0
76
+ ],
77
+ "positive": [
78
+ "6",
79
+ 0
80
+ ],
81
+ "sampler_name": "euler",
82
+ "scheduler": "normal",
83
+ "seed": 8566257,
84
+ "steps": 20
85
+ }
86
+ },
87
+ "4": {
88
+ "class_type": "CheckpointLoaderSimple",
89
+ "inputs": {
90
+ "ckpt_name": "v1-5-pruned-emaonly.safetensors"
91
+ }
92
+ },
93
+ "5": {
94
+ "class_type": "EmptyLatentImage",
95
+ "inputs": {
96
+ "batch_size": 1,
97
+ "height": 512,
98
+ "width": 512
99
+ }
100
+ },
101
+ "6": {
102
+ "class_type": "CLIPTextEncode",
103
+ "inputs": {
104
+ "clip": [
105
+ "4",
106
+ 1
107
+ ],
108
+ "text": "masterpiece best quality girl"
109
+ }
110
+ },
111
+ "7": {
112
+ "class_type": "CLIPTextEncode",
113
+ "inputs": {
114
+ "clip": [
115
+ "4",
116
+ 1
117
+ ],
118
+ "text": "bad hands"
119
+ }
120
+ },
121
+ "8": {
122
+ "class_type": "VAEDecode",
123
+ "inputs": {
124
+ "samples": [
125
+ "3",
126
+ 0
127
+ ],
128
+ "vae": [
129
+ "4",
130
+ 2
131
+ ]
132
+ }
133
+ },
134
+ "9": {
135
+ "class_type": "SaveImage",
136
+ "inputs": {
137
+ "filename_prefix": "ComfyUI",
138
+ "images": [
139
+ "8",
140
+ 0
141
+ ]
142
+ }
143
+ }
144
+ }
145
+ """
146
+
147
+ prompt = json.loads(prompt_text)
148
+ #set the text prompt for our positive CLIPTextEncode
149
+ prompt["6"]["inputs"]["text"] = "masterpiece best quality man"
150
+
151
+ #set the seed for our KSampler node
152
+ prompt["3"]["inputs"]["seed"] = 5
153
+
154
+ ws = websocket.WebSocket()
155
+ ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
156
+ images = get_images(ws, prompt)
157
+ ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
158
+ #Commented out code to display the output images:
159
+
160
+ # for node_id in images:
161
+ # for image_data in images[node_id]:
162
+ # from PIL import Image
163
+ # import io
164
+ # image = Image.open(io.BytesIO(image_data))
165
+ # image.show()
166
+
script_examples/websockets_api_example_ws_images.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #This is an example that uses the websockets api and the SaveImageWebsocket node to get images directly without
2
+ #them being saved to disk
3
+
4
+ import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
5
+ import uuid
6
+ import json
7
+ import urllib.request
8
+ import urllib.parse
9
+
10
+ server_address = "127.0.0.1:8188"
11
+ client_id = str(uuid.uuid4())
12
+
13
+ def queue_prompt(prompt):
14
+ p = {"prompt": prompt, "client_id": client_id}
15
+ data = json.dumps(p).encode('utf-8')
16
+ req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
17
+ return json.loads(urllib.request.urlopen(req).read())
18
+
19
+ def get_image(filename, subfolder, folder_type):
20
+ data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
21
+ url_values = urllib.parse.urlencode(data)
22
+ with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
23
+ return response.read()
24
+
25
+ def get_history(prompt_id):
26
+ with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
27
+ return json.loads(response.read())
28
+
29
+ def get_images(ws, prompt):
30
+ prompt_id = queue_prompt(prompt)['prompt_id']
31
+ output_images = {}
32
+ current_node = ""
33
+ while True:
34
+ out = ws.recv()
35
+ if isinstance(out, str):
36
+ message = json.loads(out)
37
+ if message['type'] == 'executing':
38
+ data = message['data']
39
+ if data['prompt_id'] == prompt_id:
40
+ if data['node'] is None:
41
+ break #Execution is done
42
+ else:
43
+ current_node = data['node']
44
+ else:
45
+ if current_node == 'save_image_websocket_node':
46
+ images_output = output_images.get(current_node, [])
47
+ images_output.append(out[8:])
48
+ output_images[current_node] = images_output
49
+
50
+ return output_images
51
+
52
+ prompt_text = """
53
+ {
54
+ "3": {
55
+ "class_type": "KSampler",
56
+ "inputs": {
57
+ "cfg": 8,
58
+ "denoise": 1,
59
+ "latent_image": [
60
+ "5",
61
+ 0
62
+ ],
63
+ "model": [
64
+ "4",
65
+ 0
66
+ ],
67
+ "negative": [
68
+ "7",
69
+ 0
70
+ ],
71
+ "positive": [
72
+ "6",
73
+ 0
74
+ ],
75
+ "sampler_name": "euler",
76
+ "scheduler": "normal",
77
+ "seed": 8566257,
78
+ "steps": 20
79
+ }
80
+ },
81
+ "4": {
82
+ "class_type": "CheckpointLoaderSimple",
83
+ "inputs": {
84
+ "ckpt_name": "v1-5-pruned-emaonly.safetensors"
85
+ }
86
+ },
87
+ "5": {
88
+ "class_type": "EmptyLatentImage",
89
+ "inputs": {
90
+ "batch_size": 1,
91
+ "height": 512,
92
+ "width": 512
93
+ }
94
+ },
95
+ "6": {
96
+ "class_type": "CLIPTextEncode",
97
+ "inputs": {
98
+ "clip": [
99
+ "4",
100
+ 1
101
+ ],
102
+ "text": "masterpiece best quality girl"
103
+ }
104
+ },
105
+ "7": {
106
+ "class_type": "CLIPTextEncode",
107
+ "inputs": {
108
+ "clip": [
109
+ "4",
110
+ 1
111
+ ],
112
+ "text": "bad hands"
113
+ }
114
+ },
115
+ "8": {
116
+ "class_type": "VAEDecode",
117
+ "inputs": {
118
+ "samples": [
119
+ "3",
120
+ 0
121
+ ],
122
+ "vae": [
123
+ "4",
124
+ 2
125
+ ]
126
+ }
127
+ },
128
+ "save_image_websocket_node": {
129
+ "class_type": "SaveImageWebsocket",
130
+ "inputs": {
131
+ "images": [
132
+ "8",
133
+ 0
134
+ ]
135
+ }
136
+ }
137
+ }
138
+ """
139
+
140
+ prompt = json.loads(prompt_text)
141
+ #set the text prompt for our positive CLIPTextEncode
142
+ prompt["6"]["inputs"]["text"] = "masterpiece best quality man"
143
+
144
+ #set the seed for our KSampler node
145
+ prompt["3"]["inputs"]["seed"] = 5
146
+
147
+ ws = websocket.WebSocket()
148
+ ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
149
+ images = get_images(ws, prompt)
150
+ ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
151
+ #Commented out code to display the output images:
152
+
153
+ # for node_id in images:
154
+ # for image_data in images[node_id]:
155
+ # from PIL import Image
156
+ # import io
157
+ # image = Image.open(io.BytesIO(image_data))
158
+ # image.show()
159
+
utils/__init__.py ADDED
File without changes
utils/extra_config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import folder_paths
4
+ import logging
5
+
6
+ def load_extra_path_config(yaml_path):
7
+ with open(yaml_path, 'r') as stream:
8
+ config = yaml.safe_load(stream)
9
+ yaml_dir = os.path.dirname(os.path.abspath(yaml_path))
10
+ for c in config:
11
+ conf = config[c]
12
+ if conf is None:
13
+ continue
14
+ base_path = None
15
+ if "base_path" in conf:
16
+ base_path = conf.pop("base_path")
17
+ base_path = os.path.expandvars(os.path.expanduser(base_path))
18
+ if not os.path.isabs(base_path):
19
+ base_path = os.path.abspath(os.path.join(yaml_dir, base_path))
20
+ is_default = False
21
+ if "is_default" in conf:
22
+ is_default = conf.pop("is_default")
23
+ for x in conf:
24
+ for y in conf[x].split("\n"):
25
+ if len(y) == 0:
26
+ continue
27
+ full_path = y
28
+ if base_path:
29
+ full_path = os.path.join(base_path, full_path)
30
+ elif not os.path.isabs(full_path):
31
+ full_path = os.path.abspath(os.path.join(yaml_dir, y))
32
+ logging.info("Adding extra search path {} {}".format(x, full_path))
33
+ folder_paths.add_model_folder_path(x, full_path, is_default)