# Set the page config import streamlit as st st.set_page_config( page_title="Image Augmentation", page_icon=":open_file_folder:", layout="wide", initial_sidebar_state="collapsed", ) # Importing necessary libraries import cv2 import utils import numpy as np import Functions.image_augmentation_functions as augmentation_functions # Load augmentation technique parameters and details from an Excel file augmentation_params_df = utils.load_data_from_excel( "packages_db.xlsx", "augmentation_parameters" ) augmentation_details_df = utils.load_data_from_excel( "packages_db.xlsx", "augmentation_details" ) # Display the page title st.title("Image Augmentation") # # Clear the Streamlit session state on the first load of the page # utils.clear_session_state_on_first_load("image_augmentation_clear") # List of session state keys to initialize if they are not already present session_state_keys = [ "file_uploader_key_augmentation", "select_processing_technique_key_augmentation", "selected_option_key_augmentation", "class_labels_input_key_augmentation", "bbox1_key", "bbox2_key", "bbox3_key", "bbox4_key", "bbox5_key", ] # Iterate through each session state key for key in session_state_keys: # Check if the key is not already in the session state if key not in st.session_state: # Initialize the key with a dictionary containing itself set to True st.session_state[key] = {key: True} # Initialize session state variables if not present if "validation_triggered" not in st.session_state: st.session_state["validation_triggered"] = False if "uploaded_files_cache_augmentation" not in st.session_state: st.session_state["uploaded_files_cache_augmentation"] = False if "zip_data_augmentation" not in st.session_state: st.session_state["zip_data_augmentation"] = "" # Interface for uploading an images and labels utils.display_file_uploader( "uploaded_files", "Choose images and labels...", st.session_state["file_uploader_key_augmentation"], st.session_state["uploaded_files_cache_augmentation"], ) # Dropdown for selecting label type label_type = st.selectbox( "Choose the label type for your augmentation process:", ["Masks", "Bboxes"], index=1, on_change=utils.reset_validation_trigger, key=st.session_state["selected_option_key_augmentation"], ) # Choosing parameters based on the label type selected by the user if label_type == "Bboxes": # If the selected label type is Bboxes, call the bbox_params function label_input_parameters = augmentation_functions.bbox_params() elif label_type == "Masks": # If the selected label type is Masks label_input_parameters = None # Text area for user to input class labels class_labels_input = st.text_area( "Enter class labels, separated by commas:", utils.sample_class_labels, on_change=utils.reset_validation_trigger, key=st.session_state["class_labels_input_key_augmentation"], ) # Example default values class_labels_input = ( class_labels_input.strip() ) # Remove unecessary space form start and end # Generating a dictionary mapping class IDs to their respective labels try: class_labels = [ label.strip() for label in class_labels_input.split(",") if label.strip() ] class_dict = { i + 1: label for i, label in enumerate(class_labels) } # Shifting class labels (keys) by 1, since 0 is reserved for the background # Invert the class_dict to map class names to class IDs class_names_to_ids = {v: k for k, v in class_dict.items()} colors = augmentation_functions.generate_unique_colors(class_dict.keys()) except Exception as e: st.warning( "Invalid format for class labels. Please enter labels separated by commas.", icon="⚠️", ) class_dict, class_names_to_ids = ( {}, {}, ) # Keeping class_dict and class_names_to_ids as an empty # Note to users st.markdown( """
Note to Users:
""", unsafe_allow_html=True, ) # List of session state variables to initialize session_vars = [ "is_valid", "image_files", "label_files", "first_image_file", "first_label_file", ] # Initialize each variable as None if it doesn't exist in the session state for var in session_vars: if var not in st.session_state: st.session_state[var] = None # Create two columns col1, col2 = st.columns(2) # Button to trigger validation if ( col1.button("Validate Input", use_container_width=True) and not st.session_state["validation_triggered"] ): st.session_state["validation_triggered"] = True st.session_state["uploaded_files_cache_augmentation"] = True ( st.session_state["is_valid"], st.session_state["image_files"], st.session_state["label_files"], st.session_state["first_image_file"], st.session_state["first_label_file"], ) = augmentation_functions.check_valid_labels( st.session_state["uploaded_files"], label_type, class_dict ) elif st.session_state["validation_triggered"]: pass else: st.session_state["is_valid"] = False st.warning( "Please upload images and labels and click **Validate Input**.", icon="⚠️" ) with col2: # Check if the 'Reset' button is pressed if st.button("Reset", use_container_width=True): # Toggle the keys for file uploader and processing technique to reset their states current_value = st.session_state["file_uploader_key_augmentation"][ "file_uploader_key_augmentation" ] updated_value = not current_value # Invert the current value # Iterate through each session state key for session_state_key in session_state_keys: # Update each key in the session state with the toggled value st.session_state[session_state_key] = {session_state_key: updated_value} # Clear all other session state keys except for widget_state_keys for key in list(st.session_state.keys()): if key not in session_state_keys: del st.session_state[key] # Clear global variables except for protected and Streamlit module global_vars = list(globals().keys()) vars_to_delete = [ var for var in global_vars if not var.startswith("_") and var != "st" ] for var in vars_to_delete: del globals()[var] # Clear the Streamlit caches st.cache_resource.clear() st.cache_data.clear() # Rerun the app to reflect the reset state st.rerun() # Fetching the names of techniques applicable to the selected option available_augmentations = augmentation_functions.get_applicable_techniques( augmentation_details_df, label_type ) # Mapping each image processing techniques to its corresponding image types input_mapping_dict = utils.technique_image_input_mapping( available_augmentations, augmentation_details_df ) # Present the option to select augmentation techniques only if the uploaded files are validated successfully if st.session_state["is_valid"]: selected_augmentations = st.multiselect( "Select augmentation technique(s)", available_augmentations, key=st.session_state["select_processing_technique_key_augmentation"], ) # Read the first uploaded image into a NumPy array st.session_state["first_image_file"].seek(0) # Reset file pointer to start file_bytes_first_image = np.frombuffer( st.session_state["first_image_file"].read(), dtype=np.uint8 ) uploaded_first_image = cv2.cvtColor( cv2.imdecode(file_bytes_first_image, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB ) # # Resize the image # uploaded_first_image = cv2.resize(uploaded_first_image, (256, 256)) else: # Reset selected techniques to empty if input validation fails selected_augmentations = [] ####################################################################################################### # Build custom augmentation pipeline ####################################################################################################### # Store parameters for each selected augmentation technique augmentations_params = {} # Initialize a flag to track if any error exists error_in_parameters = False # Loop through each selected augmentation techniques to set up parameters for augmentation in selected_augmentations: with st.expander(f"{augmentation}"): # Retrieve augmentation details from the database augmentation_info = augmentation_details_df[ augmentation_details_df["Name"] == augmentation ] # Set up columns for displaying details and image placeholders details_col, image_col = st.columns([7, 3]) with details_col: # Display the description for the augmentation technique augmentation_description = ( augmentation_info["Description"].iloc[0] if not augmentation_info.empty else "No description available." ) st.markdown( f"
Description: {augmentation_description}
", unsafe_allow_html=True, ) # Display the category for the augmentation augmentation_category = ( augmentation_info["Category"].iloc[0] if not augmentation_info.empty else "Unknown" ) st.write("Category:", augmentation_category) # Retrieve the source code link for the augmentation augmentation_source_code = ( augmentation_info["Source Code Link"].iloc[0] if not augmentation_info.empty else "www.google.com" ) # Set up columns for displaying source code button and custom settings checkbox source_code_col, custo_setting_col = st.columns(2) source_code_col.link_button("Source Code", augmentation_source_code) # Toggle for custom settings custom_settings = custo_setting_col.checkbox( f"Customize {augmentation}", key=f"toggle_{augmentation}" ) with image_col: # Create two columns col1, col2 = st.columns(2) original_image_placeholder = col1.container(height=150, border=False) processed_image_placeholder = col2.container(height=150, border=False) # Apply custom settings if custom_settings: # Retrieve parameters for the augmentation params_df = augmentation_params_df[ augmentation_params_df["Name"] == augmentation ] # Process parameters for each augmentation technique and store in a dictionary augmentations_params[augmentation] = utils.process_image_parameters( params_df, augmentation ) else: # Use default settings if customization is not selected augmentations_params[augmentation] = utils.get_default_params(augmentation) # Check for errors in the selected parameters by applying them to a sample image ( error_flag, processed_first_image, ) = augmentation_functions.apply_and_test_augmentation( augmentation, augmentations_params[augmentation], uploaded_first_image, st.session_state["first_label_file"], label_type, label_input_parameters, input_mapping_dict[augmentation], ) # If there is an error in the parameters, set the global error flag if error_flag: error_in_parameters = True else: # If no error, display the original and processed images side by side # Display the original and processed images in their respective placeholders with original_image_placeholder: st.image( uploaded_first_image, caption="Original Image", use_column_width=True, clamp=True, ) with processed_image_placeholder: st.image( processed_first_image, caption="Processed Image", use_column_width=True, clamp=True, ) # Update the base image with the previously processed image output uploaded_first_image = processed_first_image ####################################################################################################### # Display selected augmentation technique parameters as DataFrame ####################################################################################################### # Check if any augmentations have been defined if (augmentations_params.keys()) and (not error_in_parameters): # Create a dropdown for selecting an augmentation technique or 'All' selected_augmentation = st.selectbox( "Select augmentation technique", options=["All"] + list(augmentations_params.keys()), ) else: selected_augmentation = None # Create the DataFrame from the accumulated data augmentations_df = augmentation_functions.create_augmentations_dataframe( augmentations_params, augmentation_params_df ) augmentations_df["Value"] = augmentations_df["Value"].astype( str ) # Ensure consistent data types and handle potential serialization issues # Filter the DataFrame based on the selected augmentation if selected_augmentation != "All": filtered_augmentations_df = augmentations_df[ augmentations_df["augmentation"] == selected_augmentation ] else: filtered_augmentations_df = augmentations_df # Check if the filtered dataframe is not empty and the selected configurations are valid if (not filtered_augmentations_df.empty) and (not error_in_parameters): # Display the DataFrame in Streamlit and use the full width of the container st.dataframe(filtered_augmentations_df, use_container_width=False) # Display code and description code_placeholder = st.empty() ####################################################################################################### # Process images and download processed images ####################################################################################################### # Proceed if inputs are valid, techniques selected, and no errors in configurations if ( st.session_state["is_valid"] and (len(selected_augmentations) > 0) and not error_in_parameters ): # Create two columns col1, col2 = st.columns(2) # Allow user to specify the number of variations to be generated num_variations = col1.number_input( "Set the number of variations to be generated", min_value=1, max_value=3, step=1, ) # Checkbox to include original images and labels in the output with col2: for top_padding in range(2): # Top padding st.write("") include_original = st.checkbox( "Include original images and labels in output", value=False ) # Display code and download once all inputs are available with code_placeholder: # Generate the code with the function if len(st.session_state["label_files"]) == 0: generated_code = utils.generate_python_code_images( augmentations_params, num_variations, include_original, ) elif label_type == "Bboxes": # Selected label type is Bboxes generated_code = augmentation_functions.generate_python_code_bboxes( augmentations_params, label_input_parameters, num_variations, include_original, ) elif label_type == "Masks": # Selected label type is Bboxes generated_code = augmentation_functions.generate_python_code_masks( augmentations_params, label_input_parameters, num_variations, include_original, ) # Display the generated Python code with a description and provide a download button in the Streamlit app augmentation_functions.display_code_and_download_button(generated_code) # Create two columns col1, col2 = st.columns(2) # Add a button for the user to confirm their selections and proceed with processing if col1.button("Accept and Process", use_container_width=True): # Call the function and store the results augmentation_functions.process_images_and_labels( st.session_state["image_files"], st.session_state["label_files"], selected_augmentations, augmentations_params, label_type, label_input_parameters, num_variations, include_original, class_dict, ) # Download button col2.download_button( label="Download", data=st.session_state["zip_data_augmentation"], file_name="augmented_images.zip", mime="application/zip", use_container_width=True, disabled=False, ) else: if (len(selected_augmentations) == 0) and st.session_state["is_valid"]: # Inform the user that no augmentation techniques have been selected st.warning("Please select at least one augmentation technique.", icon="⚠️") if error_in_parameters and st.session_state["is_valid"]: # Inform the user that there are errors in parameters st.warning( "There are errors in the augmentation parameters. Please review your selections.", icon="⚠️", ) ####################################################################################################### # Display original and processed images ####################################################################################################### # Check if image_repository and processed_image_mapping exist in session_state if ( "image_repository_augmentation" in st.session_state and "processed_image_mapping_augmentation" in st.session_state ): # Number of unique images num_unique_images = len(st.session_state["unique_images_names"]) if num_unique_images > 1: # Create a slider to select an image index from the processed image mapping selected_image_index = st.slider( "Select an Image", min_value=1, max_value=num_unique_images, # Set the maximum to the number of unique images step=1, ) else: selected_image_index = 1 # Retrieve the name of the selected original image using the slider index selected_original_image_name = st.session_state["unique_images_names"][ selected_image_index - 1 ] # Retrieve the names of all processed variants for the selected original image processed_variant_names = st.session_state[ "processed_image_mapping_augmentation" ].get(selected_original_image_name, []) # Combine the original image name with its processed variants all_image_names = [selected_original_image_name] + processed_variant_names if len(st.session_state["label_files"]) > 0: # Options for displaying labels on the images label_display_options = ["No Label", "All Labels", "Specific Labels"] # Select box for the user to choose how labels should be displayed on the images selected_label_display_option = st.selectbox( "Choose how to display labels:", label_display_options, index=0, # Default option is 'No Label' ) # If 'All Labels' option is selected, include all class IDs if selected_label_display_option == "All Labels": labels_to_plot = list(class_dict.keys()) # If 'Specific Labels' option is selected, allow user to select specific class IDs elif selected_label_display_option == "Specific Labels": selected_class_names = st.multiselect( "Select specific labels to display", list(class_names_to_ids.keys()), class_dict[1], ) labels_to_plot = [class_names_to_ids[name] for name in selected_class_names] else: selected_label_display_option = "No Label" # Display images in a grid num_images = len(all_image_names) num_columns = 4 for i in range(0, num_images, num_columns): cols = st.columns(num_columns) for j in range(num_columns): image_index = i + j if image_index < num_images: image_name = all_image_names[image_index] image_data = st.session_state["image_repository_augmentation"][ image_name ]["image"] label_file = st.session_state["image_repository_augmentation"][ image_name ]["label"] # Overlay labels on the image based on the selected option if selected_label_display_option in ["All Labels", "Specific Labels"]: # Overlay labels if selected modified_image = augmentation_functions.overlay_labels( image=image_data.copy(), labels_to_plot=labels_to_plot, label_file=label_file, label_type=label_type, colors=colors, class_dict=class_dict, ) else: # Use the original image without overlay if 'No Label' is selected modified_image = image_data # Display the image in the respective column with a caption with cols[j]: st.image( modified_image, clamp=True, caption=image_name, use_column_width=True, )