import argparse from functools import partial import mmcv import numpy as np import onnxruntime as rt import torch import torch._C import torch.serialization from mmcv import DictAction from mmcv.onnx import register_extra_symbolics from mmcv.runner import load_checkpoint from torch import nn from mmseg.apis import show_result_pyplot from mmseg.apis.inference import LoadImage from mmseg.datasets.pipelines import Compose from mmseg.models import build_segmentor torch.manual_seed(3) def _convert_batchnorm(module): module_output = module if isinstance(module, torch.nn.SyncBatchNorm): module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats) if module.affine: = = # keep requires_grad unchanged module_output.weight.requires_grad = module.weight.requires_grad module_output.bias.requires_grad = module.bias.requires_grad module_output.running_mean = module.running_mean module_output.running_var = module.running_var module_output.num_batches_tracked = module.num_batches_tracked for name, child in module.named_children(): module_output.add_module(name, _convert_batchnorm(child)) del module return module_output def _demo_mm_inputs(input_shape, num_classes): """Create a superset of inputs needed to run test or train batches. Args: input_shape (tuple): input batch dimensions num_classes (int): number of semantic classes """ (N, C, H, W) = input_shape rng = np.random.RandomState(0) imgs = rng.rand(*input_shape) segs = rng.randint( low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) img_metas = [{ 'img_shape': (H, W, C), 'ori_shape': (H, W, C), 'pad_shape': (H, W, C), 'filename': '<demo>.png', 'scale_factor': 1.0, 'flip': False, } for _ in range(N)] mm_inputs = { 'imgs': torch.FloatTensor(imgs).requires_grad_(True), 'img_metas': img_metas, 'gt_semantic_seg': torch.LongTensor(segs) } return mm_inputs def _prepare_input_img(img_path, test_pipeline, shape=None, rescale_shape=None): # build the data pipeline if shape is not None: test_pipeline[1]['img_scale'] = (shape[1], shape[0]) test_pipeline[1]['transforms'][0]['keep_ratio'] = False test_pipeline = [LoadImage()] + test_pipeline[1:] test_pipeline = Compose(test_pipeline) # prepare data data = dict(img=img_path) data = test_pipeline(data) imgs = data['img'] img_metas = [ for i in data['img_metas']] if rescale_shape is not None: for img_meta in img_metas: img_meta['ori_shape'] = tuple(rescale_shape) + (3, ) mm_inputs = {'imgs': imgs, 'img_metas': img_metas} return mm_inputs def _update_input_img(img_list, img_meta_list): # update img and its meta list N = img_list[0].size(0) img_meta = img_meta_list[0][0] img_shape = img_meta['img_shape'] ori_shape = img_meta['ori_shape'] pad_shape = img_meta['pad_shape'] new_img_meta_list = [[{ 'img_shape': img_shape, 'ori_shape': ori_shape, 'pad_shape': pad_shape, 'filename': img_meta['filename'], 'scale_factor': (img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2, 'flip': False, } for _ in range(N)]] return img_list, new_img_meta_list def pytorch2onnx(model, mm_inputs, opset_version=11, show=False, output_file='tmp.onnx', verify=False, dynamic_export=False): """Export Pytorch model to ONNX model and verify the outputs are same between Pytorch and ONNX. Args: model (nn.Module): Pytorch model we want to export. mm_inputs (dict): Contain the input tensors and img_metas information. opset_version (int): The onnx op version. Default: 11. show (bool): Whether print the computation graph. Default: False. output_file (string): The path to where we store the output ONNX model. Default: `tmp.onnx`. verify (bool): Whether compare the outputs between Pytorch and ONNX. Default: False. dynamic_export (bool): Whether to export ONNX with dynamic axis. Default: False. """ model.cpu().eval() test_mode = model.test_cfg.mode if isinstance(model.decode_head, nn.ModuleList): num_classes = model.decode_head[-1].num_classes else: num_classes = model.decode_head.num_classes imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') img_list = [img[None, :] for img in imgs] img_meta_list = [[img_meta] for img_meta in img_metas] # update img_meta img_list, img_meta_list = _update_input_img(img_list, img_meta_list) # replace original forward function origin_forward = model.forward model.forward = partial( model.forward, img_metas=img_meta_list, return_loss=False, rescale=True) dynamic_axes = None if dynamic_export: if test_mode == 'slide': dynamic_axes = {'input': {0: 'batch'}, 'output': {1: 'batch'}} else: dynamic_axes = { 'input': { 0: 'batch', 2: 'height', 3: 'width' }, 'output': { 1: 'batch', 2: 'height', 3: 'width' } } register_extra_symbolics(opset_version) with torch.no_grad(): torch.onnx.export( model, (img_list, ), output_file, input_names=['input'], output_names=['output'], export_params=True, keep_initializers_as_inputs=False, verbose=show, opset_version=opset_version, dynamic_axes=dynamic_axes) print(f'Successfully exported ONNX model: {output_file}') model.forward = origin_forward if verify: # check by onnx import onnx onnx_model = onnx.load(output_file) onnx.checker.check_model(onnx_model) if dynamic_export and test_mode == 'whole': # scale image for dynamic shape test img_list = [ nn.functional.interpolate(_, scale_factor=1.5) for _ in img_list ] # concate flip image for batch test flip_img_list = [_.flip(-1) for _ in img_list] img_list = [, flip_img), 0) for ori_img, flip_img in zip(img_list, flip_img_list) ] # update img_meta img_list, img_meta_list = _update_input_img( img_list, img_meta_list) # check the numerical value # get pytorch output with torch.no_grad(): pytorch_result = model(img_list, img_meta_list, return_loss=False) pytorch_result = np.stack(pytorch_result, 0) # get onnx output input_all = [ for node in onnx_model.graph.input] input_initializer = [ for node in onnx_model.graph.initializer ] net_feed_input = list(set(input_all) - set(input_initializer)) assert (len(net_feed_input) == 1) sess = rt.InferenceSession(output_file) onnx_result = None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0] # show segmentation results if show: import cv2 import os.path as osp img = img_meta_list[0][0]['filename'] if not osp.exists(img): img = imgs[0][:3, ...].permute(1, 2, 0) * 255 img = img.detach().numpy().astype(np.uint8) ori_shape = img.shape[:2] else: ori_shape = LoadImage()({'img': img})['ori_shape'] # resize onnx_result to ori_shape onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8), (ori_shape[1], ori_shape[0])) show_result_pyplot( model, img, (onnx_result_, ), palette=model.PALETTE, block=False, title='ONNXRuntime', opacity=0.5) # resize pytorch_result to ori_shape pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8), (ori_shape[1], ori_shape[0])) show_result_pyplot( model, img, (pytorch_result_, ), title='PyTorch', palette=model.PALETTE, opacity=0.5) # compare results np.testing.assert_allclose( pytorch_result.astype(np.float32) / num_classes, onnx_result.astype(np.float32) / num_classes, rtol=1e-5, atol=1e-5, err_msg='The outputs are different between Pytorch and ONNX') print('The outputs are same between Pytorch and ONNX') def parse_args(): parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX') parser.add_argument('config', help='test config file path') parser.add_argument('--checkpoint', help='checkpoint file', default=None) parser.add_argument( '--input-img', type=str, help='Images for input', default=None) parser.add_argument( '--show', action='store_true', help='show onnx graph and segmentation results') parser.add_argument( '--verify', action='store_true', help='verify the onnx model') parser.add_argument('--output-file', type=str, default='tmp.onnx') parser.add_argument('--opset-version', type=int, default=11) parser.add_argument( '--shape', type=int, nargs='+', default=None, help='input image height and width.') parser.add_argument( '--rescale_shape', type=int, nargs='+', default=None, help='output image rescale height and width, work for slide mode.') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='Override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') parser.add_argument( '--dynamic-export', action='store_true', help='Whether to export onnx with dynamic axis.') args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() cfg = mmcv.Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) cfg.model.pretrained = None if args.shape is None: img_scale = cfg.test_pipeline[1]['img_scale'] input_shape = (1, 3, img_scale[1], img_scale[0]) elif len(args.shape) == 1: input_shape = (1, 3, args.shape[0], args.shape[0]) elif len(args.shape) == 2: input_shape = ( 1, 3, ) + tuple(args.shape) else: raise ValueError('invalid input shape') test_mode = cfg.model.test_cfg.mode # build the model and load checkpoint cfg.model.train_cfg = None segmentor = build_segmentor( cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg')) # convert SyncBN to BN segmentor = _convert_batchnorm(segmentor) if args.checkpoint: checkpoint = load_checkpoint( segmentor, args.checkpoint, map_location='cpu') segmentor.CLASSES = checkpoint['meta']['CLASSES'] segmentor.PALETTE = checkpoint['meta']['PALETTE'] # read input or create dummpy input if args.input_img is not None: preprocess_shape = (input_shape[2], input_shape[3]) rescale_shape = None if args.rescale_shape is not None: rescale_shape = [args.rescale_shape[0], args.rescale_shape[1]] mm_inputs = _prepare_input_img( args.input_img,, shape=preprocess_shape, rescale_shape=rescale_shape) else: if isinstance(segmentor.decode_head, nn.ModuleList): num_classes = segmentor.decode_head[-1].num_classes else: num_classes = segmentor.decode_head.num_classes mm_inputs = _demo_mm_inputs(input_shape, num_classes) # convert model to onnx file pytorch2onnx( segmentor, mm_inputs, opset_version=args.opset_version,, output_file=args.output_file, verify=args.verify, dynamic_export=args.dynamic_export)