|
|
|
|
|
from torch import Tensor |
|
import torch |
|
from PIL import Image |
|
import numpy as np |
|
|
|
import os |
|
import sys |
|
import json |
|
import folder_paths |
|
|
|
|
|
my_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
sys.path.append(my_dir) |
|
|
|
|
|
comfy_dir = os.path.abspath(os.path.join(my_dir, '..', '..')) |
|
|
|
|
|
sys.path.append(comfy_dir) |
|
|
|
|
|
import comfy.sd |
|
|
|
|
|
from tsc_sd import * |
|
|
|
|
|
loaded_objects = { |
|
"ckpt": [], |
|
"vae": [], |
|
"lora": [] |
|
} |
|
|
|
|
|
last_helds: dict[str, list] = { |
|
"results": [], |
|
"latent": [], |
|
"images": [], |
|
"vae_decode": [], |
|
} |
|
|
|
|
|
def tensor2pil(image: torch.Tensor) -> Image.Image: |
|
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) |
|
|
|
|
|
def pil2tensor(image: Image.Image) -> torch.Tensor: |
|
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) |
|
|
|
def extract_node_info(prompt, id, indirect_key=None): |
|
|
|
id = str(id) |
|
node_id = None |
|
|
|
|
|
if indirect_key: |
|
|
|
if id in prompt and 'inputs' in prompt[id] and indirect_key in prompt[id]['inputs']: |
|
|
|
indirect_id = prompt[id]['inputs'][indirect_key][0] |
|
|
|
|
|
if indirect_id in prompt: |
|
node_id = indirect_id |
|
return prompt[indirect_id].get('class_type', None), node_id |
|
|
|
|
|
return None, None |
|
|
|
|
|
return prompt.get(id, {}).get('class_type', None), node_id |
|
|
|
def extract_node_value(prompt, id, key): |
|
|
|
return prompt.get(str(id), {}).get('inputs', {}).get(key, None) |
|
|
|
def print_loaded_objects_entries(id=None, prompt=None, show_id=False): |
|
print("-" * 40) |
|
if id is not None: |
|
id = str(id) |
|
if prompt is not None and id is not None: |
|
node_name, _ = extract_node_info(prompt, id) |
|
if show_id: |
|
print(f"\033[36m{node_name} Models Cache: (node_id:{int(id)})\033[0m") |
|
else: |
|
print(f"\033[36m{node_name} Models Cache:\033[0m") |
|
elif id is None: |
|
print(f"\033[36mGlobal Models Cache:\033[0m") |
|
else: |
|
print(f"\033[36mModels Cache: \nnode_id:{int(id)}\033[0m") |
|
entries_found = False |
|
for key in ["ckpt", "vae", "lora"]: |
|
entries_with_id = loaded_objects[key] if id is None else [entry for entry in loaded_objects[key] if id in entry[-1]] |
|
if not entries_with_id: |
|
continue |
|
entries_found = True |
|
print(f"{key.capitalize()}:") |
|
for i, entry in enumerate(entries_with_id, 1): |
|
if key == "lora": |
|
lora_models_info = ', '.join(f"{os.path.splitext(os.path.basename(name))[0]}({round(strength_model, 2)},{round(strength_clip, 2)})" for name, strength_model, strength_clip in entry[0]) |
|
base_ckpt_name = os.path.splitext(os.path.basename(entry[1]))[0] |
|
if id is None: |
|
associated_ids = ', '.join(map(str, entry[-1])) |
|
print(f" [{i}] base_ckpt: {base_ckpt_name}, lora(mod,clip): {lora_models_info} (ids: {associated_ids})") |
|
else: |
|
print(f" [{i}] base_ckpt: {base_ckpt_name}, lora(mod,clip): {lora_models_info}") |
|
else: |
|
name_without_ext = os.path.splitext(os.path.basename(entry[0]))[0] |
|
if id is None: |
|
associated_ids = ', '.join(map(str, entry[-1])) |
|
print(f" [{i}] {name_without_ext} (ids: {associated_ids})") |
|
else: |
|
print(f" [{i}] {name_without_ext}") |
|
if not entries_found: |
|
print("-") |
|
|
|
|
|
|
|
def globals_cleanup(prompt): |
|
global loaded_objects |
|
global last_helds |
|
|
|
|
|
for key in list(last_helds.keys()): |
|
original_length = len(last_helds[key]) |
|
last_helds[key] = [(value, id) for value, id in last_helds[key] if str(id) in prompt.keys()] |
|
|
|
|
|
|
|
|
|
for key in list(loaded_objects.keys()): |
|
for i, tup in enumerate(list(loaded_objects[key])): |
|
|
|
id_array = [id for id in tup[-1] if str(id) in prompt.keys()] |
|
if len(id_array) != len(tup[-1]): |
|
if id_array: |
|
loaded_objects[key][i] = tup[:-1] + (id_array,) |
|
|
|
else: |
|
|
|
loaded_objects[key].remove(tup) |
|
|
|
|
|
def load_checkpoint(ckpt_name, id, output_vae=True, cache=None, cache_overwrite=False): |
|
""" |
|
Searches for tuple index that contains ckpt_name in "ckpt" array of loaded_objects. |
|
If found, extracts the model, clip, and vae from the loaded_objects. |
|
If not found, loads the checkpoint, extracts the model, clip, and vae. |
|
The id parameter represents the node ID and is used for caching models for the XY Plot node. |
|
If the cache limit is reached for a specific id, clears the cache and returns the loaded model, clip, and vae without adding a new entry. |
|
If there is cache space, adds the id to the ids list if it's not already there. |
|
If there is cache space and the checkpoint was not found in loaded_objects, adds a new entry to loaded_objects. |
|
|
|
Parameters: |
|
- ckpt_name: name of the checkpoint to load. |
|
- id: an identifier for caching models for specific nodes. |
|
- output_vae: boolean, if True loads the VAE too. |
|
- cache (optional): an integer that specifies how many checkpoint entries with a given id can exist in loaded_objects. Defaults to None. |
|
""" |
|
global loaded_objects |
|
|
|
for entry in loaded_objects["ckpt"]: |
|
if entry[0] == ckpt_name: |
|
_, model, clip, vae, ids = entry |
|
cache_full = cache and len([entry for entry in loaded_objects["ckpt"] if id in entry[-1]]) >= cache |
|
|
|
if cache_full: |
|
clear_cache(id, cache, "ckpt") |
|
elif id not in ids: |
|
ids.append(id) |
|
|
|
return model, clip, vae |
|
|
|
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) |
|
out = load_checkpoint_guess_config_tsc(ckpt_path, output_vae, output_clip=True, |
|
embedding_directory=folder_paths.get_folder_paths("embeddings")) |
|
model = out[0] |
|
clip = out[1] |
|
vae = out[2] |
|
|
|
if cache: |
|
if len([entry for entry in loaded_objects["ckpt"] if id in entry[-1]]) < cache: |
|
loaded_objects["ckpt"].append((ckpt_name, model, clip, vae, [id])) |
|
else: |
|
clear_cache(id, cache, "ckpt") |
|
if cache_overwrite: |
|
|
|
for e in loaded_objects["ckpt"]: |
|
if id in e[-1]: |
|
e[-1].remove(id) |
|
|
|
if not e[-1]: |
|
loaded_objects["ckpt"].remove(e) |
|
break |
|
loaded_objects["ckpt"].append((ckpt_name, model, clip, vae, [id])) |
|
|
|
return model, clip, vae |
|
|
|
def get_bvae_by_ckpt_name(ckpt_name): |
|
for ckpt in loaded_objects["ckpt"]: |
|
if ckpt[0] == ckpt_name: |
|
return ckpt[3] |
|
return None |
|
|
|
def load_vae(vae_name, id, cache=None, cache_overwrite=False): |
|
""" |
|
Extracts the vae with a given name from the "vae" array in loaded_objects. |
|
If the vae is not found, creates a new VAE object with the given name and adds it to the "vae" array. |
|
Also stores the id parameter, which is used for caching models specifically for nodes with the given ID. |
|
If the cache limit is reached for a specific id, returns the loaded vae without adding id or making a new entry in loaded_objects. |
|
If there is cache space, and the id is not in the ids list, adds the id to the ids list. |
|
If there is cache space, and the vae was not found in loaded_objects, adds a new entry to the loaded_objects. |
|
|
|
Parameters: |
|
- vae_name: name of the VAE to load. |
|
- id (optional): an identifier for caching models for specific nodes. Defaults to None. |
|
- cache (optional): an integer that specifies how many vae entries with a given id can exist in loaded_objects. Defaults to None. |
|
""" |
|
global loaded_objects |
|
|
|
for i, entry in enumerate(loaded_objects["vae"]): |
|
if entry[0] == vae_name: |
|
vae, ids = entry[1], entry[2] |
|
if id not in ids: |
|
if cache and len([entry for entry in loaded_objects["vae"] if id in entry[-1]]) >= cache: |
|
return vae |
|
ids.append(id) |
|
if cache: |
|
clear_cache(id, cache, "vae") |
|
return vae |
|
|
|
vae_path = folder_paths.get_full_path("vae", vae_name) |
|
vae = comfy.sd.VAE(ckpt_path=vae_path) |
|
|
|
if cache: |
|
if len([entry for entry in loaded_objects["vae"] if id in entry[-1]]) < cache: |
|
loaded_objects["vae"].append((vae_name, vae, [id])) |
|
else: |
|
clear_cache(id, cache, "vae") |
|
if cache_overwrite: |
|
|
|
for e in loaded_objects["vae"]: |
|
if id in e[-1]: |
|
e[-1].remove(id) |
|
|
|
if not e[-1]: |
|
loaded_objects["vae"].remove(e) |
|
break |
|
loaded_objects["vae"].append((vae_name, vae, [id])) |
|
|
|
return vae |
|
|
|
def load_lora(lora_params, ckpt_name, id, cache=None, ckpt_cache=None, cache_overwrite=False): |
|
""" |
|
Extracts the Lora model with a given name from the "lora" array in loaded_objects. |
|
If the Lora model is not found or strength values changed or model changed, creates a new Lora object with the given name and adds it to the "lora" array. |
|
Also stores the id parameter, which is used for caching models specifically for nodes with the given ID. |
|
If the cache limit is reached for a specific id, clears the cache and returns the loaded Lora model and clip without adding a new entry. |
|
If there is cache space, adds the id to the ids list if it's not already there. |
|
If there is cache space and the Lora model was not found in loaded_objects, adds a new entry to loaded_objects. |
|
|
|
Parameters: |
|
- lora_params: A list of tuples, where each tuple contains lora_name, strength_model, strength_clip. |
|
- ckpt_name: name of the checkpoint from which the Lora model is created. |
|
- id: an identifier for caching models for specific nodes. |
|
- cache (optional): an integer that specifies how many Lora entries with a given id can exist in loaded_objects. Defaults to None. |
|
""" |
|
global loaded_objects |
|
|
|
for entry in loaded_objects["lora"]: |
|
|
|
if set(entry[0]) == set(lora_params) and entry[1] == ckpt_name: |
|
|
|
_, _, lora_model, lora_clip, ids = entry |
|
cache_full = cache and len([entry for entry in loaded_objects["lora"] if id in entry[-1]]) >= cache |
|
|
|
if cache_full: |
|
clear_cache(id, cache, "lora") |
|
elif id not in ids: |
|
ids.append(id) |
|
|
|
|
|
for ckpt_entry in loaded_objects["ckpt"]: |
|
if ckpt_entry[0] == ckpt_name: |
|
_, _, _, _, ckpt_ids = ckpt_entry |
|
ckpt_cache_full = ckpt_cache and len( |
|
[ckpt_entry for ckpt_entry in loaded_objects["ckpt"] if id in ckpt_entry[-1]]) >= ckpt_cache |
|
|
|
if ckpt_cache_full: |
|
clear_cache(id, ckpt_cache, "ckpt") |
|
elif id not in ckpt_ids: |
|
ckpt_ids.append(id) |
|
|
|
return lora_model, lora_clip |
|
|
|
def recursive_load_lora(lora_params, ckpt, clip, id, ckpt_cache, cache_overwrite, folder_paths): |
|
if len(lora_params) == 0: |
|
return ckpt, clip |
|
|
|
lora_name, strength_model, strength_clip = lora_params[0] |
|
lora_path = folder_paths.get_full_path("loras", lora_name) |
|
lora_model, lora_clip = load_lora_for_models_tsc(ckpt, clip, lora_path, strength_model, strength_clip) |
|
|
|
|
|
return recursive_load_lora(lora_params[1:], lora_model, lora_clip, id, ckpt_cache, cache_overwrite, folder_paths) |
|
|
|
|
|
lora_name, strength_model, strength_clip = lora_params[0] |
|
ckpt, clip, _ = load_checkpoint(ckpt_name, id, cache=ckpt_cache, cache_overwrite=cache_overwrite) |
|
|
|
lora_model, lora_clip = recursive_load_lora(lora_params, ckpt, clip, id, ckpt_cache, cache_overwrite, folder_paths) |
|
|
|
if cache: |
|
if len([entry for entry in loaded_objects["lora"] if id in entry[-1]]) < cache: |
|
loaded_objects["lora"].append((lora_params, ckpt_name, lora_model, lora_clip, [id])) |
|
else: |
|
clear_cache(id, cache, "lora") |
|
if cache_overwrite: |
|
|
|
for e in loaded_objects["lora"]: |
|
if id in e[-1]: |
|
e[-1].remove(id) |
|
|
|
if not e[-1]: |
|
loaded_objects["lora"].remove(e) |
|
break |
|
loaded_objects["lora"].append((lora_params, ckpt_name, lora_model, lora_clip, [id])) |
|
|
|
return lora_model, lora_clip |
|
|
|
def clear_cache(id, cache, dict_name): |
|
""" |
|
Clear the cache for a specific id in a specific dictionary (either "ckpt" or "vae"). |
|
If the cache limit is reached for a specific id, deletes the id from the oldest entry. |
|
If the id array of the entry becomes empty, deletes the entry. |
|
""" |
|
|
|
id_associated_entries = [entry for entry in loaded_objects[dict_name] if id in entry[-1]] |
|
while len(id_associated_entries) > cache: |
|
|
|
older_entry = id_associated_entries[0] |
|
|
|
older_entry[-1].remove(id) |
|
|
|
if not older_entry[-1]: |
|
loaded_objects[dict_name].remove(older_entry) |
|
|
|
id_associated_entries = [entry for entry in loaded_objects[dict_name] if id in entry[-1]] |
|
|
|
def clear_cache_by_exception(node_id, vae_dict=None, ckpt_dict=None, lora_dict=None): |
|
global loaded_objects |
|
|
|
dict_mapping = { |
|
"vae_dict": "vae", |
|
"ckpt_dict": "ckpt", |
|
"lora_dict": "lora" |
|
} |
|
|
|
for arg_name, arg_val in {"vae_dict": vae_dict, "ckpt_dict": ckpt_dict, "lora_dict": lora_dict}.items(): |
|
if arg_val is None: |
|
continue |
|
|
|
dict_name = dict_mapping[arg_name] |
|
|
|
for tuple_idx, tuple_item in enumerate(loaded_objects[dict_name].copy()): |
|
if arg_name == "lora_dict": |
|
|
|
for lora_params, ckpt_name in arg_val: |
|
|
|
if set(lora_params) == set(tuple_item[0]) and ckpt_name == tuple_item[1]: |
|
break |
|
else: |
|
if node_id in tuple_item[-1]: |
|
tuple_item[-1].remove(node_id) |
|
if not tuple_item[-1]: |
|
loaded_objects[dict_name].remove(tuple_item) |
|
continue |
|
elif tuple_item[0] not in arg_val: |
|
if node_id in tuple_item[-1]: |
|
tuple_item[-1].remove(node_id) |
|
if not tuple_item[-1]: |
|
loaded_objects[dict_name].remove(tuple_item) |
|
|
|
|
|
|
|
def get_cache_numbers(node_name): |
|
|
|
my_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
settings_file = os.path.join(my_dir, 'node_settings.json') |
|
|
|
with open(settings_file, 'r') as file: |
|
node_settings = json.load(file) |
|
|
|
model_cache_settings = node_settings.get(node_name, {}).get('model_cache', {}) |
|
vae_cache = int(model_cache_settings.get('vae', 1)) |
|
ckpt_cache = int(model_cache_settings.get('ckpt', 1)) |
|
lora_cache = int(model_cache_settings.get('lora', 1)) |
|
return vae_cache, ckpt_cache, lora_cache |
|
|
|
def print_last_helds(id=None): |
|
print("\n" + "-" * 40) |
|
if id is not None: |
|
id = str(id) |
|
print(f"Node-specific Last Helds (node_id:{int(id)})") |
|
else: |
|
print(f"Global Last Helds:") |
|
for key in ["results", "latent", "images", "vae_decode"]: |
|
entries_with_id = last_helds[key] if id is None else [entry for entry in last_helds[key] if id == entry[-1]] |
|
if not entries_with_id: |
|
continue |
|
print(f"{key.capitalize()}:") |
|
for i, entry in enumerate(entries_with_id, 1): |
|
if isinstance(entry[0], bool): |
|
output = entry[0] |
|
else: |
|
output = len(entry[0]) |
|
if id is None: |
|
print(f" [{i}] Output: {output} (id: {entry[-1]})") |
|
else: |
|
print(f" [{i}] Output: {output}") |
|
print("-" * 40) |
|
print("\n") |