nickfraser commited on
Commit
673c9f2
1 Parent(s): eb5a5f6

[test] Fixed shapes to match new `quant_param.json`

Browse files
Files changed (2) hide show
  1. test_quant_conv2d.py +2 -2
  2. test_quant_linear.py +3 -3
test_quant_conv2d.py CHANGED
@@ -18,9 +18,9 @@ quant_params = {
18
  'weight_zp': torch.randint(-255, 0, (out_ch,)),
19
  'weight_zp_shape': (out_ch,1,1,1),
20
  'input_scale': torch.rand((1,)),
21
- 'input_scale_shape': (1,),
22
  'input_zp': torch.zeros((1,)),
23
- 'input_zp_shape': (1,),
24
  }
25
 
26
  print(quant_params)
 
18
  'weight_zp': torch.randint(-255, 0, (out_ch,)),
19
  'weight_zp_shape': (out_ch,1,1,1),
20
  'input_scale': torch.rand((1,)),
21
+ 'input_scale_shape': tuple(),
22
  'input_zp': torch.zeros((1,)),
23
+ 'input_zp_shape': tuple(),
24
  }
25
 
26
  print(quant_params)
test_quant_linear.py CHANGED
@@ -9,15 +9,15 @@ in_ch = 64
9
 
10
  quant_params = {
11
  'smoothquant_mul': torch.rand((in_ch,)),
12
- 'smoothquant_mul_shape': (in_ch,),
13
  'weight_scale': torch.rand((out_ch,)),
14
  'weight_scale_shape': (out_ch,1),
15
  'weight_zp': torch.randint(-255, 0, (out_ch,)),
16
  'weight_zp_shape': (out_ch,1),
17
  'input_scale': torch.rand((1,)),
18
- 'input_scale_shape': (1,),
19
  'input_zp': torch.zeros((1,)),
20
- 'input_zp_shape': (1,),
21
  }
22
 
23
  print(quant_params)
 
9
 
10
  quant_params = {
11
  'smoothquant_mul': torch.rand((in_ch,)),
12
+ 'smoothquant_mul_shape': (1,in_ch),
13
  'weight_scale': torch.rand((out_ch,)),
14
  'weight_scale_shape': (out_ch,1),
15
  'weight_zp': torch.randint(-255, 0, (out_ch,)),
16
  'weight_zp_shape': (out_ch,1),
17
  'input_scale': torch.rand((1,)),
18
+ 'input_scale_shape': tuple(),
19
  'input_zp': torch.zeros((1,)),
20
+ 'input_zp_shape': tuple(),
21
  }
22
 
23
  print(quant_params)