samkeet commited on
Commit
3d90a2e
·
verified ·
1 Parent(s): 987585c

First Commit

Browse files
.streamlit/config.toml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [theme]
2
+ primaryColor = "#101276"
3
+ backgroundColor = "#FFFFFF"
4
+ textColor = "#006EC0"
5
+ secondaryBackgroundColor = "#D2E5F2"
6
+ font = "sans serif"
7
+
8
+ [server]
9
+ maxUploadSize = 10000
Functions/__pycache__/image_augmentation_functions.cpython-311.pyc ADDED
Binary file (41.7 kB). View file
 
Functions/__pycache__/image_processing_functions.cpython-311.pyc ADDED
Binary file (18.3 kB). View file
 
Functions/__pycache__/model_training_functions.cpython-311.pyc ADDED
Binary file (73.9 kB). View file
 
Functions/image_augmentation_functions.py ADDED
@@ -0,0 +1,1027 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing necessary libraries
2
+ import io
3
+ import os
4
+ import cv2
5
+ import utils
6
+ import random
7
+ import zipfile
8
+ import numpy as np
9
+ import pandas as pd
10
+ from PIL import Image
11
+ import streamlit as st
12
+ import albumentations as A
13
+ import matplotlib.pyplot as plt
14
+
15
+
16
+ # Function to fetch the names of augmentation techniques applicable to the selected option
17
+ st.cache_resource(show_spinner=False)
18
+
19
+
20
+ def get_applicable_techniques(df, option):
21
+ return df[df[option] == "Applicable"]["Name"]
22
+
23
+
24
+ # Function to generate unique colors for each given list of unique numbers
25
+ def generate_unique_colors(unique_numbers):
26
+ # Get a color map
27
+ cmap = plt.get_cmap("hsv")
28
+
29
+ # Generate unique colors using the colormap
30
+ colors = {
31
+ number: cmap(i / len(unique_numbers))[:3]
32
+ for i, number in enumerate(unique_numbers)
33
+ }
34
+
35
+ # Convert colors from RGB to BGR format and scale to 0-255
36
+ colors_bgr = {
37
+ k: (int(v[2] * 255), int(v[1] * 255), int(v[0] * 255))
38
+ for k, v in colors.items()
39
+ }
40
+
41
+ return colors_bgr
42
+
43
+
44
+ # Function to adjust zero values to a small positive number
45
+ def adjust_zero_value(value):
46
+ return value if value != 0 else 1e-9
47
+
48
+
49
+ # Function to parse a YOLO label file and convert it into Albumentations-compatible Bboxes format
50
+ def bboxes_label(label_file, class_dict):
51
+ bboxes_data = {"bboxes": [], "class_labels": []}
52
+ for line in label_file:
53
+ try:
54
+ # Extracting class_id and bounding box coordinates from each line
55
+ class_id, x_center, y_center, width, height = map(
56
+ float, line.decode().strip().split()
57
+ )
58
+ class_id += 1 # Shifting its starting value to 1, since 0 is reserved for the background
59
+
60
+ # Adjusting bounding box coordinates to avoid zero values
61
+ # Albumentations does not accept zero values, but they are acceptable in YOLO format
62
+ x_center, y_center, width, height = map(
63
+ adjust_zero_value, [x_center, y_center, width, height]
64
+ )
65
+
66
+ # Check if values are within the expected range and class_id exists in class_dict
67
+ if (
68
+ 0 <= x_center <= 1
69
+ and 0 <= y_center <= 1
70
+ and 0 <= width <= 1
71
+ and 0 <= height <= 1
72
+ and class_id in class_dict.keys()
73
+ ):
74
+ bboxes_data["bboxes"].append([x_center, y_center, width, height])
75
+ bboxes_data["class_labels"].append(class_id)
76
+ else:
77
+ return None # Return None if any value is out of range or class_id is invalid
78
+
79
+ except Exception as e:
80
+ # Return None if any exception is encountered
81
+ return None
82
+
83
+ # Return None if the file is empty or no valid data found
84
+ return bboxes_data if bboxes_data["bboxes"] else None
85
+
86
+
87
+ # Function to parse a YOLO label file and convert it into compatible Mask format
88
+ def masks_label(label_file, class_dict):
89
+ mask_data = {"masks": [], "class_labels": []}
90
+ for line in label_file:
91
+ try:
92
+ # Clean up the line and split into parts
93
+ parts = line.decode().strip().split()
94
+ class_id = (
95
+ int(parts[0]) + 1
96
+ ) # Shifting its starting value to 1, since 0 is reserved for the background
97
+ points = [float(p) for p in parts[1:]]
98
+
99
+ # Check if class_id exists in class_dict and coordinates are within the expected range
100
+ if class_id in class_dict.keys() and all(0 <= p <= 1 for p in points):
101
+ # Group points into (x, y) tuples
102
+ polygon = [(points[i], points[i + 1]) for i in range(0, len(points), 2)]
103
+
104
+ # Append class label and polygon to the mask data
105
+ mask_data["class_labels"].append(class_id)
106
+ mask_data["masks"].append(polygon)
107
+ else:
108
+ return None # Return None if class_id is invalid or coordinates are out of range
109
+
110
+ except Exception as e:
111
+ # Return None if any exception is encountered
112
+ return None
113
+
114
+ # Return None if the file is empty or no valid data found
115
+ return mask_data if mask_data["masks"] else None
116
+
117
+
118
+ # Function to generate mask for albumentations format
119
+ def generate_mask(masks, class_ids, image_height, image_width):
120
+ # Create an empty mask of the same size as the image, filled with 0 for background
121
+ mask = np.full((image_height, image_width), 0, dtype=np.int32)
122
+
123
+ # Iterate over each polygon and its corresponding class_id
124
+ for polygon, class_id in zip(masks, class_ids):
125
+ # Scale the polygon points to the image size
126
+ scaled_polygon = [
127
+ (int(x * image_width), int(y * image_height)) for x, y in polygon
128
+ ]
129
+
130
+ # Draw the polygon on the mask
131
+ cv2.fillPoly(mask, [np.array(scaled_polygon, dtype=np.int32)], color=class_id)
132
+
133
+ return mask
134
+
135
+
136
+ # Function to convert a single-channel mask back to YOLO format
137
+ def mask_to_yolo(mask):
138
+ yolo_data = {"masks": [], "class_labels": []}
139
+ unique_values = np.unique(mask)
140
+
141
+ for value in unique_values:
142
+ if value == 0: # Skip the background
143
+ continue
144
+
145
+ # Extract mask for individual object
146
+ single_object_mask = np.uint8(mask == value)
147
+
148
+ # Find contours
149
+ contours, _ = cv2.findContours(
150
+ single_object_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
151
+ )
152
+
153
+ for contour in contours:
154
+ # Normalize and flatten contour points
155
+ normalized_contour = [
156
+ (point[0][0] / mask.shape[1], point[0][1] / mask.shape[0])
157
+ for point in contour
158
+ ]
159
+ yolo_data["masks"].append(normalized_contour)
160
+ yolo_data["class_labels"].append(value)
161
+
162
+ return yolo_data
163
+
164
+
165
+ # Function to create a user interface for adjusting bounding box parameters
166
+ def bbox_params():
167
+ with st.expander("Bounding Box Parameters"):
168
+ # Create two columns for the input widgets
169
+ col1, col2 = st.columns(2)
170
+
171
+ with col1:
172
+ min_area = st.number_input(
173
+ "Minimum Area",
174
+ min_value=0,
175
+ max_value=1000,
176
+ value=0,
177
+ step=1,
178
+ help="Minimum area of a bounding box. Boxes smaller than this will be removed.",
179
+ on_change=utils.reset_validation_trigger,
180
+ key=st.session_state["bbox1_key"],
181
+ )
182
+ min_visibility = st.number_input(
183
+ "Minimum Visibility",
184
+ min_value=0,
185
+ max_value=1000,
186
+ value=0,
187
+ step=1,
188
+ help="Minimum fraction of area for a bounding box to remain in the list.",
189
+ on_change=utils.reset_validation_trigger,
190
+ key=st.session_state["bbox2_key"],
191
+ )
192
+
193
+ with col2:
194
+ min_width = st.number_input(
195
+ "Minimum Width",
196
+ min_value=0,
197
+ max_value=1000,
198
+ value=0,
199
+ step=1,
200
+ help="Minimum width of a bounding box. Boxes narrower than this will be removed.",
201
+ on_change=utils.reset_validation_trigger,
202
+ key=st.session_state["bbox3_key"],
203
+ )
204
+ min_height = st.number_input(
205
+ "Minimum Height",
206
+ min_value=0,
207
+ max_value=1000,
208
+ value=0,
209
+ step=1,
210
+ help="Minimum height of a bounding box. Boxes shorter than this will be removed.",
211
+ on_change=utils.reset_validation_trigger,
212
+ key=st.session_state["bbox4_key"],
213
+ )
214
+
215
+ check_each_transform = st.checkbox(
216
+ "Check Each Transform",
217
+ help="If checked, bounding boxes will be checked after each dual transform.",
218
+ value=True,
219
+ on_change=utils.reset_validation_trigger,
220
+ key=st.session_state["bbox5_key"],
221
+ )
222
+
223
+ # Return the collected parameters as a dictionary
224
+ return {
225
+ "min_area": min_area,
226
+ "min_visibility": min_visibility,
227
+ "min_width": min_width,
228
+ "min_height": min_height,
229
+ "check_each_transform": check_each_transform,
230
+ }
231
+
232
+
233
+ # Function to check if the uploaded images and labels are valid
234
+ st.cache_resource(show_spinner=False)
235
+
236
+
237
+ def check_valid_labels(uploaded_files, selected_option, class_dict):
238
+ # Early exit if no files are uploaded
239
+ if len(uploaded_files) == 0:
240
+ st.warning("Please upload at least one image to apply augmentation.", icon="⚠️")
241
+ return False, {}, {}, None, None
242
+
243
+ # Initialize dictionaries to hold images and labels
244
+ image_files, label_files = {}, {}
245
+
246
+ # Extracting the name of the first file
247
+ first_file_name = os.path.splitext(uploaded_files[0].name)[0]
248
+
249
+ # Counters for images and labels
250
+ image_count, label_count = 0, 0
251
+
252
+ # Initialize a progress bar and progress text
253
+ progress_bar = st.progress(0)
254
+ progress_text = st.empty()
255
+ total_files = len(uploaded_files)
256
+
257
+ # Categorize and prepare uploaded files
258
+ for index, file in enumerate(uploaded_files):
259
+ file.seek(0) # Reset file pointer to ensure proper file reading
260
+ file_name_without_extension = os.path.splitext(file.name)[0]
261
+
262
+ # Distribute files into image or label categories based on their file type
263
+ if file.type in ["image/jpeg", "image/png"]:
264
+ image_files[file_name_without_extension] = file
265
+ image_count += 1
266
+ elif file.type == "text/plain":
267
+ file_content = file.readlines()
268
+
269
+ if selected_option == "Bboxes":
270
+ label_data = bboxes_label(file_content, class_dict)
271
+ elif selected_option == "Masks":
272
+ label_data = masks_label(file_content, class_dict)
273
+
274
+ # Check for valid label data
275
+ if label_data is None:
276
+ st.warning(
277
+ f"Invalid label format or data in file: {file.name}",
278
+ icon="⚠️",
279
+ )
280
+ return False, {}, {}, None, None
281
+
282
+ label_files[file_name_without_extension] = label_data
283
+ label_count += 1
284
+
285
+ # Update progress bar and display current progress
286
+ progress_percentage = (index + 1) / total_files
287
+ progress_bar.progress(progress_percentage)
288
+ progress_text.text(f"Validating file {index + 1} of {total_files}")
289
+
290
+ # Extract sets of unique file names for images and labels
291
+ unique_image_names = set(image_files.keys())
292
+ unique_label_names = set(label_files.keys())
293
+
294
+ # Remove progress bar and progress text after processing
295
+ progress_bar.empty()
296
+ progress_text.empty()
297
+
298
+ if (len(unique_image_names) != image_count) or (
299
+ len(unique_label_names) != label_count
300
+ ):
301
+ # Warn the user about the presence of duplicate file names
302
+ st.warning(
303
+ "Duplicate file names detected. Please ensure each image and label has a unique name.",
304
+ icon="⚠️",
305
+ )
306
+ return False, {}, {}, None, None
307
+
308
+ # Perform validation checks
309
+ if (len(image_files) > 0) and (len(label_files) > 0):
310
+ # Check if the number of images and labels match and each pair has corresponding files
311
+ if (len(image_files) == len(label_files)) and (
312
+ unique_image_names == unique_label_names
313
+ ):
314
+ st.info(
315
+ f"Validated: {len(image_files)} images and labels successfully matched.",
316
+ icon="✅",
317
+ )
318
+ return (
319
+ True,
320
+ image_files,
321
+ label_files,
322
+ image_files[first_file_name],
323
+ label_files[first_file_name],
324
+ )
325
+
326
+ elif len(image_files) != len(label_files):
327
+ # Warn if the count of images and labels does not match
328
+ st.warning(
329
+ "Count Mismatch: The number of uploaded images and labels does not match.",
330
+ icon="⚠️",
331
+ )
332
+ return False, {}, {}, None, None
333
+
334
+ else:
335
+ # Warn if there is a mismatch in file names between images and labels
336
+ st.warning(
337
+ "Mismatch detected: Some images do not have corresponding label files.",
338
+ icon="⚠️",
339
+ )
340
+ return False, {}, {}, None, None
341
+
342
+ elif len(image_files) > 0:
343
+ # Inform the user if only images are uploaded without labels
344
+ st.info(
345
+ f"Note: {len(image_files)} images uploaded without labels. Label type and class labels will be ignored in this case.",
346
+ icon="✅",
347
+ )
348
+ return True, image_files, {}, image_files[first_file_name], None
349
+
350
+ else:
351
+ # Warn if no images are uploaded
352
+ st.warning("Please upload an image to apply augmentation.", icon="⚠️")
353
+ return False, {}, {}, None, None
354
+
355
+
356
+ # Function to apply an augmentation technique to an image and return any errors along with the processed image
357
+ def apply_and_test_augmentation(
358
+ augmentation,
359
+ params,
360
+ image,
361
+ label,
362
+ label_type,
363
+ label_input_parameters,
364
+ allowed_image_types,
365
+ ):
366
+ try:
367
+ # Check the data type and number of channels of the input image
368
+ input_image_type = image.dtype
369
+ num_channels = (
370
+ image.shape[2] if len(image.shape) == 3 else 1
371
+ ) # Assuming 1 for single-channel images
372
+
373
+ # Validate if the input image type is among the allowed types
374
+ if not utils.is_image_type_allowed(
375
+ input_image_type, num_channels, allowed_image_types
376
+ ):
377
+ # Format the allowed types for display in the warning message
378
+ allowed_types_formatted = ", ".join(map(str, allowed_image_types))
379
+
380
+ # Display a warning message specifying the acceptable image types
381
+ st.warning(
382
+ f"Error applying {augmentation}: Incompatible image type. The input image should be one of the following types: {allowed_types_formatted}",
383
+ icon="⚠️",
384
+ )
385
+ return True, None # Error occurred
386
+
387
+ # Set the seed for reproducibility using iteration number
388
+ random.seed(0)
389
+
390
+ if label is None:
391
+ # Apply augmentation technique for no label input
392
+ transform = A.Compose([utils.apply_albumentation(params, augmentation)])
393
+ processed_image = transform(image=image)["image"]
394
+ return False, processed_image
395
+
396
+ elif label_type == "Bboxes":
397
+ # Apply augmentation technique for Bboxes lable format
398
+ transform = A.Compose(
399
+ [utils.apply_albumentation(params, augmentation)],
400
+ bbox_params=A.BboxParams(
401
+ format="yolo",
402
+ label_fields=["class_labels"],
403
+ min_area=label_input_parameters["min_area"],
404
+ min_visibility=label_input_parameters["min_visibility"],
405
+ min_width=label_input_parameters["min_width"],
406
+ min_height=label_input_parameters["min_height"],
407
+ check_each_transform=label_input_parameters["check_each_transform"],
408
+ ),
409
+ )
410
+ processed_image = transform(
411
+ image=image,
412
+ bboxes=label["bboxes"],
413
+ class_labels=label["class_labels"],
414
+ )["image"]
415
+
416
+ elif label_type == "Masks":
417
+ # Apply augmentation technique for Masks lable format
418
+ transform = A.Compose([utils.apply_albumentation(params, augmentation)])
419
+ processed_image = transform(
420
+ image=image,
421
+ mask=generate_mask(
422
+ label["masks"],
423
+ label["class_labels"],
424
+ image.shape[0],
425
+ image.shape[1],
426
+ ),
427
+ )["image"]
428
+
429
+ return False, processed_image # No error
430
+
431
+ except Exception as e:
432
+ st.warning(f"Error applying {augmentation}: {e}", icon="⚠️")
433
+ return True, None # Error occurred
434
+
435
+
436
+ # Generates a DataFrame detailing augmentation technique parameters and descriptions
437
+ def create_augmentations_dataframe(augmentations_params, augmentation_params_db):
438
+ data = []
439
+ for aug_name, params in augmentations_params.items():
440
+ for param_name, param_value in params.items():
441
+ # Retrieve relevant augmentation information from the database
442
+ augmentation_info = augmentation_params_db[
443
+ augmentation_params_db["Name"] == aug_name
444
+ ]
445
+ param_info = augmentation_info[
446
+ augmentation_info["Parameter Name"] == param_name
447
+ ]
448
+
449
+ # Check if the parameter information exists in the database
450
+ if not param_info.empty:
451
+ # Get the description of the current parameter
452
+ param_description = param_info["Parameter Description"].iloc[0]
453
+ else:
454
+ param_description = "Description not available"
455
+
456
+ # Append augmentation name, parameter name, its value, and description to the data list
457
+ data.append([aug_name, param_name, param_value, param_description])
458
+
459
+ # Create the DataFrame from the accumulated data
460
+ augmentations_df = pd.DataFrame(
461
+ data, columns=["augmentation", "Parameter", "Value", "Description"]
462
+ )
463
+ return augmentations_df
464
+
465
+
466
+ # Function to Generate Python Code for Augmentation with Bounding Box Labels
467
+ def generate_python_code_bboxes(
468
+ augmentations_params,
469
+ label_input_parameters,
470
+ num_variations=1,
471
+ include_original=False,
472
+ ):
473
+ # Start with necessary library imports
474
+ code_str = "# Importing necessary libraries\n"
475
+ code_str += "import os\nimport cv2\nimport shutil\nimport albumentations as A\n\n"
476
+
477
+ # Paths for input and output directories
478
+ code_str += "# Define the paths for input and output directories\n"
479
+ code_str += "input_directory = 'path/to/input'\n"
480
+ code_str += "output_directory = 'path/to/output'\n\n"
481
+
482
+ # Function to read YOLO format labels
483
+ code_str += "# Function to read YOLO format labels\n"
484
+ code_str += "def read_yolo_label(label_path):\n"
485
+ code_str += " bboxes = []\n"
486
+ code_str += " class_ids = []\n"
487
+ code_str += " with open(label_path, 'r') as file:\n"
488
+ code_str += " for line in file:\n"
489
+ code_str += " class_id, x_center, y_center, width, height = map(float, line.split())\n"
490
+ code_str += " bboxes.append([x_center, y_center, width, height])\n"
491
+ code_str += " class_ids.append(int(class_id))\n"
492
+ code_str += " return bboxes, class_ids\n\n"
493
+
494
+ # Function to create an augmentation pipeline
495
+ code_str += "# Function to create an augmentation pipeline using Albumentations\n"
496
+ code_str += "def process_image(image, bboxes, class_ids):\n"
497
+ code_str += " # Define the sequence of augmentation techniques\n"
498
+ code_str += " pipeline = A.Compose([\n"
499
+ for technique, params in augmentations_params.items():
500
+ code_str += f" A.{technique}({', '.join(f'{k}={v}' for k, v in params.items())}),\n"
501
+ code_str += " ], bbox_params=A.BboxParams(\n"
502
+ code_str += f" format='yolo',\n"
503
+ code_str += f" label_fields=['class_labels'],\n"
504
+ code_str += f" min_area={label_input_parameters['min_area']},\n"
505
+ code_str += f" min_visibility={label_input_parameters['min_visibility']},\n"
506
+ code_str += f" min_width={label_input_parameters['min_width']},\n"
507
+ code_str += f" min_height={label_input_parameters['min_height']},\n"
508
+ code_str += f" check_each_transform={label_input_parameters['check_each_transform']}\n"
509
+ code_str += " ))\n"
510
+ code_str += " # Apply the augmentation pipeline\n"
511
+ code_str += (
512
+ " return pipeline(image=image, bboxes=bboxes, class_labels=class_ids)\n\n"
513
+ )
514
+
515
+ # Function to process a batch of images
516
+ code_str += "# Function to process a batch of images\n"
517
+ code_str += "def process_batch(input_directory, output_directory):\n"
518
+ code_str += " for filename in os.listdir(input_directory):\n"
519
+ code_str += " if filename.lower().endswith(('.png', '.jpg', '.jpeg')):\n"
520
+ code_str += " image_path = os.path.join(input_directory, filename)\n"
521
+ code_str += " label_path = os.path.splitext(image_path)[0] + '.txt'\n\n"
522
+
523
+ code_str += " # Read the image and label\n"
524
+ code_str += " image = cv2.imread(image_path)\n"
525
+ code_str += " bboxes, class_ids = read_yolo_label(label_path)\n\n"
526
+
527
+ # Include original image and label logic
528
+ code_str += " # Include original image and label\n"
529
+ if include_original:
530
+ code_str += " shutil.copy2(image_path, output_directory)\n"
531
+ code_str += " shutil.copy2(label_path, output_directory)\n\n"
532
+
533
+ # Generate variations for each image and process them
534
+ code_str += " # Generate variations for each image and process them\n"
535
+ code_str += f" for variation in range({num_variations}):\n"
536
+ code_str += (
537
+ " processed_data = process_image(image, bboxes, class_ids)\n"
538
+ )
539
+ code_str += " processed_image = processed_data['image']\n"
540
+ code_str += " processed_bboxes = processed_data['bboxes']\n"
541
+ code_str += (
542
+ " processed_class_ids = processed_data['class_labels']\n\n"
543
+ )
544
+ code_str += " # Save the processed image\n"
545
+ code_str += " output_filename = f'processed_{os.path.splitext(filename)[0]}_{variation}{os.path.splitext(filename)[1]}'\n"
546
+ code_str += " cv2.imwrite(os.path.join(output_directory, output_filename), processed_image)\n\n"
547
+ code_str += " with open(os.path.join(output_directory, os.path.splitext(output_filename)[0] + '.txt'), 'w') as label_file:\n"
548
+ code_str += " for bbox, class_id in zip(processed_bboxes, processed_class_ids):\n"
549
+ code_str += " label_line = ' '.join(map(str, [class_id] + list(bbox)))\n"
550
+ code_str += " label_file.write(label_line + '\\n')\n\n"
551
+
552
+ # Execute the batch processing function
553
+ code_str += (
554
+ "# Execute the batch processing function with the specified parameters\n"
555
+ )
556
+ code_str += f"process_batch(input_directory, output_directory)\n"
557
+
558
+ return code_str
559
+
560
+
561
+ def generate_python_code_masks(
562
+ augmentations_params,
563
+ label_input_parameters,
564
+ num_variations=1,
565
+ include_original=False,
566
+ ):
567
+ # Start with necessary library imports
568
+ code_str = "# Importing necessary libraries\n"
569
+ code_str += "import os\nimport cv2\nimport shutil\nimport numpy as np\nimport albumentations as A\n\n"
570
+
571
+ # Paths for input and output directories
572
+ code_str += "# Define the paths for input and output directories\n"
573
+ code_str += "input_directory = 'path/to/input'\n"
574
+ code_str += "output_directory = 'path/to/output'\n\n"
575
+
576
+ # Function to read YOLO mask format and convert to mask
577
+ code_str += "# Function to read YOLO mask format and convert to mask\n"
578
+ code_str += "def read_yolo_label(label_path, image_shape):\n"
579
+ code_str += " mask = np.full(image_shape, -1, dtype=np.int32)\n"
580
+ code_str += " with open(label_path, 'r') as file:\n"
581
+ code_str += " for line in file:\n"
582
+ code_str += " parts = line.strip().split()\n"
583
+ code_str += " class_id = int(parts[0])\n"
584
+ code_str += (
585
+ " points = np.array([float(p) for p in parts[1:]]).reshape(-1, 2)\n"
586
+ )
587
+ code_str += " points = (points * [image_shape[1], image_shape[0]]).astype(np.int32)\n"
588
+ code_str += " cv2.fillPoly(mask, [points], class_id)\n"
589
+ code_str += " return mask\n\n"
590
+
591
+ # Function to convert mask to YOLO format
592
+ code_str += "# Function to convert mask to YOLO format\n"
593
+ code_str += "def mask_to_yolo(mask):\n"
594
+ code_str += " yolo_format = ''\n"
595
+ code_str += " for class_id in np.unique(mask):\n"
596
+ code_str += " if class_id == -1:\n"
597
+ code_str += " continue\n"
598
+ code_str += " contours, _ = cv2.findContours(\n"
599
+ code_str += " np.uint8(mask == class_id), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE\n"
600
+ code_str += " )\n"
601
+ code_str += " for contour in contours:\n"
602
+ code_str += " contour = contour.flatten().tolist()\n"
603
+ code_str += " normalized_contour = [\n"
604
+ code_str += " str(coord / mask.shape[i % 2]) for i, coord in enumerate(contour)\n"
605
+ code_str += " ]\n"
606
+ code_str += " yolo_format += f'{class_id} ' + ' '.join(normalized_contour) + '\\n'\n"
607
+ code_str += " return yolo_format\n\n"
608
+
609
+ # Function to create an augmentation pipeline
610
+ code_str += "# Function to create an augmentation pipeline using Albumentations\n"
611
+ code_str += "def process_image(image, mask):\n"
612
+ code_str += " # Define the sequence of augmentation techniques\n"
613
+ code_str += " pipeline = A.Compose([\n"
614
+ for technique, params in augmentations_params.items():
615
+ code_str += f" A.{technique}({', '.join(f'{k}={v}' for k, v in params.items())}),\n"
616
+ code_str += " ])\n"
617
+ code_str += " # Apply the augmentation pipeline\n"
618
+ code_str += " return pipeline(image=image, mask=mask)\n\n"
619
+
620
+ # Function to process a batch of images
621
+ code_str += "# Function to process a batch of images\n"
622
+ code_str += "def process_batch(input_directory, output_directory, include_original=False, num_variations=1):\n"
623
+ code_str += " for filename in os.listdir(input_directory):\n"
624
+ code_str += " if filename.lower().endswith(('.png', '.jpg', '.jpeg')):\n"
625
+ code_str += " image_path = os.path.join(input_directory, filename)\n"
626
+ code_str += " label_path = os.path.splitext(image_path)[0] + '.txt'\n\n"
627
+ code_str += " # Read the image\n"
628
+ code_str += " image = cv2.imread(image_path)\n\n"
629
+
630
+ # Include original image and label logic
631
+ code_str += " # Include original image\n"
632
+ if include_original:
633
+ code_str += " shutil.copy2(image_path, output_directory)\n"
634
+ code_str += " shutil.copy2(label_path, output_directory)\n\n"
635
+
636
+ code_str += " # Check if label file exists and read mask\n"
637
+ code_str += " mask = None\n"
638
+ code_str += " if os.path.exists(label_path):\n"
639
+ code_str += (
640
+ " mask = read_yolo_label(label_path, image.shape[:2])\n\n"
641
+ )
642
+
643
+ # Generate variations for each image and process them
644
+ code_str += " # Generate variations for each image and process them\n"
645
+ code_str += f" for variation in range({num_variations}):\n"
646
+ code_str += " processed_image, processed_mask = image, mask\n"
647
+ code_str += " if mask is not None:\n"
648
+ code_str += " processed_data = process_image(image, mask)\n"
649
+ code_str += " processed_image, processed_mask = processed_data['image'], processed_data['mask']\n\n"
650
+ code_str += " # Save the processed image\n"
651
+ code_str += " output_filename = f'processed_{os.path.splitext(filename)[0]}_{variation}.jpg'\n"
652
+ code_str += " cv2.imwrite(os.path.join(output_directory, output_filename), processed_image)\n\n"
653
+ code_str += " # Save the processed label in YOLO format\n"
654
+ code_str += " if processed_mask is not None:\n"
655
+ code_str += (
656
+ " processed_label_str = mask_to_yolo(processed_mask)\n"
657
+ )
658
+ code_str += " with open(os.path.join(output_directory, os.path.splitext(output_filename)[0] + '.txt'), 'w') as label_file:\n"
659
+ code_str += " label_file.write(processed_label_str)\n\n"
660
+
661
+ # Execute the batch processing function with the specified parameters
662
+ code_str += (
663
+ "# Execute the batch processing function with the specified parameters\n"
664
+ )
665
+ code_str += "process_batch(input_directory, output_directory)\n"
666
+
667
+ return code_str
668
+
669
+
670
+ # Function to create an augmentation pipeline based on the selected techniques and their parameters
671
+ def create_augmentation_pipeline(
672
+ selected_augmentations, augmentation_params, label_type, label_input_parameters=None
673
+ ):
674
+ pipeline = []
675
+ for aug_name in selected_augmentations:
676
+ # Append the function call with its parameters to the pipeline
677
+ pipeline.append(
678
+ utils.apply_albumentation(augmentation_params[aug_name], aug_name)
679
+ )
680
+
681
+ # Compose all the augmentations into one transformation
682
+ try:
683
+ # Set the seed for reproducibility using iteration number
684
+ random.seed(0)
685
+
686
+ if label_type is None:
687
+ # Apply augmentation technique for no label input
688
+ transform = A.Compose(pipeline)
689
+ return transform
690
+
691
+ elif label_type == "Bboxes":
692
+ # Apply augmentation technique for Bboxes lable format
693
+ transform = A.Compose(
694
+ pipeline,
695
+ bbox_params=A.BboxParams(
696
+ format="yolo",
697
+ label_fields=["class_labels"],
698
+ min_area=label_input_parameters["min_area"],
699
+ min_visibility=label_input_parameters["min_visibility"],
700
+ min_width=label_input_parameters["min_width"],
701
+ min_height=label_input_parameters["min_height"],
702
+ check_each_transform=label_input_parameters["check_each_transform"],
703
+ ),
704
+ )
705
+
706
+ elif label_type == "Masks":
707
+ # Apply augmentation technique for Masks lable format
708
+ transform = A.Compose(pipeline)
709
+
710
+ return transform # No error
711
+
712
+ except Exception as e:
713
+ st.warning(f"Error applying augmentation")
714
+ return None # Error occurred
715
+
716
+
717
+ # Function to convert label data from dictionary format to YOLO format
718
+ def convert_labels_to_yolo_format(label_data, class_dict):
719
+ yolo_label_str = ""
720
+
721
+ # Convert bounding boxes to YOLO format
722
+ if "bboxes" in label_data:
723
+ for bbox, class_label in zip(label_data["bboxes"], label_data["class_labels"]):
724
+ class_id = class_label - 1 # Revert to the original value
725
+ x_center, y_center, width, height = bbox
726
+ yolo_label_str += f"{class_id} {x_center} {y_center} {width} {height}\n"
727
+
728
+ # Convert masks to YOLO format
729
+ if "masks" in label_data:
730
+ for mask, class_label in zip(label_data["masks"], label_data["class_labels"]):
731
+ class_id = class_label - 1 # Revert to the original value
732
+ # Flatten the mask array into a single line of coordinates
733
+ mask_flattened = [coord for point in mask for coord in point]
734
+ mask_str = " ".join(map(str, mask_flattened))
735
+ yolo_label_str += f"{class_id} {mask_str}\n"
736
+
737
+ return yolo_label_str
738
+
739
+
740
+ # Function to apply the augmentation pipeline to the image based on the label type
741
+ def apply_augmentation_pipeline(image, label_file, label_type, transform):
742
+ # Initialize an empty dictionary to store processed labels
743
+ processed_label = {}
744
+
745
+ # Apply the transformation based on the label type
746
+ if label_type is None:
747
+ processed_output = transform(image=image)
748
+ processed_label = None
749
+
750
+ elif label_type == "Bboxes":
751
+ processed_output = transform(
752
+ image=image,
753
+ bboxes=label_file["bboxes"],
754
+ class_labels=label_file["class_labels"],
755
+ )
756
+ processed_label = {
757
+ "bboxes": processed_output["bboxes"],
758
+ "class_labels": processed_output["class_labels"],
759
+ }
760
+
761
+ elif label_type == "Masks":
762
+ mask = generate_mask(
763
+ label_file["masks"],
764
+ label_file["class_labels"],
765
+ image.shape[0],
766
+ image.shape[1],
767
+ )
768
+ processed_output = transform(image=image, mask=mask)
769
+ mask_yolo = mask_to_yolo(processed_output["mask"])
770
+ processed_label = mask_yolo
771
+
772
+ # Extract the processed image
773
+ processed_image = processed_output["image"]
774
+
775
+ return processed_image, processed_label
776
+
777
+
778
+ # Function to process images and labels, apply augmentations, and create a zip file with the results
779
+ @st.cache_resource(show_spinner=False)
780
+ def process_images_and_labels(
781
+ image_files,
782
+ label_files,
783
+ selected_augmentations,
784
+ _augmentations_params,
785
+ label_type,
786
+ label_input_parameters,
787
+ num_variations,
788
+ include_original,
789
+ class_dict,
790
+ ):
791
+ zip_buffer = io.BytesIO() # Create an in-memory buffer for the zip file
792
+ st.session_state[
793
+ "image_repository_augmentation"
794
+ ] = {} # Initialize a repository to store processed image data
795
+ st.session_state[
796
+ "processed_image_mapping_augmentation"
797
+ ] = {} # Map original images to their processed versions
798
+ st.session_state["unique_images_names"] = [] # List to store unique images names
799
+
800
+ # Create progress bar and text elements in Streamlit
801
+ progress_bar = st.progress(0)
802
+ progress_text = st.empty()
803
+
804
+ with zipfile.ZipFile(
805
+ zip_buffer, mode="a", compression=zipfile.ZIP_DEFLATED, allowZip64=True
806
+ ) as zip_file:
807
+ # Determine the label type for augmentation, if label files are present
808
+ effective_label_type = None if len(label_files) == 0 else label_type
809
+
810
+ # Create an augmentation pipeline based on selected augmentations and parameters
811
+ transform = create_augmentation_pipeline(
812
+ selected_augmentations,
813
+ _augmentations_params,
814
+ effective_label_type,
815
+ label_input_parameters,
816
+ )
817
+
818
+ total_images = len(image_files) * num_variations
819
+ processed_count = 0 # Counter for processed images
820
+
821
+ # Iterate over each uploaded file
822
+ for image_name, image_file in image_files.items():
823
+ image_file.seek(0) # Reset file pointer to start
824
+ file_bytes = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
825
+ original_image = cv2.cvtColor(
826
+ cv2.imdecode(file_bytes, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB
827
+ )
828
+ original_image_resized = utils.resize_image(original_image)
829
+
830
+ # Include original images and labels in the output if selected
831
+ if include_original:
832
+ original_img_buffer = io.BytesIO()
833
+ Image.fromarray(original_image).save(original_img_buffer, format="JPEG")
834
+ zip_file.writestr(image_file.name, original_img_buffer.getvalue())
835
+
836
+ # Convert and save original labels to YOLO format if they exist
837
+ label_file = label_files.get(image_name)
838
+ if label_file is not None:
839
+ yolo_label_str = convert_labels_to_yolo_format(
840
+ label_file, class_dict
841
+ )
842
+ zip_file.writestr(f"{image_name}.txt", yolo_label_str)
843
+
844
+ original_file_name = image_file.name
845
+ st.session_state["unique_images_names"].append(original_file_name)
846
+ st.session_state["processed_image_mapping_augmentation"][
847
+ original_file_name
848
+ ] = []
849
+ st.session_state["image_repository_augmentation"][image_file.name] = {
850
+ "image": original_image_resized,
851
+ "label": label_files.get(image_name),
852
+ }
853
+
854
+ # Apply augmentations and generate variations
855
+ for i in range(num_variations):
856
+ random.seed(i)
857
+ (
858
+ processed_image,
859
+ processed_label,
860
+ ) = apply_augmentation_pipeline(
861
+ original_image,
862
+ label_files.get(image_name),
863
+ effective_label_type,
864
+ transform,
865
+ )
866
+
867
+ img_buffer = io.BytesIO()
868
+ Image.fromarray(processed_image).save(img_buffer, format="JPEG")
869
+ processed_filename = f"processed_{image_name.split('.')[0]}_{i}.jpg"
870
+ zip_file.writestr(processed_filename, img_buffer.getvalue())
871
+ processed_image_resized = utils.resize_image(processed_image)
872
+
873
+ st.session_state["processed_image_mapping_augmentation"][
874
+ image_file.name
875
+ ].append(processed_filename)
876
+ st.session_state["image_repository_augmentation"][
877
+ processed_filename
878
+ ] = {
879
+ "image": processed_image_resized,
880
+ "label": processed_label,
881
+ }
882
+
883
+ # Convert and save processed labels to YOLO format if they exist
884
+ label_file = label_files.get(image_name)
885
+ if label_file is not None:
886
+ processed_label_str = convert_labels_to_yolo_format(
887
+ processed_label, class_dict
888
+ )
889
+ zip_file.writestr(
890
+ f"processed_{image_name.split('.')[0]}_{i}.txt",
891
+ processed_label_str,
892
+ )
893
+
894
+ processed_count += 1
895
+ # Update progress bar and text
896
+ progress_bar.progress(processed_count / total_images)
897
+ progress_text.text(
898
+ f"Processing image {processed_count} of {total_images}"
899
+ )
900
+
901
+ # Remove the progress bar and text after processing is complete
902
+ progress_bar.empty()
903
+ progress_text.empty()
904
+
905
+ zip_buffer.seek(0) # Reset buffer to start for download
906
+
907
+ st.session_state["zip_data_augmentation"] = zip_buffer.getvalue()
908
+
909
+
910
+ # Function to overlay labels on images
911
+ def overlay_labels(image, labels_to_plot, label_file, label_type, colors, class_dict):
912
+ # Ensure the image is in the correct format (RGB)
913
+ if len(image.shape) == 2 or image.shape[2] == 1:
914
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
915
+
916
+ # Overlay Bounding Boxes
917
+ if label_type == "Bboxes":
918
+ for bbox, label in zip(label_file["bboxes"], label_file["class_labels"]):
919
+ if label in labels_to_plot:
920
+ # Convert bbox from yolo format to xmin, ymin, xmax, ymax
921
+ x_center, y_center, width, height = bbox
922
+ xmin = int((x_center - width / 2) * image.shape[1])
923
+ xmax = int((x_center + width / 2) * image.shape[1])
924
+ ymin = int((y_center - height / 2) * image.shape[0])
925
+ ymax = int((y_center + height / 2) * image.shape[0])
926
+
927
+ # Get color for the class
928
+ color = colors[label]
929
+
930
+ # Draw rectangle and label
931
+ image = cv2.rectangle(
932
+ image, (xmin, ymin), (xmax, ymax), color, thickness=2
933
+ )
934
+
935
+ # Put class label text
936
+ label_text = class_dict.get(label, "Unknown")
937
+ cv2.putText(
938
+ image,
939
+ label_text,
940
+ (xmin, ymin - 5),
941
+ cv2.FONT_HERSHEY_SIMPLEX,
942
+ 0.5,
943
+ (0, 0, 0), # Black color for text
944
+ 2,
945
+ )
946
+
947
+ # Overlay Mask
948
+ elif label_type == "Masks":
949
+ for polygon, label in zip(label_file["masks"], label_file["class_labels"]):
950
+ if label in labels_to_plot:
951
+ # Convert polygon points from yolo format to image coordinates
952
+ polygon_points = [
953
+ (int(x * image.shape[1]), int(y * image.shape[0]))
954
+ for x, y in polygon
955
+ ]
956
+
957
+ # Get color for the class
958
+ color = colors[label]
959
+
960
+ # Create a temporary image to draw the polygon
961
+ temp_image = image.copy()
962
+ cv2.fillPoly(
963
+ temp_image, [np.array(polygon_points, dtype=np.int32)], color
964
+ )
965
+
966
+ # Blend the temporary image with the original image
967
+ cv2.addWeighted(temp_image, 0.5, image, 0.5, 0, image)
968
+
969
+ # Optional: Put class label text near the first point of the polygon
970
+ label_text = class_dict.get(label, "Unknown")
971
+ cv2.putText(
972
+ image,
973
+ label_text,
974
+ polygon_points[0],
975
+ cv2.FONT_HERSHEY_SIMPLEX,
976
+ 0.5,
977
+ (0, 0, 0), # Black color for text
978
+ 2,
979
+ )
980
+
981
+ return image
982
+
983
+
984
+ # Function to generate a downloadable file
985
+ def display_code_and_download_button(generated_code):
986
+ def generate_downloadable_file(code_str):
987
+ return code_str.encode("utf-8")
988
+
989
+ # Display the generated code in Streamlit with description and download button in columns
990
+ with st.expander("Plug and Play Code"):
991
+ col1, col2 = st.columns([7, 3])
992
+
993
+ with col1:
994
+ st.markdown(
995
+ """
996
+ ### Description of the Code Pipeline
997
+ """
998
+ )
999
+
1000
+ st.markdown(
1001
+ """
1002
+ <div style='text-align: justify;'>
1003
+ This code is a ready-to-use Python script for batch augmentation. It applies selected augmentation techniques to all images in a specified input directory and saves the processed images in an output directory.
1004
+
1005
+ **To use this script:**
1006
+ - Ensure you have the necessary dependencies installed.
1007
+ - Specify the input and output paths: Replace `'path/to/input'` with the path to your input images and `'path/to/output'` with the desired path for the processed images.
1008
+ - The number of augmented variations per image, the inclusion of the original images in the output, and the augmentation techniques with their parameters will be automatically set based on your selections.
1009
+
1010
+ ### Python Code
1011
+ </div>
1012
+ """,
1013
+ unsafe_allow_html=True,
1014
+ )
1015
+
1016
+ # Display python code
1017
+ st.code(generated_code, language="python")
1018
+
1019
+ with col2:
1020
+ # Create a button for downloading the Python file
1021
+ st.download_button(
1022
+ label="Download Python File",
1023
+ data=generate_downloadable_file(generated_code),
1024
+ file_name="augmentation_script.py",
1025
+ mime="text/plain",
1026
+ use_container_width=True,
1027
+ )
Functions/image_processing_functions.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing necessary libraries
2
+ import io
3
+ import os
4
+ import cv2
5
+ import utils
6
+ import random
7
+ import zipfile
8
+ import numpy as np
9
+ import pandas as pd
10
+ from PIL import Image
11
+ import streamlit as st
12
+ import albumentations as A
13
+
14
+
15
+ # Function to check if the uploaded images and labels are valid
16
+ st.cache_resource(show_spinner=False)
17
+
18
+
19
+ def check_valid_labels(uploaded_files):
20
+ # Early exit if no files are uploaded
21
+ if len(uploaded_files) == 0:
22
+ st.warning(
23
+ "Please upload at least one image to apply image processing.", icon="⚠️"
24
+ )
25
+ return False, {}, {}, None, None
26
+
27
+ # Initialize dictionaries to hold images and labels
28
+ image_files, label_files = {}, {}
29
+
30
+ # Extracting the name of the first file
31
+ first_file_name = os.path.splitext(uploaded_files[0].name)[0]
32
+
33
+ # Counters for images and labels
34
+ image_count, label_count = 0, 0
35
+
36
+ # Initialize a progress bar and progress text
37
+ progress_bar = st.progress(0)
38
+ progress_text = st.empty()
39
+ total_files = len(uploaded_files)
40
+
41
+ # Categorize and prepare uploaded files
42
+ for index, file in enumerate(uploaded_files):
43
+ file.seek(0) # Reset file pointer to ensure proper file reading
44
+ file_name_without_extension = os.path.splitext(file.name)[0]
45
+
46
+ # Distribute files into image or label categories based on their file type
47
+ if file.type in ["image/jpeg", "image/png"]:
48
+ image_files[file_name_without_extension] = file
49
+ image_count += 1
50
+ elif file.type == "text/plain":
51
+ label_files[file_name_without_extension] = file
52
+ label_count += 1
53
+
54
+ # Update progress bar and display current progress
55
+ progress_percentage = (index + 1) / total_files
56
+ progress_bar.progress(progress_percentage)
57
+ progress_text.text(f"Validating file {index + 1} of {total_files}")
58
+
59
+ # Extract sets of unique file names for images and labels
60
+ unique_image_names = set(image_files.keys())
61
+ unique_label_names = set(label_files.keys())
62
+
63
+ # Remove progress bar and progress text after processing
64
+ progress_bar.empty()
65
+ progress_text.empty()
66
+
67
+ if (len(unique_image_names) != image_count) or (
68
+ len(unique_label_names) != label_count
69
+ ):
70
+ # Warn the user about the presence of duplicate file names
71
+ st.warning(
72
+ "Duplicate file names detected. Please ensure each image and label has a unique name.",
73
+ icon="⚠️",
74
+ )
75
+ return False, {}, {}, None, None
76
+
77
+ # Perform validation checks
78
+ if (len(image_files) > 0) and (len(label_files) > 0):
79
+ # Check if the number of images and labels match and each pair has corresponding files
80
+ if (len(image_files) == len(label_files)) and (
81
+ unique_image_names == unique_label_names
82
+ ):
83
+ st.info(
84
+ f"Validated: {len(image_files)} images and labels successfully matched.",
85
+ icon="✅",
86
+ )
87
+ return (
88
+ True,
89
+ image_files,
90
+ label_files,
91
+ image_files[first_file_name],
92
+ label_files[first_file_name],
93
+ )
94
+
95
+ elif len(image_files) != len(label_files):
96
+ # Warn if the count of images and labels does not match
97
+ st.warning(
98
+ "Count Mismatch: The number of uploaded images and labels does not match.",
99
+ icon="⚠️",
100
+ )
101
+ return False, {}, {}, None, None
102
+
103
+ else:
104
+ # Warn if there is a mismatch in file names between images and labels
105
+ st.warning(
106
+ "Mismatch detected: Some images do not have corresponding label files.",
107
+ icon="⚠️",
108
+ )
109
+ return False, {}, {}, None, None
110
+
111
+ elif len(image_files) > 0:
112
+ # Inform the user if only images are uploaded without labels
113
+ st.info(
114
+ f"Note: {len(image_files)} images uploaded without labels. Label type and class labels will be ignored in this case.",
115
+ icon="✅",
116
+ )
117
+ return True, image_files, {}, image_files[first_file_name], None
118
+
119
+ else:
120
+ # Warn if no images are uploaded
121
+ st.warning("Please upload an image to apply image processing.", icon="⚠️")
122
+ return False, {}, {}, None, None
123
+
124
+
125
+ # Function to apply an image processing technique to an image and return any errors along with the processed image
126
+ def apply_and_test_image_processing(
127
+ image_processing, params, image, allowed_image_types
128
+ ):
129
+ try:
130
+ # Check the data type and number of channels of the input image
131
+ input_image_type = image.dtype
132
+ num_channels = (
133
+ image.shape[2] if len(image.shape) == 3 else 1
134
+ ) # Assuming 1 for single-channel images
135
+
136
+ # Validate if the input image type is among the allowed types
137
+ if not utils.is_image_type_allowed(
138
+ input_image_type, num_channels, allowed_image_types
139
+ ):
140
+ # Format the allowed types for display in the warning message
141
+ allowed_types_formatted = ", ".join(map(str, allowed_image_types))
142
+
143
+ # Display a warning message specifying the acceptable image types
144
+ st.warning(
145
+ f"Error applying {image_processing}: Incompatible image type. The input image should be one of the following types: {allowed_types_formatted}",
146
+ icon="⚠️",
147
+ )
148
+ return True, None # Error occurred
149
+
150
+ # Set the seed for reproducibility using iteration number
151
+ random.seed(0)
152
+
153
+ # Apply image processing technique
154
+ transform = A.Compose([utils.apply_albumentation(params, image_processing)])
155
+ processed_image = transform(image=image)["image"]
156
+
157
+ return False, processed_image # No error
158
+ except Exception as e:
159
+ st.warning(f"Error applying {image_processing}: {e}", icon="⚠️")
160
+ return True, None # Error occurred
161
+
162
+
163
+ # Function to generates a DataFrame detailing image processing technique parameters and descriptions
164
+ def create_image_processings_dataframe(
165
+ image_processings_params, image_processing_params_db
166
+ ):
167
+ data = []
168
+ for aug_name, params in image_processings_params.items():
169
+ for param_name, param_value in params.items():
170
+ # Retrieve relevant image_processing information from the database
171
+ image_processing_info = image_processing_params_db[
172
+ image_processing_params_db["Name"] == aug_name
173
+ ]
174
+ param_info = image_processing_info[
175
+ image_processing_info["Parameter Name"] == param_name
176
+ ]
177
+
178
+ # Check if the parameter information exists in the database
179
+ if not param_info.empty:
180
+ # Get the description of the current parameter
181
+ param_description = param_info["Parameter Description"].iloc[0]
182
+ else:
183
+ param_description = "Description not available"
184
+
185
+ # Append image_processing name, parameter name, its value, and description to the data list
186
+ data.append([aug_name, param_name, param_value, param_description])
187
+
188
+ # Create the DataFrame from the accumulated data
189
+ image_processings_df = pd.DataFrame(
190
+ data, columns=["image_processing", "Parameter", "Value", "Description"]
191
+ )
192
+ return image_processings_df
193
+
194
+
195
+ # Function to generate python code for images and labels
196
+ def generate_python_code_images_labels(
197
+ augmentations_params,
198
+ num_variations=1,
199
+ include_original=False,
200
+ ):
201
+ # Start with necessary library imports
202
+ code_str = "# Importing necessary libraries\n"
203
+ code_str += "import os\nimport cv2\nimport shutil\nimport albumentations as A\n\n"
204
+
205
+ # Paths for input and output directories
206
+ code_str += "# Define the paths for input and output directories\n"
207
+ code_str += "input_directory = 'path/to/input'\n"
208
+ code_str += "output_directory = 'path/to/output'\n\n"
209
+
210
+ # Function to create an augmentation pipeline
211
+ code_str += "# Function to create an augmentation pipeline using Albumentations\n"
212
+ code_str += "def process_image(image):\n"
213
+ code_str += " # Define the sequence of augmentation techniques\n"
214
+ code_str += " pipeline = A.Compose([\n"
215
+ for technique, params in augmentations_params.items():
216
+ code_str += f" A.{technique}({', '.join(f'{k}={v}' for k, v in params.items())}),\n"
217
+ code_str += " ])\n"
218
+ code_str += " # Apply the augmentation pipeline\n"
219
+ code_str += " return pipeline(image=image)['image']\n\n"
220
+
221
+ # Function to process a batch of images
222
+ code_str += "# Function to process a batch of images\n"
223
+ code_str += "def process_batch(input_directory, output_directory):\n"
224
+ code_str += " for filename in os.listdir(input_directory):\n"
225
+ code_str += " if filename.lower().endswith(('.png', '.jpg', '.jpeg')):\n"
226
+ code_str += " image_path = os.path.join(input_directory, filename)\n"
227
+ code_str += " label_path = os.path.splitext(image_path)[0] + '.txt'\n\n"
228
+
229
+ code_str += " # Read the image\n"
230
+ code_str += " image = cv2.imread(image_path)\n\n"
231
+
232
+ # Include original image and label logic
233
+ if include_original:
234
+ code_str += " # Include original image and label\n"
235
+ code_str += " shutil.copy2(image_path, output_directory)\n"
236
+ code_str += " shutil.copy2(label_path, output_directory)\n\n"
237
+
238
+ # Generate variations for each image and process them
239
+ code_str += " # Generate variations for each image and process them\n"
240
+ code_str += f" for variation in range({num_variations}):\n"
241
+ code_str += " processed_image = process_image(image)\n\n"
242
+ code_str += " # Save the processed image\n"
243
+ code_str += " output_filename = f'processed_{os.path.splitext(filename)[0]}_{variation}{os.path.splitext(filename)[1]}'\n"
244
+ code_str += " cv2.imwrite(os.path.join(output_directory, output_filename), processed_image)\n\n"
245
+ code_str += (
246
+ " # Save the original label file for the processed image\n"
247
+ )
248
+ code_str += " if os.path.exists(label_path):\n"
249
+ code_str += " shutil.copy2(label_path, os.path.join(output_directory, os.path.splitext(output_filename)[0] + '.txt'))\n\n"
250
+
251
+ # Execute the batch processing function
252
+ code_str += (
253
+ "# Execute the batch processing function with the specified parameters\n"
254
+ )
255
+ code_str += "process_batch(input_directory, output_directory)\n"
256
+
257
+ return code_str
258
+
259
+
260
+ # Function to create an image processing pipeline based on the selected techniques and their parameters
261
+ def create_image_processing_pipeline(
262
+ selected_image_processings, image_processing_params
263
+ ):
264
+ pipeline = []
265
+ for aug_name in selected_image_processings:
266
+ # Append the function call with its parameters to the pipeline
267
+ pipeline.append(
268
+ utils.apply_albumentation(image_processing_params[aug_name], aug_name)
269
+ )
270
+
271
+ # Compose all the image processings into one transformation
272
+ return A.Compose(pipeline)
273
+
274
+
275
+ # Function to process images and labels, apply image processing techniques, and create a zip file with the results
276
+ @st.cache_resource(show_spinner=False)
277
+ def process_images_and_labels(
278
+ image_files,
279
+ label_files,
280
+ selected_image_processings,
281
+ _image_processings_params,
282
+ num_variations,
283
+ include_original,
284
+ ):
285
+ zip_buffer = io.BytesIO() # Create an in-memory buffer for the zip file
286
+ st.session_state[
287
+ "image_repository_preprocessing"
288
+ ] = {} # Initialize a repository to store processed image data
289
+ st.session_state[
290
+ "processed_image_mapping_procesing"
291
+ ] = {} # Map original images to their processed versions
292
+ st.session_state["unique_images_names"] = [] # List to store unique images names
293
+
294
+ # Create progress bar and text elements in Streamlit
295
+ progress_bar = st.progress(0)
296
+ progress_text = st.empty()
297
+
298
+ with zipfile.ZipFile(
299
+ zip_buffer, mode="a", compression=zipfile.ZIP_DEFLATED, allowZip64=True
300
+ ) as zip_file:
301
+ # Compose all the image processings into one transformation
302
+ transform = create_image_processing_pipeline(
303
+ selected_image_processings, _image_processings_params
304
+ )
305
+
306
+ total_images = len(image_files) * num_variations
307
+ processed_count = 0 # Counter for processed images
308
+
309
+ # Iterate over each uploaded file
310
+ for image_name, image_file in image_files.items():
311
+ image_file.seek(0) # Reset file pointer to start
312
+ file_bytes = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
313
+ original_image = cv2.cvtColor(
314
+ cv2.imdecode(file_bytes, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB
315
+ )
316
+ original_image_resized = utils.resize_image(original_image)
317
+
318
+ # Include original images and labels in the output if selected
319
+ if include_original:
320
+ original_img_buffer = io.BytesIO()
321
+ Image.fromarray(original_image).save(original_img_buffer, format="JPEG")
322
+ zip_file.writestr(image_file.name, original_img_buffer.getvalue())
323
+
324
+ # Save corresponding label file to zip if it exists
325
+ label_file = label_files.get(image_name)
326
+ if label_file is not None:
327
+ label_file.seek(0) # Reset the file pointer
328
+ zip_file.writestr(f"{image_name}.txt", label_file.read())
329
+
330
+ original_file_name = image_file.name
331
+ st.session_state["unique_images_names"].append(original_file_name)
332
+ st.session_state["processed_image_mapping_procesing"][
333
+ original_file_name
334
+ ] = []
335
+ st.session_state["image_repository_preprocessing"][image_file.name] = {
336
+ "image": original_image_resized,
337
+ "label": label_files.get(image_name),
338
+ }
339
+
340
+ # Apply image processing techniques and generate variations
341
+ for i in range(num_variations):
342
+ random.seed(i)
343
+
344
+ # Apply the image processing pipeline to the image
345
+ processed_image = transform(image=original_image)["image"]
346
+
347
+ img_buffer = io.BytesIO()
348
+ Image.fromarray(processed_image).save(img_buffer, format="JPEG")
349
+ processed_filename = f"processed_{image_name.split('.')[0]}_{i}.jpg"
350
+ zip_file.writestr(processed_filename, img_buffer.getvalue())
351
+ processed_image_resized = utils.resize_image(processed_image)
352
+
353
+ # Save corresponding label file to zip if it exists
354
+ label_file = label_files.get(image_name)
355
+ if label_file is not None:
356
+ label_file.seek(0) # Reset the file pointer
357
+ zip_file.writestr(
358
+ f"processed_{image_name}_{i}.txt", label_file.read()
359
+ )
360
+
361
+ st.session_state["processed_image_mapping_procesing"][
362
+ image_file.name
363
+ ].append(processed_filename)
364
+ st.session_state["image_repository_preprocessing"][
365
+ processed_filename
366
+ ] = {
367
+ "image": processed_image_resized,
368
+ "label": label_file,
369
+ }
370
+
371
+ processed_count += 1
372
+ # Update progress bar and text
373
+ progress_bar.progress(processed_count / total_images)
374
+ progress_text.text(
375
+ f"Processing image {processed_count} of {total_images}"
376
+ )
377
+
378
+ # Remove the progress bar and text after processing is complete
379
+ progress_bar.empty()
380
+ progress_text.empty()
381
+
382
+ zip_buffer.seek(0) # Reset buffer to start for download
383
+
384
+ st.session_state["zip_data_processing"] = zip_buffer.getvalue()
385
+
386
+
387
+ # Function to generate a downloadable file
388
+ def display_code_and_download_button(generated_code):
389
+ def generate_downloadable_file(code_str):
390
+ return code_str.encode("utf-8")
391
+
392
+ # Display the generated code in Streamlit with description and download button in columns
393
+ with st.expander("Plug and Play Code"):
394
+ col1, col2 = st.columns([7, 3])
395
+
396
+ with col1:
397
+ st.markdown(
398
+ """
399
+ ### Description of the Code Pipeline
400
+ """
401
+ )
402
+
403
+ st.markdown(
404
+ """
405
+ <div style='text-align: justify;'>
406
+ This code is a ready-to-use Python script for batch augmentation. It applies selected augmentation techniques to all images in a specified input directory and saves the processed images in an output directory.
407
+
408
+ **To use this script:**
409
+ - Ensure you have the necessary dependencies installed.
410
+ - Specify the input and output paths: Replace `'path/to/input'` with the path to your input images and `'path/to/output'` with the desired path for the processed images.
411
+ - The number of processed variations per image, the inclusion of the original images in the output, and the processing techniques with their parameters will be automatically set based on your selections.
412
+
413
+ ### Python Code
414
+ </div>
415
+ """,
416
+ unsafe_allow_html=True,
417
+ )
418
+
419
+ # Display python code
420
+ st.code(generated_code, language="python")
421
+
422
+ with col2:
423
+ # Create a button for downloading the Python file
424
+ st.download_button(
425
+ label="Download Python File",
426
+ data=generate_downloadable_file(generated_code),
427
+ file_name="image_processing_script.py",
428
+ mime="text/plain",
429
+ use_container_width=True,
430
+ )
Functions/model_training_functions.py ADDED
@@ -0,0 +1,1896 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing necessary libraries
2
+ import io
3
+ import os
4
+ import utils
5
+ import random
6
+ import shutil
7
+ import zipfile
8
+ import numpy as np
9
+ import pandas as pd
10
+ import streamlit as st
11
+ from ultralytics import YOLO
12
+ import plotly.graph_objs as go
13
+ from onnx.defs import onnx_opset_version
14
+ from plotly.subplots import make_subplots
15
+
16
+
17
+ # Function to get the dataset directory path based on the specified path type
18
+ def get_path(path_type):
19
+ main_directory_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
20
+
21
+ if path_type == "train":
22
+ return os.path.join(
23
+ main_directory_path,
24
+ "model_data",
25
+ "input_files",
26
+ "datasets",
27
+ "train",
28
+ )
29
+ elif path_type == "val":
30
+ return os.path.join(
31
+ main_directory_path,
32
+ "model_data",
33
+ "input_files",
34
+ "datasets",
35
+ "val",
36
+ )
37
+ elif path_type == "test":
38
+ return os.path.join(
39
+ main_directory_path,
40
+ "model_data",
41
+ "input_files",
42
+ "datasets",
43
+ "test",
44
+ )
45
+ elif path_type == "config":
46
+ return os.path.join(main_directory_path, "model_data", "input_files")
47
+ elif path_type == "models":
48
+ return os.path.join(main_directory_path, "model_data", "models")
49
+ elif path_type == "output":
50
+ return os.path.join(main_directory_path, "model_data", "output_files")
51
+ else:
52
+ raise ValueError(f"Invalid path_type: {path_type}")
53
+
54
+
55
+ # Function to check minimum images in training and validation set
56
+ def check_min_images(total_files, train_pct, val_pct, test_pct):
57
+ # Calculate raw counts based on percentages
58
+ train_count = int(total_files * train_pct / 100)
59
+ val_count = int(total_files * val_pct / 100)
60
+ test_count = int(total_files * test_pct / 100)
61
+
62
+ # Ensure that both train and validation have at least one file
63
+ if train_count < 1 or val_count < 1:
64
+ return False
65
+
66
+ return True
67
+
68
+
69
+ # Function to clear data a folders
70
+ def clear_data_folders():
71
+ base_path = "./model_data/input_files/datasets"
72
+ for folder in ["train", "test", "val"]:
73
+ for subfolder in ["images", "labels"]:
74
+ folder_path = os.path.join(base_path, folder, subfolder)
75
+ if os.path.exists(folder_path):
76
+ shutil.rmtree(folder_path)
77
+ os.makedirs(folder_path, exist_ok=True)
78
+
79
+
80
+ # Function to pairs image and label files based on their filenames
81
+ def pair_files(files):
82
+ paired_files = {}
83
+ for file in files:
84
+ # Split the filename into name and extension
85
+ file_name, file_ext = os.path.splitext(file.name)
86
+
87
+ # Initialize a dict for each unique file name
88
+ if file_name not in paired_files:
89
+ paired_files[file_name] = {"image": None, "label": None}
90
+
91
+ # Assign the file to its corresponding type (image or label) based on extension
92
+ if file_ext.lower() in [".jpg", ".png"]:
93
+ paired_files[file_name]["image"] = file
94
+ elif file_ext.lower() == ".txt":
95
+ paired_files[file_name]["label"] = file
96
+
97
+ return paired_files
98
+
99
+
100
+ # Function to split the paired files into training, testing, and validation sets based on specified percentages and saves them in corresponding folders
101
+ def split_and_save_files(paired_files, train_pct, test_pct):
102
+ base_path = "./model_data/input_files/datasets"
103
+ all_keys = list(paired_files.keys())
104
+ random.shuffle(all_keys)
105
+
106
+ # Determine the size of each dataset split
107
+ total_files = len(all_keys)
108
+ train_size = int(total_files * train_pct / 100)
109
+ test_size = int(total_files * test_pct / 100)
110
+
111
+ # Split the file keys into training, testing, and validation sets
112
+ train_keys = all_keys[:train_size]
113
+ test_keys = all_keys[train_size : train_size + test_size]
114
+ val_keys = all_keys[train_size + test_size :]
115
+
116
+ # Iterate through each split and save the files to their respective directories
117
+ for folder_name, keys in zip(
118
+ ["train", "test", "val"], [train_keys, test_keys, val_keys]
119
+ ):
120
+ for key in keys:
121
+ image_file = paired_files[key]["image"]
122
+ label_file = paired_files[key]["label"]
123
+ # Save the image and label files if they exist
124
+ if image_file:
125
+ save_file_to_folder(
126
+ image_file, os.path.join(base_path, folder_name, "images")
127
+ )
128
+ if label_file:
129
+ save_file_to_folder(
130
+ label_file, os.path.join(base_path, folder_name, "labels")
131
+ )
132
+
133
+
134
+ # Function to save an individual file to a specified folder
135
+ def save_file_to_folder(file, folder_path):
136
+ os.makedirs(folder_path, exist_ok=True)
137
+ file_path = os.path.join(folder_path, file.name)
138
+ with open(file_path, "wb") as f:
139
+ f.write(file.getbuffer())
140
+
141
+
142
+ # Function to save uploaded files to a specific folder within the base path
143
+ def save_files_to_folder(uploaded_files, folder_name):
144
+ # Define the base path for saving the files
145
+ base_path = "./model_data/input_files/datasets"
146
+
147
+ # Iterate through each uploaded file
148
+ for file in uploaded_files:
149
+ if file:
150
+ # Determine the file type based on file extension
151
+ file_type = (
152
+ "images"
153
+ if os.path.splitext(file.name)[1].lower() in [".jpg", ".png"]
154
+ else "labels"
155
+ )
156
+
157
+ # Save the file to the appropriate subfolder (images or labels)
158
+ save_file_to_folder(file, os.path.join(base_path, folder_name, file_type))
159
+
160
+
161
+ # Function to validate each line in the label file for bounding box data
162
+ def check_bboxes_label(label_file, class_dict):
163
+ for line in label_file:
164
+ try:
165
+ # Decode the line, strip whitespace, split into parts, and convert each part to float
166
+ class_id, x_center, y_center, width, height = map(
167
+ float, line.decode().strip().split()
168
+ )
169
+
170
+ # Check if bounding box coordinates and class ID are valid
171
+ if not (
172
+ 0 <= x_center <= 1
173
+ and 0 <= y_center <= 1
174
+ and 0 <= width <= 1
175
+ and 0 <= height <= 1
176
+ and class_id in class_dict.keys()
177
+ ):
178
+ # Return False if any condition is not met (invalid data)
179
+ return False
180
+
181
+ except Exception as e:
182
+ # Return False in case of any exception (e.g., parsing error)
183
+ return False
184
+
185
+ # Return True if all lines in the label file pass the validation
186
+ return True
187
+
188
+
189
+ # Function to validate each line in the label file for mask data
190
+ def check_masks_label(label_file, class_dict):
191
+ for line in label_file:
192
+ try:
193
+ # Decode the line and split into parts: class ID and points
194
+ parts = line.decode().strip().split()
195
+ class_id = int(
196
+ parts[0]
197
+ ) # Convert the first part to an integer for class ID
198
+ points = [
199
+ float(p) for p in parts[1:]
200
+ ] # Convert the remaining parts to float for coordinates
201
+
202
+ # Check if class ID exists in the class dictionary and all points are within [0, 1]
203
+ if not (class_id in class_dict.keys() and all(0 <= p <= 1 for p in points)):
204
+ return False # Return False if validation fails
205
+
206
+ except Exception as e:
207
+ # Return False in case of any exception (e.g., parsing error)
208
+ return False
209
+
210
+ return True # Return True if all lines in the label file pass the validation
211
+
212
+
213
+ # Function to read label from YOLO format
214
+ def read_label(file, selected_option, class_dict):
215
+ # Read the content of the file
216
+ file_content = file.readlines()
217
+
218
+ # Check and validate bounding box labels if the selected option is 'Bboxes'
219
+ if selected_option == "Bboxes":
220
+ return check_bboxes_label(file_content, class_dict) # Validate bbox labels
221
+
222
+ # Check and validate mask labels if the selected option is 'Masks'
223
+ elif selected_option == "Masks":
224
+ return check_masks_label(file_content, class_dict) # Validate mask labels
225
+
226
+ # Return False if the selected option is neither 'Bboxes' nor 'Masks'
227
+ return False
228
+
229
+
230
+ # Function to check for duplicates
231
+ def check_file_duplicates(file_names):
232
+ unique_names = set(file_names)
233
+ return len(unique_names) == len(file_names)
234
+
235
+
236
+ # Function to validates the uploaded image and label files
237
+ def validate_files(image_names, label_names):
238
+ # Check for duplicate filenames in both images and labels
239
+ if not check_file_duplicates(image_names) or not check_file_duplicates(label_names):
240
+ # Show warning if duplicates are found
241
+ st.warning(
242
+ "Duplicate file names detected. Please ensure each image and label has a unique name.",
243
+ icon="⚠️",
244
+ )
245
+ return False # Return False indicating validation failed
246
+
247
+ # Check if the number of images matches the number of labels
248
+ if len(image_names) != len(label_names):
249
+ # Show warning if counts don't match
250
+ st.warning(
251
+ "Count Mismatch: The number of uploaded images and labels does not match.",
252
+ icon="⚠️",
253
+ )
254
+ return False # Return False indicating validation failed
255
+
256
+ # Display a success message if the above checks pass
257
+ st.info(
258
+ f"Validated: {len(image_names)} images and labels successfully matched.",
259
+ icon="✅",
260
+ )
261
+ return True # Return True indicating successful validation
262
+
263
+
264
+ # Function to check labels format
265
+ @st.cache_resource(show_spinner=False)
266
+ def check_valid_labels(uploaded_files, selected_option, class_dict):
267
+ # Check if no files were uploaded
268
+ if len(uploaded_files) == 0:
269
+ st.warning("Please upload images and labels.", icon="⚠️")
270
+ return False
271
+
272
+ # Initialize lists to store names of image and label files
273
+ image_names, label_names = [], []
274
+
275
+ # Initialize a progress bar and progress text
276
+ progress_bar = st.progress(0)
277
+ progress_text = st.empty()
278
+ total_files = len(uploaded_files)
279
+
280
+ # Iterate over each uploaded file
281
+ for index, file in enumerate(uploaded_files):
282
+ # Reset the file pointer to the beginning
283
+ file.seek(0)
284
+
285
+ # Check file type and categorize as image or label
286
+ if file.type in ["image/jpeg", "image/png"]:
287
+ # Add to image names list if file is an image
288
+ image_names.append(file.name)
289
+ elif file.type == "text/plain":
290
+ # Read and validate label file
291
+ if not read_label(file, selected_option, class_dict):
292
+ # Show warning if label format or data is invalid
293
+ st.warning(
294
+ f"Invalid label format or data in file: {file.name}", icon="⚠️"
295
+ )
296
+ return False
297
+ # Add to label names list if file is a valid label
298
+ label_names.append(file.name)
299
+
300
+ # Update progress bar and display current progress
301
+ progress_percentage = (index + 1) / total_files
302
+ progress_bar.progress(progress_percentage)
303
+ progress_text.text(f"Validating file {index + 1} of {total_files}")
304
+
305
+ # Remove progress bar and progress text after processing
306
+ progress_bar.empty()
307
+ progress_text.empty()
308
+
309
+ # Validate if all images have corresponding labels and vice versa
310
+ return validate_files(image_names, label_names)
311
+
312
+
313
+ # Function to get training, validation and export configurations
314
+ def get_training_validation_export_configuration(selected_training):
315
+ with st.expander("Training Configuration"):
316
+ # User Instruction for Default Values
317
+ st.markdown(
318
+ """
319
+ <div style='text-align: justify;'>
320
+ <b>User Instructions:</b> If you are unsure about the specific values to use for training parameters, it is
321
+ recommended to stick with the default values provided. These defaults are carefully chosen to provide a good balance
322
+ between performance and resource utilization for most scenarios. You can always come back and tweak these settings
323
+ once you have more experience or specific requirements for your model training.
324
+ </div>
325
+ """,
326
+ unsafe_allow_html=True,
327
+ )
328
+
329
+ # Padding
330
+ utils.top_padding(2)
331
+
332
+ # Training Configuration
333
+ st.markdown("### Training Configuration")
334
+
335
+ # Model Selection
336
+ st.write("**Model Selection**")
337
+ selected_model = st.selectbox(
338
+ "Choose a YOLOv8 model variant", list(utils.models_info.keys())
339
+ )
340
+ model_spec = utils.models_info[selected_model]
341
+ spec_string = (
342
+ "<div style='text-align: justify;'>"
343
+ f"The selected model, <b>{selected_model}</b>, is benchmarked on an image size of 640x640 pixels. It has a Mean Average Precision (mAPval) of <b>{model_spec['mAPval']}</b>, "
344
+ f"operates with a speed of <b>{model_spec['speed_cpu']} ms</b> on CPU (ONNX) and <b>{model_spec['speed_gpu']} ms</b> on GPU (TensorRT). "
345
+ f"It consists of approximately <b>{model_spec['params']} million</b> parameters and requires about <b>{model_spec['flops']} billion</b> Floating Point Operations (FLOPs)."
346
+ "</div>"
347
+ )
348
+ st.markdown(spec_string, unsafe_allow_html=True)
349
+
350
+ # Spacer
351
+ st.markdown("---")
352
+
353
+ # Time Configuration
354
+ st.write("**Time Configuration**")
355
+ col1_time, col2_time = st.columns([1, 3])
356
+ with col1_time:
357
+ top_padding_time = st.container()
358
+ time_allow = st.checkbox("Enable Time", value=False)
359
+ if time_allow:
360
+ with top_padding_time:
361
+ utils.top_padding(2)
362
+ time = col2_time.number_input(
363
+ "Time (hours)", min_value=1, max_value=100, value=1, step=1
364
+ )
365
+ else:
366
+ time = None
367
+ st.markdown(
368
+ "<div style='text-align: justify;'>Set the training duration in hours. This option overrides the epochs setting. Useful for limiting training time in scenarios with constrained resources.</div>",
369
+ unsafe_allow_html=True,
370
+ )
371
+
372
+ # Spacer
373
+ st.markdown("---")
374
+
375
+ # Epochs Configuration
376
+ st.write("**Epochs Configuration**")
377
+ epochs = st.number_input(
378
+ "Epochs", min_value=1, max_value=1000, value=50, step=10
379
+ )
380
+ st.markdown(
381
+ "<div style='text-align: justify;'>Define the number of epochs for the training process. An epoch represents a complete pass over the entire dataset. More epochs can improve accuracy but increase training time.</div>",
382
+ unsafe_allow_html=True,
383
+ )
384
+
385
+ # Spacer
386
+ st.markdown("---")
387
+
388
+ # Patience Configuration
389
+ st.write("**Patience Configuration**")
390
+ col1_patience, col2_patience = st.columns([1, 3])
391
+ with col1_patience:
392
+ top_padding_patience = st.container()
393
+ patience_allow = st.checkbox("Enable Patience", value=False)
394
+ if patience_allow:
395
+ with top_padding_patience:
396
+ utils.top_padding(2)
397
+ patience = col2_patience.number_input(
398
+ "Patience (epochs)", min_value=5, max_value=50, value=5, step=1
399
+ )
400
+ else:
401
+ patience = None
402
+ st.markdown(
403
+ "<div style='text-align: justify;'>Configure the early stopping mechanism. Patience denotes the number of epochs to wait for improvement in performance before stopping the training, helping to avoid overfitting.</div>",
404
+ unsafe_allow_html=True,
405
+ )
406
+
407
+ # Spacer
408
+ st.markdown("---")
409
+
410
+ # Batch Size Configuration
411
+ st.write("**Batch Size Configuration**")
412
+ batch = st.number_input(
413
+ "Batch Size", min_value=-1, max_value=128, value=-1, step=1
414
+ )
415
+ st.markdown(
416
+ "<div style='text-align: justify;'>Determine the number of images processed together in one pass (batch). A larger batch size can lead to faster training but requires more memory. Use -1 for automatic batch sizing.</div>",
417
+ unsafe_allow_html=True,
418
+ )
419
+
420
+ # Spacer
421
+ st.markdown("---")
422
+
423
+ # Image Size Configuration
424
+ st.write("**Image Size Configuration**")
425
+ imgsz = st.number_input(
426
+ "Image Size (pixels)", min_value=64, max_value=4096, value=640, step=32
427
+ )
428
+ st.markdown(
429
+ "<div style='text-align: justify;'>Specify the size of the input images. Larger images can capture more details but require more computational resources. The size is typically a square dimension, like 640x640 pixels.</div>",
430
+ unsafe_allow_html=True,
431
+ )
432
+
433
+ # Spacer
434
+ st.markdown("---")
435
+
436
+ # Cache Configuration
437
+ st.write("**Cache Configuration**")
438
+ cache = st.selectbox("Cache Option", ["False", "True/ram", "disk"])
439
+ st.markdown(
440
+ "<div style='text-align: justify;'>Choose a caching method for data loading to speed up training. 'True/ram' caches data in RAM, 'disk' caches on disk, and 'False' disables caching.</div>",
441
+ unsafe_allow_html=True,
442
+ )
443
+
444
+ # Spacer
445
+ st.markdown("---")
446
+
447
+ # Optimizer Configuration
448
+ st.write("**Optimizer Configuration**")
449
+ optimizer = st.selectbox(
450
+ "Optimizer",
451
+ ["SGD", "Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "auto"],
452
+ index=7,
453
+ )
454
+ st.markdown(
455
+ "<div style='text-align: justify;'>Select the optimizer for training. The optimizer adjusts weights to minimize the loss function. Choices include SGD, Adam, and others, with 'auto' selecting automatically based on the model.</div>",
456
+ unsafe_allow_html=True,
457
+ )
458
+
459
+ # Spacer
460
+ st.markdown("---")
461
+
462
+ # AMP Configuration
463
+ st.write("**AMP Configuration**")
464
+ amp = st.checkbox("Enable AMP", value=True)
465
+ st.markdown(
466
+ "<div style='text-align: justify;'>Enable Automatic Mixed Precision (AMP) to accelerate training on compatible hardware. AMP uses lower precision to reduce memory usage and speed up computations.</div>",
467
+ unsafe_allow_html=True,
468
+ )
469
+
470
+ # Spacer
471
+ st.markdown("---")
472
+
473
+ # Deterministic Mode Configuration
474
+ st.write("**Deterministic Mode Configuration**")
475
+ deterministic = st.checkbox("Enable Deterministic Mode", value=False)
476
+ st.markdown(
477
+ "<div style='text-align: justify;'>Activate deterministic mode to ensure reproducible results. This mode might slow down the training but is useful for experimentation and debugging.</div>",
478
+ unsafe_allow_html=True,
479
+ )
480
+
481
+ # Spacer
482
+ st.markdown("---")
483
+
484
+ # Rectangular Training Configuration
485
+ st.write("**Rectangular Training Configuration**")
486
+ rect = st.checkbox("Enable Rectangular Training", value=False)
487
+ st.markdown(
488
+ "<div style='text-align: justify;'>Enable rectangular training to process batches with minimal padding by reshaping images. This can lead to performance improvements but may affect accuracy.</div>",
489
+ unsafe_allow_html=True,
490
+ )
491
+
492
+ # Spacer
493
+ st.markdown("---")
494
+
495
+ # Cosine Learning Rate Scheduler Configuration
496
+ st.write("**Cosine Learning Rate Scheduler**")
497
+ cos_lr = st.checkbox("Use Cosine LR Scheduler", value=False)
498
+ st.markdown(
499
+ "<div style='text-align: justify;'>Use a cosine learning rate scheduler to adjust the learning rate following a cosine curve, potentially leading to better convergence during training.</div>",
500
+ unsafe_allow_html=True,
501
+ )
502
+
503
+ # Spacer
504
+ st.markdown("---")
505
+
506
+ # Freeze Layer Configuration
507
+ st.write("**Freeze Layer Configuration**")
508
+ col1_freeze, col2_freeze = st.columns([1, 3])
509
+ with col1_freeze:
510
+ top_padding_freeze = st.container()
511
+ freeze_allow = st.checkbox("Enable Freeze Layers", value=False)
512
+ if freeze_allow:
513
+ with top_padding_freeze:
514
+ utils.top_padding(2)
515
+ freeze = col2_freeze.number_input(
516
+ "Freeze Layers",
517
+ min_value=1,
518
+ max_value=1000,
519
+ value=10,
520
+ placeholder="Enter number of layers",
521
+ )
522
+ else:
523
+ freeze = None
524
+ st.markdown(
525
+ "<div style='text-align: justify;'>Enable freezing the initial layers of the model during training. Specify the number of layers to freeze or a comma-separated list of specific layer indices. Useful for fine-tuning pre-trained models without modifying early layers.</div>",
526
+ unsafe_allow_html=True,
527
+ )
528
+
529
+ # Spacer
530
+ st.markdown("---")
531
+
532
+ # Initial Learning Rate Configuration
533
+ st.write("**Initial Learning Rate (lr0)**")
534
+ lr0 = st.number_input(
535
+ "Initial Learning Rate (lr0)",
536
+ min_value=0.00001,
537
+ max_value=1.0,
538
+ value=0.01,
539
+ format="%.5f",
540
+ )
541
+ st.markdown(
542
+ "<div style='text-align: justify;'>Specify the initial learning rate (lr0) for the training process. The initial rate is crucial as it determines the starting step size for weight updates. A well-chosen initial rate helps in achieving a balance between fast convergence and overshooting the optimal solution.</div>",
543
+ unsafe_allow_html=True,
544
+ )
545
+
546
+ # Spacer
547
+ st.markdown("---")
548
+
549
+ # Final Learning Rate Configuration
550
+ st.write("**Final Learning Rate (lrf)**")
551
+ lrf = st.number_input(
552
+ "Final Learning Rate (lrf)",
553
+ min_value=0.00001,
554
+ max_value=1.0,
555
+ value=0.01,
556
+ format="%.5f",
557
+ )
558
+ st.markdown(
559
+ "<div style='text-align: justify;'>Determine the final learning rate, which is a factor (lrf) of the initial learning rate (lr0). This parameter is used to adjust the learning rate over the course of training, gradually decreasing it to fine-tune model weights and stabilize training as it approaches the minimum of the loss function.</div>",
560
+ unsafe_allow_html=True,
561
+ )
562
+
563
+ # Spacer
564
+ st.markdown("---")
565
+
566
+ # Momentum Configuration
567
+ st.write("**Momentum Configuration**")
568
+ momentum = st.number_input(
569
+ "Momentum", min_value=0.0, max_value=1.0, value=0.937, format="%.3f"
570
+ )
571
+ st.markdown(
572
+ "<div style='text-align: justify;'>Set the momentum value for the optimizer. Momentum helps in accelerating the optimizer in the relevant direction and dampens oscillations, facilitating faster convergence.</div>",
573
+ unsafe_allow_html=True,
574
+ )
575
+
576
+ # Spacer
577
+ st.markdown("---")
578
+
579
+ # Weight Decay Configuration
580
+ st.write("**Weight Decay Configuration**")
581
+ weight_decay = st.number_input(
582
+ "Weight Decay", min_value=0.0, max_value=0.1, value=0.0005, format="%.5f"
583
+ )
584
+ st.markdown(
585
+ "<div style='text-align: justify;'>Specify the weight decay, a regularization technique that adds a small penalty to the loss function for larger weights. It helps in preventing overfitting by encouraging simpler models.</div>",
586
+ unsafe_allow_html=True,
587
+ )
588
+
589
+ # Spacer
590
+ st.markdown("---")
591
+
592
+ # Warmup Epochs Configuration
593
+ st.write("**Warmup Epochs Configuration**")
594
+ warmup_epochs = st.number_input(
595
+ "Warmup Epochs", min_value=0.0, max_value=10.0, value=3.0, step=0.1
596
+ )
597
+ st.markdown(
598
+ "<div style='text-align: justify;'>Define the number of warmup epochs. During warmup, the learning rate gradually increases to its initial value, which helps in stabilizing the training process in its early stages.</div>",
599
+ unsafe_allow_html=True,
600
+ )
601
+
602
+ # Spacer
603
+ st.markdown("---")
604
+
605
+ # Warmup Momentum Configuration
606
+ st.write("**Warmup Momentum Configuration**")
607
+ warmup_momentum = st.number_input(
608
+ "Warmup Momentum", min_value=0.0, max_value=1.0, value=0.8, format="%.1f"
609
+ )
610
+ st.markdown(
611
+ "<div style='text-align: justify;'>Configure the momentum during the warmup phase. A lower momentum at the start can help in stabilizing the optimization process before reaching the specified momentum for the remaining epochs.</div>",
612
+ unsafe_allow_html=True,
613
+ )
614
+
615
+ # Spacer
616
+ st.markdown("---")
617
+
618
+ # Warmup Bias Learning Rate Configuration
619
+ st.write("**Warmup Bias Learning Rate Configuration**")
620
+ warmup_bias_lr = st.number_input(
621
+ "Warmup Bias Learning Rate",
622
+ min_value=0.0,
623
+ max_value=1.0,
624
+ value=0.1,
625
+ format="%.1f",
626
+ )
627
+ st.markdown(
628
+ "<div style='text-align: justify;'>Adjust the bias learning rate during the warmup period. This parameter can be tuned to manage the initial learning rate specifically for the bias parameters in the early training phase.</div>",
629
+ unsafe_allow_html=True,
630
+ )
631
+
632
+ # Spacer
633
+ st.markdown("---")
634
+
635
+ # Box Loss Gain Configuration
636
+ st.write("**Box Loss Gain Configuration**")
637
+ box = st.number_input(
638
+ "Box Loss Gain", min_value=0.0, max_value=10.0, value=7.5, step=0.1
639
+ )
640
+ st.markdown(
641
+ "<div style='text-align: justify;'>Configure the gain factor for the box loss. This gain helps in adjusting the importance of the box size and location accuracy in the loss function, affecting how the model prioritizes bounding box precision.</div>",
642
+ unsafe_allow_html=True,
643
+ )
644
+
645
+ # Spacer
646
+ st.markdown("---")
647
+
648
+ # Class Loss Gain Configuration
649
+ st.write("**Class Loss Gain Configuration**")
650
+ cls = st.number_input(
651
+ "Class Loss Gain", min_value=0.0, max_value=10.0, value=0.5, step=0.1
652
+ )
653
+ st.markdown(
654
+ "<div style='text-align: justify;'>Set the gain factor for the class loss. This parameter scales the contribution of class prediction accuracy in the total loss, influencing how the model prioritizes correct class identification.</div>",
655
+ unsafe_allow_html=True,
656
+ )
657
+
658
+ # Spacer
659
+ st.markdown("---")
660
+
661
+ # DFL Loss Gain Configuration
662
+ st.write("**DFL Loss Gain Configuration**")
663
+ dfl = st.number_input(
664
+ "DFL Loss Gain", min_value=0.0, max_value=10.0, value=1.5, step=0.1
665
+ )
666
+ st.markdown(
667
+ "<div style='text-align: justify;'>Determine the gain factor for the DFL loss. Adjusting this gain influences the model's focus on the Directional Focal Loss component, which is critical for precise object localization and classification.</div>",
668
+ unsafe_allow_html=True,
669
+ )
670
+
671
+ # Spacer
672
+ st.markdown("---")
673
+
674
+ # Label Smoothing Configuration
675
+ st.write("**Label Smoothing Configuration**")
676
+ label_smoothing = st.number_input(
677
+ "Label Smoothing (fraction)",
678
+ min_value=0.0,
679
+ max_value=1.0,
680
+ value=0.0,
681
+ format="%.1f",
682
+ )
683
+ st.markdown(
684
+ "<div style='text-align: justify;'>Specify the label smoothing value, a technique that introduces softening to the target labels. It promotes model generalization and reduces the impact of noisy labels on the training process.</div>",
685
+ unsafe_allow_html=True,
686
+ )
687
+
688
+ # Spacer
689
+ st.markdown("---")
690
+
691
+ # Nominal Batch Size Configuration
692
+ st.write("**Nominal Batch Size Configuration**")
693
+ nbs = st.number_input(
694
+ "Nominal Batch Size", min_value=1, max_value=128, value=64, step=1
695
+ )
696
+ st.markdown(
697
+ "<div style='text-align: justify;'>Set the nominal batch size, which is used for normalizing the loss. This size does not affect the actual batch size but is used to scale the loss to a standard reference batch size.</div>",
698
+ unsafe_allow_html=True,
699
+ )
700
+
701
+ # Spacer
702
+ st.markdown("---")
703
+
704
+ # Overlap Mask Configuration
705
+ st.write("**Overlap Mask Configuration**")
706
+ overlap_mask = st.checkbox("Masks Overlap during Training", value=True)
707
+ st.markdown(
708
+ "<div style='text-align: justify;'>Choose whether to allow masks to overlap during instance segmentation training. Overlapping can lead to more precise segmentation but may increase complexity.</div>",
709
+ unsafe_allow_html=True,
710
+ )
711
+
712
+ # Spacer
713
+ st.markdown("---")
714
+
715
+ # Mask Ratio Configuration
716
+ st.write("**Mask Ratio Configuration**")
717
+ mask_ratio = st.number_input(
718
+ "Mask Downsample Ratio", min_value=1, max_value=10, value=4, step=1
719
+ )
720
+ st.markdown(
721
+ "<div style='text-align: justify;'>Set the downsample ratio for masks in instance segmentation. A higher ratio reduces the mask resolution, which can speed up computations but might decrease segmentation accuracy.</div>",
722
+ unsafe_allow_html=True,
723
+ )
724
+
725
+ # Spacer
726
+ st.markdown("---")
727
+
728
+ # Dropout Configuration
729
+ st.write("**Dropout Configuration**")
730
+ dropout = st.number_input(
731
+ "Dropout Regularization",
732
+ min_value=0.0,
733
+ max_value=1.0,
734
+ value=0.0,
735
+ format="%.1f",
736
+ )
737
+ st.markdown(
738
+ "<div style='text-align: justify;'>Configure the dropout rate, which randomly disables a proportion of neurons during training. This prevents the model from relying too much on certain features and promotes better generalization.</div>",
739
+ unsafe_allow_html=True,
740
+ )
741
+
742
+ # Spacer
743
+ st.markdown("---")
744
+
745
+ # Validation/Test Configuration
746
+ st.write("**Validation/Test Configuration**")
747
+ val = st.checkbox("Validate/Test during Training", value=True)
748
+ st.markdown(
749
+ "<div style='text-align: justify;'>Decide whether to perform validation and testing during the training process. Regular validation helps monitor model performance and adjust training accordingly.</div>",
750
+ unsafe_allow_html=True,
751
+ )
752
+
753
+ # Spacer
754
+ st.markdown("---")
755
+
756
+ # Save Plots Configuration
757
+ st.write("**Save Plots Configuration**")
758
+ plots = st.checkbox("Save Plots and Images during Training", value=True)
759
+ st.markdown(
760
+ "<div style='text-align: justify;'>Enable saving of plots and images during training. This feature provides visual insights into the training progress and helps in diagnosing model performance across epochs.</div>",
761
+ unsafe_allow_html=True,
762
+ )
763
+
764
+ # Padding
765
+ utils.top_padding(2)
766
+
767
+ with st.expander("Validation Configuration"):
768
+ # User Instruction for Default Values
769
+ st.markdown(
770
+ """
771
+ <div style='text-align: justify;'>
772
+ <b>User Instructions:</b> If you are unsure about the specific values to use for validation parameters, it is
773
+ recommended to stick with the default values provided. These defaults are carefully chosen to provide a good balance
774
+ between performance and resource utilization for most scenarios. You can always come back and tweak these settings
775
+ once you have more experience or specific requirements for your model validation.
776
+ </div>
777
+ """,
778
+ unsafe_allow_html=True,
779
+ )
780
+
781
+ # Padding
782
+ utils.top_padding(2)
783
+
784
+ # Validation Configuration
785
+ st.markdown("### Validation Configuration")
786
+
787
+ # Object Confidence Threshold
788
+ st.write("**Object Confidence Threshold**")
789
+ conf = st.number_input(
790
+ "Confidence Threshold",
791
+ min_value=0.0,
792
+ max_value=1.0,
793
+ value=0.001,
794
+ format="%.3f",
795
+ )
796
+ st.markdown(
797
+ "<div style='text-align: justify;'>Set the confidence threshold for object detection. This threshold filters out detections with lower confidence, reducing false positives and focusing on more likely object detections.</div>",
798
+ unsafe_allow_html=True,
799
+ )
800
+
801
+ # Spacer
802
+ st.markdown("---")
803
+
804
+ # Intersection Over Union (IoU) Threshold
805
+ st.write("**IoU Threshold for NMS**")
806
+ iou = st.number_input(
807
+ "IoU Threshold", min_value=0.0, max_value=1.0, value=0.6, format="%.1f"
808
+ )
809
+ st.markdown(
810
+ "<div style='text-align: justify;'>Define the IoU threshold for Non-Maximum Suppression. NMS is used to refine the bounding boxes by eliminating redundancies and retaining the most probable ones.</div>",
811
+ unsafe_allow_html=True,
812
+ )
813
+
814
+ # Spacer
815
+ st.markdown("---")
816
+
817
+ # Maximum Number of Detections
818
+ st.write("**Maximum Number of Detections**")
819
+ max_det = st.number_input(
820
+ "Max Detections", min_value=1, max_value=1000, value=300, step=1
821
+ )
822
+ st.markdown(
823
+ "<div style='text-align: justify;'>Limit the maximum number of detections per image. This setting is crucial for controlling the computational load and focusing the model on the most confident and relevant detections.</div>",
824
+ unsafe_allow_html=True,
825
+ )
826
+
827
+ # Spacer
828
+ st.markdown("---")
829
+
830
+ # Use Half Precision
831
+ st.write("**Use Half Precision (FP16)**")
832
+ half = st.checkbox("Enable Half Precision", value=True)
833
+ st.markdown(
834
+ "<div style='text-align: justify;'>Enable half precision (FP16) training for enhanced performance on compatible GPUs. It reduces memory requirements and accelerates computation, beneficial for larger models and datasets.</div>",
835
+ unsafe_allow_html=True,
836
+ )
837
+
838
+ # Padding
839
+ utils.top_padding(2)
840
+
841
+ with st.expander("Export Configuration"):
842
+ # User Instruction for Default Values
843
+ st.markdown(
844
+ """
845
+ <div style='text-align: justify;'>
846
+ <b>User Instructions:</b> If you are unsure about the specific values to use for export parameters, it is
847
+ recommended to stick with the default values provided. These defaults are carefully chosen to provide a good balance
848
+ between performance and resource utilization for most scenarios. You can always come back and tweak these settings
849
+ once you have more experience or specific requirements for your model export.
850
+ </div>
851
+ """,
852
+ unsafe_allow_html=True,
853
+ )
854
+
855
+ # Padding
856
+ utils.top_padding(2)
857
+
858
+ # Validation Configuration
859
+ st.markdown("### Export Configuration")
860
+
861
+ # Select Export Format
862
+ st.write("**Export Format**")
863
+ export_format = st.selectbox(
864
+ "Select Export Format",
865
+ [
866
+ "Only PyTorch",
867
+ "TorchScript",
868
+ "ONNX",
869
+ "OpenVINO",
870
+ "TensorRT",
871
+ "CoreML",
872
+ "TF SavedModel",
873
+ "TF GraphDef",
874
+ "TF Lite",
875
+ "TF Edge TPU",
876
+ "TF.js",
877
+ "PaddlePaddle",
878
+ "ncnn",
879
+ ],
880
+ )
881
+
882
+ # Dynamically generate description
883
+ if export_format == "Only PyTorch":
884
+ st.markdown(
885
+ """
886
+ <div style='text-align: justify;'>
887
+ You have selected <b>PyTorch</b> as the export format.
888
+ This will export the model in the standard PyTorch <code>.pt</code> format.
889
+ There are no additional format-specific parameters to consider for this selection.
890
+ The exported model will be the same as selected during training.
891
+ </div>
892
+ """,
893
+ unsafe_allow_html=True,
894
+ )
895
+ else:
896
+ format_info = utils.export_formats[export_format]
897
+
898
+ # Handling additional arguments
899
+ if len(format_info["arguments"]) > 0:
900
+ additional_arguments = ", ".join(format_info["arguments"])
901
+ arguments_info = f"Consider the following arguments for the <b>{export_format}</b> format: {additional_arguments}."
902
+ else:
903
+ arguments_info = (
904
+ "No additional parameters need to be considered for this format."
905
+ )
906
+
907
+ st.markdown(
908
+ f"""
909
+ <div style='text-align: justify;'>
910
+ You have selected <b>{export_format}</b> as the export format. Along with the PyTorch model,
911
+ this selection will also export the model in the <b>{export_format}</b> format. The image size of
912
+ the exported model will be the same as selected during training. {arguments_info}
913
+ </div>
914
+ """,
915
+ unsafe_allow_html=True,
916
+ )
917
+
918
+ # Spacer
919
+ st.markdown("---")
920
+
921
+ # Use Keras for TF SavedModel export
922
+ st.write("**Use Keras for TF SavedModel Export**")
923
+ keras = st.checkbox("Enable Keras", value=False)
924
+ st.markdown(
925
+ "<div style='text-align: justify;'>Enabling Keras optimizes the TensorFlow SavedModel export for compatibility with the Keras API, making it easier to work with in Keras-centric workflows.</div>",
926
+ unsafe_allow_html=True,
927
+ )
928
+
929
+ # Spacer
930
+ st.markdown("---")
931
+
932
+ # Optimize for mobile (TorchScript)
933
+ st.write("**Optimize TorchScript for Mobile**")
934
+ optimize = st.checkbox("Enable Optimization", value=False)
935
+ st.markdown(
936
+ "<div style='text-align: justify;'>Optimizing for mobile reduces the model size and computational needs, enhancing performance on mobile devices with limited resources.</div>",
937
+ unsafe_allow_html=True,
938
+ )
939
+
940
+ # Spacer
941
+ st.markdown("---")
942
+
943
+ # FP16 quantization
944
+ st.write("**FP16 Quantization**")
945
+ half = st.checkbox("Enable FP16 Quantization", value=False)
946
+ st.markdown(
947
+ "<div style='text-align: justify;'>FP16 quantization reduces model size and speeds up inference, especially on GPUs with Tensor Cores, while maintaining model accuracy.</div>",
948
+ unsafe_allow_html=True,
949
+ )
950
+
951
+ # Spacer
952
+ st.markdown("---")
953
+
954
+ # INT8 quantization
955
+ st.write("**INT8 Quantization**")
956
+ int8 = st.checkbox("Enable INT8 Quantization", value=False)
957
+ st.markdown(
958
+ "<div style='text-align: justify;'>INT8 quantization further reduces model size and inference time, ideal for edge devices, at the cost of a slight decrease in accuracy.</div>",
959
+ unsafe_allow_html=True,
960
+ )
961
+
962
+ # Spacer
963
+ st.markdown("---")
964
+
965
+ # Dynamic axes for ONNX/TensorRT
966
+ st.write("**Dynamic Axes for ONNX/TensorRT**")
967
+ dynamic = st.checkbox("Enable Dynamic Axes", value=False)
968
+ st.markdown(
969
+ "<div style='text-align: justify;'>Dynamic axes allow the ONNX/TensorRT models to handle variable input sizes, increasing the model's flexibility in deployment.</div>",
970
+ unsafe_allow_html=True,
971
+ )
972
+
973
+ # Spacer
974
+ st.markdown("---")
975
+
976
+ # Simplify model for ONNX/TensorRT
977
+ st.write("**Simplify Model for ONNX/TensorRT**")
978
+ simplify = st.checkbox("Enable Model Simplification", value=False)
979
+ st.markdown(
980
+ "<div style='text-align: justify;'>Simplification optimizes the ONNX/TensorRT models by removing redundant operations, improving efficiency without impacting accuracy.</div>",
981
+ unsafe_allow_html=True,
982
+ )
983
+
984
+ # Spacer
985
+ st.markdown("---")
986
+
987
+ # ONNX Opset Version Configuration
988
+ st.write("**ONNX Opset Version Configuration**")
989
+ col1_opset, col2_opset = st.columns([1, 3])
990
+
991
+ with col1_opset:
992
+ top_padding_opset = st.container()
993
+ opset_allow = st.checkbox("Specify Opset Version", value=False)
994
+
995
+ if opset_allow:
996
+ with top_padding_opset:
997
+ utils.top_padding(2)
998
+ # Create a range of opset versions for the dropdown
999
+ opset_versions = list(range(1, onnx_opset_version() + 1))
1000
+
1001
+ with col2_opset:
1002
+ opset = st.selectbox(
1003
+ "Select Opset Version",
1004
+ opset_versions,
1005
+ index=len(opset_versions) - 1,
1006
+ )
1007
+ else:
1008
+ opset = None
1009
+
1010
+ st.markdown(
1011
+ "<div style='text-align: justify;'>Select the ONNX opset version for the export. "
1012
+ "Specifying an opset version can ensure compatibility with specific ONNX versions. "
1013
+ "The latest version is recommended to ensure the most up-to-date features and optimizations. "
1014
+ "If unsure, leave the checkbox unchecked to use the default opset version.</div>",
1015
+ unsafe_allow_html=True,
1016
+ )
1017
+
1018
+ # Spacer
1019
+ st.markdown("---")
1020
+
1021
+ # TensorRT workspace size
1022
+ st.write("**TensorRT Workspace Size (GB)**")
1023
+ workspace = st.number_input(
1024
+ "Workspace Size", min_value=1, max_value=32, value=4, step=1
1025
+ )
1026
+ st.markdown(
1027
+ "<div style='text-align: justify;'>Set the TensorRT workspace size in GB. A larger workspace can lead to more optimized models but requires more memory.</div>",
1028
+ unsafe_allow_html=True,
1029
+ )
1030
+
1031
+ # Spacer
1032
+ st.markdown("---")
1033
+
1034
+ # Add NMS for CoreML
1035
+ st.write("**Add NMS for CoreML**")
1036
+ nms = st.checkbox("Enable NMS", value=False)
1037
+ st.markdown(
1038
+ "<div style='text-align: justify;'>Enabling NMS (Non-Maximum Suppression) for CoreML models helps in reducing overlapping bounding boxes and improves the clarity of object detection results.</div>",
1039
+ unsafe_allow_html=True,
1040
+ )
1041
+
1042
+ # Padding
1043
+ utils.top_padding(2)
1044
+
1045
+ if selected_training == "Object Detection":
1046
+ model_path = os.path.join(
1047
+ get_path("models"), selected_model.lower() + ".pt"
1048
+ )
1049
+ task = "detect"
1050
+ elif selected_training == "Instance Segmentation":
1051
+ model_path = os.path.join(
1052
+ get_path("models"), selected_model.lower() + "-seg.pt"
1053
+ )
1054
+ task = "segment"
1055
+
1056
+ export_settings = {
1057
+ "format": None if export_format == "Only PyTorch" else export_format,
1058
+ "keras": keras,
1059
+ "optimize": optimize,
1060
+ "half": half,
1061
+ "int8": int8,
1062
+ "dynamic": dynamic,
1063
+ "simplify": simplify,
1064
+ "opset": opset,
1065
+ "workspace": workspace,
1066
+ "nms": nms,
1067
+ }
1068
+
1069
+ return {
1070
+ "model_path": model_path,
1071
+ "task": task,
1072
+ "model": selected_model,
1073
+ "time": time,
1074
+ "epochs": epochs,
1075
+ "patience": patience,
1076
+ "batch": batch,
1077
+ "imgsz": imgsz,
1078
+ "cache": cache,
1079
+ "optimizer": optimizer,
1080
+ "amp": amp,
1081
+ "deterministic": deterministic,
1082
+ "rect": rect,
1083
+ "cos_lr": cos_lr,
1084
+ "freeze": freeze,
1085
+ "lr0": lr0,
1086
+ "lrf": lrf,
1087
+ "momentum": momentum,
1088
+ "weight_decay": weight_decay,
1089
+ "warmup_epochs": warmup_epochs,
1090
+ "warmup_momentum": warmup_momentum,
1091
+ "warmup_bias_lr": warmup_bias_lr,
1092
+ "box": box,
1093
+ "cls": cls,
1094
+ "dfl": dfl,
1095
+ "label_smoothing": label_smoothing,
1096
+ "nbs": nbs,
1097
+ "overlap_mask": overlap_mask,
1098
+ "mask_ratio": mask_ratio,
1099
+ "dropout": dropout,
1100
+ "val": val,
1101
+ "plots": plots,
1102
+ "conf": conf,
1103
+ "iou": iou,
1104
+ "max_det": max_det,
1105
+ "half": half,
1106
+ "export_settings": export_settings,
1107
+ }
1108
+
1109
+
1110
+ # Function to generate python code for model training
1111
+ def generate_python_code_model_training(training_configuration):
1112
+ # Copy the original configuration and update with additional parameters
1113
+ training_configuration_code = training_configuration.copy()
1114
+ training_configuration_code["data"] = r".\config.yaml" # Path to config file
1115
+ training_configuration_code["save_dir"] = r".\output\train" # Output directory
1116
+ training_configuration_code["pretrained"] = True # Use a pretrained model
1117
+ training_configuration_code["save"] = True # Save the trained model
1118
+ training_configuration_code["save_period"] = -1 # Save period configuration
1119
+ training_configuration_code["augment"] = False # Augmentation setting
1120
+ training_configuration_code["seed"] = 0 # Seed for reproducibility
1121
+ training_configuration_code["verbose"] = True # Verbose output
1122
+ training_configuration_code["single_cls"] = False # Single class setting
1123
+ training_configuration_code["resume"] = False # Resume training setting
1124
+ training_configuration_code["exist_ok"] = True # Overwrite existing files
1125
+ training_configuration_code["project"] = r".\output" # Project directory
1126
+ training_configuration_code["name"] = "train" # Project name
1127
+
1128
+ # Extract the model name from the model path
1129
+ model_name = training_configuration_code["model_path"].split("\\")[-1]
1130
+
1131
+ # Start with necessary library imports and model initialization
1132
+ code_str = "# Importing necessary libraries\n"
1133
+ code_str += "from ultralytics import YOLO\n\n"
1134
+
1135
+ # Initialize the YOLO model
1136
+ code_str += f"# Initialize the YOLO model '{model_name}'\n"
1137
+ code_str += f"model = YOLO('{model_name}')\n"
1138
+
1139
+ # Add the model training code
1140
+ code_str += "\n# Start the training process\n"
1141
+ code_str += "model.train(\n"
1142
+ for key, value in training_configuration_code.items():
1143
+ if key not in [
1144
+ "model_path",
1145
+ "model",
1146
+ "export_settings",
1147
+ ]: # Exclude specific keys
1148
+ code_str += f" {key}={value},\n"
1149
+
1150
+ code_str = code_str.rstrip(",\n") + "\n)\n"
1151
+
1152
+ # Add model export code
1153
+ code_str += "\n# Model export process\n"
1154
+ code_str += "model.export(\n"
1155
+ for key, value in training_configuration_code["export_settings"].items():
1156
+ if key == "format" and value is None:
1157
+ continue # Skip format if it's None
1158
+ code_str += f" {key}={value},\n"
1159
+ code_str = code_str.rstrip(",\n") + "\n)\n"
1160
+
1161
+ return code_str
1162
+
1163
+
1164
+ # Function to overwrites a Python file with new code
1165
+ def overwrite_python_file(code_str, file_path):
1166
+ # Open the file in write mode, which automatically deletes old content
1167
+ with open(file_path, "w") as file:
1168
+ file.write(code_str)
1169
+
1170
+
1171
+ # Function to generate a downloadable file
1172
+ def display_code_and_download_button(generated_code):
1173
+ # Display the generated code in Streamlit with description and download button in columns
1174
+ with st.expander("Plug and Play Code"):
1175
+ col1, col2 = st.columns([7, 3])
1176
+
1177
+ with col1:
1178
+ st.markdown(
1179
+ """
1180
+ ### Description of the Code Pipeline
1181
+ """
1182
+ )
1183
+
1184
+ st.markdown(
1185
+ """
1186
+ <div style='text-align: justify;'>
1187
+ This Python script is configured for training a YOLO model. It includes necessary configurations and parameters for a custom YOLO model training session.
1188
+
1189
+ **To use this script:**
1190
+ - Ensure you have the necessary dependencies installed.
1191
+ - Place your image and label files in the `'datasets/train'`, `'datasets/val'`, and `'datasets/test'` folders respectively.
1192
+ - The `'config.yaml'` file and the training script are set up based on your provided configurations.
1193
+
1194
+ ### Python Code
1195
+ </div>
1196
+ """,
1197
+ unsafe_allow_html=True,
1198
+ )
1199
+
1200
+ # Display python code
1201
+ st.code(generated_code, language="python")
1202
+
1203
+ # Determine the main directory path
1204
+ main_directory_path = os.path.dirname(
1205
+ os.path.dirname(os.path.abspath(__file__))
1206
+ )
1207
+
1208
+ # Overwrites a Python file with new code
1209
+ overwrite_python_file(
1210
+ generated_code,
1211
+ os.path.join(
1212
+ main_directory_path,
1213
+ "model_data",
1214
+ "model_training_code_pipline",
1215
+ "model_training.py",
1216
+ ),
1217
+ )
1218
+
1219
+ # Determine the main directory path
1220
+ main_directory_path = os.path.dirname(
1221
+ os.path.dirname(os.path.abspath(__file__))
1222
+ )
1223
+
1224
+ # Prepare a ZIP file of the training output folder in memory for download
1225
+ zip_bytes_io = zip_folder_to_bytesio(
1226
+ os.path.join(
1227
+ main_directory_path, "model_data", "model_training_code_pipline"
1228
+ )
1229
+ )
1230
+
1231
+ with col2:
1232
+ # Create a button for downloading the training pipeline
1233
+ st.download_button(
1234
+ label="Download Training Pipeline",
1235
+ data=zip_bytes_io,
1236
+ file_name="model_training_code.zip",
1237
+ mime="application/zip",
1238
+ use_container_width=True,
1239
+ )
1240
+
1241
+
1242
+ # Function to generates a YOLO model training code snippet and displays it with a download button
1243
+ def generate_and_display_yolo_training_code(class_labels, training_configuration):
1244
+ # Determine the main directory path
1245
+ main_directory_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
1246
+
1247
+ # Construct the path to the config file directory
1248
+ config_file_path = os.path.join(
1249
+ main_directory_path, "model_data", "model_training_code_pipline"
1250
+ )
1251
+
1252
+ # Define the path to the dataset directory
1253
+ dataset_directory_path = "./datasets"
1254
+
1255
+ # Create YOLO config file using provided class labels and dataset directory
1256
+ create_yolo_config_file(config_file_path, class_labels, dataset_directory_path)
1257
+
1258
+ # Generate the Python code for YOLO model training
1259
+ generated_code = generate_python_code_model_training(training_configuration)
1260
+
1261
+ # Display the generated code and a download button
1262
+ display_code_and_download_button(generated_code)
1263
+
1264
+
1265
+ # Function to create a yolo config file
1266
+ def create_yolo_config_file(
1267
+ config_file_path, class_labels, dataset_directory_path=None
1268
+ ):
1269
+ if dataset_directory_path is None:
1270
+ dataset_directory_path = os.path.join(config_file_path, "datasets")
1271
+
1272
+ # Number of classes
1273
+ num_classes = len(class_labels)
1274
+
1275
+ # Create the configuration content
1276
+ config_content = f"""path: {dataset_directory_path} # Path to the dataset directory
1277
+ train: train # Path to the training set directory
1278
+ val: val # Path to the validation set directory
1279
+ test: test # Path to the testing set directory
1280
+ nc: {num_classes} # Number of classes
1281
+ names: {class_labels} # List of class names
1282
+ """
1283
+
1284
+ # Write the configuration to a file
1285
+ with open(os.path.join(config_file_path, "config.yaml"), "w") as file:
1286
+ file.write(config_content)
1287
+
1288
+
1289
+ # Function to delete and recreate a folder
1290
+ def delete_and_recreate_folder(folder_path):
1291
+ try:
1292
+ # Use shutil.rmtree to delete the folder and its contents
1293
+ shutil.rmtree(folder_path)
1294
+ # Recreate the folder at the same location
1295
+ os.makedirs(folder_path)
1296
+ except Exception as e:
1297
+ print(f"Error deleting or recreating folder {folder_path}: {e}")
1298
+
1299
+
1300
+ # Function to read csv and get values
1301
+ def read_csv_and_get_values(csv_file_path):
1302
+ # Read the CSV file into a pandas DataFrame
1303
+ df = pd.read_csv(csv_file_path)
1304
+
1305
+ # Initialize an empty dictionary to store the results
1306
+ result_dict = {}
1307
+
1308
+ # Iterate through the columns of the DataFrame
1309
+ for column in df.columns:
1310
+ # Remove leading and trailing spaces from the column name
1311
+ clean_column_name = column.strip()
1312
+
1313
+ # Get the values in the column
1314
+ column_values = df[column].astype(float)
1315
+
1316
+ # Add the cleaned column name and values to the result dictionary
1317
+ result_dict[clean_column_name] = np.array(column_values)
1318
+
1319
+ return result_dict
1320
+
1321
+
1322
+ # Global variables
1323
+ plot_container = None
1324
+ val_dataframe_container = None
1325
+ progress_bar = None
1326
+ progress_text = None
1327
+
1328
+
1329
+ # Function to define a custom callback function for on_pretrain_routine_start
1330
+ def on_pretrain_routine_start(trainer):
1331
+ global progress_text, progress_bar
1332
+ progress_bar = st.empty()
1333
+ progress_text = st.empty()
1334
+ progress_text.info(
1335
+ "Loading selected model...",
1336
+ icon="✅",
1337
+ )
1338
+
1339
+
1340
+ # Function to define a custom callback function for on_train_start
1341
+ def on_train_start(trainer):
1342
+ global progress_bar, progress_text
1343
+ progress_bar = st.progress(0)
1344
+ progress_text.info(
1345
+ "Training Started...",
1346
+ icon="✅",
1347
+ )
1348
+
1349
+
1350
+ # Function to display metrics plot
1351
+ st.cache_resource(show_spinner=False)
1352
+
1353
+
1354
+ def display_metrics_plot(output_data):
1355
+ global plot_container
1356
+
1357
+ # Extract data for each metric
1358
+ epoch_history = output_data.get("epoch")
1359
+
1360
+ # Extract loss histories
1361
+ train_box_loss_history = output_data.get("train/box_loss")
1362
+ train_cls_loss_history = output_data.get("train/cls_loss")
1363
+ train_dfl_loss_history = output_data.get("train/dfl_loss")
1364
+ train_seg_loss_history = output_data.get("train/seg_loss")
1365
+ val_box_loss_history = output_data.get("val/box_loss")
1366
+ val_cls_loss_history = output_data.get("val/cls_loss")
1367
+ val_dfl_loss_history = output_data.get("val/dfl_loss")
1368
+ val_seg_loss_history = output_data.get("val/seg_loss")
1369
+
1370
+ if train_seg_loss_history is None:
1371
+ train_seg_loss_history = epoch_history * 0
1372
+ val_seg_loss_history = epoch_history * 0
1373
+
1374
+ # Extract precision, recall, and mAP histories for B and M box/mask
1375
+ precision_B_history = output_data.get("metrics/precision(B)")
1376
+ recall_B_history = output_data.get("metrics/recall(B)")
1377
+ mAP50_B_history = output_data.get("metrics/mAP50(B)")
1378
+ mAP50_95_B_history = output_data.get("metrics/mAP50-95(B)")
1379
+ precision_M_history = output_data.get("metrics/precision(M)")
1380
+ recall_M_history = output_data.get("metrics/recall(M)")
1381
+ mAP50_M_history = output_data.get("metrics/mAP50(M)")
1382
+ mAP50_95_M_history = output_data.get("metrics/mAP50-95(M)")
1383
+
1384
+ # Check for 'None' data and adjust the number of rows in the grid
1385
+ num_rows = 4
1386
+ subplot_titles = [
1387
+ "Precision B",
1388
+ "Recall B",
1389
+ "mAP50 B",
1390
+ "mAP50-95 B",
1391
+ "Precision R",
1392
+ "Recall R",
1393
+ "mAP50 R",
1394
+ "mAP50-95 R",
1395
+ "Train Box Loss",
1396
+ "Train Class Loss",
1397
+ "Train DFL Loss",
1398
+ "Train Seg Loss",
1399
+ "Val Box Loss",
1400
+ "Val Class Loss",
1401
+ "Val DFL Loss",
1402
+ "Val Seg Loss",
1403
+ ]
1404
+ if precision_M_history is None:
1405
+ num_rows = 3
1406
+ subplot_titles = subplot_titles[0:4] + subplot_titles[8:]
1407
+
1408
+ # Create a subplot grid
1409
+ fig = make_subplots(
1410
+ rows=num_rows,
1411
+ cols=4,
1412
+ subplot_titles=subplot_titles,
1413
+ vertical_spacing=0.05,
1414
+ )
1415
+
1416
+ # Initialize row number
1417
+ row_number = 1
1418
+
1419
+ # Add precision, recall, mAP plots for B and R box/mask
1420
+ fig.add_trace(
1421
+ go.Scatter(
1422
+ x=epoch_history, y=precision_B_history, mode="lines", name="Precision B"
1423
+ ),
1424
+ row=row_number,
1425
+ col=1,
1426
+ )
1427
+ fig.add_trace(
1428
+ go.Scatter(x=epoch_history, y=recall_B_history, mode="lines", name="Recall B"),
1429
+ row=row_number,
1430
+ col=2,
1431
+ )
1432
+ fig.add_trace(
1433
+ go.Scatter(x=epoch_history, y=mAP50_B_history, mode="lines", name="mAP50 B"),
1434
+ row=row_number,
1435
+ col=3,
1436
+ )
1437
+ fig.add_trace(
1438
+ go.Scatter(
1439
+ x=epoch_history, y=mAP50_95_B_history, mode="lines", name="mAP50-95 B"
1440
+ ),
1441
+ row=row_number,
1442
+ col=4,
1443
+ )
1444
+
1445
+ if precision_M_history is not None:
1446
+ # Increment row number
1447
+ row_number += 1
1448
+
1449
+ fig.add_trace(
1450
+ go.Scatter(
1451
+ x=epoch_history, y=precision_M_history, mode="lines", name="Precision R"
1452
+ ),
1453
+ row=row_number,
1454
+ col=1,
1455
+ )
1456
+ fig.add_trace(
1457
+ go.Scatter(
1458
+ x=epoch_history, y=recall_M_history, mode="lines", name="Recall R"
1459
+ ),
1460
+ row=row_number,
1461
+ col=2,
1462
+ )
1463
+ fig.add_trace(
1464
+ go.Scatter(
1465
+ x=epoch_history, y=mAP50_M_history, mode="lines", name="mAP50 R"
1466
+ ),
1467
+ row=row_number,
1468
+ col=3,
1469
+ )
1470
+ fig.add_trace(
1471
+ go.Scatter(
1472
+ x=epoch_history, y=mAP50_95_M_history, mode="lines", name="mAP50-95 R"
1473
+ ),
1474
+ row=row_number,
1475
+ col=4,
1476
+ )
1477
+
1478
+ # Increment row number
1479
+ row_number += 1
1480
+
1481
+ # Add loss plots
1482
+ fig.add_trace(
1483
+ go.Scatter(
1484
+ x=epoch_history,
1485
+ y=train_box_loss_history,
1486
+ mode="lines",
1487
+ name="Train Box Loss",
1488
+ ),
1489
+ row=row_number,
1490
+ col=1,
1491
+ )
1492
+ fig.add_trace(
1493
+ go.Scatter(
1494
+ x=epoch_history,
1495
+ y=train_cls_loss_history,
1496
+ mode="lines",
1497
+ name="Train Class Loss",
1498
+ ),
1499
+ row=row_number,
1500
+ col=2,
1501
+ )
1502
+ fig.add_trace(
1503
+ go.Scatter(
1504
+ x=epoch_history,
1505
+ y=train_dfl_loss_history,
1506
+ mode="lines",
1507
+ name="Train DFL Loss",
1508
+ ),
1509
+ row=row_number,
1510
+ col=3,
1511
+ )
1512
+ fig.add_trace(
1513
+ go.Scatter(
1514
+ x=epoch_history,
1515
+ y=train_seg_loss_history,
1516
+ mode="lines",
1517
+ name="Train Seg Loss",
1518
+ ),
1519
+ row=row_number,
1520
+ col=4,
1521
+ )
1522
+
1523
+ # Increment row number
1524
+ row_number += 1
1525
+
1526
+ fig.add_trace(
1527
+ go.Scatter(
1528
+ x=epoch_history, y=val_box_loss_history, mode="lines", name="Val Box Loss"
1529
+ ),
1530
+ row=row_number,
1531
+ col=1,
1532
+ )
1533
+ fig.add_trace(
1534
+ go.Scatter(
1535
+ x=epoch_history, y=val_cls_loss_history, mode="lines", name="Val Class Loss"
1536
+ ),
1537
+ row=row_number,
1538
+ col=2,
1539
+ )
1540
+ fig.add_trace(
1541
+ go.Scatter(
1542
+ x=epoch_history, y=val_dfl_loss_history, mode="lines", name="Val DFL Loss"
1543
+ ),
1544
+ row=row_number,
1545
+ col=3,
1546
+ )
1547
+ fig.add_trace(
1548
+ go.Scatter(
1549
+ x=epoch_history,
1550
+ y=val_seg_loss_history,
1551
+ mode="lines",
1552
+ name="Val Seg Loss",
1553
+ ),
1554
+ row=row_number,
1555
+ col=4,
1556
+ )
1557
+
1558
+ # Check if the plot container is already initialized
1559
+ if plot_container is None:
1560
+ plot_container = st.empty()
1561
+
1562
+ # Update layout
1563
+ fig.update_layout(
1564
+ height=1200,
1565
+ width=1600,
1566
+ title_text="Metrics",
1567
+ legend=dict(orientation="h", yanchor="bottom", xanchor="left"),
1568
+ )
1569
+
1570
+ # Display the updated plot in the same container
1571
+ plot_container.plotly_chart(fig, use_container_width=True)
1572
+
1573
+
1574
+ # Function to define a custom callback function for on_fit_epoch_end
1575
+ def on_fit_epoch_end(trainer):
1576
+ current_epoch = int(trainer.epoch)
1577
+ total_epochs = int(trainer.epochs)
1578
+
1579
+ # Define the path to the output CSV
1580
+ output_csv_path = os.path.join(get_path("output"), "train", "results.csv")
1581
+
1582
+ # Read the CSV data
1583
+ st.session_state["plot_data"] = read_csv_and_get_values(output_csv_path)
1584
+
1585
+ # Call a function to update the plot using this data
1586
+ display_metrics_plot(st.session_state["plot_data"])
1587
+
1588
+ # Update progress bar and text
1589
+ progress_bar.progress((current_epoch + 1) / total_epochs)
1590
+ progress_text.write(f"Epoch {(current_epoch + 1)}/{total_epochs}")
1591
+
1592
+
1593
+ # Function to define a custom callback function for on_train_end
1594
+ def on_train_end(trainer):
1595
+ global progress_bar, progress_text
1596
+ progress_bar.empty()
1597
+ progress_text.info(
1598
+ "Best and last model save completed successfully.",
1599
+ icon="✅",
1600
+ )
1601
+
1602
+
1603
+ # Function to add various callbacks to the YOLO model for different stages of the training process
1604
+ def callback_add(model):
1605
+ # Add a callback to be triggered at the start of the pre-training routine
1606
+ model.add_callback("on_pretrain_routine_start", on_pretrain_routine_start)
1607
+
1608
+ # Add a callback to be triggered at the start of the training
1609
+ model.add_callback("on_train_start", on_train_start)
1610
+
1611
+ # Add a callback to be triggered at the end of each training epoch
1612
+ model.add_callback("on_fit_epoch_end", on_fit_epoch_end)
1613
+
1614
+ # Add a callback to be triggered at the end of the training process
1615
+ model.add_callback("on_train_end", on_train_end)
1616
+
1617
+
1618
+ # Function to zip a folder and all its subfolders and return a BytesIO object
1619
+ def zip_folder_to_bytesio(folder_path):
1620
+ bytes_io = io.BytesIO()
1621
+ with zipfile.ZipFile(bytes_io, "w", zipfile.ZIP_DEFLATED) as zipf:
1622
+ folder_path_abs = os.path.abspath(folder_path)
1623
+
1624
+ for root, dirs, files in os.walk(folder_path):
1625
+ # Calculate the relative path from the folder_path
1626
+ folder_rel_path = os.path.relpath(root, folder_path_abs)
1627
+
1628
+ # If the directory is empty, add the directory itself
1629
+ if not dirs and not files:
1630
+ # ZIP format requires a trailing slash for empty directories
1631
+ zip_dir_path = f"{folder_rel_path}/" if folder_rel_path != "." else ""
1632
+ zipf.write(root, zip_dir_path)
1633
+
1634
+ for file in files:
1635
+ file_path = os.path.join(root, file)
1636
+ # Construct the path within the zip file
1637
+ zip_file_path = (
1638
+ os.path.join(folder_rel_path, file)
1639
+ if folder_rel_path != "."
1640
+ else file
1641
+ )
1642
+ zipf.write(file_path, zip_file_path)
1643
+
1644
+ bytes_io.seek(0) # Go to the start of the BytesIO buffer
1645
+ return bytes_io
1646
+
1647
+
1648
+ # Function to display Metrics Table
1649
+ st.cache_resource(show_spinner=False)
1650
+
1651
+
1652
+ def display_val_dataframe(val_dataframe):
1653
+ global val_dataframe_container
1654
+
1655
+ # Check if the dataframe container is already initialized
1656
+ if val_dataframe_container is None:
1657
+ val_dataframe_container = st.container()
1658
+
1659
+ # Display the updated dataframe in the same container
1660
+ with val_dataframe_container:
1661
+ # Display the message to indicate that the metrics table is ready
1662
+ st.markdown("**Metrics Table**", unsafe_allow_html=True)
1663
+
1664
+ # Display the DataFrame
1665
+ st.dataframe(val_dataframe)
1666
+
1667
+
1668
+ # Function to display the DataFrame
1669
+ def val_dataframe(model):
1670
+ # Placeholder for the initial message
1671
+ message = st.empty()
1672
+ message.markdown("**Generating Metrics Table...**", unsafe_allow_html=True)
1673
+
1674
+ # Extract the metrics from the model
1675
+ metrics = model.val()
1676
+
1677
+ # Extract the class indices and names
1678
+ class_index = metrics.ap_class_index
1679
+ class_names = metrics.names
1680
+
1681
+ # Extract precision, recall, and mAP values for the box (B) metrics
1682
+ precision_B_values = metrics.box.p
1683
+ recall_B_values = metrics.box.r
1684
+ mAP50_95_B_values = [metrics.box.maps[i] for i in class_index]
1685
+
1686
+ # Check if segmentation (mask) metrics exist
1687
+ try:
1688
+ metrics_mask = metrics.seg
1689
+ except:
1690
+ metrics_mask = False
1691
+ if metrics_mask:
1692
+ precision_M_values = metrics_mask.p
1693
+ recall_M_values = metrics_mask.r
1694
+ mAP50_95_M_values = [metrics_mask.maps[i] for i in class_index]
1695
+
1696
+ # Extract aggregated metrics from the results dictionary
1697
+ results_dict = metrics.results_dict
1698
+
1699
+ # Initialize lists for overall precision, recall, and mAP for box (B)
1700
+ precision_B = [results_dict.get("metrics/precision(B)")]
1701
+ recall_B = [results_dict.get("metrics/recall(B)")]
1702
+ mAP50_95_B = [results_dict.get("metrics/mAP50-95(B)")]
1703
+
1704
+ # Initialize lists for overall precision, recall, and mAP for mask (M) if available
1705
+ precision_M = [results_dict.get("metrics/precision(M)")] if metrics_mask else None
1706
+ recall_M = [results_dict.get("metrics/recall(M)")] if metrics_mask else None
1707
+ mAP50_95_M = [results_dict.get("metrics/mAP50-95(M)")] if metrics_mask else None
1708
+
1709
+ # Create a list of class names starting with "All" for the overall metrics
1710
+ name_list = ["All"] + [str(class_names[i]) for i in class_index]
1711
+
1712
+ # Extend the metrics lists with values for each class
1713
+ precision_B.extend(precision_B_values)
1714
+ recall_B.extend(recall_B_values)
1715
+ mAP50_95_B.extend(mAP50_95_B_values)
1716
+
1717
+ # If mask metrics are available, extend their lists with values for each class
1718
+ if metrics_mask:
1719
+ precision_M.extend(precision_M_values)
1720
+ recall_M.extend(recall_M_values)
1721
+ mAP50_95_M.extend(mAP50_95_M_values)
1722
+
1723
+ # Create a DataFrame with the computed metrics
1724
+ if metrics_mask:
1725
+ st.session_state["val_dataframe"] = pd.DataFrame(
1726
+ {
1727
+ "Class Name": name_list,
1728
+ "Precision (B)": precision_B,
1729
+ "Recall (B)": recall_B,
1730
+ "mAP50-95 (B)": mAP50_95_B,
1731
+ "Precision (M)": precision_M,
1732
+ "Recall (M)": recall_M,
1733
+ "mAP50-95 (M)": mAP50_95_M,
1734
+ }
1735
+ )
1736
+ else:
1737
+ st.session_state["val_dataframe"] = pd.DataFrame(
1738
+ {
1739
+ "Class Name": name_list,
1740
+ "Precision (B)": precision_B,
1741
+ "Recall (B)": recall_B,
1742
+ "mAP50-95 (B)": mAP50_95_B,
1743
+ }
1744
+ )
1745
+
1746
+ # Clear the initial message
1747
+ message.empty()
1748
+
1749
+ # Update the message to indicate that the metrics table is ready and Display the DataFrame
1750
+ display_val_dataframe(st.session_state["val_dataframe"])
1751
+
1752
+
1753
+ # Function to train the YOLO model
1754
+ def train_yolo_model(training_configuration):
1755
+ # Clear and recreate the output folder to ensure a fresh start
1756
+ delete_and_recreate_folder(get_path("output"))
1757
+
1758
+ # Initialize the YOLO model with the specified path from the training configuration
1759
+ model = YOLO(training_configuration["model_path"])
1760
+
1761
+ # Add any callbacks or additional configuration to the model
1762
+ callback_add(model)
1763
+
1764
+ # Train the model with the specified parameters
1765
+ model.train(
1766
+ task=training_configuration["task"],
1767
+ data=os.path.join(get_path("config"), "config.yaml"),
1768
+ epochs=training_configuration["epochs"],
1769
+ time=training_configuration["time"],
1770
+ patience=training_configuration["patience"],
1771
+ batch=training_configuration["batch"],
1772
+ imgsz=training_configuration["imgsz"],
1773
+ save=True,
1774
+ save_period=-1,
1775
+ cache=training_configuration["cache"],
1776
+ pretrained=True,
1777
+ optimizer=training_configuration["optimizer"],
1778
+ verbose=True,
1779
+ seed=0,
1780
+ deterministic=training_configuration["deterministic"],
1781
+ single_cls=False,
1782
+ rect=training_configuration["rect"],
1783
+ cos_lr=training_configuration["cos_lr"],
1784
+ resume=False,
1785
+ amp=training_configuration["amp"],
1786
+ fraction=1.0,
1787
+ freeze=training_configuration["freeze"],
1788
+ lr0=training_configuration["lr0"],
1789
+ lrf=training_configuration["lrf"],
1790
+ momentum=training_configuration["momentum"],
1791
+ weight_decay=training_configuration["weight_decay"],
1792
+ warmup_epochs=training_configuration["warmup_epochs"],
1793
+ warmup_momentum=training_configuration["warmup_momentum"],
1794
+ warmup_bias_lr=training_configuration["warmup_bias_lr"],
1795
+ box=training_configuration["box"],
1796
+ cls=training_configuration["cls"],
1797
+ dfl=training_configuration["dfl"],
1798
+ label_smoothing=training_configuration["label_smoothing"],
1799
+ nbs=training_configuration["nbs"],
1800
+ overlap_mask=training_configuration["overlap_mask"],
1801
+ mask_ratio=training_configuration["mask_ratio"],
1802
+ dropout=training_configuration["dropout"],
1803
+ val=training_configuration["val"],
1804
+ plots=training_configuration["plots"],
1805
+ save_dir=os.path.join(get_path("output"), "train"),
1806
+ project=get_path("output"),
1807
+ name="train",
1808
+ augment=False,
1809
+ exist_ok=True,
1810
+ )
1811
+
1812
+ return model
1813
+
1814
+
1815
+ # Function to export the model with the given parameters
1816
+ def export_model_with_parameters(model, export_params):
1817
+ global progress_text
1818
+
1819
+ if export_params["format"] is not None:
1820
+ # Informing the user that the export process has started
1821
+ progress_text.info(
1822
+ "Starting the export process with the specified settings.",
1823
+ icon="✅",
1824
+ )
1825
+
1826
+ # Perform the model export
1827
+ model.export(
1828
+ format=export_params["format"],
1829
+ keras=export_params["keras"],
1830
+ optimize=export_params["optimize"],
1831
+ half=export_params["half"],
1832
+ int8=export_params["int8"],
1833
+ dynamic=export_params["dynamic"],
1834
+ simplify=export_params["simplify"],
1835
+ opset=export_params["opset"],
1836
+ workspace=export_params["workspace"],
1837
+ nms=export_params["nms"],
1838
+ )
1839
+
1840
+ # Informing the user that the export process has completed successfully
1841
+ progress_text.info(
1842
+ "The model has been successfully saved using the specified export settings.",
1843
+ icon="✅",
1844
+ )
1845
+
1846
+
1847
+ # Function to start the YOLO model training process
1848
+ def start_yolo_training(selected_training, class_labels):
1849
+ global plot_container, val_dataframe_container
1850
+
1851
+ # Retrieve the training configuration based on the user's selection
1852
+ training_configuration = get_training_validation_export_configuration(
1853
+ selected_training
1854
+ )
1855
+
1856
+ # Generates a YOLO model training code snippet and displays it with a download button
1857
+ generate_and_display_yolo_training_code(class_labels, training_configuration)
1858
+
1859
+ # Create two columns
1860
+ col1, col2 = st.columns(2)
1861
+ # When the "Start Training" button is clicked in the first column
1862
+ if col1.button("Start Training", use_container_width=True):
1863
+ plot_container = None
1864
+ val_dataframe_container = None
1865
+
1866
+ with st.spinner("Training in Progress..."):
1867
+ # Train the YOLO model using the provided configuration
1868
+ trained_model = train_yolo_model(training_configuration)
1869
+
1870
+ # Export the model with the given parameters
1871
+ export_model_with_parameters(
1872
+ trained_model, training_configuration["export_settings"]
1873
+ )
1874
+
1875
+ # Display the validation results in a DataFrame after training
1876
+ val_dataframe(trained_model)
1877
+
1878
+ elif "plot_data" in st.session_state and "val_dataframe" in st.session_state:
1879
+ plot_container = None
1880
+ val_dataframe_container = None
1881
+
1882
+ # Display metrics plot and table if already exist
1883
+ display_metrics_plot(st.session_state["plot_data"])
1884
+ display_val_dataframe(st.session_state["val_dataframe"])
1885
+
1886
+ # Prepare a ZIP file of the training output folder in memory for download
1887
+ zip_bytes_io = zip_folder_to_bytesio(os.path.join(get_path("output"), "train"))
1888
+
1889
+ # Provide a button in the second column to download the ZIP file
1890
+ col2.download_button(
1891
+ label="Download",
1892
+ data=zip_bytes_io,
1893
+ file_name="model_training_output.zip",
1894
+ mime="application/zip",
1895
+ use_container_width=True,
1896
+ )
README.md CHANGED
@@ -1,14 +1,12 @@
1
- ---
2
- title: CV Accelerator
3
- emoji: 🐠
4
- colorFrom: yellow
5
- colorTo: blue
6
- sdk: streamlit
7
- sdk_version: 1.41.1
8
- app_file: app.py
9
- pinned: false
10
- license: unlicense
11
- short_description: A no-code CV Accelerator
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: CV Accelerator
3
+ emoji: 📷
4
+ colorFrom: gray
5
+ colorTo: indigo
6
+ sdk: streamlit
7
+ sdk_version: 1.32.0
8
+ app_file: Welcome.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
Welcome.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Set the page config
2
+ import streamlit as st
3
+
4
+ st.set_page_config(
5
+ page_title="Welcome",
6
+ page_icon=":open_file_folder:",
7
+ layout="wide",
8
+ )
config.toml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [theme]
2
+ primaryColor = "#101276"
3
+ backgroundColor = "#FFFFFF"
4
+ textColor = "#006EC0"
5
+ secondaryBackgroundColor = "#D2E5F2"
6
+ font = "sans serif"
7
+
8
+ [server]
9
+ maxUploadSize = 10000
model_data/input_files/config.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ path: C:\Users\SamkeetSangai\OneDrive - Blend 360\Work\CoEs\Computer Vision\cv_accelerator\model_data\input_files\datasets # Path to the dataset directory
2
+ train: train # Path to the training set directory
3
+ val: val # Path to the validation set directory
4
+ test: test # Path to the testing set directory
5
+ nc: 80 # Number of classes
6
+ names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] # List of class names
model_data/model_training_code_pipline/config.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ path: ./datasets # Path to the dataset directory
2
+ train: train # Path to the training set directory
3
+ val: val # Path to the validation set directory
4
+ test: test # Path to the testing set directory
5
+ nc: 80 # Number of classes
6
+ names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] # List of class names
model_data/model_training_code_pipline/model_training.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing necessary libraries
2
+ from ultralytics import YOLO
3
+
4
+ # Initialize the YOLO model 'yolov8n.pt'
5
+ model = YOLO('yolov8n.pt')
6
+
7
+ # Start the training process
8
+ model.train(
9
+ task=detect,
10
+ time=None,
11
+ epochs=2,
12
+ patience=None,
13
+ batch=4,
14
+ imgsz=640,
15
+ cache=False,
16
+ optimizer=auto,
17
+ amp=True,
18
+ deterministic=False,
19
+ rect=False,
20
+ cos_lr=False,
21
+ freeze=None,
22
+ lr0=0.01,
23
+ lrf=0.01,
24
+ momentum=0.937,
25
+ weight_decay=0.0005,
26
+ warmup_epochs=3.0,
27
+ warmup_momentum=0.8,
28
+ warmup_bias_lr=0.1,
29
+ box=7.5,
30
+ cls=0.5,
31
+ dfl=1.5,
32
+ label_smoothing=0.0,
33
+ nbs=64,
34
+ overlap_mask=True,
35
+ mask_ratio=4,
36
+ dropout=0.0,
37
+ val=True,
38
+ plots=True,
39
+ conf=0.001,
40
+ iou=0.6,
41
+ max_det=300,
42
+ half=False,
43
+ data=.\config.yaml,
44
+ save_dir=.\output\train,
45
+ pretrained=True,
46
+ save=True,
47
+ save_period=-1,
48
+ augment=False,
49
+ seed=0,
50
+ verbose=True,
51
+ single_cls=False,
52
+ resume=False,
53
+ exist_ok=True,
54
+ project=.\output,
55
+ name=train
56
+ )
57
+
58
+ # Model export process
59
+ model.export(
60
+ keras=False,
61
+ optimize=False,
62
+ half=False,
63
+ int8=False,
64
+ dynamic=False,
65
+ simplify=False,
66
+ opset=None,
67
+ workspace=4,
68
+ nms=False
69
+ )
model_data/models/yolov8l-seg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfa891dab47031c12ac82cb1b883af4ab9efbc732118cc04604bfbf107f9dfa8
3
+ size 92391859
model_data/models/yolov8l.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18218ea4798da042d9862e6029ca9531adbd40ace19b6c9a75e2e28f1adf30cc
3
+ size 87769683
model_data/models/yolov8m-seg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9fc740ca0824e14b44681d491dc601efa664ec6ecea9a870acf876053826448
3
+ size 54899779
model_data/models/yolov8m.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c25b0b63b1a433843f06d821a9ac1deb8d5805f74f0f38772c7308c5adc55a5
3
+ size 52117635
model_data/models/yolov8n-seg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d39e867b2c3a5dbc1aa764411544b475cb14727bf6af1ec46c238f8bb1351ab9
3
+ size 7054355
model_data/models/yolov8n.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31e20dde3def09e2cf938c7be6fe23d9150bbbe503982af13345706515f2ef95
3
+ size 6534387
model_data/models/yolov8s-seg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c035c6c5f9c48ee962518ef648854c2cc5e60fa404fac443e17d306fdda16543
3
+ size 23897299
model_data/models/yolov8s.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:268e5bb54c640c96c3510224833bc2eeacab4135c6deb41502156e39986b562d
3
+ size 22573363
model_data/models/yolov8x-seg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d63cbfa5764867c0066bedfa43cf2dcd90a412a1de44b2e238c43978a9d28ea6
3
+ size 144076467
model_data/models/yolov8x.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4d5a3f000d771762f03fc8b57ebd0aae324aeaefdd6e68492a9c4470f2d1e8b
3
+ size 136867539
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python3-opencv
packages_db.xlsx ADDED
Binary file (54.9 kB). View file
 
pages/2_Image_Processing.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Set the page config
2
+ import streamlit as st
3
+
4
+ st.set_page_config(
5
+ page_title="Image Processing",
6
+ page_icon=":open_file_folder:",
7
+ layout="wide",
8
+ initial_sidebar_state="collapsed",
9
+ )
10
+
11
+ # Importing necessary libraries
12
+ import cv2
13
+ import utils
14
+ import numpy as np
15
+ import Functions.image_processing_functions as image_processing_functions
16
+
17
+ # Load image processing technique parameters and details from an Excel file
18
+ image_processing_params_df = utils.load_data_from_excel(
19
+ "packages_db.xlsx", "image_processing_parameters"
20
+ )
21
+ image_processing_details_df = utils.load_data_from_excel(
22
+ "packages_db.xlsx", "image_processing_details"
23
+ )
24
+
25
+ # Display the page title
26
+ st.title("Image Processing")
27
+
28
+ # # Clear the Streamlit session state on the first load of the page
29
+ # utils.clear_session_state_on_first_load("image_processing_clear")
30
+
31
+ # List of session state keys to initialize if they are not already present
32
+ session_state_keys = [
33
+ "file_uploader_key_processing",
34
+ "select_processing_technique_key_processing",
35
+ ]
36
+
37
+ # Iterate through each session state key
38
+ for key in session_state_keys:
39
+ # Check if the key is not already in the session state
40
+ if key not in st.session_state:
41
+ # Initialize the key with a dictionary containing itself set to True
42
+ st.session_state[key] = {key: True}
43
+
44
+ # Initialize session state variables if not present
45
+ if "validation_triggered" not in st.session_state:
46
+ st.session_state["validation_triggered"] = False
47
+
48
+ if "uploaded_files_cache_processing" not in st.session_state:
49
+ st.session_state["uploaded_files_cache_processing"] = False
50
+
51
+ if "zip_data_processing" not in st.session_state:
52
+ st.session_state["zip_data_processing"] = ""
53
+
54
+ if "widget_states" not in st.session_state:
55
+ st.session_state["widget_states"] = {}
56
+
57
+ # Interface for uploading an images and labels
58
+ utils.display_file_uploader(
59
+ "uploaded_files",
60
+ "Choose images and labels...",
61
+ st.session_state["file_uploader_key_processing"],
62
+ st.session_state["uploaded_files_cache_processing"],
63
+ )
64
+
65
+ # Note to users
66
+ st.markdown(
67
+ """
68
+ <div style='text-align: justify;'>
69
+ <b>Note to Users:</b>
70
+ <ul>
71
+ <li>The <i>first uploaded image</i> will be used for demonstration purposes and to validate parameters for image processing techniques.</li>
72
+ <li>Uploading <i>labels is optional</i>. If no labels are uploaded, the output will consist solely of processed images.</li>
73
+ <li>When moving to another page or if you wish to upload a new set of images and labels, don't forget to hit the <b>Reset</b> button. This helps in faster computation and frees up unused memory, ensuring smoother operation.</li>
74
+ </ul>
75
+ </div>
76
+ """,
77
+ unsafe_allow_html=True,
78
+ )
79
+
80
+ # List of session state variables to initialize
81
+ session_vars = [
82
+ "is_valid",
83
+ "image_files",
84
+ "label_files",
85
+ "first_image_file",
86
+ "first_label_file",
87
+ ]
88
+
89
+ # Initialize each variable as None if it doesn't exist in the session state
90
+ for var in session_vars:
91
+ if var not in st.session_state:
92
+ st.session_state[var] = None
93
+
94
+ # Create two columns
95
+ col1, col2 = st.columns(2)
96
+
97
+ # Button to trigger validation
98
+ if (
99
+ col1.button("Validate Input", use_container_width=True)
100
+ or st.session_state["widget_states"].get("validate_input_button", False)
101
+ ) and not st.session_state["validation_triggered"]:
102
+ st.session_state["validation_triggered"] = True
103
+ st.session_state["uploaded_files_cache_processing"] = True
104
+
105
+ (
106
+ st.session_state["is_valid"],
107
+ st.session_state["image_files"],
108
+ st.session_state["label_files"],
109
+ st.session_state["first_image_file"],
110
+ st.session_state["first_label_file"],
111
+ ) = image_processing_functions.check_valid_labels(
112
+ st.session_state["uploaded_files"]
113
+ )
114
+
115
+ elif st.session_state["validation_triggered"]:
116
+ pass
117
+
118
+ else:
119
+ st.session_state["is_valid"] = False
120
+ st.warning(
121
+ "Please upload images and labels and click **Validate Input**.", icon="⚠️"
122
+ )
123
+
124
+ with col2:
125
+ # Check if the 'Reset' button is pressed
126
+ if st.button("Reset", use_container_width=True):
127
+ # Toggle the keys for file uploader and processing technique to reset their states
128
+ current_value = st.session_state["file_uploader_key_processing"][
129
+ "file_uploader_key_processing"
130
+ ]
131
+ updated_value = not current_value # Invert the current value
132
+
133
+ # List of session state keys that need to be reset
134
+ session_state_keys = [
135
+ "file_uploader_key_processing",
136
+ "select_processing_technique_key_processing",
137
+ ]
138
+
139
+ # Iterate through each session state key
140
+ for session_state_key in session_state_keys:
141
+ # Update each key in the session state with the toggled value
142
+ st.session_state[session_state_key] = {session_state_key: updated_value}
143
+
144
+ # Clear all other session state keys except for widget_state_keys
145
+ for key in list(st.session_state.keys()):
146
+ if key not in session_state_keys:
147
+ del st.session_state[key]
148
+
149
+ # Clear global variables except for protected and Streamlit module
150
+ global_vars = list(globals().keys())
151
+ vars_to_delete = [
152
+ var for var in global_vars if not var.startswith("_") and var != "st"
153
+ ]
154
+ for var in vars_to_delete:
155
+ del globals()[var]
156
+
157
+ # Clear the Streamlit caches
158
+ st.cache_resource.clear()
159
+ st.cache_data.clear()
160
+
161
+ # Rerun the app to reflect the reset state
162
+ st.rerun()
163
+
164
+ # Interface to select image processing techniques
165
+ available_image_processings = image_processing_details_df["Name"]
166
+
167
+
168
+ # Mapping each image processing techniques to its corresponding image types
169
+ input_mapping_dict = utils.technique_image_input_mapping(
170
+ available_image_processings, image_processing_details_df
171
+ )
172
+
173
+ # Present the option to select image processing techniques only if the uploaded files are validated successfully
174
+ if st.session_state["is_valid"]:
175
+ selected_image_processings = st.multiselect(
176
+ "Select image processing technique(s)",
177
+ available_image_processings,
178
+ key=st.session_state["select_processing_technique_key_processing"],
179
+ )
180
+
181
+ # Read the first uploaded image into a NumPy array
182
+ st.session_state["first_image_file"].seek(0) # Reset file pointer to start
183
+ file_bytes_first_image = np.frombuffer(
184
+ st.session_state["first_image_file"].read(), dtype=np.uint8
185
+ )
186
+ uploaded_first_image = cv2.cvtColor(
187
+ cv2.imdecode(file_bytes_first_image, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB
188
+ )
189
+
190
+ # Resize the image
191
+ uploaded_first_image = cv2.resize(uploaded_first_image, (256, 256))
192
+
193
+ else:
194
+ # Reset selected techniques to empty if input validation fails
195
+ selected_image_processings = []
196
+
197
+
198
+ #######################################################################################################
199
+ # Build custom image processing pipeline
200
+ #######################################################################################################
201
+
202
+
203
+ # Store parameters for each selected image processing technique
204
+ image_processings_params = {}
205
+
206
+ # Initialize a flag to track if any error exists
207
+ error_in_parameters = False
208
+
209
+ # Loop through each selected image processing techniques to set up parameters
210
+ for image_processing in selected_image_processings:
211
+ with st.expander(f"{image_processing}"):
212
+ # Retrieve image processing details from the database
213
+ image_processing_info = image_processing_details_df[
214
+ image_processing_details_df["Name"] == image_processing
215
+ ]
216
+
217
+ # Set up columns for displaying details and image placeholders
218
+ details_col, image_col = st.columns([7, 3])
219
+
220
+ with details_col:
221
+ # Display the description for the image processing technique
222
+ image_processing_description = (
223
+ image_processing_info["Description"].iloc[0]
224
+ if not image_processing_info.empty
225
+ else "No description available."
226
+ )
227
+ st.markdown(
228
+ f"<div style='text-align: justify;'><b>Description:</b> {image_processing_description}</div>",
229
+ unsafe_allow_html=True,
230
+ )
231
+
232
+ # Display the category for the image processing
233
+ image_processing_category = (
234
+ image_processing_info["Category"].iloc[0]
235
+ if not image_processing_info.empty
236
+ else "Unknown"
237
+ )
238
+ st.write("Category:", image_processing_category)
239
+
240
+ # Retrieve the source code link for the image processing
241
+ image_processing_source_code = (
242
+ image_processing_info["Source Code Link"].iloc[0]
243
+ if not image_processing_info.empty
244
+ else "www.google.com"
245
+ )
246
+ # Set up columns for displaying source code button and custom settings checkbox
247
+ source_code_col, custo_setting_col = st.columns(2)
248
+
249
+ source_code_col.link_button("Source Code", image_processing_source_code)
250
+
251
+ # Toggle for custom settings
252
+ custom_settings = custo_setting_col.checkbox(
253
+ f"Customize {image_processing}", key=f"toggle_{image_processing}"
254
+ )
255
+
256
+ with image_col:
257
+ # Create two columns
258
+ col1, col2 = st.columns(2)
259
+ original_image_placeholder = col1.container(height=200, border=False)
260
+ processed_image_placeholder = col2.container(height=200, border=False)
261
+
262
+ # Apply custom settings
263
+ if custom_settings:
264
+ # Retrieve parameters for the image processing
265
+ params_df = image_processing_params_df[
266
+ image_processing_params_df["Name"] == image_processing
267
+ ]
268
+
269
+ # Process parameters for each image processing technique and store in a dictionary
270
+ image_processings_params[image_processing] = utils.process_image_parameters(
271
+ params_df, image_processing
272
+ )
273
+
274
+ else:
275
+ # Use default settings if customization is not selected
276
+ image_processings_params[image_processing] = utils.get_default_params(
277
+ image_processing
278
+ )
279
+
280
+ # Check for errors in the selected parameters by applying them to a sample image
281
+ (
282
+ error_flag,
283
+ processed_first_image,
284
+ ) = image_processing_functions.apply_and_test_image_processing(
285
+ image_processing,
286
+ image_processings_params[image_processing],
287
+ uploaded_first_image,
288
+ input_mapping_dict[image_processing],
289
+ )
290
+
291
+ # If there is an error in the parameters, set the global error flag
292
+ if error_flag:
293
+ error_in_parameters = True
294
+ else:
295
+ # If no error, display the original and processed images side by side
296
+ # Display the original and processed images in their respective placeholders
297
+ with original_image_placeholder:
298
+ st.image(
299
+ uploaded_first_image,
300
+ caption="Original Image",
301
+ use_column_width=True,
302
+ clamp=True,
303
+ )
304
+ with processed_image_placeholder:
305
+ st.image(
306
+ processed_first_image,
307
+ caption="Processed Image",
308
+ use_column_width=True,
309
+ clamp=True,
310
+ )
311
+
312
+ # Update the base image with the previously processed image output
313
+ uploaded_first_image = processed_first_image
314
+
315
+
316
+ #######################################################################################################
317
+ # Display selected image processing technique parameters as DataFrame
318
+ #######################################################################################################
319
+
320
+
321
+ # Check if any image processings have been defined
322
+ if (image_processings_params.keys()) and (not error_in_parameters):
323
+ # Create a dropdown for selecting an image processing technique or 'All'
324
+ selected_image_processing = st.selectbox(
325
+ "Select image processing technique",
326
+ options=["All"] + list(image_processings_params.keys()),
327
+ )
328
+ else:
329
+ selected_image_processing = None
330
+
331
+ # Create the DataFrame from the accumulated data
332
+ image_processings_df = image_processing_functions.create_image_processings_dataframe(
333
+ image_processings_params, image_processing_params_df
334
+ )
335
+ image_processings_df["Value"] = image_processings_df["Value"].astype(
336
+ str
337
+ ) # Ensure consistent data types and handle potential serialization issues
338
+
339
+ # Filter the DataFrame based on the selected image processing
340
+ if selected_image_processing != "All":
341
+ filtered_image_processings_df = image_processings_df[
342
+ image_processings_df["image_processing"] == selected_image_processing
343
+ ]
344
+ else:
345
+ filtered_image_processings_df = image_processings_df
346
+
347
+ # Check if the filtered dataframe is not empty and the selected configurations are valid
348
+ if (not filtered_image_processings_df.empty) and (not error_in_parameters):
349
+ # Display the DataFrame in Streamlit and use the full width of the container
350
+ st.dataframe(filtered_image_processings_df, use_container_width=False)
351
+
352
+ # Display code and description
353
+ code_placeholder = st.empty()
354
+
355
+
356
+ #######################################################################################################
357
+ # Process images and download processed images
358
+ #######################################################################################################
359
+
360
+
361
+ # Proceed if inputs are valid, techniques selected, and no errors in configurations
362
+ if (
363
+ st.session_state["is_valid"]
364
+ and (len(selected_image_processings) > 0)
365
+ and not error_in_parameters
366
+ ):
367
+ # Create two columns
368
+ col1, col2 = st.columns(2)
369
+
370
+ # Allow user to specify the number of variations to be generated
371
+ num_variations = col1.number_input(
372
+ "Set the number of variations to be generated",
373
+ min_value=1,
374
+ max_value=3,
375
+ step=1,
376
+ )
377
+
378
+ # Checkbox to include original images and labels in the output
379
+ with col2:
380
+ for top_padding in range(2): # Top padding
381
+ st.write("")
382
+
383
+ include_original = st.checkbox(
384
+ "Include original images and labels in output", value=False
385
+ )
386
+
387
+ # Display code and download once all inputs are available
388
+ with code_placeholder:
389
+ # Generate the code with the function
390
+ if len(st.session_state["label_files"]) == 0:
391
+ generated_code = utils.generate_python_code_images(
392
+ image_processings_params,
393
+ num_variations,
394
+ include_original,
395
+ )
396
+ else:
397
+ generated_code = (
398
+ image_processing_functions.generate_python_code_images_labels(
399
+ image_processings_params,
400
+ num_variations,
401
+ include_original,
402
+ )
403
+ )
404
+
405
+ # Display the generated Python code with a description and provide a download button in the Streamlit app
406
+ image_processing_functions.display_code_and_download_button(generated_code)
407
+
408
+ # Create two columns
409
+ col1, col2 = st.columns(2)
410
+
411
+ # Add a button for the user to confirm their selections and proceed with processing
412
+ if col1.button("Accept and Process", use_container_width=True):
413
+ # Call the function and store the results
414
+ image_processing_functions.process_images_and_labels(
415
+ st.session_state["image_files"],
416
+ st.session_state["label_files"],
417
+ selected_image_processings,
418
+ image_processings_params,
419
+ num_variations,
420
+ include_original,
421
+ )
422
+
423
+ # Download button
424
+ col2.download_button(
425
+ label="Download",
426
+ data=st.session_state["zip_data_processing"],
427
+ file_name="processed_images.zip",
428
+ mime="application/zip",
429
+ use_container_width=True,
430
+ disabled=False,
431
+ )
432
+
433
+
434
+ else:
435
+ if (len(selected_image_processings) == 0) and st.session_state["is_valid"]:
436
+ # Inform the user that no image processing techniques have been selected
437
+ st.warning("Please select at least one image processing technique.", icon="⚠️")
438
+
439
+ if error_in_parameters and st.session_state["is_valid"]:
440
+ # Inform the user that there are errors in parameters
441
+ st.warning(
442
+ "There are errors in the image processing parameters. Please review your selections.",
443
+ icon="⚠️",
444
+ )
445
+
446
+
447
+ #######################################################################################################
448
+ # Display original and processed images
449
+ #######################################################################################################
450
+
451
+
452
+ if (
453
+ "image_repository_preprocessing" in st.session_state
454
+ and "processed_image_mapping_procesing" in st.session_state
455
+ ):
456
+ # Number of unique images
457
+ num_unique_images = len(st.session_state["unique_images_names"])
458
+
459
+ if num_unique_images > 1:
460
+ # Create a slider to select an image index from the processed image mapping
461
+ selected_image_index = st.slider(
462
+ "Select an Image",
463
+ min_value=1,
464
+ max_value=num_unique_images, # Set the maximum to the number of unique images
465
+ step=1,
466
+ )
467
+ else:
468
+ selected_image_index = 1
469
+
470
+ # Retrieve the name of the selected original image using the slider index
471
+ selected_original_image_name = st.session_state["unique_images_names"][
472
+ selected_image_index - 1
473
+ ]
474
+
475
+ # Retrieve the names of all processed variants for the selected original image
476
+ processed_variant_names = st.session_state["processed_image_mapping_procesing"].get(
477
+ selected_original_image_name, []
478
+ )
479
+
480
+ # Combine the original image name with its processed variants
481
+ all_image_names = [selected_original_image_name] + processed_variant_names
482
+
483
+ # Number of images and columns
484
+ num_images = len(all_image_names)
485
+ num_columns = 4
486
+
487
+ # Display images in a grid of 4 columns and dynamic number of rows
488
+ for i in range(0, num_images, num_columns):
489
+ # Create a row of columns
490
+ cols = st.columns(num_columns)
491
+ for j in range(num_columns):
492
+ # Calculate the current image index
493
+ image_index = i + j
494
+ if image_index < num_images:
495
+ # Get the image name and data from the repository
496
+ image_name = all_image_names[image_index]
497
+ image_data = st.session_state["image_repository_preprocessing"][
498
+ image_name
499
+ ]["image"]
500
+
501
+ # Display the image in the respective column with caption
502
+ with cols[j]:
503
+ st.image(
504
+ image_data,
505
+ clamp=True,
506
+ caption=image_name,
507
+ use_column_width=True,
508
+ )
509
+
510
+
511
+ # if st.button("Run"):
512
+ # utils.button_click(on_click=None)
pages/3_Image_Augmentation.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Set the page config
2
+ import streamlit as st
3
+
4
+ st.set_page_config(
5
+ page_title="Image Augmentation",
6
+ page_icon=":open_file_folder:",
7
+ layout="wide",
8
+ initial_sidebar_state="collapsed",
9
+ )
10
+
11
+ # Importing necessary libraries
12
+ import cv2
13
+ import utils
14
+ import numpy as np
15
+ import Functions.image_augmentation_functions as augmentation_functions
16
+
17
+
18
+ # Load augmentation technique parameters and details from an Excel file
19
+ augmentation_params_df = utils.load_data_from_excel(
20
+ "packages_db.xlsx", "augmentation_parameters"
21
+ )
22
+ augmentation_details_df = utils.load_data_from_excel(
23
+ "packages_db.xlsx", "augmentation_details"
24
+ )
25
+
26
+ # Display the page title
27
+ st.title("Image Augmentation")
28
+
29
+ # # Clear the Streamlit session state on the first load of the page
30
+ # utils.clear_session_state_on_first_load("image_augmentation_clear")
31
+
32
+ # List of session state keys to initialize if they are not already present
33
+ session_state_keys = [
34
+ "file_uploader_key_augmentation",
35
+ "select_processing_technique_key_augmentation",
36
+ "selected_option_key_augmentation",
37
+ "class_labels_input_key_augmentation",
38
+ "bbox1_key",
39
+ "bbox2_key",
40
+ "bbox3_key",
41
+ "bbox4_key",
42
+ "bbox5_key",
43
+ ]
44
+
45
+ # Iterate through each session state key
46
+ for key in session_state_keys:
47
+ # Check if the key is not already in the session state
48
+ if key not in st.session_state:
49
+ # Initialize the key with a dictionary containing itself set to True
50
+ st.session_state[key] = {key: True}
51
+
52
+ # Initialize session state variables if not present
53
+ if "validation_triggered" not in st.session_state:
54
+ st.session_state["validation_triggered"] = False
55
+
56
+ if "uploaded_files_cache_augmentation" not in st.session_state:
57
+ st.session_state["uploaded_files_cache_augmentation"] = False
58
+
59
+ if "zip_data_augmentation" not in st.session_state:
60
+ st.session_state["zip_data_augmentation"] = ""
61
+
62
+ # Interface for uploading an images and labels
63
+ utils.display_file_uploader(
64
+ "uploaded_files",
65
+ "Choose images and labels...",
66
+ st.session_state["file_uploader_key_augmentation"],
67
+ st.session_state["uploaded_files_cache_augmentation"],
68
+ )
69
+
70
+ # Dropdown for selecting label type
71
+ label_type = st.selectbox(
72
+ "Choose the label type for your augmentation process:",
73
+ ["Masks", "Bboxes"],
74
+ index=1,
75
+ on_change=utils.reset_validation_trigger,
76
+ key=st.session_state["selected_option_key_augmentation"],
77
+ )
78
+
79
+ # Choosing parameters based on the label type selected by the user
80
+ if label_type == "Bboxes":
81
+ # If the selected label type is Bboxes, call the bbox_params function
82
+ label_input_parameters = augmentation_functions.bbox_params()
83
+ elif label_type == "Masks":
84
+ # If the selected label type is Masks
85
+ label_input_parameters = None
86
+
87
+
88
+ # Text area for user to input class labels
89
+ class_labels_input = st.text_area(
90
+ "Enter class labels, separated by commas:",
91
+ utils.sample_class_labels,
92
+ on_change=utils.reset_validation_trigger,
93
+ key=st.session_state["class_labels_input_key_augmentation"],
94
+ ) # Example default values
95
+ class_labels_input = (
96
+ class_labels_input.strip()
97
+ ) # Remove unecessary space form start and end
98
+
99
+
100
+ # Generating a dictionary mapping class IDs to their respective labels
101
+ try:
102
+ class_labels = [
103
+ label.strip() for label in class_labels_input.split(",") if label.strip()
104
+ ]
105
+ class_dict = {
106
+ i + 1: label for i, label in enumerate(class_labels)
107
+ } # Shifting class labels (keys) by 1, since 0 is reserved for the background
108
+ # Invert the class_dict to map class names to class IDs
109
+ class_names_to_ids = {v: k for k, v in class_dict.items()}
110
+
111
+ colors = augmentation_functions.generate_unique_colors(class_dict.keys())
112
+ except Exception as e:
113
+ st.warning(
114
+ "Invalid format for class labels. Please enter labels separated by commas.",
115
+ icon="⚠️",
116
+ )
117
+ class_dict, class_names_to_ids = (
118
+ {},
119
+ {},
120
+ ) # Keeping class_dict and class_names_to_ids as an empty
121
+
122
+
123
+ # Note to users
124
+ st.markdown(
125
+ """
126
+ <div style='text-align: justify;'>
127
+ <b>Note to Users:</b>
128
+ <ul>
129
+ <li>The <i>first uploaded image</i> will be used for demonstration purposes and to validate parameters for augmentation techniques.</li>
130
+ <li>Uploading <i>labels is optional</i>. If no labels are uploaded, the output will consist solely of processed images.</li>
131
+ <li>When moving to another page or if you wish to upload a new set of images and labels, don't forget to hit the <b>Reset</b> button. This helps in faster computation and frees up unused memory, ensuring smoother operation.</li>
132
+ <li>Select the class labels, label type and label parameters before uploading large data for faster computation and more efficient processing.</li>
133
+ </ul>
134
+ </div>
135
+ """,
136
+ unsafe_allow_html=True,
137
+ )
138
+
139
+ # List of session state variables to initialize
140
+ session_vars = [
141
+ "is_valid",
142
+ "image_files",
143
+ "label_files",
144
+ "first_image_file",
145
+ "first_label_file",
146
+ ]
147
+
148
+ # Initialize each variable as None if it doesn't exist in the session state
149
+ for var in session_vars:
150
+ if var not in st.session_state:
151
+ st.session_state[var] = None
152
+
153
+ # Create two columns
154
+ col1, col2 = st.columns(2)
155
+
156
+ # Button to trigger validation
157
+ if (
158
+ col1.button("Validate Input", use_container_width=True)
159
+ and not st.session_state["validation_triggered"]
160
+ ):
161
+ st.session_state["validation_triggered"] = True
162
+ st.session_state["uploaded_files_cache_augmentation"] = True
163
+
164
+ (
165
+ st.session_state["is_valid"],
166
+ st.session_state["image_files"],
167
+ st.session_state["label_files"],
168
+ st.session_state["first_image_file"],
169
+ st.session_state["first_label_file"],
170
+ ) = augmentation_functions.check_valid_labels(
171
+ st.session_state["uploaded_files"], label_type, class_dict
172
+ )
173
+
174
+ elif st.session_state["validation_triggered"]:
175
+ pass
176
+
177
+ else:
178
+ st.session_state["is_valid"] = False
179
+ st.warning(
180
+ "Please upload images and labels and click **Validate Input**.", icon="⚠️"
181
+ )
182
+
183
+ with col2:
184
+ # Check if the 'Reset' button is pressed
185
+ if st.button("Reset", use_container_width=True):
186
+ # Toggle the keys for file uploader and processing technique to reset their states
187
+ current_value = st.session_state["file_uploader_key_augmentation"][
188
+ "file_uploader_key_augmentation"
189
+ ]
190
+ updated_value = not current_value # Invert the current value
191
+
192
+ # Iterate through each session state key
193
+ for session_state_key in session_state_keys:
194
+ # Update each key in the session state with the toggled value
195
+ st.session_state[session_state_key] = {session_state_key: updated_value}
196
+
197
+ # Clear all other session state keys except for widget_state_keys
198
+ for key in list(st.session_state.keys()):
199
+ if key not in session_state_keys:
200
+ del st.session_state[key]
201
+
202
+ # Clear global variables except for protected and Streamlit module
203
+ global_vars = list(globals().keys())
204
+ vars_to_delete = [
205
+ var for var in global_vars if not var.startswith("_") and var != "st"
206
+ ]
207
+ for var in vars_to_delete:
208
+ del globals()[var]
209
+
210
+ # Clear the Streamlit caches
211
+ st.cache_resource.clear()
212
+ st.cache_data.clear()
213
+
214
+ # Rerun the app to reflect the reset state
215
+ st.rerun()
216
+
217
+ # Fetching the names of techniques applicable to the selected option
218
+ available_augmentations = augmentation_functions.get_applicable_techniques(
219
+ augmentation_details_df, label_type
220
+ )
221
+
222
+ # Mapping each image processing techniques to its corresponding image types
223
+ input_mapping_dict = utils.technique_image_input_mapping(
224
+ available_augmentations, augmentation_details_df
225
+ )
226
+
227
+ # Present the option to select augmentation techniques only if the uploaded files are validated successfully
228
+ if st.session_state["is_valid"]:
229
+ selected_augmentations = st.multiselect(
230
+ "Select augmentation technique(s)",
231
+ available_augmentations,
232
+ key=st.session_state["select_processing_technique_key_augmentation"],
233
+ )
234
+
235
+ # Read the first uploaded image into a NumPy array
236
+ st.session_state["first_image_file"].seek(0) # Reset file pointer to start
237
+ file_bytes_first_image = np.frombuffer(
238
+ st.session_state["first_image_file"].read(), dtype=np.uint8
239
+ )
240
+ uploaded_first_image = cv2.cvtColor(
241
+ cv2.imdecode(file_bytes_first_image, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB
242
+ )
243
+
244
+ # # Resize the image
245
+ # uploaded_first_image = cv2.resize(uploaded_first_image, (256, 256))
246
+
247
+ else:
248
+ # Reset selected techniques to empty if input validation fails
249
+ selected_augmentations = []
250
+
251
+
252
+ #######################################################################################################
253
+ # Build custom augmentation pipeline
254
+ #######################################################################################################
255
+
256
+
257
+ # Store parameters for each selected augmentation technique
258
+ augmentations_params = {}
259
+
260
+ # Initialize a flag to track if any error exists
261
+ error_in_parameters = False
262
+
263
+ # Loop through each selected augmentation techniques to set up parameters
264
+ for augmentation in selected_augmentations:
265
+ with st.expander(f"{augmentation}"):
266
+ # Retrieve augmentation details from the database
267
+ augmentation_info = augmentation_details_df[
268
+ augmentation_details_df["Name"] == augmentation
269
+ ]
270
+
271
+ # Set up columns for displaying details and image placeholders
272
+ details_col, image_col = st.columns([7, 3])
273
+
274
+ with details_col:
275
+ # Display the description for the augmentation technique
276
+ augmentation_description = (
277
+ augmentation_info["Description"].iloc[0]
278
+ if not augmentation_info.empty
279
+ else "No description available."
280
+ )
281
+ st.markdown(
282
+ f"<div style='text-align: justify;'><b>Description:</b> {augmentation_description}</div>",
283
+ unsafe_allow_html=True,
284
+ )
285
+
286
+ # Display the category for the augmentation
287
+ augmentation_category = (
288
+ augmentation_info["Category"].iloc[0]
289
+ if not augmentation_info.empty
290
+ else "Unknown"
291
+ )
292
+ st.write("Category:", augmentation_category)
293
+
294
+ # Retrieve the source code link for the augmentation
295
+ augmentation_source_code = (
296
+ augmentation_info["Source Code Link"].iloc[0]
297
+ if not augmentation_info.empty
298
+ else "www.google.com"
299
+ )
300
+
301
+ # Set up columns for displaying source code button and custom settings checkbox
302
+ source_code_col, custo_setting_col = st.columns(2)
303
+
304
+ source_code_col.link_button("Source Code", augmentation_source_code)
305
+
306
+ # Toggle for custom settings
307
+ custom_settings = custo_setting_col.checkbox(
308
+ f"Customize {augmentation}", key=f"toggle_{augmentation}"
309
+ )
310
+
311
+ with image_col:
312
+ # Create two columns
313
+ col1, col2 = st.columns(2)
314
+ original_image_placeholder = col1.container(height=150, border=False)
315
+ processed_image_placeholder = col2.container(height=150, border=False)
316
+
317
+ # Apply custom settings
318
+ if custom_settings:
319
+ # Retrieve parameters for the augmentation
320
+ params_df = augmentation_params_df[
321
+ augmentation_params_df["Name"] == augmentation
322
+ ]
323
+
324
+ # Process parameters for each augmentation technique and store in a dictionary
325
+ augmentations_params[augmentation] = utils.process_image_parameters(
326
+ params_df, augmentation
327
+ )
328
+
329
+ else:
330
+ # Use default settings if customization is not selected
331
+ augmentations_params[augmentation] = utils.get_default_params(augmentation)
332
+
333
+ # Check for errors in the selected parameters by applying them to a sample image
334
+ (
335
+ error_flag,
336
+ processed_first_image,
337
+ ) = augmentation_functions.apply_and_test_augmentation(
338
+ augmentation,
339
+ augmentations_params[augmentation],
340
+ uploaded_first_image,
341
+ st.session_state["first_label_file"],
342
+ label_type,
343
+ label_input_parameters,
344
+ input_mapping_dict[augmentation],
345
+ )
346
+
347
+ # If there is an error in the parameters, set the global error flag
348
+ if error_flag:
349
+ error_in_parameters = True
350
+ else:
351
+ # If no error, display the original and processed images side by side
352
+ # Display the original and processed images in their respective placeholders
353
+ with original_image_placeholder:
354
+ st.image(
355
+ uploaded_first_image,
356
+ caption="Original Image",
357
+ use_column_width=True,
358
+ clamp=True,
359
+ )
360
+
361
+ with processed_image_placeholder:
362
+ st.image(
363
+ processed_first_image,
364
+ caption="Processed Image",
365
+ use_column_width=True,
366
+ clamp=True,
367
+ )
368
+
369
+ # Update the base image with the previously processed image output
370
+ uploaded_first_image = processed_first_image
371
+
372
+
373
+ #######################################################################################################
374
+ # Display selected augmentation technique parameters as DataFrame
375
+ #######################################################################################################
376
+
377
+
378
+ # Check if any augmentations have been defined
379
+ if (augmentations_params.keys()) and (not error_in_parameters):
380
+ # Create a dropdown for selecting an augmentation technique or 'All'
381
+ selected_augmentation = st.selectbox(
382
+ "Select augmentation technique",
383
+ options=["All"] + list(augmentations_params.keys()),
384
+ )
385
+ else:
386
+ selected_augmentation = None
387
+
388
+ # Create the DataFrame from the accumulated data
389
+ augmentations_df = augmentation_functions.create_augmentations_dataframe(
390
+ augmentations_params, augmentation_params_df
391
+ )
392
+ augmentations_df["Value"] = augmentations_df["Value"].astype(
393
+ str
394
+ ) # Ensure consistent data types and handle potential serialization issues
395
+
396
+ # Filter the DataFrame based on the selected augmentation
397
+ if selected_augmentation != "All":
398
+ filtered_augmentations_df = augmentations_df[
399
+ augmentations_df["augmentation"] == selected_augmentation
400
+ ]
401
+ else:
402
+ filtered_augmentations_df = augmentations_df
403
+
404
+ # Check if the filtered dataframe is not empty and the selected configurations are valid
405
+ if (not filtered_augmentations_df.empty) and (not error_in_parameters):
406
+ # Display the DataFrame in Streamlit and use the full width of the container
407
+ st.dataframe(filtered_augmentations_df, use_container_width=False)
408
+
409
+ # Display code and description
410
+ code_placeholder = st.empty()
411
+
412
+
413
+ #######################################################################################################
414
+ # Process images and download processed images
415
+ #######################################################################################################
416
+
417
+
418
+ # Proceed if inputs are valid, techniques selected, and no errors in configurations
419
+ if (
420
+ st.session_state["is_valid"]
421
+ and (len(selected_augmentations) > 0)
422
+ and not error_in_parameters
423
+ ):
424
+ # Create two columns
425
+ col1, col2 = st.columns(2)
426
+
427
+ # Allow user to specify the number of variations to be generated
428
+ num_variations = col1.number_input(
429
+ "Set the number of variations to be generated",
430
+ min_value=1,
431
+ max_value=3,
432
+ step=1,
433
+ )
434
+
435
+ # Checkbox to include original images and labels in the output
436
+ with col2:
437
+ for top_padding in range(2): # Top padding
438
+ st.write("")
439
+
440
+ include_original = st.checkbox(
441
+ "Include original images and labels in output", value=False
442
+ )
443
+
444
+ # Display code and download once all inputs are available
445
+ with code_placeholder:
446
+ # Generate the code with the function
447
+ if len(st.session_state["label_files"]) == 0:
448
+ generated_code = utils.generate_python_code_images(
449
+ augmentations_params,
450
+ num_variations,
451
+ include_original,
452
+ )
453
+ elif label_type == "Bboxes": # Selected label type is Bboxes
454
+ generated_code = augmentation_functions.generate_python_code_bboxes(
455
+ augmentations_params,
456
+ label_input_parameters,
457
+ num_variations,
458
+ include_original,
459
+ )
460
+ elif label_type == "Masks": # Selected label type is Bboxes
461
+ generated_code = augmentation_functions.generate_python_code_masks(
462
+ augmentations_params,
463
+ label_input_parameters,
464
+ num_variations,
465
+ include_original,
466
+ )
467
+
468
+ # Display the generated Python code with a description and provide a download button in the Streamlit app
469
+ augmentation_functions.display_code_and_download_button(generated_code)
470
+
471
+ # Create two columns
472
+ col1, col2 = st.columns(2)
473
+
474
+ # Add a button for the user to confirm their selections and proceed with processing
475
+ if col1.button("Accept and Process", use_container_width=True):
476
+ # Call the function and store the results
477
+ augmentation_functions.process_images_and_labels(
478
+ st.session_state["image_files"],
479
+ st.session_state["label_files"],
480
+ selected_augmentations,
481
+ augmentations_params,
482
+ label_type,
483
+ label_input_parameters,
484
+ num_variations,
485
+ include_original,
486
+ class_dict,
487
+ )
488
+
489
+ # Download button
490
+ col2.download_button(
491
+ label="Download",
492
+ data=st.session_state["zip_data_augmentation"],
493
+ file_name="augmented_images.zip",
494
+ mime="application/zip",
495
+ use_container_width=True,
496
+ disabled=False,
497
+ )
498
+
499
+ else:
500
+ if (len(selected_augmentations) == 0) and st.session_state["is_valid"]:
501
+ # Inform the user that no augmentation techniques have been selected
502
+ st.warning("Please select at least one augmentation technique.", icon="⚠️")
503
+
504
+ if error_in_parameters and st.session_state["is_valid"]:
505
+ # Inform the user that there are errors in parameters
506
+ st.warning(
507
+ "There are errors in the augmentation parameters. Please review your selections.",
508
+ icon="⚠️",
509
+ )
510
+
511
+
512
+ #######################################################################################################
513
+ # Display original and processed images
514
+ #######################################################################################################
515
+
516
+
517
+ # Check if image_repository and processed_image_mapping exist in session_state
518
+ if (
519
+ "image_repository_augmentation" in st.session_state
520
+ and "processed_image_mapping_augmentation" in st.session_state
521
+ ):
522
+ # Number of unique images
523
+ num_unique_images = len(st.session_state["unique_images_names"])
524
+
525
+ if num_unique_images > 1:
526
+ # Create a slider to select an image index from the processed image mapping
527
+ selected_image_index = st.slider(
528
+ "Select an Image",
529
+ min_value=1,
530
+ max_value=num_unique_images, # Set the maximum to the number of unique images
531
+ step=1,
532
+ )
533
+ else:
534
+ selected_image_index = 1
535
+
536
+ # Retrieve the name of the selected original image using the slider index
537
+ selected_original_image_name = st.session_state["unique_images_names"][
538
+ selected_image_index - 1
539
+ ]
540
+
541
+ # Retrieve the names of all processed variants for the selected original image
542
+ processed_variant_names = st.session_state[
543
+ "processed_image_mapping_augmentation"
544
+ ].get(selected_original_image_name, [])
545
+
546
+ # Combine the original image name with its processed variants
547
+ all_image_names = [selected_original_image_name] + processed_variant_names
548
+
549
+ if len(st.session_state["label_files"]) > 0:
550
+ # Options for displaying labels on the images
551
+ label_display_options = ["No Label", "All Labels", "Specific Labels"]
552
+
553
+ # Select box for the user to choose how labels should be displayed on the images
554
+ selected_label_display_option = st.selectbox(
555
+ "Choose how to display labels:",
556
+ label_display_options,
557
+ index=0, # Default option is 'No Label'
558
+ )
559
+
560
+ # If 'All Labels' option is selected, include all class IDs
561
+ if selected_label_display_option == "All Labels":
562
+ labels_to_plot = list(class_dict.keys())
563
+ # If 'Specific Labels' option is selected, allow user to select specific class IDs
564
+ elif selected_label_display_option == "Specific Labels":
565
+ selected_class_names = st.multiselect(
566
+ "Select specific labels to display",
567
+ list(class_names_to_ids.keys()),
568
+ class_dict[1],
569
+ )
570
+ labels_to_plot = [class_names_to_ids[name] for name in selected_class_names]
571
+ else:
572
+ selected_label_display_option = "No Label"
573
+
574
+ # Display images in a grid
575
+ num_images = len(all_image_names)
576
+ num_columns = 4
577
+ for i in range(0, num_images, num_columns):
578
+ cols = st.columns(num_columns)
579
+ for j in range(num_columns):
580
+ image_index = i + j
581
+ if image_index < num_images:
582
+ image_name = all_image_names[image_index]
583
+ image_data = st.session_state["image_repository_augmentation"][
584
+ image_name
585
+ ]["image"]
586
+ label_file = st.session_state["image_repository_augmentation"][
587
+ image_name
588
+ ]["label"]
589
+
590
+ # Overlay labels on the image based on the selected option
591
+ if selected_label_display_option in ["All Labels", "Specific Labels"]:
592
+ # Overlay labels if selected
593
+ modified_image = augmentation_functions.overlay_labels(
594
+ image=image_data.copy(),
595
+ labels_to_plot=labels_to_plot,
596
+ label_file=label_file,
597
+ label_type=label_type,
598
+ colors=colors,
599
+ class_dict=class_dict,
600
+ )
601
+ else:
602
+ # Use the original image without overlay if 'No Label' is selected
603
+ modified_image = image_data
604
+
605
+ # Display the image in the respective column with a caption
606
+ with cols[j]:
607
+ st.image(
608
+ modified_image,
609
+ clamp=True,
610
+ caption=image_name,
611
+ use_column_width=True,
612
+ )
pages/4_Model_Training.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Set the page config
2
+ import streamlit as st
3
+
4
+ st.set_page_config(
5
+ page_title="Model_Training",
6
+ page_icon=":open_file_folder:",
7
+ layout="wide",
8
+ initial_sidebar_state="collapsed",
9
+ )
10
+
11
+ # Importing necessary libraries
12
+ import utils
13
+ import streamlit as st
14
+ import Functions.model_training_functions as model_training_functions
15
+
16
+
17
+ # Display the page title
18
+ st.title("Model Training")
19
+
20
+ # # Clear the Streamlit session state on the first load of the page
21
+ # utils.clear_session_state_on_first_load("model_training_clear")
22
+
23
+ # List of session state keys to initialize if they are not already present
24
+ session_state_keys = [
25
+ "file_uploader_split_key_training",
26
+ "file_uploader_train_key_training",
27
+ "file_uploader_val_key_training",
28
+ "file_uploader_test_key_training",
29
+ "number_input_train_key",
30
+ "number_input_val_key",
31
+ "number_input_test_key",
32
+ "split_method_key",
33
+ "training_type_key",
34
+ "class_labels_input_key_training",
35
+ ]
36
+
37
+ # Iterate through each session state key
38
+ for key in session_state_keys:
39
+ # Check if the key is not already in the session state
40
+ if key not in st.session_state:
41
+ # Initialize the key with a dictionary containing itself set to True
42
+ st.session_state[key] = {key: True}
43
+
44
+ # Initialize session state variables if not present
45
+ if "validation_triggered" not in st.session_state:
46
+ st.session_state["validation_triggered"] = False
47
+
48
+ if "uploaded_files_cache_processing" not in st.session_state:
49
+ st.session_state["uploaded_files_cache_processing"] = False
50
+
51
+ # Initialize session state variables if not present
52
+ if "is_valid" not in st.session_state:
53
+ st.session_state["is_valid"] = False
54
+
55
+ # Container for file uploaders
56
+ file_uploader_container = st.container()
57
+
58
+ # Dictionary for mapping the user-friendly terms to technical label types
59
+ label_type_mapping = {"Object Detection": "Bboxes", "Instance Segmentation": "Masks"}
60
+
61
+ # Create two columns for widgets
62
+ column_select_training, column_split_method = st.columns(2)
63
+
64
+ # Dropdown for selecting the training type
65
+ with column_select_training:
66
+ selected_training = st.selectbox(
67
+ "Select the training type:",
68
+ list(label_type_mapping.keys()),
69
+ index=0,
70
+ on_change=utils.reset_validation_trigger,
71
+ key=st.session_state["training_type_key"],
72
+ )
73
+
74
+ # Getting the corresponding label type
75
+ label_type = label_type_mapping[selected_training]
76
+
77
+ # Toggle for choosing the split method
78
+ with column_split_method:
79
+ split_method = st.radio(
80
+ "Select the dataset split method:",
81
+ ["Percentage Split", "Direct Upload"],
82
+ horizontal=True,
83
+ on_change=utils.reset_validation_trigger,
84
+ key=st.session_state["split_method_key"],
85
+ )
86
+
87
+ # Text area for user to input class labels
88
+ class_labels_input = st.text_area(
89
+ "Enter class labels, separated by commas:",
90
+ utils.sample_class_labels,
91
+ on_change=utils.reset_validation_trigger,
92
+ key=st.session_state["class_labels_input_key_training"],
93
+ ) # Example default values
94
+ class_labels_input = (
95
+ class_labels_input.strip()
96
+ ) # Remove unecessary space form start and end
97
+
98
+ # Generating a dictionary mapping class IDs to their respective labels
99
+ try:
100
+ class_labels = [
101
+ label.strip() for label in class_labels_input.split(",") if label.strip()
102
+ ]
103
+ class_dict = {i: label for i, label in enumerate(class_labels)}
104
+ # Invert the class_dict to map class names to class IDs
105
+ class_names_to_ids = {v: k for k, v in class_dict.items()}
106
+
107
+ except Exception as e:
108
+ st.warning(
109
+ "Invalid format for class labels. Please enter labels separated by commas.",
110
+ icon="⚠️",
111
+ )
112
+ class_dict, class_names_to_ids = (
113
+ {},
114
+ {},
115
+ ) # Keeping class_dict and class_names_to_ids as an empty
116
+
117
+ # Note to users
118
+ st.markdown(
119
+ """
120
+ <div style='text-align: justify;'>
121
+ <b>Note to Users:</b>
122
+ <ul>
123
+ <li>When moving to another page or if you wish to upload a new set of images and labels, don't forget to hit the <b>Reset</b> button. This helps in faster computation and frees up unused memory, ensuring smoother operation.</li>
124
+ <li>Select the training type, class labels, dataset split method and its parameters before uploading large data for faster computation and more efficient processing.</li>
125
+ </ul>
126
+ </div>
127
+ """,
128
+ unsafe_allow_html=True,
129
+ )
130
+
131
+ # Create two columns for input percentages
132
+ validate_button_col, reset_button_col = st.columns(2)
133
+
134
+ with reset_button_col:
135
+ # Check if the 'Reset' button is pressed
136
+ if st.button("Reset", use_container_width=True):
137
+ # Clear folders
138
+ model_training_functions.delete_and_recreate_folder(
139
+ model_training_functions.get_path("output")
140
+ )
141
+ model_training_functions.clear_data_folders()
142
+
143
+ # List of session state keys that need to be reset
144
+ session_state_keys = [
145
+ "file_uploader_split_key_training",
146
+ "file_uploader_train_key_training",
147
+ "file_uploader_val_key_training",
148
+ "file_uploader_test_key_training",
149
+ "number_input_train_key",
150
+ "number_input_val_key",
151
+ "number_input_test_key",
152
+ "split_method_key",
153
+ "training_type_key",
154
+ "class_labels_input_key_training",
155
+ ]
156
+
157
+ # Iterate through each session state key
158
+ for session_state_key in session_state_keys:
159
+ # Toggle the keys to reset their states
160
+ current_value = st.session_state[session_state_key][session_state_key]
161
+ updated_value = not current_value # Invert the current value
162
+
163
+ # Update each key in the session state with the toggled value
164
+ st.session_state[session_state_key] = {session_state_key: updated_value}
165
+
166
+ # Clear all other session state keys except for widget_state_keys
167
+ for key in list(st.session_state.keys()):
168
+ if key not in session_state_keys:
169
+ del st.session_state[key]
170
+
171
+ # Clear global variables except for protected and Streamlit module
172
+ global_vars = list(globals().keys())
173
+ vars_to_delete = [
174
+ var for var in global_vars if not var.startswith("_") and var != "st"
175
+ ]
176
+ for var in vars_to_delete:
177
+ del globals()[var]
178
+
179
+ # Clear the Streamlit caches
180
+ st.cache_resource.clear()
181
+ st.cache_data.clear()
182
+
183
+ # Rerun the app to reflect the reset state
184
+ st.rerun()
185
+
186
+ # Code for "Percentage Split" method
187
+ if split_method == "Percentage Split":
188
+ with file_uploader_container:
189
+ # User uploads images and labels
190
+ utils.display_file_uploader(
191
+ "uploaded_files",
192
+ "Choose images and labels...",
193
+ st.session_state["file_uploader_split_key_training"],
194
+ st.session_state["uploaded_files_cache_processing"],
195
+ )
196
+
197
+ # Create three columns for input percentages
198
+ col1, col2, col3 = st.columns(3)
199
+
200
+ # User specifies split percentages
201
+ train_pct = col1.number_input(
202
+ "Train Set Percentage",
203
+ 0,
204
+ 100,
205
+ 70,
206
+ 1,
207
+ on_change=utils.reset_validation_trigger,
208
+ key=st.session_state["number_input_train_key"],
209
+ )
210
+ test_pct = col2.number_input(
211
+ "Test Set Percentage",
212
+ 0,
213
+ 100,
214
+ 15,
215
+ 1,
216
+ on_change=utils.reset_validation_trigger,
217
+ key=st.session_state["number_input_val_key"],
218
+ )
219
+ val_pct = col3.number_input(
220
+ "Validation Set Percentage",
221
+ 0,
222
+ 100,
223
+ 15,
224
+ 1,
225
+ on_change=utils.reset_validation_trigger,
226
+ key=st.session_state["number_input_test_key"],
227
+ )
228
+
229
+ # Check if the total percentage equals 100%
230
+ pct_check = train_pct + test_pct + val_pct
231
+
232
+ # Validating the input percentages
233
+ pct_condition_check = (
234
+ pct_check == 100
235
+ and train_pct > 0
236
+ and val_pct > 0
237
+ and model_training_functions.check_min_images(
238
+ len(st.session_state["uploaded_files"]), train_pct, val_pct, test_pct
239
+ )
240
+ )
241
+
242
+ if not pct_condition_check:
243
+ file_uploader_container.warning(
244
+ "The percentages for train, test, and validation sets should add up to 100%, and train and validation set should not be empty.",
245
+ icon="⚠️",
246
+ )
247
+
248
+ # Button to trigger validation
249
+ if validate_button_col.button("Validate Input", use_container_width=True):
250
+ st.session_state["validation_triggered"] = True
251
+
252
+ st.session_state["is_valid"] = model_training_functions.check_valid_labels(
253
+ st.session_state["uploaded_files"], label_type, class_dict
254
+ )
255
+
256
+ if st.session_state["is_valid"]:
257
+ model_training_functions.create_yolo_config_file(
258
+ model_training_functions.get_path("config"),
259
+ class_labels,
260
+ )
261
+ model_training_functions.clear_data_folders()
262
+ paired_files = model_training_functions.pair_files(
263
+ st.session_state["uploaded_files"]
264
+ )
265
+ model_training_functions.split_and_save_files(
266
+ paired_files, train_pct, test_pct
267
+ )
268
+
269
+ # Process files if input is valid
270
+ if st.session_state["validation_triggered"] and (
271
+ pct_condition_check and st.session_state["is_valid"]
272
+ ):
273
+ model_training_functions.start_yolo_training(selected_training, class_labels)
274
+ else:
275
+ # Display a warning message if the validation is not successful or conditions are not met
276
+ st.warning(
277
+ "Please upload valid input, select valid parameters, and click **Validate Input**.",
278
+ icon="⚠️",
279
+ )
280
+
281
+ # Code for "Direct Upload" method
282
+ elif split_method == "Direct Upload":
283
+ with file_uploader_container:
284
+ # Create three columns for uploading train, val, and test files
285
+ col1, col2, col3 = st.columns(3)
286
+
287
+ with col1:
288
+ utils.display_file_uploader(
289
+ "uploaded_train_files",
290
+ "Upload Training Images and Labels",
291
+ st.session_state["file_uploader_train_key_training"],
292
+ st.session_state["uploaded_files_cache_processing"],
293
+ )
294
+
295
+ with col2:
296
+ utils.display_file_uploader(
297
+ "uploaded_val_files",
298
+ "Upload Validation Images and Labels",
299
+ st.session_state["file_uploader_val_key_training"],
300
+ st.session_state["uploaded_files_cache_processing"],
301
+ )
302
+
303
+ with col3:
304
+ utils.display_file_uploader(
305
+ "uploaded_test_files",
306
+ "Upload Test Images and Labels",
307
+ st.session_state["file_uploader_test_key_training"],
308
+ st.session_state["uploaded_files_cache_processing"],
309
+ )
310
+
311
+ # Check for valid input
312
+ pct_condition_check = (
313
+ len(st.session_state["uploaded_train_files"]) > 0
314
+ and len(st.session_state["uploaded_val_files"]) > 0
315
+ )
316
+
317
+ if not pct_condition_check:
318
+ file_uploader_container.warning(
319
+ "The train and validation set should not be empty.",
320
+ icon="⚠️",
321
+ )
322
+
323
+ # Button to trigger validation
324
+ if validate_button_col.button("Validate Input", use_container_width=True):
325
+ st.session_state["validation_triggered"] = True
326
+
327
+ st.session_state["is_valid"] = model_training_functions.check_valid_labels(
328
+ st.session_state["uploaded_train_files"]
329
+ + st.session_state["uploaded_val_files"]
330
+ + st.session_state["uploaded_test_files"],
331
+ label_type,
332
+ class_dict,
333
+ )
334
+
335
+ if st.session_state["is_valid"]:
336
+ model_training_functions.create_yolo_config_file(
337
+ model_training_functions.get_path("config"),
338
+ class_labels,
339
+ )
340
+ model_training_functions.clear_data_folders()
341
+ model_training_functions.save_files_to_folder(
342
+ st.session_state["uploaded_train_files"], "train"
343
+ )
344
+ model_training_functions.save_files_to_folder(
345
+ st.session_state["uploaded_val_files"], "val"
346
+ )
347
+
348
+ # Only save test files if they are uploaded
349
+ if len(st.session_state["uploaded_test_files"]) > 0:
350
+ model_training_functions.save_files_to_folder(
351
+ st.session_state["uploaded_test_files"], "test"
352
+ )
353
+
354
+ # Process files if input is valid
355
+ if st.session_state["validation_triggered"] and (
356
+ pct_condition_check and st.session_state["is_valid"]
357
+ ):
358
+ model_training_functions.start_yolo_training(selected_training, class_labels)
359
+ else:
360
+ # Display a warning message if the validation is not successful or conditions are not met
361
+ st.warning(
362
+ "Please upload valid input, select valid parameters, and click **Validate Input**.",
363
+ icon="⚠️",
364
+ )
requirements.txt ADDED
Binary file (2.99 kB). View file
 
utils.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing necessary libraries
2
+ import cv2
3
+ import utils
4
+ import inspect
5
+ import numpy as np
6
+ import pandas as pd
7
+ import streamlit as st
8
+ import albumentations as A
9
+
10
+ # Please retain all the imported libraries, even if they appear to be unused, as they are necessary for the input part of the code (dynamic_input function)
11
+
12
+
13
+ # Sample class labels
14
+ sample_class_labels = "person, bicycle, car, motorcycle, airplane, bus, train, truck, boat, traffic light, fire hydrant, stop sign, parking meter, bench, bird, cat, dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella, handbag, tie, suitcase, frisbee, skis, snowboard, sports ball, kite, baseball bat, baseball glove, skateboard, surfboard, tennis racket, bottle, wine glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange, broccoli, carrot, hot dog, pizza, donut, cake, chair, couch, potted plant, bed, dining table, toilet, tv, laptop, mouse, remote, keyboard, cell phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy bear, hair drier, toothbrush"
15
+
16
+
17
+ # Model info
18
+ models_info = {
19
+ "YOLOv8n": {
20
+ "size": 640,
21
+ "mAPval": 37.3,
22
+ "speed_cpu": 80.4,
23
+ "speed_gpu": 0.99,
24
+ "params": 3.2,
25
+ "flops": 8.7,
26
+ },
27
+ "YOLOv8s": {
28
+ "size": 640,
29
+ "mAPval": 44.9,
30
+ "speed_cpu": 128.4,
31
+ "speed_gpu": 1.20,
32
+ "params": 11.2,
33
+ "flops": 28.6,
34
+ },
35
+ "YOLOv8m": {
36
+ "size": 640,
37
+ "mAPval": 50.2,
38
+ "speed_cpu": 234.7,
39
+ "speed_gpu": 1.83,
40
+ "params": 25.9,
41
+ "flops": 78.9,
42
+ },
43
+ "YOLOv8l": {
44
+ "size": 640,
45
+ "mAPval": 52.9,
46
+ "speed_cpu": 375.2,
47
+ "speed_gpu": 2.39,
48
+ "params": 43.7,
49
+ "flops": 165.2,
50
+ },
51
+ "YOLOv8x": {
52
+ "size": 640,
53
+ "mAPval": 53.9,
54
+ "speed_cpu": 479.1,
55
+ "speed_gpu": 3.53,
56
+ "params": 68.2,
57
+ "flops": 257.8,
58
+ },
59
+ }
60
+
61
+
62
+ # Export formats info
63
+ export_formats = {
64
+ "TorchScript": {"format_argument": "torchscript", "arguments": ["optimize"]},
65
+ "ONNX": {
66
+ "format_argument": "onnx",
67
+ "arguments": ["half", "dynamic", "simplify", "opset"],
68
+ },
69
+ "OpenVINO": {"format_argument": "openvino", "arguments": ["half", "int8"]},
70
+ "TensorRT": {
71
+ "format_argument": "engine",
72
+ "arguments": ["half", "dynamic", "simplify", "workspace"],
73
+ },
74
+ "CoreML": {"format_argument": "coreml", "arguments": ["half", "int8", "nms"]},
75
+ "TF SavedModel": {"format_argument": "saved_model", "arguments": ["keras", "int8"]},
76
+ "TF GraphDef": {"format_argument": "pb", "arguments": []},
77
+ "TF Lite": {"format_argument": "tflite", "arguments": ["half", "int8"]},
78
+ "TF Edge TPU": {"format_argument": "edgetpu", "arguments": []},
79
+ "TF.js": {"format_argument": "tfjs", "arguments": ["half", "int8"]},
80
+ "PaddlePaddle": {"format_argument": "paddle", "arguments": []},
81
+ "ncnn": {"format_argument": "ncnn", "arguments": ["half"]},
82
+ }
83
+
84
+
85
+ # Function to add Top Padding
86
+ def top_padding(padding_aount=0):
87
+ for i in range(padding_aount):
88
+ st.write("")
89
+
90
+
91
+ # Function to load data from an Excel file
92
+ @st.cache_resource(show_spinner=False)
93
+ def load_data_from_excel(filename, sheet_name):
94
+ """Load data from a specified sheet in an Excel file."""
95
+ return pd.read_excel(filename, sheet_name=sheet_name)
96
+
97
+
98
+ # Function to resets the validation trigger in the session state
99
+ def reset_validation_trigger():
100
+ st.session_state["validation_triggered"] = False
101
+
102
+
103
+ # Function to clear the Streamlit session state on the first load of the page
104
+ def clear_session_state_on_first_load(key):
105
+ # Check if the function has already been executed
106
+ if key not in st.session_state:
107
+ # Clear the session state
108
+ for k in list(st.session_state.keys()):
109
+ if k != key: # Preserve the key itself
110
+ del st.session_state[k]
111
+ # Mark this function as executed
112
+ st.session_state[key] = True
113
+
114
+
115
+ # Function to display a file uploader based on the cached state
116
+ def display_file_uploader(uploaded_files_type, label_value, key_value, is_cached):
117
+ if not is_cached:
118
+ st.session_state[uploaded_files_type] = st.file_uploader(
119
+ label=label_value,
120
+ type=["jpg", "png", "txt"],
121
+ accept_multiple_files=True,
122
+ on_change=utils.reset_validation_trigger,
123
+ key=key_value,
124
+ )
125
+
126
+ else:
127
+ st.file_uploader(
128
+ label_value,
129
+ type=["jpg", "png", "txt"],
130
+ accept_multiple_files=True,
131
+ disabled=True,
132
+ )
133
+
134
+
135
+ # Function to map each technique to its corresponding image types
136
+ @st.cache_resource(show_spinner=False)
137
+ def technique_image_input_mapping(available_techniques, details_df):
138
+ input_mapping_dict = {}
139
+ for technique in available_techniques:
140
+ # Filter the DataFrame for rows where the Name matches the technique
141
+ filtered_df = details_df[details_df["Name"] == technique]
142
+
143
+ # Extract the Image Types for these rows, split by comma, and strip spaces
144
+ image_types = (
145
+ filtered_df["Image Types"]
146
+ .apply(lambda x: [type.strip() for type in str(x).split(",")])
147
+ .tolist()
148
+ )
149
+
150
+ # Add to the dictionary
151
+ input_mapping_dict[technique] = image_types[0]
152
+
153
+ return input_mapping_dict
154
+
155
+
156
+ # Function to check if the image type match the allowed types
157
+ def is_image_type_allowed(input_image_type, num_channels, allowed_image_types):
158
+ for allowed_type in allowed_image_types:
159
+ # Handle general types like 'Any' or 'nan'
160
+ if allowed_type in ["Any", "nan", input_image_type]:
161
+ return True
162
+
163
+ # Handle specific multi-channel types (e.g., '3-channel uint8 images only')
164
+ if "channel" in allowed_type:
165
+ parts = allowed_type.split()
166
+
167
+ try:
168
+ allowed_channels = int(parts[0][0])
169
+ allowed_dtype = parts[1]
170
+
171
+ if (
172
+ num_channels == allowed_channels
173
+ and input_image_type == allowed_dtype
174
+ ):
175
+ return True
176
+ except ValueError:
177
+ # Continue to the next allowed type
178
+ continue
179
+
180
+ # Add additional specific checks if needed
181
+
182
+ return False # If none of the conditions are met
183
+
184
+
185
+ # Function to creates a dynamic input widget based on the specified data type
186
+ def dynamic_input(
187
+ image_processing, data_type, label, input_key, input_data, default_val
188
+ ):
189
+ if data_type == "bool":
190
+ # Create a checkbox for bool input
191
+ return st.checkbox(label=label, value=eval(default_val), key=input_key)
192
+
193
+ elif data_type == "int" or data_type == "float":
194
+ # Create a slider for integer / float input
195
+ return st.select_slider(
196
+ label=label,
197
+ options=eval(input_data),
198
+ value=eval(default_val),
199
+ key=input_key,
200
+ )
201
+
202
+ elif data_type == "tuple":
203
+ # Create a slider for tuple input
204
+ min_val, max_val = st.select_slider(
205
+ label=label,
206
+ options=eval(input_data),
207
+ value=(eval(default_val)[0], eval(default_val)[1]),
208
+ key=input_key,
209
+ )
210
+ return (min_val, max_val)
211
+
212
+ elif data_type == "single_select":
213
+ # Create a selectbox for single select input
214
+ return st.selectbox(
215
+ label=label,
216
+ options=eval(input_data),
217
+ index=eval(input_data).index(eval(default_val)),
218
+ )
219
+
220
+ elif data_type == "none":
221
+ # Return None without creating any input widget
222
+ return None
223
+
224
+ elif data_type == "code":
225
+ # Create a selectbox for code input
226
+ try:
227
+ # Evaluate the code entered by the user in the text input
228
+ input_code = eval(st.text_input(label=label, value=default_val))
229
+ except:
230
+ # If the user input is not valid Python code, catch the exception
231
+ st.warning(f"Invalid input for {label}. Using default value.", icon="⚠️")
232
+ # Return the default parameter value for this image processing technique
233
+ input_code = get_default_params(image_processing)[label]
234
+
235
+ # Show an example of the expected input format to the user
236
+ st.write(f"Sample Input for {label}:")
237
+ st.code(input_data, language="python", line_numbers=False)
238
+
239
+ return input_code
240
+
241
+ # Handle unsupported data types
242
+ else:
243
+ # Display an error message and return None for unsupported data types
244
+ st.error(f"Unsupported data type: {data_type}")
245
+
246
+
247
+ # Returns a dictionary of default parameters for a given Albumentations augmentation class
248
+ def get_default_params(augmentation_name):
249
+ # Retrieve the augmentation class from the Albumentations module
250
+ augmentation_class = getattr(A, augmentation_name, None)
251
+
252
+ # Inspect the constructor (__init__) of the augmentation class to get its parameters
253
+ params = inspect.signature(augmentation_class.__init__).parameters
254
+
255
+ # Create a dictionary of parameter names and their default values
256
+ default_params = {
257
+ name: param.default
258
+ for name, param in params.items()
259
+ if param.default is not inspect.Parameter.empty
260
+ }
261
+
262
+ return default_params
263
+
264
+
265
+ # Function to resize the image to a maximum dimension of 512 pixels while maintaining aspect ratio
266
+ def resize_image(image, max_size=512):
267
+ height, width = image.shape[:2]
268
+ # Calculate the ratio to maintain aspect ratio
269
+ ratio = min(max_size / height, max_size / width)
270
+ new_dimensions = (int(width * ratio), int(height * ratio))
271
+ resized_image = cv2.resize(image, new_dimensions, interpolation=cv2.INTER_AREA)
272
+ return resized_image
273
+
274
+
275
+ # Function to apply the specified transformation with optional custom parameters
276
+ def apply_albumentation(custom_params, albumentation_name):
277
+ # Get the Albumentations transformation class from the albumentation_name string
278
+ albumentation_class = getattr(A, albumentation_name, None)
279
+
280
+ # Check if the transformation class exists in the Albumentations library
281
+ if albumentation_class is not None:
282
+ # Apply the transformation with or without custom parameters
283
+ if custom_params is None:
284
+ return albumentation_class()
285
+ else:
286
+ return albumentation_class(**custom_params)
287
+
288
+
289
+ # Function to process image parameters from a DataFrame and create interactive UI
290
+ def process_image_parameters(params_df, image_processing):
291
+ # Dictionary to map UI input types to database values
292
+ type_mapping = {
293
+ "Boolean": "bool",
294
+ "Integer": "int",
295
+ "Float": "float",
296
+ "Tuple": "tuple",
297
+ "Single Select": "single_select",
298
+ "Code": "code",
299
+ "None": "none",
300
+ }
301
+
302
+ # Create a reverse mapping from database values to UI input types
303
+ reverse_type_mapping = {v: k for k, v in type_mapping.items()}
304
+
305
+ # Initialize a dictionary to store user-selected parameter values
306
+ image_processing_params = {}
307
+
308
+ if not params_df.empty:
309
+ for _, row in params_df.iterrows():
310
+ # Spacer
311
+ st.markdown("---")
312
+
313
+ param_name = row["Parameter Name"]
314
+ input_types = [
315
+ input_type.strip() for input_type in str(row["Input Type"]).split(",")
316
+ ]
317
+
318
+ if len(input_types) > 1:
319
+ col1, col2 = st.columns([3, 7])
320
+ selected_type_ui = col1.selectbox(
321
+ "Choose input type",
322
+ [reverse_type_mapping[ip] for ip in input_types],
323
+ key=f"{image_processing}_{param_name}_selectbox",
324
+ )
325
+ # Map the selected UI input type to database value
326
+ selected_type = type_mapping[selected_type_ui]
327
+ selected_index = input_types.index(selected_type)
328
+
329
+ with col2:
330
+ user_input = utils.dynamic_input(
331
+ image_processing=image_processing,
332
+ data_type=selected_type,
333
+ label=param_name,
334
+ input_data=str(row[f"Input Values {selected_index+1}"]),
335
+ default_val=str(row[f"Default Value {selected_index+1}"]),
336
+ input_key=f"{image_processing}_{param_name}_{selected_type}",
337
+ )
338
+ else:
339
+ user_input = utils.dynamic_input(
340
+ image_processing=image_processing,
341
+ data_type=input_types[0],
342
+ label=param_name,
343
+ input_data=str(row[f"Input Values 1"]),
344
+ default_val=str(row[f"Default Value 1"]),
345
+ input_key=f"{image_processing}_{param_name}_{input_types[0]}",
346
+ )
347
+
348
+ st.markdown(
349
+ f"<div style='text-align: justify;'><b>Description:</b> {row['Parameter Description']}</div>",
350
+ unsafe_allow_html=True,
351
+ )
352
+
353
+ image_processing_params[param_name] = user_input
354
+
355
+ # Top-padding
356
+ top_padding(2)
357
+
358
+ return image_processing_params
359
+
360
+
361
+ # Function to generates a DataFrame detailing image processing technique parameters and descriptions
362
+ @st.cache_resource(show_spinner=False)
363
+ def create_image_processings_dataframe(
364
+ image_processings_params, image_processing_params_db
365
+ ):
366
+ data = []
367
+ for aug_name, params in image_processings_params.items():
368
+ for param_name, param_value in params.items():
369
+ # Retrieve relevant image_processing information from the database
370
+ image_processing_info = image_processing_params_db[
371
+ image_processing_params_db["Name"] == aug_name
372
+ ]
373
+ param_info = image_processing_info[
374
+ image_processing_info["Parameter Name"] == param_name
375
+ ]
376
+
377
+ # Check if the parameter information exists in the database
378
+ if not param_info.empty:
379
+ # Get the description of the current parameter
380
+ param_description = param_info["Parameter Description"].iloc[0]
381
+ else:
382
+ param_description = "Description not available"
383
+
384
+ # Append image_processing name, parameter name, its value, and description to the data list
385
+ data.append([aug_name, param_name, param_value, param_description])
386
+
387
+ # Create the DataFrame from the accumulated data
388
+ image_processings_df = pd.DataFrame(
389
+ data, columns=["image_processing", "Parameter", "Value", "Description"]
390
+ )
391
+ return image_processings_df
392
+
393
+
394
+ def generate_python_code_images(
395
+ augmentations_params,
396
+ num_variations=1,
397
+ include_original=False,
398
+ ):
399
+ # Start with necessary library imports
400
+ code_str = "# Importing necessary libraries\n"
401
+ code_str += "import os\nimport cv2\nimport shutil\nimport albumentations as A\n\n"
402
+
403
+ # Paths for input and output directories
404
+ code_str += "# Define the paths for input and output directories\n"
405
+ code_str += "input_directory = 'path/to/input'\n"
406
+ code_str += "output_directory = 'path/to/output'\n\n"
407
+
408
+ # Function to create an augmentation pipeline
409
+ code_str += "# Function to create an augmentation pipeline using Albumentations\n"
410
+ code_str += "def process_image(image):\n"
411
+ code_str += " # Define the sequence of augmentation techniques\n"
412
+ code_str += " pipeline = A.Compose([\n"
413
+ for technique, params in augmentations_params.items():
414
+ code_str += f" A.{technique}({', '.join(f'{k}={v}' for k, v in params.items())}),\n"
415
+ code_str += " ])\n"
416
+ code_str += " # Apply the augmentation pipeline\n"
417
+ code_str += " return pipeline(image=image)['image']\n\n"
418
+
419
+ # Function to process a batch of images
420
+ code_str += "# Function to process a batch of images\n"
421
+ code_str += "def process_batch(input_directory, output_directory, include_original=False, num_variations=1):\n"
422
+ code_str += " for filename in os.listdir(input_directory):\n"
423
+ code_str += " if filename.lower().endswith(('.png', '.jpg', '.jpeg')):\n"
424
+ code_str += " image_path = os.path.join(input_directory, filename)\n\n"
425
+ code_str += " # Read the image\n"
426
+ code_str += " image = cv2.imread(image_path)\n\n"
427
+
428
+ # Include original image
429
+ code_str += " # Include original image\n"
430
+ if include_original:
431
+ code_str += " shutil.copy2(image_path, output_directory)\n\n"
432
+
433
+ code_str += " # Generate variations for each image and process them\n"
434
+ code_str += f" for variation in range({num_variations}):\n"
435
+ code_str += " processed_image = process_image(image)\n\n"
436
+ code_str += " # Save the processed image\n"
437
+ code_str += " output_filename = f'processed_{os.path.splitext(filename)[0]}_{variation}{os.path.splitext(filename)[1]}'\n"
438
+ code_str += " cv2.imwrite(os.path.join(output_directory, output_filename), processed_image)\n\n"
439
+
440
+ # Execute the batch processing function
441
+ code_str += (
442
+ "# Execute the batch processing function with the specified parameters\n"
443
+ )
444
+ code_str += "process_batch(input_directory, output_directory)\n"
445
+
446
+ return code_str