import cv2
import streamlit as st
from face_detection import FaceDetector
from mark_detection import MarkDetector
from pose_estimation import PoseEstimator
from utils import refine

def main():
    # Streamlit Title and Sidebar for inputs
    st.title("Distraction Detection App")
    video_src = st.sidebar.selectbox("Select Video Source", ("Webcam", "Video File"))
    
    # If a video file is chosen, provide file uploader
    if video_src == "Video File":
        video_file = st.sidebar.file_uploader("Upload a Video File", type=["mp4", "avi", "mov"])
        if video_file is not None:
            video_src = video_file
        else:
            st.warning("Please upload a video file.")
            return
    else:
        video_src = 0  # Webcam index
    
    # Setup the video capture and detector components
    cap = cv2.VideoCapture(video_src if video_src == 0 else video_file)
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    face_detector = FaceDetector("assets/face_detector.onnx")
    mark_detector = MarkDetector("assets/face_landmarks.onnx")
    pose_estimator = PoseEstimator(frame_width, frame_height)

    # Streamlit placeholders for images
    frame_placeholder = st.empty()
    
    while cap.isOpened():
        # Capture a frame
        frame_got, frame = cap.read()
        if not frame_got:
            break

        # Flip the frame if from webcam
        if video_src == 0:
            frame = cv2.flip(frame, 2)

        # Face detection and pose estimation
        faces, _ = face_detector.detect(frame, 0.7)
        if len(faces) > 0:
            face = refine(faces, frame_width, frame_height, 0.15)[0]
            x1, y1, x2, y2 = face[:4].astype(int)
            patch = frame[y1:y2, x1:x2]
            marks = mark_detector.detect([patch])[0].reshape([68, 2])
            marks *= (x2 - x1)
            marks[:, 0] += x1
            marks[:, 1] += y1
            
            distraction_status, pose_vectors = pose_estimator.detect_distraction(marks)
            status_text = "Distracted" if distraction_status else "Focused"
            
            # Overlay status text
            cv2.putText(frame, f"Status: {status_text}", (10, 50),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, 
                        (0, 255, 0) if not distraction_status else (0, 0, 255))
            
            # Display the frame in Streamlit
            frame_placeholder.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), channels="RGB")

    cap.release()

if __name__ == "__main__":
    main()