TrafficVision / app.py
shivvamm's picture
Added Upload
c9ec3dd
import streamlit as st
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
from collections import Counter
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
import io
# Streamlit page configuration
st.set_page_config(
page_title="Object Detector Dashboard",
page_icon="πŸ”",
layout="wide",
initial_sidebar_state="expanded"
)
st.sidebar.title("Traffic Monitoring AI")
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
categories = weights.meta["categories"]
img_preprocess = weights.transforms()
@st.cache_resource
def load_model(threshold):
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=threshold)
model.eval()
return model
def make_prediction(img, model):
img_processed = img_preprocess(img)
prediction = model(img_processed.unsqueeze(0))
prediction = prediction[0]
prediction["labels"] = [categories[label] for label in prediction["labels"]]
return prediction
def create_image_with_bboxes(img, prediction):
img_tensor = torch.tensor(img)
img_with_bboxes = draw_bounding_boxes(
img_tensor,
boxes=prediction["boxes"],
labels=prediction["labels"],
colors=["Green" if label == "person" else "red" for label in prediction["labels"]],
width=1
)
img_with_bboxes_np = img_with_bboxes.detach().numpy().transpose(1, 2, 0)
return img_with_bboxes_np
threshold = st.sidebar.slider("Confidence Threshold", min_value=0.0, max_value=1.0, value=0.1, step=0.01)
st.title("Vehicle Detection")
st.markdown("Upload your images for object detection:")
# Allow users to upload multiple images
uploaded_files = st.file_uploader("Choose images...", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
if uploaded_files:
st.markdown("**Processing uploaded images...**")
# Load the model once
model = load_model(threshold)
all_predictions = []
cols = st.columns(min(len(uploaded_files), 4)) # Create columns for displaying images
with st.spinner("Processing images, please wait..."):
for i, uploaded_file in enumerate(uploaded_files):
try:
img = Image.open(uploaded_file)
img = img.convert("RGB") # Ensure the image is in RGB format
prediction = make_prediction(img, model)
img_with_bbox = create_image_with_bboxes(np.array(img).transpose(2, 0, 1), prediction)
with cols[i % 4]:
st.header(f"Image {i + 1}: Object Detection Results")
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111)
plt.imshow(img_with_bbox)
plt.xticks([], [])
plt.yticks([], [])
ax.spines[["top", "bottom", "right", "left"]].set_visible(True)
st.pyplot(fig, use_container_width=True)
for label in prediction["labels"]:
all_predictions.append({"Image": f"Image {i + 1}", "Label": label})
except Exception as e:
st.error(f"Error processing image {i + 1}: {e}")
if all_predictions:
image_object_counts = []
for i in range(1, len(uploaded_files) + 1):
current_image_preds = [pred['Label'] for pred in all_predictions if pred['Image'] == f"Image {i}"]
object_count = Counter(current_image_preds)
for label, count in object_count.items():
image_object_counts.append({"Image": f"Image {i}", "Label": label, "Count": count})
df_summary = pd.DataFrame(image_object_counts)
vehicle_categories = ['car', 'bus', 'motorcycle', 'truck', 'train', 'bicycle', 'scooter']
df_vehicles = df_summary[df_summary['Label'].isin(vehicle_categories)]
if not df_vehicles.empty:
st.header("Combined Vehicle Detection Table for All Images")
st.dataframe(df_vehicles, use_container_width=True)