amd
/

Image Classification
ONNX
RyzenAI
hangyang-amd commited on
Commit
3f2fcc7
·
1 Parent(s): 0b5f4ac

Update infer_onnx.py

Browse files
Files changed (1) hide show
  1. infer_onnx.py +3 -0
infer_onnx.py CHANGED
@@ -25,6 +25,7 @@ import torchvision.transforms as transforms
25
  parser = argparse.ArgumentParser()
26
  parser.add_argument('--onnx_path', type=str, default="EfficientNet_int.onnx", required=False)
27
  parser.add_argument('--image_path', type=str, required=True)
 
28
  parser.add_argument(
29
  "--ipu",
30
  action="store_true",
@@ -51,6 +52,8 @@ def read_image():
51
  normalize,
52
  ])
53
  img_tensor = transform(image).unsqueeze(0)
 
 
54
  return img_tensor.numpy()
55
 
56
 
 
25
  parser = argparse.ArgumentParser()
26
  parser.add_argument('--onnx_path', type=str, default="EfficientNet_int.onnx", required=False)
27
  parser.add_argument('--image_path', type=str, required=True)
28
+ parser.add_argument('--data_format', type=str, choices=["nchw", "nhwc"], default="nchw")
29
  parser.add_argument(
30
  "--ipu",
31
  action="store_true",
 
52
  normalize,
53
  ])
54
  img_tensor = transform(image).unsqueeze(0)
55
+ if args.data_format == "nhwc":
56
+ img_tensor = transform(image).unsqueeze(0).transpose(1, 3)
57
  return img_tensor.numpy()
58
 
59