alexandreteles's picture
feat: initial commit
cec228d
import os
import sys
import onnx
from onnxslim import slim
from onnxconverter_common.float16 import convert_float_to_float16
from onnxconverter_common.optimizer import optimize_onnx_model
import logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
def convert_to_fp16(input_onnx_model):
base_name, _ = os.path.splitext(input_onnx_model)
if base_name.endswith(".fp32"):
base_name = base_name[:-5]
output_onnx_model = base_name + ".fp16.onnx"
if os.path.exists(output_onnx_model):
logging.info(
f"FP16 version {output_onnx_model} already exists. Skipping conversion."
)
return
logging.info(f"Starting conversion for {input_onnx_model}")
logging.info(f"Simplifying model {input_onnx_model}")
model = slim(input_onnx_model)
sys.setrecursionlimit(10000)
logging.info(
f"Performing shape inference and quant pre-process for {input_onnx_model}"
)
model_optimized = optimize_onnx_model(model)
logging.info(f"Converting {input_onnx_model} to FP16")
model_fp16 = convert_float_to_float16(model_optimized, keep_io_types=True)
logging.info(f"Saving FP16 model to {output_onnx_model}")
onnx.save(model_fp16, output_onnx_model)
logging.info(f"FP16 model saved to {output_onnx_model}")
def main():
onnx_files = [f for f in os.listdir(".") if f.endswith(".onnx")]
if not onnx_files:
logging.warning("No .onnx files found in the current directory.")
else:
for onnx_file in onnx_files:
convert_to_fp16(onnx_file)
if __name__ == "__main__":
main()