wybxc commited on
Commit
defb54c
·
verified ·
1 Parent(s): 2a54612

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +45 -2
inference.py CHANGED
@@ -59,9 +59,36 @@ def resize_and_pad_image(image, new_shape, stride=32):
59
  return image
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  class YoloResult:
63
  def __init__(self, boxes, names):
64
  self.boxes = [YoloBox(data=d) for d in boxes]
 
65
  self.names = names
66
 
67
 
@@ -84,23 +111,28 @@ def inference(image):
84
  """
85
 
86
  # Preprocess image
 
87
  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
88
  pix = resize_and_pad_image(image, new_shape=int(image.shape[0] / stride) * stride)
89
  pix = np.transpose(pix, (2, 0, 1)) # CHW
90
  pix = np.expand_dims(pix, axis=0) # BCHW
91
  pix = pix.astype(np.float32) / 255.0 # Normalize to [0, 1]
 
92
 
93
  # Run inference
94
  preds = session.run(None, {"images": pix})[0]
95
 
96
  # Postprocess predictions
97
  preds = preds[preds[..., 4] > 0.25]
 
98
  return YoloResult(boxes=preds, names=names)
99
 
100
 
101
  if __name__ == "__main__":
102
  import sys
 
103
  import matplotlib.pyplot as plt
 
104
 
105
  image = sys.argv[1]
106
  image = cv2.imread(image)
@@ -125,7 +157,18 @@ if __name__ == "__main__":
125
  bitmap[y0:y1, x0:x1] = i + 2
126
  bitmap = bitmap[::-1, :]
127
 
128
- fig, ax = plt.subplots(1, 2, figsize=(10, 6))
 
 
 
 
 
 
 
 
 
 
129
  ax[0].imshow(image)
130
- ax[1].imshow(bitmap)
 
131
  plt.show()
 
59
  return image
60
 
61
 
62
+ def scale_boxes(img1_shape, boxes, img0_shape):
63
+ """
64
+ Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally
65
+ specified in (img1_shape) to the shape of a different image (img0_shape).
66
+
67
+ Args:
68
+ img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
69
+ boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
70
+ img0_shape (tuple): the shape of the target image, in the format of (height, width).
71
+
72
+ Returns:
73
+ boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
74
+ """
75
+
76
+ # Calculate scaling ratio
77
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])
78
+
79
+ # Calculate padding size
80
+ pad_x = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1)
81
+ pad_y = round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1)
82
+
83
+ # Remove padding and scale boxes
84
+ boxes[..., :4] = (boxes[..., :4] - [pad_x, pad_y, pad_x, pad_y]) / gain
85
+ return boxes
86
+
87
+
88
  class YoloResult:
89
  def __init__(self, boxes, names):
90
  self.boxes = [YoloBox(data=d) for d in boxes]
91
+ self.boxes = sorted(self.boxes, key=lambda x: x.conf, reverse=True)
92
  self.names = names
93
 
94
 
 
111
  """
112
 
113
  # Preprocess image
114
+ orig_h, orig_w = image.shape[:2]
115
  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
116
  pix = resize_and_pad_image(image, new_shape=int(image.shape[0] / stride) * stride)
117
  pix = np.transpose(pix, (2, 0, 1)) # CHW
118
  pix = np.expand_dims(pix, axis=0) # BCHW
119
  pix = pix.astype(np.float32) / 255.0 # Normalize to [0, 1]
120
+ new_h, new_w = pix.shape[2:]
121
 
122
  # Run inference
123
  preds = session.run(None, {"images": pix})[0]
124
 
125
  # Postprocess predictions
126
  preds = preds[preds[..., 4] > 0.25]
127
+ preds[..., :4] = scale_boxes((new_h, new_w), preds[..., :4], (orig_h, orig_w))
128
  return YoloResult(boxes=preds, names=names)
129
 
130
 
131
  if __name__ == "__main__":
132
  import sys
133
+ import matplotlib
134
  import matplotlib.pyplot as plt
135
+ import matplotlib.colors as colors
136
 
137
  image = sys.argv[1]
138
  image = cv2.imread(image)
 
157
  bitmap[y0:y1, x0:x1] = i + 2
158
  bitmap = bitmap[::-1, :]
159
 
160
+ # map bitmap to color
161
+ colormap = matplotlib.colormaps["Pastel1"]
162
+ norm = colors.Normalize(vmin=bitmap.min(), vmax=bitmap.max())
163
+ colored_bitmap = colormap(norm(bitmap))
164
+ colored_bitmap = (colored_bitmap[:, :, :3] * 255).astype(np.uint8)
165
+
166
+ # overlay bitmap on image
167
+ image_with_bitmap = cv2.multiply(image, colored_bitmap, scale=1 / 255)
168
+
169
+ # show the results
170
+ fig, ax = plt.subplots(1, 3, figsize=(15, 6))
171
  ax[0].imshow(image)
172
+ ax[1].imshow(bitmap, cmap="Pastel1")
173
+ ax[2].imshow(image_with_bitmap)
174
  plt.show()