joaoalvarenga commited on
Commit
2d481d1
1 Parent(s): 82449aa

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +163 -0
README.md CHANGED
@@ -55,5 +55,168 @@ pipeline_tag: text-generation
55
 
56
  Heavily inspired by [Hivemind's GPT-J-6B with 8-bit weights](https://huggingface.co/hivemind/gpt-j-6B-8bit), this is a version of [bigscience/bloom](https://huggingface.co/bigscience/bloom) a ~176 billions parameters language model that you run and fine-tune with less memory.
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
 
55
 
56
  Heavily inspired by [Hivemind's GPT-J-6B with 8-bit weights](https://huggingface.co/hivemind/gpt-j-6B-8bit), this is a version of [bigscience/bloom](https://huggingface.co/bigscience/bloom) a ~176 billions parameters language model that you run and fine-tune with less memory.
57
 
58
+ Here, we also apply [LoRA (Low Rank Adpatars](https://arxiv.org/abs/2106.09685) to reduce model size. The original version takes ~353GB memory, this version takes ~180GB.
59
+
60
+ ### How to use
61
+
62
+ This model can be used by adapting Bloom original implementation:
63
+
64
+ ```python
65
+ import transformers
66
+ import torch
67
+ import torch.nn as nn
68
+ import torch.nn.functional as F
69
+
70
+ from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
71
+ from typing import Tuple
72
+ from torch.cuda.amp import custom_fwd, custom_bwd
73
+
74
+ class FrozenBNBLinear(nn.Module):
75
+ def __init__(self, weight, absmax, code, bias=None):
76
+ assert isinstance(bias, nn.Parameter) or bias is None
77
+ super().__init__()
78
+ self.out_features, self.in_features = weight.shape
79
+ self.register_buffer("weight", weight.requires_grad_(False))
80
+ self.register_buffer("absmax", absmax.requires_grad_(False))
81
+ self.register_buffer("code", code.requires_grad_(False))
82
+ self.adapter = None
83
+ self.bias = bias
84
+
85
+ def forward(self, input):
86
+ output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
87
+ if self.adapter:
88
+ output += self.adapter(input)
89
+ return output
90
+
91
+ @classmethod
92
+ def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
93
+ weights_int8, state = quantize_blockise_lowmemory(linear.weight)
94
+ return cls(weights_int8, *state, linear.bias)
95
+
96
+ def __repr__(self):
97
+ return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
98
+
99
+
100
+ class DequantizeAndLinear(torch.autograd.Function):
101
+ @staticmethod
102
+ @custom_fwd
103
+ def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
104
+ absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
105
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
106
+ ctx.save_for_backward(input, weights_quantized, absmax, code)
107
+ ctx._has_bias = bias is not None
108
+ return F.linear(input, weights_deq, bias)
109
+
110
+ @staticmethod
111
+ @custom_bwd
112
+ def backward(ctx, grad_output: torch.Tensor):
113
+ assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
114
+ input, weights_quantized, absmax, code = ctx.saved_tensors
115
+ # grad_output: [*batch, out_features]
116
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
117
+ grad_input = grad_output @ weights_deq
118
+ grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
119
+ return grad_input, None, None, None, grad_bias
120
+
121
+
122
+ class FrozenBNBEmbedding(nn.Module):
123
+ def __init__(self, weight, absmax, code):
124
+ super().__init__()
125
+ self.num_embeddings, self.embedding_dim = weight.shape
126
+ self.register_buffer("weight", weight.requires_grad_(False))
127
+ self.register_buffer("absmax", absmax.requires_grad_(False))
128
+ self.register_buffer("code", code.requires_grad_(False))
129
+ self.adapter = None
130
+
131
+ def forward(self, input, **kwargs):
132
+ with torch.no_grad():
133
+ # note: both quantuized weights and input indices are *not* differentiable
134
+ weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
135
+ output = F.embedding(input, weight_deq, **kwargs)
136
+ if self.adapter:
137
+ output += self.adapter(input)
138
+ return output
139
+
140
+ @classmethod
141
+ def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
142
+ weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
143
+ return cls(weights_int8, *state)
144
+
145
+ def __repr__(self):
146
+ return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
147
+
148
+
149
+ def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
150
+ assert chunk_size % 4096 == 0
151
+ code = None
152
+ chunks = []
153
+ absmaxes = []
154
+ flat_tensor = matrix.view(-1)
155
+ for i in range((matrix.numel() - 1) // chunk_size + 1):
156
+ input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
157
+ quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
158
+ chunks.append(quantized_chunk)
159
+ absmaxes.append(absmax_chunk)
160
+
161
+ matrix_i8 = torch.cat(chunks).reshape_as(matrix)
162
+ absmax = torch.cat(absmaxes)
163
+ return matrix_i8, (absmax, code)
164
+
165
+
166
+ def convert_to_int8(model):
167
+ """Convert linear and embedding modules to 8-bit with optional adapters"""
168
+ for module in list(model.modules()):
169
+ for name, child in module.named_children():
170
+ if isinstance(child, nn.Linear):
171
+ print(name, child)
172
+ setattr(
173
+ module,
174
+ name,
175
+ FrozenBNBLinear(
176
+ weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
177
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
178
+ code=torch.zeros(256),
179
+ bias=child.bias,
180
+ ),
181
+ )
182
+ elif isinstance(child, nn.Embedding):
183
+ setattr(
184
+ module,
185
+ name,
186
+ FrozenBNBEmbedding(
187
+ weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
188
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
189
+ code=torch.zeros(256),
190
+ )
191
+ )
192
+
193
+ class BloomBlock(transformers.models.bloom.modeling_bloom.BloomBlock):
194
+ def __init__(self, config, layer_number=None):
195
+ super().__init__(config, layer_number)
196
+
197
+ convert_to_int8(self.self_attention)
198
+ convert_to_int8(self.mlp)
199
+
200
+
201
+ class BloomModel(transformers.models.bloom.modeling_bloom.BloomModel):
202
+ def __init__(self, config):
203
+ super().__init__(config)
204
+ convert_to_int8(self)
205
+
206
+
207
+ class BloomForCausalLM(transformers.models.bloom.modeling_bloom.BloomForCausalLM):
208
+ def __init__(self, config):
209
+ super().__init__(config)
210
+ convert_to_int8(self)
211
+
212
+ transformers.models.bloom.modeling_bloom.BloomBlock = BloomBlock
213
+
214
+ model = BloomForCausalLM.from_pretrained('joaoalvarenga/bloom-8bit', low_cpu_mem_usage=True)
215
+ tokenizer = BloomTokenizerFast.from_pretrained('joaoalvarenga/bloom-8bit')
216
+
217
+ prompt = tokenizer("Given a table named salaries and columns id, created_at, salary, age. Creates a SQL to answer What is the average salary for 22 years old:", return_tensors='pt')
218
+ out = model.generate(**prompt, min_length=10, do_sample=True)
219
+ tokenizer.decode(out[0])```
220
+
221
 
222