Spaces:
Sleeping
Sleeping
import sys | |
import argparse | |
import configparser | |
import pickle | |
import gradio as gr | |
import numpy as np | |
import torch | |
import clip | |
import annoy | |
CONFIG_PATH = "app.ini" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--pkl', type=str, help='input pickle produced by create_embedding.py') | |
parser.add_argument('--url', type=str, help='the base URL for the images') | |
args = parser.parse_args() | |
return args | |
def parse_config_file(): | |
config = configparser.ConfigParser() | |
config.read(CONFIG_PATH) | |
config_args = argparse.Namespace(**config['DEFAULT']) | |
return config_args | |
if len(sys.argv) == 1: | |
print(f"no command line arguments, using {CONFIG_PATH}") | |
args = parse_config_file() | |
else: | |
print("using command line arguments, ignoring ini file") | |
args = parse_args() | |
assert "pkl" in args and args.pkl is not None | |
assert "url" in args and args.url is not None | |
assert args.url.endswith("/") | |
print("arguments:", args) | |
pickle_filename, base_url = args.pkl, args.url | |
data = pickle.load(open(pickle_filename, "rb")) | |
# the data might be float16 so that the pkl is small, | |
# but we use float32 in-memory to avoid numerical issues. | |
# tbh i'm not sure there are any such issues. | |
embeddings = data["embeddings"].astype(np.float32) | |
embeddings /= np.linalg.norm(embeddings, axis=-1)[:, None] | |
n, d = embeddings.shape | |
def build_ann_index(embeddings): | |
print("annoy indexing") | |
n, d = embeddings.shape | |
annoy_index = annoy.AnnoyIndex(d, "angular") | |
for i, vec in enumerate(embeddings): | |
annoy_index.add_item(i, vec) | |
annoy_index.build(10) | |
print("done") | |
return annoy_index | |
filenames = data["filenames"] | |
def thumb_patch(filename): | |
prefix = "PhotoLibrary" | |
assert filename.startswith(prefix) | |
return prefix + ".thumbs" + filename[len(prefix): ] | |
print("patching filenames") | |
filenames = [thumb_patch(filename) for filename in filenames] | |
folders = ["/".join(filename.split("/")[:-1]) for filename in filenames] | |
# to make smart indexing possible: | |
folders = np.array(folders) | |
urls = [base_url + filename for filename in filenames] | |
urls = np.array(urls) | |
annoy_index = build_ann_index(embeddings) | |
model, preprocess = clip.load('RN50', device=device) | |
def embed_text(text): | |
tokens = clip.tokenize([text]).to(device) | |
with torch.no_grad(): | |
text_features = model.encode_text(tokens) | |
assert text_features.shape == (1, d) | |
text_features = text_features.cpu().numpy()[0] | |
text_features /= np.linalg.norm(text_features) | |
return text_features | |
def drop_same_folder(indices): | |
folder_list = folders[indices] | |
filled = set() | |
kept = [] | |
for indx, folder in zip(indices, folder_list): | |
if folder not in filled: | |
filled.add(folder) | |
kept.append(indx) | |
return kept | |
def features_to_gallery(features): | |
indices = annoy_index.get_nns_by_vector(features, n=500) | |
indices = drop_same_folder(indices)[:50] | |
top_urls = urls[indices] | |
return top_urls.tolist(), indices | |
def image_retrieval_from_text(text): | |
text_features = embed_text(text) | |
return features_to_gallery(text_features) | |
def image_retrieval_from_image(state, selected_locally): | |
if state is None or len(state) == 0: | |
return [], [] | |
selected = state[int(selected_locally)] | |
return features_to_gallery(embeddings[selected]) | |
def query_uploaded_image(uploaded_image): | |
image = preprocess(uploaded_image) | |
image_batch = torch.tensor(np.stack([image])).to(device) | |
with torch.no_grad(): | |
image_features = model.encode_image(image_batch).float() | |
image_features = image_features.cpu().numpy() | |
assert len(image_features) == 1 | |
image_features = image_features[0] | |
assert len(image_features) == d | |
return features_to_gallery(image_features) | |
def show_folder(state, selected_locally): | |
if state is None or len(state) == 0: | |
return [], [] | |
selected = state[int(selected_locally)] | |
target_folder = folders[selected] | |
indices = [] | |
# linear search | |
for i, folder in enumerate(folders): | |
if folder == target_folder: | |
indices.append(i) | |
top_urls = urls[indices] | |
return top_urls.tolist(), indices | |
with gr.Blocks(css="footer {visibility: hidden}") as demo: | |
state = gr.State() | |
with gr.Row(variant="compact"): | |
text = gr.Textbox( | |
label="Enter search query", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
).style(container=False) | |
text_query_button = gr.Button("Search").style(full_width=False) | |
with gr.Row(variant="compact"): | |
uploaded_image = gr.Image(tool="select", type="pil", show_label=False) | |
query_uploaded_image_button = gr.Button("Show similiar to uploaded") | |
gallery = gr.Gallery(label="Images", show_label=False, elem_id="gallery" | |
).style(columns=5, container=False) | |
with gr.Row(): | |
filename_textbox = gr.Textbox("", show_label=False).style(container=False) | |
with gr.Row(): | |
show_folder_button = gr.Button("Show folder of selected") | |
image_query_button = gr.Button("Show similar to selected") | |
selected = gr.Number(0, show_label=False, visible=False) | |
text_query_button.click(image_retrieval_from_text, [text], [gallery, state]) | |
image_query_button.click(image_retrieval_from_image, [state, selected], [gallery, state]) | |
show_folder_button.click(show_folder, [state, selected], [gallery, state]) | |
query_uploaded_image_button.click(query_uploaded_image, [uploaded_image], [gallery, state]) | |
def get_select_index(evt: gr.SelectData, state): | |
selected_locally = evt.index | |
selected = state[int(selected_locally)] | |
return selected_locally, filenames[selected] | |
gallery.select(get_select_index, [state], [selected, filename_textbox]) | |
if __name__ == "__main__": | |
demo.launch(share=False) | |