import sys import argparse import configparser import pickle import gradio as gr import numpy as np import torch import clip import annoy CONFIG_PATH = "app.ini" device = "cuda" if torch.cuda.is_available() else "cpu" def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--pkl', type=str, help='input pickle produced by create_embedding.py') parser.add_argument('--url', type=str, help='the base URL for the images') args = parser.parse_args() return args def parse_config_file(): config = configparser.ConfigParser() config.read(CONFIG_PATH) config_args = argparse.Namespace(**config['DEFAULT']) return config_args if len(sys.argv) == 1: print(f"no command line arguments, using {CONFIG_PATH}") args = parse_config_file() else: print("using command line arguments, ignoring ini file") args = parse_args() assert "pkl" in args and args.pkl is not None assert "url" in args and args.url is not None assert args.url.endswith("/") print("arguments:", args) pickle_filename, base_url = args.pkl, args.url data = pickle.load(open(pickle_filename, "rb")) # the data might be float16 so that the pkl is small, # but we use float32 in-memory to avoid numerical issues. # tbh i'm not sure there are any such issues. embeddings = data["embeddings"].astype(np.float32) embeddings /= np.linalg.norm(embeddings, axis=-1)[:, None] n, d = embeddings.shape def build_ann_index(embeddings): print("annoy indexing") n, d = embeddings.shape annoy_index = annoy.AnnoyIndex(d, "angular") for i, vec in enumerate(embeddings): annoy_index.add_item(i, vec) annoy_index.build(10) print("done") return annoy_index filenames = data["filenames"] def thumb_patch(filename): prefix = "PhotoLibrary" assert filename.startswith(prefix) return prefix + ".thumbs" + filename[len(prefix): ] print("patching filenames") filenames = [thumb_patch(filename) for filename in filenames] folders = ["/".join(filename.split("/")[:-1]) for filename in filenames] # to make smart indexing possible: folders = np.array(folders) urls = [base_url + filename for filename in filenames] urls = np.array(urls) annoy_index = build_ann_index(embeddings) model, preprocess = clip.load('RN50', device=device) def embed_text(text): tokens = clip.tokenize([text]).to(device) with torch.no_grad(): text_features = model.encode_text(tokens) assert text_features.shape == (1, d) text_features = text_features.cpu().numpy()[0] text_features /= np.linalg.norm(text_features) return text_features def drop_same_folder(indices): folder_list = folders[indices] filled = set() kept = [] for indx, folder in zip(indices, folder_list): if folder not in filled: filled.add(folder) kept.append(indx) return kept def features_to_gallery(features): indices = annoy_index.get_nns_by_vector(features, n=500) indices = drop_same_folder(indices)[:50] top_urls = urls[indices] return top_urls.tolist(), indices def image_retrieval_from_text(text): text_features = embed_text(text) return features_to_gallery(text_features) def image_retrieval_from_image(state, selected_locally): if state is None or len(state) == 0: return [], [] selected = state[int(selected_locally)] return features_to_gallery(embeddings[selected]) def query_uploaded_image(uploaded_image): image = preprocess(uploaded_image) image_batch = torch.tensor(np.stack([image])).to(device) with torch.no_grad(): image_features = model.encode_image(image_batch).float() image_features = image_features.cpu().numpy() assert len(image_features) == 1 image_features = image_features[0] assert len(image_features) == d return features_to_gallery(image_features) def show_folder(state, selected_locally): if state is None or len(state) == 0: return [], [] selected = state[int(selected_locally)] target_folder = folders[selected] indices = [] # linear search for i, folder in enumerate(folders): if folder == target_folder: indices.append(i) top_urls = urls[indices] return top_urls.tolist(), indices with gr.Blocks(css="footer {visibility: hidden}") as demo: state = gr.State() with gr.Row(variant="compact"): text = gr.Textbox( label="Enter search query", show_label=False, max_lines=1, placeholder="Enter your prompt", ).style(container=False) text_query_button = gr.Button("Search").style(full_width=False) with gr.Row(variant="compact"): uploaded_image = gr.Image(tool="select", type="pil", show_label=False) query_uploaded_image_button = gr.Button("Show similiar to uploaded") gallery = gr.Gallery(label="Images", show_label=False, elem_id="gallery" ).style(columns=5, container=False) with gr.Row(): filename_textbox = gr.Textbox("", show_label=False).style(container=False) with gr.Row(): show_folder_button = gr.Button("Show folder of selected") image_query_button = gr.Button("Show similar to selected") selected = gr.Number(0, show_label=False, visible=False) text_query_button.click(image_retrieval_from_text, [text], [gallery, state]) image_query_button.click(image_retrieval_from_image, [state, selected], [gallery, state]) show_folder_button.click(show_folder, [state, selected], [gallery, state]) query_uploaded_image_button.click(query_uploaded_image, [uploaded_image], [gallery, state]) def get_select_index(evt: gr.SelectData, state): selected_locally = evt.index selected = state[int(selected_locally)] return selected_locally, filenames[selected] gallery.select(get_select_index, [state], [selected, filename_textbox]) if __name__ == "__main__": demo.launch(share=False)