mgoin commited on
Commit
d288bbc
1 Parent(s): 8b22b46

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +1 -299
README.md CHANGED
@@ -10,305 +10,7 @@ tags:
10
  Mixtral-8x7B-Instruct-v0.1 quantized to FP8 weights and activations, ready for inference with vLLM >= 0.5.0.
11
 
12
  ## Usage and Creation
13
- Produced using [AutoFP8 with calibration samples from ultrachat](https://github.com/neuralmagic/AutoFP8/blob/147fa4d9e1a90ef8a93f96fc7d9c33056ddc017a/example_dataset.py).
14
-
15
- Quantized using the script below:
16
-
17
- Command:
18
- ```bash
19
- python quantize.py --model-id mistralai/Mixtral-8x7B-Instruct-v0.1 --save-dir Mixtral-8x7B-Instruct-v0.1-FP8 --num-samples 512
20
- ```
21
-
22
- Script:
23
- ```python
24
- import argparse
25
- import gc
26
- import re
27
- from typing import Tuple
28
-
29
- import torch
30
- import torch.functional as F
31
- import transformers
32
- from datasets import load_dataset
33
- from transformers import AutoModelForCausalLM, AutoTokenizer
34
-
35
-
36
- # HACK: override the dtype_byte_size function in transformers to support float8 types
37
- # Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
38
- def new_dtype_byte_size(dtype):
39
- if dtype == torch.bool:
40
- return 1 / 8
41
- bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
42
- if bit_search is None:
43
- raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
44
- bit_size = int(bit_search.groups()[0])
45
- return bit_size // 8
46
-
47
-
48
- transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
49
-
50
-
51
- def cleanup_memory():
52
- gc.collect()
53
- torch.cuda.empty_cache()
54
-
55
-
56
- def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
57
- """Quantize a tensor using per-tensor static scaling factor.
58
-
59
- Args:
60
- tensor: The input tensor.
61
- """
62
- finfo = torch.finfo(torch.float8_e4m3fn)
63
- # Calculate the scale as dtype max divided by absmax.
64
- # Since .abs() creates a new tensor, we use aminmax to get
65
- # the min and max first and then calculate the absmax.
66
- if tensor.numel() == 0:
67
- # Deal with empty tensors (triggered by empty MoE experts)
68
- min_val, max_val = (
69
- torch.tensor(0.0, dtype=tensor.dtype),
70
- torch.tensor(1.0, dtype=tensor.dtype),
71
- )
72
- else:
73
- min_val, max_val = tensor.aminmax()
74
- amax = min_val.abs().max(max_val.abs())
75
- scale = finfo.max / amax.clamp(min=1e-12)
76
- # scale and clamp the tensor to bring it to
77
- # the representative range of float8 data type
78
- # (as default cast is unsaturated)
79
- qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
80
- # Return both float8 data and the inverse scale (as float),
81
- # as both required as inputs to torch._scaled_mm
82
- qweight = qweight.to(torch.float8_e4m3fn)
83
- scale = scale.float().reciprocal()
84
- return qweight, scale
85
-
86
-
87
- def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
88
- cuda_compute_capability = torch.cuda.get_device_capability()
89
- if cuda_compute_capability >= (9, 0):
90
- output, _ = torch._scaled_mm(
91
- A,
92
- B.t(),
93
- out_dtype=out_dtype,
94
- scale_a=A_scale,
95
- scale_b=B_scale,
96
- bias=bias,
97
- )
98
- else:
99
- output = torch.nn.functional.linear(
100
- A.to(out_dtype) * A_scale,
101
- B.to(out_dtype) * B_scale.to(out_dtype),
102
- bias=bias,
103
- )
104
- return output
105
-
106
-
107
- class FP8StaticLinearQuantizer(torch.nn.Module):
108
- def __init__(self, qweight, weight_scale):
109
- super().__init__()
110
- self.weight = torch.nn.Parameter(qweight, requires_grad=False)
111
- self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
112
- self.act_scale = None
113
-
114
- def forward(self, x):
115
- # Dynamically quantize
116
- qinput, x_act_scale = per_tensor_quantize(x)
117
-
118
- # Update scale if needed.
119
- if self.act_scale is None:
120
- self.act_scale = torch.nn.Parameter(x_act_scale)
121
- elif x_act_scale > self.act_scale:
122
- self.act_scale = torch.nn.Parameter(x_act_scale)
123
-
124
- # Pass quantized to next layer so it has realistic data.
125
- output = fp8_gemm(
126
- A=qinput,
127
- A_scale=self.act_scale,
128
- B=self.weight,
129
- B_scale=self.weight_scale,
130
- bias=None,
131
- out_dtype=x.dtype,
132
- )
133
- return output
134
-
135
-
136
- class FP8StaticLinear(torch.nn.Module):
137
- def __init__(self, qweight, weight_scale, act_scale=0.0):
138
- super().__init__()
139
- self.weight = torch.nn.Parameter(qweight, requires_grad=False)
140
- self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
141
- self.act_scale = torch.nn.Parameter(act_scale, requires_grad=False)
142
-
143
- def per_tensor_quantize(
144
- self, tensor: torch.Tensor, inv_scale: float
145
- ) -> torch.Tensor:
146
- # Scale and clamp the tensor to bring it to
147
- # the representative range of float8 data type
148
- # (as default cast is unsaturated)
149
- finfo = torch.finfo(torch.float8_e4m3fn)
150
- qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
151
- return qweight.to(torch.float8_e4m3fn)
152
-
153
- def forward(self, x):
154
- qinput = self.per_tensor_quantize(x, inv_scale=self.act_scale)
155
- output = fp8_gemm(
156
- A=qinput,
157
- A_scale=self.act_scale,
158
- B=self.weight,
159
- B_scale=self.weight_scale,
160
- bias=None,
161
- out_dtype=x.dtype,
162
- )
163
- return output
164
-
165
-
166
- class FP8DynamicLinear(torch.nn.Module):
167
- def __init__(self, qweight, scale):
168
- super().__init__()
169
- self.weight = torch.nn.Parameter(qweight, requires_grad=False)
170
- self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
171
-
172
- def forward(self, x):
173
- qinput, x_scale = per_tensor_quantize(x)
174
- output = fp8_gemm(
175
- A=qinput,
176
- A_scale=x_scale,
177
- B=self.weight,
178
- B_scale=self.weight_scale,
179
- bias=None,
180
- out_dtype=x.dtype,
181
- )
182
- return output
183
-
184
-
185
- def replace_module(model, name, new_module):
186
- if "." in name:
187
- parent_name = name.rsplit(".", 1)[0]
188
- child_name = name[len(parent_name) + 1 :]
189
- parent = model.model.get_submodule(parent_name)
190
- else:
191
- parent_name = ""
192
- parent = model.model
193
- child_name = name
194
- setattr(parent, child_name, new_module)
195
-
196
-
197
- def quantize_weights(model):
198
- for name, linear in model.model.named_modules():
199
- if "gate" in name or not isinstance(linear, torch.nn.Linear):
200
- continue
201
- quant_weight, quant_scale = per_tensor_quantize(linear.weight)
202
- quant_linear = FP8DynamicLinear(quant_weight, quant_scale)
203
- replace_module(model, name, quant_linear)
204
- del linear
205
- cleanup_memory()
206
-
207
-
208
- def quantize_activations(model, calibration_tokens):
209
- # Replace layers with quantizer.
210
- for name, dynamic_quant_linear in model.model.named_modules():
211
- if "gate" in name or not isinstance(dynamic_quant_linear, FP8DynamicLinear):
212
- continue
213
- quantizer = FP8StaticLinearQuantizer(
214
- dynamic_quant_linear.weight, dynamic_quant_linear.weight_scale
215
- )
216
- replace_module(model, name, quantizer)
217
- del dynamic_quant_linear
218
- cleanup_memory()
219
-
220
- # Calibration.
221
- for row_idx in range(calibration_tokens.shape[0]):
222
- _ = model(calibration_tokens[row_idx].reshape(1, -1))
223
-
224
- # Replace quantizer with StaticLayer.
225
- for name, quantizer in model.model.named_modules():
226
- if "gate" in name or not isinstance(quantizer, FP8StaticLinearQuantizer):
227
- continue
228
- static_proj = FP8StaticLinear(
229
- quantizer.weight, quantizer.weight_scale, quantizer.act_scale
230
- )
231
- replace_module(model, name, static_proj)
232
- del quantizer
233
- cleanup_memory()
234
-
235
-
236
- def save_quantized_model(model, activation_scheme, save_dir):
237
- print(f"Saving the model to {save_dir}")
238
- static_q_dict = {
239
- "quantization_config": {
240
- "quant_method": "fp8",
241
- "activation_scheme": activation_scheme,
242
- }
243
- }
244
- model.config.update(static_q_dict)
245
- model.save_pretrained(save_dir)
246
- tokenizer.save_pretrained(save_dir)
247
-
248
-
249
- if __name__ == "__main__":
250
- parser = argparse.ArgumentParser()
251
- parser.add_argument("--model-id", type=str)
252
- parser.add_argument("--save-dir", type=str)
253
- parser.add_argument(
254
- "--activation-scheme", type=str, default="static", choices=["static", "dynamic"]
255
- )
256
- parser.add_argument("--num-samples", type=int, default=512)
257
- parser.add_argument("--max-seq-len", type=int, default=512)
258
- args = parser.parse_args()
259
-
260
- tokenizer = AutoTokenizer.from_pretrained(args.model_id)
261
- sample_input_tokens = tokenizer.apply_chat_template(
262
- [{"role": "user", "content": "What is your name?"}],
263
- add_generation_prompt=True,
264
- return_tensors="pt",
265
- ).to("cuda")
266
-
267
- ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
268
- ds = ds.shuffle(seed=42).select(range(args.num_samples))
269
- ds = ds.map(
270
- lambda batch: {
271
- "text": tokenizer.apply_chat_template(batch["messages"], tokenize=False)
272
- }
273
- )
274
- tokenizer.pad_token_id = tokenizer.eos_token_id
275
- calibration_tokens = tokenizer(
276
- ds["text"],
277
- return_tensors="pt",
278
- truncation=True,
279
- padding="max_length",
280
- max_length=args.max_seq_len,
281
- add_special_tokens=False,
282
- ).input_ids.to("cuda")
283
- print("Calibration tokens:", calibration_tokens.shape)
284
-
285
- # Load and test the model
286
- model = AutoModelForCausalLM.from_pretrained(
287
- args.model_id, torch_dtype="auto", device_map="auto"
288
- )
289
- print(model)
290
- output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
291
- print("ORIGINAL:\n", tokenizer.decode(output[0]), "\n\n")
292
-
293
- # Quantize weights.
294
- quantize_weights(model)
295
- print(model)
296
- output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
297
- print("WEIGHT QUANT:\n", tokenizer.decode(output[0]), "\n\n")
298
-
299
- if args.activation_scheme in "dynamic":
300
- print("Exporting model with static weights and dynamic activations")
301
- save_quantized_model(model, args.activation_scheme, args.save_dir)
302
- else:
303
- assert args.activation_scheme in "static"
304
- # Quantize activations.
305
- quantize_activations(model, calibration_tokens=calibration_tokens)
306
- output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
307
- print("ACT QUANT:\n", tokenizer.decode(output[0]), "\n\n")
308
-
309
- print("Exporting model with static weights and static activations")
310
- save_quantized_model(model, args.activation_scheme, args.save_dir)
311
- ```
312
 
313
  ## Evaluation
314
 
 
10
  Mixtral-8x7B-Instruct-v0.1 quantized to FP8 weights and activations, ready for inference with vLLM >= 0.5.0.
11
 
12
  ## Usage and Creation
13
+ Produced using [AutoFP8 with calibration samples from ultrachat](https://github.com/neuralmagic/AutoFP8/blob/147fa4d9e1a90ef8a93f96fc7d9c33056ddc017a/examples/example_mixtral.py) with `block_sparse_moe.gate` layers kept at original precision.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  ## Evaluation
16