File size: 13,438 Bytes
3d90a2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
# Set the page config
import streamlit as st

st.set_page_config(
    page_title="Model_Training",
    page_icon=":open_file_folder:",
    layout="wide",
    initial_sidebar_state="collapsed",
)

# Importing necessary libraries
import utils
import streamlit as st
import Functions.model_training_functions as model_training_functions


# Display the page title
st.title("Model Training")

# # Clear the Streamlit session state on the first load of the page
# utils.clear_session_state_on_first_load("model_training_clear")

# List of session state keys to initialize if they are not already present
session_state_keys = [
    "file_uploader_split_key_training",
    "file_uploader_train_key_training",
    "file_uploader_val_key_training",
    "file_uploader_test_key_training",
    "number_input_train_key",
    "number_input_val_key",
    "number_input_test_key",
    "split_method_key",
    "training_type_key",
    "class_labels_input_key_training",
]

# Iterate through each session state key
for key in session_state_keys:
    # Check if the key is not already in the session state
    if key not in st.session_state:
        # Initialize the key with a dictionary containing itself set to True
        st.session_state[key] = {key: True}

# Initialize session state variables if not present
if "validation_triggered" not in st.session_state:
    st.session_state["validation_triggered"] = False

if "uploaded_files_cache_processing" not in st.session_state:
    st.session_state["uploaded_files_cache_processing"] = False

# Initialize session state variables if not present
if "is_valid" not in st.session_state:
    st.session_state["is_valid"] = False

# Container for file uploaders
file_uploader_container = st.container()

# Dictionary for mapping the user-friendly terms to technical label types
label_type_mapping = {"Object Detection": "Bboxes", "Instance Segmentation": "Masks"}

# Create two columns for widgets
column_select_training, column_split_method = st.columns(2)

# Dropdown for selecting the training type
with column_select_training:
    selected_training = st.selectbox(
        "Select the training type:",
        list(label_type_mapping.keys()),
        index=0,
        on_change=utils.reset_validation_trigger,
        key=st.session_state["training_type_key"],
    )

# Getting the corresponding label type
label_type = label_type_mapping[selected_training]

# Toggle for choosing the split method
with column_split_method:
    split_method = st.radio(
        "Select the dataset split method:",
        ["Percentage Split", "Direct Upload"],
        horizontal=True,
        on_change=utils.reset_validation_trigger,
        key=st.session_state["split_method_key"],
    )

# Text area for user to input class labels
class_labels_input = st.text_area(
    "Enter class labels, separated by commas:",
    utils.sample_class_labels,
    on_change=utils.reset_validation_trigger,
    key=st.session_state["class_labels_input_key_training"],
)  # Example default values
class_labels_input = (
    class_labels_input.strip()
)  # Remove unecessary space form start and end

# Generating a dictionary mapping class IDs to their respective labels
try:
    class_labels = [
        label.strip() for label in class_labels_input.split(",") if label.strip()
    ]
    class_dict = {i: label for i, label in enumerate(class_labels)}
    # Invert the class_dict to map class names to class IDs
    class_names_to_ids = {v: k for k, v in class_dict.items()}

except Exception as e:
    st.warning(
        "Invalid format for class labels. Please enter labels separated by commas.",
        icon="⚠️",
    )
    class_dict, class_names_to_ids = (
        {},
        {},
    )  # Keeping class_dict and class_names_to_ids as an empty

# Note to users
st.markdown(
    """

    <div style='text-align: justify;'>

        <b>Note to Users:</b>

        <ul>

            <li>When moving to another page or if you wish to upload a new set of images and labels, don't forget to hit the <b>Reset</b> button. This helps in faster computation and frees up unused memory, ensuring smoother operation.</li>

            <li>Select the training type, class labels, dataset split method and its parameters before uploading large data for faster computation and more efficient processing.</li>

        </ul>

    </div>

    """,
    unsafe_allow_html=True,
)

# Create two columns for input percentages
validate_button_col, reset_button_col = st.columns(2)

with reset_button_col:
    # Check if the 'Reset' button is pressed
    if st.button("Reset", use_container_width=True):
        # Clear folders
        model_training_functions.delete_and_recreate_folder(
            model_training_functions.get_path("output")
        )
        model_training_functions.clear_data_folders()

        # List of session state keys that need to be reset
        session_state_keys = [
            "file_uploader_split_key_training",
            "file_uploader_train_key_training",
            "file_uploader_val_key_training",
            "file_uploader_test_key_training",
            "number_input_train_key",
            "number_input_val_key",
            "number_input_test_key",
            "split_method_key",
            "training_type_key",
            "class_labels_input_key_training",
        ]

        # Iterate through each session state key
        for session_state_key in session_state_keys:
            # Toggle the keys to reset their states
            current_value = st.session_state[session_state_key][session_state_key]
            updated_value = not current_value  # Invert the current value

            # Update each key in the session state with the toggled value
            st.session_state[session_state_key] = {session_state_key: updated_value}

        # Clear all other session state keys except for widget_state_keys
        for key in list(st.session_state.keys()):
            if key not in session_state_keys:
                del st.session_state[key]

        # Clear global variables except for protected and Streamlit module
        global_vars = list(globals().keys())
        vars_to_delete = [
            var for var in global_vars if not var.startswith("_") and var != "st"
        ]
        for var in vars_to_delete:
            del globals()[var]

        # Clear the Streamlit caches
        st.cache_resource.clear()
        st.cache_data.clear()

        # Rerun the app to reflect the reset state
        st.rerun()

# Code for "Percentage Split" method
if split_method == "Percentage Split":
    with file_uploader_container:
        # User uploads images and labels
        utils.display_file_uploader(
            "uploaded_files",
            "Choose images and labels...",
            st.session_state["file_uploader_split_key_training"],
            st.session_state["uploaded_files_cache_processing"],
        )

        # Create three columns for input percentages
        col1, col2, col3 = st.columns(3)

        # User specifies split percentages
        train_pct = col1.number_input(
            "Train Set Percentage",
            0,
            100,
            70,
            1,
            on_change=utils.reset_validation_trigger,
            key=st.session_state["number_input_train_key"],
        )
        test_pct = col2.number_input(
            "Test Set Percentage",
            0,
            100,
            15,
            1,
            on_change=utils.reset_validation_trigger,
            key=st.session_state["number_input_val_key"],
        )
        val_pct = col3.number_input(
            "Validation Set Percentage",
            0,
            100,
            15,
            1,
            on_change=utils.reset_validation_trigger,
            key=st.session_state["number_input_test_key"],
        )

    # Check if the total percentage equals 100%
    pct_check = train_pct + test_pct + val_pct

    # Validating the input percentages
    pct_condition_check = (
        pct_check == 100
        and train_pct > 0
        and val_pct > 0
        and model_training_functions.check_min_images(
            len(st.session_state["uploaded_files"]), train_pct, val_pct, test_pct
        )
    )

    if not pct_condition_check:
        file_uploader_container.warning(
            "The percentages for train, test, and validation sets should add up to 100%, and train and validation set should not be empty.",
            icon="⚠️",
        )

    # Button to trigger validation
    if validate_button_col.button("Validate Input", use_container_width=True):
        st.session_state["validation_triggered"] = True

        st.session_state["is_valid"] = model_training_functions.check_valid_labels(
            st.session_state["uploaded_files"], label_type, class_dict
        )

        if st.session_state["is_valid"]:
            model_training_functions.create_yolo_config_file(
                model_training_functions.get_path("config"),
                class_labels,
            )
            model_training_functions.clear_data_folders()
            paired_files = model_training_functions.pair_files(
                st.session_state["uploaded_files"]
            )
            model_training_functions.split_and_save_files(
                paired_files, train_pct, test_pct
            )

    # Process files if input is valid
    if st.session_state["validation_triggered"] and (
        pct_condition_check and st.session_state["is_valid"]
    ):
        model_training_functions.start_yolo_training(selected_training, class_labels)
    else:
        # Display a warning message if the validation is not successful or conditions are not met
        st.warning(
            "Please upload valid input, select valid parameters, and click **Validate Input**.",
            icon="⚠️",
        )

# Code for "Direct Upload" method
elif split_method == "Direct Upload":
    with file_uploader_container:
        # Create three columns for uploading train, val, and test files
        col1, col2, col3 = st.columns(3)

        with col1:
            utils.display_file_uploader(
                "uploaded_train_files",
                "Upload Training Images and Labels",
                st.session_state["file_uploader_train_key_training"],
                st.session_state["uploaded_files_cache_processing"],
            )

        with col2:
            utils.display_file_uploader(
                "uploaded_val_files",
                "Upload Validation Images and Labels",
                st.session_state["file_uploader_val_key_training"],
                st.session_state["uploaded_files_cache_processing"],
            )

        with col3:
            utils.display_file_uploader(
                "uploaded_test_files",
                "Upload Test Images and Labels",
                st.session_state["file_uploader_test_key_training"],
                st.session_state["uploaded_files_cache_processing"],
            )

    # Check for valid input
    pct_condition_check = (
        len(st.session_state["uploaded_train_files"]) > 0
        and len(st.session_state["uploaded_val_files"]) > 0
    )

    if not pct_condition_check:
        file_uploader_container.warning(
            "The train and validation set should not be empty.",
            icon="⚠️",
        )

    # Button to trigger validation
    if validate_button_col.button("Validate Input", use_container_width=True):
        st.session_state["validation_triggered"] = True

        st.session_state["is_valid"] = model_training_functions.check_valid_labels(
            st.session_state["uploaded_train_files"]
            + st.session_state["uploaded_val_files"]
            + st.session_state["uploaded_test_files"],
            label_type,
            class_dict,
        )

        if st.session_state["is_valid"]:
            model_training_functions.create_yolo_config_file(
                model_training_functions.get_path("config"),
                class_labels,
            )
            model_training_functions.clear_data_folders()
            model_training_functions.save_files_to_folder(
                st.session_state["uploaded_train_files"], "train"
            )
            model_training_functions.save_files_to_folder(
                st.session_state["uploaded_val_files"], "val"
            )

        # Only save test files if they are uploaded
        if len(st.session_state["uploaded_test_files"]) > 0:
            model_training_functions.save_files_to_folder(
                st.session_state["uploaded_test_files"], "test"
            )

    # Process files if input is valid
    if st.session_state["validation_triggered"] and (
        pct_condition_check and st.session_state["is_valid"]
    ):
        model_training_functions.start_yolo_training(selected_training, class_labels)
    else:
        # Display a warning message if the validation is not successful or conditions are not met
        st.warning(
            "Please upload valid input, select valid parameters, and click **Validate Input**.",
            icon="⚠️",
        )