Spaces:
Running
Running
# Set the page config | |
import streamlit as st | |
st.set_page_config( | |
page_title="Model_Training", | |
page_icon=":open_file_folder:", | |
layout="wide", | |
initial_sidebar_state="collapsed", | |
) | |
# Importing necessary libraries | |
import utils | |
import streamlit as st | |
import Functions.model_training_functions as model_training_functions | |
# Display the page title | |
st.title("Model Training") | |
# # Clear the Streamlit session state on the first load of the page | |
# utils.clear_session_state_on_first_load("model_training_clear") | |
# List of session state keys to initialize if they are not already present | |
session_state_keys = [ | |
"file_uploader_split_key_training", | |
"file_uploader_train_key_training", | |
"file_uploader_val_key_training", | |
"file_uploader_test_key_training", | |
"number_input_train_key", | |
"number_input_val_key", | |
"number_input_test_key", | |
"split_method_key", | |
"training_type_key", | |
"class_labels_input_key_training", | |
] | |
# Iterate through each session state key | |
for key in session_state_keys: | |
# Check if the key is not already in the session state | |
if key not in st.session_state: | |
# Initialize the key with a dictionary containing itself set to True | |
st.session_state[key] = {key: True} | |
# Initialize session state variables if not present | |
if "validation_triggered" not in st.session_state: | |
st.session_state["validation_triggered"] = False | |
if "uploaded_files_cache_processing" not in st.session_state: | |
st.session_state["uploaded_files_cache_processing"] = False | |
# Initialize session state variables if not present | |
if "is_valid" not in st.session_state: | |
st.session_state["is_valid"] = False | |
# Container for file uploaders | |
file_uploader_container = st.container() | |
# Dictionary for mapping the user-friendly terms to technical label types | |
label_type_mapping = {"Object Detection": "Bboxes", "Instance Segmentation": "Masks"} | |
# Create two columns for widgets | |
column_select_training, column_split_method = st.columns(2) | |
# Dropdown for selecting the training type | |
with column_select_training: | |
selected_training = st.selectbox( | |
"Select the training type:", | |
list(label_type_mapping.keys()), | |
index=0, | |
on_change=utils.reset_validation_trigger, | |
key=st.session_state["training_type_key"], | |
) | |
# Getting the corresponding label type | |
label_type = label_type_mapping[selected_training] | |
# Toggle for choosing the split method | |
with column_split_method: | |
split_method = st.radio( | |
"Select the dataset split method:", | |
["Percentage Split", "Direct Upload"], | |
horizontal=True, | |
on_change=utils.reset_validation_trigger, | |
key=st.session_state["split_method_key"], | |
) | |
# Text area for user to input class labels | |
class_labels_input = st.text_area( | |
"Enter class labels, separated by commas:", | |
utils.sample_class_labels, | |
on_change=utils.reset_validation_trigger, | |
key=st.session_state["class_labels_input_key_training"], | |
) # Example default values | |
class_labels_input = ( | |
class_labels_input.strip() | |
) # Remove unecessary space form start and end | |
# Generating a dictionary mapping class IDs to their respective labels | |
try: | |
class_labels = [ | |
label.strip() for label in class_labels_input.split(",") if label.strip() | |
] | |
class_dict = {i: label for i, label in enumerate(class_labels)} | |
# Invert the class_dict to map class names to class IDs | |
class_names_to_ids = {v: k for k, v in class_dict.items()} | |
except Exception as e: | |
st.warning( | |
"Invalid format for class labels. Please enter labels separated by commas.", | |
icon="⚠️", | |
) | |
class_dict, class_names_to_ids = ( | |
{}, | |
{}, | |
) # Keeping class_dict and class_names_to_ids as an empty | |
# Note to users | |
st.markdown( | |
""" | |
<div style='text-align: justify;'> | |
<b>Note to Users:</b> | |
<ul> | |
<li>When moving to another page or if you wish to upload a new set of images and labels, don't forget to hit the <b>Reset</b> button. This helps in faster computation and frees up unused memory, ensuring smoother operation.</li> | |
<li>Select the training type, class labels, dataset split method and its parameters before uploading large data for faster computation and more efficient processing.</li> | |
</ul> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Create two columns for input percentages | |
validate_button_col, reset_button_col = st.columns(2) | |
with reset_button_col: | |
# Check if the 'Reset' button is pressed | |
if st.button("Reset", use_container_width=True): | |
# Clear folders | |
model_training_functions.delete_and_recreate_folder( | |
model_training_functions.get_path("output") | |
) | |
model_training_functions.clear_data_folders() | |
# List of session state keys that need to be reset | |
session_state_keys = [ | |
"file_uploader_split_key_training", | |
"file_uploader_train_key_training", | |
"file_uploader_val_key_training", | |
"file_uploader_test_key_training", | |
"number_input_train_key", | |
"number_input_val_key", | |
"number_input_test_key", | |
"split_method_key", | |
"training_type_key", | |
"class_labels_input_key_training", | |
] | |
# Iterate through each session state key | |
for session_state_key in session_state_keys: | |
# Toggle the keys to reset their states | |
current_value = st.session_state[session_state_key][session_state_key] | |
updated_value = not current_value # Invert the current value | |
# Update each key in the session state with the toggled value | |
st.session_state[session_state_key] = {session_state_key: updated_value} | |
# Clear all other session state keys except for widget_state_keys | |
for key in list(st.session_state.keys()): | |
if key not in session_state_keys: | |
del st.session_state[key] | |
# Clear global variables except for protected and Streamlit module | |
global_vars = list(globals().keys()) | |
vars_to_delete = [ | |
var for var in global_vars if not var.startswith("_") and var != "st" | |
] | |
for var in vars_to_delete: | |
del globals()[var] | |
# Clear the Streamlit caches | |
st.cache_resource.clear() | |
st.cache_data.clear() | |
# Rerun the app to reflect the reset state | |
st.rerun() | |
# Code for "Percentage Split" method | |
if split_method == "Percentage Split": | |
with file_uploader_container: | |
# User uploads images and labels | |
utils.display_file_uploader( | |
"uploaded_files", | |
"Choose images and labels...", | |
st.session_state["file_uploader_split_key_training"], | |
st.session_state["uploaded_files_cache_processing"], | |
) | |
# Create three columns for input percentages | |
col1, col2, col3 = st.columns(3) | |
# User specifies split percentages | |
train_pct = col1.number_input( | |
"Train Set Percentage", | |
0, | |
100, | |
70, | |
1, | |
on_change=utils.reset_validation_trigger, | |
key=st.session_state["number_input_train_key"], | |
) | |
test_pct = col2.number_input( | |
"Test Set Percentage", | |
0, | |
100, | |
15, | |
1, | |
on_change=utils.reset_validation_trigger, | |
key=st.session_state["number_input_val_key"], | |
) | |
val_pct = col3.number_input( | |
"Validation Set Percentage", | |
0, | |
100, | |
15, | |
1, | |
on_change=utils.reset_validation_trigger, | |
key=st.session_state["number_input_test_key"], | |
) | |
# Check if the total percentage equals 100% | |
pct_check = train_pct + test_pct + val_pct | |
# Validating the input percentages | |
pct_condition_check = ( | |
pct_check == 100 | |
and train_pct > 0 | |
and val_pct > 0 | |
and model_training_functions.check_min_images( | |
len(st.session_state["uploaded_files"]), train_pct, val_pct, test_pct | |
) | |
) | |
if not pct_condition_check: | |
file_uploader_container.warning( | |
"The percentages for train, test, and validation sets should add up to 100%, and train and validation set should not be empty.", | |
icon="⚠️", | |
) | |
# Button to trigger validation | |
if validate_button_col.button("Validate Input", use_container_width=True): | |
st.session_state["validation_triggered"] = True | |
st.session_state["is_valid"] = model_training_functions.check_valid_labels( | |
st.session_state["uploaded_files"], label_type, class_dict | |
) | |
if st.session_state["is_valid"]: | |
model_training_functions.create_yolo_config_file( | |
model_training_functions.get_path("config"), | |
class_labels, | |
) | |
model_training_functions.clear_data_folders() | |
paired_files = model_training_functions.pair_files( | |
st.session_state["uploaded_files"] | |
) | |
model_training_functions.split_and_save_files( | |
paired_files, train_pct, test_pct | |
) | |
# Process files if input is valid | |
if st.session_state["validation_triggered"] and ( | |
pct_condition_check and st.session_state["is_valid"] | |
): | |
model_training_functions.start_yolo_training(selected_training, class_labels) | |
else: | |
# Display a warning message if the validation is not successful or conditions are not met | |
st.warning( | |
"Please upload valid input, select valid parameters, and click **Validate Input**.", | |
icon="⚠️", | |
) | |
# Code for "Direct Upload" method | |
elif split_method == "Direct Upload": | |
with file_uploader_container: | |
# Create three columns for uploading train, val, and test files | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
utils.display_file_uploader( | |
"uploaded_train_files", | |
"Upload Training Images and Labels", | |
st.session_state["file_uploader_train_key_training"], | |
st.session_state["uploaded_files_cache_processing"], | |
) | |
with col2: | |
utils.display_file_uploader( | |
"uploaded_val_files", | |
"Upload Validation Images and Labels", | |
st.session_state["file_uploader_val_key_training"], | |
st.session_state["uploaded_files_cache_processing"], | |
) | |
with col3: | |
utils.display_file_uploader( | |
"uploaded_test_files", | |
"Upload Test Images and Labels", | |
st.session_state["file_uploader_test_key_training"], | |
st.session_state["uploaded_files_cache_processing"], | |
) | |
# Check for valid input | |
pct_condition_check = ( | |
len(st.session_state["uploaded_train_files"]) > 0 | |
and len(st.session_state["uploaded_val_files"]) > 0 | |
) | |
if not pct_condition_check: | |
file_uploader_container.warning( | |
"The train and validation set should not be empty.", | |
icon="⚠️", | |
) | |
# Button to trigger validation | |
if validate_button_col.button("Validate Input", use_container_width=True): | |
st.session_state["validation_triggered"] = True | |
st.session_state["is_valid"] = model_training_functions.check_valid_labels( | |
st.session_state["uploaded_train_files"] | |
+ st.session_state["uploaded_val_files"] | |
+ st.session_state["uploaded_test_files"], | |
label_type, | |
class_dict, | |
) | |
if st.session_state["is_valid"]: | |
model_training_functions.create_yolo_config_file( | |
model_training_functions.get_path("config"), | |
class_labels, | |
) | |
model_training_functions.clear_data_folders() | |
model_training_functions.save_files_to_folder( | |
st.session_state["uploaded_train_files"], "train" | |
) | |
model_training_functions.save_files_to_folder( | |
st.session_state["uploaded_val_files"], "val" | |
) | |
# Only save test files if they are uploaded | |
if len(st.session_state["uploaded_test_files"]) > 0: | |
model_training_functions.save_files_to_folder( | |
st.session_state["uploaded_test_files"], "test" | |
) | |
# Process files if input is valid | |
if st.session_state["validation_triggered"] and ( | |
pct_condition_check and st.session_state["is_valid"] | |
): | |
model_training_functions.start_yolo_training(selected_training, class_labels) | |
else: | |
# Display a warning message if the validation is not successful or conditions are not met | |
st.warning( | |
"Please upload valid input, select valid parameters, and click **Validate Input**.", | |
icon="⚠️", | |
) | |