Spaces:
Running
Running
import glob | |
import os | |
import streamlit as st | |
from datastore import ChromaStore | |
from embeddings import Embedding | |
from PIL import Image | |
from tqdm import tqdm | |
from utils import base64_to_image, image_to_base64 | |
##### Image database | |
root_dir = os.path.join(os.getcwd(), "data") | |
jpg_files = glob.glob(os.path.join(root_dir, "**", "*.jpg"), recursive=True) | |
IMAGE_DATABASE = [Image.open(f).resize((64, 64)) for f in jpg_files] | |
def display_image_database(): | |
image_database_expander = st.expander(label="Image Database") | |
with image_database_expander: | |
st.image(IMAGE_DATABASE) | |
def display_sample_images(): | |
sample_img_path = os.path.join(os.getcwd(), "sample_imgs") | |
sample_images = os.listdir(sample_img_path) | |
images = [] | |
for i, img in enumerate(sample_images): | |
images.append(Image.open(os.path.join(sample_img_path, img)).resize((64, 64))) | |
st.image(images) | |
def main(): | |
st.set_page_config(page_icon="๐ผ๏ธ", page_title="image-search-engine", layout="wide") | |
st.markdown( | |
"""<h1 style="text-align: center;">๐๏ธ Image Search Engine</h1>""", | |
unsafe_allow_html=True, | |
) | |
st.markdown( | |
"""<h3 style="text-align: center;">Image to Image search using transformer embeddings</h3>""", | |
unsafe_allow_html=True, | |
) | |
main_layout = st.columns(2) | |
with main_layout[0]: | |
with st.container(border=True, height=550): | |
st.markdown( | |
"""<h3 style="text-align: center;">Search</h3>""", | |
unsafe_allow_html=True, | |
) | |
upload_img = st.file_uploader( | |
label="Query Image", | |
accept_multiple_files=False, | |
type=["jpg", "png", "jpeg"], | |
) | |
submit = st.button(label="Submit") | |
display_sample_images() | |
with main_layout[1]: | |
with st.container(border=True, height=550): | |
st.markdown( | |
"""<h3 style="text-align: center;">Results</h3>""", | |
unsafe_allow_html=True, | |
) | |
top_k = st.slider(label="Search top k results", min_value=3, max_value=10) | |
if submit and upload_img: | |
## encode uplaoded img | |
query_embedding = Embedding.encode_image(Image.open(upload_img)) | |
## query vectorstore | |
vectorstore = ChromaStore(collection_name="image_store") | |
collection = vectorstore.create() | |
# print(collection) | |
# print(vectorstore.collection_info(collection)) | |
st.toast("Vectorstore loaded successfully", icon="โ ") | |
results = vectorstore.query( | |
collection, | |
query_embedding, | |
top_k=top_k, | |
) | |
## show results | |
res_images = [] | |
for res in tqdm(results, desc="Results"): | |
res_images.append(res[0]) | |
st.image(res_images) | |
else: | |
st.warning("Please upload an image") | |
display_image_database() | |
if __name__ == "__main__": | |
main() | |