Spaces:
Running
Running
First Commit
Browse files- .streamlit/config.toml +9 -0
- Functions/__pycache__/image_augmentation_functions.cpython-311.pyc +0 -0
- Functions/__pycache__/image_processing_functions.cpython-311.pyc +0 -0
- Functions/__pycache__/model_training_functions.cpython-311.pyc +0 -0
- Functions/image_augmentation_functions.py +1027 -0
- Functions/image_processing_functions.py +430 -0
- Functions/model_training_functions.py +1896 -0
- README.md +12 -14
- Welcome.py +8 -0
- config.toml +9 -0
- model_data/input_files/config.yaml +6 -0
- model_data/model_training_code_pipline/config.yaml +6 -0
- model_data/model_training_code_pipline/model_training.py +69 -0
- model_data/models/yolov8l-seg.pt +3 -0
- model_data/models/yolov8l.pt +3 -0
- model_data/models/yolov8m-seg.pt +3 -0
- model_data/models/yolov8m.pt +3 -0
- model_data/models/yolov8n-seg.pt +3 -0
- model_data/models/yolov8n.pt +3 -0
- model_data/models/yolov8s-seg.pt +3 -0
- model_data/models/yolov8s.pt +3 -0
- model_data/models/yolov8x-seg.pt +3 -0
- model_data/models/yolov8x.pt +3 -0
- packages.txt +1 -0
- packages_db.xlsx +0 -0
- pages/2_Image_Processing.py +512 -0
- pages/3_Image_Augmentation.py +612 -0
- pages/4_Model_Training.py +364 -0
- requirements.txt +0 -0
- utils.py +446 -0
.streamlit/config.toml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[theme]
|
2 |
+
primaryColor = "#101276"
|
3 |
+
backgroundColor = "#FFFFFF"
|
4 |
+
textColor = "#006EC0"
|
5 |
+
secondaryBackgroundColor = "#D2E5F2"
|
6 |
+
font = "sans serif"
|
7 |
+
|
8 |
+
[server]
|
9 |
+
maxUploadSize = 10000
|
Functions/__pycache__/image_augmentation_functions.cpython-311.pyc
ADDED
Binary file (41.7 kB). View file
|
|
Functions/__pycache__/image_processing_functions.cpython-311.pyc
ADDED
Binary file (18.3 kB). View file
|
|
Functions/__pycache__/model_training_functions.cpython-311.pyc
ADDED
Binary file (73.9 kB). View file
|
|
Functions/image_augmentation_functions.py
ADDED
@@ -0,0 +1,1027 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Importing necessary libraries
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
import cv2
|
5 |
+
import utils
|
6 |
+
import random
|
7 |
+
import zipfile
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
from PIL import Image
|
11 |
+
import streamlit as st
|
12 |
+
import albumentations as A
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
|
15 |
+
|
16 |
+
# Function to fetch the names of augmentation techniques applicable to the selected option
|
17 |
+
st.cache_resource(show_spinner=False)
|
18 |
+
|
19 |
+
|
20 |
+
def get_applicable_techniques(df, option):
|
21 |
+
return df[df[option] == "Applicable"]["Name"]
|
22 |
+
|
23 |
+
|
24 |
+
# Function to generate unique colors for each given list of unique numbers
|
25 |
+
def generate_unique_colors(unique_numbers):
|
26 |
+
# Get a color map
|
27 |
+
cmap = plt.get_cmap("hsv")
|
28 |
+
|
29 |
+
# Generate unique colors using the colormap
|
30 |
+
colors = {
|
31 |
+
number: cmap(i / len(unique_numbers))[:3]
|
32 |
+
for i, number in enumerate(unique_numbers)
|
33 |
+
}
|
34 |
+
|
35 |
+
# Convert colors from RGB to BGR format and scale to 0-255
|
36 |
+
colors_bgr = {
|
37 |
+
k: (int(v[2] * 255), int(v[1] * 255), int(v[0] * 255))
|
38 |
+
for k, v in colors.items()
|
39 |
+
}
|
40 |
+
|
41 |
+
return colors_bgr
|
42 |
+
|
43 |
+
|
44 |
+
# Function to adjust zero values to a small positive number
|
45 |
+
def adjust_zero_value(value):
|
46 |
+
return value if value != 0 else 1e-9
|
47 |
+
|
48 |
+
|
49 |
+
# Function to parse a YOLO label file and convert it into Albumentations-compatible Bboxes format
|
50 |
+
def bboxes_label(label_file, class_dict):
|
51 |
+
bboxes_data = {"bboxes": [], "class_labels": []}
|
52 |
+
for line in label_file:
|
53 |
+
try:
|
54 |
+
# Extracting class_id and bounding box coordinates from each line
|
55 |
+
class_id, x_center, y_center, width, height = map(
|
56 |
+
float, line.decode().strip().split()
|
57 |
+
)
|
58 |
+
class_id += 1 # Shifting its starting value to 1, since 0 is reserved for the background
|
59 |
+
|
60 |
+
# Adjusting bounding box coordinates to avoid zero values
|
61 |
+
# Albumentations does not accept zero values, but they are acceptable in YOLO format
|
62 |
+
x_center, y_center, width, height = map(
|
63 |
+
adjust_zero_value, [x_center, y_center, width, height]
|
64 |
+
)
|
65 |
+
|
66 |
+
# Check if values are within the expected range and class_id exists in class_dict
|
67 |
+
if (
|
68 |
+
0 <= x_center <= 1
|
69 |
+
and 0 <= y_center <= 1
|
70 |
+
and 0 <= width <= 1
|
71 |
+
and 0 <= height <= 1
|
72 |
+
and class_id in class_dict.keys()
|
73 |
+
):
|
74 |
+
bboxes_data["bboxes"].append([x_center, y_center, width, height])
|
75 |
+
bboxes_data["class_labels"].append(class_id)
|
76 |
+
else:
|
77 |
+
return None # Return None if any value is out of range or class_id is invalid
|
78 |
+
|
79 |
+
except Exception as e:
|
80 |
+
# Return None if any exception is encountered
|
81 |
+
return None
|
82 |
+
|
83 |
+
# Return None if the file is empty or no valid data found
|
84 |
+
return bboxes_data if bboxes_data["bboxes"] else None
|
85 |
+
|
86 |
+
|
87 |
+
# Function to parse a YOLO label file and convert it into compatible Mask format
|
88 |
+
def masks_label(label_file, class_dict):
|
89 |
+
mask_data = {"masks": [], "class_labels": []}
|
90 |
+
for line in label_file:
|
91 |
+
try:
|
92 |
+
# Clean up the line and split into parts
|
93 |
+
parts = line.decode().strip().split()
|
94 |
+
class_id = (
|
95 |
+
int(parts[0]) + 1
|
96 |
+
) # Shifting its starting value to 1, since 0 is reserved for the background
|
97 |
+
points = [float(p) for p in parts[1:]]
|
98 |
+
|
99 |
+
# Check if class_id exists in class_dict and coordinates are within the expected range
|
100 |
+
if class_id in class_dict.keys() and all(0 <= p <= 1 for p in points):
|
101 |
+
# Group points into (x, y) tuples
|
102 |
+
polygon = [(points[i], points[i + 1]) for i in range(0, len(points), 2)]
|
103 |
+
|
104 |
+
# Append class label and polygon to the mask data
|
105 |
+
mask_data["class_labels"].append(class_id)
|
106 |
+
mask_data["masks"].append(polygon)
|
107 |
+
else:
|
108 |
+
return None # Return None if class_id is invalid or coordinates are out of range
|
109 |
+
|
110 |
+
except Exception as e:
|
111 |
+
# Return None if any exception is encountered
|
112 |
+
return None
|
113 |
+
|
114 |
+
# Return None if the file is empty or no valid data found
|
115 |
+
return mask_data if mask_data["masks"] else None
|
116 |
+
|
117 |
+
|
118 |
+
# Function to generate mask for albumentations format
|
119 |
+
def generate_mask(masks, class_ids, image_height, image_width):
|
120 |
+
# Create an empty mask of the same size as the image, filled with 0 for background
|
121 |
+
mask = np.full((image_height, image_width), 0, dtype=np.int32)
|
122 |
+
|
123 |
+
# Iterate over each polygon and its corresponding class_id
|
124 |
+
for polygon, class_id in zip(masks, class_ids):
|
125 |
+
# Scale the polygon points to the image size
|
126 |
+
scaled_polygon = [
|
127 |
+
(int(x * image_width), int(y * image_height)) for x, y in polygon
|
128 |
+
]
|
129 |
+
|
130 |
+
# Draw the polygon on the mask
|
131 |
+
cv2.fillPoly(mask, [np.array(scaled_polygon, dtype=np.int32)], color=class_id)
|
132 |
+
|
133 |
+
return mask
|
134 |
+
|
135 |
+
|
136 |
+
# Function to convert a single-channel mask back to YOLO format
|
137 |
+
def mask_to_yolo(mask):
|
138 |
+
yolo_data = {"masks": [], "class_labels": []}
|
139 |
+
unique_values = np.unique(mask)
|
140 |
+
|
141 |
+
for value in unique_values:
|
142 |
+
if value == 0: # Skip the background
|
143 |
+
continue
|
144 |
+
|
145 |
+
# Extract mask for individual object
|
146 |
+
single_object_mask = np.uint8(mask == value)
|
147 |
+
|
148 |
+
# Find contours
|
149 |
+
contours, _ = cv2.findContours(
|
150 |
+
single_object_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
151 |
+
)
|
152 |
+
|
153 |
+
for contour in contours:
|
154 |
+
# Normalize and flatten contour points
|
155 |
+
normalized_contour = [
|
156 |
+
(point[0][0] / mask.shape[1], point[0][1] / mask.shape[0])
|
157 |
+
for point in contour
|
158 |
+
]
|
159 |
+
yolo_data["masks"].append(normalized_contour)
|
160 |
+
yolo_data["class_labels"].append(value)
|
161 |
+
|
162 |
+
return yolo_data
|
163 |
+
|
164 |
+
|
165 |
+
# Function to create a user interface for adjusting bounding box parameters
|
166 |
+
def bbox_params():
|
167 |
+
with st.expander("Bounding Box Parameters"):
|
168 |
+
# Create two columns for the input widgets
|
169 |
+
col1, col2 = st.columns(2)
|
170 |
+
|
171 |
+
with col1:
|
172 |
+
min_area = st.number_input(
|
173 |
+
"Minimum Area",
|
174 |
+
min_value=0,
|
175 |
+
max_value=1000,
|
176 |
+
value=0,
|
177 |
+
step=1,
|
178 |
+
help="Minimum area of a bounding box. Boxes smaller than this will be removed.",
|
179 |
+
on_change=utils.reset_validation_trigger,
|
180 |
+
key=st.session_state["bbox1_key"],
|
181 |
+
)
|
182 |
+
min_visibility = st.number_input(
|
183 |
+
"Minimum Visibility",
|
184 |
+
min_value=0,
|
185 |
+
max_value=1000,
|
186 |
+
value=0,
|
187 |
+
step=1,
|
188 |
+
help="Minimum fraction of area for a bounding box to remain in the list.",
|
189 |
+
on_change=utils.reset_validation_trigger,
|
190 |
+
key=st.session_state["bbox2_key"],
|
191 |
+
)
|
192 |
+
|
193 |
+
with col2:
|
194 |
+
min_width = st.number_input(
|
195 |
+
"Minimum Width",
|
196 |
+
min_value=0,
|
197 |
+
max_value=1000,
|
198 |
+
value=0,
|
199 |
+
step=1,
|
200 |
+
help="Minimum width of a bounding box. Boxes narrower than this will be removed.",
|
201 |
+
on_change=utils.reset_validation_trigger,
|
202 |
+
key=st.session_state["bbox3_key"],
|
203 |
+
)
|
204 |
+
min_height = st.number_input(
|
205 |
+
"Minimum Height",
|
206 |
+
min_value=0,
|
207 |
+
max_value=1000,
|
208 |
+
value=0,
|
209 |
+
step=1,
|
210 |
+
help="Minimum height of a bounding box. Boxes shorter than this will be removed.",
|
211 |
+
on_change=utils.reset_validation_trigger,
|
212 |
+
key=st.session_state["bbox4_key"],
|
213 |
+
)
|
214 |
+
|
215 |
+
check_each_transform = st.checkbox(
|
216 |
+
"Check Each Transform",
|
217 |
+
help="If checked, bounding boxes will be checked after each dual transform.",
|
218 |
+
value=True,
|
219 |
+
on_change=utils.reset_validation_trigger,
|
220 |
+
key=st.session_state["bbox5_key"],
|
221 |
+
)
|
222 |
+
|
223 |
+
# Return the collected parameters as a dictionary
|
224 |
+
return {
|
225 |
+
"min_area": min_area,
|
226 |
+
"min_visibility": min_visibility,
|
227 |
+
"min_width": min_width,
|
228 |
+
"min_height": min_height,
|
229 |
+
"check_each_transform": check_each_transform,
|
230 |
+
}
|
231 |
+
|
232 |
+
|
233 |
+
# Function to check if the uploaded images and labels are valid
|
234 |
+
st.cache_resource(show_spinner=False)
|
235 |
+
|
236 |
+
|
237 |
+
def check_valid_labels(uploaded_files, selected_option, class_dict):
|
238 |
+
# Early exit if no files are uploaded
|
239 |
+
if len(uploaded_files) == 0:
|
240 |
+
st.warning("Please upload at least one image to apply augmentation.", icon="⚠️")
|
241 |
+
return False, {}, {}, None, None
|
242 |
+
|
243 |
+
# Initialize dictionaries to hold images and labels
|
244 |
+
image_files, label_files = {}, {}
|
245 |
+
|
246 |
+
# Extracting the name of the first file
|
247 |
+
first_file_name = os.path.splitext(uploaded_files[0].name)[0]
|
248 |
+
|
249 |
+
# Counters for images and labels
|
250 |
+
image_count, label_count = 0, 0
|
251 |
+
|
252 |
+
# Initialize a progress bar and progress text
|
253 |
+
progress_bar = st.progress(0)
|
254 |
+
progress_text = st.empty()
|
255 |
+
total_files = len(uploaded_files)
|
256 |
+
|
257 |
+
# Categorize and prepare uploaded files
|
258 |
+
for index, file in enumerate(uploaded_files):
|
259 |
+
file.seek(0) # Reset file pointer to ensure proper file reading
|
260 |
+
file_name_without_extension = os.path.splitext(file.name)[0]
|
261 |
+
|
262 |
+
# Distribute files into image or label categories based on their file type
|
263 |
+
if file.type in ["image/jpeg", "image/png"]:
|
264 |
+
image_files[file_name_without_extension] = file
|
265 |
+
image_count += 1
|
266 |
+
elif file.type == "text/plain":
|
267 |
+
file_content = file.readlines()
|
268 |
+
|
269 |
+
if selected_option == "Bboxes":
|
270 |
+
label_data = bboxes_label(file_content, class_dict)
|
271 |
+
elif selected_option == "Masks":
|
272 |
+
label_data = masks_label(file_content, class_dict)
|
273 |
+
|
274 |
+
# Check for valid label data
|
275 |
+
if label_data is None:
|
276 |
+
st.warning(
|
277 |
+
f"Invalid label format or data in file: {file.name}",
|
278 |
+
icon="⚠️",
|
279 |
+
)
|
280 |
+
return False, {}, {}, None, None
|
281 |
+
|
282 |
+
label_files[file_name_without_extension] = label_data
|
283 |
+
label_count += 1
|
284 |
+
|
285 |
+
# Update progress bar and display current progress
|
286 |
+
progress_percentage = (index + 1) / total_files
|
287 |
+
progress_bar.progress(progress_percentage)
|
288 |
+
progress_text.text(f"Validating file {index + 1} of {total_files}")
|
289 |
+
|
290 |
+
# Extract sets of unique file names for images and labels
|
291 |
+
unique_image_names = set(image_files.keys())
|
292 |
+
unique_label_names = set(label_files.keys())
|
293 |
+
|
294 |
+
# Remove progress bar and progress text after processing
|
295 |
+
progress_bar.empty()
|
296 |
+
progress_text.empty()
|
297 |
+
|
298 |
+
if (len(unique_image_names) != image_count) or (
|
299 |
+
len(unique_label_names) != label_count
|
300 |
+
):
|
301 |
+
# Warn the user about the presence of duplicate file names
|
302 |
+
st.warning(
|
303 |
+
"Duplicate file names detected. Please ensure each image and label has a unique name.",
|
304 |
+
icon="⚠️",
|
305 |
+
)
|
306 |
+
return False, {}, {}, None, None
|
307 |
+
|
308 |
+
# Perform validation checks
|
309 |
+
if (len(image_files) > 0) and (len(label_files) > 0):
|
310 |
+
# Check if the number of images and labels match and each pair has corresponding files
|
311 |
+
if (len(image_files) == len(label_files)) and (
|
312 |
+
unique_image_names == unique_label_names
|
313 |
+
):
|
314 |
+
st.info(
|
315 |
+
f"Validated: {len(image_files)} images and labels successfully matched.",
|
316 |
+
icon="✅",
|
317 |
+
)
|
318 |
+
return (
|
319 |
+
True,
|
320 |
+
image_files,
|
321 |
+
label_files,
|
322 |
+
image_files[first_file_name],
|
323 |
+
label_files[first_file_name],
|
324 |
+
)
|
325 |
+
|
326 |
+
elif len(image_files) != len(label_files):
|
327 |
+
# Warn if the count of images and labels does not match
|
328 |
+
st.warning(
|
329 |
+
"Count Mismatch: The number of uploaded images and labels does not match.",
|
330 |
+
icon="⚠️",
|
331 |
+
)
|
332 |
+
return False, {}, {}, None, None
|
333 |
+
|
334 |
+
else:
|
335 |
+
# Warn if there is a mismatch in file names between images and labels
|
336 |
+
st.warning(
|
337 |
+
"Mismatch detected: Some images do not have corresponding label files.",
|
338 |
+
icon="⚠️",
|
339 |
+
)
|
340 |
+
return False, {}, {}, None, None
|
341 |
+
|
342 |
+
elif len(image_files) > 0:
|
343 |
+
# Inform the user if only images are uploaded without labels
|
344 |
+
st.info(
|
345 |
+
f"Note: {len(image_files)} images uploaded without labels. Label type and class labels will be ignored in this case.",
|
346 |
+
icon="✅",
|
347 |
+
)
|
348 |
+
return True, image_files, {}, image_files[first_file_name], None
|
349 |
+
|
350 |
+
else:
|
351 |
+
# Warn if no images are uploaded
|
352 |
+
st.warning("Please upload an image to apply augmentation.", icon="⚠️")
|
353 |
+
return False, {}, {}, None, None
|
354 |
+
|
355 |
+
|
356 |
+
# Function to apply an augmentation technique to an image and return any errors along with the processed image
|
357 |
+
def apply_and_test_augmentation(
|
358 |
+
augmentation,
|
359 |
+
params,
|
360 |
+
image,
|
361 |
+
label,
|
362 |
+
label_type,
|
363 |
+
label_input_parameters,
|
364 |
+
allowed_image_types,
|
365 |
+
):
|
366 |
+
try:
|
367 |
+
# Check the data type and number of channels of the input image
|
368 |
+
input_image_type = image.dtype
|
369 |
+
num_channels = (
|
370 |
+
image.shape[2] if len(image.shape) == 3 else 1
|
371 |
+
) # Assuming 1 for single-channel images
|
372 |
+
|
373 |
+
# Validate if the input image type is among the allowed types
|
374 |
+
if not utils.is_image_type_allowed(
|
375 |
+
input_image_type, num_channels, allowed_image_types
|
376 |
+
):
|
377 |
+
# Format the allowed types for display in the warning message
|
378 |
+
allowed_types_formatted = ", ".join(map(str, allowed_image_types))
|
379 |
+
|
380 |
+
# Display a warning message specifying the acceptable image types
|
381 |
+
st.warning(
|
382 |
+
f"Error applying {augmentation}: Incompatible image type. The input image should be one of the following types: {allowed_types_formatted}",
|
383 |
+
icon="⚠️",
|
384 |
+
)
|
385 |
+
return True, None # Error occurred
|
386 |
+
|
387 |
+
# Set the seed for reproducibility using iteration number
|
388 |
+
random.seed(0)
|
389 |
+
|
390 |
+
if label is None:
|
391 |
+
# Apply augmentation technique for no label input
|
392 |
+
transform = A.Compose([utils.apply_albumentation(params, augmentation)])
|
393 |
+
processed_image = transform(image=image)["image"]
|
394 |
+
return False, processed_image
|
395 |
+
|
396 |
+
elif label_type == "Bboxes":
|
397 |
+
# Apply augmentation technique for Bboxes lable format
|
398 |
+
transform = A.Compose(
|
399 |
+
[utils.apply_albumentation(params, augmentation)],
|
400 |
+
bbox_params=A.BboxParams(
|
401 |
+
format="yolo",
|
402 |
+
label_fields=["class_labels"],
|
403 |
+
min_area=label_input_parameters["min_area"],
|
404 |
+
min_visibility=label_input_parameters["min_visibility"],
|
405 |
+
min_width=label_input_parameters["min_width"],
|
406 |
+
min_height=label_input_parameters["min_height"],
|
407 |
+
check_each_transform=label_input_parameters["check_each_transform"],
|
408 |
+
),
|
409 |
+
)
|
410 |
+
processed_image = transform(
|
411 |
+
image=image,
|
412 |
+
bboxes=label["bboxes"],
|
413 |
+
class_labels=label["class_labels"],
|
414 |
+
)["image"]
|
415 |
+
|
416 |
+
elif label_type == "Masks":
|
417 |
+
# Apply augmentation technique for Masks lable format
|
418 |
+
transform = A.Compose([utils.apply_albumentation(params, augmentation)])
|
419 |
+
processed_image = transform(
|
420 |
+
image=image,
|
421 |
+
mask=generate_mask(
|
422 |
+
label["masks"],
|
423 |
+
label["class_labels"],
|
424 |
+
image.shape[0],
|
425 |
+
image.shape[1],
|
426 |
+
),
|
427 |
+
)["image"]
|
428 |
+
|
429 |
+
return False, processed_image # No error
|
430 |
+
|
431 |
+
except Exception as e:
|
432 |
+
st.warning(f"Error applying {augmentation}: {e}", icon="⚠️")
|
433 |
+
return True, None # Error occurred
|
434 |
+
|
435 |
+
|
436 |
+
# Generates a DataFrame detailing augmentation technique parameters and descriptions
|
437 |
+
def create_augmentations_dataframe(augmentations_params, augmentation_params_db):
|
438 |
+
data = []
|
439 |
+
for aug_name, params in augmentations_params.items():
|
440 |
+
for param_name, param_value in params.items():
|
441 |
+
# Retrieve relevant augmentation information from the database
|
442 |
+
augmentation_info = augmentation_params_db[
|
443 |
+
augmentation_params_db["Name"] == aug_name
|
444 |
+
]
|
445 |
+
param_info = augmentation_info[
|
446 |
+
augmentation_info["Parameter Name"] == param_name
|
447 |
+
]
|
448 |
+
|
449 |
+
# Check if the parameter information exists in the database
|
450 |
+
if not param_info.empty:
|
451 |
+
# Get the description of the current parameter
|
452 |
+
param_description = param_info["Parameter Description"].iloc[0]
|
453 |
+
else:
|
454 |
+
param_description = "Description not available"
|
455 |
+
|
456 |
+
# Append augmentation name, parameter name, its value, and description to the data list
|
457 |
+
data.append([aug_name, param_name, param_value, param_description])
|
458 |
+
|
459 |
+
# Create the DataFrame from the accumulated data
|
460 |
+
augmentations_df = pd.DataFrame(
|
461 |
+
data, columns=["augmentation", "Parameter", "Value", "Description"]
|
462 |
+
)
|
463 |
+
return augmentations_df
|
464 |
+
|
465 |
+
|
466 |
+
# Function to Generate Python Code for Augmentation with Bounding Box Labels
|
467 |
+
def generate_python_code_bboxes(
|
468 |
+
augmentations_params,
|
469 |
+
label_input_parameters,
|
470 |
+
num_variations=1,
|
471 |
+
include_original=False,
|
472 |
+
):
|
473 |
+
# Start with necessary library imports
|
474 |
+
code_str = "# Importing necessary libraries\n"
|
475 |
+
code_str += "import os\nimport cv2\nimport shutil\nimport albumentations as A\n\n"
|
476 |
+
|
477 |
+
# Paths for input and output directories
|
478 |
+
code_str += "# Define the paths for input and output directories\n"
|
479 |
+
code_str += "input_directory = 'path/to/input'\n"
|
480 |
+
code_str += "output_directory = 'path/to/output'\n\n"
|
481 |
+
|
482 |
+
# Function to read YOLO format labels
|
483 |
+
code_str += "# Function to read YOLO format labels\n"
|
484 |
+
code_str += "def read_yolo_label(label_path):\n"
|
485 |
+
code_str += " bboxes = []\n"
|
486 |
+
code_str += " class_ids = []\n"
|
487 |
+
code_str += " with open(label_path, 'r') as file:\n"
|
488 |
+
code_str += " for line in file:\n"
|
489 |
+
code_str += " class_id, x_center, y_center, width, height = map(float, line.split())\n"
|
490 |
+
code_str += " bboxes.append([x_center, y_center, width, height])\n"
|
491 |
+
code_str += " class_ids.append(int(class_id))\n"
|
492 |
+
code_str += " return bboxes, class_ids\n\n"
|
493 |
+
|
494 |
+
# Function to create an augmentation pipeline
|
495 |
+
code_str += "# Function to create an augmentation pipeline using Albumentations\n"
|
496 |
+
code_str += "def process_image(image, bboxes, class_ids):\n"
|
497 |
+
code_str += " # Define the sequence of augmentation techniques\n"
|
498 |
+
code_str += " pipeline = A.Compose([\n"
|
499 |
+
for technique, params in augmentations_params.items():
|
500 |
+
code_str += f" A.{technique}({', '.join(f'{k}={v}' for k, v in params.items())}),\n"
|
501 |
+
code_str += " ], bbox_params=A.BboxParams(\n"
|
502 |
+
code_str += f" format='yolo',\n"
|
503 |
+
code_str += f" label_fields=['class_labels'],\n"
|
504 |
+
code_str += f" min_area={label_input_parameters['min_area']},\n"
|
505 |
+
code_str += f" min_visibility={label_input_parameters['min_visibility']},\n"
|
506 |
+
code_str += f" min_width={label_input_parameters['min_width']},\n"
|
507 |
+
code_str += f" min_height={label_input_parameters['min_height']},\n"
|
508 |
+
code_str += f" check_each_transform={label_input_parameters['check_each_transform']}\n"
|
509 |
+
code_str += " ))\n"
|
510 |
+
code_str += " # Apply the augmentation pipeline\n"
|
511 |
+
code_str += (
|
512 |
+
" return pipeline(image=image, bboxes=bboxes, class_labels=class_ids)\n\n"
|
513 |
+
)
|
514 |
+
|
515 |
+
# Function to process a batch of images
|
516 |
+
code_str += "# Function to process a batch of images\n"
|
517 |
+
code_str += "def process_batch(input_directory, output_directory):\n"
|
518 |
+
code_str += " for filename in os.listdir(input_directory):\n"
|
519 |
+
code_str += " if filename.lower().endswith(('.png', '.jpg', '.jpeg')):\n"
|
520 |
+
code_str += " image_path = os.path.join(input_directory, filename)\n"
|
521 |
+
code_str += " label_path = os.path.splitext(image_path)[0] + '.txt'\n\n"
|
522 |
+
|
523 |
+
code_str += " # Read the image and label\n"
|
524 |
+
code_str += " image = cv2.imread(image_path)\n"
|
525 |
+
code_str += " bboxes, class_ids = read_yolo_label(label_path)\n\n"
|
526 |
+
|
527 |
+
# Include original image and label logic
|
528 |
+
code_str += " # Include original image and label\n"
|
529 |
+
if include_original:
|
530 |
+
code_str += " shutil.copy2(image_path, output_directory)\n"
|
531 |
+
code_str += " shutil.copy2(label_path, output_directory)\n\n"
|
532 |
+
|
533 |
+
# Generate variations for each image and process them
|
534 |
+
code_str += " # Generate variations for each image and process them\n"
|
535 |
+
code_str += f" for variation in range({num_variations}):\n"
|
536 |
+
code_str += (
|
537 |
+
" processed_data = process_image(image, bboxes, class_ids)\n"
|
538 |
+
)
|
539 |
+
code_str += " processed_image = processed_data['image']\n"
|
540 |
+
code_str += " processed_bboxes = processed_data['bboxes']\n"
|
541 |
+
code_str += (
|
542 |
+
" processed_class_ids = processed_data['class_labels']\n\n"
|
543 |
+
)
|
544 |
+
code_str += " # Save the processed image\n"
|
545 |
+
code_str += " output_filename = f'processed_{os.path.splitext(filename)[0]}_{variation}{os.path.splitext(filename)[1]}'\n"
|
546 |
+
code_str += " cv2.imwrite(os.path.join(output_directory, output_filename), processed_image)\n\n"
|
547 |
+
code_str += " with open(os.path.join(output_directory, os.path.splitext(output_filename)[0] + '.txt'), 'w') as label_file:\n"
|
548 |
+
code_str += " for bbox, class_id in zip(processed_bboxes, processed_class_ids):\n"
|
549 |
+
code_str += " label_line = ' '.join(map(str, [class_id] + list(bbox)))\n"
|
550 |
+
code_str += " label_file.write(label_line + '\\n')\n\n"
|
551 |
+
|
552 |
+
# Execute the batch processing function
|
553 |
+
code_str += (
|
554 |
+
"# Execute the batch processing function with the specified parameters\n"
|
555 |
+
)
|
556 |
+
code_str += f"process_batch(input_directory, output_directory)\n"
|
557 |
+
|
558 |
+
return code_str
|
559 |
+
|
560 |
+
|
561 |
+
def generate_python_code_masks(
|
562 |
+
augmentations_params,
|
563 |
+
label_input_parameters,
|
564 |
+
num_variations=1,
|
565 |
+
include_original=False,
|
566 |
+
):
|
567 |
+
# Start with necessary library imports
|
568 |
+
code_str = "# Importing necessary libraries\n"
|
569 |
+
code_str += "import os\nimport cv2\nimport shutil\nimport numpy as np\nimport albumentations as A\n\n"
|
570 |
+
|
571 |
+
# Paths for input and output directories
|
572 |
+
code_str += "# Define the paths for input and output directories\n"
|
573 |
+
code_str += "input_directory = 'path/to/input'\n"
|
574 |
+
code_str += "output_directory = 'path/to/output'\n\n"
|
575 |
+
|
576 |
+
# Function to read YOLO mask format and convert to mask
|
577 |
+
code_str += "# Function to read YOLO mask format and convert to mask\n"
|
578 |
+
code_str += "def read_yolo_label(label_path, image_shape):\n"
|
579 |
+
code_str += " mask = np.full(image_shape, -1, dtype=np.int32)\n"
|
580 |
+
code_str += " with open(label_path, 'r') as file:\n"
|
581 |
+
code_str += " for line in file:\n"
|
582 |
+
code_str += " parts = line.strip().split()\n"
|
583 |
+
code_str += " class_id = int(parts[0])\n"
|
584 |
+
code_str += (
|
585 |
+
" points = np.array([float(p) for p in parts[1:]]).reshape(-1, 2)\n"
|
586 |
+
)
|
587 |
+
code_str += " points = (points * [image_shape[1], image_shape[0]]).astype(np.int32)\n"
|
588 |
+
code_str += " cv2.fillPoly(mask, [points], class_id)\n"
|
589 |
+
code_str += " return mask\n\n"
|
590 |
+
|
591 |
+
# Function to convert mask to YOLO format
|
592 |
+
code_str += "# Function to convert mask to YOLO format\n"
|
593 |
+
code_str += "def mask_to_yolo(mask):\n"
|
594 |
+
code_str += " yolo_format = ''\n"
|
595 |
+
code_str += " for class_id in np.unique(mask):\n"
|
596 |
+
code_str += " if class_id == -1:\n"
|
597 |
+
code_str += " continue\n"
|
598 |
+
code_str += " contours, _ = cv2.findContours(\n"
|
599 |
+
code_str += " np.uint8(mask == class_id), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE\n"
|
600 |
+
code_str += " )\n"
|
601 |
+
code_str += " for contour in contours:\n"
|
602 |
+
code_str += " contour = contour.flatten().tolist()\n"
|
603 |
+
code_str += " normalized_contour = [\n"
|
604 |
+
code_str += " str(coord / mask.shape[i % 2]) for i, coord in enumerate(contour)\n"
|
605 |
+
code_str += " ]\n"
|
606 |
+
code_str += " yolo_format += f'{class_id} ' + ' '.join(normalized_contour) + '\\n'\n"
|
607 |
+
code_str += " return yolo_format\n\n"
|
608 |
+
|
609 |
+
# Function to create an augmentation pipeline
|
610 |
+
code_str += "# Function to create an augmentation pipeline using Albumentations\n"
|
611 |
+
code_str += "def process_image(image, mask):\n"
|
612 |
+
code_str += " # Define the sequence of augmentation techniques\n"
|
613 |
+
code_str += " pipeline = A.Compose([\n"
|
614 |
+
for technique, params in augmentations_params.items():
|
615 |
+
code_str += f" A.{technique}({', '.join(f'{k}={v}' for k, v in params.items())}),\n"
|
616 |
+
code_str += " ])\n"
|
617 |
+
code_str += " # Apply the augmentation pipeline\n"
|
618 |
+
code_str += " return pipeline(image=image, mask=mask)\n\n"
|
619 |
+
|
620 |
+
# Function to process a batch of images
|
621 |
+
code_str += "# Function to process a batch of images\n"
|
622 |
+
code_str += "def process_batch(input_directory, output_directory, include_original=False, num_variations=1):\n"
|
623 |
+
code_str += " for filename in os.listdir(input_directory):\n"
|
624 |
+
code_str += " if filename.lower().endswith(('.png', '.jpg', '.jpeg')):\n"
|
625 |
+
code_str += " image_path = os.path.join(input_directory, filename)\n"
|
626 |
+
code_str += " label_path = os.path.splitext(image_path)[0] + '.txt'\n\n"
|
627 |
+
code_str += " # Read the image\n"
|
628 |
+
code_str += " image = cv2.imread(image_path)\n\n"
|
629 |
+
|
630 |
+
# Include original image and label logic
|
631 |
+
code_str += " # Include original image\n"
|
632 |
+
if include_original:
|
633 |
+
code_str += " shutil.copy2(image_path, output_directory)\n"
|
634 |
+
code_str += " shutil.copy2(label_path, output_directory)\n\n"
|
635 |
+
|
636 |
+
code_str += " # Check if label file exists and read mask\n"
|
637 |
+
code_str += " mask = None\n"
|
638 |
+
code_str += " if os.path.exists(label_path):\n"
|
639 |
+
code_str += (
|
640 |
+
" mask = read_yolo_label(label_path, image.shape[:2])\n\n"
|
641 |
+
)
|
642 |
+
|
643 |
+
# Generate variations for each image and process them
|
644 |
+
code_str += " # Generate variations for each image and process them\n"
|
645 |
+
code_str += f" for variation in range({num_variations}):\n"
|
646 |
+
code_str += " processed_image, processed_mask = image, mask\n"
|
647 |
+
code_str += " if mask is not None:\n"
|
648 |
+
code_str += " processed_data = process_image(image, mask)\n"
|
649 |
+
code_str += " processed_image, processed_mask = processed_data['image'], processed_data['mask']\n\n"
|
650 |
+
code_str += " # Save the processed image\n"
|
651 |
+
code_str += " output_filename = f'processed_{os.path.splitext(filename)[0]}_{variation}.jpg'\n"
|
652 |
+
code_str += " cv2.imwrite(os.path.join(output_directory, output_filename), processed_image)\n\n"
|
653 |
+
code_str += " # Save the processed label in YOLO format\n"
|
654 |
+
code_str += " if processed_mask is not None:\n"
|
655 |
+
code_str += (
|
656 |
+
" processed_label_str = mask_to_yolo(processed_mask)\n"
|
657 |
+
)
|
658 |
+
code_str += " with open(os.path.join(output_directory, os.path.splitext(output_filename)[0] + '.txt'), 'w') as label_file:\n"
|
659 |
+
code_str += " label_file.write(processed_label_str)\n\n"
|
660 |
+
|
661 |
+
# Execute the batch processing function with the specified parameters
|
662 |
+
code_str += (
|
663 |
+
"# Execute the batch processing function with the specified parameters\n"
|
664 |
+
)
|
665 |
+
code_str += "process_batch(input_directory, output_directory)\n"
|
666 |
+
|
667 |
+
return code_str
|
668 |
+
|
669 |
+
|
670 |
+
# Function to create an augmentation pipeline based on the selected techniques and their parameters
|
671 |
+
def create_augmentation_pipeline(
|
672 |
+
selected_augmentations, augmentation_params, label_type, label_input_parameters=None
|
673 |
+
):
|
674 |
+
pipeline = []
|
675 |
+
for aug_name in selected_augmentations:
|
676 |
+
# Append the function call with its parameters to the pipeline
|
677 |
+
pipeline.append(
|
678 |
+
utils.apply_albumentation(augmentation_params[aug_name], aug_name)
|
679 |
+
)
|
680 |
+
|
681 |
+
# Compose all the augmentations into one transformation
|
682 |
+
try:
|
683 |
+
# Set the seed for reproducibility using iteration number
|
684 |
+
random.seed(0)
|
685 |
+
|
686 |
+
if label_type is None:
|
687 |
+
# Apply augmentation technique for no label input
|
688 |
+
transform = A.Compose(pipeline)
|
689 |
+
return transform
|
690 |
+
|
691 |
+
elif label_type == "Bboxes":
|
692 |
+
# Apply augmentation technique for Bboxes lable format
|
693 |
+
transform = A.Compose(
|
694 |
+
pipeline,
|
695 |
+
bbox_params=A.BboxParams(
|
696 |
+
format="yolo",
|
697 |
+
label_fields=["class_labels"],
|
698 |
+
min_area=label_input_parameters["min_area"],
|
699 |
+
min_visibility=label_input_parameters["min_visibility"],
|
700 |
+
min_width=label_input_parameters["min_width"],
|
701 |
+
min_height=label_input_parameters["min_height"],
|
702 |
+
check_each_transform=label_input_parameters["check_each_transform"],
|
703 |
+
),
|
704 |
+
)
|
705 |
+
|
706 |
+
elif label_type == "Masks":
|
707 |
+
# Apply augmentation technique for Masks lable format
|
708 |
+
transform = A.Compose(pipeline)
|
709 |
+
|
710 |
+
return transform # No error
|
711 |
+
|
712 |
+
except Exception as e:
|
713 |
+
st.warning(f"Error applying augmentation")
|
714 |
+
return None # Error occurred
|
715 |
+
|
716 |
+
|
717 |
+
# Function to convert label data from dictionary format to YOLO format
|
718 |
+
def convert_labels_to_yolo_format(label_data, class_dict):
|
719 |
+
yolo_label_str = ""
|
720 |
+
|
721 |
+
# Convert bounding boxes to YOLO format
|
722 |
+
if "bboxes" in label_data:
|
723 |
+
for bbox, class_label in zip(label_data["bboxes"], label_data["class_labels"]):
|
724 |
+
class_id = class_label - 1 # Revert to the original value
|
725 |
+
x_center, y_center, width, height = bbox
|
726 |
+
yolo_label_str += f"{class_id} {x_center} {y_center} {width} {height}\n"
|
727 |
+
|
728 |
+
# Convert masks to YOLO format
|
729 |
+
if "masks" in label_data:
|
730 |
+
for mask, class_label in zip(label_data["masks"], label_data["class_labels"]):
|
731 |
+
class_id = class_label - 1 # Revert to the original value
|
732 |
+
# Flatten the mask array into a single line of coordinates
|
733 |
+
mask_flattened = [coord for point in mask for coord in point]
|
734 |
+
mask_str = " ".join(map(str, mask_flattened))
|
735 |
+
yolo_label_str += f"{class_id} {mask_str}\n"
|
736 |
+
|
737 |
+
return yolo_label_str
|
738 |
+
|
739 |
+
|
740 |
+
# Function to apply the augmentation pipeline to the image based on the label type
|
741 |
+
def apply_augmentation_pipeline(image, label_file, label_type, transform):
|
742 |
+
# Initialize an empty dictionary to store processed labels
|
743 |
+
processed_label = {}
|
744 |
+
|
745 |
+
# Apply the transformation based on the label type
|
746 |
+
if label_type is None:
|
747 |
+
processed_output = transform(image=image)
|
748 |
+
processed_label = None
|
749 |
+
|
750 |
+
elif label_type == "Bboxes":
|
751 |
+
processed_output = transform(
|
752 |
+
image=image,
|
753 |
+
bboxes=label_file["bboxes"],
|
754 |
+
class_labels=label_file["class_labels"],
|
755 |
+
)
|
756 |
+
processed_label = {
|
757 |
+
"bboxes": processed_output["bboxes"],
|
758 |
+
"class_labels": processed_output["class_labels"],
|
759 |
+
}
|
760 |
+
|
761 |
+
elif label_type == "Masks":
|
762 |
+
mask = generate_mask(
|
763 |
+
label_file["masks"],
|
764 |
+
label_file["class_labels"],
|
765 |
+
image.shape[0],
|
766 |
+
image.shape[1],
|
767 |
+
)
|
768 |
+
processed_output = transform(image=image, mask=mask)
|
769 |
+
mask_yolo = mask_to_yolo(processed_output["mask"])
|
770 |
+
processed_label = mask_yolo
|
771 |
+
|
772 |
+
# Extract the processed image
|
773 |
+
processed_image = processed_output["image"]
|
774 |
+
|
775 |
+
return processed_image, processed_label
|
776 |
+
|
777 |
+
|
778 |
+
# Function to process images and labels, apply augmentations, and create a zip file with the results
|
779 |
+
@st.cache_resource(show_spinner=False)
|
780 |
+
def process_images_and_labels(
|
781 |
+
image_files,
|
782 |
+
label_files,
|
783 |
+
selected_augmentations,
|
784 |
+
_augmentations_params,
|
785 |
+
label_type,
|
786 |
+
label_input_parameters,
|
787 |
+
num_variations,
|
788 |
+
include_original,
|
789 |
+
class_dict,
|
790 |
+
):
|
791 |
+
zip_buffer = io.BytesIO() # Create an in-memory buffer for the zip file
|
792 |
+
st.session_state[
|
793 |
+
"image_repository_augmentation"
|
794 |
+
] = {} # Initialize a repository to store processed image data
|
795 |
+
st.session_state[
|
796 |
+
"processed_image_mapping_augmentation"
|
797 |
+
] = {} # Map original images to their processed versions
|
798 |
+
st.session_state["unique_images_names"] = [] # List to store unique images names
|
799 |
+
|
800 |
+
# Create progress bar and text elements in Streamlit
|
801 |
+
progress_bar = st.progress(0)
|
802 |
+
progress_text = st.empty()
|
803 |
+
|
804 |
+
with zipfile.ZipFile(
|
805 |
+
zip_buffer, mode="a", compression=zipfile.ZIP_DEFLATED, allowZip64=True
|
806 |
+
) as zip_file:
|
807 |
+
# Determine the label type for augmentation, if label files are present
|
808 |
+
effective_label_type = None if len(label_files) == 0 else label_type
|
809 |
+
|
810 |
+
# Create an augmentation pipeline based on selected augmentations and parameters
|
811 |
+
transform = create_augmentation_pipeline(
|
812 |
+
selected_augmentations,
|
813 |
+
_augmentations_params,
|
814 |
+
effective_label_type,
|
815 |
+
label_input_parameters,
|
816 |
+
)
|
817 |
+
|
818 |
+
total_images = len(image_files) * num_variations
|
819 |
+
processed_count = 0 # Counter for processed images
|
820 |
+
|
821 |
+
# Iterate over each uploaded file
|
822 |
+
for image_name, image_file in image_files.items():
|
823 |
+
image_file.seek(0) # Reset file pointer to start
|
824 |
+
file_bytes = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
|
825 |
+
original_image = cv2.cvtColor(
|
826 |
+
cv2.imdecode(file_bytes, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB
|
827 |
+
)
|
828 |
+
original_image_resized = utils.resize_image(original_image)
|
829 |
+
|
830 |
+
# Include original images and labels in the output if selected
|
831 |
+
if include_original:
|
832 |
+
original_img_buffer = io.BytesIO()
|
833 |
+
Image.fromarray(original_image).save(original_img_buffer, format="JPEG")
|
834 |
+
zip_file.writestr(image_file.name, original_img_buffer.getvalue())
|
835 |
+
|
836 |
+
# Convert and save original labels to YOLO format if they exist
|
837 |
+
label_file = label_files.get(image_name)
|
838 |
+
if label_file is not None:
|
839 |
+
yolo_label_str = convert_labels_to_yolo_format(
|
840 |
+
label_file, class_dict
|
841 |
+
)
|
842 |
+
zip_file.writestr(f"{image_name}.txt", yolo_label_str)
|
843 |
+
|
844 |
+
original_file_name = image_file.name
|
845 |
+
st.session_state["unique_images_names"].append(original_file_name)
|
846 |
+
st.session_state["processed_image_mapping_augmentation"][
|
847 |
+
original_file_name
|
848 |
+
] = []
|
849 |
+
st.session_state["image_repository_augmentation"][image_file.name] = {
|
850 |
+
"image": original_image_resized,
|
851 |
+
"label": label_files.get(image_name),
|
852 |
+
}
|
853 |
+
|
854 |
+
# Apply augmentations and generate variations
|
855 |
+
for i in range(num_variations):
|
856 |
+
random.seed(i)
|
857 |
+
(
|
858 |
+
processed_image,
|
859 |
+
processed_label,
|
860 |
+
) = apply_augmentation_pipeline(
|
861 |
+
original_image,
|
862 |
+
label_files.get(image_name),
|
863 |
+
effective_label_type,
|
864 |
+
transform,
|
865 |
+
)
|
866 |
+
|
867 |
+
img_buffer = io.BytesIO()
|
868 |
+
Image.fromarray(processed_image).save(img_buffer, format="JPEG")
|
869 |
+
processed_filename = f"processed_{image_name.split('.')[0]}_{i}.jpg"
|
870 |
+
zip_file.writestr(processed_filename, img_buffer.getvalue())
|
871 |
+
processed_image_resized = utils.resize_image(processed_image)
|
872 |
+
|
873 |
+
st.session_state["processed_image_mapping_augmentation"][
|
874 |
+
image_file.name
|
875 |
+
].append(processed_filename)
|
876 |
+
st.session_state["image_repository_augmentation"][
|
877 |
+
processed_filename
|
878 |
+
] = {
|
879 |
+
"image": processed_image_resized,
|
880 |
+
"label": processed_label,
|
881 |
+
}
|
882 |
+
|
883 |
+
# Convert and save processed labels to YOLO format if they exist
|
884 |
+
label_file = label_files.get(image_name)
|
885 |
+
if label_file is not None:
|
886 |
+
processed_label_str = convert_labels_to_yolo_format(
|
887 |
+
processed_label, class_dict
|
888 |
+
)
|
889 |
+
zip_file.writestr(
|
890 |
+
f"processed_{image_name.split('.')[0]}_{i}.txt",
|
891 |
+
processed_label_str,
|
892 |
+
)
|
893 |
+
|
894 |
+
processed_count += 1
|
895 |
+
# Update progress bar and text
|
896 |
+
progress_bar.progress(processed_count / total_images)
|
897 |
+
progress_text.text(
|
898 |
+
f"Processing image {processed_count} of {total_images}"
|
899 |
+
)
|
900 |
+
|
901 |
+
# Remove the progress bar and text after processing is complete
|
902 |
+
progress_bar.empty()
|
903 |
+
progress_text.empty()
|
904 |
+
|
905 |
+
zip_buffer.seek(0) # Reset buffer to start for download
|
906 |
+
|
907 |
+
st.session_state["zip_data_augmentation"] = zip_buffer.getvalue()
|
908 |
+
|
909 |
+
|
910 |
+
# Function to overlay labels on images
|
911 |
+
def overlay_labels(image, labels_to_plot, label_file, label_type, colors, class_dict):
|
912 |
+
# Ensure the image is in the correct format (RGB)
|
913 |
+
if len(image.shape) == 2 or image.shape[2] == 1:
|
914 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
915 |
+
|
916 |
+
# Overlay Bounding Boxes
|
917 |
+
if label_type == "Bboxes":
|
918 |
+
for bbox, label in zip(label_file["bboxes"], label_file["class_labels"]):
|
919 |
+
if label in labels_to_plot:
|
920 |
+
# Convert bbox from yolo format to xmin, ymin, xmax, ymax
|
921 |
+
x_center, y_center, width, height = bbox
|
922 |
+
xmin = int((x_center - width / 2) * image.shape[1])
|
923 |
+
xmax = int((x_center + width / 2) * image.shape[1])
|
924 |
+
ymin = int((y_center - height / 2) * image.shape[0])
|
925 |
+
ymax = int((y_center + height / 2) * image.shape[0])
|
926 |
+
|
927 |
+
# Get color for the class
|
928 |
+
color = colors[label]
|
929 |
+
|
930 |
+
# Draw rectangle and label
|
931 |
+
image = cv2.rectangle(
|
932 |
+
image, (xmin, ymin), (xmax, ymax), color, thickness=2
|
933 |
+
)
|
934 |
+
|
935 |
+
# Put class label text
|
936 |
+
label_text = class_dict.get(label, "Unknown")
|
937 |
+
cv2.putText(
|
938 |
+
image,
|
939 |
+
label_text,
|
940 |
+
(xmin, ymin - 5),
|
941 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
942 |
+
0.5,
|
943 |
+
(0, 0, 0), # Black color for text
|
944 |
+
2,
|
945 |
+
)
|
946 |
+
|
947 |
+
# Overlay Mask
|
948 |
+
elif label_type == "Masks":
|
949 |
+
for polygon, label in zip(label_file["masks"], label_file["class_labels"]):
|
950 |
+
if label in labels_to_plot:
|
951 |
+
# Convert polygon points from yolo format to image coordinates
|
952 |
+
polygon_points = [
|
953 |
+
(int(x * image.shape[1]), int(y * image.shape[0]))
|
954 |
+
for x, y in polygon
|
955 |
+
]
|
956 |
+
|
957 |
+
# Get color for the class
|
958 |
+
color = colors[label]
|
959 |
+
|
960 |
+
# Create a temporary image to draw the polygon
|
961 |
+
temp_image = image.copy()
|
962 |
+
cv2.fillPoly(
|
963 |
+
temp_image, [np.array(polygon_points, dtype=np.int32)], color
|
964 |
+
)
|
965 |
+
|
966 |
+
# Blend the temporary image with the original image
|
967 |
+
cv2.addWeighted(temp_image, 0.5, image, 0.5, 0, image)
|
968 |
+
|
969 |
+
# Optional: Put class label text near the first point of the polygon
|
970 |
+
label_text = class_dict.get(label, "Unknown")
|
971 |
+
cv2.putText(
|
972 |
+
image,
|
973 |
+
label_text,
|
974 |
+
polygon_points[0],
|
975 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
976 |
+
0.5,
|
977 |
+
(0, 0, 0), # Black color for text
|
978 |
+
2,
|
979 |
+
)
|
980 |
+
|
981 |
+
return image
|
982 |
+
|
983 |
+
|
984 |
+
# Function to generate a downloadable file
|
985 |
+
def display_code_and_download_button(generated_code):
|
986 |
+
def generate_downloadable_file(code_str):
|
987 |
+
return code_str.encode("utf-8")
|
988 |
+
|
989 |
+
# Display the generated code in Streamlit with description and download button in columns
|
990 |
+
with st.expander("Plug and Play Code"):
|
991 |
+
col1, col2 = st.columns([7, 3])
|
992 |
+
|
993 |
+
with col1:
|
994 |
+
st.markdown(
|
995 |
+
"""
|
996 |
+
### Description of the Code Pipeline
|
997 |
+
"""
|
998 |
+
)
|
999 |
+
|
1000 |
+
st.markdown(
|
1001 |
+
"""
|
1002 |
+
<div style='text-align: justify;'>
|
1003 |
+
This code is a ready-to-use Python script for batch augmentation. It applies selected augmentation techniques to all images in a specified input directory and saves the processed images in an output directory.
|
1004 |
+
|
1005 |
+
**To use this script:**
|
1006 |
+
- Ensure you have the necessary dependencies installed.
|
1007 |
+
- Specify the input and output paths: Replace `'path/to/input'` with the path to your input images and `'path/to/output'` with the desired path for the processed images.
|
1008 |
+
- The number of augmented variations per image, the inclusion of the original images in the output, and the augmentation techniques with their parameters will be automatically set based on your selections.
|
1009 |
+
|
1010 |
+
### Python Code
|
1011 |
+
</div>
|
1012 |
+
""",
|
1013 |
+
unsafe_allow_html=True,
|
1014 |
+
)
|
1015 |
+
|
1016 |
+
# Display python code
|
1017 |
+
st.code(generated_code, language="python")
|
1018 |
+
|
1019 |
+
with col2:
|
1020 |
+
# Create a button for downloading the Python file
|
1021 |
+
st.download_button(
|
1022 |
+
label="Download Python File",
|
1023 |
+
data=generate_downloadable_file(generated_code),
|
1024 |
+
file_name="augmentation_script.py",
|
1025 |
+
mime="text/plain",
|
1026 |
+
use_container_width=True,
|
1027 |
+
)
|
Functions/image_processing_functions.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Importing necessary libraries
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
import cv2
|
5 |
+
import utils
|
6 |
+
import random
|
7 |
+
import zipfile
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
from PIL import Image
|
11 |
+
import streamlit as st
|
12 |
+
import albumentations as A
|
13 |
+
|
14 |
+
|
15 |
+
# Function to check if the uploaded images and labels are valid
|
16 |
+
st.cache_resource(show_spinner=False)
|
17 |
+
|
18 |
+
|
19 |
+
def check_valid_labels(uploaded_files):
|
20 |
+
# Early exit if no files are uploaded
|
21 |
+
if len(uploaded_files) == 0:
|
22 |
+
st.warning(
|
23 |
+
"Please upload at least one image to apply image processing.", icon="⚠️"
|
24 |
+
)
|
25 |
+
return False, {}, {}, None, None
|
26 |
+
|
27 |
+
# Initialize dictionaries to hold images and labels
|
28 |
+
image_files, label_files = {}, {}
|
29 |
+
|
30 |
+
# Extracting the name of the first file
|
31 |
+
first_file_name = os.path.splitext(uploaded_files[0].name)[0]
|
32 |
+
|
33 |
+
# Counters for images and labels
|
34 |
+
image_count, label_count = 0, 0
|
35 |
+
|
36 |
+
# Initialize a progress bar and progress text
|
37 |
+
progress_bar = st.progress(0)
|
38 |
+
progress_text = st.empty()
|
39 |
+
total_files = len(uploaded_files)
|
40 |
+
|
41 |
+
# Categorize and prepare uploaded files
|
42 |
+
for index, file in enumerate(uploaded_files):
|
43 |
+
file.seek(0) # Reset file pointer to ensure proper file reading
|
44 |
+
file_name_without_extension = os.path.splitext(file.name)[0]
|
45 |
+
|
46 |
+
# Distribute files into image or label categories based on their file type
|
47 |
+
if file.type in ["image/jpeg", "image/png"]:
|
48 |
+
image_files[file_name_without_extension] = file
|
49 |
+
image_count += 1
|
50 |
+
elif file.type == "text/plain":
|
51 |
+
label_files[file_name_without_extension] = file
|
52 |
+
label_count += 1
|
53 |
+
|
54 |
+
# Update progress bar and display current progress
|
55 |
+
progress_percentage = (index + 1) / total_files
|
56 |
+
progress_bar.progress(progress_percentage)
|
57 |
+
progress_text.text(f"Validating file {index + 1} of {total_files}")
|
58 |
+
|
59 |
+
# Extract sets of unique file names for images and labels
|
60 |
+
unique_image_names = set(image_files.keys())
|
61 |
+
unique_label_names = set(label_files.keys())
|
62 |
+
|
63 |
+
# Remove progress bar and progress text after processing
|
64 |
+
progress_bar.empty()
|
65 |
+
progress_text.empty()
|
66 |
+
|
67 |
+
if (len(unique_image_names) != image_count) or (
|
68 |
+
len(unique_label_names) != label_count
|
69 |
+
):
|
70 |
+
# Warn the user about the presence of duplicate file names
|
71 |
+
st.warning(
|
72 |
+
"Duplicate file names detected. Please ensure each image and label has a unique name.",
|
73 |
+
icon="⚠️",
|
74 |
+
)
|
75 |
+
return False, {}, {}, None, None
|
76 |
+
|
77 |
+
# Perform validation checks
|
78 |
+
if (len(image_files) > 0) and (len(label_files) > 0):
|
79 |
+
# Check if the number of images and labels match and each pair has corresponding files
|
80 |
+
if (len(image_files) == len(label_files)) and (
|
81 |
+
unique_image_names == unique_label_names
|
82 |
+
):
|
83 |
+
st.info(
|
84 |
+
f"Validated: {len(image_files)} images and labels successfully matched.",
|
85 |
+
icon="✅",
|
86 |
+
)
|
87 |
+
return (
|
88 |
+
True,
|
89 |
+
image_files,
|
90 |
+
label_files,
|
91 |
+
image_files[first_file_name],
|
92 |
+
label_files[first_file_name],
|
93 |
+
)
|
94 |
+
|
95 |
+
elif len(image_files) != len(label_files):
|
96 |
+
# Warn if the count of images and labels does not match
|
97 |
+
st.warning(
|
98 |
+
"Count Mismatch: The number of uploaded images and labels does not match.",
|
99 |
+
icon="⚠️",
|
100 |
+
)
|
101 |
+
return False, {}, {}, None, None
|
102 |
+
|
103 |
+
else:
|
104 |
+
# Warn if there is a mismatch in file names between images and labels
|
105 |
+
st.warning(
|
106 |
+
"Mismatch detected: Some images do not have corresponding label files.",
|
107 |
+
icon="⚠️",
|
108 |
+
)
|
109 |
+
return False, {}, {}, None, None
|
110 |
+
|
111 |
+
elif len(image_files) > 0:
|
112 |
+
# Inform the user if only images are uploaded without labels
|
113 |
+
st.info(
|
114 |
+
f"Note: {len(image_files)} images uploaded without labels. Label type and class labels will be ignored in this case.",
|
115 |
+
icon="✅",
|
116 |
+
)
|
117 |
+
return True, image_files, {}, image_files[first_file_name], None
|
118 |
+
|
119 |
+
else:
|
120 |
+
# Warn if no images are uploaded
|
121 |
+
st.warning("Please upload an image to apply image processing.", icon="⚠️")
|
122 |
+
return False, {}, {}, None, None
|
123 |
+
|
124 |
+
|
125 |
+
# Function to apply an image processing technique to an image and return any errors along with the processed image
|
126 |
+
def apply_and_test_image_processing(
|
127 |
+
image_processing, params, image, allowed_image_types
|
128 |
+
):
|
129 |
+
try:
|
130 |
+
# Check the data type and number of channels of the input image
|
131 |
+
input_image_type = image.dtype
|
132 |
+
num_channels = (
|
133 |
+
image.shape[2] if len(image.shape) == 3 else 1
|
134 |
+
) # Assuming 1 for single-channel images
|
135 |
+
|
136 |
+
# Validate if the input image type is among the allowed types
|
137 |
+
if not utils.is_image_type_allowed(
|
138 |
+
input_image_type, num_channels, allowed_image_types
|
139 |
+
):
|
140 |
+
# Format the allowed types for display in the warning message
|
141 |
+
allowed_types_formatted = ", ".join(map(str, allowed_image_types))
|
142 |
+
|
143 |
+
# Display a warning message specifying the acceptable image types
|
144 |
+
st.warning(
|
145 |
+
f"Error applying {image_processing}: Incompatible image type. The input image should be one of the following types: {allowed_types_formatted}",
|
146 |
+
icon="⚠️",
|
147 |
+
)
|
148 |
+
return True, None # Error occurred
|
149 |
+
|
150 |
+
# Set the seed for reproducibility using iteration number
|
151 |
+
random.seed(0)
|
152 |
+
|
153 |
+
# Apply image processing technique
|
154 |
+
transform = A.Compose([utils.apply_albumentation(params, image_processing)])
|
155 |
+
processed_image = transform(image=image)["image"]
|
156 |
+
|
157 |
+
return False, processed_image # No error
|
158 |
+
except Exception as e:
|
159 |
+
st.warning(f"Error applying {image_processing}: {e}", icon="⚠️")
|
160 |
+
return True, None # Error occurred
|
161 |
+
|
162 |
+
|
163 |
+
# Function to generates a DataFrame detailing image processing technique parameters and descriptions
|
164 |
+
def create_image_processings_dataframe(
|
165 |
+
image_processings_params, image_processing_params_db
|
166 |
+
):
|
167 |
+
data = []
|
168 |
+
for aug_name, params in image_processings_params.items():
|
169 |
+
for param_name, param_value in params.items():
|
170 |
+
# Retrieve relevant image_processing information from the database
|
171 |
+
image_processing_info = image_processing_params_db[
|
172 |
+
image_processing_params_db["Name"] == aug_name
|
173 |
+
]
|
174 |
+
param_info = image_processing_info[
|
175 |
+
image_processing_info["Parameter Name"] == param_name
|
176 |
+
]
|
177 |
+
|
178 |
+
# Check if the parameter information exists in the database
|
179 |
+
if not param_info.empty:
|
180 |
+
# Get the description of the current parameter
|
181 |
+
param_description = param_info["Parameter Description"].iloc[0]
|
182 |
+
else:
|
183 |
+
param_description = "Description not available"
|
184 |
+
|
185 |
+
# Append image_processing name, parameter name, its value, and description to the data list
|
186 |
+
data.append([aug_name, param_name, param_value, param_description])
|
187 |
+
|
188 |
+
# Create the DataFrame from the accumulated data
|
189 |
+
image_processings_df = pd.DataFrame(
|
190 |
+
data, columns=["image_processing", "Parameter", "Value", "Description"]
|
191 |
+
)
|
192 |
+
return image_processings_df
|
193 |
+
|
194 |
+
|
195 |
+
# Function to generate python code for images and labels
|
196 |
+
def generate_python_code_images_labels(
|
197 |
+
augmentations_params,
|
198 |
+
num_variations=1,
|
199 |
+
include_original=False,
|
200 |
+
):
|
201 |
+
# Start with necessary library imports
|
202 |
+
code_str = "# Importing necessary libraries\n"
|
203 |
+
code_str += "import os\nimport cv2\nimport shutil\nimport albumentations as A\n\n"
|
204 |
+
|
205 |
+
# Paths for input and output directories
|
206 |
+
code_str += "# Define the paths for input and output directories\n"
|
207 |
+
code_str += "input_directory = 'path/to/input'\n"
|
208 |
+
code_str += "output_directory = 'path/to/output'\n\n"
|
209 |
+
|
210 |
+
# Function to create an augmentation pipeline
|
211 |
+
code_str += "# Function to create an augmentation pipeline using Albumentations\n"
|
212 |
+
code_str += "def process_image(image):\n"
|
213 |
+
code_str += " # Define the sequence of augmentation techniques\n"
|
214 |
+
code_str += " pipeline = A.Compose([\n"
|
215 |
+
for technique, params in augmentations_params.items():
|
216 |
+
code_str += f" A.{technique}({', '.join(f'{k}={v}' for k, v in params.items())}),\n"
|
217 |
+
code_str += " ])\n"
|
218 |
+
code_str += " # Apply the augmentation pipeline\n"
|
219 |
+
code_str += " return pipeline(image=image)['image']\n\n"
|
220 |
+
|
221 |
+
# Function to process a batch of images
|
222 |
+
code_str += "# Function to process a batch of images\n"
|
223 |
+
code_str += "def process_batch(input_directory, output_directory):\n"
|
224 |
+
code_str += " for filename in os.listdir(input_directory):\n"
|
225 |
+
code_str += " if filename.lower().endswith(('.png', '.jpg', '.jpeg')):\n"
|
226 |
+
code_str += " image_path = os.path.join(input_directory, filename)\n"
|
227 |
+
code_str += " label_path = os.path.splitext(image_path)[0] + '.txt'\n\n"
|
228 |
+
|
229 |
+
code_str += " # Read the image\n"
|
230 |
+
code_str += " image = cv2.imread(image_path)\n\n"
|
231 |
+
|
232 |
+
# Include original image and label logic
|
233 |
+
if include_original:
|
234 |
+
code_str += " # Include original image and label\n"
|
235 |
+
code_str += " shutil.copy2(image_path, output_directory)\n"
|
236 |
+
code_str += " shutil.copy2(label_path, output_directory)\n\n"
|
237 |
+
|
238 |
+
# Generate variations for each image and process them
|
239 |
+
code_str += " # Generate variations for each image and process them\n"
|
240 |
+
code_str += f" for variation in range({num_variations}):\n"
|
241 |
+
code_str += " processed_image = process_image(image)\n\n"
|
242 |
+
code_str += " # Save the processed image\n"
|
243 |
+
code_str += " output_filename = f'processed_{os.path.splitext(filename)[0]}_{variation}{os.path.splitext(filename)[1]}'\n"
|
244 |
+
code_str += " cv2.imwrite(os.path.join(output_directory, output_filename), processed_image)\n\n"
|
245 |
+
code_str += (
|
246 |
+
" # Save the original label file for the processed image\n"
|
247 |
+
)
|
248 |
+
code_str += " if os.path.exists(label_path):\n"
|
249 |
+
code_str += " shutil.copy2(label_path, os.path.join(output_directory, os.path.splitext(output_filename)[0] + '.txt'))\n\n"
|
250 |
+
|
251 |
+
# Execute the batch processing function
|
252 |
+
code_str += (
|
253 |
+
"# Execute the batch processing function with the specified parameters\n"
|
254 |
+
)
|
255 |
+
code_str += "process_batch(input_directory, output_directory)\n"
|
256 |
+
|
257 |
+
return code_str
|
258 |
+
|
259 |
+
|
260 |
+
# Function to create an image processing pipeline based on the selected techniques and their parameters
|
261 |
+
def create_image_processing_pipeline(
|
262 |
+
selected_image_processings, image_processing_params
|
263 |
+
):
|
264 |
+
pipeline = []
|
265 |
+
for aug_name in selected_image_processings:
|
266 |
+
# Append the function call with its parameters to the pipeline
|
267 |
+
pipeline.append(
|
268 |
+
utils.apply_albumentation(image_processing_params[aug_name], aug_name)
|
269 |
+
)
|
270 |
+
|
271 |
+
# Compose all the image processings into one transformation
|
272 |
+
return A.Compose(pipeline)
|
273 |
+
|
274 |
+
|
275 |
+
# Function to process images and labels, apply image processing techniques, and create a zip file with the results
|
276 |
+
@st.cache_resource(show_spinner=False)
|
277 |
+
def process_images_and_labels(
|
278 |
+
image_files,
|
279 |
+
label_files,
|
280 |
+
selected_image_processings,
|
281 |
+
_image_processings_params,
|
282 |
+
num_variations,
|
283 |
+
include_original,
|
284 |
+
):
|
285 |
+
zip_buffer = io.BytesIO() # Create an in-memory buffer for the zip file
|
286 |
+
st.session_state[
|
287 |
+
"image_repository_preprocessing"
|
288 |
+
] = {} # Initialize a repository to store processed image data
|
289 |
+
st.session_state[
|
290 |
+
"processed_image_mapping_procesing"
|
291 |
+
] = {} # Map original images to their processed versions
|
292 |
+
st.session_state["unique_images_names"] = [] # List to store unique images names
|
293 |
+
|
294 |
+
# Create progress bar and text elements in Streamlit
|
295 |
+
progress_bar = st.progress(0)
|
296 |
+
progress_text = st.empty()
|
297 |
+
|
298 |
+
with zipfile.ZipFile(
|
299 |
+
zip_buffer, mode="a", compression=zipfile.ZIP_DEFLATED, allowZip64=True
|
300 |
+
) as zip_file:
|
301 |
+
# Compose all the image processings into one transformation
|
302 |
+
transform = create_image_processing_pipeline(
|
303 |
+
selected_image_processings, _image_processings_params
|
304 |
+
)
|
305 |
+
|
306 |
+
total_images = len(image_files) * num_variations
|
307 |
+
processed_count = 0 # Counter for processed images
|
308 |
+
|
309 |
+
# Iterate over each uploaded file
|
310 |
+
for image_name, image_file in image_files.items():
|
311 |
+
image_file.seek(0) # Reset file pointer to start
|
312 |
+
file_bytes = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
|
313 |
+
original_image = cv2.cvtColor(
|
314 |
+
cv2.imdecode(file_bytes, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB
|
315 |
+
)
|
316 |
+
original_image_resized = utils.resize_image(original_image)
|
317 |
+
|
318 |
+
# Include original images and labels in the output if selected
|
319 |
+
if include_original:
|
320 |
+
original_img_buffer = io.BytesIO()
|
321 |
+
Image.fromarray(original_image).save(original_img_buffer, format="JPEG")
|
322 |
+
zip_file.writestr(image_file.name, original_img_buffer.getvalue())
|
323 |
+
|
324 |
+
# Save corresponding label file to zip if it exists
|
325 |
+
label_file = label_files.get(image_name)
|
326 |
+
if label_file is not None:
|
327 |
+
label_file.seek(0) # Reset the file pointer
|
328 |
+
zip_file.writestr(f"{image_name}.txt", label_file.read())
|
329 |
+
|
330 |
+
original_file_name = image_file.name
|
331 |
+
st.session_state["unique_images_names"].append(original_file_name)
|
332 |
+
st.session_state["processed_image_mapping_procesing"][
|
333 |
+
original_file_name
|
334 |
+
] = []
|
335 |
+
st.session_state["image_repository_preprocessing"][image_file.name] = {
|
336 |
+
"image": original_image_resized,
|
337 |
+
"label": label_files.get(image_name),
|
338 |
+
}
|
339 |
+
|
340 |
+
# Apply image processing techniques and generate variations
|
341 |
+
for i in range(num_variations):
|
342 |
+
random.seed(i)
|
343 |
+
|
344 |
+
# Apply the image processing pipeline to the image
|
345 |
+
processed_image = transform(image=original_image)["image"]
|
346 |
+
|
347 |
+
img_buffer = io.BytesIO()
|
348 |
+
Image.fromarray(processed_image).save(img_buffer, format="JPEG")
|
349 |
+
processed_filename = f"processed_{image_name.split('.')[0]}_{i}.jpg"
|
350 |
+
zip_file.writestr(processed_filename, img_buffer.getvalue())
|
351 |
+
processed_image_resized = utils.resize_image(processed_image)
|
352 |
+
|
353 |
+
# Save corresponding label file to zip if it exists
|
354 |
+
label_file = label_files.get(image_name)
|
355 |
+
if label_file is not None:
|
356 |
+
label_file.seek(0) # Reset the file pointer
|
357 |
+
zip_file.writestr(
|
358 |
+
f"processed_{image_name}_{i}.txt", label_file.read()
|
359 |
+
)
|
360 |
+
|
361 |
+
st.session_state["processed_image_mapping_procesing"][
|
362 |
+
image_file.name
|
363 |
+
].append(processed_filename)
|
364 |
+
st.session_state["image_repository_preprocessing"][
|
365 |
+
processed_filename
|
366 |
+
] = {
|
367 |
+
"image": processed_image_resized,
|
368 |
+
"label": label_file,
|
369 |
+
}
|
370 |
+
|
371 |
+
processed_count += 1
|
372 |
+
# Update progress bar and text
|
373 |
+
progress_bar.progress(processed_count / total_images)
|
374 |
+
progress_text.text(
|
375 |
+
f"Processing image {processed_count} of {total_images}"
|
376 |
+
)
|
377 |
+
|
378 |
+
# Remove the progress bar and text after processing is complete
|
379 |
+
progress_bar.empty()
|
380 |
+
progress_text.empty()
|
381 |
+
|
382 |
+
zip_buffer.seek(0) # Reset buffer to start for download
|
383 |
+
|
384 |
+
st.session_state["zip_data_processing"] = zip_buffer.getvalue()
|
385 |
+
|
386 |
+
|
387 |
+
# Function to generate a downloadable file
|
388 |
+
def display_code_and_download_button(generated_code):
|
389 |
+
def generate_downloadable_file(code_str):
|
390 |
+
return code_str.encode("utf-8")
|
391 |
+
|
392 |
+
# Display the generated code in Streamlit with description and download button in columns
|
393 |
+
with st.expander("Plug and Play Code"):
|
394 |
+
col1, col2 = st.columns([7, 3])
|
395 |
+
|
396 |
+
with col1:
|
397 |
+
st.markdown(
|
398 |
+
"""
|
399 |
+
### Description of the Code Pipeline
|
400 |
+
"""
|
401 |
+
)
|
402 |
+
|
403 |
+
st.markdown(
|
404 |
+
"""
|
405 |
+
<div style='text-align: justify;'>
|
406 |
+
This code is a ready-to-use Python script for batch augmentation. It applies selected augmentation techniques to all images in a specified input directory and saves the processed images in an output directory.
|
407 |
+
|
408 |
+
**To use this script:**
|
409 |
+
- Ensure you have the necessary dependencies installed.
|
410 |
+
- Specify the input and output paths: Replace `'path/to/input'` with the path to your input images and `'path/to/output'` with the desired path for the processed images.
|
411 |
+
- The number of processed variations per image, the inclusion of the original images in the output, and the processing techniques with their parameters will be automatically set based on your selections.
|
412 |
+
|
413 |
+
### Python Code
|
414 |
+
</div>
|
415 |
+
""",
|
416 |
+
unsafe_allow_html=True,
|
417 |
+
)
|
418 |
+
|
419 |
+
# Display python code
|
420 |
+
st.code(generated_code, language="python")
|
421 |
+
|
422 |
+
with col2:
|
423 |
+
# Create a button for downloading the Python file
|
424 |
+
st.download_button(
|
425 |
+
label="Download Python File",
|
426 |
+
data=generate_downloadable_file(generated_code),
|
427 |
+
file_name="image_processing_script.py",
|
428 |
+
mime="text/plain",
|
429 |
+
use_container_width=True,
|
430 |
+
)
|
Functions/model_training_functions.py
ADDED
@@ -0,0 +1,1896 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Importing necessary libraries
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
import utils
|
5 |
+
import random
|
6 |
+
import shutil
|
7 |
+
import zipfile
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import streamlit as st
|
11 |
+
from ultralytics import YOLO
|
12 |
+
import plotly.graph_objs as go
|
13 |
+
from onnx.defs import onnx_opset_version
|
14 |
+
from plotly.subplots import make_subplots
|
15 |
+
|
16 |
+
|
17 |
+
# Function to get the dataset directory path based on the specified path type
|
18 |
+
def get_path(path_type):
|
19 |
+
main_directory_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
20 |
+
|
21 |
+
if path_type == "train":
|
22 |
+
return os.path.join(
|
23 |
+
main_directory_path,
|
24 |
+
"model_data",
|
25 |
+
"input_files",
|
26 |
+
"datasets",
|
27 |
+
"train",
|
28 |
+
)
|
29 |
+
elif path_type == "val":
|
30 |
+
return os.path.join(
|
31 |
+
main_directory_path,
|
32 |
+
"model_data",
|
33 |
+
"input_files",
|
34 |
+
"datasets",
|
35 |
+
"val",
|
36 |
+
)
|
37 |
+
elif path_type == "test":
|
38 |
+
return os.path.join(
|
39 |
+
main_directory_path,
|
40 |
+
"model_data",
|
41 |
+
"input_files",
|
42 |
+
"datasets",
|
43 |
+
"test",
|
44 |
+
)
|
45 |
+
elif path_type == "config":
|
46 |
+
return os.path.join(main_directory_path, "model_data", "input_files")
|
47 |
+
elif path_type == "models":
|
48 |
+
return os.path.join(main_directory_path, "model_data", "models")
|
49 |
+
elif path_type == "output":
|
50 |
+
return os.path.join(main_directory_path, "model_data", "output_files")
|
51 |
+
else:
|
52 |
+
raise ValueError(f"Invalid path_type: {path_type}")
|
53 |
+
|
54 |
+
|
55 |
+
# Function to check minimum images in training and validation set
|
56 |
+
def check_min_images(total_files, train_pct, val_pct, test_pct):
|
57 |
+
# Calculate raw counts based on percentages
|
58 |
+
train_count = int(total_files * train_pct / 100)
|
59 |
+
val_count = int(total_files * val_pct / 100)
|
60 |
+
test_count = int(total_files * test_pct / 100)
|
61 |
+
|
62 |
+
# Ensure that both train and validation have at least one file
|
63 |
+
if train_count < 1 or val_count < 1:
|
64 |
+
return False
|
65 |
+
|
66 |
+
return True
|
67 |
+
|
68 |
+
|
69 |
+
# Function to clear data a folders
|
70 |
+
def clear_data_folders():
|
71 |
+
base_path = "./model_data/input_files/datasets"
|
72 |
+
for folder in ["train", "test", "val"]:
|
73 |
+
for subfolder in ["images", "labels"]:
|
74 |
+
folder_path = os.path.join(base_path, folder, subfolder)
|
75 |
+
if os.path.exists(folder_path):
|
76 |
+
shutil.rmtree(folder_path)
|
77 |
+
os.makedirs(folder_path, exist_ok=True)
|
78 |
+
|
79 |
+
|
80 |
+
# Function to pairs image and label files based on their filenames
|
81 |
+
def pair_files(files):
|
82 |
+
paired_files = {}
|
83 |
+
for file in files:
|
84 |
+
# Split the filename into name and extension
|
85 |
+
file_name, file_ext = os.path.splitext(file.name)
|
86 |
+
|
87 |
+
# Initialize a dict for each unique file name
|
88 |
+
if file_name not in paired_files:
|
89 |
+
paired_files[file_name] = {"image": None, "label": None}
|
90 |
+
|
91 |
+
# Assign the file to its corresponding type (image or label) based on extension
|
92 |
+
if file_ext.lower() in [".jpg", ".png"]:
|
93 |
+
paired_files[file_name]["image"] = file
|
94 |
+
elif file_ext.lower() == ".txt":
|
95 |
+
paired_files[file_name]["label"] = file
|
96 |
+
|
97 |
+
return paired_files
|
98 |
+
|
99 |
+
|
100 |
+
# Function to split the paired files into training, testing, and validation sets based on specified percentages and saves them in corresponding folders
|
101 |
+
def split_and_save_files(paired_files, train_pct, test_pct):
|
102 |
+
base_path = "./model_data/input_files/datasets"
|
103 |
+
all_keys = list(paired_files.keys())
|
104 |
+
random.shuffle(all_keys)
|
105 |
+
|
106 |
+
# Determine the size of each dataset split
|
107 |
+
total_files = len(all_keys)
|
108 |
+
train_size = int(total_files * train_pct / 100)
|
109 |
+
test_size = int(total_files * test_pct / 100)
|
110 |
+
|
111 |
+
# Split the file keys into training, testing, and validation sets
|
112 |
+
train_keys = all_keys[:train_size]
|
113 |
+
test_keys = all_keys[train_size : train_size + test_size]
|
114 |
+
val_keys = all_keys[train_size + test_size :]
|
115 |
+
|
116 |
+
# Iterate through each split and save the files to their respective directories
|
117 |
+
for folder_name, keys in zip(
|
118 |
+
["train", "test", "val"], [train_keys, test_keys, val_keys]
|
119 |
+
):
|
120 |
+
for key in keys:
|
121 |
+
image_file = paired_files[key]["image"]
|
122 |
+
label_file = paired_files[key]["label"]
|
123 |
+
# Save the image and label files if they exist
|
124 |
+
if image_file:
|
125 |
+
save_file_to_folder(
|
126 |
+
image_file, os.path.join(base_path, folder_name, "images")
|
127 |
+
)
|
128 |
+
if label_file:
|
129 |
+
save_file_to_folder(
|
130 |
+
label_file, os.path.join(base_path, folder_name, "labels")
|
131 |
+
)
|
132 |
+
|
133 |
+
|
134 |
+
# Function to save an individual file to a specified folder
|
135 |
+
def save_file_to_folder(file, folder_path):
|
136 |
+
os.makedirs(folder_path, exist_ok=True)
|
137 |
+
file_path = os.path.join(folder_path, file.name)
|
138 |
+
with open(file_path, "wb") as f:
|
139 |
+
f.write(file.getbuffer())
|
140 |
+
|
141 |
+
|
142 |
+
# Function to save uploaded files to a specific folder within the base path
|
143 |
+
def save_files_to_folder(uploaded_files, folder_name):
|
144 |
+
# Define the base path for saving the files
|
145 |
+
base_path = "./model_data/input_files/datasets"
|
146 |
+
|
147 |
+
# Iterate through each uploaded file
|
148 |
+
for file in uploaded_files:
|
149 |
+
if file:
|
150 |
+
# Determine the file type based on file extension
|
151 |
+
file_type = (
|
152 |
+
"images"
|
153 |
+
if os.path.splitext(file.name)[1].lower() in [".jpg", ".png"]
|
154 |
+
else "labels"
|
155 |
+
)
|
156 |
+
|
157 |
+
# Save the file to the appropriate subfolder (images or labels)
|
158 |
+
save_file_to_folder(file, os.path.join(base_path, folder_name, file_type))
|
159 |
+
|
160 |
+
|
161 |
+
# Function to validate each line in the label file for bounding box data
|
162 |
+
def check_bboxes_label(label_file, class_dict):
|
163 |
+
for line in label_file:
|
164 |
+
try:
|
165 |
+
# Decode the line, strip whitespace, split into parts, and convert each part to float
|
166 |
+
class_id, x_center, y_center, width, height = map(
|
167 |
+
float, line.decode().strip().split()
|
168 |
+
)
|
169 |
+
|
170 |
+
# Check if bounding box coordinates and class ID are valid
|
171 |
+
if not (
|
172 |
+
0 <= x_center <= 1
|
173 |
+
and 0 <= y_center <= 1
|
174 |
+
and 0 <= width <= 1
|
175 |
+
and 0 <= height <= 1
|
176 |
+
and class_id in class_dict.keys()
|
177 |
+
):
|
178 |
+
# Return False if any condition is not met (invalid data)
|
179 |
+
return False
|
180 |
+
|
181 |
+
except Exception as e:
|
182 |
+
# Return False in case of any exception (e.g., parsing error)
|
183 |
+
return False
|
184 |
+
|
185 |
+
# Return True if all lines in the label file pass the validation
|
186 |
+
return True
|
187 |
+
|
188 |
+
|
189 |
+
# Function to validate each line in the label file for mask data
|
190 |
+
def check_masks_label(label_file, class_dict):
|
191 |
+
for line in label_file:
|
192 |
+
try:
|
193 |
+
# Decode the line and split into parts: class ID and points
|
194 |
+
parts = line.decode().strip().split()
|
195 |
+
class_id = int(
|
196 |
+
parts[0]
|
197 |
+
) # Convert the first part to an integer for class ID
|
198 |
+
points = [
|
199 |
+
float(p) for p in parts[1:]
|
200 |
+
] # Convert the remaining parts to float for coordinates
|
201 |
+
|
202 |
+
# Check if class ID exists in the class dictionary and all points are within [0, 1]
|
203 |
+
if not (class_id in class_dict.keys() and all(0 <= p <= 1 for p in points)):
|
204 |
+
return False # Return False if validation fails
|
205 |
+
|
206 |
+
except Exception as e:
|
207 |
+
# Return False in case of any exception (e.g., parsing error)
|
208 |
+
return False
|
209 |
+
|
210 |
+
return True # Return True if all lines in the label file pass the validation
|
211 |
+
|
212 |
+
|
213 |
+
# Function to read label from YOLO format
|
214 |
+
def read_label(file, selected_option, class_dict):
|
215 |
+
# Read the content of the file
|
216 |
+
file_content = file.readlines()
|
217 |
+
|
218 |
+
# Check and validate bounding box labels if the selected option is 'Bboxes'
|
219 |
+
if selected_option == "Bboxes":
|
220 |
+
return check_bboxes_label(file_content, class_dict) # Validate bbox labels
|
221 |
+
|
222 |
+
# Check and validate mask labels if the selected option is 'Masks'
|
223 |
+
elif selected_option == "Masks":
|
224 |
+
return check_masks_label(file_content, class_dict) # Validate mask labels
|
225 |
+
|
226 |
+
# Return False if the selected option is neither 'Bboxes' nor 'Masks'
|
227 |
+
return False
|
228 |
+
|
229 |
+
|
230 |
+
# Function to check for duplicates
|
231 |
+
def check_file_duplicates(file_names):
|
232 |
+
unique_names = set(file_names)
|
233 |
+
return len(unique_names) == len(file_names)
|
234 |
+
|
235 |
+
|
236 |
+
# Function to validates the uploaded image and label files
|
237 |
+
def validate_files(image_names, label_names):
|
238 |
+
# Check for duplicate filenames in both images and labels
|
239 |
+
if not check_file_duplicates(image_names) or not check_file_duplicates(label_names):
|
240 |
+
# Show warning if duplicates are found
|
241 |
+
st.warning(
|
242 |
+
"Duplicate file names detected. Please ensure each image and label has a unique name.",
|
243 |
+
icon="⚠️",
|
244 |
+
)
|
245 |
+
return False # Return False indicating validation failed
|
246 |
+
|
247 |
+
# Check if the number of images matches the number of labels
|
248 |
+
if len(image_names) != len(label_names):
|
249 |
+
# Show warning if counts don't match
|
250 |
+
st.warning(
|
251 |
+
"Count Mismatch: The number of uploaded images and labels does not match.",
|
252 |
+
icon="⚠️",
|
253 |
+
)
|
254 |
+
return False # Return False indicating validation failed
|
255 |
+
|
256 |
+
# Display a success message if the above checks pass
|
257 |
+
st.info(
|
258 |
+
f"Validated: {len(image_names)} images and labels successfully matched.",
|
259 |
+
icon="✅",
|
260 |
+
)
|
261 |
+
return True # Return True indicating successful validation
|
262 |
+
|
263 |
+
|
264 |
+
# Function to check labels format
|
265 |
+
@st.cache_resource(show_spinner=False)
|
266 |
+
def check_valid_labels(uploaded_files, selected_option, class_dict):
|
267 |
+
# Check if no files were uploaded
|
268 |
+
if len(uploaded_files) == 0:
|
269 |
+
st.warning("Please upload images and labels.", icon="⚠️")
|
270 |
+
return False
|
271 |
+
|
272 |
+
# Initialize lists to store names of image and label files
|
273 |
+
image_names, label_names = [], []
|
274 |
+
|
275 |
+
# Initialize a progress bar and progress text
|
276 |
+
progress_bar = st.progress(0)
|
277 |
+
progress_text = st.empty()
|
278 |
+
total_files = len(uploaded_files)
|
279 |
+
|
280 |
+
# Iterate over each uploaded file
|
281 |
+
for index, file in enumerate(uploaded_files):
|
282 |
+
# Reset the file pointer to the beginning
|
283 |
+
file.seek(0)
|
284 |
+
|
285 |
+
# Check file type and categorize as image or label
|
286 |
+
if file.type in ["image/jpeg", "image/png"]:
|
287 |
+
# Add to image names list if file is an image
|
288 |
+
image_names.append(file.name)
|
289 |
+
elif file.type == "text/plain":
|
290 |
+
# Read and validate label file
|
291 |
+
if not read_label(file, selected_option, class_dict):
|
292 |
+
# Show warning if label format or data is invalid
|
293 |
+
st.warning(
|
294 |
+
f"Invalid label format or data in file: {file.name}", icon="⚠️"
|
295 |
+
)
|
296 |
+
return False
|
297 |
+
# Add to label names list if file is a valid label
|
298 |
+
label_names.append(file.name)
|
299 |
+
|
300 |
+
# Update progress bar and display current progress
|
301 |
+
progress_percentage = (index + 1) / total_files
|
302 |
+
progress_bar.progress(progress_percentage)
|
303 |
+
progress_text.text(f"Validating file {index + 1} of {total_files}")
|
304 |
+
|
305 |
+
# Remove progress bar and progress text after processing
|
306 |
+
progress_bar.empty()
|
307 |
+
progress_text.empty()
|
308 |
+
|
309 |
+
# Validate if all images have corresponding labels and vice versa
|
310 |
+
return validate_files(image_names, label_names)
|
311 |
+
|
312 |
+
|
313 |
+
# Function to get training, validation and export configurations
|
314 |
+
def get_training_validation_export_configuration(selected_training):
|
315 |
+
with st.expander("Training Configuration"):
|
316 |
+
# User Instruction for Default Values
|
317 |
+
st.markdown(
|
318 |
+
"""
|
319 |
+
<div style='text-align: justify;'>
|
320 |
+
<b>User Instructions:</b> If you are unsure about the specific values to use for training parameters, it is
|
321 |
+
recommended to stick with the default values provided. These defaults are carefully chosen to provide a good balance
|
322 |
+
between performance and resource utilization for most scenarios. You can always come back and tweak these settings
|
323 |
+
once you have more experience or specific requirements for your model training.
|
324 |
+
</div>
|
325 |
+
""",
|
326 |
+
unsafe_allow_html=True,
|
327 |
+
)
|
328 |
+
|
329 |
+
# Padding
|
330 |
+
utils.top_padding(2)
|
331 |
+
|
332 |
+
# Training Configuration
|
333 |
+
st.markdown("### Training Configuration")
|
334 |
+
|
335 |
+
# Model Selection
|
336 |
+
st.write("**Model Selection**")
|
337 |
+
selected_model = st.selectbox(
|
338 |
+
"Choose a YOLOv8 model variant", list(utils.models_info.keys())
|
339 |
+
)
|
340 |
+
model_spec = utils.models_info[selected_model]
|
341 |
+
spec_string = (
|
342 |
+
"<div style='text-align: justify;'>"
|
343 |
+
f"The selected model, <b>{selected_model}</b>, is benchmarked on an image size of 640x640 pixels. It has a Mean Average Precision (mAPval) of <b>{model_spec['mAPval']}</b>, "
|
344 |
+
f"operates with a speed of <b>{model_spec['speed_cpu']} ms</b> on CPU (ONNX) and <b>{model_spec['speed_gpu']} ms</b> on GPU (TensorRT). "
|
345 |
+
f"It consists of approximately <b>{model_spec['params']} million</b> parameters and requires about <b>{model_spec['flops']} billion</b> Floating Point Operations (FLOPs)."
|
346 |
+
"</div>"
|
347 |
+
)
|
348 |
+
st.markdown(spec_string, unsafe_allow_html=True)
|
349 |
+
|
350 |
+
# Spacer
|
351 |
+
st.markdown("---")
|
352 |
+
|
353 |
+
# Time Configuration
|
354 |
+
st.write("**Time Configuration**")
|
355 |
+
col1_time, col2_time = st.columns([1, 3])
|
356 |
+
with col1_time:
|
357 |
+
top_padding_time = st.container()
|
358 |
+
time_allow = st.checkbox("Enable Time", value=False)
|
359 |
+
if time_allow:
|
360 |
+
with top_padding_time:
|
361 |
+
utils.top_padding(2)
|
362 |
+
time = col2_time.number_input(
|
363 |
+
"Time (hours)", min_value=1, max_value=100, value=1, step=1
|
364 |
+
)
|
365 |
+
else:
|
366 |
+
time = None
|
367 |
+
st.markdown(
|
368 |
+
"<div style='text-align: justify;'>Set the training duration in hours. This option overrides the epochs setting. Useful for limiting training time in scenarios with constrained resources.</div>",
|
369 |
+
unsafe_allow_html=True,
|
370 |
+
)
|
371 |
+
|
372 |
+
# Spacer
|
373 |
+
st.markdown("---")
|
374 |
+
|
375 |
+
# Epochs Configuration
|
376 |
+
st.write("**Epochs Configuration**")
|
377 |
+
epochs = st.number_input(
|
378 |
+
"Epochs", min_value=1, max_value=1000, value=50, step=10
|
379 |
+
)
|
380 |
+
st.markdown(
|
381 |
+
"<div style='text-align: justify;'>Define the number of epochs for the training process. An epoch represents a complete pass over the entire dataset. More epochs can improve accuracy but increase training time.</div>",
|
382 |
+
unsafe_allow_html=True,
|
383 |
+
)
|
384 |
+
|
385 |
+
# Spacer
|
386 |
+
st.markdown("---")
|
387 |
+
|
388 |
+
# Patience Configuration
|
389 |
+
st.write("**Patience Configuration**")
|
390 |
+
col1_patience, col2_patience = st.columns([1, 3])
|
391 |
+
with col1_patience:
|
392 |
+
top_padding_patience = st.container()
|
393 |
+
patience_allow = st.checkbox("Enable Patience", value=False)
|
394 |
+
if patience_allow:
|
395 |
+
with top_padding_patience:
|
396 |
+
utils.top_padding(2)
|
397 |
+
patience = col2_patience.number_input(
|
398 |
+
"Patience (epochs)", min_value=5, max_value=50, value=5, step=1
|
399 |
+
)
|
400 |
+
else:
|
401 |
+
patience = None
|
402 |
+
st.markdown(
|
403 |
+
"<div style='text-align: justify;'>Configure the early stopping mechanism. Patience denotes the number of epochs to wait for improvement in performance before stopping the training, helping to avoid overfitting.</div>",
|
404 |
+
unsafe_allow_html=True,
|
405 |
+
)
|
406 |
+
|
407 |
+
# Spacer
|
408 |
+
st.markdown("---")
|
409 |
+
|
410 |
+
# Batch Size Configuration
|
411 |
+
st.write("**Batch Size Configuration**")
|
412 |
+
batch = st.number_input(
|
413 |
+
"Batch Size", min_value=-1, max_value=128, value=-1, step=1
|
414 |
+
)
|
415 |
+
st.markdown(
|
416 |
+
"<div style='text-align: justify;'>Determine the number of images processed together in one pass (batch). A larger batch size can lead to faster training but requires more memory. Use -1 for automatic batch sizing.</div>",
|
417 |
+
unsafe_allow_html=True,
|
418 |
+
)
|
419 |
+
|
420 |
+
# Spacer
|
421 |
+
st.markdown("---")
|
422 |
+
|
423 |
+
# Image Size Configuration
|
424 |
+
st.write("**Image Size Configuration**")
|
425 |
+
imgsz = st.number_input(
|
426 |
+
"Image Size (pixels)", min_value=64, max_value=4096, value=640, step=32
|
427 |
+
)
|
428 |
+
st.markdown(
|
429 |
+
"<div style='text-align: justify;'>Specify the size of the input images. Larger images can capture more details but require more computational resources. The size is typically a square dimension, like 640x640 pixels.</div>",
|
430 |
+
unsafe_allow_html=True,
|
431 |
+
)
|
432 |
+
|
433 |
+
# Spacer
|
434 |
+
st.markdown("---")
|
435 |
+
|
436 |
+
# Cache Configuration
|
437 |
+
st.write("**Cache Configuration**")
|
438 |
+
cache = st.selectbox("Cache Option", ["False", "True/ram", "disk"])
|
439 |
+
st.markdown(
|
440 |
+
"<div style='text-align: justify;'>Choose a caching method for data loading to speed up training. 'True/ram' caches data in RAM, 'disk' caches on disk, and 'False' disables caching.</div>",
|
441 |
+
unsafe_allow_html=True,
|
442 |
+
)
|
443 |
+
|
444 |
+
# Spacer
|
445 |
+
st.markdown("---")
|
446 |
+
|
447 |
+
# Optimizer Configuration
|
448 |
+
st.write("**Optimizer Configuration**")
|
449 |
+
optimizer = st.selectbox(
|
450 |
+
"Optimizer",
|
451 |
+
["SGD", "Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "auto"],
|
452 |
+
index=7,
|
453 |
+
)
|
454 |
+
st.markdown(
|
455 |
+
"<div style='text-align: justify;'>Select the optimizer for training. The optimizer adjusts weights to minimize the loss function. Choices include SGD, Adam, and others, with 'auto' selecting automatically based on the model.</div>",
|
456 |
+
unsafe_allow_html=True,
|
457 |
+
)
|
458 |
+
|
459 |
+
# Spacer
|
460 |
+
st.markdown("---")
|
461 |
+
|
462 |
+
# AMP Configuration
|
463 |
+
st.write("**AMP Configuration**")
|
464 |
+
amp = st.checkbox("Enable AMP", value=True)
|
465 |
+
st.markdown(
|
466 |
+
"<div style='text-align: justify;'>Enable Automatic Mixed Precision (AMP) to accelerate training on compatible hardware. AMP uses lower precision to reduce memory usage and speed up computations.</div>",
|
467 |
+
unsafe_allow_html=True,
|
468 |
+
)
|
469 |
+
|
470 |
+
# Spacer
|
471 |
+
st.markdown("---")
|
472 |
+
|
473 |
+
# Deterministic Mode Configuration
|
474 |
+
st.write("**Deterministic Mode Configuration**")
|
475 |
+
deterministic = st.checkbox("Enable Deterministic Mode", value=False)
|
476 |
+
st.markdown(
|
477 |
+
"<div style='text-align: justify;'>Activate deterministic mode to ensure reproducible results. This mode might slow down the training but is useful for experimentation and debugging.</div>",
|
478 |
+
unsafe_allow_html=True,
|
479 |
+
)
|
480 |
+
|
481 |
+
# Spacer
|
482 |
+
st.markdown("---")
|
483 |
+
|
484 |
+
# Rectangular Training Configuration
|
485 |
+
st.write("**Rectangular Training Configuration**")
|
486 |
+
rect = st.checkbox("Enable Rectangular Training", value=False)
|
487 |
+
st.markdown(
|
488 |
+
"<div style='text-align: justify;'>Enable rectangular training to process batches with minimal padding by reshaping images. This can lead to performance improvements but may affect accuracy.</div>",
|
489 |
+
unsafe_allow_html=True,
|
490 |
+
)
|
491 |
+
|
492 |
+
# Spacer
|
493 |
+
st.markdown("---")
|
494 |
+
|
495 |
+
# Cosine Learning Rate Scheduler Configuration
|
496 |
+
st.write("**Cosine Learning Rate Scheduler**")
|
497 |
+
cos_lr = st.checkbox("Use Cosine LR Scheduler", value=False)
|
498 |
+
st.markdown(
|
499 |
+
"<div style='text-align: justify;'>Use a cosine learning rate scheduler to adjust the learning rate following a cosine curve, potentially leading to better convergence during training.</div>",
|
500 |
+
unsafe_allow_html=True,
|
501 |
+
)
|
502 |
+
|
503 |
+
# Spacer
|
504 |
+
st.markdown("---")
|
505 |
+
|
506 |
+
# Freeze Layer Configuration
|
507 |
+
st.write("**Freeze Layer Configuration**")
|
508 |
+
col1_freeze, col2_freeze = st.columns([1, 3])
|
509 |
+
with col1_freeze:
|
510 |
+
top_padding_freeze = st.container()
|
511 |
+
freeze_allow = st.checkbox("Enable Freeze Layers", value=False)
|
512 |
+
if freeze_allow:
|
513 |
+
with top_padding_freeze:
|
514 |
+
utils.top_padding(2)
|
515 |
+
freeze = col2_freeze.number_input(
|
516 |
+
"Freeze Layers",
|
517 |
+
min_value=1,
|
518 |
+
max_value=1000,
|
519 |
+
value=10,
|
520 |
+
placeholder="Enter number of layers",
|
521 |
+
)
|
522 |
+
else:
|
523 |
+
freeze = None
|
524 |
+
st.markdown(
|
525 |
+
"<div style='text-align: justify;'>Enable freezing the initial layers of the model during training. Specify the number of layers to freeze or a comma-separated list of specific layer indices. Useful for fine-tuning pre-trained models without modifying early layers.</div>",
|
526 |
+
unsafe_allow_html=True,
|
527 |
+
)
|
528 |
+
|
529 |
+
# Spacer
|
530 |
+
st.markdown("---")
|
531 |
+
|
532 |
+
# Initial Learning Rate Configuration
|
533 |
+
st.write("**Initial Learning Rate (lr0)**")
|
534 |
+
lr0 = st.number_input(
|
535 |
+
"Initial Learning Rate (lr0)",
|
536 |
+
min_value=0.00001,
|
537 |
+
max_value=1.0,
|
538 |
+
value=0.01,
|
539 |
+
format="%.5f",
|
540 |
+
)
|
541 |
+
st.markdown(
|
542 |
+
"<div style='text-align: justify;'>Specify the initial learning rate (lr0) for the training process. The initial rate is crucial as it determines the starting step size for weight updates. A well-chosen initial rate helps in achieving a balance between fast convergence and overshooting the optimal solution.</div>",
|
543 |
+
unsafe_allow_html=True,
|
544 |
+
)
|
545 |
+
|
546 |
+
# Spacer
|
547 |
+
st.markdown("---")
|
548 |
+
|
549 |
+
# Final Learning Rate Configuration
|
550 |
+
st.write("**Final Learning Rate (lrf)**")
|
551 |
+
lrf = st.number_input(
|
552 |
+
"Final Learning Rate (lrf)",
|
553 |
+
min_value=0.00001,
|
554 |
+
max_value=1.0,
|
555 |
+
value=0.01,
|
556 |
+
format="%.5f",
|
557 |
+
)
|
558 |
+
st.markdown(
|
559 |
+
"<div style='text-align: justify;'>Determine the final learning rate, which is a factor (lrf) of the initial learning rate (lr0). This parameter is used to adjust the learning rate over the course of training, gradually decreasing it to fine-tune model weights and stabilize training as it approaches the minimum of the loss function.</div>",
|
560 |
+
unsafe_allow_html=True,
|
561 |
+
)
|
562 |
+
|
563 |
+
# Spacer
|
564 |
+
st.markdown("---")
|
565 |
+
|
566 |
+
# Momentum Configuration
|
567 |
+
st.write("**Momentum Configuration**")
|
568 |
+
momentum = st.number_input(
|
569 |
+
"Momentum", min_value=0.0, max_value=1.0, value=0.937, format="%.3f"
|
570 |
+
)
|
571 |
+
st.markdown(
|
572 |
+
"<div style='text-align: justify;'>Set the momentum value for the optimizer. Momentum helps in accelerating the optimizer in the relevant direction and dampens oscillations, facilitating faster convergence.</div>",
|
573 |
+
unsafe_allow_html=True,
|
574 |
+
)
|
575 |
+
|
576 |
+
# Spacer
|
577 |
+
st.markdown("---")
|
578 |
+
|
579 |
+
# Weight Decay Configuration
|
580 |
+
st.write("**Weight Decay Configuration**")
|
581 |
+
weight_decay = st.number_input(
|
582 |
+
"Weight Decay", min_value=0.0, max_value=0.1, value=0.0005, format="%.5f"
|
583 |
+
)
|
584 |
+
st.markdown(
|
585 |
+
"<div style='text-align: justify;'>Specify the weight decay, a regularization technique that adds a small penalty to the loss function for larger weights. It helps in preventing overfitting by encouraging simpler models.</div>",
|
586 |
+
unsafe_allow_html=True,
|
587 |
+
)
|
588 |
+
|
589 |
+
# Spacer
|
590 |
+
st.markdown("---")
|
591 |
+
|
592 |
+
# Warmup Epochs Configuration
|
593 |
+
st.write("**Warmup Epochs Configuration**")
|
594 |
+
warmup_epochs = st.number_input(
|
595 |
+
"Warmup Epochs", min_value=0.0, max_value=10.0, value=3.0, step=0.1
|
596 |
+
)
|
597 |
+
st.markdown(
|
598 |
+
"<div style='text-align: justify;'>Define the number of warmup epochs. During warmup, the learning rate gradually increases to its initial value, which helps in stabilizing the training process in its early stages.</div>",
|
599 |
+
unsafe_allow_html=True,
|
600 |
+
)
|
601 |
+
|
602 |
+
# Spacer
|
603 |
+
st.markdown("---")
|
604 |
+
|
605 |
+
# Warmup Momentum Configuration
|
606 |
+
st.write("**Warmup Momentum Configuration**")
|
607 |
+
warmup_momentum = st.number_input(
|
608 |
+
"Warmup Momentum", min_value=0.0, max_value=1.0, value=0.8, format="%.1f"
|
609 |
+
)
|
610 |
+
st.markdown(
|
611 |
+
"<div style='text-align: justify;'>Configure the momentum during the warmup phase. A lower momentum at the start can help in stabilizing the optimization process before reaching the specified momentum for the remaining epochs.</div>",
|
612 |
+
unsafe_allow_html=True,
|
613 |
+
)
|
614 |
+
|
615 |
+
# Spacer
|
616 |
+
st.markdown("---")
|
617 |
+
|
618 |
+
# Warmup Bias Learning Rate Configuration
|
619 |
+
st.write("**Warmup Bias Learning Rate Configuration**")
|
620 |
+
warmup_bias_lr = st.number_input(
|
621 |
+
"Warmup Bias Learning Rate",
|
622 |
+
min_value=0.0,
|
623 |
+
max_value=1.0,
|
624 |
+
value=0.1,
|
625 |
+
format="%.1f",
|
626 |
+
)
|
627 |
+
st.markdown(
|
628 |
+
"<div style='text-align: justify;'>Adjust the bias learning rate during the warmup period. This parameter can be tuned to manage the initial learning rate specifically for the bias parameters in the early training phase.</div>",
|
629 |
+
unsafe_allow_html=True,
|
630 |
+
)
|
631 |
+
|
632 |
+
# Spacer
|
633 |
+
st.markdown("---")
|
634 |
+
|
635 |
+
# Box Loss Gain Configuration
|
636 |
+
st.write("**Box Loss Gain Configuration**")
|
637 |
+
box = st.number_input(
|
638 |
+
"Box Loss Gain", min_value=0.0, max_value=10.0, value=7.5, step=0.1
|
639 |
+
)
|
640 |
+
st.markdown(
|
641 |
+
"<div style='text-align: justify;'>Configure the gain factor for the box loss. This gain helps in adjusting the importance of the box size and location accuracy in the loss function, affecting how the model prioritizes bounding box precision.</div>",
|
642 |
+
unsafe_allow_html=True,
|
643 |
+
)
|
644 |
+
|
645 |
+
# Spacer
|
646 |
+
st.markdown("---")
|
647 |
+
|
648 |
+
# Class Loss Gain Configuration
|
649 |
+
st.write("**Class Loss Gain Configuration**")
|
650 |
+
cls = st.number_input(
|
651 |
+
"Class Loss Gain", min_value=0.0, max_value=10.0, value=0.5, step=0.1
|
652 |
+
)
|
653 |
+
st.markdown(
|
654 |
+
"<div style='text-align: justify;'>Set the gain factor for the class loss. This parameter scales the contribution of class prediction accuracy in the total loss, influencing how the model prioritizes correct class identification.</div>",
|
655 |
+
unsafe_allow_html=True,
|
656 |
+
)
|
657 |
+
|
658 |
+
# Spacer
|
659 |
+
st.markdown("---")
|
660 |
+
|
661 |
+
# DFL Loss Gain Configuration
|
662 |
+
st.write("**DFL Loss Gain Configuration**")
|
663 |
+
dfl = st.number_input(
|
664 |
+
"DFL Loss Gain", min_value=0.0, max_value=10.0, value=1.5, step=0.1
|
665 |
+
)
|
666 |
+
st.markdown(
|
667 |
+
"<div style='text-align: justify;'>Determine the gain factor for the DFL loss. Adjusting this gain influences the model's focus on the Directional Focal Loss component, which is critical for precise object localization and classification.</div>",
|
668 |
+
unsafe_allow_html=True,
|
669 |
+
)
|
670 |
+
|
671 |
+
# Spacer
|
672 |
+
st.markdown("---")
|
673 |
+
|
674 |
+
# Label Smoothing Configuration
|
675 |
+
st.write("**Label Smoothing Configuration**")
|
676 |
+
label_smoothing = st.number_input(
|
677 |
+
"Label Smoothing (fraction)",
|
678 |
+
min_value=0.0,
|
679 |
+
max_value=1.0,
|
680 |
+
value=0.0,
|
681 |
+
format="%.1f",
|
682 |
+
)
|
683 |
+
st.markdown(
|
684 |
+
"<div style='text-align: justify;'>Specify the label smoothing value, a technique that introduces softening to the target labels. It promotes model generalization and reduces the impact of noisy labels on the training process.</div>",
|
685 |
+
unsafe_allow_html=True,
|
686 |
+
)
|
687 |
+
|
688 |
+
# Spacer
|
689 |
+
st.markdown("---")
|
690 |
+
|
691 |
+
# Nominal Batch Size Configuration
|
692 |
+
st.write("**Nominal Batch Size Configuration**")
|
693 |
+
nbs = st.number_input(
|
694 |
+
"Nominal Batch Size", min_value=1, max_value=128, value=64, step=1
|
695 |
+
)
|
696 |
+
st.markdown(
|
697 |
+
"<div style='text-align: justify;'>Set the nominal batch size, which is used for normalizing the loss. This size does not affect the actual batch size but is used to scale the loss to a standard reference batch size.</div>",
|
698 |
+
unsafe_allow_html=True,
|
699 |
+
)
|
700 |
+
|
701 |
+
# Spacer
|
702 |
+
st.markdown("---")
|
703 |
+
|
704 |
+
# Overlap Mask Configuration
|
705 |
+
st.write("**Overlap Mask Configuration**")
|
706 |
+
overlap_mask = st.checkbox("Masks Overlap during Training", value=True)
|
707 |
+
st.markdown(
|
708 |
+
"<div style='text-align: justify;'>Choose whether to allow masks to overlap during instance segmentation training. Overlapping can lead to more precise segmentation but may increase complexity.</div>",
|
709 |
+
unsafe_allow_html=True,
|
710 |
+
)
|
711 |
+
|
712 |
+
# Spacer
|
713 |
+
st.markdown("---")
|
714 |
+
|
715 |
+
# Mask Ratio Configuration
|
716 |
+
st.write("**Mask Ratio Configuration**")
|
717 |
+
mask_ratio = st.number_input(
|
718 |
+
"Mask Downsample Ratio", min_value=1, max_value=10, value=4, step=1
|
719 |
+
)
|
720 |
+
st.markdown(
|
721 |
+
"<div style='text-align: justify;'>Set the downsample ratio for masks in instance segmentation. A higher ratio reduces the mask resolution, which can speed up computations but might decrease segmentation accuracy.</div>",
|
722 |
+
unsafe_allow_html=True,
|
723 |
+
)
|
724 |
+
|
725 |
+
# Spacer
|
726 |
+
st.markdown("---")
|
727 |
+
|
728 |
+
# Dropout Configuration
|
729 |
+
st.write("**Dropout Configuration**")
|
730 |
+
dropout = st.number_input(
|
731 |
+
"Dropout Regularization",
|
732 |
+
min_value=0.0,
|
733 |
+
max_value=1.0,
|
734 |
+
value=0.0,
|
735 |
+
format="%.1f",
|
736 |
+
)
|
737 |
+
st.markdown(
|
738 |
+
"<div style='text-align: justify;'>Configure the dropout rate, which randomly disables a proportion of neurons during training. This prevents the model from relying too much on certain features and promotes better generalization.</div>",
|
739 |
+
unsafe_allow_html=True,
|
740 |
+
)
|
741 |
+
|
742 |
+
# Spacer
|
743 |
+
st.markdown("---")
|
744 |
+
|
745 |
+
# Validation/Test Configuration
|
746 |
+
st.write("**Validation/Test Configuration**")
|
747 |
+
val = st.checkbox("Validate/Test during Training", value=True)
|
748 |
+
st.markdown(
|
749 |
+
"<div style='text-align: justify;'>Decide whether to perform validation and testing during the training process. Regular validation helps monitor model performance and adjust training accordingly.</div>",
|
750 |
+
unsafe_allow_html=True,
|
751 |
+
)
|
752 |
+
|
753 |
+
# Spacer
|
754 |
+
st.markdown("---")
|
755 |
+
|
756 |
+
# Save Plots Configuration
|
757 |
+
st.write("**Save Plots Configuration**")
|
758 |
+
plots = st.checkbox("Save Plots and Images during Training", value=True)
|
759 |
+
st.markdown(
|
760 |
+
"<div style='text-align: justify;'>Enable saving of plots and images during training. This feature provides visual insights into the training progress and helps in diagnosing model performance across epochs.</div>",
|
761 |
+
unsafe_allow_html=True,
|
762 |
+
)
|
763 |
+
|
764 |
+
# Padding
|
765 |
+
utils.top_padding(2)
|
766 |
+
|
767 |
+
with st.expander("Validation Configuration"):
|
768 |
+
# User Instruction for Default Values
|
769 |
+
st.markdown(
|
770 |
+
"""
|
771 |
+
<div style='text-align: justify;'>
|
772 |
+
<b>User Instructions:</b> If you are unsure about the specific values to use for validation parameters, it is
|
773 |
+
recommended to stick with the default values provided. These defaults are carefully chosen to provide a good balance
|
774 |
+
between performance and resource utilization for most scenarios. You can always come back and tweak these settings
|
775 |
+
once you have more experience or specific requirements for your model validation.
|
776 |
+
</div>
|
777 |
+
""",
|
778 |
+
unsafe_allow_html=True,
|
779 |
+
)
|
780 |
+
|
781 |
+
# Padding
|
782 |
+
utils.top_padding(2)
|
783 |
+
|
784 |
+
# Validation Configuration
|
785 |
+
st.markdown("### Validation Configuration")
|
786 |
+
|
787 |
+
# Object Confidence Threshold
|
788 |
+
st.write("**Object Confidence Threshold**")
|
789 |
+
conf = st.number_input(
|
790 |
+
"Confidence Threshold",
|
791 |
+
min_value=0.0,
|
792 |
+
max_value=1.0,
|
793 |
+
value=0.001,
|
794 |
+
format="%.3f",
|
795 |
+
)
|
796 |
+
st.markdown(
|
797 |
+
"<div style='text-align: justify;'>Set the confidence threshold for object detection. This threshold filters out detections with lower confidence, reducing false positives and focusing on more likely object detections.</div>",
|
798 |
+
unsafe_allow_html=True,
|
799 |
+
)
|
800 |
+
|
801 |
+
# Spacer
|
802 |
+
st.markdown("---")
|
803 |
+
|
804 |
+
# Intersection Over Union (IoU) Threshold
|
805 |
+
st.write("**IoU Threshold for NMS**")
|
806 |
+
iou = st.number_input(
|
807 |
+
"IoU Threshold", min_value=0.0, max_value=1.0, value=0.6, format="%.1f"
|
808 |
+
)
|
809 |
+
st.markdown(
|
810 |
+
"<div style='text-align: justify;'>Define the IoU threshold for Non-Maximum Suppression. NMS is used to refine the bounding boxes by eliminating redundancies and retaining the most probable ones.</div>",
|
811 |
+
unsafe_allow_html=True,
|
812 |
+
)
|
813 |
+
|
814 |
+
# Spacer
|
815 |
+
st.markdown("---")
|
816 |
+
|
817 |
+
# Maximum Number of Detections
|
818 |
+
st.write("**Maximum Number of Detections**")
|
819 |
+
max_det = st.number_input(
|
820 |
+
"Max Detections", min_value=1, max_value=1000, value=300, step=1
|
821 |
+
)
|
822 |
+
st.markdown(
|
823 |
+
"<div style='text-align: justify;'>Limit the maximum number of detections per image. This setting is crucial for controlling the computational load and focusing the model on the most confident and relevant detections.</div>",
|
824 |
+
unsafe_allow_html=True,
|
825 |
+
)
|
826 |
+
|
827 |
+
# Spacer
|
828 |
+
st.markdown("---")
|
829 |
+
|
830 |
+
# Use Half Precision
|
831 |
+
st.write("**Use Half Precision (FP16)**")
|
832 |
+
half = st.checkbox("Enable Half Precision", value=True)
|
833 |
+
st.markdown(
|
834 |
+
"<div style='text-align: justify;'>Enable half precision (FP16) training for enhanced performance on compatible GPUs. It reduces memory requirements and accelerates computation, beneficial for larger models and datasets.</div>",
|
835 |
+
unsafe_allow_html=True,
|
836 |
+
)
|
837 |
+
|
838 |
+
# Padding
|
839 |
+
utils.top_padding(2)
|
840 |
+
|
841 |
+
with st.expander("Export Configuration"):
|
842 |
+
# User Instruction for Default Values
|
843 |
+
st.markdown(
|
844 |
+
"""
|
845 |
+
<div style='text-align: justify;'>
|
846 |
+
<b>User Instructions:</b> If you are unsure about the specific values to use for export parameters, it is
|
847 |
+
recommended to stick with the default values provided. These defaults are carefully chosen to provide a good balance
|
848 |
+
between performance and resource utilization for most scenarios. You can always come back and tweak these settings
|
849 |
+
once you have more experience or specific requirements for your model export.
|
850 |
+
</div>
|
851 |
+
""",
|
852 |
+
unsafe_allow_html=True,
|
853 |
+
)
|
854 |
+
|
855 |
+
# Padding
|
856 |
+
utils.top_padding(2)
|
857 |
+
|
858 |
+
# Validation Configuration
|
859 |
+
st.markdown("### Export Configuration")
|
860 |
+
|
861 |
+
# Select Export Format
|
862 |
+
st.write("**Export Format**")
|
863 |
+
export_format = st.selectbox(
|
864 |
+
"Select Export Format",
|
865 |
+
[
|
866 |
+
"Only PyTorch",
|
867 |
+
"TorchScript",
|
868 |
+
"ONNX",
|
869 |
+
"OpenVINO",
|
870 |
+
"TensorRT",
|
871 |
+
"CoreML",
|
872 |
+
"TF SavedModel",
|
873 |
+
"TF GraphDef",
|
874 |
+
"TF Lite",
|
875 |
+
"TF Edge TPU",
|
876 |
+
"TF.js",
|
877 |
+
"PaddlePaddle",
|
878 |
+
"ncnn",
|
879 |
+
],
|
880 |
+
)
|
881 |
+
|
882 |
+
# Dynamically generate description
|
883 |
+
if export_format == "Only PyTorch":
|
884 |
+
st.markdown(
|
885 |
+
"""
|
886 |
+
<div style='text-align: justify;'>
|
887 |
+
You have selected <b>PyTorch</b> as the export format.
|
888 |
+
This will export the model in the standard PyTorch <code>.pt</code> format.
|
889 |
+
There are no additional format-specific parameters to consider for this selection.
|
890 |
+
The exported model will be the same as selected during training.
|
891 |
+
</div>
|
892 |
+
""",
|
893 |
+
unsafe_allow_html=True,
|
894 |
+
)
|
895 |
+
else:
|
896 |
+
format_info = utils.export_formats[export_format]
|
897 |
+
|
898 |
+
# Handling additional arguments
|
899 |
+
if len(format_info["arguments"]) > 0:
|
900 |
+
additional_arguments = ", ".join(format_info["arguments"])
|
901 |
+
arguments_info = f"Consider the following arguments for the <b>{export_format}</b> format: {additional_arguments}."
|
902 |
+
else:
|
903 |
+
arguments_info = (
|
904 |
+
"No additional parameters need to be considered for this format."
|
905 |
+
)
|
906 |
+
|
907 |
+
st.markdown(
|
908 |
+
f"""
|
909 |
+
<div style='text-align: justify;'>
|
910 |
+
You have selected <b>{export_format}</b> as the export format. Along with the PyTorch model,
|
911 |
+
this selection will also export the model in the <b>{export_format}</b> format. The image size of
|
912 |
+
the exported model will be the same as selected during training. {arguments_info}
|
913 |
+
</div>
|
914 |
+
""",
|
915 |
+
unsafe_allow_html=True,
|
916 |
+
)
|
917 |
+
|
918 |
+
# Spacer
|
919 |
+
st.markdown("---")
|
920 |
+
|
921 |
+
# Use Keras for TF SavedModel export
|
922 |
+
st.write("**Use Keras for TF SavedModel Export**")
|
923 |
+
keras = st.checkbox("Enable Keras", value=False)
|
924 |
+
st.markdown(
|
925 |
+
"<div style='text-align: justify;'>Enabling Keras optimizes the TensorFlow SavedModel export for compatibility with the Keras API, making it easier to work with in Keras-centric workflows.</div>",
|
926 |
+
unsafe_allow_html=True,
|
927 |
+
)
|
928 |
+
|
929 |
+
# Spacer
|
930 |
+
st.markdown("---")
|
931 |
+
|
932 |
+
# Optimize for mobile (TorchScript)
|
933 |
+
st.write("**Optimize TorchScript for Mobile**")
|
934 |
+
optimize = st.checkbox("Enable Optimization", value=False)
|
935 |
+
st.markdown(
|
936 |
+
"<div style='text-align: justify;'>Optimizing for mobile reduces the model size and computational needs, enhancing performance on mobile devices with limited resources.</div>",
|
937 |
+
unsafe_allow_html=True,
|
938 |
+
)
|
939 |
+
|
940 |
+
# Spacer
|
941 |
+
st.markdown("---")
|
942 |
+
|
943 |
+
# FP16 quantization
|
944 |
+
st.write("**FP16 Quantization**")
|
945 |
+
half = st.checkbox("Enable FP16 Quantization", value=False)
|
946 |
+
st.markdown(
|
947 |
+
"<div style='text-align: justify;'>FP16 quantization reduces model size and speeds up inference, especially on GPUs with Tensor Cores, while maintaining model accuracy.</div>",
|
948 |
+
unsafe_allow_html=True,
|
949 |
+
)
|
950 |
+
|
951 |
+
# Spacer
|
952 |
+
st.markdown("---")
|
953 |
+
|
954 |
+
# INT8 quantization
|
955 |
+
st.write("**INT8 Quantization**")
|
956 |
+
int8 = st.checkbox("Enable INT8 Quantization", value=False)
|
957 |
+
st.markdown(
|
958 |
+
"<div style='text-align: justify;'>INT8 quantization further reduces model size and inference time, ideal for edge devices, at the cost of a slight decrease in accuracy.</div>",
|
959 |
+
unsafe_allow_html=True,
|
960 |
+
)
|
961 |
+
|
962 |
+
# Spacer
|
963 |
+
st.markdown("---")
|
964 |
+
|
965 |
+
# Dynamic axes for ONNX/TensorRT
|
966 |
+
st.write("**Dynamic Axes for ONNX/TensorRT**")
|
967 |
+
dynamic = st.checkbox("Enable Dynamic Axes", value=False)
|
968 |
+
st.markdown(
|
969 |
+
"<div style='text-align: justify;'>Dynamic axes allow the ONNX/TensorRT models to handle variable input sizes, increasing the model's flexibility in deployment.</div>",
|
970 |
+
unsafe_allow_html=True,
|
971 |
+
)
|
972 |
+
|
973 |
+
# Spacer
|
974 |
+
st.markdown("---")
|
975 |
+
|
976 |
+
# Simplify model for ONNX/TensorRT
|
977 |
+
st.write("**Simplify Model for ONNX/TensorRT**")
|
978 |
+
simplify = st.checkbox("Enable Model Simplification", value=False)
|
979 |
+
st.markdown(
|
980 |
+
"<div style='text-align: justify;'>Simplification optimizes the ONNX/TensorRT models by removing redundant operations, improving efficiency without impacting accuracy.</div>",
|
981 |
+
unsafe_allow_html=True,
|
982 |
+
)
|
983 |
+
|
984 |
+
# Spacer
|
985 |
+
st.markdown("---")
|
986 |
+
|
987 |
+
# ONNX Opset Version Configuration
|
988 |
+
st.write("**ONNX Opset Version Configuration**")
|
989 |
+
col1_opset, col2_opset = st.columns([1, 3])
|
990 |
+
|
991 |
+
with col1_opset:
|
992 |
+
top_padding_opset = st.container()
|
993 |
+
opset_allow = st.checkbox("Specify Opset Version", value=False)
|
994 |
+
|
995 |
+
if opset_allow:
|
996 |
+
with top_padding_opset:
|
997 |
+
utils.top_padding(2)
|
998 |
+
# Create a range of opset versions for the dropdown
|
999 |
+
opset_versions = list(range(1, onnx_opset_version() + 1))
|
1000 |
+
|
1001 |
+
with col2_opset:
|
1002 |
+
opset = st.selectbox(
|
1003 |
+
"Select Opset Version",
|
1004 |
+
opset_versions,
|
1005 |
+
index=len(opset_versions) - 1,
|
1006 |
+
)
|
1007 |
+
else:
|
1008 |
+
opset = None
|
1009 |
+
|
1010 |
+
st.markdown(
|
1011 |
+
"<div style='text-align: justify;'>Select the ONNX opset version for the export. "
|
1012 |
+
"Specifying an opset version can ensure compatibility with specific ONNX versions. "
|
1013 |
+
"The latest version is recommended to ensure the most up-to-date features and optimizations. "
|
1014 |
+
"If unsure, leave the checkbox unchecked to use the default opset version.</div>",
|
1015 |
+
unsafe_allow_html=True,
|
1016 |
+
)
|
1017 |
+
|
1018 |
+
# Spacer
|
1019 |
+
st.markdown("---")
|
1020 |
+
|
1021 |
+
# TensorRT workspace size
|
1022 |
+
st.write("**TensorRT Workspace Size (GB)**")
|
1023 |
+
workspace = st.number_input(
|
1024 |
+
"Workspace Size", min_value=1, max_value=32, value=4, step=1
|
1025 |
+
)
|
1026 |
+
st.markdown(
|
1027 |
+
"<div style='text-align: justify;'>Set the TensorRT workspace size in GB. A larger workspace can lead to more optimized models but requires more memory.</div>",
|
1028 |
+
unsafe_allow_html=True,
|
1029 |
+
)
|
1030 |
+
|
1031 |
+
# Spacer
|
1032 |
+
st.markdown("---")
|
1033 |
+
|
1034 |
+
# Add NMS for CoreML
|
1035 |
+
st.write("**Add NMS for CoreML**")
|
1036 |
+
nms = st.checkbox("Enable NMS", value=False)
|
1037 |
+
st.markdown(
|
1038 |
+
"<div style='text-align: justify;'>Enabling NMS (Non-Maximum Suppression) for CoreML models helps in reducing overlapping bounding boxes and improves the clarity of object detection results.</div>",
|
1039 |
+
unsafe_allow_html=True,
|
1040 |
+
)
|
1041 |
+
|
1042 |
+
# Padding
|
1043 |
+
utils.top_padding(2)
|
1044 |
+
|
1045 |
+
if selected_training == "Object Detection":
|
1046 |
+
model_path = os.path.join(
|
1047 |
+
get_path("models"), selected_model.lower() + ".pt"
|
1048 |
+
)
|
1049 |
+
task = "detect"
|
1050 |
+
elif selected_training == "Instance Segmentation":
|
1051 |
+
model_path = os.path.join(
|
1052 |
+
get_path("models"), selected_model.lower() + "-seg.pt"
|
1053 |
+
)
|
1054 |
+
task = "segment"
|
1055 |
+
|
1056 |
+
export_settings = {
|
1057 |
+
"format": None if export_format == "Only PyTorch" else export_format,
|
1058 |
+
"keras": keras,
|
1059 |
+
"optimize": optimize,
|
1060 |
+
"half": half,
|
1061 |
+
"int8": int8,
|
1062 |
+
"dynamic": dynamic,
|
1063 |
+
"simplify": simplify,
|
1064 |
+
"opset": opset,
|
1065 |
+
"workspace": workspace,
|
1066 |
+
"nms": nms,
|
1067 |
+
}
|
1068 |
+
|
1069 |
+
return {
|
1070 |
+
"model_path": model_path,
|
1071 |
+
"task": task,
|
1072 |
+
"model": selected_model,
|
1073 |
+
"time": time,
|
1074 |
+
"epochs": epochs,
|
1075 |
+
"patience": patience,
|
1076 |
+
"batch": batch,
|
1077 |
+
"imgsz": imgsz,
|
1078 |
+
"cache": cache,
|
1079 |
+
"optimizer": optimizer,
|
1080 |
+
"amp": amp,
|
1081 |
+
"deterministic": deterministic,
|
1082 |
+
"rect": rect,
|
1083 |
+
"cos_lr": cos_lr,
|
1084 |
+
"freeze": freeze,
|
1085 |
+
"lr0": lr0,
|
1086 |
+
"lrf": lrf,
|
1087 |
+
"momentum": momentum,
|
1088 |
+
"weight_decay": weight_decay,
|
1089 |
+
"warmup_epochs": warmup_epochs,
|
1090 |
+
"warmup_momentum": warmup_momentum,
|
1091 |
+
"warmup_bias_lr": warmup_bias_lr,
|
1092 |
+
"box": box,
|
1093 |
+
"cls": cls,
|
1094 |
+
"dfl": dfl,
|
1095 |
+
"label_smoothing": label_smoothing,
|
1096 |
+
"nbs": nbs,
|
1097 |
+
"overlap_mask": overlap_mask,
|
1098 |
+
"mask_ratio": mask_ratio,
|
1099 |
+
"dropout": dropout,
|
1100 |
+
"val": val,
|
1101 |
+
"plots": plots,
|
1102 |
+
"conf": conf,
|
1103 |
+
"iou": iou,
|
1104 |
+
"max_det": max_det,
|
1105 |
+
"half": half,
|
1106 |
+
"export_settings": export_settings,
|
1107 |
+
}
|
1108 |
+
|
1109 |
+
|
1110 |
+
# Function to generate python code for model training
|
1111 |
+
def generate_python_code_model_training(training_configuration):
|
1112 |
+
# Copy the original configuration and update with additional parameters
|
1113 |
+
training_configuration_code = training_configuration.copy()
|
1114 |
+
training_configuration_code["data"] = r".\config.yaml" # Path to config file
|
1115 |
+
training_configuration_code["save_dir"] = r".\output\train" # Output directory
|
1116 |
+
training_configuration_code["pretrained"] = True # Use a pretrained model
|
1117 |
+
training_configuration_code["save"] = True # Save the trained model
|
1118 |
+
training_configuration_code["save_period"] = -1 # Save period configuration
|
1119 |
+
training_configuration_code["augment"] = False # Augmentation setting
|
1120 |
+
training_configuration_code["seed"] = 0 # Seed for reproducibility
|
1121 |
+
training_configuration_code["verbose"] = True # Verbose output
|
1122 |
+
training_configuration_code["single_cls"] = False # Single class setting
|
1123 |
+
training_configuration_code["resume"] = False # Resume training setting
|
1124 |
+
training_configuration_code["exist_ok"] = True # Overwrite existing files
|
1125 |
+
training_configuration_code["project"] = r".\output" # Project directory
|
1126 |
+
training_configuration_code["name"] = "train" # Project name
|
1127 |
+
|
1128 |
+
# Extract the model name from the model path
|
1129 |
+
model_name = training_configuration_code["model_path"].split("\\")[-1]
|
1130 |
+
|
1131 |
+
# Start with necessary library imports and model initialization
|
1132 |
+
code_str = "# Importing necessary libraries\n"
|
1133 |
+
code_str += "from ultralytics import YOLO\n\n"
|
1134 |
+
|
1135 |
+
# Initialize the YOLO model
|
1136 |
+
code_str += f"# Initialize the YOLO model '{model_name}'\n"
|
1137 |
+
code_str += f"model = YOLO('{model_name}')\n"
|
1138 |
+
|
1139 |
+
# Add the model training code
|
1140 |
+
code_str += "\n# Start the training process\n"
|
1141 |
+
code_str += "model.train(\n"
|
1142 |
+
for key, value in training_configuration_code.items():
|
1143 |
+
if key not in [
|
1144 |
+
"model_path",
|
1145 |
+
"model",
|
1146 |
+
"export_settings",
|
1147 |
+
]: # Exclude specific keys
|
1148 |
+
code_str += f" {key}={value},\n"
|
1149 |
+
|
1150 |
+
code_str = code_str.rstrip(",\n") + "\n)\n"
|
1151 |
+
|
1152 |
+
# Add model export code
|
1153 |
+
code_str += "\n# Model export process\n"
|
1154 |
+
code_str += "model.export(\n"
|
1155 |
+
for key, value in training_configuration_code["export_settings"].items():
|
1156 |
+
if key == "format" and value is None:
|
1157 |
+
continue # Skip format if it's None
|
1158 |
+
code_str += f" {key}={value},\n"
|
1159 |
+
code_str = code_str.rstrip(",\n") + "\n)\n"
|
1160 |
+
|
1161 |
+
return code_str
|
1162 |
+
|
1163 |
+
|
1164 |
+
# Function to overwrites a Python file with new code
|
1165 |
+
def overwrite_python_file(code_str, file_path):
|
1166 |
+
# Open the file in write mode, which automatically deletes old content
|
1167 |
+
with open(file_path, "w") as file:
|
1168 |
+
file.write(code_str)
|
1169 |
+
|
1170 |
+
|
1171 |
+
# Function to generate a downloadable file
|
1172 |
+
def display_code_and_download_button(generated_code):
|
1173 |
+
# Display the generated code in Streamlit with description and download button in columns
|
1174 |
+
with st.expander("Plug and Play Code"):
|
1175 |
+
col1, col2 = st.columns([7, 3])
|
1176 |
+
|
1177 |
+
with col1:
|
1178 |
+
st.markdown(
|
1179 |
+
"""
|
1180 |
+
### Description of the Code Pipeline
|
1181 |
+
"""
|
1182 |
+
)
|
1183 |
+
|
1184 |
+
st.markdown(
|
1185 |
+
"""
|
1186 |
+
<div style='text-align: justify;'>
|
1187 |
+
This Python script is configured for training a YOLO model. It includes necessary configurations and parameters for a custom YOLO model training session.
|
1188 |
+
|
1189 |
+
**To use this script:**
|
1190 |
+
- Ensure you have the necessary dependencies installed.
|
1191 |
+
- Place your image and label files in the `'datasets/train'`, `'datasets/val'`, and `'datasets/test'` folders respectively.
|
1192 |
+
- The `'config.yaml'` file and the training script are set up based on your provided configurations.
|
1193 |
+
|
1194 |
+
### Python Code
|
1195 |
+
</div>
|
1196 |
+
""",
|
1197 |
+
unsafe_allow_html=True,
|
1198 |
+
)
|
1199 |
+
|
1200 |
+
# Display python code
|
1201 |
+
st.code(generated_code, language="python")
|
1202 |
+
|
1203 |
+
# Determine the main directory path
|
1204 |
+
main_directory_path = os.path.dirname(
|
1205 |
+
os.path.dirname(os.path.abspath(__file__))
|
1206 |
+
)
|
1207 |
+
|
1208 |
+
# Overwrites a Python file with new code
|
1209 |
+
overwrite_python_file(
|
1210 |
+
generated_code,
|
1211 |
+
os.path.join(
|
1212 |
+
main_directory_path,
|
1213 |
+
"model_data",
|
1214 |
+
"model_training_code_pipline",
|
1215 |
+
"model_training.py",
|
1216 |
+
),
|
1217 |
+
)
|
1218 |
+
|
1219 |
+
# Determine the main directory path
|
1220 |
+
main_directory_path = os.path.dirname(
|
1221 |
+
os.path.dirname(os.path.abspath(__file__))
|
1222 |
+
)
|
1223 |
+
|
1224 |
+
# Prepare a ZIP file of the training output folder in memory for download
|
1225 |
+
zip_bytes_io = zip_folder_to_bytesio(
|
1226 |
+
os.path.join(
|
1227 |
+
main_directory_path, "model_data", "model_training_code_pipline"
|
1228 |
+
)
|
1229 |
+
)
|
1230 |
+
|
1231 |
+
with col2:
|
1232 |
+
# Create a button for downloading the training pipeline
|
1233 |
+
st.download_button(
|
1234 |
+
label="Download Training Pipeline",
|
1235 |
+
data=zip_bytes_io,
|
1236 |
+
file_name="model_training_code.zip",
|
1237 |
+
mime="application/zip",
|
1238 |
+
use_container_width=True,
|
1239 |
+
)
|
1240 |
+
|
1241 |
+
|
1242 |
+
# Function to generates a YOLO model training code snippet and displays it with a download button
|
1243 |
+
def generate_and_display_yolo_training_code(class_labels, training_configuration):
|
1244 |
+
# Determine the main directory path
|
1245 |
+
main_directory_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
1246 |
+
|
1247 |
+
# Construct the path to the config file directory
|
1248 |
+
config_file_path = os.path.join(
|
1249 |
+
main_directory_path, "model_data", "model_training_code_pipline"
|
1250 |
+
)
|
1251 |
+
|
1252 |
+
# Define the path to the dataset directory
|
1253 |
+
dataset_directory_path = "./datasets"
|
1254 |
+
|
1255 |
+
# Create YOLO config file using provided class labels and dataset directory
|
1256 |
+
create_yolo_config_file(config_file_path, class_labels, dataset_directory_path)
|
1257 |
+
|
1258 |
+
# Generate the Python code for YOLO model training
|
1259 |
+
generated_code = generate_python_code_model_training(training_configuration)
|
1260 |
+
|
1261 |
+
# Display the generated code and a download button
|
1262 |
+
display_code_and_download_button(generated_code)
|
1263 |
+
|
1264 |
+
|
1265 |
+
# Function to create a yolo config file
|
1266 |
+
def create_yolo_config_file(
|
1267 |
+
config_file_path, class_labels, dataset_directory_path=None
|
1268 |
+
):
|
1269 |
+
if dataset_directory_path is None:
|
1270 |
+
dataset_directory_path = os.path.join(config_file_path, "datasets")
|
1271 |
+
|
1272 |
+
# Number of classes
|
1273 |
+
num_classes = len(class_labels)
|
1274 |
+
|
1275 |
+
# Create the configuration content
|
1276 |
+
config_content = f"""path: {dataset_directory_path} # Path to the dataset directory
|
1277 |
+
train: train # Path to the training set directory
|
1278 |
+
val: val # Path to the validation set directory
|
1279 |
+
test: test # Path to the testing set directory
|
1280 |
+
nc: {num_classes} # Number of classes
|
1281 |
+
names: {class_labels} # List of class names
|
1282 |
+
"""
|
1283 |
+
|
1284 |
+
# Write the configuration to a file
|
1285 |
+
with open(os.path.join(config_file_path, "config.yaml"), "w") as file:
|
1286 |
+
file.write(config_content)
|
1287 |
+
|
1288 |
+
|
1289 |
+
# Function to delete and recreate a folder
|
1290 |
+
def delete_and_recreate_folder(folder_path):
|
1291 |
+
try:
|
1292 |
+
# Use shutil.rmtree to delete the folder and its contents
|
1293 |
+
shutil.rmtree(folder_path)
|
1294 |
+
# Recreate the folder at the same location
|
1295 |
+
os.makedirs(folder_path)
|
1296 |
+
except Exception as e:
|
1297 |
+
print(f"Error deleting or recreating folder {folder_path}: {e}")
|
1298 |
+
|
1299 |
+
|
1300 |
+
# Function to read csv and get values
|
1301 |
+
def read_csv_and_get_values(csv_file_path):
|
1302 |
+
# Read the CSV file into a pandas DataFrame
|
1303 |
+
df = pd.read_csv(csv_file_path)
|
1304 |
+
|
1305 |
+
# Initialize an empty dictionary to store the results
|
1306 |
+
result_dict = {}
|
1307 |
+
|
1308 |
+
# Iterate through the columns of the DataFrame
|
1309 |
+
for column in df.columns:
|
1310 |
+
# Remove leading and trailing spaces from the column name
|
1311 |
+
clean_column_name = column.strip()
|
1312 |
+
|
1313 |
+
# Get the values in the column
|
1314 |
+
column_values = df[column].astype(float)
|
1315 |
+
|
1316 |
+
# Add the cleaned column name and values to the result dictionary
|
1317 |
+
result_dict[clean_column_name] = np.array(column_values)
|
1318 |
+
|
1319 |
+
return result_dict
|
1320 |
+
|
1321 |
+
|
1322 |
+
# Global variables
|
1323 |
+
plot_container = None
|
1324 |
+
val_dataframe_container = None
|
1325 |
+
progress_bar = None
|
1326 |
+
progress_text = None
|
1327 |
+
|
1328 |
+
|
1329 |
+
# Function to define a custom callback function for on_pretrain_routine_start
|
1330 |
+
def on_pretrain_routine_start(trainer):
|
1331 |
+
global progress_text, progress_bar
|
1332 |
+
progress_bar = st.empty()
|
1333 |
+
progress_text = st.empty()
|
1334 |
+
progress_text.info(
|
1335 |
+
"Loading selected model...",
|
1336 |
+
icon="✅",
|
1337 |
+
)
|
1338 |
+
|
1339 |
+
|
1340 |
+
# Function to define a custom callback function for on_train_start
|
1341 |
+
def on_train_start(trainer):
|
1342 |
+
global progress_bar, progress_text
|
1343 |
+
progress_bar = st.progress(0)
|
1344 |
+
progress_text.info(
|
1345 |
+
"Training Started...",
|
1346 |
+
icon="✅",
|
1347 |
+
)
|
1348 |
+
|
1349 |
+
|
1350 |
+
# Function to display metrics plot
|
1351 |
+
st.cache_resource(show_spinner=False)
|
1352 |
+
|
1353 |
+
|
1354 |
+
def display_metrics_plot(output_data):
|
1355 |
+
global plot_container
|
1356 |
+
|
1357 |
+
# Extract data for each metric
|
1358 |
+
epoch_history = output_data.get("epoch")
|
1359 |
+
|
1360 |
+
# Extract loss histories
|
1361 |
+
train_box_loss_history = output_data.get("train/box_loss")
|
1362 |
+
train_cls_loss_history = output_data.get("train/cls_loss")
|
1363 |
+
train_dfl_loss_history = output_data.get("train/dfl_loss")
|
1364 |
+
train_seg_loss_history = output_data.get("train/seg_loss")
|
1365 |
+
val_box_loss_history = output_data.get("val/box_loss")
|
1366 |
+
val_cls_loss_history = output_data.get("val/cls_loss")
|
1367 |
+
val_dfl_loss_history = output_data.get("val/dfl_loss")
|
1368 |
+
val_seg_loss_history = output_data.get("val/seg_loss")
|
1369 |
+
|
1370 |
+
if train_seg_loss_history is None:
|
1371 |
+
train_seg_loss_history = epoch_history * 0
|
1372 |
+
val_seg_loss_history = epoch_history * 0
|
1373 |
+
|
1374 |
+
# Extract precision, recall, and mAP histories for B and M box/mask
|
1375 |
+
precision_B_history = output_data.get("metrics/precision(B)")
|
1376 |
+
recall_B_history = output_data.get("metrics/recall(B)")
|
1377 |
+
mAP50_B_history = output_data.get("metrics/mAP50(B)")
|
1378 |
+
mAP50_95_B_history = output_data.get("metrics/mAP50-95(B)")
|
1379 |
+
precision_M_history = output_data.get("metrics/precision(M)")
|
1380 |
+
recall_M_history = output_data.get("metrics/recall(M)")
|
1381 |
+
mAP50_M_history = output_data.get("metrics/mAP50(M)")
|
1382 |
+
mAP50_95_M_history = output_data.get("metrics/mAP50-95(M)")
|
1383 |
+
|
1384 |
+
# Check for 'None' data and adjust the number of rows in the grid
|
1385 |
+
num_rows = 4
|
1386 |
+
subplot_titles = [
|
1387 |
+
"Precision B",
|
1388 |
+
"Recall B",
|
1389 |
+
"mAP50 B",
|
1390 |
+
"mAP50-95 B",
|
1391 |
+
"Precision R",
|
1392 |
+
"Recall R",
|
1393 |
+
"mAP50 R",
|
1394 |
+
"mAP50-95 R",
|
1395 |
+
"Train Box Loss",
|
1396 |
+
"Train Class Loss",
|
1397 |
+
"Train DFL Loss",
|
1398 |
+
"Train Seg Loss",
|
1399 |
+
"Val Box Loss",
|
1400 |
+
"Val Class Loss",
|
1401 |
+
"Val DFL Loss",
|
1402 |
+
"Val Seg Loss",
|
1403 |
+
]
|
1404 |
+
if precision_M_history is None:
|
1405 |
+
num_rows = 3
|
1406 |
+
subplot_titles = subplot_titles[0:4] + subplot_titles[8:]
|
1407 |
+
|
1408 |
+
# Create a subplot grid
|
1409 |
+
fig = make_subplots(
|
1410 |
+
rows=num_rows,
|
1411 |
+
cols=4,
|
1412 |
+
subplot_titles=subplot_titles,
|
1413 |
+
vertical_spacing=0.05,
|
1414 |
+
)
|
1415 |
+
|
1416 |
+
# Initialize row number
|
1417 |
+
row_number = 1
|
1418 |
+
|
1419 |
+
# Add precision, recall, mAP plots for B and R box/mask
|
1420 |
+
fig.add_trace(
|
1421 |
+
go.Scatter(
|
1422 |
+
x=epoch_history, y=precision_B_history, mode="lines", name="Precision B"
|
1423 |
+
),
|
1424 |
+
row=row_number,
|
1425 |
+
col=1,
|
1426 |
+
)
|
1427 |
+
fig.add_trace(
|
1428 |
+
go.Scatter(x=epoch_history, y=recall_B_history, mode="lines", name="Recall B"),
|
1429 |
+
row=row_number,
|
1430 |
+
col=2,
|
1431 |
+
)
|
1432 |
+
fig.add_trace(
|
1433 |
+
go.Scatter(x=epoch_history, y=mAP50_B_history, mode="lines", name="mAP50 B"),
|
1434 |
+
row=row_number,
|
1435 |
+
col=3,
|
1436 |
+
)
|
1437 |
+
fig.add_trace(
|
1438 |
+
go.Scatter(
|
1439 |
+
x=epoch_history, y=mAP50_95_B_history, mode="lines", name="mAP50-95 B"
|
1440 |
+
),
|
1441 |
+
row=row_number,
|
1442 |
+
col=4,
|
1443 |
+
)
|
1444 |
+
|
1445 |
+
if precision_M_history is not None:
|
1446 |
+
# Increment row number
|
1447 |
+
row_number += 1
|
1448 |
+
|
1449 |
+
fig.add_trace(
|
1450 |
+
go.Scatter(
|
1451 |
+
x=epoch_history, y=precision_M_history, mode="lines", name="Precision R"
|
1452 |
+
),
|
1453 |
+
row=row_number,
|
1454 |
+
col=1,
|
1455 |
+
)
|
1456 |
+
fig.add_trace(
|
1457 |
+
go.Scatter(
|
1458 |
+
x=epoch_history, y=recall_M_history, mode="lines", name="Recall R"
|
1459 |
+
),
|
1460 |
+
row=row_number,
|
1461 |
+
col=2,
|
1462 |
+
)
|
1463 |
+
fig.add_trace(
|
1464 |
+
go.Scatter(
|
1465 |
+
x=epoch_history, y=mAP50_M_history, mode="lines", name="mAP50 R"
|
1466 |
+
),
|
1467 |
+
row=row_number,
|
1468 |
+
col=3,
|
1469 |
+
)
|
1470 |
+
fig.add_trace(
|
1471 |
+
go.Scatter(
|
1472 |
+
x=epoch_history, y=mAP50_95_M_history, mode="lines", name="mAP50-95 R"
|
1473 |
+
),
|
1474 |
+
row=row_number,
|
1475 |
+
col=4,
|
1476 |
+
)
|
1477 |
+
|
1478 |
+
# Increment row number
|
1479 |
+
row_number += 1
|
1480 |
+
|
1481 |
+
# Add loss plots
|
1482 |
+
fig.add_trace(
|
1483 |
+
go.Scatter(
|
1484 |
+
x=epoch_history,
|
1485 |
+
y=train_box_loss_history,
|
1486 |
+
mode="lines",
|
1487 |
+
name="Train Box Loss",
|
1488 |
+
),
|
1489 |
+
row=row_number,
|
1490 |
+
col=1,
|
1491 |
+
)
|
1492 |
+
fig.add_trace(
|
1493 |
+
go.Scatter(
|
1494 |
+
x=epoch_history,
|
1495 |
+
y=train_cls_loss_history,
|
1496 |
+
mode="lines",
|
1497 |
+
name="Train Class Loss",
|
1498 |
+
),
|
1499 |
+
row=row_number,
|
1500 |
+
col=2,
|
1501 |
+
)
|
1502 |
+
fig.add_trace(
|
1503 |
+
go.Scatter(
|
1504 |
+
x=epoch_history,
|
1505 |
+
y=train_dfl_loss_history,
|
1506 |
+
mode="lines",
|
1507 |
+
name="Train DFL Loss",
|
1508 |
+
),
|
1509 |
+
row=row_number,
|
1510 |
+
col=3,
|
1511 |
+
)
|
1512 |
+
fig.add_trace(
|
1513 |
+
go.Scatter(
|
1514 |
+
x=epoch_history,
|
1515 |
+
y=train_seg_loss_history,
|
1516 |
+
mode="lines",
|
1517 |
+
name="Train Seg Loss",
|
1518 |
+
),
|
1519 |
+
row=row_number,
|
1520 |
+
col=4,
|
1521 |
+
)
|
1522 |
+
|
1523 |
+
# Increment row number
|
1524 |
+
row_number += 1
|
1525 |
+
|
1526 |
+
fig.add_trace(
|
1527 |
+
go.Scatter(
|
1528 |
+
x=epoch_history, y=val_box_loss_history, mode="lines", name="Val Box Loss"
|
1529 |
+
),
|
1530 |
+
row=row_number,
|
1531 |
+
col=1,
|
1532 |
+
)
|
1533 |
+
fig.add_trace(
|
1534 |
+
go.Scatter(
|
1535 |
+
x=epoch_history, y=val_cls_loss_history, mode="lines", name="Val Class Loss"
|
1536 |
+
),
|
1537 |
+
row=row_number,
|
1538 |
+
col=2,
|
1539 |
+
)
|
1540 |
+
fig.add_trace(
|
1541 |
+
go.Scatter(
|
1542 |
+
x=epoch_history, y=val_dfl_loss_history, mode="lines", name="Val DFL Loss"
|
1543 |
+
),
|
1544 |
+
row=row_number,
|
1545 |
+
col=3,
|
1546 |
+
)
|
1547 |
+
fig.add_trace(
|
1548 |
+
go.Scatter(
|
1549 |
+
x=epoch_history,
|
1550 |
+
y=val_seg_loss_history,
|
1551 |
+
mode="lines",
|
1552 |
+
name="Val Seg Loss",
|
1553 |
+
),
|
1554 |
+
row=row_number,
|
1555 |
+
col=4,
|
1556 |
+
)
|
1557 |
+
|
1558 |
+
# Check if the plot container is already initialized
|
1559 |
+
if plot_container is None:
|
1560 |
+
plot_container = st.empty()
|
1561 |
+
|
1562 |
+
# Update layout
|
1563 |
+
fig.update_layout(
|
1564 |
+
height=1200,
|
1565 |
+
width=1600,
|
1566 |
+
title_text="Metrics",
|
1567 |
+
legend=dict(orientation="h", yanchor="bottom", xanchor="left"),
|
1568 |
+
)
|
1569 |
+
|
1570 |
+
# Display the updated plot in the same container
|
1571 |
+
plot_container.plotly_chart(fig, use_container_width=True)
|
1572 |
+
|
1573 |
+
|
1574 |
+
# Function to define a custom callback function for on_fit_epoch_end
|
1575 |
+
def on_fit_epoch_end(trainer):
|
1576 |
+
current_epoch = int(trainer.epoch)
|
1577 |
+
total_epochs = int(trainer.epochs)
|
1578 |
+
|
1579 |
+
# Define the path to the output CSV
|
1580 |
+
output_csv_path = os.path.join(get_path("output"), "train", "results.csv")
|
1581 |
+
|
1582 |
+
# Read the CSV data
|
1583 |
+
st.session_state["plot_data"] = read_csv_and_get_values(output_csv_path)
|
1584 |
+
|
1585 |
+
# Call a function to update the plot using this data
|
1586 |
+
display_metrics_plot(st.session_state["plot_data"])
|
1587 |
+
|
1588 |
+
# Update progress bar and text
|
1589 |
+
progress_bar.progress((current_epoch + 1) / total_epochs)
|
1590 |
+
progress_text.write(f"Epoch {(current_epoch + 1)}/{total_epochs}")
|
1591 |
+
|
1592 |
+
|
1593 |
+
# Function to define a custom callback function for on_train_end
|
1594 |
+
def on_train_end(trainer):
|
1595 |
+
global progress_bar, progress_text
|
1596 |
+
progress_bar.empty()
|
1597 |
+
progress_text.info(
|
1598 |
+
"Best and last model save completed successfully.",
|
1599 |
+
icon="✅",
|
1600 |
+
)
|
1601 |
+
|
1602 |
+
|
1603 |
+
# Function to add various callbacks to the YOLO model for different stages of the training process
|
1604 |
+
def callback_add(model):
|
1605 |
+
# Add a callback to be triggered at the start of the pre-training routine
|
1606 |
+
model.add_callback("on_pretrain_routine_start", on_pretrain_routine_start)
|
1607 |
+
|
1608 |
+
# Add a callback to be triggered at the start of the training
|
1609 |
+
model.add_callback("on_train_start", on_train_start)
|
1610 |
+
|
1611 |
+
# Add a callback to be triggered at the end of each training epoch
|
1612 |
+
model.add_callback("on_fit_epoch_end", on_fit_epoch_end)
|
1613 |
+
|
1614 |
+
# Add a callback to be triggered at the end of the training process
|
1615 |
+
model.add_callback("on_train_end", on_train_end)
|
1616 |
+
|
1617 |
+
|
1618 |
+
# Function to zip a folder and all its subfolders and return a BytesIO object
|
1619 |
+
def zip_folder_to_bytesio(folder_path):
|
1620 |
+
bytes_io = io.BytesIO()
|
1621 |
+
with zipfile.ZipFile(bytes_io, "w", zipfile.ZIP_DEFLATED) as zipf:
|
1622 |
+
folder_path_abs = os.path.abspath(folder_path)
|
1623 |
+
|
1624 |
+
for root, dirs, files in os.walk(folder_path):
|
1625 |
+
# Calculate the relative path from the folder_path
|
1626 |
+
folder_rel_path = os.path.relpath(root, folder_path_abs)
|
1627 |
+
|
1628 |
+
# If the directory is empty, add the directory itself
|
1629 |
+
if not dirs and not files:
|
1630 |
+
# ZIP format requires a trailing slash for empty directories
|
1631 |
+
zip_dir_path = f"{folder_rel_path}/" if folder_rel_path != "." else ""
|
1632 |
+
zipf.write(root, zip_dir_path)
|
1633 |
+
|
1634 |
+
for file in files:
|
1635 |
+
file_path = os.path.join(root, file)
|
1636 |
+
# Construct the path within the zip file
|
1637 |
+
zip_file_path = (
|
1638 |
+
os.path.join(folder_rel_path, file)
|
1639 |
+
if folder_rel_path != "."
|
1640 |
+
else file
|
1641 |
+
)
|
1642 |
+
zipf.write(file_path, zip_file_path)
|
1643 |
+
|
1644 |
+
bytes_io.seek(0) # Go to the start of the BytesIO buffer
|
1645 |
+
return bytes_io
|
1646 |
+
|
1647 |
+
|
1648 |
+
# Function to display Metrics Table
|
1649 |
+
st.cache_resource(show_spinner=False)
|
1650 |
+
|
1651 |
+
|
1652 |
+
def display_val_dataframe(val_dataframe):
|
1653 |
+
global val_dataframe_container
|
1654 |
+
|
1655 |
+
# Check if the dataframe container is already initialized
|
1656 |
+
if val_dataframe_container is None:
|
1657 |
+
val_dataframe_container = st.container()
|
1658 |
+
|
1659 |
+
# Display the updated dataframe in the same container
|
1660 |
+
with val_dataframe_container:
|
1661 |
+
# Display the message to indicate that the metrics table is ready
|
1662 |
+
st.markdown("**Metrics Table**", unsafe_allow_html=True)
|
1663 |
+
|
1664 |
+
# Display the DataFrame
|
1665 |
+
st.dataframe(val_dataframe)
|
1666 |
+
|
1667 |
+
|
1668 |
+
# Function to display the DataFrame
|
1669 |
+
def val_dataframe(model):
|
1670 |
+
# Placeholder for the initial message
|
1671 |
+
message = st.empty()
|
1672 |
+
message.markdown("**Generating Metrics Table...**", unsafe_allow_html=True)
|
1673 |
+
|
1674 |
+
# Extract the metrics from the model
|
1675 |
+
metrics = model.val()
|
1676 |
+
|
1677 |
+
# Extract the class indices and names
|
1678 |
+
class_index = metrics.ap_class_index
|
1679 |
+
class_names = metrics.names
|
1680 |
+
|
1681 |
+
# Extract precision, recall, and mAP values for the box (B) metrics
|
1682 |
+
precision_B_values = metrics.box.p
|
1683 |
+
recall_B_values = metrics.box.r
|
1684 |
+
mAP50_95_B_values = [metrics.box.maps[i] for i in class_index]
|
1685 |
+
|
1686 |
+
# Check if segmentation (mask) metrics exist
|
1687 |
+
try:
|
1688 |
+
metrics_mask = metrics.seg
|
1689 |
+
except:
|
1690 |
+
metrics_mask = False
|
1691 |
+
if metrics_mask:
|
1692 |
+
precision_M_values = metrics_mask.p
|
1693 |
+
recall_M_values = metrics_mask.r
|
1694 |
+
mAP50_95_M_values = [metrics_mask.maps[i] for i in class_index]
|
1695 |
+
|
1696 |
+
# Extract aggregated metrics from the results dictionary
|
1697 |
+
results_dict = metrics.results_dict
|
1698 |
+
|
1699 |
+
# Initialize lists for overall precision, recall, and mAP for box (B)
|
1700 |
+
precision_B = [results_dict.get("metrics/precision(B)")]
|
1701 |
+
recall_B = [results_dict.get("metrics/recall(B)")]
|
1702 |
+
mAP50_95_B = [results_dict.get("metrics/mAP50-95(B)")]
|
1703 |
+
|
1704 |
+
# Initialize lists for overall precision, recall, and mAP for mask (M) if available
|
1705 |
+
precision_M = [results_dict.get("metrics/precision(M)")] if metrics_mask else None
|
1706 |
+
recall_M = [results_dict.get("metrics/recall(M)")] if metrics_mask else None
|
1707 |
+
mAP50_95_M = [results_dict.get("metrics/mAP50-95(M)")] if metrics_mask else None
|
1708 |
+
|
1709 |
+
# Create a list of class names starting with "All" for the overall metrics
|
1710 |
+
name_list = ["All"] + [str(class_names[i]) for i in class_index]
|
1711 |
+
|
1712 |
+
# Extend the metrics lists with values for each class
|
1713 |
+
precision_B.extend(precision_B_values)
|
1714 |
+
recall_B.extend(recall_B_values)
|
1715 |
+
mAP50_95_B.extend(mAP50_95_B_values)
|
1716 |
+
|
1717 |
+
# If mask metrics are available, extend their lists with values for each class
|
1718 |
+
if metrics_mask:
|
1719 |
+
precision_M.extend(precision_M_values)
|
1720 |
+
recall_M.extend(recall_M_values)
|
1721 |
+
mAP50_95_M.extend(mAP50_95_M_values)
|
1722 |
+
|
1723 |
+
# Create a DataFrame with the computed metrics
|
1724 |
+
if metrics_mask:
|
1725 |
+
st.session_state["val_dataframe"] = pd.DataFrame(
|
1726 |
+
{
|
1727 |
+
"Class Name": name_list,
|
1728 |
+
"Precision (B)": precision_B,
|
1729 |
+
"Recall (B)": recall_B,
|
1730 |
+
"mAP50-95 (B)": mAP50_95_B,
|
1731 |
+
"Precision (M)": precision_M,
|
1732 |
+
"Recall (M)": recall_M,
|
1733 |
+
"mAP50-95 (M)": mAP50_95_M,
|
1734 |
+
}
|
1735 |
+
)
|
1736 |
+
else:
|
1737 |
+
st.session_state["val_dataframe"] = pd.DataFrame(
|
1738 |
+
{
|
1739 |
+
"Class Name": name_list,
|
1740 |
+
"Precision (B)": precision_B,
|
1741 |
+
"Recall (B)": recall_B,
|
1742 |
+
"mAP50-95 (B)": mAP50_95_B,
|
1743 |
+
}
|
1744 |
+
)
|
1745 |
+
|
1746 |
+
# Clear the initial message
|
1747 |
+
message.empty()
|
1748 |
+
|
1749 |
+
# Update the message to indicate that the metrics table is ready and Display the DataFrame
|
1750 |
+
display_val_dataframe(st.session_state["val_dataframe"])
|
1751 |
+
|
1752 |
+
|
1753 |
+
# Function to train the YOLO model
|
1754 |
+
def train_yolo_model(training_configuration):
|
1755 |
+
# Clear and recreate the output folder to ensure a fresh start
|
1756 |
+
delete_and_recreate_folder(get_path("output"))
|
1757 |
+
|
1758 |
+
# Initialize the YOLO model with the specified path from the training configuration
|
1759 |
+
model = YOLO(training_configuration["model_path"])
|
1760 |
+
|
1761 |
+
# Add any callbacks or additional configuration to the model
|
1762 |
+
callback_add(model)
|
1763 |
+
|
1764 |
+
# Train the model with the specified parameters
|
1765 |
+
model.train(
|
1766 |
+
task=training_configuration["task"],
|
1767 |
+
data=os.path.join(get_path("config"), "config.yaml"),
|
1768 |
+
epochs=training_configuration["epochs"],
|
1769 |
+
time=training_configuration["time"],
|
1770 |
+
patience=training_configuration["patience"],
|
1771 |
+
batch=training_configuration["batch"],
|
1772 |
+
imgsz=training_configuration["imgsz"],
|
1773 |
+
save=True,
|
1774 |
+
save_period=-1,
|
1775 |
+
cache=training_configuration["cache"],
|
1776 |
+
pretrained=True,
|
1777 |
+
optimizer=training_configuration["optimizer"],
|
1778 |
+
verbose=True,
|
1779 |
+
seed=0,
|
1780 |
+
deterministic=training_configuration["deterministic"],
|
1781 |
+
single_cls=False,
|
1782 |
+
rect=training_configuration["rect"],
|
1783 |
+
cos_lr=training_configuration["cos_lr"],
|
1784 |
+
resume=False,
|
1785 |
+
amp=training_configuration["amp"],
|
1786 |
+
fraction=1.0,
|
1787 |
+
freeze=training_configuration["freeze"],
|
1788 |
+
lr0=training_configuration["lr0"],
|
1789 |
+
lrf=training_configuration["lrf"],
|
1790 |
+
momentum=training_configuration["momentum"],
|
1791 |
+
weight_decay=training_configuration["weight_decay"],
|
1792 |
+
warmup_epochs=training_configuration["warmup_epochs"],
|
1793 |
+
warmup_momentum=training_configuration["warmup_momentum"],
|
1794 |
+
warmup_bias_lr=training_configuration["warmup_bias_lr"],
|
1795 |
+
box=training_configuration["box"],
|
1796 |
+
cls=training_configuration["cls"],
|
1797 |
+
dfl=training_configuration["dfl"],
|
1798 |
+
label_smoothing=training_configuration["label_smoothing"],
|
1799 |
+
nbs=training_configuration["nbs"],
|
1800 |
+
overlap_mask=training_configuration["overlap_mask"],
|
1801 |
+
mask_ratio=training_configuration["mask_ratio"],
|
1802 |
+
dropout=training_configuration["dropout"],
|
1803 |
+
val=training_configuration["val"],
|
1804 |
+
plots=training_configuration["plots"],
|
1805 |
+
save_dir=os.path.join(get_path("output"), "train"),
|
1806 |
+
project=get_path("output"),
|
1807 |
+
name="train",
|
1808 |
+
augment=False,
|
1809 |
+
exist_ok=True,
|
1810 |
+
)
|
1811 |
+
|
1812 |
+
return model
|
1813 |
+
|
1814 |
+
|
1815 |
+
# Function to export the model with the given parameters
|
1816 |
+
def export_model_with_parameters(model, export_params):
|
1817 |
+
global progress_text
|
1818 |
+
|
1819 |
+
if export_params["format"] is not None:
|
1820 |
+
# Informing the user that the export process has started
|
1821 |
+
progress_text.info(
|
1822 |
+
"Starting the export process with the specified settings.",
|
1823 |
+
icon="✅",
|
1824 |
+
)
|
1825 |
+
|
1826 |
+
# Perform the model export
|
1827 |
+
model.export(
|
1828 |
+
format=export_params["format"],
|
1829 |
+
keras=export_params["keras"],
|
1830 |
+
optimize=export_params["optimize"],
|
1831 |
+
half=export_params["half"],
|
1832 |
+
int8=export_params["int8"],
|
1833 |
+
dynamic=export_params["dynamic"],
|
1834 |
+
simplify=export_params["simplify"],
|
1835 |
+
opset=export_params["opset"],
|
1836 |
+
workspace=export_params["workspace"],
|
1837 |
+
nms=export_params["nms"],
|
1838 |
+
)
|
1839 |
+
|
1840 |
+
# Informing the user that the export process has completed successfully
|
1841 |
+
progress_text.info(
|
1842 |
+
"The model has been successfully saved using the specified export settings.",
|
1843 |
+
icon="✅",
|
1844 |
+
)
|
1845 |
+
|
1846 |
+
|
1847 |
+
# Function to start the YOLO model training process
|
1848 |
+
def start_yolo_training(selected_training, class_labels):
|
1849 |
+
global plot_container, val_dataframe_container
|
1850 |
+
|
1851 |
+
# Retrieve the training configuration based on the user's selection
|
1852 |
+
training_configuration = get_training_validation_export_configuration(
|
1853 |
+
selected_training
|
1854 |
+
)
|
1855 |
+
|
1856 |
+
# Generates a YOLO model training code snippet and displays it with a download button
|
1857 |
+
generate_and_display_yolo_training_code(class_labels, training_configuration)
|
1858 |
+
|
1859 |
+
# Create two columns
|
1860 |
+
col1, col2 = st.columns(2)
|
1861 |
+
# When the "Start Training" button is clicked in the first column
|
1862 |
+
if col1.button("Start Training", use_container_width=True):
|
1863 |
+
plot_container = None
|
1864 |
+
val_dataframe_container = None
|
1865 |
+
|
1866 |
+
with st.spinner("Training in Progress..."):
|
1867 |
+
# Train the YOLO model using the provided configuration
|
1868 |
+
trained_model = train_yolo_model(training_configuration)
|
1869 |
+
|
1870 |
+
# Export the model with the given parameters
|
1871 |
+
export_model_with_parameters(
|
1872 |
+
trained_model, training_configuration["export_settings"]
|
1873 |
+
)
|
1874 |
+
|
1875 |
+
# Display the validation results in a DataFrame after training
|
1876 |
+
val_dataframe(trained_model)
|
1877 |
+
|
1878 |
+
elif "plot_data" in st.session_state and "val_dataframe" in st.session_state:
|
1879 |
+
plot_container = None
|
1880 |
+
val_dataframe_container = None
|
1881 |
+
|
1882 |
+
# Display metrics plot and table if already exist
|
1883 |
+
display_metrics_plot(st.session_state["plot_data"])
|
1884 |
+
display_val_dataframe(st.session_state["val_dataframe"])
|
1885 |
+
|
1886 |
+
# Prepare a ZIP file of the training output folder in memory for download
|
1887 |
+
zip_bytes_io = zip_folder_to_bytesio(os.path.join(get_path("output"), "train"))
|
1888 |
+
|
1889 |
+
# Provide a button in the second column to download the ZIP file
|
1890 |
+
col2.download_button(
|
1891 |
+
label="Download",
|
1892 |
+
data=zip_bytes_io,
|
1893 |
+
file_name="model_training_output.zip",
|
1894 |
+
mime="application/zip",
|
1895 |
+
use_container_width=True,
|
1896 |
+
)
|
README.md
CHANGED
@@ -1,14 +1,12 @@
|
|
1 |
-
---
|
2 |
-
title: CV Accelerator
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk: streamlit
|
7 |
-
sdk_version: 1.
|
8 |
-
app_file:
|
9 |
-
pinned: false
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
---
|
2 |
+
title: CV Accelerator
|
3 |
+
emoji: 📷
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.32.0
|
8 |
+
app_file: Welcome.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
Welcome.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Set the page config
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
st.set_page_config(
|
5 |
+
page_title="Welcome",
|
6 |
+
page_icon=":open_file_folder:",
|
7 |
+
layout="wide",
|
8 |
+
)
|
config.toml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[theme]
|
2 |
+
primaryColor = "#101276"
|
3 |
+
backgroundColor = "#FFFFFF"
|
4 |
+
textColor = "#006EC0"
|
5 |
+
secondaryBackgroundColor = "#D2E5F2"
|
6 |
+
font = "sans serif"
|
7 |
+
|
8 |
+
[server]
|
9 |
+
maxUploadSize = 10000
|
model_data/input_files/config.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
path: C:\Users\SamkeetSangai\OneDrive - Blend 360\Work\CoEs\Computer Vision\cv_accelerator\model_data\input_files\datasets # Path to the dataset directory
|
2 |
+
train: train # Path to the training set directory
|
3 |
+
val: val # Path to the validation set directory
|
4 |
+
test: test # Path to the testing set directory
|
5 |
+
nc: 80 # Number of classes
|
6 |
+
names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] # List of class names
|
model_data/model_training_code_pipline/config.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
path: ./datasets # Path to the dataset directory
|
2 |
+
train: train # Path to the training set directory
|
3 |
+
val: val # Path to the validation set directory
|
4 |
+
test: test # Path to the testing set directory
|
5 |
+
nc: 80 # Number of classes
|
6 |
+
names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] # List of class names
|
model_data/model_training_code_pipline/model_training.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Importing necessary libraries
|
2 |
+
from ultralytics import YOLO
|
3 |
+
|
4 |
+
# Initialize the YOLO model 'yolov8n.pt'
|
5 |
+
model = YOLO('yolov8n.pt')
|
6 |
+
|
7 |
+
# Start the training process
|
8 |
+
model.train(
|
9 |
+
task=detect,
|
10 |
+
time=None,
|
11 |
+
epochs=2,
|
12 |
+
patience=None,
|
13 |
+
batch=4,
|
14 |
+
imgsz=640,
|
15 |
+
cache=False,
|
16 |
+
optimizer=auto,
|
17 |
+
amp=True,
|
18 |
+
deterministic=False,
|
19 |
+
rect=False,
|
20 |
+
cos_lr=False,
|
21 |
+
freeze=None,
|
22 |
+
lr0=0.01,
|
23 |
+
lrf=0.01,
|
24 |
+
momentum=0.937,
|
25 |
+
weight_decay=0.0005,
|
26 |
+
warmup_epochs=3.0,
|
27 |
+
warmup_momentum=0.8,
|
28 |
+
warmup_bias_lr=0.1,
|
29 |
+
box=7.5,
|
30 |
+
cls=0.5,
|
31 |
+
dfl=1.5,
|
32 |
+
label_smoothing=0.0,
|
33 |
+
nbs=64,
|
34 |
+
overlap_mask=True,
|
35 |
+
mask_ratio=4,
|
36 |
+
dropout=0.0,
|
37 |
+
val=True,
|
38 |
+
plots=True,
|
39 |
+
conf=0.001,
|
40 |
+
iou=0.6,
|
41 |
+
max_det=300,
|
42 |
+
half=False,
|
43 |
+
data=.\config.yaml,
|
44 |
+
save_dir=.\output\train,
|
45 |
+
pretrained=True,
|
46 |
+
save=True,
|
47 |
+
save_period=-1,
|
48 |
+
augment=False,
|
49 |
+
seed=0,
|
50 |
+
verbose=True,
|
51 |
+
single_cls=False,
|
52 |
+
resume=False,
|
53 |
+
exist_ok=True,
|
54 |
+
project=.\output,
|
55 |
+
name=train
|
56 |
+
)
|
57 |
+
|
58 |
+
# Model export process
|
59 |
+
model.export(
|
60 |
+
keras=False,
|
61 |
+
optimize=False,
|
62 |
+
half=False,
|
63 |
+
int8=False,
|
64 |
+
dynamic=False,
|
65 |
+
simplify=False,
|
66 |
+
opset=None,
|
67 |
+
workspace=4,
|
68 |
+
nms=False
|
69 |
+
)
|
model_data/models/yolov8l-seg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dfa891dab47031c12ac82cb1b883af4ab9efbc732118cc04604bfbf107f9dfa8
|
3 |
+
size 92391859
|
model_data/models/yolov8l.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:18218ea4798da042d9862e6029ca9531adbd40ace19b6c9a75e2e28f1adf30cc
|
3 |
+
size 87769683
|
model_data/models/yolov8m-seg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f9fc740ca0824e14b44681d491dc601efa664ec6ecea9a870acf876053826448
|
3 |
+
size 54899779
|
model_data/models/yolov8m.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6c25b0b63b1a433843f06d821a9ac1deb8d5805f74f0f38772c7308c5adc55a5
|
3 |
+
size 52117635
|
model_data/models/yolov8n-seg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d39e867b2c3a5dbc1aa764411544b475cb14727bf6af1ec46c238f8bb1351ab9
|
3 |
+
size 7054355
|
model_data/models/yolov8n.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:31e20dde3def09e2cf938c7be6fe23d9150bbbe503982af13345706515f2ef95
|
3 |
+
size 6534387
|
model_data/models/yolov8s-seg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c035c6c5f9c48ee962518ef648854c2cc5e60fa404fac443e17d306fdda16543
|
3 |
+
size 23897299
|
model_data/models/yolov8s.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:268e5bb54c640c96c3510224833bc2eeacab4135c6deb41502156e39986b562d
|
3 |
+
size 22573363
|
model_data/models/yolov8x-seg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d63cbfa5764867c0066bedfa43cf2dcd90a412a1de44b2e238c43978a9d28ea6
|
3 |
+
size 144076467
|
model_data/models/yolov8x.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c4d5a3f000d771762f03fc8b57ebd0aae324aeaefdd6e68492a9c4470f2d1e8b
|
3 |
+
size 136867539
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python3-opencv
|
packages_db.xlsx
ADDED
Binary file (54.9 kB). View file
|
|
pages/2_Image_Processing.py
ADDED
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Set the page config
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
st.set_page_config(
|
5 |
+
page_title="Image Processing",
|
6 |
+
page_icon=":open_file_folder:",
|
7 |
+
layout="wide",
|
8 |
+
initial_sidebar_state="collapsed",
|
9 |
+
)
|
10 |
+
|
11 |
+
# Importing necessary libraries
|
12 |
+
import cv2
|
13 |
+
import utils
|
14 |
+
import numpy as np
|
15 |
+
import Functions.image_processing_functions as image_processing_functions
|
16 |
+
|
17 |
+
# Load image processing technique parameters and details from an Excel file
|
18 |
+
image_processing_params_df = utils.load_data_from_excel(
|
19 |
+
"packages_db.xlsx", "image_processing_parameters"
|
20 |
+
)
|
21 |
+
image_processing_details_df = utils.load_data_from_excel(
|
22 |
+
"packages_db.xlsx", "image_processing_details"
|
23 |
+
)
|
24 |
+
|
25 |
+
# Display the page title
|
26 |
+
st.title("Image Processing")
|
27 |
+
|
28 |
+
# # Clear the Streamlit session state on the first load of the page
|
29 |
+
# utils.clear_session_state_on_first_load("image_processing_clear")
|
30 |
+
|
31 |
+
# List of session state keys to initialize if they are not already present
|
32 |
+
session_state_keys = [
|
33 |
+
"file_uploader_key_processing",
|
34 |
+
"select_processing_technique_key_processing",
|
35 |
+
]
|
36 |
+
|
37 |
+
# Iterate through each session state key
|
38 |
+
for key in session_state_keys:
|
39 |
+
# Check if the key is not already in the session state
|
40 |
+
if key not in st.session_state:
|
41 |
+
# Initialize the key with a dictionary containing itself set to True
|
42 |
+
st.session_state[key] = {key: True}
|
43 |
+
|
44 |
+
# Initialize session state variables if not present
|
45 |
+
if "validation_triggered" not in st.session_state:
|
46 |
+
st.session_state["validation_triggered"] = False
|
47 |
+
|
48 |
+
if "uploaded_files_cache_processing" not in st.session_state:
|
49 |
+
st.session_state["uploaded_files_cache_processing"] = False
|
50 |
+
|
51 |
+
if "zip_data_processing" not in st.session_state:
|
52 |
+
st.session_state["zip_data_processing"] = ""
|
53 |
+
|
54 |
+
if "widget_states" not in st.session_state:
|
55 |
+
st.session_state["widget_states"] = {}
|
56 |
+
|
57 |
+
# Interface for uploading an images and labels
|
58 |
+
utils.display_file_uploader(
|
59 |
+
"uploaded_files",
|
60 |
+
"Choose images and labels...",
|
61 |
+
st.session_state["file_uploader_key_processing"],
|
62 |
+
st.session_state["uploaded_files_cache_processing"],
|
63 |
+
)
|
64 |
+
|
65 |
+
# Note to users
|
66 |
+
st.markdown(
|
67 |
+
"""
|
68 |
+
<div style='text-align: justify;'>
|
69 |
+
<b>Note to Users:</b>
|
70 |
+
<ul>
|
71 |
+
<li>The <i>first uploaded image</i> will be used for demonstration purposes and to validate parameters for image processing techniques.</li>
|
72 |
+
<li>Uploading <i>labels is optional</i>. If no labels are uploaded, the output will consist solely of processed images.</li>
|
73 |
+
<li>When moving to another page or if you wish to upload a new set of images and labels, don't forget to hit the <b>Reset</b> button. This helps in faster computation and frees up unused memory, ensuring smoother operation.</li>
|
74 |
+
</ul>
|
75 |
+
</div>
|
76 |
+
""",
|
77 |
+
unsafe_allow_html=True,
|
78 |
+
)
|
79 |
+
|
80 |
+
# List of session state variables to initialize
|
81 |
+
session_vars = [
|
82 |
+
"is_valid",
|
83 |
+
"image_files",
|
84 |
+
"label_files",
|
85 |
+
"first_image_file",
|
86 |
+
"first_label_file",
|
87 |
+
]
|
88 |
+
|
89 |
+
# Initialize each variable as None if it doesn't exist in the session state
|
90 |
+
for var in session_vars:
|
91 |
+
if var not in st.session_state:
|
92 |
+
st.session_state[var] = None
|
93 |
+
|
94 |
+
# Create two columns
|
95 |
+
col1, col2 = st.columns(2)
|
96 |
+
|
97 |
+
# Button to trigger validation
|
98 |
+
if (
|
99 |
+
col1.button("Validate Input", use_container_width=True)
|
100 |
+
or st.session_state["widget_states"].get("validate_input_button", False)
|
101 |
+
) and not st.session_state["validation_triggered"]:
|
102 |
+
st.session_state["validation_triggered"] = True
|
103 |
+
st.session_state["uploaded_files_cache_processing"] = True
|
104 |
+
|
105 |
+
(
|
106 |
+
st.session_state["is_valid"],
|
107 |
+
st.session_state["image_files"],
|
108 |
+
st.session_state["label_files"],
|
109 |
+
st.session_state["first_image_file"],
|
110 |
+
st.session_state["first_label_file"],
|
111 |
+
) = image_processing_functions.check_valid_labels(
|
112 |
+
st.session_state["uploaded_files"]
|
113 |
+
)
|
114 |
+
|
115 |
+
elif st.session_state["validation_triggered"]:
|
116 |
+
pass
|
117 |
+
|
118 |
+
else:
|
119 |
+
st.session_state["is_valid"] = False
|
120 |
+
st.warning(
|
121 |
+
"Please upload images and labels and click **Validate Input**.", icon="⚠️"
|
122 |
+
)
|
123 |
+
|
124 |
+
with col2:
|
125 |
+
# Check if the 'Reset' button is pressed
|
126 |
+
if st.button("Reset", use_container_width=True):
|
127 |
+
# Toggle the keys for file uploader and processing technique to reset their states
|
128 |
+
current_value = st.session_state["file_uploader_key_processing"][
|
129 |
+
"file_uploader_key_processing"
|
130 |
+
]
|
131 |
+
updated_value = not current_value # Invert the current value
|
132 |
+
|
133 |
+
# List of session state keys that need to be reset
|
134 |
+
session_state_keys = [
|
135 |
+
"file_uploader_key_processing",
|
136 |
+
"select_processing_technique_key_processing",
|
137 |
+
]
|
138 |
+
|
139 |
+
# Iterate through each session state key
|
140 |
+
for session_state_key in session_state_keys:
|
141 |
+
# Update each key in the session state with the toggled value
|
142 |
+
st.session_state[session_state_key] = {session_state_key: updated_value}
|
143 |
+
|
144 |
+
# Clear all other session state keys except for widget_state_keys
|
145 |
+
for key in list(st.session_state.keys()):
|
146 |
+
if key not in session_state_keys:
|
147 |
+
del st.session_state[key]
|
148 |
+
|
149 |
+
# Clear global variables except for protected and Streamlit module
|
150 |
+
global_vars = list(globals().keys())
|
151 |
+
vars_to_delete = [
|
152 |
+
var for var in global_vars if not var.startswith("_") and var != "st"
|
153 |
+
]
|
154 |
+
for var in vars_to_delete:
|
155 |
+
del globals()[var]
|
156 |
+
|
157 |
+
# Clear the Streamlit caches
|
158 |
+
st.cache_resource.clear()
|
159 |
+
st.cache_data.clear()
|
160 |
+
|
161 |
+
# Rerun the app to reflect the reset state
|
162 |
+
st.rerun()
|
163 |
+
|
164 |
+
# Interface to select image processing techniques
|
165 |
+
available_image_processings = image_processing_details_df["Name"]
|
166 |
+
|
167 |
+
|
168 |
+
# Mapping each image processing techniques to its corresponding image types
|
169 |
+
input_mapping_dict = utils.technique_image_input_mapping(
|
170 |
+
available_image_processings, image_processing_details_df
|
171 |
+
)
|
172 |
+
|
173 |
+
# Present the option to select image processing techniques only if the uploaded files are validated successfully
|
174 |
+
if st.session_state["is_valid"]:
|
175 |
+
selected_image_processings = st.multiselect(
|
176 |
+
"Select image processing technique(s)",
|
177 |
+
available_image_processings,
|
178 |
+
key=st.session_state["select_processing_technique_key_processing"],
|
179 |
+
)
|
180 |
+
|
181 |
+
# Read the first uploaded image into a NumPy array
|
182 |
+
st.session_state["first_image_file"].seek(0) # Reset file pointer to start
|
183 |
+
file_bytes_first_image = np.frombuffer(
|
184 |
+
st.session_state["first_image_file"].read(), dtype=np.uint8
|
185 |
+
)
|
186 |
+
uploaded_first_image = cv2.cvtColor(
|
187 |
+
cv2.imdecode(file_bytes_first_image, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB
|
188 |
+
)
|
189 |
+
|
190 |
+
# Resize the image
|
191 |
+
uploaded_first_image = cv2.resize(uploaded_first_image, (256, 256))
|
192 |
+
|
193 |
+
else:
|
194 |
+
# Reset selected techniques to empty if input validation fails
|
195 |
+
selected_image_processings = []
|
196 |
+
|
197 |
+
|
198 |
+
#######################################################################################################
|
199 |
+
# Build custom image processing pipeline
|
200 |
+
#######################################################################################################
|
201 |
+
|
202 |
+
|
203 |
+
# Store parameters for each selected image processing technique
|
204 |
+
image_processings_params = {}
|
205 |
+
|
206 |
+
# Initialize a flag to track if any error exists
|
207 |
+
error_in_parameters = False
|
208 |
+
|
209 |
+
# Loop through each selected image processing techniques to set up parameters
|
210 |
+
for image_processing in selected_image_processings:
|
211 |
+
with st.expander(f"{image_processing}"):
|
212 |
+
# Retrieve image processing details from the database
|
213 |
+
image_processing_info = image_processing_details_df[
|
214 |
+
image_processing_details_df["Name"] == image_processing
|
215 |
+
]
|
216 |
+
|
217 |
+
# Set up columns for displaying details and image placeholders
|
218 |
+
details_col, image_col = st.columns([7, 3])
|
219 |
+
|
220 |
+
with details_col:
|
221 |
+
# Display the description for the image processing technique
|
222 |
+
image_processing_description = (
|
223 |
+
image_processing_info["Description"].iloc[0]
|
224 |
+
if not image_processing_info.empty
|
225 |
+
else "No description available."
|
226 |
+
)
|
227 |
+
st.markdown(
|
228 |
+
f"<div style='text-align: justify;'><b>Description:</b> {image_processing_description}</div>",
|
229 |
+
unsafe_allow_html=True,
|
230 |
+
)
|
231 |
+
|
232 |
+
# Display the category for the image processing
|
233 |
+
image_processing_category = (
|
234 |
+
image_processing_info["Category"].iloc[0]
|
235 |
+
if not image_processing_info.empty
|
236 |
+
else "Unknown"
|
237 |
+
)
|
238 |
+
st.write("Category:", image_processing_category)
|
239 |
+
|
240 |
+
# Retrieve the source code link for the image processing
|
241 |
+
image_processing_source_code = (
|
242 |
+
image_processing_info["Source Code Link"].iloc[0]
|
243 |
+
if not image_processing_info.empty
|
244 |
+
else "www.google.com"
|
245 |
+
)
|
246 |
+
# Set up columns for displaying source code button and custom settings checkbox
|
247 |
+
source_code_col, custo_setting_col = st.columns(2)
|
248 |
+
|
249 |
+
source_code_col.link_button("Source Code", image_processing_source_code)
|
250 |
+
|
251 |
+
# Toggle for custom settings
|
252 |
+
custom_settings = custo_setting_col.checkbox(
|
253 |
+
f"Customize {image_processing}", key=f"toggle_{image_processing}"
|
254 |
+
)
|
255 |
+
|
256 |
+
with image_col:
|
257 |
+
# Create two columns
|
258 |
+
col1, col2 = st.columns(2)
|
259 |
+
original_image_placeholder = col1.container(height=200, border=False)
|
260 |
+
processed_image_placeholder = col2.container(height=200, border=False)
|
261 |
+
|
262 |
+
# Apply custom settings
|
263 |
+
if custom_settings:
|
264 |
+
# Retrieve parameters for the image processing
|
265 |
+
params_df = image_processing_params_df[
|
266 |
+
image_processing_params_df["Name"] == image_processing
|
267 |
+
]
|
268 |
+
|
269 |
+
# Process parameters for each image processing technique and store in a dictionary
|
270 |
+
image_processings_params[image_processing] = utils.process_image_parameters(
|
271 |
+
params_df, image_processing
|
272 |
+
)
|
273 |
+
|
274 |
+
else:
|
275 |
+
# Use default settings if customization is not selected
|
276 |
+
image_processings_params[image_processing] = utils.get_default_params(
|
277 |
+
image_processing
|
278 |
+
)
|
279 |
+
|
280 |
+
# Check for errors in the selected parameters by applying them to a sample image
|
281 |
+
(
|
282 |
+
error_flag,
|
283 |
+
processed_first_image,
|
284 |
+
) = image_processing_functions.apply_and_test_image_processing(
|
285 |
+
image_processing,
|
286 |
+
image_processings_params[image_processing],
|
287 |
+
uploaded_first_image,
|
288 |
+
input_mapping_dict[image_processing],
|
289 |
+
)
|
290 |
+
|
291 |
+
# If there is an error in the parameters, set the global error flag
|
292 |
+
if error_flag:
|
293 |
+
error_in_parameters = True
|
294 |
+
else:
|
295 |
+
# If no error, display the original and processed images side by side
|
296 |
+
# Display the original and processed images in their respective placeholders
|
297 |
+
with original_image_placeholder:
|
298 |
+
st.image(
|
299 |
+
uploaded_first_image,
|
300 |
+
caption="Original Image",
|
301 |
+
use_column_width=True,
|
302 |
+
clamp=True,
|
303 |
+
)
|
304 |
+
with processed_image_placeholder:
|
305 |
+
st.image(
|
306 |
+
processed_first_image,
|
307 |
+
caption="Processed Image",
|
308 |
+
use_column_width=True,
|
309 |
+
clamp=True,
|
310 |
+
)
|
311 |
+
|
312 |
+
# Update the base image with the previously processed image output
|
313 |
+
uploaded_first_image = processed_first_image
|
314 |
+
|
315 |
+
|
316 |
+
#######################################################################################################
|
317 |
+
# Display selected image processing technique parameters as DataFrame
|
318 |
+
#######################################################################################################
|
319 |
+
|
320 |
+
|
321 |
+
# Check if any image processings have been defined
|
322 |
+
if (image_processings_params.keys()) and (not error_in_parameters):
|
323 |
+
# Create a dropdown for selecting an image processing technique or 'All'
|
324 |
+
selected_image_processing = st.selectbox(
|
325 |
+
"Select image processing technique",
|
326 |
+
options=["All"] + list(image_processings_params.keys()),
|
327 |
+
)
|
328 |
+
else:
|
329 |
+
selected_image_processing = None
|
330 |
+
|
331 |
+
# Create the DataFrame from the accumulated data
|
332 |
+
image_processings_df = image_processing_functions.create_image_processings_dataframe(
|
333 |
+
image_processings_params, image_processing_params_df
|
334 |
+
)
|
335 |
+
image_processings_df["Value"] = image_processings_df["Value"].astype(
|
336 |
+
str
|
337 |
+
) # Ensure consistent data types and handle potential serialization issues
|
338 |
+
|
339 |
+
# Filter the DataFrame based on the selected image processing
|
340 |
+
if selected_image_processing != "All":
|
341 |
+
filtered_image_processings_df = image_processings_df[
|
342 |
+
image_processings_df["image_processing"] == selected_image_processing
|
343 |
+
]
|
344 |
+
else:
|
345 |
+
filtered_image_processings_df = image_processings_df
|
346 |
+
|
347 |
+
# Check if the filtered dataframe is not empty and the selected configurations are valid
|
348 |
+
if (not filtered_image_processings_df.empty) and (not error_in_parameters):
|
349 |
+
# Display the DataFrame in Streamlit and use the full width of the container
|
350 |
+
st.dataframe(filtered_image_processings_df, use_container_width=False)
|
351 |
+
|
352 |
+
# Display code and description
|
353 |
+
code_placeholder = st.empty()
|
354 |
+
|
355 |
+
|
356 |
+
#######################################################################################################
|
357 |
+
# Process images and download processed images
|
358 |
+
#######################################################################################################
|
359 |
+
|
360 |
+
|
361 |
+
# Proceed if inputs are valid, techniques selected, and no errors in configurations
|
362 |
+
if (
|
363 |
+
st.session_state["is_valid"]
|
364 |
+
and (len(selected_image_processings) > 0)
|
365 |
+
and not error_in_parameters
|
366 |
+
):
|
367 |
+
# Create two columns
|
368 |
+
col1, col2 = st.columns(2)
|
369 |
+
|
370 |
+
# Allow user to specify the number of variations to be generated
|
371 |
+
num_variations = col1.number_input(
|
372 |
+
"Set the number of variations to be generated",
|
373 |
+
min_value=1,
|
374 |
+
max_value=3,
|
375 |
+
step=1,
|
376 |
+
)
|
377 |
+
|
378 |
+
# Checkbox to include original images and labels in the output
|
379 |
+
with col2:
|
380 |
+
for top_padding in range(2): # Top padding
|
381 |
+
st.write("")
|
382 |
+
|
383 |
+
include_original = st.checkbox(
|
384 |
+
"Include original images and labels in output", value=False
|
385 |
+
)
|
386 |
+
|
387 |
+
# Display code and download once all inputs are available
|
388 |
+
with code_placeholder:
|
389 |
+
# Generate the code with the function
|
390 |
+
if len(st.session_state["label_files"]) == 0:
|
391 |
+
generated_code = utils.generate_python_code_images(
|
392 |
+
image_processings_params,
|
393 |
+
num_variations,
|
394 |
+
include_original,
|
395 |
+
)
|
396 |
+
else:
|
397 |
+
generated_code = (
|
398 |
+
image_processing_functions.generate_python_code_images_labels(
|
399 |
+
image_processings_params,
|
400 |
+
num_variations,
|
401 |
+
include_original,
|
402 |
+
)
|
403 |
+
)
|
404 |
+
|
405 |
+
# Display the generated Python code with a description and provide a download button in the Streamlit app
|
406 |
+
image_processing_functions.display_code_and_download_button(generated_code)
|
407 |
+
|
408 |
+
# Create two columns
|
409 |
+
col1, col2 = st.columns(2)
|
410 |
+
|
411 |
+
# Add a button for the user to confirm their selections and proceed with processing
|
412 |
+
if col1.button("Accept and Process", use_container_width=True):
|
413 |
+
# Call the function and store the results
|
414 |
+
image_processing_functions.process_images_and_labels(
|
415 |
+
st.session_state["image_files"],
|
416 |
+
st.session_state["label_files"],
|
417 |
+
selected_image_processings,
|
418 |
+
image_processings_params,
|
419 |
+
num_variations,
|
420 |
+
include_original,
|
421 |
+
)
|
422 |
+
|
423 |
+
# Download button
|
424 |
+
col2.download_button(
|
425 |
+
label="Download",
|
426 |
+
data=st.session_state["zip_data_processing"],
|
427 |
+
file_name="processed_images.zip",
|
428 |
+
mime="application/zip",
|
429 |
+
use_container_width=True,
|
430 |
+
disabled=False,
|
431 |
+
)
|
432 |
+
|
433 |
+
|
434 |
+
else:
|
435 |
+
if (len(selected_image_processings) == 0) and st.session_state["is_valid"]:
|
436 |
+
# Inform the user that no image processing techniques have been selected
|
437 |
+
st.warning("Please select at least one image processing technique.", icon="⚠️")
|
438 |
+
|
439 |
+
if error_in_parameters and st.session_state["is_valid"]:
|
440 |
+
# Inform the user that there are errors in parameters
|
441 |
+
st.warning(
|
442 |
+
"There are errors in the image processing parameters. Please review your selections.",
|
443 |
+
icon="⚠️",
|
444 |
+
)
|
445 |
+
|
446 |
+
|
447 |
+
#######################################################################################################
|
448 |
+
# Display original and processed images
|
449 |
+
#######################################################################################################
|
450 |
+
|
451 |
+
|
452 |
+
if (
|
453 |
+
"image_repository_preprocessing" in st.session_state
|
454 |
+
and "processed_image_mapping_procesing" in st.session_state
|
455 |
+
):
|
456 |
+
# Number of unique images
|
457 |
+
num_unique_images = len(st.session_state["unique_images_names"])
|
458 |
+
|
459 |
+
if num_unique_images > 1:
|
460 |
+
# Create a slider to select an image index from the processed image mapping
|
461 |
+
selected_image_index = st.slider(
|
462 |
+
"Select an Image",
|
463 |
+
min_value=1,
|
464 |
+
max_value=num_unique_images, # Set the maximum to the number of unique images
|
465 |
+
step=1,
|
466 |
+
)
|
467 |
+
else:
|
468 |
+
selected_image_index = 1
|
469 |
+
|
470 |
+
# Retrieve the name of the selected original image using the slider index
|
471 |
+
selected_original_image_name = st.session_state["unique_images_names"][
|
472 |
+
selected_image_index - 1
|
473 |
+
]
|
474 |
+
|
475 |
+
# Retrieve the names of all processed variants for the selected original image
|
476 |
+
processed_variant_names = st.session_state["processed_image_mapping_procesing"].get(
|
477 |
+
selected_original_image_name, []
|
478 |
+
)
|
479 |
+
|
480 |
+
# Combine the original image name with its processed variants
|
481 |
+
all_image_names = [selected_original_image_name] + processed_variant_names
|
482 |
+
|
483 |
+
# Number of images and columns
|
484 |
+
num_images = len(all_image_names)
|
485 |
+
num_columns = 4
|
486 |
+
|
487 |
+
# Display images in a grid of 4 columns and dynamic number of rows
|
488 |
+
for i in range(0, num_images, num_columns):
|
489 |
+
# Create a row of columns
|
490 |
+
cols = st.columns(num_columns)
|
491 |
+
for j in range(num_columns):
|
492 |
+
# Calculate the current image index
|
493 |
+
image_index = i + j
|
494 |
+
if image_index < num_images:
|
495 |
+
# Get the image name and data from the repository
|
496 |
+
image_name = all_image_names[image_index]
|
497 |
+
image_data = st.session_state["image_repository_preprocessing"][
|
498 |
+
image_name
|
499 |
+
]["image"]
|
500 |
+
|
501 |
+
# Display the image in the respective column with caption
|
502 |
+
with cols[j]:
|
503 |
+
st.image(
|
504 |
+
image_data,
|
505 |
+
clamp=True,
|
506 |
+
caption=image_name,
|
507 |
+
use_column_width=True,
|
508 |
+
)
|
509 |
+
|
510 |
+
|
511 |
+
# if st.button("Run"):
|
512 |
+
# utils.button_click(on_click=None)
|
pages/3_Image_Augmentation.py
ADDED
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Set the page config
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
st.set_page_config(
|
5 |
+
page_title="Image Augmentation",
|
6 |
+
page_icon=":open_file_folder:",
|
7 |
+
layout="wide",
|
8 |
+
initial_sidebar_state="collapsed",
|
9 |
+
)
|
10 |
+
|
11 |
+
# Importing necessary libraries
|
12 |
+
import cv2
|
13 |
+
import utils
|
14 |
+
import numpy as np
|
15 |
+
import Functions.image_augmentation_functions as augmentation_functions
|
16 |
+
|
17 |
+
|
18 |
+
# Load augmentation technique parameters and details from an Excel file
|
19 |
+
augmentation_params_df = utils.load_data_from_excel(
|
20 |
+
"packages_db.xlsx", "augmentation_parameters"
|
21 |
+
)
|
22 |
+
augmentation_details_df = utils.load_data_from_excel(
|
23 |
+
"packages_db.xlsx", "augmentation_details"
|
24 |
+
)
|
25 |
+
|
26 |
+
# Display the page title
|
27 |
+
st.title("Image Augmentation")
|
28 |
+
|
29 |
+
# # Clear the Streamlit session state on the first load of the page
|
30 |
+
# utils.clear_session_state_on_first_load("image_augmentation_clear")
|
31 |
+
|
32 |
+
# List of session state keys to initialize if they are not already present
|
33 |
+
session_state_keys = [
|
34 |
+
"file_uploader_key_augmentation",
|
35 |
+
"select_processing_technique_key_augmentation",
|
36 |
+
"selected_option_key_augmentation",
|
37 |
+
"class_labels_input_key_augmentation",
|
38 |
+
"bbox1_key",
|
39 |
+
"bbox2_key",
|
40 |
+
"bbox3_key",
|
41 |
+
"bbox4_key",
|
42 |
+
"bbox5_key",
|
43 |
+
]
|
44 |
+
|
45 |
+
# Iterate through each session state key
|
46 |
+
for key in session_state_keys:
|
47 |
+
# Check if the key is not already in the session state
|
48 |
+
if key not in st.session_state:
|
49 |
+
# Initialize the key with a dictionary containing itself set to True
|
50 |
+
st.session_state[key] = {key: True}
|
51 |
+
|
52 |
+
# Initialize session state variables if not present
|
53 |
+
if "validation_triggered" not in st.session_state:
|
54 |
+
st.session_state["validation_triggered"] = False
|
55 |
+
|
56 |
+
if "uploaded_files_cache_augmentation" not in st.session_state:
|
57 |
+
st.session_state["uploaded_files_cache_augmentation"] = False
|
58 |
+
|
59 |
+
if "zip_data_augmentation" not in st.session_state:
|
60 |
+
st.session_state["zip_data_augmentation"] = ""
|
61 |
+
|
62 |
+
# Interface for uploading an images and labels
|
63 |
+
utils.display_file_uploader(
|
64 |
+
"uploaded_files",
|
65 |
+
"Choose images and labels...",
|
66 |
+
st.session_state["file_uploader_key_augmentation"],
|
67 |
+
st.session_state["uploaded_files_cache_augmentation"],
|
68 |
+
)
|
69 |
+
|
70 |
+
# Dropdown for selecting label type
|
71 |
+
label_type = st.selectbox(
|
72 |
+
"Choose the label type for your augmentation process:",
|
73 |
+
["Masks", "Bboxes"],
|
74 |
+
index=1,
|
75 |
+
on_change=utils.reset_validation_trigger,
|
76 |
+
key=st.session_state["selected_option_key_augmentation"],
|
77 |
+
)
|
78 |
+
|
79 |
+
# Choosing parameters based on the label type selected by the user
|
80 |
+
if label_type == "Bboxes":
|
81 |
+
# If the selected label type is Bboxes, call the bbox_params function
|
82 |
+
label_input_parameters = augmentation_functions.bbox_params()
|
83 |
+
elif label_type == "Masks":
|
84 |
+
# If the selected label type is Masks
|
85 |
+
label_input_parameters = None
|
86 |
+
|
87 |
+
|
88 |
+
# Text area for user to input class labels
|
89 |
+
class_labels_input = st.text_area(
|
90 |
+
"Enter class labels, separated by commas:",
|
91 |
+
utils.sample_class_labels,
|
92 |
+
on_change=utils.reset_validation_trigger,
|
93 |
+
key=st.session_state["class_labels_input_key_augmentation"],
|
94 |
+
) # Example default values
|
95 |
+
class_labels_input = (
|
96 |
+
class_labels_input.strip()
|
97 |
+
) # Remove unecessary space form start and end
|
98 |
+
|
99 |
+
|
100 |
+
# Generating a dictionary mapping class IDs to their respective labels
|
101 |
+
try:
|
102 |
+
class_labels = [
|
103 |
+
label.strip() for label in class_labels_input.split(",") if label.strip()
|
104 |
+
]
|
105 |
+
class_dict = {
|
106 |
+
i + 1: label for i, label in enumerate(class_labels)
|
107 |
+
} # Shifting class labels (keys) by 1, since 0 is reserved for the background
|
108 |
+
# Invert the class_dict to map class names to class IDs
|
109 |
+
class_names_to_ids = {v: k for k, v in class_dict.items()}
|
110 |
+
|
111 |
+
colors = augmentation_functions.generate_unique_colors(class_dict.keys())
|
112 |
+
except Exception as e:
|
113 |
+
st.warning(
|
114 |
+
"Invalid format for class labels. Please enter labels separated by commas.",
|
115 |
+
icon="⚠️",
|
116 |
+
)
|
117 |
+
class_dict, class_names_to_ids = (
|
118 |
+
{},
|
119 |
+
{},
|
120 |
+
) # Keeping class_dict and class_names_to_ids as an empty
|
121 |
+
|
122 |
+
|
123 |
+
# Note to users
|
124 |
+
st.markdown(
|
125 |
+
"""
|
126 |
+
<div style='text-align: justify;'>
|
127 |
+
<b>Note to Users:</b>
|
128 |
+
<ul>
|
129 |
+
<li>The <i>first uploaded image</i> will be used for demonstration purposes and to validate parameters for augmentation techniques.</li>
|
130 |
+
<li>Uploading <i>labels is optional</i>. If no labels are uploaded, the output will consist solely of processed images.</li>
|
131 |
+
<li>When moving to another page or if you wish to upload a new set of images and labels, don't forget to hit the <b>Reset</b> button. This helps in faster computation and frees up unused memory, ensuring smoother operation.</li>
|
132 |
+
<li>Select the class labels, label type and label parameters before uploading large data for faster computation and more efficient processing.</li>
|
133 |
+
</ul>
|
134 |
+
</div>
|
135 |
+
""",
|
136 |
+
unsafe_allow_html=True,
|
137 |
+
)
|
138 |
+
|
139 |
+
# List of session state variables to initialize
|
140 |
+
session_vars = [
|
141 |
+
"is_valid",
|
142 |
+
"image_files",
|
143 |
+
"label_files",
|
144 |
+
"first_image_file",
|
145 |
+
"first_label_file",
|
146 |
+
]
|
147 |
+
|
148 |
+
# Initialize each variable as None if it doesn't exist in the session state
|
149 |
+
for var in session_vars:
|
150 |
+
if var not in st.session_state:
|
151 |
+
st.session_state[var] = None
|
152 |
+
|
153 |
+
# Create two columns
|
154 |
+
col1, col2 = st.columns(2)
|
155 |
+
|
156 |
+
# Button to trigger validation
|
157 |
+
if (
|
158 |
+
col1.button("Validate Input", use_container_width=True)
|
159 |
+
and not st.session_state["validation_triggered"]
|
160 |
+
):
|
161 |
+
st.session_state["validation_triggered"] = True
|
162 |
+
st.session_state["uploaded_files_cache_augmentation"] = True
|
163 |
+
|
164 |
+
(
|
165 |
+
st.session_state["is_valid"],
|
166 |
+
st.session_state["image_files"],
|
167 |
+
st.session_state["label_files"],
|
168 |
+
st.session_state["first_image_file"],
|
169 |
+
st.session_state["first_label_file"],
|
170 |
+
) = augmentation_functions.check_valid_labels(
|
171 |
+
st.session_state["uploaded_files"], label_type, class_dict
|
172 |
+
)
|
173 |
+
|
174 |
+
elif st.session_state["validation_triggered"]:
|
175 |
+
pass
|
176 |
+
|
177 |
+
else:
|
178 |
+
st.session_state["is_valid"] = False
|
179 |
+
st.warning(
|
180 |
+
"Please upload images and labels and click **Validate Input**.", icon="⚠️"
|
181 |
+
)
|
182 |
+
|
183 |
+
with col2:
|
184 |
+
# Check if the 'Reset' button is pressed
|
185 |
+
if st.button("Reset", use_container_width=True):
|
186 |
+
# Toggle the keys for file uploader and processing technique to reset their states
|
187 |
+
current_value = st.session_state["file_uploader_key_augmentation"][
|
188 |
+
"file_uploader_key_augmentation"
|
189 |
+
]
|
190 |
+
updated_value = not current_value # Invert the current value
|
191 |
+
|
192 |
+
# Iterate through each session state key
|
193 |
+
for session_state_key in session_state_keys:
|
194 |
+
# Update each key in the session state with the toggled value
|
195 |
+
st.session_state[session_state_key] = {session_state_key: updated_value}
|
196 |
+
|
197 |
+
# Clear all other session state keys except for widget_state_keys
|
198 |
+
for key in list(st.session_state.keys()):
|
199 |
+
if key not in session_state_keys:
|
200 |
+
del st.session_state[key]
|
201 |
+
|
202 |
+
# Clear global variables except for protected and Streamlit module
|
203 |
+
global_vars = list(globals().keys())
|
204 |
+
vars_to_delete = [
|
205 |
+
var for var in global_vars if not var.startswith("_") and var != "st"
|
206 |
+
]
|
207 |
+
for var in vars_to_delete:
|
208 |
+
del globals()[var]
|
209 |
+
|
210 |
+
# Clear the Streamlit caches
|
211 |
+
st.cache_resource.clear()
|
212 |
+
st.cache_data.clear()
|
213 |
+
|
214 |
+
# Rerun the app to reflect the reset state
|
215 |
+
st.rerun()
|
216 |
+
|
217 |
+
# Fetching the names of techniques applicable to the selected option
|
218 |
+
available_augmentations = augmentation_functions.get_applicable_techniques(
|
219 |
+
augmentation_details_df, label_type
|
220 |
+
)
|
221 |
+
|
222 |
+
# Mapping each image processing techniques to its corresponding image types
|
223 |
+
input_mapping_dict = utils.technique_image_input_mapping(
|
224 |
+
available_augmentations, augmentation_details_df
|
225 |
+
)
|
226 |
+
|
227 |
+
# Present the option to select augmentation techniques only if the uploaded files are validated successfully
|
228 |
+
if st.session_state["is_valid"]:
|
229 |
+
selected_augmentations = st.multiselect(
|
230 |
+
"Select augmentation technique(s)",
|
231 |
+
available_augmentations,
|
232 |
+
key=st.session_state["select_processing_technique_key_augmentation"],
|
233 |
+
)
|
234 |
+
|
235 |
+
# Read the first uploaded image into a NumPy array
|
236 |
+
st.session_state["first_image_file"].seek(0) # Reset file pointer to start
|
237 |
+
file_bytes_first_image = np.frombuffer(
|
238 |
+
st.session_state["first_image_file"].read(), dtype=np.uint8
|
239 |
+
)
|
240 |
+
uploaded_first_image = cv2.cvtColor(
|
241 |
+
cv2.imdecode(file_bytes_first_image, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB
|
242 |
+
)
|
243 |
+
|
244 |
+
# # Resize the image
|
245 |
+
# uploaded_first_image = cv2.resize(uploaded_first_image, (256, 256))
|
246 |
+
|
247 |
+
else:
|
248 |
+
# Reset selected techniques to empty if input validation fails
|
249 |
+
selected_augmentations = []
|
250 |
+
|
251 |
+
|
252 |
+
#######################################################################################################
|
253 |
+
# Build custom augmentation pipeline
|
254 |
+
#######################################################################################################
|
255 |
+
|
256 |
+
|
257 |
+
# Store parameters for each selected augmentation technique
|
258 |
+
augmentations_params = {}
|
259 |
+
|
260 |
+
# Initialize a flag to track if any error exists
|
261 |
+
error_in_parameters = False
|
262 |
+
|
263 |
+
# Loop through each selected augmentation techniques to set up parameters
|
264 |
+
for augmentation in selected_augmentations:
|
265 |
+
with st.expander(f"{augmentation}"):
|
266 |
+
# Retrieve augmentation details from the database
|
267 |
+
augmentation_info = augmentation_details_df[
|
268 |
+
augmentation_details_df["Name"] == augmentation
|
269 |
+
]
|
270 |
+
|
271 |
+
# Set up columns for displaying details and image placeholders
|
272 |
+
details_col, image_col = st.columns([7, 3])
|
273 |
+
|
274 |
+
with details_col:
|
275 |
+
# Display the description for the augmentation technique
|
276 |
+
augmentation_description = (
|
277 |
+
augmentation_info["Description"].iloc[0]
|
278 |
+
if not augmentation_info.empty
|
279 |
+
else "No description available."
|
280 |
+
)
|
281 |
+
st.markdown(
|
282 |
+
f"<div style='text-align: justify;'><b>Description:</b> {augmentation_description}</div>",
|
283 |
+
unsafe_allow_html=True,
|
284 |
+
)
|
285 |
+
|
286 |
+
# Display the category for the augmentation
|
287 |
+
augmentation_category = (
|
288 |
+
augmentation_info["Category"].iloc[0]
|
289 |
+
if not augmentation_info.empty
|
290 |
+
else "Unknown"
|
291 |
+
)
|
292 |
+
st.write("Category:", augmentation_category)
|
293 |
+
|
294 |
+
# Retrieve the source code link for the augmentation
|
295 |
+
augmentation_source_code = (
|
296 |
+
augmentation_info["Source Code Link"].iloc[0]
|
297 |
+
if not augmentation_info.empty
|
298 |
+
else "www.google.com"
|
299 |
+
)
|
300 |
+
|
301 |
+
# Set up columns for displaying source code button and custom settings checkbox
|
302 |
+
source_code_col, custo_setting_col = st.columns(2)
|
303 |
+
|
304 |
+
source_code_col.link_button("Source Code", augmentation_source_code)
|
305 |
+
|
306 |
+
# Toggle for custom settings
|
307 |
+
custom_settings = custo_setting_col.checkbox(
|
308 |
+
f"Customize {augmentation}", key=f"toggle_{augmentation}"
|
309 |
+
)
|
310 |
+
|
311 |
+
with image_col:
|
312 |
+
# Create two columns
|
313 |
+
col1, col2 = st.columns(2)
|
314 |
+
original_image_placeholder = col1.container(height=150, border=False)
|
315 |
+
processed_image_placeholder = col2.container(height=150, border=False)
|
316 |
+
|
317 |
+
# Apply custom settings
|
318 |
+
if custom_settings:
|
319 |
+
# Retrieve parameters for the augmentation
|
320 |
+
params_df = augmentation_params_df[
|
321 |
+
augmentation_params_df["Name"] == augmentation
|
322 |
+
]
|
323 |
+
|
324 |
+
# Process parameters for each augmentation technique and store in a dictionary
|
325 |
+
augmentations_params[augmentation] = utils.process_image_parameters(
|
326 |
+
params_df, augmentation
|
327 |
+
)
|
328 |
+
|
329 |
+
else:
|
330 |
+
# Use default settings if customization is not selected
|
331 |
+
augmentations_params[augmentation] = utils.get_default_params(augmentation)
|
332 |
+
|
333 |
+
# Check for errors in the selected parameters by applying them to a sample image
|
334 |
+
(
|
335 |
+
error_flag,
|
336 |
+
processed_first_image,
|
337 |
+
) = augmentation_functions.apply_and_test_augmentation(
|
338 |
+
augmentation,
|
339 |
+
augmentations_params[augmentation],
|
340 |
+
uploaded_first_image,
|
341 |
+
st.session_state["first_label_file"],
|
342 |
+
label_type,
|
343 |
+
label_input_parameters,
|
344 |
+
input_mapping_dict[augmentation],
|
345 |
+
)
|
346 |
+
|
347 |
+
# If there is an error in the parameters, set the global error flag
|
348 |
+
if error_flag:
|
349 |
+
error_in_parameters = True
|
350 |
+
else:
|
351 |
+
# If no error, display the original and processed images side by side
|
352 |
+
# Display the original and processed images in their respective placeholders
|
353 |
+
with original_image_placeholder:
|
354 |
+
st.image(
|
355 |
+
uploaded_first_image,
|
356 |
+
caption="Original Image",
|
357 |
+
use_column_width=True,
|
358 |
+
clamp=True,
|
359 |
+
)
|
360 |
+
|
361 |
+
with processed_image_placeholder:
|
362 |
+
st.image(
|
363 |
+
processed_first_image,
|
364 |
+
caption="Processed Image",
|
365 |
+
use_column_width=True,
|
366 |
+
clamp=True,
|
367 |
+
)
|
368 |
+
|
369 |
+
# Update the base image with the previously processed image output
|
370 |
+
uploaded_first_image = processed_first_image
|
371 |
+
|
372 |
+
|
373 |
+
#######################################################################################################
|
374 |
+
# Display selected augmentation technique parameters as DataFrame
|
375 |
+
#######################################################################################################
|
376 |
+
|
377 |
+
|
378 |
+
# Check if any augmentations have been defined
|
379 |
+
if (augmentations_params.keys()) and (not error_in_parameters):
|
380 |
+
# Create a dropdown for selecting an augmentation technique or 'All'
|
381 |
+
selected_augmentation = st.selectbox(
|
382 |
+
"Select augmentation technique",
|
383 |
+
options=["All"] + list(augmentations_params.keys()),
|
384 |
+
)
|
385 |
+
else:
|
386 |
+
selected_augmentation = None
|
387 |
+
|
388 |
+
# Create the DataFrame from the accumulated data
|
389 |
+
augmentations_df = augmentation_functions.create_augmentations_dataframe(
|
390 |
+
augmentations_params, augmentation_params_df
|
391 |
+
)
|
392 |
+
augmentations_df["Value"] = augmentations_df["Value"].astype(
|
393 |
+
str
|
394 |
+
) # Ensure consistent data types and handle potential serialization issues
|
395 |
+
|
396 |
+
# Filter the DataFrame based on the selected augmentation
|
397 |
+
if selected_augmentation != "All":
|
398 |
+
filtered_augmentations_df = augmentations_df[
|
399 |
+
augmentations_df["augmentation"] == selected_augmentation
|
400 |
+
]
|
401 |
+
else:
|
402 |
+
filtered_augmentations_df = augmentations_df
|
403 |
+
|
404 |
+
# Check if the filtered dataframe is not empty and the selected configurations are valid
|
405 |
+
if (not filtered_augmentations_df.empty) and (not error_in_parameters):
|
406 |
+
# Display the DataFrame in Streamlit and use the full width of the container
|
407 |
+
st.dataframe(filtered_augmentations_df, use_container_width=False)
|
408 |
+
|
409 |
+
# Display code and description
|
410 |
+
code_placeholder = st.empty()
|
411 |
+
|
412 |
+
|
413 |
+
#######################################################################################################
|
414 |
+
# Process images and download processed images
|
415 |
+
#######################################################################################################
|
416 |
+
|
417 |
+
|
418 |
+
# Proceed if inputs are valid, techniques selected, and no errors in configurations
|
419 |
+
if (
|
420 |
+
st.session_state["is_valid"]
|
421 |
+
and (len(selected_augmentations) > 0)
|
422 |
+
and not error_in_parameters
|
423 |
+
):
|
424 |
+
# Create two columns
|
425 |
+
col1, col2 = st.columns(2)
|
426 |
+
|
427 |
+
# Allow user to specify the number of variations to be generated
|
428 |
+
num_variations = col1.number_input(
|
429 |
+
"Set the number of variations to be generated",
|
430 |
+
min_value=1,
|
431 |
+
max_value=3,
|
432 |
+
step=1,
|
433 |
+
)
|
434 |
+
|
435 |
+
# Checkbox to include original images and labels in the output
|
436 |
+
with col2:
|
437 |
+
for top_padding in range(2): # Top padding
|
438 |
+
st.write("")
|
439 |
+
|
440 |
+
include_original = st.checkbox(
|
441 |
+
"Include original images and labels in output", value=False
|
442 |
+
)
|
443 |
+
|
444 |
+
# Display code and download once all inputs are available
|
445 |
+
with code_placeholder:
|
446 |
+
# Generate the code with the function
|
447 |
+
if len(st.session_state["label_files"]) == 0:
|
448 |
+
generated_code = utils.generate_python_code_images(
|
449 |
+
augmentations_params,
|
450 |
+
num_variations,
|
451 |
+
include_original,
|
452 |
+
)
|
453 |
+
elif label_type == "Bboxes": # Selected label type is Bboxes
|
454 |
+
generated_code = augmentation_functions.generate_python_code_bboxes(
|
455 |
+
augmentations_params,
|
456 |
+
label_input_parameters,
|
457 |
+
num_variations,
|
458 |
+
include_original,
|
459 |
+
)
|
460 |
+
elif label_type == "Masks": # Selected label type is Bboxes
|
461 |
+
generated_code = augmentation_functions.generate_python_code_masks(
|
462 |
+
augmentations_params,
|
463 |
+
label_input_parameters,
|
464 |
+
num_variations,
|
465 |
+
include_original,
|
466 |
+
)
|
467 |
+
|
468 |
+
# Display the generated Python code with a description and provide a download button in the Streamlit app
|
469 |
+
augmentation_functions.display_code_and_download_button(generated_code)
|
470 |
+
|
471 |
+
# Create two columns
|
472 |
+
col1, col2 = st.columns(2)
|
473 |
+
|
474 |
+
# Add a button for the user to confirm their selections and proceed with processing
|
475 |
+
if col1.button("Accept and Process", use_container_width=True):
|
476 |
+
# Call the function and store the results
|
477 |
+
augmentation_functions.process_images_and_labels(
|
478 |
+
st.session_state["image_files"],
|
479 |
+
st.session_state["label_files"],
|
480 |
+
selected_augmentations,
|
481 |
+
augmentations_params,
|
482 |
+
label_type,
|
483 |
+
label_input_parameters,
|
484 |
+
num_variations,
|
485 |
+
include_original,
|
486 |
+
class_dict,
|
487 |
+
)
|
488 |
+
|
489 |
+
# Download button
|
490 |
+
col2.download_button(
|
491 |
+
label="Download",
|
492 |
+
data=st.session_state["zip_data_augmentation"],
|
493 |
+
file_name="augmented_images.zip",
|
494 |
+
mime="application/zip",
|
495 |
+
use_container_width=True,
|
496 |
+
disabled=False,
|
497 |
+
)
|
498 |
+
|
499 |
+
else:
|
500 |
+
if (len(selected_augmentations) == 0) and st.session_state["is_valid"]:
|
501 |
+
# Inform the user that no augmentation techniques have been selected
|
502 |
+
st.warning("Please select at least one augmentation technique.", icon="⚠️")
|
503 |
+
|
504 |
+
if error_in_parameters and st.session_state["is_valid"]:
|
505 |
+
# Inform the user that there are errors in parameters
|
506 |
+
st.warning(
|
507 |
+
"There are errors in the augmentation parameters. Please review your selections.",
|
508 |
+
icon="⚠️",
|
509 |
+
)
|
510 |
+
|
511 |
+
|
512 |
+
#######################################################################################################
|
513 |
+
# Display original and processed images
|
514 |
+
#######################################################################################################
|
515 |
+
|
516 |
+
|
517 |
+
# Check if image_repository and processed_image_mapping exist in session_state
|
518 |
+
if (
|
519 |
+
"image_repository_augmentation" in st.session_state
|
520 |
+
and "processed_image_mapping_augmentation" in st.session_state
|
521 |
+
):
|
522 |
+
# Number of unique images
|
523 |
+
num_unique_images = len(st.session_state["unique_images_names"])
|
524 |
+
|
525 |
+
if num_unique_images > 1:
|
526 |
+
# Create a slider to select an image index from the processed image mapping
|
527 |
+
selected_image_index = st.slider(
|
528 |
+
"Select an Image",
|
529 |
+
min_value=1,
|
530 |
+
max_value=num_unique_images, # Set the maximum to the number of unique images
|
531 |
+
step=1,
|
532 |
+
)
|
533 |
+
else:
|
534 |
+
selected_image_index = 1
|
535 |
+
|
536 |
+
# Retrieve the name of the selected original image using the slider index
|
537 |
+
selected_original_image_name = st.session_state["unique_images_names"][
|
538 |
+
selected_image_index - 1
|
539 |
+
]
|
540 |
+
|
541 |
+
# Retrieve the names of all processed variants for the selected original image
|
542 |
+
processed_variant_names = st.session_state[
|
543 |
+
"processed_image_mapping_augmentation"
|
544 |
+
].get(selected_original_image_name, [])
|
545 |
+
|
546 |
+
# Combine the original image name with its processed variants
|
547 |
+
all_image_names = [selected_original_image_name] + processed_variant_names
|
548 |
+
|
549 |
+
if len(st.session_state["label_files"]) > 0:
|
550 |
+
# Options for displaying labels on the images
|
551 |
+
label_display_options = ["No Label", "All Labels", "Specific Labels"]
|
552 |
+
|
553 |
+
# Select box for the user to choose how labels should be displayed on the images
|
554 |
+
selected_label_display_option = st.selectbox(
|
555 |
+
"Choose how to display labels:",
|
556 |
+
label_display_options,
|
557 |
+
index=0, # Default option is 'No Label'
|
558 |
+
)
|
559 |
+
|
560 |
+
# If 'All Labels' option is selected, include all class IDs
|
561 |
+
if selected_label_display_option == "All Labels":
|
562 |
+
labels_to_plot = list(class_dict.keys())
|
563 |
+
# If 'Specific Labels' option is selected, allow user to select specific class IDs
|
564 |
+
elif selected_label_display_option == "Specific Labels":
|
565 |
+
selected_class_names = st.multiselect(
|
566 |
+
"Select specific labels to display",
|
567 |
+
list(class_names_to_ids.keys()),
|
568 |
+
class_dict[1],
|
569 |
+
)
|
570 |
+
labels_to_plot = [class_names_to_ids[name] for name in selected_class_names]
|
571 |
+
else:
|
572 |
+
selected_label_display_option = "No Label"
|
573 |
+
|
574 |
+
# Display images in a grid
|
575 |
+
num_images = len(all_image_names)
|
576 |
+
num_columns = 4
|
577 |
+
for i in range(0, num_images, num_columns):
|
578 |
+
cols = st.columns(num_columns)
|
579 |
+
for j in range(num_columns):
|
580 |
+
image_index = i + j
|
581 |
+
if image_index < num_images:
|
582 |
+
image_name = all_image_names[image_index]
|
583 |
+
image_data = st.session_state["image_repository_augmentation"][
|
584 |
+
image_name
|
585 |
+
]["image"]
|
586 |
+
label_file = st.session_state["image_repository_augmentation"][
|
587 |
+
image_name
|
588 |
+
]["label"]
|
589 |
+
|
590 |
+
# Overlay labels on the image based on the selected option
|
591 |
+
if selected_label_display_option in ["All Labels", "Specific Labels"]:
|
592 |
+
# Overlay labels if selected
|
593 |
+
modified_image = augmentation_functions.overlay_labels(
|
594 |
+
image=image_data.copy(),
|
595 |
+
labels_to_plot=labels_to_plot,
|
596 |
+
label_file=label_file,
|
597 |
+
label_type=label_type,
|
598 |
+
colors=colors,
|
599 |
+
class_dict=class_dict,
|
600 |
+
)
|
601 |
+
else:
|
602 |
+
# Use the original image without overlay if 'No Label' is selected
|
603 |
+
modified_image = image_data
|
604 |
+
|
605 |
+
# Display the image in the respective column with a caption
|
606 |
+
with cols[j]:
|
607 |
+
st.image(
|
608 |
+
modified_image,
|
609 |
+
clamp=True,
|
610 |
+
caption=image_name,
|
611 |
+
use_column_width=True,
|
612 |
+
)
|
pages/4_Model_Training.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Set the page config
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
st.set_page_config(
|
5 |
+
page_title="Model_Training",
|
6 |
+
page_icon=":open_file_folder:",
|
7 |
+
layout="wide",
|
8 |
+
initial_sidebar_state="collapsed",
|
9 |
+
)
|
10 |
+
|
11 |
+
# Importing necessary libraries
|
12 |
+
import utils
|
13 |
+
import streamlit as st
|
14 |
+
import Functions.model_training_functions as model_training_functions
|
15 |
+
|
16 |
+
|
17 |
+
# Display the page title
|
18 |
+
st.title("Model Training")
|
19 |
+
|
20 |
+
# # Clear the Streamlit session state on the first load of the page
|
21 |
+
# utils.clear_session_state_on_first_load("model_training_clear")
|
22 |
+
|
23 |
+
# List of session state keys to initialize if they are not already present
|
24 |
+
session_state_keys = [
|
25 |
+
"file_uploader_split_key_training",
|
26 |
+
"file_uploader_train_key_training",
|
27 |
+
"file_uploader_val_key_training",
|
28 |
+
"file_uploader_test_key_training",
|
29 |
+
"number_input_train_key",
|
30 |
+
"number_input_val_key",
|
31 |
+
"number_input_test_key",
|
32 |
+
"split_method_key",
|
33 |
+
"training_type_key",
|
34 |
+
"class_labels_input_key_training",
|
35 |
+
]
|
36 |
+
|
37 |
+
# Iterate through each session state key
|
38 |
+
for key in session_state_keys:
|
39 |
+
# Check if the key is not already in the session state
|
40 |
+
if key not in st.session_state:
|
41 |
+
# Initialize the key with a dictionary containing itself set to True
|
42 |
+
st.session_state[key] = {key: True}
|
43 |
+
|
44 |
+
# Initialize session state variables if not present
|
45 |
+
if "validation_triggered" not in st.session_state:
|
46 |
+
st.session_state["validation_triggered"] = False
|
47 |
+
|
48 |
+
if "uploaded_files_cache_processing" not in st.session_state:
|
49 |
+
st.session_state["uploaded_files_cache_processing"] = False
|
50 |
+
|
51 |
+
# Initialize session state variables if not present
|
52 |
+
if "is_valid" not in st.session_state:
|
53 |
+
st.session_state["is_valid"] = False
|
54 |
+
|
55 |
+
# Container for file uploaders
|
56 |
+
file_uploader_container = st.container()
|
57 |
+
|
58 |
+
# Dictionary for mapping the user-friendly terms to technical label types
|
59 |
+
label_type_mapping = {"Object Detection": "Bboxes", "Instance Segmentation": "Masks"}
|
60 |
+
|
61 |
+
# Create two columns for widgets
|
62 |
+
column_select_training, column_split_method = st.columns(2)
|
63 |
+
|
64 |
+
# Dropdown for selecting the training type
|
65 |
+
with column_select_training:
|
66 |
+
selected_training = st.selectbox(
|
67 |
+
"Select the training type:",
|
68 |
+
list(label_type_mapping.keys()),
|
69 |
+
index=0,
|
70 |
+
on_change=utils.reset_validation_trigger,
|
71 |
+
key=st.session_state["training_type_key"],
|
72 |
+
)
|
73 |
+
|
74 |
+
# Getting the corresponding label type
|
75 |
+
label_type = label_type_mapping[selected_training]
|
76 |
+
|
77 |
+
# Toggle for choosing the split method
|
78 |
+
with column_split_method:
|
79 |
+
split_method = st.radio(
|
80 |
+
"Select the dataset split method:",
|
81 |
+
["Percentage Split", "Direct Upload"],
|
82 |
+
horizontal=True,
|
83 |
+
on_change=utils.reset_validation_trigger,
|
84 |
+
key=st.session_state["split_method_key"],
|
85 |
+
)
|
86 |
+
|
87 |
+
# Text area for user to input class labels
|
88 |
+
class_labels_input = st.text_area(
|
89 |
+
"Enter class labels, separated by commas:",
|
90 |
+
utils.sample_class_labels,
|
91 |
+
on_change=utils.reset_validation_trigger,
|
92 |
+
key=st.session_state["class_labels_input_key_training"],
|
93 |
+
) # Example default values
|
94 |
+
class_labels_input = (
|
95 |
+
class_labels_input.strip()
|
96 |
+
) # Remove unecessary space form start and end
|
97 |
+
|
98 |
+
# Generating a dictionary mapping class IDs to their respective labels
|
99 |
+
try:
|
100 |
+
class_labels = [
|
101 |
+
label.strip() for label in class_labels_input.split(",") if label.strip()
|
102 |
+
]
|
103 |
+
class_dict = {i: label for i, label in enumerate(class_labels)}
|
104 |
+
# Invert the class_dict to map class names to class IDs
|
105 |
+
class_names_to_ids = {v: k for k, v in class_dict.items()}
|
106 |
+
|
107 |
+
except Exception as e:
|
108 |
+
st.warning(
|
109 |
+
"Invalid format for class labels. Please enter labels separated by commas.",
|
110 |
+
icon="⚠️",
|
111 |
+
)
|
112 |
+
class_dict, class_names_to_ids = (
|
113 |
+
{},
|
114 |
+
{},
|
115 |
+
) # Keeping class_dict and class_names_to_ids as an empty
|
116 |
+
|
117 |
+
# Note to users
|
118 |
+
st.markdown(
|
119 |
+
"""
|
120 |
+
<div style='text-align: justify;'>
|
121 |
+
<b>Note to Users:</b>
|
122 |
+
<ul>
|
123 |
+
<li>When moving to another page or if you wish to upload a new set of images and labels, don't forget to hit the <b>Reset</b> button. This helps in faster computation and frees up unused memory, ensuring smoother operation.</li>
|
124 |
+
<li>Select the training type, class labels, dataset split method and its parameters before uploading large data for faster computation and more efficient processing.</li>
|
125 |
+
</ul>
|
126 |
+
</div>
|
127 |
+
""",
|
128 |
+
unsafe_allow_html=True,
|
129 |
+
)
|
130 |
+
|
131 |
+
# Create two columns for input percentages
|
132 |
+
validate_button_col, reset_button_col = st.columns(2)
|
133 |
+
|
134 |
+
with reset_button_col:
|
135 |
+
# Check if the 'Reset' button is pressed
|
136 |
+
if st.button("Reset", use_container_width=True):
|
137 |
+
# Clear folders
|
138 |
+
model_training_functions.delete_and_recreate_folder(
|
139 |
+
model_training_functions.get_path("output")
|
140 |
+
)
|
141 |
+
model_training_functions.clear_data_folders()
|
142 |
+
|
143 |
+
# List of session state keys that need to be reset
|
144 |
+
session_state_keys = [
|
145 |
+
"file_uploader_split_key_training",
|
146 |
+
"file_uploader_train_key_training",
|
147 |
+
"file_uploader_val_key_training",
|
148 |
+
"file_uploader_test_key_training",
|
149 |
+
"number_input_train_key",
|
150 |
+
"number_input_val_key",
|
151 |
+
"number_input_test_key",
|
152 |
+
"split_method_key",
|
153 |
+
"training_type_key",
|
154 |
+
"class_labels_input_key_training",
|
155 |
+
]
|
156 |
+
|
157 |
+
# Iterate through each session state key
|
158 |
+
for session_state_key in session_state_keys:
|
159 |
+
# Toggle the keys to reset their states
|
160 |
+
current_value = st.session_state[session_state_key][session_state_key]
|
161 |
+
updated_value = not current_value # Invert the current value
|
162 |
+
|
163 |
+
# Update each key in the session state with the toggled value
|
164 |
+
st.session_state[session_state_key] = {session_state_key: updated_value}
|
165 |
+
|
166 |
+
# Clear all other session state keys except for widget_state_keys
|
167 |
+
for key in list(st.session_state.keys()):
|
168 |
+
if key not in session_state_keys:
|
169 |
+
del st.session_state[key]
|
170 |
+
|
171 |
+
# Clear global variables except for protected and Streamlit module
|
172 |
+
global_vars = list(globals().keys())
|
173 |
+
vars_to_delete = [
|
174 |
+
var for var in global_vars if not var.startswith("_") and var != "st"
|
175 |
+
]
|
176 |
+
for var in vars_to_delete:
|
177 |
+
del globals()[var]
|
178 |
+
|
179 |
+
# Clear the Streamlit caches
|
180 |
+
st.cache_resource.clear()
|
181 |
+
st.cache_data.clear()
|
182 |
+
|
183 |
+
# Rerun the app to reflect the reset state
|
184 |
+
st.rerun()
|
185 |
+
|
186 |
+
# Code for "Percentage Split" method
|
187 |
+
if split_method == "Percentage Split":
|
188 |
+
with file_uploader_container:
|
189 |
+
# User uploads images and labels
|
190 |
+
utils.display_file_uploader(
|
191 |
+
"uploaded_files",
|
192 |
+
"Choose images and labels...",
|
193 |
+
st.session_state["file_uploader_split_key_training"],
|
194 |
+
st.session_state["uploaded_files_cache_processing"],
|
195 |
+
)
|
196 |
+
|
197 |
+
# Create three columns for input percentages
|
198 |
+
col1, col2, col3 = st.columns(3)
|
199 |
+
|
200 |
+
# User specifies split percentages
|
201 |
+
train_pct = col1.number_input(
|
202 |
+
"Train Set Percentage",
|
203 |
+
0,
|
204 |
+
100,
|
205 |
+
70,
|
206 |
+
1,
|
207 |
+
on_change=utils.reset_validation_trigger,
|
208 |
+
key=st.session_state["number_input_train_key"],
|
209 |
+
)
|
210 |
+
test_pct = col2.number_input(
|
211 |
+
"Test Set Percentage",
|
212 |
+
0,
|
213 |
+
100,
|
214 |
+
15,
|
215 |
+
1,
|
216 |
+
on_change=utils.reset_validation_trigger,
|
217 |
+
key=st.session_state["number_input_val_key"],
|
218 |
+
)
|
219 |
+
val_pct = col3.number_input(
|
220 |
+
"Validation Set Percentage",
|
221 |
+
0,
|
222 |
+
100,
|
223 |
+
15,
|
224 |
+
1,
|
225 |
+
on_change=utils.reset_validation_trigger,
|
226 |
+
key=st.session_state["number_input_test_key"],
|
227 |
+
)
|
228 |
+
|
229 |
+
# Check if the total percentage equals 100%
|
230 |
+
pct_check = train_pct + test_pct + val_pct
|
231 |
+
|
232 |
+
# Validating the input percentages
|
233 |
+
pct_condition_check = (
|
234 |
+
pct_check == 100
|
235 |
+
and train_pct > 0
|
236 |
+
and val_pct > 0
|
237 |
+
and model_training_functions.check_min_images(
|
238 |
+
len(st.session_state["uploaded_files"]), train_pct, val_pct, test_pct
|
239 |
+
)
|
240 |
+
)
|
241 |
+
|
242 |
+
if not pct_condition_check:
|
243 |
+
file_uploader_container.warning(
|
244 |
+
"The percentages for train, test, and validation sets should add up to 100%, and train and validation set should not be empty.",
|
245 |
+
icon="⚠️",
|
246 |
+
)
|
247 |
+
|
248 |
+
# Button to trigger validation
|
249 |
+
if validate_button_col.button("Validate Input", use_container_width=True):
|
250 |
+
st.session_state["validation_triggered"] = True
|
251 |
+
|
252 |
+
st.session_state["is_valid"] = model_training_functions.check_valid_labels(
|
253 |
+
st.session_state["uploaded_files"], label_type, class_dict
|
254 |
+
)
|
255 |
+
|
256 |
+
if st.session_state["is_valid"]:
|
257 |
+
model_training_functions.create_yolo_config_file(
|
258 |
+
model_training_functions.get_path("config"),
|
259 |
+
class_labels,
|
260 |
+
)
|
261 |
+
model_training_functions.clear_data_folders()
|
262 |
+
paired_files = model_training_functions.pair_files(
|
263 |
+
st.session_state["uploaded_files"]
|
264 |
+
)
|
265 |
+
model_training_functions.split_and_save_files(
|
266 |
+
paired_files, train_pct, test_pct
|
267 |
+
)
|
268 |
+
|
269 |
+
# Process files if input is valid
|
270 |
+
if st.session_state["validation_triggered"] and (
|
271 |
+
pct_condition_check and st.session_state["is_valid"]
|
272 |
+
):
|
273 |
+
model_training_functions.start_yolo_training(selected_training, class_labels)
|
274 |
+
else:
|
275 |
+
# Display a warning message if the validation is not successful or conditions are not met
|
276 |
+
st.warning(
|
277 |
+
"Please upload valid input, select valid parameters, and click **Validate Input**.",
|
278 |
+
icon="⚠️",
|
279 |
+
)
|
280 |
+
|
281 |
+
# Code for "Direct Upload" method
|
282 |
+
elif split_method == "Direct Upload":
|
283 |
+
with file_uploader_container:
|
284 |
+
# Create three columns for uploading train, val, and test files
|
285 |
+
col1, col2, col3 = st.columns(3)
|
286 |
+
|
287 |
+
with col1:
|
288 |
+
utils.display_file_uploader(
|
289 |
+
"uploaded_train_files",
|
290 |
+
"Upload Training Images and Labels",
|
291 |
+
st.session_state["file_uploader_train_key_training"],
|
292 |
+
st.session_state["uploaded_files_cache_processing"],
|
293 |
+
)
|
294 |
+
|
295 |
+
with col2:
|
296 |
+
utils.display_file_uploader(
|
297 |
+
"uploaded_val_files",
|
298 |
+
"Upload Validation Images and Labels",
|
299 |
+
st.session_state["file_uploader_val_key_training"],
|
300 |
+
st.session_state["uploaded_files_cache_processing"],
|
301 |
+
)
|
302 |
+
|
303 |
+
with col3:
|
304 |
+
utils.display_file_uploader(
|
305 |
+
"uploaded_test_files",
|
306 |
+
"Upload Test Images and Labels",
|
307 |
+
st.session_state["file_uploader_test_key_training"],
|
308 |
+
st.session_state["uploaded_files_cache_processing"],
|
309 |
+
)
|
310 |
+
|
311 |
+
# Check for valid input
|
312 |
+
pct_condition_check = (
|
313 |
+
len(st.session_state["uploaded_train_files"]) > 0
|
314 |
+
and len(st.session_state["uploaded_val_files"]) > 0
|
315 |
+
)
|
316 |
+
|
317 |
+
if not pct_condition_check:
|
318 |
+
file_uploader_container.warning(
|
319 |
+
"The train and validation set should not be empty.",
|
320 |
+
icon="⚠️",
|
321 |
+
)
|
322 |
+
|
323 |
+
# Button to trigger validation
|
324 |
+
if validate_button_col.button("Validate Input", use_container_width=True):
|
325 |
+
st.session_state["validation_triggered"] = True
|
326 |
+
|
327 |
+
st.session_state["is_valid"] = model_training_functions.check_valid_labels(
|
328 |
+
st.session_state["uploaded_train_files"]
|
329 |
+
+ st.session_state["uploaded_val_files"]
|
330 |
+
+ st.session_state["uploaded_test_files"],
|
331 |
+
label_type,
|
332 |
+
class_dict,
|
333 |
+
)
|
334 |
+
|
335 |
+
if st.session_state["is_valid"]:
|
336 |
+
model_training_functions.create_yolo_config_file(
|
337 |
+
model_training_functions.get_path("config"),
|
338 |
+
class_labels,
|
339 |
+
)
|
340 |
+
model_training_functions.clear_data_folders()
|
341 |
+
model_training_functions.save_files_to_folder(
|
342 |
+
st.session_state["uploaded_train_files"], "train"
|
343 |
+
)
|
344 |
+
model_training_functions.save_files_to_folder(
|
345 |
+
st.session_state["uploaded_val_files"], "val"
|
346 |
+
)
|
347 |
+
|
348 |
+
# Only save test files if they are uploaded
|
349 |
+
if len(st.session_state["uploaded_test_files"]) > 0:
|
350 |
+
model_training_functions.save_files_to_folder(
|
351 |
+
st.session_state["uploaded_test_files"], "test"
|
352 |
+
)
|
353 |
+
|
354 |
+
# Process files if input is valid
|
355 |
+
if st.session_state["validation_triggered"] and (
|
356 |
+
pct_condition_check and st.session_state["is_valid"]
|
357 |
+
):
|
358 |
+
model_training_functions.start_yolo_training(selected_training, class_labels)
|
359 |
+
else:
|
360 |
+
# Display a warning message if the validation is not successful or conditions are not met
|
361 |
+
st.warning(
|
362 |
+
"Please upload valid input, select valid parameters, and click **Validate Input**.",
|
363 |
+
icon="⚠️",
|
364 |
+
)
|
requirements.txt
ADDED
Binary file (2.99 kB). View file
|
|
utils.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Importing necessary libraries
|
2 |
+
import cv2
|
3 |
+
import utils
|
4 |
+
import inspect
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import streamlit as st
|
8 |
+
import albumentations as A
|
9 |
+
|
10 |
+
# Please retain all the imported libraries, even if they appear to be unused, as they are necessary for the input part of the code (dynamic_input function)
|
11 |
+
|
12 |
+
|
13 |
+
# Sample class labels
|
14 |
+
sample_class_labels = "person, bicycle, car, motorcycle, airplane, bus, train, truck, boat, traffic light, fire hydrant, stop sign, parking meter, bench, bird, cat, dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella, handbag, tie, suitcase, frisbee, skis, snowboard, sports ball, kite, baseball bat, baseball glove, skateboard, surfboard, tennis racket, bottle, wine glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange, broccoli, carrot, hot dog, pizza, donut, cake, chair, couch, potted plant, bed, dining table, toilet, tv, laptop, mouse, remote, keyboard, cell phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy bear, hair drier, toothbrush"
|
15 |
+
|
16 |
+
|
17 |
+
# Model info
|
18 |
+
models_info = {
|
19 |
+
"YOLOv8n": {
|
20 |
+
"size": 640,
|
21 |
+
"mAPval": 37.3,
|
22 |
+
"speed_cpu": 80.4,
|
23 |
+
"speed_gpu": 0.99,
|
24 |
+
"params": 3.2,
|
25 |
+
"flops": 8.7,
|
26 |
+
},
|
27 |
+
"YOLOv8s": {
|
28 |
+
"size": 640,
|
29 |
+
"mAPval": 44.9,
|
30 |
+
"speed_cpu": 128.4,
|
31 |
+
"speed_gpu": 1.20,
|
32 |
+
"params": 11.2,
|
33 |
+
"flops": 28.6,
|
34 |
+
},
|
35 |
+
"YOLOv8m": {
|
36 |
+
"size": 640,
|
37 |
+
"mAPval": 50.2,
|
38 |
+
"speed_cpu": 234.7,
|
39 |
+
"speed_gpu": 1.83,
|
40 |
+
"params": 25.9,
|
41 |
+
"flops": 78.9,
|
42 |
+
},
|
43 |
+
"YOLOv8l": {
|
44 |
+
"size": 640,
|
45 |
+
"mAPval": 52.9,
|
46 |
+
"speed_cpu": 375.2,
|
47 |
+
"speed_gpu": 2.39,
|
48 |
+
"params": 43.7,
|
49 |
+
"flops": 165.2,
|
50 |
+
},
|
51 |
+
"YOLOv8x": {
|
52 |
+
"size": 640,
|
53 |
+
"mAPval": 53.9,
|
54 |
+
"speed_cpu": 479.1,
|
55 |
+
"speed_gpu": 3.53,
|
56 |
+
"params": 68.2,
|
57 |
+
"flops": 257.8,
|
58 |
+
},
|
59 |
+
}
|
60 |
+
|
61 |
+
|
62 |
+
# Export formats info
|
63 |
+
export_formats = {
|
64 |
+
"TorchScript": {"format_argument": "torchscript", "arguments": ["optimize"]},
|
65 |
+
"ONNX": {
|
66 |
+
"format_argument": "onnx",
|
67 |
+
"arguments": ["half", "dynamic", "simplify", "opset"],
|
68 |
+
},
|
69 |
+
"OpenVINO": {"format_argument": "openvino", "arguments": ["half", "int8"]},
|
70 |
+
"TensorRT": {
|
71 |
+
"format_argument": "engine",
|
72 |
+
"arguments": ["half", "dynamic", "simplify", "workspace"],
|
73 |
+
},
|
74 |
+
"CoreML": {"format_argument": "coreml", "arguments": ["half", "int8", "nms"]},
|
75 |
+
"TF SavedModel": {"format_argument": "saved_model", "arguments": ["keras", "int8"]},
|
76 |
+
"TF GraphDef": {"format_argument": "pb", "arguments": []},
|
77 |
+
"TF Lite": {"format_argument": "tflite", "arguments": ["half", "int8"]},
|
78 |
+
"TF Edge TPU": {"format_argument": "edgetpu", "arguments": []},
|
79 |
+
"TF.js": {"format_argument": "tfjs", "arguments": ["half", "int8"]},
|
80 |
+
"PaddlePaddle": {"format_argument": "paddle", "arguments": []},
|
81 |
+
"ncnn": {"format_argument": "ncnn", "arguments": ["half"]},
|
82 |
+
}
|
83 |
+
|
84 |
+
|
85 |
+
# Function to add Top Padding
|
86 |
+
def top_padding(padding_aount=0):
|
87 |
+
for i in range(padding_aount):
|
88 |
+
st.write("")
|
89 |
+
|
90 |
+
|
91 |
+
# Function to load data from an Excel file
|
92 |
+
@st.cache_resource(show_spinner=False)
|
93 |
+
def load_data_from_excel(filename, sheet_name):
|
94 |
+
"""Load data from a specified sheet in an Excel file."""
|
95 |
+
return pd.read_excel(filename, sheet_name=sheet_name)
|
96 |
+
|
97 |
+
|
98 |
+
# Function to resets the validation trigger in the session state
|
99 |
+
def reset_validation_trigger():
|
100 |
+
st.session_state["validation_triggered"] = False
|
101 |
+
|
102 |
+
|
103 |
+
# Function to clear the Streamlit session state on the first load of the page
|
104 |
+
def clear_session_state_on_first_load(key):
|
105 |
+
# Check if the function has already been executed
|
106 |
+
if key not in st.session_state:
|
107 |
+
# Clear the session state
|
108 |
+
for k in list(st.session_state.keys()):
|
109 |
+
if k != key: # Preserve the key itself
|
110 |
+
del st.session_state[k]
|
111 |
+
# Mark this function as executed
|
112 |
+
st.session_state[key] = True
|
113 |
+
|
114 |
+
|
115 |
+
# Function to display a file uploader based on the cached state
|
116 |
+
def display_file_uploader(uploaded_files_type, label_value, key_value, is_cached):
|
117 |
+
if not is_cached:
|
118 |
+
st.session_state[uploaded_files_type] = st.file_uploader(
|
119 |
+
label=label_value,
|
120 |
+
type=["jpg", "png", "txt"],
|
121 |
+
accept_multiple_files=True,
|
122 |
+
on_change=utils.reset_validation_trigger,
|
123 |
+
key=key_value,
|
124 |
+
)
|
125 |
+
|
126 |
+
else:
|
127 |
+
st.file_uploader(
|
128 |
+
label_value,
|
129 |
+
type=["jpg", "png", "txt"],
|
130 |
+
accept_multiple_files=True,
|
131 |
+
disabled=True,
|
132 |
+
)
|
133 |
+
|
134 |
+
|
135 |
+
# Function to map each technique to its corresponding image types
|
136 |
+
@st.cache_resource(show_spinner=False)
|
137 |
+
def technique_image_input_mapping(available_techniques, details_df):
|
138 |
+
input_mapping_dict = {}
|
139 |
+
for technique in available_techniques:
|
140 |
+
# Filter the DataFrame for rows where the Name matches the technique
|
141 |
+
filtered_df = details_df[details_df["Name"] == technique]
|
142 |
+
|
143 |
+
# Extract the Image Types for these rows, split by comma, and strip spaces
|
144 |
+
image_types = (
|
145 |
+
filtered_df["Image Types"]
|
146 |
+
.apply(lambda x: [type.strip() for type in str(x).split(",")])
|
147 |
+
.tolist()
|
148 |
+
)
|
149 |
+
|
150 |
+
# Add to the dictionary
|
151 |
+
input_mapping_dict[technique] = image_types[0]
|
152 |
+
|
153 |
+
return input_mapping_dict
|
154 |
+
|
155 |
+
|
156 |
+
# Function to check if the image type match the allowed types
|
157 |
+
def is_image_type_allowed(input_image_type, num_channels, allowed_image_types):
|
158 |
+
for allowed_type in allowed_image_types:
|
159 |
+
# Handle general types like 'Any' or 'nan'
|
160 |
+
if allowed_type in ["Any", "nan", input_image_type]:
|
161 |
+
return True
|
162 |
+
|
163 |
+
# Handle specific multi-channel types (e.g., '3-channel uint8 images only')
|
164 |
+
if "channel" in allowed_type:
|
165 |
+
parts = allowed_type.split()
|
166 |
+
|
167 |
+
try:
|
168 |
+
allowed_channels = int(parts[0][0])
|
169 |
+
allowed_dtype = parts[1]
|
170 |
+
|
171 |
+
if (
|
172 |
+
num_channels == allowed_channels
|
173 |
+
and input_image_type == allowed_dtype
|
174 |
+
):
|
175 |
+
return True
|
176 |
+
except ValueError:
|
177 |
+
# Continue to the next allowed type
|
178 |
+
continue
|
179 |
+
|
180 |
+
# Add additional specific checks if needed
|
181 |
+
|
182 |
+
return False # If none of the conditions are met
|
183 |
+
|
184 |
+
|
185 |
+
# Function to creates a dynamic input widget based on the specified data type
|
186 |
+
def dynamic_input(
|
187 |
+
image_processing, data_type, label, input_key, input_data, default_val
|
188 |
+
):
|
189 |
+
if data_type == "bool":
|
190 |
+
# Create a checkbox for bool input
|
191 |
+
return st.checkbox(label=label, value=eval(default_val), key=input_key)
|
192 |
+
|
193 |
+
elif data_type == "int" or data_type == "float":
|
194 |
+
# Create a slider for integer / float input
|
195 |
+
return st.select_slider(
|
196 |
+
label=label,
|
197 |
+
options=eval(input_data),
|
198 |
+
value=eval(default_val),
|
199 |
+
key=input_key,
|
200 |
+
)
|
201 |
+
|
202 |
+
elif data_type == "tuple":
|
203 |
+
# Create a slider for tuple input
|
204 |
+
min_val, max_val = st.select_slider(
|
205 |
+
label=label,
|
206 |
+
options=eval(input_data),
|
207 |
+
value=(eval(default_val)[0], eval(default_val)[1]),
|
208 |
+
key=input_key,
|
209 |
+
)
|
210 |
+
return (min_val, max_val)
|
211 |
+
|
212 |
+
elif data_type == "single_select":
|
213 |
+
# Create a selectbox for single select input
|
214 |
+
return st.selectbox(
|
215 |
+
label=label,
|
216 |
+
options=eval(input_data),
|
217 |
+
index=eval(input_data).index(eval(default_val)),
|
218 |
+
)
|
219 |
+
|
220 |
+
elif data_type == "none":
|
221 |
+
# Return None without creating any input widget
|
222 |
+
return None
|
223 |
+
|
224 |
+
elif data_type == "code":
|
225 |
+
# Create a selectbox for code input
|
226 |
+
try:
|
227 |
+
# Evaluate the code entered by the user in the text input
|
228 |
+
input_code = eval(st.text_input(label=label, value=default_val))
|
229 |
+
except:
|
230 |
+
# If the user input is not valid Python code, catch the exception
|
231 |
+
st.warning(f"Invalid input for {label}. Using default value.", icon="⚠️")
|
232 |
+
# Return the default parameter value for this image processing technique
|
233 |
+
input_code = get_default_params(image_processing)[label]
|
234 |
+
|
235 |
+
# Show an example of the expected input format to the user
|
236 |
+
st.write(f"Sample Input for {label}:")
|
237 |
+
st.code(input_data, language="python", line_numbers=False)
|
238 |
+
|
239 |
+
return input_code
|
240 |
+
|
241 |
+
# Handle unsupported data types
|
242 |
+
else:
|
243 |
+
# Display an error message and return None for unsupported data types
|
244 |
+
st.error(f"Unsupported data type: {data_type}")
|
245 |
+
|
246 |
+
|
247 |
+
# Returns a dictionary of default parameters for a given Albumentations augmentation class
|
248 |
+
def get_default_params(augmentation_name):
|
249 |
+
# Retrieve the augmentation class from the Albumentations module
|
250 |
+
augmentation_class = getattr(A, augmentation_name, None)
|
251 |
+
|
252 |
+
# Inspect the constructor (__init__) of the augmentation class to get its parameters
|
253 |
+
params = inspect.signature(augmentation_class.__init__).parameters
|
254 |
+
|
255 |
+
# Create a dictionary of parameter names and their default values
|
256 |
+
default_params = {
|
257 |
+
name: param.default
|
258 |
+
for name, param in params.items()
|
259 |
+
if param.default is not inspect.Parameter.empty
|
260 |
+
}
|
261 |
+
|
262 |
+
return default_params
|
263 |
+
|
264 |
+
|
265 |
+
# Function to resize the image to a maximum dimension of 512 pixels while maintaining aspect ratio
|
266 |
+
def resize_image(image, max_size=512):
|
267 |
+
height, width = image.shape[:2]
|
268 |
+
# Calculate the ratio to maintain aspect ratio
|
269 |
+
ratio = min(max_size / height, max_size / width)
|
270 |
+
new_dimensions = (int(width * ratio), int(height * ratio))
|
271 |
+
resized_image = cv2.resize(image, new_dimensions, interpolation=cv2.INTER_AREA)
|
272 |
+
return resized_image
|
273 |
+
|
274 |
+
|
275 |
+
# Function to apply the specified transformation with optional custom parameters
|
276 |
+
def apply_albumentation(custom_params, albumentation_name):
|
277 |
+
# Get the Albumentations transformation class from the albumentation_name string
|
278 |
+
albumentation_class = getattr(A, albumentation_name, None)
|
279 |
+
|
280 |
+
# Check if the transformation class exists in the Albumentations library
|
281 |
+
if albumentation_class is not None:
|
282 |
+
# Apply the transformation with or without custom parameters
|
283 |
+
if custom_params is None:
|
284 |
+
return albumentation_class()
|
285 |
+
else:
|
286 |
+
return albumentation_class(**custom_params)
|
287 |
+
|
288 |
+
|
289 |
+
# Function to process image parameters from a DataFrame and create interactive UI
|
290 |
+
def process_image_parameters(params_df, image_processing):
|
291 |
+
# Dictionary to map UI input types to database values
|
292 |
+
type_mapping = {
|
293 |
+
"Boolean": "bool",
|
294 |
+
"Integer": "int",
|
295 |
+
"Float": "float",
|
296 |
+
"Tuple": "tuple",
|
297 |
+
"Single Select": "single_select",
|
298 |
+
"Code": "code",
|
299 |
+
"None": "none",
|
300 |
+
}
|
301 |
+
|
302 |
+
# Create a reverse mapping from database values to UI input types
|
303 |
+
reverse_type_mapping = {v: k for k, v in type_mapping.items()}
|
304 |
+
|
305 |
+
# Initialize a dictionary to store user-selected parameter values
|
306 |
+
image_processing_params = {}
|
307 |
+
|
308 |
+
if not params_df.empty:
|
309 |
+
for _, row in params_df.iterrows():
|
310 |
+
# Spacer
|
311 |
+
st.markdown("---")
|
312 |
+
|
313 |
+
param_name = row["Parameter Name"]
|
314 |
+
input_types = [
|
315 |
+
input_type.strip() for input_type in str(row["Input Type"]).split(",")
|
316 |
+
]
|
317 |
+
|
318 |
+
if len(input_types) > 1:
|
319 |
+
col1, col2 = st.columns([3, 7])
|
320 |
+
selected_type_ui = col1.selectbox(
|
321 |
+
"Choose input type",
|
322 |
+
[reverse_type_mapping[ip] for ip in input_types],
|
323 |
+
key=f"{image_processing}_{param_name}_selectbox",
|
324 |
+
)
|
325 |
+
# Map the selected UI input type to database value
|
326 |
+
selected_type = type_mapping[selected_type_ui]
|
327 |
+
selected_index = input_types.index(selected_type)
|
328 |
+
|
329 |
+
with col2:
|
330 |
+
user_input = utils.dynamic_input(
|
331 |
+
image_processing=image_processing,
|
332 |
+
data_type=selected_type,
|
333 |
+
label=param_name,
|
334 |
+
input_data=str(row[f"Input Values {selected_index+1}"]),
|
335 |
+
default_val=str(row[f"Default Value {selected_index+1}"]),
|
336 |
+
input_key=f"{image_processing}_{param_name}_{selected_type}",
|
337 |
+
)
|
338 |
+
else:
|
339 |
+
user_input = utils.dynamic_input(
|
340 |
+
image_processing=image_processing,
|
341 |
+
data_type=input_types[0],
|
342 |
+
label=param_name,
|
343 |
+
input_data=str(row[f"Input Values 1"]),
|
344 |
+
default_val=str(row[f"Default Value 1"]),
|
345 |
+
input_key=f"{image_processing}_{param_name}_{input_types[0]}",
|
346 |
+
)
|
347 |
+
|
348 |
+
st.markdown(
|
349 |
+
f"<div style='text-align: justify;'><b>Description:</b> {row['Parameter Description']}</div>",
|
350 |
+
unsafe_allow_html=True,
|
351 |
+
)
|
352 |
+
|
353 |
+
image_processing_params[param_name] = user_input
|
354 |
+
|
355 |
+
# Top-padding
|
356 |
+
top_padding(2)
|
357 |
+
|
358 |
+
return image_processing_params
|
359 |
+
|
360 |
+
|
361 |
+
# Function to generates a DataFrame detailing image processing technique parameters and descriptions
|
362 |
+
@st.cache_resource(show_spinner=False)
|
363 |
+
def create_image_processings_dataframe(
|
364 |
+
image_processings_params, image_processing_params_db
|
365 |
+
):
|
366 |
+
data = []
|
367 |
+
for aug_name, params in image_processings_params.items():
|
368 |
+
for param_name, param_value in params.items():
|
369 |
+
# Retrieve relevant image_processing information from the database
|
370 |
+
image_processing_info = image_processing_params_db[
|
371 |
+
image_processing_params_db["Name"] == aug_name
|
372 |
+
]
|
373 |
+
param_info = image_processing_info[
|
374 |
+
image_processing_info["Parameter Name"] == param_name
|
375 |
+
]
|
376 |
+
|
377 |
+
# Check if the parameter information exists in the database
|
378 |
+
if not param_info.empty:
|
379 |
+
# Get the description of the current parameter
|
380 |
+
param_description = param_info["Parameter Description"].iloc[0]
|
381 |
+
else:
|
382 |
+
param_description = "Description not available"
|
383 |
+
|
384 |
+
# Append image_processing name, parameter name, its value, and description to the data list
|
385 |
+
data.append([aug_name, param_name, param_value, param_description])
|
386 |
+
|
387 |
+
# Create the DataFrame from the accumulated data
|
388 |
+
image_processings_df = pd.DataFrame(
|
389 |
+
data, columns=["image_processing", "Parameter", "Value", "Description"]
|
390 |
+
)
|
391 |
+
return image_processings_df
|
392 |
+
|
393 |
+
|
394 |
+
def generate_python_code_images(
|
395 |
+
augmentations_params,
|
396 |
+
num_variations=1,
|
397 |
+
include_original=False,
|
398 |
+
):
|
399 |
+
# Start with necessary library imports
|
400 |
+
code_str = "# Importing necessary libraries\n"
|
401 |
+
code_str += "import os\nimport cv2\nimport shutil\nimport albumentations as A\n\n"
|
402 |
+
|
403 |
+
# Paths for input and output directories
|
404 |
+
code_str += "# Define the paths for input and output directories\n"
|
405 |
+
code_str += "input_directory = 'path/to/input'\n"
|
406 |
+
code_str += "output_directory = 'path/to/output'\n\n"
|
407 |
+
|
408 |
+
# Function to create an augmentation pipeline
|
409 |
+
code_str += "# Function to create an augmentation pipeline using Albumentations\n"
|
410 |
+
code_str += "def process_image(image):\n"
|
411 |
+
code_str += " # Define the sequence of augmentation techniques\n"
|
412 |
+
code_str += " pipeline = A.Compose([\n"
|
413 |
+
for technique, params in augmentations_params.items():
|
414 |
+
code_str += f" A.{technique}({', '.join(f'{k}={v}' for k, v in params.items())}),\n"
|
415 |
+
code_str += " ])\n"
|
416 |
+
code_str += " # Apply the augmentation pipeline\n"
|
417 |
+
code_str += " return pipeline(image=image)['image']\n\n"
|
418 |
+
|
419 |
+
# Function to process a batch of images
|
420 |
+
code_str += "# Function to process a batch of images\n"
|
421 |
+
code_str += "def process_batch(input_directory, output_directory, include_original=False, num_variations=1):\n"
|
422 |
+
code_str += " for filename in os.listdir(input_directory):\n"
|
423 |
+
code_str += " if filename.lower().endswith(('.png', '.jpg', '.jpeg')):\n"
|
424 |
+
code_str += " image_path = os.path.join(input_directory, filename)\n\n"
|
425 |
+
code_str += " # Read the image\n"
|
426 |
+
code_str += " image = cv2.imread(image_path)\n\n"
|
427 |
+
|
428 |
+
# Include original image
|
429 |
+
code_str += " # Include original image\n"
|
430 |
+
if include_original:
|
431 |
+
code_str += " shutil.copy2(image_path, output_directory)\n\n"
|
432 |
+
|
433 |
+
code_str += " # Generate variations for each image and process them\n"
|
434 |
+
code_str += f" for variation in range({num_variations}):\n"
|
435 |
+
code_str += " processed_image = process_image(image)\n\n"
|
436 |
+
code_str += " # Save the processed image\n"
|
437 |
+
code_str += " output_filename = f'processed_{os.path.splitext(filename)[0]}_{variation}{os.path.splitext(filename)[1]}'\n"
|
438 |
+
code_str += " cv2.imwrite(os.path.join(output_directory, output_filename), processed_image)\n\n"
|
439 |
+
|
440 |
+
# Execute the batch processing function
|
441 |
+
code_str += (
|
442 |
+
"# Execute the batch processing function with the specified parameters\n"
|
443 |
+
)
|
444 |
+
code_str += "process_batch(input_directory, output_directory)\n"
|
445 |
+
|
446 |
+
return code_str
|