# Set the page config import streamlit as st st.set_page_config( page_title="Model_Training", page_icon=":open_file_folder:", layout="wide", initial_sidebar_state="collapsed", ) # Importing necessary libraries import utils import streamlit as st import Functions.model_training_functions as model_training_functions # Display the page title st.title("Model Training") # # Clear the Streamlit session state on the first load of the page # utils.clear_session_state_on_first_load("model_training_clear") # List of session state keys to initialize if they are not already present session_state_keys = [ "file_uploader_split_key_training", "file_uploader_train_key_training", "file_uploader_val_key_training", "file_uploader_test_key_training", "number_input_train_key", "number_input_val_key", "number_input_test_key", "split_method_key", "training_type_key", "class_labels_input_key_training", ] # Iterate through each session state key for key in session_state_keys: # Check if the key is not already in the session state if key not in st.session_state: # Initialize the key with a dictionary containing itself set to True st.session_state[key] = {key: True} # Initialize session state variables if not present if "validation_triggered" not in st.session_state: st.session_state["validation_triggered"] = False if "uploaded_files_cache_processing" not in st.session_state: st.session_state["uploaded_files_cache_processing"] = False # Initialize session state variables if not present if "is_valid" not in st.session_state: st.session_state["is_valid"] = False # Container for file uploaders file_uploader_container = st.container() # Dictionary for mapping the user-friendly terms to technical label types label_type_mapping = {"Object Detection": "Bboxes", "Instance Segmentation": "Masks"} # Create two columns for widgets column_select_training, column_split_method = st.columns(2) # Dropdown for selecting the training type with column_select_training: selected_training = st.selectbox( "Select the training type:", list(label_type_mapping.keys()), index=0, on_change=utils.reset_validation_trigger, key=st.session_state["training_type_key"], ) # Getting the corresponding label type label_type = label_type_mapping[selected_training] # Toggle for choosing the split method with column_split_method: split_method = st.radio( "Select the dataset split method:", ["Percentage Split", "Direct Upload"], horizontal=True, on_change=utils.reset_validation_trigger, key=st.session_state["split_method_key"], ) # Text area for user to input class labels class_labels_input = st.text_area( "Enter class labels, separated by commas:", utils.sample_class_labels, on_change=utils.reset_validation_trigger, key=st.session_state["class_labels_input_key_training"], ) # Example default values class_labels_input = ( class_labels_input.strip() ) # Remove unecessary space form start and end # Generating a dictionary mapping class IDs to their respective labels try: class_labels = [ label.strip() for label in class_labels_input.split(",") if label.strip() ] class_dict = {i: label for i, label in enumerate(class_labels)} # Invert the class_dict to map class names to class IDs class_names_to_ids = {v: k for k, v in class_dict.items()} except Exception as e: st.warning( "Invalid format for class labels. Please enter labels separated by commas.", icon="⚠️", ) class_dict, class_names_to_ids = ( {}, {}, ) # Keeping class_dict and class_names_to_ids as an empty # Note to users st.markdown( """ <div style='text-align: justify;'> <b>Note to Users:</b> <ul> <li>When moving to another page or if you wish to upload a new set of images and labels, don't forget to hit the <b>Reset</b> button. This helps in faster computation and frees up unused memory, ensuring smoother operation.</li> <li>Select the training type, class labels, dataset split method and its parameters before uploading large data for faster computation and more efficient processing.</li> </ul> </div> """, unsafe_allow_html=True, ) # Create two columns for input percentages validate_button_col, reset_button_col = st.columns(2) with reset_button_col: # Check if the 'Reset' button is pressed if st.button("Reset", use_container_width=True): # Clear folders model_training_functions.delete_and_recreate_folder( model_training_functions.get_path("output") ) model_training_functions.clear_data_folders() # List of session state keys that need to be reset session_state_keys = [ "file_uploader_split_key_training", "file_uploader_train_key_training", "file_uploader_val_key_training", "file_uploader_test_key_training", "number_input_train_key", "number_input_val_key", "number_input_test_key", "split_method_key", "training_type_key", "class_labels_input_key_training", ] # Iterate through each session state key for session_state_key in session_state_keys: # Toggle the keys to reset their states current_value = st.session_state[session_state_key][session_state_key] updated_value = not current_value # Invert the current value # Update each key in the session state with the toggled value st.session_state[session_state_key] = {session_state_key: updated_value} # Clear all other session state keys except for widget_state_keys for key in list(st.session_state.keys()): if key not in session_state_keys: del st.session_state[key] # Clear global variables except for protected and Streamlit module global_vars = list(globals().keys()) vars_to_delete = [ var for var in global_vars if not var.startswith("_") and var != "st" ] for var in vars_to_delete: del globals()[var] # Clear the Streamlit caches st.cache_resource.clear() st.cache_data.clear() # Rerun the app to reflect the reset state st.rerun() # Code for "Percentage Split" method if split_method == "Percentage Split": with file_uploader_container: # User uploads images and labels utils.display_file_uploader( "uploaded_files", "Choose images and labels...", st.session_state["file_uploader_split_key_training"], st.session_state["uploaded_files_cache_processing"], ) # Create three columns for input percentages col1, col2, col3 = st.columns(3) # User specifies split percentages train_pct = col1.number_input( "Train Set Percentage", 0, 100, 70, 1, on_change=utils.reset_validation_trigger, key=st.session_state["number_input_train_key"], ) test_pct = col2.number_input( "Test Set Percentage", 0, 100, 15, 1, on_change=utils.reset_validation_trigger, key=st.session_state["number_input_val_key"], ) val_pct = col3.number_input( "Validation Set Percentage", 0, 100, 15, 1, on_change=utils.reset_validation_trigger, key=st.session_state["number_input_test_key"], ) # Check if the total percentage equals 100% pct_check = train_pct + test_pct + val_pct # Validating the input percentages pct_condition_check = ( pct_check == 100 and train_pct > 0 and val_pct > 0 and model_training_functions.check_min_images( len(st.session_state["uploaded_files"]), train_pct, val_pct, test_pct ) ) if not pct_condition_check: file_uploader_container.warning( "The percentages for train, test, and validation sets should add up to 100%, and train and validation set should not be empty.", icon="⚠️", ) # Button to trigger validation if validate_button_col.button("Validate Input", use_container_width=True): st.session_state["validation_triggered"] = True st.session_state["is_valid"] = model_training_functions.check_valid_labels( st.session_state["uploaded_files"], label_type, class_dict ) if st.session_state["is_valid"]: model_training_functions.create_yolo_config_file( model_training_functions.get_path("config"), class_labels, ) model_training_functions.clear_data_folders() paired_files = model_training_functions.pair_files( st.session_state["uploaded_files"] ) model_training_functions.split_and_save_files( paired_files, train_pct, test_pct ) # Process files if input is valid if st.session_state["validation_triggered"] and ( pct_condition_check and st.session_state["is_valid"] ): model_training_functions.start_yolo_training(selected_training, class_labels) else: # Display a warning message if the validation is not successful or conditions are not met st.warning( "Please upload valid input, select valid parameters, and click **Validate Input**.", icon="⚠️", ) # Code for "Direct Upload" method elif split_method == "Direct Upload": with file_uploader_container: # Create three columns for uploading train, val, and test files col1, col2, col3 = st.columns(3) with col1: utils.display_file_uploader( "uploaded_train_files", "Upload Training Images and Labels", st.session_state["file_uploader_train_key_training"], st.session_state["uploaded_files_cache_processing"], ) with col2: utils.display_file_uploader( "uploaded_val_files", "Upload Validation Images and Labels", st.session_state["file_uploader_val_key_training"], st.session_state["uploaded_files_cache_processing"], ) with col3: utils.display_file_uploader( "uploaded_test_files", "Upload Test Images and Labels", st.session_state["file_uploader_test_key_training"], st.session_state["uploaded_files_cache_processing"], ) # Check for valid input pct_condition_check = ( len(st.session_state["uploaded_train_files"]) > 0 and len(st.session_state["uploaded_val_files"]) > 0 ) if not pct_condition_check: file_uploader_container.warning( "The train and validation set should not be empty.", icon="⚠️", ) # Button to trigger validation if validate_button_col.button("Validate Input", use_container_width=True): st.session_state["validation_triggered"] = True st.session_state["is_valid"] = model_training_functions.check_valid_labels( st.session_state["uploaded_train_files"] + st.session_state["uploaded_val_files"] + st.session_state["uploaded_test_files"], label_type, class_dict, ) if st.session_state["is_valid"]: model_training_functions.create_yolo_config_file( model_training_functions.get_path("config"), class_labels, ) model_training_functions.clear_data_folders() model_training_functions.save_files_to_folder( st.session_state["uploaded_train_files"], "train" ) model_training_functions.save_files_to_folder( st.session_state["uploaded_val_files"], "val" ) # Only save test files if they are uploaded if len(st.session_state["uploaded_test_files"]) > 0: model_training_functions.save_files_to_folder( st.session_state["uploaded_test_files"], "test" ) # Process files if input is valid if st.session_state["validation_triggered"] and ( pct_condition_check and st.session_state["is_valid"] ): model_training_functions.start_yolo_training(selected_training, class_labels) else: # Display a warning message if the validation is not successful or conditions are not met st.warning( "Please upload valid input, select valid parameters, and click **Validate Input**.", icon="⚠️", )