Spaces:
Sleeping
Sleeping
from super_gradients.training import models | |
from apd_utils import write_video, convert_video | |
import torch, PIL, os | |
import streamlit as st | |
CLASSES = ['Dust Mask', 'Eye Wear', 'Glove', 'Protective Boots', 'Protective Helmet', 'Safety Vest', 'Shield'] | |
SOURCES = ['Images', 'Videos'] | |
# Setting page layout | |
st.set_page_config( | |
page_title="PPE Object Detection using YOLO-NAS", | |
page_icon="π·", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Main page heading | |
st.title("PPE Object Detection using YOLO-NAS") | |
# Sidebar | |
st.sidebar.header("YOLO-NAS Model Config") | |
# Model Options | |
confidence = float(st.sidebar.slider( | |
"Select Model Confidence", 0, 100, 40)) / 100 | |
st.sidebar.header("Image/Video Config") | |
source_radio = st.sidebar.radio("Select Source", SOURCES) | |
source_img = None | |
source_vid = None | |
#with st.spinner('Downloading model..'): | |
#model_url = 'https://drive.google.com/file/d/1XOq3OkpQ3OgibjHmYOCMsQPBtqjdf2i3/view?usp=sharing' | |
#download_model(model_url) | |
model = models.get('yolo_nas_m', | |
num_classes=len(CLASSES), | |
checkpoint_path="./ckpt_best_yolonas.pth") | |
device = 'cuda' if torch.cuda.is_available() else "cpu" | |
device = 'cpu' | |
if source_radio == 'Images': | |
source_img = st.sidebar.file_uploader( | |
"Choose an image...", type=("jpg", "jpeg", "png", 'bmp', 'webp')) | |
col1, col2 = st.columns(2) | |
with col1: | |
try: | |
if source_img is None: | |
st.image('default_img.png', caption="Default Image", | |
use_column_width=True) | |
else: | |
uploaded_image = PIL.Image.open(source_img) | |
st.image(source_img, caption="Uploaded Image", | |
use_column_width=True) | |
except Exception as ex: | |
st.error("Error occurred while opening the image.") | |
st.error(ex) | |
with col2: | |
if source_img is None: | |
st.image('default_img_res.png', caption="Detected Objects", | |
use_column_width=True) | |
else: | |
if st.sidebar.button('Detect Objects'): | |
res = model.to(device).predict(uploaded_image, | |
conf=confidence) | |
st.image(res.draw(), caption='Detected Image', | |
use_column_width=True) | |
elif source_radio == 'Videos': | |
source_vid = st.sidebar.file_uploader( | |
"Choose a video ...", type=("mp4", "mov", "webM")) | |
col1, col2 = st.columns(2) | |
with col1: | |
if source_vid is None: | |
st.image('default_img.png', caption="Default Image", | |
use_column_width=True) | |
else: | |
try: | |
uploaded_video = source_vid.getvalue() | |
st.video(uploaded_video) | |
except Exception as ex: | |
st.error("Error occurred while opening the video.") | |
st.error(ex) | |
with col2: | |
if source_vid is None: | |
st.image('default_img_res.png', caption="Detected Objects", | |
use_column_width=True) | |
else: | |
if st.sidebar.button('Detect Objects'): | |
temp_uploaded_path = write_video(source_vid) | |
res = model.to(device).predict(temp_uploaded_path, conf=confidence) | |
with st.spinner('Processing video ...'): | |
in_temp_res_path = "./temp/result.mp4" | |
out_temp_res_path = "./temp/result2.mp4" | |
res.save(in_temp_res_path) | |
convert_video(in_temp_res_path, out_temp_res_path) | |
st.video(out_temp_res_path) | |
os.remove(temp_uploaded_path) | |
os.remove(in_temp_res_path) | |
os.remove(out_temp_res_path) | |
else: | |
st.error("Please select a valid source type!") | |