import cv2
import numpy as np
import gradio as gr


# Define Utility Functions From Straight Lane Image.
def draw_lines(img, lines, color=[255, 0, 0], thickness=2):
    """Utility for drawing lines."""
    if lines is not None:
        for line in lines:
            for x1, y1, x2, y2 in line:
                cv2.line(img, (x1, y1), (x2, y2), color, thickness)


def hough_lines(img, rho, theta, threshold, min_line_len, max_line_gap):
    """Utility for defining Line Segments."""
    lines = cv2.HoughLinesP(
        img, rho, theta, threshold, np.array([]), minLineLength=min_line_len, maxLineGap=max_line_gap
    )
    line_img = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8)
    draw_lines(line_img, lines)
    return line_img, lines


def separate_left_right_lines(lines):
    """Separate left and right lines depending on the slope."""
    left_lines = []
    right_lines = []
    if lines is not None:
        for line in lines:
            for x1, y1, x2, y2 in line:
                if x1 == x2:
                    continue  # Avoid division by zero
                slope = (y2 - y1) / (x2 - x1)
                if slope < 0:  # Negative slope = left lane.
                    left_lines.append([x1, y1, x2, y2])
                elif slope > 0:  # Positive slope = right lane.
                    right_lines.append([x1, y1, x2, y2])
    return left_lines, right_lines


def cal_avg(values):
    """Calculate average value."""
    if values is not None:
        if len(values) > 0:
            n = len(values)
        else:
            n = 1
        return sum(values) / n


def extrapolate_lines(lines, upper_border, lower_border):
    """Extrapolate lines keeping in mind the lower and upper border intersections."""
    slopes = []
    consts = []
    if lines:
        for x1, y1, x2, y2 in lines:
            if x1 == x2:
                continue  # Avoid division by zero
            slope = (y2 - y1) / (x2 - x1)
            slopes.append(slope)
            c = y1 - slope * x1
            consts.append(c)
        avg_slope = cal_avg(slopes)
        avg_consts = cal_avg(consts)

        if avg_slope == 0:
            return None

        # Calculate average intersection at lower_border.
        x_lane_lower_point = int((lower_border - avg_consts) / avg_slope)

        # Calculate average intersection at upper_border.
        x_lane_upper_point = int((upper_border - avg_consts) / avg_slope)

        return [x_lane_lower_point, lower_border, x_lane_upper_point, upper_border]
    else:
        return None


def draw_con(img, lines):
    """Fill in lane area."""
    points = []
    if lines is not None:
        for x1, y1, x2, y2 in lines[0]:
            points.append([x1, y1])
            points.append([x2, y2])
        for x1, y1, x2, y2 in lines[1]:
            points.append([x2, y2])
            points.append([x1, y1])
    if points:
        points = np.array([points], dtype="int32")
        cv2.fillPoly(img, points, (0, 255, 0))


def extrapolated_lane_image(img, lines, roi_upper_border, roi_lower_border):
    """Main function called to get the final lane lines."""
    lanes_img = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8)
    # Extract each lane.
    lines_left, lines_right = separate_left_right_lines(lines)
    lane_left = extrapolate_lines(lines_left, roi_upper_border, roi_lower_border)
    lane_right = extrapolate_lines(lines_right, roi_upper_border, roi_lower_border)
    if lane_left is not None and lane_right is not None:
        draw_con(lanes_img, [[lane_left], [lane_right]])
    return lanes_img


def process_image(image, points):
    # process the image
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    gray_select = cv2.inRange(gray, 150, 255)
    # Create mask
    roi_mask = np.zeros_like(gray_select)
    points_array = np.array([points], dtype=np.int32)
    # print('=========')
    # print(points_array)

    # Defining a 3 channel or 1 channel color to fill the mask.
    if len(gray_select.shape) > 2:
        channel_count = gray_select.shape[2]  # 3 or 4 depending on your image.
        ignore_mask_color = (255,) * channel_count
    else:
        ignore_mask_color = 255

    cv2.fillPoly(roi_mask, points_array, ignore_mask_color)
    # cv2.imwrite('mask.png', roi_mask)
    roi_mask = cv2.bitwise_and(gray_select, roi_mask)
    # cv2.imwrite('invmask.png', roi_mask)

    # Canny Edge Detection.
    low_threshold = 50
    high_threshold = 100
    img_canny = cv2.Canny(roi_mask, low_threshold, high_threshold)

    # Remove noise using Gaussian blur.
    kernel_size = 3
    canny_blur = cv2.GaussianBlur(img_canny, (kernel_size, kernel_size), 0)

    # Hough transform parameters set according to the input image.
    rho = 1
    theta = np.pi / 180
    threshold = 100
    min_line_len = 50
    max_line_gap = 300
    hough, lines = hough_lines(canny_blur, rho, theta, threshold, min_line_len, max_line_gap)

    # Extrapolate lanes.
    ys, xs = np.where(roi_mask > 0)
    if len(ys) == 0:
        # No ROI mask, return original image.
        return image
    roi_upper_border = np.min(ys)
    roi_lower_border = np.max(ys)
    lane_img = extrapolated_lane_image(image, lines, roi_upper_border, roi_lower_border)

    # Combine using weighted image.
    image_result = cv2.addWeighted(image, 1, lane_img, 0.4, 0.0)
    # cv2.imshow('result', image_result)
    return image_result


def extract_first_frame_interface(video_file):
    # Read the video file.
    cap = cv2.VideoCapture(video_file)
    if not cap.isOpened():
        print("Error opening video stream or file")
        return None, None
    # Read the first frame.
    ret, frame = cap.read()
    cap.release()
    if not ret:
        print("Cannot read the first frame")
        return None, None
    # Convert the frame to RGB (since OpenCV uses BGR).
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    # Return the frame for display and as the original frame.
    return frame_rgb, frame_rgb  # Return frame twice, once for display, once for state


def get_point_interface(original_frame, points, evt: gr.SelectData):
    x, y = evt.index
    # Ensure points is a list
    if points is None:
        points = []
    points = points.copy()  # Make a copy to avoid modifying in-place
    points.append((x, y))
    # Draw the point and lines on the image
    image = original_frame.copy()
    # Draw the points
    for pt in points:
        cv2.circle(image, pt, 5, (255, 0, 0), -1)
    # Draw the lines
    if len(points) > 1:
        for i in range(len(points) - 1):
            cv2.line(image, points[i], points[i + 1], (255, 0, 0), 2)
    # Optionally, draw line from last to first to close the polygon
    # cv2.line(image, points[-1], points[0], (255, 0, 0), 2)
    # Return the updated image and points
    # print("selected points")
    # print(points)
    return image, points


def process_video_interface(video_file, points):
    # print("=-------------------------------")
    # print(points)
    points = list(points)
    # Ensure points is a list of tuples
    if points is None or len(points) < 3:
        print("Not enough points to define a polygon")
        return None
    # Create the ROI mask
    # Read the first frame to get the image size
    cap = cv2.VideoCapture(video_file)
    if not cap.isOpened():
        print("Error opening video stream or file")
        return None
    frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    frame_fps = int(cap.get(cv2.CAP_PROP_FPS))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")  # For mp4 output.
    output_filename = "processed_output.mp4"
    out = cv2.VideoWriter(output_filename, fourcc, frame_fps, (frame_w, frame_h))
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        # Process the frame using roi_mask
        result = process_image(frame, points)
        out.write(result)
    cap.release()
    out.release()
    return output_filename


# Gradio Interface.
with gr.Blocks(title="Lane Detection using OpenCV", theme=gr.themes.Soft()) as demo:
    gr.HTML(
        """
    <h1 style='text-align: center'>
    Lane Detection using OpenCV
    </h1>
    """
    )
    gr.HTML(
        """
        <h3 style='text-align: center'>
        <a href='https://opencv.org/university/' target='_blank'>OpenCV Courses</a> | <a href='https://github.com/OpenCV-University' target='_blank'>Github</a>
        </h3>
        """
    )
    gr.Markdown(
        "Upload your video, select the four region points to make the ROI yo want to track. Click process video to get the output."
    )
    with gr.Row():
        with gr.Column(scale=1, min_width=300):
            video_input = gr.Video(label="Input Video")
            extract_frame_button = gr.Button("Extract First Frame")
        with gr.Column(scale=1, min_width=300):
            first_frame_image = gr.Image(label="Click to select ROI points")
            original_frame_state = gr.State(None)
            points_state = gr.State([])
            process_button = gr.Button("Process Video")
            clear_points_button = gr.Button("Clear Points")
        with gr.Column(scale=1, min_width=300):
            output_video = gr.Video(label="Processed Video")

    # Extract the first frame and store it
    extract_frame_button.click(
        fn=extract_first_frame_interface, inputs=video_input, outputs=[first_frame_image, original_frame_state]
    )

    # Handle point selection on the image
    first_frame_image.select(
        fn=get_point_interface, inputs=[original_frame_state, points_state], outputs=[first_frame_image, points_state]
    )

    # Clear the selected points
    clear_points_button.click(
        fn=lambda original_frame: (original_frame, []),
        inputs=original_frame_state,
        outputs=[first_frame_image, points_state],
    )

    # Process the video using the selected ROI
    process_button.click(fn=process_video_interface, inputs=[video_input, points_state], outputs=output_video)

    # Adding examples
    gr.Examples(examples=["./lane.mp4"], inputs=video_input)
    gr.HTML(
        """
        <h3 style='text-align: center'>
        Developed with ❤️ by OpenCV
        </h3>
        """
    )


demo.queue().launch(ssr_mode=False)