import glob
import os
import streamlit as st
from datastore import ChromaStore
from embeddings import Embedding
from PIL import Image
from tqdm import tqdm
from utils import base64_to_image, image_to_base64
##### Image database
root_dir = os.path.join(os.getcwd(), "data")
jpg_files = glob.glob(os.path.join(root_dir, "**", "*.jpg"), recursive=True)
IMAGE_DATABASE = [Image.open(f).resize((64, 64)) for f in jpg_files]
def display_image_database():
image_database_expander = st.expander(label="Image Database")
with image_database_expander:
st.image(IMAGE_DATABASE)
def display_sample_images():
sample_img_path = os.path.join(os.getcwd(), "sample_imgs")
sample_images = os.listdir(sample_img_path)
images = []
for i, img in enumerate(sample_images):
images.append(Image.open(os.path.join(sample_img_path, img)).resize((64, 64)))
st.image(images)
def main():
st.set_page_config(page_icon="🖼️", page_title="image-search-engine", layout="wide")
st.markdown(
"""
🔍️ Image Search Engine
""",
unsafe_allow_html=True,
)
st.markdown(
"""Image to Image search using transformer embeddings
""",
unsafe_allow_html=True,
)
main_layout = st.columns(2)
with main_layout[0]:
with st.container(border=True, height=550):
st.markdown(
"""Search
""",
unsafe_allow_html=True,
)
upload_img = st.file_uploader(
label="Query Image",
accept_multiple_files=False,
type=["jpg", "png", "jpeg"],
)
submit = st.button(label="Submit")
display_sample_images()
with main_layout[1]:
with st.container(border=True, height=550):
st.markdown(
"""Results
""",
unsafe_allow_html=True,
)
top_k = st.slider(label="Search top k results", min_value=3, max_value=10)
if submit and upload_img:
## encode uplaoded img
query_embedding = Embedding.encode_image(Image.open(upload_img))
## query vectorstore
vectorstore = ChromaStore(collection_name="image_store")
collection = vectorstore.create()
# print(collection)
# print(vectorstore.collection_info(collection))
st.toast("Vectorstore loaded successfully", icon="✅")
results = vectorstore.query(
collection,
query_embedding,
top_k=top_k,
)
## show results
res_images = []
for res in tqdm(results, desc="Results"):
res_images.append(res[0])
st.image(res_images)
else:
st.warning("Please upload an image")
display_image_database()
if __name__ == "__main__":
main()