File size: 7,134 Bytes
6f59b43 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch.nn as nn
import torch
def quantize_fp8(tensor: torch.Tensor, scale: torch.Tensor):
dtype = tensor.dtype
clamp_min, clamp_max = torch.tensor(-240., dtype=dtype), torch.tensor(240., dtype=dtype)
quant_tensor = torch.clamp((tensor/scale), clamp_min, clamp_max).to(torch.float8_e4m3fnuz).to(dtype)
return quant_tensor
def dequantize_fp8(tensor: torch.Tensor, scale: torch.Tensor):
return tensor * scale
class QuantLinear(nn.Module):
def __init__(self, in_ch, out_ch, quant_param):
super().__init__()
mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
self.register_buffer('mul_factor', mul_factor)
self.linear = nn.Linear(in_ch, out_ch)
weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
# weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
# assert quant_param['weight_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Weight Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['weight_zp_dype']}"
input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
assert quant_param['input_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Input Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['input_zp_dype']}"
self.register_buffer('weight_scale', weight_scale)
# self.register_buffer('weight_zp', weight_zp)
self.register_buffer('input_scale', input_scale)
self.register_buffer('input_zp', input_zp)
# I.e., "fake quantization"
def qdq_forward(self, x):
print(self.mul_factor.shape)
scaled_x = x * self.mul_factor
quant_weight = quantize_fp8(self.linear.weight, self.weight_scale)
quant_input = quantize_fp8(scaled_x, self.input_scale)
dequantized_weight = dequantize_fp8(quant_weight, self.weight_scale)
dequantized_input = dequantize_fp8(quant_input, self.input_scale)
out = torch.nn.functional.linear(dequantized_input, dequantized_weight, self.linear.bias)
return out
# Accelerated version
def qop_forward(self, x):
quant_weight = quantize_fp8(self.linear.weight, self.weight_scale).to(torch.float8_e4m3fnuz)
fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
quant_input = quantize_fp8(x, fused_input_scale).to(torch.float8_e4m3fnuz)
quant_output = torch.nn.functional.linear(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.float32) # Convert inputs to FP32 to avoid F.linear quantizing the output to int8
output = dequantize_fp8(quant_output, (self.weight_scale * self.input_scale).view([1]*(quant_output.ndim-1) + [(self.weight_scale * self.input_scale).nelement()]))
output += self.linear.bias
return output
def forward(self, x, qop=False):
if qop:
return self.qop_forward(x)
else:
return self.qdq_forward(x)
class QuantConv2d(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, quant_param):
super().__init__()
mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
self.register_buffer('mul_factor', mul_factor)
self.conv2d = nn.Conv2d(in_ch, out_ch, kernel_size)
weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
assert quant_param['input_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Input Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['input_zp_dype']}"
self.register_buffer('weight_scale', weight_scale)
self.register_buffer('input_scale', input_scale)
self.register_buffer('input_zp', input_zp)
# I.e., "fake quantization"
def qdq_forward(self, x):
scaled_x = x * self.mul_factor
quant_weight = quantize_fp8(self.conv2d.weight, self.weight_scale)
quant_input = quantize_fp8(scaled_x, self.input_scale)
dequantized_weight = dequantize_fp8(quant_weight, self.weight_scale)
dequantized_input = dequantize_fp8(quant_input, self.input_scale)
out = torch.nn.functional.conv2d(dequantized_input, dequantized_weight, self.conv2d.bias)
return out
# Accelerated version
def qop_forward(self, x):
quant_weight = quantize_fp8(self.conv2d.weight, self.weight_scale).to(torch.float8_e4m3fnuz)
fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
quant_input = quantize_fp8(x, fused_input_scale).to(torch.float8_e4m3fnuz)
quant_output = torch.nn.functional.conv2d(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.float32) # Convert inputs to FP32 to avoid F.conv2d quantizing the output to int8
output = dequantize_fp8(quant_output, (self.weight_scale * self.input_scale).view([1, (self.weight_scale * self.input_scale).nelement()] + [1]*(quant_output.ndim-2)))
output += self.conv2d.bias.view([1, self.conv2d.bias.nelement()] + [1]*(quant_output.ndim-2))
return output
def forward(self, x, qop=False):
if qop:
return self.qop_forward(x)
else:
return self.qdq_forward(x)
torch.manual_seed(0)
batch_size = 1
seq_len = 11
hidden_size = 21
output_size = 36
shape = 5
query = 2.*torch.rand((batch_size, seq_len, hidden_size)) - 1.
conv_input = 2.*torch.rand((batch_size, hidden_size, shape, shape)) - 1.
quant_params = {
"quant_linear": {
"smoothquant_mul": torch.randn(hidden_size).abs(),
"smoothquant_mul_shape": [1, 1, hidden_size],
"input_scale": torch.max(torch.abs(query)) / 240.,
"input_scale_shape": [],
"input_zp": 0.0,
"input_zp_shape": [],
"input_zp_dtype": "torch.float8_e4m3fnuz",
"weight_scale":torch.randn(output_size).abs(),
"weight_scale_shape": [output_size, 1]
},
"quant_conv": {
"smoothquant_mul": torch.randn(hidden_size).abs(),
"smoothquant_mul_shape": [1, hidden_size, 1, 1],
"input_scale": torch.max(torch.abs(query)) / 240.,
"input_scale_shape": [],
"input_zp": 0.0,
"input_zp_shape": [],
"input_zp_dtype": "torch.float8_e4m3fnuz",
"weight_scale":torch.randn(output_size).abs(),
"weight_scale_shape": [output_size, 1, 1, 1]
}
}
qlinear = QuantLinear(hidden_size, output_size, quant_params['quant_linear'])
o = qlinear(query)
q_o = qlinear(query, qop=True)
assert torch.allclose(o, q_o)
qconv = QuantConv2d(hidden_size, output_size, shape, quant_params['quant_conv'])
o = qconv(conv_input)
q_o = qconv(conv_input, qop=True)
assert torch.allclose(o, q_o, atol=1e-6)
|