Spaces:
Build error
Build error
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) | |
import torch | |
import gguf | |
from .ops import GGMLTensor | |
from .dequant import is_quantized | |
IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "ltxv", "hyvid"} | |
TXT_ARCH_LIST = {"t5", "t5encoder", "llama"} | |
def get_orig_shape(reader, tensor_name): | |
field_key = f"comfy.gguf.orig_shape.{tensor_name}" | |
field = reader.get_field(field_key) | |
if field is None: | |
return None | |
# Has original shape metadata, so we try to decode it. | |
if len(field.types) != 2 or field.types[0] != gguf.GGUFValueType.ARRAY or field.types[1] != gguf.GGUFValueType.INT32: | |
raise TypeError(f"Bad original shape metadata for {field_key}: Expected ARRAY of INT32, got {field.types}") | |
return torch.Size(tuple(int(field.parts[part_idx][0]) for part_idx in field.data)) | |
def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=False): | |
""" | |
Read state dict as fake tensors | |
""" | |
reader = gguf.GGUFReader(path) | |
# filter and strip prefix | |
has_prefix = False | |
if handle_prefix is not None: | |
prefix_len = len(handle_prefix) | |
tensor_names = set(tensor.name for tensor in reader.tensors) | |
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names) | |
tensors = [] | |
for tensor in reader.tensors: | |
sd_key = tensor_name = tensor.name | |
if has_prefix: | |
if not tensor_name.startswith(handle_prefix): | |
continue | |
sd_key = tensor_name[prefix_len:] | |
tensors.append((sd_key, tensor)) | |
# detect and verify architecture | |
compat = None | |
arch_str = None | |
arch_field = reader.get_field("general.architecture") | |
if arch_field is not None: | |
if len(arch_field.types) != 1 or arch_field.types[0] != gguf.GGUFValueType.STRING: | |
raise TypeError(f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}") | |
arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8") | |
if arch_str not in IMG_ARCH_LIST and arch_str not in TXT_ARCH_LIST: | |
raise ValueError(f"Unexpected architecture type in GGUF file, expected one of flux, sd1, sdxl, t5encoder but got {arch_str!r}") | |
else: # stable-diffusion.cpp | |
# import here to avoid changes to convert.py breaking regular models | |
from .tools.convert import detect_arch | |
arch_str = detect_arch(set(val[0] for val in tensors)).arch | |
compat = "sd.cpp" | |
# main loading loop | |
state_dict = {} | |
qtype_dict = {} | |
for sd_key, tensor in tensors: | |
tensor_name = tensor.name | |
tensor_type_str = str(tensor.tensor_type) | |
torch_tensor = torch.from_numpy(tensor.data) # mmap | |
shape = get_orig_shape(reader, tensor_name) | |
if shape is None: | |
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape))) | |
# Workaround for stable-diffusion.cpp SDXL detection. | |
if compat == "sd.cpp" and arch_str == "sdxl": | |
if any([tensor_name.endswith(x) for x in (".proj_in.weight", ".proj_out.weight")]): | |
while len(shape) > 2 and shape[-1] == 1: | |
shape = shape[:-1] | |
# add to state dict | |
if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}: | |
torch_tensor = torch_tensor.view(*shape) | |
state_dict[sd_key] = GGMLTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape) | |
qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1 | |
# mark largest tensor for vram estimation | |
qsd = {k:v for k,v in state_dict.items() if is_quantized(v)} | |
if len(qsd) > 0: | |
max_key = max(qsd.keys(), key=lambda k: qsd[k].numel()) | |
state_dict[max_key].is_largest_weight = True | |
# sanity check debug print | |
print("\nggml_sd_loader:") | |
for k,v in qtype_dict.items(): | |
print(f" {k:30}{v:3}") | |
if return_arch: | |
return (state_dict, arch_str) | |
return state_dict | |
# for remapping llama.cpp -> original key names | |
T5_SD_MAP = { | |
"enc.": "encoder.", | |
".blk.": ".block.", | |
"token_embd": "shared", | |
"output_norm": "final_layer_norm", | |
"attn_q": "layer.0.SelfAttention.q", | |
"attn_k": "layer.0.SelfAttention.k", | |
"attn_v": "layer.0.SelfAttention.v", | |
"attn_o": "layer.0.SelfAttention.o", | |
"attn_norm": "layer.0.layer_norm", | |
"attn_rel_b": "layer.0.SelfAttention.relative_attention_bias", | |
"ffn_up": "layer.1.DenseReluDense.wi_1", | |
"ffn_down": "layer.1.DenseReluDense.wo", | |
"ffn_gate": "layer.1.DenseReluDense.wi_0", | |
"ffn_norm": "layer.1.layer_norm", | |
} | |
LLAMA_SD_MAP = { | |
"blk.": "model.layers.", | |
"attn_norm": "input_layernorm", | |
"attn_q": "self_attn.q_proj", | |
"attn_k": "self_attn.k_proj", | |
"attn_v": "self_attn.v_proj", | |
"attn_output": "self_attn.o_proj", | |
"ffn_up": "mlp.up_proj", | |
"ffn_down": "mlp.down_proj", | |
"ffn_gate": "mlp.gate_proj", | |
"ffn_norm": "post_attention_layernorm", | |
"token_embd": "model.embed_tokens", | |
"output_norm": "model.norm", | |
"output.weight": "lm_head.weight", | |
} | |
def sd_map_replace(raw_sd, key_map): | |
sd = {} | |
for k,v in raw_sd.items(): | |
for s,d in key_map.items(): | |
k = k.replace(s,d) | |
sd[k] = v | |
return sd | |
def llama_permute(raw_sd, n_head, n_head_kv): | |
# Reverse version of LlamaModel.permute in llama.cpp convert script | |
sd = {} | |
permute = lambda x,h: x.reshape(h, x.shape[0] // h // 2, 2, *x.shape[1:]).swapaxes(1, 2).reshape(x.shape) | |
for k,v in raw_sd.items(): | |
if k.endswith(("q_proj.weight", "q_proj.bias")): | |
v.data = permute(v.data, n_head) | |
if k.endswith(("k_proj.weight", "k_proj.bias")): | |
v.data = permute(v.data, n_head_kv) | |
sd[k] = v | |
return sd | |
def gguf_clip_loader(path): | |
sd, arch = gguf_sd_loader(path, return_arch=True) | |
if arch in {"t5", "t5encoder"}: | |
sd = sd_map_replace(sd, T5_SD_MAP) | |
elif arch in {"llama"}: | |
temb_key = "token_embd.weight" | |
if temb_key in sd and sd[temb_key].shape != (128320, 4096): | |
# This still works. Raise error? | |
print("Warning! token_embd shape may be incorrect for llama 3 model!") | |
sd = sd_map_replace(sd, LLAMA_SD_MAP) | |
sd = llama_permute(sd, 32, 8) # L3 | |
else: | |
pass | |
return sd | |