import warnings from os import path as osp import numpy as np import onnx import onnxruntime as ort import torch import torch.nn as nn ort_custom_op_path = '' try: from mmcv.ops import get_onnxruntime_op_path ort_custom_op_path = get_onnxruntime_op_path() except (ImportError, ModuleNotFoundError): warnings.warn('If input model has custom op from mmcv, \ you may have to build mmcv with ONNXRuntime from source.') class WrapFunction(nn.Module): """Wrap the function to be tested for torch.onnx.export tracking.""" def __init__(self, wrapped_function): super(WrapFunction, self).__init__() self.wrapped_function = wrapped_function def forward(self, *args, **kwargs): return self.wrapped_function(*args, **kwargs) def ort_validate(model, feats, onnx_io='tmp.onnx'): """Validate the output of the onnxruntime backend is the same as the output generated by torch. Args: model (nn.Module): the model to be verified feats (list(torch.Tensor) | torch.Tensor): the input of model onnx_io (str): the name of onnx output file """ model.cpu().eval() with torch.no_grad(): torch.onnx.export( model, feats, onnx_io, export_params=True, keep_initializers_as_inputs=True, do_constant_folding=True, verbose=False, opset_version=11) onnx_outputs = verify_model(feats) torch_outputs = convert_result_list(model.forward(feats)) torch_outputs = [ torch_output.detach().numpy() for torch_output in torch_outputs ] # match torch_outputs and onnx_outputs for i in range(len(onnx_outputs)): np.testing.assert_allclose( torch_outputs[i], onnx_outputs[i], rtol=1e-03, atol=1e-05) def verify_model(feat, onnx_io='tmp.onnx'): """Run the model in onnxruntime env. Args: feat (list[Tensor]): A list of tensors from torch.rand, each is a 4D-tensor. Returns: list[np.array]: onnxruntime infer result, each is a np.array """ onnx_model = onnx.load(onnx_io) onnx.checker.check_model(onnx_model) session_options = ort.SessionOptions() # register custom op for onnxruntime if osp.exists(ort_custom_op_path): session_options.register_custom_ops_library(ort_custom_op_path) sess = ort.InferenceSession(onnx_io, session_options) if isinstance(feat, torch.Tensor): onnx_outputs = sess.run(None, {sess.get_inputs()[0].name: feat.numpy()}) else: onnx_outputs = sess.run(None, { sess.get_inputs()[i].name: feat[i].numpy() for i in range(len(feat)) }) return onnx_outputs def convert_result_list(outputs): """Convert the torch forward outputs containing tuple or list to a list only containing torch.Tensor. Args: output (list(Tensor) | tuple(list(Tensor) | ...): the outputs in torch env, maybe containing nested structures such as list or tuple. Returns: list(Tensor): a list only containing torch.Tensor """ # recursive end condition if isinstance(outputs, torch.Tensor): return [outputs] ret = [] for sub in outputs: ret += convert_result_list(sub) return ret