from torch import Tensor |
import torch |
from PIL import Image |
import numpy as np |
import os |
import sys |
import json |
import folder_paths |
my_dir = os.path.dirname(os.path.abspath(__file__)) |
sys.path.append(my_dir) |
comfy_dir = os.path.abspath(os.path.join(my_dir, '..', '..')) |
sys.path.append(comfy_dir) |
import comfy.sd |
from tsc_sd import * |
loaded_objects = { |
"ckpt": [], |
"vae": [], |
"lora": [] |
} |
last_helds: dict[str, list] = { |
"results": [], |
"latent": [], |
"images": [], |
"vae_decode": [], |
} |
def tensor2pil(image: torch.Tensor) -> Image.Image: |
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) |
def pil2tensor(image: Image.Image) -> torch.Tensor: |
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) |
def extract_node_info(prompt, id, indirect_key=None): |
id = str(id) |
node_id = None |
if indirect_key: |
if id in prompt and 'inputs' in prompt[id] and indirect_key in prompt[id]['inputs']: |
indirect_id = prompt[id]['inputs'][indirect_key][0] |
if indirect_id in prompt: |
node_id = indirect_id |
return prompt[indirect_id].get('class_type', None), node_id |
return None, None |
return prompt.get(id, {}).get('class_type', None), node_id |
def extract_node_value(prompt, id, key): |
return prompt.get(str(id), {}).get('inputs', {}).get(key, None) |
def print_loaded_objects_entries(id=None, prompt=None, show_id=False): |
print("-" * 40) |
if id is not None: |
id = str(id) |
if prompt is not None and id is not None: |
node_name, _ = extract_node_info(prompt, id) |
if show_id: |
print(f"\033[36m{node_name} Models Cache: (node_id:{int(id)})\033[0m") |
else: |
print(f"\033[36m{node_name} Models Cache:\033[0m") |
elif id is None: |
print(f"\033[36mGlobal Models Cache:\033[0m") |
else: |
print(f"\033[36mModels Cache: \nnode_id:{int(id)}\033[0m") |
entries_found = False |
for key in ["ckpt", "vae", "lora"]: |
entries_with_id = loaded_objects[key] if id is None else [entry for entry in loaded_objects[key] if id in entry[-1]] |
if not entries_with_id: |
continue |
entries_found = True |
print(f"{key.capitalize()}:") |
for i, entry in enumerate(entries_with_id, 1): |
if key == "lora": |
lora_models_info = ', '.join(f"{os.path.splitext(os.path.basename(name))[0]}({round(strength_model, 2)},{round(strength_clip, 2)})" for name, strength_model, strength_clip in entry[0]) |
base_ckpt_name = os.path.splitext(os.path.basename(entry[1]))[0] |
if id is None: |
associated_ids = ', '.join(map(str, entry[-1])) |
print(f" [{i}] base_ckpt: {base_ckpt_name}, lora(mod,clip): {lora_models_info} (ids: {associated_ids})") |
else: |
print(f" [{i}] base_ckpt: {base_ckpt_name}, lora(mod,clip): {lora_models_info}") |
else: |
name_without_ext = os.path.splitext(os.path.basename(entry[0]))[0] |
if id is None: |
associated_ids = ', '.join(map(str, entry[-1])) |
print(f" [{i}] {name_without_ext} (ids: {associated_ids})") |
else: |
print(f" [{i}] {name_without_ext}") |
if not entries_found: |
print("-") |
def globals_cleanup(prompt): |
global loaded_objects |
global last_helds |
for key in list(last_helds.keys()): |
original_length = len(last_helds[key]) |
last_helds[key] = [(value, id) for value, id in last_helds[key] if str(id) in prompt.keys()] |
for key in list(loaded_objects.keys()): |
for i, tup in enumerate(list(loaded_objects[key])): |
id_array = [id for id in tup[-1] if str(id) in prompt.keys()] |
if len(id_array) != len(tup[-1]): |
if id_array: |
loaded_objects[key][i] = tup[:-1] + (id_array,) |
else: |
loaded_objects[key].remove(tup) |
def load_checkpoint(ckpt_name, id, output_vae=True, cache=None, cache_overwrite=False): |
""" |
Searches for tuple index that contains ckpt_name in "ckpt" array of loaded_objects. |
If found, extracts the model, clip, and vae from the loaded_objects. |
If not found, loads the checkpoint, extracts the model, clip, and vae. |
The id parameter represents the node ID and is used for caching models for the XY Plot node. |
If the cache limit is reached for a specific id, clears the cache and returns the loaded model, clip, and vae without adding a new entry. |
If there is cache space, adds the id to the ids list if it's not already there. |
If there is cache space and the checkpoint was not found in loaded_objects, adds a new entry to loaded_objects. |
Parameters: |
- ckpt_name: name of the checkpoint to load. |
- id: an identifier for caching models for specific nodes. |
- output_vae: boolean, if True loads the VAE too. |
- cache (optional): an integer that specifies how many checkpoint entries with a given id can exist in loaded_objects. Defaults to None. |
""" |
global loaded_objects |
for entry in loaded_objects["ckpt"]: |
if entry[0] == ckpt_name: |
_, model, clip, vae, ids = entry |
cache_full = cache and len([entry for entry in loaded_objects["ckpt"] if id in entry[-1]]) >= cache |
if cache_full: |
clear_cache(id, cache, "ckpt") |
elif id not in ids: |
ids.append(id) |
return model, clip, vae |
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) |
out = load_checkpoint_guess_config_tsc(ckpt_path, output_vae, output_clip=True, |
embedding_directory=folder_paths.get_folder_paths("embeddings")) |
model = out[0] |
clip = out[1] |
vae = out[2] |
if cache: |
if len([entry for entry in loaded_objects["ckpt"] if id in entry[-1]]) < cache: |
loaded_objects["ckpt"].append((ckpt_name, model, clip, vae, [id])) |
else: |
clear_cache(id, cache, "ckpt") |
if cache_overwrite: |
for e in loaded_objects["ckpt"]: |
if id in e[-1]: |
e[-1].remove(id) |
if not e[-1]: |
loaded_objects["ckpt"].remove(e) |
break |
loaded_objects["ckpt"].append((ckpt_name, model, clip, vae, [id])) |
return model, clip, vae |
def get_bvae_by_ckpt_name(ckpt_name): |
for ckpt in loaded_objects["ckpt"]: |
if ckpt[0] == ckpt_name: |
return ckpt[3] |
return None |
def load_vae(vae_name, id, cache=None, cache_overwrite=False): |
""" |
Extracts the vae with a given name from the "vae" array in loaded_objects. |
If the vae is not found, creates a new VAE object with the given name and adds it to the "vae" array. |
Also stores the id parameter, which is used for caching models specifically for nodes with the given ID. |
If the cache limit is reached for a specific id, returns the loaded vae without adding id or making a new entry in loaded_objects. |
If there is cache space, and the id is not in the ids list, adds the id to the ids list. |
If there is cache space, and the vae was not found in loaded_objects, adds a new entry to the loaded_objects. |
Parameters: |
- vae_name: name of the VAE to load. |
- id (optional): an identifier for caching models for specific nodes. Defaults to None. |
- cache (optional): an integer that specifies how many vae entries with a given id can exist in loaded_objects. Defaults to None. |
""" |
global loaded_objects |
for i, entry in enumerate(loaded_objects["vae"]): |
if entry[0] == vae_name: |
vae, ids = entry[1], entry[2] |
if id not in ids: |
if cache and len([entry for entry in loaded_objects["vae"] if id in entry[-1]]) >= cache: |
return vae |
ids.append(id) |
if cache: |
clear_cache(id, cache, "vae") |
return vae |
vae_path = folder_paths.get_full_path("vae", vae_name) |
vae = comfy.sd.VAE(ckpt_path=vae_path) |
if cache: |
if len([entry for entry in loaded_objects["vae"] if id in entry[-1]]) < cache: |
loaded_objects["vae"].append((vae_name, vae, [id])) |
else: |
clear_cache(id, cache, "vae") |
if cache_overwrite: |
for e in loaded_objects["vae"]: |
if id in e[-1]: |
e[-1].remove(id) |
if not e[-1]: |
loaded_objects["vae"].remove(e) |
break |
loaded_objects["vae"].append((vae_name, vae, [id])) |
return vae |
def load_lora(lora_params, ckpt_name, id, cache=None, ckpt_cache=None, cache_overwrite=False): |
""" |
Extracts the Lora model with a given name from the "lora" array in loaded_objects. |
If the Lora model is not found or strength values changed or model changed, creates a new Lora object with the given name and adds it to the "lora" array. |
Also stores the id parameter, which is used for caching models specifically for nodes with the given ID. |
If the cache limit is reached for a specific id, clears the cache and returns the loaded Lora model and clip without adding a new entry. |
If there is cache space, adds the id to the ids list if it's not already there. |
If there is cache space and the Lora model was not found in loaded_objects, adds a new entry to loaded_objects. |
Parameters: |
- lora_params: A list of tuples, where each tuple contains lora_name, strength_model, strength_clip. |
- ckpt_name: name of the checkpoint from which the Lora model is created. |
- id: an identifier for caching models for specific nodes. |
- cache (optional): an integer that specifies how many Lora entries with a given id can exist in loaded_objects. Defaults to None. |
""" |
global loaded_objects |
for entry in loaded_objects["lora"]: |
if set(entry[0]) == set(lora_params) and entry[1] == ckpt_name: |
_, _, lora_model, lora_clip, ids = entry |
cache_full = cache and len([entry for entry in loaded_objects["lora"] if id in entry[-1]]) >= cache |
if cache_full: |
clear_cache(id, cache, "lora") |
elif id not in ids: |
ids.append(id) |
for ckpt_entry in loaded_objects["ckpt"]: |
if ckpt_entry[0] == ckpt_name: |
_, _, _, _, ckpt_ids = ckpt_entry |
ckpt_cache_full = ckpt_cache and len( |
[ckpt_entry for ckpt_entry in loaded_objects["ckpt"] if id in ckpt_entry[-1]]) >= ckpt_cache |
if ckpt_cache_full: |
clear_cache(id, ckpt_cache, "ckpt") |
elif id not in ckpt_ids: |
ckpt_ids.append(id) |
return lora_model, lora_clip |
def recursive_load_lora(lora_params, ckpt, clip, id, ckpt_cache, cache_overwrite, folder_paths): |
if len(lora_params) == 0: |
return ckpt, clip |
lora_name, strength_model, strength_clip = lora_params[0] |
lora_path = folder_paths.get_full_path("loras", lora_name) |
lora_model, lora_clip = load_lora_for_models_tsc(ckpt, clip, lora_path, strength_model, strength_clip) |
return recursive_load_lora(lora_params[1:], lora_model, lora_clip, id, ckpt_cache, cache_overwrite, folder_paths) |
lora_name, strength_model, strength_clip = lora_params[0] |
ckpt, clip, _ = load_checkpoint(ckpt_name, id, cache=ckpt_cache, cache_overwrite=cache_overwrite) |
lora_model, lora_clip = recursive_load_lora(lora_params, ckpt, clip, id, ckpt_cache, cache_overwrite, folder_paths) |
if cache: |
if len([entry for entry in loaded_objects["lora"] if id in entry[-1]]) < cache: |
loaded_objects["lora"].append((lora_params, ckpt_name, lora_model, lora_clip, [id])) |
else: |
clear_cache(id, cache, "lora") |
if cache_overwrite: |
for e in loaded_objects["lora"]: |
if id in e[-1]: |
e[-1].remove(id) |
if not e[-1]: |
loaded_objects["lora"].remove(e) |
break |
loaded_objects["lora"].append((lora_params, ckpt_name, lora_model, lora_clip, [id])) |
return lora_model, lora_clip |
def clear_cache(id, cache, dict_name): |
""" |
Clear the cache for a specific id in a specific dictionary (either "ckpt" or "vae"). |
If the cache limit is reached for a specific id, deletes the id from the oldest entry. |
If the id array of the entry becomes empty, deletes the entry. |
""" |
id_associated_entries = [entry for entry in loaded_objects[dict_name] if id in entry[-1]] |
while len(id_associated_entries) > cache: |
older_entry = id_associated_entries[0] |
older_entry[-1].remove(id) |
if not older_entry[-1]: |
loaded_objects[dict_name].remove(older_entry) |
id_associated_entries = [entry for entry in loaded_objects[dict_name] if id in entry[-1]] |
def clear_cache_by_exception(node_id, vae_dict=None, ckpt_dict=None, lora_dict=None): |
global loaded_objects |
dict_mapping = { |
"vae_dict": "vae", |
"ckpt_dict": "ckpt", |
"lora_dict": "lora" |
} |
for arg_name, arg_val in {"vae_dict": vae_dict, "ckpt_dict": ckpt_dict, "lora_dict": lora_dict}.items(): |
if arg_val is None: |
continue |
dict_name = dict_mapping[arg_name] |
for tuple_idx, tuple_item in enumerate(loaded_objects[dict_name].copy()): |
if arg_name == "lora_dict": |
for lora_params, ckpt_name in arg_val: |
if set(lora_params) == set(tuple_item[0]) and ckpt_name == tuple_item[1]: |
break |
else: |
if node_id in tuple_item[-1]: |
tuple_item[-1].remove(node_id) |
if not tuple_item[-1]: |
loaded_objects[dict_name].remove(tuple_item) |
continue |
elif tuple_item[0] not in arg_val: |
if node_id in tuple_item[-1]: |
tuple_item[-1].remove(node_id) |
if not tuple_item[-1]: |
loaded_objects[dict_name].remove(tuple_item) |
def get_cache_numbers(node_name): |
my_dir = os.path.dirname(os.path.abspath(__file__)) |
settings_file = os.path.join(my_dir, 'node_settings.json') |
with open(settings_file, 'r') as file: |
node_settings = json.load(file) |
model_cache_settings = node_settings.get(node_name, {}).get('model_cache', {}) |
vae_cache = int(model_cache_settings.get('vae', 1)) |
ckpt_cache = int(model_cache_settings.get('ckpt', 1)) |
lora_cache = int(model_cache_settings.get('lora', 1)) |
return vae_cache, ckpt_cache, lora_cache |
def print_last_helds(id=None): |
print("\n" + "-" * 40) |
if id is not None: |
id = str(id) |
print(f"Node-specific Last Helds (node_id:{int(id)})") |
else: |
print(f"Global Last Helds:") |
for key in ["results", "latent", "images", "vae_decode"]: |
entries_with_id = last_helds[key] if id is None else [entry for entry in last_helds[key] if id == entry[-1]] |
if not entries_with_id: |
continue |
print(f"{key.capitalize()}:") |
for i, entry in enumerate(entries_with_id, 1): |
if isinstance(entry[0], bool): |
output = entry[0] |
else: |
output = len(entry[0]) |
if id is None: |
print(f" [{i}] Output: {output} (id: {entry[-1]})") |
else: |
print(f" [{i}] Output: {output}") |
print("-" * 40) |
print("\n") |