hangyang-amd
commited on
Commit
·
3f2fcc7
1
Parent(s):
0b5f4ac
Update infer_onnx.py
Browse files- infer_onnx.py +3 -0
infer_onnx.py
CHANGED
@@ -25,6 +25,7 @@ import torchvision.transforms as transforms
|
|
25 |
parser = argparse.ArgumentParser()
|
26 |
parser.add_argument('--onnx_path', type=str, default="EfficientNet_int.onnx", required=False)
|
27 |
parser.add_argument('--image_path', type=str, required=True)
|
|
|
28 |
parser.add_argument(
|
29 |
"--ipu",
|
30 |
action="store_true",
|
@@ -51,6 +52,8 @@ def read_image():
|
|
51 |
normalize,
|
52 |
])
|
53 |
img_tensor = transform(image).unsqueeze(0)
|
|
|
|
|
54 |
return img_tensor.numpy()
|
55 |
|
56 |
|
|
|
25 |
parser = argparse.ArgumentParser()
|
26 |
parser.add_argument('--onnx_path', type=str, default="EfficientNet_int.onnx", required=False)
|
27 |
parser.add_argument('--image_path', type=str, required=True)
|
28 |
+
parser.add_argument('--data_format', type=str, choices=["nchw", "nhwc"], default="nchw")
|
29 |
parser.add_argument(
|
30 |
"--ipu",
|
31 |
action="store_true",
|
|
|
52 |
normalize,
|
53 |
])
|
54 |
img_tensor = transform(image).unsqueeze(0)
|
55 |
+
if args.data_format == "nhwc":
|
56 |
+
img_tensor = transform(image).unsqueeze(0).transpose(1, 3)
|
57 |
return img_tensor.numpy()
|
58 |
|
59 |
|