Hector Lopez commited on
Commit
b80c100
·
1 Parent(s): 5039ab0

feature: Implemented example images and cleaned the interface

Browse files
app.py CHANGED
@@ -27,15 +27,58 @@ def plot_img_no_mask(image, boxes):
27
  ax.imshow(image)
28
  fig.savefig("img.png", bbox_inches='tight')
29
 
 
 
30
  image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  if image_file is not None:
33
- print(image_file)
34
  print('Getting predictions')
35
- data = image_file.read()
36
- pred_dict = predict(model, data)
 
 
 
37
  print('Fixing the preds')
38
- boxes, image = prepare_prediction(pred_dict)
39
  print('Plotting')
40
  plot_img_no_mask(image, boxes)
41
 
 
27
  ax.imshow(image)
28
  fig.savefig("img.png", bbox_inches='tight')
29
 
30
+ st.subheader('Upload Custom Image')
31
+
32
  image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])
33
 
34
+ st.subheader('Example Images')
35
+
36
+ example_imgs = [
37
+ 'example_imgs/basura_4_2.jpg',
38
+ 'example_imgs/basura_2.jpg',
39
+ 'example_imgs/basura_3.jpg'
40
+ ]
41
+
42
+ with st.container() as cont:
43
+ st.image(example_imgs[0], width=150, caption='1')
44
+ if st.button('Select Image', key='Image_1'):
45
+ image_file = example_imgs[0]
46
+
47
+ with st.container() as cont:
48
+ st.image(example_imgs[1], width=150, caption='2')
49
+ if st.button('Select Image', key='Image_2'):
50
+ image_file = example_imgs[1]
51
+
52
+ with st.container() as cont:
53
+ st.image(example_imgs[2], width=150, caption='2')
54
+ if st.button('Select Image', key='Image_3'):
55
+ image_file = example_imgs[2]
56
+
57
+ st.subheader('Detection parameters')
58
+
59
+ detection_threshold = st.slider('Detection threshold',
60
+ min_value=0.0,
61
+ max_value=1.0,
62
+ value=0.5,
63
+ step=0.1)
64
+
65
+ nms_threshold = st.slider('NMS threshold',
66
+ min_value=0.0,
67
+ max_value=1.0,
68
+ value=0.3,
69
+ step=0.1)
70
+
71
+ st.subheader('Prediction')
72
+
73
  if image_file is not None:
 
74
  print('Getting predictions')
75
+ if isinstance(image_file, str):
76
+ data = image_file
77
+ else:
78
+ data = image_file.read()
79
+ pred_dict = predict(model, data, detection_threshold)
80
  print('Fixing the preds')
81
+ boxes, image = prepare_prediction(pred_dict, nms_threshold)
82
  print('Plotting')
83
  plot_img_no_mask(image, boxes)
84
 
example_imgs/basura_1.jpg ADDED
example_imgs/basura_3.jpg ADDED
example_imgs/basura_4_2.jpg ADDED
model.py CHANGED
@@ -36,9 +36,9 @@ def get_checkpoint(checkpoint_path):
36
 
37
  return fixed_state_dict
38
 
39
- def predict(model, image):
40
- #img = PIL.Image.open(image)
41
- img = PIL.Image.open(BytesIO(image))
42
  img = np.array(img)
43
  img = PIL.Image.fromarray(img)
44
 
@@ -52,7 +52,7 @@ def predict(model, image):
52
  transforms,
53
  model,
54
  class_map=class_map,
55
- detection_threshold=0.5,
56
  return_as_pil_img=False,
57
  return_img=True,
58
  display_bbox=False,
@@ -61,7 +61,7 @@ def predict(model, image):
61
 
62
  return pred_dict
63
 
64
- def prepare_prediction(pred_dict):
65
  boxes = [box.to_tensor() for box in pred_dict['detection']['bboxes']]
66
  boxes = torch.stack(boxes)
67
 
@@ -69,7 +69,7 @@ def prepare_prediction(pred_dict):
69
  labels = torch.as_tensor(pred_dict['detection']['label_ids'])
70
  image = np.array(pred_dict['img'])
71
 
72
- fixed_boxes = torchvision.ops.batched_nms(boxes, scores, labels, 0.1)
73
  boxes = boxes[fixed_boxes, :]
74
 
75
  return boxes, image
 
36
 
37
  return fixed_state_dict
38
 
39
+ def predict(model, image, detection_threshold):
40
+ img = PIL.Image.open(image)
41
+ #img = PIL.Image.open(BytesIO(image))
42
  img = np.array(img)
43
  img = PIL.Image.fromarray(img)
44
 
 
52
  transforms,
53
  model,
54
  class_map=class_map,
55
+ detection_threshold=detection_threshold,
56
  return_as_pil_img=False,
57
  return_img=True,
58
  display_bbox=False,
 
61
 
62
  return pred_dict
63
 
64
+ def prepare_prediction(pred_dict, threshold):
65
  boxes = [box.to_tensor() for box in pred_dict['detection']['bboxes']]
66
  boxes = torch.stack(boxes)
67
 
 
69
  labels = torch.as_tensor(pred_dict['detection']['label_ids'])
70
  image = np.array(pred_dict['img'])
71
 
72
+ fixed_boxes = torchvision.ops.batched_nms(boxes, scores, labels, threshold)
73
  boxes = boxes[fixed_boxes, :]
74
 
75
  return boxes, image