Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import onnx | |
import onnxruntime | |
from PIL import Image, ImageDraw, ImageFont | |
import gradio as gr | |
import time # Import time for benchmarking | |
# Constants | |
PROB_THRESHOLD = 0.5 # Minimum probability to show results | |
MODEL_PATH = os.path.join("onnx", "model.onnx") | |
LABELS_PATH = os.path.join("onnx", "labels.txt") | |
# Load labels | |
with open(LABELS_PATH, "r") as f: | |
LABELS = f.read().strip().split("\n") | |
class Model: | |
def __init__(self, model_filepath): | |
self.session = onnxruntime.InferenceSession(model_filepath) | |
assert len(self.session.get_inputs()) == 1 | |
self.input_shape = self.session.get_inputs()[0].shape[2:] # (H, W) | |
self.input_name = self.session.get_inputs()[0].name | |
self.input_type = {'tensor(float)': np.float32, 'tensor(float16)': np.float16}.get( | |
self.session.get_inputs()[0].type, np.float32 | |
) | |
self.output_names = [o.name for o in self.session.get_outputs()] | |
self.is_bgr = False | |
self.is_range255 = False | |
onnx_model = onnx.load(model_filepath) | |
for metadata in onnx_model.metadata_props: | |
if metadata.key == 'Image.BitmapPixelFormat' and metadata.value == 'Bgr8': | |
self.is_bgr = True | |
elif metadata.key == 'Image.NominalPixelRange' and metadata.value == 'NominalRange_0_255': | |
self.is_range255 = True | |
def predict(self, image: Image.Image): | |
# Preprocess image | |
image_resized = image.resize(self.input_shape) | |
input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :] | |
input_array = input_array.transpose((0, 3, 1, 2)) # (N, C, H, W) | |
if self.is_bgr: | |
input_array = input_array[:, (2, 1, 0), :, :] | |
if not self.is_range255: | |
input_array = input_array / 255.0 # Normalize to [0,1] | |
# Run inference with benchmarking | |
start_time = time.time() # Start timing | |
outputs = self.session.run(self.output_names, {self.input_name: input_array.astype(self.input_type)}) | |
end_time = time.time() # End timing | |
execution_time = (end_time - start_time) * 1000 # Convert to milliseconds | |
print(f"Inference time: {execution_time:.2f} ms") | |
return {name: outputs[i] for i, name in enumerate(self.output_names)}, execution_time | |
def draw_boxes(image: Image.Image, outputs: dict): | |
draw = ImageDraw.Draw(image, "RGBA") # Use RGBA for transparency | |
# Dynamic font size based on image dimensions | |
image_width, image_height = image.size | |
boxes = outputs.get('detected_boxes', []) | |
classes = outputs.get('detected_classes', []) | |
scores = outputs.get('detected_scores', []) | |
for box, cls, score in zip(boxes[0], classes[0], scores[0]): | |
if score < PROB_THRESHOLD: | |
continue | |
label = LABELS[int(cls)] | |
# Assuming box format: [ymin, xmin, ymax, xmax] normalized [0,1] | |
ymin, xmin, ymax, xmax = box | |
left = xmin * image_width | |
right = xmax * image_width | |
top = ymin * image_height | |
bottom = ymax * image_height | |
# Draw bounding box | |
draw.rectangle([left, top, right, bottom], outline="red", width=4) | |
# Prepare label text | |
text = f"{label}: {score:.2f}" | |
# Set label box dimensions | |
text_width = right - left | |
text_height = (bottom - top) // 20 # 5% of the box heights | |
# Calculate label background position | |
label_top = max(top - text_height - 10, 0) | |
label_left = left | |
# Draw semi-transparent rectangle behind text | |
draw.rectangle( | |
[label_left, label_top, label_left + text_width + 10, label_top + text_height + 10], | |
fill=(255, 0, 0, 160) # Semi-transparent red | |
) | |
# Dynamically scale font size | |
font_size = 10 # Start with a small font size | |
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" # Common path on Linux | |
while True: | |
font = ImageFont.truetype(font_path, size=font_size) | |
text_bbox = draw.textbbox((0, 0), text, font=font) | |
text_pixel_height = text_bbox[3] - text_bbox[1] | |
if text_pixel_height >= text_height or font_size > 200: # Cap font size to prevent infinite loops | |
break | |
font_size += 1 | |
# Draw text with the scaled font | |
draw.text( | |
(label_left + 5, label_top + 5), | |
text, | |
fill="black", | |
font=font | |
) | |
return image | |
# Initialize model | |
model = Model(MODEL_PATH) | |
def detect_objects(image): | |
outputs, execution_time = model.predict(image) | |
annotated_image = draw_boxes(image.copy(), outputs) | |
# Prepare detection summary | |
detections = [] | |
boxes = outputs.get('detected_boxes', []) | |
classes = outputs.get('detected_classes', []) | |
scores = outputs.get('detected_scores', []) | |
for box, cls, score in zip(boxes[0], classes[0], scores[0]): | |
if score < PROB_THRESHOLD: | |
continue | |
label = LABELS[int(cls)] | |
detections.append(f"{label}: {score:.2f}") | |
detection_summary = "\n".join(detections) if detections else "No objects detected." | |
detection_summary += f"\n\nInference Time: {execution_time:.2f} ms" | |
return annotated_image, detection_summary | |
# Enhanced Gradio Interface with Links to Model Card and Repository | |
iface = gr.Interface( | |
fn=detect_objects, | |
inputs=gr.Image(type="pil"), | |
outputs=[ | |
gr.Image(type="pil", label="Detected Objects"), | |
gr.Textbox(label="Detections") | |
], | |
title="JunkWaxHero โพ - Baseball Card Set Detection (ONNX Model)", | |
description=( | |
"Upload an image to identify the set of the baseball card (1980-1999).\n\n" | |
"[๐ Model Card & Repository](https://huggingface.co/enusbaum/JunkWaxHero)" | |
), | |
examples=["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"], | |
theme="default", # You can choose other themes if desired | |
allow_flagging="never" # Disable flagging if not needed | |
) | |
if __name__ == "__main__": | |
iface.launch() | |