Spaces:
Sleeping
Sleeping
import re | |
def generate_script_v8(dataset_code, task, model_size, epochs, batch_size): | |
# Extract the necessary information from the dataset code | |
api_key_match = re.search(r'api_key="(.*?)"', dataset_code) | |
workspace_match = re.search(r'workspace\("([^"]+)"\)', dataset_code) | |
project_name_match = re.search(r'project\("([^"]+)"\)', dataset_code) | |
version_number_match = re.search(r'version\((\d+)\)', dataset_code) | |
if not (api_key_match and workspace_match and project_name_match and version_number_match): | |
return "Error: Could not extract necessary information from the dataset code." | |
api_key = api_key_match.group(1) | |
workspace = workspace_match.group(1) | |
project_name = project_name_match.group(1) | |
version_number = int(version_number_match.group(1)) | |
# Determine the model type based on the selected task | |
model_type = "seg" if task == "Segmentation" else "cls" | |
# Generate the script | |
script = f""" | |
import yaml | |
from ultralytics import YOLO | |
from roboflow import Roboflow | |
import logging | |
import re | |
import threading | |
import time | |
from io import StringIO | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def auto_train(): | |
log_stream = StringIO() | |
log_handler = logging.StreamHandler(log_stream) | |
log_handler.setLevel(logging.INFO) | |
logger.addHandler(log_handler) | |
try: | |
api_key = "{api_key}" | |
workspace = "{workspace}" | |
project_name = "{project_name}" | |
version_number = {version_number} | |
# Load the Roboflow dataset | |
rf = Roboflow(api_key=api_key) | |
project = rf.workspace(workspace).project(project_name) | |
version = project.version(version_number) | |
dataset = version.download("yolov8") | |
# Modify the data structure | |
yaml_file_path = f'{{dataset.location}}/data.yaml' | |
with open(yaml_file_path, 'r') as file: | |
data = yaml.safe_load(file) | |
data['val'] = '../valid/images' | |
data['test'] = '../test/images' | |
data['train'] = '../train/images' | |
with open(yaml_file_path, 'w') as file: | |
yaml.safe_dump(data, file) | |
# Determine the model name based on the selected size and task | |
model_name = f"yolov8{model_size}-{model_type}.pt" | |
# Load and train the model | |
model = YOLO(model_name) | |
model.info() | |
# Function to read logs in real-time and update the Streamlit textbox | |
def update_logs(): | |
while getattr(threading.currentThread(), "do_run", True): | |
time.sleep(1) | |
log_stream.seek(0) | |
print(log_stream.read()) | |
# Start a thread to update logs in real-time | |
log_thread = threading.Thread(target=update_logs) | |
log_thread.start() | |
results = model.train(data=yaml_file_path, epochs={epochs}, imgsz=640, batch={batch_size}) | |
# Stop the log update thread | |
logger.removeHandler(log_handler) | |
log_thread.do_run = False | |
log_thread.join() | |
# Return the result path and logs | |
log_stream.seek(0) | |
log_output = log_stream.read() | |
print("Results Directory:", results.results_dir) | |
print("Final Training Logs:", log_output) | |
except Exception as e: | |
logger.error(f"An error occurred: {{e}}") | |
log_stream.seek(0) | |
log_output = log_stream.read() | |
print(f"Error: {{e}}") | |
print(log_output) | |
finally: | |
logger.removeHandler(log_handler) | |
if __name__ == "__main__": | |
auto_train() | |
""" | |
return script | |
def generate_script_v9(dataset_code, task, model_size, epochs, batch_size): | |
# Extract the necessary information from the dataset code | |
api_key_match = re.search(r'api_key="(.*?)"', dataset_code) | |
workspace_match = re.search(r'workspace\("([^"]+)"\)', dataset_code) | |
project_name_match = re.search(r'project\("([^"]+)"\)', dataset_code) | |
version_number_match = re.search(r'version\((\d+)\)', dataset_code) | |
if not (api_key_match and workspace_match and project_name_match and version_number_match): | |
return "Error: Could not extract necessary information from the dataset code." | |
api_key = api_key_match.group(1) | |
workspace = workspace_match.group(1) | |
project_name = project_name_match.group(1) | |
version_number = int(version_number_match.group(1)) | |
# Determine the model name based on the selected size and task | |
if task == "Segmentation": | |
model_name = f"gelan-c-seg.pt" if model_size == "c" else f"yolov9-{model_size}-seg.pt" | |
else: | |
model_name = f"yolov9-{model_size}.pt" | |
# Generate the script | |
script = f""" | |
!pip install roboflow | |
from roboflow import Roboflow | |
rf = Roboflow(api_key="{api_key}") | |
project = rf.workspace("{workspace}").project("{project_name}") | |
version = project.version({version_number}) | |
dataset = version.download("yolov9") | |
!python train.py \\ | |
--batch {batch_size} --epochs {epochs} --img 640 --device 0 --min-items 0 --close-mosaic 15 \\ | |
--data {{dataset.location}}/data.yaml \\ | |
--weights {{HOME}}/weights/{model_name} \\ | |
--cfg models/detect/{model_name.split('.')[0]}.yaml \\ | |
--hyp hyp.scratch-high.yaml | |
""" | |
return script | |
import streamlit as st | |
st.title("Auto Train Script Generator") | |
st.write("Generate a YOLO training script using a Roboflow dataset") | |
tab1, tab2 = st.tabs(["YOLOv8", "YOLOv9"]) | |
with tab1: | |
st.subheader("YOLOv8 Script Generator") | |
dataset_code = st.text_input("Roboflow Dataset Code", key="dataset_code_v8", placeholder="Paste your Roboflow dataset code here") | |
task = st.selectbox("Task", ["Object Detection", "Segmentation"], index=0, key="task_v8") | |
model_size = st.selectbox("Model Size", ["n", "s", "m", "l", "x"], index=0, key="model_size_v8") | |
epochs = st.selectbox("Epochs", [50, 100, 200, 300, 400, 500], index=3, key="epochs_v8") | |
batch_size = st.selectbox("Batch Size", [1, 2, 4, 8, 16, 32], index=0, key="batch_size_v8") | |
if st.button("Generate YOLOv8 Script"): | |
script = generate_script_v8(dataset_code, task, model_size, epochs, batch_size) | |
st.code(script, language="python") | |
with tab2: | |
st.subheader("YOLOv9 Script Generator") | |
dataset_code = st.text_input("Roboflow Dataset Code", key="dataset_code_v9", placeholder="Paste your Roboflow dataset code here") | |
task = st.selectbox("Task", ["Object Detection", "Segmentation"], index=0, key="task_v9") | |
model_size = st.selectbox("Model Size", ["t", "s", "m", "c", "e"], index=0, key="model_size_v9") | |
epochs = st.selectbox("Epochs", [50, 100, 200, 300, 400, 500], index=3, key="epochs_v9") | |
batch_size = st.selectbox("Batch Size", [1, 2, 4, 8, 16, 32], index=0, key="batch_size_v9") | |
if st.button("Generate YOLOv9 Script"): | |
script = generate_script_v9(dataset_code, task, model_size, epochs, batch_size) | |
st.code(script, language="python") |