File size: 8,548 Bytes
d5dfd96 049c65f d5dfd96 eb5a5f6 d5dfd96 eb5a5f6 d5dfd96 eb5a5f6 d5dfd96 eb5a5f6 3f5851c 4024f9d dca9b6e eb5a5f6 161df88 eb5a5f6 d5dfd96 eb5a5f6 d5dfd96 eb5a5f6 d5dfd96 eb5a5f6 d5dfd96 eb5a5f6 d5dfd96 eb5a5f6 3f5851c 4024f9d eb5a5f6 4024f9d dca9b6e eb5a5f6 161df88 eb5a5f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import torch.nn as nn
import torch
def quantize(tensor, scale, zero_point, is_asym=False):
if is_asym:
clamp_min, clamp_max = torch.tensor(0.), torch.tensor(255.)
else:
clamp_min, clamp_max = torch.tensor(-128.), torch.tensor(127.)
quant_tensor = torch.clamp(torch.round(tensor/scale + zero_point), clamp_min, clamp_max)
return quant_tensor
def dequantize(tensor, scale, zero_point):
return (tensor - zero_point) * scale
class QuantLinear(nn.Module):
def __init__(self, in_ch, out_ch, quant_param):
super().__init__()
mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
self.register_buffer('mul_factor', mul_factor)
self.linear = nn.Linear(in_ch, out_ch)
weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
self.register_buffer('weight_scale', weight_scale)
self.register_buffer('weight_zp', weight_zp)
self.register_buffer('input_scale', input_scale)
self.register_buffer('input_zp', input_zp)
# I.e., "fake quantization"
def qdq_forward(self, x):
scaled_x = x * self.mul_factor
quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True)
quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp)
dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp)
out = torch.nn.functional.linear(dequantized_input, dequantized_weight, self.linear.bias)
return out
# Accelerated version
def qop_forward(self, x):
# With an integer linear kernel, if the weight zero point is not zero,
# A correction term must be calculated to correct the output.
# The correction term calculated as follows:
# - sum the input tensor across the dot-product dimentions: (e.g., `torch.sum(quant_input, dim=-1)`)
# - multiply this sum with every weight zero-point (e.g., `torch.sum(quant_input, dim=-1) * self.weight_zp`
# - Subtract from previous output (e.g., `quant_output -= torch.sum(quant_input, dim=-1) * self.weight_zp`)
# - All other code is just to make sure the broadcasting semantics work correctly
weight_zp_int8 = (self.weight_zp - 128).to(torch.int8).to(torch.float32) # Conversion from uint8 -> int8, can be computed offline
quant_weight = quantize(self.linear.weight, self.weight_scale, weight_zp_int8, is_asym=False).to(torch.int8)
fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
quant_input = quantize(x, fused_input_scale, self.input_zp, is_asym=False).to(torch.int8)
quant_output = torch.nn.functional.linear(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.int32) # Convert inputs to FP32 to avoid F.linear quantizing the output to int8
correction = torch.sum(quant_input, dim=-1, keepdim=True).to(torch.int32) * weight_zp_int8.to(torch.int8).view([1]*(quant_input.ndim-1) + [self.weight_zp.nelement()]) # Correct for weight zero-point
quant_output = quant_output - correction
output = dequantize(quant_output, (self.weight_scale * self.input_scale).view([1]*(quant_output.ndim-1) + [(self.weight_scale * self.input_scale).nelement()]), 0.0)
output += self.linear.bias
return output
def forward(self, x, qop=False):
if qop:
return self.qop_forward(x)
else:
return self.qdq_forward(x)
class QuantConv2d(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, quant_param):
super().__init__()
mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
self.register_buffer('mul_factor', mul_factor)
self.conv2d = nn.Conv2d(in_ch, out_ch, kernel_size)
weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
self.register_buffer('weight_scale', weight_scale)
self.register_buffer('weight_zp', weight_zp)
self.register_buffer('input_scale', input_scale)
self.register_buffer('input_zp', input_zp)
# I.e., "fake quantization"
def qdq_forward(self, x):
scaled_x = x * self.mul_factor
quant_weight = quantize(self.conv2d.weight, self.weight_scale, self.weight_zp, is_asym=True)
quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp)
dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp)
out = torch.nn.functional.conv2d(dequantized_input, dequantized_weight, self.conv2d.bias)
return out
# Accelerated version
def qop_forward(self, x):
# With an integer conv2d kernel, if the weight zero point is not zero,
# A correction term must be calculated to correct the output.
# Conceptually, it's identical to the linear case except that it's difficult
# to reduce the input across the dot-product dimension. This leaves us with two obvious options:
# 1. Manually compute the reduction via Im2Col -> `torch.sum`
# 2. Add an extra _output channel_ to the convolution with a kernel made from all ones (e.g., `torch.ones()`)
# In this example, I've used option #2.
# The correction term is then calculated as follows:
# - Add an extra output channel to the weight tensor with all values equal to 1 to calculate the sum (e.g., `torch.cat((quant_weight, torch.ones(shape)), dim=0)`)
# - Extract the sum from the output tensor (e.g., `sum = quant_output[:,-1,:,:]`)
# - multiply this sum with every weight zero-point (e.g., `sum * self.weight_zp`
# - Subtract from previous output (e.g., `quant_output -= sum * self.weight_zp`)
# - All other code is just to make sure the broadcasting semantics work correctly
weight_zp_int8 = (self.weight_zp - 128).to(torch.int8).to(torch.float32) # Conversion from uint8 -> int8, can be computed offline
quant_weight = quantize(self.conv2d.weight, self.weight_scale, weight_zp_int8, is_asym=False).to(torch.int8)
b_shape = list(quant_weight.shape) # Used for weight zero-point correction
b_shape[0] = 1 # Used for weight zero-point correction
weight_cat = torch.ones((1,1,1,1)).broadcast_to(b_shape).to(torch.int8) # Used for weight zero-point correction
quant_weight = torch.cat((quant_weight,weight_cat),dim=0).to(torch.int8) # Create extra output channel, used for weight zero-point correction
fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
quant_input = quantize(x, fused_input_scale, self.input_zp, is_asym=False).to(torch.int8)
quant_output = torch.nn.functional.conv2d(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.int32) # Convert inputs to FP32 to avoid F.conv2d quantizing the output to int8
correction = quant_output[:,-1,:,:] * weight_zp_int8.to(torch.int8).view([1, self.weight_zp.nelement()] + [1]*(quant_output.ndim-2)) # Correct zero-point for weight
quant_output = quant_output[:,:-1,:,:] - correction
output = dequantize(quant_output, (self.weight_scale * self.input_scale).view([1, (self.weight_scale * self.input_scale).nelement()] + [1]*(quant_output.ndim-2)), 0.0)
output += self.conv2d.bias.view([1, self.conv2d.bias.nelement()] + [1]*(quant_output.ndim-2))
return output
def forward(self, x, qop=False):
if qop:
return self.qop_forward(x)
else:
return self.qdq_forward(x)
|