import bz2 import torch import base64 import ctypes import os import sys import traceback import math from torch.nn.parameter import Parameter from transformers.utils import logging import ctypes import pkg_resources from typing import List logger = logging.get_logger(__name__) try: import quant_cuda except: print('CUDA extension not installed.') class QuantizedLinear(torch.nn.Module): def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args, **kwargs): super().__init__() self.weight_bit_width = weight_bit_width shape = weight.shape self.shape = shape self.group_size = 128 self.register_buffer('qzeros', torch.zeros((math.ceil(shape[1]/self.group_size),shape[0] // 256 * (weight_bit_width * 8)), dtype=torch.int)) self.register_buffer('scales', torch.zeros((math.ceil(shape[1]/self.group_size),shape[0]), dtype=torch.float)) self.register_buffer( 'qweight', torch.zeros((shape[1] // 256 * (weight_bit_width * 8), shape[0]), dtype=torch.int) ) def forward(self, x): intermediate_dtype = torch.float32 outshape = list(x.shape) outshape[-1] = self.shape[0] x = x.reshape(-1, x.shape[-1]) y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device) output_dtype = x.dtype x = x.to(intermediate_dtype) if self.weight_bit_width == 2: quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size) elif self.weight_bit_width == 3: quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size) elif self.weight_bit_width == 4: quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size) elif self.weight_bit_width == 8: quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size) else: raise NotImplementedError("Only 2,3,4,8 bits are supported.") y = y.to(output_dtype) return y.reshape(outshape) def quantize(model, weight_bit_width, empty_init=False, device=None): for layer in model.layers: layer.self_attn.q_proj = QuantizedLinear( weight_bit_width=weight_bit_width, weight=layer.self_attn.q_proj.weight, bias=layer.self_attn.q_proj.bias, dtype=layer.self_attn.q_proj.weight.dtype, device=layer.self_attn.q_proj.weight.device if device is None else device, empty_init=empty_init ) layer.self_attn.k_proj = QuantizedLinear( weight_bit_width=weight_bit_width, weight=layer.self_attn.k_proj.weight, bias=layer.self_attn.k_proj.bias, dtype=layer.self_attn.k_proj.weight.dtype, device=layer.self_attn.k_proj.weight.device if device is None else device, empty_init=empty_init ) layer.self_attn.v_proj = QuantizedLinear( weight_bit_width=weight_bit_width, weight=layer.self_attn.v_proj.weight, bias=layer.self_attn.v_proj.bias, dtype=layer.self_attn.v_proj.weight.dtype, device=layer.self_attn.v_proj.weight.device if device is None else device, empty_init=empty_init ) layer.self_attn.o_proj = QuantizedLinear( weight_bit_width=weight_bit_width, weight=layer.self_attn.o_proj.weight, bias=layer.self_attn.o_proj.bias, dtype=layer.self_attn.o_proj.weight.dtype, device=layer.self_attn.o_proj.weight.device if device is None else device, empty_init=empty_init ) layer.mlp.gate_proj = QuantizedLinear( weight_bit_width=weight_bit_width, weight=layer.mlp.gate_proj.weight, bias=layer.mlp.gate_proj.bias, dtype=layer.mlp.gate_proj.weight.dtype, device=layer.mlp.gate_proj.weight.device if device is None else device, empty_init=empty_init ) layer.mlp.down_proj = QuantizedLinear( weight_bit_width=weight_bit_width, weight=layer.mlp.down_proj.weight, bias=layer.mlp.down_proj.bias, dtype=layer.mlp.down_proj.weight.dtype, device=layer.mlp.down_proj.weight.device if device is None else device, empty_init=empty_init ) layer.mlp.up_proj = QuantizedLinear( weight_bit_width=weight_bit_width, weight=layer.mlp.up_proj.weight, bias=layer.mlp.up_proj.bias, dtype=layer.mlp.up_proj.weight.dtype, device=layer.mlp.up_proj.weight.device if device is None else device, empty_init=empty_init ) return model