ZJF-Thunder
添加文件
e26e560
import warnings
from os import path as osp
import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.nn as nn
ort_custom_op_path = ''
try:
from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
except (ImportError, ModuleNotFoundError):
warnings.warn('If input model has custom op from mmcv, \
you may have to build mmcv with ONNXRuntime from source.')
class WrapFunction(nn.Module):
"""Wrap the function to be tested for torch.onnx.export tracking."""
def __init__(self, wrapped_function):
super(WrapFunction, self).__init__()
self.wrapped_function = wrapped_function
def forward(self, *args, **kwargs):
return self.wrapped_function(*args, **kwargs)
def ort_validate(model, feats, onnx_io='tmp.onnx'):
"""Validate the output of the onnxruntime backend is the same as the output
generated by torch.
Args:
model (nn.Module): the model to be verified
feats (list(torch.Tensor) | torch.Tensor): the input of model
onnx_io (str): the name of onnx output file
"""
model.cpu().eval()
with torch.no_grad():
torch.onnx.export(
model,
feats,
onnx_io,
export_params=True,
keep_initializers_as_inputs=True,
do_constant_folding=True,
verbose=False,
opset_version=11)
onnx_outputs = verify_model(feats)
torch_outputs = convert_result_list(model.forward(feats))
torch_outputs = [
torch_output.detach().numpy() for torch_output in torch_outputs
]
# match torch_outputs and onnx_outputs
for i in range(len(onnx_outputs)):
np.testing.assert_allclose(
torch_outputs[i], onnx_outputs[i], rtol=1e-03, atol=1e-05)
def verify_model(feat, onnx_io='tmp.onnx'):
"""Run the model in onnxruntime env.
Args:
feat (list[Tensor]): A list of tensors from torch.rand,
each is a 4D-tensor.
Returns:
list[np.array]: onnxruntime infer result, each is a np.array
"""
onnx_model = onnx.load(onnx_io)
onnx.checker.check_model(onnx_model)
session_options = ort.SessionOptions()
# register custom op for onnxruntime
if osp.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
sess = ort.InferenceSession(onnx_io, session_options)
if isinstance(feat, torch.Tensor):
onnx_outputs = sess.run(None,
{sess.get_inputs()[0].name: feat.numpy()})
else:
onnx_outputs = sess.run(None, {
sess.get_inputs()[i].name: feat[i].numpy()
for i in range(len(feat))
})
return onnx_outputs
def convert_result_list(outputs):
"""Convert the torch forward outputs containing tuple or list to a list
only containing torch.Tensor.
Args:
output (list(Tensor) | tuple(list(Tensor) | ...): the outputs
in torch env, maybe containing nested structures such as list
or tuple.
Returns:
list(Tensor): a list only containing torch.Tensor
"""
# recursive end condition
if isinstance(outputs, torch.Tensor):
return [outputs]
ret = []
for sub in outputs:
ret += convert_result_list(sub)
return ret