inie2003's picture
added loading/progress bar
f6019ba verified
raw
history blame
2.87 kB
import streamlit as st
from helper import load_hf_datasets, search, get_file_paths, get_images_from_s3_to_display
import os
import time
# Load environment variables
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
# Predefined list of datasets
datasets = ["WayveScenes", "StopSign_test"] # Example dataset names
# AWS S3 bucket name
bucket_name = "datasets-quasara-io"
# Streamlit App
def main():
st.title("Semantic Search and Image Display")
# Select dataset from dropdown
dataset_name = st.selectbox("Select Dataset", datasets)
if dataset_name == 'WayveScenes':
folder_path = 'WayveScenes/'
else:
folder_path = ''
# Progress bar for loading dataset
loading_text = st.empty() # Placeholder for dynamic text
loading_text.text("Loading dataset...")
progress_bar = st.progress(0)
# Simulate dataset loading progress
for i in range(0, 100, 25):
time.sleep(0.2) # Simulate work being done
progress_bar.progress(i + 25)
# Load the selected dataset
df = load_hf_datasets(dataset_name)
# Complete progress when loading is done
progress_bar.progress(100)
loading_text.text("Dataset loaded successfully!")
# Input search query
query = st.text_input("Enter your search query")
# Number of results to display
limit = st.number_input("Number of results to display", min_value=1, max_value=10, value=10)
# Search button
if st.button("Search"):
# Validate input
if not query:
st.warning("Please enter a search query.")
else:
# Progress bar for search
search_loading_text = st.empty()
search_loading_text.text("Performing search...")
search_progress_bar = st.progress(0)
# Simulate search progress (e.g., in 4 steps)
for i in range(0, 100, 25):
time.sleep(0.3) # Simulate work being done
search_progress_bar.progress(i + 25)
# Perform the search
results = search(query, df, limit, 0, "cosine", search_in_images=True, search_in_small_objects=False)
# Complete the search progress
search_progress_bar.progress(100)
search_loading_text.text("Search completed!")
# Get the S3 file paths of the top results
top_k_paths = get_file_paths(df, results)
# Display images from S3
if top_k_paths:
st.write(f"Displaying top {len(top_k_paths)} results for query '{query}':")
get_images_from_s3_to_display(bucket_name, top_k_paths, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, folder_path)
else:
st.write("No results found.")
if __name__ == "__main__":
main()