import glob import os import streamlit as st from datastore import ChromaStore from embeddings import Embedding from PIL import Image from tqdm import tqdm from utils import base64_to_image, image_to_base64 ##### Image database root_dir = os.path.join(os.getcwd(), "data") jpg_files = glob.glob(os.path.join(root_dir, "**", "*.jpg"), recursive=True) IMAGE_DATABASE = [Image.open(f).resize((64, 64)) for f in jpg_files] def display_image_database(): image_database_expander = st.expander(label="Image Database") with image_database_expander: st.image(IMAGE_DATABASE) def display_sample_images(): sample_img_path = os.path.join(os.getcwd(), "sample_imgs") sample_images = os.listdir(sample_img_path) images = [] for i, img in enumerate(sample_images): images.append(Image.open(os.path.join(sample_img_path, img)).resize((64, 64))) st.image(images) def main(): st.set_page_config(page_icon="🖼️", page_title="image-search-engine", layout="wide") st.markdown( """

🔍️ Image Search Engine

""", unsafe_allow_html=True, ) st.markdown( """

Image to Image search using transformer embeddings

""", unsafe_allow_html=True, ) main_layout = st.columns(2) with main_layout[0]: with st.container(border=True, height=550): st.markdown( """

Search

""", unsafe_allow_html=True, ) upload_img = st.file_uploader( label="Query Image", accept_multiple_files=False, type=["jpg", "png", "jpeg"], ) submit = st.button(label="Submit") display_sample_images() with main_layout[1]: with st.container(border=True, height=550): st.markdown( """

Results

""", unsafe_allow_html=True, ) top_k = st.slider(label="Search top k results", min_value=3, max_value=10) if submit and upload_img: ## encode uplaoded img query_embedding = Embedding.encode_image(Image.open(upload_img)) ## query vectorstore vectorstore = ChromaStore(collection_name="image_store") collection = vectorstore.create() # print(collection) # print(vectorstore.collection_info(collection)) st.toast("Vectorstore loaded successfully", icon="✅") results = vectorstore.query( collection, query_embedding, top_k=top_k, ) ## show results res_images = [] for res in tqdm(results, desc="Results"): res_images.append(res[0]) st.image(res_images) else: st.warning("Please upload an image") display_image_database() if __name__ == "__main__": main()