Upload 16 files
Browse files- api_server/__init__.py +0 -0
- api_server/routes/__init__.py +0 -0
- api_server/routes/internal/README.md +3 -0
- api_server/routes/internal/__init__.py +0 -0
- api_server/routes/internal/internal_routes.py +75 -0
- api_server/services/__init__.py +0 -0
- api_server/services/file_service.py +13 -0
- api_server/services/terminal_service.py +60 -0
- api_server/utils/file_operations.py +42 -0
- app/__init__.py +0 -0
- app/app_settings.py +59 -0
- app/custom_node_manager.py +134 -0
- app/frontend_management.py +204 -0
- app/logger.py +84 -0
- app/model_manager.py +184 -0
- app/user_manager.py +330 -0
api_server/__init__.py
ADDED
File without changes
|
api_server/routes/__init__.py
ADDED
File without changes
|
api_server/routes/internal/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# ComfyUI Internal Routes
|
2 |
+
|
3 |
+
All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.
|
api_server/routes/internal/__init__.py
ADDED
File without changes
|
api_server/routes/internal/internal_routes.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from aiohttp import web
|
2 |
+
from typing import Optional
|
3 |
+
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
|
4 |
+
from api_server.services.file_service import FileService
|
5 |
+
from api_server.services.terminal_service import TerminalService
|
6 |
+
import app.logger
|
7 |
+
|
8 |
+
class InternalRoutes:
|
9 |
+
'''
|
10 |
+
The top level web router for internal routes: /internal/*
|
11 |
+
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
|
12 |
+
Check README.md for more information.
|
13 |
+
'''
|
14 |
+
|
15 |
+
def __init__(self, prompt_server):
|
16 |
+
self.routes: web.RouteTableDef = web.RouteTableDef()
|
17 |
+
self._app: Optional[web.Application] = None
|
18 |
+
self.file_service = FileService({
|
19 |
+
"models": models_dir,
|
20 |
+
"user": user_directory,
|
21 |
+
"output": output_directory
|
22 |
+
})
|
23 |
+
self.prompt_server = prompt_server
|
24 |
+
self.terminal_service = TerminalService(prompt_server)
|
25 |
+
|
26 |
+
def setup_routes(self):
|
27 |
+
@self.routes.get('/files')
|
28 |
+
async def list_files(request):
|
29 |
+
directory_key = request.query.get('directory', '')
|
30 |
+
try:
|
31 |
+
file_list = self.file_service.list_files(directory_key)
|
32 |
+
return web.json_response({"files": file_list})
|
33 |
+
except ValueError as e:
|
34 |
+
return web.json_response({"error": str(e)}, status=400)
|
35 |
+
except Exception as e:
|
36 |
+
return web.json_response({"error": str(e)}, status=500)
|
37 |
+
|
38 |
+
@self.routes.get('/logs')
|
39 |
+
async def get_logs(request):
|
40 |
+
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
|
41 |
+
|
42 |
+
@self.routes.get('/logs/raw')
|
43 |
+
async def get_raw_logs(request):
|
44 |
+
self.terminal_service.update_size()
|
45 |
+
return web.json_response({
|
46 |
+
"entries": list(app.logger.get_logs()),
|
47 |
+
"size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
|
48 |
+
})
|
49 |
+
|
50 |
+
@self.routes.patch('/logs/subscribe')
|
51 |
+
async def subscribe_logs(request):
|
52 |
+
json_data = await request.json()
|
53 |
+
client_id = json_data["clientId"]
|
54 |
+
enabled = json_data["enabled"]
|
55 |
+
if enabled:
|
56 |
+
self.terminal_service.subscribe(client_id)
|
57 |
+
else:
|
58 |
+
self.terminal_service.unsubscribe(client_id)
|
59 |
+
|
60 |
+
return web.Response(status=200)
|
61 |
+
|
62 |
+
|
63 |
+
@self.routes.get('/folder_paths')
|
64 |
+
async def get_folder_paths(request):
|
65 |
+
response = {}
|
66 |
+
for key in folder_names_and_paths:
|
67 |
+
response[key] = folder_names_and_paths[key][0]
|
68 |
+
return web.json_response(response)
|
69 |
+
|
70 |
+
def get_app(self):
|
71 |
+
if self._app is None:
|
72 |
+
self._app = web.Application()
|
73 |
+
self.setup_routes()
|
74 |
+
self._app.add_routes(self.routes)
|
75 |
+
return self._app
|
api_server/services/__init__.py
ADDED
File without changes
|
api_server/services/file_service.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional
|
2 |
+
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
|
3 |
+
|
4 |
+
class FileService:
|
5 |
+
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
|
6 |
+
self.allowed_directories: Dict[str, str] = allowed_directories
|
7 |
+
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
|
8 |
+
|
9 |
+
def list_files(self, directory_key: str) -> List[FileSystemItem]:
|
10 |
+
if directory_key not in self.allowed_directories:
|
11 |
+
raise ValueError("Invalid directory key")
|
12 |
+
directory_path: str = self.allowed_directories[directory_key]
|
13 |
+
return self.file_system_ops.walk_directory(directory_path)
|
api_server/services/terminal_service.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from app.logger import on_flush
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
|
6 |
+
class TerminalService:
|
7 |
+
def __init__(self, server):
|
8 |
+
self.server = server
|
9 |
+
self.cols = None
|
10 |
+
self.rows = None
|
11 |
+
self.subscriptions = set()
|
12 |
+
on_flush(self.send_messages)
|
13 |
+
|
14 |
+
def get_terminal_size(self):
|
15 |
+
try:
|
16 |
+
size = os.get_terminal_size()
|
17 |
+
return (size.columns, size.lines)
|
18 |
+
except OSError:
|
19 |
+
try:
|
20 |
+
size = shutil.get_terminal_size()
|
21 |
+
return (size.columns, size.lines)
|
22 |
+
except OSError:
|
23 |
+
return (80, 24) # fallback to 80x24
|
24 |
+
|
25 |
+
def update_size(self):
|
26 |
+
columns, lines = self.get_terminal_size()
|
27 |
+
changed = False
|
28 |
+
|
29 |
+
if columns != self.cols:
|
30 |
+
self.cols = columns
|
31 |
+
changed = True
|
32 |
+
|
33 |
+
if lines != self.rows:
|
34 |
+
self.rows = lines
|
35 |
+
changed = True
|
36 |
+
|
37 |
+
if changed:
|
38 |
+
return {"cols": self.cols, "rows": self.rows}
|
39 |
+
|
40 |
+
return None
|
41 |
+
|
42 |
+
def subscribe(self, client_id):
|
43 |
+
self.subscriptions.add(client_id)
|
44 |
+
|
45 |
+
def unsubscribe(self, client_id):
|
46 |
+
self.subscriptions.discard(client_id)
|
47 |
+
|
48 |
+
def send_messages(self, entries):
|
49 |
+
if not len(entries) or not len(self.subscriptions):
|
50 |
+
return
|
51 |
+
|
52 |
+
new_size = self.update_size()
|
53 |
+
|
54 |
+
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
|
55 |
+
if client_id not in self.server.sockets:
|
56 |
+
# Automatically unsub if the socket has disconnected
|
57 |
+
self.unsubscribe(client_id)
|
58 |
+
continue
|
59 |
+
|
60 |
+
self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)
|
api_server/utils/file_operations.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Union, TypedDict, Literal
|
3 |
+
from typing_extensions import TypeGuard
|
4 |
+
class FileInfo(TypedDict):
|
5 |
+
name: str
|
6 |
+
path: str
|
7 |
+
type: Literal["file"]
|
8 |
+
size: int
|
9 |
+
|
10 |
+
class DirectoryInfo(TypedDict):
|
11 |
+
name: str
|
12 |
+
path: str
|
13 |
+
type: Literal["directory"]
|
14 |
+
|
15 |
+
FileSystemItem = Union[FileInfo, DirectoryInfo]
|
16 |
+
|
17 |
+
def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
|
18 |
+
return item["type"] == "file"
|
19 |
+
|
20 |
+
class FileSystemOperations:
|
21 |
+
@staticmethod
|
22 |
+
def walk_directory(directory: str) -> List[FileSystemItem]:
|
23 |
+
file_list: List[FileSystemItem] = []
|
24 |
+
for root, dirs, files in os.walk(directory):
|
25 |
+
for name in files:
|
26 |
+
file_path = os.path.join(root, name)
|
27 |
+
relative_path = os.path.relpath(file_path, directory)
|
28 |
+
file_list.append({
|
29 |
+
"name": name,
|
30 |
+
"path": relative_path,
|
31 |
+
"type": "file",
|
32 |
+
"size": os.path.getsize(file_path)
|
33 |
+
})
|
34 |
+
for name in dirs:
|
35 |
+
dir_path = os.path.join(root, name)
|
36 |
+
relative_path = os.path.relpath(dir_path, directory)
|
37 |
+
file_list.append({
|
38 |
+
"name": name,
|
39 |
+
"path": relative_path,
|
40 |
+
"type": "directory"
|
41 |
+
})
|
42 |
+
return file_list
|
app/__init__.py
ADDED
File without changes
|
app/app_settings.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from aiohttp import web
|
4 |
+
import logging
|
5 |
+
|
6 |
+
|
7 |
+
class AppSettings():
|
8 |
+
def __init__(self, user_manager):
|
9 |
+
self.user_manager = user_manager
|
10 |
+
|
11 |
+
def get_settings(self, request):
|
12 |
+
file = self.user_manager.get_request_user_filepath(
|
13 |
+
request, "comfy.settings.json")
|
14 |
+
if os.path.isfile(file):
|
15 |
+
try:
|
16 |
+
with open(file) as f:
|
17 |
+
return json.load(f)
|
18 |
+
except:
|
19 |
+
logging.error(f"The user settings file is corrupted: {file}")
|
20 |
+
return {}
|
21 |
+
else:
|
22 |
+
return {}
|
23 |
+
|
24 |
+
def save_settings(self, request, settings):
|
25 |
+
file = self.user_manager.get_request_user_filepath(
|
26 |
+
request, "comfy.settings.json")
|
27 |
+
with open(file, "w") as f:
|
28 |
+
f.write(json.dumps(settings, indent=4))
|
29 |
+
|
30 |
+
def add_routes(self, routes):
|
31 |
+
@routes.get("/settings")
|
32 |
+
async def get_settings(request):
|
33 |
+
return web.json_response(self.get_settings(request))
|
34 |
+
|
35 |
+
@routes.get("/settings/{id}")
|
36 |
+
async def get_setting(request):
|
37 |
+
value = None
|
38 |
+
settings = self.get_settings(request)
|
39 |
+
setting_id = request.match_info.get("id", None)
|
40 |
+
if setting_id and setting_id in settings:
|
41 |
+
value = settings[setting_id]
|
42 |
+
return web.json_response(value)
|
43 |
+
|
44 |
+
@routes.post("/settings")
|
45 |
+
async def post_settings(request):
|
46 |
+
settings = self.get_settings(request)
|
47 |
+
new_settings = await request.json()
|
48 |
+
self.save_settings(request, {**settings, **new_settings})
|
49 |
+
return web.Response(status=200)
|
50 |
+
|
51 |
+
@routes.post("/settings/{id}")
|
52 |
+
async def post_setting(request):
|
53 |
+
setting_id = request.match_info.get("id", None)
|
54 |
+
if not setting_id:
|
55 |
+
return web.Response(status=400)
|
56 |
+
settings = self.get_settings(request)
|
57 |
+
settings[setting_id] = await request.json()
|
58 |
+
self.save_settings(request, settings)
|
59 |
+
return web.Response(status=200)
|
app/custom_node_manager.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
+
import folder_paths
|
5 |
+
import glob
|
6 |
+
from aiohttp import web
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
from functools import lru_cache
|
10 |
+
|
11 |
+
from utils.json_util import merge_json_recursive
|
12 |
+
|
13 |
+
|
14 |
+
# Extra locale files to load into main.json
|
15 |
+
EXTRA_LOCALE_FILES = [
|
16 |
+
"nodeDefs.json",
|
17 |
+
"commands.json",
|
18 |
+
"settings.json",
|
19 |
+
]
|
20 |
+
|
21 |
+
|
22 |
+
def safe_load_json_file(file_path: str) -> dict:
|
23 |
+
if not os.path.exists(file_path):
|
24 |
+
return {}
|
25 |
+
|
26 |
+
try:
|
27 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
28 |
+
return json.load(f)
|
29 |
+
except json.JSONDecodeError:
|
30 |
+
logging.error(f"Error loading {file_path}")
|
31 |
+
return {}
|
32 |
+
|
33 |
+
|
34 |
+
class CustomNodeManager:
|
35 |
+
@lru_cache(maxsize=1)
|
36 |
+
def build_translations(self):
|
37 |
+
"""Load all custom nodes translations during initialization. Translations are
|
38 |
+
expected to be loaded from `locales/` folder.
|
39 |
+
|
40 |
+
The folder structure is expected to be the following:
|
41 |
+
- custom_nodes/
|
42 |
+
- custom_node_1/
|
43 |
+
- locales/
|
44 |
+
- en/
|
45 |
+
- main.json
|
46 |
+
- commands.json
|
47 |
+
- settings.json
|
48 |
+
|
49 |
+
returned translations are expected to be in the following format:
|
50 |
+
{
|
51 |
+
"en": {
|
52 |
+
"nodeDefs": {...},
|
53 |
+
"commands": {...},
|
54 |
+
"settings": {...},
|
55 |
+
...{other main.json keys}
|
56 |
+
}
|
57 |
+
}
|
58 |
+
"""
|
59 |
+
|
60 |
+
translations = {}
|
61 |
+
|
62 |
+
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
63 |
+
# Sort glob results for deterministic ordering
|
64 |
+
for custom_node_dir in sorted(glob.glob(os.path.join(folder, "*/"))):
|
65 |
+
locales_dir = os.path.join(custom_node_dir, "locales")
|
66 |
+
if not os.path.exists(locales_dir):
|
67 |
+
continue
|
68 |
+
|
69 |
+
for lang_dir in glob.glob(os.path.join(locales_dir, "*/")):
|
70 |
+
lang_code = os.path.basename(os.path.dirname(lang_dir))
|
71 |
+
|
72 |
+
if lang_code not in translations:
|
73 |
+
translations[lang_code] = {}
|
74 |
+
|
75 |
+
# Load main.json
|
76 |
+
main_file = os.path.join(lang_dir, "main.json")
|
77 |
+
node_translations = safe_load_json_file(main_file)
|
78 |
+
|
79 |
+
# Load extra locale files
|
80 |
+
for extra_file in EXTRA_LOCALE_FILES:
|
81 |
+
extra_file_path = os.path.join(lang_dir, extra_file)
|
82 |
+
key = extra_file.split(".")[0]
|
83 |
+
json_data = safe_load_json_file(extra_file_path)
|
84 |
+
if json_data:
|
85 |
+
node_translations[key] = json_data
|
86 |
+
|
87 |
+
if node_translations:
|
88 |
+
translations[lang_code] = merge_json_recursive(
|
89 |
+
translations[lang_code], node_translations
|
90 |
+
)
|
91 |
+
|
92 |
+
return translations
|
93 |
+
|
94 |
+
def add_routes(self, routes, webapp, loadedModules):
|
95 |
+
|
96 |
+
@routes.get("/workflow_templates")
|
97 |
+
async def get_workflow_templates(request):
|
98 |
+
"""Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
|
99 |
+
files = [
|
100 |
+
file
|
101 |
+
for folder in folder_paths.get_folder_paths("custom_nodes")
|
102 |
+
for file in glob.glob(
|
103 |
+
os.path.join(folder, "*/example_workflows/*.json")
|
104 |
+
)
|
105 |
+
]
|
106 |
+
workflow_templates_dict = (
|
107 |
+
{}
|
108 |
+
) # custom_nodes folder name -> example workflow names
|
109 |
+
for file in files:
|
110 |
+
custom_nodes_name = os.path.basename(
|
111 |
+
os.path.dirname(os.path.dirname(file))
|
112 |
+
)
|
113 |
+
workflow_name = os.path.splitext(os.path.basename(file))[0]
|
114 |
+
workflow_templates_dict.setdefault(custom_nodes_name, []).append(
|
115 |
+
workflow_name
|
116 |
+
)
|
117 |
+
return web.json_response(workflow_templates_dict)
|
118 |
+
|
119 |
+
# Serve workflow templates from custom nodes.
|
120 |
+
for module_name, module_dir in loadedModules:
|
121 |
+
workflows_dir = os.path.join(module_dir, "example_workflows")
|
122 |
+
if os.path.exists(workflows_dir):
|
123 |
+
webapp.add_routes(
|
124 |
+
[
|
125 |
+
web.static(
|
126 |
+
"/api/workflow_templates/" + module_name, workflows_dir
|
127 |
+
)
|
128 |
+
]
|
129 |
+
)
|
130 |
+
|
131 |
+
@routes.get("/i18n")
|
132 |
+
async def get_i18n(request):
|
133 |
+
"""Returns translations from all custom nodes' locales folders."""
|
134 |
+
return web.json_response(self.build_translations())
|
app/frontend_management.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
import argparse
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import tempfile
|
7 |
+
import zipfile
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from functools import cached_property
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import TypedDict, Optional
|
12 |
+
|
13 |
+
import requests
|
14 |
+
from typing_extensions import NotRequired
|
15 |
+
from comfy.cli_args import DEFAULT_VERSION_STRING
|
16 |
+
|
17 |
+
|
18 |
+
REQUEST_TIMEOUT = 10 # seconds
|
19 |
+
|
20 |
+
|
21 |
+
class Asset(TypedDict):
|
22 |
+
url: str
|
23 |
+
|
24 |
+
|
25 |
+
class Release(TypedDict):
|
26 |
+
id: int
|
27 |
+
tag_name: str
|
28 |
+
name: str
|
29 |
+
prerelease: bool
|
30 |
+
created_at: str
|
31 |
+
published_at: str
|
32 |
+
body: str
|
33 |
+
assets: NotRequired[list[Asset]]
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class FrontEndProvider:
|
38 |
+
owner: str
|
39 |
+
repo: str
|
40 |
+
|
41 |
+
@property
|
42 |
+
def folder_name(self) -> str:
|
43 |
+
return f"{self.owner}_{self.repo}"
|
44 |
+
|
45 |
+
@property
|
46 |
+
def release_url(self) -> str:
|
47 |
+
return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"
|
48 |
+
|
49 |
+
@cached_property
|
50 |
+
def all_releases(self) -> list[Release]:
|
51 |
+
releases = []
|
52 |
+
api_url = self.release_url
|
53 |
+
while api_url:
|
54 |
+
response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
|
55 |
+
response.raise_for_status() # Raises an HTTPError if the response was an error
|
56 |
+
releases.extend(response.json())
|
57 |
+
# GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
|
58 |
+
if "next" in response.links:
|
59 |
+
api_url = response.links["next"]["url"]
|
60 |
+
else:
|
61 |
+
api_url = None
|
62 |
+
return releases
|
63 |
+
|
64 |
+
@cached_property
|
65 |
+
def latest_release(self) -> Release:
|
66 |
+
latest_release_url = f"{self.release_url}/latest"
|
67 |
+
response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
|
68 |
+
response.raise_for_status() # Raises an HTTPError if the response was an error
|
69 |
+
return response.json()
|
70 |
+
|
71 |
+
def get_release(self, version: str) -> Release:
|
72 |
+
if version == "latest":
|
73 |
+
return self.latest_release
|
74 |
+
else:
|
75 |
+
for release in self.all_releases:
|
76 |
+
if release["tag_name"] in [version, f"v{version}"]:
|
77 |
+
return release
|
78 |
+
raise ValueError(f"Version {version} not found in releases")
|
79 |
+
|
80 |
+
|
81 |
+
def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
82 |
+
"""Download dist.zip from github release."""
|
83 |
+
asset_url = None
|
84 |
+
for asset in release.get("assets", []):
|
85 |
+
if asset["name"] == "dist.zip":
|
86 |
+
asset_url = asset["url"]
|
87 |
+
break
|
88 |
+
|
89 |
+
if not asset_url:
|
90 |
+
raise ValueError("dist.zip not found in the release assets")
|
91 |
+
|
92 |
+
# Use a temporary file to download the zip content
|
93 |
+
with tempfile.TemporaryFile() as tmp_file:
|
94 |
+
headers = {"Accept": "application/octet-stream"}
|
95 |
+
response = requests.get(
|
96 |
+
asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
|
97 |
+
)
|
98 |
+
response.raise_for_status() # Ensure we got a successful response
|
99 |
+
|
100 |
+
# Write the content to the temporary file
|
101 |
+
tmp_file.write(response.content)
|
102 |
+
|
103 |
+
# Go back to the beginning of the temporary file
|
104 |
+
tmp_file.seek(0)
|
105 |
+
|
106 |
+
# Extract the zip file content to the destination path
|
107 |
+
with zipfile.ZipFile(tmp_file, "r") as zip_ref:
|
108 |
+
zip_ref.extractall(destination_path)
|
109 |
+
|
110 |
+
|
111 |
+
class FrontendManager:
|
112 |
+
DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
|
113 |
+
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
117 |
+
"""
|
118 |
+
Args:
|
119 |
+
value (str): The version string to parse.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
tuple[str, str]: A tuple containing provider name and version.
|
123 |
+
|
124 |
+
Raises:
|
125 |
+
argparse.ArgumentTypeError: If the version string is invalid.
|
126 |
+
"""
|
127 |
+
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
|
128 |
+
match_result = re.match(VERSION_PATTERN, value)
|
129 |
+
if match_result is None:
|
130 |
+
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
|
131 |
+
|
132 |
+
return match_result.group(1), match_result.group(2), match_result.group(3)
|
133 |
+
|
134 |
+
@classmethod
|
135 |
+
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
|
136 |
+
"""
|
137 |
+
Initializes the frontend for the specified version.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
version_string (str): The version string.
|
141 |
+
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
str: The path to the initialized frontend.
|
145 |
+
|
146 |
+
Raises:
|
147 |
+
Exception: If there is an error during the initialization process.
|
148 |
+
main error source might be request timeout or invalid URL.
|
149 |
+
"""
|
150 |
+
if version_string == DEFAULT_VERSION_STRING:
|
151 |
+
return cls.DEFAULT_FRONTEND_PATH
|
152 |
+
|
153 |
+
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
154 |
+
|
155 |
+
if version.startswith("v"):
|
156 |
+
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
|
157 |
+
if os.path.exists(expected_path):
|
158 |
+
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
|
159 |
+
return expected_path
|
160 |
+
|
161 |
+
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
|
162 |
+
|
163 |
+
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
164 |
+
release = provider.get_release(version)
|
165 |
+
|
166 |
+
semantic_version = release["tag_name"].lstrip("v")
|
167 |
+
web_root = str(
|
168 |
+
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
169 |
+
)
|
170 |
+
if not os.path.exists(web_root):
|
171 |
+
try:
|
172 |
+
os.makedirs(web_root, exist_ok=True)
|
173 |
+
logging.info(
|
174 |
+
"Downloading frontend(%s) version(%s) to (%s)",
|
175 |
+
provider.folder_name,
|
176 |
+
semantic_version,
|
177 |
+
web_root,
|
178 |
+
)
|
179 |
+
logging.debug(release)
|
180 |
+
download_release_asset_zip(release, destination_path=web_root)
|
181 |
+
finally:
|
182 |
+
# Clean up the directory if it is empty, i.e. the download failed
|
183 |
+
if not os.listdir(web_root):
|
184 |
+
os.rmdir(web_root)
|
185 |
+
|
186 |
+
return web_root
|
187 |
+
|
188 |
+
@classmethod
|
189 |
+
def init_frontend(cls, version_string: str) -> str:
|
190 |
+
"""
|
191 |
+
Initializes the frontend with the specified version string.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
version_string (str): The version string to initialize the frontend with.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
str: The path of the initialized frontend.
|
198 |
+
"""
|
199 |
+
try:
|
200 |
+
return cls.init_frontend_unsafe(version_string)
|
201 |
+
except Exception as e:
|
202 |
+
logging.error("Failed to initialize frontend: %s", e)
|
203 |
+
logging.info("Falling back to the default frontend.")
|
204 |
+
return cls.DEFAULT_FRONTEND_PATH
|
app/logger.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import deque
|
2 |
+
from datetime import datetime
|
3 |
+
import io
|
4 |
+
import logging
|
5 |
+
import sys
|
6 |
+
import threading
|
7 |
+
|
8 |
+
logs = None
|
9 |
+
stdout_interceptor = None
|
10 |
+
stderr_interceptor = None
|
11 |
+
|
12 |
+
|
13 |
+
class LogInterceptor(io.TextIOWrapper):
|
14 |
+
def __init__(self, stream, *args, **kwargs):
|
15 |
+
buffer = stream.buffer
|
16 |
+
encoding = stream.encoding
|
17 |
+
super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
|
18 |
+
self._lock = threading.Lock()
|
19 |
+
self._flush_callbacks = []
|
20 |
+
self._logs_since_flush = []
|
21 |
+
|
22 |
+
def write(self, data):
|
23 |
+
entry = {"t": datetime.now().isoformat(), "m": data}
|
24 |
+
with self._lock:
|
25 |
+
self._logs_since_flush.append(entry)
|
26 |
+
|
27 |
+
# Simple handling for cr to overwrite the last output if it isnt a full line
|
28 |
+
# else logs just get full of progress messages
|
29 |
+
if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
|
30 |
+
logs.pop()
|
31 |
+
logs.append(entry)
|
32 |
+
super().write(data)
|
33 |
+
|
34 |
+
def flush(self):
|
35 |
+
super().flush()
|
36 |
+
for cb in self._flush_callbacks:
|
37 |
+
cb(self._logs_since_flush)
|
38 |
+
self._logs_since_flush = []
|
39 |
+
|
40 |
+
def on_flush(self, callback):
|
41 |
+
self._flush_callbacks.append(callback)
|
42 |
+
|
43 |
+
|
44 |
+
def get_logs():
|
45 |
+
return logs
|
46 |
+
|
47 |
+
|
48 |
+
def on_flush(callback):
|
49 |
+
if stdout_interceptor is not None:
|
50 |
+
stdout_interceptor.on_flush(callback)
|
51 |
+
if stderr_interceptor is not None:
|
52 |
+
stderr_interceptor.on_flush(callback)
|
53 |
+
|
54 |
+
def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False):
|
55 |
+
global logs
|
56 |
+
if logs:
|
57 |
+
return
|
58 |
+
|
59 |
+
# Override output streams and log to buffer
|
60 |
+
logs = deque(maxlen=capacity)
|
61 |
+
|
62 |
+
global stdout_interceptor
|
63 |
+
global stderr_interceptor
|
64 |
+
stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
|
65 |
+
stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)
|
66 |
+
|
67 |
+
# Setup default global logger
|
68 |
+
logger = logging.getLogger()
|
69 |
+
logger.setLevel(log_level)
|
70 |
+
|
71 |
+
stream_handler = logging.StreamHandler()
|
72 |
+
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
73 |
+
|
74 |
+
if use_stdout:
|
75 |
+
# Only errors and critical to stderr
|
76 |
+
stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR)
|
77 |
+
|
78 |
+
# Lesser to stdout
|
79 |
+
stdout_handler = logging.StreamHandler(sys.stdout)
|
80 |
+
stdout_handler.setFormatter(logging.Formatter("%(message)s"))
|
81 |
+
stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
|
82 |
+
logger.addHandler(stdout_handler)
|
83 |
+
|
84 |
+
logger.addHandler(stream_handler)
|
app/model_manager.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
+
import base64
|
5 |
+
import json
|
6 |
+
import time
|
7 |
+
import logging
|
8 |
+
import folder_paths
|
9 |
+
import glob
|
10 |
+
import comfy.utils
|
11 |
+
from aiohttp import web
|
12 |
+
from PIL import Image
|
13 |
+
from io import BytesIO
|
14 |
+
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
|
15 |
+
|
16 |
+
|
17 |
+
class ModelFileManager:
|
18 |
+
def __init__(self) -> None:
|
19 |
+
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
20 |
+
|
21 |
+
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
|
22 |
+
return self.cache.get(key, default)
|
23 |
+
|
24 |
+
def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
|
25 |
+
self.cache[key] = value
|
26 |
+
|
27 |
+
def clear_cache(self):
|
28 |
+
self.cache.clear()
|
29 |
+
|
30 |
+
def add_routes(self, routes):
|
31 |
+
# NOTE: This is an experiment to replace `/models`
|
32 |
+
@routes.get("/experiment/models")
|
33 |
+
async def get_model_folders(request):
|
34 |
+
model_types = list(folder_paths.folder_names_and_paths.keys())
|
35 |
+
folder_black_list = ["configs", "custom_nodes"]
|
36 |
+
output_folders: list[dict] = []
|
37 |
+
for folder in model_types:
|
38 |
+
if folder in folder_black_list:
|
39 |
+
continue
|
40 |
+
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
|
41 |
+
return web.json_response(output_folders)
|
42 |
+
|
43 |
+
# NOTE: This is an experiment to replace `/models/{folder}`
|
44 |
+
@routes.get("/experiment/models/{folder}")
|
45 |
+
async def get_all_models(request):
|
46 |
+
folder = request.match_info.get("folder", None)
|
47 |
+
if not folder in folder_paths.folder_names_and_paths:
|
48 |
+
return web.Response(status=404)
|
49 |
+
files = self.get_model_file_list(folder)
|
50 |
+
return web.json_response(files)
|
51 |
+
|
52 |
+
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
|
53 |
+
async def get_model_preview(request):
|
54 |
+
folder_name = request.match_info.get("folder", None)
|
55 |
+
path_index = int(request.match_info.get("path_index", None))
|
56 |
+
filename = request.match_info.get("filename", None)
|
57 |
+
|
58 |
+
if not folder_name in folder_paths.folder_names_and_paths:
|
59 |
+
return web.Response(status=404)
|
60 |
+
|
61 |
+
folders = folder_paths.folder_names_and_paths[folder_name]
|
62 |
+
folder = folders[0][path_index]
|
63 |
+
full_filename = os.path.join(folder, filename)
|
64 |
+
|
65 |
+
previews = self.get_model_previews(full_filename)
|
66 |
+
default_preview = previews[0] if len(previews) > 0 else None
|
67 |
+
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
|
68 |
+
return web.Response(status=404)
|
69 |
+
|
70 |
+
try:
|
71 |
+
with Image.open(default_preview) as img:
|
72 |
+
img_bytes = BytesIO()
|
73 |
+
img.save(img_bytes, format="WEBP")
|
74 |
+
img_bytes.seek(0)
|
75 |
+
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
|
76 |
+
except:
|
77 |
+
return web.Response(status=404)
|
78 |
+
|
79 |
+
def get_model_file_list(self, folder_name: str):
|
80 |
+
folder_name = map_legacy(folder_name)
|
81 |
+
folders = folder_paths.folder_names_and_paths[folder_name]
|
82 |
+
output_list: list[dict] = []
|
83 |
+
|
84 |
+
for index, folder in enumerate(folders[0]):
|
85 |
+
if not os.path.isdir(folder):
|
86 |
+
continue
|
87 |
+
out = self.cache_model_file_list_(folder)
|
88 |
+
if out is None:
|
89 |
+
out = self.recursive_search_models_(folder, index)
|
90 |
+
self.set_cache(folder, out)
|
91 |
+
output_list.extend(out[0])
|
92 |
+
|
93 |
+
return output_list
|
94 |
+
|
95 |
+
def cache_model_file_list_(self, folder: str):
|
96 |
+
model_file_list_cache = self.get_cache(folder)
|
97 |
+
|
98 |
+
if model_file_list_cache is None:
|
99 |
+
return None
|
100 |
+
if not os.path.isdir(folder):
|
101 |
+
return None
|
102 |
+
if os.path.getmtime(folder) != model_file_list_cache[1]:
|
103 |
+
return None
|
104 |
+
for x in model_file_list_cache[1]:
|
105 |
+
time_modified = model_file_list_cache[1][x]
|
106 |
+
folder = x
|
107 |
+
if os.path.getmtime(folder) != time_modified:
|
108 |
+
return None
|
109 |
+
|
110 |
+
return model_file_list_cache
|
111 |
+
|
112 |
+
def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
|
113 |
+
if not os.path.isdir(directory):
|
114 |
+
return [], {}, time.perf_counter()
|
115 |
+
|
116 |
+
excluded_dir_names = [".git"]
|
117 |
+
# TODO use settings
|
118 |
+
include_hidden_files = False
|
119 |
+
|
120 |
+
result: list[str] = []
|
121 |
+
dirs: dict[str, float] = {}
|
122 |
+
|
123 |
+
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
124 |
+
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
125 |
+
if not include_hidden_files:
|
126 |
+
subdirs[:] = [d for d in subdirs if not d.startswith(".")]
|
127 |
+
filenames = [f for f in filenames if not f.startswith(".")]
|
128 |
+
|
129 |
+
filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
|
130 |
+
|
131 |
+
for file_name in filenames:
|
132 |
+
try:
|
133 |
+
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
134 |
+
result.append(relative_path)
|
135 |
+
except:
|
136 |
+
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
137 |
+
continue
|
138 |
+
|
139 |
+
for d in subdirs:
|
140 |
+
path: str = os.path.join(dirpath, d)
|
141 |
+
try:
|
142 |
+
dirs[path] = os.path.getmtime(path)
|
143 |
+
except FileNotFoundError:
|
144 |
+
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
145 |
+
continue
|
146 |
+
|
147 |
+
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
148 |
+
|
149 |
+
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
|
150 |
+
dirname = os.path.dirname(filepath)
|
151 |
+
|
152 |
+
if not os.path.exists(dirname):
|
153 |
+
return []
|
154 |
+
|
155 |
+
basename = os.path.splitext(filepath)[0]
|
156 |
+
match_files = glob.glob(f"{basename}.*", recursive=False)
|
157 |
+
image_files = filter_files_content_types(match_files, "image")
|
158 |
+
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
|
159 |
+
safetensors_metadata = {}
|
160 |
+
|
161 |
+
result: list[str | BytesIO] = []
|
162 |
+
|
163 |
+
for filename in image_files:
|
164 |
+
_basename = os.path.splitext(filename)[0]
|
165 |
+
if _basename == basename:
|
166 |
+
result.append(filename)
|
167 |
+
if _basename == f"{basename}.preview":
|
168 |
+
result.append(filename)
|
169 |
+
|
170 |
+
if safetensors_file:
|
171 |
+
safetensors_filepath = os.path.join(dirname, safetensors_file)
|
172 |
+
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
|
173 |
+
if header:
|
174 |
+
safetensors_metadata = json.loads(header)
|
175 |
+
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
|
176 |
+
if safetensors_images:
|
177 |
+
safetensors_images = json.loads(safetensors_images)
|
178 |
+
for image in safetensors_images:
|
179 |
+
result.append(BytesIO(base64.b64decode(image)))
|
180 |
+
|
181 |
+
return result
|
182 |
+
|
183 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
184 |
+
self.clear_cache()
|
app/user_manager.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import uuid
|
6 |
+
import glob
|
7 |
+
import shutil
|
8 |
+
import logging
|
9 |
+
from aiohttp import web
|
10 |
+
from urllib import parse
|
11 |
+
from comfy.cli_args import args
|
12 |
+
import folder_paths
|
13 |
+
from .app_settings import AppSettings
|
14 |
+
from typing import TypedDict
|
15 |
+
|
16 |
+
default_user = "default"
|
17 |
+
|
18 |
+
|
19 |
+
class FileInfo(TypedDict):
|
20 |
+
path: str
|
21 |
+
size: int
|
22 |
+
modified: int
|
23 |
+
|
24 |
+
|
25 |
+
def get_file_info(path: str, relative_to: str) -> FileInfo:
|
26 |
+
return {
|
27 |
+
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
|
28 |
+
"size": os.path.getsize(path),
|
29 |
+
"modified": os.path.getmtime(path)
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
class UserManager():
|
34 |
+
def __init__(self):
|
35 |
+
user_directory = folder_paths.get_user_directory()
|
36 |
+
|
37 |
+
self.settings = AppSettings(self)
|
38 |
+
if not os.path.exists(user_directory):
|
39 |
+
os.makedirs(user_directory, exist_ok=True)
|
40 |
+
if not args.multi_user:
|
41 |
+
logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
42 |
+
logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
43 |
+
|
44 |
+
if args.multi_user:
|
45 |
+
if os.path.isfile(self.get_users_file()):
|
46 |
+
with open(self.get_users_file()) as f:
|
47 |
+
self.users = json.load(f)
|
48 |
+
else:
|
49 |
+
self.users = {}
|
50 |
+
else:
|
51 |
+
self.users = {"default": "default"}
|
52 |
+
|
53 |
+
def get_users_file(self):
|
54 |
+
return os.path.join(folder_paths.get_user_directory(), "users.json")
|
55 |
+
|
56 |
+
def get_request_user_id(self, request):
|
57 |
+
user = "default"
|
58 |
+
if args.multi_user and "comfy-user" in request.headers:
|
59 |
+
user = request.headers["comfy-user"]
|
60 |
+
|
61 |
+
if user not in self.users:
|
62 |
+
raise KeyError("Unknown user: " + user)
|
63 |
+
|
64 |
+
return user
|
65 |
+
|
66 |
+
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
67 |
+
user_directory = folder_paths.get_user_directory()
|
68 |
+
|
69 |
+
if type == "userdata":
|
70 |
+
root_dir = user_directory
|
71 |
+
else:
|
72 |
+
raise KeyError("Unknown filepath type:" + type)
|
73 |
+
|
74 |
+
user = self.get_request_user_id(request)
|
75 |
+
path = user_root = os.path.abspath(os.path.join(root_dir, user))
|
76 |
+
|
77 |
+
# prevent leaving /{type}
|
78 |
+
if os.path.commonpath((root_dir, user_root)) != root_dir:
|
79 |
+
return None
|
80 |
+
|
81 |
+
if file is not None:
|
82 |
+
# Check if filename is url encoded
|
83 |
+
if "%" in file:
|
84 |
+
file = parse.unquote(file)
|
85 |
+
|
86 |
+
# prevent leaving /{type}/{user}
|
87 |
+
path = os.path.abspath(os.path.join(user_root, file))
|
88 |
+
if os.path.commonpath((user_root, path)) != user_root:
|
89 |
+
return None
|
90 |
+
|
91 |
+
parent = os.path.split(path)[0]
|
92 |
+
|
93 |
+
if create_dir and not os.path.exists(parent):
|
94 |
+
os.makedirs(parent, exist_ok=True)
|
95 |
+
|
96 |
+
return path
|
97 |
+
|
98 |
+
def add_user(self, name):
|
99 |
+
name = name.strip()
|
100 |
+
if not name:
|
101 |
+
raise ValueError("username not provided")
|
102 |
+
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
|
103 |
+
user_id = user_id + "_" + str(uuid.uuid4())
|
104 |
+
|
105 |
+
self.users[user_id] = name
|
106 |
+
|
107 |
+
with open(self.get_users_file(), "w") as f:
|
108 |
+
json.dump(self.users, f)
|
109 |
+
|
110 |
+
return user_id
|
111 |
+
|
112 |
+
def add_routes(self, routes):
|
113 |
+
self.settings.add_routes(routes)
|
114 |
+
|
115 |
+
@routes.get("/users")
|
116 |
+
async def get_users(request):
|
117 |
+
if args.multi_user:
|
118 |
+
return web.json_response({"storage": "server", "users": self.users})
|
119 |
+
else:
|
120 |
+
user_dir = self.get_request_user_filepath(request, None, create_dir=False)
|
121 |
+
return web.json_response({
|
122 |
+
"storage": "server",
|
123 |
+
"migrated": os.path.exists(user_dir)
|
124 |
+
})
|
125 |
+
|
126 |
+
@routes.post("/users")
|
127 |
+
async def post_users(request):
|
128 |
+
body = await request.json()
|
129 |
+
username = body["username"]
|
130 |
+
if username in self.users.values():
|
131 |
+
return web.json_response({"error": "Duplicate username."}, status=400)
|
132 |
+
|
133 |
+
user_id = self.add_user(username)
|
134 |
+
return web.json_response(user_id)
|
135 |
+
|
136 |
+
@routes.get("/userdata")
|
137 |
+
async def listuserdata(request):
|
138 |
+
"""
|
139 |
+
List user data files in a specified directory.
|
140 |
+
|
141 |
+
This endpoint allows listing files in a user's data directory, with options for recursion,
|
142 |
+
full file information, and path splitting.
|
143 |
+
|
144 |
+
Query Parameters:
|
145 |
+
- dir (required): The directory to list files from.
|
146 |
+
- recurse (optional): If "true", recursively list files in subdirectories.
|
147 |
+
- full_info (optional): If "true", return detailed file information (path, size, modified time).
|
148 |
+
- split (optional): If "true", split file paths into components (only applies when full_info is false).
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
- 400: If 'dir' parameter is missing.
|
152 |
+
- 403: If the requested path is not allowed.
|
153 |
+
- 404: If the requested directory does not exist.
|
154 |
+
- 200: JSON response with the list of files or file information.
|
155 |
+
|
156 |
+
The response format depends on the query parameters:
|
157 |
+
- Default: List of relative file paths.
|
158 |
+
- full_info=true: List of dictionaries with file details.
|
159 |
+
- split=true (and full_info=false): List of lists, each containing path components.
|
160 |
+
"""
|
161 |
+
directory = request.rel_url.query.get('dir', '')
|
162 |
+
if not directory:
|
163 |
+
return web.Response(status=400, text="Directory not provided")
|
164 |
+
|
165 |
+
path = self.get_request_user_filepath(request, directory)
|
166 |
+
if not path:
|
167 |
+
return web.Response(status=403, text="Invalid directory")
|
168 |
+
|
169 |
+
if not os.path.exists(path):
|
170 |
+
return web.Response(status=404, text="Directory not found")
|
171 |
+
|
172 |
+
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
173 |
+
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
|
174 |
+
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
175 |
+
|
176 |
+
# Use different patterns based on whether we're recursing or not
|
177 |
+
if recurse:
|
178 |
+
pattern = os.path.join(glob.escape(path), '**', '*')
|
179 |
+
else:
|
180 |
+
pattern = os.path.join(glob.escape(path), '*')
|
181 |
+
|
182 |
+
def process_full_path(full_path: str) -> FileInfo | str | list[str]:
|
183 |
+
if full_info:
|
184 |
+
return get_file_info(full_path, path)
|
185 |
+
|
186 |
+
rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
|
187 |
+
if split_path:
|
188 |
+
return [rel_path] + rel_path.split('/')
|
189 |
+
|
190 |
+
return rel_path
|
191 |
+
|
192 |
+
results = [
|
193 |
+
process_full_path(full_path)
|
194 |
+
for full_path in glob.glob(pattern, recursive=recurse)
|
195 |
+
if os.path.isfile(full_path)
|
196 |
+
]
|
197 |
+
|
198 |
+
return web.json_response(results)
|
199 |
+
|
200 |
+
def get_user_data_path(request, check_exists = False, param = "file"):
|
201 |
+
file = request.match_info.get(param, None)
|
202 |
+
if not file:
|
203 |
+
return web.Response(status=400)
|
204 |
+
|
205 |
+
path = self.get_request_user_filepath(request, file)
|
206 |
+
if not path:
|
207 |
+
return web.Response(status=403)
|
208 |
+
|
209 |
+
if check_exists and not os.path.exists(path):
|
210 |
+
return web.Response(status=404)
|
211 |
+
|
212 |
+
return path
|
213 |
+
|
214 |
+
@routes.get("/userdata/{file}")
|
215 |
+
async def getuserdata(request):
|
216 |
+
path = get_user_data_path(request, check_exists=True)
|
217 |
+
if not isinstance(path, str):
|
218 |
+
return path
|
219 |
+
|
220 |
+
return web.FileResponse(path)
|
221 |
+
|
222 |
+
@routes.post("/userdata/{file}")
|
223 |
+
async def post_userdata(request):
|
224 |
+
"""
|
225 |
+
Upload or update a user data file.
|
226 |
+
|
227 |
+
This endpoint handles file uploads to a user's data directory, with options for
|
228 |
+
controlling overwrite behavior and response format.
|
229 |
+
|
230 |
+
Query Parameters:
|
231 |
+
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
|
232 |
+
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
|
233 |
+
If "false", returns only the relative file path.
|
234 |
+
|
235 |
+
Path Parameters:
|
236 |
+
- file: The target file path (URL encoded if necessary).
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
- 400: If 'file' parameter is missing.
|
240 |
+
- 403: If the requested path is not allowed.
|
241 |
+
- 409: If overwrite=false and the file already exists.
|
242 |
+
- 200: JSON response with either:
|
243 |
+
- Full file information (if full_info=true)
|
244 |
+
- Relative file path (if full_info=false)
|
245 |
+
|
246 |
+
The request body should contain the raw file content to be written.
|
247 |
+
"""
|
248 |
+
path = get_user_data_path(request)
|
249 |
+
if not isinstance(path, str):
|
250 |
+
return path
|
251 |
+
|
252 |
+
overwrite = request.query.get("overwrite", 'true') != "false"
|
253 |
+
full_info = request.query.get('full_info', 'false').lower() == "true"
|
254 |
+
|
255 |
+
if not overwrite and os.path.exists(path):
|
256 |
+
return web.Response(status=409, text="File already exists")
|
257 |
+
|
258 |
+
body = await request.read()
|
259 |
+
|
260 |
+
with open(path, "wb") as f:
|
261 |
+
f.write(body)
|
262 |
+
|
263 |
+
user_path = self.get_request_user_filepath(request, None)
|
264 |
+
if full_info:
|
265 |
+
resp = get_file_info(path, user_path)
|
266 |
+
else:
|
267 |
+
resp = os.path.relpath(path, user_path)
|
268 |
+
|
269 |
+
return web.json_response(resp)
|
270 |
+
|
271 |
+
@routes.delete("/userdata/{file}")
|
272 |
+
async def delete_userdata(request):
|
273 |
+
path = get_user_data_path(request, check_exists=True)
|
274 |
+
if not isinstance(path, str):
|
275 |
+
return path
|
276 |
+
|
277 |
+
os.remove(path)
|
278 |
+
|
279 |
+
return web.Response(status=204)
|
280 |
+
|
281 |
+
@routes.post("/userdata/{file}/move/{dest}")
|
282 |
+
async def move_userdata(request):
|
283 |
+
"""
|
284 |
+
Move or rename a user data file.
|
285 |
+
|
286 |
+
This endpoint handles moving or renaming files within a user's data directory, with options for
|
287 |
+
controlling overwrite behavior and response format.
|
288 |
+
|
289 |
+
Path Parameters:
|
290 |
+
- file: The source file path (URL encoded if necessary)
|
291 |
+
- dest: The destination file path (URL encoded if necessary)
|
292 |
+
|
293 |
+
Query Parameters:
|
294 |
+
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
|
295 |
+
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
|
296 |
+
If "false", returns only the relative file path.
|
297 |
+
|
298 |
+
Returns:
|
299 |
+
- 400: If either 'file' or 'dest' parameter is missing
|
300 |
+
- 403: If either requested path is not allowed
|
301 |
+
- 404: If the source file does not exist
|
302 |
+
- 409: If overwrite=false and the destination file already exists
|
303 |
+
- 200: JSON response with either:
|
304 |
+
- Full file information (if full_info=true)
|
305 |
+
- Relative file path (if full_info=false)
|
306 |
+
"""
|
307 |
+
source = get_user_data_path(request, check_exists=True)
|
308 |
+
if not isinstance(source, str):
|
309 |
+
return source
|
310 |
+
|
311 |
+
dest = get_user_data_path(request, check_exists=False, param="dest")
|
312 |
+
if not isinstance(source, str):
|
313 |
+
return dest
|
314 |
+
|
315 |
+
overwrite = request.query.get("overwrite", 'true') != "false"
|
316 |
+
full_info = request.query.get('full_info', 'false').lower() == "true"
|
317 |
+
|
318 |
+
if not overwrite and os.path.exists(dest):
|
319 |
+
return web.Response(status=409, text="File already exists")
|
320 |
+
|
321 |
+
logging.info(f"moving '{source}' -> '{dest}'")
|
322 |
+
shutil.move(source, dest)
|
323 |
+
|
324 |
+
user_path = self.get_request_user_filepath(request, None)
|
325 |
+
if full_info:
|
326 |
+
resp = get_file_info(dest, user_path)
|
327 |
+
else:
|
328 |
+
resp = os.path.relpath(dest, user_path)
|
329 |
+
|
330 |
+
return web.json_response(resp)
|