|
import math |
|
import torch |
|
from torch import nn |
|
|
|
|
|
def weight_quant(weight, num_bits=1): |
|
dtype = weight.dtype |
|
weight = weight.float() |
|
s = 1 / weight.abs().mean().clamp(min=1e-5) |
|
result = (weight * s).round().clamp(-1, 1) / s |
|
return result.type(dtype) |
|
|
|
|
|
def activation_quant(x, num_bits=8): |
|
dtype = x.dtype |
|
x = x.float() |
|
Qn = -2 ** (num_bits - 1) |
|
Qp = 2 ** (num_bits - 1) - 1 |
|
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) |
|
result = (x * s).round().clamp(Qn, Qp) / s |
|
return result.type(dtype) |
|
|
|
|
|
class BitLinear(nn.Linear): |
|
|
|
def __init__(self, |
|
*kargs, |
|
weight_bits=1, |
|
input_bits=8, |
|
**kwargs |
|
): |
|
super(BitLinear, self).__init__(*kargs, **kwargs) |
|
""" |
|
RMSNorm is placed outside BitLinear |
|
""" |
|
self.weight_bits = weight_bits |
|
self.input_bits = input_bits |
|
|
|
def forward(self, input): |
|
|
|
quant_input = input + (activation_quant(input, self.input_bits) - input).detach() |
|
quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach() |
|
|
|
out = nn.functional.linear(quant_input, quant_weight) |
|
if not self.bias is None: |
|
out += self.bias.view(1, -1).expand_as(out) |
|
|
|
return out |