Spaces:
Running
on
Zero
Running
on
Zero
# https://github.com/MinusZoneAI/ComfyUI-CogVideoX-MZ/blob/9616415220fd09388622f40f6609e4ed81f048a5/mz_gguf_loader.py | |
import torch | |
import torch.nn as nn | |
import gc | |
class quantize_lazy_load(): | |
def __init__(self): | |
self.device = None | |
def __enter__(self): | |
self.device = torch.device("meta") | |
self.device.__enter__() | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
self.device.__exit__(exc_type, exc_value, traceback) | |
def quantize_load_state_dict(model, state_dict, device="cpu"): | |
quant_keys = [] | |
for key in state_dict.keys(): | |
if key.endswith(".Q4_0_qweight"): | |
quant_keys.append(key.replace(".Q4_0_qweight", "")) | |
qtype = "Q4_0" | |
elif key.endswith(".Q8_0_qweight"): | |
quant_keys.append(key.replace(".Q8_0_qweight", "")) | |
qtype = "Q8_0" | |
for name, module in model.named_modules(): | |
if name in quant_keys: | |
q_linear = WQLinear_GGUF.from_linear( | |
linear=module, | |
device=device, | |
qtype=qtype, | |
) | |
set_op_by_name(model, name, q_linear) | |
model.to_empty(device=device) | |
model.load_state_dict(state_dict, strict=False) | |
model.to(device) | |
return model | |
def set_op_by_name(layer, name, new_module): | |
levels = name.split(".") | |
if len(levels) > 1: | |
mod_ = layer | |
for l_idx in range(len(levels) - 1): | |
if levels[l_idx].isdigit(): | |
mod_ = mod_[int(levels[l_idx])] | |
else: | |
mod_ = getattr(mod_, levels[l_idx]) | |
setattr(mod_, levels[-1], new_module) | |
else: | |
setattr(layer, name, new_module) | |
import torch.nn.functional as F | |
class WQLinear_GGUF(nn.Module): | |
def __init__( | |
self, in_features, out_features, bias, dev, qtype="Q4_0" | |
): | |
super().__init__() | |
self.in_features = in_features | |
self.out_features = out_features | |
self.qtype = qtype | |
qweight_shape = quant_shape_to_byte_shape( | |
(out_features, in_features), qtype | |
) | |
self.register_buffer( | |
f"{qtype}_qweight", | |
torch.zeros( | |
qweight_shape, | |
dtype=torch.uint8, | |
device=dev, | |
), | |
) | |
if bias: | |
self.register_buffer( | |
"bias", | |
torch.zeros( | |
(out_features), | |
dtype=torch.float16, | |
device=dev, | |
), | |
) | |
else: | |
self.bias = None | |
def from_linear( | |
cls, linear, | |
device="cpu", | |
qtype="Q4_0", | |
): | |
q_linear = cls( | |
linear.in_features, | |
linear.out_features, | |
linear.bias is not None, | |
device, | |
qtype=qtype, | |
) | |
return q_linear | |
def extra_repr(self) -> str: | |
return ( | |
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format( | |
self.in_features, | |
self.out_features, | |
self.bias is not None, | |
self.w_bit, | |
self.group_size, | |
) | |
) | |
def forward(self, x): | |
if self.qtype == "Q4_0": | |
dequant = dequantize_blocks_Q4_0(self.Q4_0_qweight, x.dtype) | |
elif self.qtype == "Q8_0": | |
dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype) | |
else: | |
raise ValueError(f"Unknown qtype: {self.qtype}") | |
return F.linear(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None) | |
def split_block_dims(blocks, *args): | |
n_max = blocks.shape[1] | |
dims = list(args) + [n_max - sum(args)] | |
return torch.split(blocks, dims, dim=1) | |
def quant_shape_to_byte_shape(shape, qtype) -> tuple[int, ...]: | |
# shape = shape[::-1] | |
block_size, type_size = GGML_QUANT_SIZES[qtype] | |
if shape[-1] % block_size != 0: | |
raise ValueError( | |
f"Quantized tensor row size ({shape[-1]}) is not a multiple of Q4_0 block size ({block_size})") | |
return (*shape[:-1], shape[-1] // block_size * type_size) | |
def quant_shape_from_byte_shape(shape, qtype) -> tuple[int, ...]: | |
# shape = shape[::-1] | |
block_size, type_size = GGML_QUANT_SIZES[qtype] | |
if shape[-1] % type_size != 0: | |
raise ValueError( | |
f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of Q4_0 type size ({type_size})") | |
return (*shape[:-1], shape[-1] // type_size * block_size) | |
GGML_QUANT_SIZES = { | |
"Q4_0": (32, 2 + 16), | |
"Q8_0": (32, 2 + 32), | |
} | |
def dequantize_blocks_Q4_0(data, dtype=torch.float16): | |
block_size, type_size = GGML_QUANT_SIZES["Q4_0"] | |
data = data.to(torch.uint8) | |
shape = data.shape | |
rows = data.reshape( | |
(-1, data.shape[-1]) | |
).view(torch.uint8) | |
n_blocks = rows.numel() // type_size | |
blocks = data.reshape((n_blocks, type_size)) | |
n_blocks = blocks.shape[0] | |
d, qs = split_block_dims(blocks, 2) | |
d = d.view(torch.float16) | |
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( | |
[0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) | |
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 | |
out = (d * qs) | |
out = out.reshape(quant_shape_from_byte_shape( | |
shape, | |
qtype="Q4_0", | |
)).to(dtype) | |
return out | |
def dequantize_blocks_Q8_0(data, dtype=torch.float16): | |
block_size, type_size = GGML_QUANT_SIZES["Q8_0"] | |
data = data.to(torch.uint8) | |
shape = data.shape | |
rows = data.reshape( | |
(-1, data.shape[-1]) | |
).view(torch.uint8) | |
n_blocks = rows.numel() // type_size | |
blocks = data.reshape((n_blocks, type_size)) | |
n_blocks = blocks.shape[0] | |
d, qs = split_block_dims(blocks, 2) | |
d = d.view(torch.float16).to(torch.float32) | |
qs = qs.view(torch.int8).to(torch.float32) | |
out = (d * qs) | |
out = out.reshape(quant_shape_from_byte_shape( | |
shape, | |
qtype="Q8_0", | |
)).to(dtype) | |
return out | |