import streamlit as st from annoy import AnnoyIndex from sentence_transformers import SentenceTransformer import json from PIL import Image import os import urllib st.set_page_config( page_title="BHL Flickr Image Search", page_icon="🖼️", layout="wide" ) @st.cache_resource def load_clip_model(): return SentenceTransformer('clip-ViT-B-32') @st.cache_resource def load_annoy_index(): annoy_index = AnnoyIndex(512, metric='angular') annoy_index.load('bhl_index.annoy') return annoy_index @st.cache_data def load_flickr_data(): with open('bhl_flickr_list.json') as json_in: bhl_flickr_ids = json.load(json_in) return bhl_flickr_ids def bhl_annoy_search(mode, query, k=5): if mode == 'id': for idx, row in enumerate(bhl_flickr_ids): if str(row['flickr_id']) == query: matching_row = idx neighbors = bhl_index.get_nns_by_item(matching_row, k, include_distances=True) elif mode == 'text': query_emb = model.encode([query], show_progress_bar=False) neighbors = bhl_index.get_nns_by_vector(query_emb[0], k, include_distances=True) elif mode == 'image': query_emb = model.encode([query], show_progress_bar=False) neighbors = bhl_index.get_nns_by_vector(query_emb[0], k, include_distances=True) return neighbors #DEPLOY_MODE = 'streamlit_share' DEPLOY_MODE = 'hf_spaces' #DEPLOY_MODE = 'localhost' if DEPLOY_MODE == 'localhost': BASE_URL = 'http://localhost:8501/' elif DEPLOY_MODE == 'streamlit_share': BASE_URL = 'https://share.streamlit.io/miketrizna/bhl_flickr_search' elif DEPLOY_MODE == 'hf_spaces': BASE_URL = 'https://huggingface.co/spaces/MikeTrizna/bhl_flickr_search' if __name__ == "__main__": st.markdown("# BHL Flickr Image Search") with st.expander("How does this work?", expanded=False): st.write('placeholder') st.sidebar.markdown('### Search Mode') query_params = st.experimental_get_query_params() mode_index = 0 if 'mode' in query_params: if query_params['mode'][0] == 'text_search': mode_index = 0 elif query_params['mode'][0] == 'flickr_id': mode_index = 2 app_mode = st.sidebar.radio("How would you like to search?", ['Text search','Upload Image', 'BHL Flickr ID'], index = mode_index) model = load_clip_model() bhl_index = load_annoy_index() bhl_flickr_ids = load_flickr_data() if app_mode == 'Text search': search_text = 'a watercolor illustration of an insect with flowers' if 'mode' in query_params: if query_params['mode'][0] == 'text_search': if 'query' in query_params: search_text = query_params['query'][0] else: st.experimental_set_query_params() query = st.text_input('Text query',search_text) search_mode = 'text' #closest_k_idx, closest_k_dist = bhl_text_search(text_query, 100) elif app_mode == 'BHL Flickr ID': search_id = '5974846748' if 'mode' in st.experimental_get_query_params(): if st.experimental_get_query_params()['mode'][0] == 'flickr_id': if 'query' in st.experimental_get_query_params(): search_id = st.experimental_get_query_params()['query'][0] else: st.experimental_set_query_params() query = st.text_input('Query ID', search_id) search_mode = 'id' #closest_k_idx, closest_k_dist = bhl_id_search(id_query, 100) elif app_mode == 'Upload Image': st.experimental_set_query_params() query = None image_file = st.file_uploader("Upload Image", type=["png","jpg","jpeg"]) search_mode = 'image' #closest_k_idx = [] if image_file is not None: query = Image.open(image_file) st.image(query,width=100,caption='Query image') #closest_k_idx, closest_k_dist = bhl_image_search(img, 100) if query: closest_k_idx, closest_k_dist = bhl_annoy_search(search_mode, query, 100) col_list = st.columns(5) if len(closest_k_idx): for idx, annoy_idx in enumerate(closest_k_idx): bhl_ids = bhl_flickr_ids[annoy_idx] bhl_url = f"https://live.staticflickr.com/{bhl_ids['server']}/{bhl_ids['flickr_id']}_{bhl_ids['secret']}.jpg" col_list[idx%5].image(bhl_url, use_column_width=True) flickr_url = f"https://www.flickr.com/photos/biodivlibrary/{bhl_ids['flickr_id']}/" neighbors_url = f"{BASE_URL}?mode=flickr_id&query={bhl_ids['flickr_id']}" link_html = f'Flickr Link | Neighbors' col_list[idx%5].markdown(link_html, unsafe_allow_html=True) col_list[idx%5].markdown("---")