Afnan214's picture
distraction detection
e897273 unverified
import cv2
import streamlit as st
from face_detection import FaceDetector
from mark_detection import MarkDetector
from pose_estimation import PoseEstimator
from utils import refine
def main():
# Streamlit Title and Sidebar for inputs
st.title("Distraction Detection App")
video_src = st.sidebar.selectbox("Select Video Source", ("Webcam", "Video File"))
# If a video file is chosen, provide file uploader
if video_src == "Video File":
video_file = st.sidebar.file_uploader("Upload a Video File", type=["mp4", "avi", "mov"])
if video_file is not None:
video_src = video_file
else:
st.warning("Please upload a video file.")
return
else:
video_src = 0 # Webcam index
# Setup the video capture and detector components
cap = cv2.VideoCapture(video_src if video_src == 0 else video_file)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
face_detector = FaceDetector("assets/face_detector.onnx")
mark_detector = MarkDetector("assets/face_landmarks.onnx")
pose_estimator = PoseEstimator(frame_width, frame_height)
# Streamlit placeholders for images
frame_placeholder = st.empty()
while cap.isOpened():
# Capture a frame
frame_got, frame = cap.read()
if not frame_got:
break
# Flip the frame if from webcam
if video_src == 0:
frame = cv2.flip(frame, 2)
# Face detection and pose estimation
faces, _ = face_detector.detect(frame, 0.7)
if len(faces) > 0:
face = refine(faces, frame_width, frame_height, 0.15)[0]
x1, y1, x2, y2 = face[:4].astype(int)
patch = frame[y1:y2, x1:x2]
marks = mark_detector.detect([patch])[0].reshape([68, 2])
marks *= (x2 - x1)
marks[:, 0] += x1
marks[:, 1] += y1
distraction_status, pose_vectors = pose_estimator.detect_distraction(marks)
status_text = "Distracted" if distraction_status else "Focused"
# Overlay status text
cv2.putText(frame, f"Status: {status_text}", (10, 50),
cv2.FONT_HERSHEY_SIMPLEX, 0.5,
(0, 255, 0) if not distraction_status else (0, 0, 255))
# Display the frame in Streamlit
frame_placeholder.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), channels="RGB")
cap.release()
if __name__ == "__main__":
main()