Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import onnx | |
| import onnxruntime | |
| from PIL import Image, ImageDraw, ImageFont | |
| import gradio as gr | |
| # Constants | |
| PROB_THRESHOLD = 0.5 # Minimum probability to show results | |
| MODEL_PATH = os.path.join("onnx", "model.onnx") | |
| LABELS_PATH = os.path.join("onnx", "labels.txt") | |
| # Load labels | |
| with open(LABELS_PATH, "r") as f: | |
| LABELS = f.read().strip().split("\n") | |
| class Model: | |
| def __init__(self, model_filepath): | |
| self.session = onnxruntime.InferenceSession(model_filepath) | |
| assert len(self.session.get_inputs()) == 1 | |
| self.input_shape = self.session.get_inputs()[0].shape[2:] # (H, W) | |
| self.input_name = self.session.get_inputs()[0].name | |
| self.input_type = {'tensor(float)': np.float32, 'tensor(float16)': np.float16}.get( | |
| self.session.get_inputs()[0].type, np.float32 | |
| ) | |
| self.output_names = [o.name for o in self.session.get_outputs()] | |
| self.is_bgr = False | |
| self.is_range255 = False | |
| onnx_model = onnx.load(model_filepath) | |
| for metadata in onnx_model.metadata_props: | |
| if metadata.key == 'Image.BitmapPixelFormat' and metadata.value == 'Bgr8': | |
| self.is_bgr = True | |
| elif metadata.key == 'Image.NominalPixelRange' and metadata.value == 'NominalRange_0_255': | |
| self.is_range255 = True | |
| def predict(self, image: Image.Image): | |
| # Preprocess image | |
| image_resized = image.resize(self.input_shape) | |
| input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :] | |
| input_array = input_array.transpose((0, 3, 1, 2)) # (N, C, H, W) | |
| if self.is_bgr: | |
| input_array = input_array[:, (2, 1, 0), :, :] | |
| if not self.is_range255: | |
| input_array = input_array / 255.0 # Normalize to [0,1] | |
| # Run inference | |
| outputs = self.session.run(self.output_names, {self.input_name: input_array.astype(self.input_type)}) | |
| return {name: outputs[i] for i, name in enumerate(self.output_names)} | |
| def draw_boxes(image: Image.Image, outputs: dict): | |
| draw = ImageDraw.Draw(image, "RGBA") # Use RGBA for transparency | |
| # Dynamic font size based on image dimensions | |
| image_width, image_height = image.size | |
| boxes = outputs.get('detected_boxes', []) | |
| classes = outputs.get('detected_classes', []) | |
| scores = outputs.get('detected_scores', []) | |
| for box, cls, score in zip(boxes[0], classes[0], scores[0]): | |
| if score < PROB_THRESHOLD: | |
| continue | |
| label = LABELS[int(cls)] | |
| # Assuming box format: [ymin, xmin, ymax, xmax] normalized [0,1] | |
| ymin, xmin, ymax, xmax = box | |
| left = xmin * image_width | |
| right = xmax * image_width | |
| top = ymin * image_height | |
| bottom = ymax * image_height | |
| # Draw bounding box | |
| draw.rectangle([left, top, right, bottom], outline="red", width=4) | |
| # Prepare label text | |
| text = f"{label}: {score:.2f}" | |
| #Text Box will occupy the top left corner of the bounding box | |
| #Size of the box (and text) will be 10% of the height of the bounding box | |
| #And width of the entire bounding box | |
| text_width = right - left | |
| text_height = (bottom - top) // 20 # 5% of the bounding box height | |
| # Calculate label background position | |
| # Ensure the label box does not go above the image | |
| label_top = max(top - text_height - 10, 0) | |
| label_left = left | |
| # Draw semi-transparent rectangle behind text | |
| draw.rectangle( | |
| [label_left, label_top, label_left + text_width + 10, label_top + text_height + 10], | |
| fill=(255, 0, 0, 160) # Semi-transparent red | |
| ) | |
| #Font Size should fill the text box we just drew vertically | |
| font_size = 1 | |
| font = ImageFont.load_default() | |
| while font.getsize(text)[1] < text_height: | |
| font_size += 1 | |
| font = ImageFont.load_default().font_variant(size=font_size) | |
| # Draw text | |
| draw.text( | |
| (label_left + 5, label_top + 5), | |
| text, | |
| fill="black", | |
| font=font | |
| ) | |
| return image | |
| # Initialize model | |
| model = Model(MODEL_PATH) | |
| def detect_objects(image): | |
| outputs = model.predict(image) | |
| annotated_image = draw_boxes(image.copy(), outputs) | |
| # Prepare detection summary | |
| detections = [] | |
| boxes = outputs.get('detected_boxes', []) | |
| classes = outputs.get('detected_classes', []) | |
| scores = outputs.get('detected_scores', []) | |
| for box, cls, score in zip(boxes[0], classes[0], scores[0]): | |
| if score < PROB_THRESHOLD: | |
| continue | |
| label = LABELS[int(cls)] | |
| detections.append(f"{label}: {score:.2f}") | |
| detection_summary = "\n".join(detections) if detections else "No objects detected." | |
| return annotated_image, detection_summary | |
| # Gradio Interface | |
| iface = gr.Interface( | |
| fn=detect_objects, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[ | |
| gr.Image(type="pil", label="Detected Objects"), | |
| gr.Textbox(label="Detections") | |
| ], | |
| title="JunkWaxHero - Baseball Card Set Detection (ONNX Model)", | |
| description="Upload an image to itentify the set of the baseball card.", | |
| examples=["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"], | |
| theme="default", # You can choose other themes if desired | |
| allow_flagging="never" # Disable flagging if not needed | |
| # Removed 'layout' parameter | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |