dmedhi's picture
modify: display image size 224 to 64
80175e6 verified
import glob
import os
import streamlit as st
from datastore import ChromaStore
from embeddings import Embedding
from PIL import Image
from tqdm import tqdm
from utils import base64_to_image, image_to_base64
##### Image database
root_dir = os.path.join(os.getcwd(), "data")
jpg_files = glob.glob(os.path.join(root_dir, "**", "*.jpg"), recursive=True)
IMAGE_DATABASE = [Image.open(f).resize((64, 64)) for f in jpg_files]
def display_image_database():
image_database_expander = st.expander(label="Image Database")
with image_database_expander:
st.image(IMAGE_DATABASE)
def display_sample_images():
sample_img_path = os.path.join(os.getcwd(), "sample_imgs")
sample_images = os.listdir(sample_img_path)
images = []
for i, img in enumerate(sample_images):
images.append(Image.open(os.path.join(sample_img_path, img)).resize((64, 64)))
st.image(images)
def main():
st.set_page_config(page_icon="๐Ÿ–ผ๏ธ", page_title="image-search-engine", layout="wide")
st.markdown(
"""<h1 style="text-align: center;">๐Ÿ”๏ธ Image Search Engine</h1>""",
unsafe_allow_html=True,
)
st.markdown(
"""<h3 style="text-align: center;">Image to Image search using transformer embeddings</h3>""",
unsafe_allow_html=True,
)
main_layout = st.columns(2)
with main_layout[0]:
with st.container(border=True, height=550):
st.markdown(
"""<h3 style="text-align: center;">Search</h3>""",
unsafe_allow_html=True,
)
upload_img = st.file_uploader(
label="Query Image",
accept_multiple_files=False,
type=["jpg", "png", "jpeg"],
)
submit = st.button(label="Submit")
display_sample_images()
with main_layout[1]:
with st.container(border=True, height=550):
st.markdown(
"""<h3 style="text-align: center;">Results</h3>""",
unsafe_allow_html=True,
)
top_k = st.slider(label="Search top k results", min_value=3, max_value=10)
if submit and upload_img:
## encode uplaoded img
query_embedding = Embedding.encode_image(Image.open(upload_img))
## query vectorstore
vectorstore = ChromaStore(collection_name="image_store")
collection = vectorstore.create()
# print(collection)
# print(vectorstore.collection_info(collection))
st.toast("Vectorstore loaded successfully", icon="โœ…")
results = vectorstore.query(
collection,
query_embedding,
top_k=top_k,
)
## show results
res_images = []
for res in tqdm(results, desc="Results"):
res_images.append(res[0])
st.image(res_images)
else:
st.warning("Please upload an image")
display_image_database()
if __name__ == "__main__":
main()