CV_Accelerator / pages /4_Model_Training.py
samkeet's picture
First Commit
3d90a2e verified
# Set the page config
import streamlit as st
st.set_page_config(
page_title="Model_Training",
page_icon=":open_file_folder:",
layout="wide",
initial_sidebar_state="collapsed",
)
# Importing necessary libraries
import utils
import streamlit as st
import Functions.model_training_functions as model_training_functions
# Display the page title
st.title("Model Training")
# # Clear the Streamlit session state on the first load of the page
# utils.clear_session_state_on_first_load("model_training_clear")
# List of session state keys to initialize if they are not already present
session_state_keys = [
"file_uploader_split_key_training",
"file_uploader_train_key_training",
"file_uploader_val_key_training",
"file_uploader_test_key_training",
"number_input_train_key",
"number_input_val_key",
"number_input_test_key",
"split_method_key",
"training_type_key",
"class_labels_input_key_training",
]
# Iterate through each session state key
for key in session_state_keys:
# Check if the key is not already in the session state
if key not in st.session_state:
# Initialize the key with a dictionary containing itself set to True
st.session_state[key] = {key: True}
# Initialize session state variables if not present
if "validation_triggered" not in st.session_state:
st.session_state["validation_triggered"] = False
if "uploaded_files_cache_processing" not in st.session_state:
st.session_state["uploaded_files_cache_processing"] = False
# Initialize session state variables if not present
if "is_valid" not in st.session_state:
st.session_state["is_valid"] = False
# Container for file uploaders
file_uploader_container = st.container()
# Dictionary for mapping the user-friendly terms to technical label types
label_type_mapping = {"Object Detection": "Bboxes", "Instance Segmentation": "Masks"}
# Create two columns for widgets
column_select_training, column_split_method = st.columns(2)
# Dropdown for selecting the training type
with column_select_training:
selected_training = st.selectbox(
"Select the training type:",
list(label_type_mapping.keys()),
index=0,
on_change=utils.reset_validation_trigger,
key=st.session_state["training_type_key"],
)
# Getting the corresponding label type
label_type = label_type_mapping[selected_training]
# Toggle for choosing the split method
with column_split_method:
split_method = st.radio(
"Select the dataset split method:",
["Percentage Split", "Direct Upload"],
horizontal=True,
on_change=utils.reset_validation_trigger,
key=st.session_state["split_method_key"],
)
# Text area for user to input class labels
class_labels_input = st.text_area(
"Enter class labels, separated by commas:",
utils.sample_class_labels,
on_change=utils.reset_validation_trigger,
key=st.session_state["class_labels_input_key_training"],
) # Example default values
class_labels_input = (
class_labels_input.strip()
) # Remove unecessary space form start and end
# Generating a dictionary mapping class IDs to their respective labels
try:
class_labels = [
label.strip() for label in class_labels_input.split(",") if label.strip()
]
class_dict = {i: label for i, label in enumerate(class_labels)}
# Invert the class_dict to map class names to class IDs
class_names_to_ids = {v: k for k, v in class_dict.items()}
except Exception as e:
st.warning(
"Invalid format for class labels. Please enter labels separated by commas.",
icon="⚠️",
)
class_dict, class_names_to_ids = (
{},
{},
) # Keeping class_dict and class_names_to_ids as an empty
# Note to users
st.markdown(
"""
<div style='text-align: justify;'>
<b>Note to Users:</b>
<ul>
<li>When moving to another page or if you wish to upload a new set of images and labels, don't forget to hit the <b>Reset</b> button. This helps in faster computation and frees up unused memory, ensuring smoother operation.</li>
<li>Select the training type, class labels, dataset split method and its parameters before uploading large data for faster computation and more efficient processing.</li>
</ul>
</div>
""",
unsafe_allow_html=True,
)
# Create two columns for input percentages
validate_button_col, reset_button_col = st.columns(2)
with reset_button_col:
# Check if the 'Reset' button is pressed
if st.button("Reset", use_container_width=True):
# Clear folders
model_training_functions.delete_and_recreate_folder(
model_training_functions.get_path("output")
)
model_training_functions.clear_data_folders()
# List of session state keys that need to be reset
session_state_keys = [
"file_uploader_split_key_training",
"file_uploader_train_key_training",
"file_uploader_val_key_training",
"file_uploader_test_key_training",
"number_input_train_key",
"number_input_val_key",
"number_input_test_key",
"split_method_key",
"training_type_key",
"class_labels_input_key_training",
]
# Iterate through each session state key
for session_state_key in session_state_keys:
# Toggle the keys to reset their states
current_value = st.session_state[session_state_key][session_state_key]
updated_value = not current_value # Invert the current value
# Update each key in the session state with the toggled value
st.session_state[session_state_key] = {session_state_key: updated_value}
# Clear all other session state keys except for widget_state_keys
for key in list(st.session_state.keys()):
if key not in session_state_keys:
del st.session_state[key]
# Clear global variables except for protected and Streamlit module
global_vars = list(globals().keys())
vars_to_delete = [
var for var in global_vars if not var.startswith("_") and var != "st"
]
for var in vars_to_delete:
del globals()[var]
# Clear the Streamlit caches
st.cache_resource.clear()
st.cache_data.clear()
# Rerun the app to reflect the reset state
st.rerun()
# Code for "Percentage Split" method
if split_method == "Percentage Split":
with file_uploader_container:
# User uploads images and labels
utils.display_file_uploader(
"uploaded_files",
"Choose images and labels...",
st.session_state["file_uploader_split_key_training"],
st.session_state["uploaded_files_cache_processing"],
)
# Create three columns for input percentages
col1, col2, col3 = st.columns(3)
# User specifies split percentages
train_pct = col1.number_input(
"Train Set Percentage",
0,
100,
70,
1,
on_change=utils.reset_validation_trigger,
key=st.session_state["number_input_train_key"],
)
test_pct = col2.number_input(
"Test Set Percentage",
0,
100,
15,
1,
on_change=utils.reset_validation_trigger,
key=st.session_state["number_input_val_key"],
)
val_pct = col3.number_input(
"Validation Set Percentage",
0,
100,
15,
1,
on_change=utils.reset_validation_trigger,
key=st.session_state["number_input_test_key"],
)
# Check if the total percentage equals 100%
pct_check = train_pct + test_pct + val_pct
# Validating the input percentages
pct_condition_check = (
pct_check == 100
and train_pct > 0
and val_pct > 0
and model_training_functions.check_min_images(
len(st.session_state["uploaded_files"]), train_pct, val_pct, test_pct
)
)
if not pct_condition_check:
file_uploader_container.warning(
"The percentages for train, test, and validation sets should add up to 100%, and train and validation set should not be empty.",
icon="⚠️",
)
# Button to trigger validation
if validate_button_col.button("Validate Input", use_container_width=True):
st.session_state["validation_triggered"] = True
st.session_state["is_valid"] = model_training_functions.check_valid_labels(
st.session_state["uploaded_files"], label_type, class_dict
)
if st.session_state["is_valid"]:
model_training_functions.create_yolo_config_file(
model_training_functions.get_path("config"),
class_labels,
)
model_training_functions.clear_data_folders()
paired_files = model_training_functions.pair_files(
st.session_state["uploaded_files"]
)
model_training_functions.split_and_save_files(
paired_files, train_pct, test_pct
)
# Process files if input is valid
if st.session_state["validation_triggered"] and (
pct_condition_check and st.session_state["is_valid"]
):
model_training_functions.start_yolo_training(selected_training, class_labels)
else:
# Display a warning message if the validation is not successful or conditions are not met
st.warning(
"Please upload valid input, select valid parameters, and click **Validate Input**.",
icon="⚠️",
)
# Code for "Direct Upload" method
elif split_method == "Direct Upload":
with file_uploader_container:
# Create three columns for uploading train, val, and test files
col1, col2, col3 = st.columns(3)
with col1:
utils.display_file_uploader(
"uploaded_train_files",
"Upload Training Images and Labels",
st.session_state["file_uploader_train_key_training"],
st.session_state["uploaded_files_cache_processing"],
)
with col2:
utils.display_file_uploader(
"uploaded_val_files",
"Upload Validation Images and Labels",
st.session_state["file_uploader_val_key_training"],
st.session_state["uploaded_files_cache_processing"],
)
with col3:
utils.display_file_uploader(
"uploaded_test_files",
"Upload Test Images and Labels",
st.session_state["file_uploader_test_key_training"],
st.session_state["uploaded_files_cache_processing"],
)
# Check for valid input
pct_condition_check = (
len(st.session_state["uploaded_train_files"]) > 0
and len(st.session_state["uploaded_val_files"]) > 0
)
if not pct_condition_check:
file_uploader_container.warning(
"The train and validation set should not be empty.",
icon="⚠️",
)
# Button to trigger validation
if validate_button_col.button("Validate Input", use_container_width=True):
st.session_state["validation_triggered"] = True
st.session_state["is_valid"] = model_training_functions.check_valid_labels(
st.session_state["uploaded_train_files"]
+ st.session_state["uploaded_val_files"]
+ st.session_state["uploaded_test_files"],
label_type,
class_dict,
)
if st.session_state["is_valid"]:
model_training_functions.create_yolo_config_file(
model_training_functions.get_path("config"),
class_labels,
)
model_training_functions.clear_data_folders()
model_training_functions.save_files_to_folder(
st.session_state["uploaded_train_files"], "train"
)
model_training_functions.save_files_to_folder(
st.session_state["uploaded_val_files"], "val"
)
# Only save test files if they are uploaded
if len(st.session_state["uploaded_test_files"]) > 0:
model_training_functions.save_files_to_folder(
st.session_state["uploaded_test_files"], "test"
)
# Process files if input is valid
if st.session_state["validation_triggered"] and (
pct_condition_check and st.session_state["is_valid"]
):
model_training_functions.start_yolo_training(selected_training, class_labels)
else:
# Display a warning message if the validation is not successful or conditions are not met
st.warning(
"Please upload valid input, select valid parameters, and click **Validate Input**.",
icon="⚠️",
)