File size: 2,487 Bytes
cd4c90e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab5b42b
cd4c90e
 
 
 
b80c100
 
cd4c90e
 
b80c100
 
 
 
1463eb9
b80c100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd4c90e
 
b80c100
 
 
 
 
cd4c90e
b80c100
cd4c90e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import streamlit as st
import matplotlib.pyplot as plt
import numpy as np
import cv2
import PIL

from model import get_model, predict, prepare_prediction

print('Creating the model')
model = get_model('checkpoint.ckpt')

def plot_img_no_mask(image, boxes):
    # Show image
    boxes = boxes.cpu().detach().numpy().astype(np.int32)
    fig, ax = plt.subplots(1, 1, figsize=(12, 6))

    for i, box in enumerate(boxes):            
        [x1, y1, x2, y2] = np.array(box).astype(int)
        # Si no se hace la copia da error en cv2.rectangle
        image = np.array(image).copy()

        pt1 = (x1, y1)
        pt2 = (x2, y2)
        cv2.rectangle(image, pt1, pt2, (220,0,0), thickness=5)

    plt.axis('off')
    ax.imshow(image)
    fig.savefig("img.png", bbox_inches='tight')

st.subheader('Upload Custom Image')

image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])

st.subheader('Example Images')

example_imgs = [
    'example_imgs/basura_4_2.jpg',
    'example_imgs/basura_1.jpg',
    'example_imgs/basura_3.jpg'
]

with st.container() as cont:
    st.image(example_imgs[0], width=150, caption='1')
    if st.button('Select Image', key='Image_1'):
        image_file = example_imgs[0]

with st.container() as cont:
    st.image(example_imgs[1], width=150, caption='2')
    if st.button('Select Image', key='Image_2'):
        image_file = example_imgs[1]

with st.container() as cont:
    st.image(example_imgs[2], width=150, caption='2')
    if st.button('Select Image', key='Image_3'):
        image_file = example_imgs[2]

st.subheader('Detection parameters')

detection_threshold = st.slider('Detection threshold',
                                min_value=0.0,
                                max_value=1.0,
                                value=0.5,
                                step=0.1)

nms_threshold = st.slider('NMS threshold',
                        min_value=0.0,
                        max_value=1.0,
                        value=0.3,
                        step=0.1)

st.subheader('Prediction')

if image_file is not None:
    print('Getting predictions')
    if isinstance(image_file, str):
        data = image_file
    else:
        data = image_file.read()
    pred_dict = predict(model, data, detection_threshold)
    print('Fixing the preds')
    boxes, image = prepare_prediction(pred_dict, nms_threshold)
    print('Plotting')
    plot_img_no_mask(image, boxes)

    img = PIL.Image.open('img.png')
    st.image(img,width=750)