Spaces:
Running
Running
# Set the page config | |
import streamlit as st | |
st.set_page_config( | |
page_title="Image Augmentation", | |
page_icon=":open_file_folder:", | |
layout="wide", | |
initial_sidebar_state="collapsed", | |
) | |
# Importing necessary libraries | |
import cv2 | |
import utils | |
import numpy as np | |
import Functions.image_augmentation_functions as augmentation_functions | |
# Load augmentation technique parameters and details from an Excel file | |
augmentation_params_df = utils.load_data_from_excel( | |
"packages_db.xlsx", "augmentation_parameters" | |
) | |
augmentation_details_df = utils.load_data_from_excel( | |
"packages_db.xlsx", "augmentation_details" | |
) | |
# Display the page title | |
st.title("Image Augmentation") | |
# # Clear the Streamlit session state on the first load of the page | |
# utils.clear_session_state_on_first_load("image_augmentation_clear") | |
# List of session state keys to initialize if they are not already present | |
session_state_keys = [ | |
"file_uploader_key_augmentation", | |
"select_processing_technique_key_augmentation", | |
"selected_option_key_augmentation", | |
"class_labels_input_key_augmentation", | |
"bbox1_key", | |
"bbox2_key", | |
"bbox3_key", | |
"bbox4_key", | |
"bbox5_key", | |
] | |
# Iterate through each session state key | |
for key in session_state_keys: | |
# Check if the key is not already in the session state | |
if key not in st.session_state: | |
# Initialize the key with a dictionary containing itself set to True | |
st.session_state[key] = {key: True} | |
# Initialize session state variables if not present | |
if "validation_triggered" not in st.session_state: | |
st.session_state["validation_triggered"] = False | |
if "uploaded_files_cache_augmentation" not in st.session_state: | |
st.session_state["uploaded_files_cache_augmentation"] = False | |
if "zip_data_augmentation" not in st.session_state: | |
st.session_state["zip_data_augmentation"] = "" | |
# Interface for uploading an images and labels | |
utils.display_file_uploader( | |
"uploaded_files", | |
"Choose images and labels...", | |
st.session_state["file_uploader_key_augmentation"], | |
st.session_state["uploaded_files_cache_augmentation"], | |
) | |
# Dropdown for selecting label type | |
label_type = st.selectbox( | |
"Choose the label type for your augmentation process:", | |
["Masks", "Bboxes"], | |
index=1, | |
on_change=utils.reset_validation_trigger, | |
key=st.session_state["selected_option_key_augmentation"], | |
) | |
# Choosing parameters based on the label type selected by the user | |
if label_type == "Bboxes": | |
# If the selected label type is Bboxes, call the bbox_params function | |
label_input_parameters = augmentation_functions.bbox_params() | |
elif label_type == "Masks": | |
# If the selected label type is Masks | |
label_input_parameters = None | |
# Text area for user to input class labels | |
class_labels_input = st.text_area( | |
"Enter class labels, separated by commas:", | |
utils.sample_class_labels, | |
on_change=utils.reset_validation_trigger, | |
key=st.session_state["class_labels_input_key_augmentation"], | |
) # Example default values | |
class_labels_input = ( | |
class_labels_input.strip() | |
) # Remove unecessary space form start and end | |
# Generating a dictionary mapping class IDs to their respective labels | |
try: | |
class_labels = [ | |
label.strip() for label in class_labels_input.split(",") if label.strip() | |
] | |
class_dict = { | |
i + 1: label for i, label in enumerate(class_labels) | |
} # Shifting class labels (keys) by 1, since 0 is reserved for the background | |
# Invert the class_dict to map class names to class IDs | |
class_names_to_ids = {v: k for k, v in class_dict.items()} | |
colors = augmentation_functions.generate_unique_colors(class_dict.keys()) | |
except Exception as e: | |
st.warning( | |
"Invalid format for class labels. Please enter labels separated by commas.", | |
icon="⚠️", | |
) | |
class_dict, class_names_to_ids = ( | |
{}, | |
{}, | |
) # Keeping class_dict and class_names_to_ids as an empty | |
# Note to users | |
st.markdown( | |
""" | |
<div style='text-align: justify;'> | |
<b>Note to Users:</b> | |
<ul> | |
<li>The <i>first uploaded image</i> will be used for demonstration purposes and to validate parameters for augmentation techniques.</li> | |
<li>Uploading <i>labels is optional</i>. If no labels are uploaded, the output will consist solely of processed images.</li> | |
<li>When moving to another page or if you wish to upload a new set of images and labels, don't forget to hit the <b>Reset</b> button. This helps in faster computation and frees up unused memory, ensuring smoother operation.</li> | |
<li>Select the class labels, label type and label parameters before uploading large data for faster computation and more efficient processing.</li> | |
</ul> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
# List of session state variables to initialize | |
session_vars = [ | |
"is_valid", | |
"image_files", | |
"label_files", | |
"first_image_file", | |
"first_label_file", | |
] | |
# Initialize each variable as None if it doesn't exist in the session state | |
for var in session_vars: | |
if var not in st.session_state: | |
st.session_state[var] = None | |
# Create two columns | |
col1, col2 = st.columns(2) | |
# Button to trigger validation | |
if ( | |
col1.button("Validate Input", use_container_width=True) | |
and not st.session_state["validation_triggered"] | |
): | |
st.session_state["validation_triggered"] = True | |
st.session_state["uploaded_files_cache_augmentation"] = True | |
( | |
st.session_state["is_valid"], | |
st.session_state["image_files"], | |
st.session_state["label_files"], | |
st.session_state["first_image_file"], | |
st.session_state["first_label_file"], | |
) = augmentation_functions.check_valid_labels( | |
st.session_state["uploaded_files"], label_type, class_dict | |
) | |
elif st.session_state["validation_triggered"]: | |
pass | |
else: | |
st.session_state["is_valid"] = False | |
st.warning( | |
"Please upload images and labels and click **Validate Input**.", icon="⚠️" | |
) | |
with col2: | |
# Check if the 'Reset' button is pressed | |
if st.button("Reset", use_container_width=True): | |
# Toggle the keys for file uploader and processing technique to reset their states | |
current_value = st.session_state["file_uploader_key_augmentation"][ | |
"file_uploader_key_augmentation" | |
] | |
updated_value = not current_value # Invert the current value | |
# Iterate through each session state key | |
for session_state_key in session_state_keys: | |
# Update each key in the session state with the toggled value | |
st.session_state[session_state_key] = {session_state_key: updated_value} | |
# Clear all other session state keys except for widget_state_keys | |
for key in list(st.session_state.keys()): | |
if key not in session_state_keys: | |
del st.session_state[key] | |
# Clear global variables except for protected and Streamlit module | |
global_vars = list(globals().keys()) | |
vars_to_delete = [ | |
var for var in global_vars if not var.startswith("_") and var != "st" | |
] | |
for var in vars_to_delete: | |
del globals()[var] | |
# Clear the Streamlit caches | |
st.cache_resource.clear() | |
st.cache_data.clear() | |
# Rerun the app to reflect the reset state | |
st.rerun() | |
# Fetching the names of techniques applicable to the selected option | |
available_augmentations = augmentation_functions.get_applicable_techniques( | |
augmentation_details_df, label_type | |
) | |
# Mapping each image processing techniques to its corresponding image types | |
input_mapping_dict = utils.technique_image_input_mapping( | |
available_augmentations, augmentation_details_df | |
) | |
# Present the option to select augmentation techniques only if the uploaded files are validated successfully | |
if st.session_state["is_valid"]: | |
selected_augmentations = st.multiselect( | |
"Select augmentation technique(s)", | |
available_augmentations, | |
key=st.session_state["select_processing_technique_key_augmentation"], | |
) | |
# Read the first uploaded image into a NumPy array | |
st.session_state["first_image_file"].seek(0) # Reset file pointer to start | |
file_bytes_first_image = np.frombuffer( | |
st.session_state["first_image_file"].read(), dtype=np.uint8 | |
) | |
uploaded_first_image = cv2.cvtColor( | |
cv2.imdecode(file_bytes_first_image, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB | |
) | |
# # Resize the image | |
# uploaded_first_image = cv2.resize(uploaded_first_image, (256, 256)) | |
else: | |
# Reset selected techniques to empty if input validation fails | |
selected_augmentations = [] | |
####################################################################################################### | |
# Build custom augmentation pipeline | |
####################################################################################################### | |
# Store parameters for each selected augmentation technique | |
augmentations_params = {} | |
# Initialize a flag to track if any error exists | |
error_in_parameters = False | |
# Loop through each selected augmentation techniques to set up parameters | |
for augmentation in selected_augmentations: | |
with st.expander(f"{augmentation}"): | |
# Retrieve augmentation details from the database | |
augmentation_info = augmentation_details_df[ | |
augmentation_details_df["Name"] == augmentation | |
] | |
# Set up columns for displaying details and image placeholders | |
details_col, image_col = st.columns([7, 3]) | |
with details_col: | |
# Display the description for the augmentation technique | |
augmentation_description = ( | |
augmentation_info["Description"].iloc[0] | |
if not augmentation_info.empty | |
else "No description available." | |
) | |
st.markdown( | |
f"<div style='text-align: justify;'><b>Description:</b> {augmentation_description}</div>", | |
unsafe_allow_html=True, | |
) | |
# Display the category for the augmentation | |
augmentation_category = ( | |
augmentation_info["Category"].iloc[0] | |
if not augmentation_info.empty | |
else "Unknown" | |
) | |
st.write("Category:", augmentation_category) | |
# Retrieve the source code link for the augmentation | |
augmentation_source_code = ( | |
augmentation_info["Source Code Link"].iloc[0] | |
if not augmentation_info.empty | |
else "www.google.com" | |
) | |
# Set up columns for displaying source code button and custom settings checkbox | |
source_code_col, custo_setting_col = st.columns(2) | |
source_code_col.link_button("Source Code", augmentation_source_code) | |
# Toggle for custom settings | |
custom_settings = custo_setting_col.checkbox( | |
f"Customize {augmentation}", key=f"toggle_{augmentation}" | |
) | |
with image_col: | |
# Create two columns | |
col1, col2 = st.columns(2) | |
original_image_placeholder = col1.container(height=150, border=False) | |
processed_image_placeholder = col2.container(height=150, border=False) | |
# Apply custom settings | |
if custom_settings: | |
# Retrieve parameters for the augmentation | |
params_df = augmentation_params_df[ | |
augmentation_params_df["Name"] == augmentation | |
] | |
# Process parameters for each augmentation technique and store in a dictionary | |
augmentations_params[augmentation] = utils.process_image_parameters( | |
params_df, augmentation | |
) | |
else: | |
# Use default settings if customization is not selected | |
augmentations_params[augmentation] = utils.get_default_params(augmentation) | |
# Check for errors in the selected parameters by applying them to a sample image | |
( | |
error_flag, | |
processed_first_image, | |
) = augmentation_functions.apply_and_test_augmentation( | |
augmentation, | |
augmentations_params[augmentation], | |
uploaded_first_image, | |
st.session_state["first_label_file"], | |
label_type, | |
label_input_parameters, | |
input_mapping_dict[augmentation], | |
) | |
# If there is an error in the parameters, set the global error flag | |
if error_flag: | |
error_in_parameters = True | |
else: | |
# If no error, display the original and processed images side by side | |
# Display the original and processed images in their respective placeholders | |
with original_image_placeholder: | |
st.image( | |
uploaded_first_image, | |
caption="Original Image", | |
use_column_width=True, | |
clamp=True, | |
) | |
with processed_image_placeholder: | |
st.image( | |
processed_first_image, | |
caption="Processed Image", | |
use_column_width=True, | |
clamp=True, | |
) | |
# Update the base image with the previously processed image output | |
uploaded_first_image = processed_first_image | |
####################################################################################################### | |
# Display selected augmentation technique parameters as DataFrame | |
####################################################################################################### | |
# Check if any augmentations have been defined | |
if (augmentations_params.keys()) and (not error_in_parameters): | |
# Create a dropdown for selecting an augmentation technique or 'All' | |
selected_augmentation = st.selectbox( | |
"Select augmentation technique", | |
options=["All"] + list(augmentations_params.keys()), | |
) | |
else: | |
selected_augmentation = None | |
# Create the DataFrame from the accumulated data | |
augmentations_df = augmentation_functions.create_augmentations_dataframe( | |
augmentations_params, augmentation_params_df | |
) | |
augmentations_df["Value"] = augmentations_df["Value"].astype( | |
str | |
) # Ensure consistent data types and handle potential serialization issues | |
# Filter the DataFrame based on the selected augmentation | |
if selected_augmentation != "All": | |
filtered_augmentations_df = augmentations_df[ | |
augmentations_df["augmentation"] == selected_augmentation | |
] | |
else: | |
filtered_augmentations_df = augmentations_df | |
# Check if the filtered dataframe is not empty and the selected configurations are valid | |
if (not filtered_augmentations_df.empty) and (not error_in_parameters): | |
# Display the DataFrame in Streamlit and use the full width of the container | |
st.dataframe(filtered_augmentations_df, use_container_width=False) | |
# Display code and description | |
code_placeholder = st.empty() | |
####################################################################################################### | |
# Process images and download processed images | |
####################################################################################################### | |
# Proceed if inputs are valid, techniques selected, and no errors in configurations | |
if ( | |
st.session_state["is_valid"] | |
and (len(selected_augmentations) > 0) | |
and not error_in_parameters | |
): | |
# Create two columns | |
col1, col2 = st.columns(2) | |
# Allow user to specify the number of variations to be generated | |
num_variations = col1.number_input( | |
"Set the number of variations to be generated", | |
min_value=1, | |
max_value=3, | |
step=1, | |
) | |
# Checkbox to include original images and labels in the output | |
with col2: | |
for top_padding in range(2): # Top padding | |
st.write("") | |
include_original = st.checkbox( | |
"Include original images and labels in output", value=False | |
) | |
# Display code and download once all inputs are available | |
with code_placeholder: | |
# Generate the code with the function | |
if len(st.session_state["label_files"]) == 0: | |
generated_code = utils.generate_python_code_images( | |
augmentations_params, | |
num_variations, | |
include_original, | |
) | |
elif label_type == "Bboxes": # Selected label type is Bboxes | |
generated_code = augmentation_functions.generate_python_code_bboxes( | |
augmentations_params, | |
label_input_parameters, | |
num_variations, | |
include_original, | |
) | |
elif label_type == "Masks": # Selected label type is Bboxes | |
generated_code = augmentation_functions.generate_python_code_masks( | |
augmentations_params, | |
label_input_parameters, | |
num_variations, | |
include_original, | |
) | |
# Display the generated Python code with a description and provide a download button in the Streamlit app | |
augmentation_functions.display_code_and_download_button(generated_code) | |
# Create two columns | |
col1, col2 = st.columns(2) | |
# Add a button for the user to confirm their selections and proceed with processing | |
if col1.button("Accept and Process", use_container_width=True): | |
# Call the function and store the results | |
augmentation_functions.process_images_and_labels( | |
st.session_state["image_files"], | |
st.session_state["label_files"], | |
selected_augmentations, | |
augmentations_params, | |
label_type, | |
label_input_parameters, | |
num_variations, | |
include_original, | |
class_dict, | |
) | |
# Download button | |
col2.download_button( | |
label="Download", | |
data=st.session_state["zip_data_augmentation"], | |
file_name="augmented_images.zip", | |
mime="application/zip", | |
use_container_width=True, | |
disabled=False, | |
) | |
else: | |
if (len(selected_augmentations) == 0) and st.session_state["is_valid"]: | |
# Inform the user that no augmentation techniques have been selected | |
st.warning("Please select at least one augmentation technique.", icon="⚠️") | |
if error_in_parameters and st.session_state["is_valid"]: | |
# Inform the user that there are errors in parameters | |
st.warning( | |
"There are errors in the augmentation parameters. Please review your selections.", | |
icon="⚠️", | |
) | |
####################################################################################################### | |
# Display original and processed images | |
####################################################################################################### | |
# Check if image_repository and processed_image_mapping exist in session_state | |
if ( | |
"image_repository_augmentation" in st.session_state | |
and "processed_image_mapping_augmentation" in st.session_state | |
): | |
# Number of unique images | |
num_unique_images = len(st.session_state["unique_images_names"]) | |
if num_unique_images > 1: | |
# Create a slider to select an image index from the processed image mapping | |
selected_image_index = st.slider( | |
"Select an Image", | |
min_value=1, | |
max_value=num_unique_images, # Set the maximum to the number of unique images | |
step=1, | |
) | |
else: | |
selected_image_index = 1 | |
# Retrieve the name of the selected original image using the slider index | |
selected_original_image_name = st.session_state["unique_images_names"][ | |
selected_image_index - 1 | |
] | |
# Retrieve the names of all processed variants for the selected original image | |
processed_variant_names = st.session_state[ | |
"processed_image_mapping_augmentation" | |
].get(selected_original_image_name, []) | |
# Combine the original image name with its processed variants | |
all_image_names = [selected_original_image_name] + processed_variant_names | |
if len(st.session_state["label_files"]) > 0: | |
# Options for displaying labels on the images | |
label_display_options = ["No Label", "All Labels", "Specific Labels"] | |
# Select box for the user to choose how labels should be displayed on the images | |
selected_label_display_option = st.selectbox( | |
"Choose how to display labels:", | |
label_display_options, | |
index=0, # Default option is 'No Label' | |
) | |
# If 'All Labels' option is selected, include all class IDs | |
if selected_label_display_option == "All Labels": | |
labels_to_plot = list(class_dict.keys()) | |
# If 'Specific Labels' option is selected, allow user to select specific class IDs | |
elif selected_label_display_option == "Specific Labels": | |
selected_class_names = st.multiselect( | |
"Select specific labels to display", | |
list(class_names_to_ids.keys()), | |
class_dict[1], | |
) | |
labels_to_plot = [class_names_to_ids[name] for name in selected_class_names] | |
else: | |
selected_label_display_option = "No Label" | |
# Display images in a grid | |
num_images = len(all_image_names) | |
num_columns = 4 | |
for i in range(0, num_images, num_columns): | |
cols = st.columns(num_columns) | |
for j in range(num_columns): | |
image_index = i + j | |
if image_index < num_images: | |
image_name = all_image_names[image_index] | |
image_data = st.session_state["image_repository_augmentation"][ | |
image_name | |
]["image"] | |
label_file = st.session_state["image_repository_augmentation"][ | |
image_name | |
]["label"] | |
# Overlay labels on the image based on the selected option | |
if selected_label_display_option in ["All Labels", "Specific Labels"]: | |
# Overlay labels if selected | |
modified_image = augmentation_functions.overlay_labels( | |
image=image_data.copy(), | |
labels_to_plot=labels_to_plot, | |
label_file=label_file, | |
label_type=label_type, | |
colors=colors, | |
class_dict=class_dict, | |
) | |
else: | |
# Use the original image without overlay if 'No Label' is selected | |
modified_image = image_data | |
# Display the image in the respective column with a caption | |
with cols[j]: | |
st.image( | |
modified_image, | |
clamp=True, | |
caption=image_name, | |
use_column_width=True, | |
) | |