GiusFra commited on
Commit
6f59b43
·
verified ·
1 Parent(s): d9e66a0

Create math_model.py

Browse files
Files changed (1) hide show
  1. math_model.py +143 -0
math_model.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ def quantize_fp8(tensor: torch.Tensor, scale: torch.Tensor):
5
+ dtype = tensor.dtype
6
+ clamp_min, clamp_max = torch.tensor(-240., dtype=dtype), torch.tensor(240., dtype=dtype)
7
+ quant_tensor = torch.clamp((tensor/scale), clamp_min, clamp_max).to(torch.float8_e4m3fnuz).to(dtype)
8
+ return quant_tensor
9
+
10
+ def dequantize_fp8(tensor: torch.Tensor, scale: torch.Tensor):
11
+ return tensor * scale
12
+
13
+
14
+ class QuantLinear(nn.Module):
15
+ def __init__(self, in_ch, out_ch, quant_param):
16
+ super().__init__()
17
+ mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
18
+ self.register_buffer('mul_factor', mul_factor)
19
+ self.linear = nn.Linear(in_ch, out_ch)
20
+ weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
21
+ # weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
22
+ # assert quant_param['weight_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Weight Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['weight_zp_dype']}"
23
+ input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
24
+ input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
25
+ assert quant_param['input_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Input Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['input_zp_dype']}"
26
+ self.register_buffer('weight_scale', weight_scale)
27
+ # self.register_buffer('weight_zp', weight_zp)
28
+ self.register_buffer('input_scale', input_scale)
29
+ self.register_buffer('input_zp', input_zp)
30
+
31
+ # I.e., "fake quantization"
32
+ def qdq_forward(self, x):
33
+ print(self.mul_factor.shape)
34
+ scaled_x = x * self.mul_factor
35
+ quant_weight = quantize_fp8(self.linear.weight, self.weight_scale)
36
+ quant_input = quantize_fp8(scaled_x, self.input_scale)
37
+ dequantized_weight = dequantize_fp8(quant_weight, self.weight_scale)
38
+ dequantized_input = dequantize_fp8(quant_input, self.input_scale)
39
+ out = torch.nn.functional.linear(dequantized_input, dequantized_weight, self.linear.bias)
40
+ return out
41
+
42
+ # Accelerated version
43
+ def qop_forward(self, x):
44
+ quant_weight = quantize_fp8(self.linear.weight, self.weight_scale).to(torch.float8_e4m3fnuz)
45
+ fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
46
+ quant_input = quantize_fp8(x, fused_input_scale).to(torch.float8_e4m3fnuz)
47
+ quant_output = torch.nn.functional.linear(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.float32) # Convert inputs to FP32 to avoid F.linear quantizing the output to int8
48
+ output = dequantize_fp8(quant_output, (self.weight_scale * self.input_scale).view([1]*(quant_output.ndim-1) + [(self.weight_scale * self.input_scale).nelement()]))
49
+ output += self.linear.bias
50
+ return output
51
+
52
+ def forward(self, x, qop=False):
53
+ if qop:
54
+ return self.qop_forward(x)
55
+ else:
56
+ return self.qdq_forward(x)
57
+
58
+ class QuantConv2d(nn.Module):
59
+ def __init__(self, in_ch, out_ch, kernel_size, quant_param):
60
+ super().__init__()
61
+ mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
62
+ self.register_buffer('mul_factor', mul_factor)
63
+ self.conv2d = nn.Conv2d(in_ch, out_ch, kernel_size)
64
+ weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
65
+
66
+ input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
67
+ input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
68
+ assert quant_param['input_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Input Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['input_zp_dype']}"
69
+ self.register_buffer('weight_scale', weight_scale)
70
+ self.register_buffer('input_scale', input_scale)
71
+ self.register_buffer('input_zp', input_zp)
72
+
73
+ # I.e., "fake quantization"
74
+ def qdq_forward(self, x):
75
+ scaled_x = x * self.mul_factor
76
+ quant_weight = quantize_fp8(self.conv2d.weight, self.weight_scale)
77
+ quant_input = quantize_fp8(scaled_x, self.input_scale)
78
+ dequantized_weight = dequantize_fp8(quant_weight, self.weight_scale)
79
+ dequantized_input = dequantize_fp8(quant_input, self.input_scale)
80
+ out = torch.nn.functional.conv2d(dequantized_input, dequantized_weight, self.conv2d.bias)
81
+ return out
82
+
83
+ # Accelerated version
84
+ def qop_forward(self, x):
85
+ quant_weight = quantize_fp8(self.conv2d.weight, self.weight_scale).to(torch.float8_e4m3fnuz)
86
+ fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
87
+ quant_input = quantize_fp8(x, fused_input_scale).to(torch.float8_e4m3fnuz)
88
+ quant_output = torch.nn.functional.conv2d(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.float32) # Convert inputs to FP32 to avoid F.conv2d quantizing the output to int8
89
+ output = dequantize_fp8(quant_output, (self.weight_scale * self.input_scale).view([1, (self.weight_scale * self.input_scale).nelement()] + [1]*(quant_output.ndim-2)))
90
+ output += self.conv2d.bias.view([1, self.conv2d.bias.nelement()] + [1]*(quant_output.ndim-2))
91
+ return output
92
+
93
+ def forward(self, x, qop=False):
94
+ if qop:
95
+ return self.qop_forward(x)
96
+ else:
97
+ return self.qdq_forward(x)
98
+
99
+
100
+ torch.manual_seed(0)
101
+
102
+ batch_size = 1
103
+ seq_len = 11
104
+ hidden_size = 21
105
+ output_size = 36
106
+ shape = 5
107
+ query = 2.*torch.rand((batch_size, seq_len, hidden_size)) - 1.
108
+ conv_input = 2.*torch.rand((batch_size, hidden_size, shape, shape)) - 1.
109
+
110
+ quant_params = {
111
+ "quant_linear": {
112
+ "smoothquant_mul": torch.randn(hidden_size).abs(),
113
+ "smoothquant_mul_shape": [1, 1, hidden_size],
114
+ "input_scale": torch.max(torch.abs(query)) / 240.,
115
+ "input_scale_shape": [],
116
+ "input_zp": 0.0,
117
+ "input_zp_shape": [],
118
+ "input_zp_dtype": "torch.float8_e4m3fnuz",
119
+ "weight_scale":torch.randn(output_size).abs(),
120
+ "weight_scale_shape": [output_size, 1]
121
+ },
122
+ "quant_conv": {
123
+ "smoothquant_mul": torch.randn(hidden_size).abs(),
124
+ "smoothquant_mul_shape": [1, hidden_size, 1, 1],
125
+ "input_scale": torch.max(torch.abs(query)) / 240.,
126
+ "input_scale_shape": [],
127
+ "input_zp": 0.0,
128
+ "input_zp_shape": [],
129
+ "input_zp_dtype": "torch.float8_e4m3fnuz",
130
+ "weight_scale":torch.randn(output_size).abs(),
131
+ "weight_scale_shape": [output_size, 1, 1, 1]
132
+
133
+ }
134
+ }
135
+
136
+ qlinear = QuantLinear(hidden_size, output_size, quant_params['quant_linear'])
137
+ o = qlinear(query)
138
+ q_o = qlinear(query, qop=True)
139
+ assert torch.allclose(o, q_o)
140
+ qconv = QuantConv2d(hidden_size, output_size, shape, quant_params['quant_conv'])
141
+ o = qconv(conv_input)
142
+ q_o = qconv(conv_input, qop=True)
143
+ assert torch.allclose(o, q_o, atol=1e-6)