CV_Accelerator / Functions /image_processing_functions.py
samkeet's picture
First Commit
3d90a2e verified
raw
history blame
18.4 kB
# Importing necessary libraries
import io
import os
import cv2
import utils
import random
import zipfile
import numpy as np
import pandas as pd
from PIL import Image
import streamlit as st
import albumentations as A
# Function to check if the uploaded images and labels are valid
st.cache_resource(show_spinner=False)
def check_valid_labels(uploaded_files):
# Early exit if no files are uploaded
if len(uploaded_files) == 0:
st.warning(
"Please upload at least one image to apply image processing.", icon="⚠️"
)
return False, {}, {}, None, None
# Initialize dictionaries to hold images and labels
image_files, label_files = {}, {}
# Extracting the name of the first file
first_file_name = os.path.splitext(uploaded_files[0].name)[0]
# Counters for images and labels
image_count, label_count = 0, 0
# Initialize a progress bar and progress text
progress_bar = st.progress(0)
progress_text = st.empty()
total_files = len(uploaded_files)
# Categorize and prepare uploaded files
for index, file in enumerate(uploaded_files):
file.seek(0) # Reset file pointer to ensure proper file reading
file_name_without_extension = os.path.splitext(file.name)[0]
# Distribute files into image or label categories based on their file type
if file.type in ["image/jpeg", "image/png"]:
image_files[file_name_without_extension] = file
image_count += 1
elif file.type == "text/plain":
label_files[file_name_without_extension] = file
label_count += 1
# Update progress bar and display current progress
progress_percentage = (index + 1) / total_files
progress_bar.progress(progress_percentage)
progress_text.text(f"Validating file {index + 1} of {total_files}")
# Extract sets of unique file names for images and labels
unique_image_names = set(image_files.keys())
unique_label_names = set(label_files.keys())
# Remove progress bar and progress text after processing
progress_bar.empty()
progress_text.empty()
if (len(unique_image_names) != image_count) or (
len(unique_label_names) != label_count
):
# Warn the user about the presence of duplicate file names
st.warning(
"Duplicate file names detected. Please ensure each image and label has a unique name.",
icon="⚠️",
)
return False, {}, {}, None, None
# Perform validation checks
if (len(image_files) > 0) and (len(label_files) > 0):
# Check if the number of images and labels match and each pair has corresponding files
if (len(image_files) == len(label_files)) and (
unique_image_names == unique_label_names
):
st.info(
f"Validated: {len(image_files)} images and labels successfully matched.",
icon="✅",
)
return (
True,
image_files,
label_files,
image_files[first_file_name],
label_files[first_file_name],
)
elif len(image_files) != len(label_files):
# Warn if the count of images and labels does not match
st.warning(
"Count Mismatch: The number of uploaded images and labels does not match.",
icon="⚠️",
)
return False, {}, {}, None, None
else:
# Warn if there is a mismatch in file names between images and labels
st.warning(
"Mismatch detected: Some images do not have corresponding label files.",
icon="⚠️",
)
return False, {}, {}, None, None
elif len(image_files) > 0:
# Inform the user if only images are uploaded without labels
st.info(
f"Note: {len(image_files)} images uploaded without labels. Label type and class labels will be ignored in this case.",
icon="✅",
)
return True, image_files, {}, image_files[first_file_name], None
else:
# Warn if no images are uploaded
st.warning("Please upload an image to apply image processing.", icon="⚠️")
return False, {}, {}, None, None
# Function to apply an image processing technique to an image and return any errors along with the processed image
def apply_and_test_image_processing(
image_processing, params, image, allowed_image_types
):
try:
# Check the data type and number of channels of the input image
input_image_type = image.dtype
num_channels = (
image.shape[2] if len(image.shape) == 3 else 1
) # Assuming 1 for single-channel images
# Validate if the input image type is among the allowed types
if not utils.is_image_type_allowed(
input_image_type, num_channels, allowed_image_types
):
# Format the allowed types for display in the warning message
allowed_types_formatted = ", ".join(map(str, allowed_image_types))
# Display a warning message specifying the acceptable image types
st.warning(
f"Error applying {image_processing}: Incompatible image type. The input image should be one of the following types: {allowed_types_formatted}",
icon="⚠️",
)
return True, None # Error occurred
# Set the seed for reproducibility using iteration number
random.seed(0)
# Apply image processing technique
transform = A.Compose([utils.apply_albumentation(params, image_processing)])
processed_image = transform(image=image)["image"]
return False, processed_image # No error
except Exception as e:
st.warning(f"Error applying {image_processing}: {e}", icon="⚠️")
return True, None # Error occurred
# Function to generates a DataFrame detailing image processing technique parameters and descriptions
def create_image_processings_dataframe(
image_processings_params, image_processing_params_db
):
data = []
for aug_name, params in image_processings_params.items():
for param_name, param_value in params.items():
# Retrieve relevant image_processing information from the database
image_processing_info = image_processing_params_db[
image_processing_params_db["Name"] == aug_name
]
param_info = image_processing_info[
image_processing_info["Parameter Name"] == param_name
]
# Check if the parameter information exists in the database
if not param_info.empty:
# Get the description of the current parameter
param_description = param_info["Parameter Description"].iloc[0]
else:
param_description = "Description not available"
# Append image_processing name, parameter name, its value, and description to the data list
data.append([aug_name, param_name, param_value, param_description])
# Create the DataFrame from the accumulated data
image_processings_df = pd.DataFrame(
data, columns=["image_processing", "Parameter", "Value", "Description"]
)
return image_processings_df
# Function to generate python code for images and labels
def generate_python_code_images_labels(
augmentations_params,
num_variations=1,
include_original=False,
):
# Start with necessary library imports
code_str = "# Importing necessary libraries\n"
code_str += "import os\nimport cv2\nimport shutil\nimport albumentations as A\n\n"
# Paths for input and output directories
code_str += "# Define the paths for input and output directories\n"
code_str += "input_directory = 'path/to/input'\n"
code_str += "output_directory = 'path/to/output'\n\n"
# Function to create an augmentation pipeline
code_str += "# Function to create an augmentation pipeline using Albumentations\n"
code_str += "def process_image(image):\n"
code_str += " # Define the sequence of augmentation techniques\n"
code_str += " pipeline = A.Compose([\n"
for technique, params in augmentations_params.items():
code_str += f" A.{technique}({', '.join(f'{k}={v}' for k, v in params.items())}),\n"
code_str += " ])\n"
code_str += " # Apply the augmentation pipeline\n"
code_str += " return pipeline(image=image)['image']\n\n"
# Function to process a batch of images
code_str += "# Function to process a batch of images\n"
code_str += "def process_batch(input_directory, output_directory):\n"
code_str += " for filename in os.listdir(input_directory):\n"
code_str += " if filename.lower().endswith(('.png', '.jpg', '.jpeg')):\n"
code_str += " image_path = os.path.join(input_directory, filename)\n"
code_str += " label_path = os.path.splitext(image_path)[0] + '.txt'\n\n"
code_str += " # Read the image\n"
code_str += " image = cv2.imread(image_path)\n\n"
# Include original image and label logic
if include_original:
code_str += " # Include original image and label\n"
code_str += " shutil.copy2(image_path, output_directory)\n"
code_str += " shutil.copy2(label_path, output_directory)\n\n"
# Generate variations for each image and process them
code_str += " # Generate variations for each image and process them\n"
code_str += f" for variation in range({num_variations}):\n"
code_str += " processed_image = process_image(image)\n\n"
code_str += " # Save the processed image\n"
code_str += " output_filename = f'processed_{os.path.splitext(filename)[0]}_{variation}{os.path.splitext(filename)[1]}'\n"
code_str += " cv2.imwrite(os.path.join(output_directory, output_filename), processed_image)\n\n"
code_str += (
" # Save the original label file for the processed image\n"
)
code_str += " if os.path.exists(label_path):\n"
code_str += " shutil.copy2(label_path, os.path.join(output_directory, os.path.splitext(output_filename)[0] + '.txt'))\n\n"
# Execute the batch processing function
code_str += (
"# Execute the batch processing function with the specified parameters\n"
)
code_str += "process_batch(input_directory, output_directory)\n"
return code_str
# Function to create an image processing pipeline based on the selected techniques and their parameters
def create_image_processing_pipeline(
selected_image_processings, image_processing_params
):
pipeline = []
for aug_name in selected_image_processings:
# Append the function call with its parameters to the pipeline
pipeline.append(
utils.apply_albumentation(image_processing_params[aug_name], aug_name)
)
# Compose all the image processings into one transformation
return A.Compose(pipeline)
# Function to process images and labels, apply image processing techniques, and create a zip file with the results
@st.cache_resource(show_spinner=False)
def process_images_and_labels(
image_files,
label_files,
selected_image_processings,
_image_processings_params,
num_variations,
include_original,
):
zip_buffer = io.BytesIO() # Create an in-memory buffer for the zip file
st.session_state[
"image_repository_preprocessing"
] = {} # Initialize a repository to store processed image data
st.session_state[
"processed_image_mapping_procesing"
] = {} # Map original images to their processed versions
st.session_state["unique_images_names"] = [] # List to store unique images names
# Create progress bar and text elements in Streamlit
progress_bar = st.progress(0)
progress_text = st.empty()
with zipfile.ZipFile(
zip_buffer, mode="a", compression=zipfile.ZIP_DEFLATED, allowZip64=True
) as zip_file:
# Compose all the image processings into one transformation
transform = create_image_processing_pipeline(
selected_image_processings, _image_processings_params
)
total_images = len(image_files) * num_variations
processed_count = 0 # Counter for processed images
# Iterate over each uploaded file
for image_name, image_file in image_files.items():
image_file.seek(0) # Reset file pointer to start
file_bytes = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
original_image = cv2.cvtColor(
cv2.imdecode(file_bytes, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB
)
original_image_resized = utils.resize_image(original_image)
# Include original images and labels in the output if selected
if include_original:
original_img_buffer = io.BytesIO()
Image.fromarray(original_image).save(original_img_buffer, format="JPEG")
zip_file.writestr(image_file.name, original_img_buffer.getvalue())
# Save corresponding label file to zip if it exists
label_file = label_files.get(image_name)
if label_file is not None:
label_file.seek(0) # Reset the file pointer
zip_file.writestr(f"{image_name}.txt", label_file.read())
original_file_name = image_file.name
st.session_state["unique_images_names"].append(original_file_name)
st.session_state["processed_image_mapping_procesing"][
original_file_name
] = []
st.session_state["image_repository_preprocessing"][image_file.name] = {
"image": original_image_resized,
"label": label_files.get(image_name),
}
# Apply image processing techniques and generate variations
for i in range(num_variations):
random.seed(i)
# Apply the image processing pipeline to the image
processed_image = transform(image=original_image)["image"]
img_buffer = io.BytesIO()
Image.fromarray(processed_image).save(img_buffer, format="JPEG")
processed_filename = f"processed_{image_name.split('.')[0]}_{i}.jpg"
zip_file.writestr(processed_filename, img_buffer.getvalue())
processed_image_resized = utils.resize_image(processed_image)
# Save corresponding label file to zip if it exists
label_file = label_files.get(image_name)
if label_file is not None:
label_file.seek(0) # Reset the file pointer
zip_file.writestr(
f"processed_{image_name}_{i}.txt", label_file.read()
)
st.session_state["processed_image_mapping_procesing"][
image_file.name
].append(processed_filename)
st.session_state["image_repository_preprocessing"][
processed_filename
] = {
"image": processed_image_resized,
"label": label_file,
}
processed_count += 1
# Update progress bar and text
progress_bar.progress(processed_count / total_images)
progress_text.text(
f"Processing image {processed_count} of {total_images}"
)
# Remove the progress bar and text after processing is complete
progress_bar.empty()
progress_text.empty()
zip_buffer.seek(0) # Reset buffer to start for download
st.session_state["zip_data_processing"] = zip_buffer.getvalue()
# Function to generate a downloadable file
def display_code_and_download_button(generated_code):
def generate_downloadable_file(code_str):
return code_str.encode("utf-8")
# Display the generated code in Streamlit with description and download button in columns
with st.expander("Plug and Play Code"):
col1, col2 = st.columns([7, 3])
with col1:
st.markdown(
"""
### Description of the Code Pipeline
"""
)
st.markdown(
"""
<div style='text-align: justify;'>
This code is a ready-to-use Python script for batch augmentation. It applies selected augmentation techniques to all images in a specified input directory and saves the processed images in an output directory.
**To use this script:**
- Ensure you have the necessary dependencies installed.
- Specify the input and output paths: Replace `'path/to/input'` with the path to your input images and `'path/to/output'` with the desired path for the processed images.
- The number of processed variations per image, the inclusion of the original images in the output, and the processing techniques with their parameters will be automatically set based on your selections.
### Python Code
</div>
""",
unsafe_allow_html=True,
)
# Display python code
st.code(generated_code, language="python")
with col2:
# Create a button for downloading the Python file
st.download_button(
label="Download Python File",
data=generate_downloadable_file(generated_code),
file_name="image_processing_script.py",
mime="text/plain",
use_container_width=True,
)