File size: 941 Bytes
ddadf19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# coding: utf-8

__author__ = 'cleardusk'

import sys

sys.path.append('..')

import torch
import tddfa.models
from tddfa.utils.tddfa_util import load_model


def convert_to_onnx(**kvs):
    # 1. load model
    size = kvs.get('size', 120)
    model = getattr(models, kvs.get('arch'))(
        num_classes=kvs.get('num_params', 62),
        widen_factor=kvs.get('widen_factor', 1),
        size=size,
        mode=kvs.get('mode', 'small')
    )
    checkpoint_fp = kvs.get('checkpoint_fp')
    model = load_model(model, checkpoint_fp)
    model.eval()

    # 2. convert
    batch_size = 1
    dummy_input = torch.randn(batch_size, 3, size, size)
    wfp = checkpoint_fp.replace('.pth', '.onnx')
    torch.onnx.export(
        model,
        (dummy_input, ),
        wfp,
        input_names=['input'],
        output_names=['output'],
        do_constant_folding=True
    )
    print(f'Convert {checkpoint_fp} to {wfp} done.')
    return wfp