Hector Lopez commited on
Commit
f890c24
·
1 Parent(s): ab0c2de

refactor: Using streamlit again

Browse files
Files changed (4) hide show
  1. app.py +56 -45
  2. model.py +5 -1
  3. requirements.txt +0 -1
  4. utils.py +1 -2
app.py CHANGED
@@ -1,64 +1,75 @@
1
- import gradio as gr
2
- from gradio.networking import get_first_available_port
3
  import PIL
4
  import torch
5
- import os
6
 
7
  from utils import plot_img_no_mask, get_models
8
  from classifier import CustomEfficientNet, CustomViT
9
  from model import get_model, predict, prepare_prediction, predict_class
10
 
11
- os.system('pkill -9 python')
12
-
13
  DET_CKPT = 'efficientDet_icevision.ckpt'
14
  CLASS_CKPT = 'class_ViT_taco_7_class.pth'
15
 
16
- def waste_detector_interface(
17
- image,
18
- detection_threshold,
19
- nms_threshold
20
- ):
21
- det_model, classifier = get_models(DET_CKPT, CLASS_CKPT)
22
- print('Getting predictions')
23
- pred_dict = predict(det_model, image, detection_threshold)
24
- print('Fixing the preds')
25
- boxes, image = prepare_prediction(pred_dict, nms_threshold)
26
 
27
- print('Predicting classes')
28
- labels = predict_class(classifier, image, boxes)
29
- print('Plotting')
30
 
31
- return plot_img_no_mask(image, boxes, labels)
32
 
33
- inputs = [
34
- gr.inputs.Image(type="pil", label="Original Image"),
35
- gr.inputs.Number(default=0.5, label="detection_threshold"),
36
- gr.inputs.Number(default=0.5, label="nms_threshold"),
37
- ]
38
 
39
- outputs = [
40
- gr.outputs.Image(type="plot", label="Prediction"),
41
- ]
42
 
43
- title = 'Waste Detection'
44
- description = 'Demo for waste object detection. It detects and classify waste in images according to which rubbish bin the waste should be thrown. Upload an image or click an image to use.'
45
- examples = [
46
- ['example_imgs/basura_4_2.jpg', 0.5, 0.5],
47
- ['example_imgs/basura_1.jpg', 0.5, 0.5],
48
- ['example_imgs/basura_3.jpg', 0.5, 0.5]
49
  ]
50
 
51
- gr.close_all()
52
- #port = get_first_available_port(7682, 9000)
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- gr.Interface(
55
- waste_detector_interface,
56
- inputs,
57
- outputs,
58
- title=title,
59
- description=description,
60
- examples=examples,
61
- theme="huggingface"
62
- ).launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- os.system('python3 app.py')
 
 
1
+ import streamlit as st
 
2
  import PIL
3
  import torch
 
4
 
5
  from utils import plot_img_no_mask, get_models
6
  from classifier import CustomEfficientNet, CustomViT
7
  from model import get_model, predict, prepare_prediction, predict_class
8
 
 
 
9
  DET_CKPT = 'efficientDet_icevision.ckpt'
10
  CLASS_CKPT = 'class_ViT_taco_7_class.pth'
11
 
 
 
 
 
 
 
 
 
 
 
12
 
 
 
 
13
 
14
+ st.subheader('Upload Custom Image')
15
 
16
+ image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])
 
 
 
 
17
 
18
+ st.subheader('Example Images')
 
 
19
 
20
+ example_imgs = [
21
+ 'example_imgs/basura_4_2.jpg',
22
+ 'example_imgs/basura_1.jpg',
23
+ 'example_imgs/basura_3.jpg'
 
 
24
  ]
25
 
26
+ with st.container() as cont:
27
+ st.image(example_imgs[0], width=150, caption='1')
28
+ if st.button('Select Image', key='Image_1'):
29
+ image_file = example_imgs[0]
30
+
31
+ with st.container() as cont:
32
+ st.image(example_imgs[1], width=150, caption='2')
33
+ if st.button('Select Image', key='Image_2'):
34
+ image_file = example_imgs[1]
35
+
36
+ with st.container() as cont:
37
+ st.image(example_imgs[2], width=150, caption='2')
38
+ if st.button('Select Image', key='Image_3'):
39
+ image_file = example_imgs[2]
40
 
41
+ st.subheader('Detection parameters')
42
+
43
+ detection_threshold = st.slider('Detection threshold',
44
+ min_value=0.0,
45
+ max_value=1.0,
46
+ value=0.5,
47
+ step=0.1)
48
+
49
+ nms_threshold = st.slider('NMS threshold',
50
+ min_value=0.0,
51
+ max_value=1.0,
52
+ value=0.3,
53
+ step=0.1)
54
+
55
+ st.subheader('Prediction')
56
+
57
+ if image_file is not None:
58
+ det_model, classifier = get_models(DET_CKPT, CLASS_CKPT)
59
+
60
+ print('Getting predictions')
61
+ if isinstance(image_file, str):
62
+ data = image_file
63
+ else:
64
+ data = image_file.read()
65
+ pred_dict = predict(det_model, data, detection_threshold)
66
+ print('Fixing the preds')
67
+ boxes, image = prepare_prediction(pred_dict, nms_threshold)
68
+
69
+ print('Predicting classes')
70
+ labels = predict_class(classifier, image, boxes)
71
+ print('Plotting')
72
+ plot_img_no_mask(image, boxes, labels)
73
 
74
+ img = PIL.Image.open('img.png')
75
+ st.image(img,width=750)
model.py CHANGED
@@ -39,7 +39,11 @@ def get_checkpoint(checkpoint_path : str):
39
 
40
  return fixed_state_dict
41
 
42
- def predict(model : object, img : Union[str, BytesIO], detection_threshold : float):
 
 
 
 
43
  class_map = ClassMap(classes=['Waste'])
44
  transforms = tfms.A.Adapter([
45
  *tfms.A.resize_and_pad(512),
 
39
 
40
  return fixed_state_dict
41
 
42
+ def predict(model : object, image : Union[str, BytesIO], detection_threshold : float):
43
+ img = PIL.Image.open(image)
44
+ #img = PIL.Image.open(BytesIO(image))
45
+ img = np.array(img)
46
+ img = PIL.Image.fromarray(img)
47
  class_map = ClassMap(classes=['Waste'])
48
  transforms = tfms.A.Adapter([
49
  *tfms.A.resize_and_pad(512),
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  icevision[all]
2
  matplotlib
3
  effdet
4
- gradio
5
  Pillow==8.4.0
 
1
  icevision[all]
2
  matplotlib
3
  effdet
 
4
  Pillow==8.4.0
utils.py CHANGED
@@ -45,11 +45,10 @@ def plot_img_no_mask(image : np.ndarray, boxes : torch.Tensor, labels):
45
  cv2.putText(image, texts[labels[i]], (x1, y1-10),
46
  cv2.FONT_HERSHEY_SIMPLEX, 4, thickness=5, color=color)
47
 
48
-
49
  plt.axis('off')
50
  ax.imshow(image)
51
 
52
- return fig
53
 
54
  def get_models(
55
  detection_ckpt : str,
 
45
  cv2.putText(image, texts[labels[i]], (x1, y1-10),
46
  cv2.FONT_HERSHEY_SIMPLEX, 4, thickness=5, color=color)
47
 
 
48
  plt.axis('off')
49
  ax.imshow(image)
50
 
51
+ fig.savefig("img.png", bbox_inches='tight')
52
 
53
  def get_models(
54
  detection_ckpt : str,