Update README.md
Browse files
README.md
CHANGED
@@ -10,305 +10,7 @@ tags:
|
|
10 |
Mixtral-8x7B-Instruct-v0.1 quantized to FP8 weights and activations, ready for inference with vLLM >= 0.5.0.
|
11 |
|
12 |
## Usage and Creation
|
13 |
-
Produced using [AutoFP8 with calibration samples from ultrachat](https://github.com/neuralmagic/AutoFP8/blob/147fa4d9e1a90ef8a93f96fc7d9c33056ddc017a/
|
14 |
-
|
15 |
-
Quantized using the script below:
|
16 |
-
|
17 |
-
Command:
|
18 |
-
```bash
|
19 |
-
python quantize.py --model-id mistralai/Mixtral-8x7B-Instruct-v0.1 --save-dir Mixtral-8x7B-Instruct-v0.1-FP8 --num-samples 512
|
20 |
-
```
|
21 |
-
|
22 |
-
Script:
|
23 |
-
```python
|
24 |
-
import argparse
|
25 |
-
import gc
|
26 |
-
import re
|
27 |
-
from typing import Tuple
|
28 |
-
|
29 |
-
import torch
|
30 |
-
import torch.functional as F
|
31 |
-
import transformers
|
32 |
-
from datasets import load_dataset
|
33 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
34 |
-
|
35 |
-
|
36 |
-
# HACK: override the dtype_byte_size function in transformers to support float8 types
|
37 |
-
# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
|
38 |
-
def new_dtype_byte_size(dtype):
|
39 |
-
if dtype == torch.bool:
|
40 |
-
return 1 / 8
|
41 |
-
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
|
42 |
-
if bit_search is None:
|
43 |
-
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
44 |
-
bit_size = int(bit_search.groups()[0])
|
45 |
-
return bit_size // 8
|
46 |
-
|
47 |
-
|
48 |
-
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
|
49 |
-
|
50 |
-
|
51 |
-
def cleanup_memory():
|
52 |
-
gc.collect()
|
53 |
-
torch.cuda.empty_cache()
|
54 |
-
|
55 |
-
|
56 |
-
def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
|
57 |
-
"""Quantize a tensor using per-tensor static scaling factor.
|
58 |
-
|
59 |
-
Args:
|
60 |
-
tensor: The input tensor.
|
61 |
-
"""
|
62 |
-
finfo = torch.finfo(torch.float8_e4m3fn)
|
63 |
-
# Calculate the scale as dtype max divided by absmax.
|
64 |
-
# Since .abs() creates a new tensor, we use aminmax to get
|
65 |
-
# the min and max first and then calculate the absmax.
|
66 |
-
if tensor.numel() == 0:
|
67 |
-
# Deal with empty tensors (triggered by empty MoE experts)
|
68 |
-
min_val, max_val = (
|
69 |
-
torch.tensor(0.0, dtype=tensor.dtype),
|
70 |
-
torch.tensor(1.0, dtype=tensor.dtype),
|
71 |
-
)
|
72 |
-
else:
|
73 |
-
min_val, max_val = tensor.aminmax()
|
74 |
-
amax = min_val.abs().max(max_val.abs())
|
75 |
-
scale = finfo.max / amax.clamp(min=1e-12)
|
76 |
-
# scale and clamp the tensor to bring it to
|
77 |
-
# the representative range of float8 data type
|
78 |
-
# (as default cast is unsaturated)
|
79 |
-
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
|
80 |
-
# Return both float8 data and the inverse scale (as float),
|
81 |
-
# as both required as inputs to torch._scaled_mm
|
82 |
-
qweight = qweight.to(torch.float8_e4m3fn)
|
83 |
-
scale = scale.float().reciprocal()
|
84 |
-
return qweight, scale
|
85 |
-
|
86 |
-
|
87 |
-
def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
|
88 |
-
cuda_compute_capability = torch.cuda.get_device_capability()
|
89 |
-
if cuda_compute_capability >= (9, 0):
|
90 |
-
output, _ = torch._scaled_mm(
|
91 |
-
A,
|
92 |
-
B.t(),
|
93 |
-
out_dtype=out_dtype,
|
94 |
-
scale_a=A_scale,
|
95 |
-
scale_b=B_scale,
|
96 |
-
bias=bias,
|
97 |
-
)
|
98 |
-
else:
|
99 |
-
output = torch.nn.functional.linear(
|
100 |
-
A.to(out_dtype) * A_scale,
|
101 |
-
B.to(out_dtype) * B_scale.to(out_dtype),
|
102 |
-
bias=bias,
|
103 |
-
)
|
104 |
-
return output
|
105 |
-
|
106 |
-
|
107 |
-
class FP8StaticLinearQuantizer(torch.nn.Module):
|
108 |
-
def __init__(self, qweight, weight_scale):
|
109 |
-
super().__init__()
|
110 |
-
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
|
111 |
-
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
112 |
-
self.act_scale = None
|
113 |
-
|
114 |
-
def forward(self, x):
|
115 |
-
# Dynamically quantize
|
116 |
-
qinput, x_act_scale = per_tensor_quantize(x)
|
117 |
-
|
118 |
-
# Update scale if needed.
|
119 |
-
if self.act_scale is None:
|
120 |
-
self.act_scale = torch.nn.Parameter(x_act_scale)
|
121 |
-
elif x_act_scale > self.act_scale:
|
122 |
-
self.act_scale = torch.nn.Parameter(x_act_scale)
|
123 |
-
|
124 |
-
# Pass quantized to next layer so it has realistic data.
|
125 |
-
output = fp8_gemm(
|
126 |
-
A=qinput,
|
127 |
-
A_scale=self.act_scale,
|
128 |
-
B=self.weight,
|
129 |
-
B_scale=self.weight_scale,
|
130 |
-
bias=None,
|
131 |
-
out_dtype=x.dtype,
|
132 |
-
)
|
133 |
-
return output
|
134 |
-
|
135 |
-
|
136 |
-
class FP8StaticLinear(torch.nn.Module):
|
137 |
-
def __init__(self, qweight, weight_scale, act_scale=0.0):
|
138 |
-
super().__init__()
|
139 |
-
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
|
140 |
-
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
141 |
-
self.act_scale = torch.nn.Parameter(act_scale, requires_grad=False)
|
142 |
-
|
143 |
-
def per_tensor_quantize(
|
144 |
-
self, tensor: torch.Tensor, inv_scale: float
|
145 |
-
) -> torch.Tensor:
|
146 |
-
# Scale and clamp the tensor to bring it to
|
147 |
-
# the representative range of float8 data type
|
148 |
-
# (as default cast is unsaturated)
|
149 |
-
finfo = torch.finfo(torch.float8_e4m3fn)
|
150 |
-
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
|
151 |
-
return qweight.to(torch.float8_e4m3fn)
|
152 |
-
|
153 |
-
def forward(self, x):
|
154 |
-
qinput = self.per_tensor_quantize(x, inv_scale=self.act_scale)
|
155 |
-
output = fp8_gemm(
|
156 |
-
A=qinput,
|
157 |
-
A_scale=self.act_scale,
|
158 |
-
B=self.weight,
|
159 |
-
B_scale=self.weight_scale,
|
160 |
-
bias=None,
|
161 |
-
out_dtype=x.dtype,
|
162 |
-
)
|
163 |
-
return output
|
164 |
-
|
165 |
-
|
166 |
-
class FP8DynamicLinear(torch.nn.Module):
|
167 |
-
def __init__(self, qweight, scale):
|
168 |
-
super().__init__()
|
169 |
-
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
|
170 |
-
self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
|
171 |
-
|
172 |
-
def forward(self, x):
|
173 |
-
qinput, x_scale = per_tensor_quantize(x)
|
174 |
-
output = fp8_gemm(
|
175 |
-
A=qinput,
|
176 |
-
A_scale=x_scale,
|
177 |
-
B=self.weight,
|
178 |
-
B_scale=self.weight_scale,
|
179 |
-
bias=None,
|
180 |
-
out_dtype=x.dtype,
|
181 |
-
)
|
182 |
-
return output
|
183 |
-
|
184 |
-
|
185 |
-
def replace_module(model, name, new_module):
|
186 |
-
if "." in name:
|
187 |
-
parent_name = name.rsplit(".", 1)[0]
|
188 |
-
child_name = name[len(parent_name) + 1 :]
|
189 |
-
parent = model.model.get_submodule(parent_name)
|
190 |
-
else:
|
191 |
-
parent_name = ""
|
192 |
-
parent = model.model
|
193 |
-
child_name = name
|
194 |
-
setattr(parent, child_name, new_module)
|
195 |
-
|
196 |
-
|
197 |
-
def quantize_weights(model):
|
198 |
-
for name, linear in model.model.named_modules():
|
199 |
-
if "gate" in name or not isinstance(linear, torch.nn.Linear):
|
200 |
-
continue
|
201 |
-
quant_weight, quant_scale = per_tensor_quantize(linear.weight)
|
202 |
-
quant_linear = FP8DynamicLinear(quant_weight, quant_scale)
|
203 |
-
replace_module(model, name, quant_linear)
|
204 |
-
del linear
|
205 |
-
cleanup_memory()
|
206 |
-
|
207 |
-
|
208 |
-
def quantize_activations(model, calibration_tokens):
|
209 |
-
# Replace layers with quantizer.
|
210 |
-
for name, dynamic_quant_linear in model.model.named_modules():
|
211 |
-
if "gate" in name or not isinstance(dynamic_quant_linear, FP8DynamicLinear):
|
212 |
-
continue
|
213 |
-
quantizer = FP8StaticLinearQuantizer(
|
214 |
-
dynamic_quant_linear.weight, dynamic_quant_linear.weight_scale
|
215 |
-
)
|
216 |
-
replace_module(model, name, quantizer)
|
217 |
-
del dynamic_quant_linear
|
218 |
-
cleanup_memory()
|
219 |
-
|
220 |
-
# Calibration.
|
221 |
-
for row_idx in range(calibration_tokens.shape[0]):
|
222 |
-
_ = model(calibration_tokens[row_idx].reshape(1, -1))
|
223 |
-
|
224 |
-
# Replace quantizer with StaticLayer.
|
225 |
-
for name, quantizer in model.model.named_modules():
|
226 |
-
if "gate" in name or not isinstance(quantizer, FP8StaticLinearQuantizer):
|
227 |
-
continue
|
228 |
-
static_proj = FP8StaticLinear(
|
229 |
-
quantizer.weight, quantizer.weight_scale, quantizer.act_scale
|
230 |
-
)
|
231 |
-
replace_module(model, name, static_proj)
|
232 |
-
del quantizer
|
233 |
-
cleanup_memory()
|
234 |
-
|
235 |
-
|
236 |
-
def save_quantized_model(model, activation_scheme, save_dir):
|
237 |
-
print(f"Saving the model to {save_dir}")
|
238 |
-
static_q_dict = {
|
239 |
-
"quantization_config": {
|
240 |
-
"quant_method": "fp8",
|
241 |
-
"activation_scheme": activation_scheme,
|
242 |
-
}
|
243 |
-
}
|
244 |
-
model.config.update(static_q_dict)
|
245 |
-
model.save_pretrained(save_dir)
|
246 |
-
tokenizer.save_pretrained(save_dir)
|
247 |
-
|
248 |
-
|
249 |
-
if __name__ == "__main__":
|
250 |
-
parser = argparse.ArgumentParser()
|
251 |
-
parser.add_argument("--model-id", type=str)
|
252 |
-
parser.add_argument("--save-dir", type=str)
|
253 |
-
parser.add_argument(
|
254 |
-
"--activation-scheme", type=str, default="static", choices=["static", "dynamic"]
|
255 |
-
)
|
256 |
-
parser.add_argument("--num-samples", type=int, default=512)
|
257 |
-
parser.add_argument("--max-seq-len", type=int, default=512)
|
258 |
-
args = parser.parse_args()
|
259 |
-
|
260 |
-
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
|
261 |
-
sample_input_tokens = tokenizer.apply_chat_template(
|
262 |
-
[{"role": "user", "content": "What is your name?"}],
|
263 |
-
add_generation_prompt=True,
|
264 |
-
return_tensors="pt",
|
265 |
-
).to("cuda")
|
266 |
-
|
267 |
-
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
|
268 |
-
ds = ds.shuffle(seed=42).select(range(args.num_samples))
|
269 |
-
ds = ds.map(
|
270 |
-
lambda batch: {
|
271 |
-
"text": tokenizer.apply_chat_template(batch["messages"], tokenize=False)
|
272 |
-
}
|
273 |
-
)
|
274 |
-
tokenizer.pad_token_id = tokenizer.eos_token_id
|
275 |
-
calibration_tokens = tokenizer(
|
276 |
-
ds["text"],
|
277 |
-
return_tensors="pt",
|
278 |
-
truncation=True,
|
279 |
-
padding="max_length",
|
280 |
-
max_length=args.max_seq_len,
|
281 |
-
add_special_tokens=False,
|
282 |
-
).input_ids.to("cuda")
|
283 |
-
print("Calibration tokens:", calibration_tokens.shape)
|
284 |
-
|
285 |
-
# Load and test the model
|
286 |
-
model = AutoModelForCausalLM.from_pretrained(
|
287 |
-
args.model_id, torch_dtype="auto", device_map="auto"
|
288 |
-
)
|
289 |
-
print(model)
|
290 |
-
output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
|
291 |
-
print("ORIGINAL:\n", tokenizer.decode(output[0]), "\n\n")
|
292 |
-
|
293 |
-
# Quantize weights.
|
294 |
-
quantize_weights(model)
|
295 |
-
print(model)
|
296 |
-
output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
|
297 |
-
print("WEIGHT QUANT:\n", tokenizer.decode(output[0]), "\n\n")
|
298 |
-
|
299 |
-
if args.activation_scheme in "dynamic":
|
300 |
-
print("Exporting model with static weights and dynamic activations")
|
301 |
-
save_quantized_model(model, args.activation_scheme, args.save_dir)
|
302 |
-
else:
|
303 |
-
assert args.activation_scheme in "static"
|
304 |
-
# Quantize activations.
|
305 |
-
quantize_activations(model, calibration_tokens=calibration_tokens)
|
306 |
-
output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
|
307 |
-
print("ACT QUANT:\n", tokenizer.decode(output[0]), "\n\n")
|
308 |
-
|
309 |
-
print("Exporting model with static weights and static activations")
|
310 |
-
save_quantized_model(model, args.activation_scheme, args.save_dir)
|
311 |
-
```
|
312 |
|
313 |
## Evaluation
|
314 |
|
|
|
10 |
Mixtral-8x7B-Instruct-v0.1 quantized to FP8 weights and activations, ready for inference with vLLM >= 0.5.0.
|
11 |
|
12 |
## Usage and Creation
|
13 |
+
Produced using [AutoFP8 with calibration samples from ultrachat](https://github.com/neuralmagic/AutoFP8/blob/147fa4d9e1a90ef8a93f96fc7d9c33056ddc017a/examples/example_mixtral.py) with `block_sparse_moe.gate` layers kept at original precision.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
## Evaluation
|
16 |
|