Spaces:
Sleeping
Sleeping
import streamlit as st | |
from helper import load_hf_datasets, search, get_file_paths, get_images_from_s3_to_display | |
import os | |
import time | |
# Load environment variables | |
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID") | |
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") | |
# Predefined list of datasets | |
datasets = ["WayveScenes", "StopSign_test"] # Example dataset names | |
# AWS S3 bucket name | |
bucket_name = "datasets-quasara-io" | |
# Streamlit App | |
def main(): | |
st.title("Semantic Search and Image Display") | |
# Select dataset from dropdown | |
dataset_name = st.selectbox("Select Dataset", datasets) | |
if dataset_name == 'WayveScenes': | |
folder_path = 'WayveScenes/' | |
else: | |
folder_path = '' | |
# Progress bar for loading dataset | |
loading_text = st.empty() # Placeholder for dynamic text | |
loading_text.text("Loading dataset...") | |
progress_bar = st.progress(0) | |
# Simulate dataset loading progress | |
for i in range(0, 100, 25): | |
time.sleep(0.2) # Simulate work being done | |
progress_bar.progress(i + 25) | |
# Load the selected dataset | |
df = load_hf_datasets(dataset_name) | |
# Complete progress when loading is done | |
progress_bar.progress(100) | |
loading_text.text("Dataset loaded successfully!") | |
# Input search query | |
query = st.text_input("Enter your search query") | |
# Number of results to display | |
limit = st.number_input("Number of results to display", min_value=1, max_value=10, value=10) | |
# Search button | |
if st.button("Search"): | |
# Validate input | |
if not query: | |
st.warning("Please enter a search query.") | |
else: | |
# Progress bar for search | |
search_loading_text = st.empty() | |
search_loading_text.text("Performing search...") | |
search_progress_bar = st.progress(0) | |
# Simulate search progress (e.g., in 4 steps) | |
for i in range(0, 100, 25): | |
time.sleep(0.3) # Simulate work being done | |
search_progress_bar.progress(i + 25) | |
# Perform the search | |
results = search(query, df, limit, 0, "cosine", search_in_images=True, search_in_small_objects=False) | |
# Complete the search progress | |
search_progress_bar.progress(100) | |
search_loading_text.text("Search completed!") | |
# Get the S3 file paths of the top results | |
top_k_paths = get_file_paths(df, results) | |
# Display images from S3 | |
if top_k_paths: | |
st.write(f"Displaying top {len(top_k_paths)} results for query '{query}':") | |
get_images_from_s3_to_display(bucket_name, top_k_paths, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, folder_path) | |
else: | |
st.write("No results found.") | |
if __name__ == "__main__": | |
main() | |