|
|
|
|
|
import time |
|
from threading import Thread |
|
|
|
import pandas as pd |
|
|
|
from ultralytics import Explorer |
|
from ultralytics.utils import ROOT, SETTINGS |
|
from ultralytics.utils.checks import check_requirements |
|
|
|
check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.3")) |
|
|
|
import streamlit as st |
|
from streamlit_select import image_select |
|
|
|
|
|
def _get_explorer(): |
|
"""Initializes and returns an instance of the Explorer class.""" |
|
exp = Explorer(data=st.session_state.get("dataset"), model=st.session_state.get("model")) |
|
thread = Thread( |
|
target=exp.create_embeddings_table, kwargs={"force": st.session_state.get("force_recreate_embeddings")} |
|
) |
|
thread.start() |
|
progress_bar = st.progress(0, text="Creating embeddings table...") |
|
while exp.progress < 1: |
|
time.sleep(0.1) |
|
progress_bar.progress(exp.progress, text=f"Progress: {exp.progress * 100}%") |
|
thread.join() |
|
st.session_state["explorer"] = exp |
|
progress_bar.empty() |
|
|
|
|
|
def init_explorer_form(): |
|
"""Initializes an Explorer instance and creates embeddings table with progress tracking.""" |
|
datasets = ROOT / "cfg" / "datasets" |
|
ds = [d.name for d in datasets.glob("*.yaml")] |
|
models = [ |
|
"yolov8n.pt", |
|
"yolov8s.pt", |
|
"yolov8m.pt", |
|
"yolov8l.pt", |
|
"yolov8x.pt", |
|
"yolov8n-seg.pt", |
|
"yolov8s-seg.pt", |
|
"yolov8m-seg.pt", |
|
"yolov8l-seg.pt", |
|
"yolov8x-seg.pt", |
|
"yolov8n-pose.pt", |
|
"yolov8s-pose.pt", |
|
"yolov8m-pose.pt", |
|
"yolov8l-pose.pt", |
|
"yolov8x-pose.pt", |
|
] |
|
with st.form(key="explorer_init_form"): |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.selectbox("Select dataset", ds, key="dataset", index=ds.index("coco128.yaml")) |
|
with col2: |
|
st.selectbox("Select model", models, key="model") |
|
st.checkbox("Force recreate embeddings", key="force_recreate_embeddings") |
|
|
|
st.form_submit_button("Explore", on_click=_get_explorer) |
|
|
|
|
|
def query_form(): |
|
"""Sets up a form in Streamlit to initialize Explorer with dataset and model selection.""" |
|
with st.form("query_form"): |
|
col1, col2 = st.columns([0.8, 0.2]) |
|
with col1: |
|
st.text_input( |
|
"Query", |
|
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'", |
|
label_visibility="collapsed", |
|
key="query", |
|
) |
|
with col2: |
|
st.form_submit_button("Query", on_click=run_sql_query) |
|
|
|
|
|
def ai_query_form(): |
|
"""Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection.""" |
|
with st.form("ai_query_form"): |
|
col1, col2 = st.columns([0.8, 0.2]) |
|
with col1: |
|
st.text_input("Query", "Show images with 1 person and 1 dog", label_visibility="collapsed", key="ai_query") |
|
with col2: |
|
st.form_submit_button("Ask AI", on_click=run_ai_query) |
|
|
|
|
|
def find_similar_imgs(imgs): |
|
"""Initializes a Streamlit form for AI-based image querying with custom input.""" |
|
exp = st.session_state["explorer"] |
|
similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow") |
|
paths = similar.to_pydict()["im_file"] |
|
st.session_state["imgs"] = paths |
|
st.session_state["res"] = similar |
|
|
|
|
|
def similarity_form(selected_imgs): |
|
"""Initializes a form for AI-based image querying with custom input in Streamlit.""" |
|
st.write("Similarity Search") |
|
with st.form("similarity_form"): |
|
subcol1, subcol2 = st.columns([1, 1]) |
|
with subcol1: |
|
st.number_input( |
|
"limit", min_value=None, max_value=None, value=25, label_visibility="collapsed", key="limit" |
|
) |
|
|
|
with subcol2: |
|
disabled = not len(selected_imgs) |
|
st.write("Selected: ", len(selected_imgs)) |
|
st.form_submit_button( |
|
"Search", |
|
disabled=disabled, |
|
on_click=find_similar_imgs, |
|
args=(selected_imgs,), |
|
) |
|
if disabled: |
|
st.error("Select at least one image to search.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_sql_query(): |
|
"""Executes an SQL query and returns the results.""" |
|
st.session_state["error"] = None |
|
query = st.session_state.get("query") |
|
if query.rstrip().lstrip(): |
|
exp = st.session_state["explorer"] |
|
res = exp.sql_query(query, return_type="arrow") |
|
st.session_state["imgs"] = res.to_pydict()["im_file"] |
|
st.session_state["res"] = res |
|
|
|
|
|
def run_ai_query(): |
|
"""Execute SQL query and update session state with query results.""" |
|
if not SETTINGS["openai_api_key"]: |
|
st.session_state["error"] = ( |
|
'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."' |
|
) |
|
return |
|
st.session_state["error"] = None |
|
query = st.session_state.get("ai_query") |
|
if query.rstrip().lstrip(): |
|
exp = st.session_state["explorer"] |
|
res = exp.ask_ai(query) |
|
if not isinstance(res, pd.DataFrame) or res.empty: |
|
st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it." |
|
return |
|
st.session_state["imgs"] = res["im_file"].to_list() |
|
st.session_state["res"] = res |
|
|
|
|
|
def reset_explorer(): |
|
"""Resets the explorer to its initial state by clearing session variables.""" |
|
st.session_state["explorer"] = None |
|
st.session_state["imgs"] = None |
|
st.session_state["error"] = None |
|
|
|
|
|
def utralytics_explorer_docs_callback(): |
|
"""Resets the explorer to its initial state by clearing session variables.""" |
|
with st.container(border=True): |
|
st.image( |
|
"https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg", |
|
width=100, |
|
) |
|
st.markdown( |
|
"<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>", |
|
unsafe_allow_html=True, |
|
help=None, |
|
) |
|
st.link_button("Ultrlaytics Explorer API", "https://docs.ultralytics.com/datasets/explorer/") |
|
|
|
|
|
def layout(): |
|
"""Resets explorer session variables and provides documentation with a link to API docs.""" |
|
st.set_page_config(layout="wide", initial_sidebar_state="collapsed") |
|
st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True) |
|
|
|
if st.session_state.get("explorer") is None: |
|
init_explorer_form() |
|
return |
|
|
|
st.button(":arrow_backward: Select Dataset", on_click=reset_explorer) |
|
exp = st.session_state.get("explorer") |
|
col1, col2 = st.columns([0.75, 0.25], gap="small") |
|
imgs = [] |
|
if st.session_state.get("error"): |
|
st.error(st.session_state["error"]) |
|
else: |
|
if st.session_state.get("imgs"): |
|
imgs = st.session_state.get("imgs") |
|
else: |
|
imgs = exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"] |
|
st.session_state["res"] = exp.table.to_arrow() |
|
total_imgs, selected_imgs = len(imgs), [] |
|
with col1: |
|
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5) |
|
with subcol1: |
|
st.write("Max Images Displayed:") |
|
with subcol2: |
|
num = st.number_input( |
|
"Max Images Displayed", |
|
min_value=0, |
|
max_value=total_imgs, |
|
value=min(500, total_imgs), |
|
key="num_imgs_displayed", |
|
label_visibility="collapsed", |
|
) |
|
with subcol3: |
|
st.write("Start Index:") |
|
with subcol4: |
|
start_idx = st.number_input( |
|
"Start Index", |
|
min_value=0, |
|
max_value=total_imgs, |
|
value=0, |
|
key="start_index", |
|
label_visibility="collapsed", |
|
) |
|
with subcol5: |
|
reset = st.button("Reset", use_container_width=False, key="reset") |
|
if reset: |
|
st.session_state["imgs"] = None |
|
st.experimental_rerun() |
|
|
|
query_form() |
|
ai_query_form() |
|
if total_imgs: |
|
labels, boxes, masks, kpts, classes = None, None, None, None, None |
|
task = exp.model.task |
|
if st.session_state.get("display_labels"): |
|
labels = st.session_state.get("res").to_pydict()["labels"][start_idx : start_idx + num] |
|
boxes = st.session_state.get("res").to_pydict()["bboxes"][start_idx : start_idx + num] |
|
masks = st.session_state.get("res").to_pydict()["masks"][start_idx : start_idx + num] |
|
kpts = st.session_state.get("res").to_pydict()["keypoints"][start_idx : start_idx + num] |
|
classes = st.session_state.get("res").to_pydict()["cls"][start_idx : start_idx + num] |
|
imgs_displayed = imgs[start_idx : start_idx + num] |
|
selected_imgs = image_select( |
|
f"Total samples: {total_imgs}", |
|
images=imgs_displayed, |
|
use_container_width=False, |
|
|
|
labels=labels, |
|
classes=classes, |
|
bboxes=boxes, |
|
masks=masks if task == "segment" else None, |
|
kpts=kpts if task == "pose" else None, |
|
) |
|
|
|
with col2: |
|
similarity_form(selected_imgs) |
|
display_labels = st.checkbox("Labels", value=False, key="display_labels") |
|
utralytics_explorer_docs_callback() |
|
|
|
|
|
if __name__ == "__main__": |
|
layout() |
|
|