File size: 5,838 Bytes
9ff47b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import transformers
import torch
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import custom_fwd, custom_bwd
from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
class FrozenBNBLinear(nn.Module):
def __init__(self, weight, absmax, code, bias=None):
assert isinstance(bias, nn.Parameter) or bias is None
super().__init__()
self.out_features, self.in_features = weight.shape
self.register_buffer("weight", weight.requires_grad_(False))
self.register_buffer("absmax", absmax.requires_grad_(False))
self.register_buffer("code", code.requires_grad_(False))
self.adapter = None
self.bias = bias
def forward(self, input):
output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
if self.adapter:
output += self.adapter(input)
return output
@classmethod
def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
weights_int8, state = quantize_blockise_lowmemory(linear.weight)
return cls(weights_int8, *state, linear.bias)
def __repr__(self):
return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
class DequantizeAndLinear(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
ctx.save_for_backward(input, weights_quantized, absmax, code)
ctx._has_bias = bias is not None
return F.linear(input, weights_deq, bias)
@staticmethod
@custom_bwd
def backward(ctx, grad_output: torch.Tensor):
assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
input, weights_quantized, absmax, code = ctx.saved_tensors
# grad_output: [*batch, out_features]
weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
grad_input = grad_output @ weights_deq
grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
return grad_input, None, None, None, grad_bias
class FrozenBNBEmbedding(nn.Module):
def __init__(self, weight, absmax, code):
super().__init__()
self.num_embeddings, self.embedding_dim = weight.shape
self.register_buffer("weight", weight.requires_grad_(False))
self.register_buffer("absmax", absmax.requires_grad_(False))
self.register_buffer("code", code.requires_grad_(False))
self.adapter = None
def forward(self, input, **kwargs):
with torch.no_grad():
# note: both quantuized weights and input indices are *not* differentiable
weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
output = F.embedding(input, weight_deq, **kwargs)
if self.adapter:
output += self.adapter(input)
return output
@classmethod
def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
return cls(weights_int8, *state)
def __repr__(self):
return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
assert chunk_size % 4096 == 0
code = None
chunks = []
absmaxes = []
flat_tensor = matrix.view(-1)
for i in range((matrix.numel() - 1) // chunk_size + 1):
input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
chunks.append(quantized_chunk)
absmaxes.append(absmax_chunk)
matrix_i8 = torch.cat(chunks).reshape_as(matrix)
absmax = torch.cat(absmaxes)
return matrix_i8, (absmax, code)
def convert_to_int8(model):
"""Convert linear and embedding modules to 8-bit with optional adapters"""
for module in list(model.modules()):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
print(name, child)
setattr(
module,
name,
FrozenBNBLinear(
weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
code=torch.zeros(256),
bias=child.bias,
),
)
elif isinstance(child, nn.Embedding):
setattr(
module,
name,
FrozenBNBEmbedding(
weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
code=torch.zeros(256),
)
)
class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
def __init__(self, config):
super().__init__(config)
convert_to_int8(self.attn)
convert_to_int8(self.mlp)
class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
def __init__(self, config):
super().__init__(config)
convert_to_int8(self)
class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
def __init__(self, config):
super().__init__(config)
convert_to_int8(self) |