Segment-Anything-2.1-RKNN2 / convert_rknn.py
happyme531's picture
Upload 12 files
50704de verified
#!/usr/bin/env python
# coding: utf-8
import datetime
import argparse
from rknn.api import RKNN
from sys import exit
import os
import onnxslim
num_pointss = [1]
num_labelss = [1]
def convert_to_rknn(onnx_model, model_part, dataset="/home/zt/rk3588-nn/rknn_model_zoo/datasets/COCO/coco_subset_20.txt", quantize=False):
"""转换单个ONNX模型到RKNN格式"""
rknn_model = onnx_model.replace(".onnx",".rknn")
timedate_iso = datetime.datetime.now().isoformat()
print(f"\n开始转换 {onnx_model}{rknn_model}")
input_shapes = None
if model_part == "encoder":
input_shapes = None
elif model_part == "decoder":
input_shapes = [
[
[1, 256, 64, 64], # image_embedding
[1, 32, 256, 256], # high_res_feats_0
[1, 64, 128, 128], # high_res_feats_1
[num_labels, num_points, 2], # point_coords
[num_labels, num_points], # point_labels
[num_labels, 1, 256, 256], # mask_input
[num_labels], # has_mask_input
]
for num_labels in num_labelss
for num_points in num_pointss
]
rknn = RKNN(verbose=True)
rknn.config(
dynamic_input=input_shapes,
std_values=[[255,255,255]] if model_part == "encoder" else None,
quantized_dtype='w8a8',
quantized_algorithm='normal',
quantized_method='channel',
quantized_hybrid_level=0,
target_platform='rk3588',
quant_img_RGB2BGR = False,
float_dtype='float16',
optimization_level=3,
custom_string=f"converted at {timedate_iso}",
remove_weight=False,
compress_weight=False,
inputs_yuv_fmt=None,
single_core_mode=False,
model_pruning=False,
op_target=None,
quantize_weight=False,
remove_reshape=False,
sparse_infer=False,
enable_flash_attention=False,
)
ret = rknn.load_onnx(model=onnx_model)
ret = rknn.build(do_quantization=quantize, dataset=dataset, rknn_batch_size=None)
ret = rknn.export_rknn(rknn_model)
print(f"完成转换 {rknn_model}\n")
def main():
parser = argparse.ArgumentParser(description='转换SAM模型从ONNX到RKNN格式')
parser.add_argument('model_name', type=str, help='模型名称,例如: sam2.1_hiera_tiny')
args = parser.parse_args()
# 构建encoder和decoder的文件名
encoder_onnx = f"{args.model_name}_encoder.onnx"
decoder_onnx = f"{args.model_name}_decoder.onnx"
# 检查文件是否存在
for model in [encoder_onnx, decoder_onnx]:
if not os.path.exists(model):
print(f"错误: 找不到文件 {model}")
exit(1)
# 转换encoder和decoder
#encoder需要先跑一个onnxslim
print("开始转换encoder...")
onnxslim.slim(encoder_onnx, output_model="encoder_slim.onnx", skip_fusion_patterns=["EliminationSlice"])
convert_to_rknn("encoder_slim.onnx", model_part="encoder")
os.rename("encoder_slim.rknn", encoder_onnx.replace(".onnx", ".rknn"))
os.remove("encoder_slim.onnx")
# convert_to_rknn(decoder_onnx, model_part="decoder") # 坏的
print("所有模型转换完成!")
if __name__ == "__main__":
main()