import math | |
from torch import nn | |
from transformers.models.llama.modeling_llama import * | |
def activation_quant(x, n_bits = 8): | |
q_min = - 2**(n_bits - 1) | |
q_max = 2**(n_bits - 1) - 1 | |
scale = q_max / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) | |
x_quant = (x * scale).round().clamp_(q_min, q_max) / scale | |
return x_quant | |
def weight_quant(w): | |
scale = 1 / w.abs().mean().clamp_(min=1e-5) | |
w_quant = (w * scale).round().clamp_(-1, 1) / scale | |
return w_quant | |
class BitLinear(nn.Linear): | |
def __init__(self, | |
*kargs, | |
weight_bits=1, | |
input_bits=8, | |
**kwargs | |
): | |
super(BitLinear, self).__init__(*kargs, **kwargs) | |
def forward(self, x): | |
w = self.weight # a weight tensor with shape [d, k] | |
x = x.to(w.device) | |
RMSNorm = LlamaRMSNorm(x.shape[-1]).to(w.device) | |
x_norm = RMSNorm(x) | |
# A trick for implementing Straight−Through−Estimator (STE) using detach() | |
x_quant = x_norm + (activation_quant(x_norm, 8) - x_norm).detach() | |
w_quant = w + (weight_quant(w) - w).detach() | |
y = F.linear(x_quant, w_quant) | |
return y |