import re def generate_script_v8(dataset_code, task, model_size, epochs, batch_size): # Extract the necessary information from the dataset code api_key_match = re.search(r'api_key="(.*?)"', dataset_code) workspace_match = re.search(r'workspace\("([^"]+)"\)', dataset_code) project_name_match = re.search(r'project\("([^"]+)"\)', dataset_code) version_number_match = re.search(r'version\((\d+)\)', dataset_code) if not (api_key_match and workspace_match and project_name_match and version_number_match): return "Error: Could not extract necessary information from the dataset code." api_key = api_key_match.group(1) workspace = workspace_match.group(1) project_name = project_name_match.group(1) version_number = int(version_number_match.group(1)) # Determine the model type based on the selected task model_type = "seg" if task == "Segmentation" else "cls" # Generate the script script = f""" import yaml from ultralytics import YOLO from roboflow import Roboflow import logging import re import threading import time from io import StringIO # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def auto_train(): log_stream = StringIO() log_handler = logging.StreamHandler(log_stream) log_handler.setLevel(logging.INFO) logger.addHandler(log_handler) try: api_key = "{api_key}" workspace = "{workspace}" project_name = "{project_name}" version_number = {version_number} # Load the Roboflow dataset rf = Roboflow(api_key=api_key) project = rf.workspace(workspace).project(project_name) version = project.version(version_number) dataset = version.download("yolov8") # Modify the data structure yaml_file_path = f'{{dataset.location}}/data.yaml' with open(yaml_file_path, 'r') as file: data = yaml.safe_load(file) data['val'] = '../valid/images' data['test'] = '../test/images' data['train'] = '../train/images' with open(yaml_file_path, 'w') as file: yaml.safe_dump(data, file) # Determine the model name based on the selected size and task model_name = f"yolov8{model_size}-{model_type}.pt" # Load and train the model model = YOLO(model_name) model.info() # Function to read logs in real-time and update the Streamlit textbox def update_logs(): while getattr(threading.currentThread(), "do_run", True): time.sleep(1) log_stream.seek(0) print(log_stream.read()) # Start a thread to update logs in real-time log_thread = threading.Thread(target=update_logs) log_thread.start() results = model.train(data=yaml_file_path, epochs={epochs}, imgsz=640, batch={batch_size}) # Stop the log update thread logger.removeHandler(log_handler) log_thread.do_run = False log_thread.join() # Return the result path and logs log_stream.seek(0) log_output = log_stream.read() print("Results Directory:", results.results_dir) print("Final Training Logs:", log_output) except Exception as e: logger.error(f"An error occurred: {{e}}") log_stream.seek(0) log_output = log_stream.read() print(f"Error: {{e}}") print(log_output) finally: logger.removeHandler(log_handler) if __name__ == "__main__": auto_train() """ return script def generate_script_v9(dataset_code, task, model_size, epochs, batch_size): # Extract the necessary information from the dataset code api_key_match = re.search(r'api_key="(.*?)"', dataset_code) workspace_match = re.search(r'workspace\("([^"]+)"\)', dataset_code) project_name_match = re.search(r'project\("([^"]+)"\)', dataset_code) version_number_match = re.search(r'version\((\d+)\)', dataset_code) if not (api_key_match and workspace_match and project_name_match and version_number_match): return "Error: Could not extract necessary information from the dataset code." api_key = api_key_match.group(1) workspace = workspace_match.group(1) project_name = project_name_match.group(1) version_number = int(version_number_match.group(1)) # Determine the model name based on the selected size and task if task == "Segmentation": model_name = f"gelan-c-seg.pt" if model_size == "c" else f"yolov9-{model_size}-seg.pt" else: model_name = f"yolov9-{model_size}.pt" # Generate the script script = f""" !pip install roboflow from roboflow import Roboflow rf = Roboflow(api_key="{api_key}") project = rf.workspace("{workspace}").project("{project_name}") version = project.version({version_number}) dataset = version.download("yolov9") !python train.py \\ --batch {batch_size} --epochs {epochs} --img 640 --device 0 --min-items 0 --close-mosaic 15 \\ --data {{dataset.location}}/data.yaml \\ --weights {{HOME}}/weights/{model_name} \\ --cfg models/detect/{model_name.split('.')[0]}.yaml \\ --hyp hyp.scratch-high.yaml """ return script import streamlit as st st.title("Auto Train Script Generator") st.write("Generate a YOLO training script using a Roboflow dataset") tab1, tab2 = st.tabs(["YOLOv8", "YOLOv9"]) with tab1: st.subheader("YOLOv8 Script Generator") dataset_code = st.text_input("Roboflow Dataset Code", key="dataset_code_v8", placeholder="Paste your Roboflow dataset code here") task = st.selectbox("Task", ["Object Detection", "Segmentation"], index=0, key="task_v8") model_size = st.selectbox("Model Size", ["n", "s", "m", "l", "x"], index=0, key="model_size_v8") epochs = st.selectbox("Epochs", [50, 100, 200, 300, 400, 500], index=3, key="epochs_v8") batch_size = st.selectbox("Batch Size", [1, 2, 4, 8, 16, 32], index=0, key="batch_size_v8") if st.button("Generate YOLOv8 Script"): script = generate_script_v8(dataset_code, task, model_size, epochs, batch_size) st.code(script, language="python") with tab2: st.subheader("YOLOv9 Script Generator") dataset_code = st.text_input("Roboflow Dataset Code", key="dataset_code_v9", placeholder="Paste your Roboflow dataset code here") task = st.selectbox("Task", ["Object Detection", "Segmentation"], index=0, key="task_v9") model_size = st.selectbox("Model Size", ["t", "s", "m", "c", "e"], index=0, key="model_size_v9") epochs = st.selectbox("Epochs", [50, 100, 200, 300, 400, 500], index=3, key="epochs_v9") batch_size = st.selectbox("Batch Size", [1, 2, 4, 8, 16, 32], index=0, key="batch_size_v9") if st.button("Generate YOLOv9 Script"): script = generate_script_v9(dataset_code, task, model_size, epochs, batch_size) st.code(script, language="python")