MikeTrizna's picture
Changed deploy mode for url params to work
fa4049f
raw
history blame
5.1 kB
import streamlit as st
from annoy import AnnoyIndex
from sentence_transformers import SentenceTransformer
import json
from PIL import Image
import os
import urllib
st.set_page_config(
page_title="BHL Flickr Image Search",
page_icon="🖼️",
layout="wide"
)
@st.cache_resource
def load_clip_model():
return SentenceTransformer('clip-ViT-B-32')
@st.cache_resource
def load_annoy_index():
annoy_index = AnnoyIndex(512, metric='angular')
annoy_index.load('bhl_index.annoy')
return annoy_index
@st.cache_data
def load_flickr_data():
with open('bhl_flickr_list.json') as json_in:
bhl_flickr_ids = json.load(json_in)
return bhl_flickr_ids
def bhl_annoy_search(mode, query, k=5):
if mode == 'id':
for idx, row in enumerate(bhl_flickr_ids):
if str(row['flickr_id']) == query:
matching_row = idx
neighbors = bhl_index.get_nns_by_item(matching_row, k,
include_distances=True)
elif mode == 'text':
query_emb = model.encode([query], show_progress_bar=False)
neighbors = bhl_index.get_nns_by_vector(query_emb[0], k,
include_distances=True)
elif mode == 'image':
query_emb = model.encode([query], show_progress_bar=False)
neighbors = bhl_index.get_nns_by_vector(query_emb[0], k,
include_distances=True)
return neighbors
#DEPLOY_MODE = 'streamlit_share'
DEPLOY_MODE = 'hf_spaces'
#DEPLOY_MODE = 'localhost'
if DEPLOY_MODE == 'localhost':
BASE_URL = 'http://localhost:8501/'
elif DEPLOY_MODE == 'streamlit_share':
BASE_URL = 'https://share.streamlit.io/miketrizna/bhl_flickr_search'
elif DEPLOY_MODE == 'hf_spaces':
BASE_URL = 'https://huggingface.co/spaces/MikeTrizna/bhl_flickr_search'
if __name__ == "__main__":
st.markdown("# BHL Flickr Image Search")
with st.expander("How does this work?", expanded=False):
st.write('placeholder')
st.sidebar.markdown('### Search Mode')
query_params = st.experimental_get_query_params()
mode_index = 0
if 'mode' in query_params:
if query_params['mode'][0] == 'text_search':
mode_index = 0
elif query_params['mode'][0] == 'flickr_id':
mode_index = 2
app_mode = st.sidebar.radio("How would you like to search?",
['Text search','Upload Image', 'BHL Flickr ID'],
index = mode_index)
model = load_clip_model()
bhl_index = load_annoy_index()
bhl_flickr_ids = load_flickr_data()
if app_mode == 'Text search':
search_text = 'a watercolor illustration of an insect with flowers'
if 'mode' in query_params:
if query_params['mode'][0] == 'text_search':
if 'query' in query_params:
search_text = query_params['query'][0]
else:
st.experimental_set_query_params()
query = st.text_input('Text query',search_text)
search_mode = 'text'
#closest_k_idx, closest_k_dist = bhl_text_search(text_query, 100)
elif app_mode == 'BHL Flickr ID':
search_id = '5974846748'
if 'mode' in st.experimental_get_query_params():
if st.experimental_get_query_params()['mode'][0] == 'flickr_id':
if 'query' in st.experimental_get_query_params():
search_id = st.experimental_get_query_params()['query'][0]
else:
st.experimental_set_query_params()
query = st.text_input('Query ID', search_id)
search_mode = 'id'
#closest_k_idx, closest_k_dist = bhl_id_search(id_query, 100)
elif app_mode == 'Upload Image':
st.experimental_set_query_params()
query = None
image_file = st.file_uploader("Upload Image", type=["png","jpg","jpeg"])
search_mode = 'image'
#closest_k_idx = []
if image_file is not None:
query = Image.open(image_file)
st.image(query,width=100,caption='Query image')
#closest_k_idx, closest_k_dist = bhl_image_search(img, 100)
if query:
closest_k_idx, closest_k_dist = bhl_annoy_search(search_mode, query, 100)
col_list = st.columns(5)
if len(closest_k_idx):
for idx, annoy_idx in enumerate(closest_k_idx):
bhl_ids = bhl_flickr_ids[annoy_idx]
bhl_url = f"https://live.staticflickr.com/{bhl_ids['server']}/{bhl_ids['flickr_id']}_{bhl_ids['secret']}.jpg"
col_list[idx%5].image(bhl_url, use_column_width=True)
flickr_url = f"https://www.flickr.com/photos/biodivlibrary/{bhl_ids['flickr_id']}/"
neighbors_url = f"{BASE_URL}?mode=flickr_id&query={bhl_ids['flickr_id']}"
link_html = f'<a href="{flickr_url}" target="_blank">Flickr Link</a> | <a href="{neighbors_url}">Neighbors</a>'
col_list[idx%5].markdown(link_html, unsafe_allow_html=True)
col_list[idx%5].markdown("---")