|
|
|
import torch |
|
import gguf |
|
import copy |
|
import logging |
|
|
|
import comfy.sd |
|
import comfy.utils |
|
import comfy.model_management |
|
import comfy.model_patcher |
|
import folder_paths |
|
|
|
from .ops import GGMLTensor, GGMLOps, move_patch_to_device |
|
from .dequant import is_quantized, is_torch_compatible |
|
|
|
|
|
if "unet_gguf" not in folder_paths.folder_names_and_paths: |
|
orig = folder_paths.folder_names_and_paths.get("diffusion_models", folder_paths.folder_names_and_paths.get("unet", [[], set()])) |
|
folder_paths.folder_names_and_paths["unet_gguf"] = (orig[0], {".gguf"}) |
|
|
|
if "clip_gguf" not in folder_paths.folder_names_and_paths: |
|
orig = folder_paths.folder_names_and_paths.get("clip", [[], set()]) |
|
folder_paths.folder_names_and_paths["clip_gguf"] = (orig[0], {".gguf"}) |
|
|
|
def gguf_sd_loader_get_orig_shape(reader, tensor_name): |
|
field_key = f"comfy.gguf.orig_shape.{tensor_name}" |
|
field = reader.get_field(field_key) |
|
if field is None: |
|
return None |
|
|
|
if len(field.types) != 2 or field.types[0] != gguf.GGUFValueType.ARRAY or field.types[1] != gguf.GGUFValueType.INT32: |
|
raise TypeError(f"Bad original shape metadata for {field_key}: Expected ARRAY of INT32, got {field.types}") |
|
return torch.Size(tuple(int(field.parts[part_idx][0]) for part_idx in field.data)) |
|
|
|
def gguf_sd_loader(path, handle_prefix="model.diffusion_model."): |
|
""" |
|
Read state dict as fake tensors |
|
""" |
|
reader = gguf.GGUFReader(path) |
|
|
|
|
|
has_prefix = False |
|
if handle_prefix is not None: |
|
prefix_len = len(handle_prefix) |
|
tensor_names = set(tensor.name for tensor in reader.tensors) |
|
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names) |
|
|
|
tensors = [] |
|
for tensor in reader.tensors: |
|
sd_key = tensor_name = tensor.name |
|
if has_prefix: |
|
if not tensor_name.startswith(handle_prefix): |
|
continue |
|
sd_key = tensor_name[prefix_len:] |
|
tensors.append((sd_key, tensor)) |
|
|
|
|
|
compat = None |
|
arch_str = None |
|
arch_field = reader.get_field("general.architecture") |
|
if arch_field is not None: |
|
if len(arch_field.types) != 1 or arch_field.types[0] != gguf.GGUFValueType.STRING: |
|
raise TypeError(f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}") |
|
arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8") |
|
if arch_str not in {"flux", "sd1", "sdxl", "sd3", "t5", "t5encoder"}: |
|
raise ValueError(f"Unexpected architecture type in GGUF file, expected one of flux, sd1, sdxl, t5encoder but got {arch_str!r}") |
|
else: |
|
|
|
from .tools.convert import detect_arch |
|
arch_str = detect_arch(set(val[0] for val in tensors)).arch |
|
compat = "sd.cpp" |
|
|
|
|
|
state_dict = {} |
|
qtype_dict = {} |
|
for sd_key, tensor in tensors: |
|
tensor_name = tensor.name |
|
tensor_type_str = str(tensor.tensor_type) |
|
torch_tensor = torch.from_numpy(tensor.data) |
|
|
|
shape = gguf_sd_loader_get_orig_shape(reader, tensor_name) |
|
if shape is None: |
|
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape))) |
|
|
|
if compat == "sd.cpp" and arch_str == "sdxl": |
|
if any([tensor_name.endswith(x) for x in (".proj_in.weight", ".proj_out.weight")]): |
|
while len(shape) > 2 and shape[-1] == 1: |
|
shape = shape[:-1] |
|
|
|
|
|
if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}: |
|
torch_tensor = torch_tensor.view(*shape) |
|
state_dict[sd_key] = GGMLTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape) |
|
qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1 |
|
|
|
|
|
print("\nggml_sd_loader:") |
|
for k,v in qtype_dict.items(): |
|
print(f" {k:30}{v:3}") |
|
|
|
return state_dict |
|
|
|
|
|
clip_sd_map = { |
|
"enc.": "encoder.", |
|
".blk.": ".block.", |
|
"token_embd": "shared", |
|
"output_norm": "final_layer_norm", |
|
"attn_q": "layer.0.SelfAttention.q", |
|
"attn_k": "layer.0.SelfAttention.k", |
|
"attn_v": "layer.0.SelfAttention.v", |
|
"attn_o": "layer.0.SelfAttention.o", |
|
"attn_norm": "layer.0.layer_norm", |
|
"attn_rel_b": "layer.0.SelfAttention.relative_attention_bias", |
|
"ffn_up": "layer.1.DenseReluDense.wi_1", |
|
"ffn_down": "layer.1.DenseReluDense.wo", |
|
"ffn_gate": "layer.1.DenseReluDense.wi_0", |
|
"ffn_norm": "layer.1.layer_norm", |
|
} |
|
|
|
def gguf_clip_loader(path): |
|
raw_sd = gguf_sd_loader(path) |
|
assert "enc.blk.23.ffn_up.weight" in raw_sd, "Invalid Text Encoder!" |
|
sd = {} |
|
for k,v in raw_sd.items(): |
|
for s,d in clip_sd_map.items(): |
|
k = k.replace(s,d) |
|
sd[k] = v |
|
return sd |
|
|
|
|
|
import collections |
|
class GGUFModelPatcher(comfy.model_patcher.ModelPatcher): |
|
patch_on_device = False |
|
|
|
def patch_weight_to_device(self, key, device_to=None, inplace_update=False): |
|
if key not in self.patches: |
|
return |
|
weight = comfy.utils.get_attr(self.model, key) |
|
|
|
try: |
|
from comfy.lora import calculate_weight |
|
except Exception: |
|
calculate_weight = self.calculate_weight |
|
|
|
patches = self.patches[key] |
|
if is_quantized(weight): |
|
out_weight = weight.to(device_to) |
|
patches = move_patch_to_device(patches, self.load_device if self.patch_on_device else self.offload_device) |
|
|
|
out_weight.patches = [(calculate_weight, patches, key)] |
|
else: |
|
inplace_update = self.weight_inplace_update or inplace_update |
|
if key not in self.backup: |
|
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])( |
|
weight.to(device=self.offload_device, copy=inplace_update), inplace_update |
|
) |
|
|
|
if device_to is not None: |
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) |
|
else: |
|
temp_weight = weight.to(torch.float32, copy=True) |
|
|
|
out_weight = calculate_weight(patches, temp_weight, key) |
|
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype) |
|
|
|
if inplace_update: |
|
comfy.utils.copy_to_param(self.model, key, out_weight) |
|
else: |
|
comfy.utils.set_attr_param(self.model, key, out_weight) |
|
|
|
def unpatch_model(self, device_to=None, unpatch_weights=True): |
|
if unpatch_weights: |
|
for p in self.model.parameters(): |
|
if is_torch_compatible(p): |
|
continue |
|
patches = getattr(p, "patches", []) |
|
if len(patches) > 0: |
|
p.patches = [] |
|
|
|
return super().unpatch_model(device_to=device_to, unpatch_weights=unpatch_weights) |
|
|
|
mmap_released = False |
|
def load(self, *args, force_patch_weights=False, **kwargs): |
|
|
|
super().load(*args, force_patch_weights=True, **kwargs) |
|
|
|
|
|
if not self.mmap_released: |
|
linked = [] |
|
if kwargs.get("lowvram_model_memory", 0) > 0: |
|
for n, m in self.model.named_modules(): |
|
if hasattr(m, "weight"): |
|
device = getattr(m.weight, "device", None) |
|
if device == self.offload_device: |
|
linked.append((n, m)) |
|
continue |
|
if hasattr(m, "bias"): |
|
device = getattr(m.bias, "device", None) |
|
if device == self.offload_device: |
|
linked.append((n, m)) |
|
continue |
|
if linked: |
|
print(f"Attempting to release mmap ({len(linked)})") |
|
for n, m in linked: |
|
|
|
m.to(self.load_device).to(self.offload_device) |
|
self.mmap_released = True |
|
|
|
def clone(self, *args, **kwargs): |
|
n = GGUFModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) |
|
n.patches = {} |
|
for k in self.patches: |
|
n.patches[k] = self.patches[k][:] |
|
n.patches_uuid = self.patches_uuid |
|
|
|
n.object_patches = self.object_patches.copy() |
|
n.model_options = copy.deepcopy(self.model_options) |
|
n.backup = self.backup |
|
n.object_patches_backup = self.object_patches_backup |
|
n.patch_on_device = getattr(self, "patch_on_device", False) |
|
return n |
|
|
|
class UnetLoaderGGUF: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
unet_names = [x for x in folder_paths.get_filename_list("unet_gguf")] |
|
return { |
|
"required": { |
|
"unet_name": (unet_names,), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("MODEL",) |
|
FUNCTION = "load_unet" |
|
CATEGORY = "bootleg" |
|
TITLE = "Unet Loader (GGUF)" |
|
|
|
def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_device=None): |
|
ops = GGMLOps() |
|
|
|
if dequant_dtype in ("default", None): |
|
ops.Linear.dequant_dtype = None |
|
elif dequant_dtype in ["target"]: |
|
ops.Linear.dequant_dtype = dequant_dtype |
|
else: |
|
ops.Linear.dequant_dtype = getattr(torch, dequant_dtype) |
|
|
|
if patch_dtype in ("default", None): |
|
ops.Linear.patch_dtype = None |
|
elif patch_dtype in ["target"]: |
|
ops.Linear.patch_dtype = patch_dtype |
|
else: |
|
ops.Linear.patch_dtype = getattr(torch, patch_dtype) |
|
|
|
|
|
unet_path = folder_paths.get_full_path("unet", unet_name) |
|
sd = gguf_sd_loader(unet_path) |
|
model = comfy.sd.load_diffusion_model_state_dict( |
|
sd, model_options={"custom_operations": ops} |
|
) |
|
if model is None: |
|
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path)) |
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) |
|
model = GGUFModelPatcher.clone(model) |
|
model.patch_on_device = patch_on_device |
|
return (model,) |
|
|
|
class UnetLoaderGGUFAdvanced(UnetLoaderGGUF): |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
unet_names = [x for x in folder_paths.get_filename_list("unet_gguf")] |
|
return { |
|
"required": { |
|
"unet_name": (unet_names,), |
|
"dequant_dtype": (["default", "target", "float32", "float16", "bfloat16"], {"default": "default"}), |
|
"patch_dtype": (["default", "target", "float32", "float16", "bfloat16"], {"default": "default"}), |
|
"patch_on_device": ("BOOLEAN", {"default": False}), |
|
} |
|
} |
|
TITLE = "Unet Loader (GGUF/Advanced)" |
|
|
|
clip_name_dict = { |
|
"stable_diffusion": comfy.sd.CLIPType.STABLE_DIFFUSION, |
|
"stable_cascade": comfy.sd.CLIPType.STABLE_CASCADE, |
|
"stable_audio": comfy.sd.CLIPType.STABLE_AUDIO, |
|
"sdxl": comfy.sd.CLIPType.STABLE_DIFFUSION, |
|
"sd3": comfy.sd.CLIPType.SD3, |
|
"flux": comfy.sd.CLIPType.FLUX, |
|
} |
|
|
|
class CLIPLoaderGGUF: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"clip_name": (s.get_filename_list(),), |
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio"],), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("CLIP",) |
|
FUNCTION = "load_clip" |
|
CATEGORY = "bootleg" |
|
TITLE = "CLIPLoader (GGUF)" |
|
|
|
@classmethod |
|
def get_filename_list(s): |
|
files = [] |
|
files += folder_paths.get_filename_list("clip") |
|
files += folder_paths.get_filename_list("clip_gguf") |
|
return sorted(files) |
|
|
|
def load_data(self, ckpt_paths): |
|
clip_data = [] |
|
for p in ckpt_paths: |
|
if p.endswith(".gguf"): |
|
clip_data.append(gguf_clip_loader(p)) |
|
else: |
|
sd = comfy.utils.load_torch_file(p, safe_load=True) |
|
clip_data.append( |
|
{k:GGMLTensor(v, tensor_type=gguf.GGMLQuantizationType.F16, tensor_shape=v.shape) for k,v in sd.items()} |
|
) |
|
return clip_data |
|
|
|
def load_patcher(self, clip_paths, clip_type, clip_data): |
|
clip = comfy.sd.load_text_encoder_state_dicts( |
|
clip_type = clip_type, |
|
state_dicts = clip_data, |
|
model_options = { |
|
"custom_operations": GGMLOps, |
|
"initial_device": comfy.model_management.text_encoder_offload_device() |
|
}, |
|
embedding_directory = folder_paths.get_folder_paths("embeddings"), |
|
) |
|
clip.patcher = GGUFModelPatcher.clone(clip.patcher) |
|
|
|
|
|
if getattr(clip.cond_stage_model, "clip_l", None) is not None: |
|
if getattr(clip.cond_stage_model.clip_l.transformer.text_projection.weight, "tensor_shape", None) is None: |
|
clip.cond_stage_model.clip_l.transformer.text_projection = comfy.ops.manual_cast.Linear(768, 768) |
|
if getattr(clip.cond_stage_model, "clip_g", None) is not None: |
|
if getattr(clip.cond_stage_model.clip_g.transformer.text_projection.weight, "tensor_shape", None) is None: |
|
clip.cond_stage_model.clip_g.transformer.text_projection = comfy.ops.manual_cast.Linear(1280, 1280) |
|
|
|
return clip |
|
|
|
def load_clip(self, clip_name, type="stable_diffusion"): |
|
clip_path = folder_paths.get_full_path("clip", clip_name) |
|
clip_type = clip_name_dict.get(type, comfy.sd.CLIPType.STABLE_DIFFUSION) |
|
return (self.load_patcher([clip_path], clip_type, self.load_data([clip_path])),) |
|
|
|
class DualCLIPLoaderGGUF(CLIPLoaderGGUF): |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
file_options = (s.get_filename_list(), ) |
|
return { |
|
"required": { |
|
"clip_name1": file_options, |
|
"clip_name2": file_options, |
|
"type": (("sdxl", "sd3", "flux"), ), |
|
} |
|
} |
|
|
|
TITLE = "DualCLIPLoader (GGUF)" |
|
|
|
def load_clip(self, clip_name1, clip_name2, type): |
|
clip_path1 = folder_paths.get_full_path("clip", clip_name1) |
|
clip_path2 = folder_paths.get_full_path("clip", clip_name2) |
|
clip_paths = (clip_path1, clip_path2) |
|
clip_type = clip_name_dict.get(type, comfy.sd.CLIPType.STABLE_DIFFUSION) |
|
return (self.load_patcher(clip_paths, clip_type, self.load_data(clip_paths)),) |
|
|
|
class TripleCLIPLoaderGGUF(CLIPLoaderGGUF): |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
file_options = (s.get_filename_list(), ) |
|
return { |
|
"required": { |
|
"clip_name1": file_options, |
|
"clip_name2": file_options, |
|
"clip_name3": file_options, |
|
} |
|
} |
|
|
|
TITLE = "TripleCLIPLoader (GGUF)" |
|
|
|
def load_clip(self, clip_name1, clip_name2, clip_name3, type="sd3"): |
|
clip_path1 = folder_paths.get_full_path("clip", clip_name1) |
|
clip_path2 = folder_paths.get_full_path("clip", clip_name2) |
|
clip_path3 = folder_paths.get_full_path("clip", clip_name3) |
|
clip_paths = (clip_path1, clip_path2, clip_path3) |
|
clip_type = clip_name_dict.get(type, comfy.sd.CLIPType.STABLE_DIFFUSION) |
|
return (self.load_patcher(clip_paths, clip_type, self.load_data(clip_paths)),) |
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"UnetLoaderGGUF": UnetLoaderGGUF, |
|
"CLIPLoaderGGUF": CLIPLoaderGGUF, |
|
"DualCLIPLoaderGGUF": DualCLIPLoaderGGUF, |
|
"TripleCLIPLoaderGGUF": TripleCLIPLoaderGGUF, |
|
"UnetLoaderGGUFAdvanced": UnetLoaderGGUFAdvanced, |
|
} |
|
|