amaye15 commited on
Commit
4cc7068
1 Parent(s): dffa7f9

Update export_to_onnx.py

Browse files
Files changed (1) hide show
  1. export_to_onnx.py +47 -0
export_to_onnx.py CHANGED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModel, AutoProcessor
3
+ from PIL import Image
4
+ import requests
5
+ import os
6
+ import onnxruntime as ort
7
+ import numpy as np
8
+
9
+ # Constants
10
+ MODEL_NAME = "amaye15/DaViT-Florence-2-large-ft"
11
+ CACHE_DIR = os.getcwd()
12
+ PROMPT = "<OCR>"
13
+ IMAGE_URL = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
14
+ ONNX_MODEL_PATH = "model.onnx"
15
+
16
+ # Load the model and processor
17
+ model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True, cache_dir=CACHE_DIR)
18
+ processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True, cache_dir=CACHE_DIR)
19
+
20
+ # Prepare the input
21
+ image = Image.open(requests.get(IMAGE_URL, stream=True).raw)
22
+ inputs = processor(text=PROMPT, images=image, return_tensors="pt")
23
+
24
+ # Export the model to ONNX
25
+ input_names = ["pixel_values"]
26
+ output_names = ["output"]
27
+ torch.onnx.export(
28
+ model,
29
+ inputs["pixel_values"],
30
+ ONNX_MODEL_PATH,
31
+ input_names=input_names,
32
+ output_names=output_names,
33
+ dynamic_axes={"pixel_values": {0: "batch_size"}, "output": {0: "batch_size"}},
34
+ opset_version=11
35
+ )
36
+
37
+ # Load the ONNX model
38
+ ort_session = ort.InferenceSession(ONNX_MODEL_PATH)
39
+
40
+ # Prepare the inputs for ONNX model
41
+ ort_inputs = {"pixel_values": inputs["pixel_values"].numpy()}
42
+
43
+ # Run the ONNX model
44
+ ort_outs = ort_session.run(None, ort_inputs)
45
+
46
+ # Display the output
47
+ print(ort_outs)