Eric P. Nusbaum commited on
Commit
de40de9
·
1 Parent(s): 32eca4a

Update Space

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py CHANGED
@@ -1,3 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def draw_boxes(image: Image.Image, outputs: dict):
2
  draw = ImageDraw.Draw(image, "RGBA") # Use RGBA for transparency
3
 
@@ -60,3 +110,45 @@ def draw_boxes(image: Image.Image, outputs: dict):
60
  )
61
 
62
  return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import onnx
4
+ import onnxruntime
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ import gradio as gr
7
+
8
+ # Constants
9
+ PROB_THRESHOLD = 0.5 # Minimum probability to show results
10
+ MODEL_PATH = os.path.join("onnx", "model.onnx")
11
+ LABELS_PATH = os.path.join("onnx", "labels.txt")
12
+
13
+ # Load labels
14
+ with open(LABELS_PATH, "r") as f:
15
+ LABELS = f.read().strip().split("\n")
16
+
17
+ class Model:
18
+ def __init__(self, model_filepath):
19
+ self.session = onnxruntime.InferenceSession(model_filepath)
20
+ assert len(self.session.get_inputs()) == 1
21
+ self.input_shape = self.session.get_inputs()[0].shape[2:] # (H, W)
22
+ self.input_name = self.session.get_inputs()[0].name
23
+ self.input_type = {'tensor(float)': np.float32, 'tensor(float16)': np.float16}.get(
24
+ self.session.get_inputs()[0].type, np.float32
25
+ )
26
+ self.output_names = [o.name for o in self.session.get_outputs()]
27
+
28
+ self.is_bgr = False
29
+ self.is_range255 = False
30
+ onnx_model = onnx.load(model_filepath)
31
+ for metadata in onnx_model.metadata_props:
32
+ if metadata.key == 'Image.BitmapPixelFormat' and metadata.value == 'Bgr8':
33
+ self.is_bgr = True
34
+ elif metadata.key == 'Image.NominalPixelRange' and metadata.value == 'NominalRange_0_255':
35
+ self.is_range255 = True
36
+
37
+ def predict(self, image: Image.Image):
38
+ # Preprocess image
39
+ image_resized = image.resize(self.input_shape)
40
+ input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :]
41
+ input_array = input_array.transpose((0, 3, 1, 2)) # (N, C, H, W)
42
+ if self.is_bgr:
43
+ input_array = input_array[:, (2, 1, 0), :, :]
44
+ if not self.is_range255:
45
+ input_array = input_array / 255.0 # Normalize to [0,1]
46
+
47
+ # Run inference
48
+ outputs = self.session.run(self.output_names, {self.input_name: input_array.astype(self.input_type)})
49
+ return {name: outputs[i] for i, name in enumerate(self.output_names)}
50
+
51
  def draw_boxes(image: Image.Image, outputs: dict):
52
  draw = ImageDraw.Draw(image, "RGBA") # Use RGBA for transparency
53
 
 
110
  )
111
 
112
  return image
113
+
114
+ # Initialize model
115
+ model = Model(MODEL_PATH)
116
+
117
+ def detect_objects(image):
118
+ outputs = model.predict(image)
119
+ annotated_image = draw_boxes(image.copy(), outputs)
120
+
121
+ # Prepare detection summary
122
+ detections = []
123
+ boxes = outputs.get('detected_boxes', [])
124
+ classes = outputs.get('detected_classes', [])
125
+ scores = outputs.get('detected_scores', [])
126
+
127
+ for box, cls, score in zip(boxes[0], classes[0], scores[0]):
128
+ if score < PROB_THRESHOLD:
129
+ continue
130
+ label = LABELS[int(cls)]
131
+ detections.append(f"{label}: {score:.2f}")
132
+
133
+ detection_summary = "\n".join(detections) if detections else "No objects detected."
134
+
135
+ return annotated_image, detection_summary
136
+
137
+ # Gradio Interface
138
+ iface = gr.Interface(
139
+ fn=detect_objects,
140
+ inputs=gr.Image(type="pil"),
141
+ outputs=[
142
+ gr.Image(type="pil", label="Detected Objects"),
143
+ gr.Textbox(label="Detections")
144
+ ],
145
+ title="JunkWaxHero - Baseball Card Set Detection (ONNX Model)",
146
+ description="Upload an image to itentify the set of the baseball card.",
147
+ examples=["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"],
148
+ theme="default", # You can choose other themes if desired
149
+ allow_flagging="never" # Disable flagging if not needed
150
+ # Removed 'layout' parameter
151
+ )
152
+
153
+ if __name__ == "__main__":
154
+ iface.launch()