Spaces:
Sleeping
Sleeping
import torch | |
import tensorflow as tf | |
import time | |
import os | |
import logging | |
import queue | |
from pathlib import Path | |
from typing import List, NamedTuple | |
import av | |
import cv2 | |
import numpy as np | |
import streamlit as st | |
from streamlit_webrtc import WebRtcMode, webrtc_streamer | |
from utils.download import download_file | |
from utils.turn import get_ice_servers | |
from PIL import Image, ImageDraw # Import PIL for image processing | |
from transformers import pipeline # Import Hugging Face transformers pipeline | |
import requests | |
from io import BytesIO # Import for handling byte streams | |
# Named tuple to store detection results | |
class Detection(NamedTuple): | |
class_id: int | |
label: str | |
score: float | |
box: np.ndarray | |
# Queue to store detection results | |
result_queue: "queue.Queue[List[Detection]]" = queue.Queue() | |
# CHANGE CODE BELOW HERE, USE TO REPLACE WITH YOUR WANTED ANALYSIS. | |
# Update below string to set display title of analysis | |
# Appropriate imports needed for analysis | |
MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel" | |
MODEL_LOCAL_PATH = Path("./models/MobileNetSSD_deploy.caffemodel") | |
PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt" | |
PROTOTXT_LOCAL_PATH = Path("./models/MobileNetSSD_deploy.prototxt.txt") | |
CLASSES = [ | |
"background", | |
"aeroplane", | |
"bicycle", | |
"bird", | |
"boat", | |
"bottle", | |
"bus", | |
"car", | |
"cat", | |
"chair", | |
"cow", | |
"diningtable", | |
"dog", | |
"horse", | |
"motorbike", | |
"person", | |
"pottedplant", | |
"sheep", | |
"sofa", | |
"train", | |
"tvmonitor", | |
] | |
# Generate random colors for each class label | |
def generate_label_colors(): | |
return np.random.uniform(0, 255, size=(len(CLASSES), 3)) | |
COLORS = generate_label_colors() | |
# Download model and prototxt files | |
def download_file(url, local_path, expected_size=None): | |
if not local_path.exists() or (expected_size and local_path.stat().st_size != expected_size): | |
import requests | |
with open(local_path, "wb") as f: | |
response = requests.get(url) | |
f.write(response.content) | |
download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564) | |
download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353) | |
# Load the model | |
net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH)) | |
# Default title - "Facial Sentiment Analysis" | |
ANALYSIS_TITLE = "Object Detection Analysis" | |
# CHANGE THE CONTENTS OF THIS FUNCTION, USE TO REPLACE WITH YOUR WANTED ANALYSIS. | |
# | |
# Set analysis results in img_container and result queue for display | |
# img_container["input"] - holds the input frame contents - of type np.ndarray | |
# img_container["analyzed"] - holds the analyzed frame with any added annotations - of type np.ndarray | |
# img_container["analysis_time"] - holds how long the analysis has taken in miliseconds | |
# result_queue - holds the analysis metadata results - of type queue.Queue[List[Detection]] | |
def analyze_frame(frame: np.ndarray): | |
start_time = time.time() # Start timing the analysis | |
img_container["input"] = frame # Store the input frame | |
frame = frame.copy() # Create a copy of the frame to modify | |
# Run inference | |
blob = cv2.dnn.blobFromImage( | |
cv2.resize(frame, (300, 300)), 0.007843, (300, 300), 127.5 | |
) | |
net.setInput(blob) | |
output = net.forward() | |
h, w = frame.shape[:2] | |
# Filter the detections based on the score threshold | |
score_threshold = 0.5 # You can adjust the score threshold as needed | |
output = output.squeeze() # (1, 1, N, 7) -> (N, 7) | |
output = output[output[:, 2] >= score_threshold] | |
detections = [ | |
Detection( | |
class_id=int(detection[1]), | |
label=CLASSES[int(detection[1])], | |
score=float(detection[2]), | |
box=(detection[3:7] * np.array([w, h, w, h])), | |
) | |
for detection in output | |
] | |
# Render bounding boxes and captions | |
for detection in detections: | |
caption = f"{detection.label}: {round(detection.score * 100, 2)}%" | |
color = COLORS[detection.class_id] | |
xmin, ymin, xmax, ymax = detection.box.astype("int") | |
cv2.rectangle(frame, (xmin, ymin), (xmax, ymax), color, 2) | |
cv2.putText( | |
frame, | |
caption, | |
(xmin, ymin - 15 if ymin - 15 > 15 else ymin + 15), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
0.5, | |
color, | |
2, | |
) | |
end_time = time.time() # End timing the analysis | |
# Calculate execution time in milliseconds | |
execution_time_ms = round((end_time - start_time) * 1000, 2) | |
# Store the execution time | |
img_container["analysis_time"] = execution_time_ms | |
result_queue.put(detections) # Put the results in the result queue | |
img_container["analyzed"] = frame # Store the analyzed frame | |
return # End of the function | |
# | |
# | |
# DO NOT TOUCH THE BELOW CODE (NOT NEEDED) | |
# | |
# | |
# Suppress FFmpeg logs | |
os.environ["FFMPEG_LOG_LEVEL"] = "quiet" | |
# Suppress TensorFlow or PyTorch progress bars | |
tf.get_logger().setLevel("ERROR") | |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
# Suppress PyTorch logs | |
logging.getLogger().setLevel(logging.WARNING) | |
torch.set_num_threads(1) | |
logging.getLogger("torch").setLevel(logging.ERROR) | |
# Suppress Streamlit logs using the logging module | |
logging.getLogger("streamlit").setLevel(logging.ERROR) | |
# Container to hold image data and analysis results | |
img_container = {"input": None, "analyzed": None, "analysis_time": None} | |
# Initialize MTCNN for face detection | |
mtcnn = MTCNN() | |
# Logger for debugging and information | |
logger = logging.getLogger(__name__) | |
# Callback function to process video frames | |
# This function is called for each video frame in the WebRTC stream. | |
# It converts the frame to a numpy array in RGB format, analyzes the frame, | |
# and returns the original frame. | |
def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame: | |
# Convert frame to numpy array in RGB format | |
img = frame.to_ndarray(format="rgb24") | |
analyze_frame(img) # Analyze the frame | |
return frame # Return the original frame | |
# Get ICE servers for WebRTC | |
ice_servers = get_ice_servers() | |
# Streamlit UI configuration | |
st.set_page_config(layout="wide") | |
# Custom CSS for the Streamlit page | |
st.markdown( | |
""" | |
<style> | |
.main { | |
padding: 2rem; | |
} | |
h1, h2, h3 { | |
font-family: 'Arial', sans-serif; | |
} | |
h1 { | |
font-weight: 700; | |
font-size: 2.5rem; | |
} | |
h2 { | |
font-weight: 600; | |
font-size: 2rem; | |
} | |
h3 { | |
font-weight: 500; | |
font-size: 1.5rem; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Streamlit page title and subtitle | |
st.title("Computer Vision Playground") | |
# Add a link to the README file | |
st.markdown( | |
""" | |
<div style="text-align: left;"> | |
<p>See the <a href="https://huggingface.co/spaces/eusholli/sentiment-analyzer/blob/main/README.md" | |
target="_blank">README</a> to learn how to use this code to help you start your computer vision exploration.</p> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.subheader(ANALYSIS_TITLE) | |
# Columns for input and output streams | |
col1, col2 = st.columns(2) | |
with col1: | |
st.header("Input Stream") | |
st.subheader("input") | |
# WebRTC streamer to get video input from the webcam | |
webrtc_ctx = webrtc_streamer( | |
key="input-webcam", | |
mode=WebRtcMode.SENDRECV, | |
rtc_configuration=ice_servers, | |
video_frame_callback=video_frame_callback, | |
media_stream_constraints={"video": True, "audio": False}, | |
async_processing=True, | |
) | |
# File uploader for images | |
st.subheader("Upload an Image") | |
uploaded_file = st.file_uploader( | |
"Choose an image...", type=["jpg", "jpeg", "png"]) | |
# Text input for image URL | |
st.subheader("Or Enter Image URL") | |
image_url = st.text_input("Image URL") | |
# File uploader for videos | |
st.subheader("Upload a Video") | |
uploaded_video = st.file_uploader( | |
"Choose a video...", type=["mp4", "avi", "mov", "mkv"] | |
) | |
# Text input for video URL | |
st.subheader("Or Enter Video Download URL") | |
video_url = st.text_input("Video URL") | |
# Streamlit footer | |
st.markdown( | |
""" | |
<div style="text-align: center; margin-top: 2rem;"> | |
<p>If you want to set up your own computer vision playground see <a href="https://huggingface.co/spaces/eusholli/computer-vision-playground/blob/main/README.md" target="_blank">here</a>.</p> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
# Function to initialize the analysis UI | |
# This function sets up the placeholders and UI elements in the analysis section. | |
# It creates placeholders for input and output frames, analysis time, and detected labels. | |
def analysis_init(): | |
global analysis_time, show_labels, labels_placeholder, input_placeholder, output_placeholder | |
with col2: | |
st.header("Analysis") | |
st.subheader("Input Frame") | |
input_placeholder = st.empty() # Placeholder for input frame | |
st.subheader("Output Frame") | |
output_placeholder = st.empty() # Placeholder for output frame | |
analysis_time = st.empty() # Placeholder for analysis time | |
show_labels = st.checkbox( | |
"Show the detected labels", value=True | |
) # Checkbox to show/hide labels | |
labels_placeholder = st.empty() # Placeholder for labels | |
# Function to publish frames and results to the Streamlit UI | |
# This function retrieves the latest frames and results from the global container and result queue, | |
# and updates the placeholders in the Streamlit UI with the current input frame, analyzed frame, analysis time, and detected labels. | |
def publish_frame(): | |
if not result_queue.empty(): | |
result = result_queue.get() | |
if show_labels: | |
labels_placeholder.table( | |
result | |
) # Display labels if the checkbox is checked | |
img = img_container["input"] | |
if img is None: | |
return | |
input_placeholder.image(img, channels="RGB") # Display the input frame | |
analyzed = img_container["analyzed"] | |
if analyzed is None: | |
return | |
# Display the analyzed frame | |
output_placeholder.image(analyzed, channels="RGB") | |
time = img_container["analysis_time"] | |
if time is None: | |
return | |
# Display the analysis time | |
analysis_time.text(f"Analysis Time: {time} ms") | |
# If the WebRTC streamer is playing, initialize and publish frames | |
if webrtc_ctx.state.playing: | |
analysis_init() # Initialize the analysis UI | |
while True: | |
publish_frame() # Publish the frames and results | |
time.sleep(0.1) # Delay to control frame rate | |
# If an image is uploaded or a URL is provided, process the image | |
if uploaded_file is not None or image_url: | |
analysis_init() # Initialize the analysis UI | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) # Open the uploaded image | |
img = np.array(image.convert("RGB")) # Convert the image to RGB format | |
else: | |
response = requests.get(image_url) # Download the image from the URL | |
# Open the downloaded image | |
image = Image.open(BytesIO(response.content)) | |
img = np.array(image.convert("RGB")) # Convert the image to RGB format | |
analyze_frame(img) # Analyze the image | |
publish_frame() # Publish the results | |
# Function to process video files | |
# This function reads frames from a video file, analyzes each frame for face detection and sentiment analysis, | |
# and updates the Streamlit UI with the current input frame, analyzed frame, and detected labels. | |
def process_video(video_path): | |
cap = cv2.VideoCapture(video_path) # Open the video file | |
while cap.isOpened(): | |
ret, frame = cap.read() # Read a frame from the video | |
if not ret: | |
break # Exit the loop if no more frames are available | |
# Display the current frame as the input frame | |
input_placeholder.image(frame) | |
analyze_frame( | |
frame | |
) # Analyze the frame for face detection and sentiment analysis | |
publish_frame() # Publish the results | |
if not result_queue.empty(): | |
result = result_queue.get() | |
if show_labels: | |
labels_placeholder.table( | |
result | |
) # Display labels if the checkbox is checked | |
cap.release() # Release the video capture object | |
# If a video is uploaded or a URL is provided, process the video | |
if uploaded_video is not None or video_url: | |
analysis_init() # Initialize the analysis UI | |
if uploaded_video is not None: | |
video_path = uploaded_video.name # Get the name of the uploaded video | |
with open(video_path, "wb") as f: | |
# Save the uploaded video to a file | |
f.write(uploaded_video.getbuffer()) | |
else: | |
# Download the video from the URL | |
video_path = download_file(video_url) | |
process_video(video_path) # Process the video | |