Spaces:
Sleeping
Sleeping
File size: 6,860 Bytes
0409d7f 5550d19 0409d7f 5550d19 0409d7f 34de998 0409d7f 5550d19 0409d7f 5550d19 0409d7f 5550d19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import streamlit as st
import re
def generate_script_v8(dataset_code, task, model_size, epochs, batch_size):
# Extract the necessary information from the dataset code
api_key_match = re.search(r'api_key="(.*?)"', dataset_code)
workspace_match = re.search(r'workspace\("([^"]+)"\)', dataset_code)
project_name_match = re.search(r'project\("([^"]+)"\)', dataset_code)
version_number_match = re.search(r'version\((\d+)\)', dataset_code)
if not (api_key_match and workspace_match and project_name_match and version_number_match):
return "Error: Could not extract necessary information from the dataset code."
api_key = api_key_match.group(1)
workspace = workspace_match.group(1)
project_name = project_name_match.group(1)
version_number = int(version_number_match.group(1))
# Generate the script
script = f"""
import yaml
from ultralytics import YOLO
from roboflow import Roboflow
import logging
import re
import threading
import time
from io import StringIO
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def auto_train():
log_stream = StringIO()
log_handler = logging.StreamHandler(log_stream)
log_handler.setLevel(logging.INFO)
logger.addHandler(log_handler)
try:
api_key = "{api_key}"
workspace = "{workspace}"
project_name = "{project_name}"
version_number = {version_number}
# Load the Roboflow dataset
rf = Roboflow(api_key=api_key)
project = rf.workspace(workspace).project(project_name)
version = project.version(version_number)
dataset = version.download("yolov8")
# Modify the data structure
yaml_file_path = f'{{dataset.location}}/data.yaml'
with open(yaml_file_path, 'r') as file:
data = yaml.safe_load(file)
data['val'] = '../valid/images'
data['test'] = '../test/images'
data['train'] = '../train/images'
with open(yaml_file_path, 'w') as file:
yaml.safe_dump(data, file)
# Determine the model name based on the selected size and task
model_type = "seg" if task == "Segmentation" else "cls"
model_name = f"yolov8{model_size}-{model_type}.pt"
# Load and train the model
model = YOLO(model_name)
model.info()
# Function to read logs in real-time and update the Streamlit textbox
def update_logs():
while getattr(threading.currentThread(), "do_run", True):
time.sleep(1)
log_stream.seek(0)
print(log_stream.read())
# Start a thread to update logs in real-time
log_thread = threading.Thread(target=update_logs)
log_thread.start()
results = model.train(data=yaml_file_path, epochs={epochs}, imgsz=640, batch={batch_size})
# Stop the log update thread
logger.removeHandler(log_handler)
log_thread.do_run = False
log_thread.join()
# Return the result path and logs
log_stream.seek(0)
log_output = log_stream.read()
print("Results Directory:", results.results_dir)
print("Final Training Logs:", log_output)
except Exception as e:
logger.error(f"An error occurred: {{e}}")
log_stream.seek(0)
log_output = log_stream.read()
print(f"Error: {{e}}")
print(log_output)
finally:
logger.removeHandler(log_handler)
if __name__ == "__main__":
auto_train()
"""
return script
def generate_script_v9(dataset_code, task, model_size, epochs, batch_size):
# Extract the necessary information from the dataset code
api_key_match = re.search(r'api_key="(.*?)"', dataset_code)
workspace_match = re.search(r'workspace\("([^"]+)"\)', dataset_code)
project_name_match = re.search(r'project\("([^"]+)"\)', dataset_code)
version_number_match = re.search(r'version\((\d+)\)', dataset_code)
if not (api_key_match and workspace_match and project_name_match and version_number_match):
return "Error: Could not extract necessary information from the dataset code."
api_key = api_key_match.group(1)
workspace = workspace_match.group(1)
project_name = project_name_match.group(1)
version_number = int(version_number_match.group(1))
# Determine the model name based on the selected size and task
if task == "Segmentation":
model_name = f"gelan-c-seg.pt" if model_size == "c" else f"yolov9-{model_size}-seg.pt"
else:
model_name = f"yolov9-{model_size}.pt"
# Generate the script
script = f"""
!pip install roboflow
from roboflow import Roboflow
rf = Roboflow(api_key="{api_key}")
project = rf.workspace("{workspace}").project("{project_name}")
version = project.version({version_number})
dataset = version.download("yolov9")
!python train.py \\
--batch {batch_size} --epochs {epochs} --img 640 --device 0 --min-items 0 --close-mosaic 15 \\
--data {{dataset.location}}/data.yaml \\
--weights {{HOME}}/weights/{model_name} \\
--cfg models/detect/{model_name.split('.')[0]}.yaml \\
--hyp hyp.scratch-high.yaml
"""
return script
# Streamlit interface
st.title("Auto Train Script Generator")
st.write("Generate a YOLO training script using a Roboflow dataset")
tab1, tab2 = st.tabs(["YOLOv8", "YOLOv9"])
with tab1:
st.subheader("YOLOv8 Script Generator")
dataset_code = st.text_input("Roboflow Dataset Code", key="dataset_code_v8", placeholder="Paste your Roboflow dataset code here")
task = st.selectbox("Task", ["Object Detection", "Segmentation"], index=0, key="task_v8")
model_size = st.selectbox("Model Size", ["n", "s", "m", "l", "x"], index=0, key="model_size_v8")
epochs = st.selectbox("Epochs", [50, 100, 200, 300, 400, 500], index=3, key="epochs_v8")
batch_size = st.selectbox("Batch Size", [1, 2, 4, 8, 16, 32], index=0, key="batch_size_v8")
if st.button("Generate YOLOv8 Script"):
script = generate_script_v8(dataset_code, task, model_size, epochs, batch_size)
st.code(script, language="python")
with tab2:
st.subheader("YOLOv9 Script Generator")
dataset_code = st.text_input("Roboflow Dataset Code", key="dataset_code_v9", placeholder="Paste your Roboflow dataset code here")
task = st.selectbox("Task", ["Object Detection", "Segmentation"], index=0, key="task_v9")
model_size = st.selectbox("Model Size", ["t", "s", "m", "c", "e"], index=0, key="model_size_v9")
epochs = st.selectbox("Epochs", [50, 100, 200, 300, 400, 500], index=3, key="epochs_v9")
batch_size = st.selectbox("Batch Size", [1, 2, 4, 8, 16, 32], index=0, key="batch_size_v9")
if st.button("Generate YOLOv9 Script"):
script = generate_script_v9(dataset_code, task, model_size, epochs, batch_size)
st.code(script, language="python") |