# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) import gguf import torch import comfy.ops import comfy.model_management from .dequant import dequantize_tensor, is_quantized class GGMLTensor(torch.Tensor): """ Main tensor-like class for storing quantized weights """ def __init__(self, *args, tensor_type, tensor_shape, patches=[], **kwargs): super().__init__() self.tensor_type = tensor_type self.tensor_shape = tensor_shape self.patches = patches def __new__(cls, *args, tensor_type, tensor_shape, patches=[], **kwargs): return super().__new__(cls, *args, **kwargs) def to(self, *args, **kwargs): new = super().to(*args, **kwargs) new.tensor_type = getattr(self, "tensor_type", None) new.tensor_shape = getattr(self, "tensor_shape", new.data.shape) new.patches = getattr(self, "patches", []).copy() return new def clone(self, *args, **kwargs): return self def detach(self, *args, **kwargs): return self def copy_(self, *args, **kwargs): # fixes .weight.copy_ in comfy/clip_model/CLIPTextModel try: return super().copy_(*args, **kwargs) except Exception as e: print(f"ignoring 'copy_' on tensor: {e}") def new_empty(self, size, *args, **kwargs): # Intel Arc fix, ref#50 new_tensor = super().new_empty(size, *args, **kwargs) return GGMLTensor( new_tensor, tensor_type = getattr(self, "tensor_type", None), tensor_shape = size, patches = getattr(self, "patches", []).copy() ) @property def shape(self): if not hasattr(self, "tensor_shape"): self.tensor_shape = self.size() return self.tensor_shape class GGMLLayer(torch.nn.Module): """ This (should) be responsible for de-quantizing on the fly """ comfy_cast_weights = True dequant_dtype = None patch_dtype = None largest_layer = False torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16} def is_ggml_quantized(self, *, weight=None, bias=None): if weight is None: weight = self.weight if bias is None: bias = self.bias return is_quantized(weight) or is_quantized(bias) def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): weight, bias = state_dict.get(f"{prefix}weight"), state_dict.get(f"{prefix}bias") # NOTE: using modified load for linear due to not initializing on creation, see GGMLOps todo if self.is_ggml_quantized(weight=weight, bias=bias) or isinstance(self, torch.nn.Linear): return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs) return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) def ggml_load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): prefix_len = len(prefix) for k,v in state_dict.items(): if k[prefix_len:] == "weight": self.weight = torch.nn.Parameter(v, requires_grad=False) elif k[prefix_len:] == "bias" and v is not None: self.bias = torch.nn.Parameter(v, requires_grad=False) else: unexpected_keys.append(k) # For Linear layer with missing weight if self.weight is None and isinstance(self, torch.nn.Linear): v = torch.zeros(self.in_features, self.out_features) self.weight = torch.nn.Parameter(v, requires_grad=False) missing_keys.append(prefix+"weight") # for vram estimation (TODO: less fragile logic?) if getattr(self.weight, "is_largest_weight", False): self.largest_layer = True def _save_to_state_dict(self, *args, **kwargs): if self.is_ggml_quantized(): return self.ggml_save_to_state_dict(*args, **kwargs) return super()._save_to_state_dict(*args, **kwargs) def ggml_save_to_state_dict(self, destination, prefix, keep_vars): # This is a fake state dict for vram estimation weight = torch.zeros_like(self.weight, device=torch.device("meta")) destination[prefix + "weight"] = weight if self.bias is not None: bias = torch.zeros_like(self.bias, device=torch.device("meta")) destination[prefix + "bias"] = bias # Take into account space required for dequantizing the largest tensor if self.largest_layer: shape = getattr(self.weight, "tensor_shape", self.weight.shape) dtype = self.dequant_dtype or torch.float16 temp = torch.empty(*shape, device=torch.device("meta"), dtype=dtype) destination[prefix + "temp.weight"] = temp return # This would return the dequantized state dict destination[prefix + "weight"] = self.get_weight(self.weight) if bias is not None: destination[prefix + "bias"] = self.get_weight(self.bias) def get_weight(self, tensor, dtype): if tensor is None: return # consolidate and load patches to GPU in async patch_list = [] device = tensor.device for function, patches, key in getattr(tensor, "patches", []): patch_list += move_patch_to_device(patches, device) # dequantize tensor while patches load weight = dequantize_tensor(tensor, dtype, self.dequant_dtype) # prevent propagating custom tensor class if isinstance(weight, GGMLTensor): weight.__class__ = torch.Tensor # apply patches if patch_list: if self.patch_dtype is None: weight = function(patch_list, weight, key) else: # for testing, may degrade image quality patch_dtype = dtype if self.patch_dtype == "target" else self.patch_dtype weight = function(patch_list, weight, key, patch_dtype) return weight def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if input is not None: if dtype is None: dtype = getattr(input, "dtype", torch.float32) if bias_dtype is None: bias_dtype = dtype if device is None: device = input.device bias = None non_blocking = comfy.model_management.device_supports_non_blocking(device) if s.bias is not None: bias = s.get_weight(s.bias.to(device), dtype) bias = comfy.ops.cast_to(bias, bias_dtype, device, non_blocking=non_blocking, copy=False) weight = s.get_weight(s.weight.to(device), dtype) weight = comfy.ops.cast_to(weight, dtype, device, non_blocking=non_blocking, copy=False) return weight, bias def forward_comfy_cast_weights(self, input, *args, **kwargs): if self.is_ggml_quantized(): out = self.forward_ggml_cast_weights(input, *args, **kwargs) else: out = super().forward_comfy_cast_weights(input, *args, **kwargs) # non-ggml forward might still propagate custom tensor class if isinstance(out, GGMLTensor): out.__class__ = torch.Tensor return out def forward_ggml_cast_weights(self, input): raise NotImplementedError class GGMLOps(comfy.ops.manual_cast): """ Dequantize weights on the fly before doing the compute """ class Linear(GGMLLayer, comfy.ops.manual_cast.Linear): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): torch.nn.Module.__init__(self) # TODO: better workaround for reserved memory spike on windows # Issue is with `torch.empty` still reserving the full memory for the layer # Windows doesn't over-commit memory so without this 24GB+ of pagefile is used self.in_features = in_features self.out_features = out_features self.weight = None self.bias = None def forward_ggml_cast_weights(self, input): weight, bias = self.cast_bias_weight(input) return torch.nn.functional.linear(input, weight, bias) class Conv2d(GGMLLayer, comfy.ops.manual_cast.Conv2d): def forward_ggml_cast_weights(self, input): weight, bias = self.cast_bias_weight(input) return self._conv_forward(input, weight, bias) class Embedding(GGMLLayer, comfy.ops.manual_cast.Embedding): def forward_ggml_cast_weights(self, input, out_dtype=None): output_dtype = out_dtype if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: out_dtype = None weight, _bias = self.cast_bias_weight(self, device=input.device, dtype=out_dtype) return torch.nn.functional.embedding( input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse ).to(dtype=output_dtype) class LayerNorm(GGMLLayer, comfy.ops.manual_cast.LayerNorm): def forward_ggml_cast_weights(self, input): if self.weight is None: return super().forward_comfy_cast_weights(input) weight, bias = self.cast_bias_weight(input) return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) class GroupNorm(GGMLLayer, comfy.ops.manual_cast.GroupNorm): def forward_ggml_cast_weights(self, input): weight, bias = self.cast_bias_weight(input) return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) def move_patch_to_device(item, device): if isinstance(item, torch.Tensor): return item.to(device, non_blocking=True) elif isinstance(item, tuple): return tuple(move_patch_to_device(x, device) for x in item) elif isinstance(item, list): return [move_patch_to_device(x, device) for x in item] else: return item