Spaces:
Sleeping
Sleeping
# Importing necessary libraries | |
import io | |
import os | |
import utils | |
import random | |
import shutil | |
import zipfile | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
from ultralytics import YOLO | |
import plotly.graph_objs as go | |
from onnx.defs import onnx_opset_version | |
from plotly.subplots import make_subplots | |
# Function to get the dataset directory path based on the specified path type | |
def get_path(path_type): | |
main_directory_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
if path_type == "train": | |
return os.path.join( | |
main_directory_path, | |
"model_data", | |
"input_files", | |
"datasets", | |
"train", | |
) | |
elif path_type == "val": | |
return os.path.join( | |
main_directory_path, | |
"model_data", | |
"input_files", | |
"datasets", | |
"val", | |
) | |
elif path_type == "test": | |
return os.path.join( | |
main_directory_path, | |
"model_data", | |
"input_files", | |
"datasets", | |
"test", | |
) | |
elif path_type == "config": | |
return os.path.join(main_directory_path, "model_data", "input_files") | |
elif path_type == "models": | |
return os.path.join(main_directory_path, "model_data", "models") | |
elif path_type == "output": | |
return os.path.join(main_directory_path, "model_data", "output_files") | |
else: | |
raise ValueError(f"Invalid path_type: {path_type}") | |
# Function to check minimum images in training and validation set | |
def check_min_images(total_files, train_pct, val_pct, test_pct): | |
# Calculate raw counts based on percentages | |
train_count = int(total_files * train_pct / 100) | |
val_count = int(total_files * val_pct / 100) | |
test_count = int(total_files * test_pct / 100) | |
# Ensure that both train and validation have at least one file | |
if train_count < 1 or val_count < 1: | |
return False | |
return True | |
# Function to clear data a folders | |
def clear_data_folders(): | |
base_path = "./model_data/input_files/datasets" | |
for folder in ["train", "test", "val"]: | |
for subfolder in ["images", "labels"]: | |
folder_path = os.path.join(base_path, folder, subfolder) | |
if os.path.exists(folder_path): | |
shutil.rmtree(folder_path) | |
os.makedirs(folder_path, exist_ok=True) | |
# Function to pairs image and label files based on their filenames | |
def pair_files(files): | |
paired_files = {} | |
for file in files: | |
# Split the filename into name and extension | |
file_name, file_ext = os.path.splitext(file.name) | |
# Initialize a dict for each unique file name | |
if file_name not in paired_files: | |
paired_files[file_name] = {"image": None, "label": None} | |
# Assign the file to its corresponding type (image or label) based on extension | |
if file_ext.lower() in [".jpg", ".png"]: | |
paired_files[file_name]["image"] = file | |
elif file_ext.lower() == ".txt": | |
paired_files[file_name]["label"] = file | |
return paired_files | |
# Function to split the paired files into training, testing, and validation sets based on specified percentages and saves them in corresponding folders | |
def split_and_save_files(paired_files, train_pct, test_pct): | |
base_path = "./model_data/input_files/datasets" | |
all_keys = list(paired_files.keys()) | |
random.shuffle(all_keys) | |
# Determine the size of each dataset split | |
total_files = len(all_keys) | |
train_size = int(total_files * train_pct / 100) | |
test_size = int(total_files * test_pct / 100) | |
# Split the file keys into training, testing, and validation sets | |
train_keys = all_keys[:train_size] | |
test_keys = all_keys[train_size : train_size + test_size] | |
val_keys = all_keys[train_size + test_size :] | |
# Iterate through each split and save the files to their respective directories | |
for folder_name, keys in zip( | |
["train", "test", "val"], [train_keys, test_keys, val_keys] | |
): | |
for key in keys: | |
image_file = paired_files[key]["image"] | |
label_file = paired_files[key]["label"] | |
# Save the image and label files if they exist | |
if image_file: | |
save_file_to_folder( | |
image_file, os.path.join(base_path, folder_name, "images") | |
) | |
if label_file: | |
save_file_to_folder( | |
label_file, os.path.join(base_path, folder_name, "labels") | |
) | |
# Function to save an individual file to a specified folder | |
def save_file_to_folder(file, folder_path): | |
os.makedirs(folder_path, exist_ok=True) | |
file_path = os.path.join(folder_path, file.name) | |
with open(file_path, "wb") as f: | |
f.write(file.getbuffer()) | |
# Function to save uploaded files to a specific folder within the base path | |
def save_files_to_folder(uploaded_files, folder_name): | |
# Define the base path for saving the files | |
base_path = "./model_data/input_files/datasets" | |
# Iterate through each uploaded file | |
for file in uploaded_files: | |
if file: | |
# Determine the file type based on file extension | |
file_type = ( | |
"images" | |
if os.path.splitext(file.name)[1].lower() in [".jpg", ".png"] | |
else "labels" | |
) | |
# Save the file to the appropriate subfolder (images or labels) | |
save_file_to_folder(file, os.path.join(base_path, folder_name, file_type)) | |
# Function to validate each line in the label file for bounding box data | |
def check_bboxes_label(label_file, class_dict): | |
for line in label_file: | |
try: | |
# Decode the line, strip whitespace, split into parts, and convert each part to float | |
class_id, x_center, y_center, width, height = map( | |
float, line.decode().strip().split() | |
) | |
# Check if bounding box coordinates and class ID are valid | |
if not ( | |
0 <= x_center <= 1 | |
and 0 <= y_center <= 1 | |
and 0 <= width <= 1 | |
and 0 <= height <= 1 | |
and class_id in class_dict.keys() | |
): | |
# Return False if any condition is not met (invalid data) | |
return False | |
except Exception as e: | |
# Return False in case of any exception (e.g., parsing error) | |
return False | |
# Return True if all lines in the label file pass the validation | |
return True | |
# Function to validate each line in the label file for mask data | |
def check_masks_label(label_file, class_dict): | |
for line in label_file: | |
try: | |
# Decode the line and split into parts: class ID and points | |
parts = line.decode().strip().split() | |
class_id = int( | |
parts[0] | |
) # Convert the first part to an integer for class ID | |
points = [ | |
float(p) for p in parts[1:] | |
] # Convert the remaining parts to float for coordinates | |
# Check if class ID exists in the class dictionary and all points are within [0, 1] | |
if not (class_id in class_dict.keys() and all(0 <= p <= 1 for p in points)): | |
return False # Return False if validation fails | |
except Exception as e: | |
# Return False in case of any exception (e.g., parsing error) | |
return False | |
return True # Return True if all lines in the label file pass the validation | |
# Function to read label from YOLO format | |
def read_label(file, selected_option, class_dict): | |
# Read the content of the file | |
file_content = file.readlines() | |
# Check and validate bounding box labels if the selected option is 'Bboxes' | |
if selected_option == "Bboxes": | |
return check_bboxes_label(file_content, class_dict) # Validate bbox labels | |
# Check and validate mask labels if the selected option is 'Masks' | |
elif selected_option == "Masks": | |
return check_masks_label(file_content, class_dict) # Validate mask labels | |
# Return False if the selected option is neither 'Bboxes' nor 'Masks' | |
return False | |
# Function to check for duplicates | |
def check_file_duplicates(file_names): | |
unique_names = set(file_names) | |
return len(unique_names) == len(file_names) | |
# Function to validates the uploaded image and label files | |
def validate_files(image_names, label_names): | |
# Check for duplicate filenames in both images and labels | |
if not check_file_duplicates(image_names) or not check_file_duplicates(label_names): | |
# Show warning if duplicates are found | |
st.warning( | |
"Duplicate file names detected. Please ensure each image and label has a unique name.", | |
icon="⚠️", | |
) | |
return False # Return False indicating validation failed | |
# Check if the number of images matches the number of labels | |
if len(image_names) != len(label_names): | |
# Show warning if counts don't match | |
st.warning( | |
"Count Mismatch: The number of uploaded images and labels does not match.", | |
icon="⚠️", | |
) | |
return False # Return False indicating validation failed | |
# Display a success message if the above checks pass | |
st.info( | |
f"Validated: {len(image_names)} images and labels successfully matched.", | |
icon="✅", | |
) | |
return True # Return True indicating successful validation | |
# Function to check labels format | |
def check_valid_labels(uploaded_files, selected_option, class_dict): | |
# Check if no files were uploaded | |
if len(uploaded_files) == 0: | |
st.warning("Please upload images and labels.", icon="⚠️") | |
return False | |
# Initialize lists to store names of image and label files | |
image_names, label_names = [], [] | |
# Initialize a progress bar and progress text | |
progress_bar = st.progress(0) | |
progress_text = st.empty() | |
total_files = len(uploaded_files) | |
# Iterate over each uploaded file | |
for index, file in enumerate(uploaded_files): | |
# Reset the file pointer to the beginning | |
file.seek(0) | |
# Check file type and categorize as image or label | |
if file.type in ["image/jpeg", "image/png"]: | |
# Add to image names list if file is an image | |
image_names.append(file.name) | |
elif file.type == "text/plain": | |
# Read and validate label file | |
if not read_label(file, selected_option, class_dict): | |
# Show warning if label format or data is invalid | |
st.warning( | |
f"Invalid label format or data in file: {file.name}", icon="⚠️" | |
) | |
return False | |
# Add to label names list if file is a valid label | |
label_names.append(file.name) | |
# Update progress bar and display current progress | |
progress_percentage = (index + 1) / total_files | |
progress_bar.progress(progress_percentage) | |
progress_text.text(f"Validating file {index + 1} of {total_files}") | |
# Remove progress bar and progress text after processing | |
progress_bar.empty() | |
progress_text.empty() | |
# Validate if all images have corresponding labels and vice versa | |
return validate_files(image_names, label_names) | |
# Function to get training, validation and export configurations | |
def get_training_validation_export_configuration(selected_training): | |
with st.expander("Training Configuration"): | |
# User Instruction for Default Values | |
st.markdown( | |
""" | |
<div style='text-align: justify;'> | |
<b>User Instructions:</b> If you are unsure about the specific values to use for training parameters, it is | |
recommended to stick with the default values provided. These defaults are carefully chosen to provide a good balance | |
between performance and resource utilization for most scenarios. You can always come back and tweak these settings | |
once you have more experience or specific requirements for your model training. | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Padding | |
utils.top_padding(2) | |
# Training Configuration | |
st.markdown("### Training Configuration") | |
# Model Selection | |
st.write("**Model Selection**") | |
selected_model = st.selectbox( | |
"Choose a YOLOv8 model variant", list(utils.models_info.keys()) | |
) | |
model_spec = utils.models_info[selected_model] | |
spec_string = ( | |
"<div style='text-align: justify;'>" | |
f"The selected model, <b>{selected_model}</b>, is benchmarked on an image size of 640x640 pixels. It has a Mean Average Precision (mAPval) of <b>{model_spec['mAPval']}</b>, " | |
f"operates with a speed of <b>{model_spec['speed_cpu']} ms</b> on CPU (ONNX) and <b>{model_spec['speed_gpu']} ms</b> on GPU (TensorRT). " | |
f"It consists of approximately <b>{model_spec['params']} million</b> parameters and requires about <b>{model_spec['flops']} billion</b> Floating Point Operations (FLOPs)." | |
"</div>" | |
) | |
st.markdown(spec_string, unsafe_allow_html=True) | |
# Spacer | |
st.markdown("---") | |
# Time Configuration | |
st.write("**Time Configuration**") | |
col1_time, col2_time = st.columns([1, 3]) | |
with col1_time: | |
top_padding_time = st.container() | |
time_allow = st.checkbox("Enable Time", value=False) | |
if time_allow: | |
with top_padding_time: | |
utils.top_padding(2) | |
time = col2_time.number_input( | |
"Time (hours)", min_value=1, max_value=100, value=1, step=1 | |
) | |
else: | |
time = None | |
st.markdown( | |
"<div style='text-align: justify;'>Set the training duration in hours. This option overrides the epochs setting. Useful for limiting training time in scenarios with constrained resources.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Epochs Configuration | |
st.write("**Epochs Configuration**") | |
epochs = st.number_input( | |
"Epochs", min_value=1, max_value=1000, value=50, step=10 | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Define the number of epochs for the training process. An epoch represents a complete pass over the entire dataset. More epochs can improve accuracy but increase training time.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Patience Configuration | |
st.write("**Patience Configuration**") | |
col1_patience, col2_patience = st.columns([1, 3]) | |
with col1_patience: | |
top_padding_patience = st.container() | |
patience_allow = st.checkbox("Enable Patience", value=False) | |
if patience_allow: | |
with top_padding_patience: | |
utils.top_padding(2) | |
patience = col2_patience.number_input( | |
"Patience (epochs)", min_value=5, max_value=50, value=5, step=1 | |
) | |
else: | |
patience = None | |
st.markdown( | |
"<div style='text-align: justify;'>Configure the early stopping mechanism. Patience denotes the number of epochs to wait for improvement in performance before stopping the training, helping to avoid overfitting.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Batch Size Configuration | |
st.write("**Batch Size Configuration**") | |
batch = st.number_input( | |
"Batch Size", min_value=-1, max_value=128, value=-1, step=1 | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Determine the number of images processed together in one pass (batch). A larger batch size can lead to faster training but requires more memory. Use -1 for automatic batch sizing.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Image Size Configuration | |
st.write("**Image Size Configuration**") | |
imgsz = st.number_input( | |
"Image Size (pixels)", min_value=64, max_value=4096, value=640, step=32 | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Specify the size of the input images. Larger images can capture more details but require more computational resources. The size is typically a square dimension, like 640x640 pixels.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Cache Configuration | |
st.write("**Cache Configuration**") | |
cache = st.selectbox("Cache Option", ["False", "True/ram", "disk"]) | |
st.markdown( | |
"<div style='text-align: justify;'>Choose a caching method for data loading to speed up training. 'True/ram' caches data in RAM, 'disk' caches on disk, and 'False' disables caching.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Optimizer Configuration | |
st.write("**Optimizer Configuration**") | |
optimizer = st.selectbox( | |
"Optimizer", | |
["SGD", "Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "auto"], | |
index=7, | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Select the optimizer for training. The optimizer adjusts weights to minimize the loss function. Choices include SGD, Adam, and others, with 'auto' selecting automatically based on the model.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# AMP Configuration | |
st.write("**AMP Configuration**") | |
amp = st.checkbox("Enable AMP", value=True) | |
st.markdown( | |
"<div style='text-align: justify;'>Enable Automatic Mixed Precision (AMP) to accelerate training on compatible hardware. AMP uses lower precision to reduce memory usage and speed up computations.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Deterministic Mode Configuration | |
st.write("**Deterministic Mode Configuration**") | |
deterministic = st.checkbox("Enable Deterministic Mode", value=False) | |
st.markdown( | |
"<div style='text-align: justify;'>Activate deterministic mode to ensure reproducible results. This mode might slow down the training but is useful for experimentation and debugging.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Rectangular Training Configuration | |
st.write("**Rectangular Training Configuration**") | |
rect = st.checkbox("Enable Rectangular Training", value=False) | |
st.markdown( | |
"<div style='text-align: justify;'>Enable rectangular training to process batches with minimal padding by reshaping images. This can lead to performance improvements but may affect accuracy.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Cosine Learning Rate Scheduler Configuration | |
st.write("**Cosine Learning Rate Scheduler**") | |
cos_lr = st.checkbox("Use Cosine LR Scheduler", value=False) | |
st.markdown( | |
"<div style='text-align: justify;'>Use a cosine learning rate scheduler to adjust the learning rate following a cosine curve, potentially leading to better convergence during training.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Freeze Layer Configuration | |
st.write("**Freeze Layer Configuration**") | |
col1_freeze, col2_freeze = st.columns([1, 3]) | |
with col1_freeze: | |
top_padding_freeze = st.container() | |
freeze_allow = st.checkbox("Enable Freeze Layers", value=False) | |
if freeze_allow: | |
with top_padding_freeze: | |
utils.top_padding(2) | |
freeze = col2_freeze.number_input( | |
"Freeze Layers", | |
min_value=1, | |
max_value=1000, | |
value=10, | |
placeholder="Enter number of layers", | |
) | |
else: | |
freeze = None | |
st.markdown( | |
"<div style='text-align: justify;'>Enable freezing the initial layers of the model during training. Specify the number of layers to freeze or a comma-separated list of specific layer indices. Useful for fine-tuning pre-trained models without modifying early layers.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Initial Learning Rate Configuration | |
st.write("**Initial Learning Rate (lr0)**") | |
lr0 = st.number_input( | |
"Initial Learning Rate (lr0)", | |
min_value=0.00001, | |
max_value=1.0, | |
value=0.01, | |
format="%.5f", | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Specify the initial learning rate (lr0) for the training process. The initial rate is crucial as it determines the starting step size for weight updates. A well-chosen initial rate helps in achieving a balance between fast convergence and overshooting the optimal solution.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Final Learning Rate Configuration | |
st.write("**Final Learning Rate (lrf)**") | |
lrf = st.number_input( | |
"Final Learning Rate (lrf)", | |
min_value=0.00001, | |
max_value=1.0, | |
value=0.01, | |
format="%.5f", | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Determine the final learning rate, which is a factor (lrf) of the initial learning rate (lr0). This parameter is used to adjust the learning rate over the course of training, gradually decreasing it to fine-tune model weights and stabilize training as it approaches the minimum of the loss function.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Momentum Configuration | |
st.write("**Momentum Configuration**") | |
momentum = st.number_input( | |
"Momentum", min_value=0.0, max_value=1.0, value=0.937, format="%.3f" | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Set the momentum value for the optimizer. Momentum helps in accelerating the optimizer in the relevant direction and dampens oscillations, facilitating faster convergence.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Weight Decay Configuration | |
st.write("**Weight Decay Configuration**") | |
weight_decay = st.number_input( | |
"Weight Decay", min_value=0.0, max_value=0.1, value=0.0005, format="%.5f" | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Specify the weight decay, a regularization technique that adds a small penalty to the loss function for larger weights. It helps in preventing overfitting by encouraging simpler models.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Warmup Epochs Configuration | |
st.write("**Warmup Epochs Configuration**") | |
warmup_epochs = st.number_input( | |
"Warmup Epochs", min_value=0.0, max_value=10.0, value=3.0, step=0.1 | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Define the number of warmup epochs. During warmup, the learning rate gradually increases to its initial value, which helps in stabilizing the training process in its early stages.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Warmup Momentum Configuration | |
st.write("**Warmup Momentum Configuration**") | |
warmup_momentum = st.number_input( | |
"Warmup Momentum", min_value=0.0, max_value=1.0, value=0.8, format="%.1f" | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Configure the momentum during the warmup phase. A lower momentum at the start can help in stabilizing the optimization process before reaching the specified momentum for the remaining epochs.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Warmup Bias Learning Rate Configuration | |
st.write("**Warmup Bias Learning Rate Configuration**") | |
warmup_bias_lr = st.number_input( | |
"Warmup Bias Learning Rate", | |
min_value=0.0, | |
max_value=1.0, | |
value=0.1, | |
format="%.1f", | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Adjust the bias learning rate during the warmup period. This parameter can be tuned to manage the initial learning rate specifically for the bias parameters in the early training phase.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Box Loss Gain Configuration | |
st.write("**Box Loss Gain Configuration**") | |
box = st.number_input( | |
"Box Loss Gain", min_value=0.0, max_value=10.0, value=7.5, step=0.1 | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Configure the gain factor for the box loss. This gain helps in adjusting the importance of the box size and location accuracy in the loss function, affecting how the model prioritizes bounding box precision.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Class Loss Gain Configuration | |
st.write("**Class Loss Gain Configuration**") | |
cls = st.number_input( | |
"Class Loss Gain", min_value=0.0, max_value=10.0, value=0.5, step=0.1 | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Set the gain factor for the class loss. This parameter scales the contribution of class prediction accuracy in the total loss, influencing how the model prioritizes correct class identification.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# DFL Loss Gain Configuration | |
st.write("**DFL Loss Gain Configuration**") | |
dfl = st.number_input( | |
"DFL Loss Gain", min_value=0.0, max_value=10.0, value=1.5, step=0.1 | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Determine the gain factor for the DFL loss. Adjusting this gain influences the model's focus on the Directional Focal Loss component, which is critical for precise object localization and classification.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Label Smoothing Configuration | |
st.write("**Label Smoothing Configuration**") | |
label_smoothing = st.number_input( | |
"Label Smoothing (fraction)", | |
min_value=0.0, | |
max_value=1.0, | |
value=0.0, | |
format="%.1f", | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Specify the label smoothing value, a technique that introduces softening to the target labels. It promotes model generalization and reduces the impact of noisy labels on the training process.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Nominal Batch Size Configuration | |
st.write("**Nominal Batch Size Configuration**") | |
nbs = st.number_input( | |
"Nominal Batch Size", min_value=1, max_value=128, value=64, step=1 | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Set the nominal batch size, which is used for normalizing the loss. This size does not affect the actual batch size but is used to scale the loss to a standard reference batch size.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Overlap Mask Configuration | |
st.write("**Overlap Mask Configuration**") | |
overlap_mask = st.checkbox("Masks Overlap during Training", value=True) | |
st.markdown( | |
"<div style='text-align: justify;'>Choose whether to allow masks to overlap during instance segmentation training. Overlapping can lead to more precise segmentation but may increase complexity.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Mask Ratio Configuration | |
st.write("**Mask Ratio Configuration**") | |
mask_ratio = st.number_input( | |
"Mask Downsample Ratio", min_value=1, max_value=10, value=4, step=1 | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Set the downsample ratio for masks in instance segmentation. A higher ratio reduces the mask resolution, which can speed up computations but might decrease segmentation accuracy.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Dropout Configuration | |
st.write("**Dropout Configuration**") | |
dropout = st.number_input( | |
"Dropout Regularization", | |
min_value=0.0, | |
max_value=1.0, | |
value=0.0, | |
format="%.1f", | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Configure the dropout rate, which randomly disables a proportion of neurons during training. This prevents the model from relying too much on certain features and promotes better generalization.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Validation/Test Configuration | |
st.write("**Validation/Test Configuration**") | |
val = st.checkbox("Validate/Test during Training", value=True) | |
st.markdown( | |
"<div style='text-align: justify;'>Decide whether to perform validation and testing during the training process. Regular validation helps monitor model performance and adjust training accordingly.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Save Plots Configuration | |
st.write("**Save Plots Configuration**") | |
plots = st.checkbox("Save Plots and Images during Training", value=True) | |
st.markdown( | |
"<div style='text-align: justify;'>Enable saving of plots and images during training. This feature provides visual insights into the training progress and helps in diagnosing model performance across epochs.</div>", | |
unsafe_allow_html=True, | |
) | |
# Padding | |
utils.top_padding(2) | |
with st.expander("Validation Configuration"): | |
# User Instruction for Default Values | |
st.markdown( | |
""" | |
<div style='text-align: justify;'> | |
<b>User Instructions:</b> If you are unsure about the specific values to use for validation parameters, it is | |
recommended to stick with the default values provided. These defaults are carefully chosen to provide a good balance | |
between performance and resource utilization for most scenarios. You can always come back and tweak these settings | |
once you have more experience or specific requirements for your model validation. | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Padding | |
utils.top_padding(2) | |
# Validation Configuration | |
st.markdown("### Validation Configuration") | |
# Object Confidence Threshold | |
st.write("**Object Confidence Threshold**") | |
conf = st.number_input( | |
"Confidence Threshold", | |
min_value=0.0, | |
max_value=1.0, | |
value=0.001, | |
format="%.3f", | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Set the confidence threshold for object detection. This threshold filters out detections with lower confidence, reducing false positives and focusing on more likely object detections.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Intersection Over Union (IoU) Threshold | |
st.write("**IoU Threshold for NMS**") | |
iou = st.number_input( | |
"IoU Threshold", min_value=0.0, max_value=1.0, value=0.6, format="%.1f" | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Define the IoU threshold for Non-Maximum Suppression. NMS is used to refine the bounding boxes by eliminating redundancies and retaining the most probable ones.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Maximum Number of Detections | |
st.write("**Maximum Number of Detections**") | |
max_det = st.number_input( | |
"Max Detections", min_value=1, max_value=1000, value=300, step=1 | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Limit the maximum number of detections per image. This setting is crucial for controlling the computational load and focusing the model on the most confident and relevant detections.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Use Half Precision | |
st.write("**Use Half Precision (FP16)**") | |
half = st.checkbox("Enable Half Precision", value=True) | |
st.markdown( | |
"<div style='text-align: justify;'>Enable half precision (FP16) training for enhanced performance on compatible GPUs. It reduces memory requirements and accelerates computation, beneficial for larger models and datasets.</div>", | |
unsafe_allow_html=True, | |
) | |
# Padding | |
utils.top_padding(2) | |
with st.expander("Export Configuration"): | |
# User Instruction for Default Values | |
st.markdown( | |
""" | |
<div style='text-align: justify;'> | |
<b>User Instructions:</b> If you are unsure about the specific values to use for export parameters, it is | |
recommended to stick with the default values provided. These defaults are carefully chosen to provide a good balance | |
between performance and resource utilization for most scenarios. You can always come back and tweak these settings | |
once you have more experience or specific requirements for your model export. | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Padding | |
utils.top_padding(2) | |
# Validation Configuration | |
st.markdown("### Export Configuration") | |
# Select Export Format | |
st.write("**Export Format**") | |
export_format = st.selectbox( | |
"Select Export Format", | |
[ | |
"Only PyTorch", | |
"TorchScript", | |
"ONNX", | |
"OpenVINO", | |
"TensorRT", | |
"CoreML", | |
"TF SavedModel", | |
"TF GraphDef", | |
"TF Lite", | |
"TF Edge TPU", | |
"TF.js", | |
"PaddlePaddle", | |
"ncnn", | |
], | |
) | |
# Dynamically generate description | |
if export_format == "Only PyTorch": | |
st.markdown( | |
""" | |
<div style='text-align: justify;'> | |
You have selected <b>PyTorch</b> as the export format. | |
This will export the model in the standard PyTorch <code>.pt</code> format. | |
There are no additional format-specific parameters to consider for this selection. | |
The exported model will be the same as selected during training. | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
else: | |
format_info = utils.export_formats[export_format] | |
# Handling additional arguments | |
if len(format_info["arguments"]) > 0: | |
additional_arguments = ", ".join(format_info["arguments"]) | |
arguments_info = f"Consider the following arguments for the <b>{export_format}</b> format: {additional_arguments}." | |
else: | |
arguments_info = ( | |
"No additional parameters need to be considered for this format." | |
) | |
st.markdown( | |
f""" | |
<div style='text-align: justify;'> | |
You have selected <b>{export_format}</b> as the export format. Along with the PyTorch model, | |
this selection will also export the model in the <b>{export_format}</b> format. The image size of | |
the exported model will be the same as selected during training. {arguments_info} | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Use Keras for TF SavedModel export | |
st.write("**Use Keras for TF SavedModel Export**") | |
keras = st.checkbox("Enable Keras", value=False) | |
st.markdown( | |
"<div style='text-align: justify;'>Enabling Keras optimizes the TensorFlow SavedModel export for compatibility with the Keras API, making it easier to work with in Keras-centric workflows.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Optimize for mobile (TorchScript) | |
st.write("**Optimize TorchScript for Mobile**") | |
optimize = st.checkbox("Enable Optimization", value=False) | |
st.markdown( | |
"<div style='text-align: justify;'>Optimizing for mobile reduces the model size and computational needs, enhancing performance on mobile devices with limited resources.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# FP16 quantization | |
st.write("**FP16 Quantization**") | |
half = st.checkbox("Enable FP16 Quantization", value=False) | |
st.markdown( | |
"<div style='text-align: justify;'>FP16 quantization reduces model size and speeds up inference, especially on GPUs with Tensor Cores, while maintaining model accuracy.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# INT8 quantization | |
st.write("**INT8 Quantization**") | |
int8 = st.checkbox("Enable INT8 Quantization", value=False) | |
st.markdown( | |
"<div style='text-align: justify;'>INT8 quantization further reduces model size and inference time, ideal for edge devices, at the cost of a slight decrease in accuracy.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Dynamic axes for ONNX/TensorRT | |
st.write("**Dynamic Axes for ONNX/TensorRT**") | |
dynamic = st.checkbox("Enable Dynamic Axes", value=False) | |
st.markdown( | |
"<div style='text-align: justify;'>Dynamic axes allow the ONNX/TensorRT models to handle variable input sizes, increasing the model's flexibility in deployment.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Simplify model for ONNX/TensorRT | |
st.write("**Simplify Model for ONNX/TensorRT**") | |
simplify = st.checkbox("Enable Model Simplification", value=False) | |
st.markdown( | |
"<div style='text-align: justify;'>Simplification optimizes the ONNX/TensorRT models by removing redundant operations, improving efficiency without impacting accuracy.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# ONNX Opset Version Configuration | |
st.write("**ONNX Opset Version Configuration**") | |
col1_opset, col2_opset = st.columns([1, 3]) | |
with col1_opset: | |
top_padding_opset = st.container() | |
opset_allow = st.checkbox("Specify Opset Version", value=False) | |
if opset_allow: | |
with top_padding_opset: | |
utils.top_padding(2) | |
# Create a range of opset versions for the dropdown | |
opset_versions = list(range(1, onnx_opset_version() + 1)) | |
with col2_opset: | |
opset = st.selectbox( | |
"Select Opset Version", | |
opset_versions, | |
index=len(opset_versions) - 1, | |
) | |
else: | |
opset = None | |
st.markdown( | |
"<div style='text-align: justify;'>Select the ONNX opset version for the export. " | |
"Specifying an opset version can ensure compatibility with specific ONNX versions. " | |
"The latest version is recommended to ensure the most up-to-date features and optimizations. " | |
"If unsure, leave the checkbox unchecked to use the default opset version.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# TensorRT workspace size | |
st.write("**TensorRT Workspace Size (GB)**") | |
workspace = st.number_input( | |
"Workspace Size", min_value=1, max_value=32, value=4, step=1 | |
) | |
st.markdown( | |
"<div style='text-align: justify;'>Set the TensorRT workspace size in GB. A larger workspace can lead to more optimized models but requires more memory.</div>", | |
unsafe_allow_html=True, | |
) | |
# Spacer | |
st.markdown("---") | |
# Add NMS for CoreML | |
st.write("**Add NMS for CoreML**") | |
nms = st.checkbox("Enable NMS", value=False) | |
st.markdown( | |
"<div style='text-align: justify;'>Enabling NMS (Non-Maximum Suppression) for CoreML models helps in reducing overlapping bounding boxes and improves the clarity of object detection results.</div>", | |
unsafe_allow_html=True, | |
) | |
# Padding | |
utils.top_padding(2) | |
if selected_training == "Object Detection": | |
model_path = os.path.join( | |
get_path("models"), selected_model.lower() + ".pt" | |
) | |
task = "detect" | |
elif selected_training == "Instance Segmentation": | |
model_path = os.path.join( | |
get_path("models"), selected_model.lower() + "-seg.pt" | |
) | |
task = "segment" | |
export_settings = { | |
"format": None if export_format == "Only PyTorch" else export_format, | |
"keras": keras, | |
"optimize": optimize, | |
"half": half, | |
"int8": int8, | |
"dynamic": dynamic, | |
"simplify": simplify, | |
"opset": opset, | |
"workspace": workspace, | |
"nms": nms, | |
} | |
return { | |
"model_path": model_path, | |
"task": task, | |
"model": selected_model, | |
"time": time, | |
"epochs": epochs, | |
"patience": patience, | |
"batch": batch, | |
"imgsz": imgsz, | |
"cache": cache, | |
"optimizer": optimizer, | |
"amp": amp, | |
"deterministic": deterministic, | |
"rect": rect, | |
"cos_lr": cos_lr, | |
"freeze": freeze, | |
"lr0": lr0, | |
"lrf": lrf, | |
"momentum": momentum, | |
"weight_decay": weight_decay, | |
"warmup_epochs": warmup_epochs, | |
"warmup_momentum": warmup_momentum, | |
"warmup_bias_lr": warmup_bias_lr, | |
"box": box, | |
"cls": cls, | |
"dfl": dfl, | |
"label_smoothing": label_smoothing, | |
"nbs": nbs, | |
"overlap_mask": overlap_mask, | |
"mask_ratio": mask_ratio, | |
"dropout": dropout, | |
"val": val, | |
"plots": plots, | |
"conf": conf, | |
"iou": iou, | |
"max_det": max_det, | |
"half": half, | |
"export_settings": export_settings, | |
} | |
# Function to generate python code for model training | |
def generate_python_code_model_training(training_configuration): | |
# Copy the original configuration and update with additional parameters | |
training_configuration_code = training_configuration.copy() | |
training_configuration_code["data"] = r".\config.yaml" # Path to config file | |
training_configuration_code["save_dir"] = r".\output\train" # Output directory | |
training_configuration_code["pretrained"] = True # Use a pretrained model | |
training_configuration_code["save"] = True # Save the trained model | |
training_configuration_code["save_period"] = -1 # Save period configuration | |
training_configuration_code["augment"] = False # Augmentation setting | |
training_configuration_code["seed"] = 0 # Seed for reproducibility | |
training_configuration_code["verbose"] = True # Verbose output | |
training_configuration_code["single_cls"] = False # Single class setting | |
training_configuration_code["resume"] = False # Resume training setting | |
training_configuration_code["exist_ok"] = True # Overwrite existing files | |
training_configuration_code["project"] = r".\output" # Project directory | |
training_configuration_code["name"] = "train" # Project name | |
# Extract the model name from the model path | |
model_name = training_configuration_code["model_path"].split("\\")[-1] | |
# Start with necessary library imports and model initialization | |
code_str = "# Importing necessary libraries\n" | |
code_str += "from ultralytics import YOLO\n\n" | |
# Initialize the YOLO model | |
code_str += f"# Initialize the YOLO model '{model_name}'\n" | |
code_str += f"model = YOLO('{model_name}')\n" | |
# Add the model training code | |
code_str += "\n# Start the training process\n" | |
code_str += "model.train(\n" | |
for key, value in training_configuration_code.items(): | |
if key not in [ | |
"model_path", | |
"model", | |
"export_settings", | |
]: # Exclude specific keys | |
code_str += f" {key}={value},\n" | |
code_str = code_str.rstrip(",\n") + "\n)\n" | |
# Add model export code | |
code_str += "\n# Model export process\n" | |
code_str += "model.export(\n" | |
for key, value in training_configuration_code["export_settings"].items(): | |
if key == "format" and value is None: | |
continue # Skip format if it's None | |
code_str += f" {key}={value},\n" | |
code_str = code_str.rstrip(",\n") + "\n)\n" | |
return code_str | |
# Function to overwrites a Python file with new code | |
def overwrite_python_file(code_str, file_path): | |
# Open the file in write mode, which automatically deletes old content | |
with open(file_path, "w") as file: | |
file.write(code_str) | |
# Function to generate a downloadable file | |
def display_code_and_download_button(generated_code): | |
# Display the generated code in Streamlit with description and download button in columns | |
with st.expander("Plug and Play Code"): | |
col1, col2 = st.columns([7, 3]) | |
with col1: | |
st.markdown( | |
""" | |
### Description of the Code Pipeline | |
""" | |
) | |
st.markdown( | |
""" | |
<div style='text-align: justify;'> | |
This Python script is configured for training a YOLO model. It includes necessary configurations and parameters for a custom YOLO model training session. | |
**To use this script:** | |
- Ensure you have the necessary dependencies installed. | |
- Place your image and label files in the `'datasets/train'`, `'datasets/val'`, and `'datasets/test'` folders respectively. | |
- The `'config.yaml'` file and the training script are set up based on your provided configurations. | |
### Python Code | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Display python code | |
st.code(generated_code, language="python") | |
# Determine the main directory path | |
main_directory_path = os.path.dirname( | |
os.path.dirname(os.path.abspath(__file__)) | |
) | |
# Overwrites a Python file with new code | |
overwrite_python_file( | |
generated_code, | |
os.path.join( | |
main_directory_path, | |
"model_data", | |
"model_training_code_pipline", | |
"model_training.py", | |
), | |
) | |
# Determine the main directory path | |
main_directory_path = os.path.dirname( | |
os.path.dirname(os.path.abspath(__file__)) | |
) | |
# Prepare a ZIP file of the training output folder in memory for download | |
zip_bytes_io = zip_folder_to_bytesio( | |
os.path.join( | |
main_directory_path, "model_data", "model_training_code_pipline" | |
) | |
) | |
with col2: | |
# Create a button for downloading the training pipeline | |
st.download_button( | |
label="Download Training Pipeline", | |
data=zip_bytes_io, | |
file_name="model_training_code.zip", | |
mime="application/zip", | |
use_container_width=True, | |
) | |
# Function to generates a YOLO model training code snippet and displays it with a download button | |
def generate_and_display_yolo_training_code(class_labels, training_configuration): | |
# Determine the main directory path | |
main_directory_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
# Construct the path to the config file directory | |
config_file_path = os.path.join( | |
main_directory_path, "model_data", "model_training_code_pipline" | |
) | |
# Define the path to the dataset directory | |
dataset_directory_path = "./datasets" | |
# Create YOLO config file using provided class labels and dataset directory | |
create_yolo_config_file(config_file_path, class_labels, dataset_directory_path) | |
# Generate the Python code for YOLO model training | |
generated_code = generate_python_code_model_training(training_configuration) | |
# Display the generated code and a download button | |
display_code_and_download_button(generated_code) | |
# Function to create a yolo config file | |
def create_yolo_config_file( | |
config_file_path, class_labels, dataset_directory_path=None | |
): | |
if dataset_directory_path is None: | |
dataset_directory_path = os.path.join(config_file_path, "datasets") | |
# Number of classes | |
num_classes = len(class_labels) | |
# Create the configuration content | |
config_content = f"""path: {dataset_directory_path} # Path to the dataset directory | |
train: train # Path to the training set directory | |
val: val # Path to the validation set directory | |
test: test # Path to the testing set directory | |
nc: {num_classes} # Number of classes | |
names: {class_labels} # List of class names | |
""" | |
# Write the configuration to a file | |
with open(os.path.join(config_file_path, "config.yaml"), "w") as file: | |
file.write(config_content) | |
# Function to delete and recreate a folder | |
def delete_and_recreate_folder(folder_path): | |
try: | |
# Use shutil.rmtree to delete the folder and its contents | |
shutil.rmtree(folder_path) | |
# Recreate the folder at the same location | |
os.makedirs(folder_path) | |
except Exception as e: | |
print(f"Error deleting or recreating folder {folder_path}: {e}") | |
# Function to read csv and get values | |
def read_csv_and_get_values(csv_file_path): | |
# Read the CSV file into a pandas DataFrame | |
df = pd.read_csv(csv_file_path) | |
# Initialize an empty dictionary to store the results | |
result_dict = {} | |
# Iterate through the columns of the DataFrame | |
for column in df.columns: | |
# Remove leading and trailing spaces from the column name | |
clean_column_name = column.strip() | |
# Get the values in the column | |
column_values = df[column].astype(float) | |
# Add the cleaned column name and values to the result dictionary | |
result_dict[clean_column_name] = np.array(column_values) | |
return result_dict | |
# Global variables | |
plot_container = None | |
val_dataframe_container = None | |
progress_bar = None | |
progress_text = None | |
# Function to define a custom callback function for on_pretrain_routine_start | |
def on_pretrain_routine_start(trainer): | |
global progress_text, progress_bar | |
progress_bar = st.empty() | |
progress_text = st.empty() | |
progress_text.info( | |
"Loading selected model...", | |
icon="✅", | |
) | |
# Function to define a custom callback function for on_train_start | |
def on_train_start(trainer): | |
global progress_bar, progress_text | |
progress_bar = st.progress(0) | |
progress_text.info( | |
"Training Started...", | |
icon="✅", | |
) | |
# Function to display metrics plot | |
st.cache_resource(show_spinner=False) | |
def display_metrics_plot(output_data): | |
global plot_container | |
# Extract data for each metric | |
epoch_history = output_data.get("epoch") | |
# Extract loss histories | |
train_box_loss_history = output_data.get("train/box_loss") | |
train_cls_loss_history = output_data.get("train/cls_loss") | |
train_dfl_loss_history = output_data.get("train/dfl_loss") | |
train_seg_loss_history = output_data.get("train/seg_loss") | |
val_box_loss_history = output_data.get("val/box_loss") | |
val_cls_loss_history = output_data.get("val/cls_loss") | |
val_dfl_loss_history = output_data.get("val/dfl_loss") | |
val_seg_loss_history = output_data.get("val/seg_loss") | |
if train_seg_loss_history is None: | |
train_seg_loss_history = epoch_history * 0 | |
val_seg_loss_history = epoch_history * 0 | |
# Extract precision, recall, and mAP histories for B and M box/mask | |
precision_B_history = output_data.get("metrics/precision(B)") | |
recall_B_history = output_data.get("metrics/recall(B)") | |
mAP50_B_history = output_data.get("metrics/mAP50(B)") | |
mAP50_95_B_history = output_data.get("metrics/mAP50-95(B)") | |
precision_M_history = output_data.get("metrics/precision(M)") | |
recall_M_history = output_data.get("metrics/recall(M)") | |
mAP50_M_history = output_data.get("metrics/mAP50(M)") | |
mAP50_95_M_history = output_data.get("metrics/mAP50-95(M)") | |
# Check for 'None' data and adjust the number of rows in the grid | |
num_rows = 4 | |
subplot_titles = [ | |
"Precision B", | |
"Recall B", | |
"mAP50 B", | |
"mAP50-95 B", | |
"Precision R", | |
"Recall R", | |
"mAP50 R", | |
"mAP50-95 R", | |
"Train Box Loss", | |
"Train Class Loss", | |
"Train DFL Loss", | |
"Train Seg Loss", | |
"Val Box Loss", | |
"Val Class Loss", | |
"Val DFL Loss", | |
"Val Seg Loss", | |
] | |
if precision_M_history is None: | |
num_rows = 3 | |
subplot_titles = subplot_titles[0:4] + subplot_titles[8:] | |
# Create a subplot grid | |
fig = make_subplots( | |
rows=num_rows, | |
cols=4, | |
subplot_titles=subplot_titles, | |
vertical_spacing=0.05, | |
) | |
# Initialize row number | |
row_number = 1 | |
# Add precision, recall, mAP plots for B and R box/mask | |
fig.add_trace( | |
go.Scatter( | |
x=epoch_history, y=precision_B_history, mode="lines", name="Precision B" | |
), | |
row=row_number, | |
col=1, | |
) | |
fig.add_trace( | |
go.Scatter(x=epoch_history, y=recall_B_history, mode="lines", name="Recall B"), | |
row=row_number, | |
col=2, | |
) | |
fig.add_trace( | |
go.Scatter(x=epoch_history, y=mAP50_B_history, mode="lines", name="mAP50 B"), | |
row=row_number, | |
col=3, | |
) | |
fig.add_trace( | |
go.Scatter( | |
x=epoch_history, y=mAP50_95_B_history, mode="lines", name="mAP50-95 B" | |
), | |
row=row_number, | |
col=4, | |
) | |
if precision_M_history is not None: | |
# Increment row number | |
row_number += 1 | |
fig.add_trace( | |
go.Scatter( | |
x=epoch_history, y=precision_M_history, mode="lines", name="Precision R" | |
), | |
row=row_number, | |
col=1, | |
) | |
fig.add_trace( | |
go.Scatter( | |
x=epoch_history, y=recall_M_history, mode="lines", name="Recall R" | |
), | |
row=row_number, | |
col=2, | |
) | |
fig.add_trace( | |
go.Scatter( | |
x=epoch_history, y=mAP50_M_history, mode="lines", name="mAP50 R" | |
), | |
row=row_number, | |
col=3, | |
) | |
fig.add_trace( | |
go.Scatter( | |
x=epoch_history, y=mAP50_95_M_history, mode="lines", name="mAP50-95 R" | |
), | |
row=row_number, | |
col=4, | |
) | |
# Increment row number | |
row_number += 1 | |
# Add loss plots | |
fig.add_trace( | |
go.Scatter( | |
x=epoch_history, | |
y=train_box_loss_history, | |
mode="lines", | |
name="Train Box Loss", | |
), | |
row=row_number, | |
col=1, | |
) | |
fig.add_trace( | |
go.Scatter( | |
x=epoch_history, | |
y=train_cls_loss_history, | |
mode="lines", | |
name="Train Class Loss", | |
), | |
row=row_number, | |
col=2, | |
) | |
fig.add_trace( | |
go.Scatter( | |
x=epoch_history, | |
y=train_dfl_loss_history, | |
mode="lines", | |
name="Train DFL Loss", | |
), | |
row=row_number, | |
col=3, | |
) | |
fig.add_trace( | |
go.Scatter( | |
x=epoch_history, | |
y=train_seg_loss_history, | |
mode="lines", | |
name="Train Seg Loss", | |
), | |
row=row_number, | |
col=4, | |
) | |
# Increment row number | |
row_number += 1 | |
fig.add_trace( | |
go.Scatter( | |
x=epoch_history, y=val_box_loss_history, mode="lines", name="Val Box Loss" | |
), | |
row=row_number, | |
col=1, | |
) | |
fig.add_trace( | |
go.Scatter( | |
x=epoch_history, y=val_cls_loss_history, mode="lines", name="Val Class Loss" | |
), | |
row=row_number, | |
col=2, | |
) | |
fig.add_trace( | |
go.Scatter( | |
x=epoch_history, y=val_dfl_loss_history, mode="lines", name="Val DFL Loss" | |
), | |
row=row_number, | |
col=3, | |
) | |
fig.add_trace( | |
go.Scatter( | |
x=epoch_history, | |
y=val_seg_loss_history, | |
mode="lines", | |
name="Val Seg Loss", | |
), | |
row=row_number, | |
col=4, | |
) | |
# Check if the plot container is already initialized | |
if plot_container is None: | |
plot_container = st.empty() | |
# Update layout | |
fig.update_layout( | |
height=1200, | |
width=1600, | |
title_text="Metrics", | |
legend=dict(orientation="h", yanchor="bottom", xanchor="left"), | |
) | |
# Display the updated plot in the same container | |
plot_container.plotly_chart(fig, use_container_width=True) | |
# Function to define a custom callback function for on_fit_epoch_end | |
def on_fit_epoch_end(trainer): | |
current_epoch = int(trainer.epoch) | |
total_epochs = int(trainer.epochs) | |
# Define the path to the output CSV | |
output_csv_path = os.path.join(get_path("output"), "train", "results.csv") | |
# Read the CSV data | |
st.session_state["plot_data"] = read_csv_and_get_values(output_csv_path) | |
# Call a function to update the plot using this data | |
display_metrics_plot(st.session_state["plot_data"]) | |
# Update progress bar and text | |
progress_bar.progress((current_epoch + 1) / total_epochs) | |
progress_text.write(f"Epoch {(current_epoch + 1)}/{total_epochs}") | |
# Function to define a custom callback function for on_train_end | |
def on_train_end(trainer): | |
global progress_bar, progress_text | |
progress_bar.empty() | |
progress_text.info( | |
"Best and last model save completed successfully.", | |
icon="✅", | |
) | |
# Function to add various callbacks to the YOLO model for different stages of the training process | |
def callback_add(model): | |
# Add a callback to be triggered at the start of the pre-training routine | |
model.add_callback("on_pretrain_routine_start", on_pretrain_routine_start) | |
# Add a callback to be triggered at the start of the training | |
model.add_callback("on_train_start", on_train_start) | |
# Add a callback to be triggered at the end of each training epoch | |
model.add_callback("on_fit_epoch_end", on_fit_epoch_end) | |
# Add a callback to be triggered at the end of the training process | |
model.add_callback("on_train_end", on_train_end) | |
# Function to zip a folder and all its subfolders and return a BytesIO object | |
def zip_folder_to_bytesio(folder_path): | |
bytes_io = io.BytesIO() | |
with zipfile.ZipFile(bytes_io, "w", zipfile.ZIP_DEFLATED) as zipf: | |
folder_path_abs = os.path.abspath(folder_path) | |
for root, dirs, files in os.walk(folder_path): | |
# Calculate the relative path from the folder_path | |
folder_rel_path = os.path.relpath(root, folder_path_abs) | |
# If the directory is empty, add the directory itself | |
if not dirs and not files: | |
# ZIP format requires a trailing slash for empty directories | |
zip_dir_path = f"{folder_rel_path}/" if folder_rel_path != "." else "" | |
zipf.write(root, zip_dir_path) | |
for file in files: | |
file_path = os.path.join(root, file) | |
# Construct the path within the zip file | |
zip_file_path = ( | |
os.path.join(folder_rel_path, file) | |
if folder_rel_path != "." | |
else file | |
) | |
zipf.write(file_path, zip_file_path) | |
bytes_io.seek(0) # Go to the start of the BytesIO buffer | |
return bytes_io | |
# Function to display Metrics Table | |
st.cache_resource(show_spinner=False) | |
def display_val_dataframe(val_dataframe): | |
global val_dataframe_container | |
# Check if the dataframe container is already initialized | |
if val_dataframe_container is None: | |
val_dataframe_container = st.container() | |
# Display the updated dataframe in the same container | |
with val_dataframe_container: | |
# Display the message to indicate that the metrics table is ready | |
st.markdown("**Metrics Table**", unsafe_allow_html=True) | |
# Display the DataFrame | |
st.dataframe(val_dataframe) | |
# Function to display the DataFrame | |
def val_dataframe(model): | |
# Placeholder for the initial message | |
message = st.empty() | |
message.markdown("**Generating Metrics Table...**", unsafe_allow_html=True) | |
# Extract the metrics from the model | |
metrics = model.val() | |
# Extract the class indices and names | |
class_index = metrics.ap_class_index | |
class_names = metrics.names | |
# Extract precision, recall, and mAP values for the box (B) metrics | |
precision_B_values = metrics.box.p | |
recall_B_values = metrics.box.r | |
mAP50_95_B_values = [metrics.box.maps[i] for i in class_index] | |
# Check if segmentation (mask) metrics exist | |
try: | |
metrics_mask = metrics.seg | |
except: | |
metrics_mask = False | |
if metrics_mask: | |
precision_M_values = metrics_mask.p | |
recall_M_values = metrics_mask.r | |
mAP50_95_M_values = [metrics_mask.maps[i] for i in class_index] | |
# Extract aggregated metrics from the results dictionary | |
results_dict = metrics.results_dict | |
# Initialize lists for overall precision, recall, and mAP for box (B) | |
precision_B = [results_dict.get("metrics/precision(B)")] | |
recall_B = [results_dict.get("metrics/recall(B)")] | |
mAP50_95_B = [results_dict.get("metrics/mAP50-95(B)")] | |
# Initialize lists for overall precision, recall, and mAP for mask (M) if available | |
precision_M = [results_dict.get("metrics/precision(M)")] if metrics_mask else None | |
recall_M = [results_dict.get("metrics/recall(M)")] if metrics_mask else None | |
mAP50_95_M = [results_dict.get("metrics/mAP50-95(M)")] if metrics_mask else None | |
# Create a list of class names starting with "All" for the overall metrics | |
name_list = ["All"] + [str(class_names[i]) for i in class_index] | |
# Extend the metrics lists with values for each class | |
precision_B.extend(precision_B_values) | |
recall_B.extend(recall_B_values) | |
mAP50_95_B.extend(mAP50_95_B_values) | |
# If mask metrics are available, extend their lists with values for each class | |
if metrics_mask: | |
precision_M.extend(precision_M_values) | |
recall_M.extend(recall_M_values) | |
mAP50_95_M.extend(mAP50_95_M_values) | |
# Create a DataFrame with the computed metrics | |
if metrics_mask: | |
st.session_state["val_dataframe"] = pd.DataFrame( | |
{ | |
"Class Name": name_list, | |
"Precision (B)": precision_B, | |
"Recall (B)": recall_B, | |
"mAP50-95 (B)": mAP50_95_B, | |
"Precision (M)": precision_M, | |
"Recall (M)": recall_M, | |
"mAP50-95 (M)": mAP50_95_M, | |
} | |
) | |
else: | |
st.session_state["val_dataframe"] = pd.DataFrame( | |
{ | |
"Class Name": name_list, | |
"Precision (B)": precision_B, | |
"Recall (B)": recall_B, | |
"mAP50-95 (B)": mAP50_95_B, | |
} | |
) | |
# Clear the initial message | |
message.empty() | |
# Update the message to indicate that the metrics table is ready and Display the DataFrame | |
display_val_dataframe(st.session_state["val_dataframe"]) | |
# Function to train the YOLO model | |
def train_yolo_model(training_configuration): | |
# Clear and recreate the output folder to ensure a fresh start | |
delete_and_recreate_folder(get_path("output")) | |
# Initialize the YOLO model with the specified path from the training configuration | |
model = YOLO(training_configuration["model_path"]) | |
# Add any callbacks or additional configuration to the model | |
callback_add(model) | |
# Train the model with the specified parameters | |
model.train( | |
task=training_configuration["task"], | |
data=os.path.join(get_path("config"), "config.yaml"), | |
epochs=training_configuration["epochs"], | |
time=training_configuration["time"], | |
patience=training_configuration["patience"], | |
batch=training_configuration["batch"], | |
imgsz=training_configuration["imgsz"], | |
save=True, | |
save_period=-1, | |
cache=training_configuration["cache"], | |
pretrained=True, | |
optimizer=training_configuration["optimizer"], | |
verbose=True, | |
seed=0, | |
deterministic=training_configuration["deterministic"], | |
single_cls=False, | |
rect=training_configuration["rect"], | |
cos_lr=training_configuration["cos_lr"], | |
resume=False, | |
amp=training_configuration["amp"], | |
fraction=1.0, | |
freeze=training_configuration["freeze"], | |
lr0=training_configuration["lr0"], | |
lrf=training_configuration["lrf"], | |
momentum=training_configuration["momentum"], | |
weight_decay=training_configuration["weight_decay"], | |
warmup_epochs=training_configuration["warmup_epochs"], | |
warmup_momentum=training_configuration["warmup_momentum"], | |
warmup_bias_lr=training_configuration["warmup_bias_lr"], | |
box=training_configuration["box"], | |
cls=training_configuration["cls"], | |
dfl=training_configuration["dfl"], | |
label_smoothing=training_configuration["label_smoothing"], | |
nbs=training_configuration["nbs"], | |
overlap_mask=training_configuration["overlap_mask"], | |
mask_ratio=training_configuration["mask_ratio"], | |
dropout=training_configuration["dropout"], | |
val=training_configuration["val"], | |
plots=training_configuration["plots"], | |
save_dir=os.path.join(get_path("output"), "train"), | |
project=get_path("output"), | |
name="train", | |
augment=False, | |
exist_ok=True, | |
) | |
return model | |
# Function to export the model with the given parameters | |
def export_model_with_parameters(model, export_params): | |
global progress_text | |
if export_params["format"] is not None: | |
# Informing the user that the export process has started | |
progress_text.info( | |
"Starting the export process with the specified settings.", | |
icon="✅", | |
) | |
# Perform the model export | |
model.export( | |
format=export_params["format"], | |
keras=export_params["keras"], | |
optimize=export_params["optimize"], | |
half=export_params["half"], | |
int8=export_params["int8"], | |
dynamic=export_params["dynamic"], | |
simplify=export_params["simplify"], | |
opset=export_params["opset"], | |
workspace=export_params["workspace"], | |
nms=export_params["nms"], | |
) | |
# Informing the user that the export process has completed successfully | |
progress_text.info( | |
"The model has been successfully saved using the specified export settings.", | |
icon="✅", | |
) | |
# Function to start the YOLO model training process | |
def start_yolo_training(selected_training, class_labels): | |
global plot_container, val_dataframe_container | |
# Retrieve the training configuration based on the user's selection | |
training_configuration = get_training_validation_export_configuration( | |
selected_training | |
) | |
# Generates a YOLO model training code snippet and displays it with a download button | |
generate_and_display_yolo_training_code(class_labels, training_configuration) | |
# Create two columns | |
col1, col2 = st.columns(2) | |
# When the "Start Training" button is clicked in the first column | |
if col1.button("Start Training", use_container_width=True): | |
plot_container = None | |
val_dataframe_container = None | |
with st.spinner("Training in Progress..."): | |
# Train the YOLO model using the provided configuration | |
trained_model = train_yolo_model(training_configuration) | |
# Export the model with the given parameters | |
export_model_with_parameters( | |
trained_model, training_configuration["export_settings"] | |
) | |
# Display the validation results in a DataFrame after training | |
val_dataframe(trained_model) | |
elif "plot_data" in st.session_state and "val_dataframe" in st.session_state: | |
plot_container = None | |
val_dataframe_container = None | |
# Display metrics plot and table if already exist | |
display_metrics_plot(st.session_state["plot_data"]) | |
display_val_dataframe(st.session_state["val_dataframe"]) | |
# Prepare a ZIP file of the training output folder in memory for download | |
zip_bytes_io = zip_folder_to_bytesio(os.path.join(get_path("output"), "train")) | |
# Provide a button in the second column to download the ZIP file | |
col2.download_button( | |
label="Download", | |
data=zip_bytes_io, | |
file_name="model_training_output.zip", | |
mime="application/zip", | |
use_container_width=True, | |
) | |