# Set the page config
import streamlit as st

st.set_page_config(
    page_title="Image Processing",
    page_icon=":open_file_folder:",
    layout="wide",
    initial_sidebar_state="collapsed",
)

# Importing necessary libraries
import cv2
import utils
import numpy as np
import Functions.image_processing_functions as image_processing_functions

# Load image processing technique parameters and details from an Excel file
image_processing_params_df = utils.load_data_from_excel(
    "packages_db.xlsx", "image_processing_parameters"
)
image_processing_details_df = utils.load_data_from_excel(
    "packages_db.xlsx", "image_processing_details"
)

# Display the page title
st.title("Image Processing")

# # Clear the Streamlit session state on the first load of the page
# utils.clear_session_state_on_first_load("image_processing_clear")

# List of session state keys to initialize if they are not already present
session_state_keys = [
    "file_uploader_key_processing",
    "select_processing_technique_key_processing",
]

# Iterate through each session state key
for key in session_state_keys:
    # Check if the key is not already in the session state
    if key not in st.session_state:
        # Initialize the key with a dictionary containing itself set to True
        st.session_state[key] = {key: True}

# Initialize session state variables if not present
if "validation_triggered" not in st.session_state:
    st.session_state["validation_triggered"] = False

if "uploaded_files_cache_processing" not in st.session_state:
    st.session_state["uploaded_files_cache_processing"] = False

if "zip_data_processing" not in st.session_state:
    st.session_state["zip_data_processing"] = ""

if "widget_states" not in st.session_state:
    st.session_state["widget_states"] = {}

# Interface for uploading an images and labels
utils.display_file_uploader(
    "uploaded_files",
    "Choose images and labels...",
    st.session_state["file_uploader_key_processing"],
    st.session_state["uploaded_files_cache_processing"],
)

# Note to users
st.markdown(
    """
    <div style='text-align: justify;'>
        <b>Note to Users:</b>
        <ul>
            <li>The <i>first uploaded image</i> will be used for demonstration purposes and to validate parameters for image processing techniques.</li>
            <li>Uploading <i>labels is optional</i>. If no labels are uploaded, the output will consist solely of processed images.</li>
            <li>When moving to another page or if you wish to upload a new set of images and labels, don't forget to hit the <b>Reset</b> button. This helps in faster computation and frees up unused memory, ensuring smoother operation.</li>
        </ul>
    </div>
    """,
    unsafe_allow_html=True,
)

# List of session state variables to initialize
session_vars = [
    "is_valid",
    "image_files",
    "label_files",
    "first_image_file",
    "first_label_file",
]

# Initialize each variable as None if it doesn't exist in the session state
for var in session_vars:
    if var not in st.session_state:
        st.session_state[var] = None

# Create two columns
col1, col2 = st.columns(2)

# Button to trigger validation
if (
    col1.button("Validate Input", use_container_width=True)
    or st.session_state["widget_states"].get("validate_input_button", False)
) and not st.session_state["validation_triggered"]:
    st.session_state["validation_triggered"] = True
    st.session_state["uploaded_files_cache_processing"] = True

    (
        st.session_state["is_valid"],
        st.session_state["image_files"],
        st.session_state["label_files"],
        st.session_state["first_image_file"],
        st.session_state["first_label_file"],
    ) = image_processing_functions.check_valid_labels(
        st.session_state["uploaded_files"]
    )

elif st.session_state["validation_triggered"]:
    pass

else:
    st.session_state["is_valid"] = False
    st.warning(
        "Please upload images and labels and click **Validate Input**.", icon="⚠️"
    )

with col2:
    # Check if the 'Reset' button is pressed
    if st.button("Reset", use_container_width=True):
        # Toggle the keys for file uploader and processing technique to reset their states
        current_value = st.session_state["file_uploader_key_processing"][
            "file_uploader_key_processing"
        ]
        updated_value = not current_value  # Invert the current value

        # List of session state keys that need to be reset
        session_state_keys = [
            "file_uploader_key_processing",
            "select_processing_technique_key_processing",
        ]

        # Iterate through each session state key
        for session_state_key in session_state_keys:
            # Update each key in the session state with the toggled value
            st.session_state[session_state_key] = {session_state_key: updated_value}

        # Clear all other session state keys except for widget_state_keys
        for key in list(st.session_state.keys()):
            if key not in session_state_keys:
                del st.session_state[key]

        # Clear global variables except for protected and Streamlit module
        global_vars = list(globals().keys())
        vars_to_delete = [
            var for var in global_vars if not var.startswith("_") and var != "st"
        ]
        for var in vars_to_delete:
            del globals()[var]

        # Clear the Streamlit caches
        st.cache_resource.clear()
        st.cache_data.clear()

        # Rerun the app to reflect the reset state
        st.rerun()

# Interface to select image processing techniques
available_image_processings = image_processing_details_df["Name"]


# Mapping each image processing techniques to its corresponding image types
input_mapping_dict = utils.technique_image_input_mapping(
    available_image_processings, image_processing_details_df
)

# Present the option to select image processing techniques only if the uploaded files are validated successfully
if st.session_state["is_valid"]:
    selected_image_processings = st.multiselect(
        "Select image processing technique(s)",
        available_image_processings,
        key=st.session_state["select_processing_technique_key_processing"],
    )

    # Read the first uploaded image into a NumPy array
    st.session_state["first_image_file"].seek(0)  # Reset file pointer to start
    file_bytes_first_image = np.frombuffer(
        st.session_state["first_image_file"].read(), dtype=np.uint8
    )
    uploaded_first_image = cv2.cvtColor(
        cv2.imdecode(file_bytes_first_image, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB
    )

    # Resize the image
    uploaded_first_image = cv2.resize(uploaded_first_image, (256, 256))

else:
    # Reset selected techniques to empty if input validation fails
    selected_image_processings = []


#######################################################################################################
# Build custom image processing pipeline
#######################################################################################################


# Store parameters for each selected image processing technique
image_processings_params = {}

# Initialize a flag to track if any error exists
error_in_parameters = False

# Loop through each selected image processing techniques to set up parameters
for image_processing in selected_image_processings:
    with st.expander(f"{image_processing}"):
        # Retrieve image processing details from the database
        image_processing_info = image_processing_details_df[
            image_processing_details_df["Name"] == image_processing
        ]

        # Set up columns for displaying details and image placeholders
        details_col, image_col = st.columns([7, 3])

        with details_col:
            # Display the description for the image processing technique
            image_processing_description = (
                image_processing_info["Description"].iloc[0]
                if not image_processing_info.empty
                else "No description available."
            )
            st.markdown(
                f"<div style='text-align: justify;'><b>Description:</b> {image_processing_description}</div>",
                unsafe_allow_html=True,
            )

            # Display the category for the image processing
            image_processing_category = (
                image_processing_info["Category"].iloc[0]
                if not image_processing_info.empty
                else "Unknown"
            )
            st.write("Category:", image_processing_category)

            # Retrieve the source code link for the image processing
            image_processing_source_code = (
                image_processing_info["Source Code Link"].iloc[0]
                if not image_processing_info.empty
                else "www.google.com"
            )
            # Set up columns for displaying source code button and custom settings checkbox
            source_code_col, custo_setting_col = st.columns(2)

            source_code_col.link_button("Source Code", image_processing_source_code)

            # Toggle for custom settings
            custom_settings = custo_setting_col.checkbox(
                f"Customize {image_processing}", key=f"toggle_{image_processing}"
            )

        with image_col:
            # Create two columns
            col1, col2 = st.columns(2)
            original_image_placeholder = col1.container(height=200, border=False)
            processed_image_placeholder = col2.container(height=200, border=False)

        # Apply custom settings
        if custom_settings:
            # Retrieve parameters for the image processing
            params_df = image_processing_params_df[
                image_processing_params_df["Name"] == image_processing
            ]

            # Process parameters for each image processing technique and store in a dictionary
            image_processings_params[image_processing] = utils.process_image_parameters(
                params_df, image_processing
            )

        else:
            # Use default settings if customization is not selected
            image_processings_params[image_processing] = utils.get_default_params(
                image_processing
            )

        # Check for errors in the selected parameters by applying them to a sample image
        (
            error_flag,
            processed_first_image,
        ) = image_processing_functions.apply_and_test_image_processing(
            image_processing,
            image_processings_params[image_processing],
            uploaded_first_image,
            input_mapping_dict[image_processing],
        )

        # If there is an error in the parameters, set the global error flag
        if error_flag:
            error_in_parameters = True
        else:
            # If no error, display the original and processed images side by side
            # Display the original and processed images in their respective placeholders
            with original_image_placeholder:
                st.image(
                    uploaded_first_image,
                    caption="Original Image",
                    use_column_width=True,
                    clamp=True,
                )
            with processed_image_placeholder:
                st.image(
                    processed_first_image,
                    caption="Processed Image",
                    use_column_width=True,
                    clamp=True,
                )

            # Update the base image with the previously processed image output
            uploaded_first_image = processed_first_image


#######################################################################################################
# Display selected image processing technique parameters as DataFrame
#######################################################################################################


# Check if any image processings have been defined
if (image_processings_params.keys()) and (not error_in_parameters):
    # Create a dropdown for selecting an image processing technique or 'All'
    selected_image_processing = st.selectbox(
        "Select image processing technique",
        options=["All"] + list(image_processings_params.keys()),
    )
else:
    selected_image_processing = None

# Create the DataFrame from the accumulated data
image_processings_df = image_processing_functions.create_image_processings_dataframe(
    image_processings_params, image_processing_params_df
)
image_processings_df["Value"] = image_processings_df["Value"].astype(
    str
)  # Ensure consistent data types and handle potential serialization issues

# Filter the DataFrame based on the selected image processing
if selected_image_processing != "All":
    filtered_image_processings_df = image_processings_df[
        image_processings_df["image_processing"] == selected_image_processing
    ]
else:
    filtered_image_processings_df = image_processings_df

# Check if the filtered dataframe is not empty and the selected configurations are valid
if (not filtered_image_processings_df.empty) and (not error_in_parameters):
    # Display the DataFrame in Streamlit and use the full width of the container
    st.dataframe(filtered_image_processings_df, use_container_width=False)

    # Display code and description
    code_placeholder = st.empty()


#######################################################################################################
# Process images and download processed images
#######################################################################################################


# Proceed if inputs are valid, techniques selected, and no errors in configurations
if (
    st.session_state["is_valid"]
    and (len(selected_image_processings) > 0)
    and not error_in_parameters
):
    # Create two columns
    col1, col2 = st.columns(2)

    # Allow user to specify the number of variations to be generated
    num_variations = col1.number_input(
        "Set the number of variations to be generated",
        min_value=1,
        max_value=3,
        step=1,
    )

    # Checkbox to include original images and labels in the output
    with col2:
        for top_padding in range(2):  # Top padding
            st.write("")

        include_original = st.checkbox(
            "Include original images and labels in output", value=False
        )

    # Display code and download once all inputs are available
    with code_placeholder:
        # Generate the code with the function
        if len(st.session_state["label_files"]) == 0:
            generated_code = utils.generate_python_code_images(
                image_processings_params,
                num_variations,
                include_original,
            )
        else:
            generated_code = (
                image_processing_functions.generate_python_code_images_labels(
                    image_processings_params,
                    num_variations,
                    include_original,
                )
            )

        # Display the generated Python code with a description and provide a download button in the Streamlit app
        image_processing_functions.display_code_and_download_button(generated_code)

    # Create two columns
    col1, col2 = st.columns(2)

    # Add a button for the user to confirm their selections and proceed with processing
    if col1.button("Accept and Process", use_container_width=True):
        # Call the function and store the results
        image_processing_functions.process_images_and_labels(
            st.session_state["image_files"],
            st.session_state["label_files"],
            selected_image_processings,
            image_processings_params,
            num_variations,
            include_original,
        )

    # Download button
    col2.download_button(
        label="Download",
        data=st.session_state["zip_data_processing"],
        file_name="processed_images.zip",
        mime="application/zip",
        use_container_width=True,
        disabled=False,
    )


else:
    if (len(selected_image_processings) == 0) and st.session_state["is_valid"]:
        # Inform the user that no image processing techniques have been selected
        st.warning("Please select at least one image processing technique.", icon="⚠️")

    if error_in_parameters and st.session_state["is_valid"]:
        # Inform the user that there are errors in parameters
        st.warning(
            "There are errors in the image processing parameters. Please review your selections.",
            icon="⚠️",
        )


#######################################################################################################
# Display original and processed images
#######################################################################################################


if (
    "image_repository_preprocessing" in st.session_state
    and "processed_image_mapping_procesing" in st.session_state
):
    # Number of unique images
    num_unique_images = len(st.session_state["unique_images_names"])

    if num_unique_images > 1:
        # Create a slider to select an image index from the processed image mapping
        selected_image_index = st.slider(
            "Select an Image",
            min_value=1,
            max_value=num_unique_images,  # Set the maximum to the number of unique images
            step=1,
        )
    else:
        selected_image_index = 1

    # Retrieve the name of the selected original image using the slider index
    selected_original_image_name = st.session_state["unique_images_names"][
        selected_image_index - 1
    ]

    # Retrieve the names of all processed variants for the selected original image
    processed_variant_names = st.session_state["processed_image_mapping_procesing"].get(
        selected_original_image_name, []
    )

    # Combine the original image name with its processed variants
    all_image_names = [selected_original_image_name] + processed_variant_names

    # Number of images and columns
    num_images = len(all_image_names)
    num_columns = 4

    # Display images in a grid of 4 columns and dynamic number of rows
    for i in range(0, num_images, num_columns):
        # Create a row of columns
        cols = st.columns(num_columns)
        for j in range(num_columns):
            # Calculate the current image index
            image_index = i + j
            if image_index < num_images:
                # Get the image name and data from the repository
                image_name = all_image_names[image_index]
                image_data = st.session_state["image_repository_preprocessing"][
                    image_name
                ]["image"]

                # Display the image in the respective column with caption
                with cols[j]:
                    st.image(
                        image_data,
                        clamp=True,
                        caption=image_name,
                        use_column_width=True,
                    )


# if st.button("Run"):
#     utils.button_click(on_click=None)