|
import os |
|
import sys |
|
import onnx |
|
from onnxslim import slim |
|
from onnxconverter_common.float16 import convert_float_to_float16 |
|
from onnxconverter_common.optimizer import optimize_onnx_model |
|
import logging |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
|
) |
|
|
|
|
|
def convert_to_fp16(input_onnx_model): |
|
base_name, _ = os.path.splitext(input_onnx_model) |
|
|
|
if base_name.endswith(".fp32"): |
|
base_name = base_name[:-5] |
|
|
|
output_onnx_model = base_name + ".fp16.onnx" |
|
|
|
if os.path.exists(output_onnx_model): |
|
logging.info( |
|
f"FP16 version {output_onnx_model} already exists. Skipping conversion." |
|
) |
|
return |
|
|
|
logging.info(f"Starting conversion for {input_onnx_model}") |
|
|
|
logging.info(f"Simplifying model {input_onnx_model}") |
|
model = slim(input_onnx_model) |
|
|
|
sys.setrecursionlimit(10000) |
|
|
|
logging.info( |
|
f"Performing shape inference and quant pre-process for {input_onnx_model}" |
|
) |
|
model_optimized = optimize_onnx_model(model) |
|
|
|
logging.info(f"Converting {input_onnx_model} to FP16") |
|
model_fp16 = convert_float_to_float16(model_optimized, keep_io_types=True) |
|
|
|
logging.info(f"Saving FP16 model to {output_onnx_model}") |
|
onnx.save(model_fp16, output_onnx_model) |
|
logging.info(f"FP16 model saved to {output_onnx_model}") |
|
|
|
|
|
def main(): |
|
onnx_files = [f for f in os.listdir(".") if f.endswith(".onnx")] |
|
|
|
if not onnx_files: |
|
logging.warning("No .onnx files found in the current directory.") |
|
else: |
|
for onnx_file in onnx_files: |
|
convert_to_fp16(onnx_file) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|