# Set the page config
import streamlit as st
st.set_page_config(
page_title="Image Augmentation",
page_icon=":open_file_folder:",
layout="wide",
initial_sidebar_state="collapsed",
)
# Importing necessary libraries
import cv2
import utils
import numpy as np
import Functions.image_augmentation_functions as augmentation_functions
# Load augmentation technique parameters and details from an Excel file
augmentation_params_df = utils.load_data_from_excel(
"packages_db.xlsx", "augmentation_parameters"
)
augmentation_details_df = utils.load_data_from_excel(
"packages_db.xlsx", "augmentation_details"
)
# Display the page title
st.title("Image Augmentation")
# # Clear the Streamlit session state on the first load of the page
# utils.clear_session_state_on_first_load("image_augmentation_clear")
# List of session state keys to initialize if they are not already present
session_state_keys = [
"file_uploader_key_augmentation",
"select_processing_technique_key_augmentation",
"selected_option_key_augmentation",
"class_labels_input_key_augmentation",
"bbox1_key",
"bbox2_key",
"bbox3_key",
"bbox4_key",
"bbox5_key",
]
# Iterate through each session state key
for key in session_state_keys:
# Check if the key is not already in the session state
if key not in st.session_state:
# Initialize the key with a dictionary containing itself set to True
st.session_state[key] = {key: True}
# Initialize session state variables if not present
if "validation_triggered" not in st.session_state:
st.session_state["validation_triggered"] = False
if "uploaded_files_cache_augmentation" not in st.session_state:
st.session_state["uploaded_files_cache_augmentation"] = False
if "zip_data_augmentation" not in st.session_state:
st.session_state["zip_data_augmentation"] = ""
# Interface for uploading an images and labels
utils.display_file_uploader(
"uploaded_files",
"Choose images and labels...",
st.session_state["file_uploader_key_augmentation"],
st.session_state["uploaded_files_cache_augmentation"],
)
# Dropdown for selecting label type
label_type = st.selectbox(
"Choose the label type for your augmentation process:",
["Masks", "Bboxes"],
index=1,
on_change=utils.reset_validation_trigger,
key=st.session_state["selected_option_key_augmentation"],
)
# Choosing parameters based on the label type selected by the user
if label_type == "Bboxes":
# If the selected label type is Bboxes, call the bbox_params function
label_input_parameters = augmentation_functions.bbox_params()
elif label_type == "Masks":
# If the selected label type is Masks
label_input_parameters = None
# Text area for user to input class labels
class_labels_input = st.text_area(
"Enter class labels, separated by commas:",
utils.sample_class_labels,
on_change=utils.reset_validation_trigger,
key=st.session_state["class_labels_input_key_augmentation"],
) # Example default values
class_labels_input = (
class_labels_input.strip()
) # Remove unecessary space form start and end
# Generating a dictionary mapping class IDs to their respective labels
try:
class_labels = [
label.strip() for label in class_labels_input.split(",") if label.strip()
]
class_dict = {
i + 1: label for i, label in enumerate(class_labels)
} # Shifting class labels (keys) by 1, since 0 is reserved for the background
# Invert the class_dict to map class names to class IDs
class_names_to_ids = {v: k for k, v in class_dict.items()}
colors = augmentation_functions.generate_unique_colors(class_dict.keys())
except Exception as e:
st.warning(
"Invalid format for class labels. Please enter labels separated by commas.",
icon="⚠️",
)
class_dict, class_names_to_ids = (
{},
{},
) # Keeping class_dict and class_names_to_ids as an empty
# Note to users
st.markdown(
"""
Note to Users:
- The first uploaded image will be used for demonstration purposes and to validate parameters for augmentation techniques.
- Uploading labels is optional. If no labels are uploaded, the output will consist solely of processed images.
- When moving to another page or if you wish to upload a new set of images and labels, don't forget to hit the Reset button. This helps in faster computation and frees up unused memory, ensuring smoother operation.
- Select the class labels, label type and label parameters before uploading large data for faster computation and more efficient processing.
""",
unsafe_allow_html=True,
)
# List of session state variables to initialize
session_vars = [
"is_valid",
"image_files",
"label_files",
"first_image_file",
"first_label_file",
]
# Initialize each variable as None if it doesn't exist in the session state
for var in session_vars:
if var not in st.session_state:
st.session_state[var] = None
# Create two columns
col1, col2 = st.columns(2)
# Button to trigger validation
if (
col1.button("Validate Input", use_container_width=True)
and not st.session_state["validation_triggered"]
):
st.session_state["validation_triggered"] = True
st.session_state["uploaded_files_cache_augmentation"] = True
(
st.session_state["is_valid"],
st.session_state["image_files"],
st.session_state["label_files"],
st.session_state["first_image_file"],
st.session_state["first_label_file"],
) = augmentation_functions.check_valid_labels(
st.session_state["uploaded_files"], label_type, class_dict
)
elif st.session_state["validation_triggered"]:
pass
else:
st.session_state["is_valid"] = False
st.warning(
"Please upload images and labels and click **Validate Input**.", icon="⚠️"
)
with col2:
# Check if the 'Reset' button is pressed
if st.button("Reset", use_container_width=True):
# Toggle the keys for file uploader and processing technique to reset their states
current_value = st.session_state["file_uploader_key_augmentation"][
"file_uploader_key_augmentation"
]
updated_value = not current_value # Invert the current value
# Iterate through each session state key
for session_state_key in session_state_keys:
# Update each key in the session state with the toggled value
st.session_state[session_state_key] = {session_state_key: updated_value}
# Clear all other session state keys except for widget_state_keys
for key in list(st.session_state.keys()):
if key not in session_state_keys:
del st.session_state[key]
# Clear global variables except for protected and Streamlit module
global_vars = list(globals().keys())
vars_to_delete = [
var for var in global_vars if not var.startswith("_") and var != "st"
]
for var in vars_to_delete:
del globals()[var]
# Clear the Streamlit caches
st.cache_resource.clear()
st.cache_data.clear()
# Rerun the app to reflect the reset state
st.rerun()
# Fetching the names of techniques applicable to the selected option
available_augmentations = augmentation_functions.get_applicable_techniques(
augmentation_details_df, label_type
)
# Mapping each image processing techniques to its corresponding image types
input_mapping_dict = utils.technique_image_input_mapping(
available_augmentations, augmentation_details_df
)
# Present the option to select augmentation techniques only if the uploaded files are validated successfully
if st.session_state["is_valid"]:
selected_augmentations = st.multiselect(
"Select augmentation technique(s)",
available_augmentations,
key=st.session_state["select_processing_technique_key_augmentation"],
)
# Read the first uploaded image into a NumPy array
st.session_state["first_image_file"].seek(0) # Reset file pointer to start
file_bytes_first_image = np.frombuffer(
st.session_state["first_image_file"].read(), dtype=np.uint8
)
uploaded_first_image = cv2.cvtColor(
cv2.imdecode(file_bytes_first_image, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB
)
# # Resize the image
# uploaded_first_image = cv2.resize(uploaded_first_image, (256, 256))
else:
# Reset selected techniques to empty if input validation fails
selected_augmentations = []
#######################################################################################################
# Build custom augmentation pipeline
#######################################################################################################
# Store parameters for each selected augmentation technique
augmentations_params = {}
# Initialize a flag to track if any error exists
error_in_parameters = False
# Loop through each selected augmentation techniques to set up parameters
for augmentation in selected_augmentations:
with st.expander(f"{augmentation}"):
# Retrieve augmentation details from the database
augmentation_info = augmentation_details_df[
augmentation_details_df["Name"] == augmentation
]
# Set up columns for displaying details and image placeholders
details_col, image_col = st.columns([7, 3])
with details_col:
# Display the description for the augmentation technique
augmentation_description = (
augmentation_info["Description"].iloc[0]
if not augmentation_info.empty
else "No description available."
)
st.markdown(
f"Description: {augmentation_description}
",
unsafe_allow_html=True,
)
# Display the category for the augmentation
augmentation_category = (
augmentation_info["Category"].iloc[0]
if not augmentation_info.empty
else "Unknown"
)
st.write("Category:", augmentation_category)
# Retrieve the source code link for the augmentation
augmentation_source_code = (
augmentation_info["Source Code Link"].iloc[0]
if not augmentation_info.empty
else "www.google.com"
)
# Set up columns for displaying source code button and custom settings checkbox
source_code_col, custo_setting_col = st.columns(2)
source_code_col.link_button("Source Code", augmentation_source_code)
# Toggle for custom settings
custom_settings = custo_setting_col.checkbox(
f"Customize {augmentation}", key=f"toggle_{augmentation}"
)
with image_col:
# Create two columns
col1, col2 = st.columns(2)
original_image_placeholder = col1.container(height=150, border=False)
processed_image_placeholder = col2.container(height=150, border=False)
# Apply custom settings
if custom_settings:
# Retrieve parameters for the augmentation
params_df = augmentation_params_df[
augmentation_params_df["Name"] == augmentation
]
# Process parameters for each augmentation technique and store in a dictionary
augmentations_params[augmentation] = utils.process_image_parameters(
params_df, augmentation
)
else:
# Use default settings if customization is not selected
augmentations_params[augmentation] = utils.get_default_params(augmentation)
# Check for errors in the selected parameters by applying them to a sample image
(
error_flag,
processed_first_image,
) = augmentation_functions.apply_and_test_augmentation(
augmentation,
augmentations_params[augmentation],
uploaded_first_image,
st.session_state["first_label_file"],
label_type,
label_input_parameters,
input_mapping_dict[augmentation],
)
# If there is an error in the parameters, set the global error flag
if error_flag:
error_in_parameters = True
else:
# If no error, display the original and processed images side by side
# Display the original and processed images in their respective placeholders
with original_image_placeholder:
st.image(
uploaded_first_image,
caption="Original Image",
use_column_width=True,
clamp=True,
)
with processed_image_placeholder:
st.image(
processed_first_image,
caption="Processed Image",
use_column_width=True,
clamp=True,
)
# Update the base image with the previously processed image output
uploaded_first_image = processed_first_image
#######################################################################################################
# Display selected augmentation technique parameters as DataFrame
#######################################################################################################
# Check if any augmentations have been defined
if (augmentations_params.keys()) and (not error_in_parameters):
# Create a dropdown for selecting an augmentation technique or 'All'
selected_augmentation = st.selectbox(
"Select augmentation technique",
options=["All"] + list(augmentations_params.keys()),
)
else:
selected_augmentation = None
# Create the DataFrame from the accumulated data
augmentations_df = augmentation_functions.create_augmentations_dataframe(
augmentations_params, augmentation_params_df
)
augmentations_df["Value"] = augmentations_df["Value"].astype(
str
) # Ensure consistent data types and handle potential serialization issues
# Filter the DataFrame based on the selected augmentation
if selected_augmentation != "All":
filtered_augmentations_df = augmentations_df[
augmentations_df["augmentation"] == selected_augmentation
]
else:
filtered_augmentations_df = augmentations_df
# Check if the filtered dataframe is not empty and the selected configurations are valid
if (not filtered_augmentations_df.empty) and (not error_in_parameters):
# Display the DataFrame in Streamlit and use the full width of the container
st.dataframe(filtered_augmentations_df, use_container_width=False)
# Display code and description
code_placeholder = st.empty()
#######################################################################################################
# Process images and download processed images
#######################################################################################################
# Proceed if inputs are valid, techniques selected, and no errors in configurations
if (
st.session_state["is_valid"]
and (len(selected_augmentations) > 0)
and not error_in_parameters
):
# Create two columns
col1, col2 = st.columns(2)
# Allow user to specify the number of variations to be generated
num_variations = col1.number_input(
"Set the number of variations to be generated",
min_value=1,
max_value=3,
step=1,
)
# Checkbox to include original images and labels in the output
with col2:
for top_padding in range(2): # Top padding
st.write("")
include_original = st.checkbox(
"Include original images and labels in output", value=False
)
# Display code and download once all inputs are available
with code_placeholder:
# Generate the code with the function
if len(st.session_state["label_files"]) == 0:
generated_code = utils.generate_python_code_images(
augmentations_params,
num_variations,
include_original,
)
elif label_type == "Bboxes": # Selected label type is Bboxes
generated_code = augmentation_functions.generate_python_code_bboxes(
augmentations_params,
label_input_parameters,
num_variations,
include_original,
)
elif label_type == "Masks": # Selected label type is Bboxes
generated_code = augmentation_functions.generate_python_code_masks(
augmentations_params,
label_input_parameters,
num_variations,
include_original,
)
# Display the generated Python code with a description and provide a download button in the Streamlit app
augmentation_functions.display_code_and_download_button(generated_code)
# Create two columns
col1, col2 = st.columns(2)
# Add a button for the user to confirm their selections and proceed with processing
if col1.button("Accept and Process", use_container_width=True):
# Call the function and store the results
augmentation_functions.process_images_and_labels(
st.session_state["image_files"],
st.session_state["label_files"],
selected_augmentations,
augmentations_params,
label_type,
label_input_parameters,
num_variations,
include_original,
class_dict,
)
# Download button
col2.download_button(
label="Download",
data=st.session_state["zip_data_augmentation"],
file_name="augmented_images.zip",
mime="application/zip",
use_container_width=True,
disabled=False,
)
else:
if (len(selected_augmentations) == 0) and st.session_state["is_valid"]:
# Inform the user that no augmentation techniques have been selected
st.warning("Please select at least one augmentation technique.", icon="⚠️")
if error_in_parameters and st.session_state["is_valid"]:
# Inform the user that there are errors in parameters
st.warning(
"There are errors in the augmentation parameters. Please review your selections.",
icon="⚠️",
)
#######################################################################################################
# Display original and processed images
#######################################################################################################
# Check if image_repository and processed_image_mapping exist in session_state
if (
"image_repository_augmentation" in st.session_state
and "processed_image_mapping_augmentation" in st.session_state
):
# Number of unique images
num_unique_images = len(st.session_state["unique_images_names"])
if num_unique_images > 1:
# Create a slider to select an image index from the processed image mapping
selected_image_index = st.slider(
"Select an Image",
min_value=1,
max_value=num_unique_images, # Set the maximum to the number of unique images
step=1,
)
else:
selected_image_index = 1
# Retrieve the name of the selected original image using the slider index
selected_original_image_name = st.session_state["unique_images_names"][
selected_image_index - 1
]
# Retrieve the names of all processed variants for the selected original image
processed_variant_names = st.session_state[
"processed_image_mapping_augmentation"
].get(selected_original_image_name, [])
# Combine the original image name with its processed variants
all_image_names = [selected_original_image_name] + processed_variant_names
if len(st.session_state["label_files"]) > 0:
# Options for displaying labels on the images
label_display_options = ["No Label", "All Labels", "Specific Labels"]
# Select box for the user to choose how labels should be displayed on the images
selected_label_display_option = st.selectbox(
"Choose how to display labels:",
label_display_options,
index=0, # Default option is 'No Label'
)
# If 'All Labels' option is selected, include all class IDs
if selected_label_display_option == "All Labels":
labels_to_plot = list(class_dict.keys())
# If 'Specific Labels' option is selected, allow user to select specific class IDs
elif selected_label_display_option == "Specific Labels":
selected_class_names = st.multiselect(
"Select specific labels to display",
list(class_names_to_ids.keys()),
class_dict[1],
)
labels_to_plot = [class_names_to_ids[name] for name in selected_class_names]
else:
selected_label_display_option = "No Label"
# Display images in a grid
num_images = len(all_image_names)
num_columns = 4
for i in range(0, num_images, num_columns):
cols = st.columns(num_columns)
for j in range(num_columns):
image_index = i + j
if image_index < num_images:
image_name = all_image_names[image_index]
image_data = st.session_state["image_repository_augmentation"][
image_name
]["image"]
label_file = st.session_state["image_repository_augmentation"][
image_name
]["label"]
# Overlay labels on the image based on the selected option
if selected_label_display_option in ["All Labels", "Specific Labels"]:
# Overlay labels if selected
modified_image = augmentation_functions.overlay_labels(
image=image_data.copy(),
labels_to_plot=labels_to_plot,
label_file=label_file,
label_type=label_type,
colors=colors,
class_dict=class_dict,
)
else:
# Use the original image without overlay if 'No Label' is selected
modified_image = image_data
# Display the image in the respective column with a caption
with cols[j]:
st.image(
modified_image,
clamp=True,
caption=image_name,
use_column_width=True,
)