import streamlit as st import PIL import torch from utils import plot_img_no_mask, get_models from model import predict, prepare_prediction, predict_class DET_CKPT = 'efficientDet_icevision.ckpt' CLASS_CKPT = 'class_ViT_taco_7_class.pth' st.subheader('Upload Custom Image') image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"]) st.subheader('Example Images') example_imgs = [ 'example_imgs/basura_4_2.jpg', 'example_imgs/basura_1.jpg', 'example_imgs/basura_3.jpg' ] with st.container() as cont: st.image(example_imgs[0], width=150, caption='1') if st.button('Select Image', key='Image_1'): image_file = example_imgs[0] with st.container() as cont: st.image(example_imgs[1], width=150, caption='2') if st.button('Select Image', key='Image_2'): image_file = example_imgs[1] with st.container() as cont: st.image(example_imgs[2], width=150, caption='2') if st.button('Select Image', key='Image_3'): image_file = example_imgs[2] st.subheader('Detection parameters') detection_threshold = st.slider('Detection threshold', min_value=0.0, max_value=1.0, value=0.5, step=0.1) nms_threshold = st.slider('NMS threshold', min_value=0.0, max_value=1.0, value=0.3, step=0.1) st.subheader('Prediction') if image_file is not None: det_model, classifier = get_models(DET_CKPT, CLASS_CKPT) print('Getting predictions') pred_dict = predict(det_model, image_file, detection_threshold) print('Fixing the preds') boxes, image = prepare_prediction(pred_dict, nms_threshold) print('Predicting classes') labels = predict_class(classifier, image, boxes) print('Plotting') plot_img_no_mask(image, boxes, labels) img = PIL.Image.open('img.png') st.image(img,width=750)