StableVITON / preprocess /detectron2 /tests /test_export_onnx.py
rlawjdghek's picture
prep (#1)
61c2d32 verified
raw
history blame
8.69 kB
# Copyright (c) Facebook, Inc. and its affiliates.
import io
import unittest
import warnings
import torch
from torch.hub import _check_module_exists
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.export import STABLE_ONNX_OPSET_VERSION
from detectron2.export.flatten import TracingAdapter
from detectron2.export.torchscript_patch import patch_builtin_len
from detectron2.layers import ShapeSpec
from detectron2.modeling import build_model
from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead
from detectron2.structures import Boxes, Instances
from detectron2.utils.testing import (
_pytorch1111_symbolic_opset9_repeat_interleave,
_pytorch1111_symbolic_opset9_to,
get_sample_coco_image,
has_dynamic_axes,
random_boxes,
register_custom_op_onnx_export,
skipIfOnCPUCI,
skipIfUnsupportedMinOpsetVersion,
skipIfUnsupportedMinTorchVersion,
unregister_custom_op_onnx_export,
)
@unittest.skipIf(not _check_module_exists("onnx"), "ONNX not installed.")
@skipIfUnsupportedMinTorchVersion("1.10")
class TestONNXTracingExport(unittest.TestCase):
opset_version = STABLE_ONNX_OPSET_VERSION
def testMaskRCNNFPN(self):
def inference_func(model, images):
with warnings.catch_warnings(record=True):
inputs = [{"image": image} for image in images]
inst = model.inference(inputs, do_postprocess=False)[0]
return [{"instances": inst}]
self._test_model_zoo_from_config_path(
"COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", inference_func
)
@skipIfOnCPUCI
def testMaskRCNNC4(self):
def inference_func(model, image):
inputs = [{"image": image}]
return model.inference(inputs, do_postprocess=False)[0]
self._test_model_zoo_from_config_path(
"COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x.yaml", inference_func
)
@skipIfOnCPUCI
def testCascadeRCNN(self):
def inference_func(model, image):
inputs = [{"image": image}]
return model.inference(inputs, do_postprocess=False)[0]
self._test_model_zoo_from_config_path(
"Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml", inference_func
)
def testRetinaNet(self):
def inference_func(model, image):
return model.forward([{"image": image}])[0]["instances"]
self._test_model_zoo_from_config_path(
"COCO-Detection/retinanet_R_50_FPN_3x.yaml", inference_func
)
@skipIfOnCPUCI
def testMaskRCNNFPN_batched(self):
def inference_func(model, image1, image2):
inputs = [{"image": image1}, {"image": image2}]
return model.inference(inputs, do_postprocess=False)
self._test_model_zoo_from_config_path(
"COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", inference_func, batch=2
)
@skipIfUnsupportedMinOpsetVersion(16, STABLE_ONNX_OPSET_VERSION)
@skipIfUnsupportedMinTorchVersion("1.11.1")
def testMaskRCNNFPN_with_postproc(self):
def inference_func(model, image):
inputs = [{"image": image, "height": image.shape[1], "width": image.shape[2]}]
return model.inference(inputs, do_postprocess=True)[0]["instances"]
self._test_model_zoo_from_config_path(
"COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
inference_func,
)
def testKeypointHead(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = KRCNNConvDeconvUpsampleHead(
ShapeSpec(channels=4, height=14, width=14), num_keypoints=17, conv_dims=(4,)
)
def forward(self, x, predbox1, predbox2):
inst = [
Instances((100, 100), pred_boxes=Boxes(predbox1)),
Instances((100, 100), pred_boxes=Boxes(predbox2)),
]
ret = self.model(x, inst)
return tuple(x.pred_keypoints for x in ret)
model = M()
model.eval()
def gen_input(num1, num2):
feat = torch.randn((num1 + num2, 4, 14, 14))
box1 = random_boxes(num1)
box2 = random_boxes(num2)
return feat, box1, box2
with patch_builtin_len():
onnx_model = self._test_model(
model,
gen_input(1, 2),
input_names=["features", "pred_boxes", "pred_classes"],
output_names=["box1", "box2"],
dynamic_axes={
"features": {0: "batch", 1: "static_four", 2: "height", 3: "width"},
"pred_boxes": {0: "batch", 1: "static_four"},
"pred_classes": {0: "batch", 1: "static_four"},
"box1": {0: "num_instance", 1: "K", 2: "static_three"},
"box2": {0: "num_instance", 1: "K", 2: "static_three"},
},
)
# Although ONNX models are not executable by PyTorch to verify
# support of batches with different sizes, we can verify model's IR
# does not hard-code input and/or output shapes.
# TODO: Add tests with different batch sizes when detectron2's CI
# support ONNX Runtime backend.
assert has_dynamic_axes(onnx_model)
################################################################################
# Testcase internals - DO NOT add tests below this point
################################################################################
def setUp(self):
register_custom_op_onnx_export("::to", _pytorch1111_symbolic_opset9_to, 9, "1.11.1")
register_custom_op_onnx_export(
"::repeat_interleave",
_pytorch1111_symbolic_opset9_repeat_interleave,
9,
"1.11.1",
)
def tearDown(self):
unregister_custom_op_onnx_export("::to", 9, "1.11.1")
unregister_custom_op_onnx_export("::repeat_interleave", 9, "1.11.1")
def _test_model(
self,
model,
inputs,
inference_func=None,
opset_version=STABLE_ONNX_OPSET_VERSION,
save_onnx_graph_path=None,
**export_kwargs,
):
# Not imported in the beginning of file to prevent runtime errors
# for environments without ONNX.
# This testcase checks dependencies before running
import onnx # isort:skip
f = io.BytesIO()
adapter_model = TracingAdapter(model, inputs, inference_func)
adapter_model.eval()
with torch.no_grad():
try:
torch.onnx.enable_log()
except AttributeError:
# Older ONNX versions does not have this API
pass
torch.onnx.export(
adapter_model,
adapter_model.flattened_inputs,
f,
training=torch.onnx.TrainingMode.EVAL,
opset_version=opset_version,
verbose=True,
**export_kwargs,
)
onnx_model = onnx.load_from_string(f.getvalue())
assert onnx_model is not None
if save_onnx_graph_path:
onnx.save(onnx_model, save_onnx_graph_path)
return onnx_model
def _test_model_zoo_from_config_path(
self,
config_path,
inference_func,
batch=1,
opset_version=STABLE_ONNX_OPSET_VERSION,
save_onnx_graph_path=None,
**export_kwargs,
):
model = model_zoo.get(config_path, trained=True)
image = get_sample_coco_image()
inputs = tuple(image.clone() for _ in range(batch))
return self._test_model(
model, inputs, inference_func, opset_version, save_onnx_graph_path, **export_kwargs
)
def _test_model_from_config_path(
self,
config_path,
inference_func,
batch=1,
opset_version=STABLE_ONNX_OPSET_VERSION,
save_onnx_graph_path=None,
**export_kwargs,
):
from projects.PointRend import point_rend # isort:skip
cfg = get_cfg()
cfg.DATALOADER.NUM_WORKERS = 0
point_rend.add_pointrend_config(cfg)
cfg.merge_from_file(config_path)
cfg.freeze()
model = build_model(cfg)
image = get_sample_coco_image()
inputs = tuple(image.clone() for _ in range(batch))
return self._test_model(
model, inputs, inference_func, opset_version, save_onnx_graph_path, **export_kwargs
)