# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) import gguf import torch from tqdm import tqdm TORCH_COMPATIBLE_QTYPES = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16} def is_torch_compatible(tensor): return tensor is None or getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES def is_quantized(tensor): return not is_torch_compatible(tensor) def dequantize_tensor(tensor, dtype=None, dequant_dtype=None): qtype = getattr(tensor, "tensor_type", None) oshape = getattr(tensor, "tensor_shape", tensor.shape) if qtype in TORCH_COMPATIBLE_QTYPES: return tensor.to(dtype) elif qtype in dequantize_functions: dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype) else: # this is incredibly slow tqdm.write(f"Falling back to numpy dequant for qtype: {qtype}") new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype) return torch.from_numpy(new).to(tensor.device, dtype=dtype) def dequantize(data, qtype, oshape, dtype=None): """ Dequantize tensor back to usable shape/dtype """ block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] dequantize_blocks = dequantize_functions[qtype] rows = data.reshape( (-1, data.shape[-1]) ).view(torch.uint8) n_blocks = rows.numel() // type_size blocks = rows.reshape((n_blocks, type_size)) blocks = dequantize_blocks(blocks, block_size, type_size, dtype) return blocks.reshape(oshape) def to_uint32(x): # no uint32 :( x = x.view(torch.uint8).to(torch.int32) return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1) def split_block_dims(blocks, *args): n_max = blocks.shape[1] dims = list(args) + [n_max - sum(args)] return torch.split(blocks, dims, dim=1) # Full weights # def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None): return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32) # Legacy Quants # def dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None): d, x = split_block_dims(blocks, 2) d = d.view(torch.float16).to(dtype) x = x.view(torch.int8) return (d * x) def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] d, m, qh, qs = split_block_dims(blocks, 2, 2, 4) d = d.view(torch.float16).to(dtype) m = m.view(torch.float16).to(dtype) qh = to_uint32(qh) qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) qh = (qh & 1).to(torch.uint8) ql = (ql & 0x0F).reshape((n_blocks, -1)) qs = (ql | (qh << 4)) return (d * qs) + m def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] d, qh, qs = split_block_dims(blocks, 2, 4) d = d.view(torch.float16).to(dtype) qh = to_uint32(qh) qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) qh = (qh & 1).to(torch.uint8) ql = (ql & 0x0F).reshape(n_blocks, -1) qs = (ql | (qh << 4)).to(torch.int8) - 16 return (d * qs) def dequantize_blocks_Q4_1(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] d, m, qs = split_block_dims(blocks, 2, 2) d = d.view(torch.float16).to(dtype) m = m.view(torch.float16).to(dtype) qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) qs = (qs & 0x0F).reshape(n_blocks, -1) return (d * qs) + m def dequantize_blocks_Q4_0(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] d, qs = split_block_dims(blocks, 2) d = d.view(torch.float16).to(dtype) qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 return (d * qs) # K Quants # QK_K = 256 K_SCALE_SIZE = 12 def get_scale_min(scales): n_blocks = scales.shape[0] scales = scales.view(torch.uint8) scales = scales.reshape((n_blocks, 3, 4)) d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2) sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1) min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1) return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8))) def dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] ql, qh, scales, d, = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16) scales = scales.view(torch.int8).to(dtype) d = d.view(torch.float16).to(dtype) d = (d * scales).reshape((n_blocks, QK_K // 16, 1)) ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) qh = (qh & 0x03).reshape((n_blocks, -1, 32)) q = (ql | (qh << 4)).to(torch.int8) - 32 q = q.reshape((n_blocks, QK_K // 16, -1)) return (d * q).reshape((n_blocks, QK_K)) def dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8) d = d.view(torch.float16).to(dtype) dmin = dmin.view(torch.float16).to(dtype) sc, m = get_scale_min(scales) d = (d * sc).reshape((n_blocks, -1, 1)) dm = (dmin * m).reshape((n_blocks, -1, 1)) ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([i for i in range(8)], device=d.device, dtype=torch.uint8).reshape((1, 1, 8, 1)) ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) qh = (qh & 0x01).reshape((n_blocks, -1, 32)) q = (ql | (qh << 4)) return (d * q - dm).reshape((n_blocks, QK_K)) def dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE) d = d.view(torch.float16).to(dtype) dmin = dmin.view(torch.float16).to(dtype) sc, m = get_scale_min(scales) d = (d * sc).reshape((n_blocks, -1, 1)) dm = (dmin * m).reshape((n_blocks, -1, 1)) qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) qs = (qs & 0x0F).reshape((n_blocks, -1, 32)) return (d * qs - dm).reshape((n_blocks, QK_K)) def dequantize_blocks_Q3_K(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12) d = d.view(torch.float16).to(dtype) lscales, hscales = scales[:, :8], scales[:, 8:] lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 2, 1)) lscales = lscales.reshape((n_blocks, 16)) hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 4, 1)) hscales = hscales.reshape((n_blocks, 16)) scales = (lscales & 0x0F) | ((hscales & 0x03) << 4) scales = (scales.to(torch.int8) - 32) dl = (d * scales).reshape((n_blocks, 16, 1)) ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.tensor([i for i in range(8)], device=d.device, dtype=torch.uint8).reshape((1, 1, 8, 1)) ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3 qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1 q = (ql.to(torch.int8) - (qh << 2).to(torch.int8)) return (dl * q).reshape((n_blocks, QK_K)) def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2) d = d.view(torch.float16).to(dtype) dmin = dmin.view(torch.float16).to(dtype) # (n_blocks, 16, 1) dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1)) ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1)) shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3 qs = qs.reshape((n_blocks, QK_K // 16, 16)) qs = dl * qs - ml return qs.reshape((n_blocks, -1)) dequantize_functions = { gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16, gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1, gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0, gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1, gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0, gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K, gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K, gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K, gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K, gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K, }