|
import torch |
|
from math_model import QuantConv2d |
|
|
|
torch.manual_seed(0) |
|
|
|
batch_size = 1 |
|
out_ch = 8 |
|
in_ch = 4 |
|
k = 3 |
|
h = 5 |
|
w = 5 |
|
|
|
quant_params = { |
|
'smoothquant_mul': torch.rand((in_ch,)), |
|
'smoothquant_mul_shape': (1,in_ch,1,1), |
|
'weight_scale': torch.rand((out_ch,)), |
|
'weight_scale_shape': (out_ch,1,1,1), |
|
'weight_zp': torch.randint(-255, 0, (out_ch,)), |
|
'weight_zp_shape': (out_ch,1,1,1), |
|
'input_scale': torch.rand((1,)), |
|
'input_scale_shape': (1,), |
|
'input_zp': torch.zeros((1,)), |
|
'input_zp_shape': (1,), |
|
} |
|
|
|
print(quant_params) |
|
|
|
l = QuantConv2d(in_ch, out_ch, k, quant_params) |
|
i = torch.rand((batch_size,in_ch,h,w)) |
|
o_qdq = l(i) |
|
o_qop = l(i, qop=True) |
|
print(o_qdq.shape) |
|
print(o_qop.shape) |
|
print(o_qdq - o_qop) |
|
|