Spaces:
Sleeping
Sleeping
Eric P. Nusbaum
commited on
Commit
·
de40de9
1
Parent(s):
32eca4a
Update Space
Browse files
app.py
CHANGED
@@ -1,3 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
def draw_boxes(image: Image.Image, outputs: dict):
|
2 |
draw = ImageDraw.Draw(image, "RGBA") # Use RGBA for transparency
|
3 |
|
@@ -60,3 +110,45 @@ def draw_boxes(image: Image.Image, outputs: dict):
|
|
60 |
)
|
61 |
|
62 |
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import onnx
|
4 |
+
import onnxruntime
|
5 |
+
from PIL import Image, ImageDraw, ImageFont
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
# Constants
|
9 |
+
PROB_THRESHOLD = 0.5 # Minimum probability to show results
|
10 |
+
MODEL_PATH = os.path.join("onnx", "model.onnx")
|
11 |
+
LABELS_PATH = os.path.join("onnx", "labels.txt")
|
12 |
+
|
13 |
+
# Load labels
|
14 |
+
with open(LABELS_PATH, "r") as f:
|
15 |
+
LABELS = f.read().strip().split("\n")
|
16 |
+
|
17 |
+
class Model:
|
18 |
+
def __init__(self, model_filepath):
|
19 |
+
self.session = onnxruntime.InferenceSession(model_filepath)
|
20 |
+
assert len(self.session.get_inputs()) == 1
|
21 |
+
self.input_shape = self.session.get_inputs()[0].shape[2:] # (H, W)
|
22 |
+
self.input_name = self.session.get_inputs()[0].name
|
23 |
+
self.input_type = {'tensor(float)': np.float32, 'tensor(float16)': np.float16}.get(
|
24 |
+
self.session.get_inputs()[0].type, np.float32
|
25 |
+
)
|
26 |
+
self.output_names = [o.name for o in self.session.get_outputs()]
|
27 |
+
|
28 |
+
self.is_bgr = False
|
29 |
+
self.is_range255 = False
|
30 |
+
onnx_model = onnx.load(model_filepath)
|
31 |
+
for metadata in onnx_model.metadata_props:
|
32 |
+
if metadata.key == 'Image.BitmapPixelFormat' and metadata.value == 'Bgr8':
|
33 |
+
self.is_bgr = True
|
34 |
+
elif metadata.key == 'Image.NominalPixelRange' and metadata.value == 'NominalRange_0_255':
|
35 |
+
self.is_range255 = True
|
36 |
+
|
37 |
+
def predict(self, image: Image.Image):
|
38 |
+
# Preprocess image
|
39 |
+
image_resized = image.resize(self.input_shape)
|
40 |
+
input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :]
|
41 |
+
input_array = input_array.transpose((0, 3, 1, 2)) # (N, C, H, W)
|
42 |
+
if self.is_bgr:
|
43 |
+
input_array = input_array[:, (2, 1, 0), :, :]
|
44 |
+
if not self.is_range255:
|
45 |
+
input_array = input_array / 255.0 # Normalize to [0,1]
|
46 |
+
|
47 |
+
# Run inference
|
48 |
+
outputs = self.session.run(self.output_names, {self.input_name: input_array.astype(self.input_type)})
|
49 |
+
return {name: outputs[i] for i, name in enumerate(self.output_names)}
|
50 |
+
|
51 |
def draw_boxes(image: Image.Image, outputs: dict):
|
52 |
draw = ImageDraw.Draw(image, "RGBA") # Use RGBA for transparency
|
53 |
|
|
|
110 |
)
|
111 |
|
112 |
return image
|
113 |
+
|
114 |
+
# Initialize model
|
115 |
+
model = Model(MODEL_PATH)
|
116 |
+
|
117 |
+
def detect_objects(image):
|
118 |
+
outputs = model.predict(image)
|
119 |
+
annotated_image = draw_boxes(image.copy(), outputs)
|
120 |
+
|
121 |
+
# Prepare detection summary
|
122 |
+
detections = []
|
123 |
+
boxes = outputs.get('detected_boxes', [])
|
124 |
+
classes = outputs.get('detected_classes', [])
|
125 |
+
scores = outputs.get('detected_scores', [])
|
126 |
+
|
127 |
+
for box, cls, score in zip(boxes[0], classes[0], scores[0]):
|
128 |
+
if score < PROB_THRESHOLD:
|
129 |
+
continue
|
130 |
+
label = LABELS[int(cls)]
|
131 |
+
detections.append(f"{label}: {score:.2f}")
|
132 |
+
|
133 |
+
detection_summary = "\n".join(detections) if detections else "No objects detected."
|
134 |
+
|
135 |
+
return annotated_image, detection_summary
|
136 |
+
|
137 |
+
# Gradio Interface
|
138 |
+
iface = gr.Interface(
|
139 |
+
fn=detect_objects,
|
140 |
+
inputs=gr.Image(type="pil"),
|
141 |
+
outputs=[
|
142 |
+
gr.Image(type="pil", label="Detected Objects"),
|
143 |
+
gr.Textbox(label="Detections")
|
144 |
+
],
|
145 |
+
title="JunkWaxHero - Baseball Card Set Detection (ONNX Model)",
|
146 |
+
description="Upload an image to itentify the set of the baseball card.",
|
147 |
+
examples=["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"],
|
148 |
+
theme="default", # You can choose other themes if desired
|
149 |
+
allow_flagging="never" # Disable flagging if not needed
|
150 |
+
# Removed 'layout' parameter
|
151 |
+
)
|
152 |
+
|
153 |
+
if __name__ == "__main__":
|
154 |
+
iface.launch()
|