Eric P. Nusbaum commited on
Commit
f61c335
·
1 Parent(s): b5e07e6

Update Space

Browse files
Files changed (2) hide show
  1. app.py +106 -103
  2. requirements.txt +5 -4
app.py CHANGED
@@ -1,149 +1,152 @@
1
  import os
2
  import numpy as np
 
3
  import onnxruntime
4
  from PIL import Image, ImageDraw, ImageFont
5
  import gradio as gr
6
 
7
- # Define paths
 
8
  MODEL_PATH = os.path.join("onnx", "model.onnx")
9
  LABELS_PATH = os.path.join("onnx", "labels.txt")
10
 
11
  # Load labels
12
  with open(LABELS_PATH, "r") as f:
13
- LABELS = [line.strip() for line in f.readlines()]
14
 
15
- # Initialize ONNX Runtime session
16
  class Model:
17
  def __init__(self, model_filepath):
18
- # Initialize the InferenceSession
19
  self.session = onnxruntime.InferenceSession(model_filepath)
20
-
21
- # Ensure the model has exactly one input
22
- assert len(self.session.get_inputs()) == 1, "Model should have exactly one input."
23
-
24
- # Extract input details
25
  self.input_shape = self.session.get_inputs()[0].shape[2:] # (H, W)
26
  self.input_name = self.session.get_inputs()[0].name
27
- self.input_type = {
28
- 'tensor(float)': np.float32,
29
- 'tensor(float16)': np.float16
30
- }.get(self.session.get_inputs()[0].type, np.float32)
31
-
32
- # Extract output names
33
- self.output_names = [output.name for output in self.session.get_outputs()]
34
-
35
- # Default preprocessing flags
36
  self.is_bgr = False
37
  self.is_range255 = False
38
-
39
- # Retrieve metadata from the model
40
- metadata_map = self.session.get_modelmeta().custom_metadata_map
41
- for key, value in metadata_map.items():
42
- if key == 'Image.BitmapPixelFormat' and value == 'Bgr8':
43
  self.is_bgr = True
44
- elif key == 'Image.NominalPixelRange' and value == 'NominalRange_0_255':
45
  self.is_range255 = True
46
 
47
- def predict(self, image):
48
  # Preprocess image
49
  image_resized = image.resize(self.input_shape)
50
  input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :]
51
  input_array = input_array.transpose((0, 3, 1, 2)) # (N, C, H, W)
52
-
53
  if self.is_bgr:
54
- input_array = input_array[:, (2, 1, 0), :, :] # Convert RGB to BGR
55
-
56
  if not self.is_range255:
57
  input_array = input_array / 255.0 # Normalize to [0,1]
58
-
59
- # Prepare input tensor
60
- input_tensor = input_array.astype(self.input_type)
61
-
62
  # Run inference
63
- outputs = self.session.run(self.output_names, {self.input_name: input_tensor})
64
-
65
- # Process outputs
66
- # Extract the first (and only) batch element
67
- if len(outputs) >= 3:
68
- boxes = outputs[0][0] # shape: [num_detections, 4]
69
- labels = outputs[1][0].astype(int) # shape: [num_detections]
70
- scores = outputs[2][0] # shape: [num_detections]
71
- return boxes, labels, scores
72
- else:
73
- raise ValueError("Unexpected number of outputs from the model.")
74
-
75
- # Load the model
76
- model = Model(MODEL_PATH)
77
 
78
- # Function to draw bounding boxes
79
- def draw_boxes(image, boxes, labels, scores, threshold=0.5):
80
- draw = ImageDraw.Draw(image)
81
  try:
82
- font = ImageFont.truetype("arial.ttf", 15)
 
83
  except IOError:
 
84
  font = ImageFont.load_default()
85
 
86
- for box, label, score in zip(boxes, labels, scores):
87
- if score < threshold:
88
- continue
89
- if len(box) != 4:
90
- print(f"Invalid box format: {box}")
 
91
  continue
92
- xmin, ymin, xmax, ymax = box
93
- width, height = image.size
94
- xmin = int(xmin * width)
95
- ymin = int(ymin * height)
96
- xmax = int(xmax * width)
97
- ymax = int(ymax * height)
98
-
99
- # Draw rectangle
100
- draw.rectangle([(xmin, ymin), (xmax, ymax)], outline="red", width=2)
101
-
102
- # Draw label
103
- label_text = f"{LABELS[label]}: {score:.2f}"
104
- text_size = font.getsize(label_text) # Updated line
105
- draw.rectangle([(xmin, ymin - text_size[1]), (xmin + text_size[0], ymin)], fill="red")
106
- draw.text((xmin, ymin - text_size[1]), label_text, fill="white", font=font)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  return image
109
 
110
- # Prediction function for Gradio
111
- def predict_image(input_image):
112
- try:
113
- boxes, labels, scores = model.predict(input_image)
114
- output_image = input_image.copy()
115
- output_image = draw_boxes(output_image, boxes, labels, scores, threshold=0.5)
116
- return output_image
117
- except Exception as e:
118
- print(f"Error during prediction: {e}")
119
- return input_image # Return the original image if prediction fails
120
-
121
- # Define Gradio Interface
122
- def get_example_images():
123
- examples_dir = "examples"
124
- return [
125
- os.path.join(examples_dir, img)
126
- for img in os.listdir(examples_dir)
127
- if img.lower().endswith(('.png', '.jpg', '.jpeg'))
128
- ]
129
-
130
- example_images = get_example_images()
131
-
132
- title = "JunkWaxHero: Object Detection for Junk Wax Baseball Cards"
133
- description = """
134
- Upload an image of a Junk Wax Baseball Card, and the model will identify the card by its set (1980-1999).
135
- """
136
 
 
137
  iface = gr.Interface(
138
- fn=predict_image,
139
  inputs=gr.Image(type="pil"),
140
- outputs=gr.Image(type="pil"),
141
- examples=example_images,
142
- title=title,
143
- description=description,
144
- flagging_mode="never" # Updated from allow_flagging="never"
 
 
 
 
 
145
  )
146
 
147
- # Launch the interface
148
  if __name__ == "__main__":
149
  iface.launch()
 
1
  import os
2
  import numpy as np
3
+ import onnx
4
  import onnxruntime
5
  from PIL import Image, ImageDraw, ImageFont
6
  import gradio as gr
7
 
8
+ # Constants
9
+ PROB_THRESHOLD = 0.5 # Minimum probability to show results
10
  MODEL_PATH = os.path.join("onnx", "model.onnx")
11
  LABELS_PATH = os.path.join("onnx", "labels.txt")
12
 
13
  # Load labels
14
  with open(LABELS_PATH, "r") as f:
15
+ LABELS = f.read().strip().split("\n")
16
 
 
17
  class Model:
18
  def __init__(self, model_filepath):
 
19
  self.session = onnxruntime.InferenceSession(model_filepath)
20
+ assert len(self.session.get_inputs()) == 1
 
 
 
 
21
  self.input_shape = self.session.get_inputs()[0].shape[2:] # (H, W)
22
  self.input_name = self.session.get_inputs()[0].name
23
+ self.input_type = {'tensor(float)': np.float32, 'tensor(float16)': np.float16}.get(
24
+ self.session.get_inputs()[0].type, np.float32
25
+ )
26
+ self.output_names = [o.name for o in self.session.get_outputs()]
27
+
 
 
 
 
28
  self.is_bgr = False
29
  self.is_range255 = False
30
+ onnx_model = onnx.load(model_filepath)
31
+ for metadata in onnx_model.metadata_props:
32
+ if metadata.key == 'Image.BitmapPixelFormat' and metadata.value == 'Bgr8':
 
 
33
  self.is_bgr = True
34
+ elif metadata.key == 'Image.NominalPixelRange' and metadata.value == 'NominalRange_0_255':
35
  self.is_range255 = True
36
 
37
+ def predict(self, image: Image.Image):
38
  # Preprocess image
39
  image_resized = image.resize(self.input_shape)
40
  input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :]
41
  input_array = input_array.transpose((0, 3, 1, 2)) # (N, C, H, W)
 
42
  if self.is_bgr:
43
+ input_array = input_array[:, (2, 1, 0), :, :]
 
44
  if not self.is_range255:
45
  input_array = input_array / 255.0 # Normalize to [0,1]
46
+
 
 
 
47
  # Run inference
48
+ outputs = self.session.run(self.output_names, {self.input_name: input_array.astype(self.input_type)})
49
+ return {name: outputs[i] for i, name in enumerate(self.output_names)}
50
+
51
+ def draw_boxes(image: Image.Image, outputs: dict):
52
+ draw = ImageDraw.Draw(image, "RGBA") # Use RGBA for transparency
 
 
 
 
 
 
 
 
 
53
 
54
+ # Dynamic font size based on image dimensions
55
+ image_width, image_height = image.size
56
+ font_size = max(20, image_width // 50) # Increased minimum font size
57
  try:
58
+ # Attempt to load a truetype font; adjust the path if necessary
59
+ font = ImageFont.truetype("arial.ttf", size=font_size)
60
  except IOError:
61
+ # Fallback to default font if truetype font is not found
62
  font = ImageFont.load_default()
63
 
64
+ boxes = outputs.get('detected_boxes', [])
65
+ classes = outputs.get('detected_classes', [])
66
+ scores = outputs.get('detected_scores', [])
67
+
68
+ for box, cls, score in zip(boxes[0], classes[0], scores[0]):
69
+ if score < PROB_THRESHOLD:
70
  continue
71
+ label = LABELS[int(cls)]
72
+
73
+ # Assuming box format: [ymin, xmin, ymax, xmax] normalized [0,1]
74
+ ymin, xmin, ymax, xmax = box
75
+ left = xmin * image_width
76
+ right = xmax * image_width
77
+ top = ymin * image_height
78
+ bottom = ymax * image_height
79
+
80
+ # Draw bounding box
81
+ draw.rectangle([left, top, right, bottom], outline="red", width=3)
82
+
83
+ # Prepare label text
84
+ text = f"{label}: {score:.2f}"
85
+
86
+ # Calculate text size using textbbox
87
+ text_bbox = draw.textbbox((0, 0), text, font=font)
88
+ text_width = text_bbox[2] - text_bbox[0]
89
+ text_height = text_bbox[3] - text_bbox[1]
90
+
91
+ # Calculate label background position
92
+ # Ensure the label box does not go above the image
93
+ label_top = max(top - text_height - 10, 0)
94
+ label_left = left
95
+
96
+ # Draw semi-transparent rectangle behind text
97
+ draw.rectangle(
98
+ [label_left, label_top, label_left + text_width + 10, label_top + text_height + 10],
99
+ fill=(255, 0, 0, 160) # Semi-transparent red
100
+ )
101
+
102
+ # Draw text
103
+ draw.text(
104
+ (label_left + 5, label_top + 5),
105
+ text,
106
+ fill="white",
107
+ font=font
108
+ )
109
 
110
  return image
111
 
112
+ # Initialize model
113
+ model = Model(MODEL_PATH)
114
+
115
+ def detect_objects(image):
116
+ outputs = model.predict(image)
117
+ annotated_image = draw_boxes(image.copy(), outputs)
118
+
119
+ # Prepare detection summary
120
+ detections = []
121
+ boxes = outputs.get('detected_boxes', [])
122
+ classes = outputs.get('detected_classes', [])
123
+ scores = outputs.get('detected_scores', [])
124
+
125
+ for box, cls, score in zip(boxes[0], classes[0], scores[0]):
126
+ if score < PROB_THRESHOLD:
127
+ continue
128
+ label = LABELS[int(cls)]
129
+ detections.append(f"{label}: {score:.2f}")
130
+
131
+ detection_summary = "\n".join(detections) if detections else "No objects detected."
132
+
133
+ return annotated_image, detection_summary
 
 
 
 
134
 
135
+ # Gradio Interface
136
  iface = gr.Interface(
137
+ fn=detect_objects,
138
  inputs=gr.Image(type="pil"),
139
+ outputs=[
140
+ gr.Image(type="pil", label="Detected Objects"),
141
+ gr.Textbox(label="Detections")
142
+ ],
143
+ title="Object Detection with ONNX Model",
144
+ description="Upload an image to detect objects using the ONNX model.",
145
+ examples=["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"],
146
+ theme="default", # You can choose other themes if desired
147
+ allow_flagging="never" # Disable flagging if not needed
148
+ # Removed 'layout' parameter
149
  )
150
 
 
151
  if __name__ == "__main__":
152
  iface.launch()
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
- gradio
2
- numpy
3
- onnxruntime
4
- pillow
 
 
1
+ gradio==3.32.0
2
+ onnx==1.14.0
3
+ onnxruntime==1.15.1
4
+ Pillow>=10.0.0
5
+ numpy==1.25.0