# Set the page config
import streamlit as st

st.set_page_config(
    page_title="Model_Training",
    page_icon=":open_file_folder:",
    layout="wide",
    initial_sidebar_state="collapsed",
)

# Importing necessary libraries
import utils
import streamlit as st
import Functions.model_training_functions as model_training_functions


# Display the page title
st.title("Model Training")

# # Clear the Streamlit session state on the first load of the page
# utils.clear_session_state_on_first_load("model_training_clear")

# List of session state keys to initialize if they are not already present
session_state_keys = [
    "file_uploader_split_key_training",
    "file_uploader_train_key_training",
    "file_uploader_val_key_training",
    "file_uploader_test_key_training",
    "number_input_train_key",
    "number_input_val_key",
    "number_input_test_key",
    "split_method_key",
    "training_type_key",
    "class_labels_input_key_training",
]

# Iterate through each session state key
for key in session_state_keys:
    # Check if the key is not already in the session state
    if key not in st.session_state:
        # Initialize the key with a dictionary containing itself set to True
        st.session_state[key] = {key: True}

# Initialize session state variables if not present
if "validation_triggered" not in st.session_state:
    st.session_state["validation_triggered"] = False

if "uploaded_files_cache_processing" not in st.session_state:
    st.session_state["uploaded_files_cache_processing"] = False

# Initialize session state variables if not present
if "is_valid" not in st.session_state:
    st.session_state["is_valid"] = False

# Container for file uploaders
file_uploader_container = st.container()

# Dictionary for mapping the user-friendly terms to technical label types
label_type_mapping = {"Object Detection": "Bboxes", "Instance Segmentation": "Masks"}

# Create two columns for widgets
column_select_training, column_split_method = st.columns(2)

# Dropdown for selecting the training type
with column_select_training:
    selected_training = st.selectbox(
        "Select the training type:",
        list(label_type_mapping.keys()),
        index=0,
        on_change=utils.reset_validation_trigger,
        key=st.session_state["training_type_key"],
    )

# Getting the corresponding label type
label_type = label_type_mapping[selected_training]

# Toggle for choosing the split method
with column_split_method:
    split_method = st.radio(
        "Select the dataset split method:",
        ["Percentage Split", "Direct Upload"],
        horizontal=True,
        on_change=utils.reset_validation_trigger,
        key=st.session_state["split_method_key"],
    )

# Text area for user to input class labels
class_labels_input = st.text_area(
    "Enter class labels, separated by commas:",
    utils.sample_class_labels,
    on_change=utils.reset_validation_trigger,
    key=st.session_state["class_labels_input_key_training"],
)  # Example default values
class_labels_input = (
    class_labels_input.strip()
)  # Remove unecessary space form start and end

# Generating a dictionary mapping class IDs to their respective labels
try:
    class_labels = [
        label.strip() for label in class_labels_input.split(",") if label.strip()
    ]
    class_dict = {i: label for i, label in enumerate(class_labels)}
    # Invert the class_dict to map class names to class IDs
    class_names_to_ids = {v: k for k, v in class_dict.items()}

except Exception as e:
    st.warning(
        "Invalid format for class labels. Please enter labels separated by commas.",
        icon="⚠️",
    )
    class_dict, class_names_to_ids = (
        {},
        {},
    )  # Keeping class_dict and class_names_to_ids as an empty

# Note to users
st.markdown(
    """
    <div style='text-align: justify;'>
        <b>Note to Users:</b>
        <ul>
            <li>When moving to another page or if you wish to upload a new set of images and labels, don't forget to hit the <b>Reset</b> button. This helps in faster computation and frees up unused memory, ensuring smoother operation.</li>
            <li>Select the training type, class labels, dataset split method and its parameters before uploading large data for faster computation and more efficient processing.</li>
        </ul>
    </div>
    """,
    unsafe_allow_html=True,
)

# Create two columns for input percentages
validate_button_col, reset_button_col = st.columns(2)

with reset_button_col:
    # Check if the 'Reset' button is pressed
    if st.button("Reset", use_container_width=True):
        # Clear folders
        model_training_functions.delete_and_recreate_folder(
            model_training_functions.get_path("output")
        )
        model_training_functions.clear_data_folders()

        # List of session state keys that need to be reset
        session_state_keys = [
            "file_uploader_split_key_training",
            "file_uploader_train_key_training",
            "file_uploader_val_key_training",
            "file_uploader_test_key_training",
            "number_input_train_key",
            "number_input_val_key",
            "number_input_test_key",
            "split_method_key",
            "training_type_key",
            "class_labels_input_key_training",
        ]

        # Iterate through each session state key
        for session_state_key in session_state_keys:
            # Toggle the keys to reset their states
            current_value = st.session_state[session_state_key][session_state_key]
            updated_value = not current_value  # Invert the current value

            # Update each key in the session state with the toggled value
            st.session_state[session_state_key] = {session_state_key: updated_value}

        # Clear all other session state keys except for widget_state_keys
        for key in list(st.session_state.keys()):
            if key not in session_state_keys:
                del st.session_state[key]

        # Clear global variables except for protected and Streamlit module
        global_vars = list(globals().keys())
        vars_to_delete = [
            var for var in global_vars if not var.startswith("_") and var != "st"
        ]
        for var in vars_to_delete:
            del globals()[var]

        # Clear the Streamlit caches
        st.cache_resource.clear()
        st.cache_data.clear()

        # Rerun the app to reflect the reset state
        st.rerun()

# Code for "Percentage Split" method
if split_method == "Percentage Split":
    with file_uploader_container:
        # User uploads images and labels
        utils.display_file_uploader(
            "uploaded_files",
            "Choose images and labels...",
            st.session_state["file_uploader_split_key_training"],
            st.session_state["uploaded_files_cache_processing"],
        )

        # Create three columns for input percentages
        col1, col2, col3 = st.columns(3)

        # User specifies split percentages
        train_pct = col1.number_input(
            "Train Set Percentage",
            0,
            100,
            70,
            1,
            on_change=utils.reset_validation_trigger,
            key=st.session_state["number_input_train_key"],
        )
        test_pct = col2.number_input(
            "Test Set Percentage",
            0,
            100,
            15,
            1,
            on_change=utils.reset_validation_trigger,
            key=st.session_state["number_input_val_key"],
        )
        val_pct = col3.number_input(
            "Validation Set Percentage",
            0,
            100,
            15,
            1,
            on_change=utils.reset_validation_trigger,
            key=st.session_state["number_input_test_key"],
        )

    # Check if the total percentage equals 100%
    pct_check = train_pct + test_pct + val_pct

    # Validating the input percentages
    pct_condition_check = (
        pct_check == 100
        and train_pct > 0
        and val_pct > 0
        and model_training_functions.check_min_images(
            len(st.session_state["uploaded_files"]), train_pct, val_pct, test_pct
        )
    )

    if not pct_condition_check:
        file_uploader_container.warning(
            "The percentages for train, test, and validation sets should add up to 100%, and train and validation set should not be empty.",
            icon="⚠️",
        )

    # Button to trigger validation
    if validate_button_col.button("Validate Input", use_container_width=True):
        st.session_state["validation_triggered"] = True

        st.session_state["is_valid"] = model_training_functions.check_valid_labels(
            st.session_state["uploaded_files"], label_type, class_dict
        )

        if st.session_state["is_valid"]:
            model_training_functions.create_yolo_config_file(
                model_training_functions.get_path("config"),
                class_labels,
            )
            model_training_functions.clear_data_folders()
            paired_files = model_training_functions.pair_files(
                st.session_state["uploaded_files"]
            )
            model_training_functions.split_and_save_files(
                paired_files, train_pct, test_pct
            )

    # Process files if input is valid
    if st.session_state["validation_triggered"] and (
        pct_condition_check and st.session_state["is_valid"]
    ):
        model_training_functions.start_yolo_training(selected_training, class_labels)
    else:
        # Display a warning message if the validation is not successful or conditions are not met
        st.warning(
            "Please upload valid input, select valid parameters, and click **Validate Input**.",
            icon="⚠️",
        )

# Code for "Direct Upload" method
elif split_method == "Direct Upload":
    with file_uploader_container:
        # Create three columns for uploading train, val, and test files
        col1, col2, col3 = st.columns(3)

        with col1:
            utils.display_file_uploader(
                "uploaded_train_files",
                "Upload Training Images and Labels",
                st.session_state["file_uploader_train_key_training"],
                st.session_state["uploaded_files_cache_processing"],
            )

        with col2:
            utils.display_file_uploader(
                "uploaded_val_files",
                "Upload Validation Images and Labels",
                st.session_state["file_uploader_val_key_training"],
                st.session_state["uploaded_files_cache_processing"],
            )

        with col3:
            utils.display_file_uploader(
                "uploaded_test_files",
                "Upload Test Images and Labels",
                st.session_state["file_uploader_test_key_training"],
                st.session_state["uploaded_files_cache_processing"],
            )

    # Check for valid input
    pct_condition_check = (
        len(st.session_state["uploaded_train_files"]) > 0
        and len(st.session_state["uploaded_val_files"]) > 0
    )

    if not pct_condition_check:
        file_uploader_container.warning(
            "The train and validation set should not be empty.",
            icon="⚠️",
        )

    # Button to trigger validation
    if validate_button_col.button("Validate Input", use_container_width=True):
        st.session_state["validation_triggered"] = True

        st.session_state["is_valid"] = model_training_functions.check_valid_labels(
            st.session_state["uploaded_train_files"]
            + st.session_state["uploaded_val_files"]
            + st.session_state["uploaded_test_files"],
            label_type,
            class_dict,
        )

        if st.session_state["is_valid"]:
            model_training_functions.create_yolo_config_file(
                model_training_functions.get_path("config"),
                class_labels,
            )
            model_training_functions.clear_data_folders()
            model_training_functions.save_files_to_folder(
                st.session_state["uploaded_train_files"], "train"
            )
            model_training_functions.save_files_to_folder(
                st.session_state["uploaded_val_files"], "val"
            )

        # Only save test files if they are uploaded
        if len(st.session_state["uploaded_test_files"]) > 0:
            model_training_functions.save_files_to_folder(
                st.session_state["uploaded_test_files"], "test"
            )

    # Process files if input is valid
    if st.session_state["validation_triggered"] and (
        pct_condition_check and st.session_state["is_valid"]
    ):
        model_training_functions.start_yolo_training(selected_training, class_labels)
    else:
        # Display a warning message if the validation is not successful or conditions are not met
        st.warning(
            "Please upload valid input, select valid parameters, and click **Validate Input**.",
            icon="⚠️",
        )