File size: 7,134 Bytes
6f59b43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch.nn as nn
import torch

def quantize_fp8(tensor: torch.Tensor, scale: torch.Tensor):
    dtype = tensor.dtype
    clamp_min, clamp_max = torch.tensor(-240., dtype=dtype), torch.tensor(240.,  dtype=dtype)
    quant_tensor = torch.clamp((tensor/scale), clamp_min, clamp_max).to(torch.float8_e4m3fnuz).to(dtype)
    return quant_tensor

def dequantize_fp8(tensor: torch.Tensor, scale: torch.Tensor):
    return tensor * scale


class QuantLinear(nn.Module):
    def __init__(self, in_ch, out_ch, quant_param):
        super().__init__()
        mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
        self.register_buffer('mul_factor', mul_factor)
        self.linear = nn.Linear(in_ch, out_ch)
        weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
        # weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
        # assert quant_param['weight_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Weight Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['weight_zp_dype']}"
        input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
        input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
        assert quant_param['input_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Input Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['input_zp_dype']}"
        self.register_buffer('weight_scale', weight_scale)
        # self.register_buffer('weight_zp', weight_zp)
        self.register_buffer('input_scale', input_scale)
        self.register_buffer('input_zp', input_zp)

    # I.e., "fake quantization"
    def qdq_forward(self, x):
        print(self.mul_factor.shape)
        scaled_x = x * self.mul_factor
        quant_weight = quantize_fp8(self.linear.weight, self.weight_scale)
        quant_input = quantize_fp8(scaled_x, self.input_scale)
        dequantized_weight = dequantize_fp8(quant_weight, self.weight_scale)
        dequantized_input = dequantize_fp8(quant_input, self.input_scale)
        out = torch.nn.functional.linear(dequantized_input, dequantized_weight, self.linear.bias)
        return out

    # Accelerated version
    def qop_forward(self, x):
        quant_weight = quantize_fp8(self.linear.weight, self.weight_scale).to(torch.float8_e4m3fnuz)
        fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
        quant_input = quantize_fp8(x, fused_input_scale).to(torch.float8_e4m3fnuz)
        quant_output = torch.nn.functional.linear(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.float32) # Convert inputs to FP32 to avoid F.linear quantizing the output to int8
        output = dequantize_fp8(quant_output, (self.weight_scale * self.input_scale).view([1]*(quant_output.ndim-1) + [(self.weight_scale * self.input_scale).nelement()]))
        output += self.linear.bias
        return output

    def forward(self, x, qop=False):
        if qop:
            return self.qop_forward(x)
        else:
            return self.qdq_forward(x)

class QuantConv2d(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, quant_param):
        super().__init__()
        mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
        self.register_buffer('mul_factor', mul_factor)
        self.conv2d = nn.Conv2d(in_ch, out_ch, kernel_size)
        weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])

        input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
        input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
        assert quant_param['input_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Input Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['input_zp_dype']}"
        self.register_buffer('weight_scale', weight_scale)
        self.register_buffer('input_scale', input_scale)
        self.register_buffer('input_zp', input_zp)

    # I.e., "fake quantization"
    def qdq_forward(self, x):
        scaled_x = x * self.mul_factor
        quant_weight = quantize_fp8(self.conv2d.weight, self.weight_scale)
        quant_input = quantize_fp8(scaled_x, self.input_scale)
        dequantized_weight = dequantize_fp8(quant_weight, self.weight_scale)
        dequantized_input = dequantize_fp8(quant_input, self.input_scale)
        out = torch.nn.functional.conv2d(dequantized_input, dequantized_weight, self.conv2d.bias)
        return out

    # Accelerated version
    def qop_forward(self, x):
        quant_weight = quantize_fp8(self.conv2d.weight, self.weight_scale).to(torch.float8_e4m3fnuz)
        fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
        quant_input = quantize_fp8(x, fused_input_scale).to(torch.float8_e4m3fnuz)
        quant_output = torch.nn.functional.conv2d(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.float32) # Convert inputs to FP32 to avoid F.conv2d quantizing the output to int8
        output = dequantize_fp8(quant_output, (self.weight_scale * self.input_scale).view([1, (self.weight_scale * self.input_scale).nelement()] + [1]*(quant_output.ndim-2)))
        output += self.conv2d.bias.view([1, self.conv2d.bias.nelement()] + [1]*(quant_output.ndim-2))
        return output

    def forward(self, x, qop=False):
        if qop:
            return self.qop_forward(x)
        else:
            return self.qdq_forward(x)


torch.manual_seed(0)

batch_size = 1
seq_len = 11
hidden_size = 21
output_size = 36
shape = 5
query = 2.*torch.rand((batch_size, seq_len, hidden_size)) - 1.
conv_input = 2.*torch.rand((batch_size, hidden_size, shape, shape)) - 1.

quant_params = {
  "quant_linear": {
    "smoothquant_mul": torch.randn(hidden_size).abs(),
    "smoothquant_mul_shape": [1, 1, hidden_size],
    "input_scale": torch.max(torch.abs(query)) / 240.,
    "input_scale_shape": [],
    "input_zp": 0.0,
    "input_zp_shape": [],
    "input_zp_dtype": "torch.float8_e4m3fnuz",
    "weight_scale":torch.randn(output_size).abs(),
    "weight_scale_shape": [output_size, 1]
  },
  "quant_conv": {
    "smoothquant_mul": torch.randn(hidden_size).abs(),
    "smoothquant_mul_shape": [1, hidden_size, 1, 1],
    "input_scale": torch.max(torch.abs(query)) / 240.,
    "input_scale_shape": [],
    "input_zp": 0.0,
    "input_zp_shape": [],
    "input_zp_dtype": "torch.float8_e4m3fnuz",
    "weight_scale":torch.randn(output_size).abs(),
    "weight_scale_shape": [output_size, 1, 1, 1]

}
}

qlinear = QuantLinear(hidden_size, output_size, quant_params['quant_linear'])
o = qlinear(query)
q_o = qlinear(query, qop=True)
assert torch.allclose(o, q_o)
qconv = QuantConv2d(hidden_size, output_size, shape, quant_params['quant_conv'])
o = qconv(conv_input)
q_o = qconv(conv_input, qop=True)
assert torch.allclose(o, q_o, atol=1e-6)