File size: 3,360 Bytes
e26e560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import warnings
from os import path as osp

import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.nn as nn

ort_custom_op_path = ''
try:
    from mmcv.ops import get_onnxruntime_op_path
    ort_custom_op_path = get_onnxruntime_op_path()
except (ImportError, ModuleNotFoundError):
    warnings.warn('If input model has custom op from mmcv, \
        you may have to build mmcv with ONNXRuntime from source.')


class WrapFunction(nn.Module):
    """Wrap the function to be tested for torch.onnx.export tracking."""

    def __init__(self, wrapped_function):
        super(WrapFunction, self).__init__()
        self.wrapped_function = wrapped_function

    def forward(self, *args, **kwargs):
        return self.wrapped_function(*args, **kwargs)


def ort_validate(model, feats, onnx_io='tmp.onnx'):
    """Validate the output of the onnxruntime backend is the same as the output
    generated by torch.

    Args:
        model (nn.Module): the model to be verified
        feats (list(torch.Tensor) | torch.Tensor): the input of model
        onnx_io (str): the name of onnx output file
    """
    model.cpu().eval()
    with torch.no_grad():
        torch.onnx.export(
            model,
            feats,
            onnx_io,
            export_params=True,
            keep_initializers_as_inputs=True,
            do_constant_folding=True,
            verbose=False,
            opset_version=11)

    onnx_outputs = verify_model(feats)

    torch_outputs = convert_result_list(model.forward(feats))
    torch_outputs = [
        torch_output.detach().numpy() for torch_output in torch_outputs
    ]

    # match torch_outputs and onnx_outputs
    for i in range(len(onnx_outputs)):
        np.testing.assert_allclose(
            torch_outputs[i], onnx_outputs[i], rtol=1e-03, atol=1e-05)


def verify_model(feat, onnx_io='tmp.onnx'):
    """Run the model in onnxruntime env.

    Args:
        feat (list[Tensor]): A list of tensors from torch.rand,
            each is a 4D-tensor.

    Returns:
        list[np.array]: onnxruntime infer result, each is a np.array
    """

    onnx_model = onnx.load(onnx_io)
    onnx.checker.check_model(onnx_model)

    session_options = ort.SessionOptions()
    # register custom op for onnxruntime
    if osp.exists(ort_custom_op_path):
        session_options.register_custom_ops_library(ort_custom_op_path)
    sess = ort.InferenceSession(onnx_io, session_options)
    if isinstance(feat, torch.Tensor):
        onnx_outputs = sess.run(None,
                                {sess.get_inputs()[0].name: feat.numpy()})
    else:
        onnx_outputs = sess.run(None, {
            sess.get_inputs()[i].name: feat[i].numpy()
            for i in range(len(feat))
        })
    return onnx_outputs


def convert_result_list(outputs):
    """Convert the torch forward outputs containing tuple or list to a list
    only containing torch.Tensor.

    Args:
        output (list(Tensor) | tuple(list(Tensor) | ...): the outputs
        in torch env, maybe containing nested structures such as list
        or tuple.

    Returns:
        list(Tensor): a list only containing torch.Tensor
    """
    # recursive end condition
    if isinstance(outputs, torch.Tensor):
        return [outputs]

    ret = []
    for sub in outputs:
        ret += convert_result_list(sub)
    return ret