happyme531 commited on
Commit
ca6c51e
·
verified ·
1 Parent(s): 3b37ef5

Upload convert-onnx-to-rknn.py

Browse files
Files changed (1) hide show
  1. convert-onnx-to-rknn.py +120 -0
convert-onnx-to-rknn.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ from typing import List
5
+ from rknn.api import RKNN
6
+ from math import exp
7
+ from sys import exit
8
+ import argparse
9
+
10
+
11
+ def convert_pipeline_component(onnx_path: str, resolution_list: List[List[int]], target_platform: str = 'rk3588'):
12
+ print(f'Converting {onnx_path} to RKNN model')
13
+ print(f'with target platform {target_platform}')
14
+ print(f'with resolutions:')
15
+ for res in resolution_list:
16
+ print(f'- {res[0]}x{res[1]}')
17
+ use_dynamic_shape = False
18
+ if(len(resolution_list) > 1):
19
+ print("Warning: RKNN dynamic shape support is probably broken, may throw errors")
20
+ use_dynamic_shape = True
21
+
22
+ batch_size = 1
23
+ LATENT_RESIZE_FACTOR = 8
24
+ # build shape list
25
+ if "text_encoder" in onnx_path:
26
+ input_size_list = [[[1,77]]]
27
+ inputs=['input_ids']
28
+ use_dynamic_shape = False
29
+ elif "unet" in onnx_path:
30
+ # batch_size = 2 # for classifier free guidance # broken for rknn python api
31
+
32
+ input_size_list = []
33
+ for res in resolution_list:
34
+ input_size_list.append(
35
+ [[1,4, res[0]//LATENT_RESIZE_FACTOR, res[1]//LATENT_RESIZE_FACTOR],
36
+ [1],
37
+ [1, 77, 768],
38
+ [1, 256]]
39
+ )
40
+ inputs=['sample','timestep','encoder_hidden_states','timestep_cond']
41
+ elif "vae_decoder" in onnx_path:
42
+ input_size_list = []
43
+ for res in resolution_list:
44
+ input_size_list.append(
45
+ [[1,4, res[0]//LATENT_RESIZE_FACTOR, res[1]//LATENT_RESIZE_FACTOR]]
46
+ )
47
+ inputs=['latent_sample']
48
+ else:
49
+ print("Unknown component: ", onnx_path)
50
+ exit(1)
51
+
52
+ rknn = RKNN(verbose=True)
53
+
54
+ # pre-process config
55
+ print('--> Config model')
56
+ rknn.config(target_platform='rk3588', optimization_level=3, single_core_mode=True,
57
+ dynamic_input= input_size_list if use_dynamic_shape else None)
58
+ print('done')
59
+
60
+ # Load ONNX model
61
+ print('--> Loading model')
62
+ ret = rknn.load_onnx(model=onnx_path,
63
+ inputs=None if use_dynamic_shape else inputs,
64
+ input_size_list= None if use_dynamic_shape else input_size_list[0])
65
+ if ret != 0:
66
+ print('Load model failed!')
67
+ exit(ret)
68
+ print('done')
69
+
70
+ # Build model
71
+ print('--> Building model')
72
+ ret = rknn.build(do_quantization=False, rknn_batch_size=batch_size)
73
+ if ret != 0:
74
+ print('Build model failed!')
75
+ exit(ret)
76
+ print('done')
77
+
78
+ #export
79
+ print('--> Export RKNN model')
80
+ ret = rknn.export_rknn(onnx_path.replace('.onnx', '.rknn'))
81
+ if ret != 0:
82
+ print('Export RKNN model failed!')
83
+ exit(ret)
84
+ print('done')
85
+
86
+ rknn.release()
87
+ print('RKNN model is converted successfully!')
88
+
89
+
90
+ def parse_resolution_list(resolution: str) -> List[List[int]]:
91
+ resolution_pairs = resolution.split(',')
92
+ parsed_resolutions = []
93
+ for pair in resolution_pairs:
94
+ width, height = map(int, pair.split('x'))
95
+ parsed_resolutions.append([width, height])
96
+
97
+ return parsed_resolutions
98
+
99
+
100
+ if __name__ == '__main__':
101
+ parser = argparse.ArgumentParser(description='Convert Stable Diffusion ONNX models to RKNN models')
102
+ parser.add_argument('-m','--model-dir', type=str, help='Directory containing the Stable Diffusion ONNX models', required=True)
103
+ parser.add_argument('-c','--components', type=str, help='Name of the components to convert, e.g. "text_encoder,unet,vae_decoder"', default='text_encoder, unet, vae_decoder')
104
+ parser.add_argument('-r','--resolutions', type=str, help='Comma-separated list of resolutions for the model, e.g. "256x256,512x512"', default='256x256')
105
+ parser.add_argument('--target_platform', type=str, help='Target platform for the RKNN model, default is "rk3588"', default='rk3588')
106
+ args = parser.parse_args()
107
+
108
+ components = args.components.split(',')
109
+
110
+ for component in components:
111
+ onnx_path = f'{args.model_dir}/{component.strip()}/model.onnx'
112
+ resolution_list = parse_resolution_list(args.resolutions)
113
+ if(len(resolution_list) == 0):
114
+ print("Error: No resolutions specified")
115
+ exit(1)
116
+
117
+ convert_pipeline_component(onnx_path, resolution_list, args.target_platform)
118
+
119
+
120
+