File size: 6,887 Bytes
5550d19
0409d7f
5550d19
0409d7f
 
 
 
 
 
 
 
 
 
 
 
 
 
617b1f3
 
 
0409d7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34de998
0409d7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5550d19
 
 
 
 
 
0409d7f
71f851a
5550d19
0409d7f
5550d19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617b1f3
 
5550d19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import re

def generate_script_v8(dataset_code, task, model_size, epochs, batch_size):
    # Extract the necessary information from the dataset code
    api_key_match = re.search(r'api_key="(.*?)"', dataset_code)
    workspace_match = re.search(r'workspace\("([^"]+)"\)', dataset_code)
    project_name_match = re.search(r'project\("([^"]+)"\)', dataset_code)
    version_number_match = re.search(r'version\((\d+)\)', dataset_code)

    if not (api_key_match and workspace_match and project_name_match and version_number_match):
        return "Error: Could not extract necessary information from the dataset code."

    api_key = api_key_match.group(1)
    workspace = workspace_match.group(1)
    project_name = project_name_match.group(1)
    version_number = int(version_number_match.group(1))

    # Determine the model type based on the selected task
    model_type = "seg" if task == "Segmentation" else "cls"

    # Generate the script
    script = f"""
import yaml
from ultralytics import YOLO
from roboflow import Roboflow
import logging
import re
import threading
import time
from io import StringIO

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def auto_train():
    log_stream = StringIO()
    log_handler = logging.StreamHandler(log_stream)
    log_handler.setLevel(logging.INFO)
    logger.addHandler(log_handler)

    try:
        api_key = "{api_key}"
        workspace = "{workspace}"
        project_name = "{project_name}"
        version_number = {version_number}

        # Load the Roboflow dataset
        rf = Roboflow(api_key=api_key)
        project = rf.workspace(workspace).project(project_name)
        version = project.version(version_number)
        dataset = version.download("yolov8")

        # Modify the data structure
        yaml_file_path = f'{{dataset.location}}/data.yaml'
        with open(yaml_file_path, 'r') as file:
            data = yaml.safe_load(file)
        data['val'] = '../valid/images'
        data['test'] = '../test/images'
        data['train'] = '../train/images'
        with open(yaml_file_path, 'w') as file:
            yaml.safe_dump(data, file)

        # Determine the model name based on the selected size and task
        model_name = f"yolov8{model_size}-{model_type}.pt"

        # Load and train the model
        model = YOLO(model_name)
        model.info()

        # Function to read logs in real-time and update the Streamlit textbox
        def update_logs():
            while getattr(threading.currentThread(), "do_run", True):
                time.sleep(1)
                log_stream.seek(0)
                print(log_stream.read())

        # Start a thread to update logs in real-time
        log_thread = threading.Thread(target=update_logs)
        log_thread.start()
        results = model.train(data=yaml_file_path, epochs={epochs}, imgsz=640, batch={batch_size})

        # Stop the log update thread
        logger.removeHandler(log_handler)
        log_thread.do_run = False
        log_thread.join()

        # Return the result path and logs
        log_stream.seek(0)
        log_output = log_stream.read()
        print("Results Directory:", results.results_dir)
        print("Final Training Logs:", log_output)
    except Exception as e:
        logger.error(f"An error occurred: {{e}}")
        log_stream.seek(0)
        log_output = log_stream.read()
        print(f"Error: {{e}}")
        print(log_output)
    finally:
        logger.removeHandler(log_handler)

if __name__ == "__main__":
    auto_train()
    """
    return script

def generate_script_v9(dataset_code, task, model_size, epochs, batch_size):
    # Extract the necessary information from the dataset code
    api_key_match = re.search(r'api_key="(.*?)"', dataset_code)
    workspace_match = re.search(r'workspace\("([^"]+)"\)', dataset_code)
    project_name_match = re.search(r'project\("([^"]+)"\)', dataset_code)
    version_number_match = re.search(r'version\((\d+)\)', dataset_code)

    if not (api_key_match and workspace_match and project_name_match and version_number_match):
        return "Error: Could not extract necessary information from the dataset code."

    api_key = api_key_match.group(1)
    workspace = workspace_match.group(1)
    project_name = project_name_match.group(1)
    version_number = int(version_number_match.group(1))

    # Determine the model name based on the selected size and task
    if task == "Segmentation":
        model_name = f"gelan-c-seg.pt" if model_size == "c" else f"yolov9-{model_size}-seg.pt"
    else:
        model_name = f"yolov9-{model_size}.pt"

    # Generate the script
    script = f"""
!pip install roboflow
from roboflow import Roboflow
rf = Roboflow(api_key="{api_key}")
project = rf.workspace("{workspace}").project("{project_name}")
version = project.version({version_number})
dataset = version.download("yolov9")
!python train.py \\
--batch {batch_size} --epochs {epochs} --img 640 --device 0 --min-items 0 --close-mosaic 15 \\
--data {{dataset.location}}/data.yaml \\
--weights {{HOME}}/weights/{model_name} \\
--cfg models/detect/{model_name.split('.')[0]}.yaml \\
--hyp hyp.scratch-high.yaml
    """
    return script

import streamlit as st

st.title("Auto Train Script Generator")
st.write("Generate a YOLO training script using a Roboflow dataset")

tab1, tab2 = st.tabs(["YOLOv8", "YOLOv9"])

with tab1:
    st.subheader("YOLOv8 Script Generator")
    dataset_code = st.text_input("Roboflow Dataset Code", key="dataset_code_v8", placeholder="Paste your Roboflow dataset code here")
    task = st.selectbox("Task", ["Object Detection", "Segmentation"], index=0, key="task_v8")
    model_size = st.selectbox("Model Size", ["n", "s", "m", "l", "x"], index=0, key="model_size_v8")
    epochs = st.selectbox("Epochs", [50, 100, 200, 300, 400, 500], index=3, key="epochs_v8")
    batch_size = st.selectbox("Batch Size", [1, 2, 4, 8, 16, 32], index=0, key="batch_size_v8")

    if st.button("Generate YOLOv8 Script"):
        script = generate_script_v8(dataset_code, task, model_size, epochs, batch_size)
        st.code(script, language="python")

with tab2:
    st.subheader("YOLOv9 Script Generator")
    dataset_code = st.text_input("Roboflow Dataset Code", key="dataset_code_v9", placeholder="Paste your Roboflow dataset code here")
    task = st.selectbox("Task", ["Object Detection", "Segmentation"], index=0, key="task_v9")
    model_size = st.selectbox("Model Size", ["t", "s", "m", "c", "e"], index=0, key="model_size_v9")
    epochs = st.selectbox("Epochs", [50, 100, 200, 300, 400, 500], index=3, key="epochs_v9")
    batch_size = st.selectbox("Batch Size", [1, 2, 4, 8, 16, 32], index=0, key="batch_size_v9")

    if st.button("Generate YOLOv9 Script"):
        script = generate_script_v9(dataset_code, task, model_size, epochs, batch_size)
        st.code(script, language="python")