|
import warnings |
|
from os import path as osp |
|
|
|
import numpy as np |
|
import onnx |
|
import onnxruntime as ort |
|
import torch |
|
import torch.nn as nn |
|
|
|
ort_custom_op_path = '' |
|
try: |
|
from mmcv.ops import get_onnxruntime_op_path |
|
ort_custom_op_path = get_onnxruntime_op_path() |
|
except (ImportError, ModuleNotFoundError): |
|
warnings.warn('If input model has custom op from mmcv, \ |
|
you may have to build mmcv with ONNXRuntime from source.') |
|
|
|
|
|
class WrapFunction(nn.Module): |
|
"""Wrap the function to be tested for torch.onnx.export tracking.""" |
|
|
|
def __init__(self, wrapped_function): |
|
super(WrapFunction, self).__init__() |
|
self.wrapped_function = wrapped_function |
|
|
|
def forward(self, *args, **kwargs): |
|
return self.wrapped_function(*args, **kwargs) |
|
|
|
|
|
def ort_validate(model, feats, onnx_io='tmp.onnx'): |
|
"""Validate the output of the onnxruntime backend is the same as the output |
|
generated by torch. |
|
|
|
Args: |
|
model (nn.Module): the model to be verified |
|
feats (list(torch.Tensor) | torch.Tensor): the input of model |
|
onnx_io (str): the name of onnx output file |
|
""" |
|
model.cpu().eval() |
|
with torch.no_grad(): |
|
torch.onnx.export( |
|
model, |
|
feats, |
|
onnx_io, |
|
export_params=True, |
|
keep_initializers_as_inputs=True, |
|
do_constant_folding=True, |
|
verbose=False, |
|
opset_version=11) |
|
|
|
onnx_outputs = verify_model(feats) |
|
|
|
torch_outputs = convert_result_list(model.forward(feats)) |
|
torch_outputs = [ |
|
torch_output.detach().numpy() for torch_output in torch_outputs |
|
] |
|
|
|
|
|
for i in range(len(onnx_outputs)): |
|
np.testing.assert_allclose( |
|
torch_outputs[i], onnx_outputs[i], rtol=1e-03, atol=1e-05) |
|
|
|
|
|
def verify_model(feat, onnx_io='tmp.onnx'): |
|
"""Run the model in onnxruntime env. |
|
|
|
Args: |
|
feat (list[Tensor]): A list of tensors from torch.rand, |
|
each is a 4D-tensor. |
|
|
|
Returns: |
|
list[np.array]: onnxruntime infer result, each is a np.array |
|
""" |
|
|
|
onnx_model = onnx.load(onnx_io) |
|
onnx.checker.check_model(onnx_model) |
|
|
|
session_options = ort.SessionOptions() |
|
|
|
if osp.exists(ort_custom_op_path): |
|
session_options.register_custom_ops_library(ort_custom_op_path) |
|
sess = ort.InferenceSession(onnx_io, session_options) |
|
if isinstance(feat, torch.Tensor): |
|
onnx_outputs = sess.run(None, |
|
{sess.get_inputs()[0].name: feat.numpy()}) |
|
else: |
|
onnx_outputs = sess.run(None, { |
|
sess.get_inputs()[i].name: feat[i].numpy() |
|
for i in range(len(feat)) |
|
}) |
|
return onnx_outputs |
|
|
|
|
|
def convert_result_list(outputs): |
|
"""Convert the torch forward outputs containing tuple or list to a list |
|
only containing torch.Tensor. |
|
|
|
Args: |
|
output (list(Tensor) | tuple(list(Tensor) | ...): the outputs |
|
in torch env, maybe containing nested structures such as list |
|
or tuple. |
|
|
|
Returns: |
|
list(Tensor): a list only containing torch.Tensor |
|
""" |
|
|
|
if isinstance(outputs, torch.Tensor): |
|
return [outputs] |
|
|
|
ret = [] |
|
for sub in outputs: |
|
ret += convert_result_list(sub) |
|
return ret |
|
|