import os import numpy as np import onnx import onnxruntime from PIL import Image, ImageDraw, ImageFont import gradio as gr import time # Import time for benchmarking # Constants PROB_THRESHOLD = 0.5 # Minimum probability to show results MODEL_PATH = os.path.join("onnx", "model.onnx") LABELS_PATH = os.path.join("onnx", "labels.txt") # Load labels with open(LABELS_PATH, "r") as f: LABELS = f.read().strip().split("\n") class Model: def __init__(self, model_filepath): self.session = onnxruntime.InferenceSession(model_filepath) assert len(self.session.get_inputs()) == 1 self.input_shape = self.session.get_inputs()[0].shape[2:] # (H, W) self.input_name = self.session.get_inputs()[0].name self.input_type = {'tensor(float)': np.float32, 'tensor(float16)': np.float16}.get( self.session.get_inputs()[0].type, np.float32 ) self.output_names = [o.name for o in self.session.get_outputs()] self.is_bgr = False self.is_range255 = False onnx_model = onnx.load(model_filepath) for metadata in onnx_model.metadata_props: if metadata.key == 'Image.BitmapPixelFormat' and metadata.value == 'Bgr8': self.is_bgr = True elif metadata.key == 'Image.NominalPixelRange' and metadata.value == 'NominalRange_0_255': self.is_range255 = True def predict(self, image: Image.Image): # Preprocess image image_resized = image.resize(self.input_shape) input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :] input_array = input_array.transpose((0, 3, 1, 2)) # (N, C, H, W) if self.is_bgr: input_array = input_array[:, (2, 1, 0), :, :] if not self.is_range255: input_array = input_array / 255.0 # Normalize to [0,1] # Run inference with benchmarking start_time = time.time() # Start timing outputs = self.session.run(self.output_names, {self.input_name: input_array.astype(self.input_type)}) end_time = time.time() # End timing execution_time = (end_time - start_time) * 1000 # Convert to milliseconds print(f"Inference time: {execution_time:.2f} ms") return {name: outputs[i] for i, name in enumerate(self.output_names)}, execution_time def draw_boxes(image: Image.Image, outputs: dict): draw = ImageDraw.Draw(image, "RGBA") # Use RGBA for transparency # Dynamic font size based on image dimensions image_width, image_height = image.size boxes = outputs.get('detected_boxes', []) classes = outputs.get('detected_classes', []) scores = outputs.get('detected_scores', []) for box, cls, score in zip(boxes[0], classes[0], scores[0]): if score < PROB_THRESHOLD: continue label = LABELS[int(cls)] # Assuming box format: [ymin, xmin, ymax, xmax] normalized [0,1] ymin, xmin, ymax, xmax = box left = xmin * image_width right = xmax * image_width top = ymin * image_height bottom = ymax * image_height # Draw bounding box draw.rectangle([left, top, right, bottom], outline="red", width=4) # Prepare label text text = f"{label}: {score:.2f}" # Set label box dimensions text_width = right - left text_height = (bottom - top) // 20 # 5% of the box heights # Calculate label background position label_top = max(top - text_height - 10, 0) label_left = left # Draw semi-transparent rectangle behind text draw.rectangle( [label_left, label_top, label_left + text_width + 10, label_top + text_height + 10], fill=(255, 0, 0, 160) # Semi-transparent red ) # Dynamically scale font size font_size = 10 # Start with a small font size font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" # Common path on Linux while True: font = ImageFont.truetype(font_path, size=font_size) text_bbox = draw.textbbox((0, 0), text, font=font) text_pixel_height = text_bbox[3] - text_bbox[1] if text_pixel_height >= text_height or font_size > 200: # Cap font size to prevent infinite loops break font_size += 1 # Draw text with the scaled font draw.text( (label_left + 5, label_top + 5), text, fill="black", font=font ) return image # Initialize model model = Model(MODEL_PATH) def detect_objects(image): outputs, execution_time = model.predict(image) annotated_image = draw_boxes(image.copy(), outputs) # Prepare detection summary detections = [] boxes = outputs.get('detected_boxes', []) classes = outputs.get('detected_classes', []) scores = outputs.get('detected_scores', []) for box, cls, score in zip(boxes[0], classes[0], scores[0]): if score < PROB_THRESHOLD: continue label = LABELS[int(cls)] detections.append(f"{label}: {score:.2f}") detection_summary = "\n".join(detections) if detections else "No objects detected." detection_summary += f"\n\nInference Time: {execution_time:.2f} ms" return annotated_image, detection_summary # Enhanced Gradio Interface with Links to Model Card and Repository iface = gr.Interface( fn=detect_objects, inputs=gr.Image(type="pil"), outputs=[ gr.Image(type="pil", label="Detected Objects"), gr.Textbox(label="Detections") ], title="JunkWaxHero ⚾ - Baseball Card Set Detection (ONNX Model)", description=( "Upload an image to identify the set of the baseball card (1980-1999).\n\n" "[🔗 Model Card & Repository](https://huggingface.co/enusbaum/JunkWaxHero)" ), examples=["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"], theme="default", # You can choose other themes if desired allow_flagging="never" # Disable flagging if not needed ) if __name__ == "__main__": iface.launch()