Eric P. Nusbaum commited on
Commit
0b444ec
·
1 Parent(s): f60a7c0

Update Space

Browse files
Files changed (1) hide show
  1. app.py +98 -106
app.py CHANGED
@@ -1,152 +1,144 @@
1
  import os
2
  import numpy as np
3
- import onnx
4
  import onnxruntime
5
  from PIL import Image, ImageDraw, ImageFont
6
  import gradio as gr
7
 
8
- # Constants
9
- PROB_THRESHOLD = 0.5 # Minimum probability to show results
10
  MODEL_PATH = os.path.join("onnx", "model.onnx")
11
  LABELS_PATH = os.path.join("onnx", "labels.txt")
12
 
13
  # Load labels
14
  with open(LABELS_PATH, "r") as f:
15
- LABELS = f.read().strip().split("\n")
16
 
 
17
  class Model:
18
  def __init__(self, model_filepath):
 
19
  self.session = onnxruntime.InferenceSession(model_filepath)
20
- assert len(self.session.get_inputs()) == 1
 
 
 
 
21
  self.input_shape = self.session.get_inputs()[0].shape[2:] # (H, W)
22
  self.input_name = self.session.get_inputs()[0].name
23
- self.input_type = {'tensor(float)': np.float32, 'tensor(float16)': np.float16}.get(
24
- self.session.get_inputs()[0].type, np.float32
25
- )
26
- self.output_names = [o.name for o in self.session.get_outputs()]
27
-
 
 
 
 
28
  self.is_bgr = False
29
  self.is_range255 = False
30
- onnx_model = onnx.load(model_filepath)
31
- for metadata in onnx_model.metadata_props:
32
- if metadata.key == 'Image.BitmapPixelFormat' and metadata.value == 'Bgr8':
 
 
33
  self.is_bgr = True
34
- elif metadata.key == 'Image.NominalPixelRange' and metadata.value == 'NominalRange_0_255':
35
  self.is_range255 = True
36
 
37
- def predict(self, image: Image.Image):
38
  # Preprocess image
39
  image_resized = image.resize(self.input_shape)
40
  input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :]
41
  input_array = input_array.transpose((0, 3, 1, 2)) # (N, C, H, W)
 
42
  if self.is_bgr:
43
- input_array = input_array[:, (2, 1, 0), :, :]
 
44
  if not self.is_range255:
45
  input_array = input_array / 255.0 # Normalize to [0,1]
46
-
 
 
 
47
  # Run inference
48
- outputs = self.session.run(self.output_names, {self.input_name: input_array.astype(self.input_type)})
49
- return {name: outputs[i] for i, name in enumerate(self.output_names)}
50
-
51
- def draw_boxes(image: Image.Image, outputs: dict):
52
- draw = ImageDraw.Draw(image, "RGBA") # Use RGBA for transparency
 
 
 
 
 
 
 
 
 
 
53
 
54
- # Dynamic font size based on image dimensions
55
- image_width, image_height = image.size
56
- font_size = max(20, image_width // 50) # Increased minimum font size
57
  try:
58
- # Attempt to load a truetype font; adjust the path if necessary
59
- font = ImageFont.truetype("arial.ttf", size=font_size)
60
  except IOError:
61
- # Fallback to default font if truetype font is not found
62
  font = ImageFont.load_default()
63
 
64
- boxes = outputs.get('detected_boxes', [])
65
- classes = outputs.get('detected_classes', [])
66
- scores = outputs.get('detected_scores', [])
67
-
68
- for box, cls, score in zip(boxes[0], classes[0], scores[0]):
69
- if score < PROB_THRESHOLD:
70
  continue
71
- label = LABELS[int(cls)]
72
-
73
- # Assuming box format: [ymin, xmin, ymax, xmax] normalized [0,1]
74
- ymin, xmin, ymax, xmax = box
75
- left = xmin * image_width
76
- right = xmax * image_width
77
- top = ymin * image_height
78
- bottom = ymax * image_height
79
-
80
- # Draw bounding box
81
- draw.rectangle([left, top, right, bottom], outline="red", width=3)
82
-
83
- # Prepare label text
84
- text = f"{label}: {score:.2f}"
85
-
86
- # Calculate text size using textbbox
87
- text_bbox = draw.textbbox((0, 0), text, font=font)
88
- text_width = text_bbox[2] - text_bbox[0]
89
- text_height = text_bbox[3] - text_bbox[1]
90
-
91
- # Calculate label background position
92
- # Ensure the label box does not go above the image
93
- label_top = max(top - text_height - 10, 0)
94
- label_left = left
95
-
96
- # Draw semi-transparent rectangle behind text
97
- draw.rectangle(
98
- [label_left, label_top, label_left + text_width + 10, label_top + text_height + 10],
99
- fill=(255, 0, 0, 160) # Semi-transparent red
100
- )
101
-
102
- # Draw text
103
- draw.text(
104
- (label_left + 5, label_top + 5),
105
- text,
106
- fill="white",
107
- font=font
108
- )
109
 
110
  return image
111
 
112
- # Initialize model
113
- model = Model(MODEL_PATH)
114
-
115
- def detect_objects(image):
116
- outputs = model.predict(image)
117
- annotated_image = draw_boxes(image.copy(), outputs)
118
-
119
- # Prepare detection summary
120
- detections = []
121
- boxes = outputs.get('detected_boxes', [])
122
- classes = outputs.get('detected_classes', [])
123
- scores = outputs.get('detected_scores', [])
124
-
125
- for box, cls, score in zip(boxes[0], classes[0], scores[0]):
126
- if score < PROB_THRESHOLD:
127
- continue
128
- label = LABELS[int(cls)]
129
- detections.append(f"{label}: {score:.2f}")
130
-
131
- detection_summary = "\n".join(detections) if detections else "No objects detected."
132
-
133
- return annotated_image, detection_summary
134
 
135
- # Gradio Interface
136
  iface = gr.Interface(
137
- fn=detect_objects,
138
  inputs=gr.Image(type="pil"),
139
- outputs=[
140
- gr.Image(type="pil", label="Detected Objects"),
141
- gr.Textbox(label="Detections")
142
- ],
143
- title="Object Detection with ONNX Model",
144
- description="Upload an image to detect objects using the ONNX model.",
145
- examples=["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"],
146
- theme="default", # You can choose other themes if desired
147
- allow_flagging="never" # Disable flagging if not needed
148
- # Removed 'layout' parameter
149
  )
150
 
 
151
  if __name__ == "__main__":
152
  iface.launch()
 
1
  import os
2
  import numpy as np
 
3
  import onnxruntime
4
  from PIL import Image, ImageDraw, ImageFont
5
  import gradio as gr
6
 
7
+ # Define paths
 
8
  MODEL_PATH = os.path.join("onnx", "model.onnx")
9
  LABELS_PATH = os.path.join("onnx", "labels.txt")
10
 
11
  # Load labels
12
  with open(LABELS_PATH, "r") as f:
13
+ LABELS = [line.strip() for line in f.readlines()]
14
 
15
+ # Initialize ONNX Runtime session
16
  class Model:
17
  def __init__(self, model_filepath):
18
+ # Initialize the InferenceSession
19
  self.session = onnxruntime.InferenceSession(model_filepath)
20
+
21
+ # Ensure the model has exactly one input
22
+ assert len(self.session.get_inputs()) == 1, "Model should have exactly one input."
23
+
24
+ # Extract input details
25
  self.input_shape = self.session.get_inputs()[0].shape[2:] # (H, W)
26
  self.input_name = self.session.get_inputs()[0].name
27
+ self.input_type = {
28
+ 'tensor(float)': np.float32,
29
+ 'tensor(float16)': np.float16
30
+ }.get(self.session.get_inputs()[0].type, np.float32)
31
+
32
+ # Extract output names
33
+ self.output_names = [output.name for output in self.session.get_outputs()]
34
+
35
+ # Default preprocessing flags
36
  self.is_bgr = False
37
  self.is_range255 = False
38
+
39
+ # Retrieve metadata from the model
40
+ metadata_map = self.session.get_modelmeta().custom_metadata_map
41
+ for key, value in metadata_map.items():
42
+ if key == 'Image.BitmapPixelFormat' and value == 'Bgr8':
43
  self.is_bgr = True
44
+ elif key == 'Image.NominalPixelRange' and value == 'NominalRange_0_255':
45
  self.is_range255 = True
46
 
47
+ def predict(self, image):
48
  # Preprocess image
49
  image_resized = image.resize(self.input_shape)
50
  input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :]
51
  input_array = input_array.transpose((0, 3, 1, 2)) # (N, C, H, W)
52
+
53
  if self.is_bgr:
54
+ input_array = input_array[:, (2, 1, 0), :, :] # Convert RGB to BGR
55
+
56
  if not self.is_range255:
57
  input_array = input_array / 255.0 # Normalize to [0,1]
58
+
59
+ # Prepare input tensor
60
+ input_tensor = input_array.astype(self.input_type)
61
+
62
  # Run inference
63
+ outputs = self.session.run(self.output_names, {self.input_name: input_tensor})
64
+
65
+ # Process outputs
66
+ # Assuming outputs are in the format: [boxes, labels, scores]
67
+ # Adjust based on your actual model's output format
68
+ if len(outputs) >= 3:
69
+ boxes = outputs[0] # shape: [num_detections, 4]
70
+ labels = outputs[1].astype(int) # shape: [num_detections]
71
+ scores = outputs[2] # shape: [num_detections]
72
+ return boxes, labels, scores
73
+ else:
74
+ raise ValueError("Unexpected number of outputs from the model.")
75
+
76
+ # Load the model
77
+ model = Model(MODEL_PATH)
78
 
79
+ # Function to draw bounding boxes
80
+ def draw_boxes(image, boxes, labels, scores, threshold=0.5):
81
+ draw = ImageDraw.Draw(image)
82
  try:
83
+ font = ImageFont.truetype("arial.ttf", 15)
 
84
  except IOError:
 
85
  font = ImageFont.load_default()
86
 
87
+ for box, label, score in zip(boxes, labels, scores):
88
+ if score < threshold:
 
 
 
 
89
  continue
90
+ # Assuming box format is [xmin, ymin, xmax, ymax] normalized [0,1]
91
+ xmin, ymin, xmax, ymax = box
92
+ width, height = image.size
93
+ xmin = int(xmin * width)
94
+ ymin = int(ymin * height)
95
+ xmax = int(xmax * width)
96
+ ymax = int(ymax * height)
97
+
98
+ # Draw rectangle
99
+ draw.rectangle([(xmin, ymin), (xmax, ymax)], outline="red", width=2)
100
+
101
+ # Draw label
102
+ label_text = f"{LABELS[label]}: {score:.2f}"
103
+ text_size = draw.textsize(label_text, font=font)
104
+ draw.rectangle([(xmin, ymin - text_size[1]), (xmin + text_size[0], ymin)], fill="red")
105
+ draw.text((xmin, ymin - text_size[1]), label_text, fill="white", font=font)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  return image
108
 
109
+ # Prediction function for Gradio
110
+ def predict_image(input_image):
111
+ boxes, labels, scores = model.predict(input_image)
112
+ output_image = input_image.copy()
113
+ output_image = draw_boxes(output_image, boxes, labels, scores, threshold=0.5)
114
+ return output_image
115
+
116
+ # Define Gradio Interface
117
+ def get_example_images():
118
+ examples_dir = "examples"
119
+ return [
120
+ os.path.join(examples_dir, img)
121
+ for img in os.listdir(examples_dir)
122
+ if img.lower().endswith(('.png', '.jpg', '.jpeg'))
123
+ ]
124
+
125
+ example_images = get_example_images()
126
+
127
+ title = "JunkWaxHero: Object Detection for Junk Wax Baseball Cards"
128
+ description = """
129
+ Upload an image of a Junk Wax Baseball Card, and the model will identify the card by its set (1980-1999).
130
+ """
131
 
 
132
  iface = gr.Interface(
133
+ fn=predict_image,
134
  inputs=gr.Image(type="pil"),
135
+ outputs=gr.Image(type="pil"),
136
+ examples=example_images,
137
+ title=title,
138
+ description=description,
139
+ allow_flagging="never"
 
 
 
 
 
140
  )
141
 
142
+ # Launch the interface
143
  if __name__ == "__main__":
144
  iface.launch()