Eric P. Nusbaum
Benchmarking
cb58d33
import os
import numpy as np
import onnx
import onnxruntime
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import time # Import time for benchmarking
# Constants
PROB_THRESHOLD = 0.5 # Minimum probability to show results
MODEL_PATH = os.path.join("onnx", "model.onnx")
LABELS_PATH = os.path.join("onnx", "labels.txt")
# Load labels
with open(LABELS_PATH, "r") as f:
LABELS = f.read().strip().split("\n")
class Model:
def __init__(self, model_filepath):
self.session = onnxruntime.InferenceSession(model_filepath)
assert len(self.session.get_inputs()) == 1
self.input_shape = self.session.get_inputs()[0].shape[2:] # (H, W)
self.input_name = self.session.get_inputs()[0].name
self.input_type = {'tensor(float)': np.float32, 'tensor(float16)': np.float16}.get(
self.session.get_inputs()[0].type, np.float32
)
self.output_names = [o.name for o in self.session.get_outputs()]
self.is_bgr = False
self.is_range255 = False
onnx_model = onnx.load(model_filepath)
for metadata in onnx_model.metadata_props:
if metadata.key == 'Image.BitmapPixelFormat' and metadata.value == 'Bgr8':
self.is_bgr = True
elif metadata.key == 'Image.NominalPixelRange' and metadata.value == 'NominalRange_0_255':
self.is_range255 = True
def predict(self, image: Image.Image):
# Preprocess image
image_resized = image.resize(self.input_shape)
input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :]
input_array = input_array.transpose((0, 3, 1, 2)) # (N, C, H, W)
if self.is_bgr:
input_array = input_array[:, (2, 1, 0), :, :]
if not self.is_range255:
input_array = input_array / 255.0 # Normalize to [0,1]
# Run inference with benchmarking
start_time = time.time() # Start timing
outputs = self.session.run(self.output_names, {self.input_name: input_array.astype(self.input_type)})
end_time = time.time() # End timing
execution_time = (end_time - start_time) * 1000 # Convert to milliseconds
print(f"Inference time: {execution_time:.2f} ms")
return {name: outputs[i] for i, name in enumerate(self.output_names)}, execution_time
def draw_boxes(image: Image.Image, outputs: dict):
draw = ImageDraw.Draw(image, "RGBA") # Use RGBA for transparency
# Dynamic font size based on image dimensions
image_width, image_height = image.size
boxes = outputs.get('detected_boxes', [])
classes = outputs.get('detected_classes', [])
scores = outputs.get('detected_scores', [])
for box, cls, score in zip(boxes[0], classes[0], scores[0]):
if score < PROB_THRESHOLD:
continue
label = LABELS[int(cls)]
# Assuming box format: [ymin, xmin, ymax, xmax] normalized [0,1]
ymin, xmin, ymax, xmax = box
left = xmin * image_width
right = xmax * image_width
top = ymin * image_height
bottom = ymax * image_height
# Draw bounding box
draw.rectangle([left, top, right, bottom], outline="red", width=4)
# Prepare label text
text = f"{label}: {score:.2f}"
# Set label box dimensions
text_width = right - left
text_height = (bottom - top) // 20 # 5% of the box heights
# Calculate label background position
label_top = max(top - text_height - 10, 0)
label_left = left
# Draw semi-transparent rectangle behind text
draw.rectangle(
[label_left, label_top, label_left + text_width + 10, label_top + text_height + 10],
fill=(255, 0, 0, 160) # Semi-transparent red
)
# Dynamically scale font size
font_size = 10 # Start with a small font size
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" # Common path on Linux
while True:
font = ImageFont.truetype(font_path, size=font_size)
text_bbox = draw.textbbox((0, 0), text, font=font)
text_pixel_height = text_bbox[3] - text_bbox[1]
if text_pixel_height >= text_height or font_size > 200: # Cap font size to prevent infinite loops
break
font_size += 1
# Draw text with the scaled font
draw.text(
(label_left + 5, label_top + 5),
text,
fill="black",
font=font
)
return image
# Initialize model
model = Model(MODEL_PATH)
def detect_objects(image):
outputs, execution_time = model.predict(image)
annotated_image = draw_boxes(image.copy(), outputs)
# Prepare detection summary
detections = []
boxes = outputs.get('detected_boxes', [])
classes = outputs.get('detected_classes', [])
scores = outputs.get('detected_scores', [])
for box, cls, score in zip(boxes[0], classes[0], scores[0]):
if score < PROB_THRESHOLD:
continue
label = LABELS[int(cls)]
detections.append(f"{label}: {score:.2f}")
detection_summary = "\n".join(detections) if detections else "No objects detected."
detection_summary += f"\n\nInference Time: {execution_time:.2f} ms"
return annotated_image, detection_summary
# Enhanced Gradio Interface with Links to Model Card and Repository
iface = gr.Interface(
fn=detect_objects,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(type="pil", label="Detected Objects"),
gr.Textbox(label="Detections")
],
title="JunkWaxHero โšพ - Baseball Card Set Detection (ONNX Model)",
description=(
"Upload an image to identify the set of the baseball card (1980-1999).\n\n"
"[๐Ÿ”— Model Card & Repository](https://huggingface.co/enusbaum/JunkWaxHero)"
),
examples=["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"],
theme="default", # You can choose other themes if desired
allow_flagging="never" # Disable flagging if not needed
)
if __name__ == "__main__":
iface.launch()