import torch from math_model import QuantConv2d torch.manual_seed(0) batch_size = 1 out_ch = 8 in_ch = 4 k = 3 h = 5 w = 5 quant_params = { 'smoothquant_mul': torch.rand((in_ch,)), 'smoothquant_mul_shape': (1,in_ch,1,1), 'weight_scale': torch.rand((out_ch,)), 'weight_scale_shape': (out_ch,1,1,1), 'weight_zp': torch.randint(-255, 0, (out_ch,)), 'weight_zp_shape': (out_ch,1,1,1), 'input_scale': torch.rand((1,)), 'input_scale_shape': (1,), 'input_zp': torch.zeros((1,)), 'input_zp_shape': (1,), } print(quant_params) l = QuantConv2d(in_ch, out_ch, k, quant_params) i = torch.rand((batch_size,in_ch,h,w)) o_qdq = l(i) o_qop = l(i, qop=True) print(o_qdq.shape) print(o_qop.shape) print(o_qdq - o_qop)