Spaces:
Sleeping
Sleeping
Eric P. Nusbaum
commited on
Commit
·
0b444ec
1
Parent(s):
f60a7c0
Update Space
Browse files
app.py
CHANGED
@@ -1,152 +1,144 @@
|
|
1 |
import os
|
2 |
import numpy as np
|
3 |
-
import onnx
|
4 |
import onnxruntime
|
5 |
from PIL import Image, ImageDraw, ImageFont
|
6 |
import gradio as gr
|
7 |
|
8 |
-
#
|
9 |
-
PROB_THRESHOLD = 0.5 # Minimum probability to show results
|
10 |
MODEL_PATH = os.path.join("onnx", "model.onnx")
|
11 |
LABELS_PATH = os.path.join("onnx", "labels.txt")
|
12 |
|
13 |
# Load labels
|
14 |
with open(LABELS_PATH, "r") as f:
|
15 |
-
LABELS =
|
16 |
|
|
|
17 |
class Model:
|
18 |
def __init__(self, model_filepath):
|
|
|
19 |
self.session = onnxruntime.InferenceSession(model_filepath)
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
self.input_shape = self.session.get_inputs()[0].shape[2:] # (H, W)
|
22 |
self.input_name = self.session.get_inputs()[0].name
|
23 |
-
self.input_type = {
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
self.is_bgr = False
|
29 |
self.is_range255 = False
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
33 |
self.is_bgr = True
|
34 |
-
elif
|
35 |
self.is_range255 = True
|
36 |
|
37 |
-
def predict(self, image
|
38 |
# Preprocess image
|
39 |
image_resized = image.resize(self.input_shape)
|
40 |
input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :]
|
41 |
input_array = input_array.transpose((0, 3, 1, 2)) # (N, C, H, W)
|
|
|
42 |
if self.is_bgr:
|
43 |
-
input_array = input_array[:, (2, 1, 0), :, :]
|
|
|
44 |
if not self.is_range255:
|
45 |
input_array = input_array / 255.0 # Normalize to [0,1]
|
46 |
-
|
|
|
|
|
|
|
47 |
# Run inference
|
48 |
-
outputs = self.session.run(self.output_names, {self.input_name:
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
try:
|
58 |
-
|
59 |
-
font = ImageFont.truetype("arial.ttf", size=font_size)
|
60 |
except IOError:
|
61 |
-
# Fallback to default font if truetype font is not found
|
62 |
font = ImageFont.load_default()
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
scores = outputs.get('detected_scores', [])
|
67 |
-
|
68 |
-
for box, cls, score in zip(boxes[0], classes[0], scores[0]):
|
69 |
-
if score < PROB_THRESHOLD:
|
70 |
continue
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
text_bbox = draw.textbbox((0, 0), text, font=font)
|
88 |
-
text_width = text_bbox[2] - text_bbox[0]
|
89 |
-
text_height = text_bbox[3] - text_bbox[1]
|
90 |
-
|
91 |
-
# Calculate label background position
|
92 |
-
# Ensure the label box does not go above the image
|
93 |
-
label_top = max(top - text_height - 10, 0)
|
94 |
-
label_left = left
|
95 |
-
|
96 |
-
# Draw semi-transparent rectangle behind text
|
97 |
-
draw.rectangle(
|
98 |
-
[label_left, label_top, label_left + text_width + 10, label_top + text_height + 10],
|
99 |
-
fill=(255, 0, 0, 160) # Semi-transparent red
|
100 |
-
)
|
101 |
-
|
102 |
-
# Draw text
|
103 |
-
draw.text(
|
104 |
-
(label_left + 5, label_top + 5),
|
105 |
-
text,
|
106 |
-
fill="white",
|
107 |
-
font=font
|
108 |
-
)
|
109 |
|
110 |
return image
|
111 |
|
112 |
-
#
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
|
135 |
-
# Gradio Interface
|
136 |
iface = gr.Interface(
|
137 |
-
fn=
|
138 |
inputs=gr.Image(type="pil"),
|
139 |
-
outputs=
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
description="Upload an image to detect objects using the ONNX model.",
|
145 |
-
examples=["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"],
|
146 |
-
theme="default", # You can choose other themes if desired
|
147 |
-
allow_flagging="never" # Disable flagging if not needed
|
148 |
-
# Removed 'layout' parameter
|
149 |
)
|
150 |
|
|
|
151 |
if __name__ == "__main__":
|
152 |
iface.launch()
|
|
|
1 |
import os
|
2 |
import numpy as np
|
|
|
3 |
import onnxruntime
|
4 |
from PIL import Image, ImageDraw, ImageFont
|
5 |
import gradio as gr
|
6 |
|
7 |
+
# Define paths
|
|
|
8 |
MODEL_PATH = os.path.join("onnx", "model.onnx")
|
9 |
LABELS_PATH = os.path.join("onnx", "labels.txt")
|
10 |
|
11 |
# Load labels
|
12 |
with open(LABELS_PATH, "r") as f:
|
13 |
+
LABELS = [line.strip() for line in f.readlines()]
|
14 |
|
15 |
+
# Initialize ONNX Runtime session
|
16 |
class Model:
|
17 |
def __init__(self, model_filepath):
|
18 |
+
# Initialize the InferenceSession
|
19 |
self.session = onnxruntime.InferenceSession(model_filepath)
|
20 |
+
|
21 |
+
# Ensure the model has exactly one input
|
22 |
+
assert len(self.session.get_inputs()) == 1, "Model should have exactly one input."
|
23 |
+
|
24 |
+
# Extract input details
|
25 |
self.input_shape = self.session.get_inputs()[0].shape[2:] # (H, W)
|
26 |
self.input_name = self.session.get_inputs()[0].name
|
27 |
+
self.input_type = {
|
28 |
+
'tensor(float)': np.float32,
|
29 |
+
'tensor(float16)': np.float16
|
30 |
+
}.get(self.session.get_inputs()[0].type, np.float32)
|
31 |
+
|
32 |
+
# Extract output names
|
33 |
+
self.output_names = [output.name for output in self.session.get_outputs()]
|
34 |
+
|
35 |
+
# Default preprocessing flags
|
36 |
self.is_bgr = False
|
37 |
self.is_range255 = False
|
38 |
+
|
39 |
+
# Retrieve metadata from the model
|
40 |
+
metadata_map = self.session.get_modelmeta().custom_metadata_map
|
41 |
+
for key, value in metadata_map.items():
|
42 |
+
if key == 'Image.BitmapPixelFormat' and value == 'Bgr8':
|
43 |
self.is_bgr = True
|
44 |
+
elif key == 'Image.NominalPixelRange' and value == 'NominalRange_0_255':
|
45 |
self.is_range255 = True
|
46 |
|
47 |
+
def predict(self, image):
|
48 |
# Preprocess image
|
49 |
image_resized = image.resize(self.input_shape)
|
50 |
input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :]
|
51 |
input_array = input_array.transpose((0, 3, 1, 2)) # (N, C, H, W)
|
52 |
+
|
53 |
if self.is_bgr:
|
54 |
+
input_array = input_array[:, (2, 1, 0), :, :] # Convert RGB to BGR
|
55 |
+
|
56 |
if not self.is_range255:
|
57 |
input_array = input_array / 255.0 # Normalize to [0,1]
|
58 |
+
|
59 |
+
# Prepare input tensor
|
60 |
+
input_tensor = input_array.astype(self.input_type)
|
61 |
+
|
62 |
# Run inference
|
63 |
+
outputs = self.session.run(self.output_names, {self.input_name: input_tensor})
|
64 |
+
|
65 |
+
# Process outputs
|
66 |
+
# Assuming outputs are in the format: [boxes, labels, scores]
|
67 |
+
# Adjust based on your actual model's output format
|
68 |
+
if len(outputs) >= 3:
|
69 |
+
boxes = outputs[0] # shape: [num_detections, 4]
|
70 |
+
labels = outputs[1].astype(int) # shape: [num_detections]
|
71 |
+
scores = outputs[2] # shape: [num_detections]
|
72 |
+
return boxes, labels, scores
|
73 |
+
else:
|
74 |
+
raise ValueError("Unexpected number of outputs from the model.")
|
75 |
+
|
76 |
+
# Load the model
|
77 |
+
model = Model(MODEL_PATH)
|
78 |
|
79 |
+
# Function to draw bounding boxes
|
80 |
+
def draw_boxes(image, boxes, labels, scores, threshold=0.5):
|
81 |
+
draw = ImageDraw.Draw(image)
|
82 |
try:
|
83 |
+
font = ImageFont.truetype("arial.ttf", 15)
|
|
|
84 |
except IOError:
|
|
|
85 |
font = ImageFont.load_default()
|
86 |
|
87 |
+
for box, label, score in zip(boxes, labels, scores):
|
88 |
+
if score < threshold:
|
|
|
|
|
|
|
|
|
89 |
continue
|
90 |
+
# Assuming box format is [xmin, ymin, xmax, ymax] normalized [0,1]
|
91 |
+
xmin, ymin, xmax, ymax = box
|
92 |
+
width, height = image.size
|
93 |
+
xmin = int(xmin * width)
|
94 |
+
ymin = int(ymin * height)
|
95 |
+
xmax = int(xmax * width)
|
96 |
+
ymax = int(ymax * height)
|
97 |
+
|
98 |
+
# Draw rectangle
|
99 |
+
draw.rectangle([(xmin, ymin), (xmax, ymax)], outline="red", width=2)
|
100 |
+
|
101 |
+
# Draw label
|
102 |
+
label_text = f"{LABELS[label]}: {score:.2f}"
|
103 |
+
text_size = draw.textsize(label_text, font=font)
|
104 |
+
draw.rectangle([(xmin, ymin - text_size[1]), (xmin + text_size[0], ymin)], fill="red")
|
105 |
+
draw.text((xmin, ymin - text_size[1]), label_text, fill="white", font=font)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
return image
|
108 |
|
109 |
+
# Prediction function for Gradio
|
110 |
+
def predict_image(input_image):
|
111 |
+
boxes, labels, scores = model.predict(input_image)
|
112 |
+
output_image = input_image.copy()
|
113 |
+
output_image = draw_boxes(output_image, boxes, labels, scores, threshold=0.5)
|
114 |
+
return output_image
|
115 |
+
|
116 |
+
# Define Gradio Interface
|
117 |
+
def get_example_images():
|
118 |
+
examples_dir = "examples"
|
119 |
+
return [
|
120 |
+
os.path.join(examples_dir, img)
|
121 |
+
for img in os.listdir(examples_dir)
|
122 |
+
if img.lower().endswith(('.png', '.jpg', '.jpeg'))
|
123 |
+
]
|
124 |
+
|
125 |
+
example_images = get_example_images()
|
126 |
+
|
127 |
+
title = "JunkWaxHero: Object Detection for Junk Wax Baseball Cards"
|
128 |
+
description = """
|
129 |
+
Upload an image of a Junk Wax Baseball Card, and the model will identify the card by its set (1980-1999).
|
130 |
+
"""
|
131 |
|
|
|
132 |
iface = gr.Interface(
|
133 |
+
fn=predict_image,
|
134 |
inputs=gr.Image(type="pil"),
|
135 |
+
outputs=gr.Image(type="pil"),
|
136 |
+
examples=example_images,
|
137 |
+
title=title,
|
138 |
+
description=description,
|
139 |
+
allow_flagging="never"
|
|
|
|
|
|
|
|
|
|
|
140 |
)
|
141 |
|
142 |
+
# Launch the interface
|
143 |
if __name__ == "__main__":
|
144 |
iface.launch()
|